RL Modules#

Note

Ray 2.40 uses RLlib’s new API stack by default. The Ray team has mostly completed transitioning algorithms, example scripts, and documentation to the new code base.

If you’re still using the old API stack, see New API stack migration guide for details on how to migrate.

The RLModule class in RLlib’s new API stack allows you to write custom models, including highly complex multi-network setups often found in multi-agent or model-based algorithms.

RLModule is the main neural network class and exposes three public methods, each corresponding to a distinct phase in the reinforcement learning cycle: - forward_exploration() handles the computation of actions during data collection if RLlib uses the data for a succeeding training step, balancing exploration and exploitation. - forward_inference() computes actions for evaluation and production, which often need to be greedy or less stochastic. - forward_train() manages the training phase, performing calculations required to compute losses, such as Q-values in a DQN model, value function predictions in a PG-style setup, or world-model predictions in model-based algorithms.

../_images/rl_module_overview.svg

RLModule overview: (left) A plain RLModule contains the neural network RLlib uses for computations, for example, a policy network written in PyTorch, and exposes the three forward methods: forward_exploration() for sample collection, forward_inference() for production/deployment, and forward_train() for computing loss function inputs when training. (right) A MultiRLModule may contain one or more sub-RLModules, each identified by a ModuleID, allowing you to implement arbitrarily complex multi-network or multi-agent architectures and algorithms.#

Enabling the RLModule API in the AlgorithmConfig#

In the new API stack, activated by default, RLlib exclusively uses RLModules.

If you’re working with a legacy config or want to migrate ModelV2 or Policy classes to the new API stack, see the new API stack migration guide for more information.

If you configured the Algorithm to the old API stack, use the api_stack() method to switch:

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

config = (
    AlgorithmConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
)

Default RLModules#

If you don’t specify module-related settings in the AlgorithmConfig, RLlib uses the respective algorithm’s default RLModule, which is an appropriate choice for initial experimentation and benchmarking. All of the default RLModules support 1D-tensor and image observations ([width] x [height] x [channels]).

Note

For discrete or more complex input observation spaces like dictionaries, use the FlattenObservations connector piece as follows:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.connectors.env_to_module import FlattenObservations

config = (
    PPOConfig()
    # FrozenLake has a discrete observation space, ...
    .environment("FrozenLake-v1")
    # ... which `FlattenObservations` converts to one-hot.
    .env_runners(env_to_module_connector=lambda env: FlattenObservations())
)

Furthermore, all default models offer configurable architecture choices with respect to the number and size of the layers used (Dense or Conv2D), their activations and initializations, and automatic LSTM-wrapping behavior.

Use the DefaultModelConfig datadict class to configure any default model in RLlib. Note that you should only use this class for configuring default models. When writing your own custom RLModules, use plain python dicts to define the model configurations. For how to write and configure your custom RLModules, see Implementing custom RLModules.

Configuring default MLP nets#

To train a simple multi layer perceptron (MLP) policy, which only contains dense layers, with PPO and the default RLModule, configure your experiment as follows:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

config = (
    PPOConfig()
    .environment("CartPole-v1")
    .rl_module(
        # Use a non-default 32,32-stack with ReLU activations.
        model_config=DefaultModelConfig(
            fcnet_hiddens=[32, 32],
            fcnet_activation="relu",
        )
    )
)

