Transforming Datasets

The Ray Datasets transformations take in datasets and produce new datasets. For example, map is a transformation that applies a user-defined function (UDF) on each row of input dataset and returns a new dataset as result. The Datasets transformations are composable. Operations can be further applied on the result dataset, forming a chain of transformations to express more complex computations. Transformations are the core for expressing business logic in Datasets.


In general, we have two types of transformations:

  • One-to-one transformations: each input block will contribute to only one output block, such as ds.map_batches(). In other systems this may be called narrow transformations.

  • All-to-all transformations: input blocks can contribute to multiple output blocks, such as ds.random_shuffle(). In other systems this may be called wide transformations.

Here is a table listing some common transformations supported by Ray Datasets.

Common Ray Datasets transformations.






Apply a given function to batches of records of this dataset.



Split the dataset into N disjoint pieces.



Repartition the dataset into N blocks, without shuffling the data.



Repartition the dataset into N blocks, shuffling the data during repartition.



Randomly shuffle the elements of this dataset.



Sort the dataset by a sortkey.



Group the dataset by a groupkey.


Datasets also provides the convenience transformation methods, ds.flat_map(), and ds.filter(), which are not vectorized (slower than ds.map_batches()), but may be useful for development.

The following is an example to make use of those transformation APIs for processing the Iris dataset.

import ray
import pandas

# Create a dataset from file with Iris data.
# Tip: "example://" is a convenient protocol to access the
# python/ray/data/examples/data directory.
ds ="example://iris.csv")
# Dataset(num_blocks=1, num_rows=150,
#         schema={sepal.length: float64, sepal.width: float64,
#                 petal.length: float64, petal.width: float64, variety: object})
# -> {'sepal.length': 5.1, 'sepal.width': 3.5,
#     'petal.length': 1.4, 'petal.width': 0.2, 'variety': 'Setosa'}
# -> {'sepal.length': 4.9, 'sepal.width': 3.0,
#     'petal.length': 1.4, 'petal.width': 0.2, 'variety': 'Setosa'}
# -> {'sepal.length': 4.7, 'sepal.width': 3.2,
#     'petal.length': 1.3, 'petal.width': 0.2, 'variety': 'Setosa'}

# Repartition the dataset to 5 blocks.
ds = ds.repartition(5)
# Dataset(num_blocks=5, num_rows=150,
#         schema={sepal.length: double, sepal.width: double,
#                 petal.length: double, petal.width: double, variety: string})

# Find rows with sepal.length < 5.5 and petal.length > 3.5.
def transform_batch(df: pandas.DataFrame) -> pandas.DataFrame:
    return df[(df["sepal.length"] < 5.5) & (df["petal.length"] > 3.5)]

# Map processing the dataset.
# -> {'sepal.length': 5.2, 'sepal.width': 2.7,
#     'petal.length': 3.9, 'petal.width': 1.4, 'variety': 'Versicolor'}
# -> {'sepal.length': 5.4, 'sepal.width': 3.0,
#     'petal.length': 4.5, 'petal.width': 1.5, 'variety': 'Versicolor'}
# -> {'sepal.length': 4.9, 'sepal.width': 2.5,
#     'petal.length': 4.5, 'petal.width': 1.7, 'variety': 'Virginica'}

# Split the dataset into 2 datasets
# -> [Dataset(num_blocks=3, num_rows=90,
#             schema={sepal.length: double, sepal.width: double,
#                     petal.length: double, petal.width: double, variety: string}),
#     Dataset(num_blocks=2, num_rows=60,
#             schema={sepal.length: double, sepal.width: double,
#                     petal.length: double, petal.width: double, variety: string})]

# Sort the dataset by sepal.length.
ds = ds.sort("sepal.length")
# -> {'sepal.length': 4.3, 'sepal.width': 3.0,
#     'petal.length': 1.1, 'petal.width': 0.1, 'variety': 'Setosa'}
# -> {'sepal.length': 4.4, 'sepal.width': 2.9,
#     'petal.length': 1.4, 'petal.width': 0.2, 'variety': 'Setosa'}
# -> {'sepal.length': 4.4, 'sepal.width': 3.0,
#     'petal.length': 1.3, 'petal.width': 0.2, 'variety': 'Setosa'}

# Shuffle the dataset.
ds = ds.random_shuffle()
# -> {'sepal.length': 6.7, 'sepal.width': 3.1,
#     'petal.length': 4.4, 'petal.width': 1.4, 'variety': 'Versicolor'}
# -> {'sepal.length': 6.7, 'sepal.width': 3.3,
#     'petal.length': 5.7, 'petal.width': 2.1, 'variety': 'Virginica'}
# -> {'sepal.length': 4.5, 'sepal.width': 2.3,
#     'petal.length': 1.3, 'petal.width': 0.3, 'variety': 'Setosa'}

# Group by the variety.
# -> {'variety': 'Setosa', 'count()': 50}
# -> {'variety': 'Versicolor', 'count()': 50}
# -> {'variety': 'Virginica', 'count()': 50}

Compute Strategy

Datasets transformations are executed by either Ray tasks or Ray actors across a Ray cluster. By default, Ray tasks are used (with compute="tasks"). For transformations that require expensive setup, it’s preferrable to use Ray actors, which are stateful and allow setup to be reused for efficiency. You can specify, max) and Ray will use an autoscaling actor pool of min to max actors to execute your transforms. For a fixed-size actor pool, just specify ActorPoolStrategy(n, n).

The following is an example of using the Ray tasks and actors compute strategy for batch inference:

import ray
import pandas
import numpy
from import ActorPoolStrategy

# Dummy model to predict Iris variety.
def predict_iris(df: pandas.DataFrame) -> pandas.DataFrame:
    conditions = [
        (df["sepal.length"] < 5.0),
        (df["sepal.length"] >= 5.0) & (df["sepal.length"] < 6.0),
        (df["sepal.length"] >= 6.0)
    values = ["Setosa", "Versicolor", "Virginica"]
    return pandas.DataFrame({"predicted_variety":, values)})

class IrisInferModel:
    def __init__(self):
        self._model = predict_iris 

    def __call__(self, batch: pandas.DataFrame) -> pandas.DataFrame:
        return self._model(batch)

ds ="example://iris.csv").repartition(10)

# Batch inference processing with Ray tasks (the default compute strategy).
predicted = ds.map_batches(predict_iris)

# Batch inference processing with Ray actors. Autoscale the actors between 3 and 10.
predicted = ds.map_batches(
    IrisInferModel, compute=ActorPoolStrategy(3, 10), batch_size=256)