ray.rllib.core.rl_module.rl_module.RLModule#
- class ray.rllib.core.rl_module.rl_module.RLModule(config=-1, *, observation_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, inference_only: bool | None = None, learner_only: bool = False, model_config: dict | DefaultModelConfig | None = None, catalog_class=None, **kwargs)[source]#
Bases:
Checkpointable
,ABC
Base class for RLlib modules.
Subclasses should call
super().__init__(observation_space=.., action_space=.., inference_only=.., learner_only=.., model_config={..})
in their __init__ methods.Here is the pseudocode for how the forward methods are called:
Example for creating a (inference-only) sampling loop:
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( DefaultPPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create an instance of the default RLModule used by PPO. module = DefaultPPOTorchRLModule( observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) action_dist_class = module.get_inference_action_dist_cls() obs, info = env.reset() terminated = False while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_exploration(fwd_ins) # This can be either deterministic or stochastic distribution. action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
Example for training:
import gymnasium as gym import torch from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( DefaultPPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog env = gym.make("CartPole-v1") # Create an instance of the default RLModule used by PPO. module = DefaultPPOTorchRLModule( observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_train(fwd_ins) # loss = compute_loss(fwd_outputs, fwd_ins) # update_params(module, loss)
Example for inference:
import gymnasium as gym import torch from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( DefaultPPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog env = gym.make("CartPole-v1") # Create an instance of the default RLModule used by PPO. module = DefaultPPOTorchRLModule( observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_inference(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
- Parameters:
config – The config for the RLModule.
- Abstract Methods:
~_forward_train
: Forward pass during training.~_forward_exploration
: Forward pass during training for exploration.~_forward_inference
: Forward pass during inference.
PublicAPI (beta): This API is in beta and may change before becoming stable.
Methods
Returns a multi-agent wrapper around this module.
DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.
DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.
DO NOT OVERRIDE! Forward-pass during training called from the learner.
Creates a new Checkpointable instance from the given location and returns it.
Returns the implementing class's own Checkpointable subcomponents.
Returns the action distribution class for this RLModule used for exploration.
Returns the action distribution class for this RLModule used for inference.
Returns the initial state of the RLModule, in case this is a stateful module.
Returns JSON writable metadata further describing the implementing class.
Returns the state dict of the module.
Returns the action distribution class for this RLModule used for training.
Returns the input specs of the forward_exploration method.
Returns the input specs of the forward_inference method.
Returns the input specs of the forward_train method.
By default, returns False if the initial state is an empty dict (or None).
Returns the output specs of the forward_train method.
Restores the state of the implementing class from the given path.
Saves the state of the implementing class (or
state
) topath
.Sets up the components of the module.
Returns the underlying module if this module is a wrapper.
Attributes