class ray.data.Dataset(plan: ray.data._internal.plan.ExecutionPlan, epoch: int, lazy: bool = True, logical_plan: Optional[ray.data._internal.logical.interfaces.LogicalPlan] = None)[source]#

A Dataset is a distributed data collection for data loading and processing.

Datasets are implemented as a list of ObjectRef[Block], where each block holds an ordered collection of items, representing a shard of the overall data collection. The block can be either a pyarrow.Table, or Python list. The block also determines the unit of parallelism.

Datasets can be created in multiple ways: from synthetic data via range_*() APIs, from existing memory data via from_*() APIs, or from external storage systems such as local disk, S3, HDFS etc. via the read_*() APIs. The (potentially processed) Dataset can be saved back to external storage systems via the write_*() APIs.


>>> import ray
>>> # Create dataset from synthetic data.
>>> ds = ray.data.range(1000)
>>> # Create dataset from in-memory data.
>>> ds = ray.data.from_items(
...     [{"col1": i, "col2": i * 2} for i in range(1000)])
>>> # Create dataset from external storage system.
>>> ds = ray.data.read_parquet("s3://bucket/path") 
>>> # Save dataset back to external storage system.
>>> ds.write_csv("s3://bucket/output") 

Datasets has two kinds of operations: tranformation, which takes in Datasets and outputs a new Dataset (e.g. map_batches()); and consumption, which produces values (not Dataset) as output (e.g. iter_batches()).

Datasets supports parallel processing at scale: transformations such as map_batches(), aggregations such as min()/max()/mean(), grouping via groupby(), shuffling operations such as sort(), random_shuffle(), and repartition().


>>> import ray
>>> ds = ray.data.range(1000)
>>> # Transform in parallel with map_batches().
>>> ds.map_batches(lambda batch: [v * 2 for v in batch])
+- Dataset(num_blocks=17, num_rows=1000, schema=<class 'int'>)
>>> # Compute max.
>>> ds.max()
>>> # Group the data.
>>> ds.groupby(lambda x: x % 3).count()
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
>>> # Shuffle this dataset randomly.
>>> ds.random_shuffle()
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)
>>> # Sort it back in order.
>>> ds.sort()
+- Dataset(num_blocks=..., num_rows=1000, schema=<class 'int'>)

Since Datasets are just lists of Ray object refs, they can be passed between Ray tasks and actors without incurring a copy. Datasets support conversion to/from several more featureful dataframe libraries (e.g., Spark, Dask, Modin, MARS), and are also compatible with distributed TensorFlow / PyTorch.

PublicAPI: This API is stable across Ray releases.

__init__(plan: ray.data._internal.plan.ExecutionPlan, epoch: int, lazy: bool = True, logical_plan: Optional[ray.data._internal.logical.interfaces.LogicalPlan] = None)[source]#

Construct a Dataset (internal API).

The constructor is not part of the Dataset API. Use the ray.data.* read methods to construct a dataset.


__init__(plan, epoch[, lazy, logical_plan])

Construct a Dataset (internal API).

add_column(col, fn, *[, compute])

Add the given column to the dataset.


Aggregate the entire dataset as one group.



Count the number of records in the dataset.


The format of the dataset's underlying data blocks.


Return this dataset's default batch format.


Deserialize the provided lineage-serialized Dataset.

drop_columns(cols, *[, compute])

Drop one or more columns from the dataset.


filter(fn, *[, compute])

Filter out records that do not satisfy the given predicate.

flat_map(fn, *[, compute])

Apply the given function to each record and then flatten results.


Force full evaluation of the blocks of this dataset.


Get a list of references to the underlying blocks of this dataset.


Group the dataset by the key function or column name.


Whether this dataset's lineage is able to be serialized for storage and later deserialized, possibly on a different cluster.


Return the list of input files for the dataset.


Returns whether this Dataset has been fully executed.

