ray.rllib.env.single_agent_episode.SingleAgentEpisode#

class ray.rllib.env.single_agent_episode.SingleAgentEpisode(id_: str | None = None, *, observations: List[gymnasium.core.ObsType] | InfiniteLookbackBuffer | None = None, observation_space: gymnasium.Space | None = None, infos: List[Dict] | InfiniteLookbackBuffer | None = None, actions: List[gymnasium.core.ActType] | InfiniteLookbackBuffer | None = None, action_space: gymnasium.Space | None = None, rewards: List[SupportsFloat] | InfiniteLookbackBuffer | None = None, terminated: bool = False, truncated: bool = False, extra_model_outputs: Dict[str, Any] | None = None, t_started: int | None = None, len_lookback_buffer: int | str = 'auto', agent_id: Any | None = None, module_id: str | None = None, multi_agent_episode_id: int | None = None)[source]#

A class representing RL environment episodes for individual agents.

SingleAgentEpisode stores observations, info dicts, actions, rewards, and all module outputs (e.g. state outs, action logp, etc..) for an individual agent within some single-agent or multi-agent environment. The two main APIs to add data to an ongoing episode are the add_env_reset() and add_env_step() methods, which should be called passing the outputs of the respective gym.Env API calls: env.reset() and env.step().

A SingleAgentEpisode might also only represent a chunk of an episode, which is useful for cases, in which partial (non-complete episode) sampling is performed and collected episode data has to be returned before the actual gym.Env episode has finished (see SingleAgentEpisode.cut()). In order to still maintain visibility onto past experiences within such a “cut” episode, SingleAgentEpisode instances can have a “lookback buffer” of n timesteps at their beginning (left side), which solely exists for the purpose of compiling extra data (e.g. “prev. reward”), but is not considered part of the finished/packaged episode (b/c the data in the lookback buffer is already part of a previous episode chunk).

Powerful getter methods, such as get_observations() help collect different types of data from the episode at individual time indices or time ranges, including the “lookback buffer” range described above. For example, to extract the last 4 rewards of an ongoing episode, one can call self.get_rewards(slice(-4, None)) or self.rewards[-4:]. This would work, even if the ongoing SingleAgentEpisode is a continuation chunk from a much earlier started episode, as long as it has a lookback buffer size of sufficient size.

Examples:

import gymnasium as gym
import numpy as np

from ray.rllib.env.single_agent_episode import SingleAgentEpisode

# Construct a new episode (without any data in it yet).
episode = SingleAgentEpisode()
assert len(episode) == 0

# Fill the episode with some data (10 timesteps).
env = gym.make("CartPole-v1")
obs, infos = env.reset()
episode.add_env_reset(obs, infos)

# Even with the initial obs/infos, the episode is still considered len=0.
assert len(episode) == 0
for _ in range(5):
    action = env.action_space.sample()
    obs, reward, term, trunc, infos = env.step(action)
    episode.add_env_step(
        observation=obs,
        action=action,
        reward=reward,
        terminated=term,
        truncated=trunc,
        infos=infos,
    )
assert len(episode) == 5

# We can now access information from the episode via the getter APIs.

# Get the last 3 rewards (in a batch of size 3).
episode.get_rewards(slice(-3, None))  # same as `episode.rewards[-3:]`

# Get the most recent action (single item, not batched).
# This works regardless of the action space or whether the episode has
# been numpy'ized or not (see below).
episode.get_actions(-1)  # same as episode.actions[-1]

# Looking back from ts=1, get the previous 4 rewards AND fill with 0.0
# in case we go over the beginning (ts=0). So we would expect
# [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received
# reward in the episode:
episode.get_rewards(slice(-4, 0), neg_index_as_lookback=True, fill=0.0)

# Note the use of fill=0.0 here (fill everything that's out of range with this
# value) AND the argument `neg_index_as_lookback=True`, which interprets
# negative indices as being left of ts=0 (e.g. -1 being the timestep before
# ts=0).

