import pathlib
from collections import defaultdict, Counter
import copy
from functools import partial
import itertools
from typing import (
Any,
Callable,
Collection,
Dict,
List,
Optional,
Set,
Type,
TYPE_CHECKING,
Union,
)
import tree # pip install dm_tree
import ray
from ray import ObjectRef
from ray.rllib.core import COMPONENT_LEARNER, COMPONENT_RL_MODULE
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actor_manager import (
FaultTolerantActorManager,
RemoteCallResults,
ResultOrError,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
deprecation_warning,
)
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.minibatch_utils import (
ShardBatchIterator,
ShardEpisodesIterator,
ShardObjectRefIterator,
)
from ray.rllib.utils.typing import (
EpisodeType,
ModuleID,
RLModuleSpecType,
ShouldModuleBeUpdatedFn,
StateDict,
T,
)
from ray.train._internal.backend_executor import BackendExecutor
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
def _get_backend_config(learner_class: Type[Learner]) -> str:
if learner_class.framework == "torch":
from ray.train.torch import TorchConfig
backend_config = TorchConfig()
elif learner_class.framework == "tf2":
from ray.train.tensorflow import TensorflowConfig
backend_config = TensorflowConfig()
else:
raise ValueError(
"`learner_class.framework` must be either 'torch' or 'tf2' (but is "
f"{learner_class.framework}!"
)
return backend_config
[docs]@PublicAPI(stability="alpha")
class LearnerGroup(Checkpointable):
"""Coordinator of n (possibly remote) Learner workers.
Each Learner worker has a copy of the RLModule, the loss function(s), and
one or more optimizers.
"""
[docs] def __init__(
self,
*,
config: "AlgorithmConfig",
# TODO (sven): Rename into `rl_module_spec`.
module_spec: Optional[RLModuleSpecType] = None,
):
"""Initializes a LearnerGroup instance.
Args:
config: The AlgorithmConfig object to use to configure this LearnerGroup.
Call the `learners(num_learners=...)` method on your config to
specify the number of learner workers to use.
Call the same method with arguments `num_cpus_per_learner` and/or
`num_gpus_per_learner` to configure the compute used by each
Learner worker in this LearnerGroup.
Call the `training(learner_class=...)` method on your config to specify,
which exact Learner class to use.
Call the `rl_module(rl_module_spec=...)` method on your config to set up
the specifics for your RLModule to be used in each Learner.
module_spec: If not already specified in `config`, a separate overriding
RLModuleSpec may be provided via this argument.
"""
self.config = config.copy(copy_frozen=False)
self._module_spec = module_spec
learner_class = self.config.learner_class
module_spec = module_spec or self.config.get_multi_rl_module_spec()
self._learner = None
self._workers = None
# If a user calls self.shutdown() on their own then this flag is set to true.
# When del is called the backend executor isn't shutdown twice if this flag is
# true. the backend executor would otherwise log a warning to the console from
# ray train.
self._is_shut_down = False
# How many timesteps had to be dropped due to a full input queue?
self._ts_dropped = 0
# A single local Learner.
if not self.is_remote:
self._learner = learner_class(config=config, module_spec=module_spec)
self._learner.build()
self._worker_manager = None
# N remote Learner workers.
else:
backend_config = _get_backend_config(learner_class)
# TODO (sven): Can't set both `num_cpus_per_learner`>1 and
# `num_gpus_per_learner`>0! Users must set one or the other due
# to issues with placement group fragmentation. See
# https://github.com/ray-project/ray/issues/35409 for more details.
num_cpus_per_learner = (
self.config.num_cpus_per_learner
if not self.config.num_gpus_per_learner
else 0
)
num_gpus_per_learner = self.config.num_gpus_per_learner
resources_per_learner = {
"CPU": num_cpus_per_learner,
"GPU": num_gpus_per_learner,
}
backend_executor = BackendExecutor(
backend_config=backend_config,
num_workers=self.config.num_learners,
resources_per_worker=resources_per_learner,
max_retries=0,
)
backend_executor.start(
train_cls=learner_class,
train_cls_kwargs={
"config": config,
"module_spec": module_spec,
},
)
self._backend_executor = backend_executor
self._workers = [w.actor for w in backend_executor.worker_group.workers]
# Run the neural network building code on remote workers.
ray.get([w.build.remote() for w in self._workers])
self._worker_manager = FaultTolerantActorManager(
self._workers,
# TODO (sven): This probably works even without any restriction
# (allowing for any arbitrary number of requests in-flight). Test with
# 3 first, then with unlimited, and if both show the same behavior on
# an async algo, remove this restriction entirely.
max_remote_requests_in_flight_per_actor=3,
)
# Counters for the tags for asynchronous update requests that are
# in-flight. Used for keeping trakc of and grouping together the results of
# requests that were sent to the workers at the same time.
self._update_request_tags = Counter()
self._update_request_tag = 0
self._update_request_results = {}
# A special MetricsLogger object (not exposed to the user) for reducing
# the n results dicts returned by our n Learner workers in case we are on
# the old or hybrid API stack.
self._metrics_logger_old_and_hybrid_stack: Optional[MetricsLogger] = None
if not self.config.enable_env_runner_and_connector_v2:
self._metrics_logger_old_and_hybrid_stack = MetricsLogger()
# TODO (sven): Replace this with call to `self.metrics.peek()`?
# Currently LearnerGroup does not have a metrics object.
[docs] def get_stats(self) -> Dict[str, Any]:
"""Returns the current stats for the input queue for this learner group."""
return {
"learner_group_ts_dropped": self._ts_dropped,
"actor_manager_num_outstanding_async_reqs": (
0
if self.is_local
else self._worker_manager.num_outstanding_async_reqs()
),
}
@property
def is_remote(self) -> bool:
return self.config.num_learners > 0
@property
def is_local(self) -> bool:
return not self.is_remote
[docs] def update_from_batch(
self,
batch: MultiAgentBatch,
*,
timesteps: Optional[Dict[str, Any]] = None,
async_update: bool = False,
return_state: bool = False,
# TODO (sven): Deprecate the following args. They should be extracted from the
# self.config of those specific algorithms that actually require these
# settings.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
# Already deprecated args.
reduce_fn=DEPRECATED_VALUE,
# User kwargs.
**kwargs,
) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""Performs gradient based update(s) on the Learner(s), based on given batch.
Args:
batch: A data batch to use for the update. If there are more
than one Learner workers, the batch is split amongst these and one
shard is sent to each Learner.
async_update: Whether the update request(s) to the Learner workers should be
sent asynchronously. If True, will return NOT the results from the
update on the given data, but all results from prior asynchronous update
requests that have not been returned thus far.
return_state: Whether to include one of the Learner worker's state from
after the update step in the returned results dict (under the
`_rl_module_state_after_update` key). Note that after an update, all
Learner workers' states should be identical, so we use the first
Learner's state here. Useful for avoiding an extra `get_weights()` call,
e.g. for synchronizing EnvRunner weights.
minibatch_size: The minibatch size to use for the update.
num_iters: The number of complete passes over all the sub-batches in the
input multi-agent batch.
Returns:
If `async_update` is False, a dictionary with the reduced results of the
updates from the Learner(s) or a list of dictionaries of results from the
updates from the Learner(s).
If `async_update` is True, a list of list of dictionaries of results, where
the outer list corresponds to separate previous calls to this method, and
the inner list corresponds to the results from each Learner(s). Or if the
results are reduced, a list of dictionaries of the reduced results from each
call to async_update that is ready.
"""
if reduce_fn != DEPRECATED_VALUE:
deprecation_warning(
old="LearnerGroup.update_from_batch(reduce_fn=..)",
new="Learner.metrics.[log_value|log_dict|log_time](key=..., value=..., "
"reduce=[mean|min|max|sum], window=..., ema_coeff=...)",
help="Use the new ray.rllib.utils.metrics.metrics_logger::MetricsLogger"
" API in your custom Learner methods for logging and time-reducing any "
"custom metrics. The central `MetricsLogger` instance is available "
"under `self.metrics` within your custom Learner.",
error=True,
)
return self._update(
batch=batch,
timesteps=timesteps,
async_update=async_update,
return_state=return_state,
minibatch_size=minibatch_size,
num_iters=num_iters,
**kwargs,
)
[docs] def update_from_episodes(
self,
episodes: List[EpisodeType],
*,
timesteps: Optional[Dict[str, Any]] = None,
async_update: bool = False,
return_state: bool = False,
# TODO (sven): Deprecate the following args. They should be extracted from the
# self.config of those specific algorithms that actually require these
# settings.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
# Already deprecated args.
reduce_fn=DEPRECATED_VALUE,
# User kwargs.
**kwargs,
) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""Performs gradient based update(s) on the Learner(s), based on given episodes.
Args:
episodes: A list of Episodes to process and perform the update
for. If there are more than one Learner workers, the list of episodes
is split amongst these and one list shard is sent to each Learner.
async_update: Whether the update request(s) to the Learner workers should be
sent asynchronously. If True, will return NOT the results from the
update on the given data, but all results from prior asynchronous update
requests that have not been returned thus far.
return_state: Whether to include one of the Learner worker's state from
after the update step in the returned results dict (under the
`_rl_module_state_after_update` key). Note that after an update, all
Learner workers' states should be identical, so we use the first
Learner's state here. Useful for avoiding an extra `get_weights()` call,
e.g. for synchronizing EnvRunner weights.
minibatch_size: The minibatch size to use for the update.
num_iters: The number of complete passes over all the sub-batches in the
input multi-agent batch.
Returns:
If async_update is False, a dictionary with the reduced results of the
updates from the Learner(s) or a list of dictionaries of results from the
updates from the Learner(s).
If async_update is True, a list of list of dictionaries of results, where
the outer list corresponds to separate previous calls to this method, and
the inner list corresponds to the results from each Learner(s). Or if the
results are reduced, a list of dictionaries of the reduced results from each
call to async_update that is ready.
"""
if reduce_fn != DEPRECATED_VALUE:
deprecation_warning(
old="LearnerGroup.update_from_episodes(reduce_fn=..)",
new="Learner.metrics.[log_value|log_dict|log_time](key=..., value=..., "
"reduce=[mean|min|max|sum], window=..., ema_coeff=...)",
help="Use the new ray.rllib.utils.metrics.metrics_logger::MetricsLogger"
" API in your custom Learner methods for logging and time-reducing any "
"custom metrics. The central `MetricsLogger` instance is available "
"under `self.metrics` within your custom Learner.",
error=True,
)
return self._update(
episodes=episodes,
timesteps=timesteps,
async_update=async_update,
return_state=return_state,
minibatch_size=minibatch_size,
num_iters=num_iters,
**kwargs,
)
def _update(
self,
*,
batch: Optional[MultiAgentBatch] = None,
episodes: Optional[List[EpisodeType]] = None,
timesteps: Optional[Dict[str, Any]] = None,
async_update: bool = False,
return_state: bool = False,
minibatch_size: Optional[int] = None,
num_iters: int = 1,
**kwargs,
) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
# Define function to be called on all Learner actors (or the local learner).
def _learner_update(
_learner: Learner,
*,
_batch_shard=None,
_episodes_shard=None,
_timesteps=None,
_return_state=False,
_min_total_mini_batches=0,
**_kwargs,
):
# If the batch shard is an `DataIterator` we have an offline
# multi-learner setup and `update_from_iterator` needs to
# handle updating.
if isinstance(_batch_shard, ray.data.DataIterator):
result = _learner.update_from_iterator(
iterator=_batch_shard,
timesteps=_timesteps,
minibatch_size=minibatch_size,
num_iters=num_iters,
**_kwargs,
)
elif _batch_shard is not None:
result = _learner.update_from_batch(
batch=_batch_shard,
timesteps=_timesteps,
minibatch_size=minibatch_size,
num_iters=num_iters,
**_kwargs,
)
else:
result = _learner.update_from_episodes(
episodes=_episodes_shard,
timesteps=_timesteps,
minibatch_size=minibatch_size,
num_iters=num_iters,
min_total_mini_batches=_min_total_mini_batches,
**_kwargs,
)
if _return_state:
result["_rl_module_state_after_update"] = _learner.get_state(
components=COMPONENT_RL_MODULE, inference_only=True
)[COMPONENT_RL_MODULE]
return result
# Local Learner worker: Don't shard batch/episodes, just run data as-is through
# this Learner.
if self.is_local:
if async_update:
raise ValueError(
"Cannot call `update_from_batch(async_update=True)` when running in"
" local mode! Try setting `config.num_learners > 0`."
)
results = [
_learner_update(
_learner=self._learner,
_batch_shard=batch,
_episodes_shard=episodes,
_timesteps=timesteps,
_return_state=return_state,
**kwargs,
)
]
# One or more remote Learners: Shard batch/episodes into equal pieces (roughly
# equal if multi-agent AND episodes) and send each Learner worker one of these
# shards.
else:
# MultiAgentBatch: Shard into equal pieces.
# TODO (sven): The sharder used here destroys - for multi-agent only -
# the relationship of the different agents' timesteps to each other.
# Thus, in case the algorithm requires agent-synchronized data (aka.
# "lockstep"), the `ShardBatchIterator` should not be used.
# Then again, we might move into a world where Learner always
# receives Episodes, never batches.
if isinstance(batch, list) and isinstance(batch[0], ray.data.DataIterator):
partials = [
partial(
_learner_update,
_batch_shard=iterator,
_return_state=(return_state and i == 0),
_timesteps=timesteps,
**kwargs,
)
# Note, `OfflineData` defines exactly as many iterators as there
# are learners.
for i, iterator in enumerate(batch)
]
elif batch is not None:
partials = [
partial(
_learner_update,
_batch_shard=batch_shard,
_return_state=(return_state and i == 0),
_timesteps=timesteps,
**kwargs,
)
for i, batch_shard in enumerate(
ShardBatchIterator(batch, len(self._workers))
)
]
elif isinstance(episodes, list) and isinstance(episodes[0], ObjectRef):
partials = [
partial(
_learner_update,
_episodes_shard=episodes_shard,
_timesteps=timesteps,
_return_state=(return_state and i == 0),
**kwargs,
)
for i, episodes_shard in enumerate(
ShardObjectRefIterator(episodes, len(self._workers))
)
]
# Single- or MultiAgentEpisodes: Shard into equal pieces (only roughly equal
# in case of multi-agent).
else:
from ray.data.iterator import DataIterator
if isinstance(episodes[0], DataIterator):
min_total_mini_batches = 0
partials = [
partial(
_learner_update,
_episodes_shard=episodes_shard,
_min_total_mini_batches=min_total_mini_batches,
)
for episodes_shard in episodes
]
else:
eps_shards = list(
ShardEpisodesIterator(episodes, len(self._workers))
)
# In the multi-agent case AND `minibatch_size` AND num_workers
# > 1, we compute a max iteration counter such that the different
# Learners will not go through a different number of iterations.
min_total_mini_batches = 0
if (
isinstance(episodes[0], MultiAgentEpisode)
and minibatch_size
and len(self._workers) > 1
):
# Find episode w/ the largest single-agent episode in it, then
# compute this single-agent episode's total number of mini
# batches (if we iterated over it num_sgd_iter times with the
# mini batch size).
longest_ts = 0
per_mod_ts = defaultdict(int)
for i, shard in enumerate(eps_shards):
for ma_episode in shard:
for sa_episode in ma_episode.agent_episodes.values():
key = (i, sa_episode.module_id)
per_mod_ts[key] += len(sa_episode)
if per_mod_ts[key] > longest_ts:
longest_ts = per_mod_ts[key]
min_total_mini_batches = self._compute_num_total_mini_batches(
batch_size=longest_ts,
mini_batch_size=minibatch_size,
num_iters=num_iters,
)
partials = [
partial(
_learner_update,
_episodes_shard=eps_shard,
_min_total_mini_batches=min_total_mini_batches,
)
for eps_shard in eps_shards
]
if async_update:
# Retrieve all ready results (kicked off by prior calls to this method).
tags_to_get = []
for tag in self._update_request_tags.keys():
result = self._worker_manager.fetch_ready_async_reqs(
tags=[str(tag)], timeout_seconds=0.0
)
if tag not in self._update_request_results:
self._update_request_results[tag] = result
else:
for r in result:
self._update_request_results[tag].add_result(
r.actor_id, r.result_or_error, tag
)
# Still not done with this `tag` -> skip out early.
if (
self._update_request_tags[tag]
> len(self._update_request_results[tag].result_or_errors)
> 0
):
break
tags_to_get.append(tag)
# Send out new request(s), if there is still capacity on the actors.
update_tag = self._update_request_tag
self._update_request_tag += 1
num_sent_requests = self._worker_manager.foreach_actor_async(
partials, tag=str(update_tag)
)
if num_sent_requests:
self._update_request_tags[update_tag] = num_sent_requests
# Some requests were dropped, record lost ts/data.
if num_sent_requests != len(self._workers):
# assert num_sent_requests == 0, num_sent_requests
factor = 1 - (num_sent_requests / len(self._workers))
# Batch: Measure its length.
if episodes is None:
dropped = len(batch)
# List of Ray ObjectRefs (each object ref is a list of episodes of
# total len=`rollout_fragment_length * num_envs_per_env_runner`)
elif isinstance(episodes[0], ObjectRef):
dropped = (
len(episodes)
* self.config.get_rollout_fragment_length()
* self.config.num_envs_per_env_runner
)
else:
dropped = sum(len(e) for e in episodes)
self._ts_dropped += factor * dropped
# NOTE: There is a strong assumption here that the requests launched to
# learner workers will return at the same time, since they have a
# barrier inside for gradient aggregation. Therefore, results should be
# a list of lists where each inner list should be the length of the
# number of learner workers, if results from an non-blocking update are
# ready.
results = self._get_async_results(tags_to_get)
else:
results = self._get_results(
self._worker_manager.foreach_actor(partials)
)
# If we are on the hybrid API stacks (no EnvRunners), we need to emulate
# the old behavior of returning an already reduced dict (as if we had a
# reduce_fn).
if not self.config.enable_env_runner_and_connector_v2:
# If we are doing an ansync update, we operate on a list (different async
# requests that now have results ready) of lists (n Learner workers) here.
if async_update:
results = tree.flatten_up_to(
[[None] * len(r) for r in results], results
)
self._metrics_logger_old_and_hybrid_stack.merge_and_log_n_dicts(results)
results = self._metrics_logger_old_and_hybrid_stack.reduce(
# We are returning to a client (Algorithm) that does NOT make any
# use of MetricsLogger (or Stats) -> Convert all values to non-Stats
# primitives.
return_stats_obj=False
)
return results
# TODO (sven): Move this into FaultTolerantActorManager?
def _get_results(self, results):
processed_results = []
for result in results:
result_or_error = result.get()
if result.ok:
processed_results.append(result_or_error)
else:
raise result_or_error
return processed_results
def _get_async_results(self, tags_to_get): # results):
"""Get results from the worker manager and group them by tag.
Returns:
A list of lists of results, where each inner list contains all results
for same tags.
"""
# if results is None:
# return []
unprocessed_results = defaultdict(list)
for tag in tags_to_get:
results = self._update_request_results[tag]
for result in results:
result_or_error = result.get()
if result.ok:
if result.tag is None:
raise RuntimeError(
"Cannot call `LearnerGroup._get_async_results()` on "
"untagged async requests!"
)
tag = int(result.tag)
unprocessed_results[tag].append(result_or_error)
if tag in self._update_request_tags:
self._update_request_tags[tag] -= 1
if self._update_request_tags[tag] == 0:
del self._update_request_tags[tag]
del self._update_request_results[tag]
else:
assert False
else:
raise result_or_error
return list(unprocessed_results.values())
[docs] def add_module(
self,
*,
module_id: ModuleID,
module_spec: RLModuleSpec,
config_overrides: Optional[Dict] = None,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiRLModuleSpec:
"""Adds a module to the underlying MultiRLModule.
Changes this Learner's config in order to make this architectural change
permanent wrt. to checkpointing.
Args:
module_id: The ModuleID of the module to be added.
module_spec: The ModuleSpec of the module to be added.
config_overrides: The `AlgorithmConfig` overrides that should apply to
the new Module, if any.
new_should_module_be_updated: An optional sequence of ModuleIDs or a
callable taking ModuleID and SampleBatchType and returning whether the
ModuleID should be updated (trained).
If None, will keep the existing setup in place. RLModules,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
Returns:
The new MultiRLModuleSpec (after the change has been performed).
"""
validate_module_id(module_id, error=True)
# Force-set inference-only = False.
module_spec = copy.deepcopy(module_spec)
module_spec.inference_only = False
results = self.foreach_learner(
func=lambda _learner: _learner.add_module(
module_id=module_id,
module_spec=module_spec,
config_overrides=config_overrides,
new_should_module_be_updated=new_should_module_be_updated,
),
)
marl_spec = self._get_results(results)[0]
# Change our config (AlgorithmConfig) to contain the new Module.
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
# but we'll deprecate config.policies soon anyway.
self.config.policies[module_id] = PolicySpec()
if config_overrides is not None:
self.config.multi_agent(
algorithm_config_overrides_per_module={module_id: config_overrides}
)
self.config.rl_module(rl_module_spec=marl_spec)
if new_should_module_be_updated is not None:
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
return marl_spec
[docs] def remove_module(
self,
module_id: ModuleID,
*,
new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None,
) -> MultiRLModuleSpec:
"""Removes a module from the Learner.
Args:
module_id: The ModuleID of the module to be removed.
new_should_module_be_updated: An optional sequence of ModuleIDs or a
callable taking ModuleID and SampleBatchType and returning whether the
ModuleID should be updated (trained).
If None, will keep the existing setup in place. RLModules,
whose IDs are not in the list (or for which the callable
returns False) will not be updated.
Returns:
The new MultiRLModuleSpec (after the change has been performed).
"""
# Remove all stats from the module from our metrics logger (hybrid API stack
# only), so we don't report results from this module again.
if (
not self.config.enable_env_runner_and_connector_v2
and module_id in self._metrics_logger_old_and_hybrid_stack.stats
):
del self._metrics_logger_old_and_hybrid_stack.stats[module_id]
results = self.foreach_learner(
func=lambda _learner: _learner.remove_module(
module_id=module_id,
new_should_module_be_updated=new_should_module_be_updated,
),
)
marl_spec = self._get_results(results)[0]
# Change self.config to reflect the new architecture.
# TODO (sven): This is a hack to manipulate the AlgorithmConfig directly,
# but we'll deprecate config.policies soon anyway.
del self.config.policies[module_id]
self.config.algorithm_config_overrides_per_module.pop(module_id, None)
if new_should_module_be_updated is not None:
self.config.multi_agent(policies_to_train=new_should_module_be_updated)
self.config.rl_module(rl_module_spec=marl_spec)
return marl_spec
@override(Checkpointable)
def get_state(
self,
components: Optional[Union[str, Collection[str]]] = None,
*,
not_components: Optional[Union[str, Collection[str]]] = None,
**kwargs,
) -> StateDict:
state = {}
if self._check_component(COMPONENT_LEARNER, components, not_components):
if self.is_local:
state[COMPONENT_LEARNER] = self._learner.get_state(
components=self._get_subcomponents(COMPONENT_LEARNER, components),
not_components=self._get_subcomponents(
COMPONENT_LEARNER, not_components
),
**kwargs,
)
else:
worker = self._worker_manager.healthy_actor_ids()[0]
assert len(self._workers) == self._worker_manager.num_healthy_actors()
_comps = self._get_subcomponents(COMPONENT_LEARNER, components)
_not_comps = self._get_subcomponents(COMPONENT_LEARNER, not_components)
results = self._worker_manager.foreach_actor(
lambda w: w.get_state(_comps, not_components=_not_comps, **kwargs),
remote_actor_ids=[worker],
)
state[COMPONENT_LEARNER] = self._get_results(results)[0]
return state
@override(Checkpointable)
def set_state(self, state: StateDict) -> None:
if COMPONENT_LEARNER in state:
if self.is_local:
self._learner.set_state(state[COMPONENT_LEARNER])
else:
state_ref = ray.put(state[COMPONENT_LEARNER])
self.foreach_learner(
lambda _learner, _ref=state_ref: _learner.set_state(ray.get(_ref))
)
[docs] def get_weights(
self, module_ids: Optional[Collection[ModuleID]] = None
) -> StateDict:
"""Convenience method instead of self.get_state(components=...).
Args:
module_ids: An optional collection of ModuleIDs for which to return weights.
If None (default), return weights of all RLModules.
Returns:
The results of
`self.get_state(components='learner/rl_module')['learner']['rl_module']`.
"""
# Return the entire RLModule state (all possible single-agent RLModules).
if module_ids is None:
components = COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE
# Return a subset of the single-agent RLModules.
else:
components = [
"".join(tup)
for tup in itertools.product(
[COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/"],
list(module_ids),
)
]
return self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
[docs] def set_weights(self, weights) -> None:
"""Convenience method instead of self.set_state({'learner': {'rl_module': ..}}).
Args:
weights: The weights dict of the MultiRLModule of a Learner inside this
LearnerGroup.
"""
self.set_state({COMPONENT_LEARNER: {COMPONENT_RL_MODULE: weights}})
@override(Checkpointable)
def get_ctor_args_and_kwargs(self):
return (
(), # *args
{
"config": self.config,
"module_spec": self._module_spec,
}, # **kwargs
)
@override(Checkpointable)
def get_checkpointable_components(self):
# Return the entire ActorManager, if remote. Otherwise, return the
# local worker. Also, don't give the component (Learner) a name ("")
# as it's the only component in this LearnerGroup to be saved.
return [
(
COMPONENT_LEARNER,
self._learner if self.is_local else self._worker_manager,
)
]
[docs] def foreach_learner(
self,
func: Callable[[Learner, Optional[Any]], T],
*,
healthy_only: bool = True,
remote_actor_ids: List[int] = None,
timeout_seconds: Optional[float] = None,
return_obj_refs: bool = False,
mark_healthy: bool = True,
**kwargs,
) -> RemoteCallResults:
"""Calls the given function on each Learner L with the args: (L, \*\*kwargs).
Args:
func: The function to call on each Learner L with args: (L, \*\*kwargs).
healthy_only: If True, applies `func` only to Learner actors currently
tagged "healthy", otherwise to all actors. If `healthy_only=False` and
`mark_healthy=True`, will send `func` to all actors and mark those
actors "healthy" that respond to the request within `timeout_seconds`
and are currently tagged as "unhealthy".
remote_actor_ids: Apply func on a selected set of remote actors. Use None
(default) for all actors.
timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
synchronous execution).
return_obj_refs: whether to return ObjectRef instead of actual results.
Note, for fault tolerance reasons, these returned ObjectRefs should
never be resolved with ray.get() outside of the context of this manager.
mark_healthy: Whether to mark all those actors healthy again that are
currently marked unhealthy AND that returned results from the remote
call (within the given `timeout_seconds`).
Note that actors are NOT set unhealthy, if they simply time out
(only if they return a RayActorError).
Also not that this setting is ignored if `healthy_only=True` (b/c this
setting only affects actors that are currently tagged as unhealthy).
Returns:
A list of size len(Learners) with the return values of all calls to `func`.
"""
if self.is_local:
results = RemoteCallResults()
results.add_result(
None,
ResultOrError(result=func(self._learner, **kwargs)),
None,
)
return results
return self._worker_manager.foreach_actor(
func=partial(func, **kwargs),
healthy_only=healthy_only,
remote_actor_ids=remote_actor_ids,
timeout_seconds=timeout_seconds,
return_obj_refs=return_obj_refs,
mark_healthy=mark_healthy,
)
[docs] def shutdown(self):
"""Shuts down the LearnerGroup."""
if self.is_remote and hasattr(self, "_backend_executor"):
self._backend_executor.shutdown()
self._is_shut_down = True
def __del__(self):
if not self._is_shut_down:
self.shutdown()
@staticmethod
def _compute_num_total_mini_batches(batch_size, mini_batch_size, num_iters):
num_total_mini_batches = 0
rest_size = 0
for i in range(num_iters):
eaten_batch = -rest_size
while eaten_batch < batch_size:
eaten_batch += mini_batch_size
num_total_mini_batches += 1
rest_size = mini_batch_size - (eaten_batch - batch_size)
if rest_size:
num_total_mini_batches -= 1
if rest_size:
num_total_mini_batches += 1
return num_total_mini_batches
@Deprecated(new="LearnerGroup.update_from_batch(async=False)", error=False)
def update(self, *args, **kwargs):
# Just in case, we would like to revert this API retirement, we can do so
# easily.
return self._update(*args, **kwargs, async_update=False)
@Deprecated(new="LearnerGroup.update_from_batch(async=True)", error=False)
def async_update(self, *args, **kwargs):
# Just in case, we would like to revert this API retirement, we can do so
# easily.
return self._update(*args, **kwargs, async_update=True)
@Deprecated(new="LearnerGroup.save_to_path(...)", error=True)
def save_state(self, *args, **kwargs):
pass
@Deprecated(new="LearnerGroup.restore_from_path(...)", error=True)
def load_state(self, *args, **kwargs):
pass
@Deprecated(new="LearnerGroup.load_from_path(path=..., component=...)", error=False)
def load_module_state(
self,
*,
multi_rl_module_ckpt_dir: Optional[str] = None,
modules_to_load: Optional[Set[str]] = None,
rl_module_ckpt_dirs: Optional[Dict[ModuleID, str]] = None,
) -> None:
"""Load the checkpoints of the modules being trained by this LearnerGroup.
`load_module_state` can be used 3 ways:
1. Load a checkpoint for the MultiRLModule being trained by this
LearnerGroup. Limit the modules that are loaded from the checkpoint
by specifying the `modules_to_load` argument.
2. Load the checkpoint(s) for single agent RLModules that
are in the MultiRLModule being trained by this LearnerGroup.
3. Load a checkpoint for the MultiRLModule being trained by this
LearnerGroup and load the checkpoint(s) for single agent RLModules
that are in the MultiRLModule. The checkpoints for the single
agent RLModules take precedence over the module states in the
MultiRLModule checkpoint.
NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is
must be specified. modules_to_load can only be specified if
multi_rl_module_ckpt_dir is specified.
Args:
multi_rl_module_ckpt_dir: The path to the checkpoint for the
MultiRLModule.
modules_to_load: A set of module ids to load from the checkpoint.
rl_module_ckpt_dirs: A mapping from module ids to the path to a
checkpoint for a single agent RLModule.
"""
if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs):
raise ValueError(
"At least one of `multi_rl_module_ckpt_dir` or "
"`rl_module_ckpt_dirs` must be provided!"
)
if multi_rl_module_ckpt_dir:
multi_rl_module_ckpt_dir = pathlib.Path(multi_rl_module_ckpt_dir)
if rl_module_ckpt_dirs:
for module_id, path in rl_module_ckpt_dirs.items():
rl_module_ckpt_dirs[module_id] = pathlib.Path(path)
# MultiRLModule checkpoint is provided.
if multi_rl_module_ckpt_dir:
# Restore the entire MultiRLModule state.
if modules_to_load is None:
self.restore_from_path(
multi_rl_module_ckpt_dir,
component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
)
# Restore individual module IDs.
else:
for module_id in modules_to_load:
self.restore_from_path(
multi_rl_module_ckpt_dir / module_id,
component=(
COMPONENT_LEARNER
+ "/"
+ COMPONENT_RL_MODULE
+ "/"
+ module_id
),
)
if rl_module_ckpt_dirs:
for module_id, path in rl_module_ckpt_dirs.items():
self.restore_from_path(
path,
component=(
COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id
),
)