Note

Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.

Note, however, that so far only PPO (single- and multi-agent) and SAC (single-agent only) support the “new API stack” and continue to run by default with the old APIs. You can continue to use the existing custom (old stack) classes.

See here for more details on how to use the new API stack.

SingleAgentEpisode API#

rllib.env.single_agent_episode.SingleAgentEpisode#

class ray.rllib.env.single_agent_episode.SingleAgentEpisode(id_: str | None = None, *, observations: List[gymnasium.core.ObsType] | InfiniteLookbackBuffer | None = None, observation_space: gymnasium.Space | None = None, infos: List[Dict] | InfiniteLookbackBuffer | None = None, actions: List[gymnasium.core.ActType] | InfiniteLookbackBuffer | None = None, action_space: gymnasium.Space | None = None, rewards: List[SupportsFloat] | InfiniteLookbackBuffer | None = None, terminated: bool = False, truncated: bool = False, extra_model_outputs: Dict[str, Any] | None = None, t_started: int | None = None, len_lookback_buffer: int | str = 'auto', agent_id: Any | None = None, module_id: str | None = None, multi_agent_episode_id: int | None = None)[source]#

A class representing RL environment episodes for individual agents.

SingleAgentEpisode stores observations, info dicts, actions, rewards, and all module outputs (e.g. state outs, action logp, etc..) for an individual agent within some single-agent or multi-agent environment. The two main APIs to add data to an ongoing episode are the add_env_reset() and add_env_step() methods, which should be called passing the outputs of the respective gym.Env API calls: env.reset() and env.step().

A SingleAgentEpisode might also only represent a chunk of an episode, which is useful for cases, in which partial (non-complete episode) sampling is performed and collected episode data has to be returned before the actual gym.Env episode has finished (see SingleAgentEpisode.cut()). In order to still maintain visibility onto past experiences within such a “cut” episode, SingleAgentEpisode instances can have a “lookback buffer” of n timesteps at their beginning (left side), which solely exists for the purpose of compiling extra data (e.g. “prev. reward”), but is not considered part of the finished/packaged episode (b/c the data in the lookback buffer is already part of a previous episode chunk).

Powerful getter methods, such as get_observations() help collect different types of data from the episode at individual time indices or time ranges, including the “lookback buffer” range described above. For example, to extract the last 4 rewards of an ongoing episode, one can call self.get_rewards(slice(-4, None)) or self.rewards[-4:]. This would work, even if the ongoing SingleAgentEpisode is a continuation chunk from a much earlier started episode, as long as it has a lookback buffer size of sufficient size.

Examples:

import gymnasium as gym
import numpy as np

from ray.rllib.env.single_agent_episode import SingleAgentEpisode

# Construct a new episode (without any data in it yet).
episode = SingleAgentEpisode()
assert len(episode) == 0

# Fill the episode with some data (10 timesteps).
env = gym.make("CartPole-v1")
obs, infos = env.reset()
episode.add_env_reset(obs, infos)

# Even with the initial obs/infos, the episode is still considered len=0.
assert len(episode) == 0
for _ in range(5):
    action = env.action_space.sample()
    obs, reward, term, trunc, infos = env.step(action)
    episode.add_env_step(
        observation=obs,
        action=action,
        reward=reward,
        terminated=term,
        truncated=trunc,
        infos=infos,
    )
assert len(episode) == 5

# We can now access information from the episode via the getter APIs.

# Get the last 3 rewards (in a batch of size 3).
episode.get_rewards(slice(-3, None))  # same as `episode.rewards[-3:]`

# Get the most recent action (single item, not batched).
# This works regardless of the action space or whether the episode has
# been finalized or not (see below).
episode.get_actions(-1)  # same as episode.actions[-1]

# Looking back from ts=1, get the previous 4 rewards AND fill with 0.0
# in case we go over the beginning (ts=0). So we would expect
# [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received
# reward in the episode:
episode.get_rewards(slice(-4, 0), neg_index_as_lookback=True, fill=0.0)

# Note the use of fill=0.0 here (fill everything that's out of range with this
# value) AND the argument `neg_index_as_lookback=True`, which interprets
# negative indices as being left of ts=0 (e.g. -1 being the timestep before
# ts=0).

