Note

Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.

Note, however, that so far only PPO (single- and multi-agent) and SAC (single-agent only) support the “new API stack” and continue to run by default with the old APIs. You can continue to use the existing custom (old stack) classes.

See here for more details on how to use the new API stack.

Note

This doc is related to RLlib’s new API stack and therefore experimental.

Catalog (Alpha)#

Catalog is a utility abstraction that modularizes the construction of components for RLModules. It includes information such how input observation spaces should be encoded, what action distributions should be used, and so on. Catalog. For example, PPOTorchRLModule has the PPOCatalog. To customize existing RLModules either change the RLModule directly by inheriting the class and changing the setup() method or, alternatively, extend the Catalog class attributed to that RLModule. Use Catalogs only if your customizations fits the abstractions provided by Catalog.

Note

Modifying Catalogs signifies advanced use cases so you should only consider this if modifying an RLModule or writing one does not cover your use case. We recommend to modify Catalogs only when making deeper customizations to the decision trees that determine what Model and Distribution RLlib creates by default.

Note

If you simply want to modify a Model by changing its default values, have a look at the model config dict:

MODEL_DEFAULTS

This dict (or an overriding sub-set) is part of AlgorithmConfig and therefore also part of any algorithm-specific config. To change the behavior RLlib’s default models, override it and pass it to an AlgorithmConfig. to change the behavior RLlib’s default models.

