ray.data.Dataset.iter_torch_batches#

Dataset.iter_torch_batches(*, prefetch_batches: int = 1, batch_size: int | None = 256, dtypes: torch.dtype | Dict[str, torch.dtype] | None = None, device: str = 'auto', collate_fn: Callable[[Dict[str, numpy.ndarray]], CollatedData] | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None) Iterable[Dict[str, torch.Tensor] | CollatedData][source]#

Return an iterable over batches of data represented as Torch tensors.

This iterable yields batches of type Dict[str, torch.Tensor]. For more flexibility, call iter_batches() and manually convert your data to Torch tensors.

Note

This operation will trigger execution of the lazy transformations performed on this dataset.

Examples

>>> import ray
>>> for batch in ray.data.range(
...     12,
... ).iter_torch_batches(batch_size=4):
...     print(batch)
{'id': tensor([0, 1, 2, 3])}
{'id': tensor([4, 5, 6, 7])}
{'id': tensor([ 8,  9, 10, 11])}

Use the collate_fn to customize how the tensor batch is created.

>>> from typing import Any, Dict
>>> import torch
>>> import numpy as np
>>> import ray
>>> def collate_fn(batch: Dict[str, np.ndarray]) -> Any:
...     return torch.stack(
...         [torch.as_tensor(array) for array in batch.values()],
...         axis=1
...     )
>>> dataset = ray.data.from_items([
...     {"col_1": 1, "col_2": 2},
...     {"col_1": 3, "col_2": 4}])
>>> for batch in dataset.iter_torch_batches(collate_fn=collate_fn):
...     print(batch)
tensor([[1, 2],
        [3, 4]])

Time complexity: O(1)

Parameters:
  • prefetch_batches – The number of batches to fetch ahead of the current batch to fetch. If set to greater than 0, a separate threadpool is used to fetch the objects to the local node, format the batches, and apply the collate_fn. Defaults to 1.

  • batch_size – The number of rows in each batch, or None to use entire blocks as batches (blocks may contain different number of rows). The final batch may include fewer than batch_size rows if drop_last is False. Defaults to 256.

  • dtypes – The Torch dtype(s) for the created tensor(s); if None, the dtype is inferred from the tensor data. You can’t use this parameter with collate_fn.

  • device – The device on which the tensor should be placed. Defaults to “auto” which moves the tensors to the appropriate device when the Dataset is passed to Ray Train and collate_fn is not provided. Otherwise, defaults to CPU. You can’t use this parameter with collate_fn.

  • collate_fn – A function to convert a Numpy batch to a PyTorch tensor batch. When this parameter is specified, the user should manually handle the host to device data transfer outside of collate_fn. This is useful for further processing the data after it has been batched. Potential use cases include collating along a dimension other than the first, padding sequences of various lengths, or generally handling batches of different length tensors. If not provided, the default collate function is used which simply converts the batch of numpy arrays to a batch of PyTorch tensors. This API is still experimental and is subject to change. You can’t use this parameter in conjunction with dtypes or device.

  • drop_last – Whether to drop the last batch if it’s incomplete.

  • local_shuffle_buffer_size – If not None, the data is randomly shuffled using a local in-memory shuffle buffer, and this value serves as the minimum number of rows that must be in the local in-memory shuffle buffer in order to yield a batch. When there are no more rows to add to the buffer, the remaining rows in the buffer are drained. batch_size must also be specified when using local shuffling.

  • local_shuffle_seed – The seed to use for the local random shuffle.

Returns:

An iterable over Torch Tensor batches.

See also

Dataset.iter_batches()

Call this method to manually convert your data to Torch tensors.