class ray.train.trainer.BaseTrainer(*args, **kwargs)[source]#

Bases: ABC

Defines interface for distributed training on Ray.

Note: The base BaseTrainer class cannot be instantiated directly. Only one of its subclasses can be used.

Note to developers: If a new trainer is added, please update air/_internal/usage.py.

How does a trainer work?

  • First, initialize the Trainer. The initialization runs locally, so heavyweight setup should not be done in __init__.

  • Then, when you call trainer.fit(), the Trainer is serialized and copied to a remote Ray actor. The following methods are then called in sequence on the remote actor.

  • trainer.setup(): Any heavyweight Trainer setup should be specified here.

  • trainer.training_loop(): Executes the main training logic.

  • Calling trainer.fit() will return a ray.result.Result object where you can access metrics from your training run, as well as any checkpoints that may have been saved.

How do I create a new Trainer?

Subclass ray.train.trainer.BaseTrainer, and override the training_loop method, and optionally setup.

import torch

from ray.train.trainer import BaseTrainer
from ray import train, tune

class MyPytorchTrainer(BaseTrainer):
    def setup(self):
        self.model = torch.nn.Linear(1, 1)
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.1)

    def training_loop(self):
        # You can access any Trainer attributes directly in this method.
        # self.datasets["train"] has already been
        dataset = self.datasets["train"]

        torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
        loss_fn = torch.nn.MSELoss()

        for epoch_idx in range(10):
            loss = 0
            num_batches = 0
            torch_ds = dataset.iter_torch_batches(
                dtypes=torch.float, batch_size=2
            for batch in torch_ds:
                X = torch.unsqueeze(batch["x"], 1)
                y = torch.unsqueeze(batch["y"], 1)
                # Compute prediction error
                pred = self.model(X)
                batch_loss = loss_fn(pred, y)

                # Backpropagation

                loss += batch_loss.item()
                num_batches += 1
            loss /= num_batches

            # Use Tune functions to report intermediate
            # results.
            train.report({"loss": loss, "epoch": epoch_idx})

# Initialize the Trainer, and call Trainer.fit()
import ray
train_dataset = ray.data.from_items(
    [{"x": i, "y": i} for i in range(10)])
my_trainer = MyPytorchTrainer(datasets={"train": train_dataset})
result = my_trainer.fit()
  • scaling_config – Configuration for how to scale training.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.

  • metadata – Dict that should be made available via train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

  • resume_from_checkpoint – A checkpoint to resume training from.

DeveloperAPI: This API may change across minor Ray releases.



Converts self to a tune.Trainable class.


Checks whether a given directory contains a restorable Train experiment.


Runs training.




Restores a Train experiment from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.


Loop called by fit() to run training and report results to Tune.