Using PyTorch Lightning with Tune#

PyTorch Lightning is a framework which brings structure into training PyTorch models. It aims to avoid boilerplate code, so you don’t have to write the same training loops all over again when building a new model.

../../_images/pytorch_lightning_full.png

The main abstraction of PyTorch Lightning is the LightningModule class, which should be extended by your application. There is a great post on how to transfer your models from vanilla PyTorch to Lightning.

The class structure of PyTorch Lightning makes it very easy to define and tune model parameters. This tutorial will show you how to use Tune with AIR LightningTrainer to find the best set of parameters for your application on the example of training a MNIST classifier. Notably, the LightningModule does not have to be altered at all for this - so you can use it plug and play for your existing models, assuming their parameters are configurable!

Note

If you don’t want to use AIR LightningTrainer and prefer using vanilla lightning trainer with function trainable, please refer to this document: Using vanilla Pytorch Lightning with Tune.

Note

To run this example, you will need to install the following:

$ pip install "ray[tune]" torch torchvision pytorch-lightning

PyTorch Lightning classifier for MNIST#

Let’s first start with the basic PyTorch Lightning implementation of an MNIST classifier. This classifier does not include any tuning code at this point.

First, we run some imports:

import os
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from filelock import FileLock
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms

from ray.train.lightning import LightningTrainer, LightningConfigBuilder
/home/ray/anaconda3/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Our example builds on the MNIST example from the blog post we mentioned before. We adapted the original model and dataset definitions into MNISTClassifier and MNISTDataModule.

class MNISTClassifier(pl.LightningModule):
    def __init__(self, config):
        super(MNISTClassifier, self).__init__()
        self.accuracy = Accuracy()
        self.layer_1_size = config["layer_1_size"]
        self.layer_2_size = config["layer_2_size"]
        self.lr = config["lr"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=128):
        super().__init__()
        self.data_dir = os.getcwd()
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self, stage=None):
        with FileLock(f"{self.data_dir}.lock"):
            mnist = MNIST(
                self.data_dir, train=True, download=True, transform=self.transform
            )
            self.mnist_train, self.mnist_val = random_split(mnist, [55000, 5000])

            self.mnist_test = MNIST(
                self.data_dir, train=False, download=True, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=4)
default_config = {
    "layer_1_size": 128,
    "layer_2_size": 256,
    "lr": 1e-3,
}

Tuning the model parameters#

The parameters above should give you a good accuracy of over 90% already. However, we might improve on this simply by changing some of the hyperparameters. For instance, maybe we get an even higher accuracy if we used a smaller learning rate and larger middle layer size.

Instead of manually loop through all the parameter combinitions, let’s use Tune to systematically try out parameter combinations and find the best performing set.

First, we need some additional imports:

from pytorch_lightning.loggers import TensorBoardLogger
from ray import air, tune
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining

Configuring the search space#

Now we configure the parameter search space using LightningConfigBuilder. We would like to choose between three different layer and batch sizes. The learning rate should be sampled uniformly between 0.0001 and 0.1. The tune.loguniform() function is syntactic sugar to make sampling between these different orders of magnitude easier, specifically we are able to also sample small values.

Note

In LightningTrainer, the frequency of metric reporting is the same as the frequency of checkpointing. For example, if you set builder.checkpointing(..., every_n_epochs=2), then for every 2 epochs, all the latest metrics will be reported to the Ray Tune session along with the latest checkpoint. Please make sure the target metrics(e.g. metrics specified in TuneConfig, schedulers, and searchers) are logged before saving a checkpoint.

Note

Use LightningConfigBuilder.checkpointing() to specify the monitor metric and checkpoint frequency for the Lightning ModelCheckpoint callback. To properly save AIR checkpoints, you must also provide an AIR CheckpointConfig. Otherwise, LightningTrainer will create a default CheckpointConfig, which saves all the reported checkpoints by default.

# The maximum training epochs
num_epochs = 5

# Number of sampls from parameter space
num_samples = 10

accelerator = "gpu"

config = {
    "layer_1_size": tune.choice([32, 64, 128]),
    "layer_2_size": tune.choice([64, 128, 256]),
    "lr": tune.loguniform(1e-4, 1e-1),
}

