Batch Training with Ray DatasetsΒΆ

IntroductionΒΆ

Batch training and tuning are common tasks in simple machine learning use-cases such as time series forecasting. They require fitting of simple models on multiple data batches corresponding to locations, products, etc.

In the context of this notebook, batch training is understood as creating the same model(s) for different and separate datasets or subsets of a dataset. This notebook showcases how to conduct batch training using Ray Dataset.

Batch training diagram

For the data, we will use the NYC Taxi dataset. This popular tabular dataset contains historical taxi pickups by timestamp and location in NYC. To demonstrate batch training, we will simplify the data to a regression problem to predict trip_duration and use scikit-learn.

To demonstrate how batch training can be parallelized, we will train a separate model for each dropoff location. This means we can use the dropoff_location_id column in the dataset to group the dataset into data batches. Then we will fit a separate model for each batch and evaluate it.

WalkthroughΒΆ

Let us start by importing a few required libraries, including open-source Ray itself!

import os
print(f'Number of CPUs in this system: {os.cpu_count()}')
from typing import Tuple, List, Union, Optional, Callable
import time
import pandas as pd
import numpy as np
import pyarrow.dataset as pds
from pyarrow import fs
from pyarrow import parquet as pq
from ray.data import Dataset
Number of CPUs in this system: 8
import ray

if ray.is_initialized():
    ray.shutdown()
ray.init()
2022-10-31 13:42:00,434	INFO worker.py:1509 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8266 

Ray

Python version: 3.8.13
Ray version: 2.0.0
Dashboard: http://127.0.0.1:8266
# For benchmarking purposes, we can print the times of various operations.
# In order to reduce clutter in the output, this is set to False by default.
PRINT_TIMES = False

def print_time(msg: str):
    if PRINT_TIMES:
        print(msg)
# To speed things up, we’ll only use a small subset of the full dataset consisting of two last months of 2019.
# You can choose to use the full dataset for 2018-2019 by setting the SMOKE_TEST variable to False.

SMOKE_TEST = True

Creating Ray Dataset ΒΆ

Ray Datasets are the standard way to load and exchange data in Ray libraries and applications. We will use the Ray Dataset APIs to read the data and quickly inspect it.

First, we will define some global variables we will use throughout the notebook, such as the list of S3 links to the files making up the dataset and the possible location IDs.

# Define some global variables.
target = "trip_duration"
s3_partitions = pds.dataset(
    "s3://anonymous@air-example-data/ursa-labs-taxi-data/by_year/",
    partitioning=["year", "month"],
)
s3_files = [f"s3://anonymous@{file}" for file in s3_partitions.files]

# Obtain all location IDs
location_ids = (
    pq.read_table(s3_files[0], columns=["dropoff_location_id"])["dropoff_location_id"]
    .unique()
    .to_pylist()
)

# Use smoke testing or not.
starting_idx = -1 if SMOKE_TEST else 0
sample_locations = [145, 166, 152] if SMOKE_TEST else location_ids

# Display what data will be used.
s3_files = s3_files[starting_idx:]
print(f"NYC Taxi using {len(s3_files)} file(s)!")
print(f"s3_files: {s3_files}")
print(f"Locations: {sample_locations}")
NYC Taxi using 1 file(s)!
s3_files: ['s3://anonymous@air-example-data/ursa-labs-taxi-data/by_year/2019/06/data.parquet/ab5b9d2b8cc94be19346e260b543ec35_000000.parquet']
Locations: [145, 166, 152]

The easiest way to create a ray dataset is to use ray.data.read_parquet to read parquet files in parallel onto the Ray cluster.

Uncomment the cell below if you want to try it out.

# # This cell commented out because it can take a long time!
# # In the next section "Filtering read" we make it faster.

# # Read everything in the files list into a ray dataset.
# ds = ray.data.read_parquet(s3_files)
# ds

Filtering a Ray Dataset on Read ΒΆ

Normally there is some last-mile data processing required before training. Let’s just assume we know the data processing steps are:

  • Drop negative trip distances, 0 fares, 0 passengers.

  • Drop 2 unknown zones: ['264', '265'].

  • Calculate trip duration and add it as a new column.

  • Drop trip durations smaller than 1 minute and greater than 24 hours.

