RLModule APIs#

Note

Ray 2.40 uses RLlib’s new API stack by default. The Ray team has mostly completed transitioning algorithms, example scripts, and documentation to the new code base.

If you’re still using the old API stack, see New API stack migration guide for details on how to migrate.

RLModule specifications and configurations#

Single RLModuleSpec#

RLModuleSpec

Utility spec class to make constructing RLModules (in single-agent case) easier.

RLModuleSpec.build

Builds the RLModule from this spec.

RLModuleSpec.module_class

RLModuleSpec.observation_space

RLModuleSpec.action_space

RLModuleSpec.inference_only

RLModuleSpec.learner_only

RLModuleSpec.model_config

MultiRLModuleSpec#

MultiRLModuleSpec

A utility spec class to make it constructing MultiRLModules easier.

MultiRLModuleSpec.build

Builds either the MultiRLModule or a (single) sub-RLModule under module_id.

MultiRLModuleSpec.multi_rl_module_class: Type[MultiRLModule] = <class 'ray.rllib.core.rl_module.multi_rl_module.MultiRLModule'>

The class of the MultiRLModule to construct. By default, this is the base MultiRLModule class.

MultiRLModuleSpec.observation_space: gymnasium.Space | None = None

Optional global observation space for the MultiRLModule. Useful for shared network components that live only inside the MultiRLModule and don’t have their own ModuleID and own RLModule within self._rl_modules.

MultiRLModuleSpec.action_space: gymnasium.Space | None = None

Optional global action space for the MultiRLModule. Useful for shared network components that live only inside the MultiRLModule and don’t have their own ModuleID and own RLModule within self._rl_modules.

MultiRLModuleSpec.inference_only: bool | None = None

An optional global inference_only flag. If not set (None by default), considers the MultiRLModule to be inference_only=True, only if all submodules also have their own inference_only flags set to True.

MultiRLModuleSpec.model_config: dict | None = None

An optional global model_config dict. Useful to configure shared network components that only live inside the MultiRLModule and don’t have their own ModuleID and own RLModule within self._rl_modules.

MultiRLModuleSpec.rl_module_specs: RLModuleSpec | Dict[str, RLModuleSpec] = None

The module specs for each individual module. It can be either an RLModuleSpec used for all module_ids or a dictionary mapping from module IDs to RLModuleSpecs for each individual module.

DefaultModelConfig#

DefaultModelConfig

Dataclass to configure all default RLlib RLModules.

RLModule API#

Construction and setup#

RLModule

Base class for RLlib modules.

RLModule.observation_space

RLModule.action_space

RLModule.inference_only

RLModule.model_config

RLModule.setup

Sets up the components of the module.

RLModule.as_multi_rl_module

Returns a multi-agent wrapper around this module.

Forward methods#

Use the following three forward methods when you use RLModule from inside other classes and components. However, do NOT override them and leave them as-is in your custom subclasses. For defining your own forward behavior, override the private methods _forward (generic forward behavior for all phases) or, for more granularity, use _forward_exploration, _forward_inference, and _forward_train.

forward_exploration

DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.

forward_inference

DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.

forward_train

DO NOT OVERRIDE! Forward-pass during training called from the learner.

