Sample Batches
Contents
Sample Batches#
SampleBatch (ray.rllib.policy.sample_batch.SampleBatch)#
Whether running in a single process or large cluster,
all data interchange in RLlib happens in the form of RolloutWorker
collects batches of size
rollout_fragment_length
, and RLlib then concatenates one or more of these batches (across different
RolloutWorker
in subsequent sampling steps) into a batch of size
train_batch_size
, which then serves as the input to a Policy
’s learn_on_batch()
method.
A typical sample batch looks something like the following when summarized. Since all values are kept in arrays, this allows for efficient encoding and transmission across the network:
{ 'action_logp': np.ndarray((200,), dtype=float32, min=-0.701, max=-0.685, mean=-0.694),
'actions': np.ndarray((200,), dtype=int64, min=0.0, max=1.0, mean=0.495),
'dones': np.ndarray((200,), dtype=bool, min=0.0, max=1.0, mean=0.055),
'infos': np.ndarray((200,), dtype=object, head={}),
'new_obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.018),
'obs': np.ndarray((200, 4), dtype=float32, min=-2.46, max=2.259, mean=0.016),
'rewards': np.ndarray((200,), dtype=float32, min=1.0, max=1.0, mean=1.0),
't': np.ndarray((200,), dtype=int64, min=0.0, max=34.0, mean=9.14)}
- class ray.rllib.policy.sample_batch.SampleBatch(*args, **kwargs)[source]#
Wrapper around a dictionary with string keys and array-like values.
For example, {“obs”: [1, 2, 3], “reward”: [0, -1, 1]} is a batch of three samples, each with an “obs” and “reward” attribute.
- agent_steps() int [source]#
Returns the same as len(self) (number of steps in this batch).
To make this compatible with
MultiAgentBatch.agent_steps()
.
- env_steps() int [source]#
Returns the same as len(self) (number of steps in this batch).
To make this compatible with
MultiAgentBatch.env_steps()
.
- is_terminated_or_truncated() bool [source]#
Returns True if
self
is either terminated or truncated at idx -1.
- is_single_trajectory() bool [source]#
Returns True if this SampleBatch only contains one trajectory.
This is determined by checking all timesteps (except for the last) for being not terminated AND (if applicable) not truncated.
- concat(other: ray.rllib.policy.sample_batch.SampleBatch) ray.rllib.policy.sample_batch.SampleBatch [source]#
Concatenates
other
to this one and returns a new SampleBatch.- Parameters
other – The other SampleBatch object to concat to this one.
- Returns
The new SampleBatch, resulting from concating
other
toself
.
Examples
>>> import numpy as np >>> from ray.rllib.policy.sample_batch import SampleBatch >>> b1 = SampleBatch({"a": np.array([1, 2])}) >>> b2 = SampleBatch({"a": np.array([3, 4, 5])}) >>> print(b1.concat(b2)) {"a": np.array([1, 2, 3, 4, 5])}
- copy(shallow: bool = False) ray.rllib.policy.sample_batch.SampleBatch [source]#
Creates a deep or shallow copy of this SampleBatch and returns it.
- Parameters
shallow – Whether the copying should be done shallowly.
- Returns
A deep or shallow copy of this SampleBatch object.
- rows() Iterator[Dict[str, Union[numpy.array, tensorflow.python.framework.ops.Tensor, torch.Tensor]]] [source]#
Returns an iterator over data rows, i.e. dicts with column values.
Note that if
seq_lens
is set in self, we set it to 1 in the rows.- Yields
The column values of the row in this iteration.
Examples
>>> from ray.rllib.policy.sample_batch import SampleBatch >>> batch = SampleBatch({ ... "a": [1, 2, 3], ... "b": [4, 5, 6], ... "seq_lens": [1, 2] ... }) >>> for row in batch.rows(): ... print(row) {"a": 1, "b": 4, "seq_lens": 1} {"a": 2, "b": 5, "seq_lens": 1} {"a": 3, "b": 6, "seq_lens": 1}
- columns(keys: List[str]) List[any] [source]#
Returns a list of the batch-data in the specified columns.
- Parameters
keys – List of column names fo which to return the data.
- Returns
The list of data items ordered by the order of column names in
keys
.
Examples
>>> from ray.rllib.policy.sample_batch import SampleBatch >>> batch = SampleBatch({"a": [1], "b": [2], "c": [3]}) >>> print(batch.columns(["a", "b"])) [[1], [2]]
- shuffle() ray.rllib.policy.sample_batch.SampleBatch [source]#
Shuffles the rows of this batch in-place.
- Returns
This very (now shuffled) SampleBatch.
- Raises
ValueError – If self[SampleBatch.SEQ_LENS] is defined.
Examples
>>> from ray.rllib.policy.sample_batch import SampleBatch >>> batch = SampleBatch({"a": [1, 2, 3, 4]}) >>> print(batch.shuffle()) {"a": [4, 1, 3, 2]}
- split_by_episode(key: Optional[str] = None) List[ray.rllib.policy.sample_batch.SampleBatch] [source]#
Splits by
eps_id
column and returns list of new batches. Ifeps_id
is not present, splits bydones
instead.- Parameters
key – If specified, overwrite default and use key to split.
- Returns
List of batches, one per distinct episode.
- Raises
KeyError – If the
eps_id
ANDdones
columns are not present.
Examples
>>> from ray.rllib.policy.sample_batch import SampleBatch >>> # "eps_id" is present >>> batch = SampleBatch( ... {"a": [1, 2, 3], "eps_id": [0, 0, 1]}) >>> print(batch.split_by_episode()) [{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}] >>> >>> # "eps_id" not present, split by "dones" instead >>> batch = SampleBatch( ... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]}) >>> print(batch.split_by_episode()) [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}] >>> >>> # The last episode is appended even if it does not end with done >>> batch = SampleBatch( ... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]}) >>> print(batch.split_by_episode()) [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}] >>> batch = SampleBatch( ... {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}) >>> print(batch.split_by_episode()) [{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]
- slice(start: int, end: int, state_start=None, state_end=None) ray.rllib.policy.sample_batch.SampleBatch [source]#
Returns a slice of the row data of this batch (w/o copying).
- Parameters
start – Starting index. If < 0, will left-zero-pad.
end – Ending index.
- Returns
A new SampleBatch, which has a slice of this batch’s data.
- timeslices(size: Optional[int] = None, num_slices: Optional[int] = None, k: Optional[int] = None) List[ray.rllib.policy.sample_batch.SampleBatch] [source]#
Returns SampleBatches, each one representing a k-slice of this one.
Will start from timestep 0 and produce slices of size=k.
- Parameters
size – The size (in timesteps) of each returned SampleBatch.
num_slices – The number of slices to produce.
k – Deprecated: Use size or num_slices instead. The size (in timesteps) of each returned SampleBatch.
- Returns
The list of
num_slices
(new) SampleBatches or n (new) SampleBatches each one of sizesize
.
- right_zero_pad(max_seq_len: int, exclude_states: bool = True)[source]#
Right (adding zeros at end) zero-pads this SampleBatch in-place.
This will set the
self.zero_padded
flag to True andself.max_seq_len
to the givenmax_seq_len
value.- Parameters
max_seq_len – The max (total) length to zero pad to.
exclude_states – If False, also right-zero-pad all
state_in_x
data. If True, leavestate_in_x
keys as-is.
- Returns
This very (now right-zero-padded) SampleBatch.
- Raises
ValueError – If self[SampleBatch.SEQ_LENS] is None (not defined).
Examples
>>> from ray.rllib.policy.sample_batch import SampleBatch >>> batch = SampleBatch( ... {"a": [1, 2, 3], "seq_lens": [1, 2]}) >>> print(batch.right_zero_pad(max_seq_len=4)) {"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
>>> batch = SampleBatch({"a": [1, 2, 3], ... "state_in_0": [1.0, 3.0], ... "seq_lens": [1, 2]}) >>> print(batch.right_zero_pad(max_seq_len=5)) {"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0], "state_in_0": [1.0, 3.0], # <- all state-ins remain as-is "seq_lens": [1, 2]}
- to_device(device, framework='torch')[source]#
TODO: transfer batch to given device as framework tensor.
- size_bytes() int [source]#
Returns sum over number of bytes of all data buffers.
For numpy arrays, we use
.nbytes
. For all other value types, we use sys.getsizeof(…).- Returns
The overall size in bytes of the data buffer (all columns).
- as_multi_agent() ray.rllib.policy.sample_batch.MultiAgentBatch [source]#
Returns the respective MultiAgentBatch using DEFAULT_POLICY_ID.
- Returns
The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding to this SampleBatch.
- compress(bulk: bool = False, columns: Set[str] = frozenset({'new_obs', 'obs'})) ray.rllib.policy.sample_batch.SampleBatch [source]#
Compresses the data buffers (by column) in place.
- Parameters
bulk – Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size.
columns – The columns to compress. Default: Only compress the obs and new_obs columns.
- Returns
This very (now compressed) SampleBatch.
- decompress_if_needed(columns: Set[str] = frozenset({'new_obs', 'obs'})) ray.rllib.policy.sample_batch.SampleBatch [source]#
Decompresses data buffers (per column if not compressed) in place.
- Parameters
columns – The columns to decompress. Default: Only decompress the obs and new_obs columns.
- Returns
This very (now uncompressed) SampleBatch.
- get_single_step_input_dict(view_requirements: Dict[str, ViewRequirement], index: Union[str, int] = 'last') SampleBatch [source]#
Creates single ts SampleBatch at given index from
self
.For usage as input-dict for model (action or value function) calls.
- Parameters
view_requirements – A view requirements dict from the model for which to produce the input_dict.
index – An integer index value indicating the position in the trajectory for which to generate the compute_actions input dict. Set to “last” to generate the dict at the very end of the trajectory (e.g. for value estimation). Note that “last” is different from -1, as “last” will use the final NEXT_OBS as observation input.
- Returns
The (single-timestep) input dict for ModelV2 calls.
MultiAgentBatch (ray.rllib.policy.sample_batch.MultiAgentBatch)#
In multi-agent mode, several sample batches may be
collected separately for each individual policy and are placed in a container object of type MultiAgentBatch
:
- class ray.rllib.policy.sample_batch.MultiAgentBatch(policy_batches: Dict[str, ray.rllib.policy.sample_batch.SampleBatch], env_steps: int)[source]#
A batch of experiences from multiple agents in the environment.
- policy_batches#
Mapping from policy ids to SampleBatches of experiences.
- Type
Dict[PolicyID, SampleBatch]
- count#
The number of env steps in this batch.
- env_steps() int [source]#
The number of env steps (there are >= 1 agent steps per env step).
- Returns
The number of environment steps contained in this batch.
- agent_steps() int [source]#
The number of agent steps (there are >= 1 agent steps per env step).
- Returns
The number of agent steps total in this batch.
- timeslices(k: int) List[ray.rllib.policy.sample_batch.MultiAgentBatch] [source]#
Returns k-step batches holding data for each agent at those steps.
For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3], for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
Calling timeslices(1) would return three MultiAgentBatches containing [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
Calling timeslices(2) would return two MultiAgentBatches containing [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
This method is used to implement “lockstep” replay mode. Note that this method does not guarantee each batch contains only data from a single unroll. Batches might contain data from multiple different envs.
- static wrap_as_needed(policy_batches: Dict[str, ray.rllib.policy.sample_batch.SampleBatch], env_steps: int) Union[ray.rllib.policy.sample_batch.SampleBatch, ray.rllib.policy.sample_batch.MultiAgentBatch] [source]#
Returns SampleBatch or MultiAgentBatch, depending on given policies. If policy_batches is empty (i.e. {}) it returns an empty MultiAgentBatch.
- Parameters
policy_batches – Mapping from policy ids to SampleBatch.
env_steps – Number of env steps in the batch.
- Returns
The single default policy’s SampleBatch or a MultiAgentBatch (more than one policy).
- copy() ray.rllib.policy.sample_batch.MultiAgentBatch [source]#
Deep-copies self into a new MultiAgentBatch.
- Returns
The copy of self with deep-copied data.
- compress(bulk: bool = False, columns: Set[str] = frozenset({'new_obs', 'obs'})) None [source]#
Compresses each policy batch (per column) in place.
- Parameters
bulk – Whether to compress across the batch dimension (0) as well. If False will compress n separate list items, where n is the batch size.
columns – Set of column names to compress.
- decompress_if_needed(columns: Set[str] = frozenset({'new_obs', 'obs'})) ray.rllib.policy.sample_batch.MultiAgentBatch [source]#
Decompresses each policy batch (per column), if already compressed.
- Parameters
columns – Set of column names to decompress.
- Returns
Self.
- as_multi_agent() ray.rllib.policy.sample_batch.MultiAgentBatch [source]#
Simply returns
self
(already a MultiAgentBatch).- Returns
This very instance of MultiAgentBatch.