ray.rllib.core.rl_module.rl_module.RLModule#
- class ray.rllib.core.rl_module.rl_module.RLModule(config: RLModuleConfig)[source]#
Bases:
ABC
Base class for RLlib modules.
Subclasses should call super().__init__(config) in their __init__ method. Here is the pseudocode for how the forward methods are called:
Example for creating a sampling loop:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = SingleAgentRLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) module = module_spec.build() action_dist_class = module.get_inference_action_dist_cls() obs, info = env.reset() terminated = False while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_exploration(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
Example for training:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = SingleAgentRLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) module = module_spec.build() fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_train(fwd_ins) # loss = compute_loss(fwd_outputs, fwd_ins) # update_params(module, loss)
Example for inference:
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule ) from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog import gymnasium as gym import torch env = gym.make("CartPole-v1") # Create a single agent RL module spec. module_spec = SingleAgentRLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict = {"hidden": [128, 128]}, catalog_class = PPOCatalog, ) module = module_spec.build() while not terminated: fwd_ins = {"obs": torch.Tensor([obs])} fwd_outputs = module.forward_inference(fwd_ins) # this can be either deterministic or stochastic distribution action_dist = action_dist_class.from_logits( fwd_outputs["action_dist_inputs"] ) action = action_dist.sample()[0].numpy() obs, reward, terminated, truncated, info = env.step(action)
- Parameters:
config – The config for the RLModule.
- Abstract Methods:
forward_train()
: Forward pass during training.forward_exploration()
: Forward pass during training for exploration.forward_inference()
: Forward pass during inference.
Note
There is a reason that the specs are not written as abstract properties. The reason is that torch overrides
__getattr__
and__setattr__
. This means that if we define the specs as properties, then any error in the property will be interpreted as a failure to retrieve the attribute and will invoke__getattr__
which will give a confusing error about the attribute not found. More details here: pytorch/pytorch#49726.Methods
Returns a multi-agent wrapper around this module.
Forward-pass during exploration, called from the sampler.
Forward-pass during evaluation, called from the sampler.
Forward-pass during training called from the learner.
Loads the module from a checkpoint directory.
Returns the action distribution class for this RLModule used for exploration.
Returns the action distribution class for this RLModule used for inference.
Returns the initial state of the module.
Returns the state dict of the module.
Returns the action distribution class for this RLModule used for training.
Returns the input specs of the forward_exploration method.
Returns the input specs of the forward_inference method.
Returns the input specs of the forward_train method.
Returns False if the initial state is an empty dict (or None).
Loads the weights of an RLModule from the directory dir.
Returns the output specs of the forward_exploration method.
Returns the output specs of the forward_inference method.
Returns the output specs of the forward_train method.
Saves the weights of this RLModule to the directory dir.
Saves the module to a checkpoint directory.
Sets the state dict of the module.
Sets up the components of the module.
Returns the underlying module if this module is a wrapper.
Updates default view requirements with the view requirements of this module.
Attributes