Instead of blindly reading all the data, it would be better if we only read the data we needed. This is similar concept to SQL SELECT only rows, columns you need vs SELECT *.

Tip

Best practice is to filter as much as you can directly in the Ray Dataset read_parquet().

Note that Ray Datasets’ Parquet reader supports projection (column selection) and row filter pushdown, where we can push the above column selection and the row-based filter to the Parquet read. If we specify column selection at Parquet read time, the unselected columns won’t even be read from disk. This can save a lot of memory, especially with big datasets, and allow us to avoid OOM issues.

The row-based filter is specified via Arrow’s dataset field expressions.

def pushdown_read_data(files_list: list, sample_ids: list) -> Dataset:
    start = time.time()

    filter_expr = (
        (pds.field("passenger_count") > 0)
        & (pds.field("trip_distance") > 0)
        & (pds.field("fare_amount") > 0)
        & (~pds.field("pickup_location_id").isin([264, 265]))
        & (~pds.field("dropoff_location_id").isin([264, 265]))
        & (pds.field("dropoff_location_id").isin(sample_ids))
    )

    dataset = ray.data.read_parquet(
        files_list,
        columns=[
            "pickup_at",
            "dropoff_at",
            "pickup_location_id",
            "dropoff_location_id",
            "passenger_count",
            "trip_distance",
            "fare_amount",
        ],
        filter=filter_expr,
    )

    data_loading_time = time.time() - start
    print(f"Data loading time: {data_loading_time:.2f} seconds")
    
    return dataset
# Test the pushdown_read_data function
ds_raw = pushdown_read_data(s3_files, sample_locations)
2022-10-31 13:42:15,071	WARNING read_api.py:291 -- ⚠️  The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.
Data loading time: 1.01 seconds

Inspecting a Ray Dataset ΒΆ

Let’s get some basic statistics about our newly created Ray Dataset.

As our Ray Dataset is backed by Parquet, we can obtain the number of rows from the metadata without triggering a full data read.

print(f"Number of rows: {ds_raw.count()}")
Number of rows: 6941024

Similarly, we can obtain the Dataset size (in bytes) from the metadata.

print(f"Size bytes (from parquet metadata): {ds_raw.size_bytes()}")
Size bytes (from parquet metadata): 925892280

Let’s fetch and inspect the schema of the underlying Parquet files.

print("\nSchema data types:")
data_types = list(zip(ds_raw.schema().names, ds_raw.schema().types))
for s in data_types:
    print(f"{s[0]}: {s[1]}")
Schema data types:
pickup_at: timestamp[us]
dropoff_at: timestamp[us]
pickup_location_id: int32
dropoff_location_id: int32
passenger_count: int8
trip_distance: float
fare_amount: float

Transforming a Ray Dataset in parallel using custom functions ΒΆ

Ray Datasets allows you to specify custom data transform functions. These user defined functions (UDFs) can be called using Dataset.map_batches(my_UDF). The transformation will be conducted in parallel for each data batch.

Tip

You may need to call Dataset.repartition(n) first to split the Dataset into more blocks internally. By default, each block corresponds to one file. The upper bound of parallelism is the number of blocks.

You can specify the data format you are using in the batch_format parameter. The dataset will be divided into batches and those batches converted into the specified format. Available data formats you can specify in the batch_format paramater include "pandas", "pyarrow", "numpy". Tabular data will be passed into your UDF by default as a pandas DataFrame. Tensor data will be passed into your UDF as a numpy array.

Here, we will use batch_format="pandas" explicitly for clarity.

# A pandas DataFrame UDF for transforming the Dataset in parallel.
def transform_batch(the_df: pd.DataFrame) -> pd.DataFrame:
    df = the_df.copy()
    
    df["trip_duration"] = (df["dropoff_at"] - df["pickup_at"]).dt.seconds
    df = df[df["trip_duration"] > 60]
    df = df[df["trip_duration"] < 24 * 60 * 60] 
    df.drop(["dropoff_at", "pickup_at", "pickup_location_id", "fare_amount"]
            , axis=1, inplace=True)
    df["dropoff_location_id"] = df["dropoff_location_id"].fillna(-1)
    return df
%%time 

