# Original Code here:
# https://github.com/pytorch/examples/blob/master/mnist/main.py
from __future__ import print_function
import argparse
import os
import torch
import torch.optim as optim
import ray
from ray import train, tune
from ray.tune.examples.mnist_pytorch import (
ConvNet,
get_data_loaders,
test_func,
train_func,
)
from ray.tune.schedulers import ASHAScheduler
# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 512
TEST_SIZE = 256
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="enables CUDA training"
)
parser.add_argument("--ray-address", type=str, help="The Redis address of the cluster.")
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
# Below comments are for documentation purposes only.
# fmt: off
# __trainable_example_begin__
class TrainMNIST(tune.Trainable):
def setup(self, config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet().to(self.device)
self.optimizer = optim.SGD(
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def step(self):
train_func(
self.model, self.optimizer, self.train_loader, device=self.device)
acc = test_func(self.model, self.test_loader, self.device)
return {"mean_accuracy": acc}
def save_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
def load_checkpoint(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
# __trainable_example_end__
# fmt: on
if __name__ == "__main__":
args = parser.parse_args()
ray.init(address=args.ray_address, num_cpus=6 if args.smoke_test else None)
sched = ASHAScheduler()
tuner = tune.Tuner(
tune.with_resources(TrainMNIST, resources={"cpu": 3, "gpu": int(args.use_gpu)}),
run_config=train.RunConfig(
stop={
"mean_accuracy": 0.95,
"training_iteration": 3 if args.smoke_test else 20,
},
checkpoint_config=train.CheckpointConfig(
checkpoint_at_end=True, checkpoint_frequency=3
),
),
tune_config=tune.TuneConfig(
metric="mean_accuracy",
mode="max",
scheduler=sched,
num_samples=1 if args.smoke_test else 20,
),
param_space={
"args": args,
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
},
)
results = tuner.fit()
print("Best config is:", results.get_best_result().config)