Using Ray with Pytorch Lightning

Note

For an overview of Ray’s distributed training library, see Ray Train.

PyTorch Lightning is a framework which brings structure into training PyTorch models. It aims to avoid boilerplate code, so you don’t have to write the same training loops all over again when building a new model.

../_images/pytorch_lightning_full.png

Using Ray with Pytorch Lightning allows you to easily distribute training and also run distributed hyperparameter tuning experiments all from a single Python script. You can use the same code to run Pytorch Lightning in a single process on your laptop, parallelize across the cores of your laptop, or parallelize across a large multi-node cluster.

Ray provides 2 integration points with Pytorch Lightning.

  1. Ray Lightning Library for distributed Pytorch Lightning training with Ray

  2. Ray Tune with Pytorch Lightning for distributed hyperparameter tuning of your PTL models.

Distributed Training with Ray Lightning

The Ray Lightning Library provides plugins for distributed training with Ray.

These PyTorch Lightning Plugins on Ray enable quick and easy parallel training while still leveraging all the benefits of PyTorch Lightning. It offers the following plugins:

Once you add your plugin to the PyTorch Lightning Trainer, you can parallelize training to all the cores in your laptop, or across a massive multi-node, multi-GPU cluster with no additional code changes.

Install the Ray Lightning Library with the following commands:

# To install from master
pip install git+https://github.com/ray-project/ray_lightning#ray_lightning

To use, simply pass in the plugin to your Pytorch Lightning Trainer. For full details, you can checkout the README here

Here is an example of using the RayStrategy for Distributed Data Parallel training on a Ray cluster:

First, let’s define our PyTorch Lightning module.

import torch
from torch import nn
import torch.nn.functional as F

import pytorch_lightning as pl


class LitAutoEncoder(pl.LightningModule):
    def __init__(self, lr=1e-1):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)
        )
        self.lr = lr

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


Then, we create a PyTorch Lightning Trainer, passing in RayStrategy. We can also configure the number of training workers we want to use and whether to use GPU.

import os
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

from ray_lightning import RayStrategy

num_workers = 2
use_gpu = False
max_steps = 10

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
trainer = pl.Trainer(
    strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
    max_steps=max_steps,
)
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))

With this strategy, Pytorch DDP is used as the distributed training communication protocol, but Ray is used to launch and manage the training worker processes.

Multi-node Distributed Training

Using the same examples above, you can run distributed training on a multi-node cluster with just a couple simple steps.

First, use Ray’s Cluster Launcher to start a Ray cluster:

ray up my_cluster_config.yaml

Then, run your Ray script using one of the following options:

  1. on the head node of the cluster (python train_script.py)

  2. via ray job submit (docs) from your laptop (ray job submit -- python train.py)

  3. via the Ray Client from your laptop.

Distributed Hyperparameter Optimization with Ray Tune

You can also use Ray Tune with Pytorch Lightning to tune the hyperparameters of your model. With this integration, you can run multiple training runs in parallel, with each run having a different set of hyperparameters for your Pytorch Lightning model.

Hyperparameter Tuning with non-distributed training

If you only want distributed hyperparameter tuning, but each training run doesn’t need to be distributed, you can use the ready-to-use Pytorch Lightning callbacks that Ray Tune provides.

We first wrap our training code into a function. To report metrics back to Tune after each validation epoch, we make sure to add the TuneReportCallback to the PyTorch Lightning Trainer. The learning rate is read from the provided config argument.

import os
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl

from ray.tune.integration.pytorch_lightning import TuneReportCallback


def train(config):
    max_steps = 10

    dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
    train, val = random_split(dataset, [55000, 5000])

    metrics = {"loss": "train_loss"}
    autoencoder = LitAutoEncoder(lr=config["lr"])
    trainer = pl.Trainer(
        callbacks=[TuneReportCallback(metrics, on="batch_end")],
        max_steps=max_steps,
    )
    trainer.fit(autoencoder, DataLoader(train), DataLoader(val))


Then, we use the Ray Tune Tuner to run our hyperparameter tuning experiment. We define a hyperparameter search space, and in this case we will try out different learning rate values. These hyperparameters get passed in as the config argument to the training function that we defined earlier.

from ray import tune

param_space = {"lr": tune.loguniform(1e-4, 1e-1)}
num_samples = 1

tuner = tune.Tuner(
    train,
    tune_config=tune.TuneConfig(metric="loss", mode="min", num_samples=num_samples),
    param_space=param_space,
)

results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)

And if you want to add periodic checkpointing as well, you can use the TuneReportCheckpointCallback instead.

from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
callback = TuneReportCheckpointCallback(
    metrics={"loss": "val_loss", "mean_accuracy": "val_accuracy"},
    filename="checkpoint",
    on="validation_end")

Check out the Pytorch Lightning with Ray Tune tutorial for a full example on how you can use these callbacks and run a tuning experiment for your Pytorch Lightning model.

Hyperparameter Tuning with distributed training

These integrations also support the case where you want a distributed hyperparameter tuning experiment, but each trial (training run) needs to be distributed as well. In this case, you want to use the Ray Lightning Library’s integration with Ray Tune.

With this integration, you can run multiple PyTorch Lightning training runs in parallel, each with a different hyperparameter configuration, and each training run also parallelized. All you have to do is move your training code to a function, pass the function to Tuner(), and make sure to add the appropriate callback (Either TuneReportCallback or TuneReportCheckpointCallback) to your PyTorch Lightning Trainer.

Warning

Make sure to use the callbacks from the Ray Lightning library and not the one from the Tune library, i.e. use ray_lightning.tune.TuneReportCallback and not ray.tune.integrations.pytorch_lightning.TuneReportCallback.

As before, we first define our training function, this time making sure we specify RayStrategy and using the TuneReportCallback from the ray_lightning library.

import os
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl

from ray_lightning import RayStrategy
from ray_lightning.tune import TuneReportCallback

num_workers = 1
use_gpu = False
max_steps = 10


def train_distributed(config):
    dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
    train, val = random_split(dataset, [55000, 5000])

    metrics = {"loss": "train_loss"}
    autoencoder = LitAutoEncoder(lr=config["lr"])
    trainer = pl.Trainer(
        callbacks=[TuneReportCallback(metrics, on="batch_end")],
        strategy=RayStrategy(num_workers=num_workers, use_gpu=use_gpu),
        max_steps=max_steps,
    )
    trainer.fit(autoencoder, DataLoader(train), DataLoader(val))


Then, we use the Tuner to run our hyperparameter tuning experiment. We have to make sure we wrap our training function with tune.with_resources to tell Tune that each of the trials will also be distributed.

from ray import tune
from ray_lightning.tune import get_tune_resources

param_space = {"lr": tune.loguniform(1e-4, 1e-1)}
num_samples = 1

tuner = tune.Tuner(
    tune.with_resources(
        train_distributed,
        get_tune_resources(num_workers=num_workers, use_gpu=use_gpu),
    ),
    tune_config=tune.TuneConfig(
        metric="loss",
        mode="min",
        num_samples=num_samples,
    ),
    param_space=param_space,
)

results = tuner.fit()

print("Best hyperparameters found were: ", results.get_best_result().config)