RLModule APIs#
Note
Ray 2.40 uses RLlib’s new API stack by default. The Ray team has mostly completed transitioning algorithms, example scripts, and documentation to the new code base.
If you’re still using the old API stack, see New API stack migration guide for details on how to migrate.
RLModule specifications and configurations#
Single RLModuleSpec#
Utility spec class to make constructing RLModules (in single-agent case) easier. |
|
Builds the RLModule from this spec. |
|
MultiRLModuleSpec#
A utility spec class to make it constructing MultiRLModules easier. |
|
Builds either the MultiRLModule or a (single) sub-RLModule under |
- MultiRLModuleSpec.multi_rl_module_class: Type[MultiRLModule] = <class 'ray.rllib.core.rl_module.multi_rl_module.MultiRLModule'>
The class of the MultiRLModule to construct. By default, this is the base
MultiRLModule
class.
- MultiRLModuleSpec.observation_space: gymnasium.Space | None = None
Optional global observation space for the MultiRLModule. Useful for shared network components that live only inside the MultiRLModule and don’t have their own ModuleID and own RLModule within
self._rl_modules
.
- MultiRLModuleSpec.action_space: gymnasium.Space | None = None
Optional global action space for the MultiRLModule. Useful for shared network components that live only inside the MultiRLModule and don’t have their own ModuleID and own RLModule within
self._rl_modules
.
- MultiRLModuleSpec.inference_only: bool | None = None
An optional global inference_only flag. If not set (None by default), considers the MultiRLModule to be inference_only=True, only if all submodules also have their own inference_only flags set to True.
- MultiRLModuleSpec.model_config: dict | None = None
An optional global model_config dict. Useful to configure shared network components that only live inside the MultiRLModule and don’t have their own ModuleID and own RLModule within
self._rl_modules
.
- MultiRLModuleSpec.rl_module_specs: RLModuleSpec | Dict[str, RLModuleSpec] = None
The module specs for each individual module. It can be either an RLModuleSpec used for all module_ids or a dictionary mapping from module IDs to RLModuleSpecs for each individual module.
DefaultModelConfig#
Dataclass to configure all default RLlib RLModules. |
RLModule API#
Construction and setup#
Base class for RLlib modules. |
|
Sets up the components of the module. |
|
Returns a multi-agent wrapper around this module. |
Forward methods#
Use the following three forward methods when you use RLModule from inside other classes
and components. However, do NOT override them and leave them as-is in your custom subclasses.
For defining your own forward behavior, override the private methods _forward
(generic forward behavior for
all phases) or, for more granularity, use _forward_exploration
, _forward_inference
, and _forward_train
.
DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler. |
|
DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler. |
|
DO NOT OVERRIDE! Forward-pass during training called from the learner. |
Override these private methods to define your custom model’s forward behavior.
- _forward
: generic forward behavior for all phases
- _forward_exploration
: for training sample collection
- _forward_inference
: for production deployments, greedy acting
- _forward_train`
: for computing loss function inputs
Generic forward pass method, used in all phases of training and evaluation. |
|
Forward-pass used for action computation with exploration behavior. |
|
Forward-pass used for action computation without exploration behavior. |
|
Forward-pass used before the loss computation (training). |
Saving and restoring#
Saves the state of the implementing class (or |
|
Restores the state of the implementing class from the given path. |
|
Creates a new Checkpointable instance from the given location and returns it. |
|
Returns the state dict of the module. |
|
Sets the implementing class' state to the given state dict. |
MultiRLModule API#
Constructor#
Base class for an RLModule that contains n sub-RLModules. |
|
Sets up the underlying, individual RLModules. |
|
Returns self in order to match |
Modifying the underlying RLModules#
Adds a module at run time to the multi-agent module. |
|
Removes a module at runtime from the multi-agent module. |
Saving and restoring#
Saves the state of the implementing class (or |
|
Restores the state of the implementing class from the given path. |
|
Creates a new Checkpointable instance from the given location and returns it. |
|
Returns the state dict of the module. |
|
Sets the state of the multi-agent module. |
Additional RLModule APIs#
InferenceOnlyAPI#
- class ray.rllib.core.rl_module.apis.inference_only_api.InferenceOnlyAPI[source]#
An API to be implemented by RLModules that have an inference-only mode.
Only the
get_non_inference_attributes
method needs to get implemented for an RLModule to have the following functionality: - On EnvRunners (or when self.inference_only=True), RLlib will remove those parts of the model not required for action computation. - An RLModule on a Learner (whereself.inference_only=False
) will return only those weights fromget_state()
that are part of its inference-only version, thus possibly saving network traffic/time.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract get_non_inference_attributes() List[str] [source]#
Returns a list of attribute names (str) of components NOT used for inference.
RLlib will use this information to remove those attributes/components from an RLModule, whose
config.inference_only
is set to True. This so-called “inference-only setup” is activated. Normally, all RLModules located on EnvRunners are constructed this way (because they are only used for computing actions). Similarly, when deployed into a production environment, users should consider building their RLModules with this flag set to True as well.For example:
from ray.rllib.core.rl_module.rl_module import RLModuleSpec spec = RLModuleSpec(module_class=..., inference_only=True)
If an RLModule has the following setup() implementation:
class MyRLModule(RLModule): def setup(self): self._policy_head = [some NN component] self._value_function_head = [some NN component] self._encoder = [some NN component with attributes: pol and vf (policy- and value func. encoder)]
Then its
get_non_inference_attributes()
should return: [“_value_function_head”, “_encoder.vf”].Note the “.” notation to separate attributes and their sub-attributes in case you need more fine-grained control over which exact sub-attributes to exclude in the inference-only setup.
- Returns:
A list of names (str) of those attributes (or sub-attributes) that should be excluded (deleted) from this RLModule in case it’s setup in
inference_only
mode.
QNetAPI#
- class ray.rllib.core.rl_module.apis.q_net_api.QNetAPI[source]#
An API to be implemented by RLModules used for (distributional) Q-learning.
RLModules implementing this API must override the
compute_q_values
and thecompute_advantage_distribution
methods.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract compute_q_values(batch: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [source]#
Computes Q-values, given encoder, q-net and (optionally), advantage net.
Note, these can be accompanied by logits and probabilities in case of distributional Q-learning, i.e.
self.num_atoms > 1
.- Parameters:
batch – The batch received in the forward pass.
- Results:
A dictionary containing the Q-value predictions (“qf_preds”) and in case of distributional Q-learning - in addition to the Q-value predictions (“qf_preds”) - the support atoms (“atoms”), the Q-logits (“qf_logits”), and the probabilities (“qf_probs”).
- compute_advantage_distribution(batch: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor] [source]#
Computes the advantage distribution.
Note this distribution is identical to the Q-distribution in case no dueling architecture is used.
- Parameters:
batch – A dictionary containing a tensor with the outputs of the forward pass of the Q-head or advantage stream head.
- Returns:
A
dict
containing the support of the discrete distribution for either Q-values or advantages (in case of a dueling architecture), (“atoms”), the logits per action and atom and the probabilities of the discrete distribution (per action and atom of the support).
SelfSupervisedLossAPI#
- class ray.rllib.core.rl_module.apis.self_supervised_loss_api.SelfSupervisedLossAPI[source]#
An API to be implemented by RLModules that bring their own self-supervised loss.
Learners will call these model’s
compute_self_supervised_loss()
method instead of the Learner’s owncompute_loss_for_module()
method. The call signature is identical to the Learner’scompute_loss_for_module()
method except of an additional mandatorylearner
kwarg.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract compute_self_supervised_loss(*, learner: Learner, module_id: str, config: AlgorithmConfig, batch: Dict[str, Any], fwd_out: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor [source]#
Computes the loss for a single module.
Think of this as computing loss for a single agent. For multi-agent use-cases that require more complicated computation for loss, consider overriding the
compute_losses
method instead.- Parameters:
learner – The Learner calling this loss method on the RLModule.
module_id – The ID of the RLModule (within a MultiRLModule).
config – The AlgorithmConfig specific to the given
module_id
.batch – The sample batch for this particular RLModule.
fwd_out – The output of the forward pass for this particular RLModule.
- Returns:
A single total loss tensor. If you have more than one optimizer on the provided
module_id
and would like to compute gradients separately using these different optimizers, simply add up the individual loss terms for each optimizer and return the sum. Also, for recording/logging any individual loss terms, you can use theLearner.metrics.log_value( key=..., value=...)
orLearner.metrics.log_dict()
APIs. See:MetricsLogger
for more information.
TargetNetworkAPI#
- class ray.rllib.core.rl_module.apis.target_network_api.TargetNetworkAPI[source]#
An API to be implemented by RLModules for handling target networks.
RLModules implementing this API must override the
make_target_networks
,get_target_network_pairs
, and theforward_target
methods.Note that the respective Learner that owns the implementing RLModule handles all target syncing logic.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract make_target_networks() None [source]#
Creates the required target nets for this RLModule.
Use the convenience
ray.rllib.core.learner.utils.make_target_network()
utility when implementing this method. Pass in an already existing, corresponding “main” net (for which you need a target net). This function already takes care of initialization (from the “main” net).
- abstract get_target_network_pairs() List[Tuple[torch.nn.Module | tf.keras.Model, torch.nn.Module | tf.keras.Model]] [source]#
Returns a list of 2-tuples of (main_net, target_net).
For example, if your RLModule has a property:
self.q_net
and this network has a corresponding target netself.target_q_net
, return from this (overridden) method: [(self.q_net, self.target_q_net)].Note that you need to create all target nets in your overridden
make_target_networks
method and store the target nets in any properly of your choice.- Returns:
A list of 2-tuples of (main_net, target_net)
ValueFunctionAPI#
- class ray.rllib.core.rl_module.apis.value_function_api.ValueFunctionAPI[source]#
An API to be implemented by RLModules for handling value function-based learning.
RLModules implementing this API must override the
compute_values
method.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- abstract compute_values(batch: Dict[str, Any], embeddings: Any | None = None) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor [source]#
Computes the value estimates given
batch
.- Parameters:
batch – The batch to compute value function estimates for.
embeddings – Optional embeddings already computed from the
batch
(by another forward pass through the model’s encoder (or other subcomponent that computes an embedding). For example, the caller of thie method should provideembeddings
- if available - to avoid duplicate passes through a shared encoder.
- Returns:
A tensor of shape (B,) or (B, T) (in case the input
batch
has a time dimension. Note that the last value dimension should already be squeezed out (not 1!).