Batch Training with Ray Core


We strongly recommend using Ray Datasets and AIR Trainers to develop batch training, which will enable you to build it faster and more easily, and get the built-in benefits like auto-scaling actor pool. If you think your use case cannot be supported by Ray Datasets or AIR, we’d love to get your feedback e.g. through a Ray GitHub issue.

Batch training and tuning are common tasks in simple machine learning use-cases such as time series forecasting. They require fitting of simple models on multiple data batches corresponding to locations, products, etc. This notebook showcases how to conduct batch training on the NYC Taxi Dataset using only Ray Core and stateless Ray tasks.

Batch training in the context of this notebook is understood as creating the same model(s) for different and separate datasets or subsets of a dataset. This task is naively parallelizable and can be easily scaled with Ray.

Batch training diagram


Our task is to create separate time series models for each pickup location. We can use the pickup_location_id column in the dataset to group the dataset into data batches. We will then fit models for each batch and choose the best one.

Let’s start by importing Ray and initializing a local Ray cluster.

from typing import Callable, Optional, List, Union, Tuple, Iterable
import time
import numpy as np
import pandas as pd

from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error

import pyarrow as pa
from pyarrow import fs
from pyarrow import dataset as ds
from pyarrow import parquet as pq
import pyarrow.compute as pc
import ray


For benchmarking purposes, we can print the times of various operations. In order to reduce clutter in the output, this is set to False by default.


def print_time(msg: str):

To speed things up, we’ll only use a small subset of the full dataset consisting of two last months of 2019. You can choose to use the full dataset for 2018-2019 by setting the SMOKE_TEST variable to False.


As we will be using the NYC Taxi dataset, we define a simple batch transformation function to set correct data types, calculate the trip duration and fill missing values.

# A Pandas DataFrame UDF for transforming the underlying blocks of a Dataset in parallel.
def transform_batch(df: pd.DataFrame) -> pd.DataFrame:
    df["pickup_at"] = pd.to_datetime(df["pickup_at"], format="%Y-%m-%d %H:%M:%S")
    df["dropoff_at"] = pd.to_datetime(df["dropoff_at"], format="%Y-%m-%d %H:%M:%S")
    df["trip_duration"] = (df["dropoff_at"] - df["pickup_at"]).dt.seconds
    df["pickup_location_id"] = df["pickup_location_id"].fillna(-1)
    df["dropoff_location_id"] = df["dropoff_location_id"].fillna(-1)
    return df

We will be fitting scikit-learn models on data batches. We define a Ray task fit_and_score_sklearn that fits the model and calculates mean absolute error on the validation set. We will be treating this as a simple regression problem where we want to predict the relationship between the drop-off location and the trip duration.

# Ray task to fit and score a scikit-learn model.
def fit_and_score_sklearn(
    train: pd.DataFrame, test: pd.DataFrame, model: BaseEstimator
) -> Tuple[BaseEstimator, float]:
    train_X = train[["dropoff_location_id"]]
    train_y = train["trip_duration"]
    test_X = test[["dropoff_location_id"]]
    test_y = test["trip_duration"]

    # Start training.
    model =, train_y)
    pred_y = model.predict(test_X)
    error = mean_absolute_error(test_y, pred_y)
    return model, error

The train_and_evaluate function contains the logic for train-test splitting and fitting of multiple models in parallel on each data batch, for purposes of comparison. Thanks to this, we can evaluate several models and choose the best one for each data batch.