MODEL_DEFAULTS: ModelConfigDict = {
    # Experimental flag.
    # If True, user specified no preprocessor to be created
    # (via config._disable_preprocessor_api=True). If True, observations
    # will arrive in model as they are returned by the env.
    "_disable_preprocessor_api": False,
    # Experimental flag.
    # If True, RLlib will no longer flatten the policy-computed actions into
    # a single tensor (for storage in SampleCollectors/output files/etc..),
    # but leave (possibly nested) actions as-is. Disabling flattening affects:
    # - SampleCollectors: Have to store possibly nested action structs.
    # - Models that have the previous action(s) as part of their input.
    # - Algorithms reading from offline files (incl. action information).
    "_disable_action_flattening": False,

    # === Built-in options ===
    # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
    # These are used if no custom model is specified and the input space is 1D.
    # Number of hidden layers to be used.
    "fcnet_hiddens": [256, 256],
    # Activation function descriptor.
    # Supported values are: "tanh", "relu", "swish" (or "silu", which is the same),
    # "linear" (or None).
    "fcnet_activation": "tanh",
    # Initializer function or class descriptor for encoder weigths.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "fcnet_weights_initializer": None,
    # Initializer configuration for encoder weights.
    # This configuration is passed to the initializer defined in
    # `fcnet_weights_initializer`.
    "fcnet_weights_initializer_config": None,
    # Initializer function or class descriptor for encoder bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "fcnet_bias_initializer": None,
    # Initializer configuration for encoder bias.
    # This configuration is passed to the initializer defined in
    # `fcnet_bias_initializer`.
    "fcnet_bias_initializer_config": None,

    # VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
    # These are used if no custom model is specified and the input space is 2D.
    # Filter config: List of [out_channels, kernel, stride] for each filter.
    # Example:
    # Use None for making RLlib try to find a default filter setup given the
    # observation space.
    "conv_filters": None,
    # Activation function descriptor.
    # Supported values are: "tanh", "relu", "swish" (or "silu", which is the same),
    # "linear" (or None).
    "conv_activation": "relu",
    # Initializer function or class descriptor for CNN encoder kernel.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_kernel_initializer": None,
    # Initializer configuration for CNN encoder kernel.
    # This configuration is passed to the initializer defined in
    # `conv_weights_initializer`.
    "conv_kernel_initializer_config": None,
    # Initializer function or class descriptor for CNN encoder bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_bias_initializer": None,
    # Initializer configuration for CNN encoder bias.
    # This configuration is passed to the initializer defined in
    # `conv_bias_initializer`.
    "conv_bias_initializer_config": None,
    # Initializer function or class descriptor for CNN head (pi, Q, V) kernel.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_transpose_kernel_initializer": None,
    # Initializer configuration for CNN head (pi, Q, V) kernel.
    # This configuration is passed to the initializer defined in
    # `conv_transpose_weights_initializer`.
    "conv_transpose_kernel_initializer_config": None,
    # Initializer function or class descriptor for CNN head (pi, Q, V) bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "conv_transpose_bias_initializer": None,
    # Initializer configuration for CNN head (pi, Q, V) bias.
    # This configuration is passed to the initializer defined in
    # `conv_transpose_bias_initializer`.
    "conv_transpose_bias_initializer_config": None,

    # Some default models support a final FC stack of n Dense layers with given
    # activation:
    # - Complex observation spaces: Image components are fed through
    #   VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
    #   everything is concated and pushed through this final FC stack.
    # - VisionNets (CNNs), e.g. after the CNN stack, there may be
    #   additional Dense layers.
    # - FullyConnectedNetworks will have this additional FCStack as well
    # (that's why it's empty by default).
    "post_fcnet_hiddens": [],
    "post_fcnet_activation": "relu",
    # Initializer function or class descriptor for head (pi, Q, V) weights.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "post_fcnet_weights_initializer": None,
    # Initializer configuration for head (pi, Q, V) weights.
    # This configuration is passed to the initializer defined in
    # `post_fcnet_weights_initializer`.
    "post_fcnet_weights_initializer_config": None,
    # Initializer function or class descriptor for head (pi, Q, V) bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "post_fcnet_bias_initializer": None,
    # Initializer configuration for head (pi, Q, V) bias.
    # This configuration is passed to the initializer defined in
    # `post_fcnet_bias_initializer`.
    "post_fcnet_bias_initializer_config": None,

    # For DiagGaussian action distributions, make the second half of the model
    # outputs floating bias variables instead of state-dependent. This only
    # has an effect is using the default fully connected net.
    "free_log_std": False,
    # Whether to skip the final linear layer used to resize the hidden layer
    # outputs to size `num_outputs`. If True, then the last hidden layer
    # should already match num_outputs.
    "no_final_linear": False,
    # Whether layers should be shared for the value function.
    "vf_share_layers": True,

    # == LSTM ==
    # Whether to wrap the model with an LSTM.
    "use_lstm": False,
    # Max seq len for training the LSTM, defaults to 20.
    "max_seq_len": 20,
    # Size of the LSTM cell.
    "lstm_cell_size": 256,
    # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
    "lstm_use_prev_action": False,
    # Whether to feed r_{t-1} to LSTM.
    "lstm_use_prev_reward": False,
    # Initializer function or class descriptor for LSTM weights.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "lstm_weights_initializer": None,
    # Initializer configuration for LSTM weights.
    # This configuration is passed to the initializer defined in
    # `lstm_weights_initializer`.
    "lstm_weights_initializer_config": None,
    # Initializer function or class descriptor for LSTM bias.
    # Supported values are the initializer names (str), classes or functions listed
    # by the frameworks (`tf2``, `torch`). See
    # https://pytorch.org/docs/stable/nn.init.html for `torch` and
    # https://www.tensorflow.org/api_docs/python/tf/keras/initializers for `tf2`.
    # Note, if `None`, the default initializer defined by `torch` or `tf2` is used.
    "lstm_bias_initializer": None,
    # Initializer configuration for LSTM bias.
    # This configuration is passed to the initializer defined in
    # `lstm_bias_initializer`.
    "lstm_bias_initializer_config": None,
    # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
    "_time_major": False,

    # == Attention Nets (experimental: torch-version is untested) ==
    # Whether to use a GTrXL ("Gru transformer XL"; attention net) as the
    # wrapper Model around the default Model.
    "use_attention": False,
    # The number of transformer units within GTrXL.
    # A transformer unit in GTrXL consists of a) MultiHeadAttention module and
    # b) a position-wise MLP.
    "attention_num_transformer_units": 1,
    # The input and output size of each transformer unit.
    "attention_dim": 64,
    # The number of attention heads within the MultiHeadAttention units.
    "attention_num_heads": 1,
    # The dim of a single head (within the MultiHeadAttention units).
    "attention_head_dim": 32,
    # The memory sizes for inference and training.
    "attention_memory_inference": 50,
    "attention_memory_training": 50,
    # The output dim of the position-wise MLP.
    "attention_position_wise_mlp_dim": 32,
    # The initial bias values for the 2 GRU gates within a transformer unit.
    "attention_init_gru_gate_bias": 2.0,
    # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
    "attention_use_n_prev_actions": 0,
    # Whether to feed r_{t-n:t-1} to GTrXL.
    "attention_use_n_prev_rewards": 0,

    # == Atari ==
    # Set to True to enable 4x stacking behavior.
    "framestack": True,
    # Final resized frame dimension
    "dim": 84,
    # (deprecated) Converts ATARI frame to 1 Channel Grayscale image
    "grayscale": False,
    # (deprecated) Changes frame to range from [-1, 1] if true
    "zero_mean": True,

    # === Options for custom models ===
    # Name of a custom model to use
    "custom_model": None,
    # Extra options to pass to the custom classes. These will be available to
    # the Model's constructor in the model_config field. Also, they will be
    # attempted to be passed as **kwargs to ModelV2 models. For an example,
    # see rllib/models/[tf|torch]/attention_net.py.
    "custom_model_config": {},
    # Name of a custom action distribution to use.
    "custom_action_dist": None,
    # Custom preprocessors are deprecated. Please use a wrapper class around
    # your environment instead to preprocess observations.
    "custom_preprocessor": None,

    # === Options for ModelConfigs in RLModules ===
    # The latent dimension to encode into.
    # Since most RLModules have an encoder and heads, this establishes an agreement
    # on the dimensionality of the latent space they share.
    # This has no effect for models outside RLModule.
    # If None, model_config['fcnet_hiddens'][-1] value will be used to guarantee
    # backward compatibility to old configs. This yields different models than past
    # versions of RLlib.
    "encoder_latent_dim": None,
    # Whether to always check the inputs and outputs of RLlib's default models for
    # their specifications. Input specifications are checked on failed forward passes
    # of the models regardless of this flag. If this flag is set to `True`, inputs and
    # outputs are checked on every call. This leads to a slow-down and should only be
    # used for debugging. Note that this flag is only relevant for instances of
    # RLlib's Model class. These are commonly generated from ModelConfigs in RLModules.
    "always_check_shapes": False,

    # Deprecated keys:
    # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
    "lstm_use_prev_action_reward": DEPRECATED_VALUE,
    # Deprecated in anticipation of RLModules API
    "_use_default_native_models": DEPRECATED_VALUE,

}

