Sample Batches

SampleBatch (ray.rllib.policy.sample_batch.SampleBatch)

Whether running in a single process or large cluster, all data interchange in RLlib happens in the form of RolloutWorker collects batches of size rollout_fragment_length, and RLlib then concatenates one or more of these batches (across different RolloutWorker in subsequent sampling steps) into a batch of size train_batch_size, which then serves as the input to a Policy’s learn_on_batch() method.

A typical sample batch looks something like the following when summarized. Since all values are kept in arrays, this allows for efficient encoding and transmission across the network:

{ 'action_logp': np.ndarray((200,), dtype=float32, min=-0.701, max=-0.685, mean=-0.694),
  'actions': np.ndarray((200,), dtype=int64, min=0.0, max=1.0, mean=0.495),
  'dones': np.ndarray((200,), dtype=bool, min=0.0, max=1.0, mean=0.055),
  'infos': np.ndarray((200,), dtype=object, head={}),
  'new_obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.018),
  'obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.016),
  'rewards': np.ndarray((200,), dtype=float32, min=1.0, max=1.0, mean=1.0),
  't': np.ndarray((200,), dtype=int64, min=0.0, max=34.0, mean=9.14)}
class ray.rllib.policy.sample_batch.SampleBatch(*args, **kwargs)[source]

Wrapper around a dictionary with string keys and array-like values.

For example, {“obs”: [1, 2, 3], “reward”: [0, -1, 1]} is a batch of three samples, each with an “obs” and “reward” attribute.

agent_steps() int[source]

Returns the same as len(self) (number of steps in this batch).

To make this compatible with MultiAgentBatch.agent_steps().

static concat_samples(samples: Union[List[ray.rllib.policy.sample_batch.SampleBatch], List[ray.rllib.policy.sample_batch.MultiAgentBatch]]) Union[ray.rllib.policy.sample_batch.SampleBatch, ray.rllib.policy.sample_batch.MultiAgentBatch][source]

Concatenates n SampleBatches or MultiAgentBatches.


samples – List of SampleBatches or MultiAgentBatches to be concatenated.


A new (concatenated) SampleBatch or MultiAgentBatch.


