class ray.train.trainer.BaseTrainer(*args, **kwargs)[source]#

Bases: abc.ABC

Defines interface for distributed training on Ray.

Note: The base BaseTrainer class cannot be instantiated directly. Only one of its subclasses can be used.

How does a trainer work?

  • First, initialize the Trainer. The initialization runs locally, so heavyweight setup should not be done in __init__.

  • Then, when you call trainer.fit(), the Trainer is serialized and copied to a remote Ray actor. The following methods are then called in sequence on the remote actor.

  • trainer.setup(): Any heavyweight Trainer setup should be specified here.

  • trainer.preprocess_datasets(): The provided ray.data.Dataset are preprocessed with the provided ray.data.Preprocessor.

  • trainer.train_loop(): Executes the main training logic.

  • Calling trainer.fit() will return a ray.result.Result object where you can access metrics from your training run, as well as any checkpoints that may have been saved.

How do I create a new Trainer?

Subclass ray.train.trainer.BaseTrainer, and override the training_loop method, and optionally setup.

import torch

from ray.train.trainer import BaseTrainer
from ray import tune
from ray.air import session

class MyPytorchTrainer(BaseTrainer):
    def setup(self):
        self.model = torch.nn.Linear(1, 1)
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.1)

    def training_loop(self):
        # You can access any Trainer attributes directly in this method.
        # self.datasets["train"] has already been
        # preprocessed by self.preprocessor
        dataset = self.datasets["train"]

        torch_ds = dataset.iter_torch_batches(dtypes=torch.float)
        loss_fn = torch.nn.MSELoss()

        for epoch_idx in range(10):
            loss = 0
            num_batches = 0
            for batch in torch_ds:
                X, y = torch.unsqueeze(batch["x"], 1), batch["y"]
                # Compute prediction error
                pred = self.model(X)
                batch_loss = loss_fn(pred, y)

                # Backpropagation

                loss += batch_loss.item()
                num_batches += 1
            loss /= num_batches

            # Use Tune functions to report intermediate
            # results.
            session.report({"loss": loss, "epoch": epoch_idx})

How do I use an existing Trainer or one of my custom Trainers?

Initialize the Trainer, and call Trainer.fit()

import ray
train_dataset = ray.data.from_items(
    [{"x": i, "y": i} for i in range(3)])
my_trainer = MyPytorchTrainer(datasets={"train": train_dataset})
result = my_trainer.fit()
  • scaling_config – Configuration for how to scale training.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Ray Datasets to use for training. Use the key β€œtrain” to denote which dataset is the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided.

  • preprocessor – A preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

DeveloperAPI: This API may change across minor Ray releases.



Convert self to a tune.Trainable class.


Runs training.


Called during fit() to preprocess dataset attributes with preprocessor.


Called during fit() to perform initial setup on the Trainer.


Loop called by fit() to run training and report results to Tune.