Ray Train: Scalable Model Training
Contents

Ray Train: Scalable Model Training#
Tip
Train is currently in beta. Fill out this short form to get involved with Train development!
Ray Train scales model training for popular ML frameworks such as Torch, XGBoost, TensorFlow, and more. It seamlessly integrates with other Ray libraries such as Tune and Predictors:
Intro to Ray Train#
Framework support: Train abstracts away the complexity of scaling up training for common machine learning frameworks such as XGBoost, Pytorch, and Tensorflow. There are three broad categories of Trainers that Train offers:
Deep Learning Trainers (Pytorch, Tensorflow, Horovod)
Tree-based Trainers (XGboost, LightGBM)
Other ML frameworks (HuggingFace, Scikit-Learn, RLlib)
Built for ML practitioners: Train supports standard ML tools and features that practitioners love:
Callbacks for early stopping
Checkpointing
Integration with TensorBoard, Weights/Biases, and MLflow
Jupyter notebooks
Batteries included: Train is part of Ray AIR and seamlessly operates in the Ray ecosystem.
Use Ray Data with Train to load and process datasets both small and large.
Use Ray Tune with Train to sweep parameter grids and leverage cutting edge hyperparameter search algorithms.
Leverage the Ray cluster launcher to launch autoscaling or spot instance clusters on any cloud.
Quick Start to Distributed Training with Ray Train#
import ray
from ray.train.xgboost import XGBoostTrainer
from ray.air.config import ScalingConfig
# Load data.
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")
# Split data into train and validation.
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)
trainer = XGBoostTrainer(
scaling_config=ScalingConfig(
# Number of workers to use for data parallelism.
num_workers=2,
# Whether to use GPU acceleration.
use_gpu=False,
),
label_column="target",
num_boost_round=20,
params={
# XGBoost specific params
"objective": "binary:logistic",
# "tree_method": "gpu_hist", # uncomment this to use GPU for training
"eval_metric": ["logloss", "error"],
},
datasets={"train": train_dataset, "valid": valid_dataset},
)
result = trainer.fit()
print(result.metrics)
import ray
from ray.train.lightgbm import LightGBMTrainer
from ray.air.config import ScalingConfig
# Load data.
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")
# Split data into train and validation.
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)
trainer = LightGBMTrainer(
scaling_config=ScalingConfig(
# Number of workers to use for data parallelism.
num_workers=2,
# Whether to use GPU acceleration.
use_gpu=False,
),
label_column="target",
num_boost_round=20,
params={
# LightGBM specific params
"objective": "binary",
"metric": ["binary_logloss", "binary_error"],
},
datasets={"train": train_dataset, "valid": valid_dataset},
)
result = trainer.fit()
print(result.metrics)
import torch
import torch.nn as nn
import ray
from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
# If using GPUs, set this to True.
use_gpu = False
input_size = 1
layer_size = 15
output_size = 1
num_epochs = 3
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.layer1 = nn.Linear(input_size, layer_size)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(layer_size, output_size)
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
def train_loop_per_worker():
dataset_shard = session.get_dataset_shard("train")
model = NeuralNetwork()
loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
model = train.torch.prepare_model(model)
for epoch in range(num_epochs):
for batches in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float
):
inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
output = model(inputs)
loss = loss_fn(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch: {epoch}, loss: {loss.item()}")
session.report(
{},
checkpoint=Checkpoint.from_dict(
dict(epoch=epoch, model=model.state_dict())
),
)
train_dataset = ray.data.from_items([{"x": x, "y": 2 * x + 1} for x in range(200)])
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
scaling_config=scaling_config,
datasets={"train": train_dataset},
)
result = trainer.fit()
import ray
import tensorflow as tf
from ray.air import session
from ray.air.integrations.keras import ReportCheckpointCallback
from ray.train.tensorflow import TensorflowTrainer
from ray.air.config import ScalingConfig
# If using GPUs, set this to True.
use_gpu = False
a = 5
b = 10
size = 100
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=()),
# Add feature dimension, expanding (batch_size,) to (batch_size, 1).
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(10),
tf.keras.layers.Dense(1),
]
)
return model
def train_func(config: dict):
batch_size = config.get("batch_size", 64)
epochs = config.get("epochs", 3)
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
# Model building/compiling need to be within `strategy.scope()`.
multi_worker_model = build_model()
multi_worker_model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
loss=tf.keras.losses.mean_squared_error,
metrics=[tf.keras.metrics.mean_squared_error],
)
dataset = session.get_dataset_shard("train")
results = []
for _ in range(epochs):
tf_dataset = dataset.to_tf(
feature_columns="x", label_columns="y", batch_size=batch_size
)
history = multi_worker_model.fit(
tf_dataset, callbacks=[ReportCheckpointCallback()]
)
results.append(history.history)
return results
config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}
train_dataset = ray.data.from_items(
[{"x": x / 200, "y": 2 * x / 200} for x in range(200)]
)
scaling_config = ScalingConfig(num_workers=2, use_gpu=use_gpu)
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
train_loop_config=config,
scaling_config=scaling_config,
datasets={"train": train_dataset},
)
result = trainer.fit()
print(result.metrics)
import ray
import ray.train as train
import ray.train.torch # Need this to use `train.torch.get_device()`
import horovod.torch as hvd
import torch
import torch.nn as nn
from ray.air import session, Checkpoint
from ray.train.horovod import HorovodTrainer
from ray.air.config import ScalingConfig
# If using GPUs, set this to True.
use_gpu = False
input_size = 1
layer_size = 15
output_size = 1
num_epochs = 3
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.layer1 = nn.Linear(input_size, layer_size)
self.relu = nn.ReLU()
self.layer2 = nn.Linear(layer_size, output_size)
def forward(self, input):
return self.layer2(self.relu(self.layer1(input)))
def train_loop_per_worker():
hvd.init()
dataset_shard = session.get_dataset_shard("train")
model = NeuralNetwork()
device = train.torch.get_device()
model.to(device)
loss_fn = nn.MSELoss()
lr_scaler = 1
optimizer = torch.optim.SGD(model.parameters(), lr=0.1 * lr_scaler)
# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(
optimizer,
named_parameters=model.named_parameters(),
op=hvd.Average,
)
for epoch in range(num_epochs):
model.train()
for batch in dataset_shard.iter_torch_batches(
batch_size=32, dtypes=torch.float
):
inputs, labels = torch.unsqueeze(batch["x"], 1), batch["y"]
outputs = model(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"epoch: {epoch}, loss: {loss.item()}")
session.report(
{},
checkpoint=Checkpoint.from_dict(dict(model=model.state_dict())),
)
train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
trainer = HorovodTrainer(
train_loop_per_worker=train_loop_per_worker,
scaling_config=scaling_config,
datasets={"train": train_dataset},
)
result = trainer.fit()
Training Framework Catalog#
Here is a catalog of the framework-specific Trainer, Checkpoint, and Predictor classes that ship out of the box with Train:
Trainer Class |
Checkpoint Class |
Predictor Class |
(Torch/TF Checkpoint) |
(Torch/TF Predictor) |
|
Next steps#
