Pytorch Lightning with RaySGD¶

RaySGD includes an integration with Pytorch Lightning’s LightningModule.
Easily take your existing LightningModule
, and use it with Ray SGD’s TorchTrainer
to take advantage of all of Ray SGD’s distributed training features with minimal code changes.
Tip
This LightningModule integration is currently under active development. If you encounter any bugs, please raise an issue on Github!
Note
Not all Pytorch Lightning features are supported. A full list of unsupported model hooks is listed down below. Please post any feature requests on Github and we will get to it shortly!
Quick Start¶
Step 1: Define your LightningModule
just like how you would with Pytorch Lightning.
from pytorch_lightning.core.lightning import LightningModule
class MyLightningModule(LightningModule):
...
Step 2: Use the TrainingOperator.from_ptl
method to convert the LightningModule
to a Ray SGD compatible LightningOperator
.
from ray.util.sgd.torch import TrainingOperator
MyLightningOperator = TrainingOperator.from_ptl(MyLightningModule)
Step 3: Use the Operator with Ray SGD’s TorchTrainer
, just like how you would normally. See Distributed PyTorch for a more full guide on TorchTrainer
.
import ray
from ray.util.sgd.torch import TorchTrainer
ray.init()
trainer = TorchTrainer(training_operator_cls=MyLightningOperator, num_workers=4, use_gpu=True)
train_stats = trainer.train()
And that’s it! For a more comprehensive guide, see the MNIST tutorial below.
MNIST Tutorial¶
In this walkthrough we will go through how to train an MNIST classifier with Pytorch Lightning’s LightningModule
and Ray SGD.
We will follow this tutorial from the PyTorch Lightning documentation for specifying our MNIST LightningModule.
Setup / Imports¶
Let’s start with some basic imports:
import os
# Pytorch imports
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
# Ray imports
from ray.util.sgd import TorchTrainer
from ray.util.sgd.torch import TrainingOperator
# PTL imports
from pytorch_lightning.core.lightning import LightningModule
Most of these imports are needed for building our Pytorch model and training components. Only a few additional imports are needed for Ray and Pytorch Lightning.
MNIST LightningModule¶
We now define our Pytorch Lightning LightningModule
:
class LitMNIST(LightningModule):
# We take in an additional config parameter here. But this is not required.
def __init__(self, config):
super().__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
self.config = config
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.config["lr"])
def setup(self, stage):
# transforms for images
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])
# prepare transforms standard to MNIST
mnist_train = MNIST(
os.path.expanduser("~/data"),
train=True,
download=True,
transform=transform)
self.mnist_train, self.mnist_val = random_split(
mnist_train, [55000, 5000])
def train_dataloader(self):
return DataLoader(
self.mnist_train, batch_size=self.config["batch_size"])
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.config["batch_size"])
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
_, predicted = torch.max(logits.data, 1)
num_correct = (predicted == y).sum().item()
num_samples = y.size(0)
return {"val_loss": loss.item(), "val_acc": num_correct / num_samples}
This is the same code that would normally be used in Pytorch Lightning, and is taken directly from this PTL guide.
The only difference here is that the __init__
method can optionally take in a config
argument,
as a way to pass in hyperparameters to your model, optimizer, or schedulers. The config
will be passed in directly from
the TorchTrainer. Or if using Ray SGD in conjunction with Tune (RaySGD Hyperparameter Tuning), it will come directly from the config in your
tune.run
call.
Training with Ray SGD¶
We now can define our training function using our LitMNIST module and Ray SGD.
def train_mnist(num_workers=1, use_gpu=False, num_epochs=5):
Operator = TrainingOperator.from_ptl(LitMNIST)
trainer = TorchTrainer(
training_operator_cls=Operator,
num_workers=num_workers,
config={
"lr": 1e-3,
"batch_size": 64
},
use_gpu=use_gpu,
use_tqdm=True,
)
for i in range(num_epochs):
stats = trainer.train()
print(stats)
print(trainer.validate())
print("Saving model checkpoint to ./model.pt")
trainer.save("./model.pt")
print("Model Checkpointed!")
trainer.shutdown()
print("success!")
With just a single from_ptl
call, we can convert our LightningModule to a TrainingOperator
class that’s compatible
with Ray SGD. Now we can take full advantage of all of Ray SGD’s distributed trainign features without having to rewrite our existing
LightningModule.
The last thing to do is initialize Ray, and run our training function!
# Use ray.init(address="auto") if running on a Ray cluster.
ray.init()
train_mnist(num_workers=32, use_gpu=True, num_epochs=5)
Unsupported Features¶
This integration is currently under active development, so not all Pytorch Lightning features are supported. Please post any feature requests on Github and we will get to it shortly!
A list of unsupported model hooks (as of v1.0.0) is as follows:
test_dataloader
, on_test_batch_start
, on_test_epoch_start
, on_test_batch_end
, on_test_epoch_start
,
get_progress_bar_dict
, on_fit_end
, on_pretrain_routine_end
, manual_backward
, tbtt_split_batch
.