ray.train.lightning.LightningTrainer
ray.train.lightning.LightningTrainer#
- class ray.train.lightning.LightningTrainer(*args, **kwargs)[source]#
Bases:
ray.train.torch.torch_trainer.TorchTrainer
A Trainer for data parallel PyTorch Lightning training.
This Trainer runs the
pytorch_lightning.Trainer.fit()
method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary Torch process group configured for distributed data parallel training. We will support more distributed training strategies in the future.The training function ran on every Actor will first initialize an instance of the user-provided
lightning_module
class, which is a subclass ofpytorch_lightning.LightningModule
using the arguments provided inLightningConfigBuilder.module()
.For data ingestion, the LightningTrainer will then either convert the Ray Dataset shards to a
pytorch_lightning.LightningDataModule
, or directly use the datamodule or dataloaders if provided by users.The trainer also creates a ModelCheckpoint callback based on the configuration provided in
LightningConfigBuilder.checkpointing()
. In addition to checkpointing, this callback also callssession.report()
to report the latest metrics along with the checkpoint to the AIR session.For logging, users can continue to use Lightning’s native loggers, such as WandbLogger, TensorboardLogger, etc. LightningTrainer will also log the latest metrics to the trail directory whenever a new checkpoint is saved.
Then, the training function will initialize an instance of
pl.Trainer
using the arguments provided inLightningConfigBuilder.fit_params()
and then runpytorch_lightning.Trainer.fit
.Example
import torch import torch.nn.functional as F from torchmetrics import Accuracy from torch.utils.data import DataLoader, Subset from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from ray.air.config import ScalingConfig from ray.train.lightning import LightningTrainer, LightningConfigBuilder class MNISTClassifier(pl.LightningModule): def __init__(self, lr, feature_dim): super(MNISTClassifier, self).__init__() self.fc1 = torch.nn.Linear(28 * 28, feature_dim) self.fc2 = torch.nn.Linear(feature_dim, 10) self.lr = lr self.accuracy = Accuracy() def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = torch.nn.functional.cross_entropy(y_hat, y) self.log("train_loss", loss) return loss def validation_step(self, val_batch, batch_idx): x, y = val_batch logits = self.forward(x) loss = F.nll_loss(logits, y) acc = self.accuracy(logits, y) return {"val_loss": loss, "val_accuracy": acc} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() self.log("ptl/val_loss", avg_loss) self.log("ptl/val_accuracy", avg_acc) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer # Prepare MNIST Datasets transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) mnist_train = MNIST( './data', train=True, download=True, transform=transform ) mnist_val = MNIST( './data', train=False, download=True, transform=transform ) # Take small subsets for smoke test # Please remove these two lines if you want to train the full dataset mnist_train = Subset(mnist_train, range(1000)) mnist_train = Subset(mnist_train, range(500)) train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True) val_loader = DataLoader(mnist_val, batch_size=128, shuffle=False) lightning_config = ( LightningConfigBuilder() .module(cls=MNISTClassifier, lr=1e-3, feature_dim=128) .trainer(max_epochs=3, accelerator="cpu") .fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader) .build() ) scaling_config = ScalingConfig( num_workers=4, use_gpu=False, resources_per_worker={"CPU": 1} ) trainer = LightningTrainer( lightning_config=lightning_config, scaling_config=scaling_config, ) result = trainer.fit() result
- Parameters
lightning_config – Configuration for setting up the Pytorch Lightning Trainer. You can setup the configurations with
LightningConfigBuilder
, and generate this config dictionary throughLightningBuilder.build()
.torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the
backend_config
arg ofDataParallelTrainer
. Same as inTorchTrainer
.scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest.
run_config – Configuration for the execution of the training run.
datasets – A dictionary of Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset and (optionally) key “val” to denote the validation dataset. If a
preprocessor
is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by thepreprocessor
if one is provided.datasets_iter_config – Configurations for iterating over input Ray datasets. This configuration is only valid when
datasets
argument is provided to the LightningTrainer. Otherwise, LightningTrainer will use datamodule or dataloaders specified inLightningConfig.trainer_init_config
. For valid arguments to pass, please refer to:Dataset.iter_torch_batches
preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.
resume_from_checkpoint – A checkpoint to resume training from.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
Convert self to a
tune.Trainable
class.can_restore
(path)Checks whether a given directory contains a restorable Train experiment.
fit
()Runs training.
Return a copy of this Trainer's final dataset configs.
restore
(path[, datasets, preprocessor, ...])Restores a LightningTrainer from a previously interrupted/failed run.
setup
()Called during fit() to perform initial setup on the Trainer.