# Assuming we had a complex action space (nested gym.spaces.Dict) with one or
# more elements being Discrete or MultiDiscrete spaces:
# 1) The `fill=...` argument would still work, filling all spaces (Boxes,
# Discrete) with that provided value.
# 2) Setting the flag `one_hot_discrete=True` would convert those discrete
# sub-components automatically into one-hot (or multi-one-hot) tensors.
# This simplifies the task of having to provide the previous 4 (nested and
# partially discrete/multi-discrete) actions for each timestep within a training
# batch, thereby filling timesteps before the episode started with 0.0s and
# one-hot'ing the discrete/multi-discrete components in these actions:
episode = SingleAgentEpisode(action_space=gym.spaces.Dict({
    "a": gym.spaces.Discrete(3),
    "b": gym.spaces.MultiDiscrete([2, 3]),
    "c": gym.spaces.Box(-1.0, 1.0, (2,)),
}))

# ... fill episode with data ...
episode.add_env_reset(observation=0)
# ... from a few steps.
episode.add_env_step(
    observation=1,
    action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)},
    reward=1.0,
)

# In your connector
prev_4_a = []
# Note here that len(episode) does NOT include the lookback buffer.
for ts in range(len(episode)):
    prev_4_a.append(
        episode.get_actions(
            indices=slice(ts - 4, ts),
            # Make sure negative indices are interpreted as
            # "into lookback buffer"
            neg_index_as_lookback=True,
            # Zero-out everything even further before the lookback buffer.
            fill=0.0,
            # Take care of discrete components (get ready as NN input).
            one_hot_discrete=True,
        )
    )

# Finally, convert from list of batch items to a struct (same as action space)
# of batched (numpy) arrays, in which all leafs have B==len(prev_4_a).
from ray.rllib.utils.spaces.space_utils import batch

prev_4_actions_col = batch(prev_4_a)

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

add_env_reset(observation: gymnasium.core.ObsType, infos: Dict | None = None) None[source]#

Adds the initial data (after an env.reset()) to the episode.

This data consists of initial observations and initial infos.

Parameters:
  • observation – The initial observation returned by env.reset().

  • infos – An (optional) info dict returned by env.reset().

add_env_step(observation: gymnasium.core.ObsType, action: gymnasium.core.ActType, reward: SupportsFloat, infos: Dict[str, Any] | None = None, *, terminated: bool = False, truncated: bool = False, extra_model_outputs: Dict[str, Any] | None = None) None[source]#

Adds results of an env.step() call (including the action) to this episode.

This data consists of an observation and info dict, an action, a reward, terminated/truncated flags, and extra model outputs (e.g. action probabilities or RNN internal state outputs).

Parameters:
  • observation – The next observation received from the environment after(!) taking action.

  • action – The last action used by the agent during the call to env.step().

  • reward – The last reward received by the agent after taking action.

  • infos – The last info received from the environment after taking action.

  • terminated – A boolean indicating, if the environment has been terminated (after taking action).

  • truncated – A boolean indicating, if the environment has been truncated (after taking action).

  • extra_model_outputs – The last timestep’s specific model outputs. These are normally outputs of an RLModule that were computed along with action, e.g. action_logp or action_dist_inputs.

validate() None[source]#

Validates the episode’s data.

This function ensures that the data stored to a SingleAgentEpisode is in order (e.g. that the correct number of observations, actions, rewards are there).

property is_finalized: bool#

True, if the data in this episode is already stored as numpy arrays.

property is_done: bool#

Whether the episode is actually done (terminated or truncated).

A done episode cannot be continued via self.add_timestep() or being concatenated on its right-side with another episode chunk or being succeeded via self.create_successor().

finalize() SingleAgentEpisode[source]#

Converts this Episode’s list attributes to numpy arrays.

This means in particular that this episodes’ lists of (possibly complex) data (e.g. if we have a dict obs space) will be converted to (possibly complex) structs, whose leafs are now numpy arrays. Each of these leaf numpy arrays will have the same length (batch dimension) as the length of the original lists.

Note that Columns.INFOS are NEVER numpy’ized and will remain a list (normally, a list of the original, env-returned dicts). This is due to the herterogenous nature of INFOS returned by envs, which would make it unwieldy to convert this information to numpy arrays.

After calling this method, no further data may be added to this episode via the self.add_env_step() method.

Examples:

import numpy as np

from ray.rllib.env.single_agent_episode import SingleAgentEpisode

episode = SingleAgentEpisode(
    observations=[0, 1, 2, 3],
    actions=[1, 2, 3],
    rewards=[1, 2, 3],
    # Note: terminated/truncated have nothing to do with an episode
    # being `finalized` or not (via the `self.finalize()` method)!
    terminated=False,
    len_lookback_buffer=0,  # no lookback; all data is actually "in" episode
)
# Episode has not been finalized (numpy'ized) yet.
assert not episode.is_finalized
# We are still operating on lists.
assert episode.get_observations([1]) == [1]
assert episode.get_observations(slice(None, 2)) == [0, 1]
# We can still add data (and even add the terminated=True flag).
episode.add_env_step(
    observation=4,
    action=4,
    reward=4,
    terminated=True,
)
# Still NOT finalized.
assert not episode.is_finalized

