Note

This doc is related to the RLModule API and therefore experimental.

Note

Interacting with Catalogs mainly covers advanced use cases.

Catalog (Alpha)#

Catalogs are where RLModules primarily get their models and action distributions from. Each RLModule has its own default Catalog. For example, PPOTorchRLModule has the PPOCatalog. You can override Catalogs’ methods to alter the behavior of existing RLModules. This makes Catalogs a means of configuration for RLModules. You interact with Catalogs when making deeper customization to what Model and Distribution RLlib creates by default.

Note

If you simply want to modify a Model by changing its default values, have a look at the model config dict:

``MODEL_DEFAULTS`` dict

This dict (or an overriding sub-set) is part of AlgorithmConfig and therefore also part of any algorithm-specific config. You can override its values and pass it to an AlgorithmConfig to change the behavior RLlib’s default models.

MODEL_DEFAULTS: ModelConfigDict = {
    # Experimental flag.
    # If True, user specified no preprocessor to be created
    # (via config._disable_preprocessor_api=True). If True, observations
    # will arrive in model as they are returned by the env.
    "_disable_preprocessor_api": False,
    # Experimental flag.
    # If True, RLlib will no longer flatten the policy-computed actions into
    # a single tensor (for storage in SampleCollectors/output files/etc..),
    # but leave (possibly nested) actions as-is. Disabling flattening affects:
    # - SampleCollectors: Have to store possibly nested action structs.
    # - Models that have the previous action(s) as part of their input.
    # - Algorithms reading from offline files (incl. action information).
    "_disable_action_flattening": False,

    # === Built-in options ===
    # FullyConnectedNetwork (tf and torch): rllib.models.tf|torch.fcnet.py
    # These are used if no custom model is specified and the input space is 1D.
    # Number of hidden layers to be used.
    "fcnet_hiddens": [256, 256],
    # Activation function descriptor.
    # Supported values are: "tanh", "relu", "swish" (or "silu", which is the same),
    # "linear" (or None).
    "fcnet_activation": "tanh",

    # VisionNetwork (tf and torch): rllib.models.tf|torch.visionnet.py
    # These are used if no custom model is specified and the input space is 2D.
    # Filter config: List of [out_channels, kernel, stride] for each filter.
    # Example:
    # Use None for making RLlib try to find a default filter setup given the
    # observation space.
    "conv_filters": None,
    # Activation function descriptor.
    # Supported values are: "tanh", "relu", "swish" (or "silu", which is the same),
    # "linear" (or None).
    "conv_activation": "relu",

    # Some default models support a final FC stack of n Dense layers with given
    # activation:
    # - Complex observation spaces: Image components are fed through
    #   VisionNets, flat Boxes are left as-is, Discrete are one-hot'd, then
    #   everything is concated and pushed through this final FC stack.
    # - VisionNets (CNNs), e.g. after the CNN stack, there may be
    #   additional Dense layers.
    # - FullyConnectedNetworks will have this additional FCStack as well
    # (that's why it's empty by default).
    "post_fcnet_hiddens": [],
    "post_fcnet_activation": "relu",

    # For DiagGaussian action distributions, make the second half of the model
    # outputs floating bias variables instead of state-dependent. This only
    # has an effect is using the default fully connected net.
    "free_log_std": False,
    # Whether to skip the final linear layer used to resize the hidden layer
    # outputs to size `num_outputs`. If True, then the last hidden layer
    # should already match num_outputs.
    "no_final_linear": False,
    # Whether layers should be shared for the value function.
    "vf_share_layers": True,

    # == LSTM ==
    # Whether to wrap the model with an LSTM.
    "use_lstm": False,
    # Max seq len for training the LSTM, defaults to 20.
    "max_seq_len": 20,
    # Size of the LSTM cell.
    "lstm_cell_size": 256,
    # Whether to feed a_{t-1} to LSTM (one-hot encoded if discrete).
    "lstm_use_prev_action": False,
    # Whether to feed r_{t-1} to LSTM.
    "lstm_use_prev_reward": False,
    # Whether the LSTM is time-major (TxBx..) or batch-major (BxTx..).
    "_time_major": False,

    # == Attention Nets (experimental: torch-version is untested) ==
    # Whether to use a GTrXL ("Gru transformer XL"; attention net) as the
    # wrapper Model around the default Model.
    "use_attention": False,
    # The number of transformer units within GTrXL.
    # A transformer unit in GTrXL consists of a) MultiHeadAttention module and
    # b) a position-wise MLP.
    "attention_num_transformer_units": 1,
    # The input and output size of each transformer unit.
    "attention_dim": 64,
    # The number of attention heads within the MultiHeadAttention units.
    "attention_num_heads": 1,
    # The dim of a single head (within the MultiHeadAttention units).
    "attention_head_dim": 32,
    # The memory sizes for inference and training.
    "attention_memory_inference": 50,
    "attention_memory_training": 50,
    # The output dim of the position-wise MLP.
    "attention_position_wise_mlp_dim": 32,
    # The initial bias values for the 2 GRU gates within a transformer unit.
    "attention_init_gru_gate_bias": 2.0,
    # Whether to feed a_{t-n:t-1} to GTrXL (one-hot encoded if discrete).
    "attention_use_n_prev_actions": 0,
    # Whether to feed r_{t-n:t-1} to GTrXL.
    "attention_use_n_prev_rewards": 0,

    # == Atari ==
    # Set to True to enable 4x stacking behavior.
    "framestack": True,
    # Final resized frame dimension
    "dim": 84,
    # (deprecated) Converts ATARI frame to 1 Channel Grayscale image
    "grayscale": False,
    # (deprecated) Changes frame to range from [-1, 1] if true
    "zero_mean": True,

    # === Options for custom models ===
    # Name of a custom model to use
    "custom_model": None,
    # Extra options to pass to the custom classes. These will be available to
    # the Model's constructor in the model_config field. Also, they will be
    # attempted to be passed as **kwargs to ModelV2 models. For an example,
    # see rllib/models/[tf|torch]/attention_net.py.
    "custom_model_config": {},
    # Name of a custom action distribution to use.
    "custom_action_dist": None,
    # Custom preprocessors are deprecated. Please use a wrapper class around
    # your environment instead to preprocess observations.
    "custom_preprocessor": None,

    # === Options for ModelConfigs in RLModules ===
    # The latent dimension to encode into.
    # Since most RLModules have an encoder and heads, this establishes an agreement
    # on the dimensionality of the latent space they share.
    # This has no effect for models outside RLModule.
    # If None, model_config['fcnet_hiddens'][-1] value will be used to guarantee
    # backward compatibility to old configs. This yields different models than past
    # versions of RLlib.
    "encoder_latent_dim": None,
    # Whether to always check the inputs and outputs of RLlib's default models for
    # their specifications. Input specifications are checked on failed forward passes
    # of the models regardless of this flag. If this flag is set to `True`, inputs and
    # outputs are checked on every call. This leads to a slow-down and should only be
    # used for debugging. Note that this flag is only relevant for instances of
    # RLlib's Model class. These are commonly generated from ModelConfigs in RLModules.
    "always_check_shapes": False,

    # Deprecated keys:
    # Use `lstm_use_prev_action` or `lstm_use_prev_reward` instead.
    "lstm_use_prev_action_reward": DEPRECATED_VALUE,
    # Deprecated in anticipation of RLModules API
    "_use_default_native_models": DEPRECATED_VALUE,

}

