ray.data.DataIterator.iter_jax_batches#

DataIterator.iter_jax_batches(*, prefetch_batches: int = 1, batch_size: int = 256, dtypes: jax.typing.DTypeLike | Dict[str, jax.typing.DTypeLike] | None = None, collate_fn: CollateFn | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None, synchronize_batches: bool = False, paddings: int | float | bool | Dict[str, int | float | bool] | None = None) Iterable[Any][source]#

Return a batched iterable of JAX Arrays over the dataset.

This iterator fetches data blocks, converts them to NumPy arrays, and loads them directly onto JAX-addressable devices using Global Data Parallel sharding. Data types are inferred from the underlying NumPy arrays, unless specified via dtypes.

This iterable will yield a dictionary of column-tensors, or a single tensor if the underlying dataset consists of a single unnamed column.

Note

The returned JAX Arrays are sharded using an internal 1D mesh created by Ray Data. If you are using these arrays within a jax.set_mesh context that defines a different mesh (e.g., a multi-dimensional mesh or a different device ordering), JAX may perform an implicit resharding (communication) when the arrays are first used in a JAX operation. To minimize this overhead, ensure your training loop’s device ordering aligns with the one produced by jax.experimental.mesh_utils.create_device_mesh.

Parameters:
  • prefetch_batches – The number of batches to fetch ahead. Defaults to 1.

  • batch_size – The number of rows in each batch for each host. Must be divisible by the number of local devices. Defaults to 256.

  • dtypes – The JAX dtype(s) for the created array(s); if None, the dtype will be inferred from the NumPy ndarray data.

  • collate_fn

    [Alpha] A function to customize how data batches are collated before being passed to the model. This is useful for last-mile data formatting such as padding, masking, or packaging tensors into custom data structures. The input to collate_fn may be:

    1. pyarrow.Table, where you should provide a callable class that subclasses ArrowBatchCollateFn (recommended for best performance).

    2. Dict[str, np.ndarray], where you should provide a callable class that subclasses NumpyBatchCollateFn

    3. pd.DataFrame, where you should provide a callable class that subclasses PandasBatchCollateFn

    The output must be a np.ndarray or Dict[str, np.ndarray], and will be automatically sharded across JAX-addressable devices. Note: This function is called in a multi-threaded context; avoid using thread-unsafe code.

  • drop_last – Whether to drop the last batch if incomplete.

  • local_shuffle_buffer_size – Minimum rows for local in-memory shuffle.

  • local_shuffle_seed – Seed for local random shuffle.

  • synchronize_batches – Whether to synchronize batch shapes across all hosts. Setting this to False can improve performance if you guarantee that all hosts produce identical batch shapes and counts beforehand. Setting this to True can help catch bugs where different hosts produce different batch shapes.

  • paddings – The value to use for padding the last batch to batch_size. If a dictionary is provided, it must map column names to padding values. If not None, uneven batches will be padded with this value. Must be castable to the dtypes of the created arrays.

Returns:

An iterable over JAX Array batches.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.