Getting Started with Distributed Model Training in Ray Train#

Ray Train offers multiple Trainers which implement scalable model training for different machine learning frameworks. Here are examples for some of the commonly used trainers:

In this example we will train a model using distributed XGBoost.

First, we load the dataset from S3 using Ray Datasets and split it into a train and validation dataset.

import ray

# Load data.
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")

# Split data into train and validation.
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)

In the ScalingConfig, we configure the number of workers to use:

from ray.air.config import ScalingConfig

scaling_config = ScalingConfig(
    # Number of workers to use for data parallelism.
    num_workers=2,
    # Whether to use GPU acceleration.
    use_gpu=False,
)

We then instantiate our XGBoostTrainer by passing in:

  • The aforementioned ScalingConfig.

  • The label_column refers to the column name containing the labels in the Ray Dataset

  • The params are XGBoost training parameters

from ray.train.xgboost import XGBoostTrainer

trainer = XGBoostTrainer(
    scaling_config=scaling_config,
    label_column="target",
    num_boost_round=20,
    params={
        # XGBoost specific params
        "objective": "binary:logistic",
        # "tree_method": "gpu_hist",  # uncomment this to use GPU for training
        "eval_metric": ["logloss", "error"],
    },
    datasets={"train": train_dataset, "valid": valid_dataset},
)

Lastly, we call trainer.fit() to kick off training and obtain the results.

result = trainer.fit()
print(result.metrics)

In this example we will train a model using distributed LightGBM.

First, we load the dataset from S3 using Ray Datasets and split it into a train and validation dataset.

import ray

# Load data.
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")

# Split data into train and validation.
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)

In the ScalingConfig, we configure the number of workers to use:

from ray.air.config import ScalingConfig

scaling_config = ScalingConfig(
    # Number of workers to use for data parallelism.
    num_workers=2,
    # Whether to use GPU acceleration.
    use_gpu=False,
)

We then instantiate our LightGBMTrainer by passing in:

  • The aforementioned ScalingConfig

  • The label_column refers to the column name containing the labels in the Ray Dataset

  • The params are core LightGBM training parameters

from ray.train.lightgbm import LightGBMTrainer

trainer = LightGBMTrainer(
    scaling_config=scaling_config,
    label_column="target",
    num_boost_round=20,
    params={
        # LightGBM specific params
        "objective": "binary",
        "metric": ["binary_logloss", "binary_error"],
    },
    datasets={"train": train_dataset, "valid": valid_dataset},
)

And lastly we call trainer.fit() to kick off training and obtain the results.

result = trainer.fit()
print(result.metrics)

This example shows how you can use Ray Train with PyTorch.

First, set up your dataset and model.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

def get_dataset():
    return datasets.FashionMNIST(
        root="/tmp/data",
        train=True,
        download=True,
        transform=ToTensor(),
    )

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, inputs):
        inputs = self.flatten(inputs)
        logits = self.linear_relu_stack(inputs)
        return logits

Now define your single-worker PyTorch training function.

def train_func():
    num_epochs = 3
    batch_size = 64

    dataset = get_dataset()
    dataloader = DataLoader(dataset, batch_size=batch_size)

    model = NeuralNetwork()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            pred = model(inputs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()
        print(f"epoch: {epoch}, loss: {loss.item()}")

This training function can be executed with:

train_func()

Now let’s convert this to a distributed multi-worker training function!

All you have to do is use the ray.train.torch.prepare_model and ray.train.torch.prepare_data_loader utility functions to easily setup your model & data for distributed training. This will automatically wrap your model with DistributedDataParallel and place it on the right device, and add DistributedSampler to your DataLoaders.

from ray import train

def train_func_distributed():
    num_epochs = 3
    batch_size = 64

    dataset = get_dataset()
    dataloader = DataLoader(dataset, batch_size=batch_size)
    dataloader = train.torch.prepare_data_loader(dataloader)

    model = NeuralNetwork()
    model = train.torch.prepare_model(model)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            pred = model(inputs)
            loss = criterion(pred, labels)
            loss.backward()
            optimizer.step()
        print(f"epoch: {epoch}, loss: {loss.item()}")

Then, instantiate a TorchTrainer with 4 workers, and use it to run the new training function!

from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

# For GPU Training, set `use_gpu` to True.
use_gpu = False

trainer = TorchTrainer(
    train_func_distributed,
    scaling_config=ScalingConfig(num_workers=4, use_gpu=use_gpu)
)

results = trainer.fit()

See Porting code from PyTorch, TensorFlow, or Horovod to Ray Train for a more comprehensive example.

This example shows how you can use Ray Train to set up Multi-worker training with Keras.

First, set up your dataset and model.

import numpy as np
import tensorflow as tf

def mnist_dataset(batch_size):
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
    # The `x` arrays are in uint8 and have values in the [0, 255] range.
    # You need to convert them to float32 with values in the [0, 1] range.
    x_train = x_train / np.float32(255)
    y_train = y_train.astype(np.int64)
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
    return train_dataset


def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(28, 28)),
        tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(32, 3, activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
        metrics=['accuracy'])
    return model

Now define your single-worker TensorFlow training function.

def train_func():
    batch_size = 64
    single_worker_dataset = mnist_dataset(batch_size)
    single_worker_model = build_and_compile_cnn_model()
    single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)

This training function can be executed with:

train_func()

Now let’s convert this to a distributed multi-worker training function! All you need to do is:

  1. Set the per-worker batch size - each worker will process the same size batch as in the single-worker code.

  2. Choose your TensorFlow distributed training strategy. In this example we use the MultiWorkerMirroredStrategy.

import json
import os

def train_func_distributed():
    per_worker_batch_size = 64
    # This environment variable will be set by Ray Train.
    tf_config = json.loads(os.environ['TF_CONFIG'])
    num_workers = len(tf_config['cluster']['worker'])

    strategy = tf.distribute.MultiWorkerMirroredStrategy()

    global_batch_size = per_worker_batch_size * num_workers
    multi_worker_dataset = mnist_dataset(global_batch_size)

    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_and_compile_cnn_model()

    multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

Then, instantiate a TensorflowTrainer with 4 workers, and use it to run the new training function!

from ray.train.tensorflow import TensorflowTrainer
from ray.air.config import ScalingConfig

# For GPU Training, set `use_gpu` to True.
use_gpu = False

trainer = TensorflowTrainer(train_func_distributed, scaling_config=ScalingConfig(num_workers=4, use_gpu=use_gpu))

trainer.fit()

See Porting code from PyTorch, TensorFlow, or Horovod to Ray Train for a more comprehensive example.

Next Steps#

  • To check how your application is doing, you can use the Ray dashboard.