ray.rllib.core.rl_module.rl_module.RLModule#
- class ray.rllib.core.rl_module.rl_module.RLModule(config=-1, *, observation_space: gymnasium.Space | None = None, action_space: gymnasium.Space | None = None, inference_only: bool | None = None, learner_only: bool = False, model_config: dict | DefaultModelConfig | None = None, catalog_class=None, **kwargs)[source]#
Bases:
Checkpointable
,ABC
Base class for RLlib modules.
Subclasses should call
super().__init__(observation_space=.., action_space=.., inference_only=.., learner_only=.., model_config={..})
in their __init__ methods.Here is the pseudocode for how the forward methods are called:
Example for creating a (inference-only) sampling loop:
from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( DefaultPPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create an instance of the default RLModule used by PPO. module = DefaultPPOTorchRLModule( observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) action_dist_class = module.get_inference_action_dist_cls() obs, info = env.reset() terminated = False while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_exploration(fwd_ins) # This can be either deterministic or stochastic distribution. action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
Example for training:
import gymnasium as gym import torch from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( DefaultPPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog env = gym.make("CartPole-v1") # Create an instance of the default RLModule used by PPO. module = DefaultPPOTorchRLModule( observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_train(fwd_ins) # loss = compute_loss(fwd_outputs, fwd_ins) # update_params(module, loss)
Example for inference:
import gymnasium as gym import torch from ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module import ( DefaultPPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog env = gym.make("CartPole-v1") # Create an instance of the default RLModule used by PPO. module = DefaultPPOTorchRLModule( observation_space=env.observation_space, action_space=env.action_space, model_config=DefaultModelConfig(fcnet_hiddens=[128, 128]), catalog_class=PPOCatalog, ) while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_inference(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
- Parameters:
observation_space – The observation space of the model. Note that in multi-agent setups, this is typically the observation space of an agent that maps to this RLModule.
action_space – The action space of the model. Note that in multi-agent setups, this is typically the action space of an agent that maps to this RLModule.
inference_only – If True, this RLModule should construct itself in an inference- only fashion. This is done automatically, if the user implements the
InferenceOnlyAPI
with their custom RLModule subclass. False by default.learner_only – If True, RLlib won’t built this RLModule on EnvRunner actors. False by default.
model_config – A config dict to specify features of this RLModule.
- Abstract Methods:
~_forward_train
: Forward pass during training.~_forward_exploration
: Forward pass during training for exploration.~_forward_inference
: Forward pass during inference.
PublicAPI (beta): This API is in beta and may change before becoming stable.
Methods
Returns a multi-agent wrapper around this module.
DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler.
DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler.
DO NOT OVERRIDE! Forward-pass during training called from the learner.
Creates a new Checkpointable instance from the given location and returns it.
Returns the implementing class's own Checkpointable subcomponents.
Returns the action distribution class for this RLModule used for exploration.
Returns the action distribution class for this RLModule used for inference.
Returns the initial state of the RLModule, in case this is a stateful module.
Returns JSON writable metadata further describing the implementing class.
Returns the state dict of the module.
Returns the action distribution class for this RLModule used for training.
Returns the input specs of the forward_exploration method.
Returns the input specs of the forward_inference method.
Returns the input specs of the forward_train method.
By default, returns False if the initial state is an empty dict (or None).
Returns the output specs of the forward_train method.
Restores the state of the implementing class from the given path.
Saves the state of the implementing class (or
state
) topath
.Sets up the components of the module.
Returns the underlying module if this module is a wrapper.
Attributes