Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The team is currently transitioning algorithms, example scripts, and documentation to the new code base throughout the subsequent minor releases leading up to Ray 3.0.
See here for more details on how to activate and use the new API stack.
Catalog (Alpha)#
Catalog is a utility abstraction that modularizes the construction of components for RLModules.
It includes information such how input observation spaces should be encoded,
what action distributions should be used, and so on.
Catalog
. For example,
PPOTorchRLModule
has the
PPOCatalog
.
To customize existing RLModules either change the RLModule directly by inheriting the class and changing the
setup()
method or, alternatively, extend the Catalog class
attributed to that RLModule
. Use Catalogs only if your customizations fits the abstractions provided by Catalog.
Note
Modifying Catalogs signifies advanced use cases so you should only consider this if modifying an RLModule or writing one does not cover your use case.
We recommend to modify Catalogs only when making deeper customizations to the decision trees that determine what Model
and Distribution
RLlib creates by default.
Note
If you simply want to modify a Model by changing its default values, have a look at the model config dict:
While Catalogs have a base class Catalog, you mostly interact with Algorithm-specific Catalogs. Therefore, this doc also includes examples around PPO from which you can extrapolate to other algorithms. Prerequisites for this user guide is a rough understanding of RLModules. This user guide covers the following topics:
What are Catalogs
Catalog design and ideas
Catalog and AlgorithmConfig
Basic usage
Inject your custom models into RLModules
Inject your custom action distributions into RLModules
Write a Catalog from scratch
What are Catalogs#
Catalogs have two primary roles: Choosing the right Model
and choosing the right Distribution
.
By default, all catalogs implement decision trees that decide model architecture based on a combination of input configurations.
These mainly include the observation space
and action space
of the RLModule
, the model config dict
and the deep learning framework backend
.
The following diagram shows the break down of the information flow towards models
and distributions
within an RLModule.
An RLModule creates an instance of the Catalog class they receive as part of their constructor.
It then create its internal models
and distributions
with the help of this Catalog.
Note
You can also modify Model or Distribution in an RLModule directly by overriding the RLModule’s constructor!
The following diagram shows a concrete case in more detail.
Catalog design and ideas#
Since the main use cases for this component involve deep modifications of it, we explain the design and ideas behind Catalogs in this section.
What problems do Catalogs solve?#
RL algorithms need neural network models
and distributions
.
Within an algorithm, many different architectures for such sub-components are valid.
Moreover, models and distributions vary with environments.
However, most algorithms require models that have similarities.
The problem is finding sensible sub-components for a wide range of use cases while sharing this functionality
across algorithms.
How do Catalogs solve this?#
As states above, Catalogs implement decision-trees for sub-components of RLModules
.
Models and distributions from a Catalog object are meant to fit together.
Since we mostly build RLModules out of Encoder
s, Heads and Distribution
s, Catalogs also generally reflect this.
For example, the PPOCatalog will output Encoders that output a latent vector and two Heads that take this latent vector as input.
(That’s why Catalogs have a latent_dims
attribute). Heads and distributions behave accordingly.
Whenever you create a Catalog, the decision tree is executed to find suitable configs for models and classes for distributions.
By default this happens in _get_encoder_config()
and _get_dist_cls_from_action_space()
.
Whenever you build a model, the config is turned into a model.
Distributions are instantiated per forward pass of an RLModule
and are therefore not built.
API philosophy#
Catalogs attempt to encapsulate most complexity around models inside the Encoder
.
This means that recurrency, attention and other special cases are fully handles inside the Encoder and are transparent
to other components.
Encoders are the only components that the Catalog base class builds.
This is because many algorithms require custom heads and distributions but most of them can use the same encoders.
The Catalog API is designed such that interaction usually happens in two stages:
Instantiate a Catalog. This executes the decision tree.
Generate arbitrary number of decided components through Catalog methods.
The two default methods to access components on the base class are…
You can override these to quickly hack what models RLModules build.
Other methods are private and should only be overridden to make deep changes to the decision tree to enhance the capabilities of Catalogs.
Additionally, get_tokenizer_config()
is a method that can be used when tokenization
is required. Tokenization means single-step-embedding. Encoding also means embedding but can span multiple timesteps.
In fact, RLlib’s tokenizers used in its recurrent Encoders (e.g. TorchLSTMEncoder
),
are instances of non-recurrent Encoder classes.
Catalog and AlgorithmConfig#
Since Catalogs effectively control what models
and distributions
RLlib uses under the hood,
they are also part of RLlib’s configurations. As the primary entry point for configuring RLlib,
AlgorithmConfig
is the place where you can configure the
Catalogs of the RLModules that are created.
You set the catalog class
by going through the RLModuleSpec
or MultiRLModuleSpec
of an AlgorithmConfig.
For example, in heterogeneous multi-agent cases, you modify the MultiRLModuleSpec.
The following example shows how to configure the Catalog of an RLModule
created by PPO.
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
class MyPPOCatalog(PPOCatalog):
def __init__(self, *args, **kwargs):
print("Hi from within PPORLModule!")
super().__init__(*args, **kwargs)
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.framework("torch")
)
# Specify the catalog to use for the PPORLModule.
config = config.rl_module(rl_module_spec=RLModuleSpec(catalog_class=MyPPOCatalog))
# This is how RLlib constructs a PPORLModule
# It will say "Hi from within PPORLModule!".
ppo = config.build()
Basic usage#
In the following three examples, we play with Catalogs to illustrate their API.
High-level API#
The first example showcases the general API for interacting with Catalogs.
import gymnasium as gym
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
env = gym.make("CartPole-v1")
catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
Creating models and distributions#
The second example showcases how to use the base Catalog
to create an model
and an action distribution
.
Besides these, we create a head network
by hand that fits these two by hand.
Creating models and distributions for PPO#
The third example showcases how to use the PPOCatalog
to create a encoder
and an action distribution
.
This is more similar to what RLlib does internally.
Note that the above two examples illustrate in principle what it takes to implement a Catalog.
In this case, we see the difference between Catalog
and PPOCatalog
.
In most cases, we can reuse the capabilities of the base Catalog
base class
and only need to add methods to build head networks that we can then use in the appropriate RLModule
.
Inject your custom model or action distributions into Catalogs#
You can make a Catalog
build custom models
by overriding the Catalog’s methods used by RLModules to build models
.
Have a look at these lines from the constructor of the PPOTorchRLModule
to see how Catalogs are being used by an RLModule
:
# If we have a stateful model, states for the critic need to be collected
# during sampling and `inference-only` needs to be `False`. Note, at this
# point the encoder is not built, yet and therefore `is_stateful()` does
# not work.
is_stateful = isinstance(
self.catalog.actor_critic_encoder_config.base_encoder_config,
RecurrentEncoderConfig,
)
if is_stateful:
self.inference_only = False
# If this is an `inference_only` Module, we'll have to pass this information
# to the encoder config as well.
if self.inference_only and self.framework == "torch":
self.catalog.actor_critic_encoder_config.inference_only = True
# Build models from catalog.
self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework)
self.pi = self.catalog.build_pi_head(framework=self.framework)
self.vf = self.catalog.build_vf_head(framework=self.framework)
Note that what happens inside the constructor of PPOTorchRLModule is similar to the earlier example Creating models and distributions for PPO.
Consequently, in order to build a custom Model
compatible with a PPORLModule,
you can override methods by inheriting from PPOCatalog
or write a Catalog
that implements them from scratch.
The following examples showcase such modifications:
This example shows two modifications:
How to write a custom
Encoder
How to inject the custom Encoder into a
Catalog
Note that, if you only want to inject your Encoder into a single RLModule
, the recommended workflow is to inherit
from an existing RL Module and place the Encoder there.
import gymnasium as gym
import numpy as np
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import (
MobileNetV2EncoderConfig,
MOBILENET_INPUT_SHAPE,
)
from ray.rllib.examples.envs.classes.random_env import RandomEnv
# Define a PPO Catalog that we can use to inject our MobileNetV2 Encoder into RLlib's
# decision tree of what model to choose
class MobileNetEnhancedPPOCatalog(PPOCatalog):
@classmethod
def _get_encoder_config(
cls,
observation_space: gym.Space,
**kwargs,
):
if (
isinstance(observation_space, gym.spaces.Box)
and observation_space.shape == MOBILENET_INPUT_SHAPE
):
# Inject our custom encoder here, only if the observation space fits it
return MobileNetV2EncoderConfig()
else:
return super()._get_encoder_config(observation_space, **kwargs)
# Create a generic config with our enhanced Catalog
ppo_config = (
PPOConfig()
.rl_module(rl_module_spec=RLModuleSpec(catalog_class=MobileNetEnhancedPPOCatalog))
.env_runners(num_env_runners=0)
# The following training settings make it so that a training iteration is very
# quick. This is just for the sake of this example. PPO will not learn properly
# with these settings!
.training(train_batch_size_per_learner=32, minibatch_size=16, num_epochs=1)
)
# CartPole's observation space is not compatible with our MobileNetV2 Encoder, so
# this will use the default behaviour of Catalogs
ppo_config.environment("CartPole-v1")
results = ppo_config.build().train()
print(results)
# For this training, we use a RandomEnv with observations of shape
# MOBILENET_INPUT_SHAPE. This will use our custom Encoder.
ppo_config.environment(
RandomEnv,
env_config={
"action_space": gym.spaces.Discrete(2),
# Test a simple Image observation space.
"observation_space": gym.spaces.Box(
0.0,
1.0,
shape=MOBILENET_INPUT_SHAPE,
dtype=np.float32,
),
},
)
results = ppo_config.build().train()
print(results)
This example shows two modifications:
How to write a custom
Distribution
How to inject the custom action distribution into a
Catalog
import torch
import gymnasium as gym
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.models.distributions import Distribution
from ray.rllib.models.torch.torch_distributions import TorchDeterministic
# Define a simple categorical distribution that can be used for PPO
class CustomTorchCategorical(Distribution):
def __init__(self, logits):
self.torch_dist = torch.distributions.categorical.Categorical(logits=logits)
def sample(self, sample_shape=torch.Size(), **kwargs):
return self.torch_dist.sample(sample_shape)
def rsample(self, sample_shape=torch.Size(), **kwargs):
return self._dist.rsample(sample_shape)
def logp(self, value, **kwargs):
return self.torch_dist.log_prob(value)
def entropy(self):
return self.torch_dist.entropy()
def kl(self, other, **kwargs):
return torch.distributions.kl.kl_divergence(self.torch_dist, other.torch_dist)
@staticmethod
def required_input_dim(space, **kwargs):
return int(space.n)
@classmethod
# This method is used to create distributions from logits inside RLModules.
# You can use this to inject arguments into the constructor of this distribution
# that are not the logits themselves.
def from_logits(cls, logits):
return CustomTorchCategorical(logits=logits)
# This method is used to create a deterministic distribution for the
# PPORLModule.forward_inference.
def to_deterministic(self):
return TorchDeterministic(loc=torch.argmax(self.logits, dim=-1))
# See if we can create this distribution and sample from it to interact with our
# target environment
env = gym.make("CartPole-v1")
dummy_logits = torch.randn([env.action_space.n])
dummy_dist = CustomTorchCategorical.from_logits(dummy_logits)
action = dummy_dist.sample()
env = gym.make("CartPole-v1")
env.reset()
env.step(action.numpy())
# Define a simple catalog that returns our custom distribution when
# get_action_dist_cls is called
class CustomPPOCatalog(PPOCatalog):
def get_action_dist_cls(self, framework):
# The distribution we wrote will only work with torch
assert framework == "torch"
return CustomTorchCategorical
# Train with our custom action distribution
algo = (
PPOConfig()
.environment("CartPole-v1")
.rl_module(rl_module_spec=RLModuleSpec(catalog_class=CustomPPOCatalog))
.build()
)
results = algo.train()
print(results)
These examples target PPO but the workflows apply to all RLlib algorithms.
Note that PPO adds the from ray.rllib.core.models.base.ActorCriticEncoder
and two heads (policy- and value-head) to the base class.
You can override these similarly to the above.
Other algorithms may add different sub-components or override default ones.
Write a Catalog from scratch#
You only need this when you want to write a new Algorithm under RLlib. Note that writing an Algorithm does not strictly require writing a new Catalog but you can use Catalogs as a tool to create the fitting default sub-components, such as models or distributions. The following are typical requirements and steps for writing a new Catalog:
Does the Algorithm need a special Encoder? Overwrite
_get_encoder_config()
.Does the Algorithm need an additional network? Write a method to build it. You can use RLlib’s model configurations to build models from dimensions.
Does the Algorithm need a custom distribution? Overwrite
_get_dist_cls_from_action_space()
.Does the Algorithm need a special tokenizer? Overwrite
get_tokenizer_config()
.Does the Algorithm not need an Encoder at all? Overwrite
_determine_components_hook()
.
The following example shows the implementation of a Catalog for the PPO algorithm based on the preceeding steps:
Catalog for PPORLModules
import gymnasium as gym
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import (
ActorCriticEncoderConfig,
MLPHeadConfig,
FreeLogStdMLPHeadConfig,
)
from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model
from ray.rllib.utils import override
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic
def _check_if_diag_gaussian(action_distribution_cls, framework, no_error=False):
if framework == "torch":
from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian
is_diag_gaussian = issubclass(action_distribution_cls, TorchDiagGaussian)
if no_error:
return is_diag_gaussian
else:
assert is_diag_gaussian, (
f"free_log_std is only supported for DiagGaussian action "
f"distributions. Found action distribution: {action_distribution_cls}."
)
elif framework == "tf2":
from ray.rllib.models.tf.tf_distributions import TfDiagGaussian
is_diag_gaussian = issubclass(action_distribution_cls, TfDiagGaussian)
if no_error:
return is_diag_gaussian
else:
assert is_diag_gaussian, (
"free_log_std is only supported for DiagGaussian action distributions. "
"Found action distribution: {}.".format(action_distribution_cls)
)
else:
raise ValueError(f"Framework {framework} not supported for free_log_std.")
class PPOCatalog(Catalog):
"""The Catalog class used to build models for PPO.
PPOCatalog provides the following models:
- ActorCriticEncoder: The encoder used to encode the observations.
- Pi Head: The head used to compute the policy logits.
- Value Function Head: The head used to compute the value function.
The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
for the policy and value function. See implementations of PPORLModule for
more details.
Any custom ActorCriticEncoder can be built by overriding the
build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom
ActorCriticEncoder during RLModule runtime.
Any custom head can be built by overriding the build_pi_head() and build_vf_head()
methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
build custom heads during RLModule runtime.
Any module built for exploration or inference is built with the flag
`ìnference_only=True` and does not contain a value network. This flag can be set
in the `SingleAgentModuleSpec` through the `inference_only` boolean flag.
In case that the actor-critic-encoder is not shared between the policy and value
function, the inference-only module will contain only the actor encoder network.
"""
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
model_config_dict: dict,
):
"""Initializes the PPOCatalog.
Args:
observation_space: The observation space of the Encoder.
action_space: The action space for the Pi Head.
model_config_dict: The model config to use.
"""
super().__init__(
observation_space=observation_space,
action_space=action_space,
model_config_dict=model_config_dict,
)
# Replace EncoderConfig by ActorCriticEncoderConfig
self.actor_critic_encoder_config = ActorCriticEncoderConfig(
base_encoder_config=self._encoder_config,
shared=self._model_config_dict["vf_share_layers"],
)
self.pi_and_vf_head_hiddens = self._model_config_dict["head_fcnet_hiddens"]
self.pi_and_vf_head_activation = self._model_config_dict[
"head_fcnet_activation"
]
# We don't have the exact (framework specific) action dist class yet and thus
# cannot determine the exact number of output nodes (action space) required.
# -> Build pi config only in the `self.build_pi_head` method.
self.pi_head_config = None
self.vf_head_config = MLPHeadConfig(
input_dims=self.latent_dims,
hidden_layer_dims=self.pi_and_vf_head_hiddens,
hidden_layer_activation=self.pi_and_vf_head_activation,
output_layer_activation="linear",
output_layer_dim=1,
)
@OverrideToImplementCustomLogic
def build_actor_critic_encoder(self, framework: str) -> ActorCriticEncoder:
"""Builds the ActorCriticEncoder.
The default behavior is to build the encoder from the encoder_config.
This can be overridden to build a custom ActorCriticEncoder as a means of
configuring the behavior of a PPORLModule implementation.
Args:
framework: The framework to use. Either "torch" or "tf2".
Returns:
The ActorCriticEncoder.
"""
return self.actor_critic_encoder_config.build(framework=framework)
@override(Catalog)
def build_encoder(self, framework: str) -> Encoder:
"""Builds the encoder.
Since PPO uses an ActorCriticEncoder, this method should not be implemented.
"""
raise NotImplementedError(
"Use PPOCatalog.build_actor_critic_encoder() instead for PPO."
)
@OverrideToImplementCustomLogic
def build_pi_head(self, framework: str) -> Model:
"""Builds the policy head.
The default behavior is to build the head from the pi_head_config.
This can be overridden to build a custom policy head as a means of configuring
the behavior of a PPORLModule implementation.
Args:
framework: The framework to use. Either "torch" or "tf2".
Returns:
The policy head.
"""
# Get action_distribution_cls to find out about the output dimension for pi_head
action_distribution_cls = self.get_action_dist_cls(framework=framework)
if self._model_config_dict["free_log_std"]:
_check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls, framework=framework
)
is_diag_gaussian = True
else:
is_diag_gaussian = _check_if_diag_gaussian(
action_distribution_cls=action_distribution_cls,
framework=framework,
no_error=True,
)
required_output_dim = action_distribution_cls.required_input_dim(
space=self.action_space, model_config=self._model_config_dict
)
# Now that we have the action dist class and number of outputs, we can define
# our pi-config and build the pi head.
pi_head_config_class = (
FreeLogStdMLPHeadConfig
if self._model_config_dict["free_log_std"]
else MLPHeadConfig
)
self.pi_head_config = pi_head_config_class(
input_dims=self.latent_dims,
hidden_layer_dims=self.pi_and_vf_head_hiddens,
hidden_layer_activation=self.pi_and_vf_head_activation,
output_layer_dim=required_output_dim,
output_layer_activation="linear",
clip_log_std=is_diag_gaussian,
log_std_clip_param=self._model_config_dict.get("log_std_clip_param", 20),
)
return self.pi_head_config.build(framework=framework)
@OverrideToImplementCustomLogic
def build_vf_head(self, framework: str) -> Model:
"""Builds the value function head.
The default behavior is to build the head from the vf_head_config.
This can be overridden to build a custom value function head as a means of
configuring the behavior of a PPORLModule implementation.
Args:
framework: The framework to use. Either "torch" or "tf2".
Returns:
The value function head.
"""
return self.vf_head_config.build(framework=framework)