class ray.train.torch.TorchTrainer(*args, **kwargs)[source]#

Bases: DataParallelTrainer

A Trainer for data parallel PyTorch training.

At a high level, this Trainer does the following:

  1. Launches multiple workers as defined by the scaling_config.

  2. Sets up a distributed PyTorch environment on these workers as defined by the torch_config.

  3. Ingests the input datasets based on the dataset_config.

  4. Runs the input train_loop_per_worker(train_loop_config) on all workers.

For more details, see:


import os
import tempfile

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel

import ray
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer

# If using GPUs, set this to True.
use_gpu = False
# Number of processes to run training on.
num_workers = 4

# Define your network structure.
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(32, 1)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# Training loop.
def train_loop_per_worker(config):

    # Read configurations.
    lr = config["lr"]
    batch_size = config["batch_size"]
    num_epochs = config["num_epochs"]

    # Fetch training dataset.
    train_dataset_shard = ray.train.get_dataset_shard("train")

    # Instantiate and prepare model for training.
    model = NeuralNetwork()
    model = ray.train.torch.prepare_model(model)

    # Define loss and optimizer.
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    # Create data loader.
    dataloader = train_dataset_shard.iter_torch_batches(
        batch_size=batch_size, dtypes=torch.float

    # Train multiple epochs.
    for epoch in range(num_epochs):

        # Train epoch.
        for batch in dataloader:
            output = model(batch["input"])
            loss = loss_fn(output, batch["label"])

        # Create checkpoint.
        base_model = (model.module
            if isinstance(model, DistributedDataParallel) else model)
        checkpoint_dir = tempfile.mkdtemp()
            {"model_state_dict": base_model.state_dict()},
            os.path.join(checkpoint_dir, "model.pt"),
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        # Report metrics and checkpoint.
        ray.train.report({"loss": loss.item()}, checkpoint=checkpoint)

# Define configurations.
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

# Define datasets.
train_dataset = ray.data.from_items(
    [{"input": [x], "label": [2 * x + 1]} for x in range(2000)]
datasets = {"train": train_dataset}

# Initialize the Trainer.
trainer = TorchTrainer(

# Train the model.
result = trainer.fit()

# Inspect the results.
final_loss = result.metrics["loss"]
  • train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single Dict argument which is set by defining train_loop_config. Within this function you can use any of the Ray Train Loop utilities.

  • train_loop_config – A configuration Dict to pass in as an argument to train_loop_per_worker. This is typically used for specifying hyperparameters. Passing large datasets via train_loop_config is not recommended and may introduce large overhead and unknown issues with serialization and deserialization.

  • torch_config – The configuration for setting up the PyTorch Distributed backend. If set to None, a default configuration will be used in which GPU training uses NCCL and CPU training uses Gloo.

  • scaling_config – The configuration for how to scale data parallel training. num_workers determines how many Python processes are used for training, and use_gpu determines whether or not each process should use GPUs. See ScalingConfig for more info.

  • run_config – The configuration for the execution of the training run. See RunConfig for more info.

  • datasets – The Ray Datasets to ingest for training. Datasets are keyed by name ({name: dataset}). Each dataset can be accessed from within the train_loop_per_worker by calling ray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in a dataset_config.

  • dataset_config – The configuration for ingesting the input datasets. By default, all the Ray Dataset are split equally across workers. See DataConfig for more details.

  • resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within train_loop_per_worker by calling ray.train.get_checkpoint().

  • metadata – Dict that should be made available via ray.train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.



Converts self to a tune.Trainable class.


Checks whether a given directory contains a restorable Train experiment.


Runs training.


Returns a copy of this Trainer's final dataset configs.




Restores a DataParallelTrainer from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.