Advanced: Scaling out expensive collate functions#
By default, the collate function executes on the training worker when you call ray.data.DataIterator.iter_torch_batches(). This approach has two main drawbacks:
Low scalability: The collate function runs sequentially on each training worker, limiting parallelism.
Resource competition: The collate function consumes CPU and memory resources from the training worker, potentially slowing down model training.
Scaling out the collate function to Ray Data allows you to scale collation across multiple CPU nodes independently of training workers, improving better overall pipeline throughput, especially with heavy collate functions.
This optimization is particularly effective when the collate function is computationally expensive (such as tokenization, image augmentation, or complex feature engineering) and you have additional CPU resources available for data preprocessing.
Moving the collate function to Ray Data#
The following example shows a typical collate function that runs on the training worker:
train_dataset = read_parquet().map(...)
def train_func():
for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
collate_fn=collate_fn,
batch_size=BATCH_SIZE
):
# Training logic here
pass
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
result = trainer.fit()
If the collate function is time/compute intensive and you’d like to scale it out,you should:
Create a custom collate function that runs in Ray Data and use
ray.data.Dataset.map_batches()to scale it out.Use
ray.data.Dataset.repartition()to ensure the batch size alignment.
Creating a custom collate function that runs in Ray Data#
To scale out, you’ll want to move the collate_fn into a Ray Data map_batches operation:
def collate_fn(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return batch
train_dataset = train_dataset.map_batches(collate_fn, batch_size=BATCH_SIZE)
def train_func():
for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
collate_fn=None,
batch_size=BATCH_SIZE,
):
# Training logic here
pass
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
result = trainer.fit()
A couple of things to note:
The
collate_fnreturns a dictionary of NumPy arrays, which is a standard Ray Data batch format.The
iter_torch_batchesmethod usescollate_fn=None, which reduces the amount of work is done on the training worker process.
Ensuring batch size alignment#
Typically, collate functions are used to create complete batches of data with a target batch size.
However, if you move the collate function to Ray Data using ray.data.Dataset.map_batches(), by default, it will not guarantee the batch size for each function call.
There are two common problems that you may encounter.
The collate function requires a certain number of rows provided as an input to work properly.
You want to avoid any reformatting / rebatching of the data on the training worker process.
To solve these problems, you can use ray.data.Dataset.repartition() with target_num_rows_per_block to ensure the batch size alignment.
By calling repartition before map_batches, you ensure that the input blocks contain the desired number of rows.
# Note: If you only use map_batches(batch_size=BATCH_SIZE), you are not guaranteed to get the desired number of rows as an input.
dataset = dataset.repartition(target_num_rows_per_block=BATCH_SIZE).map_batches(collate_fn, batch_size=BATCH_SIZE)
By calling repartition after map_batches, you ensure that the output blocks contain the desired number of rows. This avoids any reformatting / rebatching of the data on the training worker process.
dataset = dataset.map_batches(collate_fn, batch_size=BATCH_SIZE).repartition(target_num_rows_per_block=BATCH_SIZE)
def train_func():
for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
collate_fn=None,
batch_size=BATCH_SIZE,
):
# Training logic here
pass
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
result = trainer.fit()
Putting things together#
Throughout this guide, we use a mock text dataset to demonstrate the optimization. You can find the implementation of the mock dataset in Random Text Generator.
The following example shows a typical collate function that runs on the training worker:
from transformers import AutoTokenizer
import torch
import numpy as np
from typing import Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from mock_dataset import create_mock_ray_text_dataset
BATCH_SIZE = 10000
def vanilla_collate_fn(tokenizer: AutoTokenizer, batch: Dict[str, np.ndarray]) -> Dict[str, torch.Tensor]:
outputs = tokenizer(
list(batch["text"]),
truncation=True,
padding="longest",
return_tensors="pt",
)
outputs["labels"] = torch.LongTensor(batch["label"])
return outputs
def train_func():
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
collate_fn = lambda x: vanilla_collate_fn(tokenizer, x)
# Collate function runs on the training worker
for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
collate_fn=collate_fn,
batch_size=BATCH_SIZE
):
# Training logic here
pass
train_dataset = create_mock_ray_text_dataset(
dataset_size=1000000,
min_len=1000,
max_len=3000
)
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
result = trainer.fit()
The following example moves the collate function to Ray Data preprocessing:
from transformers import AutoTokenizer
import numpy as np
from typing import Dict
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from mock_dataset import create_mock_ray_text_dataset
import pyarrow as pa
BATCH_SIZE = 10000
class CollateFnRayData:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def __call__(self, batch: pa.Table) -> Dict[str, np.ndarray]:
results = self.tokenizer(
batch["text"].to_pylist(),
truncation=True,
padding="longest",
return_tensors="np",
)
results["labels"] = np.array(batch["label"])
return results
def train_func():
# Collate function already ran in Ray Data
for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
collate_fn=None,
batch_size=BATCH_SIZE,
):
# Training logic here
pass
# Apply preprocessing in Ray Data
train_dataset = (
create_mock_ray_text_dataset(
dataset_size=1000000,
min_len=1000,
max_len=3000
)
.map_batches(
CollateFnRayData,
batch_size=BATCH_SIZE,
batch_format="pyarrow",
)
.repartition(target_num_rows_per_block=BATCH_SIZE) # Ensure batch size alignment
)
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
result = trainer.fit()
The optimized implementation makes these changes:
Preprocessing in Ray Data: The tokenization logic moves from
train_functoCollateFnRayData, which runs inmap_batches.NumPy output: The collate function returns
Dict[str, np.ndarray]instead of PyTorch tensors, which Ray Data natively supports.Batch alignment:
repartition(target_num_rows_per_block=BATCH_SIZE)aftermap_batchesensures the collate function receives exact batch sizes and output blocks align with the batch size.No collate_fn in iterator:
iter_torch_batchesusescollate_fn=Nonebecause preprocessing already happened in Ray Data.
Benchmark results#
The following benchmarks demonstrate the performance improvement from scaling out the collate function. The test uses text tokenization with a batch size of 10,000 on a dataset of 1 million rows with text lengths between 1,000 and 3,000 characters.
Single node (g4dn.12xlarge: 48 vCPU, 4 NVIDIA T4 GPUs, 192 GiB memory)
Configuration |
Throughput |
|---|---|
Collate in iterator (baseline) |
1,588 rows/s |
Collate in Ray Data |
3,437 rows/s |
With 2 additional CPU nodes (m5.8xlarge: 32 vCPU, 128 GiB memory each)
Configuration |
Throughput |
|---|---|
Collate in iterator (baseline) |
1,659 rows/s |
Collate in Ray Data |
10,717 rows/s |
The results show that scaling out the collate function to Ray Data provides a 2x speedup on a single node and a 6x speedup when adding CPU-only nodes for preprocessing.
Advanced: Handling custom data types#
The optimized implementation above returns Dict[str, np.ndarray], which Ray Data natively supports. However, if your collate function needs to return PyTorch tensors or other custom data types that ray.data.Dataset.map_batches() doesn’t directly support, you need to serialize them.
Tensor serialization utility#
The following utility serializes PyTorch tensors into PyArrow format. It flattens all tensors in a batch into a single binary buffer, stores metadata about tensor shapes and dtypes, and packs everything into a single-row PyArrow table. On the training side, it deserializes the table back into the original tensor structure.
The serialization and deserialization operations are typically lightweight compared to the actual collate function work (such as tokenization or image processing), so the overhead is minimal relative to the performance gains from scaling the collate function.
You can use Collate Utilities as a reference implementation and adapt it to your needs.
Example with tensor serialization#
The following example demonstrates using tensor serialization when your collate function must return PyTorch tensors. This approach requires repartition before map_batches because the collate function changes the number of output rows (each batch becomes a single serialized row).
from transformers import AutoTokenizer
import torch
from typing import Dict
from ray.data.collate_fn import ArrowBatchCollateFn
import pyarrow as pa
from collate_utils import serialize_tensors_to_table, deserialize_table_to_tensors
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
from mock_dataset import create_mock_ray_text_dataset
BATCH_SIZE = 10000
class TextTokenizerCollateFn:
"""Collate function that runs in Ray Data preprocessing."""
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def __call__(self, batch: pa.Table) -> pa.Table:
# Tokenize the batch
outputs = self.tokenizer(
batch["text"].to_pylist(),
truncation=True,
padding="longest",
return_tensors="pt",
)
outputs["labels"] = torch.LongTensor(batch["label"].to_numpy())
# Serialize to single-row table using the utility
return serialize_tensors_to_table(outputs)
class IteratorCollateFn(ArrowBatchCollateFn):
"""Collate function for iter_torch_batches that deserializes the batch."""
def __init__(self, pin_memory=False):
self._pin_memory = pin_memory
def __call__(self, batch: pa.Table) -> Dict[str, torch.Tensor]:
# Deserialize from single-row table using the utility
return deserialize_table_to_tensors(batch, pin_memory=self._pin_memory)
def train_func():
collate_fn = IteratorCollateFn()
# Collate function only deserializes on the training worker
for batch in ray.train.get_dataset_shard("train").iter_torch_batches(
collate_fn=collate_fn,
batch_size=1 # Each "row" is actually a full batch
):
# Training logic here
pass
# Apply preprocessing in Ray Data
# Use repartition BEFORE map_batches because output row count changes
train_dataset = (
create_mock_ray_text_dataset(
dataset_size=1000000,
min_len=1000,
max_len=3000
)
.repartition(target_num_rows_per_block=BATCH_SIZE)
.map_batches(
TextTokenizerCollateFn,
batch_size=BATCH_SIZE,
batch_format="pyarrow",
)
)
trainer = TorchTrainer(
train_func,
datasets={"train": train_dataset},
scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
)
result = trainer.fit()