ray.rllib.core.rl_module.rl_module.RLModule#

class ray.rllib.core.rl_module.rl_module.RLModule(config: RLModuleConfig)[source]#

Bases: Checkpointable, ABC

Base class for RLlib modules.

Subclasses should call super().__init__(config) in their __init__ method. Here is the pseudocode for how the forward methods are called:

Example for creating a sampling loop:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()
action_dist_class = module.get_inference_action_dist_cls()
obs, info = env.reset()
terminated = False

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_exploration(fwd_ins)
    # this can be either deterministic or stochastic distribution
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)

Example for training:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()

fwd_ins = {"obs": torch.Tensor([obs])}
fwd_outputs = module.forward_train(fwd_ins)
# loss = compute_loss(fwd_outputs, fwd_ins)
# update_params(module, loss)

Example for inference:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_inference(fwd_ins)
    # this can be either deterministic or stochastic distribution
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)
Parameters:

config – The config for the RLModule.

Abstract Methods:

~_forward_train: Forward pass during training.

~_forward_exploration: Forward pass during training for exploration.

~_forward_inference: Forward pass during inference.

Note

There is a reason that the specs are not written as abstract properties. The reason is that torch overrides __getattr__ and __setattr__. This means that if we define the specs as properties, then any error in the property will be interpreted as a failure to retrieve the attribute and will invoke __getattr__ which will give a confusing error about the attribute not found. More details here: pytorch/pytorch#49726.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

as_multi_agent

Returns a multi-agent wrapper around this module.

forward_exploration

Forward-pass during exploration, called from the sampler.

forward_inference

Forward-pass during evaluation, called from the sampler.

forward_train

Forward-pass during training called from the learner.

from_checkpoint

Creates a new Checkpointable instance from the given location and returns it.

get_checkpointable_components

Returns the implementing class's own Checkpointable subcomponents.

get_exploration_action_dist_cls

Returns the action distribution class for this RLModule used for exploration.

get_inference_action_dist_cls

Returns the action distribution class for this RLModule used for inference.

get_initial_state

Returns the initial state of the RLModule.

get_metadata

Returns JSON writable metadata further describing the implementing class.

get_state

Returns the state dict of the module.

get_train_action_dist_cls

Returns the action distribution class for this RLModule used for training.

input_specs_exploration

Returns the input specs of the forward_exploration method.

input_specs_inference

Returns the input specs of the forward_inference method.

input_specs_train

Returns the input specs of the forward_train method.

is_stateful

Returns False if the initial state is an empty dict (or None).

output_specs_exploration

Returns the output specs of the forward_exploration() method.

output_specs_inference

Returns the output specs of the forward_inference() method.

output_specs_train

Returns the output specs of the forward_train method.

restore_from_path

Restores the state of the implementing class from the given path.

save_to_path

Saves the state of the implementing class (or state) to path.

setup

Sets up the components of the module.

unwrapped

Returns the underlying module if this module is a wrapper.

update_default_view_requirements

Updates default view requirements with the view requirements of this module.

Attributes

CLASS_AND_CTOR_ARGS_FILE_NAME

METADATA_FILE_NAME

STATE_FILE_NAME

framework