class ray.train.torch.TorchTrainer(*args, **kwargs)[source]#

Bases: ray.train.data_parallel_trainer.DataParallelTrainer

A Trainer for data parallel PyTorch training.

This Trainer runs the function train_loop_per_worker on multiple Ray Actors. These actors already have the necessary torch process group configured for distributed PyTorch training.

The train_loop_per_worker function is expected to take in either 0 or 1 arguments:

def train_loop_per_worker():
from typing import Dict, Any
def train_loop_per_worker(config: Dict[str, Any]):

If train_loop_per_worker accepts an argument, then train_loop_config will be passed in as the argument. This is useful if you want to tune the values in train_loop_config as hyperparameters.

If the datasets dict contains a training dataset (denoted by the β€œtrain” key), then it will be split into multiple dataset shards that can then be accessed by session.get_dataset_shard("train") inside train_loop_per_worker. All the other datasets will not be split and session.get_dataset_shard(...) will return the the entire Dataset.

Inside the train_loop_per_worker function, you can use any of the Ray AIR session methods. See full example code below.

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.

    # Get dict of last saved checkpoint.

    # Session returns the Ray Dataset shard for the given key.

    # Get the total number of workers executing training.

    # Get the rank of this worker.

    # Get the rank of the worker on the current node.

You can also use any of the Torch specific function utils, such as ray.train.torch.get_device() and ray.train.torch.prepare_model()

def train_loop_per_worker():
    # Prepares model for distribted training by wrapping in
    # `DistributedDataParallel` and moving to correct device.

    # Configures the dataloader for distributed training by adding a
    # `DistributedSampler`.
    # You should NOT use this if you are doing
    # `session.get_dataset_shard(...).iter_torch_batches(...)`

    # Get the current torch device.

Any returns from the train_loop_per_worker will be discarded and not used or persisted anywhere.

To save a model to use for the TorchPredictor, you must save it under the β€œmodel” kwarg in Checkpoint passed to session.report().


When you wrap the model with prepare_model, the keys of its state_dict are prefixed by module.. For example, layer1.0.bn1.bias becomes module.layer1.0.bn1.bias. However, when saving model through session.report() all module. prefixes are stripped. As a result, when you load from a saved checkpoint, make sure that you first load state_dict to the model before calling prepare_model. Otherwise, you will run into errors like Error(s) in loading state_dict for DistributedDataParallel: Missing key(s) in state_dict: "module.conv1.weight", .... See snippet below.

from torchvision.models import resnet18
from ray.air import session
from ray.air.checkpoint import Checkpoint
import ray.train as train

def train_func():
    model = resnet18()
    model = train.torch.prepare_model(model)
    for epoch in range(3):
        ckpt = Checkpoint.from_dict({
            "epoch": epoch,
            "model": model.state_dict(),
            # "model": model.module.state_dict(),
            # ** The above two are equivalent **
        session.report({"foo": "bar"}, ckpt)


import torch
import torch.nn as nn

import ray
from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air.config import RunConfig
from ray.air.config import CheckpointConfig

# If using GPUs, set this to True.
use_gpu = False

# Define NN layers archicture, epochs, and number of workers
input_size = 1
layer_size = 32
output_size = 1
num_epochs = 200
num_workers = 3

# Define your network structure
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(input_size, layer_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(layer_size, output_size)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# Define your train worker loop
def train_loop_per_worker():

    # Fetch training set from the session
    dataset_shard = session.get_dataset_shard("train")
    model = NeuralNetwork()

    # Loss function, optimizer, prepare model for training.
    # This moves the data and prepares model for distributed
    # execution
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
    model = train.torch.prepare_model(model)

    # Iterate over epochs and batches
    for epoch in range(num_epochs):
        for batches in dataset_shard.iter_torch_batches(batch_size=32,

            # Add batch or unsqueeze as an additional dimension [32, x]
            inputs, labels = torch.unsqueeze(batches["x"], 1), batches["y"]
            output = model(inputs)

            # Make output shape same as the as labels
            loss = loss_fn(output.squeeze(), labels)

            # Zero out grads, do backward, and update optimizer

            # Print what's happening with loss per 30 epochs
            if epoch % 20 == 0:
                print(f"epoch: {epoch}/{num_epochs}, loss: {loss:.3f}")

        # Report and record metrics, checkpoint model at end of each
        # epoch
        session.report({"loss": loss.item(), "epoch": epoch},
                             dict(epoch=epoch, model=model.state_dict()))

train_dataset = ray.data.from_items(
    [{"x": x, "y": 2 * x + 1} for x in range(200)]

# Define scaling and run configs
scaling_config = ScalingConfig(num_workers=3, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

trainer = TorchTrainer(
    datasets={"train": train_dataset})

result = trainer.fit()

best_checkpoint_loss = result.metrics['loss']

# Assert loss is less 0.09
assert best_checkpoint_loss <= 0.09
  • train_loop_per_worker – The training function to execute. This can either take in no arguments or a config dict.

  • train_loop_config – Configurations to pass into train_loop_per_worker if it accepts an argument.

  • torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the backend_config arg of DataParallelTrainer.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Ray Datasets to use for training. Use the key β€œtrain” to denote which dataset is the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided.

  • preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

PublicAPI (beta): This API is in beta and may change before becoming stable.