class ray.train.trainer.BaseTrainer(*args, **kwargs)[source]#

Bases: abc.ABC

Defines interface for distributed training on Ray.

Note: The base BaseTrainer class cannot be instantiated directly. Only one of its subclasses can be used.

How does a trainer work?

  • First, initialize the Trainer. The initialization runs locally, so heavyweight setup should not be done in __init__.

  • Then, when you call trainer.fit(), the Trainer is serialized and copied to a remote Ray actor. The following methods are then called in sequence on the remote actor.

  • trainer.setup(): Any heavyweight Trainer setup should be specified here.

  • trainer.preprocess_datasets(): The provided ray.data.Dataset are preprocessed with the provided ray.data.Preprocessor.

  • trainer.train_loop(): Executes the main training logic.

  • Calling trainer.fit() will return a ray.result.Result object where you can access metrics from your training run, as well as any checkpoints that may have been saved.

How do I create a new Trainer?

Subclass ray.train.trainer.BaseTrainer, and override the training_loop method, and optionally setup.

import torch

from ray.train.trainer import BaseTrainer
from ray import tune
from ray.air import session

class MyPytorchTrainer(BaseTrainer):
    def setup(self):
        self.model = torch.nn.Linear(1, 1)
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.1)

    def training_loop(self):
        # You can access any Trainer attributes directly in this method.
        # self.datasets["train"] has already been
        # preprocessed by self.preprocessor
        dataset = self.datasets["train"]

        torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
        loss_fn = torch.nn.MSELoss()

        for epoch_idx in range(10):
            loss = 0
            num_batches = 0
            for batch in torch_ds:
                X, y = torch.unsqueeze(batch["x"], 1), batch["y"]
                # Compute prediction error
                pred = self.model(X)
                batch_loss = loss_fn(pred, y)

                # Backpropagation

                loss += batch_loss.item()
                num_batches += 1
            loss /= num_batches

            # Use Tune functions to report intermediate
            # results.
            session.report({"loss": loss, "epoch": epoch_idx})

How do I use an existing Trainer or one of my custom Trainers?

Initialize the Trainer, and call Trainer.fit()

import ray
train_dataset = ray.data.from_items(
    [{"x": i, "y": i} for i in range(3)])
my_trainer = MyPytorchTrainer(datasets={"train": train_dataset})
result = my_trainer.fit()
  • scaling_config – Configuration for how to scale training.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided.

  • preprocessor – A preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

DeveloperAPI: This API may change across minor Ray releases.



Convert self to a tune.Trainable class.


Checks whether a given directory contains a restorable Train experiment.


Runs training.


Called during fit() to preprocess dataset attributes with preprocessor.

restore(path[, datasets, preprocessor, ...])

Restores a Train experiment from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.


Loop called by fit() to run training and report results to Tune.