If you have more resources available, you can modify the above parameters accordingly. e.g. more epochs, more parameter samples.

dm = MNISTDataModule(batch_size=64)
logger = TensorBoardLogger(save_dir=os.getcwd(), name="tune-ptl-example", version=".")

lightning_config = (
    LightningConfigBuilder()
    .module(cls=MNISTClassifier, config=config)
    .trainer(max_epochs=num_epochs, accelerator=accelerator, logger=logger)
    .fit_params(datamodule=dm)
    .checkpointing(monitor="ptl/val_accuracy", save_top_k=2, mode="max")
    .build()
)

# Make sure to also define an AIR CheckpointConfig here
# to properly save checkpoints in AIR format.
run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=2,
        checkpoint_score_attribute="ptl/val_accuracy",
        checkpoint_score_order="max",
    ),
)

Selecting a scheduler#

In this example, we use an Asynchronous Hyperband scheduler. This scheduler decides at each iteration which trials are likely to perform badly, and stops these trials. This way we don’t waste any resources on bad hyperparameter configurations.

scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

Training with GPUs#

We can specify the number of resources, including GPUs, that Tune should request for each trial.

LightningTrainer takes care of environment setup for Distributed Data Parallel training, the model and data will automatically get distributed across GPUs. You only need to set the number of GPUs per worker in ScalingConfig and also set accelerator="gpu" in LightningTrainerConfigBuilder.

scaling_config = ScalingConfig(
    num_workers=3, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)
# Define a base LightningTrainer without hyper-parameters for Tuner
lightning_trainer = LightningTrainer(
    scaling_config=scaling_config,
    run_config=run_config,
)

Putting it together#

Lastly, we need to create a Tuner() object and start Ray Tune with tuner.fit().

The full code looks like this:

def tune_mnist_asha(num_samples=10):
    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    tuner = tune.Tuner(
        lightning_trainer,
        param_space={"lightning_config": lightning_config},
        tune_config=tune.TuneConfig(
            metric="ptl/val_accuracy",
            mode="max",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
        run_config=air.RunConfig(
            name="tune_mnist_asha",
        ),
    )
    results = tuner.fit()
    best_result = results.get_best_result(metric="ptl/val_accuracy", mode="max")
    best_result


tune_mnist_asha(num_samples=num_samples)

In the example above, Tune runs 10 trials with different hyperparameter configurations. An example output could look like so:

  +------------------------------+------------+-------------------+----------------+----------------+-------------+----------+-----------------+----------------------+
  | Trial name                   | status     | loc               |   layer_1_size |   layer_2_size |          lr |     loss |   mean_accuracy |   training_iteration |
  |------------------------------+------------+-------------------+----------------+----------------+-------------+----------+-----------------+----------------------|
  | LightningTrainer_9532b_00001 | TERMINATED |  10.0.37.7:448989 |            32  |            64  | 0.00025324  | 0.58146  |       0.866667  |                   1  |
  | LightningTrainer_9532b_00002 | TERMINATED |  10.0.37.7:449722 |            128 |            128 | 0.000166782 | 0.29038  |       0.933333  |                   2  |
  | LightningTrainer_9532b_00003 | TERMINATED |  10.0.37.7:453404 |            64  |            128 | 0.0004948	  | 0.15375  |       0.9       |                   4  |
  | LightningTrainer_9532b_00004 | TERMINATED |  10.0.37.7:457981 |            128 |            128 | 0.000304361 | 0.17622  |       0.966667  |                   4  |
  | LightningTrainer_9532b_00005 | TERMINATED |  10.0.37.7:467478 |            128 |            64  | 0.0344561	  | 0.34665  |       0.866667  |                   1  |
  | LightningTrainer_9532b_00006 | TERMINATED |  10.0.37.7:484401 |            128 |            256 | 0.0262851	  | 0.34981  |       0.866667  |                   1  |
  | LightningTrainer_9532b_00007 | TERMINATED |  10.0.37.7:490670 |            32  |            128 | 0.0550712	  | 0.62575  |       0.766667  |                   1  |
  | LightningTrainer_9532b_00008 | TERMINATED |  10.0.37.7:491159 |            32  |            64  | 0.000489046 | 0.27384  |       0.966667  |                   2  |
  | LightningTrainer_9532b_00009 | TERMINATED |  10.0.37.7:491494 |            64  |            256 | 0.000395127 | 0.09642  |       0.933333  |                   4  |
  +------------------------------+------------+-------------------+----------------+----------------+-------------+----------+-----------------+----------------------+

As you can see in the training_iteration column, trials with a high loss (and low accuracy) have been terminated early. The best performing trial used layer_1_size=32, layer_2_size=64, and lr=0.000489046.

Using Population Based Training to find the best parameters#

The ASHAScheduler terminates those trials early that show bad performance. Sometimes, this stops trials that would get better after more training steps, and which might eventually even show better performance than other configurations.

Another popular method for hyperparameter tuning, called Population Based Training, instead perturbs hyperparameters during the training run. Tune implements PBT, and we only need to make some slight adjustments to our code.

def tune_mnist_pbt(num_samples=10):
    # The range of hyperparameter perturbation.
    mutations_config = (
        LightningConfigBuilder()
        .module(
            config={
                "lr": tune.loguniform(1e-4, 1e-1),
            }
        )
        .build()
    )

    # Create a PBT scheduler
    scheduler = PopulationBasedTraining(
        perturbation_interval=1,
        time_attr="training_iteration",
        hyperparam_mutations={"lightning_config": mutations_config},
    )

    tuner = tune.Tuner(
        lightning_trainer,
        param_space={"lightning_config": lightning_config},
        tune_config=tune.TuneConfig(
            metric="ptl/val_accuracy",
            mode="max",
            num_samples=num_samples,
            scheduler=scheduler,
        ),
        run_config=air.RunConfig(
            name="tune_mnist_pbt",
        ),
    )
    results = tuner.fit()
    best_result = results.get_best_result(metric="ptl/val_accuracy", mode="max")
    best_result
tune_mnist_pbt(num_samples=num_samples)

An example output of a run could look like this:

:emphasize-lines: 12

 +------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------+
 | Trial name                   | status     | loc   |   layer_1_size |   layer_2_size |                  lr |      loss |   ptl/val_accuracy |   training_iteration |
 |------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------|
 | LightningTrainer_85489_00000 | TERMINATED |       |            64  |            64  | 0.0030@perturbed... | 0.108734  |        0.984954    |                   5  |
 | LightningTrainer_85489_00001 | TERMINATED |       |            32  |            256 | 0.0010@perturbed... | 0.093577  |        0.983411    |                   5  |
 | LightningTrainer_85489_00002 | TERMINATED |       |            128 |            64  | 0.0233@perturbed... | 0.0922348 |        0.983989    |                   5  |
 | LightningTrainer_85489_00003 | TERMINATED |       |            64  |            128 | 0.0002@perturbed... | 0.124648  |        0.98206	  |                   5  |
 | LightningTrainer_85489_00004 | TERMINATED |       |            128 |            256 | 0.0021              | 0.101717  |        0.993248    |                   5  |
 | LightningTrainer_85489_00005 | TERMINATED |       |            32  |            128 | 0.0003@perturbed... | 0.121467  |        0.984182    |                   5  |
 | LightningTrainer_85489_00006 | TERMINATED |       |            128 |            64  | 0.0020@perturbed... | 0.053446  |        0.984375    |                   5  |
 | LightningTrainer_85489_00007 | TERMINATED |       |            64  |            64  | 0.0063@perturbed... | 0.129804  |        0.98669	  |                   5  |
 | LightningTrainer_85489_00008 | TERMINATED |       |            128 |            256 | 0.0436@perturbed... | 0.363236  |        0.982253    |                   5  |
 | LightningTrainer_85489_00009 | TERMINATED |       |            128 |            256 | 0.001               | 0.150946  |        0.985147    |                   5  |
 +------------------------------+------------+-------+----------------+----------------+---------------------+-----------+--------------------+----------------------+

As you can see, each sample ran the full number of 5 iterations. All trials ended with quite good parameter combinations and showed relatively good performances (above 0.98). In some runs, the parameters have been perturbed. And the best configuration even reached a mean validation accuracy of 0.993248!

In summary, AIR LightningTrainer is easy to extend to use with Tune. It only required adding a few lines of code to integrate with Ray Tuner to get great performing parameter configurations.

More PyTorch Lightning Examples#