ray.data.Dataset.iter_jax_batches#
- Dataset.iter_jax_batches(*, prefetch_batches: int = 1, batch_size: int = 256, dtypes: jax.typing.DTypeLike | Dict[str, jax.typing.DTypeLike] | None = None, collate_fn: CollateFn | None = None, drop_last: bool = False, local_shuffle_buffer_size: int | None = None, local_shuffle_seed: int | None = None, synchronize_batches: bool = False, paddings: int | float | bool | Dict[str, int | float | bool] | None = None) Iterable[Any][source]#
Return an iterable over batches of data represented as JAX arrays.
This iterable yields batches of type
Union["jax.Array", Dict[str, "jax.Array"]]. The returned batches will be the global view of the 1D data parallel JAX arrays (sharded along the batch dimension) put on all the jax devices. Data types are inferred from the underlying NumPy arrays, unless specified viadtypes. For more flexibility, calliter_batches()and manually convert your data to JAX arrays.Note
The returned JAX Arrays are sharded using an internal 1D mesh created by Ray Data. If you are using these arrays within a
jax.set_meshcontext that defines a different mesh (e.g., a multi-dimensional mesh or a different device ordering), JAX may perform an implicit resharding (communication) when the arrays are first used in a JAX operation. To minimize this overhead, ensure your training loop’s device ordering aligns with the one produced byjax.experimental.mesh_utils.create_device_mesh.Note
This operation will trigger execution of the lazy transformations performed on this dataset.
Examples
import ray ds = ray.data.read_csv("s3://anonymous@air-example-data/iris.csv") jax_dataset = ds.iter_jax_batches(batch_size=2) for batch in jax_dataset: print(batch["sepal length (cm)"], batch["target"]) break
[5.1 4.9] [0 0]
- Parameters:
prefetch_batches – The number of batches to fetch ahead. Defaults to 1.
batch_size – The number of rows in each batch. Must be divisible by the number of local devices. Defaults to 256.
dtypes – The JAX dtype(s) for the created array(s); if None, the dtype will be inferred from the NumPy ndarray data.
collate_fn –
[Alpha] A function to customize how data batches are collated before being passed to the model. This is useful for last-mile data formatting such as padding, masking, or packaging tensors into custom data structures. The input to
collate_fnmay be:pyarrow.Table, where you should provide a callable class that subclasses
ArrowBatchCollateFn(recommended for best performance).Dict[str, np.ndarray], where you should provide a callable class that subclasses
NumpyBatchCollateFnpd.DataFrame, where you should provide a callable class that subclasses
PandasBatchCollateFn
The output must be a
np.ndarrayorDict[str, np.ndarray], and will be automatically sharded across JAX-addressable devices. Note: This function is called in a multi-threaded context; avoid using thread-unsafe code.drop_last – Whether to drop the last batch if it’s incomplete. Defaults to False.
local_shuffle_buffer_size – If not
None, the data is randomly shuffled using a local in-memory shuffle buffer, and this value serves as the minimum number of rows that must be in the local in-memory shuffle buffer in order to yield a batch. When there are no more rows to add to the buffer, the remaining rows in the buffer are drained.batch_sizemust also be specified when using local shuffling.local_shuffle_seed – The seed to use for the local random shuffle.
synchronize_batches – Whether to synchronize batch shapes across all hosts. Setting this to False can improve performance if you guarantee that all hosts produce identical batch shapes and counts beforehand. Setting this to True can help catch bugs where different hosts produce different batch shapes.
paddings – The value to use for padding the last batch to
batch_size. If a dictionary is provided, it must map column names to padding values. If not None, uneven batches will be padded with this value. Must be castable to the user-provided dtypes.
- Returns:
An iterable over JAX Array batches.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.