ray.rllib.core.rl_module.rl_module.RLModule#

class ray.rllib.core.rl_module.rl_module.RLModule(config: RLModuleConfig)[source]#

Bases: ABC

Base class for RLlib modules.

Subclasses should call super().__init__(config) in their __init__ method. Here is the pseudocode for how the forward methods are called:

Example for creating a sampling loop:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()
action_dist_class = module.get_inference_action_dist_cls()
obs, info = env.reset()
terminated = False

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_exploration(fwd_ins)
    # this can be either deterministic or stochastic distribution
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)

Example for training:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()

fwd_ins = {"obs": torch.Tensor([obs])}
fwd_outputs = module.forward_train(fwd_ins)
# loss = compute_loss(fwd_outputs, fwd_ins)
# update_params(module, loss)

Example for inference:

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
import gymnasium as gym
import torch

env = gym.make("CartPole-v1")

# Create a single agent RL module spec.
module_spec = SingleAgentRLModuleSpec(
    module_class=PPOTorchRLModule,
    observation_space=env.observation_space,
    action_space=env.action_space,
    model_config_dict = {"hidden": [128, 128]},
    catalog_class = PPOCatalog,
)
module = module_spec.build()

while not terminated:
    fwd_ins = {"obs": torch.Tensor([obs])}
    fwd_outputs = module.forward_inference(fwd_ins)
    # this can be either deterministic or stochastic distribution
    action_dist = action_dist_class.from_logits(
        fwd_outputs["action_dist_inputs"]
    )
    action = action_dist.sample()[0].numpy()
    obs, reward, terminated, truncated, info = env.step(action)
Parameters:

config – The config for the RLModule.

Abstract Methods:

_forward_train(): Forward pass during training.

_forward_exploration(): Forward pass during training for exploration.

_forward_inference(): Forward pass during inference.

Note

There is a reason that the specs are not written as abstract properties. The reason is that torch overrides __getattr__ and __setattr__. This means that if we define the specs as properties, then any error in the property will be interpreted as a failure to retrieve the attribute and will invoke __getattr__ which will give a confusing error about the attribute not found. More details here: pytorch/pytorch#49726.

Methods

as_multi_agent

Returns a multi-agent wrapper around this module.

forward_exploration

Forward-pass during exploration, called from the sampler.

forward_inference

Forward-pass during evaluation, called from the sampler.

forward_train

Forward-pass during training called from the learner.

from_checkpoint

Loads the module from a checkpoint directory.

get_exploration_action_dist_cls

Returns the action distribution class for this RLModule used for exploration.

get_inference_action_dist_cls

Returns the action distribution class for this RLModule used for inference.

get_initial_state

Returns the initial state of the RLModule.

get_state

Returns the state dict of the module.

get_train_action_dist_cls

Returns the action distribution class for this RLModule used for training.

input_specs_exploration

Returns the input specs of the forward_exploration method.

input_specs_inference

Returns the input specs of the forward_inference method.

input_specs_train

Returns the input specs of the forward_train method.

is_stateful

Returns False if the initial state is an empty dict (or None).

load_state

Loads the weights of an RLModule from the directory dir.

output_specs_exploration

Returns the output specs of the forward_exploration() method.

output_specs_inference

Returns the output specs of the forward_inference() method.

output_specs_train

Returns the output specs of the forward_train method.

save_state

Saves the weights of this RLModule to the directory dir.

save_to_checkpoint

Saves the module to a checkpoint directory.

set_state

Sets the state dict of the module.

setup

Sets up the components of the module.

unwrapped

Returns the underlying module if this module is a wrapper.

update_default_view_requirements

Updates default view requirements with the view requirements of this module.

Attributes

framework