The following is the compete list of all supported fcnet_.. options:

    #: List containing the sizes (number of nodes) of a fully connected (MLP) stack.
    #: Note that in an encoder-based default architecture with a policy head (and
    #: possible value head), this setting only affects the encoder component. To set the
    #: policy (and value) head sizes, use `post_fcnet_hiddens`, instead. For example,
    #: if you set `fcnet_hiddens=[32, 32]` and `post_fcnet_hiddens=[64]`, you would get
    #: an RLModule with a [32, 32] encoder, a [64, act-dim] policy head, and a [64, 1]
    #: value head (if applicable).
    fcnet_hiddens: List[int] = field(default_factory=lambda: [256, 256])
    #: Activation function descriptor for the stack configured by `fcnet_hiddens`.
    #: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same),
    #: and 'linear' (or None).
    fcnet_activation: str = "tanh"
    #: Initializer function or class descriptor for the weight/kernel matrices in the
    #: stack configured by `fcnet_hiddens`. Supported values are the initializer names
    #: (str), classes or functions listed by the frameworks (`torch`). See
    #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
    #: the default initializer defined by `torch` is used.
    fcnet_kernel_initializer: Optional[Union[str, Callable]] = None
    #: Kwargs passed into the initializer function defined through
    #: `fcnet_kernel_initializer`.
    fcnet_kernel_initializer_kwargs: Optional[dict] = None
    #: Initializer function or class descriptor for the bias vectors in the stack
    #: configured by `fcnet_hiddens`. Supported values are the initializer names (str),
    #: classes or functions listed by the frameworks (`torch`). See
    #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
    #: the default initializer defined by `torch` is used.
    fcnet_bias_initializer: Optional[Union[str, Callable]] = None
    #: Kwargs passed into the initializer function defined through
    #: `fcnet_bias_initializer`.
    fcnet_bias_initializer_kwargs: Optional[dict] = None

Configuring default CNN nets#

For image-based environments like Atari, use the conv_.. fields in DefaultModelConfig to configure the convolutional neural network (CNN) stack.

For example:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

config = (
    PPOConfig()
    .environment("ale_py:ALE/Pong-v5")  # `pip install gymnasium[atari]`
    .rl_module(
        model_config=DefaultModelConfig(
            # Use a DreamerV3-style CNN stack.
            conv_filters=[
                [16, 4, 2],  # 1st CNN layer: num_filters, kernel, stride(, padding)?
                [32, 4, 2],  # 2nd CNN layer
                [64, 4, 2],  # etc..
                [128, 4, 2],
            ],
            conv_activation="silu",

            # After the last CNN, the default model flattens, then adds an optional MLP.
            head_fcnet_hiddens=[256],
        )
    )
)

The following is the compete list of all supported conv_.. options:

    #: List of lists of format [num_out_channels, kernel, stride] defining a Conv2D
    #: stack if the input space is 2D. Each item in the outer list represents one Conv2D
    #: layer. `kernel` and `stride` may be single ints (width and height have same
    #: value) or 2-tuples (int, int) specifying width and height dimensions separately.
    #: If None (default) and the input space is 2D, RLlib tries to find a default filter
    #: setup given the exact input dimensions.
    conv_filters: Optional[ConvFilterSpec] = None
    #: Activation function descriptor for the stack configured by `conv_filters`.
    #: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same), and
    #: 'linear' (or None).
    conv_activation: str = "relu"
    #: Initializer function or class descriptor for the weight/kernel matrices in the
    #: stack configured by `conv_filters`. Supported values are the initializer names
    #: (str), classes or functions listed by the frameworks (`torch`). See
    #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
    #: the default initializer defined by `torch` is used.
    conv_kernel_initializer: Optional[Union[str, Callable]] = None
    #: Kwargs passed into the initializer function defined through
    #: `conv_kernel_initializer`.
    conv_kernel_initializer_kwargs: Optional[dict] = None
    #: Initializer function or class descriptor for the bias vectors in the stack
    #: configured by `conv_filters`. Supported values are the initializer names (str),
    #: classes or functions listed by the frameworks (`torch`). See
    #: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
    #: the default initializer defined by `torch` is used.
    conv_bias_initializer: Optional[Union[str, Callable]] = None
    #: Kwargs passed into the initializer function defined through
    #: `conv_bias_initializer`.
    conv_bias_initializer_kwargs: Optional[dict] = None

Other default model settings#

For LSTM-based configurations and specific settings for continuous action output layers, see DefaultModelConfig.

Constructing RLModule instances#

To maintain consistency and usability, RLlib offers a standardized approach for constructing RLModule instances for both single-module and multi-module use cases. An example of a single-module use case is a single-agent experiment. Examples of multi-module use cases are multi-agent learning or other multi-NN setups.