While Catalogs have a base class Catalog, you mostly interact with Algorithm-specific Catalogs. Therefore, this doc also includes examples around PPO from which you can extrapolate to other algorithms. Prerequisites for this user guide is a rough understanding of RLModules. This user guide covers the following topics:

  • Basic usage

  • What are Catalogs

  • Inject your custom models into RLModules

  • Inject your custom action distributions into RLModules

Catalog and AlgorithmConfig#

Since Catalogs effectively control what models and distributions RLlib uses under the hood, they are also part of RLlib’s configurations. As the primary entry point for configuring RLlib, AlgorithmConfig is the place where you can configure the Catalogs of the RLModules that are created. You set the catalog class by going through the SingleAgentRLModuleSpec or MultiAgentRLModuleSpec of an AlgorithmConfig. For example, in heterogeneous multi-agent cases, you modify the MultiAgentRLModuleSpec.

../_images/catalog_rlmspecs_diagram.svg

The following example shows how to configure the Catalog of an RLModule created by PPO.

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec


class MyPPOCatalog(PPOCatalog):
    def __init__(self, *args, **kwargs):
        print("Hi from within PPORLModule!")
        super().__init__(*args, **kwargs)


config = (
    PPOConfig()
    .environment("CartPole-v1")
    .framework("torch")
    .rl_module(_enable_rl_module_api=True)
    .training(_enable_learner_api=True)
)

# Specify the catalog to use for the PPORLModule.
config = config.rl_module(
    rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyPPOCatalog)
)
# This is how RLlib constructs a PPORLModule
# It will say "Hi from within PPORLModule!".
ppo = config.build()

Basic usage#

The following three examples illustrate three basic usage patterns of Catalogs.

The first example showcases the general API for interacting with Catalogs.

import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

The second example showcases how to use the PPOCatalog to create a model and an action distribution. This is more similar to what RLlib does internally.

Use catalog-generated models
import gymnasium as gym
import torch

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, ACTOR
from ray.rllib.policy.sample_batch import SampleBatch

env = gym.make("CartPole-v1")

catalog = PPOCatalog(env.observation_space, env.action_space, model_config_dict={})
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_actor_critic_encoder(framework="torch")
policy_head = catalog.build_pi_head(framework="torch")
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")

# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {
    SampleBatch.OBS: torch.Tensor([obs]),
    STATE_IN: None,
    SampleBatch.SEQ_LENS: None,
}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT][ACTOR]
action_dist_inputs = policy_head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])

The third example showcases how to use the base Catalog to create an encoder and an action distribution. Besides these, we create a head network that fits these two by hand to show how you can combine RLlib’s ModelConfig API and Catalog. Extending Catalog to also build this head is how Catalog is meant to be extended, which we cover later in this guide.