While Catalogs have a base class Catalog, you mostly interact with Algorithm-specific Catalogs. Therefore, this doc also includes examples around PPO from which you can extrapolate to other algorithms. Prerequisites for this user guide is a rough understanding of RLModules. This user guide covers the following topics:

  • What are Catalogs

  • Catalog design and ideas

  • Catalog and AlgorithmConfig

  • Basic usage

  • Inject your custom models into RLModules

  • Inject your custom action distributions into RLModules

  • Write a Catalog from scratch

What are Catalogs#

Catalogs have two primary roles: Choosing the right Model and choosing the right Distribution. By default, all catalogs implement decision trees that decide model architecture based on a combination of input configurations. These mainly include the observation space and action space of the RLModule, the model config dict and the deep learning framework backend.

The following diagram shows the break down of the information flow towards models and distributions within an RLModule. An RLModule creates an instance of the Catalog class they receive as part of their constructor. It then create its internal models and distributions with the help of this Catalog.

Note

You can also modify Model or Distribution in an RLModule directly by overriding the RLModule’s constructor!

../_images/catalog_and_rlm_diagram.svg

The following diagram shows a concrete case in more detail.

Example of catalog in a PPORLModule

The PPOCatalog is fed an observation space, action space, a model config dict and the view requirements of the RLModule. The model config dicts and the view requirements are only of interest in special cases, such as recurrent networks or attention networks. A PPORLModule has four components that are created by the PPOCatalog: Encoder, value function head, policy head, and action distribution.

../_images/ppo_catalog_and_rlm_diagram.svg

Catalog design and ideas#

Since the main use cases for this component involve deep modifications of it, we explain the design and ideas behind Catalogs in this section.

What problems do Catalogs solve?#