# Let's finalize the episode.
episode.finalize()
assert episode.is_finalized

# We cannot add data anymore. The following would crash.
# episode.add_env_step(observation=5, action=5, reward=5)

# Everything is now numpy arrays (with 0-axis of size
# B=[len of requested slice]).
assert isinstance(episode.get_observations([1]), np.ndarray)  # B=1
assert isinstance(episode.actions[0:2], np.ndarray)  # B=2
assert isinstance(episode.rewards[1:4], np.ndarray)  # B=3
Returns:

This SingleAgentEpisode object with the converted numpy data.

concat_episode(other: SingleAgentEpisode) None[source]#

Adds the given other SingleAgentEpisode to the right side of self.

In order for this to work, both chunks (self and other) must fit together. This is checked by the IDs (must be identical), the time step counters (self.env_t must be the same as episode_chunk.env_t_started), as well as the observations/infos at the concatenation boundaries. Also, self.is_done must not be True, meaning self.is_terminated and self.is_truncated are both False.

Parameters:

other – The other SingleAgentEpisode to be concatenated to this one.

Returns: A SingleAgentEpisode instance containing the concatenated data

from both episodes (self and other).

cut(len_lookback_buffer: int = 0) SingleAgentEpisode[source]#

Returns a successor episode chunk (of len=0) continuing from this Episode.

The successor will have the same ID as self. If no lookback buffer is requested (len_lookback_buffer=0), the successor’s observations will be the last observation(s) of self and its length will therefore be 0 (no further steps taken yet). If len_lookback_buffer > 0, the returned successor will have len_lookback_buffer observations (and actions, rewards, etc..) taken from the right side (end) of self. For example if len_lookback_buffer=2, the returned successor’s lookback buffer actions will be identical to self.actions[-2:].

This method is useful if you would like to discontinue building an episode chunk (b/c you have to return it from somewhere), but would like to have a new episode instance to continue building the actual gym.Env episode at a later time. Vie the len_lookback_buffer argument, the continuing chunk (successor) will still be able to “look back” into this predecessor episode’s data (at least to some extend, depending on the value of len_lookback_buffer).

Parameters:

len_lookback_buffer – The number of timesteps to take along into the new chunk as “lookback buffer”. A lookback buffer is additional data on the left side of the actual episode data for visibility purposes (but without actually being part of the new chunk). For example, if self ends in actions 5, 6, 7, and 8, and we call self.cut(len_lookback_buffer=2), the returned chunk will have actions 7 and 8 already in it, but still t_started`==t==8 (not 7!) and a length of 0. If there is not enough data in `self yet to fulfil the len_lookback_buffer request, the value of len_lookback_buffer is automatically adjusted (lowered).

Returns:

The successor Episode chunk of this one with the same ID and state and the only observation being the last observation in self.

get_observations(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None, one_hot_discrete: bool = False) Any[source]#

Returns individual observations or batched ranges thereof from this episode.

