How to use Tune with PyTorch
Contents
How to use Tune with PyTorch#
In this walkthrough, we will show you how to integrate Tune into your PyTorch training workflow. We will follow this tutorial from the PyTorch documentation for training a CIFAR10 image classifier.

Hyperparameter tuning can make the difference between an average model and a highly accurate one. Often simple things like choosing a different learning rate or changing a network layer size can have a dramatic impact on your model performance. Fortunately, Tune makes exploring these optimal parameter combinations easy - and works nicely together with PyTorch.
As you will see, we only need to add some slight modifications. In particular, we need to
wrap data loading and training in functions,
make some network parameters configurable,
add checkpointing (optional),
and define the search space for the model tuning
Note
To run this example, you will need to install the following:
$ pip install ray torch torchvision
Setup / Imports#
Let’s start with the imports:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from filelock import FileLock
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import ASHAScheduler
Most of the imports are needed for building the PyTorch model. Only the last three imports are for Ray Tune.
Data loaders#
We wrap the data loaders in their own function and pass a global data directory. This way we can share a data directory between different trials.
def load_data(data_dir="./data"):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since
# DataLoader is not threadsafe.
with FileLock(os.path.expanduser("~/.data.lock")):
trainset = torchvision.datasets.CIFAR10(
root=data_dir, train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(
root=data_dir, train=False, download=True, transform=transform)
return trainset, testset
Configurable neural network#
We can only tune those parameters that are configurable. In this example, we can specify the layer sizes of the fully connected layers:
class Net(nn.Module):
def __init__(self, l1=120, l2=84):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, l1)
self.fc2 = nn.Linear(l1, l2)
self.fc3 = nn.Linear(l2, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
The train function#
Now it gets interesting, because we introduce some changes to the example from the PyTorch documentation.
The full code example looks like this:
def train_cifar(config):
net = Net(config["l1"], config["l2"])
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
if torch.cuda.device_count() > 1:
net = nn.DataParallel(net)
net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=config["lr"], momentum=0.9)
# To restore a checkpoint, use `train.get_checkpoint()`.
loaded_checkpoint = train.get_checkpoint()
if loaded_checkpoint:
with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
model_state, optimizer_state = torch.load(os.path.join(loaded_checkpoint_dir, "checkpoint.pt"))
net.load_state_dict(model_state)
optimizer.load_state_dict(optimizer_state)
data_dir = os.path.abspath("./data")
trainset, testset = load_data(data_dir)
test_abs = int(len(trainset) * 0.8)
train_subset, val_subset = random_split(
trainset, [test_abs, len(trainset) - test_abs])
trainloader = torch.utils.data.DataLoader(
train_subset,
batch_size=int(config["batch_size"]),
shuffle=True,
num_workers=8)
valloader = torch.utils.data.DataLoader(
val_subset,
batch_size=int(config["batch_size"]),
shuffle=True,
num_workers=8)
for epoch in range(10): # loop over the dataset multiple times
running_loss = 0.0
epoch_steps = 0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
epoch_steps += 1
if i % 2000 == 1999: # print every 2000 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
running_loss / epoch_steps))
running_loss = 0.0
# Validation loss
val_loss = 0.0
val_steps = 0
total = 0
correct = 0
for i, data in enumerate(valloader, 0):
with torch.no_grad():
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss = criterion(outputs, labels)
val_loss += loss.cpu().numpy()
val_steps += 1
# Here we save a checkpoint. It is automatically registered with
# Ray Tune and can be accessed through `train.get_checkpoint()`
# API in future iterations.
os.makedirs("my_model", exist_ok=True)
torch.save(
(net.state_dict(), optimizer.state_dict()), "my_model/checkpoint.pt")
checkpoint = Checkpoint.from_directory("my_model")
train.report({"loss": (val_loss / val_steps), "accuracy": correct / total}, checkpoint=checkpoint)
print("Finished Training")
As you can see, most of the code is adapted directly from the example.
Test set accuracy#
Commonly the performance of a machine learning model is tested on a hold-out test set with data that has not been used for training the model. We also wrap this in a function:
def test_best_model(best_result):
best_trained_model = Net(best_result.config["l1"], best_result.config["l2"])
device = "cuda:0" if torch.cuda.is_available() else "cpu"
best_trained_model.to(device)
checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), "checkpoint.pt")
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
trainset, testset = load_data()
testloader = torch.utils.data.DataLoader(
testset, batch_size=4, shuffle=False, num_workers=2)
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = best_trained_model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("Best trial test set accuracy: {}".format(correct / total))
As you can see, the function also expects a device
parameter, so we can do the
test set validation on a GPU.
Configuring the search space#
Lastly, we need to define Tune’s search space. Here is an example:
config = {
"l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
"l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([2, 4, 8, 16]),
}
The tune.sample_from()
function makes it possible to define your own sample
methods to obtain hyperparameters. In this example, the l1
and l2
parameters
should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.
The lr
(learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
the batch size is a choice between 2, 4, 8, and 16.
At each trial, Tune will now randomly sample a combination of parameters from these
search spaces. It will then train a number of models in parallel and find the best
performing one among these. We also use the ASHAScheduler
which will terminate bad
performing trials early.
You can specify the number of CPUs, which are then available e.g.
to increase the num_workers
of the PyTorch DataLoader
instances. The selected
number of GPUs are made visible to PyTorch in each trial. Trials do not have access to
GPUs that haven’t been requested for them - so you don’t have to care about two trials
using the same set of resources.
Here we can also specify fractional GPUs, so something like gpus_per_trial=0.5
is
completely valid. The trials will then share GPUs among each other.
You just have to make sure that the models still fit in the GPU memory.
After training the models, we will find the best performing one and load the trained network from the checkpoint file. We then obtain the test set accuracy and report everything by printing.
The full main function looks like this:
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
config = {
"l1": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
"l2": tune.sample_from(lambda _: 2 ** np.random.randint(2, 9)),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([2, 4, 8, 16])
}
scheduler = ASHAScheduler(
max_t=max_num_epochs,
grace_period=1,
reduction_factor=2)
tuner = tune.Tuner(
tune.with_resources(
tune.with_parameters(train_cifar),
resources={"cpu": 2, "gpu": gpus_per_trial}
),
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
scheduler=scheduler,
num_samples=num_samples,
),
param_space=config,
)
results = tuner.fit()
best_result = results.get_best_result("loss", "min")
print("Best trial config: {}".format(best_result.config))
print("Best trial final validation loss: {}".format(
best_result.metrics["loss"]))
print("Best trial final validation accuracy: {}".format(
best_result.metrics["accuracy"]))
test_best_model(best_result)
main(num_samples=2, max_num_epochs=2, gpus_per_trial=0)
2022-07-22 16:38:53,384 INFO services.py:1483 -- View the Ray dashboard at http://127.0.0.1:8273
2022-07-22 16:38:56,785 WARNING function_trainable.py:619 --
Current time: 2022-07-22 16:40:13 (running for 00:01:16.43)
Memory usage on this node: 10.7/16.0 GiB
Using AsyncHyperBand: num_stopped=2 Bracket: Iter 2.000: -1.421571186053753 | Iter 1.000: -1.7652838359832763
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/5.63 GiB heap, 0.0/2.0 GiB objects
Current best trial: 66098_00000 with loss=1.421571186053753 and parameters={'l1': 128, 'l2': 128, 'lr': 0.00046907397024184945, 'batch_size': 4}
Result logdir: /Users/kai/ray_results/train_cifar_2022-07-22_16-38-50
Number of trials: 2/2 (2 TERMINATED)
Trial name | status | loc | batch_size | l1 | l2 | lr | iter | total time (s) | loss | accuracy |
---|---|---|---|---|---|---|---|---|---|---|
train_cifar_66098_00000 | TERMINATED | 127.0.0.1:53065 | 4 | 128 | 128 | 0.000469074 | 2 | 72.6176 | 1.42157 | 0.4877 |
train_cifar_66098_00001 | TERMINATED | 127.0.0.1:53078 | 4 | 128 | 64 | 0.00993903 | 1 | 64.9721 | 1.90462 | 0.2915 |
2022-07-22 16:38:57,794 INFO plugin_schema_manager.py:52 -- Loading the default runtime env schemas: ['/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/working_dir_schema.json', '/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/pip_schema.json'].
(train_cifar pid=53065) Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /Users/kai/ray_results/train_cifar_2022-07-22_16-38-50/train_cifar_66098_00000_0_batch_size=4,l1=128,l2=128,lr=0.0005_2022-07-22_16-38-57/data/cifar-10-python.tar.gz
0%| | 0/170498071 [00:00<?, ?it/s]
0%| | 33792/170498071 [00:00<14:09, 200766.21it/s]
0%| | 197632/170498071 [00:00<04:21, 650251.40it/s]
0%| | 492544/170498071 [00:00<02:02, 1393215.11it/s]
1%| | 1082368/170498071 [00:00<01:00, 2821583.21it/s]
1%| | 1950720/170498071 [00:00<00:36, 4640856.44it/s]
2%|▏ | 2835456/170498071 [00:00<00:28, 5924997.25it/s]
2%|▏ | 3965952/170498071 [00:00<00:22, 7537390.37it/s]
3%|▎ | 5063680/170498071 [00:00<00:19, 8499565.01it/s]
4%|▎ | 6128640/170498071 [00:01<00:17, 9134765.49it/s]
4%|▍ | 7406592/170498071 [00:01<00:15, 10222961.60it/s]
5%|▌ | 8553472/170498071 [00:01<00:15, 10510446.68it/s]
6%|▌ | 9700352/170498071 [00:01<00:14, 10735559.18it/s]
6%|▋ | 10863616/170498071 [00:01<00:14, 11000864.50it/s]
7%|▋ | 11971584/170498071 [00:01<00:14, 10807721.84it/s]
8%|▊ | 13059072/170498071 [00:01<00:14, 10666283.43it/s]
8%|▊ | 14130176/170498071 [00:01<00:14, 10574609.87it/s]
9%|▉ | 15303680/170498071 [00:01<00:14, 10892197.45it/s]
10%|▉ | 16396288/170498071 [00:01<00:14, 10675729.31it/s]
10%|█ | 17548288/170498071 [00:02<00:14, 10904960.72it/s]
11%|█ | 18641920/170498071 [00:02<00:15, 9863782.34it/s]
12%|█▏ | 20268032/170498071 [00:02<00:12, 11570721.42it/s]
13%|█▎ | 21452800/170498071 [00:02<00:12, 11541443.91it/s]
13%|█▎ | 22742016/170498071 [00:02<00:12, 11907361.68it/s]
14%|█▍ | 23948288/170498071 [00:02<00:12, 11767051.50it/s]
15%|█▍ | 25136128/170498071 [00:02<00:12, 11573913.54it/s]
15%|█▌ | 26362880/170498071 [00:02<00:12, 11761937.08it/s]
16%|█▌ | 27545600/170498071 [00:02<00:12, 11044330.05it/s]
17%|█▋ | 28662784/170498071 [00:03<00:12, 11026608.99it/s]
17%|█▋ | 29819904/170498071 [00:03<00:12, 11138417.94it/s]
18%|█▊ | 30940160/170498071 [00:03<00:12, 11010962.50it/s]
19%|█▉ | 32046080/170498071 [00:03<00:13, 10129993.91it/s]
19%|█▉ | 33074176/170498071 [00:03<00:13, 9943367.80it/s]
20%|██ | 34259968/170498071 [00:03<00:13, 10448733.26it/s]
21%|██ | 35521536/170498071 [00:03<00:12, 11062770.09it/s]
22%|██▏ | 36799488/170498071 [00:03<00:11, 11524350.69it/s]
22%|██▏ | 37961728/170498071 [00:03<00:11, 11438058.60it/s]
23%|██▎ | 39112704/170498071 [00:04<00:11, 11458573.52it/s]
24%|██▎ | 40263680/170498071 [00:04<00:11, 11424542.73it/s]
24%|██▍ | 41409536/170498071 [00:04<00:11, 11346691.19it/s]
25%|██▌ | 42697728/170498071 [00:04<00:10, 11786332.70it/s]
26%|██▌ | 43879424/170498071 [00:04<00:11, 11446463.24it/s]
26%|██▋ | 45028352/170498071 [00:04<00:11, 10783870.46it/s]
27%|██▋ | 46115840/170498071 [00:04<00:11, 10678464.51it/s]
28%|██▊ | 47190016/170498071 [00:04<00:11, 10442301.65it/s]
28%|██▊ | 48350208/170498071 [00:04<00:11, 10759276.29it/s]
29%|██▉ | 49447936/170498071 [00:04<00:11, 10792782.74it/s]
30%|██▉ | 50731008/170498071 [00:05<00:10, 11385649.79it/s]
31%|███ | 52003840/170498071 [00:05<00:10, 11752264.60it/s]
31%|███ | 53183488/170498071 [00:05<00:10, 11665484.87it/s]
32%|███▏ | 54360064/170498071 [00:05<00:09, 11694832.15it/s]
33%|███▎ | 55531520/170498071 [00:05<00:10, 11476978.86it/s]
33%|███▎ | 56681472/170498071 [00:05<00:10, 11202902.06it/s]
35%|███▍ | 59196416/170498071 [00:05<00:09, 11613442.08it/s]
35%|███▌ | 60392448/170498071 [00:05<00:09, 11697365.01it/s]
36%|███▌ | 61568000/170498071 [00:05<00:09, 11682261.49it/s]
37%|███▋ | 62740480/170498071 [00:06<00:09, 11452394.76it/s]
37%|███▋ | 63889408/170498071 [00:06<00:09, 11380145.84it/s]
38%|███▊ | 65030144/170498071 [00:06<00:09, 10851178.13it/s]
39%|███▉ | 66126848/170498071 [00:06<00:09, 10801561.44it/s]
39%|███▉ | 67224576/170498071 [00:06<00:09, 10829254.00it/s]
40%|████ | 68322304/170498071 [00:06<00:09, 10860535.53it/s]
41%|████ | 69410816/170498071 [00:06<00:09, 10588112.87it/s]
41%|████▏ | 70472704/170498071 [00:06<00:09, 10246538.71it/s]
42%|████▏ | 71500800/170498071 [00:06<00:09, 9998803.68it/s]
43%|████▎ | 72631296/170498071 [00:07<00:09, 10246250.37it/s]
43%|████▎ | 73746432/170498071 [00:07<00:09, 10505166.41it/s]
44%|████▍ | 74843136/170498071 [00:07<00:09, 10607834.22it/s]
45%|████▍ | 75907072/170498071 [00:07<00:09, 10449028.08it/s]
45%|████▌ | 76954624/170498071 [00:07<00:08, 10432352.97it/s]
46%|████▌ | 78021632/170498071 [00:07<00:08, 10390722.73it/s]
46%|████▋ | 79119360/170498071 [00:07<00:08, 10527472.73it/s]
47%|████▋ | 80173056/170498071 [00:07<00:08, 10185476.35it/s]
48%|████▊ | 81265664/170498071 [00:07<00:08, 10271324.45it/s]
48%|████▊ | 82396160/170498071 [00:07<00:08, 10568711.71it/s]
49%|████▉ | 83456000/170498071 [00:08<00:08, 10544461.91it/s]
50%|████▉ | 84624384/170498071 [00:08<00:07, 10847079.54it/s]
50%|█████ | 85754880/170498071 [00:08<00:07, 10930136.27it/s]
51%|█████ | 86934528/170498071 [00:08<00:07, 11111543.18it/s]
52%|█████▏ | 88179712/170498071 [00:08<00:07, 11438251.85it/s]
52%|█████▏ | 89375744/170498071 [00:08<00:07, 11528186.37it/s]
53%|█████▎ | 90620928/170498071 [00:08<00:06, 11741163.72it/s]
54%|█████▍ | 91833344/170498071 [00:08<00:06, 11844882.32it/s]
55%|█████▍ | 93019136/170498071 [00:08<00:07, 10859729.47it/s]
55%|█████▌ | 94120960/170498071 [00:09<00:07, 10842087.69it/s]
56%|█████▌ | 95216640/170498071 [00:09<00:07, 10396612.64it/s]
56%|█████▋ | 96266240/170498071 [00:09<00:07, 10348003.47it/s]
57%|█████▋ | 97354752/170498071 [00:09<00:06, 10497379.20it/s]
58%|█████▊ | 98550784/170498071 [00:09<00:06, 10859176.06it/s]
59%|█████▊ | 99746816/170498071 [00:09<00:06, 11154133.97it/s]
59%|█████▉ | 101041152/170498071 [00:09<00:05, 11656717.30it/s]
61%|██████ | 103482368/170498071 [00:09<00:05, 11894491.00it/s]
61%|██████▏ | 104776704/170498071 [00:09<00:05, 12198050.59it/s]
62%|██████▏ | 106021888/170498071 [00:10<00:05, 12267825.95it/s]
63%|██████▎ | 107250688/170498071 [00:10<00:05, 12168871.59it/s]
64%|██████▎ | 108469248/170498071 [00:10<00:05, 11595891.87it/s]
64%|██████▍ | 109635584/170498071 [00:10<00:05, 11559630.86it/s]
65%|██████▍ | 110795776/170498071 [00:10<00:05, 11349017.87it/s]
66%|██████▌ | 111943680/170498071 [00:10<00:05, 11386490.70it/s]
66%|██████▋ | 113085440/170498071 [00:10<00:05, 11297820.57it/s]
67%|██████▋ | 114216960/170498071 [00:10<00:05, 11067298.70it/s]
68%|██████▊ | 115325952/170498071 [00:10<00:05, 10487747.55it/s]
68%|██████▊ | 116442112/170498071 [00:11<00:05, 10659499.65it/s]
69%|██████▉ | 117720064/170498071 [00:11<00:04, 11250043.47it/s]
70%|██████▉ | 118981632/170498071 [00:11<00:04, 11635244.90it/s]
70%|███████ | 120151040/170498071 [00:11<00:04, 11648387.01it/s]
71%|███████ | 121357312/170498071 [00:11<00:04, 11768670.42it/s]
72%|███████▏ | 122586112/170498071 [00:11<00:04, 11905710.02it/s]
73%|███████▎ | 123779072/170498071 [00:11<00:03, 11731120.04it/s]
73%|███████▎ | 124954624/170498071 [00:11<00:03, 11738276.15it/s]
74%|███████▍ | 126130176/170498071 [00:11<00:03, 11630197.21it/s]
75%|███████▍ | 127370240/170498071 [00:11<00:03, 11827478.91it/s]
75%|███████▌ | 128631808/170498071 [00:12<00:03, 12036486.08it/s]
76%|███████▌ | 129860608/170498071 [00:12<00:03, 12070353.14it/s]
77%|███████▋ | 131068928/170498071 [00:12<00:03, 11930703.88it/s]
78%|███████▊ | 132262912/170498071 [00:12<00:03, 11640168.06it/s]
78%|███████▊ | 133429248/170498071 [00:12<00:03, 11474541.42it/s]
79%|███████▉ | 134611968/170498071 [00:12<00:03, 11412959.43it/s]
80%|████████ | 137004032/170498071 [00:12<00:02, 11687905.40it/s]
81%|████████ | 138174464/170498071 [00:13<00:05, 5764496.82it/s]
83%|████████▎ | 141067264/170498071 [00:13<00:03, 9739825.72it/s]
84%|████████▎ | 142558208/170498071 [00:13<00:02, 9526524.50it/s]
84%|████████▍ | 143872000/170498071 [00:13<00:03, 6741235.46it/s]
86%|████████▌ | 146703360/170498071 [00:13<00:02, 10136683.18it/s]
87%|████████▋ | 148267008/170498071 [00:14<00:02, 10106749.67it/s]
88%|████████▊ | 149662720/170498071 [00:14<00:02, 10207494.38it/s]
89%|████████▊ | 150955008/170498071 [00:14<00:01, 10359984.46it/s]
89%|████████▉ | 152185856/170498071 [00:14<00:01, 10695074.47it/s]
90%|████████▉ | 153402368/170498071 [00:14<00:01, 10797069.40it/s]
91%|█████████ | 154587136/170498071 [00:14<00:01, 10637510.80it/s]
91%|█████████▏| 155722752/170498071 [00:14<00:01, 10801079.07it/s]
92%|█████████▏| 156895232/170498071 [00:14<00:01, 11046731.64it/s]
93%|█████████▎| 158041088/170498071 [00:14<00:01, 11153902.38it/s]
93%|█████████▎| 159286272/170498071 [00:15<00:00, 11495953.86it/s]
94%|█████████▍| 160531456/170498071 [00:15<00:00, 11759881.01it/s]
95%|█████████▍| 161809408/170498071 [00:15<00:00, 12040688.94it/s]
96%|█████████▌| 163026944/170498071 [00:15<00:00, 12066384.42it/s]
96%|█████████▋| 164242432/170498071 [00:15<00:00, 12030570.11it/s]
97%|█████████▋| 165451776/170498071 [00:15<00:00, 11715975.80it/s]
98%|█████████▊| 166629376/170498071 [00:15<00:00, 11429772.28it/s]
98%|█████████▊| 167777280/170498071 [00:15<00:00, 11396536.97it/s]
99%|█████████▉| 168921088/170498071 [00:15<00:00, 11335778.16it/s]
170499072it [00:16, 10634117.63it/s]
(train_cifar pid=53065) Extracting /Users/kai/ray_results/train_cifar_2022-07-22_16-38-50/train_cifar_66098_00000_0_batch_size=4,l1=128,l2=128,lr=0.0005_2022-07-22_16-38-57/data/cifar-10-python.tar.gz to /Users/kai/ray_results/train_cifar_2022-07-22_16-38-50/train_cifar_66098_00000_0_batch_size=4,l1=128,l2=128,lr=0.0005_2022-07-22_16-38-57/data
(train_cifar pid=53065) Files already downloaded and verified
(train_cifar pid=53065) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1156.)
(train_cifar pid=53065) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
(train_cifar pid=53078) Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /Users/kai/ray_results/train_cifar_2022-07-22_16-38-50/train_cifar_66098_00001_1_batch_size=4,l1=128,l2=64,lr=0.0099_2022-07-22_16-39-01/data/cifar-10-python.tar.gz
0%| | 0/170498071 [00:00<?, ?it/s]
0%| | 33792/170498071 [00:00<13:49, 205600.99it/s]
0%| | 197632/170498071 [00:00<03:15, 869741.20it/s]
0%| | 443392/170498071 [00:00<02:17, 1233031.00it/s]
1%| | 1868800/170498071 [00:00<00:37, 4499043.51it/s]
2%|▏ | 2589696/170498071 [00:00<00:31, 5313052.46it/s]
2%|▏ | 3671040/170498071 [00:00<00:23, 6978488.44it/s]
3%|▎ | 4703232/170498071 [00:00<00:20, 7988630.70it/s]
3%|▎ | 5817344/170498071 [00:01<00:18, 8937646.63it/s]
4%|▍ | 6996992/170498071 [00:01<00:16, 9785571.02it/s]
5%|▍ | 8225792/170498071 [00:01<00:15, 10363873.64it/s]
5%|▌ | 9307136/170498071 [00:01<00:15, 10452563.74it/s]
6%|▌ | 10361856/170498071 [00:01<00:15, 10414666.33it/s]
7%|▋ | 11469824/170498071 [00:01<00:14, 10611373.09it/s]
7%|▋ | 12600320/170498071 [00:01<00:14, 10817759.86it/s]
8%|▊ | 13685760/170498071 [00:01<00:14, 10674349.48it/s]
9%|▊ | 14779392/170498071 [00:01<00:14, 10735551.44it/s]
9%|▉ | 15893504/170498071 [00:01<00:14, 10841588.52it/s]
10%|▉ | 17024000/170498071 [00:02<00:14, 10935080.88it/s]
11%|█ | 18118656/170498071 [00:02<00:14, 10734313.48it/s]
11%|█▏ | 19193856/170498071 [00:02<00:14, 10736033.30it/s]
12%|█▏ | 20284416/170498071 [00:02<00:13, 10784705.38it/s]
13%|█▎ | 22594560/170498071 [00:02<00:13, 11197004.04it/s]
14%|█▍ | 23774208/170498071 [00:02<00:12, 11327905.78it/s]
15%|█▍ | 24953856/170498071 [00:02<00:12, 11441768.69it/s]
15%|█▌ | 26098688/170498071 [00:02<00:12, 11317906.35it/s]
16%|█▌ | 27231232/170498071 [00:02<00:12, 11304969.94it/s]
(train_cifar pid=53065) [1, 2000] loss: 2.289
17%|█▋ | 28558336/170498071 [00:03<00:11, 11857850.10it/s]
17%|█▋ | 29745152/170498071 [00:03<00:11, 11779914.21it/s]
18%|█▊ | 30923776/170498071 [00:03<00:12, 11470109.27it/s]
19%|█▉ | 32072704/170498071 [00:03<00:12, 10962954.99it/s]
19%|█▉ | 33244160/170498071 [00:03<00:12, 11138489.87it/s]
20%|██ | 34362368/170498071 [00:03<00:12, 11060657.86it/s]
21%|██ | 35488768/170498071 [00:03<00:12, 11075482.88it/s]
22%|██▏ | 36733952/170498071 [00:03<00:11, 11463576.76it/s]
22%|██▏ | 37946368/170498071 [00:03<00:11, 11644667.40it/s]
23%|██▎ | 39207936/170498071 [00:03<00:11, 11771113.78it/s]
24%|██▍ | 40551424/170498071 [00:04<00:10, 12239670.49it/s]
25%|██▍ | 41777152/170498071 [00:04<00:10, 11724623.32it/s]
25%|██▌ | 42959872/170498071 [00:04<00:10, 11728894.89it/s]
26%|██▌ | 44237824/170498071 [00:04<00:10, 11980491.97it/s]
27%|██▋ | 45474816/170498071 [00:04<00:10, 12094090.27it/s]
27%|██▋ | 46687232/170498071 [00:04<00:10, 11965519.76it/s]
28%|██▊ | 47886336/170498071 [00:04<00:12, 10214803.94it/s]
29%|██▉ | 49693696/170498071 [00:04<00:09, 12272562.75it/s]
30%|██▉ | 50979840/170498071 [00:04<00:09, 12216219.96it/s]
31%|███ | 52242432/170498071 [00:05<00:10, 11819987.87it/s]
31%|███▏ | 53508096/170498071 [00:05<00:09, 12050298.93it/s]
32%|███▏ | 54735872/170498071 [00:05<00:09, 12065940.19it/s]
33%|███▎ | 55958528/170498071 [00:05<00:09, 12089312.70it/s]
34%|███▎ | 57274368/170498071 [00:05<00:09, 12399347.53it/s]
34%|███▍ | 58523648/170498071 [00:05<00:09, 12262574.22it/s]
35%|███▌ | 59756544/170498071 [00:05<00:09, 12035557.70it/s]
36%|███▌ | 60998656/170498071 [00:05<00:09, 12133520.71it/s]
37%|███▋ | 62276608/170498071 [00:05<00:08, 12211085.29it/s]
37%|███▋ | 63501312/170498071 [00:06<00:08, 12042366.84it/s]
38%|███▊ | 64708608/170498071 [00:06<00:08, 12023498.31it/s]
39%|███▊ | 65912832/170498071 [00:06<00:08, 11757085.89it/s]
39%|███▉ | 67090432/170498071 [00:06<00:09, 11320601.68it/s]
40%|████ | 68227072/170498071 [00:06<00:09, 11258567.04it/s]
41%|████ | 69355520/170498071 [00:06<00:09, 11082937.77it/s]
41%|████▏ | 70465536/170498071 [00:06<00:09, 10911735.04it/s]
42%|████▏ | 71558144/170498071 [00:06<00:09, 10780589.37it/s]
43%|████▎ | 72680448/170498071 [00:06<00:08, 10893260.16it/s]
43%|████▎ | 73771008/170498071 [00:06<00:08, 10841439.34it/s]
44%|████▍ | 74856448/170498071 [00:07<00:08, 10769785.94it/s]
45%|████▍ | 75940864/170498071 [00:07<00:08, 10763959.14it/s]
45%|████▌ | 77054976/170498071 [00:07<00:08, 10849984.90it/s]
46%|████▌ | 78140416/170498071 [00:07<00:08, 10726097.11it/s]
47%|████▋ | 79315968/170498071 [00:07<00:08, 11026350.29it/s]
47%|████▋ | 80462848/170498071 [00:07<00:08, 11146851.17it/s]
48%|████▊ | 81593344/170498071 [00:07<00:08, 11105620.54it/s]
(train_cifar pid=53065) [1, 4000] loss: 1.058
49%|████▊ | 82854912/170498071 [00:07<00:07, 11527296.83it/s]
49%|████▉ | 84008960/170498071 [00:07<00:07, 11414359.53it/s]
50%|████▉ | 85151744/170498071 [00:07<00:07, 10991470.33it/s]
51%|█████ | 86254592/170498071 [00:08<00:07, 10779941.74it/s]
51%|█████▏ | 87426048/170498071 [00:08<00:07, 10916081.56it/s]
52%|█████▏ | 88548352/170498071 [00:08<00:07, 11004500.00it/s]
53%|█████▎ | 89670656/170498071 [00:08<00:07, 11052931.99it/s]
53%|█████▎ | 90850304/170498071 [00:08<00:07, 11002613.99it/s]
54%|█████▍ | 92144640/170498071 [00:08<00:06, 11551701.57it/s]
55%|█████▍ | 93357056/170498071 [00:08<00:06, 11700840.36it/s]
55%|█████▌ | 94553088/170498071 [00:08<00:06, 11639934.12it/s]
56%|█████▌ | 95724544/170498071 [00:08<00:06, 11661673.96it/s]
57%|█████▋ | 96891904/170498071 [00:09<00:06, 11214261.52it/s]
57%|█████▋ | 98017280/170498071 [00:09<00:06, 10988710.06it/s]
58%|█████▊ | 99120128/170498071 [00:09<00:06, 10959602.56it/s]
59%|█████▉ | 100218880/170498071 [00:09<00:06, 10938597.56it/s]
59%|█████▉ | 101314560/170498071 [00:09<00:06, 10739220.45it/s]
60%|██████ | 102417408/170498071 [00:09<00:06, 10815208.28it/s]
61%|██████ | 103500800/170498071 [00:09<00:06, 10820566.49it/s]
61%|██████▏ | 104678400/170498071 [00:09<00:05, 11089700.09it/s]
62%|██████▏ | 105907200/170498071 [00:09<00:05, 11414720.98it/s]
64%|██████▎ | 108348416/170498071 [00:10<00:05, 11853907.96it/s]
64%|██████▍ | 109593600/170498071 [00:10<00:05, 12026577.90it/s]
65%|██████▍ | 110797824/170498071 [00:10<00:05, 11649481.70it/s]
66%|██████▌ | 111966208/170498071 [00:10<00:05, 11295371.19it/s]
67%|██████▋ | 114279424/170498071 [00:10<00:05, 11174140.17it/s]
68%|██████▊ | 115491840/170498071 [00:10<00:04, 11428872.46it/s]
68%|██████▊ | 116638720/170498071 [00:10<00:04, 11376576.33it/s]
69%|██████▉ | 117779456/170498071 [00:10<00:04, 11128529.20it/s]
70%|██████▉ | 118894592/170498071 [00:10<00:04, 11105995.24it/s]
70%|███████ | 120046592/170498071 [00:11<00:04, 11194307.10it/s]
71%|███████ | 121291776/170498071 [00:11<00:04, 11529028.42it/s]
72%|███████▏ | 122504192/170498071 [00:11<00:04, 11678012.53it/s]
73%|███████▎ | 123673600/170498071 [00:11<00:04, 11616160.53it/s]
73%|███████▎ | 124896256/170498071 [00:11<00:03, 11792603.53it/s]
74%|███████▍ | 126092288/170498071 [00:11<00:03, 11722249.72it/s]
75%|███████▍ | 127265792/170498071 [00:11<00:03, 11438563.16it/s]
75%|███████▌ | 128411648/170498071 [00:11<00:03, 11383743.46it/s]
76%|███████▌ | 129582080/170498071 [00:11<00:03, 11459099.89it/s]
77%|███████▋ | 130728960/170498071 [00:12<00:03, 11417116.03it/s]
77%|███████▋ | 131924992/170498071 [00:12<00:03, 11562162.11it/s]
(train_cifar pid=53065) [1, 6000] loss: 0.633
78%|███████▊ | 133082112/170498071 [00:12<00:04, 8762729.97it/s]
79%|███████▊ | 134251520/170498071 [00:12<00:03, 9459387.31it/s]
79%|███████▉ | 135414784/170498071 [00:12<00:03, 10010405.01it/s]
80%|████████ | 136610816/170498071 [00:12<00:03, 10523298.83it/s]
82%|████████▏ | 138980352/170498071 [00:12<00:02, 11012610.93it/s]
82%|████████▏ | 140198912/170498071 [00:12<00:02, 11337292.82it/s]
83%|████████▎ | 141378560/170498071 [00:13<00:02, 11433312.45it/s]
84%|████████▎ | 142537728/170498071 [00:13<00:02, 11225954.39it/s]
84%|████████▍ | 143672320/170498071 [00:13<00:02, 11106773.35it/s]
85%|████████▍ | 144884736/170498071 [00:13<00:02, 11392243.15it/s]
86%|████████▌ | 146064384/170498071 [00:13<00:02, 11461605.17it/s]
86%|████████▋ | 147276800/170498071 [00:13<00:01, 11644469.12it/s]
87%|████████▋ | 148587520/170498071 [00:13<00:01, 12032938.55it/s]
88%|████████▊ | 149793792/170498071 [00:13<00:01, 11625996.36it/s]
89%|████████▊ | 150961152/170498071 [00:13<00:01, 11353618.86it/s]
89%|████████▉ | 152100864/170498071 [00:13<00:01, 11210110.17it/s]
90%|████████▉ | 153225216/170498071 [00:14<00:01, 10989692.27it/s]
91%|█████████ | 154327040/170498071 [00:14<00:02, 6395856.39it/s]
92%|█████████▏| 157336576/170498071 [00:14<00:01, 10963246.92it/s]
93%|█████████▎| 158843904/170498071 [00:14<00:01, 10672622.13it/s]
94%|█████████▍| 160198656/170498071 [00:15<00:01, 7535629.34it/s]
96%|█████████▌| 163070976/170498071 [00:15<00:00, 11093551.80it/s]
97%|█████████▋| 164648960/170498071 [00:15<00:00, 11143733.74it/s]
97%|█████████▋| 166092800/170498071 [00:15<00:00, 11190526.07it/s]
98%|█████████▊| 167443456/170498071 [00:15<00:00, 11480616.98it/s]
99%|█████████▉| 168762368/170498071 [00:15<00:00, 11429373.68it/s]
170499072it [00:15, 10815137.24it/s]
(train_cifar pid=53078) Extracting /Users/kai/ray_results/train_cifar_2022-07-22_16-38-50/train_cifar_66098_00001_1_batch_size=4,l1=128,l2=64,lr=0.0099_2022-07-22_16-39-01/data/cifar-10-python.tar.gz to /Users/kai/ray_results/train_cifar_2022-07-22_16-38-50/train_cifar_66098_00001_1_batch_size=4,l1=128,l2=64,lr=0.0099_2022-07-22_16-39-01/data
(train_cifar pid=53065) [1, 8000] loss: 0.434
(train_cifar pid=53078) Files already downloaded and verified
(train_cifar pid=53078) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1156.)
(train_cifar pid=53078) return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
(train_cifar pid=53065) [1, 10000] loss: 0.325
(train_cifar pid=53078) [1, 2000] loss: 2.117
Result for train_cifar_66098_00000:
accuracy: 0.4004
date: 2022-07-22_16-39-47
done: false
experiment_id: 6512b700fdb64a458c3496f36ea1776c
hostname: Kais-MacBook-Pro.local
iterations_since_restore: 1
loss: 1.625945699906349
node_ip: 127.0.0.1
pid: 53065
should_checkpoint: true
time_since_restore: 45.849108934402466
time_this_iter_s: 45.849108934402466
time_total_s: 45.849108934402466
timestamp: 1658504387
timesteps_since_restore: 0
training_iteration: 1
trial_id: '66098_00000'
warmup_time: 0.003801107406616211
(train_cifar pid=53078) [1, 4000] loss: 0.983
(train_cifar pid=53065) [2, 2000] loss: 1.582
(train_cifar pid=53078) [1, 6000] loss: 0.647
(train_cifar pid=53065) [2, 4000] loss: 0.758
(train_cifar pid=53078) [1, 8000] loss: 0.489
(train_cifar pid=53065) [2, 6000] loss: 0.499
(train_cifar pid=53065) [2, 8000] loss: 0.365
(train_cifar pid=53078) [1, 10000] loss: 0.388
Result for train_cifar_66098_00001:
accuracy: 0.2915
date: 2022-07-22_16-40-09
done: true
experiment_id: 6410c16837024e5e903317c212a4af63
hostname: Kais-MacBook-Pro.local
iterations_since_restore: 1
loss: 1.9046219720602036
node_ip: 127.0.0.1
pid: 53078
should_checkpoint: true
time_since_restore: 64.97207283973694
time_this_iter_s: 64.97207283973694
time_total_s: 64.97207283973694
timestamp: 1658504409
timesteps_since_restore: 0
training_iteration: 1
trial_id: '66098_00001'
warmup_time: 0.0027120113372802734
(train_cifar pid=53065) [2, 10000] loss: 0.285
Result for train_cifar_66098_00000:
accuracy: 0.4877
date: 2022-07-22_16-40-13
done: true
experiment_id: 6512b700fdb64a458c3496f36ea1776c
hostname: Kais-MacBook-Pro.local
iterations_since_restore: 2
loss: 1.421571186053753
node_ip: 127.0.0.1
pid: 53065
should_checkpoint: true
time_since_restore: 72.61763620376587
time_this_iter_s: 26.768527269363403
time_total_s: 72.61763620376587
timestamp: 1658504413
timesteps_since_restore: 0
training_iteration: 2
trial_id: '66098_00000'
warmup_time: 0.003801107406616211
2022-07-22 16:40:14,050 INFO tune.py:738 -- Total run time: 77.27 seconds (76.42 seconds for the tuning loop).
Best trial config: {'l1': 128, 'l2': 128, 'lr': 0.00046907397024184945, 'batch_size': 4}
Best trial final validation loss: 1.421571186053753
Best trial final validation accuracy: 0.4877
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
/Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at ../c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Best trial test set accuracy: 0.4939
If you run the code, an example output could look like this:
Number of trials: 10 (10 TERMINATED)
+-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+
| Trial name | status | loc | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration |
|-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------|
| train_cifar_87d1f_00000 | TERMINATED | | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 |
| train_cifar_87d1f_00001 | TERMINATED | | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 |
| train_cifar_87d1f_00002 | TERMINATED | | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 |
| train_cifar_87d1f_00003 | TERMINATED | | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 |
| train_cifar_87d1f_00004 | TERMINATED | | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 |
| train_cifar_87d1f_00005 | TERMINATED | | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 |
| train_cifar_87d1f_00006 | TERMINATED | | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 |
| train_cifar_87d1f_00007 | TERMINATED | | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 |
| train_cifar_87d1f_00008 | TERMINATED | | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 |
| train_cifar_87d1f_00009 | TERMINATED | | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 |
+-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+
Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.0027624906698231976, 'batch_size': 16, 'data_dir': '...'}
Best trial final validation loss: 1.1815014744281769
Best trial final validation accuracy: 0.5836
Best trial test set accuracy: 0.5806
As you can see, most trials have been stopped early in order to avoid wasting resources. The best performing trial achieved a validation accuracy of about 58%, which could be confirmed on the test set.
So that’s it! You can now tune the parameters of your PyTorch models.
See More PyTorch Examples#
MNIST PyTorch Example: Converts the PyTorch MNIST example to use Tune with the function-based API. Also shows how to easily convert something relying on argparse to use Tune.
PBT ConvNet Example: Example training a ConvNet with checkpointing in function API.
MNIST PyTorch Trainable Example: Converts the PyTorch MNIST example to use Tune with Trainable API. Also uses the HyperBandScheduler and checkpoints the model at the end.