Key Concepts#

Learn about Dataset and the capabilities it provides.

This guide provides a lightweight introduction to:


Ray Data’s main abstraction is a Dataset, which is a distributed data collection. Datasets are designed for machine learning, and they can represent data collections that exceed a single machine’s memory.

Loading data#

Create datasets from on-disk files, Python objects, and cloud storage services like S3. Ray Data can read from any filesystem supported by Arrow.

import ray

ds ="s3://anonymous@air-example-data/iris.csv")
{'sepal length (cm)': 5.1, 'sepal width (cm)': 3.5, 'petal length (cm)': 1.4, 'petal width (cm)': 0.2, 'target': 0}

To learn more about creating datasets, read Loading data.

Transforming data#

Apply user-defined functions (UDFs) to transform datasets. Ray executes transformations in parallel for performance.

from typing import Dict
import numpy as np

# Compute a "petal area" attribute.
def transform_batch(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    vec_a = batch["petal length (cm)"]
    vec_b = batch["petal width (cm)"]
    batch["petal area (cm^2)"] = vec_a * vec_b
    return batch

transformed_ds = ds.map_batches(transform_batch)
      sepal length (cm): double,
      sepal width (cm): double,
      petal length (cm): double,
      petal width (cm): double,
      target: int64,
      petal area (cm^2): double

To learn more about transforming datasets, read Transforming data.

Consuming data#

Pass datasets to Ray Tasks or Actors, and access records with methods like take_batch() and iter_batches().

{'sepal length (cm)': array([5.1, 4.9, 4.7]),
 'sepal width (cm)': array([3.5, 3. , 3.2]),
 'petal length (cm)': array([1.4, 1.4, 1.3]),
 'petal width (cm)': array([0.2, 0.2, 0.2]),
 'target': array([0, 0, 0]),
 'petal area (cm^2)': array([0.28, 0.28, 0.26])}
def consume(ds: -> int:
    num_batches = 0
    for batch in ds.iter_batches(batch_size=8):
        num_batches += 1
    return num_batches

class Worker:

    def train(self, data_iterator):
        for batch in data_iterator.iter_batches(batch_size=8):

workers = [Worker.remote() for _ in range(4)]
shards = transformed_ds.streaming_split(n=4, equal=True)
ray.get([w.train.remote(s) for w, s in zip(workers, shards)])

To learn more about consuming datasets, see Iterating over Data and Saving Data.

Saving data#

Call methods like write_parquet() to save dataset contents to local or remote filesystems.

import os


['..._000000.parquet', '..._000001.parquet']

To learn more about saving dataset contents, see Saving data.