RL algorithms need neural network models and distributions. Within an algorithm, many different architectures for such sub-components are valid. Moreover, models and distributions vary with environments. However, most algorithms require models that have similarities. The problem is finding sensible sub-components for a wide range of use cases while sharing this functionality across algorithms.

How do Catalogs solve this?#

As states above, Catalogs implement decision-trees for sub-components of RLModules. Models and distributions from a Catalog object are meant to fit together. Since we mostly build RLModules out of Encoder s, Heads and Distribution s, Catalogs also generally reflect this. For example, the PPOCatalog will output Encoders that output a latent vector and two Heads that take this latent vector as input. (That’s why Catalogs have a latent_dims attribute). Heads and distributions behave accordingly. Whenever you create a Catalog, the decision tree is executed to find suitable configs for models and classes for distributions. By default this happens in get_encoder_config() and _get_dist_cls_from_action_space(). Whenever you build a model, the config is turned into a model. Distributions are instantiated per forward pass of an RLModule and are therefore not built.

API philosophy#

Catalogs attempt to encapsulate most complexity around models inside the Encoder. This means that recurrency, attention and other special cases are fully handles inside the Encoder and are transparent to other components. Encoders are the only components that the Catalog base class builds. This is because many algorithms require custom heads and distributions but most of them can use the same encoders. The Catalog API is designed such that interaction usually happens in two stages:

  • Instantiate a Catalog. This executes the decision tree.

  • Generate arbitrary number of decided components through Catalog methods.

The two default methods to access components on the base class are…

You can override these to quickly hack what models RLModules build. Other methods are private and should only be overridden to make deep changes to the decision tree to enhance the capabilities of Catalogs. Additionally, get_tokenizer_config() is a method that can be used when tokenization is required. Tokenization means single-step-embedding. Encoding also means embedding but can span multiple timesteps. In fact, RLlib’s tokenizers used in its recurrent Encoders (e.g. TorchLSTMEncoder), are instances of non-recurrent Encoder classes.

Catalog and AlgorithmConfig#

Since Catalogs effectively control what models and distributions RLlib uses under the hood, they are also part of RLlib’s configurations. As the primary entry point for configuring RLlib, AlgorithmConfig is the place where you can configure the Catalogs of the RLModules that are created. You set the catalog class by going through the SingleAgentRLModuleSpec or MultiAgentRLModuleSpec of an AlgorithmConfig. For example, in heterogeneous multi-agent cases, you modify the MultiAgentRLModuleSpec.

../_images/catalog_rlmspecs_diagram.svg

The following example shows how to configure the Catalog of an RLModule created by PPO.

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec


class MyPPOCatalog(PPOCatalog):
    def __init__(self, *args, **kwargs):
        print("Hi from within PPORLModule!")
        super().__init__(*args, **kwargs)


config = (
    PPOConfig()
    .experimental(_enable_new_api_stack=True)
    .environment("CartPole-v1")
    .framework("torch")
)

# Specify the catalog to use for the PPORLModule.
config = config.rl_module(
    rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyPPOCatalog)
)
# This is how RLlib constructs a PPORLModule
# It will say "Hi from within PPORLModule!".
ppo = config.build()

Basic usage#

In the following three examples, we play with Catalogs to illustrate their API.

High-level API#

The first example showcases the general API for interacting with Catalogs.

import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

Creating models and distributions#

The second example showcases how to use the base Catalog to create an model and an action distribution. Besides these, we create a head network by hand that fits these two by hand.

Customize a policy head
import gymnasium as gym
import torch

# ENCODER_OUT is a constant we use to enumerate Encoder I/O.
from ray.rllib.core.columns import Columns
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.catalog import Catalog

env = gym.make("CartPole-v1")

catalog = Catalog(env.observation_space, env.action_space, model_config_dict={})
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_encoder(framework="torch")
# Build a suitable head model for the action distribution.
# We need `env.action_space.n` action distribution inputs.
head = torch.nn.Linear(catalog.latent_dims[0], env.action_space.n)
# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {Columns.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT]
action_dist_inputs = head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])

Creating models and distributions for PPO#

