Source code for ray.rllib.core.learner.learner_group

import pathlib
from collections import defaultdict, Counter
import copy
from functools import partial
import itertools
from typing import (
    Any,
    Callable,
    Collection,
    Dict,
    List,
    Optional,
    Set,
    Type,
    TYPE_CHECKING,
    Union,
)

import tree  # pip install dm_tree

import ray
from ray import ObjectRef
from ray.rllib.core import COMPONENT_LEARNER, COMPONENT_RL_MODULE
from ray.rllib.core.learner.learner import Learner
from ray.rllib.core.rl_module import validate_module_id
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.actor_manager import (
    FaultTolerantActorManager,
    RemoteCallResults,
    ResultOrError,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.deprecation import (
    Deprecated,
    DEPRECATED_VALUE,
    deprecation_warning,
)
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.minibatch_utils import (
    ShardBatchIterator,
    ShardEpisodesIterator,
    ShardObjectRefIterator,
)
from ray.rllib.utils.typing import (
    EpisodeType,
    ModuleID,
    RLModuleSpecType,
    ShouldModuleBeUpdatedFn,
    StateDict,
    T,
)
from ray.train._internal.backend_executor import BackendExecutor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.rllib.algorithms.algorithm_config import AlgorithmConfig


def _get_backend_config(learner_class: Type[Learner]) -> str:
    if learner_class.framework == "torch":
        from ray.train.torch import TorchConfig

        backend_config = TorchConfig()
    elif learner_class.framework == "tf2":
        from ray.train.tensorflow import TensorflowConfig

        backend_config = TensorflowConfig()
    else:
        raise ValueError(
            "`learner_class.framework` must be either 'torch' or 'tf2' (but is "
            f"{learner_class.framework}!"
        )

    return backend_config


[docs]@PublicAPI(stability="alpha") class LearnerGroup(Checkpointable): """Coordinator of n (possibly remote) Learner workers. Each Learner worker has a copy of the RLModule, the loss function(s), and one or more optimizers. """
[docs] def __init__( self, *, config: "AlgorithmConfig", # TODO (sven): Rename into `rl_module_spec`. module_spec: Optional[RLModuleSpecType] = None, ): """Initializes a LearnerGroup instance. Args: config: The AlgorithmConfig object to use to configure this LearnerGroup. Call the `learners(num_learners=...)` method on your config to specify the number of learner workers to use. Call the same method with arguments `num_cpus_per_learner` and/or `num_gpus_per_learner` to configure the compute used by each Learner worker in this LearnerGroup. Call the `training(learner_class=...)` method on your config to specify, which exact Learner class to use. Call the `rl_module(rl_module_spec=...)` method on your config to set up the specifics for your RLModule to be used in each Learner. module_spec: If not already specified in `config`, a separate overriding RLModuleSpec may be provided via this argument. """ self.config = config.copy(copy_frozen=False) self._module_spec = module_spec learner_class = self.config.learner_class module_spec = module_spec or self.config.get_multi_rl_module_spec() self._learner = None self._workers = None # If a user calls self.shutdown() on their own then this flag is set to true. # When del is called the backend executor isn't shutdown twice if this flag is # true. the backend executor would otherwise log a warning to the console from # ray train. self._is_shut_down = False # How many timesteps had to be dropped due to a full input queue? self._ts_dropped = 0 # A single local Learner. if not self.is_remote: self._learner = learner_class(config=config, module_spec=module_spec) self._learner.build() self._worker_manager = None # N remote Learner workers. else: backend_config = _get_backend_config(learner_class) # TODO (sven): Can't set both `num_cpus_per_learner`>1 and # `num_gpus_per_learner`>0! Users must set one or the other due # to issues with placement group fragmentation. See # https://github.com/ray-project/ray/issues/35409 for more details. num_cpus_per_learner = ( self.config.num_cpus_per_learner if not self.config.num_gpus_per_learner else 0 ) num_gpus_per_learner = self.config.num_gpus_per_learner resources_per_learner = { "CPU": num_cpus_per_learner, "GPU": num_gpus_per_learner, } backend_executor = BackendExecutor( backend_config=backend_config, num_workers=self.config.num_learners, resources_per_worker=resources_per_learner, max_retries=0, ) backend_executor.start( train_cls=learner_class, train_cls_kwargs={ "config": config, "module_spec": module_spec, }, ) self._backend_executor = backend_executor self._workers = [w.actor for w in backend_executor.worker_group.workers] # Run the neural network building code on remote workers. ray.get([w.build.remote() for w in self._workers]) self._worker_manager = FaultTolerantActorManager( self._workers, # TODO (sven): This probably works even without any restriction # (allowing for any arbitrary number of requests in-flight). Test with # 3 first, then with unlimited, and if both show the same behavior on # an async algo, remove this restriction entirely. max_remote_requests_in_flight_per_actor=3, ) # Counters for the tags for asynchronous update requests that are # in-flight. Used for keeping trakc of and grouping together the results of # requests that were sent to the workers at the same time. self._update_request_tags = Counter() self._update_request_tag = 0 self._update_request_results = {} # A special MetricsLogger object (not exposed to the user) for reducing # the n results dicts returned by our n Learner workers in case we are on # the old or hybrid API stack. self._metrics_logger_old_and_hybrid_stack: Optional[MetricsLogger] = None if not self.config.enable_env_runner_and_connector_v2: self._metrics_logger_old_and_hybrid_stack = MetricsLogger()
# TODO (sven): Replace this with call to `self.metrics.peek()`? # Currently LearnerGroup does not have a metrics object.
[docs] def get_stats(self) -> Dict[str, Any]: """Returns the current stats for the input queue for this learner group.""" return { "learner_group_ts_dropped": self._ts_dropped, "actor_manager_num_outstanding_async_reqs": ( 0 if self.is_local else self._worker_manager.num_outstanding_async_reqs() ), }
@property def is_remote(self) -> bool: return self.config.num_learners > 0 @property def is_local(self) -> bool: return not self.is_remote
[docs] def update_from_batch( self, batch: MultiAgentBatch, *, timesteps: Optional[Dict[str, Any]] = None, async_update: bool = False, return_state: bool = False, # TODO (sven): Deprecate the following args. They should be extracted from the # self.config of those specific algorithms that actually require these # settings. minibatch_size: Optional[int] = None, num_iters: int = 1, # Already deprecated args. reduce_fn=DEPRECATED_VALUE, # User kwargs. **kwargs, ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """Performs gradient based update(s) on the Learner(s), based on given batch. Args: batch: A data batch to use for the update. If there are more than one Learner workers, the batch is split amongst these and one shard is sent to each Learner. async_update: Whether the update request(s) to the Learner workers should be sent asynchronously. If True, will return NOT the results from the update on the given data, but all results from prior asynchronous update requests that have not been returned thus far. return_state: Whether to include one of the Learner worker's state from after the update step in the returned results dict (under the `_rl_module_state_after_update` key). Note that after an update, all Learner workers' states should be identical, so we use the first Learner's state here. Useful for avoiding an extra `get_weights()` call, e.g. for synchronizing EnvRunner weights. minibatch_size: The minibatch size to use for the update. num_iters: The number of complete passes over all the sub-batches in the input multi-agent batch. Returns: If `async_update` is False, a dictionary with the reduced results of the updates from the Learner(s) or a list of dictionaries of results from the updates from the Learner(s). If `async_update` is True, a list of list of dictionaries of results, where the outer list corresponds to separate previous calls to this method, and the inner list corresponds to the results from each Learner(s). Or if the results are reduced, a list of dictionaries of the reduced results from each call to async_update that is ready. """ if reduce_fn != DEPRECATED_VALUE: deprecation_warning( old="LearnerGroup.update_from_batch(reduce_fn=..)", new="Learner.metrics.[log_value|log_dict|log_time](key=..., value=..., " "reduce=[mean|min|max|sum], window=..., ema_coeff=...)", help="Use the new ray.rllib.utils.metrics.metrics_logger::MetricsLogger" " API in your custom Learner methods for logging and time-reducing any " "custom metrics. The central `MetricsLogger` instance is available " "under `self.metrics` within your custom Learner.", error=True, ) return self._update( batch=batch, timesteps=timesteps, async_update=async_update, return_state=return_state, minibatch_size=minibatch_size, num_iters=num_iters, **kwargs, )
[docs] def update_from_episodes( self, episodes: List[EpisodeType], *, timesteps: Optional[Dict[str, Any]] = None, async_update: bool = False, return_state: bool = False, # TODO (sven): Deprecate the following args. They should be extracted from the # self.config of those specific algorithms that actually require these # settings. minibatch_size: Optional[int] = None, num_iters: int = 1, # Already deprecated args. reduce_fn=DEPRECATED_VALUE, # User kwargs. **kwargs, ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: """Performs gradient based update(s) on the Learner(s), based on given episodes. Args: episodes: A list of Episodes to process and perform the update for. If there are more than one Learner workers, the list of episodes is split amongst these and one list shard is sent to each Learner. async_update: Whether the update request(s) to the Learner workers should be sent asynchronously. If True, will return NOT the results from the update on the given data, but all results from prior asynchronous update requests that have not been returned thus far. return_state: Whether to include one of the Learner worker's state from after the update step in the returned results dict (under the `_rl_module_state_after_update` key). Note that after an update, all Learner workers' states should be identical, so we use the first Learner's state here. Useful for avoiding an extra `get_weights()` call, e.g. for synchronizing EnvRunner weights. minibatch_size: The minibatch size to use for the update. num_iters: The number of complete passes over all the sub-batches in the input multi-agent batch. Returns: If async_update is False, a dictionary with the reduced results of the updates from the Learner(s) or a list of dictionaries of results from the updates from the Learner(s). If async_update is True, a list of list of dictionaries of results, where the outer list corresponds to separate previous calls to this method, and the inner list corresponds to the results from each Learner(s). Or if the results are reduced, a list of dictionaries of the reduced results from each call to async_update that is ready. """ if reduce_fn != DEPRECATED_VALUE: deprecation_warning( old="LearnerGroup.update_from_episodes(reduce_fn=..)", new="Learner.metrics.[log_value|log_dict|log_time](key=..., value=..., " "reduce=[mean|min|max|sum], window=..., ema_coeff=...)", help="Use the new ray.rllib.utils.metrics.metrics_logger::MetricsLogger" " API in your custom Learner methods for logging and time-reducing any " "custom metrics. The central `MetricsLogger` instance is available " "under `self.metrics` within your custom Learner.", error=True, ) return self._update( episodes=episodes, timesteps=timesteps, async_update=async_update, return_state=return_state, minibatch_size=minibatch_size, num_iters=num_iters, **kwargs, )
def _update( self, *, batch: Optional[MultiAgentBatch] = None, episodes: Optional[List[EpisodeType]] = None, timesteps: Optional[Dict[str, Any]] = None, async_update: bool = False, return_state: bool = False, minibatch_size: Optional[int] = None, num_iters: int = 1, **kwargs, ) -> Union[Dict[str, Any], List[Dict[str, Any]], List[List[Dict[str, Any]]]]: # Define function to be called on all Learner actors (or the local learner). def _learner_update( _learner: Learner, *, _batch_shard=None, _episodes_shard=None, _timesteps=None, _return_state=False, _min_total_mini_batches=0, **_kwargs, ): # If the batch shard is an `DataIterator` we have an offline # multi-learner setup and `update_from_iterator` needs to # handle updating. if isinstance(_batch_shard, ray.data.DataIterator): result = _learner.update_from_iterator( iterator=_batch_shard, timesteps=_timesteps, minibatch_size=minibatch_size, num_iters=num_iters, **_kwargs, ) elif _batch_shard is not None: result = _learner.update_from_batch( batch=_batch_shard, timesteps=_timesteps, minibatch_size=minibatch_size, num_iters=num_iters, **_kwargs, ) else: result = _learner.update_from_episodes( episodes=_episodes_shard, timesteps=_timesteps, minibatch_size=minibatch_size, num_iters=num_iters, min_total_mini_batches=_min_total_mini_batches, **_kwargs, ) if _return_state: result["_rl_module_state_after_update"] = _learner.get_state( components=COMPONENT_RL_MODULE, inference_only=True )[COMPONENT_RL_MODULE] return result # Local Learner worker: Don't shard batch/episodes, just run data as-is through # this Learner. if self.is_local: if async_update: raise ValueError( "Cannot call `update_from_batch(async_update=True)` when running in" " local mode! Try setting `config.num_learners > 0`." ) results = [ _learner_update( _learner=self._learner, _batch_shard=batch, _episodes_shard=episodes, _timesteps=timesteps, _return_state=return_state, **kwargs, ) ] # One or more remote Learners: Shard batch/episodes into equal pieces (roughly # equal if multi-agent AND episodes) and send each Learner worker one of these # shards. else: # MultiAgentBatch: Shard into equal pieces. # TODO (sven): The sharder used here destroys - for multi-agent only - # the relationship of the different agents' timesteps to each other. # Thus, in case the algorithm requires agent-synchronized data (aka. # "lockstep"), the `ShardBatchIterator` should not be used. # Then again, we might move into a world where Learner always # receives Episodes, never batches. if isinstance(batch, list) and isinstance(batch[0], ray.data.DataIterator): partials = [ partial( _learner_update, _batch_shard=iterator, _return_state=(return_state and i == 0), _timesteps=timesteps, **kwargs, ) # Note, `OfflineData` defines exactly as many iterators as there # are learners. for i, iterator in enumerate(batch) ] elif batch is not None: partials = [ partial( _learner_update, _batch_shard=batch_shard, _return_state=(return_state and i == 0), _timesteps=timesteps, **kwargs, ) for i, batch_shard in enumerate( ShardBatchIterator(batch, len(self._workers)) ) ] elif isinstance(episodes, list) and isinstance(episodes[0], ObjectRef): partials = [ partial( _learner_update, _episodes_shard=episodes_shard, _timesteps=timesteps, _return_state=(return_state and i == 0), **kwargs, ) for i, episodes_shard in enumerate( ShardObjectRefIterator(episodes, len(self._workers)) ) ] # Single- or MultiAgentEpisodes: Shard into equal pieces (only roughly equal # in case of multi-agent). else: from ray.data.iterator import DataIterator if isinstance(episodes[0], DataIterator): min_total_mini_batches = 0 partials = [ partial( _learner_update, _episodes_shard=episodes_shard, _min_total_mini_batches=min_total_mini_batches, ) for episodes_shard in episodes ] else: eps_shards = list( ShardEpisodesIterator(episodes, len(self._workers)) ) # In the multi-agent case AND `minibatch_size` AND num_workers # > 1, we compute a max iteration counter such that the different # Learners will not go through a different number of iterations. min_total_mini_batches = 0 if ( isinstance(episodes[0], MultiAgentEpisode) and minibatch_size and len(self._workers) > 1 ): # Find episode w/ the largest single-agent episode in it, then # compute this single-agent episode's total number of mini # batches (if we iterated over it num_sgd_iter times with the # mini batch size). longest_ts = 0 per_mod_ts = defaultdict(int) for i, shard in enumerate(eps_shards): for ma_episode in shard: for sa_episode in ma_episode.agent_episodes.values(): key = (i, sa_episode.module_id) per_mod_ts[key] += len(sa_episode) if per_mod_ts[key] > longest_ts: longest_ts = per_mod_ts[key] min_total_mini_batches = self._compute_num_total_mini_batches( batch_size=longest_ts, mini_batch_size=minibatch_size, num_iters=num_iters, ) partials = [ partial( _learner_update, _episodes_shard=eps_shard, _min_total_mini_batches=min_total_mini_batches, ) for eps_shard in eps_shards ] if async_update: # Retrieve all ready results (kicked off by prior calls to this method). tags_to_get = [] for tag in self._update_request_tags.keys(): result = self._worker_manager.fetch_ready_async_reqs( tags=[str(tag)], timeout_seconds=0.0 ) if tag not in self._update_request_results: self._update_request_results[tag] = result else: for r in result: self._update_request_results[tag].add_result( r.actor_id, r.result_or_error, tag ) # Still not done with this `tag` -> skip out early. if ( self._update_request_tags[tag] > len(self._update_request_results[tag].result_or_errors) > 0 ): break tags_to_get.append(tag) # Send out new request(s), if there is still capacity on the actors. update_tag = self._update_request_tag self._update_request_tag += 1 num_sent_requests = self._worker_manager.foreach_actor_async( partials, tag=str(update_tag) ) if num_sent_requests: self._update_request_tags[update_tag] = num_sent_requests # Some requests were dropped, record lost ts/data. if num_sent_requests != len(self._workers): # assert num_sent_requests == 0, num_sent_requests factor = 1 - (num_sent_requests / len(self._workers)) # Batch: Measure its length. if episodes is None: dropped = len(batch) # List of Ray ObjectRefs (each object ref is a list of episodes of # total len=`rollout_fragment_length * num_envs_per_env_runner`) elif isinstance(episodes[0], ObjectRef): dropped = ( len(episodes) * self.config.get_rollout_fragment_length() * self.config.num_envs_per_env_runner ) else: dropped = sum(len(e) for e in episodes) self._ts_dropped += factor * dropped # NOTE: There is a strong assumption here that the requests launched to # learner workers will return at the same time, since they have a # barrier inside for gradient aggregation. Therefore, results should be # a list of lists where each inner list should be the length of the # number of learner workers, if results from an non-blocking update are # ready. results = self._get_async_results(tags_to_get) else: results = self._get_results( self._worker_manager.foreach_actor(partials) ) # If we are on the hybrid API stacks (no EnvRunners), we need to emulate # the old behavior of returning an already reduced dict (as if we had a # reduce_fn). if not self.config.enable_env_runner_and_connector_v2: # If we are doing an ansync update, we operate on a list (different async # requests that now have results ready) of lists (n Learner workers) here. if async_update: results = tree.flatten_up_to( [[None] * len(r) for r in results], results ) self._metrics_logger_old_and_hybrid_stack.merge_and_log_n_dicts(results) results = self._metrics_logger_old_and_hybrid_stack.reduce( # We are returning to a client (Algorithm) that does NOT make any # use of MetricsLogger (or Stats) -> Convert all values to non-Stats # primitives. return_stats_obj=False ) return results # TODO (sven): Move this into FaultTolerantActorManager? def _get_results(self, results): processed_results = [] for result in results: result_or_error = result.get() if result.ok: processed_results.append(result_or_error) else: raise result_or_error return processed_results def _get_async_results(self, tags_to_get): # results): """Get results from the worker manager and group them by tag. Returns: A list of lists of results, where each inner list contains all results for same tags. """ # if results is None: # return [] unprocessed_results = defaultdict(list) for tag in tags_to_get: results = self._update_request_results[tag] for result in results: result_or_error = result.get() if result.ok: if result.tag is None: raise RuntimeError( "Cannot call `LearnerGroup._get_async_results()` on " "untagged async requests!" ) tag = int(result.tag) unprocessed_results[tag].append(result_or_error) if tag in self._update_request_tags: self._update_request_tags[tag] -= 1 if self._update_request_tags[tag] == 0: del self._update_request_tags[tag] del self._update_request_results[tag] else: assert False else: raise result_or_error return list(unprocessed_results.values())
[docs] def add_module( self, *, module_id: ModuleID, module_spec: RLModuleSpec, config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, ) -> MultiRLModuleSpec: """Adds a module to the underlying MultiRLModule. Changes this Learner's config in order to make this architectural change permanent wrt. to checkpointing. Args: module_id: The ModuleID of the module to be added. module_spec: The ModuleSpec of the module to be added. config_overrides: The `AlgorithmConfig` overrides that should apply to the new Module, if any. new_should_module_be_updated: An optional sequence of ModuleIDs or a callable taking ModuleID and SampleBatchType and returning whether the ModuleID should be updated (trained). If None, will keep the existing setup in place. RLModules, whose IDs are not in the list (or for which the callable returns False) will not be updated. Returns: The new MultiRLModuleSpec (after the change has been performed). """ validate_module_id(module_id, error=True) # Force-set inference-only = False. module_spec = copy.deepcopy(module_spec) module_spec.inference_only = False results = self.foreach_learner( func=lambda _learner: _learner.add_module( module_id=module_id, module_spec=module_spec, config_overrides=config_overrides, new_should_module_be_updated=new_should_module_be_updated, ), ) marl_spec = self._get_results(results)[0] # Change our config (AlgorithmConfig) to contain the new Module. # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, # but we'll deprecate config.policies soon anyway. self.config.policies[module_id] = PolicySpec() if config_overrides is not None: self.config.multi_agent( algorithm_config_overrides_per_module={module_id: config_overrides} ) self.config.rl_module(rl_module_spec=marl_spec) if new_should_module_be_updated is not None: self.config.multi_agent(policies_to_train=new_should_module_be_updated) return marl_spec
[docs] def remove_module( self, module_id: ModuleID, *, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, ) -> MultiRLModuleSpec: """Removes a module from the Learner. Args: module_id: The ModuleID of the module to be removed. new_should_module_be_updated: An optional sequence of ModuleIDs or a callable taking ModuleID and SampleBatchType and returning whether the ModuleID should be updated (trained). If None, will keep the existing setup in place. RLModules, whose IDs are not in the list (or for which the callable returns False) will not be updated. Returns: The new MultiRLModuleSpec (after the change has been performed). """ # Remove all stats from the module from our metrics logger (hybrid API stack # only), so we don't report results from this module again. if ( not self.config.enable_env_runner_and_connector_v2 and module_id in self._metrics_logger_old_and_hybrid_stack.stats ): del self._metrics_logger_old_and_hybrid_stack.stats[module_id] results = self.foreach_learner( func=lambda _learner: _learner.remove_module( module_id=module_id, new_should_module_be_updated=new_should_module_be_updated, ), ) marl_spec = self._get_results(results)[0] # Change self.config to reflect the new architecture. # TODO (sven): This is a hack to manipulate the AlgorithmConfig directly, # but we'll deprecate config.policies soon anyway. del self.config.policies[module_id] self.config.algorithm_config_overrides_per_module.pop(module_id, None) if new_should_module_be_updated is not None: self.config.multi_agent(policies_to_train=new_should_module_be_updated) self.config.rl_module(rl_module_spec=marl_spec) return marl_spec
@override(Checkpointable) def get_state( self, components: Optional[Union[str, Collection[str]]] = None, *, not_components: Optional[Union[str, Collection[str]]] = None, **kwargs, ) -> StateDict: state = {} if self._check_component(COMPONENT_LEARNER, components, not_components): if self.is_local: state[COMPONENT_LEARNER] = self._learner.get_state( components=self._get_subcomponents(COMPONENT_LEARNER, components), not_components=self._get_subcomponents( COMPONENT_LEARNER, not_components ), **kwargs, ) else: worker = self._worker_manager.healthy_actor_ids()[0] assert len(self._workers) == self._worker_manager.num_healthy_actors() _comps = self._get_subcomponents(COMPONENT_LEARNER, components) _not_comps = self._get_subcomponents(COMPONENT_LEARNER, not_components) results = self._worker_manager.foreach_actor( lambda w: w.get_state(_comps, not_components=_not_comps, **kwargs), remote_actor_ids=[worker], ) state[COMPONENT_LEARNER] = self._get_results(results)[0] return state @override(Checkpointable) def set_state(self, state: StateDict) -> None: if COMPONENT_LEARNER in state: if self.is_local: self._learner.set_state(state[COMPONENT_LEARNER]) else: state_ref = ray.put(state[COMPONENT_LEARNER]) self.foreach_learner( lambda _learner, _ref=state_ref: _learner.set_state(ray.get(_ref)) )
[docs] def get_weights( self, module_ids: Optional[Collection[ModuleID]] = None ) -> StateDict: """Convenience method instead of self.get_state(components=...). Args: module_ids: An optional collection of ModuleIDs for which to return weights. If None (default), return weights of all RLModules. Returns: The results of `self.get_state(components='learner/rl_module')['learner']['rl_module']`. """ # Return the entire RLModule state (all possible single-agent RLModules). if module_ids is None: components = COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE # Return a subset of the single-agent RLModules. else: components = [ "".join(tup) for tup in itertools.product( [COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/"], list(module_ids), ) ] return self.get_state(components)[COMPONENT_LEARNER][COMPONENT_RL_MODULE]
[docs] def set_weights(self, weights) -> None: """Convenience method instead of self.set_state({'learner': {'rl_module': ..}}). Args: weights: The weights dict of the MultiRLModule of a Learner inside this LearnerGroup. """ self.set_state({COMPONENT_LEARNER: {COMPONENT_RL_MODULE: weights}})
@override(Checkpointable) def get_ctor_args_and_kwargs(self): return ( (), # *args { "config": self.config, "module_spec": self._module_spec, }, # **kwargs ) @override(Checkpointable) def get_checkpointable_components(self): # Return the entire ActorManager, if remote. Otherwise, return the # local worker. Also, don't give the component (Learner) a name ("") # as it's the only component in this LearnerGroup to be saved. return [ ( COMPONENT_LEARNER, self._learner if self.is_local else self._worker_manager, ) ]
[docs] def foreach_learner( self, func: Callable[[Learner, Optional[Any]], T], *, healthy_only: bool = True, remote_actor_ids: List[int] = None, timeout_seconds: Optional[float] = None, return_obj_refs: bool = False, mark_healthy: bool = True, **kwargs, ) -> RemoteCallResults: """Calls the given function on each Learner L with the args: (L, \*\*kwargs). Args: func: The function to call on each Learner L with args: (L, \*\*kwargs). healthy_only: If True, applies `func` only to Learner actors currently tagged "healthy", otherwise to all actors. If `healthy_only=False` and `mark_healthy=True`, will send `func` to all actors and mark those actors "healthy" that respond to the request within `timeout_seconds` and are currently tagged as "unhealthy". remote_actor_ids: Apply func on a selected set of remote actors. Use None (default) for all actors. timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for fire-and-forget. Set this to None (default) to wait infinitely (i.e. for synchronous execution). return_obj_refs: whether to return ObjectRef instead of actual results. Note, for fault tolerance reasons, these returned ObjectRefs should never be resolved with ray.get() outside of the context of this manager. mark_healthy: Whether to mark all those actors healthy again that are currently marked unhealthy AND that returned results from the remote call (within the given `timeout_seconds`). Note that actors are NOT set unhealthy, if they simply time out (only if they return a RayActorError). Also not that this setting is ignored if `healthy_only=True` (b/c this setting only affects actors that are currently tagged as unhealthy). Returns: A list of size len(Learners) with the return values of all calls to `func`. """ if self.is_local: results = RemoteCallResults() results.add_result( None, ResultOrError(result=func(self._learner, **kwargs)), None, ) return results return self._worker_manager.foreach_actor( func=partial(func, **kwargs), healthy_only=healthy_only, remote_actor_ids=remote_actor_ids, timeout_seconds=timeout_seconds, return_obj_refs=return_obj_refs, mark_healthy=mark_healthy, )
[docs] def shutdown(self): """Shuts down the LearnerGroup.""" if self.is_remote and hasattr(self, "_backend_executor"): self._backend_executor.shutdown() self._is_shut_down = True
def __del__(self): if not self._is_shut_down: self.shutdown() @staticmethod def _compute_num_total_mini_batches(batch_size, mini_batch_size, num_iters): num_total_mini_batches = 0 rest_size = 0 for i in range(num_iters): eaten_batch = -rest_size while eaten_batch < batch_size: eaten_batch += mini_batch_size num_total_mini_batches += 1 rest_size = mini_batch_size - (eaten_batch - batch_size) if rest_size: num_total_mini_batches -= 1 if rest_size: num_total_mini_batches += 1 return num_total_mini_batches @Deprecated(new="LearnerGroup.update_from_batch(async=False)", error=False) def update(self, *args, **kwargs): # Just in case, we would like to revert this API retirement, we can do so # easily. return self._update(*args, **kwargs, async_update=False) @Deprecated(new="LearnerGroup.update_from_batch(async=True)", error=False) def async_update(self, *args, **kwargs): # Just in case, we would like to revert this API retirement, we can do so # easily. return self._update(*args, **kwargs, async_update=True) @Deprecated(new="LearnerGroup.save_to_path(...)", error=True) def save_state(self, *args, **kwargs): pass @Deprecated(new="LearnerGroup.restore_from_path(...)", error=True) def load_state(self, *args, **kwargs): pass @Deprecated(new="LearnerGroup.load_from_path(path=..., component=...)", error=False) def load_module_state( self, *, multi_rl_module_ckpt_dir: Optional[str] = None, modules_to_load: Optional[Set[str]] = None, rl_module_ckpt_dirs: Optional[Dict[ModuleID, str]] = None, ) -> None: """Load the checkpoints of the modules being trained by this LearnerGroup. `load_module_state` can be used 3 ways: 1. Load a checkpoint for the MultiRLModule being trained by this LearnerGroup. Limit the modules that are loaded from the checkpoint by specifying the `modules_to_load` argument. 2. Load the checkpoint(s) for single agent RLModules that are in the MultiRLModule being trained by this LearnerGroup. 3. Load a checkpoint for the MultiRLModule being trained by this LearnerGroup and load the checkpoint(s) for single agent RLModules that are in the MultiRLModule. The checkpoints for the single agent RLModules take precedence over the module states in the MultiRLModule checkpoint. NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is must be specified. modules_to_load can only be specified if multi_rl_module_ckpt_dir is specified. Args: multi_rl_module_ckpt_dir: The path to the checkpoint for the MultiRLModule. modules_to_load: A set of module ids to load from the checkpoint. rl_module_ckpt_dirs: A mapping from module ids to the path to a checkpoint for a single agent RLModule. """ if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs): raise ValueError( "At least one of `multi_rl_module_ckpt_dir` or " "`rl_module_ckpt_dirs` must be provided!" ) if multi_rl_module_ckpt_dir: multi_rl_module_ckpt_dir = pathlib.Path(multi_rl_module_ckpt_dir) if rl_module_ckpt_dirs: for module_id, path in rl_module_ckpt_dirs.items(): rl_module_ckpt_dirs[module_id] = pathlib.Path(path) # MultiRLModule checkpoint is provided. if multi_rl_module_ckpt_dir: # Restore the entire MultiRLModule state. if modules_to_load is None: self.restore_from_path( multi_rl_module_ckpt_dir, component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, ) # Restore individual module IDs. else: for module_id in modules_to_load: self.restore_from_path( multi_rl_module_ckpt_dir / module_id, component=( COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id ), ) if rl_module_ckpt_dirs: for module_id, path in rl_module_ckpt_dirs.items(): self.restore_from_path( path, component=( COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id ), )