# Assuming we had a complex action space (nested gym.spaces.Dict) with one or
# more elements being Discrete or MultiDiscrete spaces:
# 1) The `fill=...` argument would still work, filling all spaces (Boxes,
# Discrete) with that provided value.
# 2) Setting the flag `one_hot_discrete=True` would convert those discrete
# sub-components automatically into one-hot (or multi-one-hot) tensors.
# This simplifies the task of having to provide the previous 4 (nested and
# partially discrete/multi-discrete) actions for each timestep within a training
# batch, thereby filling timesteps before the episode started with 0.0s and
# one-hot'ing the discrete/multi-discrete components in these actions:
episode = SingleAgentEpisode(action_space=gym.spaces.Dict({
    "a": gym.spaces.Discrete(3),
    "b": gym.spaces.MultiDiscrete([2, 3]),
    "c": gym.spaces.Box(-1.0, 1.0, (2,)),
}))

# ... fill episode with data ...
episode.add_env_reset(observation=0)
# ... from a few steps.
episode.add_env_step(
    observation=1,
    action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)},
    reward=1.0,
)

# In your connector
prev_4_a = []
# Note here that len(episode) does NOT include the lookback buffer.
for ts in range(len(episode)):
    prev_4_a.append(
        episode.get_actions(
            indices=slice(ts - 4, ts),
            # Make sure negative indices are interpreted as
            # "into lookback buffer"
            neg_index_as_lookback=True,
            # Zero-out everything even further before the lookback buffer.
            fill=0.0,
            # Take care of discrete components (get ready as NN input).
            one_hot_discrete=True,
        )
    )

# Finally, convert from list of batch items to a struct (same as action space)
# of batched (numpy) arrays, in which all leafs have B==len(prev_4_a).
from ray.rllib.utils.spaces.space_utils import batch

prev_4_actions_col = batch(prev_4_a)

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

__init__

Initializes a SingleAgentEpisode instance.

add_env_reset

Adds the initial data (after an env.reset()) to the episode.

add_env_step

Adds results of an env.step() call (including the action) to this episode.

add_temporary_timestep_data

Temporarily adds (until to_numpy() called) per-timestep data to self.

agent_steps

Returns the number of agent steps.

concat_episode

Adds the given other SingleAgentEpisode to the right side of self.

cut

Returns a successor episode chunk (of len=0) continuing from this Episode.

env_steps

Returns the number of environment steps.

from_state

Creates a new SingleAgentEpisode instance from a state dict.

get_actions

Returns individual actions or batched ranges thereof from this episode.

get_data_dict

Converts a SingleAgentEpisode into a data dict mapping str keys to data.

get_duration_s

Returns the duration of this Episode (chunk) in seconds.

get_extra_model_outputs

Returns extra model outputs (under given key) from this episode.

get_infos

Returns individual info dicts or list (ranges) thereof from this episode.

get_observations

Returns individual observations or batched ranges thereof from this episode.

get_return

Calculates an episode's return, excluding the lookback buffer's rewards.

get_rewards

Returns individual rewards or batched ranges thereof from this episode.

get_sample_batch

Converts this SingleAgentEpisode into a SampleBatch.

get_state

Returns the pickable state of an episode.

get_temporary_timestep_data

Returns all temporarily stored data items (list) under the given key.

set_actions

Overwrites all or some of this Episode's actions with the provided data.

set_extra_model_outputs

Overwrites all or some of this Episode's extra model outputs with new_data.

set_observations

Overwrites all or some of this Episode's observations with the provided data.

set_rewards

Overwrites all or some of this Episode's rewards with the provided data.

slice

Returns a slice of this episode with the given slice object.

to_numpy

Converts this Episode's list attributes to numpy arrays.

validate

Validates the episode's data.

Attributes

actions

agent_id

extra_model_outputs

id_

infos

is_terminated

is_truncated

module_id

multi_agent_episode_id

observations

rewards

t

t_started

action_space

is_done

Whether the episode is actually done (terminated or truncated).

is_numpy

True, if the data in this episode is already stored as numpy arrays.

is_reset

Returns True if self.add_env_reset() has already been called.

observation_space