Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The team is currently transitioning algorithms, example scripts, and documentation to the new code base throughout the subsequent minor releases leading up to Ray 3.0.
See here for more details on how to activate and use the new API stack.
RL Modules (Alpha)#
Note
This is an experimental module that serves as a general replacement for ModelV2, and is subject to change. It will eventually match the functionality of the previous stack. If you only use high-level RLlib APIs such as Algorithm
you should not experience significant changes, except for a few new parameters to the configuration object. If you’ve used custom models or policies before, you’ll need to migrate them to the new modules. Check the Migration guide for more information.
The table below shows the list of migrated algorithms and their current supported features, which will be updated as we progress.
Algorithm |
Independent MARL |
Fully-connected |
Image inputs (CNN) |
RNN support (LSTM) |
Complex observations (ComplexNet) |
---|---|---|---|---|---|
PPO |
|||||
IMPALA |
|||||
APPO |
RL Module is a neural network container that implements three public methods: forward_train()
, forward_exploration()
, and forward_inference()
. Each method corresponds to a distinct reinforcement learning phase.
forward_exploration()
handles acting and data collection, balancing exploration and exploitation. On the other hand, the forward_inference()
serves the learned model during evaluation, often being less stochastic.
forward_train()
manages the training phase, handling calculations exclusive to computing losses, such as learning Q values in a DQN model.
Enabling RL Modules in the Configuration#
Enable RL Modules via our configuration object: AlgorithmConfig.api_stack(enable_rl_module_and_learner=True)
.
import torch
from pprint import pprint
from ray.rllib.algorithms.ppo import PPOConfig
config = PPOConfig().framework("torch").environment("CartPole-v1")
algorithm = config.build()
# run for 2 training steps
for _ in range(2):
result = algorithm.train()
pprint(result)
Constructing RL Modules#
The RLModule API provides a unified way to define custom reinforcement learning models in RLlib. This API enables you to design and implement your own models to suit specific needs.
To maintain consistency and usability, RLlib offers a standardized approach for defining module objects for both single-agent and multi-agent reinforcement learning environments. This is achieved through the RLModuleSpec
and MultiRLModuleSpec
classes. The built-in RLModules in RLlib follow this consistent design pattern, making it easier for you to understand and utilize these modules.
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
env = gym.make("CartPole-v1")
spec = RLModuleSpec(
module_class=DiscreteBCTorchModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"fcnet_hiddens": [64]},
)
module = spec.build()
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
spec = MultiRLModuleSpec(
rl_module_specs={
"module_1": RLModuleSpec(
module_class=DiscreteBCTorchModule,
observation_space=gym.spaces.Box(low=-1, high=1, shape=(10,)),
action_space=gym.spaces.Discrete(2),
model_config={"fcnet_hiddens": [32]},
),
"module_2": RLModuleSpec(
module_class=DiscreteBCTorchModule,
observation_space=gym.spaces.Box(low=-1, high=1, shape=(5,)),
action_space=gym.spaces.Discrete(2),
model_config={"fcnet_hiddens": [16]},
),
},
)
multi_rl_module = spec.build()
You can pass RL Module specs to the algorithm configuration to be used by the algorithm.
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
from ray.rllib.core.testing.bc_algorithm import BCConfigTest
config = (
BCConfigTest()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.rl_module(
model_config={"fcnet_hiddens": [32, 32]},
rl_module_spec=RLModuleSpec(module_class=DiscreteBCTorchModule),
)
)
algo = config.build()
Note
For passing RL Module specs, all fields don’t have to be filled as they are filled based on the described environment or other algorithm configuration parameters (i.e. ,``observation_space``, action_space
, model_config_dict
are not required fields when passing a custom RL Module spec to the algorithm config.)
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
from ray.rllib.core.testing.bc_algorithm import BCConfigTest
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
config = (
BCConfigTest()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment(MultiAgentCartPole, env_config={"num_agents": 2})
.rl_module(
model_config={"fcnet_hiddens": [32, 32]},
rl_module_spec=MultiRLModuleSpec(
rl_module_specs={
"p0": RLModuleSpec(module_class=DiscreteBCTorchModule),
},
),
)
)
Writing Custom Single Agent RL Modules#
For single-agent algorithms (e.g., PPO, DQN) or independent multi-agent algorithms (e.g., PPO-MultiAgent), use RLModule
. For more advanced multi-agent use cases with a shared communication between agents, extend the MultiRLModule
class.
RLlib treats single-agent modules as a special case of MultiRLModule
with only one module. Create the multi-agent representation of all RLModules by calling as_multi_rl_module()
. For example:
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule
env = gym.make("CartPole-v1")
spec = RLModuleSpec(
module_class=DiscreteBCTorchModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config={"fcnet_hiddens": [64]},
)
module = spec.build()
multi_rl_module = module.as_multi_rl_module()
RLlib implements the following abstract framework specific base classes:
TorchRLModule
: For PyTorch-based RL Modules.TfRLModule
: For TensorFlow-based RL Modules.
The minimum requirement is for sub-classes of RLModule
is to implement the following methods:
_forward_train()
: Forward pass for training._forward_inference()
: Forward pass for inference._forward_exploration()
: Forward pass for exploration.
For your custom forward_exploration()
and forward_inference()
methods, you must return a dictionary that either contains the key “actions” and/or the key “action_dist_inputs”.
If you return the “actions” key:
RLlib will use the actions provided thereunder as-is.
If you also returned the “action_dist_inputs” key: RLlib will also create a
Distribution
object from the distribution parameters under that key and - in the case offorward_exploration()
- compute action probs and logp values from the given actions automatically.
If you don’t return the “actions” key:
You must return the “action_dist_inputs” key instead from your
forward_exploration()
andforward_inference()
methods.RLlib will create a
Distribution
object from the distribution parameters under that key and sample actions from the thus generated distribution.In the case of
forward_exploration()
, RLlib will also compute action probs and logp values from the sampled actions automatically.
Note
In the case of forward_inference()
,
the generated distributions (from returned key “action_dist_inputs”) will always be made deterministic first via
the to_deterministic()
utility before a possible action sample step.
Thus, for example, sampling from a Categorical distribution will be reduced to simply selecting the argmax actions from the distribution’s logits/probs.
Commonly used distribution implementations can be found under ray.rllib.models.tf.tf_distributions
for tensorflow and
ray.rllib.models.torch.torch_distributions
for torch. You can choose to return determinstic actions, by creating a determinstic distribution instance.
"""
An RLModule whose forward_exploration/inference methods return the
"actions" key.
"""
class MyRLModule(TorchRLModule):
...
def _forward_inference(self, batch):
...
return {
"actions": ... # actions will be used as-is
}
def _forward_exploration(self, batch):
...
return {
"actions": ... # actions will be used as-is (no sampling step!)
"action_dist_inputs": ... # optional: If provided, will be used to compute action probs and logp.
}
"""
An RLModule whose forward_exploration/inference methods don't return the
"actions" key.
"""
class MyRLModule(TorchRLModule):
...
def _forward_inference(self, batch):
...
return {
# RLlib will:
# - Generate distribution from these parameters.
# - Convert distribution to a deterministic equivalent.
# - "sample" from the deterministic distribution.
"action_dist_inputs": ...
}
def _forward_exploration(self, batch):
...
return {
# RLlib will:
# - Generate distribution from these parameters.
# - "sample" from the (stochastic) distribution.
# - Compute action probs/logs automatically using the sampled
# actions and the generated distribution object.
"action_dist_inputs": ...
}
Also the RLModule
class’s constrcutor requires a dataclass config object called ~ray.rllib.core.rl_module.rl_module.RLModuleConfig
which contains the following fields:
observation_space
: The observation space of the environment (either processed or raw).action_space
: The action space of the environment.model_config_dict
: The model config dictionary of the algorithm. Model hyper-parameters such as number of layers, type of activation, etc. are defined here.catalog_class
: TheCatalog
object of the algorithm.
When writing RL Modules, you need to use these fields to construct your model.
from typing import Any, Dict
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleConfig
import torch
import torch.nn as nn
class DiscreteBCTorchModule(TorchRLModule):
def __init__(self, config: RLModuleConfig) -> None:
super().__init__(config)
def setup(self):
input_dim = self.observation_space.shape[0]
hidden_dim = self.model_config["fcnet_hiddens"][0]
output_dim = self.action_space.n
self.policy = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
self.input_dim = input_dim
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
return self._forward_train(batch)
def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
return self._forward_train(batch)
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
action_logits = self.policy(batch["obs"])
return {"action_dist": torch.distributions.Categorical(logits=action_logits)}
from typing import Any
from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule
from ray.rllib.core.rl_module.rl_module import RLModuleConfig
import tensorflow as tf
class DiscreteBCTfModule(TfRLModule):
def __init__(self, config: RLModuleConfig) -> None:
super().__init__(config)
def setup(self):
input_dim = self.observation_space.shape[0]
hidden_dim = self.model_config["fcnet_hiddens"][0]
output_dim = self.action_space.n
self.policy = tf.keras.Sequential(
[
tf.keras.layers.Dense(hidden_dim, activation="relu"),
tf.keras.layers.Dense(output_dim),
]
)
self.input_dim = input_dim
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
return self._forward_train(batch)
def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]:
return self._forward_train(batch)
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
action_logits = self.policy(batch["obs"])
return {"action_dist": tf.distributions.Categorical(logits=action_logits)}
Writing Custom Multi-Agent RL Modules (Advanced)#
For multi-agent modules, RLlib implements MultiAgentRLModule
, which is a dictionary of RLModule
objects, one for each policy, and possibly some shared modules. The base-class implementation works for most of use cases that need to define independent neural networks for sub-groups of agents. For more complex, multi-agent use cases, where the agents share some part of their neural network, you should inherit from this class and override the default implementation.
The MultiRLModule
offers an API for constructing custom models tailored to specific needs. The key method for this customization is MultiRLModule()
.build.
The following example creates a custom multi-agent RL module with underlying modules. The modules share an encoder, which gets applied to the global part of the observations space. The local part passes through a separate encoder, specific to each policy.
from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleConfig, MultiRLModule
import torch
import torch.nn as nn
class BCTorchRLModuleWithSharedGlobalEncoder(TorchRLModule):
"""An RLModule with a shared encoder between agents for global observation."""
def setup(self):
self.encoder = self.model_config["encoder"]
self.policy_head = nn.Sequential(
nn.Linear(
self.model_config["hidden_dim"] + self.model_config["local_dim"],
self.model_config["hidden_dim"],
),
nn.ReLU(),
nn.Linear(self.model_config["hidden_dim"], self.model_config["action_dim"]),
)
def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
return self._common_forward(batch)
def _forward_exploration(self, batch: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
return self._common_forward(batch)
def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]:
return self._common_forward(batch)
def _common_forward(self, batch):
obs = batch["obs"]
global_enc = self.encoder(obs["global"])
policy_in = torch.cat([global_enc, obs["local"]], dim=-1)
action_logits = self.policy_head(policy_in)
return {"action_dist": torch.distributions.Categorical(logits=action_logits)}
class BCTorchMultiAgentModuleWithSharedEncoder(MultiRLModule):
def setup(self):
module_specs = self.rl_module_specs
module_spec = next(iter(module_specs.values()))
global_dim = module_spec.observation_space["global"].shape[0]
hidden_dim = module_spec.model_config["fcnet_hiddens"][0]
shared_encoder = nn.Sequential(
nn.Linear(global_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
rl_modules = {}
for module_id, module_spec in module_specs.items():
rl_modules[module_id] = BCTorchRLModuleWithSharedGlobalEncoder(
observation_space=module_spec.observation_space,
action_space=module_spec.action_space,
model_config={
"local_dim": module_spec.observation_space["local"].shape[0],
"hidden_dim": hidden_dim,
"action_dim": module_spec.action_space.n,
"encoder": shared_encoder,
},
)
self._rl_modules = rl_modules
To construct this custom multi-agent RL module, pass the class to the MultiRLModuleSpec
constructor. Also, pass the RLModuleSpec
for each agent because RLlib requires the observation, action spaces, and model hyper-parameters for each agent.
import gymnasium as gym
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
spec = MultiRLModuleSpec(
multi_rl_module_class=BCTorchMultiAgentModuleWithSharedEncoder,
rl_module_specs={
"local_2d": RLModuleSpec(
observation_space=gym.spaces.Dict(
{
"global": gym.spaces.Box(low=-1, high=1, shape=(2,)),
"local": gym.spaces.Box(low=-1, high=1, shape=(2,)),
}
),
action_space=gym.spaces.Discrete(2),
model_config={"fcnet_hiddens": [64]},
),
"local_5d": RLModuleSpec(
observation_space=gym.spaces.Dict(
{
"global": gym.spaces.Box(low=-1, high=1, shape=(2,)),
"local": gym.spaces.Box(low=-1, high=1, shape=(5,)),
}
),
action_space=gym.spaces.Discrete(5),
model_config={"fcnet_hiddens": [64]},
),
},
)
module = spec.build()
Extending Existing RLlib RL Modules#
RLlib provides a number of RL Modules for different frameworks (e.g., PyTorch, TensorFlow, etc.).
To customize existing RLModules you can change the RLModule directly by inheriting the class and changing the
setup()
or other methods.
For example, extend PPOTorchRLModule
and augment it with your own customization.
Then pass the new customized class into the appropriate AlgorithmConfig
.
There are two possible ways to extend existing RL Modules:
The default way to extend existing RL Modules is to inherit from them and override the methods you need to customize.
Then pass the new customized class into the AlgorithmConfig
to optimize your custom RL Module.
This is the preferred approach. With it, we can define our own models explicitly within a given RL Module
and don’t need to interact with a Catalog, so you don’t need to learn about Catalog.
class MyPPORLModule(PPORLModule):
def __init__(self, config: RLModuleConfig):
super().__init__(config)
...
# Pass in the custom RL Module class to the spec
algo_config = algo_config.rl_module(
rl_module_spec=RLModuleSpec(module_class=MyPPORLModule)
)
A concrete example: If you want to replace the default encoder that RLlib builds for torch, PPO and a given observation space,
you can override the __init__
method on the PPOTorchRLModule
class to create your custom encoder instead of the default one. We do this in the following example.
import gymnasium as gym
import numpy as np
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
DefaultPPOTorchRLModule,
)
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples.envs.classes.random_env import RandomEnv
from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import (
MobileNetV2EncoderConfig,
MOBILENET_INPUT_SHAPE,
)
from ray.rllib.core.models.configs import ActorCriticEncoderConfig
class MobileNetTorchPPORLModule(DefaultPPOTorchRLModule):
"""A DefaultPPORLModule with mobilenet v2 as an encoder.
The idea behind this model is to demonstrate how we can bypass catalog to
take full control over what models and action distribution are being built.
In this example, we do this to modify an existing RLModule with a custom encoder.
"""
def setup(self):
mobilenet_v2_config = MobileNetV2EncoderConfig()
# Since we want to use PPO, which is an actor-critic algorithm, we need to
# use an ActorCriticEncoderConfig to wrap the base encoder config.
actor_critic_encoder_config = ActorCriticEncoderConfig(
base_encoder_config=mobilenet_v2_config
)
self.encoder = actor_critic_encoder_config.build(framework="torch")
mobilenet_v2_output_dims = mobilenet_v2_config.output_dims
pi_config = MLPHeadConfig(
input_dims=mobilenet_v2_output_dims,
output_layer_dim=2,
)
vf_config = MLPHeadConfig(
input_dims=mobilenet_v2_output_dims, output_layer_dim=1
)
self.pi = pi_config.build(framework="torch")
self.vf = vf_config.build(framework="torch")
config = (
PPOConfig()
.rl_module(rl_module_spec=RLModuleSpec(module_class=MobileNetTorchPPORLModule))
.environment(
RandomEnv,
env_config={
"action_space": gym.spaces.Discrete(2),
# Test a simple Image observation space.
"observation_space": gym.spaces.Box(
0.0,
1.0,
shape=MOBILENET_INPUT_SHAPE,
dtype=np.float32,
),
},
)
.env_runners(num_env_runners=0)
# The following training settings make it so that a training iteration is very
# quick. This is just for the sake of this example. PPO will not learn properly
# with these settings!
.training(train_batch_size_per_learner=32, minibatch_size=16, num_epochs=1)
)
config.build().train()
An advanced way to customize your module is by extending its Catalog
.
The Catalog is a component that defines the default models and other sub-components for RL Modules based on factors such as observation_space
, action_space
, etc.
For more information on the Catalog
class, refer to the Catalog user guide.
By modifying the Catalog, you can alter what sub-components are being built for existing RL Modules.
This approach is useful mostly if you want your custom component to integrate with the decision trees that the Catalogs represent.
The following use cases are examples of what may require you to extend the Catalogs:
Choosing a custom model only for a certain observation space.
Using a custom action distribution in multiple distinct Algorithms.
Reusing your custom component in many distinct RL Modules.
For instance, to adapt existing PPORLModules
for a custom graph observation space not supported by RLlib out-of-the-box,
extend the Catalog
class used to create the PPORLModule
and override the method responsible for returning the encoder component to ensure that your custom encoder replaces the default one initially provided by RLlib.
class MyAwesomeCatalog(PPOCatalog):
def build_actor_critic_encoder():
# create your awesome graph encoder here and return it
pass
# Pass in the custom catalog class to the spec
algo_config = algo_config.rl_module(
rl_module_spec=RLModuleSpec(catalog_class=MyAwesomeCatalog)
)
Checkpointing RL Modules#
RL Modules can be checkpointed with their two methods save_to_path()
and from_checkpoint()
.
The following example shows how these methods can be used outside of, or in conjunction with, an RLlib Algorithm.
import gymnasium as gym
import shutil
import tempfile
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
config = PPOConfig().environment("CartPole-v1")
env = gym.make("CartPole-v1")
# Create an RL Module that we would like to checkpoint
module_spec = RLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
# If we want to use this externally created module in the algorithm,
# we need to provide the same config as the algorithm. Any changes to
# the defaults can be given via the right side of the `|` operator.
model_config=config.model_config | {"fcnet_hiddens": [32]},
catalog_class=PPOCatalog,
)
module = module_spec.build()
# Create the checkpoint.
module_ckpt_path = tempfile.mkdtemp()
module.save_to_path(module_ckpt_path)
# Create a new RLModule from the checkpoint.
loaded_module = RLModule.from_checkpoint(module_ckpt_path)
# Create a new Algorithm (with the changed module config: 32 units instead of the
# default 256; otherwise loading the state of `module` will fail due to a shape
# mismatch).
config.rl_module(model_config=config.model_config | {"fcnet_hiddens": [32]})
algo = config.build()
# Now load the saved RLModule state (from the above `module.save_to_path()`) into the
# Algorithm's RLModule(s). Note that all RLModules within the algo get updated, the ones
# in the Learner workers and the ones in the EnvRunners.
algo.restore_from_path(
module_ckpt_path, # <- NOT an Algorithm checkpoint, but single-agent RLModule one.
# We have to provide the exact component-path to the (single) RLModule
# within the algorithm, which is:
component="learner_group/learner/rl_module/default_policy",
)
Migrating from Custom Policies and Models to RL Modules#
This document is for those who have implemented custom policies and models in RLlib and want to migrate to the new ~ray.rllib.core.rl_module.rl_module.RLModule
API. If you have implemented custom policies that extended the ~ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2
or ~ray.rllib.policy.torch_policy_v2.TorchPolicyV2
classes, you likely did so that you could either modify the behavior of constructing models and distributions (via overriding ~ray.rllib.policy.torch_policy_v2.TorchPolicyV2.make_model
, ~ray.rllib.policy.torch_policy_v2.TorchPolicyV2.make_model_and_action_dist
), control the action sampling logic (via overriding ~ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.action_distribution_fn
or ~ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2.action_sampler_fn
), or control the logic for infernce (via overriding ~ray.rllib.policy.policy.Policy.compute_actions_from_input_dict
, ~ray.rllib.policy.policy.Policy.compute_actions
, or ~ray.rllib.policy.policy.Policy.compute_log_likelihoods
). These APIs were built with ray.rllib.models.modelv2.ModelV2
models in mind to enable you to customize the behavior of those functions. However ~ray.rllib.core.rl_module.rl_module.RLModule
is a more general abstraction that will reduce the amount of functions that you need to override.
In the new ~ray.rllib.core.rl_module.rl_module.RLModule
API the construction of the models and the action distribution class that should be used are best defined in the constructor. That RL Module is constructed automatically if users follow the instructions outlined in the sections Enabling RL Modules in the Configuration and Constructing RL Modules. ~ray.rllib.policy.policy.Policy.compute_actions
and ~ray.rllib.policy.policy.Policy.compute_actions_from_input_dict
can still be used for sampling actions for inference or exploration by using the explore=True|False
parameter. If called with explore=True
these functions will invoke ~ray.rllib.core.rl_module.rl_module.RLModule.forward_exploration
and if explore=False
then they will call ~ray.rllib.core.rl_module.rl_module.RLModule.forward_inference
.
What your customization could have looked like before:
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
class MyCustomModel(TorchModelV2):
"""Code for your previous custom model"""
...
class CustomPolicy(TorchPolicyV2):
@DeveloperAPI
@OverrideToImplementCustomLogic
def make_model(self) -> ModelV2:
"""Create model.
Note: only one of make_model or make_model_and_action_dist
can be overridden.
Returns:
ModelV2 model.
"""
return MyCustomModel(...)
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
class MyCustomModel(TorchModelV2):
"""Code for your previous custom model"""
...
class CustomPolicy(TorchPolicyV2):
@DeveloperAPI
@OverrideToImplementCustomLogic
def make_model_and_action_dist(self):
"""Create model and action distribution function.
Returns:
ModelV2 model.
ActionDistribution class.
"""
my_model = MyCustomModel(...) # construct some ModelV2 instance here
dist_class = ... # Action distribution cls
return my_model, dist_class
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
class CustomPolicy(TorchPolicyV2):
@DeveloperAPI
@OverrideToImplementCustomLogic
def action_sampler_fn(
self,
model: ModelV2,
*,
obs_batch: TensorType,
state_batches: TensorType,
**kwargs,
) -> Tuple[TensorType, TensorType, TensorType, List[TensorType]]:
"""Custom function for sampling new actions given policy.
Args:
model: Underlying model.
obs_batch: Observation tensor batch.
state_batches: Action sampling state batch.
Returns:
Sampled action
Log-likelihood
Action distribution inputs
Updated state
"""
return None, None, None, None
@DeveloperAPI
@OverrideToImplementCustomLogic
def action_distribution_fn(
self,
model: ModelV2,
*,
obs_batch: TensorType,
state_batches: TensorType,
**kwargs,
) -> Tuple[TensorType, type, List[TensorType]]:
"""Action distribution function for this Policy.
Args:
model: Underlying model.
obs_batch: Observation tensor batch.
state_batches: Action sampling state batch.
Returns:
Distribution input.
ActionDistribution class.
State outs.
"""
return None, None, None
All of the Policy.compute_***
functions expect that
forward_exploration()
and forward_inference()
return a dictionary that either contains the key “actions” and/or the key “action_dist_inputs”.
See Writing Custom Single Agent RL Modules for more details on how to implement your own custom RL Modules.
"""
No need to override any policy functions. Simply instead implement any custom logic in your custom RL Module
"""
from ray.rllib.models.torch.torch_distributions import YOUR_DIST_CLASS
class MyRLModule(TorchRLModule):
def __init__(self, config: RLConfig):
# construct any custom networks here using config
# specify an action distribution class here
...
def _forward_inference(self, batch):
...
def _forward_exploration(self, batch):
...