RLModule APIs#
RLModule specifications and configurations#
Single RLModuleSpec#
Utility spec class to make constructing RLModules (in single-agent case) easier. |
|
Builds the RLModule from this spec. |
|
MultiRLModuleSpec#
A utility spec class to make it constructing MultiRLModules easier. |
|
Builds either the MultiRLModule or a (single) sub-RLModule under |
- MultiRLModuleSpec.multi_rl_module_class: Type[MultiRLModule] = <class 'ray.rllib.core.rl_module.multi_rl_module.MultiRLModule'>
The class of the MultiRLModule to construct. By default, this is the base
MultiRLModuleclass.
- MultiRLModuleSpec.observation_space: gymnasium.Space | None = None
Optional global observation space for the MultiRLModule. Useful for shared network components that live only inside the MultiRLModule and don’t have their own ModuleID and own RLModule within
self._rl_modules.
- MultiRLModuleSpec.action_space: gymnasium.Space | None = None
Optional global action space for the MultiRLModule. Useful for shared network components that live only inside the MultiRLModule and don’t have their own ModuleID and own RLModule within
self._rl_modules.
- MultiRLModuleSpec.inference_only: bool | None = None
An optional global inference_only flag. If not set (None by default), considers the MultiRLModule to be inference_only=True, only if all submodules also have their own inference_only flags set to True.
- MultiRLModuleSpec.model_config: dict | None = None
An optional global model_config dict. Useful to configure shared network components that only live inside the MultiRLModule and don’t have their own ModuleID and own RLModule within
self._rl_modules.
- MultiRLModuleSpec.rl_module_specs: RLModuleSpec | Dict[str, RLModuleSpec] = None
The module specs for each individual module. It can be either an RLModuleSpec used for all module_ids or a dictionary mapping from module IDs to RLModuleSpecs for each individual module.
DefaultModelConfig#
Dataclass to configure all default RLlib RLModules. |
RLModule API#
Construction and setup#
Base class for RLlib modules. |
|
Sets up the components of the module. |
|
Returns a multi-agent wrapper around this module. |
Forward methods#
Use the following three forward methods when you use RLModule from inside other classes
and components. However, do NOT override them and leave them as-is in your custom subclasses.
For defining your own forward behavior, override the private methods _forward (generic forward behavior for
all phases) or, for more granularity, use _forward_exploration, _forward_inference, and _forward_train.
DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler. |
|
DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler. |
|
DO NOT OVERRIDE! Forward-pass during training called from the learner. |
Override these private methods to define your custom model’s forward behavior.
- _forward: generic forward behavior for all phases
- _forward_exploration: for training sample collection
- _forward_inference: for production deployments, greedy acting
- _forward_train`: for computing loss function inputs
Generic forward pass method, used in all phases of training and evaluation. |
|
Forward-pass used for action computation with exploration behavior. |
|
Forward-pass used for action computation without exploration behavior. |
|
Forward-pass used before the loss computation (training). |
Saving and restoring#
Saves the state of the implementing class (or |
|
Restores the state of the implementing class from the given path. |
|
Creates a new Checkpointable instance from the given location and returns it. |
|
Returns the state dict of the module. |
|
Sets the implementing class' state to the given state dict. |
MultiRLModule API#
Constructor#
Base class for an RLModule that contains n sub-RLModules. |
|
Sets up the underlying, individual RLModules. |
|
Returns self in order to match |
Modifying the underlying RLModules#
Adds a module at run time to the multi-agent module. |
|
Removes a module at runtime from the multi-agent module. |
Saving and restoring#
Saves the state of the implementing class (or |
|
Restores the state of the implementing class from the given path. |
|
Creates a new Checkpointable instance from the given location and returns it. |
|
Returns the state dict of the module. |
|
Sets the state of the multi-agent module. |
Additional RLModule APIs#
InferenceOnlyAPI#
- class ray.rllib.core.rl_module.apis.inference_only_api.InferenceOnlyAPI[source]#
An API to be implemented by RLModules that have an inference-only mode.
Only the
get_non_inference_attributesmethod needs to get implemented for an RLModule to have the following functionality: - On EnvRunners (or when self.inference_only=True), RLlib will remove those parts of the model not required for action computation. - An RLModule on a Learner (whereself.inference_only=False) will return only those weights fromget_state()that are part of its inference-only version, thus possibly saving network traffic/time.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract get_non_inference_attributes() List[str][source]#
Returns a list of attribute names (str) of components NOT used for inference.
RLlib will use this information to remove those attributes/components from an RLModule, whose
config.inference_onlyis set to True. This so-called “inference-only setup” is activated. Normally, all RLModules located on EnvRunners are constructed this way (because they are only used for computing actions). Similarly, when deployed into a production environment, users should consider building their RLModules with this flag set to True as well.For example:
from ray.rllib.core.rl_module.rl_module import RLModuleSpec spec = RLModuleSpec(module_class=..., inference_only=True)
If an RLModule has the following setup() implementation:
class MyRLModule(RLModule): def setup(self): self._policy_head = [some NN component] self._value_function_head = [some NN component] self._encoder = [some NN component with attributes: pol and vf (policy- and value func. encoder)]
Then its
get_non_inference_attributes()should return: [“_value_function_head”, “_encoder.vf”].Note the “.” notation to separate attributes and their sub-attributes in case you need more fine-grained control over which exact sub-attributes to exclude in the inference-only setup.
- Returns:
A list of names (str) of those attributes (or sub-attributes) that should be excluded (deleted) from this RLModule in case it’s setup in
inference_onlymode.
QNetAPI#
- class ray.rllib.core.rl_module.apis.q_net_api.QNetAPI[source]#
An API to be implemented by RLModules used for (distributional) Q-learning.
RLModules implementing this API must override the
compute_q_valuesand thecompute_advantage_distributionmethods.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract compute_q_values(batch: Dict[str, NDArray[Any] | jnp.ndarray | tf.Tensor | torch.Tensor]) Dict[str, NDArray[Any] | jnp.ndarray | tf.Tensor | torch.Tensor][source]#
Computes Q-values, given encoder, q-net and (optionally), advantage net.
Note, these can be accompanied by logits and probabilities in case of distributional Q-learning, i.e.
self.num_atoms > 1.- Parameters:
batch – The batch received in the forward pass.
- Results:
A dictionary containing the Q-value predictions (“qf_preds”) and in case of distributional Q-learning - in addition to the Q-value predictions (“qf_preds”) - the support atoms (“atoms”), the Q-logits (“qf_logits”), and the probabilities (“qf_probs”).
- compute_advantage_distribution(batch: Dict[str, NDArray[Any] | jnp.ndarray | tf.Tensor | torch.Tensor]) Dict[str, NDArray[Any] | jnp.ndarray | tf.Tensor | torch.Tensor][source]#
Computes the advantage distribution.
Note this distribution is identical to the Q-distribution in case no dueling architecture is used.
- Parameters:
batch – A dictionary containing a tensor with the outputs of the forward pass of the Q-head or advantage stream head.
- Returns:
A
dictcontaining the support of the discrete distribution for either Q-values or advantages (in case of a dueling architecture), (“atoms”), the logits per action and atom and the probabilities of the discrete distribution (per action and atom of the support).
SelfSupervisedLossAPI#
- class ray.rllib.core.rl_module.apis.self_supervised_loss_api.SelfSupervisedLossAPI[source]#
An API to be implemented by RLModules that bring their own self-supervised loss.
Learners will call these model’s
compute_self_supervised_loss()method instead of the Learner’s owncompute_loss_for_module()method. The call signature is identical to the Learner’scompute_loss_for_module()method except of an additional mandatorylearnerkwarg.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract compute_self_supervised_loss(*, learner: Learner, module_id: str, config: AlgorithmConfig, batch: Dict[str, Any], fwd_out: Dict[str, NDArray[Any] | jnp.ndarray | tf.Tensor | torch.Tensor], **kwargs) NDArray[Any] | jnp.ndarray | tf.Tensor | torch.Tensor[source]#
Computes the loss for a single module.
Think of this as computing loss for a single agent. For multi-agent use-cases that require more complicated computation for loss, consider overriding the
compute_lossesmethod instead.- Parameters:
learner – The Learner calling this loss method on the RLModule.
module_id – The ID of the RLModule (within a MultiRLModule).
config – The AlgorithmConfig specific to the given
module_id.batch – The sample batch for this particular RLModule.
fwd_out – The output of the forward pass for this particular RLModule.
- Returns:
A single total loss tensor. If you have more than one optimizer on the provided
module_idand would like to compute gradients separately using these different optimizers, simply add up the individual loss terms for each optimizer and return the sum. Also, for recording/logging any individual loss terms, you can use theLearner.metrics.log_value( key=..., value=...)orLearner.metrics.log_dict()APIs. See:MetricsLoggerfor more information.
TargetNetworkAPI#
- class ray.rllib.core.rl_module.apis.target_network_api.TargetNetworkAPI[source]#
An API to be implemented by RLModules for handling target networks.
RLModules implementing this API must override the
make_target_networks,get_target_network_pairs, and theforward_targetmethods.Note that the respective Learner that owns the implementing RLModule handles all target syncing logic.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract make_target_networks() None[source]#
Creates the required target nets for this RLModule.
Use the convenience
ray.rllib.core.learner.utils.make_target_network()utility when implementing this method. Pass in an already existing, corresponding “main” net (for which you need a target net). This function already takes care of initialization (from the “main” net).
- abstract get_target_network_pairs() List[Tuple[torch.nn.Module | keras.Model, torch.nn.Module | keras.Model]][source]#
Returns a list of 2-tuples of (main_net, target_net).
For example, if your RLModule has a property:
self.q_netand this network has a corresponding target netself.target_q_net, return from this (overridden) method: [(self.q_net, self.target_q_net)].Note that you need to create all target nets in your overridden
make_target_networksmethod and store the target nets in any properly of your choice.- Returns:
A list of 2-tuples of (main_net, target_net)
ValueFunctionAPI#
- class ray.rllib.core.rl_module.apis.value_function_api.ValueFunctionAPI[source]#
An API to be implemented by RLModules for handling value function-based learning.
RLModules implementing this API must override the
compute_valuesmethod.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract compute_values(batch: Dict[str, Any], embeddings: Any | None = None) NDArray[Any] | jnp.ndarray | tf.Tensor | torch.Tensor[source]#
Computes the value estimates given
batch.- Parameters:
batch – The batch to compute value function estimates for.
embeddings – Optional embeddings already computed from the
batch(by another forward pass through the model’s encoder (or other subcomponent that computes an embedding). For example, the caller of thie method should provideembeddings- if available - to avoid duplicate passes through a shared encoder.
- Returns:
A tensor of shape (B,) or (B, T) (in case the input
batchhas a time dimension. Note that the last value dimension should already be squeezed out (not 1!).