import math
import os
import pytorch_lightning as pl
import torch
from filelock import FileLock
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
from torchvision.datasets import MNIST
from ray import train, tune
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size: int, data_dir: str = PATH_DATASETS):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
self.batch_size = batch_size
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self, config, data_dir=None):
super(LightningMNISTClassifier, self).__init__()
self.data_dir = data_dir or os.getcwd()
self.lr = config["lr"]
layer_1, layer_2 = config["layer_1"], config["layer_2"]
self.batch_size = config["batch_size"]
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, layer_1)
self.layer_2 = torch.nn.Linear(layer_1, layer_2)
self.layer_3 = torch.nn.Linear(layer_2, 10)
self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
x = torch.relu(x)
x = self.layer_3(x)
x = torch.log_softmax(x, dim=1)
return x
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
def training_step(self, train_batch, batch_idx):
x, y = train_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
self.log("ptl/train_loss", loss)
self.log("ptl/train_accuracy", acc)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
logits = self.forward(x)
loss = F.nll_loss(logits, y)
acc = self.accuracy(logits, y)
return {"val_loss": loss, "val_accuracy": acc}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
self.log("ptl/val_loss", avg_loss)
self.log("ptl/val_accuracy", avg_acc)
def train_mnist_tune(config, num_epochs=10, num_gpus=0):
data_dir = os.path.abspath("./data")
model = LightningMNISTClassifier(config, data_dir)
with FileLock(os.path.expanduser("~/.data.lock")):
dm = MNISTDataModule(data_dir=data_dir, batch_size=config["batch_size"])
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
trainer = pl.Trainer(
max_epochs=num_epochs,
# If fractional GPUs passed in, convert to int.
gpus=math.ceil(num_gpus),
enable_progress_bar=False,
callbacks=[
TuneReportCheckpointCallback(
metrics, on="validation_end", save_checkpoints=False
)
],
)
trainer.fit(model, dm)
def tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0):
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
}
trainable = tune.with_parameters(
train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial
)
tuner = tune.Tuner(
tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
num_samples=num_samples,
),
run_config=train.RunConfig(
name="tune_mnist",
),
param_space=config,
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_mnist(num_samples=1, num_epochs=1, gpus_per_trial=0)
else:
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)