# Test the transform UDF.
print(f"Number of rows before transformation: {ds_raw.count()}")

# Repartition the dataset to allow for higher parallelism.
ds = ds_raw.repartition(14) 

# .map_batches applies a UDF to each partition of the data in parallel.
ds = ds.map_batches(transform_batch, batch_format="pandas")

# Verify row count.
print(f"Number of rows after transformation: {ds.count()}")
Number of rows before transformation: 6941024
Read: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [02:07<00:00, 127.41s/it]
Repartition: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [00:00<00:00, 17.27it/s]
Map_Batches: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [00:00<00:00, 61.09it/s]
Number of rows after transformation: 82704
CPU times: user 933 ms, sys: 411 ms, total: 1.34 s
Wall time: 2min 8s

Batch training with Ray Datasets ΒΆ

Now that we have learned more about our data and written a pandas UDF to transform our data, we are ready to train a model on batches of this data in parallel.

  1. We will use the dropoff_location_id column in the dataset to group the dataset into data batches.

  2. Then we will fit a separate model for each batch to predict trip_duration.

# import standard sklearn libraries
import sklearn
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_absolute_error
print(f"sklearn: {sklearn.__version__}")

# set global random seed for sklearn models
np.random.seed(415)
sklearn: 1.1.2

Define training functionsΒΆ

We want to fit a linear regression model to the trip duration for each drop-off location. For scoring, we will calculate mean absolute error on the validation set, and report that as model error per drop-off location.

The fit_and_score_sklearn function contains the logic necessary to fit a scikit-learn model and evaluate it using mean absolute error.

def fit_and_score_sklearn(
    train_df: pd.DataFrame, test_df: pd.DataFrame, model: BaseEstimator
) -> pd.DataFrame:
    
    # Assemble train/test pandas dfs
    train_X = train_df[["passenger_count", "trip_distance"]]
    train_y = train_df.trip_duration
    test_X = test_df[["passenger_count", "trip_distance"]]
    test_y = test_df.trip_duration

    # Start training.
    model = model.fit(train_X, train_y)
    pred_y = model.predict(test_X)
    
    # Evaluate.
    error = sklearn.metrics.mean_absolute_error(test_y, pred_y)
    if error is None:
        error = 10000.0
    
    # Assemble return as a pandas dataframe.
    return_df = pd.DataFrame({'model': [model], 'error': [error]})

    # return model, error
    return return_df

The train_and_evaluate function contains the logic for train-test splitting and fitting of a model using the fit_and_score_sklearn function.

As an input, this function takes in a pandas DataFrame. When we call Dataset.map_batches or Dataset.groupby().map_groups(), the Dataset will be batched into multiple pandas DataFrames and this function will be ran for each one in parallel. We will return the model and its error. Those results will be collected back into a Ray Dataset automatically.

def train_and_evaluate(
    df: pd.DataFrame, 
    models: List[BaseEstimator], 
    location_id: int
) -> pd.DataFrame:
    
    # We need at least 4 rows to create a train / test split.
    if len(df) < 4:
        print_time(f"Data batch for LocID {location_id} is empty or smaller than 4 rows")
        return None

    start = time.time()

    # Train / test split
    # Randomly split the data into 80/20 train/test.
    train_df, test_df = train_test_split(df, test_size=0.2, shuffle=True)

    # Launch a fit and score task for each model.
    # results is a list of pandas dataframes, one per model
    results = [fit_and_score_sklearn(train_df, test_df, model) for model in models]

    # Assemble location_id, name of model, and metrics in a pandas DataFrame
    # results_df = pd.concat(results) 
    results_df = pd.concat(results, axis=0, join='inner', ignore_index=True)
    results_df.insert(0, column='location_id', value=location_id)

    training_time = time.time() - start
    print_time(f"Training time for LocID {location_id}: {training_time:.2f} seconds")

    return results_df

We can also provide multiple scikit-learn models to our run function and the best one will be chosen for each batch. A common use-case here would be to define several models of the same type with different hyperparameters.

We will use Linear Regression and Decision Tree Regression.

MODELS = [LinearRegression(), DecisionTreeRegressor(max_depth=4)]

Recall how we wrote a data transform transform_batch UDF? It was called with pattern:

  • Dataset.map_batches(transform_batch, batch_format="pandas")

