ray.rllib.core.rl_module.rl_module.RLModule#

class ray.rllib.core.rl_module.rl_module.RLModule(config=-1, *, observation_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, inference_only: bool | None = None, learner_only: bool = False, model_config: dict | DefaultModelConfig | None = None, catalog_class=None, **kwargs)[source]#

Bases: Checkpointable, ABC

Base class for RLlib modules.

Subclasses should call super().__init__(observation_space=.., action_space=.., inference_only=.., learner_only=.., model_config={..}) in their __init__ methods.

Here is the pseudocode for how the forward methods are called:

Example for creating a (inference-only) sampling loop:

from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
    DefaultPPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create an instance of the default RLModule used by PPO.
module = DefaultPPOTorchRLModule(
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
    catalog_class=PPOCatalog,
)
action_dist_class = module.get_inference_action_dist_cls()
obs, info = env.reset()
terminated = False

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_exploration(fwd_ins)
    # This can be either deterministic or stochastic distribution.
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)

Example for training:

import gymnasium as gym
import torch

from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
    DefaultPPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

env = gym.make("CartPole-v1")

# Create an instance of the default RLModule used by PPO.
module = DefaultPPOTorchRLModule(
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
    catalog_class=PPOCatalog,
)

fwd_ins = {"obs": torch.Tensor([obs])}
fwd_outputs = module.forward_train(fwd_ins)
# loss = compute_loss(fwd_outputs, fwd_ins)
# update_params(module, loss)

Example for inference:

import gymnasium as gym
import torch

from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import (
    DefaultPPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

env = gym.make("CartPole-v1")

# Create an instance of the default RLModule used by PPO.
module = DefaultPPOTorchRLModule(
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]),
    catalog_class=PPOCatalog,
)

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_inference(fwd_ins)
    # this can be either deterministic or stochastic distribution
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)
Parameters:

config – The config for the RLModule.

Abstract Methods:

~_forward_train: Forward pass during training.

~_forward_exploration: Forward pass during training for exploration.

~_forward_inference: Forward pass during inference.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Methods

as_multi_rl_module

Returns a multi-agent wrapper around this module.

forward_exploration

DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.

forward_inference

DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.

forward_train

DO NOT OVERRIDE! Forward-pass during training called from the learner.

from_checkpoint

Creates a new Checkpointable instance from the given location and returns it.

get_checkpointable_components

Returns the implementing class's own Checkpointable subcomponents.

get_exploration_action_dist_cls

Returns the action distribution class for this RLModule used for exploration.

get_inference_action_dist_cls

Returns the action distribution class for this RLModule used for inference.

get_initial_state

Returns the initial state of the RLModule, in case this is a stateful module.

get_metadata

Returns JSON writable metadata further describing the implementing class.

get_state

Returns the state dict of the module.

get_train_action_dist_cls

Returns the action distribution class for this RLModule used for training.

input_specs_exploration

Returns the input specs of the forward_exploration method.

input_specs_inference

Returns the input specs of the forward_inference method.

input_specs_train

Returns the input specs of the forward_train method.

is_stateful

By default, returns False if the initial state is an empty dict (or None).

output_specs_train

Returns the output specs of the forward_train method.

restore_from_path

Restores the state of the implementing class from the given path.

save_to_path

Saves the state of the implementing class (or state) to path.

setup

Sets up the components of the module.

unwrapped

Returns the underlying module if this module is a wrapper.

Attributes

CLASS_AND_CTOR_ARGS_FILE_NAME

METADATA_FILE_NAME

STATE_FILE_NAME

framework