iter_batches(*[, prefetch_blocks, ...])

Return a local batched iterator over the dataset.

iter_rows(*[, prefetch_blocks])

Return a local row iterator over the dataset.

iter_tf_batches(*[, prefetch_blocks, ...])

Return a local batched iterator of TensorFlow Tensors over the dataset.

iter_torch_batches(*[, prefetch_blocks, ...])

Return a local batched iterator of Torch Tensors over the dataset.


Return a DatasetIterator that can be used to repeatedly iterate over the dataset.


Enable lazy evaluation.


Truncate the dataset to the first limit records.

map(fn, *[, compute])

Apply the given function to each record of this dataset.

map_batches(fn, *[, batch_size, compute, ...])

Apply the given function to batches of data.

max([on, ignore_nulls])

Compute maximum over entire dataset.

mean([on, ignore_nulls])

Compute mean over entire dataset.

min([on, ignore_nulls])

Compute minimum over entire dataset.


Return the number of blocks of this dataset.

random_sample(fraction, *[, seed])

Randomly samples a fraction of the elements of this dataset.

random_shuffle(*[, seed, num_blocks])

Randomly shuffle the elements of this dataset.

randomize_block_order(*[, seed])

Randomly shuffle the blocks of this dataset.

repartition(num_blocks, *[, shuffle])

Repartition the dataset into exactly this number of blocks.


Convert this into a DatasetPipeline by looping over this dataset.


Return the schema of the dataset.

select_columns(cols, *[, compute])

Select one or more columns from the dataset.


Serialize this dataset's lineage, not the actual data or the existing data futures, to bytes that can be stored and later deserialized, possibly on a different cluster.


Print up to the given number of records from the dataset.


Return the in-memory size of the dataset.

sort([key, descending])

Sort the dataset by the specified key column or key function.

split(n, *[, equal, locality_hints])

Split the dataset into n disjoint pieces.


Split the dataset at the given indices (like np.split).


Split the dataset using proportions.


Returns a string containing execution timing information.

std([on, ddof, ignore_nulls])

Compute standard deviation over entire dataset.

sum([on, ignore_nulls])

Compute sum over entire dataset.


Return up to limit records from the dataset.


Return all of the records in the dataset.


Convert this dataset into a distributed set of Arrow tables.


Convert this dataset into a Dask DataFrame.


Convert this dataset into a MARS dataframe.


Convert this dataset into a Modin dataframe.

to_numpy_refs(*[, column])

Convert this dataset into a distributed set of NumPy ndarrays.


Convert this dataset into a single Pandas DataFrame.


Convert this dataset into a distributed set of Pandas dataframes.

to_random_access_dataset(key[, num_workers])

Convert this Dataset into a distributed RandomAccessDataset (EXPERIMENTAL).


Convert this dataset into a Spark dataframe.

to_tf(feature_columns, label_columns, *[, ...])

Return a TF Dataset over this dataset.

to_torch(*[, label_column, feature_columns, ...])

Return a Torch IterableDataset over this dataset.

train_test_split(test_size, *[, shuffle, seed])

Split the dataset into train and test subsets.


Combine this dataset with others of the same type.

window(*[, blocks_per_window, bytes_per_window])

Convert this into a DatasetPipeline by windowing over data blocks.

write_csv(path, *[, filesystem, ...])

Write the dataset to csv.

write_datasource(datasource, *[, ...])

Write the dataset to a custom datasource.

write_json(path, *[, filesystem, ...])

Write the dataset to json.

write_mongo(uri, database, collection[, ...])

Write the dataset to a MongoDB datasource.

write_numpy(path, *[, column, filesystem, ...])

Write a tensor column of the dataset to npy files.

write_parquet(path, *[, filesystem, ...])

Write the dataset to parquet.

write_tfrecords(path, *[, filesystem, ...])

Write the dataset to TFRecord files.


Zip this dataset with the elements of another.