import abc
from collections import defaultdict
import copy
from functools import partial
import logging
import numpy
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Hashable,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
Union,
)
import tree # pip install dm_tree
import ray
from ray.data.iterator import DataIterator
from ray.rllib.connectors.learner.learner_connector_pipeline import (
LearnerConnectorPipeline,
)
from ray.rllib.core import (
COMPONENT_METRICS_LOGGER,
COMPONENT_OPTIMIZER,
COMPONENT_RL_MODULE,
DEFAULT_MODULE_ID,
)
from ray.rllib.core.rl_module.apis import SelfSupervisedLossAPI
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.multi_rl_module import (
MultiRLModule,
MultiRLModuleSpec,
)
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.debug import update_global_seed_if_necessary
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.metrics import (
ALL_MODULES,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
NUM_ENV_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED_LIFETIME,
NUM_MODULE_STEPS_TRAINED,
NUM_MODULE_STEPS_TRAINED_LIFETIME,
LEARNER_CONNECTOR_TIMER,
MODULE_TRAIN_BATCH_SIZE_MEAN,
WEIGHTS_SEQ_NO,
)
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.minibatch_utils import (
MiniBatchDummyIterator,
MiniBatchCyclicIterator,
)
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.typing import (
EpisodeType,
LearningRateOrSchedule,
ModuleID,
Optimizer,
Param,
ParamRef,
ParamDict,
ResultDict,
ShouldModuleBeUpdatedFn,
StateDict,
TensorType,
)
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
torch, _ = try_import_torch()
tf1, tf, tfv = try_import_tf()
logger = logging.getLogger(__name__)
DEFAULT_OPTIMIZER = "default_optimizer"
# COMMON LEARNER LOSS_KEYS
POLICY_LOSS_KEY = "policy_loss"
VF_LOSS_KEY = "vf_loss"
ENTROPY_KEY = "entropy"
# Additional update keys
LR_KEY = "learning_rate"
[docs]
@PublicAPI(stability="alpha")
class Learner(Checkpointable):
"""Base class for Learners.
This class will be used to train RLModules. It is responsible for defining the loss
function, and updating the neural network weights that it owns. It also provides a
way to add/remove modules to/from RLModules in a multi-agent scenario, in the
middle of training (This is useful for league based training).
TF and Torch specific implementation of this class fills in the framework-specific
implementation details for distributed training, and for computing and applying
gradients. User should not need to sub-class this class, but instead inherit from
the TF or Torch specific sub-classes to implement their algorithm-specific update
logic.
Args:
config: The AlgorithmConfig object from which to derive most of the settings
needed to build the Learner.
module_spec: The module specification for the RLModule that is being trained.
If the module is a single agent module, after building the module it will
be converted to a multi-agent module with a default key. Can be none if the
module is provided directly via the `module` argument. Refer to
ray.rllib.core.rl_module.RLModuleSpec
or ray.rllib.core.rl_module.MultiRLModuleSpec for more info.
module: If learner is being used stand-alone, the RLModule can be optionally
passed in directly instead of the through the `module_spec`.
Note: We use PPO and torch as an example here because many of the showcased
components need implementations to come together. However, the same
pattern is generally applicable.
.. testcode::
import gymnasium as gym
from ray.rllib.algorithms.ppo.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule
)
from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
env = gym.make("CartPole-v1")
# Create a PPO config object first.
config = (
PPOConfig()
.framework("torch")
.training(model={"fcnet_hiddens": [128, 128]})
)
# Create a learner instance directly from our config. All we need as
# extra information here is the env to be able to extract space information
# (needed to construct the RLModule inside the Learner).
learner = config.build_learner(env=env)
# Take one gradient update on the module and report the results.
# results = learner.update(...)
# Add a new module, perhaps for league based training.
learner.add_module(
module_id="new_player",
module_spec=RLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config=DefaultModelConfig(fcnet_hiddens=[64, 64]),
catalog_class=PPOCatalog,
)
)
# Take another gradient update with both previous and new modules.
# results = learner.update(...)
# Remove a module.
learner.remove_module("new_player")
# Will train previous modules only.
# results = learner.update(...)
# Get the state of the learner.
state = learner.get_state()
# Set the state of the learner.
learner.set_state(state)
# Get the weights of the underlying MultiRLModule.
weights = learner.get_state(components=COMPONENT_RL_MODULE)
# Set the weights of the underlying MultiRLModule.
learner.set_state({COMPONENT_RL_MODULE: weights})
Extension pattern:
.. testcode::
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
class MyLearner(TorchLearner):
def compute_losses(self, fwd_out, batch):
# Compute the losses per module based on `batch` and output of the
# forward pass (`fwd_out`). To access the (algorithm) config for a
# specific RLModule, do:
# `self.config.get_config_for_module([moduleID])`.
return {DEFAULT_MODULE_ID: module_loss}
"""
framework: str = None
TOTAL_LOSS_KEY: str = "total_loss"
def __init__(
self,
*,
config: "AlgorithmConfig",
module_spec: Optional[Union[RLModuleSpec, MultiRLModuleSpec]] = None,
module: Optional[RLModule] = None,
):
# TODO (sven): Figure out how to do this
self.config = config.copy(copy_frozen=False)
self._module_spec: Optional[MultiRLModuleSpec] = module_spec
self._module_obj: Optional[MultiRLModule] = module
self._device = None
# Set a seed, if necessary.
if self.config.seed is not None:
update_global_seed_if_necessary(self.framework, self.config.seed)
# Whether self.build has already been called.
self._is_built = False
# These are the attributes that are set during build.
# The actual MultiRLModule used by this Learner.
self._module: Optional[MultiRLModule] = None
self._weights_seq_no = 0
# Our Learner connector pipeline.
self._learner_connector: Optional[LearnerConnectorPipeline] = None
# These are set for properly applying optimizers and adding or removing modules.
self._optimizer_parameters: Dict[Optimizer, List[ParamRef]] = {}
self._named_optimizers: Dict[str, Optimizer] = {}
self._params: ParamDict = {}
# Dict mapping ModuleID to a list of optimizer names. Note that the optimizer
# name includes the ModuleID as a prefix: optimizer_name=`[ModuleID]_[.. rest]`.
self._module_optimizers: Dict[ModuleID, List[str]] = defaultdict(list)
self._optimizer_name_to_module: Dict[str, ModuleID] = {}
# Only manage optimizer's learning rate if user has NOT overridden
# the `configure_optimizers_for_module` method. Otherwise, leave responsibility
# to handle lr-updates entirely in user's hands.
self._optimizer_lr_schedules: Dict[Optimizer, Scheduler] = {}
# The Learner's own MetricsLogger to be used to log RLlib's built-in metrics or
# custom user-defined ones (e.g. custom loss values). When returning from an
# `update_from_...()` method call, the Learner will do a `self.metrics.reduce()`
# and return the resulting (reduced) dict.
self.metrics = MetricsLogger()
# In case of offline learning and multiple learners, each learner receives a
# repeatable iterator that iterates over a split of the streamed data.
self.iterator: DataIterator = None
# TODO (sven): Do we really need this API? It seems like LearnerGroup constructs
# all Learner workers and then immediately builds them any ways? Seems to make
# thing more complicated. Unless there is a reason related to Train worker group
# setup.
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def build(self) -> None:
"""Builds the Learner.
This method should be called before the learner is used. It is responsible for
setting up the LearnerConnectorPipeline, the RLModule, optimizer(s), and
(optionally) the optimizers' learning rate schedulers.
"""
if self._is_built:
logger.debug("Learner already built. Skipping build.")
return
# Build learner connector pipeline used on this Learner worker.
# TODO (sven): Figure out which space to provide here. For now,
# it doesn't matter, as the default connector piece doesn't use
# this information anyway.
# module_spec = self._module_spec.as_multi_rl_module_spec()
self._learner_connector = self.config.build_learner_connector(
input_observation_space=None,
input_action_space=None,
device=self._device,
)
# Build the module to be trained by this learner.
self._module = self._make_module()
# Configure, construct, and register all optimizers needed to train
# `self.module`.
self.configure_optimizers()
self._is_built = True
@property
def distributed(self) -> bool:
"""Whether the learner is running in distributed mode."""
return self.config.num_learners > 1
@property
def module(self) -> MultiRLModule:
"""The MultiRLModule that is being trained."""
return self._module
[docs]
def register_optimizer(
self,
*,
module_id: ModuleID = ALL_MODULES,
optimizer_name: str = DEFAULT_OPTIMIZER,
optimizer: Optimizer,
params: Sequence[Param],
lr_or_lr_schedule: Optional[LearningRateOrSchedule] = None,
) -> None:
"""Registers an optimizer with a ModuleID, name, param list and lr-scheduler.
Use this method in your custom implementations of either
`self.configure_optimizers()` or `self.configure_optimzers_for_module()` (you
should only override one of these!). If you register a learning rate Scheduler
setting together with an optimizer, RLlib will automatically keep this
optimizer's learning rate updated throughout the training process.
Alternatively, you can construct your optimizers directly with a learning rate
and manage learning rate scheduling or updating yourself.
Args:
module_id: The `module_id` under which to register the optimizer. If not
provided, will assume ALL_MODULES.
optimizer_name: The name (str) of the optimizer. If not provided, will
assume DEFAULT_OPTIMIZER.
optimizer: The already instantiated optimizer object to register.
params: A list of parameters (framework-specific variables) that will be
trained/updated
lr_or_lr_schedule: An optional fixed learning rate or learning rate schedule
setup. If provided, RLlib will automatically keep the optimizer's
learning rate updated.
"""
# Validate optimizer instance and its param list.
self._check_registered_optimizer(optimizer, params)
full_registration_name = module_id + "_" + optimizer_name
# Store the given optimizer under the given `module_id`.
self._module_optimizers[module_id].append(full_registration_name)
self._optimizer_name_to_module[full_registration_name] = module_id
# Store the optimizer instance under its full `module_id`_`optimizer_name`
# key.
self._named_optimizers[full_registration_name] = optimizer
# Store all given parameters under the given optimizer.
self._optimizer_parameters[optimizer] = []
for param in params:
param_ref = self.get_param_ref(param)
self._optimizer_parameters[optimizer].append(param_ref)
self._params[param_ref] = param
# Optionally, store a scheduler object along with this optimizer. If such a
# setting is provided, RLlib will handle updating the optimizer's learning rate
# over time.
if lr_or_lr_schedule is not None:
# Validate the given setting.
Scheduler.validate(
fixed_value_or_schedule=lr_or_lr_schedule,
setting_name="lr_or_lr_schedule",
description="learning rate or schedule",
)
# Create the scheduler object for this optimizer.
scheduler = Scheduler(
fixed_value_or_schedule=lr_or_lr_schedule,
framework=self.framework,
device=self._device,
)
self._optimizer_lr_schedules[optimizer] = scheduler
# Set the optimizer to the current (first) learning rate.
self._set_optimizer_lr(
optimizer=optimizer,
lr=scheduler.get_current_value(),
)
[docs]
@OverrideToImplementCustomLogic
@abc.abstractmethod
def compute_gradients(
self, loss_per_module: Dict[ModuleID, TensorType], **kwargs
) -> ParamDict:
"""Computes the gradients based on the given losses.
Args:
loss_per_module: Dict mapping module IDs to their individual total loss
terms, computed by the individual `compute_loss_for_module()` calls.
The overall total loss (sum of loss terms over all modules) is stored
under `loss_per_module[ALL_MODULES]`.
**kwargs: Forward compatibility kwargs.
Returns:
The gradients in the same (flat) format as self._params. Note that all
top-level structures, such as module IDs, will not be present anymore in
the returned dict. It will merely map parameter tensor references to their
respective gradient tensors.
"""
[docs]
@OverrideToImplementCustomLogic
def postprocess_gradients(self, gradients_dict: ParamDict) -> ParamDict:
"""Applies potential postprocessing operations on the gradients.
This method is called after gradients have been computed and modifies them
before they are applied to the respective module(s) by the optimizer(s).
This might include grad clipping by value, norm, or global-norm, or other
algorithm specific gradient postprocessing steps.
This default implementation calls `self.postprocess_gradients_for_module()`
on each of the sub-modules in our MultiRLModule: `self.module` and
returns the accumulated gradients dicts.
Args:
gradients_dict: A dictionary of gradients in the same (flat) format as
self._params. Note that top-level structures, such as module IDs,
will not be present anymore in this dict. It will merely map gradient
tensor references to gradient tensors.
Returns:
A dictionary with the updated gradients and the exact same (flat) structure
as the incoming `gradients_dict` arg.
"""
# The flat gradients dict (mapping param refs to params), returned by this
# method.
postprocessed_gradients = {}
for module_id in self.module.keys():
# Send a gradients dict for only this `module_id` to the
# `self.postprocess_gradients_for_module()` method.
module_grads_dict = {}
for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
module_grads_dict.update(
self.filter_param_dict_for_optimizer(gradients_dict, optimizer)
)
module_grads_dict = self.postprocess_gradients_for_module(
module_id=module_id,
config=self.config.get_config_for_module(module_id),
module_gradients_dict=module_grads_dict,
)
assert isinstance(module_grads_dict, dict)
# Update our return dict.
postprocessed_gradients.update(module_grads_dict)
return postprocessed_gradients
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def postprocess_gradients_for_module(
self,
*,
module_id: ModuleID,
config: Optional["AlgorithmConfig"] = None,
module_gradients_dict: ParamDict,
) -> ParamDict:
"""Applies postprocessing operations on the gradients of the given module.
Args:
module_id: The module ID for which we will postprocess computed gradients.
Note that `module_gradients_dict` already only carries those gradient
tensors that belong to this `module_id`. Other `module_id`'s gradients
are not available in this call.
config: The AlgorithmConfig specific to the given `module_id`.
module_gradients_dict: A dictionary of gradients in the same (flat) format
as self._params, mapping gradient refs to gradient tensors, which are to
be postprocessed. You may alter these tensors in place or create new
ones and return these in a new dict.
Returns:
A dictionary with the updated gradients and the exact same (flat) structure
as the incoming `module_gradients_dict` arg.
"""
postprocessed_grads = {}
if config.grad_clip is None and not config.log_gradients:
postprocessed_grads.update(module_gradients_dict)
return postprocessed_grads
for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
grad_dict_to_clip = self.filter_param_dict_for_optimizer(
param_dict=module_gradients_dict,
optimizer=optimizer,
)
if config.grad_clip:
# Perform gradient clipping, if configured.
global_norm = self._get_clip_function()(
grad_dict_to_clip,
grad_clip=config.grad_clip,
grad_clip_by=config.grad_clip_by,
)
if config.grad_clip_by == "global_norm" or config.log_gradients:
# If we want to log gradients, but do not use the global norm
# for clipping compute it here.
if config.log_gradients and config.grad_clip_by != "global_norm":
# Compute the global norm of gradients.
global_norm = self._get_global_norm_function()(
# Note, `tf.linalg.global_norm` needs a list of tensors.
list(grad_dict_to_clip.values()),
)
self.metrics.log_value(
key=(module_id, f"gradients_{optimizer_name}_global_norm"),
value=global_norm,
window=1,
)
postprocessed_grads.update(grad_dict_to_clip)
# In the other case check, if we want to log gradients only.
elif config.log_gradients:
# Compute the global norm of gradients and log it.
global_norm = self._get_global_norm_function()(
# Note, `tf.linalg.global_norm` needs a list of tensors.
list(grad_dict_to_clip.values()),
)
self.metrics.log_value(
key=(module_id, f"gradients_{optimizer_name}_global_norm"),
value=global_norm,
window=1,
)
return postprocessed_grads
[docs]
@OverrideToImplementCustomLogic
@abc.abstractmethod
def apply_gradients(self, gradients_dict: ParamDict) -> None:
"""Applies the gradients to the MultiRLModule parameters.
Args:
gradients_dict: A dictionary of gradients in the same (flat) format as
self._params. Note that top-level structures, such as module IDs,
will not be present anymore in this dict. It will merely map gradient
tensor references to gradient tensors.
"""
[docs]
def get_optimizer(
self,
module_id: ModuleID = DEFAULT_MODULE_ID,
optimizer_name: str = DEFAULT_OPTIMIZER,
) -> Optimizer:
"""Returns the optimizer object, configured under the given module_id and name.
If only one optimizer was registered under `module_id` (or ALL_MODULES)
via the `self.register_optimizer` method, `optimizer_name` is assumed to be
DEFAULT_OPTIMIZER.
Args:
module_id: The ModuleID for which to return the configured optimizer.
If not provided, will assume DEFAULT_MODULE_ID.
optimizer_name: The name of the optimizer (registered under `module_id` via
`self.register_optimizer()`) to return. If not provided, will assume
DEFAULT_OPTIMIZER.
Returns:
The optimizer object, configured under the given `module_id` and
`optimizer_name`.
"""
# `optimizer_name` could possibly be the full optimizer name (including the
# module_id under which it is registered).
if optimizer_name in self._named_optimizers:
return self._named_optimizers[optimizer_name]
# Normally, `optimizer_name` is just the optimizer's name, not including the
# `module_id`.
full_registration_name = module_id + "_" + optimizer_name
if full_registration_name in self._named_optimizers:
return self._named_optimizers[full_registration_name]
# No optimizer found.
raise KeyError(
f"Optimizer not found! module_id={module_id} "
f"optimizer_name={optimizer_name}"
)
[docs]
def get_optimizers_for_module(
self, module_id: ModuleID = ALL_MODULES
) -> List[Tuple[str, Optimizer]]:
"""Returns a list of (optimizer_name, optimizer instance)-tuples for module_id.
Args:
module_id: The ModuleID for which to return the configured
(optimizer name, optimizer)-pairs. If not provided, will return
optimizers registered under ALL_MODULES.
Returns:
A list of tuples of the format: ([optimizer_name], [optimizer object]),
where optimizer_name is the name under which the optimizer was registered
in `self.register_optimizer`. If only a single optimizer was
configured for `module_id`, [optimizer_name] will be DEFAULT_OPTIMIZER.
"""
named_optimizers = []
for full_registration_name in self._module_optimizers[module_id]:
optimizer = self._named_optimizers[full_registration_name]
# TODO (sven): How can we avoid registering optimziers under this
# constructed `[module_id]_[optim_name]` format?
optim_name = full_registration_name[len(module_id) + 1 :]
named_optimizers.append((optim_name, optimizer))
return named_optimizers
[docs]
def filter_param_dict_for_optimizer(
self, param_dict: ParamDict, optimizer: Optimizer
) -> ParamDict:
"""Reduces the given ParamDict to contain only parameters for given optimizer.
Args:
param_dict: The ParamDict to reduce/filter down to the given `optimizer`.
The returned dict will be a subset of `param_dict` only containing keys
(param refs) that were registered together with `optimizer` (and thus
that `optimizer` is responsible for applying gradients to).
optimizer: The optimizer object to whose parameter refs the given
`param_dict` should be reduced.
Returns:
A new ParamDict only containing param ref keys that belong to `optimizer`.
"""
# Return a sub-dict only containing those param_ref keys (and their values)
# that belong to the `optimizer`.
return {
ref: param_dict[ref]
for ref in self._optimizer_parameters[optimizer]
if ref in param_dict and param_dict[ref] is not None
}
[docs]
@abc.abstractmethod
def get_param_ref(self, param: Param) -> Hashable:
"""Returns a hashable reference to a trainable parameter.
This should be overridden in framework specific specialization. For example in
torch it will return the parameter itself, while in tf it returns the .ref() of
the variable. The purpose is to retrieve a unique reference to the parameters.
Args:
param: The parameter to get the reference to.
Returns:
A reference to the parameter.
"""
[docs]
@abc.abstractmethod
def get_parameters(self, module: RLModule) -> Sequence[Param]:
"""Returns the list of parameters of a module.
This should be overridden in framework specific learner. For example in torch it
will return .parameters(), while in tf it returns .trainable_variables.
Args:
module: The module to get the parameters from.
Returns:
The parameters of the module.
"""
@abc.abstractmethod
def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch:
"""Converts the elements of a MultiAgentBatch to Tensors on the correct device.
Args:
batch: The MultiAgentBatch object to convert.
Returns:
The resulting MultiAgentBatch with framework-specific tensor values placed
on the correct device.
"""
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def add_module(
self,
*,
module_id: ModuleID,
module_spec: RLModuleSpec,
config_overrides: Optional[Dict] = None,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiRLModuleSpec:
"""Adds a module to the underlying MultiRLModule.
Changes this Learner's config in order to make this architectural change
permanent wrt. to checkpointing.
Args:
module_id: The ModuleID of the module to be added.
module_spec: The ModuleSpec of the module to be added.
config_overrides: The `AlgorithmConfig` overrides that should apply to
the new Module, if any.
new_should_module_be_updated: An optional sequence of ModuleIDs or a
callable taking ModuleID and SampleBatchType and returning whether the
ModuleID should be updated (trained).
If None, will keep the existing setup in place. RLModules,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
Returns:
The new MultiRLModuleSpec (after the RLModule has been added).
"""
validate_module_id(module_id, error=True)
self._check_is_built()
# Force-set inference-only = False.
module_spec = copy.deepcopy(module_spec)
module_spec.inference_only = False
# Build the new RLModule and add it to self.module.
module = module_spec.build()
self.module.add_module(module_id, module)
# Change our config (AlgorithmConfig) to contain the new Module.
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
# but we'll deprecate config.policies soon anyway.
self.config.policies[module_id] = PolicySpec()
if config_overrides is not None:
self.config.multi_agent(
algorithm_config_overrides_per_module={module_id: config_overrides}
)
self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module))
self._module_spec = self.config.rl_module_spec
if new_should_module_be_updated is not None:
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
# Allow the user to configure one or more optimizers for this new module.
self.configure_optimizers_for_module(
module_id=module_id,
config=self.config.get_config_for_module(module_id),
)
return self.config.rl_module_spec
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def remove_module(
self,
module_id: ModuleID,
*,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiRLModuleSpec:
"""Removes a module from the Learner.
Args:
module_id: The ModuleID of the module to be removed.
new_should_module_be_updated: An optional sequence of ModuleIDs or a
callable taking ModuleID and SampleBatchType and returning whether the
ModuleID should be updated (trained).
If None, will keep the existing setup in place. RLModules,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
Returns:
The new MultiRLModuleSpec (after the RLModule has been removed).
"""
self._check_is_built()
module = self.module[module_id]
# Delete the removed module's parameters and optimizers.
if self._is_module_compatible_with_learner(module):
parameters = self.get_parameters(module)
for param in parameters:
param_ref = self.get_param_ref(param)
if param_ref in self._params:
del self._params[param_ref]
for optimizer_name, optimizer in self.get_optimizers_for_module(module_id):
del self._optimizer_parameters[optimizer]
name = module_id + "_" + optimizer_name
del self._named_optimizers[name]
if optimizer in self._optimizer_lr_schedules:
del self._optimizer_lr_schedules[optimizer]
del self._module_optimizers[module_id]
# Remove the module from the MultiRLModule.
self.module.remove_module(module_id)
# Change self.config to reflect the new architecture.
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
# but we'll deprecate config.policies soon anyway.
del self.config.policies[module_id]
self.config.algorithm_config_overrides_per_module.pop(module_id, None)
if new_should_module_be_updated is not None:
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module))
# Remove all stats from the module from our metrics logger, so we don't report
# results from this module again.
if module_id in self.metrics.stats:
del self.metrics.stats[module_id]
return self.config.rl_module_spec
[docs]
@OverrideToImplementCustomLogic
def should_module_be_updated(self, module_id, multi_agent_batch=None):
"""Returns whether a module should be updated or not based on `self.config`.
Args:
module_id: The ModuleID that we want to query on whether this module
should be updated or not.
multi_agent_batch: An optional MultiAgentBatch to possibly provide further
information on the decision on whether the RLModule should be updated
or not.
"""
should_module_be_updated_fn = self.config.policies_to_train
# If None, return True (by default, all modules should be updated).
if should_module_be_updated_fn is None:
return True
# If collection given, return whether `module_id` is in that container.
elif not callable(should_module_be_updated_fn):
return module_id in set(should_module_be_updated_fn)
return should_module_be_updated_fn(module_id, multi_agent_batch)
[docs]
@OverrideToImplementCustomLogic
def compute_losses(
self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any]
) -> Dict[str, Any]:
"""Computes the loss(es) for the module being optimized.
This method must be overridden by MultiRLModule-specific Learners in order to
define the specific loss computation logic. If the algorithm is single-agent,
only `compute_loss_for_module()` should be overridden instead. If the algorithm
uses independent multi-agent learning (default behavior for RLlib's multi-agent
setups), also only `compute_loss_for_module()` should be overridden, but it will
be called for each individual RLModule inside the MultiRLModule.
It is recommended to not compute any forward passes within this method, and to
use the `forward_train()` outputs of the RLModule(s) to compute the required
loss tensors.
See here for a custom loss function example script:
https://github.com/ray-project/ray/blob/master/rllib/examples/learners/custom_loss_fn_simple.py # noqa
Args:
fwd_out: Output from a call to the `forward_train()` method of the
underlying MultiRLModule (`self.module`) during training
(`self.update()`).
batch: The train batch that was used to compute `fwd_out`.
Returns:
A dictionary mapping module IDs to individual loss terms.
"""
loss_per_module = {}
for module_id in fwd_out:
module_batch = batch[module_id]
module_fwd_out = fwd_out[module_id]
module = self.module[module_id].unwrapped()
if isinstance(module, SelfSupervisedLossAPI):
loss = module.compute_self_supervised_loss(
learner=self,
module_id=module_id,
config=self.config.get_config_for_module(module_id),
batch=module_batch,
fwd_out=module_fwd_out,
)
else:
loss = self.compute_loss_for_module(
module_id=module_id,
config=self.config.get_config_for_module(module_id),
batch=module_batch,
fwd_out=module_fwd_out,
)
loss_per_module[module_id] = loss
return loss_per_module
[docs]
@OverrideToImplementCustomLogic
@abc.abstractmethod
def compute_loss_for_module(
self,
*,
module_id: ModuleID,
config: "AlgorithmConfig",
batch: Dict[str, Any],
fwd_out: Dict[str, TensorType],
) -> TensorType:
"""Computes the loss for a single module.
Think of this as computing loss for a single agent. For multi-agent use-cases
that require more complicated computation for loss, consider overriding the
`compute_losses` method instead.
Args:
module_id: The id of the module.
config: The AlgorithmConfig specific to the given `module_id`.
batch: The train batch for this particular module.
fwd_out: The output of the forward pass for this particular module.
Returns:
A single total loss tensor. If you have more than one optimizer on the
provided `module_id` and would like to compute gradients separately using
these different optimizers, simply add up the individual loss terms for
each optimizer and return the sum. Also, for recording/logging any
individual loss terms, you can use the `Learner.metrics.log_value(
key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See:
:py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more
information.
"""
[docs]
def update_from_batch(
self,
batch: MultiAgentBatch,
*,
# TODO (sven): Make this a more formal structure with its own type.
timesteps: Optional[Dict[str, Any]] = None,
num_epochs: int = 1,
minibatch_size: Optional[int] = None,
shuffle_batch_per_epoch: bool = False,
# Deprecated args.
num_iters=DEPRECATED_VALUE,
**kwargs,
) -> ResultDict:
"""Run `num_epochs` epochs over the given train batch.
You can use this method to take more than one backward pass on the batch.
The same `minibatch_size` and `num_epochs` will be used for all module ids in
MultiRLModule.
Args:
batch: A batch of training data to update from.
timesteps: Timesteps dict, which must have the key
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
# TODO (sven): Make this a more formal structure with its own type.
num_epochs: The number of complete passes over the entire train batch. Each
pass might be further split into n minibatches (if `minibatch_size`
provided).
minibatch_size: The size of minibatches to use to further split the train
`batch` into sub-batches. The `batch` is then iterated over n times
where n is `len(batch) // minibatch_size`.
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
If the train batch has a time rank (axis=1), shuffling will only take
place along the batch axis to not disturb any intact (episode)
trajectories. Also, shuffling is always skipped if `minibatch_size` is
None, meaning the entire train batch is processed each epoch, making it
unnecessary to shuffle.
Returns:
A `ResultDict` object produced by a call to `self.metrics.reduce()`. The
returned dict may be arbitrarily nested and must have `Stats` objects at
all its leafs, allowing components further downstream (i.e. a user of this
Learner) to further reduce these results (for example over n parallel
Learners).
"""
if num_iters != DEPRECATED_VALUE:
deprecation_warning(
old="Learner.update_from_episodes(num_iters=...)",
new="Learner.update_from_episodes(num_epochs=...)",
error=True,
)
self._update_from_batch_or_episodes(
batch=batch,
timesteps=timesteps,
num_epochs=num_epochs,
minibatch_size=minibatch_size,
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
)
return self.metrics.reduce()
[docs]
def update_from_episodes(
self,
episodes: List[EpisodeType],
*,
# TODO (sven): Make this a more formal structure with its own type.
timesteps: Optional[Dict[str, Any]] = None,
num_epochs: int = 1,
minibatch_size: Optional[int] = None,
shuffle_batch_per_epoch: bool = False,
num_total_minibatches: int = 0,
# Deprecated args.
num_iters=DEPRECATED_VALUE,
) -> ResultDict:
"""Run `num_epochs` epochs over the train batch generated from `episodes`.
You can use this method to take more than one backward pass on the batch.
The same `minibatch_size` and `num_epochs` will be used for all module ids in
MultiRLModule.
Args:
episodes: An list of episode objects to update from.
timesteps: Timesteps dict, which must have the key
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
# TODO (sven): Make this a more formal structure with its own type.
num_epochs: The number of complete passes over the entire train batch. Each
pass might be further split into n minibatches (if `minibatch_size`
provided). The train batch is generated from the given `episodes`
through the Learner connector pipeline.
minibatch_size: The size of minibatches to use to further split the train
`batch` into sub-batches. The `batch` is then iterated over n times
where n is `len(batch) // minibatch_size`. The train batch is generated
from the given `episodes` through the Learner connector pipeline.
shuffle_batch_per_epoch: Whether to shuffle the train batch once per epoch.
If the train batch has a time rank (axis=1), shuffling will only take
place along the batch axis to not disturb any intact (episode)
trajectories. Also, shuffling is always skipped if `minibatch_size` is
None, meaning the entire train batch is processed each epoch, making it
unnecessary to shuffle. The train batch is generated from the given
`episodes` through the Learner connector pipeline.
num_total_minibatches: The total number of minibatches to loop through
(over all `num_epochs` epochs). It's only required to set this to != 0
in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes
themselves are roughly sharded equally, however, they might contain
SingleAgentEpisodes with very lopsided length distributions. Thus,
without this fixed, pre-computed value, one Learner might go through a
different number of minibatche passes than others causing a deadlock.
Returns:
A `ResultDict` object produced by a call to `self.metrics.reduce()`. The
returned dict may be arbitrarily nested and must have `Stats` objects at
all its leafs, allowing components further downstream (i.e. a user of this
Learner) to further reduce these results (for example over n parallel
Learners).
"""
if num_iters != DEPRECATED_VALUE:
deprecation_warning(
old="Learner.update_from_episodes(num_iters=...)",
new="Learner.update_from_episodes(num_epochs=...)",
error=True,
)
self._update_from_batch_or_episodes(
episodes=episodes,
timesteps=timesteps,
num_epochs=num_epochs,
minibatch_size=minibatch_size,
shuffle_batch_per_epoch=shuffle_batch_per_epoch,
num_total_minibatches=num_total_minibatches,
)
return self.metrics.reduce()
def update_from_iterator(
self,
iterator,
*,
timesteps: Optional[Dict[str, Any]] = None,
minibatch_size: Optional[int] = None,
num_iters: int = None,
**kwargs,
):
if "num_epochs" in kwargs:
raise ValueError(
"`num_epochs` arg NOT supported by Learner.update_from_iterator! Use "
"`num_iters` instead."
)
if not self.iterator:
self.iterator = iterator
self._check_is_built()
# Call `before_gradient_based_update` to allow for non-gradient based
# preparations-, logging-, and update logic to happen.
self.before_gradient_based_update(timesteps=timesteps or {})
def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]:
# Note, the incoming batch is a dictionary with a numpy array
# holding the `MultiAgentBatch`.
batch = self._convert_batch_type(batch["batch"][0])
return {"batch": self._set_slicing_by_batch_id(batch, value=True)}
i = 0
logger.debug(f"===> [Learner {id(self)}]: Looping through batches ... ")
for batch in self.iterator.iter_batches(
# Note, this needs to be one b/c data is already mapped to
# `MultiAgentBatch`es of `minibatch_size`.
batch_size=1,
_finalize_fn=_finalize_fn,
**kwargs,
):
# Update the iteration counter.
i += 1
# Note, `_finalize_fn` must return a dictionary.
batch = batch["batch"]
logger.debug(
f"===> [Learner {id(self)}]: batch {i} with {batch.env_steps()} rows."
)
# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
# found in this batch. If not, throw an error.
unknown_module_ids = set(batch.policy_batches.keys()) - set(
self.module.keys()
)
if len(unknown_module_ids) > 0:
raise ValueError(
"Batch contains one or more ModuleIDs that are not in this "
f"Learner! Found IDs: {unknown_module_ids}"
)
# Log metrics.
self._log_steps_trained_metrics(batch)
# Make the actual in-graph/traced `_update` call. This should return
# all tensor values (no numpy).
fwd_out, loss_per_module, tensor_metrics = self._update(
batch.policy_batches
)
# Convert logged tensor metrics (logged during tensor-mode of MetricsLogger)
# to actual (numpy) values.
self.metrics.tensors_to_numpy(tensor_metrics)
self._set_slicing_by_batch_id(batch, value=False)
# If `num_iters` is reached break and return.
if num_iters and i == num_iters:
break
logger.debug(
f"===> [Learner {id(self)}] number of iterations run in this epoch: {i}"
)
# Log all individual RLModules' loss terms and its registered optimizers'
# current learning rates.
for mid, loss in convert_to_numpy(loss_per_module).items():
self.metrics.log_value(
key=(mid, self.TOTAL_LOSS_KEY),
value=loss,
window=1,
)
# Call `after_gradient_based_update` to allow for non-gradient based
# cleanups-, logging-, and update logic to happen.
# TODO (simon): Check, if this should stay here, when running multiple
# gradient steps inside the iterator loop above (could be a complete epoch)
# the target networks might need to be updated earlier.
self.after_gradient_based_update(timesteps=timesteps or {})
# Reduce results across all minibatch update steps.
return self.metrics.reduce()
[docs]
@OverrideToImplementCustomLogic
@abc.abstractmethod
def _update(
self,
batch: Dict[str, Any],
**kwargs,
) -> Tuple[Any, Any, Any]:
"""Contains all logic for an in-graph/traceable update step.
Framework specific subclasses must implement this method. This should include
calls to the RLModule's `forward_train`, `compute_loss`, compute_gradients`,
`postprocess_gradients`, and `apply_gradients` methods and return a tuple
with all the individual results.
Args:
batch: The train batch already converted to a Dict mapping str to (possibly
nested) tensors.
kwargs: Forward compatibility kwargs.
Returns:
A tuple consisting of:
1) The `forward_train()` output of the RLModule,
2) the loss_per_module dictionary mapping module IDs to individual loss
tensors
3) a metrics dict mapping module IDs to metrics key/value pairs.
"""
[docs]
@override(Checkpointable)
def get_state(
self,
components: Optional[Union[str, Collection[str]]] = None,
*,
not_components: Optional[Union[str, Collection[str]]] = None,
**kwargs,
) -> StateDict:
self._check_is_built()
state = {
"should_module_be_updated": self.config.policies_to_train,
}
if self._check_component(COMPONENT_RL_MODULE, components, not_components):
state[COMPONENT_RL_MODULE] = self.module.get_state(
components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
not_components=self._get_subcomponents(
COMPONENT_RL_MODULE, not_components
),
**kwargs,
)
state[WEIGHTS_SEQ_NO] = self._weights_seq_no
if self._check_component(COMPONENT_OPTIMIZER, components, not_components):
state[COMPONENT_OPTIMIZER] = self._get_optimizer_state()
if self._check_component(COMPONENT_METRICS_LOGGER, components, not_components):
# TODO (sven): Make `MetricsLogger` a Checkpointable.
state[COMPONENT_METRICS_LOGGER] = self.metrics.get_state()
return state
[docs]
@override(Checkpointable)
def set_state(self, state: StateDict) -> None:
self._check_is_built()
weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0)
if COMPONENT_RL_MODULE in state:
if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no:
self.module.set_state(state[COMPONENT_RL_MODULE])
if COMPONENT_OPTIMIZER in state:
self._set_optimizer_state(state[COMPONENT_OPTIMIZER])
# Update our weights_seq_no, if the new one is > 0.
if weights_seq_no > 0:
self._weights_seq_no = weights_seq_no
# Update our trainable Modules information/function via our config.
# If not provided in state (None), all Modules will be trained by default.
if "should_module_be_updated" in state:
self.config.multi_agent(policies_to_train=state["should_module_be_updated"])
# TODO (sven): Make `MetricsLogger` a Checkpointable.
if COMPONENT_METRICS_LOGGER in state:
self.metrics.set_state(state[COMPONENT_METRICS_LOGGER])
@override(Checkpointable)
def get_ctor_args_and_kwargs(self):
return (
(), # *args,
{
"config": self.config,
"module_spec": self._module_spec,
"module": self._module_obj,
}, # **kwargs
)
@override(Checkpointable)
def get_checkpointable_components(self):
if not self._check_is_built(error=False):
self.build()
return [
(COMPONENT_RL_MODULE, self.module),
]
[docs]
def _get_optimizer_state(self) -> StateDict:
"""Returns the state of all optimizers currently registered in this Learner.
Returns:
The current state of all optimizers currently registered in this Learner.
"""
raise NotImplementedError
[docs]
def _set_optimizer_state(self, state: StateDict) -> None:
"""Sets the state of all optimizers currently registered in this Learner.
Args:
state: The state of the optimizers.
"""
raise NotImplementedError
def _update_from_batch_or_episodes(
self,
*,
# TODO (sven): We should allow passing in a single agent batch here
# as well for simplicity.
batch: Optional[MultiAgentBatch] = None,
episodes: Optional[List[EpisodeType]] = None,
# TODO (sven): Make this a more formal structure with its own type.
timesteps: Optional[Dict[str, Any]] = None,
# TODO (sven): Deprecate these in favor of config attributes for only those
# algos that actually need (and know how) to do minibatching.
num_epochs: int = 1,
minibatch_size: Optional[int] = None,
shuffle_batch_per_epoch: bool = False,
num_total_minibatches: int = 0,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
self._check_is_built()
# Call `before_gradient_based_update` to allow for non-gradient based
# preparations-, logging-, and update logic to happen.
self.before_gradient_based_update(timesteps=timesteps or {})
# Resolve batch/episodes being ray object refs (instead of
# actual batch/episodes objects).
if isinstance(batch, ray.ObjectRef):
batch = ray.get(batch)
if isinstance(episodes, ray.ObjectRef) or (
isinstance(episodes, list) and isinstance(episodes[0], ray.ObjectRef)
):
episodes = ray.get(episodes)
episodes = tree.flatten(episodes)
# Call the learner connector.
if episodes is not None:
# Call the learner connector pipeline.
with self.metrics.log_time((ALL_MODULES, LEARNER_CONNECTOR_TIMER)):
shared_data = {}
batch = self._learner_connector(
rl_module=self.module,
batch=batch if batch is not None else {},
episodes=episodes,
shared_data=shared_data,
)
# Convert to a batch.
# TODO (sven): Try to not require MultiAgentBatch anymore.
batch = MultiAgentBatch(
{
module_id: (
SampleBatch(module_data, _zero_padded=True)
if shared_data.get(f"_zero_padded_for_mid={module_id}")
else SampleBatch(module_data)
)
for module_id, module_data in batch.items()
},
env_steps=sum(len(e) for e in episodes),
)
# Single-agent SampleBatch: Have to convert to MultiAgentBatch.
elif isinstance(batch, SampleBatch):
assert len(self.module) == 1
batch = MultiAgentBatch(
{next(iter(self.module.keys())): batch}, env_steps=len(batch)
)
# Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs
# found in this batch. If not, throw an error.
unknown_module_ids = set(batch.policy_batches.keys()) - set(self.module.keys())
if len(unknown_module_ids) > 0:
raise ValueError(
"Batch contains one or more ModuleIDs that are not in this Learner! "
f"Found IDs: {unknown_module_ids}"
)
# TODO: Move this into LearnerConnector pipeline?
# Filter out those RLModules from the final train batch that should not be
# updated.
for module_id in list(batch.policy_batches.keys()):
if not self.should_module_be_updated(module_id, batch):
del batch.policy_batches[module_id]
# Log all timesteps (env, agent, modules) based on given episodes/batch.
self._log_steps_trained_metrics(batch)
if minibatch_size:
if self._learner_connector is not None:
batch_iter = partial(
MiniBatchCyclicIterator, _uses_new_env_runners=True
)
else:
batch_iter = MiniBatchCyclicIterator
elif num_epochs > 1:
# `minibatch_size` was not set but `num_epochs` > 1.
# Under the old training stack, users could do multiple epochs
# over a batch without specifying a minibatch size. We enable
# this behavior here by setting the minibatch size to be the size
# of the batch (e.g. 1 minibatch of size batch.count)
minibatch_size = batch.count
# Note that there is no need to shuffle here, b/c we don't have minibatches.
batch_iter = MiniBatchCyclicIterator
else:
# `minibatch_size` and `num_epochs` are not set by the user.
batch_iter = MiniBatchDummyIterator
batch = self._set_slicing_by_batch_id(batch, value=True)
for tensor_minibatch in batch_iter(
batch,
num_epochs=num_epochs,
minibatch_size=minibatch_size,
shuffle_batch_per_epoch=shuffle_batch_per_epoch and (num_epochs > 1),
num_total_minibatches=num_total_minibatches,
):
# Make the actual in-graph/traced `_update` call. This should return
# all tensor values (no numpy).
fwd_out, loss_per_module, tensor_metrics = self._update(
tensor_minibatch.policy_batches
)
# Convert logged tensor metrics (logged during tensor-mode of MetricsLogger)
# to actual (numpy) values.
self.metrics.tensors_to_numpy(tensor_metrics)
# Log all individual RLModules' loss terms and its registered optimizers'
# current learning rates.
for mid, loss in convert_to_numpy(loss_per_module).items():
self.metrics.log_value(
key=(mid, self.TOTAL_LOSS_KEY),
value=loss,
window=1,
)
self._weights_seq_no += 1
self.metrics.log_dict(
{
(mid, WEIGHTS_SEQ_NO): self._weights_seq_no
for mid in batch.policy_batches.keys()
},
window=1,
)
self._set_slicing_by_batch_id(batch, value=False)
# Call `after_gradient_based_update` to allow for non-gradient based
# cleanups-, logging-, and update logic to happen.
self.after_gradient_based_update(timesteps=timesteps or {})
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def before_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
"""Called before gradient-based updates are completed.
Should be overridden to implement custom preparation-, logging-, or
non-gradient-based Learner/RLModule update logic before(!) gradient-based
updates are performed.
Args:
timesteps: Timesteps dict, which must have the key
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
# TODO (sven): Make this a more formal structure with its own type.
"""
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def after_gradient_based_update(self, *, timesteps: Dict[str, Any]) -> None:
"""Called after gradient-based updates are completed.
Should be overridden to implement custom cleanup-, logging-, or non-gradient-
based Learner/RLModule update logic after(!) gradient-based updates have been
completed.
Args:
timesteps: Timesteps dict, which must have the key
`NUM_ENV_STEPS_SAMPLED_LIFETIME`.
# TODO (sven): Make this a more formal structure with its own type.
"""
# Only update this optimizer's lr, if a scheduler has been registered
# along with it.
for module_id, optimizer_names in self._module_optimizers.items():
for optimizer_name in optimizer_names:
optimizer = self._named_optimizers[optimizer_name]
# Update and log learning rate of this optimizer.
lr_schedule = self._optimizer_lr_schedules.get(optimizer)
if lr_schedule is not None:
new_lr = lr_schedule.update(
timestep=timesteps.get(NUM_ENV_STEPS_SAMPLED_LIFETIME, 0)
)
self._set_optimizer_lr(optimizer, lr=new_lr)
self.metrics.log_value(
# Cut out the module ID from the beginning since it's already part
# of the key sequence: (ModuleID, "[optim name]_lr").
key=(module_id, f"{optimizer_name[len(module_id) + 1:]}_{LR_KEY}"),
value=convert_to_numpy(self._get_optimizer_lr(optimizer)),
window=1,
)
def _set_slicing_by_batch_id(
self, batch: MultiAgentBatch, *, value: bool
) -> MultiAgentBatch:
"""Enables slicing by batch id in the given batch.
If the input batch contains batches of sequences we need to make sure when
slicing happens it is sliced via batch id and not timestamp. Calling this
method enables the same flag on each SampleBatch within the input
MultiAgentBatch.
Args:
batch: The MultiAgentBatch to enable slicing by batch id on.
value: The value to set the flag to.
Returns:
The input MultiAgentBatch with the indexing flag is enabled / disabled on.
"""
for pid, policy_batch in batch.policy_batches.items():
# We assume that arriving batches for recurrent modules OR batches that
# have a SEQ_LENS column are already zero-padded to the max sequence length
# and have tensors of shape [B, T, ...]. Therefore, we slice sequence
# lengths in B. See SampleBatch for more information.
if (
self.module[pid].is_stateful()
or policy_batch.get("seq_lens") is not None
):
if value:
policy_batch.enable_slicing_by_batch_id()
else:
policy_batch.disable_slicing_by_batch_id()
return batch
[docs]
@abc.abstractmethod
def _is_module_compatible_with_learner(self, module: RLModule) -> bool:
"""Check whether the module is compatible with the learner.
For example, if there is a random RLModule, it will not be a torch or tf
module, but rather it is a numpy module. Therefore we should not consider it
during gradient based optimization.
Args:
module: The module to check.
Returns:
True if the module is compatible with the learner.
"""
[docs]
def _make_module(self) -> MultiRLModule:
"""Construct the multi-agent RL module for the learner.
This method uses `self._module_specs` or `self._module_obj` to construct the
module. If the module_class is a single agent RL module it will be wrapped to a
multi-agent RL module. Override this method if there are other things that
need to happen for instantiation of the module.
Returns:
A constructed MultiRLModule.
"""
# Module was provided directly through constructor -> Use as-is.
if self._module_obj is not None:
module = self._module_obj
self._module_spec = MultiRLModuleSpec.from_module(module)
# RLModuleSpec was provided directly through constructor -> Use it to build the
# RLModule.
elif self._module_spec is not None:
module = self._module_spec.build()
# Try using our config object. Note that this would only work if the config
# object has all the necessary space information already in it.
else:
module = self.config.get_multi_agent_module_spec().build()
# If not already, convert to MultiRLModule.
module = module.as_multi_rl_module()
return module
[docs]
def _check_registered_optimizer(
self,
optimizer: Optimizer,
params: Sequence[Param],
) -> None:
"""Checks that the given optimizer and parameters are valid for the framework.
Args:
optimizer: The optimizer object to check.
params: The list of parameters to check.
"""
if not isinstance(params, list):
raise ValueError(
f"`params` ({params}) must be a list of framework-specific parameters "
"(variables)!"
)
[docs]
def _check_is_built(self, error: bool = True) -> bool:
if self.module is None:
if error:
raise ValueError(
"Learner.build() must be called after constructing a "
"Learner and before calling any methods on it."
)
return False
return True
def _reset(self):
self._params = {}
self._optimizer_parameters = {}
self._named_optimizers = {}
self._module_optimizers = defaultdict(list)
self._optimizer_lr_schedules = {}
self.metrics = MetricsLogger()
self._is_built = False
def apply(self, func, *_args, **_kwargs):
return func(self, *_args, **_kwargs)
[docs]
@abc.abstractmethod
def _get_tensor_variable(
self,
value: Any,
dtype: Any = None,
trainable: bool = False,
) -> TensorType:
"""Returns a framework-specific tensor variable with the initial given value.
This is a framework specific method that should be implemented by the
framework specific sub-classes.
Args:
value: The initial value for the tensor variable variable.
Returns:
The framework specific tensor variable of the given initial value,
dtype and trainable/requires_grad property.
"""
@staticmethod
@abc.abstractmethod
def _get_optimizer_lr(optimizer: Optimizer) -> float:
"""Returns the current learning rate of the given local optimizer.
Args:
optimizer: The local optimizer to get the current learning rate for.
Returns:
The learning rate value (float) of the given optimizer.
"""
[docs]
@staticmethod
@abc.abstractmethod
def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None:
"""Updates the learning rate of the given local optimizer.
Args:
optimizer: The local optimizer to update the learning rate for.
lr: The new learning rate.
"""
[docs]
@staticmethod
@abc.abstractmethod
def _get_clip_function() -> Callable:
"""Returns the gradient clipping function to use, given the framework."""
@staticmethod
@abc.abstractmethod
def _get_global_norm_function() -> Callable:
"""Returns the global norm function to use, given the framework."""
def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
"""Logs this iteration's steps trained, based on given `batch`."""
for mid, module_batch in batch.policy_batches.items():
module_batch_size = len(module_batch)
# Log average batch size (for each module).
self.metrics.log_value(
key=(mid, MODULE_TRAIN_BATCH_SIZE_MEAN),
value=module_batch_size,
)
# Log module steps (for each module).
self.metrics.log_value(
key=(mid, NUM_MODULE_STEPS_TRAINED),
value=module_batch_size,
reduce="sum",
clear_on_reduce=True,
)
self.metrics.log_value(
key=(mid, NUM_MODULE_STEPS_TRAINED_LIFETIME),
value=module_batch_size,
reduce="sum",
)
# Log module steps (sum of all modules).
self.metrics.log_value(
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED),
value=module_batch_size,
reduce="sum",
clear_on_reduce=True,
)
self.metrics.log_value(
key=(ALL_MODULES, NUM_MODULE_STEPS_TRAINED_LIFETIME),
value=module_batch_size,
reduce="sum",
)
# Log env steps (all modules).
self.metrics.log_value(
(ALL_MODULES, NUM_ENV_STEPS_TRAINED),
batch.env_steps(),
reduce="sum",
clear_on_reduce=True,
)
self.metrics.log_value(
(ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME),
batch.env_steps(),
reduce="sum",
with_throughput=True,
)
@Deprecated(
new="Learner.before_gradient_based_update("
"timesteps={'num_env_steps_sampled_lifetime': ...}) and/or "
"Learner.after_gradient_based_update("
"timesteps={'num_env_steps_sampled_lifetime': ...})",
error=True,
)
def additional_update_for_module(self, *args, **kwargs):
pass
@Deprecated(new="Learner.save_to_path(...)", error=True)
def save_state(self, *args, **kwargs):
pass
@Deprecated(new="Learner.restore_from_path(...)", error=True)
def load_state(self, *args, **kwargs):
pass
@Deprecated(new="Learner.module.get_state()", error=True)
def get_module_state(self, *args, **kwargs):
pass
@Deprecated(new="Learner.module.set_state()", error=True)
def set_module_state(self, *args, **kwargs):
pass
@Deprecated(new="Learner._get_optimizer_state()", error=True)
def get_optimizer_state(self, *args, **kwargs):
pass
@Deprecated(new="Learner._set_optimizer_state()", error=True)
def set_optimizer_state(self, *args, **kwargs):
pass
@Deprecated(new="Learner.compute_losses(...)", error=False)
def compute_loss(self, *args, **kwargs):
losses_per_module = self.compute_losses(*args, **kwargs)
# To continue supporting the old `compute_loss` behavior (instead of
# the new `compute_losses`, add the ALL_MODULES key here holding the sum
# of all individual loss terms.
if ALL_MODULES not in losses_per_module:
losses_per_module[ALL_MODULES] = sum(losses_per_module.values())
return losses_per_module