Parameters:
  • indices – A single int is interpreted as an index, from which to return the individual observation stored at this index. A list of ints is interpreted as a list of indices from which to gather individual observations in a batch of size len(indices). A slice object is interpreted as a range of observations to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all observations (from ts=0 to the end).

  • neg_index_as_lookback – If True, negative values in indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with observations [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond to get_observations(-1, neg_index_as_lookback=True) with 6 and to get_observations(slice(-2, 1), neg_index_as_lookback=True) with [5, 6,  7].

  • fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with observations [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning observations 10 and 11 are part of the lookback buffer) will respond to get_observations(slice(-7, -2), fill=0.0) with [0.0, 0.0, 10, 11, 12].

  • one_hot_discrete – If True, will return one-hot vectors (instead of int-values) for those sub-components of a (possibly complex) observation space that are Discrete or MultiDiscrete. Note that if fill=0 and the requested indices are out of the range of our data, the returned one-hot vectors will actually be zero-hot (all slots zero).

Examples:

import gymnasium as gym

from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.test_utils import check

episode = SingleAgentEpisode(
    # Discrete(4) observations (ints between 0 and 4 (excl.))
    observation_space=gym.spaces.Discrete(4),
    observations=[0, 1, 2, 3],
    actions=[1, 2, 3], rewards=[1, 2, 3],  # <- not relevant for this demo
    len_lookback_buffer=0,  # no lookback; all data is actually "in" episode
)
# Plain usage (`indices` arg only).
check(episode.get_observations(-1), 3)
check(episode.get_observations(0), 0)
check(episode.get_observations([0, 2]), [0, 2])
check(episode.get_observations([-1, 0]), [3, 0])
check(episode.get_observations(slice(None, 2)), [0, 1])
check(episode.get_observations(slice(-2, None)), [2, 3])
# Using `fill=...` (requesting slices beyond the boundaries).
check(episode.get_observations(slice(-6, -2), fill=-9), [-9, -9, 0, 1])
check(episode.get_observations(slice(2, 5), fill=-7), [2, 3, -7])
# Using `one_hot_discrete=True`.
check(episode.get_observations(2, one_hot_discrete=True), [0, 0, 1, 0])
check(episode.get_observations(3, one_hot_discrete=True), [0, 0, 0, 1])
check(episode.get_observations(
    slice(0, 3),
    one_hot_discrete=True,
), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]])
# Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
check(episode.get_observations(
    -1,
    neg_index_as_lookback=True,  # -1 means one left of ts=0
    fill=0.0,
    one_hot_discrete=True,
), [0, 0, 0, 0])  # <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
Returns:

The collected observations. As a 0-axis batch, if there are several indices or a list of exactly one index provided OR indices is a slice object. As single item (B=0 -> no additional 0-axis) if indices is a single int.

get_infos(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None) Any[source]#

Returns individual info dicts or list (ranges) thereof from this episode.

Parameters:
  • indices – A single int is interpreted as an index, from which to return the individual info dict stored at this index. A list of ints is interpreted as a list of indices from which to gather individual info dicts in a list of size len(indices). A slice object is interpreted as a range of info dicts to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all infos (from ts=0 to the end).

  • neg_index_as_lookback – If True, negative values in indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with infos [{“l”:4}, {“l”:5}, {“l”:6}, {“a”:7}, {“b”:8}, {“c”:9}], where the first 3 items are the lookback buffer (ts=0 item is {“a”: 7}), will respond to get_infos(-1, neg_index_as_lookback=True) with {"l":6} and to get_infos(slice(-2, 1), neg_index_as_lookback=True) with [{"l":5}, {"l":6},  {"a":7}].

  • fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to auto-fill. For example, an episode with infos [{“l”:10}, {“l”:11}, {“a”:12}, {“b”:13}, {“c”:14}] and lookback buffer size of 2 (meaning infos {“l”:10}, {“l”:11} are part of the lookback buffer) will respond to get_infos(slice(-7, -2), fill={"o": 0.0}) with [{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}].

Examples:

from ray.rllib.env.single_agent_episode import SingleAgentEpisode

episode = SingleAgentEpisode(
    infos=[{"a":0}, {"b":1}, {"c":2}, {"d":3}],
    # The following is needed, but not relevant for this demo.
    observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
    len_lookback_buffer=0,  # no lookback; all data is actually "in" episode
)
# Plain usage (`indices` arg only).
episode.get_infos(-1)  # {"d":3}
episode.get_infos(0)  # {"a":0}
episode.get_infos([0, 2])  # [{"a":0},{"c":2}]
episode.get_infos([-1, 0])  # [{"d":3},{"a":0}]
episode.get_infos(slice(None, 2))  # [{"a":0},{"b":1}]
episode.get_infos(slice(-2, None))  # [{"c":2},{"d":3}]
# Using `fill=...` (requesting slices beyond the boundaries).
# TODO (sven): This would require a space being provided. Maybe we can
#  skip this check for infos, which don't have a space anyways.
# episode.get_infos(slice(-5, -3), fill={"o":-1})  # [{"o":-1},{"a":0}]
# episode.get_infos(slice(3, 5), fill={"o":-2})  # [{"d":3},{"o":-2}]
Returns:

The collected info dicts. As a 0-axis batch, if there are several indices or a list of exactly one index provided OR indices is a slice object. As single item (B=0 -> no additional 0-axis) if indices is a single int.

get_actions(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None, one_hot_discrete: bool = False) Any[source]#

Returns individual actions or batched ranges thereof from this episode.

