Get Started with PyTorch Lightning#

This tutorial walks through the process of converting an existing PyTorch Lightning script to use Ray Train.

Learn how to:

  1. Configure the Lightning Trainer so that it runs distributed with Ray and on the correct CPU or GPU device.

  2. Configure training function to report metrics and save checkpoints.

  3. Configure scaling and CPU or GPU resource requirements for a training job.

  4. Launch a distributed training job with a TorchTrainer.

Quickstart#

For reference, the final code is as follows:

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig

def train_func(config):
    # Your PyTorch Lightning training code here.

scaling_config = ScalingConfig(num_workers=2, use_gpu=True)
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()
  1. train_func is the Python code that executes on each distributed training worker.

  2. ScalingConfig defines the number of distributed training workers and whether to use GPUs.

  3. TorchTrainer launches the distributed training job.

Compare a PyTorch Lightning training script with and without Ray Train.

import torch
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
import lightning.pytorch as pl

# Model, Loss, Optimizer
class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        self.model = resnet18(num_classes=10)
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        outputs = self.forward(x)
        loss = self.criterion(outputs, y)
        self.log("loss", loss, on_step=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)

# Data
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
train_data = FashionMNIST(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_data, batch_size=128, shuffle=True)

# Training
model = ImageClassifier()
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=train_dataloader)

Set up a training function#

First, update your training code to support distributed training. Begin by wrapping your code in a training function:

def train_func(config):
    # Your PyTorch Lightning training code here.

Each distributed training worker executes this function.

Ray Train sets up your distributed process group on each worker. You only need to make a few changes to your Lightning Trainer definition.

 import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

 def train_func(config):
     ...
     model = MyLightningModule(...)
     datamodule = MyLightningDataModule(...)

     trainer = pl.Trainer(
-        devices=[0,1,2,3],
-        strategy=DDPStrategy(),
-        plugins=[LightningEnvironment()],
+        devices="auto",
+        accelerator="auto",
+        strategy=ray.train.lightning.RayDDPStrategy(),
+        plugins=[ray.train.lightning.RayLightningEnvironment()]
     )
+    trainer = ray.train.lightning.prepare_trainer(trainer)

     trainer.fit(model, datamodule=datamodule)

The following sections discuss each change.

Configure the distributed strategy#

Ray Train offers several sub-classed distributed strategies for Lightning. These strategies retain the same argument list as their base strategy classes. Internally, they configure the root device and the distributed sampler arguments.

 import lightning.pytorch as pl
-from pl.strategies import DDPStrategy
+import ray.train.lightning

 def train_func(config):
     ...
     trainer = pl.Trainer(
         ...
-        strategy=DDPStrategy(),
+        strategy=ray.train.lightning.RayDDPStrategy(),
         ...
     )
     ...

Configure the Ray cluster environment plugin#

Ray Train also provides a RayLightningEnvironment class as a specification for the Ray Cluster. This utility class configures the worker’s local, global, and node rank and world size.

 import lightning.pytorch as pl
-from pl.plugins.environments import LightningEnvironment
+import ray.train.lightning

 def train_func(config):
     ...
     trainer = pl.Trainer(
         ...
-        plugins=[LightningEnvironment()],
+        plugins=[ray.train.lightning.RayLightningEnvironment()],
         ...
     )
     ...

Configure parallel devices#

In addition, Ray TorchTrainer has already configured the correct CUDA_VISIBLE_DEVICES for you. One should always use all available GPUs by setting devices="auto" and acelerator="auto".

 import lightning.pytorch as pl

 def train_func(config):
     ...
     trainer = pl.Trainer(
         ...
-        devices=[0,1,2,3],
+        devices="auto",
+        accelerator="auto",
         ...
     )
     ...

Report checkpoints and metrics#