Construction through the class constructor#

The most direct way to construct your RLModule is through its constructor:

import gymnasium as gym
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import DefaultBCTorchRLModule

# Create an env object to know the spaces.
env = gym.make("CartPole-v1")

# Construct the actual RLModule object.
rl_module = DefaultBCTorchRLModule(
    observation_space=env.observation_space,
    action_space=env.action_space,
    # A custom dict that's accessible inside your class as `self.model_config`.
    model_config={"fcnet_hiddens": [64]},
)

Note

If you have a checkpoint of an py:class:`~ray.rllib.algorithms.algorithm.Algorithm or an individual RLModule, see Creating instances with from_checkpoint for how to recreate your RLModule from disk.

Construction through RLModuleSpecs#

Because RLlib is a distributed RL library and needs to create more than one copy of your RLModule, you can use RLModuleSpec objects to define how RLlib should construct each copy during the algorithm’s setup process. The algorithm passes the spec to all subcomponents that need to have a copy of your RLModule.

Creating an RLModuleSpec is straightforward and analogous to the RLModule constructor:

import gymnasium as gym
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import DefaultBCTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

# Create an env object to know the spaces.
env = gym.make("CartPole-v1")

# First construct the spec.
spec = RLModuleSpec(
    module_class=DefaultBCTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    # A custom dict that's accessible inside your class as `self.model_config`.
    model_config={"fcnet_hiddens": [64]},
)

# Then, build the RLModule through the spec's `build()` method.
rl_module = spec.build()
import gymnasium as gym
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import DefaultBCTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec

# First construct the MultiRLModuleSpec.
spec = MultiRLModuleSpec(
    rl_module_specs={
        "module_1": RLModuleSpec(
            module_class=DefaultBCTorchRLModule,

            # Define the spaces for only this sub-module.
            observation_space=gym.spaces.Box(low=-1, high=1, shape=(10,)),
            action_space=gym.spaces.Discrete(2),

            # A custom dict that's accessible inside your class as
            # `self.model_config`.
            model_config={"fcnet_hiddens": [32]},
        ),
        "module_2": RLModuleSpec(
            module_class=DefaultBCTorchRLModule,

            # Define the spaces for only this sub-module.
            observation_space=gym.spaces.Box(low=-1, high=1, shape=(5,)),
            action_space=gym.spaces.Discrete(2),

            # A custom dict that's accessible inside your class as
            # `self.model_config`.
            model_config={"fcnet_hiddens": [16]},
        ),
    },
)

# Construct the actual MultiRLModule instance with .build():
multi_rl_module = spec.build()

You can pass the RLModuleSpec instances to your AlgorithmConfig to tell RLlib to use the particular module class and constructor arguments:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

config = (
    PPOConfig()
    .environment("CartPole-v1")
    .rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=MyRLModuleClass,
            model_config={"some_key": "some_setting"},
        ),
    )
)
ppo = config.build()
print(ppo.get_module())

Note

Often when creating an RLModuleSpec , you don’t have to define attributes like observation_space or action_space because RLlib automatically infers these attributes based on the used environment or other configuration parameters.

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole

config = (
    PPOConfig()
    .environment(MultiAgentCartPole, env_config={"num_agents": 2})
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            # All agents (0 and 1) use the same (single) RLModule.
            rl_module_specs=RLModuleSpec(
                module_class=MyRLModuleClass,
                model_config={"some_key": "some_setting"},
            )
        ),
    )
)
ppo = config.build()
print(ppo.get_module())
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole

config = (
    PPOConfig()
    .environment(MultiAgentCartPole, env_config={"num_agents": 2})
    .multi_agent(
        policies={"p0", "p1"},
        # Agent IDs of `MultiAgentCartPole` are 0 and 1, mapping to
        # "p0" and "p1", respectively.
        policy_mapping_fn=lambda agent_id, episode, **kw: f"p{agent_id}"
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            # Agents (0 and 1) use different (single) RLModules.
            rl_module_specs={
                "p0": RLModuleSpec(
                    module_class=MyRLModuleClass,
                    # Small network.
                    model_config={"fcnet_hiddens": [32, 32]},
                ),
                "p1": RLModuleSpec(
                    module_class=MyRLModuleClass,
                    # Large network.
                    model_config={"fcnet_hiddens": [128, 128]},
                ),
            },
        ),
    )
)
ppo = config.build()
print(ppo.get_module())

Implementing custom RLModules#

To implement your own neural network architecture and computation logic, subclass TorchRLModule for any single-agent learning experiment or for independent multi-agent learning.

For more advanced multi-agent use cases like ones with shared communication between agents, or any multi-model use cases, subclass the MultiRLModule class, instead.

Note

An alternative to subclassing TorchRLModule is to directly subclass your Algorithm’s default RLModule. For example, to use PPO, you can subclass DefaultPPOTorchRLModule. You should carefully study the existing default model in this case to understand how to override the setup(), the _forward_() methods, and possibly some algo-specific API methods. See Algorithm-specific RLModule APIs for how to determine which APIs your algorithm requires you to implement.

The setup() method#

You should first implement the setup() method, in which you add needed NN subcomponents and assign these to class attributes of your choice.

Note that you should call super().setup() in your implementation.

You also have access to the following attributes anywhere in the class, including in setup():

  1. self.observation_space

  2. self.action_space

  3. self.inference_only

  4. self.model_config (a dict with any custom config settings)

import torch
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule

class MyTorchPolicy(TorchRLModule):

    def setup(self):
        # You have access here to the following already set attributes:
        # self.observation_space
        # self.action_space
        # self.inference_only
        # self.model_config  # <- a dict with custom settings

        # Use the observation space (if a Box) to infer the input dimension.
        input_dim = self.observation_space.shape[0]

        # Use the model_config dict to extract the hidden dimension.
        hidden_dim = self.model_config["fcnet_hiddens"][0]

        # Use the action space to infer the number of output nodes.
        output_dim = self.action_space.n

        # Build all the layers and subcomponents here you need for the
        # RLModule's forward passes.
        self._pi_head = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
        )

Forward methods#

Implementing the forward computation logic, you can either define a generic forward behavior by overriding the private _forward() method, which RLlib then uses everywhere in the model’s lifecycle, or, if you require more granularity, define the following three private methods:

For custom _forward(), _forward_inference(), and _forward_exploration() methods, you must return a dictionary that contains the key actions and/or the key action_dist_inputs.

If you return the actions key from your forward method:

  • RLlib uses the provided actions as-is.

  • In case you also return the action_dist_inputs key, RLlib creates a Distribution instance from the parameters under that key. In the case of forward_exploration(), RLlib also creates compute action probabilities and log probabilities for the given actions automatically. See Custom action distributions for more information on custom action distribution classes.

If you don’t return the actions key from your forward method:

Note

In case of _forward_inference(), RLlib always makes the generated distributions from returned key action_dist_inputs deterministic first through the to_deterministic() utility before a possible action sample step. For example, RLlib reduces the sampling from a Categorical distribution to selecting the argmax actions from the distribution logits or probabilities. If you return the “actions” key, RLlib skips that sampling step.

from ray.rllib.core import Columns, TorchRLModule

class MyTorchPolicy(TorchRLModule):
    ...

    def _forward_inference(self, batch):
        ...
        return {
            Columns.ACTIONS: ...  # RLlib uses these actions as-is
        }

    def _forward_exploration(self, batch):
        ...
        return {
            Columns.ACTIONS: ...  # RLlib uses these actions as-is (no sampling step!)
            Columns.ACTION_DIST_INPUTS: ...  # If provided, RLlib uses these dist inputs to compute probs and logp.
        }
from ray.rllib.core import Columns, TorchRLModule

class MyTorchPolicy(TorchRLModule):
    ...

    def _forward_inference(self, batch):
        ...
        return {
            # RLlib:
            # - Generates distribution from ACTION_DIST_INPUTS parameters.
            # - Converts distribution to a deterministic equivalent.
            # - Samples from the deterministic distribution.
            Columns.ACTION_DIST_INPUTS: ...
        }

    def _forward_exploration(self, batch):
        ...
        return {
            # RLlib:
            # - Generates distribution from ACTION_DIST_INPUTS parameters.
            # - Samples from the stochastic distribution.
            # - Computes action probs and logs automatically using the sampled
            #   actions and the distribution.
            Columns.ACTION_DIST_INPUTS: ...
        }

Never override the constructor (__init__), however, note that the RLModule class’s constructor requires the following arguments and also receives these properly when you call a spec’s build() method:

  • observation_space: The observation space after having passed all connectors; this observation space is the actual input space for the model after all preprocessing steps.

  • action_space: The action space of the environment.

  • inference_only: Whether RLlib should build the RLModule in inference-only mode, dropping subcomponents that it only needs for learning.

  • model_config: The model config, which is either a custom dictionary for custom RLModules or a DefaultModelConfig dataclass object, which is only for RLlib’s default models. Define model hyper-parameters such as number of layers, type of activation, etc. in this object.

See Construction through the class constructor for more details on how to create an RLModule through the constructor.

Algorithm-specific RLModule APIs#

The algorithm that you choose to use with your RLModule affects to some extent the structure of the final custom module. Each Algorithm class has a fixed set of APIs that all RLModules trained by that algorithm, need to implement.

To find out, what APIs your Algorithms require, do the following:

# Import the config of the algorithm of your choice.
from ray.rllib.algorithms.sac import SACConfig

# Print out the abstract APIs, you need to subclass from and whose
# abstract methods you need to implement, besides the ``setup()`` and ``_forward_..()``
# methods.
print(
    SACConfig()
    .get_default_learner_class()
    .rl_module_required_apis()
)

Note

You don’t need the preceding VPG example module to implement any APIs because you haven’t considered training it with any particular algorithm. You can find examples of algorithm-ready PPO custom RLModules in the tiny_atari_cnn_rlm example and in the lstm_containing_rlm example.

End-to-end example#

Putting together the elements of your custom RLModule that you implemented, a working end-to-end example is as follows:

import torch

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.torch import TorchRLModule


class VPGTorchRLModule(TorchRLModule):
    """A simple VPG (vanilla policy gradient)-style RLModule for testing purposes.

    Use this as a minimum, bare-bones example implementation of a custom TorchRLModule.
    """

    def setup(self):
        # You have access here to the following already set attributes:
        # self.observation_space
        # self.action_space
        # self.inference_only
        # self.model_config  # <- a dict with custom settings
        input_dim = self.observation_space.shape[0]
        hidden_dim = self.model_config["hidden_dim"]
        output_dim = self.action_space.n

        self._policy_net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
        )

    def _forward(self, batch, **kwargs):
        # Push the observations from the batch through our pi-head.
        action_logits = self._policy_net(batch[Columns.OBS])
        # Return parameters for the (default) action distribution, which is
        # `TorchCategorical` (due to our action space being `gym.spaces.Discrete`).
        return {Columns.ACTION_DIST_INPUTS: action_logits}

        # If you need more granularity between the different forward behaviors during
        # the different phases of the module's lifecycle, implement three different
        # forward methods. Thereby, it is recommended to put the inference and
        # exploration versions inside a `with torch.no_grad()` context for better
        # performance.
        # def _forward_train(self, batch):
        #    ...
        #
        # def _forward_inference(self, batch):
        #    with torch.no_grad():
        #        return self._forward_train(batch)
        #
        # def _forward_exploration(self, batch):
        #    with torch.no_grad():
        #        return self._forward_train(batch)

Custom action distributions#

The preceding examples rely on RLModule using the correct action distribution with the computed ACTION_DIST_INPUTS returned by the forward methods. RLlib picks a default distribution class based on the action space, which is TorchCategorical for Discrete action spaces and TorchDiagGaussian for Box action spaces.

To use a different distribution class and return parameters for its constructor from your forward methods, override the following methods in the RLModule implementation:

Note

If you only return ACTION_DIST_INPUTS from your forward methods, RLlib automatically uses the to_deterministic() method of the distribution returned by your get_inference_action_dist_cls().

See torch_distributions.py for common distribution implementations.

Implementing custom MultiRLModules#

For multi-module setups, RLlib provides the MultiRLModule class, whose default implementation is a dictionary of individual RLModule objects. one for each submodule and identified by a ModuleID.

The base-class MultiRLModule implementation works for most of the use cases that need to define independent neural networks. However, for any complex, multi-network or multi-agent use case, where agents share one or more neural networks, you should inherit from this class and override the default implementation.

The following code snippets create a custom multi-agent RL module with two simple “policy head” modules, which share the same encoder, the third network in the MultiRLModule. The encoder receives the raw observations from the env and outputs embedding vectors that then serve as input for the two policy heads to compute the agents’ actions.

class VPGMultiRLModuleWithSharedEncoder(MultiRLModule):
    """VPG (vanilla pol. gradient)-style MultiRLModule handling a shared encoder.
    """

    def setup(self):
        # Call the super's setup().
        super().setup()

        # Assert, we have the shared encoder submodule.
        assert (
            SHARED_ENCODER_ID in self._rl_modules
            and isinstance(self._rl_modules[SHARED_ENCODER_ID], SharedEncoder)
            and len(self._rl_modules) > 1
        )
        # Assign the encoder to a convenience attribute.
        self.encoder = self._rl_modules[SHARED_ENCODER_ID]

    def _forward(self, batch, **kwargs):
        # Collect our policies' outputs in this dict.
        outputs = {}

        # Loop through the policy nets (through the given batch's keys).
        for policy_id, policy_batch in batch.items():
            rl_module = self._rl_modules[policy_id]

            # Pass policy's observations through shared encoder to get the features for
            # this policy.
            policy_batch["encoder_embeddings"] = self.encoder._forward(batch[policy_id])

            # Pass the policy's embeddings through the policy net.
            outputs[policy_id] = rl_module._forward(batch[policy_id], **kwargs)

        return outputs


Within the MultiRLModule, you need to have two policy sub-RLModules. They may be of the same class, which you can implement as follows:

class VPGPolicyAfterSharedEncoder(TorchRLModule):
    """A VPG (vanilla pol. gradient)-style RLModule using a shared encoder.
    """

    def setup(self):
        super().setup()

        # Incoming feature dim from the shared encoder.
        embedding_dim = self.model_config["embedding_dim"]
        hidden_dim = self.model_config["hidden_dim"]

        self._pi_head = torch.nn.Sequential(
            torch.nn.Linear(embedding_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, self.action_space.n),
        )

    def _forward(self, batch, **kwargs):
        # Embeddings can be found in the batch under the "encoder_embeddings" key.
        embeddings = batch["encoder_embeddings"]
        logits = self._pi_head(embeddings)
        return {Columns.ACTION_DIST_INPUTS: logits}


Finally, the shared encoder RLModule should look similar to this:

class SharedEncoder(TorchRLModule):
    """A shared encoder that can be used with `VPGMultiRLModuleWithSharedEncoder`."""

    def setup(self):
        super().setup()

        input_dim = self.observation_space.shape[0]
        embedding_dim = self.model_config["embedding_dim"]

        # A very simple encoder network.
        self._net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, embedding_dim),
        )

    def _forward(self, batch, **kwargs):
        # Pass observations through the net and return outputs.
        return {"encoder_embeddings": self._net(batch[Columns.OBS])}


To plug in the custom MultiRLModule from the first tab, into your algorithm’s config, create a MultiRLModuleSpec with the new class and its constructor settings. Also, create one RLModuleSpec for each agent and the shared encoder RLModule, because RLlib requires their observation and action spaces and their model hyper-parameters:

            import gymnasium as gym
            from ray.rllib.algorithms.ppo import PPOConfig
            from ray.rllib.core import MultiRLModuleSpec, RLModuleSpec
            from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole

            single_agent_env = gym.make("CartPole-v1")

            EMBEDDING_DIM = 64  # encoder output dim

            config = (
                PPOConfig()
                .environment(MultiAgentCartPole, env_config={"num_agents": 2})
                .multi_agent(
                    # Declare the two policies trained.
                    policies={"p0", "p1"},
                    # Agent IDs of `MultiAgentCartPole` are 0 and 1. They are mapped to
                    # the two policies with ModuleIDs "p0" and "p1", respectively.
                    policy_mapping_fn=lambda agent_id, episode, **kw: f"p{agent_id}"
                )
                .rl_module(
                    rl_module_spec=MultiRLModuleSpec(
                        rl_module_specs={
                            # Shared encoder.
                            SHARED_ENCODER_ID: RLModuleSpec(
                                module_class=SharedEncoder,
                                model_config={"embedding_dim": EMBEDDING_DIM},
                                observation_space=single_agent_env.observation_space,
                            ),
                            # Large policy net.
                            "p0": RLModuleSpec(
                                module_class=VPGPolicyAfterSharedEncoder,
                                model_config={
                                    "embedding_dim": EMBEDDING_DIM,
                                    "hidden_dim": 1024,
                                },
                            ),
                            # Small policy net.
                            "p1": RLModuleSpec(
                                module_class=VPGPolicyAfterSharedEncoder,
                                model_config={
                                    "embedding_dim": EMBEDDING_DIM,
                                    "hidden_dim": 64,
                                },
                            ),
                        },
                    ),
                )
            )
            algo = config.build()
            print(algo.get_module())

Note

In order to properly learn with the preceding setup, you should write and use a specific multi-agent Learner, capable of handling the shared encoder. This Learner should only have a single optimizer updating all three submodules, which are the encoder and the two policy nets, to stabilize learning. When using the standard “one-optimizer-per-module” Learners, however, the two optimizers for policy 1 and 2 would take turns updating the same shared encoder, which would lead to learning instabilities.

Checkpointing RLModules#

You can checkpoint RLModules instances with their save_to_path() method. If you already have an instantiated RLModule and would like to load new model weights into it from an existing checkpoint, use the restore_from_path() method.

The following examples show how you can use these methods outside of, or in conjunction with, an RLlib Algorithm.

Creating an RLModule checkpoint#

import tempfile

import gymnasium as gym

from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import DefaultPPOTorchRLModule
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

env = gym.make("CartPole-v1")

# Create an RLModule to later checkpoint.
rl_module = DefaultPPOTorchRLModule(
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config=DefaultModelConfig(fcnet_hiddens=[32]),
)

# Finally, write the RLModule checkpoint.
module_ckpt_path = tempfile.mkdtemp()
rl_module.save_to_path(module_ckpt_path)

Creating an RLModule from an (RLModule) checkpoint#

If you have an RLModule checkpoint saved and would like to create a new RLModule directly from it, use the from_checkpoint() method:

from ray.rllib.core.rl_module.rl_module import RLModule

# Create a new RLModule from the checkpoint.
new_module = RLModule.from_checkpoint(module_ckpt_path)

Loading an RLModule checkpoint into a running Algorithm#

from ray.rllib.algorithms.ppo import PPOConfig

# Create a new Algorithm (with the changed module config: 32 units instead of the
# default 256; otherwise loading the state of ``module`` fails due to a shape
# mismatch).
config = (
    PPOConfig()
    .environment("CartPole-v1")
    .rl_module(model_config=DefaultModelConfig(fcnet_hiddens=[32]))
)
ppo = config.build()

Now you can load the saved RLModule state from the preceding module.save_to_path(), directly into the running Algorithm RLModules. Note that all RLModules within the algorithm get updated, the ones in the Learner workers and the ones in the EnvRunners.

ppo.restore_from_path(
    module_ckpt_path,  # <- NOT an Algorithm checkpoint, but single-agent RLModule one.

    # Therefore, we have to provide the exact path (of RLlib components) down
    # to the individual RLModule within the algorithm, which is:
    component="learner_group/learner/rl_module/default_policy",
)