class ray.train.lightning.LightningTrainer(*args, **kwargs)[source]#

Bases: ray.train.torch.torch_trainer.TorchTrainer

A Trainer for data parallel PyTorch Lightning training.

This Trainer runs the pytorch_lightning.Trainer.fit() method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary Torch process group configured for distributed data parallel training. We will support more distributed training strategies in the future.

The training function ran on every Actor will first initialize an instance of the user-provided lightning_module class, which is a subclass of pytorch_lightning.LightningModule using the arguments provided in LightningConfigBuilder.module().

For data ingestion, the LightningTrainer will then either convert the Ray Dataset shards to a pytorch_lightning.LightningDataModule, or directly use the datamodule or dataloaders if provided by users.

The trainer also creates a ModelCheckpoint callback based on the configuration provided in LightningConfigBuilder.checkpointing(). In addition to checkpointing, this callback also calls session.report() to report the latest metrics along with the checkpoint to the AIR session.

For logging, users can continue to use Lightning’s native loggers, such as WandbLogger, TensorboardLogger, etc. LightningTrainer will also log the latest metrics to the trail directory whenever a new checkpoint is saved.

Then, the training function will initialize an instance of pl.Trainer using the arguments provided in LightningConfigBuilder.fit_params() and then run pytorch_lightning.Trainer.fit.


import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from ray.air.config import ScalingConfig
from ray.train.lightning import LightningTrainer, LightningConfigBuilder

class MNISTClassifier(pl.LightningModule):
    def __init__(self, lr, feature_dim):
        super(MNISTClassifier, self).__init__()
        self.fc1 = torch.nn.Linear(28 * 28, feature_dim)
        self.fc2 = torch.nn.Linear(feature_dim, 10)
        self.lr = lr
        self.accuracy = Accuracy()

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = torch.nn.functional.cross_entropy(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = F.nll_loss(logits, y)
        acc = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": acc}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# Prepare MNIST Datasets
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
mnist_train = MNIST(
    './data', train=True, download=True, transform=transform
mnist_val = MNIST(
    './data', train=False, download=True, transform=transform

# Take small subsets for smoke test
# Please remove these two lines if you want to train the full dataset
mnist_train = Subset(mnist_train, range(1000))
mnist_train = Subset(mnist_train, range(500))

train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)
val_loader = DataLoader(mnist_val, batch_size=128, shuffle=False)

lightning_config = (
    .module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)
    .trainer(max_epochs=3, accelerator="cpu")
    .fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader)

scaling_config = ScalingConfig(
    num_workers=4, use_gpu=False, resources_per_worker={"CPU": 1}
trainer = LightningTrainer(
result = trainer.fit()
  • lightning_config – Configuration for setting up the Pytorch Lightning Trainer. You can setup the configurations with LightningConfigBuilder, and generate this config dictionary through LightningBuilder.build().

  • torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the backend_config arg of DataParallelTrainer. Same as in TorchTrainer.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest.

  • run_config – Configuration for the execution of the training run.

  • datasets – A dictionary of Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset and (optionally) key “val” to denote the validation dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided.

  • datasets_iter_config – Configurations for iterating over input Ray datasets. This configuration is only valid when datasets argument is provided to the LightningTrainer. Otherwise, LightningTrainer will use datamodule or dataloaders specified in LightningConfig.trainer_init_config. For valid arguments to pass, please refer to: Dataset.iter_torch_batches

  • preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.



Convert self to a tune.Trainable class.


Checks whether a given directory contains a restorable Train experiment.


Runs training.


Return a copy of this Trainer's final dataset configs.

restore(path[, datasets, preprocessor, ...])

Restores a LightningTrainer from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.