Pytorch Lightning with RaySGD


RaySGD includes an integration with Pytorch Lightning’s LightningModule. Easily take your existing LightningModule, and use it with Ray SGD’s TorchTrainer to take advantage of all of Ray SGD’s distributed training features with minimal code changes.


This LightningModule integration is currently under active development. If you encounter any bugs, please raise an issue on Github!


Not all Pytorch Lightning features are supported. A full list of unsupported model hooks is listed down below. Please post any feature requests on Github and we will get to it shortly!

Quick Start

Step 1: Define your LightningModule just like how you would with Pytorch Lightning.

from pytorch_lightning.core.lightning import LightningModule

class MyLightningModule(LightningModule):

Step 2: Use the TrainingOperator.from_ptl method to convert the LightningModule to a Ray SGD compatible LightningOperator.

from ray.util.sgd.torch import TrainingOperator

MyLightningOperator = TrainingOperator.from_ptl(MyLightningModule)

Step 3: Use the Operator with Ray SGD’s TorchTrainer, just like how you would normally. See Distributed PyTorch for a more full guide on TorchTrainer.

import ray
from ray.util.sgd.torch import TorchTrainer

trainer = TorchTrainer(training_operator_cls=MyLightningOperator, num_workers=4, use_gpu=True)
train_stats = trainer.train()

And that’s it! For a more comprehensive guide, see the MNIST tutorial below.

MNIST Tutorial

In this walkthrough we will go through how to train an MNIST classifier with Pytorch Lightning’s LightningModule and Ray SGD.

We will follow this tutorial from the PyTorch Lightning documentation for specifying our MNIST LightningModule.

Setup / Imports

Let’s start with some basic imports:

import os

# Pytorch imports
import torch
from torch.optim import Adam
from import DataLoader, random_split
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST

# Ray imports
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator

# PTL imports
from pytorch_lightning.core.lightning import LightningModule

Most of these imports are needed for building our Pytorch model and training components. Only a few additional imports are needed for Ray and Pytorch Lightning.

MNIST LightningModule

We now define our Pytorch Lightning LightningModule:

class LitMNIST(LightningModule):
    # We take in an additional config parameter here. But this is not required.
    def __init__(self, config):

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

        self.config = config

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        x = torch.relu(x)
        x = self.layer_3(x)

        x = torch.log_softmax(x, dim=1)
        return x

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.config["lr"])

    def setup(self, stage):
        # transforms for images
        transform = transforms.Compose([
            transforms.Normalize((0.1307, ), (0.3081, ))

        # prepare transforms standard to MNIST
        mnist_train = MNIST(
            os.getcwd(), train=True, download=True, transform=transform)

        self.mnist_train, self.mnist_val = random_split(
            mnist_train, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train, batch_size=self.config["batch_size"])

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.config["batch_size"])

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        _, predicted = torch.max(, 1)
        num_correct = (predicted == y).sum().item()
        num_samples = y.size(0)
        return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}

This is the same code that would normally be used in Pytorch Lightning, and is taken directly from this PTL guide. The only difference here is that the __init__ method can optionally take in a config argument, as a way to pass in hyperparameters to your model, optimizer, or schedulers. The config will be passed in directly from the TorchTrainer. Or if using Ray SGD in conjunction with Tune (RaySGD Hyperparameter Tuning), it will come directly from the config in your call.

Training with Ray SGD

We now can define our training function using our LitMNIST module and Ray SGD.

def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
    Operator = TrainingOperator.from_ptl(LitMNIST)
    trainer = TorchTrainer(
            "lr": 1e-3,
            "batch_size": 64
    for i in range(num_epochs):
        stats = trainer.train()

    print("Saving model checkpoint to ./")"./")
    print("Model Checkpointed!")

With just a single from_ptl call, we can convert our LightningModule to a TrainingOperator class that’s compatible with Ray SGD. Now we can take full advantage of all of Ray SGD’s distributed trainign features without having to rewrite our existing LightningModule.

The last thing to do is initialize Ray, and run our training function!

# Use ray.init(address="auto") if running on a Ray cluster.
train_mnist(num_workers=32, use_gpu=True, num_epochs=5)

Unsupported Features

This integration is currently under active development, so not all Pytorch Lightning features are supported. Please post any feature requests on Github and we will get to it shortly!

A list of unsupported model hooks (as of v1.0.0) is as follows: test_dataloader, on_test_batch_start, on_test_epoch_start, on_test_batch_end, on_test_epoch_start, get_progress_bar_dict, on_fit_end, on_pretrain_routine_end, manual_backward, tbtt_split_batch.