Customize a policy head
import gymnasium as gym
import torch

from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.policy.sample_batch import SampleBatch

env = gym.make("CartPole-v1")

catalog = Catalog(env.observation_space, env.action_space, model_config_dict={})
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
# Therefore, we need `env.action_space.n` action distribution inputs.
expected_action_dist_input_dims = (env.action_space.n,)
# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_encoder(framework="torch")
# Build a suitable head model for the action distribution.
head_config = MLPHeadConfig(
    input_dims=catalog.latent_dims, hidden_layer_dims=expected_action_dist_input_dims
)
head = head_config.build(framework="torch")
# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {
    SampleBatch.OBS: torch.Tensor([obs]),
    STATE_IN: None,
    SampleBatch.SEQ_LENS: None,
}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT]
action_dist_inputs = head(encoding)
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
env.step(actions[0])

What are Catalogs#

Catalogs have two primary roles: Choosing the right Model and choosing the right Distribution. By default, all catalogs implement decision trees that decide model architecture based on a combination of input configurations. These mainly include the observation space and action space of the RLModule, the model config dict and the deep learning framework backend.

The following diagram shows the break down of the information flow towards models and distributions within an RLModule. An RLModule creates an instance of the Catalog class they receive as part of their constructor. It then create its internal models and distributions with the help of this Catalog.

Note

You can also modify Model or Distribution in an RLModule directly by overriding the RLModule’s constructor!

../_images/catalog_and_rlm_diagram.svg

The following diagram shows a concrete case in more detail.

Example of catalog in a PPORLModule

The PPOCatalog is fed an observation space, action space, a model config dict and the view requirements of the RLModule. The model config dicts and the view requirements are only of interest in special cases, such as recurrent networks or attention networks. A PPORLModule has four components that are created by the PPOCatalog: Encoder, value function head, policy head, and action distribution.

../_images/ppo_catalog_and_rlm_diagram.svg

Inject your custom model or action distributions into RLModules#

You can make a Catalog build custom models by overriding the Catalog’s methods used by RLModules to build models. Have a look at these lines from the constructor of the PPOTorchRLModule to see how Catalogs are being used by an RLModule:

        catalog = self.config.get_catalog()

        # Build models from catalog
        self.encoder = catalog.build_actor_critic_encoder(framework=self.framework)
        self.pi = catalog.build_pi_head(framework=self.framework)
        self.vf = catalog.build_vf_head(framework=self.framework)

        self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework)

Consequently, in order to build a custom Model compatible with a PPORLModule, you can override methods by inheriting from PPOCatalog or write a Catalog that implements them from scratch. The following showcases such modifications.

This example shows two modifications:

  • How to write a custom Distribution

  • How to inject a custom action distribution into a Catalog

import torch
import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.models.distributions import Distribution
from ray.rllib.models.torch.torch_distributions import TorchDeterministic


# Define a simple categorical distribution that can be used for PPO
class CustomTorchCategorical(Distribution):
    def __init__(self, logits):
        self.torch_dist = torch.distributions.categorical.Categorical(logits=logits)

    def sample(self, sample_shape=torch.Size()):
        return self.torch_dist.sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return self._dist.rsample(sample_shape)

    def logp(self, value):
        return self.torch_dist.log_prob(value)

    def entropy(self):
        return self.torch_dist.entropy()

    def kl(self, other):
        return torch.distributions.kl.kl_divergence(self.torch_dist, other.torch_dist)

    @staticmethod
    def required_input_dim(space):
        return int(space.n)

    @classmethod
    # This method is used to create distributions from logits inside RLModules.
    # You can use this to inject arguments into the constructor of this distribution
    # that are not the logits themselves.
    def from_logits(cls, logits):
        return CustomTorchCategorical(logits=logits)

    # This method is used to create a deterministic distribution for the
    # PPORLModule.forward_inference.
    def to_deterministic(self):
        return TorchDeterministic(loc=torch.argmax(self.logits, dim=-1))


# See if we can create this distribution and sample from it to interact with our
# target environment
env = gym.make("CartPole-v1")
dummy_logits = torch.randn([env.action_space.n])
dummy_dist = CustomTorchCategorical.from_logits(dummy_logits)
action = dummy_dist.sample()
env = gym.make("CartPole-v1")
env.reset()
env.step(action.numpy())


# Define a simple catalog that returns our custom distribution when
# get_action_dist_cls is called
class CustomPPOCatalog(PPOCatalog):
    def get_action_dist_cls(self, framework):
        # The distribution we wrote will only work with torch
        assert framework == "torch"
        return CustomTorchCategorical


# Train with our custom action distribution
algo = (
    PPOConfig()
    .environment("CartPole-v1")
    .rl_module(rl_module_spec=SingleAgentRLModuleSpec(catalog_class=CustomPPOCatalog))
    .build()
)
results = algo.train()
print(results)

Notable TODOs#

  • Add cross references to Model and Distribution API docs

  • Add example that shows how to inject own model

  • Add more instructions on how to write a catalog from scratch

  • Add section “Extend RLlib’s selection of Models and Distributions with your own”

  • Add section “Write a Catalog from scratch”