ray.rllib.core.rl_module.rl_module.RLModule#
- class ray.rllib.core.rl_module.rl_module.RLModule(config=-1, *, observation_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, inference_only: bool | None = None, learner_only: bool = False, model_config: dict | DefaultModelConfig | None = None, catalog_class=None)[source]#
Bases:
Checkpointable
,ABC
Base class for RLlib modules.
Subclasses should call super().__init__(config) in their __init__ method. Here is the pseudocode for how the forward methods are called:
Example for creating a sampling loop:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) module = module_spec.build() action_dist_class = module.get_inference_action_dist_cls() obs, info = env.reset() terminated = False while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_exploration(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
Example for training:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) module = module_spec.build() fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_train(fwd_ins) # loss = compute_loss(fwd_outputs, fwd_ins) # update_params(module, loss)
Example for inference:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) module = module_spec.build() while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_inference(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
- Parameters:
config – The config for the RLModule.
- Abstract Methods:
~_forward_train
: Forward pass during training.~_forward_exploration
: Forward pass during training for exploration.~_forward_inference
: Forward pass during inference.
Note
There is a reason that the specs are not written as abstract properties. The reason is that torch overrides
__getattr__
and__setattr__
. This means that if we define the specs as properties, then any error in the property will be interpreted as a failure to retrieve the attribute and will invoke__getattr__
which will give a confusing error about the attribute not found. More details here: pytorch/pytorch#49726.PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
Returns a multi-agent wrapper around this module.
DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.
DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.
DO NOT OVERRIDE! Forward-pass during training called from the learner.
Creates a new Checkpointable instance from the given location and returns it.
Returns the implementing class's own Checkpointable subcomponents.
Returns the action distribution class for this RLModule used for exploration.
Returns the action distribution class for this RLModule used for inference.
Returns the initial state of the RLModule, in case this is a stateful module.
Returns JSON writable metadata further describing the implementing class.
Returns the state dict of the module.
Returns the action distribution class for this RLModule used for training.
Returns the input specs of the forward_exploration method.
Returns the input specs of the forward_inference method.
Returns the input specs of the forward_train method.
By default, returns False if the initial state is an empty dict (or None).
Returns the output specs of the forward_train method.
Restores the state of the implementing class from the given path.
Saves the state of the implementing class (or
state
) topath
.Sets up the components of the module.
Returns the underlying module if this module is a wrapper.
Attributes