MNIST PyTorch Trainable Example#

# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
from __future__ import print_function

import argparse
import os

import torch
import torch.optim as optim

import ray
from ray import train, tune
from ray.tune.examples.mnist_pytorch import (
    ConvNet,
    get_data_loaders,
    test_func,
    train_func,
)
from ray.tune.schedulers import ASHAScheduler

# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 512
TEST_SIZE = 256

# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
    "--use-gpu", action="store_true", default=False, help="enables CUDA training"
)
parser.add_argument("--ray-address", type=str, help="The Redis address of the cluster.")
parser.add_argument(
    "--smoke-test", action="store_true", help="Finish quickly for testing"
)


# Below comments are for documentation purposes only.
# fmt: off
# __trainable_example_begin__
class TrainMNIST(tune.Trainable):
    def setup(self, config):
        use_cuda = config.get("use_gpu") and torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet().to(self.device)
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config.get("lr", 0.01),
            momentum=config.get("momentum", 0.9))

    def step(self):
        train_func(
            self.model, self.optimizer, self.train_loader, device=self.device)
        acc = test_func(self.model, self.test_loader, self.device)
        return {"mean_accuracy": acc}

    def save_checkpoint(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)

    def load_checkpoint(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))


# __trainable_example_end__
# fmt: on

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init(address=args.ray_address, num_cpus=6 if args.smoke_test else None)
    sched = ASHAScheduler()

    tuner = tune.Tuner(
        tune.with_resources(TrainMNIST, resources={"cpu": 3, "gpu": int(args.use_gpu)}),
        run_config=train.RunConfig(
            stop={
                "mean_accuracy": 0.95,
                "training_iteration": 3 if args.smoke_test else 20,
            },
            checkpoint_config=train.CheckpointConfig(
                checkpoint_at_end=True, checkpoint_frequency=3
            ),
        ),
        tune_config=tune.TuneConfig(
            metric="mean_accuracy",
            mode="max",
            scheduler=sched,
            num_samples=1 if args.smoke_test else 20,
        ),
        param_space={
            "args": args,
            "lr": tune.uniform(0.001, 0.1),
            "momentum": tune.uniform(0.1, 0.9),
        },
    )
    results = tuner.fit()

    print("Best config is:", results.get_best_result().config)