ray.rllib.policy.sample_batch.SampleBatch.split_by_episode#

SampleBatch.split_by_episode(key: str | None = None) List[SampleBatch][source]#

Splits by eps_id column and returns list of new batches. If eps_id is not present, splits by dones instead.

Parameters:

key – If specified, overwrite default and use key to split.

Returns:

List of batches, one per distinct episode.

Raises:

KeyError – If the eps_id AND dones columns are not present.

from ray.rllib.policy.sample_batch import SampleBatch
# "eps_id" is present
batch = SampleBatch(
    {"a": [1, 2, 3], "eps_id": [0, 0, 1]})
print(batch.split_by_episode())

# "eps_id" not present, split by "dones" instead
batch = SampleBatch(
    {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]})
print(batch.split_by_episode())

# The last episode is appended even if it does not end with done
batch = SampleBatch(
    {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]})
print(batch.split_by_episode())

batch = SampleBatch(
    {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]})
print(batch.split_by_episode())
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}]
[{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}]
[{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]