Similarly, we can write a custom groupy-aggregate function agg_func which will run for each Ray Dataset group-by group in parallel. The usage pattern is:

  • Dataset.groupby(column).map_groups(agg_func, batch_format="pandas").

In the cell below, we define our custom agg_func.

# A Pandas DataFrame aggregation function for processing grouped batches of Ray Dataset data.
def agg_func(df: pd.DataFrame) -> pd.DataFrame:
    location_id = df["dropoff_location_id"][0]

    # Handle errors in data groups
    try:
        # Transform the input pandas AND fit_and_evaluate the transformed pandas
        results_df = train_and_evaluate(df, MODELS, location_id)
        assert results_df is not None
    except Exception:
        # assemble a null entry
        print(f"Failed on LocID {location_id}!")
        results_df = pd.DataFrame([[location_id, None, 10000.0]], 
                         columns=["location_id", "model", "error"],
                         dtypes=["int32", BaseEstimator, "float64"])

    return results_df

Run batch training using map_groupsΒΆ

Finally, the main β€œdriver code” reads each Parquet file (each file corresponds to one month of NYC taxi data) into a Ray Dataset ds. Then we use Ray Dataset group-by to map each group into a batch of data and run agg_func on each of them in parallel by calling ds.groupby("dropoff_location_id").map_groups(agg_func, batch_format="pandas").

# Driver code to run this.

start = time.time()

# Read data into Ray Dataset
# ds = pushdown_read_data(s3_files, sample_locations).repartition(14)

# Use Ray Dataset groupby.map_groups() to process each group in parallel and return a Ray Dataset.
results = ds.groupby("dropoff_location_id")\
            .map_groups(agg_func, batch_format="pandas")

total_time_taken = time.time() - start
print(f"Total number of models: {results.count()}")
print(f"TOTAL TIME TAKEN: {total_time_taken:.2f} seconds")
Sort Sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [00:00<00:00, 1490.93it/s]
Shuffle Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [00:00<00:00, 537.55it/s]
Shuffle Reduce: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [00:00<00:00, 459.78it/s]
Map_Batches:   0%|                                       | 0/14 [00:00<?, ?it/s](reduce pid=89033) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89033)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89033) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89033)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89031) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89031)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89035) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89035)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89029) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89029)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89034) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89034)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89034) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89034)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89036) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89036)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89304) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89304)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89304) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89304)   return self._table.memory_usage(index=True, deep=True).sum()
(reduce pid=89304) /Users/christy/miniforge3/envs/rllib/lib/python3.8/site-packages/ray/data/_internal/pandas_block.py:216: FutureWarning: The default dtype for empty Series will be 'object' instead of 'float64' in a future version. Specify a dtype explicitly to silence this warning.
(reduce pid=89304)   return self._table.memory_usage(index=True, deep=True).sum()
Map_Batches: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 14/14 [00:02<00:00,  6.06it/s]
Total number of models: 6
TOTAL TIME TAKEN: 2.46 seconds

Finally, we can inspect the models we have trained and their errors.

results
Dataset(num_blocks=14, num_rows=6, schema={location_id: int32, model: object, error: float64})
# sort values by location id
results_df = results.to_pandas()
results_df.sort_values(by=["location_id"], ascending=True, inplace=True)
results_df
location_id model error
0 145 LinearRegression() 575.762413
1 145 DecisionTreeRegressor(max_depth=4) 571.611435
2 152 LinearRegression() 577.747356
3 152 DecisionTreeRegressor(max_depth=4) 609.994895
4 166 LinearRegression() 525.381878
5 166 DecisionTreeRegressor(max_depth=4) 521.116511
results_df.dtypes
location_id      int32
model           object
error          float64
dtype: object
# Keep only 1 model per location_id with minimum error
final_df = results_df.loc[results_df.groupby('location_id')['error'].idxmin()].copy()
final_df.reset_index(inplace=True, drop=True)
final_df.sort_values(by="error", ascending=True)
final_df
location_id model error
0 145 DecisionTreeRegressor(max_depth=4) 571.611435
1 152 LinearRegression() 577.747356
2 166 DecisionTreeRegressor(max_depth=4) 521.116511