To persist your checkpoints and monitor training progress, add a ray.train.lightning.RayTrainReportCallback utility callback to your Trainer.

 import lightning.pytorch as pl
 from ray.train.lightning import RayTrainReportCallback

 def train_func(config):
     ...
     trainer = pl.Trainer(
         ...
-        callbacks=[...],
+        callbacks=[..., RayTrainReportCallback()],
     )
     ...

Reporting metrics and checkpoints to Ray Train enables you to support fault-tolerant training and hyperparameter optimization. Note that the ray.train.lightning.RayTrainReportCallback class only provides a simple implementation, and can be further customized.

Prepare your Lightning Trainer#

Finally, pass your Lightning Trainer into prepare_trainer() to validate your configurations.

 import lightning.pytorch as pl
 import ray.train.lightning

 def train_func(config):
     ...
     trainer = pl.Trainer(...)
+    trainer = ray.train.lightning.prepare_trainer(trainer)
     ...

Configure scale and GPUs#

Outside of your training function, create a ScalingConfig object to configure:

  1. num_workers - The number of distributed training worker processes.

  2. use_gpu - Whether each worker should use a GPU (or CPU).

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

For more details, see Configuring Scale and GPUs.

Launch a training job#

Tying this all together, you can now launch a distributed training job with a TorchTrainer.

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()

See Run Configuration in Train (RunConfig) for more configuration options for TorchTrainer.

Access training results#

After training completes, Ray Train returns a Result object, which contains information about the training run, including the metrics and checkpoints reported during training.

result.metrics     # The metrics reported during training.
result.checkpoint  # The latest checkpoint reported during training.
result.path     # The path where logs are stored.
result.error       # The exception that was raised, if training failed.

Next steps#

After you have converted your PyTorch Lightning training script to use Ray Train:

  • See User Guides to learn more about how to perform specific tasks.

  • Browse the Examples for end-to-end examples of how to use Ray Train.

  • Consult the API Reference for more details on the classes and methods from this tutorial.

Version Compatibility#

Ray Train is tested with pytorch_lightning versions 1.6.5 and 2.0.4. For full compatibility, use pytorch_lightning>=1.6.5 . Earlier versions aren’t prohibited but may result in unexpected issues. If you run into any compatibility issues, consider upgrading your PyTorch Lightning version or file an issue.

Note

If you are using Lightning 2.x, please use the import path lightning.pytorch.xxx instead of pytorch_lightning.xxx.

LightningTrainer Migration Guide#

Ray 2.4 introduced the LightningTrainer, and exposed a LightningConfigBuilder to define configurations for pl.LightningModule and pl.Trainer.

It then instantiates the model and trainer objects and runs a pre-defined training function in a black box.

This version of the LightningTrainer API was constraining and limited your ability to manage the training functionality.

Ray 2.7 introduced the newly unified TorchTrainer API, which offers enhanced transparency, flexibility, and simplicity. This API is more aligned with standard PyTorch Lightning scripts, ensuring users have better control over their native Lightning code.

from ray.train.lightning import LightningConfigBuilder, LightningTrainer

config_builder = LightningConfigBuilder()
# [1] Collect model configs
config_builder.module(cls=MNISTClassifier, lr=1e-3, feature_dim=128)

# [2] Collect checkpointing configs
config_builder.checkpointing(monitor="val_accuracy", mode="max", save_top_k=3)

# [3] Collect pl.Trainer configs
config_builder.trainer(
    max_epochs=10,
    accelerator="gpu",
    log_every_n_steps=100,
    logger=CSVLogger("./logs"),
)

# [4] Build datasets on the head node
datamodule = MNISTDataModule(batch_size=32)
config_builder.fit_params(datamodule=datamodule)

# [5] Execute the internal training function in a black box
ray_trainer = LightningTrainer(
    lightning_config=config_builder.build(),
    scaling_config=ScalingConfig(num_workers=4, use_gpu=True),
    run_config=RunConfig(
        checkpoint_config=CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="val_accuracy",
            checkpoint_score_order="max",
        ),
    )
)
ray_trainer.fit()