The third example showcases how to use the PPOCatalog to create a encoder and an action distribution. This is more similar to what RLlib does internally.

Use catalog-generated models
import gymnasium as gym
import torch

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

# STATE_IN, STATE_OUT and ENCODER_OUT are constants we use to enumerate Encoder I/O.
from ray.rllib.core.models.base import ENCODER_OUT, ACTOR
from ray.rllib.policy.sample_batch import SampleBatch

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {SampleBatch.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT][ACTOR]
action_dist_inputs = policy_head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])

Note that the above two examples illustrate in principle what it takes to implement a Catalog. In this case, we see the difference between Catalog and PPOCatalog. In most cases, we can reuse the capabilities of the base Catalog base class and only need to add methods to build head networks that we can then use in the appropriate RLModule.

Inject your custom model or action distributions into Catalogs#

You can make a Catalog build custom models by overriding the Catalog’s methods used by RLModules to build models. Have a look at these lines from the constructor of the PPOTorchRLModule to see how Catalogs are being used by an RLModule:

        catalog = self.config.get_catalog()

        # Build models from catalog
        self.encoder = catalog.build_actor_critic_encoder(framework=self.framework)
        self.pi = catalog.build_pi_head(framework=self.framework)
        self.vf = catalog.build_vf_head(framework=self.framework)

        self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework)

Note that what happens inside the constructor of PPOTorchRLModule is similar to the earlier example Creating models and distributions for PPO.

Consequently, in order to build a custom Model compatible with a PPORLModule, you can override methods by inheriting from PPOCatalog or write a Catalog that implements them from scratch. The following examples showcase such modifications:

This example shows two modifications:

  • How to write a custom Encoder

  • How to inject the custom Encoder into a Catalog

Note that, if you only want to inject your Encoder into a single RLModule, the recommended workflow is to inherit from an existing RL Module and place the Encoder there.

import gymnasium as gym
import numpy as np

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import (
    MobileNetV2EncoderConfig,
    MOBILENET_INPUT_SHAPE,
)
from ray.rllib.examples.envs.classes.random_env import RandomEnv


# Define a PPO Catalog that we can use to inject our MobileNetV2 Encoder into RLlib's
# decision tree of what model to choose
class MobileNetEnhancedPPOCatalog(PPOCatalog):
    @classmethod
    def _get_encoder_config(
        cls,
        observation_space: gym.Space,
        **kwargs,
    ):
        if (
            isinstance(observation_space, gym.spaces.Box)
            and observation_space.shape == MOBILENET_INPUT_SHAPE
        ):
            # Inject our custom encoder here, only if the observation space fits it
            return MobileNetV2EncoderConfig()
        else:
            return super()._get_encoder_config(observation_space, **kwargs)


# Create a generic config with our enhanced Catalog
ppo_config = (
    PPOConfig()
    .experimental(_enable_new_api_stack=True)
    .rl_module(
        rl_module_spec=SingleAgentRLModuleSpec(
            catalog_class=MobileNetEnhancedPPOCatalog
        )
    )
    .rollouts(num_rollout_workers=0)
    # The following training settings make it so that a training iteration is very
    # quick. This is just for the sake of this example. PPO will not learn properly
    # with these settings!
    .training(train_batch_size=32, sgd_minibatch_size=16, num_sgd_iter=1)
)

# CartPole's observation space is not compatible with our MobileNetV2 Encoder, so
# this will use the default behaviour of Catalogs
ppo_config.environment("CartPole-v1")
results = ppo_config.build().train()
print(results)

# For this training, we use a RandomEnv with observations of shape
# MOBILENET_INPUT_SHAPE. This will use our custom Encoder.
ppo_config.environment(
    RandomEnv,
    env_config={
        "action_space": gym.spaces.Discrete(2),
        # Test a simple Image observation space.
        "observation_space": gym.spaces.Box(
            0.0,
            1.0,
            shape=MOBILENET_INPUT_SHAPE,
            dtype=np.float32,
        ),
    },
)
results = ppo_config.build().train()
print(results)

This example shows two modifications:

  • How to write a custom Distribution

  • How to inject the custom action distribution into a Catalog

import torch
import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.models.distributions import Distribution
from ray.rllib.models.torch.torch_distributions import TorchDeterministic