Override these private methods to define your custom model’s forward behavior. - _forward: generic forward behavior for all phases - _forward_exploration: for training sample collection - _forward_inference: for production deployments, greedy acting - _forward_train`: for computing loss function inputs

_forward

Generic forward pass method, used in all phases of training and evaluation.

_forward_exploration

Forward-pass used for action computation with exploration behavior.

_forward_inference

Forward-pass used for action computation without exploration behavior.

_forward_train

Forward-pass used before the loss computation (training).

Saving and restoring#

save_to_path

Saves the state of the implementing class (or state) to path.

restore_from_path

Restores the state of the implementing class from the given path.

from_checkpoint

Creates a new Checkpointable instance from the given location and returns it.

get_state

Returns the state dict of the module.

set_state

Sets the implementing class' state to the given state dict.

MultiRLModule API#

Constructor#

MultiRLModule

Base class for an RLModule that contains n sub-RLModules.

MultiRLModule.setup

Sets up the underlying, individual RLModules.

MultiRLModule.as_multi_rl_module

Returns self in order to match RLModule.as_multi_rl_module() behavior.

Modifying the underlying RLModules#

add_module

Adds a module at run time to the multi-agent module.

remove_module

Removes a module at runtime from the multi-agent module.

Saving and restoring#

save_to_path

Saves the state of the implementing class (or state) to path.

restore_from_path

Restores the state of the implementing class from the given path.

from_checkpoint

Creates a new Checkpointable instance from the given location and returns it.

get_state

Returns the state dict of the module.

set_state

Sets the state of the multi-agent module.

Additional RLModule APIs#

InferenceOnlyAPI#

class ray.rllib.core.rl_module.apis.inference_only_api.InferenceOnlyAPI[source]#

An API to be implemented by RLModules that have an inference-only mode.

Only the get_non_inference_attributes method needs to get implemented for an RLModule to have the following functionality: - On EnvRunners (or when self.inference_only=True), RLlib will remove those parts of the model not required for action computation. - An RLModule on a Learner (where self.inference_only=False) will return only those weights from get_state() that are part of its inference-only version, thus possibly saving network traffic/time.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

abstract get_non_inference_attributes() List[str][source]#

Returns a list of attribute names (str) of components NOT used for inference.

RLlib will use this information to remove those attributes/components from an RLModule, whose config.inference_only is set to True. This so-called “inference-only setup” is activated. Normally, all RLModules located on EnvRunners are constructed this way (because they are only used for computing actions). Similarly, when deployed into a production environment, users should consider building their RLModules with this flag set to True as well.

For example:

from ray.rllib.core.rl_module.rl_module import RLModuleSpec

spec = RLModuleSpec(module_class=..., inference_only=True)

If an RLModule has the following setup() implementation:

class MyRLModule(RLModule):

    def setup(self):
        self._policy_head = [some NN component]
        self._value_function_head = [some NN component]

        self._encoder = [some NN component with attributes: pol and vf
                         (policy- and value func. encoder)]

Then its get_non_inference_attributes() should return: [“_value_function_head”, “_encoder.vf”].

Note the “.” notation to separate attributes and their sub-attributes in case you need more fine-grained control over which exact sub-attributes to exclude in the inference-only setup.

Returns:

A list of names (str) of those attributes (or sub-attributes) that should be excluded (deleted) from this RLModule in case it’s setup in inference_only mode.

QNetAPI#

class ray.rllib.core.rl_module.apis.q_net_api.QNetAPI[source]#

An API to be implemented by RLModules used for (distributional) Q-learning.

RLModules implementing this API must override the compute_q_values and the compute_advantage_distribution methods.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

abstract compute_q_values(batch: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][source]#

Computes Q-values, given encoder, q-net and (optionally), advantage net.

Note, these can be accompanied by logits and probabilities in case of distributional Q-learning, i.e. self.num_atoms > 1.

Parameters:

batch – The batch received in the forward pass.

Results:

A dictionary containing the Q-value predictions (“qf_preds”) and in case of distributional Q-learning - in addition to the Q-value predictions (“qf_preds”) - the support atoms (“atoms”), the Q-logits (“qf_logits”), and the probabilities (“qf_probs”).

compute_advantage_distribution(batch: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][source]#

Computes the advantage distribution.

Note this distribution is identical to the Q-distribution in case no dueling architecture is used.

Parameters:

batch – A dictionary containing a tensor with the outputs of the forward pass of the Q-head or advantage stream head.

Returns:

A dict containing the support of the discrete distribution for either Q-values or advantages (in case of a dueling architecture), (“atoms”), the logits per action and atom and the probabilities of the discrete distribution (per action and atom of the support).

SelfSupervisedLossAPI#

class ray.rllib.core.rl_module.apis.self_supervised_loss_api.SelfSupervisedLossAPI[source]#

An API to be implemented by RLModules that bring their own self-supervised loss.

Learners will call these model’s compute_self_supervised_loss() method instead of the Learner’s own compute_loss_for_module() method. The call signature is identical to the Learner’s compute_loss_for_module() method except of an additional mandatory learner kwarg.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

abstract compute_self_supervised_loss(*, learner: Learner, module_id: str, config: AlgorithmConfig, batch: Dict[str, Any], fwd_out: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor[source]#

Computes the loss for a single module.

Think of this as computing loss for a single agent. For multi-agent use-cases that require more complicated computation for loss, consider overriding the compute_losses method instead.

Parameters:
  • learner – The Learner calling this loss method on the RLModule.

  • module_id – The ID of the RLModule (within a MultiRLModule).

  • config – The AlgorithmConfig specific to the given module_id.

  • batch – The sample batch for this particular RLModule.

  • fwd_out – The output of the forward pass for this particular RLModule.

Returns:

A single total loss tensor. If you have more than one optimizer on the provided module_id and would like to compute gradients separately using these different optimizers, simply add up the individual loss terms for each optimizer and return the sum. Also, for recording/logging any individual loss terms, you can use the Learner.metrics.log_value( key=..., value=...) or Learner.metrics.log_dict() APIs. See: MetricsLogger for more information.

TargetNetworkAPI#

class ray.rllib.core.rl_module.apis.target_network_api.TargetNetworkAPI[source]#

An API to be implemented by RLModules for handling target networks.

RLModules implementing this API must override the make_target_networks, get_target_network_pairs, and the forward_target methods.

Note that the respective Learner that owns the implementing RLModule handles all target syncing logic.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

abstract make_target_networks() None[source]#

Creates the required target nets for this RLModule.

Use the convenience ray.rllib.core.learner.utils.make_target_network() utility when implementing this method. Pass in an already existing, corresponding “main” net (for which you need a target net). This function already takes care of initialization (from the “main” net).

abstract get_target_network_pairs() List[Tuple[torch.nn.Module | tf.keras.Model, torch.nn.Module | tf.keras.Model]][source]#

Returns a list of 2-tuples of (main_net, target_net).

For example, if your RLModule has a property: self.q_net and this network has a corresponding target net self.target_q_net, return from this (overridden) method: [(self.q_net, self.target_q_net)].

Note that you need to create all target nets in your overridden make_target_networks method and store the target nets in any properly of your choice.

Returns:

A list of 2-tuples of (main_net, target_net)

abstract forward_target(batch: Dict[str, Any]) Dict[str, Any][source]#

Performs the forward pass through the target net(s).

Parameters:

batch – The batch to use for the forward pass.

Returns:

The results from the forward pass(es) through the target net(s).

ValueFunctionAPI#

class ray.rllib.core.rl_module.apis.value_function_api.ValueFunctionAPI[source]#

An API to be implemented by RLModules for handling value function-based learning.

RLModules implementing this API must override the compute_values method.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

abstract compute_values(batch: Dict[str, Any], embeddings: Any | None = None) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor[source]#

Computes the value estimates given batch.

Parameters:
  • batch – The batch to compute value function estimates for.

  • embeddings – Optional embeddings already computed from the batch (by another forward pass through the model’s encoder (or other subcomponent that computes an embedding). For example, the caller of thie method should provide embeddings - if available - to avoid duplicate passes through a shared encoder.

Returns:

A tensor of shape (B,) or (B, T) (in case the input batch has a time dimension. Note that the last value dimension should already be squeezed out (not 1!).