>>> b1 = SampleBatch({"a": np.array([1, 2]),
...                   "b": np.array([10, 11])})
>>> b2 = SampleBatch({"a": np.array([3]),
...                   "b": np.array([12])})
>>> print(SampleBatch.concat_samples([b1, b2]))
{"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
concat(other: ray.rllib.policy.sample_batch.SampleBatch) ray.rllib.policy.sample_batch.SampleBatch[source]

Concatenates other to this one and returns a new SampleBatch.


other – The other SampleBatch object to concat to this one.


The new SampleBatch, resulting from concating other to self.


>>> b1 = SampleBatch({"a": np.array([1, 2])})
>>> b2 = SampleBatch({"a": np.array([3, 4, 5])})
>>> print(b1.concat(b2))
{"a": np.array([1, 2, 3, 4, 5])}
copy(shallow: bool = False) ray.rllib.policy.sample_batch.SampleBatch[source]

Creates a deep or shallow copy of this SampleBatch and returns it.


shallow – Whether the copying should be done shallowly.


A deep or shallow copy of this SampleBatch object.

rows() Iterator[Dict[str, Any]][source]

Returns an iterator over data rows, i.e. dicts with column values.

Note that if seq_lens is set in self, we set it to [1] in the rows.


The column values of the row in this iteration.


>>> batch = SampleBatch({
...    "a": [1, 2, 3],
...    "b": [4, 5, 6],
...    "seq_lens": [1, 2]
... })
>>> for row in batch.rows():
{"a": 1, "b": 4, "seq_lens": [1]}
{"a": 2, "b": 5, "seq_lens": [1]}
{"a": 3, "b": 6, "seq_lens": [1]}
columns(keys: List[str]) List[any][source]

Returns a list of the batch-data in the specified columns.


keys – List of column names fo which to return the data.


The list of data items ordered by the order of column names in keys.


>>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
>>> print(batch.columns(["a", "b"]))
[[1], [2]]
shuffle() ray.rllib.policy.sample_batch.SampleBatch[source]

Shuffles the rows of this batch in-place.


This very (now shuffled) SampleBatch.


ValueError – If self[SampleBatch.SEQ_LENS] is defined.


>>> batch = SampleBatch({"a": [1, 2, 3, 4]})
>>> print(batch.shuffle())
{"a": [4, 1, 3, 2]}
split_by_episode() List[ray.rllib.policy.sample_batch.SampleBatch][source]

Splits by eps_id column and returns list of new batches.


List of batches, one per distinct episode.


KeyError – If the eps_id AND dones columns are not present.


>>> batch = SampleBatch({"a": [1, 2, 3], "eps_id": [0, 0, 1]})
>>> print(batch.split_by_episode())
[{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
timeslices(size: Optional[int] = None, num_slices: Optional[int] = None, k: Optional[int] = None) List[ray.rllib.policy.sample_batch.SampleBatch][source]

Returns SampleBatches, each one representing a k-slice of this one.

Will start from timestep 0 and produce slices of size=k.

  • size – The size (in timesteps) of each returned SampleBatch.

  • num_slices – The number of slices to produce.

  • k – Deprecated: Use size or num_slices instead. The size (in timesteps) of each returned SampleBatch.


The list of num_slices (new) SampleBatches or n (new) SampleBatches each one of size size.

right_zero_pad(max_seq_len: int, exclude_states: bool = True)[source]

Right (adding zeros at end) zero-pads this SampleBatch in-place.

This will set the self.zero_padded flag to True and self.max_seq_len to the given max_seq_len value.

  • max_seq_len – The max (total) length to zero pad to.

  • exclude_states – If False, also right-zero-pad all state_in_x data. If True, leave state_in_x keys as-is.


This very (now right-zero-padded) SampleBatch.


ValueError – If self[SampleBatch.SEQ_LENS] is None (not defined).


>>> batch = SampleBatch({"a": [1, 2, 3], "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=4))
{"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
>>> batch = SampleBatch({"a": [1, 2, 3],
...                      "state_in_0": [1.0, 3.0],
...                      "seq_lens": [1, 2]})
>>> print(batch.right_zero_pad(max_seq_len=5))
{"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
 "state_in_0": [1.0, 3.0],  # <- all state-ins remain as-is
 "seq_lens": [1, 2]}
to_device(device, framework='torch')[source]

TODO: transfer batch to given device as framework tensor.

size_bytes() int[source]

Returns sum over number of bytes of all data buffers.

For numpy arrays, we use .nbytes. For all other value types, we use sys.getsizeof(…).


The overall size in bytes of the data buffer (all columns).

get(key, default=None)[source]

Return the value for key if key is in the dictionary, else default.

as_multi_agent() ray.rllib.policy.sample_batch.MultiAgentBatch[source]

Returns the respective MultiAgentBatch using DEFAULT_POLICY_ID.


The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding to this SampleBatch.

compress(bulk: bool = False, columns: Set[str] = frozenset({'new_obs', 'obs'})) ray.rllib.policy.sample_batch.SampleBatch[source]

Compresses the data buffers (by column) in place.

  • bulk – Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size.

  • columns – The columns to compress. Default: Only compress the obs and new_obs columns.


This very (now compressed) SampleBatch.

decompress_if_needed(columns: Set[str] = frozenset({'new_obs', 'obs'})) ray.rllib.policy.sample_batch.SampleBatch[source]

Decompresses data buffers (per column if not compressed) in place.


columns – The columns to decompress. Default: Only decompress the obs and new_obs columns.


This very (now uncompressed) SampleBatch.

get_single_step_input_dict(view_requirements: Dict[str, ViewRequirement], index: Union[str, int] = 'last') SampleBatch[source]

Creates single ts SampleBatch at given index from self.

For usage as input-dict for model (action or value function) calls.

  • view_requirements – A view requirements dict from the model for which to produce the input_dict.

  • index – An integer index value indicating the position in the trajectory for which to generate the compute_actions input dict. Set to “last” to generate the dict at the very end of the trajectory (e.g. for value estimation). Note that “last” is different from -1, as “last” will use the final NEXT_OBS as observation input.


The (single-timestep) input dict for ModelV2 calls.

MultiAgentBatch (ray.rllib.policy.sample_batch.MultiAgentBatch)

In multi-agent mode, several sample batches may be collected separately for each individual policy and are placed in a container object of type MultiAgentBatch:

class ray.rllib.policy.sample_batch.MultiAgentBatch(policy_batches: Dict[str, ray.rllib.policy.sample_batch.SampleBatch], env_steps: int)[source]

A batch of experiences from multiple agents in the environment.


Mapping from policy ids to SampleBatches of experiences.


Dict[PolicyID, SampleBatch]


The number of env steps in this batch.



env_steps() int[source]

The number of env steps (there are >= 1 agent steps per env step).


The number of environment steps contained in this batch.

agent_steps() int[source]

The number of agent steps (there are >= 1 agent steps per env step).


The number of agent steps total in this batch.

timeslices(k: int) List[ray.rllib.policy.sample_batch.MultiAgentBatch][source]

Returns k-step batches holding data for each agent at those steps.

For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3], for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.

Calling timeslices(1) would return three MultiAgentBatches containing [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].

Calling timeslices(2) would return two MultiAgentBatches containing [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].

This method is used to implement “lockstep” replay mode. Note that this method does not guarantee each batch contains only data from a single unroll. Batches might contain data from multiple different envs.

static wrap_as_needed(policy_batches: Dict[str, ray.rllib.policy.sample_batch.SampleBatch], env_steps: int) Union[ray.rllib.policy.sample_batch.SampleBatch, ray.rllib.policy.sample_batch.MultiAgentBatch][source]

Returns SampleBatch or MultiAgentBatch, depending on given policies.

  • policy_batches – Mapping from policy ids to SampleBatch.

  • env_steps – Number of env steps in the batch.


The single default policy’s SampleBatch or a MultiAgentBatch (more than one policy).

static concat_samples(samples: List[ray.rllib.policy.sample_batch.MultiAgentBatch]) ray.rllib.policy.sample_batch.MultiAgentBatch[source]

Concatenates a list of MultiAgentBatches into a new MultiAgentBatch.


samples – List of MultiagentBatch objects to concatenate.


A new MultiAgentBatch consisting of the concatenated inputs.

copy() ray.rllib.policy.sample_batch.MultiAgentBatch[source]

Deep-copies self into a new MultiAgentBatch.


The copy of self with deep-copied data.

size_bytes() int[source]

The overall size in bytes of all policy batches (all columns).

compress(bulk: bool = False, columns: Set[str] = frozenset({'new_obs', 'obs'})) None[source]

Compresses each policy batch (per column) in place.

  • bulk – Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size.

  • columns – Set of column names to compress.

decompress_if_needed(columns: Set[str] = frozenset({'new_obs', 'obs'})) ray.rllib.policy.sample_batch.MultiAgentBatch[source]

Decompresses each policy batch (per column), if already compressed.


columns – Set of column names to decompress.



as_multi_agent() ray.rllib.policy.sample_batch.MultiAgentBatch[source]

Simply returns self (already a MultiAgentBatch).


This very instance of MultiAgentBatch.