Source code for ray.rllib.env.multi_agent_env_runner

from collections import defaultdict
from functools import partial
import logging
from typing import Collection, DefaultDict, Dict, List, Optional, Union

import gymnasium as gym

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.callbacks.utils import make_callback
from ray.rllib.core import (
    COMPONENT_ENV_TO_MODULE_CONNECTOR,
    COMPONENT_MODULE_TO_ENV_CONNECTOR,
    COMPONENT_RL_MODULE,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.env_runner import EnvRunner, ENV_STEP_FAILURE
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.env.utils import _gym_env_creator
from ray.rllib.utils import force_list
from ray.rllib.utils.annotations import override
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.deprecation import Deprecated
from ray.rllib.utils.framework import get_device, try_import_torch
from ray.rllib.utils.metrics import (
    EPISODE_DURATION_SEC_MEAN,
    EPISODE_LEN_MAX,
    EPISODE_LEN_MEAN,
    EPISODE_LEN_MIN,
    EPISODE_RETURN_MAX,
    EPISODE_RETURN_MEAN,
    EPISODE_RETURN_MIN,
    NUM_AGENT_STEPS_SAMPLED,
    NUM_AGENT_STEPS_SAMPLED_LIFETIME,
    NUM_ENV_STEPS_SAMPLED,
    NUM_ENV_STEPS_SAMPLED_LIFETIME,
    NUM_EPISODES,
    NUM_EPISODES_LIFETIME,
    NUM_MODULE_STEPS_SAMPLED,
    NUM_MODULE_STEPS_SAMPLED_LIFETIME,
    WEIGHTS_SEQ_NO,
)
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.pre_checks.env import check_multiagent_environments
from ray.rllib.utils.typing import EpisodeID, ModelWeights, ResultDict, StateDict
from ray.tune.registry import ENV_CREATOR, _global_registry
from ray.util.annotations import PublicAPI

torch, _ = try_import_torch()
logger = logging.getLogger("ray.rllib")


# TODO (sven): As soon as RolloutWorker is no longer supported, make `EnvRunner` itself
#  a Checkpointable. Currently, only some of its subclasses are Checkpointables.
[docs] @PublicAPI(stability="alpha") class MultiAgentEnvRunner(EnvRunner, Checkpointable): """The genetic environment runner for the multi-agent case."""
[docs] @override(EnvRunner) def __init__(self, config: AlgorithmConfig, **kwargs): """Initializes a MultiAgentEnvRunner instance. Args: config: An `AlgorithmConfig` object containing all settings needed to build this `EnvRunner` class. """ super().__init__(config=config) # Raise an Error, if the provided config is not a multi-agent one. if not self.config.is_multi_agent: raise ValueError( f"Cannot use this EnvRunner class ({type(self).__name__}), if your " "setup is not multi-agent! Try adding multi-agent information to your " "AlgorithmConfig via calling the `config.multi_agent(policies=..., " "policy_mapping_fn=...)`." ) # Get the worker index on which this instance is running. self.worker_index: int = kwargs.get("worker_index") self.tune_trial_id: str = kwargs.get("tune_trial_id") # Set up all metrics-related structures and counters. self.metrics: Optional[MetricsLogger] = None self._setup_metrics() # Create our callbacks object. self._callbacks = [cls() for cls in force_list(self.config.callbacks_class)] # Set device. self._device = get_device( self.config, 0 if not self.worker_index else self.config.num_gpus_per_env_runner, ) # Create the vectorized gymnasium env. self.env: Optional[gym.Wrapper] = None self.num_envs: int = 0 self.make_env() # Create the env-to-module connector pipeline. self._env_to_module = self.config.build_env_to_module_connector( self.env.unwrapped, device=self._device ) # Cached env-to-module results taken at the end of a `_sample_timesteps()` # call to make sure the final observation (before an episode cut) gets properly # processed (and maybe postprocessed and re-stored into the episode). # For example, if we had a connector that normalizes observations and directly # re-inserts these new obs back into the episode, the last observation in each # sample call would NOT be processed, which could be very harmful in cases, # in which value function bootstrapping of those (truncation) observations is # required in the learning step. self._cached_to_module = None # Construct the MultiRLModule. self.module: Optional[MultiRLModule] = None self.make_module() # Create the module-to-env connector pipeline. self._module_to_env = self.config.build_module_to_env_connector( self.env.unwrapped ) self._needs_initial_reset: bool = True self._episode: Optional[MultiAgentEpisode] = None self._shared_data = None self._weights_seq_no: int = 0
[docs] @override(EnvRunner) def sample( self, *, num_timesteps: int = None, num_episodes: int = None, explore: bool = None, random_actions: bool = False, force_reset: bool = False, ) -> List[MultiAgentEpisode]: """Runs and returns a sample (n timesteps or m episodes) on the env(s). Args: num_timesteps: The number of timesteps to sample during this call. Note that only one of `num_timetseps` or `num_episodes` may be provided. num_episodes: The number of episodes to sample during this call. Note that only one of `num_timetseps` or `num_episodes` may be provided. explore: If True, will use the RLModule's `forward_exploration()` method to compute actions. If False, will use the RLModule's `forward_inference()` method. If None (default), will use the `explore` boolean setting from `self.config` passed into this EnvRunner's constructor. You can change this setting in your config via `config.env_runners(explore=True|False)`. random_actions: If True, actions will be sampled randomly (from the action space of the environment). If False (default), actions or action distribution parameters are computed by the RLModule. force_reset: Whether to force-reset all (vector) environments before sampling. Useful if you would like to collect a clean slate of new episodes via this call. Note that when sampling n episodes (`num_episodes != None`), this is fixed to True. Returns: A list of `MultiAgentEpisode` instances, carrying the sampled data. """ assert not (num_timesteps is not None and num_episodes is not None) # If no execution details are provided, use the config to try to infer the # desired timesteps/episodes to sample and the exploration behavior. if explore is None: explore = self.config.explore if num_timesteps is None and num_episodes is None: if self.config.batch_mode == "truncate_episodes": num_timesteps = self.config.get_rollout_fragment_length( worker_index=self.worker_index, ) else: num_episodes = 1 # Sample n timesteps. if num_timesteps is not None: samples = self._sample_timesteps( num_timesteps=num_timesteps, explore=explore, random_actions=random_actions, force_reset=force_reset, ) # Sample m episodes. else: samples = self._sample_episodes( num_episodes=num_episodes, explore=explore, random_actions=random_actions, ) # Make the `on_sample_end` callback. make_callback( "on_sample_end", callbacks_objects=self._callbacks, callbacks_functions=self.config.callbacks_on_sample_end, kwargs=dict( env_runner=self, metrics_logger=self.metrics, samples=samples, ), ) return samples
def _sample_timesteps( self, num_timesteps: int, explore: bool, random_actions: bool = False, force_reset: bool = False, ) -> List[MultiAgentEpisode]: """Helper method to sample n timesteps. Args: num_timesteps: int. Number of timesteps to sample during rollout. explore: boolean. If in exploration or inference mode. Exploration mode might for some algorithms provide extza model outputs that are redundant in inference mode. random_actions: boolean. If actions should be sampled from the action space. In default mode (i.e. `False`) we sample actions frokm the policy. Returns: `Lists of `MultiAgentEpisode` instances, carrying the collected sample data. """ done_episodes_to_return: List[MultiAgentEpisode] = [] # Have to reset the env. if force_reset or self._needs_initial_reset: # Create n new episodes and make the `on_episode_created` callbacks. self._episode = self._new_episode() self._make_on_episode_callback("on_episode_created") # Erase all cached ongoing episodes (these will never be completed and # would thus never be returned/cleaned by `get_metrics` and cause a memory # leak). self._ongoing_episodes_for_metrics.clear() # Try resetting the environment. # TODO (simon): Check, if we need here the seed from the config. obs, infos = self._try_env_reset() self._cached_to_module = None # Call `on_episode_start()` callbacks. self._make_on_episode_callback("on_episode_start") # We just reset the env. Don't have to force this again in the next # call to `self._sample_timesteps()`. self._needs_initial_reset = False # Set the initial observations in the episodes. self._episode.add_env_reset(observations=obs, infos=infos) self._shared_data = { "agent_to_module_mapping_fn": self.config.policy_mapping_fn, } # Loop through timesteps. ts = 0 while ts < num_timesteps: # Act randomly. if random_actions: # Only act (randomly) for those agents that had an observation. to_env = { Columns.ACTIONS: [ { aid: self.env.unwrapped.get_action_space(aid).sample() for aid in self._episode.get_agents_to_act() } ] } # Compute an action using the RLModule. else: # Env-to-module connector. to_module = self._cached_to_module or self._env_to_module( rl_module=self.module, episodes=[self._episode], explore=explore, shared_data=self._shared_data, ) self._cached_to_module = None # MultiRLModule forward pass: Explore or not. if explore: env_steps_lifetime = ( self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0) ) * (self.config.num_env_runners or 1) to_env = self.module.forward_exploration( to_module, t=env_steps_lifetime ) else: to_env = self.module.forward_inference(to_module) # Module-to-env connector. to_env = self._module_to_env( rl_module=self.module, batch=to_env, episodes=[self._episode], explore=explore, shared_data=self._shared_data, ) # Extract the (vectorized) actions (to be sent to the env) from the # module/connector output. Note that these actions are fully ready (e.g. # already unsquashed/clipped) to be sent to the environment) and might not # be identical to the actions produced by the RLModule/distribution, which # are the ones stored permanently in the episode objects. actions = to_env.pop(Columns.ACTIONS) actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions) # Try stepping the environment. # TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env. # Support vectorized multi-agent envs. results = self._try_env_step(actions_for_env[0]) # If any failure occurs during stepping -> Throw away all data collected # thus far and restart sampling procedure. if results == ENV_STEP_FAILURE: return self._sample_timesteps( num_timesteps=num_timesteps, explore=explore, random_actions=random_actions, force_reset=True, ) obs, rewards, terminateds, truncateds, infos = results # TODO (sven): This simple approach to re-map `to_env` from a # dict[col, List[MADict]] to a dict[agentID, MADict] would not work for # a vectorized env. extra_model_outputs = defaultdict(dict) for col, ma_dict_list in to_env.items(): # TODO (sven): Support vectorized MA env. ma_dict = ma_dict_list[0] for agent_id, val in ma_dict.items(): extra_model_outputs[agent_id][col] = val extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no extra_model_outputs = dict(extra_model_outputs) # Record the timestep in the episode instance. self._episode.add_env_step( obs, actions[0], rewards, infos=infos, terminateds=terminateds, truncateds=truncateds, extra_model_outputs=extra_model_outputs, ) ts += self._increase_sampled_metrics(self.num_envs, obs, self._episode) # Make the `on_episode_step` callback (before finalizing the episode # object). self._make_on_episode_callback("on_episode_step") # Episode is done for all agents. Wrap up the old one and create a new # one (and reset it) to continue. if self._episode.is_done: # We have to perform an extra env-to-module pass here, just in case # the user's connector pipeline performs (permanent) transforms # on each observation (including this final one here). Without such # a call and in case the structure of the observations change # sufficiently, the following `finalize()` call on the episode will # fail. if self.module is not None: self._env_to_module( episodes=[self._episode], explore=explore, rl_module=self.module, shared_data=self._shared_data, ) # Make the `on_episode_end` callback (before finalizing the episode, # but after(!) the last env-to-module connector call has been made. # -> All obs (even the terminal one) should have been processed now (by # the connector, if applicable). self._make_on_episode_callback("on_episode_end") # Finalize (numpy'ize) the episode. self._episode.finalize(drop_zero_len_single_agent_episodes=True) done_episodes_to_return.append(self._episode) # Create a new episode instance. self._episode = self._new_episode() self._make_on_episode_callback("on_episode_created") # Reset the environment. obs, infos = self._try_env_reset() # Add initial observations and infos. self._episode.add_env_reset(observations=obs, infos=infos) # Make the `on_episode_start` callback. self._make_on_episode_callback("on_episode_start") # Already perform env-to-module connector call for next call to # `_sample_timesteps()`. See comment in c'tor for `self._cached_to_module`. if self.module is not None: self._cached_to_module = self._env_to_module( rl_module=self.module, episodes=[self._episode], explore=explore, shared_data=self._shared_data, ) # Store done episodes for metrics. self._done_episodes_for_metrics.extend(done_episodes_to_return) # Also, make sure we start new episode chunks (continuing the ongoing episodes # from the to-be-returned chunks). ongoing_episode_continuation = self._episode.cut( len_lookback_buffer=self.config.episode_lookback_horizon ) ongoing_episodes_to_return = [] # Just started Episodes do not have to be returned. There is no data # in them anyway. if self._episode.env_t > 0: self._episode.validate() self._ongoing_episodes_for_metrics[self._episode.id_].append(self._episode) # Return finalized (numpy'ized) Episodes. ongoing_episodes_to_return.append( self._episode.finalize(drop_zero_len_single_agent_episodes=True) ) # Continue collecting into the cut Episode chunk. self._episode = ongoing_episode_continuation # Return collected episode data. return done_episodes_to_return + ongoing_episodes_to_return def _sample_episodes( self, num_episodes: int, explore: bool, random_actions: bool = False, ) -> List[MultiAgentEpisode]: """Helper method to run n episodes. See docstring of `self.sample()` for more details. """ # If user calls sample(num_timesteps=..) after this, we must reset again # at the beginning. self._needs_initial_reset = True done_episodes_to_return: List[MultiAgentEpisode] = [] # Create a new multi-agent episode. _episode = self._new_episode() self._make_on_episode_callback("on_episode_created", _episode) _shared_data = { "agent_to_module_mapping_fn": self.config.policy_mapping_fn, } # Try resetting the environment. # TODO (simon): Check, if we need here the seed from the config. obs, infos = self._try_env_reset() # Set initial obs and infos in the episodes. _episode.add_env_reset(observations=obs, infos=infos) self._make_on_episode_callback("on_episode_start", _episode) # Loop over episodes. eps = 0 ts = 0 while eps < num_episodes: # Act randomly. if random_actions: # Only act (randomly) for those agents that had an observation. to_env = { Columns.ACTIONS: [ { aid: self.env.unwrapped.get_action_space(aid).sample() for aid in self._episode.get_agents_to_act() } ] } # Compute an action using the RLModule. else: # Env-to-module connector. to_module = self._env_to_module( rl_module=self.module, episodes=[_episode], explore=explore, shared_data=_shared_data, ) # MultiRLModule forward pass: Explore or not. if explore: env_steps_lifetime = ( self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) + self.metrics.peek(NUM_ENV_STEPS_SAMPLED, default=0) ) * (self.config.num_env_runners or 1) to_env = self.module.forward_exploration( to_module, t=env_steps_lifetime ) else: to_env = self.module.forward_inference(to_module) # Module-to-env connector. to_env = self._module_to_env( rl_module=self.module, batch=to_env, episodes=[_episode], explore=explore, shared_data=_shared_data, ) # Extract the (vectorized) actions (to be sent to the env) from the # module/connector output. Note that these actions are fully ready (e.g. # already unsquashed/clipped) to be sent to the environment) and might not # be identical to the actions produced by the RLModule/distribution, which # are the ones stored permanently in the episode objects. actions = to_env.pop(Columns.ACTIONS) actions_for_env = to_env.pop(Columns.ACTIONS_FOR_ENV, actions) # Try stepping the environment. # TODO (sven): [0] = actions is vectorized, but env is NOT a vector Env. # Support vectorized multi-agent envs. results = self._try_env_step(actions_for_env[0]) # If any failure occurs during stepping -> Throw away all data collected # thus far and restart sampling procedure. if results == ENV_STEP_FAILURE: return self._sample_episodes( num_episodes=num_episodes, explore=explore, random_actions=random_actions, ) obs, rewards, terminateds, truncateds, infos = results # TODO (sven): This simple approach to re-map `to_env` from a # dict[col, List[MADict]] to a dict[agentID, MADict] would not work for # a vectorized env. extra_model_outputs = defaultdict(dict) for col, ma_dict_list in to_env.items(): # TODO (sven): Support vectorized MA env. ma_dict = ma_dict_list[0] for agent_id, val in ma_dict.items(): extra_model_outputs[agent_id][col] = val extra_model_outputs[agent_id][WEIGHTS_SEQ_NO] = self._weights_seq_no extra_model_outputs = dict(extra_model_outputs) # Record the timestep in the episode instance. _episode.add_env_step( obs, actions[0], rewards, infos=infos, terminateds=terminateds, truncateds=truncateds, extra_model_outputs=extra_model_outputs, ) ts += self._increase_sampled_metrics(self.num_envs, obs, _episode) # Make `on_episode_step` callback before finalizing the episode. self._make_on_episode_callback("on_episode_step", _episode) # TODO (sven, simon): We have to check, if we need this elaborate # function here or if the `MultiAgentEnv` defines the cases that # can happen. # Right now we have: # 1. Most times only agents that step get `terminated`, `truncated` # i.e. the rest we have to check in the episode. # 2. There are edge cases like, some agents terminated, all others # truncated and vice versa. # See also `MultiAgentEpisode` for handling the `__all__`. if _episode.is_done: # Increase episode count. eps += 1 # We have to perform an extra env-to-module pass here, just in case # the user's connector pipeline performs (permanent) transforms # on each observation (including this final one here). Without such # a call and in case the structure of the observations change # sufficiently, the following `finalize()` call on the episode will # fail. if self.module is not None: self._env_to_module( episodes=[_episode], explore=explore, rl_module=self.module, shared_data=_shared_data, ) # Make the `on_episode_end` callback (before finalizing the episode, # but after(!) the last env-to-module connector call has been made. # -> All obs (even the terminal one) should have been processed now (by # the connector, if applicable). self._make_on_episode_callback("on_episode_end", _episode) # Finish the episode. done_episodes_to_return.append( _episode.finalize(drop_zero_len_single_agent_episodes=True) ) # Also early-out if we reach the number of episodes within this # for-loop. if eps == num_episodes: break # Create a new episode instance. _episode = self._new_episode() self._make_on_episode_callback("on_episode_created", _episode) # Try resetting the environment. obs, infos = self._try_env_reset() # Add initial observations and infos. _episode.add_env_reset(observations=obs, infos=infos) # Make `on_episode_start` callback. self._make_on_episode_callback("on_episode_start", _episode) self._done_episodes_for_metrics.extend(done_episodes_to_return) return done_episodes_to_return
[docs] @override(EnvRunner) def get_spaces(self): # Return the already agent-to-module translated spaces from our connector # pipeline. return { **{ mid: (o, self._env_to_module.action_space[mid]) for mid, o in self._env_to_module.observation_space.spaces.items() }, }
[docs] @override(EnvRunner) def get_metrics(self) -> ResultDict: # Compute per-episode metrics (only on already completed episodes). for eps in self._done_episodes_for_metrics: assert eps.is_done episode_length = len(eps) agent_steps = defaultdict( int, {str(aid): len(sa_eps) for aid, sa_eps in eps.agent_episodes.items()}, ) episode_return = eps.get_return() episode_duration_s = eps.get_duration_s() agent_episode_returns = defaultdict( float, { str(sa_eps.agent_id): sa_eps.get_return() for sa_eps in eps.agent_episodes.values() }, ) module_episode_returns = defaultdict( float, { sa_eps.module_id: sa_eps.get_return() for sa_eps in eps.agent_episodes.values() }, ) # Don't forget about the already returned chunks of this episode. if eps.id_ in self._ongoing_episodes_for_metrics: for eps2 in self._ongoing_episodes_for_metrics[eps.id_]: return_eps2 = eps2.get_return() episode_length += len(eps2) episode_return += return_eps2 episode_duration_s += eps2.get_duration_s() for sa_eps in eps2.agent_episodes.values(): return_sa = sa_eps.get_return() agent_steps[str(sa_eps.agent_id)] += len(sa_eps) agent_episode_returns[str(sa_eps.agent_id)] += return_sa module_episode_returns[sa_eps.module_id] += return_sa del self._ongoing_episodes_for_metrics[eps.id_] self._log_episode_metrics( episode_length, episode_return, episode_duration_s, agent_episode_returns, module_episode_returns, dict(agent_steps), ) # Now that we have logged everything, clear cache of done episodes. self._done_episodes_for_metrics.clear() # Return reduced metrics. return self.metrics.reduce()
@override(Checkpointable) def get_state( self, components: Optional[Union[str, Collection[str]]] = None, *, not_components: Optional[Union[str, Collection[str]]] = None, **kwargs, ) -> StateDict: # Basic state dict. state = { NUM_ENV_STEPS_SAMPLED_LIFETIME: ( self.metrics.peek(NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0) ), } # RLModule (MultiRLModule) component. if self._check_component(COMPONENT_RL_MODULE, components, not_components): state[COMPONENT_RL_MODULE] = self.module.get_state( components=self._get_subcomponents(COMPONENT_RL_MODULE, components), not_components=self._get_subcomponents( COMPONENT_RL_MODULE, not_components ), **kwargs, ) state[WEIGHTS_SEQ_NO] = self._weights_seq_no # Env-to-module connector. if self._check_component( COMPONENT_ENV_TO_MODULE_CONNECTOR, components, not_components ): state[COMPONENT_ENV_TO_MODULE_CONNECTOR] = self._env_to_module.get_state() # Module-to-env connector. if self._check_component( COMPONENT_MODULE_TO_ENV_CONNECTOR, components, not_components ): state[COMPONENT_MODULE_TO_ENV_CONNECTOR] = self._module_to_env.get_state() return state @override(Checkpointable) def set_state(self, state: StateDict) -> None: if COMPONENT_ENV_TO_MODULE_CONNECTOR in state: self._env_to_module.set_state(state[COMPONENT_ENV_TO_MODULE_CONNECTOR]) if COMPONENT_MODULE_TO_ENV_CONNECTOR in state: self._module_to_env.set_state(state[COMPONENT_MODULE_TO_ENV_CONNECTOR]) # Update RLModule state. if COMPONENT_RL_MODULE in state: # A missing value for WEIGHTS_SEQ_NO or a value of 0 means: Force the # update. weights_seq_no = state.get(WEIGHTS_SEQ_NO, 0) # Only update the weigths, if this is the first synchronization or # if the weights of this `EnvRunner` lacks behind the actual ones. if weights_seq_no == 0 or self._weights_seq_no < weights_seq_no: self.module.set_state(state[COMPONENT_RL_MODULE]) # Update weights_seq_no, if the new one is > 0. if weights_seq_no > 0: self._weights_seq_no = weights_seq_no # Update lifetime counters. if NUM_ENV_STEPS_SAMPLED_LIFETIME in state: self.metrics.set_value( key=NUM_ENV_STEPS_SAMPLED_LIFETIME, value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME], reduce="sum", with_throughput=True, ) @override(Checkpointable) def get_ctor_args_and_kwargs(self): return ( (), # *args {"config": self.config}, # **kwargs ) @override(Checkpointable) def get_metadata(self): metadata = Checkpointable.get_metadata(self) metadata.update( { # TODO (sven): Maybe add serialized (JSON-writable) config here? } ) return metadata @override(Checkpointable) def get_checkpointable_components(self): return [ (COMPONENT_RL_MODULE, self.module), (COMPONENT_ENV_TO_MODULE_CONNECTOR, self._env_to_module), (COMPONENT_MODULE_TO_ENV_CONNECTOR, self._module_to_env), ] @override(EnvRunner) def assert_healthy(self): """Checks that self.__init__() has been completed properly. Ensures that the instances has a `MultiRLModule` and an environment defined. Raises: AssertionError: If the EnvRunner Actor has NOT been properly initialized. """ # Make sure, we have built our gym.vector.Env and RLModule properly. assert self.env and self.module
[docs] @override(EnvRunner) def make_env(self): # If an env already exists, try closing it first (to allow it to properly # cleanup). if self.env is not None: try: self.env.close() except Exception as e: logger.warning( "Tried closing the existing env (multi-agent), but failed with " f"error: {e.args[0]}" ) del self.env env_ctx = self.config.env_config if not isinstance(env_ctx, EnvContext): env_ctx = EnvContext( env_ctx, worker_index=self.worker_index, num_workers=self.config.num_env_runners, remote=self.config.remote_worker_envs, ) # No env provided -> Error. if not self.config.env: raise ValueError( "`config.env` is not provided! You should provide a valid environment " "to your config through `config.environment([env descriptor e.g. " "'CartPole-v1'])`." ) # Register env for the local context. # Note, `gym.register` has to be called on each worker. elif isinstance(self.config.env, str) and _global_registry.contains( ENV_CREATOR, self.config.env ): entry_point = partial( _global_registry.get(ENV_CREATOR, self.config.env), env_ctx, ) else: entry_point = partial( _gym_env_creator, env_descriptor=self.config.env, env_context=env_ctx, ) gym.register( "rllib-multi-agent-env-v0", entry_point=entry_point, disable_env_checker=True, ) # Perform actual gym.make call. self.env: MultiAgentEnv = gym.make("rllib-multi-agent-env-v0") self.num_envs = 1 # If required, check the created MultiAgentEnv. if not self.config.disable_env_checking: try: check_multiagent_environments(self.env.unwrapped) except Exception as e: logger.exception(e.args[0]) # If not required, still check the type (must be MultiAgentEnv). else: assert isinstance(self.env.unwrapped, MultiAgentEnv), ( "ERROR: When using the `MultiAgentEnvRunner` the environment needs " "to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`." ) # Set the flag to reset all envs upon the next `sample()` call. self._needs_initial_reset = True # Call the `on_environment_created` callback. make_callback( "on_environment_created", callbacks_objects=self._callbacks, callbacks_functions=self.config.callbacks_on_environment_created, kwargs=dict( env_runner=self, metrics_logger=self.metrics, env=self.env.unwrapped, env_context=env_ctx, ), )
[docs] @override(EnvRunner) def make_module(self): try: module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( env=self.env.unwrapped, spaces=self.get_spaces(), inference_only=True ) # Build the module from its spec. self.module = module_spec.build() # Move the RLModule to our device. # TODO (sven): In order to make this framework-agnostic, we should maybe # make the MultiRLModule.build() method accept a device OR create an # additional `(Multi)RLModule.to()` override. if torch: self.module.foreach_module( lambda mid, mod: ( mod.to(self._device) if isinstance(mod, torch.nn.Module) else mod ) ) # If `AlgorithmConfig.get_rl_module_spec()` is not implemented, this env runner # will not have an RLModule, but might still be usable with random actions. except NotImplementedError: self.module = None
@override(EnvRunner) def stop(self): # Note, `MultiAgentEnv` inherits `close()`-method from `gym.Env`. self.env.close() def _setup_metrics(self): self.metrics = MetricsLogger() self._done_episodes_for_metrics: List[MultiAgentEpisode] = [] self._ongoing_episodes_for_metrics: DefaultDict[ EpisodeID, List[MultiAgentEpisode] ] = defaultdict(list) def _new_episode(self): return MultiAgentEpisode( observation_space={ aid: self.env.unwrapped.get_observation_space(aid) for aid in self.env.unwrapped.possible_agents }, action_space={ aid: self.env.unwrapped.get_action_space(aid) for aid in self.env.unwrapped.possible_agents }, agent_to_module_mapping_fn=self.config.policy_mapping_fn, ) def _make_on_episode_callback(self, which: str, episode=None): episode = episode if episode is not None else self._episode make_callback( which, callbacks_objects=self._callbacks, callbacks_functions=getattr(self.config, f"callbacks_{which}"), kwargs=dict( episode=episode, env_runner=self, metrics_logger=self.metrics, env=self.env.unwrapped, rl_module=self.module, env_index=0, ), ) def _increase_sampled_metrics(self, num_steps, next_obs, episode): # Env steps. self.metrics.log_value( NUM_ENV_STEPS_SAMPLED, num_steps, reduce="sum", clear_on_reduce=True ) self.metrics.log_value( NUM_ENV_STEPS_SAMPLED_LIFETIME, num_steps, reduce="sum", with_throughput=True, ) # Completed episodes. if episode.is_done: self.metrics.log_value(NUM_EPISODES, 1, reduce="sum", clear_on_reduce=True) self.metrics.log_value(NUM_EPISODES_LIFETIME, 1, reduce="sum") # TODO (sven): obs is not-vectorized. Support vectorized MA envs. for aid in next_obs: self.metrics.log_value( (NUM_AGENT_STEPS_SAMPLED, str(aid)), 1, reduce="sum", clear_on_reduce=True, ) self.metrics.log_value( (NUM_AGENT_STEPS_SAMPLED_LIFETIME, str(aid)), 1, reduce="sum", ) self.metrics.log_value( (NUM_MODULE_STEPS_SAMPLED, episode.module_for(aid)), 1, reduce="sum", clear_on_reduce=True, ) self.metrics.log_value( (NUM_MODULE_STEPS_SAMPLED_LIFETIME, episode.module_for(aid)), 1, reduce="sum", ) return num_steps def _log_episode_metrics( self, length, ret, sec, agents=None, modules=None, agent_steps=None, ): # Log general episode metrics. self.metrics.log_dict( { EPISODE_LEN_MEAN: length, EPISODE_RETURN_MEAN: ret, EPISODE_DURATION_SEC_MEAN: sec, **( { # Per-agent returns. "agent_episode_returns_mean": agents, # Per-RLModule returns. "module_episode_returns_mean": modules, "agent_steps": agent_steps, } if agents is not None else {} ), }, # To mimick the old API stack behavior, we'll use `window` here for # these particular stats (instead of the default EMA). window=self.config.metrics_num_episodes_for_smoothing, ) # For some metrics, log min/max as well. self.metrics.log_dict( { EPISODE_LEN_MIN: length, EPISODE_RETURN_MIN: ret, }, reduce="min", window=self.config.metrics_num_episodes_for_smoothing, ) self.metrics.log_dict( { EPISODE_LEN_MAX: length, EPISODE_RETURN_MAX: ret, }, reduce="max", window=self.config.metrics_num_episodes_for_smoothing, ) @Deprecated( new="MultiAgentEnvRunner.get_state(components='rl_module')", error=False, ) def get_weights(self, modules=None): rl_module_state = self.get_state(components=COMPONENT_RL_MODULE)[ COMPONENT_RL_MODULE ] return rl_module_state @Deprecated(new="MultiAgentEnvRunner.set_state()", error=False) def set_weights( self, weights: ModelWeights, global_vars: Optional[Dict] = None, weights_seq_no: int = 0, ) -> None: assert global_vars is None return self.set_state( { COMPONENT_RL_MODULE: weights, WEIGHTS_SEQ_NO: weights_seq_no, } )