# Define a simple categorical distribution that can be used for PPO
class CustomTorchCategorical(Distribution):
    def __init__(self, logits):
        self.torch_dist = torch.distributions.categorical.Categorical(logits=logits)

    def sample(self, sample_shape=torch.Size(), **kwargs):
        return self.torch_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size(), **kwargs):
        return self._dist.rsample(sample_shape)

    def logp(self, value, **kwargs):
        return self.torch_dist.log_prob(value)

    def entropy(self):
        return self.torch_dist.entropy()

    def kl(self, other, **kwargs):
        return torch.distributions.kl.kl_divergence(self.torch_dist, other.torch_dist)

    @staticmethod
    def required_input_dim(space, **kwargs):
        return int(space.n)

    @classmethod
    # This method is used to create distributions from logits inside RLModules.
    # You can use this to inject arguments into the constructor of this distribution
    # that are not the logits themselves.
    def from_logits(cls, logits):
        return CustomTorchCategorical(logits=logits)

    # This method is used to create a deterministic distribution for the
    # PPORLModule.forward_inference.
    def to_deterministic(self):
        return TorchDeterministic(loc=torch.argmax(self.logits, dim=-1))


# See if we can create this distribution and sample from it to interact with our
# target environment
env = gym.make("CartPole-v1")
dummy_logits = torch.randn([env.action_space.n])
dummy_dist = CustomTorchCategorical.from_logits(dummy_logits)
action = dummy_dist.sample()
env = gym.make("CartPole-v1")
env.reset()
env.step(action.numpy())


# Define a simple catalog that returns our custom distribution when
# get_action_dist_cls is called
class CustomPPOCatalog(PPOCatalog):
    def get_action_dist_cls(self, framework):
        # The distribution we wrote will only work with torch
        assert framework == "torch"
        return CustomTorchCategorical


# Train with our custom action distribution
algo = (
    PPOConfig()
    .environment("CartPole-v1")
    .rl_module(rl_module_spec=SingleAgentRLModuleSpec(catalog_class=CustomPPOCatalog))
    .build()
)
results = algo.train()
print(results)

These examples target PPO but the workflows apply to all RLlib algorithms. Note that PPO adds the from ray.rllib.core.models.base.ActorCriticEncoder and two heads (policy- and value-head) to the base class. You can override these similarly to the above. Other algorithms may add different sub-components or override default ones.

Write a Catalog from scratch#

You only need this when you want to write a new Algorithm under RLlib. Note that writing an Algorithm does not strictly require writing a new Catalog but you can use Catalogs as a tool to create the fitting default sub-components, such as models or distributions. The following are typical requirements and steps for writing a new Catalog:

The following example shows our implementation of a Catalog for PPO that follows the above steps:

Catalog for PPORLModules
import gymnasium as gym

from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import (
    ActorCriticEncoderConfig,
    MLPHeadConfig,
    FreeLogStdMLPHeadConfig,
)
from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model
from ray.rllib.utils import override
from ray.rllib.utils.annotations import OverrideToImplementCustomLogic


def _check_if_diag_gaussian(action_distribution_cls, framework):
    if framework == "torch":
        from ray.rllib.models.torch.torch_distributions import TorchDiagGaussian

        assert issubclass(action_distribution_cls, TorchDiagGaussian), (
            f"free_log_std is only supported for DiagGaussian action distributions. "
            f"Found action distribution: {action_distribution_cls}."
        )
    elif framework == "tf2":
        from ray.rllib.models.tf.tf_distributions import TfDiagGaussian

        assert issubclass(action_distribution_cls, TfDiagGaussian), (
            "free_log_std is only supported for DiagGaussian action distributions. "
            "Found action distribution: {}.".format(action_distribution_cls)
        )
    else:
        raise ValueError(f"Framework {framework} not supported for free_log_std.")