Parameters:
  • indices – A single int is interpreted as an index, from which to return the individual action stored at this index. A list of ints is interpreted as a list of indices from which to gather individual actions in a batch of size len(indices). A slice object is interpreted as a range of actions to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all actions (from ts=0 to the end).

  • neg_index_as_lookback – If True, negative values in indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with actions [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond to get_actions(-1, neg_index_as_lookback=True) with 6 and to get_actions(slice(-2, 1), neg_index_as_lookback=True) with [5, 6,  7].

  • fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with actions [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning actions 10 and 11 are part of the lookback buffer) will respond to get_actions(slice(-7, -2), fill=0.0) with [0.0, 0.0, 10, 11, 12].

  • one_hot_discrete – If True, will return one-hot vectors (instead of int-values) for those sub-components of a (possibly complex) action space that are Discrete or MultiDiscrete. Note that if fill=0 and the requested indices are out of the range of our data, the returned one-hot vectors will actually be zero-hot (all slots zero).

Examples:

import gymnasium as gym
from ray.rllib.env.single_agent_episode import SingleAgentEpisode

episode = SingleAgentEpisode(
    # Discrete(4) actions (ints between 0 and 4 (excl.))
    action_space=gym.spaces.Discrete(4),
    actions=[1, 2, 3],
    observations=[0, 1, 2, 3], rewards=[1, 2, 3],  # <- not relevant here
    len_lookback_buffer=0,  # no lookback; all data is actually "in" episode
)
# Plain usage (`indices` arg only).
episode.get_actions(-1)  # 3
episode.get_actions(0)  # 1
episode.get_actions([0, 2])  # [1, 3]
episode.get_actions([-1, 0])  # [3, 1]
episode.get_actions(slice(None, 2))  # [1, 2]
episode.get_actions(slice(-2, None))  # [2, 3]
# Using `fill=...` (requesting slices beyond the boundaries).
episode.get_actions(slice(-5, -2), fill=-9)  # [-9, -9, 1, 2]
episode.get_actions(slice(1, 5), fill=-7)  # [2, 3, -7, -7]
# Using `one_hot_discrete=True`.
episode.get_actions(1, one_hot_discrete=True)  # [0 0 1 0] (action=2)
episode.get_actions(2, one_hot_discrete=True)  # [0 0 0 1] (action=3)
episode.get_actions(
    slice(0, 2),
    one_hot_discrete=True,
)   # [[0 1 0 0], [0 0 0 1]] (actions=1 and 3)
# Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
episode.get_actions(
    -1,
    neg_index_as_lookback=True,  # -1 means one left of ts=0
    fill=0.0,
    one_hot_discrete=True,
)  # [0 0 0 0]  <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
Returns:

The collected actions. As a 0-axis batch, if there are several indices or a list of exactly one index provided OR indices is a slice object. As single item (B=0 -> no additional 0-axis) if indices is a single int.

get_rewards(indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: float | None = None) Any[source]#

Returns individual rewards or batched ranges thereof from this episode.

Parameters:
  • indices – A single int is interpreted as an index, from which to return the individual reward stored at this index. A list of ints is interpreted as a list of indices from which to gather individual rewards in a batch of size len(indices). A slice object is interpreted as a range of rewards to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all rewards (from ts=0 to the end).

  • neg_index_as_lookback – Negative values in indices are interpreted as as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with rewards [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond to get_rewards(-1, neg_index_as_lookback=True) with 6 and to get_rewards(slice(-2, 1), neg_index_as_lookback=True) with [5, 6,  7].

  • fill – An optional float value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with rewards [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning rewards 10 and 11 are part of the lookback buffer) will respond to get_rewards(slice(-7, -2), fill=0.0) with [0.0, 0.0, 10, 11, 12].

Examples:

from ray.rllib.env.single_agent_episode import SingleAgentEpisode

episode = SingleAgentEpisode(
    rewards=[1.0, 2.0, 3.0],
    observations=[0, 1, 2, 3], actions=[1, 2, 3],  # <- not relevant here
    len_lookback_buffer=0,  # no lookback; all data is actually "in" episode
)
# Plain usage (`indices` arg only).
episode.get_rewards(-1)  # 3.0
episode.get_rewards(0)  # 1.0
episode.get_rewards([0, 2])  # [1.0, 3.0]
episode.get_rewards([-1, 0])  # [3.0, 1.0]
episode.get_rewards(slice(None, 2))  # [1.0, 2.0]
episode.get_rewards(slice(-2, None))  # [2.0, 3.0]
# Using `fill=...` (requesting slices beyond the boundaries).
episode.get_rewards(slice(-5, -2), fill=0.0)  # [0.0, 0.0, 1.0, 2.0]
episode.get_rewards(slice(1, 5), fill=0.0)  # [2.0, 3.0, 0.0, 0.0]
Returns:

The collected rewards. As a 0-axis batch, if there are several indices or a list of exactly one index provided OR indices is a slice object. As single item (B=0 -> no additional 0-axis) if indices is a single int.

get_extra_model_outputs(key: str, indices: int | slice | List[int] | None = None, *, neg_index_as_lookback: bool = False, fill: Any | None = None) Any[source]#

Returns extra model outputs (under given key) from this episode.

Parameters:
  • key – The key within self.extra_model_outputs to extract data for.

  • indices – A single int is interpreted as an index, from which to return an individual extra model output stored under key at index. A list of ints is interpreted as a list of indices from which to gather individual actions in a batch of size len(indices). A slice object is interpreted as a range of extra model outputs to be returned. Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. If None, will return all extra model outputs (from ts=0 to the end).

  • neg_index_as_lookback – If True, negative values in indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with extra_model_outputs[‘a’] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond to get_extra_model_outputs("a", -1, neg_index_as_lookback=True) with 6 and to get_extra_model_outputs("a", slice(-2, 1), neg_index_as_lookback=True) with [5, 6,  7].

  • fill – An optional value to use for filling up the returned results at the boundaries. This filling only happens if the requested index range’s start/stop boundaries exceed the episode’s boundaries (including the lookback buffer on the left side). This comes in very handy, if users don’t want to worry about reaching such boundaries and want to zero-pad. For example, an episode with extra_model_outputs[“b”] = [10, 11, 12, 13, 14] and lookback buffer size of 2 (meaning 10 and 11 are part of the lookback buffer) will respond to get_extra_model_outputs("b", slice(-7, -2), fill=0.0) with [0.0, 0.0, 10, 11, 12]. TODO (sven): This would require a space being provided. Maybe we can automatically infer the space from existing data?

Examples:

from ray.rllib.env.single_agent_episode import SingleAgentEpisode

episode = SingleAgentEpisode(
    extra_model_outputs={"mo": [1, 2, 3]},
    len_lookback_buffer=0,  # no lookback; all data is actually "in" episode
    # The following is needed, but not relevant for this demo.
    observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
)

# Plain usage (`indices` arg only).
episode.get_extra_model_outputs("mo", -1)  # 3
episode.get_extra_model_outputs("mo", 1)  # 0
episode.get_extra_model_outputs("mo", [0, 2])  # [1, 3]
episode.get_extra_model_outputs("mo", [-1, 0])  # [3, 1]
episode.get_extra_model_outputs("mo", slice(None, 2))  # [1, 2]
episode.get_extra_model_outputs("mo", slice(-2, None))  # [2, 3]
# Using `fill=...` (requesting slices beyond the boundaries).
# TODO (sven): This would require a space being provided. Maybe we can
#  automatically infer the space from existing data?
# episode.get_extra_model_outputs("mo", slice(-5, -2), fill=0)  # [0, 0, 1]
# episode.get_extra_model_outputs("mo", slice(2, 5), fill=-1)  # [3, -1, -1]
Returns:

The collected extra_model_outputs[key]. As a 0-axis batch, if there are several indices or a list of exactly one index provided OR indices is a slice object. As single item (B=0 -> no additional 0-axis) if indices is a single int.

set_observations(*, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None[source]#

Overwrites all or some of this Episode’s observations with the provided data.

Note that an episode’s observation data cannot be written to directly as it is managed by a InfiniteLookbackBuffer object. Normally, individual, current observations are added to the episode either by calling self.add_env_step or more directly (and manually) via self.observations.append|extend(). However, for certain postprocessing steps, the entirety (or a slice) of an episode’s observations might have to be rewritten, which is when self.set_observations() should be used.

Parameters:
  • new_data – The new observation data to overwrite existing data with. This may be a list of individual observation(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be (possibly complex) struct matching the observation space and with a batch size of its leafs exactly the size of the to-be-overwritten slice or segment (provided by at_indices).

  • at_indices – A single int is interpreted as one index, which to overwrite with new_data (which is expected to be a single observation). A list of ints is interpreted as a list of indices, all of which to overwrite with new_data (which is expected to be of the same size as len(at_indices)). A slice object is interpreted as a range of indices to be overwritten with new_data (which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.

  • neg_index_as_lookback – If True, negative values in at_indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with observations = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call to set_observations(individual_observation, -1, neg_index_as_lookback=True) by overwriting the value of 6 in our observations buffer with the provided “individual_observation”.

Raises:

IndexError – If the provided at_indices do not match the size of new_data.

set_actions(*, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None[source]#

Overwrites all or some of this Episode’s actions with the provided data.

Note that an episode’s action data cannot be written to directly as it is managed by a InfiniteLookbackBuffer object. Normally, individual, current actions are added to the episode either by calling self.add_env_step or more directly (and manually) via self.actions.append|extend(). However, for certain postprocessing steps, the entirety (or a slice) of an episode’s actions might have to be rewritten, which is when self.set_actions() should be used.

Parameters:
  • new_data – The new action data to overwrite existing data with. This may be a list of individual action(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be (possibly complex) struct matching the action space and with a batch size of its leafs exactly the size of the to-be-overwritten slice or segment (provided by at_indices).

  • at_indices – A single int is interpreted as one index, which to overwrite with new_data (which is expected to be a single action). A list of ints is interpreted as a list of indices, all of which to overwrite with new_data (which is expected to be of the same size as len(at_indices)). A slice object is interpreted as a range of indices to be overwritten with new_data (which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.

  • neg_index_as_lookback – If True, negative values in at_indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with actions = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call to set_actions(individual_action, -1, neg_index_as_lookback=True) by overwriting the value of 6 in our actions buffer with the provided “individual_action”.

Raises:

IndexError – If the provided at_indices do not match the size of new_data.

set_rewards(*, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None[source]#

Overwrites all or some of this Episode’s rewards with the provided data.

Note that an episode’s reward data cannot be written to directly as it is managed by a InfiniteLookbackBuffer object. Normally, individual, current rewards are added to the episode either by calling self.add_env_step or more directly (and manually) via self.rewards.append|extend(). However, for certain postprocessing steps, the entirety (or a slice) of an episode’s rewards might have to be rewritten, which is when self.set_rewards() should be used.

Parameters:
  • new_data – The new reward data to overwrite existing data with. This may be a list of individual reward(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be a np.ndarray with a length exactly the size of the to-be-overwritten slice or segment (provided by at_indices).

  • at_indices – A single int is interpreted as one index, which to overwrite with new_data (which is expected to be a single reward). A list of ints is interpreted as a list of indices, all of which to overwrite with new_data (which is expected to be of the same size as len(at_indices)). A slice object is interpreted as a range of indices to be overwritten with new_data (which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.

  • neg_index_as_lookback – If True, negative values in at_indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call to set_rewards(individual_reward, -1, neg_index_as_lookback=True) by overwriting the value of 6 in our rewards buffer with the provided “individual_reward”.

Raises:

IndexError – If the provided at_indices do not match the size of new_data.

set_extra_model_outputs(*, key, new_data, at_indices: int | slice | List[int] | None = None, neg_index_as_lookback: bool = False) None[source]#

Overwrites all or some of this Episode’s extra model outputs with new_data.

Note that an episode’s extra_model_outputs data cannot be written to directly as it is managed by a InfiniteLookbackBuffer object. Normally, individual, current extra_model_output values are added to the episode either by calling self.add_env_step or more directly (and manually) via self.extra_model_outputs[key].append|extend(). However, for certain postprocessing steps, the entirety (or a slice) of an episode’s extra_model_outputs might have to be rewritten or a new key (a new type of extra_model_outputs) must be inserted, which is when self.set_extra_model_outputs() should be used.

Parameters:
  • key – The key within self.extra_model_outputs to override data on or to insert as a new key into self.extra_model_outputs.

  • new_data – The new data to overwrite existing data with. This may be a list of individual reward(s) in case this episode is still not finalized yet. In case this episode has already been finalized, this should be a np.ndarray with a length exactly the size of the to-be-overwritten slice or segment (provided by at_indices).

  • at_indices – A single int is interpreted as one index, which to overwrite with new_data (which is expected to be a single reward). A list of ints is interpreted as a list of indices, all of which to overwrite with new_data (which is expected to be of the same size as len(at_indices)). A slice object is interpreted as a range of indices to be overwritten with new_data (which is expected to be of the same size as the provided slice). Thereby, negative indices by default are interpreted as “before the end” unless the neg_index_as_lookback=True option is used, in which case negative indices are interpreted as “before ts=0”, meaning going back into the lookback buffer.

  • neg_index_as_lookback – If True, negative values in at_indices are interpreted as “before ts=0”, meaning going back into the lookback buffer. For example, an episode with rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will handle a call to set_rewards(individual_reward, -1, neg_index_as_lookback=True) by overwriting the value of 6 in our rewards buffer with the provided “individual_reward”.

Raises:

IndexError – If the provided at_indices do not match the size of new_data.

add_temporary_timestep_data(key: str, data: Any) None[source]#

Temporarily adds (until finalized() called) per-timestep data to self.

The given data is appended to a list (self._temporary_timestep_data), which is cleared upon calling self.finalize(). To get the thus-far accumulated temporary timestep data for a certain key, use the get_temporary_timestep_data API. Note that the size of the per timestep list is NOT checked or validated against the other, non-temporary data in this episode (like observations).

Parameters:
  • key – The key under which to find the list to append data to. If data is the first data to be added for this key, start a new list.

  • data – The data item (representing a single timestep) to be stored.

get_temporary_timestep_data(key: str) List[Any][source]#

Returns all temporarily stored data items (list) under the given key.

Note that all temporary timestep data is erased/cleared when calling self.finalize().

Returns:

The current list storing temporary timestep data under key.

slice(slice_: slice, *, len_lookback_buffer: int | None = None) SingleAgentEpisode[source]#

Returns a slice of this episode with the given slice object.

For example, if self contains o0 (the reset observation), o1, o2, o3, and o4 and the actions a1, a2, a3, and a4 (len of self is 4), then a call to self.slice(slice(1, 3)) would return a new SingleAgentEpisode with observations o1, o2, and o3, and actions a2 and a3. Note here that there is always one observation more in an episode than there are actions (and rewards and extra model outputs) due to the initial observation received after an env reset.

from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.test_utils import check

# Generate a simple multi-agent episode.
observations = [0, 1, 2, 3, 4, 5]
actions = [1, 2, 3, 4, 5]
rewards = [0.1, 0.2, 0.3, 0.4, 0.5]
episode = SingleAgentEpisode(
    observations=observations,
    actions=actions,
    rewards=rewards,
    len_lookback_buffer=0,  # all given data is part of the episode
)
slice_1 = episode[:1]
check(slice_1.observations, [0, 1])
check(slice_1.actions, [1])
check(slice_1.rewards, [0.1])

slice_2 = episode[-2:]
check(slice_2.observations, [3, 4, 5])
check(slice_2.actions, [4, 5])
check(slice_2.rewards, [0.4, 0.5])
Parameters:
  • slice – The slice object to use for slicing. This should exclude the lookback buffer, which will be prepended automatically to the returned slice.

  • len_lookback_buffer – If not None, forces the returned slice to try to have this number of timesteps in its lookback buffer (if available). If None (default), tries to make the returned slice’s lookback as large as the current lookback buffer of this episode (self).

Returns:

The new SingleAgentEpisode representing the requested slice.

get_data_dict()[source]#

Converts a SingleAgentEpisode into a data dict mapping str keys to data.

The keys used are: Columns.EPS_ID, T, OBS, INFOS, ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS, and those in self.extra_model_outputs.

Returns:

A data dict mapping str keys to data records.

get_sample_batch() SampleBatch[source]#

Converts this SingleAgentEpisode into a SampleBatch.

Returns:

A SampleBatch containing all of this episode’s data.

get_return() float[source]#

Calculates an episode’s return, excluding the lookback buffer’s rewards.

The return is computed by a simple sum, neglecting the discount factor. Note that if self is a continuation chunk (resulting from a call to self.cut()), the previous chunk’s rewards are NOT counted and thus NOT part of the returned reward sum.

Returns:

The sum of rewards collected during this episode, excluding possible data inside the lookback buffer and excluding possible data in a predecessor chunk.

get_duration_s() float[source]#

Returns the duration of this Episode (chunk) in seconds.

env_steps() int[source]#

Returns the number of environment steps.

Note, this episode instance could be a chunk of an actual episode.

Returns:

An integer that counts the number of environment steps this episode instance has seen.

agent_steps() int[source]#

Returns the number of agent steps.

Note, these are identical to the environment steps for a single-agent episode.

Returns:

An integer counting the number of agent steps executed during the time this episode instance records.

get_state() Dict[str, Any][source]#

Returns the pickable state of an episode.

The data in the episode is stored into a dictionary. Note that episodes can also be generated from states (see SingleAgentEpisode.from_state()).

Returns:

A dict containing all the data from the episode.

static from_state(state: Dict[str, Any]) SingleAgentEpisode[source]#

Creates a new SingleAgentEpisode instance from a state dict.

Parameters:

state – The state dict, as returned by self.get_state().

Returns:

A new SingleAgentEpisode instance with the data from the state dict.