ray.rllib.core.learner.learner.Learner#

class ray.rllib.core.learner.learner.Learner(*, config: AlgorithmConfig, module_spec: RLModuleSpec | MultiRLModuleSpec | None = None, module: RLModule | None = None)[source]#

Bases: Checkpointable

Base class for Learners.

This class will be used to train RLModules. It is responsible for defining the loss function, and updating the neural network weights that it owns. It also provides a way to add/remove modules to/from RLModules in a multi-agent scenario, in the middle of training (This is useful for league based training).

TF and Torch specific implementation of this class fills in the framework-specific implementation details for distributed training, and for computing and applying gradients. User should not need to sub-class this class, but instead inherit from the TF or Torch specific sub-classes to implement their algorithm-specific update logic.

Parameters:
  • config – The AlgorithmConfig object from which to derive most of the settings needed to build the Learner.

  • module_spec – The module specification for the RLModule that is being trained. If the module is a single agent module, after building the module it will be converted to a multi-agent module with a default key. Can be none if the module is provided directly via the module argument. Refer to ray.rllib.core.rl_module.RLModuleSpec or ray.rllib.core.rl_module.MultiRLModuleSpec for more info.

  • module – If learner is being used stand-alone, the RLModule can be optionally passed in directly instead of the through the module_spec.

Note: We use PPO and torch as an example here because many of the showcased components need implementations to come together. However, the same pattern is generally applicable.

import gymnasium as gym

from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
    PPOTorchRLModule
)
from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

env = gym.make("CartPole-v1")

# Create a PPO config object first.
config = (
    PPOConfig()
    .framework("torch")
    .training(model={"fcnet_hiddens": [128, 128]})
)

# Create a learner instance directly from our config. All we need as
# extra information here is the env to be able to extract space information
# (needed to construct the RLModule inside the Learner).
learner = config.build_learner(env=env)

# Take one gradient update on the module and report the results.
# results = learner.update(...)

# Add a new module, perhaps for league based training.
learner.add_module(
    module_id="new_player",
    module_spec=RLModuleSpec(
        module_class=PPOTorchRLModule,
        observation_space=env.observation_space,
        action_space=env.action_space,
        model_config_dict={"fcnet_hiddens": [64, 64]},
        catalog_class=PPOCatalog,
    )
)

# Take another gradient update with both previous and new modules.
# results = learner.update(...)

# Remove a module.
learner.remove_module("new_player")

# Will train previous modules only.
# results = learner.update(...)

# Get the state of the learner.
state = learner.get_state()

# Set the state of the learner.
learner.set_state(state)

# Get the weights of the underlying MultiRLModule.
weights = learner.get_state(components=COMPONENT_RL_MODULE)

# Set the weights of the underlying MultiRLModule.
learner.set_state({COMPONENT_RL_MODULE: weights})

Extension pattern:

from ray.rllib.core.learner.torch.torch_learner import TorchLearner

class MyLearner(TorchLearner):

   def compute_losses(self, fwd_out, batch):
       # Compute the losses per module based on `batch` and output of the
       # forward pass (`fwd_out`). To access the (algorithm) config for a
       # specific RLModule, do:
       # `self.config.get_config_for_module([moduleID])`.
       return {DEFAULT_MODULE_ID: module_loss}

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

add_module

Adds a module to the underlying MultiRLModule.

after_gradient_based_update

Called after gradient-based updates are completed.

apply_gradients

Applies the gradients to the MultiRLModule parameters.

before_gradient_based_update

Called before gradient-based updates are completed.

build

Builds the Learner.

compute_gradients

Computes the gradients based on the given losses.

compute_loss_for_module

Computes the loss for a single module.

compute_losses

Computes the loss(es) for the module being optimized.

configure_optimizers

Configures, creates, and registers the optimizers for this Learner.

configure_optimizers_for_module

Configures an optimizer for the given module_id.

filter_param_dict_for_optimizer

Reduces the given ParamDict to contain only parameters for given optimizer.

from_checkpoint

Creates a new Checkpointable instance from the given location and returns it.

get_metadata

Returns JSON writable metadata further describing the implementing class.

get_optimizer

Returns the optimizer object, configured under the given module_id and name.

get_optimizers_for_module

Returns a list of (optimizer_name, optimizer instance)-tuples for module_id.

get_param_ref

Returns a hashable reference to a trainable parameter.

get_parameters

Returns the list of parameters of a module.

postprocess_gradients

Applies potential postprocessing operations on the gradients.

postprocess_gradients_for_module

Applies postprocessing operations on the gradients of the given module.

register_optimizer

Registers an optimizer with a ModuleID, name, param list and lr-scheduler.

remove_module

Removes a module from the Learner.

restore_from_path

Restores the state of the implementing class from the given path.

save_to_path

Saves the state of the implementing class (or state) to path.

should_module_be_updated

Returns whether a module should be updated or not based on self.config.

update_from_batch

Do num_iters minibatch updates given a train batch.

update_from_episodes

Do num_iters minibatch updates given a list of episodes.

Attributes

CLASS_AND_CTOR_ARGS_FILE_NAME

METADATA_FILE_NAME

STATE_FILE_NAME

TOTAL_LOSS_KEY

distributed

Whether the learner is running in distributed mode.

framework

module

The MultiRLModule that is being trained.