class PPOCatalog(Catalog):
    """The Catalog class used to build models for PPO.

    PPOCatalog provides the following models:
        - ActorCriticEncoder: The encoder used to encode the observations.
        - Pi Head: The head used to compute the policy logits.
        - Value Function Head: The head used to compute the value function.

    The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
    for the policy and value function. See implementations of PPORLModule for
    more details.

    Any custom ActorCriticEncoder can be built by overriding the
    build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig
    at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom
    ActorCriticEncoder during RLModule runtime.

    Any custom head can be built by overriding the build_pi_head() and build_vf_head()
    methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to
    build custom heads during RLModule runtime.
    """

    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config_dict: dict,
    ):
        """Initializes the PPOCatalog.

        Args:
            observation_space: The observation space of the Encoder.
            action_space: The action space for the Pi Head.
            model_config_dict: The model config to use.
        """
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            model_config_dict=model_config_dict,
        )

        # Replace EncoderConfig by ActorCriticEncoderConfig
        self.actor_critic_encoder_config = ActorCriticEncoderConfig(
            base_encoder_config=self._encoder_config,
            shared=self._model_config_dict["vf_share_layers"],
        )

        self.pi_and_vf_head_hiddens = self._model_config_dict["post_fcnet_hiddens"]
        self.pi_and_vf_head_activation = self._model_config_dict[
            "post_fcnet_activation"
        ]

        # We don't have the exact (framework specific) action dist class yet and thus
        # cannot determine the exact number of output nodes (action space) required.
        # -> Build pi config only in the `self.build_pi_head` method.
        self.pi_head_config = None

        self.vf_head_config = MLPHeadConfig(
            input_dims=self.latent_dims,
            hidden_layer_dims=self.pi_and_vf_head_hiddens,
            hidden_layer_activation=self.pi_and_vf_head_activation,
            output_layer_activation="linear",
            output_layer_dim=1,
        )

    @OverrideToImplementCustomLogic
    def build_actor_critic_encoder(self, framework: str) -> ActorCriticEncoder:
        """Builds the ActorCriticEncoder.

        The default behavior is to build the encoder from the encoder_config.
        This can be overridden to build a custom ActorCriticEncoder as a means of
        configuring the behavior of a PPORLModule implementation.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The ActorCriticEncoder.
        """
        return self.actor_critic_encoder_config.build(framework=framework)

    @override(Catalog)
    def build_encoder(self, framework: str) -> Encoder:
        """Builds the encoder.

        Since PPO uses an ActorCriticEncoder, this method should not be implemented.
        """
        raise NotImplementedError(
            "Use PPOCatalog.build_actor_critic_encoder() instead for PPO."
        )

    @OverrideToImplementCustomLogic
    def build_pi_head(self, framework: str) -> Model:
        """Builds the policy head.

        The default behavior is to build the head from the pi_head_config.
        This can be overridden to build a custom policy head as a means of configuring
        the behavior of a PPORLModule implementation.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The policy head.
        """
        # Get action_distribution_cls to find out about the output dimension for pi_head
        action_distribution_cls = self.get_action_dist_cls(framework=framework)
        if self._model_config_dict["free_log_std"]:
            _check_if_diag_gaussian(
                action_distribution_cls=action_distribution_cls, framework=framework
            )
        required_output_dim = action_distribution_cls.required_input_dim(
            space=self.action_space, model_config=self._model_config_dict
        )
        # Now that we have the action dist class and number of outputs, we can define
        # our pi-config and build the pi head.
        pi_head_config_class = (
            FreeLogStdMLPHeadConfig
            if self._model_config_dict["free_log_std"]
            else MLPHeadConfig
        )
        self.pi_head_config = pi_head_config_class(
            input_dims=self.latent_dims,
            hidden_layer_dims=self.pi_and_vf_head_hiddens,
            hidden_layer_activation=self.pi_and_vf_head_activation,
            output_layer_dim=required_output_dim,
            output_layer_activation="linear",
        )

        return self.pi_head_config.build(framework=framework)

    @OverrideToImplementCustomLogic
    def build_vf_head(self, framework: str) -> Model:
        """Builds the value function head.

        The default behavior is to build the head from the vf_head_config.
        This can be overridden to build a custom value function head as a means of
        configuring the behavior of a PPORLModule implementation.

        Args:
            framework: The framework to use. Either "torch" or "tf2".

        Returns:
            The value function head.
        """
        return self.vf_head_config.build(framework=framework)