ray.train.lightning.LightningTrainer
ray.train.lightning.LightningTrainer#
- class ray.train.lightning.LightningTrainer(*args, **kwargs)[source]#
Bases:
ray.train.torch.torch_trainer.TorchTrainer
A Trainer for data parallel PyTorch Lightning training.
This Trainer runs the
pytorch_lightning.Trainer.fit()
method on multiple Ray Actors. The training is carried out in a distributed fashion through PyTorch DDP. These actors already have the necessary Torch process group configured for distributed data parallel training.The training function ran on every Actor will first initialize an instance of the user-provided
lightning_module
class, which is a subclass ofpytorch_lightning.LightningModule
using the arguments provided inLightningConfigBuilder.module()
.For data ingestion, the LightningTrainer will then either convert the Ray Dataset shards to a
pytorch_lightning.LightningDataModule
, or directly use the datamodule if provided by users.The trainer will also create a ModelCheckpoint callback based on the configuration provided in
model_checkpoint_config
. Notice that all the other ModelCheckpoint callbacks specified inlightning_trainer_config
will be ignored.For logging, users can continue to use Lightning’s native loggers, such as WandbLogger, TensorboardLogger, etc. LightningTrainer will also log the latest metrics to the trail directory whenever a new checkpoint is saved.
Then, the training function will initialize an instance of
pl.Trainer
using the arguments provided inLightningConfigBuilder.fit_params()
and then runpytorch_lightning.Trainer.fit
.TODO(yunxuanx): make this example testable
Example
import torch import torch.nn.functional as F from torchmetrics import Accuracy from torch.utils.data import DataLoader, Subset from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from ray.air.config import ScalingConfig from ray.train.lightning import LightningTrainer, LightningConfig class MNISTClassifier(pl.LightningModule): def __init__(self, lr, feature_dim): super(MNISTClassifier, self).__init__() self.fc1 = torch.nn.Linear(28 * 28, feature_dim) self.fc2 = torch.nn.Linear(feature_dim, 10) self.lr = lr self.accuracy = Accuracy() def forward(self, x): x = x.view(-1, 28 * 28) x = torch.relu(self.fc1(x)) x = self.fc2(x) return x def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = torch.nn.functional.cross_entropy(y_hat, y) self.log('train_loss', loss) return loss def validation_step(self, val_batch, batch_idx): x, y = val_batch logits = self.forward(x) loss = F.nll_loss(logits, y) acc = self.accuracy(logits, y) return {"val_loss": loss, "val_accuracy": acc} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean() self.log("ptl/val_loss", avg_loss) self.log("ptl/val_accuracy", avg_acc) def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer # Prepare MNIST Datasets transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) mnist_train = MNIST( './data', train=True, download=True, transform=transform ) mnist_val = MNIST( './data', train=False, download=True, transform=transform ) # Take small subsets for smoke test # Please remove these two lines if you want to train the full dataset mnist_train = Subset(mnist_train, range(1000)) mnist_train = Subset(mnist_train, range(500)) train_loader = DataLoader(mnist_train, batch_size=32, shuffle=True) val_loader = DataLoader(mnist_val, batch_size=32, shuffle=False) lightning_config = ( LightningConfigBuilder() .module(cls=MNISTClassifier, lr=1e-3, feature_dim=128) .trainer(max_epochs=3, accelerator="cpu") .fit_params(train_dataloaders=train_loader, val_dataloaders=val_loader) .build() ) scaling_config = ScalingConfig( num_workers=4, use_gpu=False, resources_per_worker={"CPU": 1} ) trainer = LightningTrainer( lightning_config=lightning_config, scaling_config=scaling_config, ) results = trainer.fit()
- Parameters
lightning_config – Configuration for setting up the Pytorch Lightning Trainer. You can setup the configurations with
LightningConfigBuilder
, and generate this config dictionary throughLightningBuilder.build()
.torch_config – Configuration for setting up the PyTorch backend. If set to None, use the default configuration. This replaces the
backend_config
arg ofDataParallelTrainer
. Same as inTorchTrainer
.scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest.
run_config – Configuration for the execution of the training run.
datasets – A dictionary of Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset and (optionally) key “val” to denote the validation dataset. If a
preprocessor
is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by thepreprocessor
if one is provided.datasets_iter_config – Configurations for iterating over input Ray datasets. This configuration is only valid when
datasets
argument is provided to the LightningTrainer. Otherwise, LightningTrainer will use datamodule or dataloaders specified inLightningConfig.trainer_init_config
. For valid arguments to pass, please refer to:Dataset.iter_torch_batches
preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.
resume_from_checkpoint – A checkpoint to resume training from.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
- classmethod restore(path: str, datasets: Optional[Dict[str, Union[Dataset, Callable[[], Dataset]]]] = None, preprocessor: Optional[Preprocessor] = None, scaling_config: Optional[ray.air.config.ScalingConfig] = None, **kwargs) LightningTrainer [source]#
Restores a LightningTrainer from a previously interrupted/failed run.
See
BaseTrainer.restore()
for descriptions of the arguments.- Returns
A restored instance of
LightningTrainer
- Return type