def train_and_evaluate(
    df: pd.DataFrame, models: List[BaseEstimator], i: int = 0
) -> List[Tuple[BaseEstimator, float]]:
    # We need at least 4 rows to create a train / test split.
    if len(df) < 4:
        print_time(f"Dataframe for LocID: {i} is empty or smaller than 4")
        return None

    start = time.time()

    # Train / test split.
    train, test = train_test_split(df)

    # We put the train & test dataframes into Ray object store
    # so that they can be reused by all models fitted here.
    train_ref = ray.put(train)
    test_ref = ray.put(test)

    # Launch a fit and score task for each model.
    results = ray.get(
        [fit_and_score_sklearn.remote(train_ref, test_ref, model) for model in models]
    results.sort(key=lambda x: x[1])  # sort by error

    time_taken = time.time() - start
    print_time(f"Training time for LocID: {i}: {time_taken}")
    return results

The read_data function reads a Parquet file and uses a push-down predicate to extract the data batch we want to fit a model on using the provided index to group the rows. By having each task read the data and extract batches separately, we ensure that memory utilization is minimal - as opposed to requiring each task to load the entire partition into memory first.

We are using PyArrow to read the file, as it supports push-down predicates to be applied during file reading. This lets us avoid having to load an entire file into memory, which could cause an OOM error with a large dataset. After the dataset is loaded, we convert it to pandas so that it can be used for training with scikit-learn.

def read_data(file: str, i: int) -> pd.DataFrame:
    return pq.read_table(
        filters=[("pickup_location_id", "=", i)],

The task Ray task contains all logic necessary to load a data batch, transform it and fit and evaluate models on it.

You may notice that we have previously defined fit_and_score_sklearn as a Ray task as well and set it to be executed from inside task. This allows us to dynamically create a tree of tasks, ensuring that the cluster resources are fully utillized. Without this pattern, each task would need to be assigned several CPU cores for the model fitting, meaning that if certain models finish faster, then those CPU cores would stil stay occupied. Thankfully, Ray is able to deal with nested parallelism in tasks without the need for any extra logic, allowing us to simplify the code.

def task(
    data: Union[str, pd.DataFrame],
    file_name: str,
    i: int,
    models: List[BaseEstimator],
    load_data_func: Optional[Callable] = None,
) -> List[Tuple[BaseEstimator, float]]:
    if load_data_func:
        start_time = time.time()
        data = load_data_func(data, i)
        data_loading_time = time.time() - start_time
        print_time(f"Data loading time for LocID: {i}: {data_loading_time}")

    # Cast PyArrow scalar to Python if needed.
        i = i.as_py()
    except Exception:

    # Perform transformation
    start_time = time.time()
    data = transform_batch(data)
    transform_time = time.time() - start_time
    print_time(f"Data transform time for LocID: {i}: {transform_time}")

    return file_name, i, train_and_evaluate(data, models, i)

The task_generator generator dispatches tasks and yields references to them. Each task will be ran in parallel on a separate batch as determined by the pickup_location_id column in the provided file. Ray will handle scheduling automatically.

def task_generator(files: List[str], models: List[BaseEstimator]) -> ray.ObjectRef:
    for file in files:
            locdf = pq.read_table(file, columns=["pickup_location_id"])
        except Exception:
        loc_list = locdf["pickup_location_id"].unique()

        for i in loc_list:
            yield task.remote(file, file, i, models, read_data)

Finally, the run driver function generates tasks for each Parquet file it recieves (with each file corresponding to one month). We define the function to take in a list of models, so that we can evaluate them all and choose the best one for each batch. The function blocks when it reaches ray.get() and waits for tasks to return their results.

def run(files: List[str], task_generator: Callable, models: List[BaseEstimator]):
    print("Starting run...")
    start = time.time()

    task_refs = list(task_generator(files, models))
    results = ray.get(task_refs)

    taken = time.time() - start
    count = len(results)
    results_not_none = [x for x in results if x is not None]
    count_not_none = len(results_not_none)

    # Sleep a moment for nicer output
    print("", flush=True)
    print(f"Total number of models (all tasks): {count_not_none} ({count})")
    print(f"TOTAL TIME TAKEN: {taken:.2f} seconds")
    return results

We obtain the partitions of the dataset from an S3 bucket so that we can pass them to run.

# Obtain the dataset. Each month is a separate file.
dataset = ds.dataset(
    partitioning=["year", "month"],
starting_idx = -2 if SMOKE_TEST else 0
files = [f"s3://{file}" for file in dataset.files][starting_idx:]
print(f"Obtained {len(files)} files!")
Obtained 2 files!

We can now run our script. The output is a list of tuples in the following format: (file name, partition id, list of models and their MAE scores). For brevity, we will print out the first 10 tuples.

from sklearn.linear_model import LinearRegression

results = run(files, task_generator, models=[LinearRegression()])
Starting run...

Total number of models (all tasks): 522 (522)
TOTAL TIME TAKEN: 21.19 seconds
[('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 145, [(LinearRegression(), 851.3091289442241)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 161, [(LinearRegression(), 763.587971487081)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 163, [(LinearRegression(), 742.3122613593824)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 193, [(LinearRegression(), 899.5440269877245)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 260, [(LinearRegression(), 741.1232150739363)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 56, [(LinearRegression(), 860.3183412585847)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 79, [(LinearRegression(), 728.9143263092787)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 90, [(LinearRegression(), 649.3464235594931)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 162, [(LinearRegression(), 723.9509168205005)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 50, [(LinearRegression(), 671.7616933026658)])]

We can also provide multiple scikit-learn models to our run function and the best one will be chosen for each batch. A common use-case here would be to define several models of the same type with different hyperparameters.

from sklearn.tree import DecisionTreeRegressor

results = run(
Starting run...

Total number of models (all tasks): 522 (522)
TOTAL TIME TAKEN: 18.51 seconds
[('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 145, [(DecisionTreeRegressor(), 619.9080145718), (DecisionTreeRegressor(splitter='random'), 620.9351656841662), (LinearRegression(), 894.9093613150645)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 161, [(DecisionTreeRegressor(), 585.1303154215874), (DecisionTreeRegressor(splitter='random'), 585.1334584269538), (LinearRegression(), 746.3996639952683)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 163, [(DecisionTreeRegressor(), 590.8829340940193), (DecisionTreeRegressor(splitter='random'), 591.0654550332006), (LinearRegression(), 758.3602607590221)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 193, [(DecisionTreeRegressor(), 739.1724549207835), (DecisionTreeRegressor(splitter='random'), 739.5002953972328), (LinearRegression(), 906.5242773055481)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 260, [(DecisionTreeRegressor(), 593.1233945510796), (DecisionTreeRegressor(splitter='random'), 593.1233945510796), (LinearRegression(), 709.558440515228)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 56, [(DecisionTreeRegressor(splitter='random'), 1302.8135501217532), (DecisionTreeRegressor(), 1308.5687584550865), (LinearRegression(), 1400.7256036944598)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 79, [(DecisionTreeRegressor(), 573.3767209185635), (DecisionTreeRegressor(splitter='random'), 573.3853566498115), (LinearRegression(), 711.9296171689957)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 90, [(DecisionTreeRegressor(splitter='random'), 483.88298667156215), (DecisionTreeRegressor(), 484.1489956504658), (LinearRegression(), 638.507610810801)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 162, [(DecisionTreeRegressor(splitter='random'), 546.0548872824131), (DecisionTreeRegressor(), 546.0673940127546), (LinearRegression(), 687.9393358281769)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 50, [(DecisionTreeRegressor(splitter='random'), 529.9439816747014), (DecisionTreeRegressor(), 530.0687930367063), (LinearRegression(), 681.5231361774709)])]

Loading data once into Ray object store

In order to ensure that the data can always fit in memory, each task reads the files independently and extracts the desired data batch. This, however, negatively impacts the runtime. If we have sufficient memory in our Ray cluster, we can instead load each partition once, extract the batches, and save them in the Ray object store, reducing time required dramatically at a cost of higher memory usage.

Notice we do not call ray.get() on the references of the read_into_object_store. Instead, we pass the reference itself as the argument to the task.remote dispatch, allowing for the data to stay in the object store until it is actually needed. This avoids a situation where all the data would be loaded into the memory of the process calling ray.get().

You can use the Ray Dashboard to compare the memory usage between the previous approach and this one.

from ray.util.placement_group import placement_group, remove_placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

def read_into_object_store(file: str) -> List[ray.ObjectRef]:
    print(f"Loading {file}")
    # Read the entire file into memory.
        locdf = pq.read_table(
    except Exception:
        return []

    loc_list = locdf["pickup_location_id"].unique()

    group_refs = []
    for i in loc_list:
        # Put each data batch as a separate dataframe into Ray object store.
            (i, ray.put(locdf.filter(pc.field("pickup_location_id") == i).to_pandas()))

    return group_refs

def task_generator_with_object_store(
    files: List[str], models: List[BaseEstimator]
) -> ray.ObjectRef:
    # Use a placement group with a SPREAD strategy to load each
    # file on a separate node as an OOM safeguard.
    # This is not foolproof though! We can also specify a resource
    # requirement for memory, if we know what is the maximum
    # memory requirement for a single file.
    pg = placement_group([{"CPU": 1}] * len(files), strategy="SPREAD")

    read_into_object_store_pg = read_into_object_store.options(
    load_tasks = [read_into_object_store_pg.remote(file) for file in files]
    group_refs = {}
    for i, refs in enumerate(ray.get(load_tasks)):
        group_refs[files[i]] = refs

    for file, refs in group_refs.items():
        for i, ref in refs:
            yield task.remote(ref, file, i, models)
results = run(files, task_generator_with_object_store, models=[LinearRegression()])
Starting run...
(read_into_object_store pid=3170, ip= Loading s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet
(read_into_object_store pid=3077, ip= Loading s3://air-example-data/ursa-labs-taxi-data/by_year/2019/06/data.parquet/ab5b9d2b8cc94be19346e260b543ec35_000000.parquet
(scheduler +59s) Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
(scheduler +59s) Warning: The following resource request cannot be scheduled right now: {'CPU': 1.0}. This is likely due to all cluster resources being claimed by actors. Consider creating fewer actors or adding more nodes to this Ray cluster.

Total number of models (all tasks): 522 (522)
TOTAL TIME TAKEN: 15.78 seconds
[('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 145, [(LinearRegression(), 852.9429209323498)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 161, [(LinearRegression(), 756.4310964446844)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 163, [(LinearRegression(), 759.0581689980796)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 193, [(LinearRegression(), 811.8705198797737)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 260, [(LinearRegression(), 669.7161874214457)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 56, [(LinearRegression(), 1388.4215954337024)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 79, [(LinearRegression(), 715.368673359218)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 90, [(LinearRegression(), 644.6049120675258)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 162, [(LinearRegression(), 695.9343158694874)]), ('s3://air-example-data/ursa-labs-taxi-data/by_year/2019/05/data.parquet/359c21b3e28f40328e68cf66f7ba40e2_000000.parquet', 50, [(LinearRegression(), 717.3705726378896)])]

We can see that this approach allowed us to finish training much faster, but it would not have been possible if the dataset was too large to fit into our cluster memory. Therefore, this pattern is only recommended if the data you are working with is small. Otherwise, it is recommended to load the data inside the tasks right before its used.