Using Ray with Pytorch Lightning
Contents
Using Ray with Pytorch Lightning#
Note
For an overview of Ray’s distributed training library, see Ray Train.
PyTorch Lightning is a framework which brings structure into training PyTorch models. It aims to avoid boilerplate code, so you don’t have to write the same training loops all over again when building a new model.

Using Ray with Pytorch Lightning allows you to easily distribute training and also run distributed hyperparameter tuning experiments all from a single Python script. You can use the same code to run Pytorch Lightning in a single process on your laptop, parallelize across the cores of your laptop, or parallelize across a large multi-node cluster.
Ray provides 2 integration points with Pytorch Lightning.
Ray Lightning Library for distributed Pytorch Lightning training with Ray
Ray Tune with Pytorch Lightning for distributed hyperparameter tuning of your PTL models.
Distributed Training with Ray Lightning
#
The Ray Lightning Library provides plugins for distributed training with Ray.
These PyTorch Lightning Plugins on Ray enable quick and easy parallel training while still leveraging all the benefits of PyTorch Lightning. It offers the following plugins:
Fairscale for model parallel training.
Once you add your plugin to the PyTorch Lightning Trainer, you can parallelize training to all the cores in your laptop, or across a massive multi-node, multi-GPU cluster with no additional code changes.
Install the Ray Lightning Library with the following commands:
# To install from master
pip install git+https://github.com/ray-project/ray_lightning#ray_lightning
To use, simply pass in the plugin to your Pytorch Lightning Trainer
. For full details, you can checkout the README here
Here is an example of using the RayStrategy
for Distributed Data Parallel training on a Ray cluster:
First, let’s define our PyTorch Lightning module.
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
class LitAutoEncoder(pl.LightningModule):
def __init__(self, lr=1e-1):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)
)
self.lr = lr
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
Then, we create a PyTorch Lightning Trainer, passing in RayStrategy
. We can also configure the number of training workers we want to use and whether to use GPU.
import os
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
from ray_lightning import RayStrategy
num_workers = 2
use_gpu = False
max_steps = 10
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = pl.Trainer(
strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
max_steps=max_steps,
)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
With this strategy, Pytorch DDP is used as the distributed training communication protocol, but Ray is used to launch and manage the training worker processes.
Multi-node Distributed Training#
Using the same examples above, you can run distributed training on a multi-node cluster with just a couple simple steps.
First, use Ray’s Cluster Launcher to start a Ray cluster:
ray up my_cluster_config.yaml
Then, run your Ray script using one of the following options:
on the head node of the cluster (
python train_script.py
)via
ray job submit
(docs) from your laptop (ray job submit -- python train.py
)
Distributed Hyperparameter Optimization with Ray Tune#
You can also use Ray Tune with Pytorch Lightning to tune the hyperparameters of your model. With this integration, you can run multiple training runs in parallel, with each run having a different set of hyperparameters for your Pytorch Lightning model.
Hyperparameter Tuning with non-distributed training#
If you only want distributed hyperparameter tuning, but each training run doesn’t need to be distributed, you can use the ready-to-use Pytorch Lightning callbacks that Ray Tune provides.
We first wrap our training code into a function. To report metrics back to Tune after each validation epoch, we make sure to add the TuneReportCallback
to the PyTorch Lightning Trainer. The learning rate is read from the provided config
argument.
import os
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from ray.tune.integration.pytorch_lightning import TuneReportCallback
def train(config):
max_steps = 10
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
metrics = {"loss": "train_loss"}
autoencoder = LitAutoEncoder(lr=config["lr"])
trainer = pl.Trainer(
callbacks=[TuneReportCallback(metrics, on="batch_end")],
max_steps=max_steps,
)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
Then, we use the Ray Tune Tuner
to run our hyperparameter tuning experiment. We define a hyperparameter search space, and in this case we will try out different learning rate values. These hyperparameters get passed in as the config
argument to the training function that we defined earlier.
from ray import tune
param_space = {"lr": tune.loguniform(1e-4, 1e-1)}
num_samples = 1
tuner = tune.Tuner(
train,
tune_config=tune.TuneConfig(metric="loss", mode="min", num_samples=num_samples),
param_space=param_space,
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
And if you want to add periodic checkpointing as well, you can use the TuneReportCheckpointCallback
instead.
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
callback = TuneReportCheckpointCallback(
metrics={"loss": "val_loss", "mean_accuracy": "val_accuracy"},
filename="checkpoint",
on="validation_end")
Check out the Pytorch Lightning with Ray Tune tutorial for a full example on how you can use these callbacks and run a tuning experiment for your Pytorch Lightning model.
Hyperparameter Tuning with distributed training#
These integrations also support the case where you want a distributed hyperparameter tuning experiment, but each trial (training run) needs to be distributed as well. In this case, you want to use the Ray Lightning Library’s integration with Ray Tune.
With this integration, you can run multiple PyTorch Lightning training runs in parallel,
each with a different hyperparameter configuration, and each training run also parallelized.
All you have to do is move your training code to a function, pass the function to Tuner()
, and make sure to add the appropriate callback (Either TuneReportCallback
or TuneReportCheckpointCallback
) to your PyTorch Lightning Trainer.
Warning
Make sure to use the callbacks from the Ray Lightning library and not the one from the Tune library, i.e. use ray_lightning.tune.TuneReportCallback
and not ray.tune.integrations.pytorch_lightning.TuneReportCallback
.
As before, we first define our training function, this time making sure we specify RayStrategy
and using the TuneReportCallback
from the ray_lightning
library.
import os
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from ray_lightning import RayStrategy
from ray_lightning.tune import TuneReportCallback
num_workers = 1
use_gpu = False
max_steps = 10
def train_distributed(config):
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
metrics = {"loss": "train_loss"}
autoencoder = LitAutoEncoder(lr=config["lr"])
trainer = pl.Trainer(
callbacks=[TuneReportCallback(metrics, on="batch_end")],
strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
max_steps=max_steps,
)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
Then, we use the Tuner
to run our hyperparameter tuning experiment. We have to make sure we wrap our training function with tune.with_resources
to tell Tune that each of the trials will also be distributed.
from ray import tune
from ray_lightning.tune import get_tune_resources
param_space = {"lr": tune.loguniform(1e-4, 1e-1)}
num_samples = 1
tuner = tune.Tuner(
tune.with_resources(
train_distributed,
get_tune_resources(num_workers=num_workers, use_gpu=use_gpu),
),
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
num_samples=num_samples,
),
param_space=param_space,
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)