ray.data.DataIterator.iter_torch_batches#

DataIterator.iter_torch_batches(*, prefetch_batches: int = 1, batch_size: int | None = 256, dtypes: torch.dtype | Dict[str, torch.dtype] | None = None, device: str = 'auto', collate_fn: Callable[[Dict[str, numpy.ndarray]], CollatedData] | CollateFn | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None) Iterable[TorchBatchType][source]#

Return a batched iterable of Torch Tensors over the dataset.

This iterable yields a dictionary of column-tensors. If you are looking for more flexibility in the tensor conversion (e.g. casting dtypes) or the batch format, try using iter_batches() directly.

Examples

>>> import ray
>>> for batch in ray.data.range(
...     12,
... ).iterator().iter_torch_batches(batch_size=4):
...     print(batch)
{'id': tensor([0, 1, 2, 3])}
{'id': tensor([4, 5, 6, 7])}
{'id': tensor([ 8,  9, 10, 11])}

Use the collate_fn to customize how the tensor batch is created.

>>> from typing import Any, Dict
>>> import torch
>>> import numpy as np
>>> import ray
>>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
...     return torch.stack(
...         [torch.as_tensor(array) for array in batch.values()],
...         axis=1
...     )
>>> iterator = ray.data.from_items([
...     {"col_1": 1, "col_2": 2},
...     {"col_1": 3, "col_2": 4}]).iterator()
>>> for batch in iterator.iter_torch_batches(collate_fn=collate_fn):
...     print(batch)
tensor([[1, 2],
        [3, 4]])

Time complexity: O(1)

Parameters:
  • prefetch_batches – The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool will be used to fetch the objects to the local node, format the batches, and apply the collate_fn. Defaults to 1.

  • batch_size – The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than batch_size rows if drop_last is False. Defaults to 256.

  • dtypes – The Torch dtype(s) for the created tensor(s); if None, the dtype will be inferred from the tensor data. You can’t use this parameter with collate_fn.

  • device – The device on which the tensor should be placed. Defaults to “auto” which moves the tensors to the appropriate device when the Dataset is passed to Ray Train and collate_fn is not provided. Otherwise, defaults to CPU. You can’t use this parameter with collate_fn.

  • collate_fn

    [Alpha] A function to customize how data batches are collated before being passed to the model. This is useful for last-mile data formatting such as padding, masking, or packaging tensors into custom data structures. If not provided, iter_torch_batches automatically converts batches to torch.Tensor`s and moves them to the device assigned to the current worker. The input to `collate_fn may be:

    1. pyarrow.Table, where you should provide a callable class that subclasses ArrowBatchCollateFn (recommended for best performance). Note that you should use util function arrow_batch_to_tensors to convert the pyarrow.Table to a dictionary of non-contiguous tensor batches.

    2. Dict[str, np.ndarray], where you should provide a callable class that subclasses NumpyBatchCollateFn

    3. pd.DataFrame, where you should provide a callable class that subclasses PandasBatchCollateFn

    The output can be any type. If the output is a TensorBatchType, it will be automatically moved to the current worker’s device. For other types, you must handle device transfer manually in your training loop. Note: This function is called in a multi-threaded context; avoid using thread-unsafe code.

  • drop_last – Whether to drop the last batch if it’s incomplete.

  • local_shuffle_buffer_size – If non-None, the data will be randomly shuffled using a local in-memory shuffle buffer, and this value will serve as the minimum number of rows that must be in the local in-memory shuffle buffer in order to yield a batch. When there are no more rows to add to the buffer, the remaining rows in the buffer will be drained. This buffer size must be greater than or equal to batch_size, and therefore batch_size must also be specified when using local shuffling.

  • local_shuffle_seed – The seed to use for the local random shuffle.

Returns:

An iterable over Torch Tensor batches.