class ray.train.torch.TorchTrainer(*args, **kwargs)[source]#

Bases: ray.train.data_parallel_trainer.DataParallelTrainer

A Trainer for data parallel PyTorch training.

At a high level, this Trainer does the following:

  1. Launches multiple workers as defined by the scaling_config.

  2. Sets up a distributed PyTorch environment on these workers as defined by the torch_config.

  3. Ingests the input datasets based on the dataset_config.

  4. Runs the input train_loop_per_worker(train_loop_config) on all workers.

For more details, see the PyTorch User Guide.


import os
import tempfile

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel

import ray
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer

# If using GPUs, set this to True.
use_gpu = False
# Number of processes to run training on.
num_workers = 4

# Define your network structure.
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.layer1 = nn.Linear(1, 32)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(32, 1)

    def forward(self, input):
        return self.layer2(self.relu(self.layer1(input)))

# Training loop.
def train_loop_per_worker(config):

    # Read configurations.
    lr = config["lr"]
    batch_size = config["batch_size"]
    num_epochs = config["num_epochs"]

    # Fetch training dataset.
    train_dataset_shard = ray.train.get_dataset_shard("train")

    # Instantiate and prepare model for training.
    model = NeuralNetwork()
    model = ray.train.torch.prepare_model(model)

    # Define loss and optimizer.
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    # Create data loader.
    dataloader = train_dataset_shard.iter_torch_batches(
        batch_size=batch_size, dtypes=torch.float

    # Train multiple epochs.
    for epoch in range(num_epochs):

        # Train epoch.
        for batch in dataloader:
            output = model(batch["input"])
            loss = loss_fn(output, batch["label"])

        # Create checkpoint.
        base_model = (model.module
            if isinstance(model, DistributedDataParallel) else model)
        checkpoint_dir = tempfile.mkdtemp()
            {"model_state_dict": base_model.state_dict()},
            os.path.join(checkpoint_dir, "model.pt"),
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        # Report metrics and checkpoint.
        ray.train.report({"loss": loss.item()}, checkpoint=checkpoint)

# Define configurations.
train_loop_config = {"num_epochs": 20, "lr": 0.01, "batch_size": 32}
scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

# Define datasets.
train_dataset = ray.data.from_items(
    [{"input": [x], "label": [2 * x + 1]} for x in range(2000)]
datasets = {"train": train_dataset}

# Initialize the Trainer.
trainer = TorchTrainer(

# Train the model.
result = trainer.fit()

# Inspect the results.
final_loss = result.metrics["loss"]
  • train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single Dict argument which is set by defining train_loop_config. Within this function you can use any of the Ray Train Loop utilities.

  • train_loop_config – A configuration Dict to pass in as an argument to train_loop_per_worker. This is typically used for specifying hyperparameters.

  • torch_config – The configuration for setting up the PyTorch Distributed backend. If set to None, a default configuration will be used in which GPU training uses NCCL and CPU training uses Gloo.

  • scaling_config – The configuration for how to scale data parallel training. num_workers determines how many Python processes are used for training, and use_gpu determines whether or not each process should use GPUs. See ScalingConfig for more info.

  • run_config – The configuration for the execution of the training run. See RunConfig for more info.

  • datasets – The Ray Datasets to ingest for training. Datasets are keyed by name ({name: dataset}). Each dataset can be accessed from within the train_loop_per_worker by calling ray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in a dataset_config.

  • dataset_config – The configuration for ingesting the input datasets. By default, all the Ray Dataset are split equally across workers. See DataConfig for more details.

  • resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within train_loop_per_worker by calling ray.train.get_checkpoint().

  • metadata – Dict that should be made available via ray.train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

PublicAPI: This API is stable across Ray releases.



Converts self to a tune.Trainable class.

can_restore(path[, storage_filesystem])

Checks whether a given directory contains a restorable Train experiment.


Runs training.


Returns a copy of this Trainer's final dataset configs.

restore(path[, train_loop_per_worker, ...])

Restores a DataParallelTrainer from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.