ray.data.Dataset.iter_jax_batches#

Dataset.iter_jax_batches(*, prefetch_batches: int = 1, batch_size: int = 256, dtypes: jax.typing.DTypeLike | Dict[str, jax.typing.DTypeLike] | None = None, collate_fn: CollateFn | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None, synchronize_batches: bool = False, paddings: int | float | bool | Dict[str, int | float | bool] | None = None) Iterable[Any][source]#

Return an iterable over batches of data represented as JAX arrays.

This iterable yields batches of type Union["jax.Array", Dict[str, "jax.Array"]]. The returned batches will be the global view of the 1D data parallel JAX arrays (sharded along the batch dimension) put on all the jax devices. Data types are inferred from the underlying NumPy arrays, unless specified via dtypes. For more flexibility, call iter_batches() and manually convert your data to JAX arrays.

Note

The returned JAX Arrays are sharded using an internal 1D mesh created by Ray Data. If you are using these arrays within a jax.set_mesh context that defines a different mesh (e.g., a multi-dimensional mesh or a different device ordering), JAX may perform an implicit resharding (communication) when the arrays are first used in a JAX operation. To minimize this overhead, ensure your training loop’s device ordering aligns with the one produced by jax.experimental.mesh_utils.create_device_mesh.

Note

This operation will trigger execution of the lazy transformations performed on this dataset.

Examples

import ray

ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv")

jax_dataset = ds.iter_jax_batches(batch_size=2)
for batch in jax_dataset:
    print(batch["sepal length (cm)"], batch["target"])
    break
[5.1 4.9] [0 0]
Parameters:
  • prefetch_batches – The number of batches to fetch ahead. Defaults to 1.

  • batch_size – The number of rows in each batch. Must be divisible by the number of local devices. Defaults to 256.

  • dtypes – The JAX dtype(s) for the created array(s); if None, the dtype will be inferred from the NumPy ndarray data.

  • collate_fn

    [Alpha] A function to customize how data batches are collated before being passed to the model. This is useful for last-mile data formatting such as padding, masking, or packaging tensors into custom data structures. The input to collate_fn may be:

    1. pyarrow.Table, where you should provide a callable class that subclasses ArrowBatchCollateFn (recommended for best performance).

    2. Dict[str, np.ndarray], where you should provide a callable class that subclasses NumpyBatchCollateFn

    3. pd.DataFrame, where you should provide a callable class that subclasses PandasBatchCollateFn

    The output must be a np.ndarray or Dict[str, np.ndarray], and will be automatically sharded across JAX-addressable devices. Note: This function is called in a multi-threaded context; avoid using thread-unsafe code.

  • drop_last – Whether to drop the last batch if it’s incomplete. Defaults to False.

  • local_shuffle_buffer_size – If not None, the data is randomly shuffled using a local in-memory shuffle buffer, and this value serves as the minimum number of rows that must be in the local in-memory shuffle buffer in order to yield a batch. When there are no more rows to add to the buffer, the remaining rows in the buffer are drained. batch_size must also be specified when using local shuffling.

  • local_shuffle_seed – The seed to use for the local random shuffle.

  • synchronize_batches – Whether to synchronize batch shapes across all hosts. Setting this to False can improve performance if you guarantee that all hosts produce identical batch shapes and counts beforehand. Setting this to True can help catch bugs where different hosts produce different batch shapes.

  • paddings – The value to use for padding the last batch to batch_size. If a dictionary is provided, it must map column names to padding values. If not None, uneven batches will be padded with this value. Must be castable to the user-provided dtypes.

Returns:

An iterable over JAX Array batches.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.