import logging
import queue
import time
from abc import ABCMeta, abstractmethod
from collections import defaultdict, namedtuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
Type,
Union,
)
import numpy as np
import tree # pip install dm_tree
from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv, convert_to_base_env
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.evaluation.env_runner_v2 import (
EnvRunnerV2,
_fetch_atari_metrics,
_get_or_raise,
_PerfStats,
)
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.metrics import RolloutMetrics
from ray.rllib.offline import InputReader
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import SampleBatch, concat_samples
from ray.rllib.utils.annotations import OldAPIStack, override
from ray.rllib.utils.debug import summarize
from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import convert_to_numpy, make_action_immutable
from ray.rllib.utils.spaces.space_utils import clip_action, unbatch, unsquash_action
from ray.rllib.utils.typing import (
AgentID,
EnvActionType,
EnvID,
EnvInfoDict,
EnvObsType,
MultiEnvDict,
PolicyID,
SampleBatchType,
TensorStructType,
)
from ray.util.debug import log_once
if TYPE_CHECKING:
from gymnasium.envs.classic_control.rendering import SimpleImageViewer
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.evaluation.observation_function import ObservationFunction
from ray.rllib.evaluation.rollout_worker import RolloutWorker
tf1, tf, _ = try_import_tf()
logger = logging.getLogger(__name__)
_PolicyEvalData = namedtuple(
"_PolicyEvalData",
["env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"],
)
# A batch of RNN states with dimensions [state_index, batch, state_object].
StateBatch = List[List[Any]]
class _NewEpisodeDefaultDict(defaultdict):
def __missing__(self, env_id):
if self.default_factory is None:
raise KeyError(env_id)
else:
ret = self[env_id] = self.default_factory(env_id)
return ret
[docs]@OldAPIStack
class SyncSampler(SamplerInput):
"""Sync SamplerInput that collects experiences when `get_data()` is called."""
[docs] def __init__(
self,
*,
worker: "RolloutWorker",
env: BaseEnv,
clip_rewards: Union[bool, float],
rollout_fragment_length: int,
count_steps_by: str = "env_steps",
callbacks: "DefaultCallbacks",
multiple_episodes_in_batch: bool = False,
normalize_actions: bool = True,
clip_actions: bool = False,
observation_fn: Optional["ObservationFunction"] = None,
sample_collector_class: Optional[Type[SampleCollector]] = None,
render: bool = False,
# Obsolete.
policies=None,
policy_mapping_fn=None,
preprocessors=None,
obs_filters=None,
tf_sess=None,
horizon=DEPRECATED_VALUE,
soft_horizon=DEPRECATED_VALUE,
no_done_at_end=DEPRECATED_VALUE,
):
"""Initializes a SyncSampler instance.
Args:
worker: The RolloutWorker that will use this Sampler for sampling.
env: Any Env object. Will be converted into an RLlib BaseEnv.
clip_rewards: True for +/-1.0 clipping,
actual float value for +/- value clipping. False for no
clipping.
rollout_fragment_length: The length of a fragment to collect
before building a SampleBatch from the data and resetting
the SampleBatchBuilder object.
count_steps_by: One of "env_steps" (default) or "agent_steps".
Use "agent_steps", if you want rollout lengths to be counted
by individual agent steps. In a multi-agent env,
a single env_step contains one or more agent_steps, depending
on how many agents are present at any given time in the
ongoing episode.
callbacks: The Callbacks object to use when episode
events happen during rollout.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be
exactly `rollout_fragment_length` in size.
normalize_actions: Whether to normalize actions to the
action space's bounds.
clip_actions: Whether to clip actions according to the
given action_space's bounds.
observation_fn: Optional multi-agent observation func to use for
preprocessing observations.
sample_collector_class: An optional SampleCollector sub-class to
use to collect, store, and retrieve environment-, model-,
and sampler data.
render: Whether to try to render the environment after each step.
"""
# All of the following arguments are deprecated. They will instead be
# provided via the passed in `worker` arg, e.g. `worker.policy_map`.
if log_once("deprecated_sync_sampler_args"):
if policies is not None:
deprecation_warning(old="policies")
if policy_mapping_fn is not None:
deprecation_warning(old="policy_mapping_fn")
if preprocessors is not None:
deprecation_warning(old="preprocessors")
if obs_filters is not None:
deprecation_warning(old="obs_filters")
if tf_sess is not None:
deprecation_warning(old="tf_sess")
if horizon != DEPRECATED_VALUE:
deprecation_warning(old="horizon", error=True)
if soft_horizon != DEPRECATED_VALUE:
deprecation_warning(old="soft_horizon", error=True)
if no_done_at_end != DEPRECATED_VALUE:
deprecation_warning(old="no_done_at_end", error=True)
self.base_env = convert_to_base_env(env)
self.rollout_fragment_length = rollout_fragment_length
self.extra_batches = queue.Queue()
self.perf_stats = _PerfStats(
ema_coef=worker.config.sampler_perf_stats_ema_coef,
)
if not sample_collector_class:
sample_collector_class = SimpleListCollector
self.sample_collector = sample_collector_class(
worker.policy_map,
clip_rewards,
callbacks,
multiple_episodes_in_batch,
rollout_fragment_length,
count_steps_by=count_steps_by,
)
self.render = render
if worker.config.enable_connectors:
# Keep a reference to the underlying EnvRunnerV2 instance for
# unit testing purpose.
self._env_runner_obj = EnvRunnerV2(
worker=worker,
base_env=self.base_env,
multiple_episodes_in_batch=multiple_episodes_in_batch,
callbacks=callbacks,
perf_stats=self.perf_stats,
rollout_fragment_length=rollout_fragment_length,
count_steps_by=count_steps_by,
render=self.render,
)
self._env_runner = self._env_runner_obj.run()
else:
# Create the rollout generator to use for calls to `get_data()`.
self._env_runner = _env_runner(
worker,
self.base_env,
self.extra_batches.put,
normalize_actions,
clip_actions,
multiple_episodes_in_batch,
callbacks,
self.perf_stats,
observation_fn,
self.sample_collector,
self.render,
)
self.metrics_queue = queue.Queue()
@override(SamplerInput)
def get_data(self) -> SampleBatchType:
while True:
item = next(self._env_runner)
if isinstance(item, RolloutMetrics):
self.metrics_queue.put(item)
else:
return item
@override(SamplerInput)
def get_metrics(self) -> List[RolloutMetrics]:
completed = []
while True:
try:
completed.append(
self.metrics_queue.get_nowait()._replace(
perf_stats=self.perf_stats.get()
)
)
except queue.Empty:
break
return completed
@override(SamplerInput)
def get_extra_batches(self) -> List[SampleBatchType]:
extra = []
while True:
try:
extra.append(self.extra_batches.get_nowait())
except queue.Empty:
break
return extra
@OldAPIStack
def _env_runner(
worker: "RolloutWorker",
base_env: BaseEnv,
extra_batch_callback: Callable[[SampleBatchType], None],
normalize_actions: bool,
clip_actions: bool,
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
perf_stats: _PerfStats,
observation_fn: "ObservationFunction",
sample_collector: Optional[SampleCollector] = None,
render: bool = None,
) -> Iterator[SampleBatchType]:
"""This implements the common experience collection logic.
Args:
worker: Reference to the current rollout worker.
base_env: Env implementing BaseEnv.
extra_batch_callback: function to send extra batch data to.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
normalize_actions: Whether to normalize actions to the action
space's bounds.
clip_actions: Whether to clip actions to the space range.
callbacks: User callbacks to run on episode events.
perf_stats: Record perf stats into this object.
observation_fn: Optional multi-agent
observation func to use for preprocessing observations.
sample_collector: An optional
SampleCollector object to use.
render: Whether to try to render the environment after each
step.
Yields:
Object containing state, action, reward, terminal condition,
and other fields as dictated by `policy`.
"""
# May be populated with used for image rendering
simple_image_viewer: Optional["SimpleImageViewer"] = None
def _new_episode(env_id):
episode = Episode(
worker.policy_map,
worker.policy_mapping_fn,
# SimpleListCollector will find or create a
# simple_list_collector._PolicyCollector as batch_builder
# for this episode later. Here we simply provide a None factory.
lambda: None, # batch_builder_factory
extra_batch_callback,
env_id=env_id,
worker=worker,
)
return episode
active_episodes: Dict[EnvID, Episode] = _NewEpisodeDefaultDict(_new_episode)
# Before the very first poll (this will reset all vector sub-environments):
# Call custom `before_sub_environment_reset` callbacks for all sub-environments.
for env_id, sub_env in base_env.get_sub_environments(as_dict=True).items():
_create_episode(active_episodes, env_id, callbacks, worker, base_env)
while True:
perf_stats.incr("iters", 1)
t0 = time.time()
# Get observations from all ready agents.
# types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
(
unfiltered_obs,
rewards,
terminateds,
truncateds,
infos,
off_policy_actions,
) = base_env.poll()
env_poll_time = time.time() - t0
if log_once("env_returns"):
logger.info("Raw obs from env: {}".format(summarize(unfiltered_obs)))
logger.info("Info return from env: {}".format(summarize(infos)))
# Process observations and prepare for policy evaluation.
t1 = time.time()
# types: Set[EnvID], Dict[PolicyID, List[_PolicyEvalData]],
# List[Union[RolloutMetrics, SampleBatchType]]
active_envs, to_eval, outputs = _process_observations(
worker=worker,
base_env=base_env,
active_episodes=active_episodes,
unfiltered_obs=unfiltered_obs,
rewards=rewards,
terminateds=terminateds,
truncateds=truncateds,
infos=infos,
multiple_episodes_in_batch=multiple_episodes_in_batch,
callbacks=callbacks,
observation_fn=observation_fn,
sample_collector=sample_collector,
)
perf_stats.incr("raw_obs_processing_time", time.time() - t1)
for o in outputs:
yield o
# Do batched policy eval (accross vectorized envs).
t2 = time.time()
# types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
eval_results = _do_policy_eval(
to_eval=to_eval,
policies=worker.policy_map,
sample_collector=sample_collector,
active_episodes=active_episodes,
)
perf_stats.incr("inference_time", time.time() - t2)
# Process results and update episode state.
t3 = time.time()
actions_to_send: Dict[
EnvID, Dict[AgentID, EnvActionType]
] = _process_policy_eval_results(
to_eval=to_eval,
eval_results=eval_results,
active_episodes=active_episodes,
active_envs=active_envs,
off_policy_actions=off_policy_actions,
policies=worker.policy_map,
normalize_actions=normalize_actions,
clip_actions=clip_actions,
)
perf_stats.incr("action_processing_time", time.time() - t3)
# Return computed actions to ready envs. We also send to envs that have
# taken off-policy actions; those envs are free to ignore the action.
t4 = time.time()
base_env.send_actions(actions_to_send)
perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4)
# Try to render the env, if required.
if render:
t5 = time.time()
# Render can either return an RGB image (uint8 [w x h x 3] numpy
# array) or take care of rendering itself (returning True).
rendered = base_env.try_render()
# Rendering returned an image -> Display it in a SimpleImageViewer.
if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
# ImageViewer not defined yet, try to create one.
if simple_image_viewer is None:
try:
from gymnasium.envs.classic_control.rendering import (
SimpleImageViewer,
)
simple_image_viewer = SimpleImageViewer()
except (ImportError, ModuleNotFoundError):
render = False # disable rendering
logger.warning(
"Could not import gymnasium.envs.classic_control."
"rendering! Try `pip install gymnasium[all]`."
)
if simple_image_viewer:
simple_image_viewer.imshow(rendered)
elif rendered not in [True, False, None]:
raise ValueError(
f"The env's ({base_env}) `try_render()` method returned an"
" unsupported value! Make sure you either return a "
"uint8/w x h x 3 (RGB) image or handle rendering in a "
"window and then return `True`."
)
perf_stats.incr("env_render_time", time.time() - t5)
@OldAPIStack
def _process_observations(
*,
worker: "RolloutWorker",
base_env: BaseEnv,
active_episodes: Dict[EnvID, Episode],
unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
rewards: Dict[EnvID, Dict[AgentID, float]],
terminateds: Dict[EnvID, Dict[AgentID, bool]],
truncateds: Dict[EnvID, Dict[AgentID, bool]],
infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
multiple_episodes_in_batch: bool,
callbacks: "DefaultCallbacks",
observation_fn: "ObservationFunction",
sample_collector: SampleCollector,
) -> Tuple[
Set[EnvID],
Dict[PolicyID, List[_PolicyEvalData]],
List[Union[RolloutMetrics, SampleBatchType]],
]:
"""Record new data from the environment and prepare for policy evaluation.
Args:
worker: Reference to the current rollout worker.
base_env: Env implementing BaseEnv.
active_episodes: Mapping from
episode ID to currently ongoing Episode object.
unfiltered_obs: Doubly keyed dict of env-ids -> agent ids
-> unfiltered observation tensor, returned by a `BaseEnv.poll()`
call.
rewards: Doubly keyed dict of env-ids -> agent ids ->
rewards tensor, returned by a `BaseEnv.poll()` call.
terminateds: Doubly keyed dict of env-ids -> agent ids ->
boolean `terminated` flags, returned by a `BaseEnv.poll()` call.
truncateds: Doubly keyed dict of env-ids -> agent ids ->
boolean `truncated` flags, returned by a `BaseEnv.poll()` call.
infos: Doubly keyed dict of env-ids -> agent ids ->
info dicts, returned by a `BaseEnv.poll()` call.
multiple_episodes_in_batch: Whether to pack multiple
episodes into each batch. This guarantees batches will be exactly
`rollout_fragment_length` in size.
callbacks: User callbacks to run on episode events.
observation_fn: Optional multi-agent
observation func to use for preprocessing observations.
sample_collector: The SampleCollector object
used to store and retrieve environment samples.
Returns:
Tuple consisting of 1) active_envs: Set of non-terminated env ids.
2) to_eval: Map of policy_id to list of agent _PolicyEvalData.
3) outputs: List of metrics and samples to return from the sampler.
"""
# Output objects.
active_envs: Set[EnvID] = set()
to_eval: Dict[PolicyID, List[_PolicyEvalData]] = defaultdict(list)
outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
# For each (vectorized) sub-environment.
# types: EnvID, Dict[AgentID, EnvObsType]
for env_id, all_agents_obs in unfiltered_obs.items():
episode: Episode = active_episodes[env_id]
# Check for env_id having returned an error instead of a multi-agent obs dict.
# This is how our BaseEnv can tell the caller to `poll()` that one of its
# sub-environments is faulty and should be restarted (and the ongoing episode
# should not be used for training).
if isinstance(all_agents_obs, Exception):
episode.is_faulty = True
assert terminateds[env_id]["__all__"] is True, (
f"ERROR: When a sub-environment (env-id {env_id}) returns an error as "
"observation, the terminateds[__all__] flag must also be set to True!"
)
# This will be filled with dummy observations below.
all_agents_obs = {}
# Add init obs and infos (from the call to `reset/try_reset`) to episode.
for aid, obs in all_agents_obs.items():
episode._set_last_raw_obs(aid, obs)
common_infos = infos[env_id].get("__common__", {})
episode._set_last_info("__common__", common_infos)
for aid, info in infos[env_id].items():
episode._set_last_info(aid, info)
# Episode is brand new.
if episode.started is False:
# Call the episode start callback(s).
_call_on_episode_start(episode, env_id, callbacks, worker, base_env)
else:
sample_collector.episode_step(episode)
episode._add_agent_rewards(rewards[env_id])
# Check episode termination conditions.
if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]:
all_agents_done = True
# Check whether we have to create a fake-last observation
# for some agents (the environment is not required to do so if
# terminateds[__all__]=True or truncateds[__all__]=True).
for ag_id in episode.get_agents():
if (
not episode.last_terminated_for(ag_id)
and not episode.last_truncated_for(ag_id)
and ag_id not in all_agents_obs
):
# Create a fake (all-0s) observation.
obs_sp = worker.policy_map[
episode.policy_for(ag_id)
].observation_space
obs_sp = getattr(obs_sp, "original_space", obs_sp)
all_agents_obs[ag_id] = tree.map_structure(
np.zeros_like, obs_sp.sample()
)
else:
all_agents_done = False
active_envs.add(env_id)
# Custom observation function is applied before preprocessing.
if observation_fn:
all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=all_agents_obs,
worker=worker,
base_env=base_env,
policies=worker.policy_map,
episode=episode,
)
if not isinstance(all_agents_obs, dict):
raise ValueError("observe() must return a dict of agent observations")
# For each agent in the environment.
# types: AgentID, EnvObsType
for agent_id, raw_obs in all_agents_obs.items():
assert agent_id != "__all__"
last_observation: EnvObsType = episode.last_observation_for(agent_id)
agent_terminated = bool(
terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id)
)
agent_truncated = bool(
truncateds[env_id]["__all__"] or truncateds[env_id].get(agent_id, False)
)
# A new agent (initial obs) is already done -> Skip entirely.
if last_observation is None and (agent_terminated or agent_truncated):
continue
policy_id: PolicyID = episode.policy_for(agent_id)
preprocessor = _get_or_raise(worker.preprocessors, policy_id)
prep_obs: EnvObsType = raw_obs
if preprocessor is not None:
prep_obs = preprocessor.transform(raw_obs)
if log_once("prep_obs"):
logger.info("Preprocessed obs: {}".format(summarize(prep_obs)))
filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)(
prep_obs
)
if log_once("filtered_obs"):
logger.info("Filtered obs: {}".format(summarize(filtered_obs)))
episode._set_last_observation(agent_id, filtered_obs)
episode._set_last_terminated(agent_id, agent_terminated)
episode._set_last_truncated(agent_id, agent_truncated)
agent_infos = infos[env_id].get(agent_id, {})
# Record transition info if applicable.
if last_observation is None:
sample_collector.add_init_obs(
episode=episode,
agent_id=agent_id,
env_id=env_id,
policy_id=policy_id,
init_obs=filtered_obs,
init_infos=agent_infos,
t=episode.length - 1,
)
else:
# Add actions, rewards, next-obs to collectors.
values_dict = {
SampleBatch.T: episode.length - 1,
SampleBatch.ENV_ID: env_id,
SampleBatch.AGENT_INDEX: episode._agent_index(agent_id),
# Action (slot 0) taken at timestep t.
SampleBatch.ACTIONS: episode.last_action_for(agent_id),
# Reward received after taking a at timestep t.
SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
# After taking action=a, did we terminate the episode?
SampleBatch.TERMINATEDS: agent_terminated,
# Was the episode truncated artificially
# (e.g. b/c of some time limit)?
SampleBatch.TRUNCATEDS: agent_truncated,
# Next observation.
SampleBatch.NEXT_OBS: filtered_obs,
}
# Add extra-action-fetches (policy-inference infos) to
# collectors.
pol = worker.policy_map[policy_id]
for key, value in episode.last_extra_action_outs_for(agent_id).items():
if key in pol.view_requirements:
values_dict[key] = value
# Env infos for this agent.
if SampleBatch.INFOS in pol.view_requirements:
values_dict[SampleBatch.INFOS] = agent_infos
sample_collector.add_action_reward_next_obs(
episode.episode_id,
agent_id,
env_id,
policy_id,
agent_terminated or agent_truncated,
values_dict,
)
if not agent_terminated and not agent_truncated:
item = _PolicyEvalData(
env_id,
agent_id,
filtered_obs,
agent_infos,
None
if last_observation is None
else episode.rnn_state_for(agent_id),
None
if last_observation is None
else episode.last_action_for(agent_id),
rewards[env_id].get(agent_id, 0.0),
)
to_eval[policy_id].append(item)
# Invoke the `on_episode_step` callback after the step is logged
# to the episode.
# Exception: The very first env.poll() call causes the env to get reset
# (no step taken yet, just a single starting observation logged).
# We need to skip this callback in this case.
if not episode.is_faulty and episode.length > 0:
callbacks.on_episode_step(
worker=worker,
base_env=base_env,
policies=worker.policy_map,
episode=episode,
env_index=env_id,
)
# Episode is terminated for all agents (terminateds[__all__] == True or
# truncateds[__all__] == True).
if all_agents_done:
# If, we are not allowed to pack the next episode into the same
# SampleBatch (batch_mode=complete_episodes) -> Build the
# MultiAgentBatch from a single episode and add it to "outputs".
# Otherwise, just postprocess and continue collecting across
# episodes.
# If an episode was marked faulty, perform regular postprocessing
# (to e.g. properly flush and clean up the SampleCollector's buffers),
# but then discard the entire batch and don't return it.
ma_sample_batch = None
if not episode.is_faulty or episode.length > 0:
ma_sample_batch = sample_collector.postprocess_episode(
episode,
is_done=True,
check_dones=True,
build=episode.is_faulty or not multiple_episodes_in_batch,
)
if not episode.is_faulty:
# Call each (in-memory) policy's Exploration.on_episode_end
# method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
environment=base_env,
episode=episode,
tf_sess=p.get_session(),
)
# Call custom on_episode_end callback.
callbacks.on_episode_end(
worker=worker,
base_env=base_env,
policies=worker.policy_map,
episode=episode,
env_index=env_id,
)
# Now that all callbacks are done and users had the chance to add custom
# metrics based on the last observation in the episode, finish up metrics
# object and append to `outputs`.
atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(base_env)
if not episode.is_faulty:
if atari_metrics is not None:
for m in atari_metrics:
outputs.append(
m._replace(
custom_metrics=episode.custom_metrics,
hist_data=episode.hist_data,
)
)
else:
outputs.append(
RolloutMetrics(
episode.length,
episode.total_reward,
dict(episode.agent_rewards),
episode.custom_metrics,
{},
episode.hist_data,
episode.media,
)
)
else:
# Add metrics about a faulty episode.
outputs.append(RolloutMetrics(episode_faulty=True))
# Only after the RolloutMetrics were appended, append the collected sample
# batch, if any.
if not episode.is_faulty and ma_sample_batch:
outputs.append(ma_sample_batch)
# Terminated: Try to reset the sub environment.
# Clean up old finished episode.
del active_episodes[env_id]
# Create a new episode and call `on_episode_created` callback(s).
_create_episode(active_episodes, env_id, callbacks, worker, base_env)
# The sub environment at index `env_id` might throw an exception
# during the following `try_reset()` attempt. If configured with
# `restart_failed_sub_environments=True`, the BaseEnv will restart
# the affected sub environment (create a new one using its c'tor) and
# must reset the recreated sub env right after that.
# Should the sub environment fail indefinitely during these
# repeated reset attempts, the entire worker will be blocked.
# This would be ok, b/c the alternative would be the worker crashing
# entirely.
while True:
resetted_obs, resetted_infos = base_env.try_reset(env_id)
if resetted_obs is None or not isinstance(
resetted_obs[env_id], Exception
):
break
else:
# Failed to reset, add metrics about a faulty episode.
outputs.append(RolloutMetrics(episode_faulty=True))
# Creates a new episode if this is not async return.
# If reset is async, we will get its result in some future poll.
if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN:
new_episode: Episode = active_episodes[env_id]
resetted_obs = resetted_obs[env_id]
resetted_infos = resetted_infos[env_id]
# Add init obs and infos (from the call to `reset/try_reset`) to
# episode.
for aid, obs in resetted_obs.items():
new_episode._set_last_raw_obs(aid, obs)
common_infos = resetted_infos.get("__common__", {})
new_episode._set_last_info("__common__", common_infos)
for aid, info in resetted_infos.items():
new_episode._set_last_info(aid, info)
_call_on_episode_start(new_episode, env_id, callbacks, worker, base_env)
_assert_episode_not_faulty(new_episode)
if observation_fn:
resetted_obs: Dict[AgentID, EnvObsType] = observation_fn(
agent_obs=resetted_obs,
worker=worker,
base_env=base_env,
policies=worker.policy_map,
episode=new_episode,
)
# types: AgentID, EnvObsType
for agent_id, raw_obs in resetted_obs.items():
policy_id: PolicyID = new_episode.policy_for(agent_id)
preproccessor = _get_or_raise(worker.preprocessors, policy_id)
prep_obs: EnvObsType = raw_obs
if preproccessor is not None:
prep_obs = preproccessor.transform(raw_obs)
filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)(
prep_obs
)
new_episode._set_last_observation(agent_id, filtered_obs)
# Add initial obs to buffer.
sample_collector.add_init_obs(
episode=new_episode,
agent_id=agent_id,
env_id=env_id,
policy_id=policy_id,
init_obs=filtered_obs,
init_infos=resetted_infos,
t=new_episode.length - 1,
)
item = _PolicyEvalData(
env_id,
agent_id,
filtered_obs,
new_episode.last_info_for(agent_id) or {},
new_episode.rnn_state_for(agent_id),
None,
0.0,
)
to_eval[policy_id].append(item)
# Try to build something.
if multiple_episodes_in_batch:
sample_batches = (
sample_collector.try_build_truncated_episode_multi_agent_batch()
)
if sample_batches:
outputs.extend(sample_batches)
return active_envs, to_eval, outputs
@OldAPIStack
def _do_policy_eval(
*,
to_eval: Dict[PolicyID, List[_PolicyEvalData]],
policies: PolicyMap,
sample_collector: SampleCollector,
active_episodes: Dict[EnvID, Episode],
) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]:
"""Call compute_actions on collected episode/model data to get next action.
Args:
to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects
(items in these lists will be the batch's items for the model
forward pass).
policies: Mapping from policy ID to Policy obj.
sample_collector: The SampleCollector object to use.
active_episodes: Mapping of EnvID to its currently active episode.
Returns:
Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
"""
eval_results: Dict[PolicyID, TensorStructType] = {}
if log_once("compute_actions_input"):
logger.info("Inputs to compute_actions():\n\n{}\n".format(summarize(to_eval)))
for policy_id, eval_data in to_eval.items():
# In case the policyID has been removed from this worker, we need to
# re-assign policy_id and re-lookup the Policy object to use.
try:
policy: Policy = _get_or_raise(policies, policy_id)
except ValueError:
# Important: Get the policy_mapping_fn from the active
# Episode as the policy_mapping_fn from the worker may
# have already been changed (mapping fn stay constant
# within one episode).
episode = active_episodes[eval_data[0].env_id]
_assert_episode_not_faulty(episode)
policy_id = episode.policy_mapping_fn(
eval_data[0].agent_id, episode, worker=episode.worker
)
policy: Policy = _get_or_raise(policies, policy_id)
input_dict = sample_collector.get_inference_input_dict(policy_id)
eval_results[policy_id] = policy.compute_actions_from_input_dict(
input_dict,
timestep=policy.global_timestep,
episodes=[active_episodes[t.env_id] for t in eval_data],
)
if log_once("compute_actions_result"):
logger.info(
"Outputs of compute_actions():\n\n{}\n".format(summarize(eval_results))
)
return eval_results
@OldAPIStack
def _process_policy_eval_results(
*,
to_eval: Dict[PolicyID, List[_PolicyEvalData]],
eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]],
active_episodes: Dict[EnvID, Episode],
active_envs: Set[int],
off_policy_actions: MultiEnvDict,
policies: Dict[PolicyID, Policy],
normalize_actions: bool,
clip_actions: bool,
) -> Dict[EnvID, Dict[AgentID, EnvActionType]]:
"""Process the output of policy neural network evaluation.
Records policy evaluation results into the given episode objects and
returns replies to send back to agents in the env.
Args:
to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects.
eval_results: Mapping of policy IDs to list of
actions, rnn-out states, extra-action-fetches dicts.
active_episodes: Mapping from episode ID to currently ongoing
Episode object.
active_envs: Set of non-terminated env ids.
off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
off-policy-action, returned by a `BaseEnv.poll()` call.
policies: Mapping from policy ID to Policy.
normalize_actions: Whether to normalize actions to the action
space's bounds.
clip_actions: Whether to clip actions to the action space's bounds.
Returns:
Nested dict of env id -> agent id -> actions to be sent to
Env (np.ndarrays).
"""
actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)
# types: int
for env_id in active_envs:
actions_to_send[env_id] = {} # at minimum send empty dict
# types: PolicyID, List[_PolicyEvalData]
for policy_id, eval_data in to_eval.items():
actions: TensorStructType = eval_results[policy_id][0]
actions = convert_to_numpy(actions)
rnn_out_cols: StateBatch = eval_results[policy_id][1]
extra_action_out_cols: dict = eval_results[policy_id][2]
# In case actions is a list (representing the 0th dim of a batch of
# primitive actions), try converting it first.
if isinstance(actions, list):
actions = np.array(actions)
# Store RNN state ins/outs and extra-action fetches to episode.
for f_i, column in enumerate(rnn_out_cols):
extra_action_out_cols["state_out_{}".format(f_i)] = column
policy: Policy = _get_or_raise(policies, policy_id)
# Split action-component batches into single action rows.
actions: List[EnvActionType] = unbatch(actions)
# types: int, EnvActionType
for i, action in enumerate(actions):
# Normalize, if necessary.
if normalize_actions:
action_to_send = unsquash_action(action, policy.action_space_struct)
# Clip, if necessary.
elif clip_actions:
action_to_send = clip_action(action, policy.action_space_struct)
else:
action_to_send = action
env_id: int = eval_data[i].env_id
agent_id: AgentID = eval_data[i].agent_id
episode: Episode = active_episodes[env_id]
_assert_episode_not_faulty(episode)
episode._set_rnn_state(
agent_id, tree.map_structure(lambda x: x[i], rnn_out_cols)
)
episode._set_last_extra_action_outs(
agent_id, tree.map_structure(lambda x: x[i], extra_action_out_cols)
)
if env_id in off_policy_actions and agent_id in off_policy_actions[env_id]:
episode._set_last_action(agent_id, off_policy_actions[env_id][agent_id])
else:
episode._set_last_action(agent_id, action)
assert agent_id not in actions_to_send[env_id]
# Flag actions as immutable to notify the user when trying to change it
# and to avoid hardly traceable errors.
tree.traverse(make_action_immutable, action_to_send, top_down=False)
actions_to_send[env_id][agent_id] = action_to_send
return actions_to_send
@OldAPIStack
def _create_episode(active_episodes, env_id, callbacks, worker, base_env):
# Make sure we are really creating a new episode here.
assert env_id not in active_episodes
# Create a new episode under the given `env_id` and call the
# `on_episode_created` callbacks.
new_episode = active_episodes[env_id]
# Call `on_episode_created()` callback.
callbacks.on_episode_created(
worker=worker,
base_env=base_env,
policies=worker.policy_map,
env_index=env_id,
episode=new_episode,
)
return new_episode
@OldAPIStack
def _call_on_episode_start(episode, env_id, callbacks, worker, base_env):
# Call each policy's Exploration.on_episode_start method.
# Note: This may break the exploration (e.g. ParameterNoise) of
# policies in the `policy_map` that have not been recently used
# (and are therefore stashed to disk). However, we certainly do not
# want to loop through all (even stashed) policies here as that
# would counter the purpose of the LRU policy caching.
for p in worker.policy_map.cache.values():
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
environment=base_env,
episode=episode,
tf_sess=p.get_session(),
)
callbacks.on_episode_start(
worker=worker,
base_env=base_env,
policies=worker.policy_map,
episode=episode,
env_index=env_id,
)
episode.started = True
def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch:
num_cols = len(rnn_state_rows[0])
return [[row[i] for row in rnn_state_rows] for i in range(num_cols)]
def _assert_episode_not_faulty(episode):
if episode.is_faulty:
raise AssertionError(
"Episodes marked as `faulty` should not be kept in the "
f"`active_episodes` map! Episode ID={episode.episode_id}."
)