Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The team is currently transitioning algorithms, example scripts, and documentation to the new code base throughout the subsequent minor releases leading up to Ray 3.0.
See here for more details on how to activate and use the new API stack.
SingleAgentEpisode API#
rllib.env.single_agent_episode.SingleAgentEpisode#
- class ray.rllib.env.single_agent_episode.SingleAgentEpisode(id_: str | None = None, *, observations: List[gymnasium.core.ObsType] | InfiniteLookbackBuffer | None = None, observation_space: gymnasium.Space | None = None, infos: List[Dict] | InfiniteLookbackBuffer | None = None, actions: List[gymnasium.core.ActType] | InfiniteLookbackBuffer | None = None, action_space: gymnasium.Space | None = None, rewards: List[SupportsFloat] | InfiniteLookbackBuffer | None = None, terminated: bool = False, truncated: bool = False, extra_model_outputs: Dict[str, Any] | None = None, t_started: int | None = None, len_lookback_buffer: int | str = 'auto', agent_id: Any | None = None, module_id: str | None = None, multi_agent_episode_id: int | None = None)[source]#
A class representing RL environment episodes for individual agents.
SingleAgentEpisode stores observations, info dicts, actions, rewards, and all module outputs (e.g. state outs, action logp, etc..) for an individual agent within some single-agent or multi-agent environment. The two main APIs to add data to an ongoing episode are the
add_env_reset()
andadd_env_step()
methods, which should be called passing the outputs of the respective gym.Env API calls:env.reset()
andenv.step()
.A SingleAgentEpisode might also only represent a chunk of an episode, which is useful for cases, in which partial (non-complete episode) sampling is performed and collected episode data has to be returned before the actual gym.Env episode has finished (see
SingleAgentEpisode.cut()
). In order to still maintain visibility onto past experiences within such a “cut” episode, SingleAgentEpisode instances can have a “lookback buffer” of n timesteps at their beginning (left side), which solely exists for the purpose of compiling extra data (e.g. “prev. reward”), but is not considered part of the finished/packaged episode (b/c the data in the lookback buffer is already part of a previous episode chunk).Powerful getter methods, such as
get_observations()
help collect different types of data from the episode at individual time indices or time ranges, including the “lookback buffer” range described above. For example, to extract the last 4 rewards of an ongoing episode, one can callself.get_rewards(slice(-4, None))
orself.rewards[-4:]
. This would work, even if the ongoing SingleAgentEpisode is a continuation chunk from a much earlier started episode, as long as it has a lookback buffer size of sufficient size.Examples:
import gymnasium as gym import numpy as np from ray.rllib.env.single_agent_episode import SingleAgentEpisode # Construct a new episode (without any data in it yet). episode = SingleAgentEpisode() assert len(episode) == 0 # Fill the episode with some data (10 timesteps). env = gym.make("CartPole-v1") obs, infos = env.reset() episode.add_env_reset(obs, infos) # Even with the initial obs/infos, the episode is still considered len=0. assert len(episode) == 0 for _ in range(5): action = env.action_space.sample() obs, reward, term, trunc, infos = env.step(action) episode.add_env_step( observation=obs, action=action, reward=reward, terminated=term, truncated=trunc, infos=infos, ) assert len(episode) == 5 # We can now access information from the episode via the getter APIs. # Get the last 3 rewards (in a batch of size 3). episode.get_rewards(slice(-3, None)) # same as `episode.rewards[-3:]` # Get the most recent action (single item, not batched). # This works regardless of the action space or whether the episode has # been finalized or not (see below). episode.get_actions(-1) # same as episode.actions[-1] # Looking back from ts=1, get the previous 4 rewards AND fill with 0.0 # in case we go over the beginning (ts=0). So we would expect # [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received # reward in the episode: episode.get_rewards(slice(-4, 0), neg_index_as_lookback=True, fill=0.0) # Note the use of fill=0.0 here (fill everything that's out of range with this # value) AND the argument `neg_index_as_lookback=True`, which interprets # negative indices as being left of ts=0 (e.g. -1 being the timestep before # ts=0). # Assuming we had a complex action space (nested gym.spaces.Dict) with one or # more elements being Discrete or MultiDiscrete spaces: # 1) The `fill=...` argument would still work, filling all spaces (Boxes, # Discrete) with that provided value. # 2) Setting the flag `one_hot_discrete=True` would convert those discrete # sub-components automatically into one-hot (or multi-one-hot) tensors. # This simplifies the task of having to provide the previous 4 (nested and # partially discrete/multi-discrete) actions for each timestep within a training # batch, thereby filling timesteps before the episode started with 0.0s and # one-hot'ing the discrete/multi-discrete components in these actions: episode = SingleAgentEpisode(action_space=gym.spaces.Dict({ "a": gym.spaces.Discrete(3), "b": gym.spaces.MultiDiscrete([2, 3]), "c": gym.spaces.Box(-1.0, 1.0, (2,)), })) # ... fill episode with data ... episode.add_env_reset(observation=0) # ... from a few steps. episode.add_env_step( observation=1, action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)}, reward=1.0, ) # In your connector prev_4_a = [] # Note here that len(episode) does NOT include the lookback buffer. for ts in range(len(episode)): prev_4_a.append( episode.get_actions( indices=slice(ts - 4, ts), # Make sure negative indices are interpreted as # "into lookback buffer" neg_index_as_lookback=True, # Zero-out everything even further before the lookback buffer. fill=0.0, # Take care of discrete components (get ready as NN input). one_hot_discrete=True, ) ) # Finally, convert from list of batch items to a struct (same as action space) # of batched (numpy) arrays, in which all leafs have B==len(prev_4_a). from ray.rllib.utils.spaces.space_utils import batch prev_4_actions_col = batch(prev_4_a)
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- add_env_reset(observation: gymnasium.core.ObsType, infos: Dict | None = None) None [source]#
Adds the initial data (after an
env.reset()
) to the episode.This data consists of initial observations and initial infos.
- Parameters:
observation – The initial observation returned by
env.reset()
.infos – An (optional) info dict returned by
env.reset()
.
- add_env_step(observation: gymnasium.core.ObsType, action: gymnasium.core.ActType, reward: SupportsFloat, infos: Dict[str, Any] | None = None, *, terminated: bool = False, truncated: bool = False, extra_model_outputs: Dict[str, Any] | None = None) None [source]#
Adds results of an
env.step()
call (including the action) to this episode.This data consists of an observation and info dict, an action, a reward, terminated/truncated flags, and extra model outputs (e.g. action probabilities or RNN internal state outputs).
- Parameters:
observation – The next observation received from the environment after(!) taking
action
.action – The last action used by the agent during the call to
env.step()
.reward – The last reward received by the agent after taking
action
.infos – The last info received from the environment after taking
action
.terminated – A boolean indicating, if the environment has been terminated (after taking
action
).truncated – A boolean indicating, if the environment has been truncated (after taking
action
).extra_model_outputs – The last timestep’s specific model outputs. These are normally outputs of an RLModule that were computed along with
action
, e.g.action_logp
oraction_dist_inputs
.
- validate() None [source]#
Validates the episode’s data.
This function ensures that the data stored to a
SingleAgentEpisode
is in order (e.g. that the correct number of observations, actions, rewards are there).
- property is_done: bool#
Whether the episode is actually done (terminated or truncated).
A done episode cannot be continued via
self.add_timestep()
or being concatenated on its right-side with another episode chunk or being succeeded viaself.create_successor()
.
- finalize() SingleAgentEpisode [source]#
Converts this Episode’s list attributes to numpy arrays.
This means in particular that this episodes’ lists of (possibly complex) data (e.g. if we have a dict obs space) will be converted to (possibly complex) structs, whose leafs are now numpy arrays. Each of these leaf numpy arrays will have the same length (batch dimension) as the length of the original lists.
Note that Columns.INFOS are NEVER numpy’ized and will remain a list (normally, a list of the original, env-returned dicts). This is due to the herterogenous nature of INFOS returned by envs, which would make it unwieldy to convert this information to numpy arrays.
After calling this method, no further data may be added to this episode via the
self.add_env_step()
method.Examples:
import numpy as np from ray.rllib.env.single_agent_episode import SingleAgentEpisode episode = SingleAgentEpisode( observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], # Note: terminated/truncated have nothing to do with an episode # being `finalized` or not (via the `self.finalize()` method)! terminated=False, len_lookback_buffer=0, # no lookback; all data is actually "in" episode ) # Episode has not been finalized (numpy'ized) yet. assert not episode.is_finalized # We are still operating on lists. assert episode.get_observations([1]) == [1] assert episode.get_observations(slice(None, 2)) == [0, 1] # We can still add data (and even add the terminated=True flag). episode.add_env_step( observation=4, action=4, reward=4, terminated=True, ) # Still NOT finalized. assert not episode.is_finalized # Let's finalize the episode. episode.finalize() assert episode.is_finalized # We cannot add data anymore. The following would crash. # episode.add_env_step(observation=5, action=5, reward=5) # Everything is now numpy arrays (with 0-axis of size # B=[len of requested slice]). assert isinstance(episode.get_observations([1]), np.ndarray) # B=1 assert isinstance(episode.actions[0:2], np.ndarray) # B=2 assert isinstance(episode.rewards[1:4], np.ndarray) # B=3
- Returns:
This
SingleAgentEpisode
object with the converted numpy data.
- concat_episode(other: SingleAgentEpisode) None [source]#
Adds the given
other
SingleAgentEpisode to the right side of self.In order for this to work, both chunks (
self
andother
) must fit together. This is checked by the IDs (must be identical), the time step counters (self.env_t
must be the same asepisode_chunk.env_t_started
), as well as the observations/infos at the concatenation boundaries. Also,self.is_done
must not be True, meaningself.is_terminated
andself.is_truncated
are both False.- Parameters:
other – The other
SingleAgentEpisode
to be concatenated to this one.
- Returns: A
SingleAgentEpisode
instance containing the concatenated data from both episodes (
self
andother
).
- cut(len_lookback_buffer: int = 0) SingleAgentEpisode [source]#
Returns a successor episode chunk (of len=0) continuing from this Episode.
The successor will have the same ID as
self
. If no lookback buffer is requested (len_lookback_buffer=0), the successor’s observations will be the last observation(s) ofself
and its length will therefore be 0 (no further steps taken yet). Iflen_lookback_buffer
> 0, the returned successor will havelen_lookback_buffer
observations (and actions, rewards, etc..) taken from the right side (end) ofself
. For example iflen_lookback_buffer=2
, the returned successor’s lookback buffer actions will be identical toself.actions[-2:]
.This method is useful if you would like to discontinue building an episode chunk (b/c you have to return it from somewhere), but would like to have a new episode instance to continue building the actual gym.Env episode at a later time. Vie the
len_lookback_buffer
argument, the continuing chunk (successor) will still be able to “look back” into this predecessor episode’s data (at least to some extend, depending on the value oflen_lookback_buffer
).- Parameters:
len_lookback_buffer – The number of timesteps to take along into the new chunk as “lookback buffer”. A lookback buffer is additional data on the left side of the actual episode data for visibility purposes (but without actually being part of the new chunk). For example, if
self
ends in actions 5, 6, 7, and 8, and we callself.cut(len_lookback_buffer=2)
, the returned chunk will have actions 7 and 8 already in it, but stillt_started`==t==8 (not 7!) and a length of 0. If there is not enough data in `self
yet to fulfil thelen_lookback_buffer
request, the value oflen_lookback_buffer
is automatically adjusted (lowered).- Returns:
The successor Episode chunk of this one with the same ID and state and the only observation being the last observation in self.
- get_observations(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None, one_hot_discrete: bool = False) Any [source]#
Returns individual observations or batched ranges thereof from this episode.
- Parameters:
indices – A single int is interpreted as an index, from which to return the individual observation stored at this index. A list of ints is interpreted as a list of indices from which to gather individual observations in a batch of size len(indices). A slice object is interpreted as a range of observations to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the
neg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all observations (from ts=0 to the end).neg_index_as_lookback – If True, negative values in
indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with observations [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond toget_observations(-1, neg_index_as_lookback=True)
with6
and toget_observations(slice(-2, 1), neg_index_as_lookback=True)
with[5, 6, 7]
.fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with observations [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning observations
10
and11
are part of the lookback buffer) will respond toget_observations(slice(-7, -2), fill=0.0)
with[0.0, 0.0, 10, 11, 12]
.one_hot_discrete – If True, will return one-hot vectors (instead of int-values) for those sub-components of a (possibly complex) observation space that are Discrete or MultiDiscrete. Note that if
fill=0
and the requestedindices
are out of the range of our data, the returned one-hot vectors will actually be zero-hot (all slots zero).
Examples:
import gymnasium as gym from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.test_utils import check episode = SingleAgentEpisode( # Discrete(4) observations (ints between 0 and 4 (excl.)) observation_space=gym.spaces.Discrete(4), observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], # <- not relevant for this demo len_lookback_buffer=0, # no lookback; all data is actually "in" episode ) # Plain usage (`indices` arg only). check(episode.get_observations(-1), 3) check(episode.get_observations(0), 0) check(episode.get_observations([0, 2]), [0, 2]) check(episode.get_observations([-1, 0]), [3, 0]) check(episode.get_observations(slice(None, 2)), [0, 1]) check(episode.get_observations(slice(-2, None)), [2, 3]) # Using `fill=...` (requesting slices beyond the boundaries). check(episode.get_observations(slice(-6, -2), fill=-9), [-9, -9, 0, 1]) check(episode.get_observations(slice(2, 5), fill=-7), [2, 3, -7]) # Using `one_hot_discrete=True`. check(episode.get_observations(2, one_hot_discrete=True), [0, 0, 1, 0]) check(episode.get_observations(3, one_hot_discrete=True), [0, 0, 0, 1]) check(episode.get_observations( slice(0, 3), one_hot_discrete=True, ), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]) # Special case: Using `fill=0.0` AND `one_hot_discrete=True`. check(episode.get_observations( -1, neg_index_as_lookback=True, # -1 means one left of ts=0 fill=0.0, one_hot_discrete=True, ), [0, 0, 0, 0]) # <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
- Returns:
The collected observations. As a 0-axis batch, if there are several
indices
or a list of exactly one index provided ORindices
is a slice object. As single item (B=0 -> no additional 0-axis) ifindices
is a single int.
- get_infos(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None) Any [source]#
Returns individual info dicts or list (ranges) thereof from this episode.
- Parameters:
indices – A single int is interpreted as an index, from which to return the individual info dict stored at this index. A list of ints is interpreted as a list of indices from which to gather individual info dicts in a list of size len(indices). A slice object is interpreted as a range of info dicts to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the
neg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all infos (from ts=0 to the end).neg_index_as_lookback – If True, negative values in
indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with infos [{“l”:4}, {“l”:5}, {“l”:6}, {“a”:7}, {“b”:8}, {“c”:9}], where the first 3 items are the lookback buffer (ts=0 item is {“a”: 7}), will respond toget_infos(-1, neg_index_as_lookback=True)
with{"l":6}
and toget_infos(slice(-2, 1), neg_index_as_lookback=True)
with[{"l":5}, {"l":6}, {"a":7}]
.fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to auto-fill. For example, an episode with infos [{“l”:10}, {“l”:11}, {“a”:12}, {“b”:13}, {“c”:14}] and lookback buffer size of 2 (meaning infos {“l”:10}, {“l”:11} are part of the lookback buffer) will respond to
get_infos(slice(-7, -2), fill={"o": 0.0})
with[{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]
.
Examples:
from ray.rllib.env.single_agent_episode import SingleAgentEpisode episode = SingleAgentEpisode( infos=[{"a":0}, {"b":1}, {"c":2}, {"d":3}], # The following is needed, but not relevant for this demo. observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], len_lookback_buffer=0, # no lookback; all data is actually "in" episode ) # Plain usage (`indices` arg only). episode.get_infos(-1) # {"d":3} episode.get_infos(0) # {"a":0} episode.get_infos([0, 2]) # [{"a":0},{"c":2}] episode.get_infos([-1, 0]) # [{"d":3},{"a":0}] episode.get_infos(slice(None, 2)) # [{"a":0},{"b":1}] episode.get_infos(slice(-2, None)) # [{"c":2},{"d":3}] # Using `fill=...` (requesting slices beyond the boundaries). # TODO (sven): This would require a space being provided. Maybe we can # skip this check for infos, which don't have a space anyways. # episode.get_infos(slice(-5, -3), fill={"o":-1}) # [{"o":-1},{"a":0}] # episode.get_infos(slice(3, 5), fill={"o":-2}) # [{"d":3},{"o":-2}]
- Returns:
The collected info dicts. As a 0-axis batch, if there are several
indices
or a list of exactly one index provided ORindices
is a slice object. As single item (B=0 -> no additional 0-axis) ifindices
is a single int.
- get_actions(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None, one_hot_discrete: bool = False) Any [source]#
Returns individual actions or batched ranges thereof from this episode.
- Parameters:
indices – A single int is interpreted as an index, from which to return the individual action stored at this index. A list of ints is interpreted as a list of indices from which to gather individual actions in a batch of size len(indices). A slice object is interpreted as a range of actions to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the
neg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all actions (from ts=0 to the end).neg_index_as_lookback – If True, negative values in
indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with actions [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond toget_actions(-1, neg_index_as_lookback=True)
with6
and toget_actions(slice(-2, 1), neg_index_as_lookback=True)
with[5, 6, 7]
.fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with actions [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning actions
10
and11
are part of the lookback buffer) will respond toget_actions(slice(-7, -2), fill=0.0)
with[0.0, 0.0, 10, 11, 12]
.one_hot_discrete – If True, will return one-hot vectors (instead of int-values) for those sub-components of a (possibly complex) action space that are Discrete or MultiDiscrete. Note that if
fill=0
and the requestedindices
are out of the range of our data, the returned one-hot vectors will actually be zero-hot (all slots zero).
Examples:
import gymnasium as gym from ray.rllib.env.single_agent_episode import SingleAgentEpisode episode = SingleAgentEpisode( # Discrete(4) actions (ints between 0 and 4 (excl.)) action_space=gym.spaces.Discrete(4), actions=[1, 2, 3], observations=[0, 1, 2, 3], rewards=[1, 2, 3], # <- not relevant here len_lookback_buffer=0, # no lookback; all data is actually "in" episode ) # Plain usage (`indices` arg only). episode.get_actions(-1) # 3 episode.get_actions(0) # 1 episode.get_actions([0, 2]) # [1, 3] episode.get_actions([-1, 0]) # [3, 1] episode.get_actions(slice(None, 2)) # [1, 2] episode.get_actions(slice(-2, None)) # [2, 3] # Using `fill=...` (requesting slices beyond the boundaries). episode.get_actions(slice(-5, -2), fill=-9) # [-9, -9, 1, 2] episode.get_actions(slice(1, 5), fill=-7) # [2, 3, -7, -7] # Using `one_hot_discrete=True`. episode.get_actions(1, one_hot_discrete=True) # [0 0 1 0] (action=2) episode.get_actions(2, one_hot_discrete=True) # [0 0 0 1] (action=3) episode.get_actions( slice(0, 2), one_hot_discrete=True, ) # [[0 1 0 0], [0 0 0 1]] (actions=1 and 3) # Special case: Using `fill=0.0` AND `one_hot_discrete=True`. episode.get_actions( -1, neg_index_as_lookback=True, # -1 means one left of ts=0 fill=0.0, one_hot_discrete=True, ) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
- Returns:
The collected actions. As a 0-axis batch, if there are several
indices
or a list of exactly one index provided ORindices
is a slice object. As single item (B=0 -> no additional 0-axis) ifindices
is a single int.
- get_rewards(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: float | None = None) Any [source]#
Returns individual rewards or batched ranges thereof from this episode.
- Parameters:
indices – A single int is interpreted as an index, from which to return the individual reward stored at this index. A list of ints is interpreted as a list of indices from which to gather individual rewards in a batch of size len(indices). A slice object is interpreted as a range of rewards to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the
neg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all rewards (from ts=0 to the end).neg_index_as_lookback – Negative values in
indices
are interpreted as as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with rewards [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond toget_rewards(-1, neg_index_as_lookback=True)
with6
and toget_rewards(slice(-2, 1), neg_index_as_lookback=True)
with[5, 6, 7]
.fill – An optional float value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with rewards [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning rewards
10
and11
are part of the lookback buffer) will respond toget_rewards(slice(-7, -2), fill=0.0)
with[0.0, 0.0, 10, 11, 12]
.
Examples:
from ray.rllib.env.single_agent_episode import SingleAgentEpisode episode = SingleAgentEpisode( rewards=[1.0, 2.0, 3.0], observations=[0, 1, 2, 3], actions=[1, 2, 3], # <- not relevant here len_lookback_buffer=0, # no lookback; all data is actually "in" episode ) # Plain usage (`indices` arg only). episode.get_rewards(-1) # 3.0 episode.get_rewards(0) # 1.0 episode.get_rewards([0, 2]) # [1.0, 3.0] episode.get_rewards([-1, 0]) # [3.0, 1.0] episode.get_rewards(slice(None, 2)) # [1.0, 2.0] episode.get_rewards(slice(-2, None)) # [2.0, 3.0] # Using `fill=...` (requesting slices beyond the boundaries). episode.get_rewards(slice(-5, -2), fill=0.0) # [0.0, 0.0, 1.0, 2.0] episode.get_rewards(slice(1, 5), fill=0.0) # [2.0, 3.0, 0.0, 0.0]
- Returns:
The collected rewards. As a 0-axis batch, if there are several
indices
or a list of exactly one index provided ORindices
is a slice object. As single item (B=0 -> no additional 0-axis) ifindices
is a single int.
- get_extra_model_outputs(key: str, indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None) Any [source]#
Returns extra model outputs (under given key) from this episode.
- Parameters:
key – The
key
withinself.extra_model_outputs
to extract data for.indices – A single int is interpreted as an index, from which to return an individual extra model output stored under
key
at index. A list of ints is interpreted as a list of indices from which to gather individual actions in a batch of size len(indices). A slice object is interpreted as a range of extra model outputs to be returned. Thereby, negative indices by default are interpreted as “before the end” unless theneg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all extra model outputs (from ts=0 to the end).neg_index_as_lookback – If True, negative values in
indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with extra_model_outputs[‘a’] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond toget_extra_model_outputs("a", -1, neg_index_as_lookback=True)
with6
and toget_extra_model_outputs("a", slice(-2, 1), neg_index_as_lookback=True)
with[5, 6, 7]
.fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with extra_model_outputs[“b”] = [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning
10
and11
are part of the lookback buffer) will respond toget_extra_model_outputs("b", slice(-7, -2), fill=0.0)
with[0.0, 0.0, 10, 11, 12]
. TODO (sven): This would require a space being provided. Maybe we can automatically infer the space from existing data?
Examples:
from ray.rllib.env.single_agent_episode import SingleAgentEpisode episode = SingleAgentEpisode( extra_model_outputs={"mo": [1, 2, 3]}, len_lookback_buffer=0, # no lookback; all data is actually "in" episode # The following is needed, but not relevant for this demo. observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], ) # Plain usage (`indices` arg only). episode.get_extra_model_outputs("mo", -1) # 3 episode.get_extra_model_outputs("mo", 1) # 0 episode.get_extra_model_outputs("mo", [0, 2]) # [1, 3] episode.get_extra_model_outputs("mo", [-1, 0]) # [3, 1] episode.get_extra_model_outputs("mo", slice(None, 2)) # [1, 2] episode.get_extra_model_outputs("mo", slice(-2, None)) # [2, 3] # Using `fill=...` (requesting slices beyond the boundaries). # TODO (sven): This would require a space being provided. Maybe we can # automatically infer the space from existing data? # episode.get_extra_model_outputs("mo", slice(-5, -2), fill=0) # [0, 0, 1] # episode.get_extra_model_outputs("mo", slice(2, 5), fill=-1) # [3, -1, -1]
- Returns:
The collected extra_model_outputs[
key
]. As a 0-axis batch, if there are severalindices
or a list of exactly one index provided ORindices
is a slice object. As single item (B=0 -> no additional 0-axis) ifindices
is a single int.
- set_observations(*, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None [source]#
Overwrites all or some of this Episode’s observations with the provided data.
Note that an episode’s observation data cannot be written to directly as it is managed by a
InfiniteLookbackBuffer
object. Normally, individual, current observations are added to the episode either by callingself.add_env_step
or more directly (and manually) viaself.observations.append|extend()
. However, for certain postprocessing steps, the entirety (or a slice) of an episode’s observations might have to be rewritten, which is whenself.set_observations()
should be used.- Parameters:
new_data – The new observation data to overwrite existing data with. This may be a list of individual observation(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be (possibly complex) struct matching the observation space and with a batch size of its leafs exactly the size of the to-be-overwritten slice or segment (provided by
at_indices
).at_indices – A single int is interpreted as one index, which to overwrite with
new_data
(which is expected to be a single observation). A list of ints is interpreted as a list of indices, all of which to overwrite withnew_data
(which is expected to be of the same size aslen(at_indices)
). A slice object is interpreted as a range of indices to be overwritten withnew_data
(which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless theneg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.neg_index_as_lookback – If True, negative values in
at_indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with observations = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call toset_observations(individual_observation, -1, neg_index_as_lookback=True)
by overwriting the value of 6 in our observations buffer with the provided “individual_observation”.
- Raises:
IndexError – If the provided
at_indices
do not match the size ofnew_data
.
- set_actions(*, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None [source]#
Overwrites all or some of this Episode’s actions with the provided data.
Note that an episode’s action data cannot be written to directly as it is managed by a
InfiniteLookbackBuffer
object. Normally, individual, current actions are added to the episode either by callingself.add_env_step
or more directly (and manually) viaself.actions.append|extend()
. However, for certain postprocessing steps, the entirety (or a slice) of an episode’s actions might have to be rewritten, which is whenself.set_actions()
should be used.- Parameters:
new_data – The new action data to overwrite existing data with. This may be a list of individual action(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be (possibly complex) struct matching the action space and with a batch size of its leafs exactly the size of the to-be-overwritten slice or segment (provided by
at_indices
).at_indices – A single int is interpreted as one index, which to overwrite with
new_data
(which is expected to be a single action). A list of ints is interpreted as a list of indices, all of which to overwrite withnew_data
(which is expected to be of the same size aslen(at_indices)
). A slice object is interpreted as a range of indices to be overwritten withnew_data
(which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless theneg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.neg_index_as_lookback – If True, negative values in
at_indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with actions = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call toset_actions(individual_action, -1, neg_index_as_lookback=True)
by overwriting the value of 6 in our actions buffer with the provided “individual_action”.
- Raises:
IndexError – If the provided
at_indices
do not match the size ofnew_data
.
- set_rewards(*, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None [source]#
Overwrites all or some of this Episode’s rewards with the provided data.
Note that an episode’s reward data cannot be written to directly as it is managed by a
InfiniteLookbackBuffer
object. Normally, individual, current rewards are added to the episode either by callingself.add_env_step
or more directly (and manually) viaself.rewards.append|extend()
. However, for certain postprocessing steps, the entirety (or a slice) of an episode’s rewards might have to be rewritten, which is whenself.set_rewards()
should be used.- Parameters:
new_data – The new reward data to overwrite existing data with. This may be a list of individual reward(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be a np.ndarray with a length exactly the size of the to-be-overwritten slice or segment (provided by
at_indices
).at_indices – A single int is interpreted as one index, which to overwrite with
new_data
(which is expected to be a single reward). A list of ints is interpreted as a list of indices, all of which to overwrite withnew_data
(which is expected to be of the same size aslen(at_indices)
). A slice object is interpreted as a range of indices to be overwritten withnew_data
(which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless theneg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.neg_index_as_lookback – If True, negative values in
at_indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call toset_rewards(individual_reward, -1, neg_index_as_lookback=True)
by overwriting the value of 6 in our rewards buffer with the provided “individual_reward”.
- Raises:
IndexError – If the provided
at_indices
do not match the size ofnew_data
.
- set_extra_model_outputs(*, key, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None [source]#
Overwrites all or some of this Episode’s extra model outputs with
new_data
.Note that an episode’s
extra_model_outputs
data cannot be written to directly as it is managed by aInfiniteLookbackBuffer
object. Normally, individual, currentextra_model_output
values are added to the episode either by callingself.add_env_step
or more directly (and manually) viaself.extra_model_outputs[key].append|extend()
. However, for certain postprocessing steps, the entirety (or a slice) of an episode’sextra_model_outputs
might have to be rewritten or a new key (a new type ofextra_model_outputs
) must be inserted, which is whenself.set_extra_model_outputs()
should be used.- Parameters:
key – The
key
withinself.extra_model_outputs
to override data on or to insert as a new key intoself.extra_model_outputs
.new_data – The new data to overwrite existing data with. This may be a list of individual reward(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be a np.ndarray with a length exactly the size of the to-be-overwritten slice or segment (provided by
at_indices
).at_indices – A single int is interpreted as one index, which to overwrite with
new_data
(which is expected to be a single reward). A list of ints is interpreted as a list of indices, all of which to overwrite withnew_data
(which is expected to be of the same size aslen(at_indices)
). A slice object is interpreted as a range of indices to be overwritten withnew_data
(which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless theneg_index_as_lookback=True
option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.neg_index_as_lookback – If True, negative values in
at_indices
are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call toset_rewards(individual_reward, -1, neg_index_as_lookback=True)
by overwriting the value of 6 in our rewards buffer with the provided “individual_reward”.
- Raises:
IndexError – If the provided
at_indices
do not match the size ofnew_data
.
- add_temporary_timestep_data(key: str, data: Any) None [source]#
Temporarily adds (until
finalized()
called) per-timestep data to self.The given
data
is appended to a list (self._temporary_timestep_data
), which is cleared upon callingself.finalize()
. To get the thus-far accumulated temporary timestep data for a certain key, use theget_temporary_timestep_data
API. Note that the size of the per timestep list is NOT checked or validated against the other, non-temporary data in this episode (like observations).- Parameters:
key – The key under which to find the list to append
data
to. Ifdata
is the first data to be added for this key, start a new list.data – The data item (representing a single timestep) to be stored.
- get_temporary_timestep_data(key: str) List[Any] [source]#
Returns all temporarily stored data items (list) under the given key.
Note that all temporary timestep data is erased/cleared when calling
self.finalize()
.- Returns:
The current list storing temporary timestep data under
key
.
- slice(slice_: slice, *, len_lookback_buffer: int | None = None) SingleAgentEpisode [source]#
Returns a slice of this episode with the given slice object.
For example, if
self
contains o0 (the reset observation), o1, o2, o3, and o4 and the actions a1, a2, a3, and a4 (len ofself
is 4), then a call toself.slice(slice(1, 3))
would return a new SingleAgentEpisode with observations o1, o2, and o3, and actions a2 and a3. Note here that there is always one observation more in an episode than there are actions (and rewards and extra model outputs) due to the initial observation received after an env reset.from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.test_utils import check # Generate a simple multi-agent episode. observations = [0, 1, 2, 3, 4, 5] actions = [1, 2, 3, 4, 5] rewards = [0.1, 0.2, 0.3, 0.4, 0.5] episode = SingleAgentEpisode( observations=observations, actions=actions, rewards=rewards, len_lookback_buffer=0, # all given data is part of the episode ) slice_1 = episode[:1] check(slice_1.observations, [0, 1]) check(slice_1.actions, [1]) check(slice_1.rewards, [0.1]) slice_2 = episode[-2:] check(slice_2.observations, [3, 4, 5]) check(slice_2.actions, [4, 5]) check(slice_2.rewards, [0.4, 0.5])
- Parameters:
slice – The slice object to use for slicing. This should exclude the lookback buffer, which will be prepended automatically to the returned slice.
len_lookback_buffer – If not None, forces the returned slice to try to have this number of timesteps in its lookback buffer (if available). If None (default), tries to make the returned slice’s lookback as large as the current lookback buffer of this episode (
self
).
- Returns:
The new SingleAgentEpisode representing the requested slice.
- get_data_dict()[source]#
Converts a SingleAgentEpisode into a data dict mapping str keys to data.
The keys used are: Columns.EPS_ID, T, OBS, INFOS, ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS, and those in
self.extra_model_outputs
.- Returns:
A data dict mapping str keys to data records.
- get_sample_batch() SampleBatch [source]#
Converts this
SingleAgentEpisode
into aSampleBatch
.- Returns:
A SampleBatch containing all of this episode’s data.
- get_return() float [source]#
Calculates an episode’s return, excluding the lookback buffer’s rewards.
The return is computed by a simple sum, neglecting the discount factor. Note that if
self
is a continuation chunk (resulting from a call toself.cut()
), the previous chunk’s rewards are NOT counted and thus NOT part of the returned reward sum.- Returns:
The sum of rewards collected during this episode, excluding possible data inside the lookback buffer and excluding possible data in a predecessor chunk.
- env_steps() int [source]#
Returns the number of environment steps.
Note, this episode instance could be a chunk of an actual episode.
- Returns:
An integer that counts the number of environment steps this episode instance has seen.
- agent_steps() int [source]#
Returns the number of agent steps.
Note, these are identical to the environment steps for a single-agent episode.
- Returns:
An integer counting the number of agent steps executed during the time this episode instance records.