RL Modules#
Note
Ray 2.40 uses RLlib’s new API stack by default. The Ray team has mostly completed transitioning algorithms, example scripts, and documentation to the new code base.
If you’re still using the old API stack, see New API stack migration guide for details on how to migrate.
The RLModule
class in RLlib’s new API stack allows you to write custom
models, including highly complex multi-network setups often found in multi-agent or model-based algorithms.
RLModule
is the main neural network class and exposes
three public methods, each corresponding to a distinct phase in the reinforcement learning cycle:
- forward_exploration()
handles the computation of actions during data collection
if RLlib uses the data for a succeeding training step, balancing exploration and exploitation.
- forward_inference()
computes actions for evaluation and production, which often need to be greedy or less stochastic.
- forward_train()
manages the training phase, performing calculations required to
compute losses, such as Q-values in a DQN model, value function predictions in a PG-style setup,
or world-model predictions in model-based algorithms.
Enabling the RLModule API in the AlgorithmConfig#
In the new API stack, activated by default, RLlib exclusively uses RLModules.
If you’re working with a legacy config or want to migrate ModelV2
or Policy
classes to the
new API stack, see the new API stack migration guide for more information.
If you configured the Algorithm
to the old API stack, use the
api_stack()
method to switch:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
config = (
AlgorithmConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
)
Default RLModules#
If you don’t specify module-related settings in the
AlgorithmConfig
, RLlib uses the respective algorithm’s default
RLModule, which is an appropriate choice for initial experimentation and benchmarking. All of the default RLModules support 1D-tensor and
image observations ([width] x [height] x [channels]
).
Note
For discrete or more complex input observation spaces like dictionaries, use the
FlattenObservations
connector
piece as follows:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.connectors.env_to_module import FlattenObservations
config = (
PPOConfig()
# FrozenLake has a discrete observation space, ...
.environment("FrozenLake-v1")
# ... which `FlattenObservations` converts to one-hot.
.env_runners(env_to_module_connector=lambda env: FlattenObservations())
)
Furthermore, all default models offer configurable architecture choices with respect to the number
and size of the layers used (Dense
or Conv2D
), their activations and initializations, and automatic LSTM-wrapping behavior.
Use the DefaultModelConfig
datadict class to configure
any default model in RLlib. Note that you should only use this class for configuring default models.
When writing your own custom RLModules, use plain python dicts to define the model configurations.
For how to write and configure your custom RLModules, see Implementing custom RLModules.
Configuring default MLP nets#
To train a simple multi layer perceptron (MLP) policy, which only contains dense layers, with PPO and the default RLModule, configure your experiment as follows:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
config = (
PPOConfig()
.environment("CartPole-v1")
.rl_module(
# Use a non-default 32,32-stack with ReLU activations.
model_config=DefaultModelConfig(
fcnet_hiddens=[32, 32],
fcnet_activation="relu",
)
)
)
The following is the compete list of all supported fcnet_..
options:
#: List containing the sizes (number of nodes) of a fully connected (MLP) stack.
#: Note that in an encoder-based default architecture with a policy head (and
#: possible value head), this setting only affects the encoder component. To set the
#: policy (and value) head sizes, use `post_fcnet_hiddens`, instead. For example,
#: if you set `fcnet_hiddens=[32, 32]` and `post_fcnet_hiddens=[64]`, you would get
#: an RLModule with a [32, 32] encoder, a [64, act-dim] policy head, and a [64, 1]
#: value head (if applicable).
fcnet_hiddens: List[int] = field(default_factory=lambda: [256, 256])
#: Activation function descriptor for the stack configured by `fcnet_hiddens`.
#: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same),
#: and 'linear' (or None).
fcnet_activation: str = "tanh"
#: Initializer function or class descriptor for the weight/kernel matrices in the
#: stack configured by `fcnet_hiddens`. Supported values are the initializer names
#: (str), classes or functions listed by the frameworks (`torch`). See
#: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
#: the default initializer defined by `torch` is used.
fcnet_kernel_initializer: Optional[Union[str, Callable]] = None
#: Kwargs passed into the initializer function defined through
#: `fcnet_kernel_initializer`.
fcnet_kernel_initializer_kwargs: Optional[dict] = None
#: Initializer function or class descriptor for the bias vectors in the stack
#: configured by `fcnet_hiddens`. Supported values are the initializer names (str),
#: classes or functions listed by the frameworks (`torch`). See
#: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
#: the default initializer defined by `torch` is used.
fcnet_bias_initializer: Optional[Union[str, Callable]] = None
#: Kwargs passed into the initializer function defined through
#: `fcnet_bias_initializer`.
fcnet_bias_initializer_kwargs: Optional[dict] = None
Configuring default CNN nets#
For image-based environments like Atari, use the
conv_..
fields in DefaultModelConfig
to configure
the convolutional neural network (CNN) stack.
For example:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
config = (
PPOConfig()
.environment("ale_py:ALE/Pong-v5") # `pip install gymnasium[atari]`
.rl_module(
model_config=DefaultModelConfig(
# Use a DreamerV3-style CNN stack.
conv_filters=[
[16, 4, 2], # 1st CNN layer: num_filters, kernel, stride(, padding)?
[32, 4, 2], # 2nd CNN layer
[64, 4, 2], # etc..
[128, 4, 2],
],
conv_activation="silu",
# After the last CNN, the default model flattens, then adds an optional MLP.
head_fcnet_hiddens=[256],
)
)
)
The following is the compete list of all supported conv_..
options:
#: List of lists of format [num_out_channels, kernel, stride] defining a Conv2D
#: stack if the input space is 2D. Each item in the outer list represents one Conv2D
#: layer. `kernel` and `stride` may be single ints (width and height have same
#: value) or 2-tuples (int, int) specifying width and height dimensions separately.
#: If None (default) and the input space is 2D, RLlib tries to find a default filter
#: setup given the exact input dimensions.
conv_filters: Optional[ConvFilterSpec] = None
#: Activation function descriptor for the stack configured by `conv_filters`.
#: Supported values are: 'tanh', 'relu', 'swish' (or 'silu', which is the same), and
#: 'linear' (or None).
conv_activation: str = "relu"
#: Initializer function or class descriptor for the weight/kernel matrices in the
#: stack configured by `conv_filters`. Supported values are the initializer names
#: (str), classes or functions listed by the frameworks (`torch`). See
#: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
#: the default initializer defined by `torch` is used.
conv_kernel_initializer: Optional[Union[str, Callable]] = None
#: Kwargs passed into the initializer function defined through
#: `conv_kernel_initializer`.
conv_kernel_initializer_kwargs: Optional[dict] = None
#: Initializer function or class descriptor for the bias vectors in the stack
#: configured by `conv_filters`. Supported values are the initializer names (str),
#: classes or functions listed by the frameworks (`torch`). See
#: https://pytorch.org/docs/stable/nn.init.html for `torch`. If `None` (default),
#: the default initializer defined by `torch` is used.
conv_bias_initializer: Optional[Union[str, Callable]] = None
#: Kwargs passed into the initializer function defined through
#: `conv_bias_initializer`.
conv_bias_initializer_kwargs: Optional[dict] = None
Other default model settings#
For LSTM-based configurations and specific settings for continuous action output layers,
see DefaultModelConfig
.
Constructing RLModule instances#
To maintain consistency and usability, RLlib offers a standardized approach for constructing
RLModule
instances for both single-module and multi-module use cases. An example of a single-module use case is a single-agent experiment. Examples of multi-module use cases are
multi-agent learning or other multi-NN setups.
Construction through the class constructor#
The most direct way to construct your RLModule
is through its constructor:
import gymnasium as gym
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import DefaultBCTorchRLModule
# Create an env object to know the spaces.
env = gym.make("CartPole-v1")
# Construct the actual RLModule object.
rl_module = DefaultBCTorchRLModule(
observation_space=env.observation_space,
action_space=env.action_space,
# A custom dict that's accessible inside your class as `self.model_config`.
model_config={"fcnet_hiddens": [64]},
)
Note
If you have a checkpoint of an py:class:`~ray.rllib.algorithms.algorithm.Algorithm
or an individual
RLModule
,
see Creating instances with from_checkpoint for how to recreate your
RLModule
from disk.
Construction through RLModuleSpecs#
Because RLlib is a distributed RL library and needs to create more than one copy of
your RLModule
, you can use
RLModuleSpec
objects to define how RLlib should construct
each copy during the algorithm’s setup process. The algorithm passes the spec to all
subcomponents that need to have a copy of your RLModule.
Creating an RLModuleSpec
is straightforward
and analogous to the RLModule
constructor:
import gymnasium as gym
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import DefaultBCTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
# Create an env object to know the spaces.
env = gym.make("CartPole-v1")
# First construct the spec.
spec = RLModuleSpec(
module_class=DefaultBCTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
# A custom dict that's accessible inside your class as `self.model_config`.
model_config={"fcnet_hiddens": [64]},
)
# Then, build the RLModule through the spec's `build()` method.
rl_module = spec.build()
import gymnasium as gym
from ray.rllib.algorithms.bc.torch.default_bc_torch_rl_module import DefaultBCTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
# First construct the MultiRLModuleSpec.
spec = MultiRLModuleSpec(
rl_module_specs={
"module_1": RLModuleSpec(
module_class=DefaultBCTorchRLModule,
# Define the spaces for only this sub-module.
observation_space=gym.spaces.Box(low=-1, high=1, shape=(10,)),
action_space=gym.spaces.Discrete(2),
# A custom dict that's accessible inside your class as
# `self.model_config`.
model_config={"fcnet_hiddens": [32]},
),
"module_2": RLModuleSpec(
module_class=DefaultBCTorchRLModule,
# Define the spaces for only this sub-module.
observation_space=gym.spaces.Box(low=-1, high=1, shape=(5,)),
action_space=gym.spaces.Discrete(2),
# A custom dict that's accessible inside your class as
# `self.model_config`.
model_config={"fcnet_hiddens": [16]},
),
},
)
# Construct the actual MultiRLModule instance with .build():
multi_rl_module = spec.build()
You can pass the RLModuleSpec
instances to your
AlgorithmConfig
to
tell RLlib to use the particular module class and constructor arguments:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
config = (
PPOConfig()
.environment("CartPole-v1")
.rl_module(
rl_module_spec=RLModuleSpec(
module_class=MyRLModuleClass,
model_config={"some_key": "some_setting"},
),
)
)
ppo = config.build()
print(ppo.get_module())
Note
Often when creating an RLModuleSpec
, you don’t have to define attributes
like observation_space
or action_space
because RLlib automatically infers these attributes based on the used
environment or other configuration parameters.
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
config = (
PPOConfig()
.environment(MultiAgentCartPole, env_config={"num_agents": 2})
.rl_module(
rl_module_spec=MultiRLModuleSpec(
# All agents (0 and 1) use the same (single) RLModule.
rl_module_specs=RLModuleSpec(
module_class=MyRLModuleClass,
model_config={"some_key": "some_setting"},
)
),
)
)
ppo = config.build()
print(ppo.get_module())
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
config = (
PPOConfig()
.environment(MultiAgentCartPole, env_config={"num_agents": 2})
.multi_agent(
policies={"p0", "p1"},
# Agent IDs of `MultiAgentCartPole` are 0 and 1, mapping to
# "p0" and "p1", respectively.
policy_mapping_fn=lambda agent_id, episode, **kw: f"p{agent_id}"
)
.rl_module(
rl_module_spec=MultiRLModuleSpec(
# Agents (0 and 1) use different (single) RLModules.
rl_module_specs={
"p0": RLModuleSpec(
module_class=MyRLModuleClass,
# Small network.
model_config={"fcnet_hiddens": [32, 32]},
),
"p1": RLModuleSpec(
module_class=MyRLModuleClass,
# Large network.
model_config={"fcnet_hiddens": [128, 128]},
),
},
),
)
)
ppo = config.build()
print(ppo.get_module())
Implementing custom RLModules#
To implement your own neural network architecture and computation logic, subclass
TorchRLModule
for any single-agent learning experiment
or for independent multi-agent learning.
For more advanced multi-agent use cases like ones with shared communication between agents,
or any multi-model use cases, subclass the MultiRLModule
class, instead.
Note
An alternative to subclassing TorchRLModule
is to
directly subclass your Algorithm’s default RLModule. For example, to use PPO, you can subclass
DefaultPPOTorchRLModule
.
You should carefully study the existing default model in this case to understand how to override
the setup()
, the
_forward_()
methods, and possibly some algo-specific API methods.
See Algorithm-specific RLModule APIs for how to determine which APIs your algorithm requires you to implement.
The setup() method#
You should first implement the setup()
method,
in which you add needed NN subcomponents and assign these to class attributes of your choice.
Note that you should call super().setup()
in your implementation.
You also have access to the following attributes anywhere in the class, including in setup()
:
self.observation_space
self.action_space
self.inference_only
self.model_config
(a dict with any custom config settings)
import torch
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
class MyTorchPolicy(TorchRLModule):
def setup(self):
# You have access here to the following already set attributes:
# self.observation_space
# self.action_space
# self.inference_only
# self.model_config # <- a dict with custom settings
# Use the observation space (if a Box) to infer the input dimension.
input_dim = self.observation_space.shape[0]
# Use the model_config dict to extract the hidden dimension.
hidden_dim = self.model_config["fcnet_hiddens"][0]
# Use the action space to infer the number of output nodes.
output_dim = self.action_space.n
# Build all the layers and subcomponents here you need for the
# RLModule's forward passes.
self._pi_head = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, output_dim),
)
Forward methods#
Implementing the forward computation logic, you can either define a generic forward behavior by overriding the
private _forward()
method, which RLlib then uses everywhere in the model’s lifecycle,
or, if you require more granularity, define the following three private methods:
_forward_exploration()
: Forward pass for computing exploration actions for collecting training data._forward_inference()
: Forward pass for action inference, like greedy._forward_train()
: Forward pass for computing loss function inputs for a training update.
For custom _forward()
,
_forward_inference()
, and
_forward_exploration()
methods, you must return a
dictionary that contains the key actions
and/or the key action_dist_inputs
.
If you return the actions
key from your forward method:
RLlib uses the provided actions as-is.
In case you also return the
action_dist_inputs
key, RLlib creates aDistribution
instance from the parameters under that key. In the case offorward_exploration()
, RLlib also creates compute action probabilities and log probabilities for the given actions automatically. See Custom action distributions for more information on custom action distribution classes.
If you don’t return the actions
key from your forward method:
You must return the
action_dist_inputs
key from your_forward_exploration()
and_forward_inference()
methods.RLlib creates a
Distribution
instance from the parameters under that key and sample actions from that distribution. See here for more information on custom action distribution classes.For
_forward_exploration()
, RLlib also computes action probability and log probability values from the sampled actions automatically.
Note
In case of _forward_inference()
,
RLlib always makes the generated distributions from returned key action_dist_inputs
deterministic first through
the to_deterministic()
utility before a possible action sample step.
For example, RLlib reduces the sampling from a Categorical distribution to selecting the argmax
actions from the distribution logits or probabilities.
If you return the “actions” key, RLlib skips that sampling step.
from ray.rllib.core import Columns, TorchRLModule
class MyTorchPolicy(TorchRLModule):
...
def _forward_inference(self, batch):
...
return {
Columns.ACTIONS: ... # RLlib uses these actions as-is
}
def _forward_exploration(self, batch):
...
return {
Columns.ACTIONS: ... # RLlib uses these actions as-is (no sampling step!)
Columns.ACTION_DIST_INPUTS: ... # If provided, RLlib uses these dist inputs to compute probs and logp.
}
from ray.rllib.core import Columns, TorchRLModule
class MyTorchPolicy(TorchRLModule):
...
def _forward_inference(self, batch):
...
return {
# RLlib:
# - Generates distribution from ACTION_DIST_INPUTS parameters.
# - Converts distribution to a deterministic equivalent.
# - Samples from the deterministic distribution.
Columns.ACTION_DIST_INPUTS: ...
}
def _forward_exploration(self, batch):
...
return {
# RLlib:
# - Generates distribution from ACTION_DIST_INPUTS parameters.
# - Samples from the stochastic distribution.
# - Computes action probs and logs automatically using the sampled
# actions and the distribution.
Columns.ACTION_DIST_INPUTS: ...
}
Never override the constructor (__init__
), however, note that the
RLModule
class’s constructor requires the following arguments
and also receives these properly when you call a spec’s build()
method:
observation_space
: The observation space after having passed all connectors; this observation space is the actual input space for the model after all preprocessing steps.action_space
: The action space of the environment.inference_only
: Whether RLlib should build the RLModule in inference-only mode, dropping subcomponents that it only needs for learning.model_config
: The model config, which is either a custom dictionary for custom RLModules or aDefaultModelConfig
dataclass object, which is only for RLlib’s default models. Define model hyper-parameters such as number of layers, type of activation, etc. in this object.
See Construction through the class constructor for more details on how to create an RLModule through the constructor.
Algorithm-specific RLModule APIs#
The algorithm that you choose to use with your RLModule affects to some extent the structure of the final custom module. Each Algorithm class has a fixed set of APIs that all RLModules trained by that algorithm, need to implement.
To find out, what APIs your Algorithms require, do the following:
# Import the config of the algorithm of your choice.
from ray.rllib.algorithms.sac import SACConfig
# Print out the abstract APIs, you need to subclass from and whose
# abstract methods you need to implement, besides the ``setup()`` and ``_forward_..()``
# methods.
print(
SACConfig()
.get_default_learner_class()
.rl_module_required_apis()
)
Note
You don’t need the preceding VPG example module to implement any APIs because
you haven’t considered training it with any particular algorithm.
You can find examples of algorithm-ready PPO
custom RLModules
in the tiny_atari_cnn_rlm example
and in the lstm_containing_rlm example.
End-to-end example#
Putting together the elements of your custom RLModule
that you implemented,
a working end-to-end example is as follows:
import torch
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.torch import TorchRLModule
class VPGTorchRLModule(TorchRLModule):
"""A simple VPG (vanilla policy gradient)-style RLModule for testing purposes.
Use this as a minimum, bare-bones example implementation of a custom TorchRLModule.
"""
def setup(self):
# You have access here to the following already set attributes:
# self.observation_space
# self.action_space
# self.inference_only
# self.model_config # <- a dict with custom settings
input_dim = self.observation_space.shape[0]
hidden_dim = self.model_config["hidden_dim"]
output_dim = self.action_space.n
self._policy_net = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, output_dim),
)
def _forward(self, batch, **kwargs):
# Push the observations from the batch through our pi-head.
action_logits = self._policy_net(batch[Columns.OBS])
# Return parameters for the (default) action distribution, which is
# `TorchCategorical` (due to our action space being `gym.spaces.Discrete`).
return {Columns.ACTION_DIST_INPUTS: action_logits}
# If you need more granularity between the different forward behaviors during
# the different phases of the module's lifecycle, implement three different
# forward methods. Thereby, it is recommended to put the inference and
# exploration versions inside a `with torch.no_grad()` context for better
# performance.
# def _forward_train(self, batch):
# ...
#
# def _forward_inference(self, batch):
# with torch.no_grad():
# return self._forward_train(batch)
#
# def _forward_exploration(self, batch):
# with torch.no_grad():
# return self._forward_train(batch)
Custom action distributions#
The preceding examples rely on RLModule
using the correct action distribution with the computed
ACTION_DIST_INPUTS
returned by the forward methods. RLlib picks a default distribution class based on
the action space, which is TorchCategorical
for Discrete
action spaces
and TorchDiagGaussian
for Box
action spaces.
To use a different distribution class and return parameters for its constructor from your forward methods,
override the following methods in the RLModule
implementation:
Note
If you only return ACTION_DIST_INPUTS
from your forward methods, RLlib automatically
uses the to_deterministic()
method of the
distribution returned by your get_inference_action_dist_cls()
.
See torch_distributions.py for common distribution implementations.
Implementing custom MultiRLModules#
For multi-module setups, RLlib provides the MultiRLModule
class,
whose default implementation is a dictionary of individual RLModule
objects.
one for each submodule and identified by a ModuleID
.
The base-class MultiRLModule
implementation works for most of the
use cases that need to define independent neural networks. However, for any complex, multi-network or multi-agent use case, where agents share one or more neural networks,
you should inherit from this class and override the default implementation.
The following code snippets create a custom multi-agent RL module with two simple “policy head” modules, which share the same encoder, the third network in the MultiRLModule. The encoder receives the raw observations from the env and outputs embedding vectors that then serve as input for the two policy heads to compute the agents’ actions.
class VPGMultiRLModuleWithSharedEncoder(MultiRLModule):
"""VPG (vanilla pol. gradient)-style MultiRLModule handling a shared encoder.
"""
def setup(self):
# Call the super's setup().
super().setup()
# Assert, we have the shared encoder submodule.
assert (
SHARED_ENCODER_ID in self._rl_modules
and isinstance(self._rl_modules[SHARED_ENCODER_ID], SharedEncoder)
and len(self._rl_modules) > 1
)
# Assign the encoder to a convenience attribute.
self.encoder = self._rl_modules[SHARED_ENCODER_ID]
def _forward(self, batch, **kwargs):
# Collect our policies' outputs in this dict.
outputs = {}
# Loop through the policy nets (through the given batch's keys).
for policy_id, policy_batch in batch.items():
rl_module = self._rl_modules[policy_id]
# Pass policy's observations through shared encoder to get the features for
# this policy.
policy_batch["encoder_embeddings"] = self.encoder._forward(batch[policy_id])
# Pass the policy's embeddings through the policy net.
outputs[policy_id] = rl_module._forward(batch[policy_id], **kwargs)
return outputs
Within the MultiRLModule, you need to have two policy sub-RLModules. They may be of the same class, which you can implement as follows:
class VPGPolicyAfterSharedEncoder(TorchRLModule):
"""A VPG (vanilla pol. gradient)-style RLModule using a shared encoder.
"""
def setup(self):
super().setup()
# Incoming feature dim from the shared encoder.
embedding_dim = self.model_config["embedding_dim"]
hidden_dim = self.model_config["hidden_dim"]
self._pi_head = torch.nn.Sequential(
torch.nn.Linear(embedding_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, self.action_space.n),
)
def _forward(self, batch, **kwargs):
# Embeddings can be found in the batch under the "encoder_embeddings" key.
embeddings = batch["encoder_embeddings"]
logits = self._pi_head(embeddings)
return {Columns.ACTION_DIST_INPUTS: logits}
Finally, the shared encoder RLModule should look similar to this:
class SharedEncoder(TorchRLModule):
"""A shared encoder that can be used with `VPGMultiRLModuleWithSharedEncoder`."""
def setup(self):
super().setup()
input_dim = self.observation_space.shape[0]
embedding_dim = self.model_config["embedding_dim"]
# A very simple encoder network.
self._net = torch.nn.Sequential(
torch.nn.Linear(input_dim, embedding_dim),
)
def _forward(self, batch, **kwargs):
# Pass observations through the net and return outputs.
return {"encoder_embeddings": self._net(batch[Columns.OBS])}
To plug in the custom MultiRLModule from the first tab,
into your algorithm’s config, create a MultiRLModuleSpec
with the new class and its constructor settings. Also, create one RLModuleSpec
for each agent and the shared encoder RLModule, because RLlib requires their observation and action spaces and their
model hyper-parameters:
import gymnasium as gym
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import MultiRLModuleSpec, RLModuleSpec
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
single_agent_env = gym.make("CartPole-v1")
EMBEDDING_DIM = 64 # encoder output dim
config = (
PPOConfig()
.environment(MultiAgentCartPole, env_config={"num_agents": 2})
.multi_agent(
# Declare the two policies trained.
policies={"p0", "p1"},
# Agent IDs of `MultiAgentCartPole` are 0 and 1. They are mapped to
# the two policies with ModuleIDs "p0" and "p1", respectively.
policy_mapping_fn=lambda agent_id, episode, **kw: f"p{agent_id}"
)
.rl_module(
rl_module_spec=MultiRLModuleSpec(
rl_module_specs={
# Shared encoder.
SHARED_ENCODER_ID: RLModuleSpec(
module_class=SharedEncoder,
model_config={"embedding_dim": EMBEDDING_DIM},
observation_space=single_agent_env.observation_space,
),
# Large policy net.
"p0": RLModuleSpec(
module_class=VPGPolicyAfterSharedEncoder,
model_config={
"embedding_dim": EMBEDDING_DIM,
"hidden_dim": 1024,
},
),
# Small policy net.
"p1": RLModuleSpec(
module_class=VPGPolicyAfterSharedEncoder,
model_config={
"embedding_dim": EMBEDDING_DIM,
"hidden_dim": 64,
},
),
},
),
)
)
algo = config.build()
print(algo.get_module())
Note
In order to properly learn with the preceding setup, you should write and use a specific multi-agent
Learner
, capable of handling the shared encoder.
This Learner should only have a single optimizer updating all three submodules, which are the encoder and the two policy nets,
to stabilize learning.
When using the standard “one-optimizer-per-module” Learners, however, the two optimizers for policy 1 and 2
would take turns updating the same shared encoder, which would lead to learning instabilities.
Checkpointing RLModules#
You can checkpoint RLModules
instances with their
save_to_path()
method.
If you already have an instantiated RLModule and would like to load new model weights into it from an existing
checkpoint, use the restore_from_path()
method.
The following examples show how you can use these methods outside of, or in conjunction with, an RLlib Algorithm.
Creating an RLModule checkpoint#
import tempfile
import gymnasium as gym
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import DefaultPPOTorchRLModule
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
env = gym.make("CartPole-v1")
# Create an RLModule to later checkpoint.
rl_module = DefaultPPOTorchRLModule(
observation_space=env.observation_space,
action_space=env.action_space,
model_config=DefaultModelConfig(fcnet_hiddens=[32]),
)
# Finally, write the RLModule checkpoint.
module_ckpt_path = tempfile.mkdtemp()
rl_module.save_to_path(module_ckpt_path)
Creating an RLModule from an (RLModule) checkpoint#
If you have an RLModule checkpoint saved and would like to create a new RLModule directly from it,
use the from_checkpoint()
method:
from ray.rllib.core.rl_module.rl_module import RLModule
# Create a new RLModule from the checkpoint.
new_module = RLModule.from_checkpoint(module_ckpt_path)
Loading an RLModule checkpoint into a running Algorithm#
from ray.rllib.algorithms.ppo import PPOConfig
# Create a new Algorithm (with the changed module config: 32 units instead of the
# default 256; otherwise loading the state of ``module`` fails due to a shape
# mismatch).
config = (
PPOConfig()
.environment("CartPole-v1")
.rl_module(model_config=DefaultModelConfig(fcnet_hiddens=[32]))
)
ppo = config.build()
Now you can load the saved RLModule state from the preceding module.save_to_path()
, directly
into the running Algorithm RLModules. Note that all RLModules within the algorithm get updated, the ones
in the Learner workers and the ones in the EnvRunners.
ppo.restore_from_path(
module_ckpt_path, # <- NOT an Algorithm checkpoint, but single-agent RLModule one.
# Therefore, we have to provide the exact path (of RLlib components) down
# to the individual RLModule within the algorithm, which is:
component="learner_group/learner/rl_module/default_policy",
)