Finetuning a Pytorch Image Classifier with Ray AIR
Contents
Finetuning a Pytorch Image Classifier with Ray AIR#
This example fine tunes a pre-trained ResNet model with Ray Train.
For this example, the network architecture consists of the intermediate layer output of a pre-trained ResNet model, which feeds into a randomly initialized linear layer that outputs classification logits for our new task.
Load and preprocess finetuning dataset#
This example is adapted from Pytorch’s Finetuning Torchvision Models tutorial. We will use hymenoptera_data as the finetuning dataset, which contains two classes (bees and ants) and 397 total images (across training and validation). This is a quite small dataset and used only for demonstration purposes.
The dataset is publicly available here. Note that it is structured with directory names as the labels. Use torchvision.datasets.ImageFolder()
to load the images and their corresponding labels.
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import numpy as np
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
"train": transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
),
"val": transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
),
}
# Download and build torch datasets
def build_datasets():
os.system(
"wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1"
)
os.system("unzip hymenoptera_data.zip >/dev/null 2>&1")
torch_datasets = {}
for split in ["train", "val"]:
torch_datasets[split] = datasets.ImageFolder(
os.path.join("./hymenoptera_data", split), data_transforms[split]
)
return torch_datasets
Initialize Model and Fine-tuning configs#
Next, let’s define the training configuration that will be passed into the training loop function later.
train_loop_config = {
"input_size": 224, # Input image size (224 x 224)
"batch_size": 32, # Batch size for training
"num_epochs": 10, # Number of epochs to train for
"lr": 0.001, # Learning Rate
"momentum": 0.9, # SGD optimizer momentum
}
Next, let’s define our model. You can either create a model from pre-trained weights or reload the model checkpoint from a previous run.
from ray.train.torch import TorchCheckpoint
# Option 1: Initialize model with pretrained weights
def initialize_model():
# Load pretrained model params
model = models.resnet50(pretrained=True)
# Replace the original classifier with a new Linear layer
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
# Ensure all params get updated during finetuning
for param in model.parameters():
param.requires_grad = True
return model
# Option 2: Initialize model with an AIR checkpoint
# Replace this with your own uri
CHECKPOINT_URI = "s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000000/"
def initialize_model_from_uri(checkpoint_uri):
checkpoint = TorchCheckpoint.from_uri(checkpoint_uri)
resnet50 = initialize_model()
return checkpoint.get_model(model=resnet50)
Define the Training Loop#
The train_loop_per_worker
function defines the fine-tuning procedure for each worker.
1. Prepare dataloaders for each worker:
This tutorial assumes you are using PyTorch’s native
torch.utils.data.Dataset
for data input.train.torch.prepare_data_loader()
prepares your dataLoader for distributed execution. You can also use Ray Data for more efficient preprocessing(see this example).
2. Prepare your model:
train.torch.prepare_model()
prepares the model for distributed training. Under the hood, it converts your torch model toDistributedDataParallel
model, which synchronize its weights across all workers.
3. Report metrics and checkpoint:
session.report()
will report metrics and checkpoints to Ray AIR.Saving checkpoints through
session.report(metrics, checkpoint=...)
will automatically upload checkpoints to cloud storage (if configured), and allow you to easily enable Ray AIR worker fault tolerance in the future.
import ray.train as train
from ray.air import session
from ray.train.torch import TorchCheckpoint
def evaluate(logits, labels):
_, preds = torch.max(logits, 1)
corrects = torch.sum(preds == labels).item()
return corrects
def train_loop_per_worker(configs):
import warnings
warnings.filterwarnings("ignore")
# Calculate the batch size for a single worker
worker_batch_size = configs["batch_size"] // session.get_world_size()
# Build datasets on each worker
torch_datasets = build_datasets()
# Prepare dataloader for each worker
dataloaders = dict()
dataloaders["train"] = DataLoader(
torch_datasets["train"], batch_size=worker_batch_size, shuffle=True
)
dataloaders["val"] = DataLoader(
torch_datasets["val"], batch_size=worker_batch_size, shuffle=False
)
# Distribute
dataloaders["train"] = train.torch.prepare_data_loader(dataloaders["train"])
dataloaders["val"] = train.torch.prepare_data_loader(dataloaders["val"])
device = train.torch.get_device()
# Prepare DDP Model, optimizer, and loss function
model = initialize_model()
model = train.torch.prepare_model(model)
optimizer = optim.SGD(
model.parameters(), lr=configs["lr"], momentum=configs["momentum"]
)
criterion = nn.CrossEntropyLoss()
# Start training loops
for epoch in range(configs["num_epochs"]):
# Each epoch has a training and validation phase
for phase in ["train", "val"]:
if phase == "train":
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
with torch.set_grad_enabled(phase == "train"):
# Get model outputs and calculate loss
outputs = model(inputs)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == "train":
loss.backward()
optimizer.step()
# calculate statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += evaluate(outputs, labels)
size = len(torch_datasets[phase]) // session.get_world_size()
epoch_loss = running_loss / size
epoch_acc = running_corrects / size
if session.get_world_rank() == 0:
print(
"Epoch {}-{} Loss: {:.4f} Acc: {:.4f}".format(
epoch, phase, epoch_loss, epoch_acc
)
)
# Report metrics and checkpoint every epoch
if phase == "val":
checkpoint = TorchCheckpoint.from_dict(
{
"epoch": epoch,
"model": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}
)
session.report(
metrics={"loss": epoch_loss, "acc": epoch_acc},
checkpoint=checkpoint,
)
Next, setup the TorchTrainer:
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, CheckpointConfig
# Scale out model training across 4 GPUs.
scaling_config = ScalingConfig(
num_workers=4, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 1}
)
# Save the latest checkpoint
checkpoint_config = CheckpointConfig(num_to_keep=1)
# Set experiment name and checkpoint configs
run_config = RunConfig(
name="finetune-resnet",
storage_path="/tmp/ray_results",
checkpoint_config=checkpoint_config,
)
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
scaling_config=scaling_config,
run_config=run_config,
)
result = trainer.fit()
print(result)
2023-03-01 12:40:15,468 INFO worker.py:1360 -- Connecting to existing Ray cluster at address: 10.0.53.212:6379...
2023-03-01 12:40:15,520 INFO worker.py:1548 -- Connected to Ray cluster. View the dashboard at https://console.anyscale-staging.com/api/v2/sessions/ses_49hwcjc1pzcddc2nf6cg9itj6b/services?redirect_to=dashboard
2023-03-01 12:40:16,841 INFO packaging.py:330 -- Pushing file package 'gcs://_ray_pkg_d6a92d7fa9e73b7fc2276251a1203373.zip' (451.72MiB) to Ray cluster...
2023-03-01 12:40:26,413 INFO packaging.py:343 -- Successfully pushed file package 'gcs://_ray_pkg_d6a92d7fa9e73b7fc2276251a1203373.zip'.
Tune Status
Current time: | 2023-03-01 12:41:31 |
Running for: | 00:01:05.01 |
Memory: | 8.3/62.0 GiB |
System Info
Using FIFO scheduling algorithm.Resources requested: 0/64 CPUs, 0/4 GPUs, 0.0/163.97 GiB heap, 0.0/72.85 GiB objects (0.0/4.0 accelerator_type:T4)
Trial Status
Trial name | status | loc | iter | total time (s) | loss | acc |
---|---|---|---|---|---|---|
TorchTrainer_4c393_00000 | TERMINATED | 10.0.62.120:1395 | 10 | 51.9574 | 0.143938 | 0.973684 |
(RayTrainWorker pid=1478, ip=10.0.62.120) 2023-03-01 12:40:37,398 INFO config.py:86 -- Setting up process group for: env:// [rank=0, world_size=4]
(RayTrainWorker pid=89742) 2023-03-01 12:40:39,344 INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=1478, ip=10.0.62.120) 2023-03-01 12:40:39,474 INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=1902, ip=10.0.24.64) 2023-03-01 12:40:39,495 INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=3689, ip=10.0.52.145) 2023-03-01 12:40:39,588 INFO train_loop_utils.py:307 -- Moving model to device: cuda:0
(RayTrainWorker pid=89742) 2023-03-01 12:40:40,888 INFO train_loop_utils.py:367 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=1478, ip=10.0.62.120) 2023-03-01 12:40:41,001 INFO train_loop_utils.py:367 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=1902, ip=10.0.24.64) 2023-03-01 12:40:41,019 INFO train_loop_utils.py:367 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=3689, ip=10.0.52.145) 2023-03-01 12:40:41,123 INFO train_loop_utils.py:367 -- Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 0-train Loss: 0.7398 Acc: 0.4426
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 0-val Loss: 0.5739 Acc: 0.7105
Trial Progress
Trial name | acc | date | done | experiment_tag | hostname | iterations_since_restore | loss | node_ip | pid | should_checkpoint | time_since_restore | time_this_iter_s | time_total_s | timestamp | training_iteration | trial_id |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
TorchTrainer_4c393_00000 | 0.973684 | 2023-03-01_12-41-26 | True | 0 | ip-10-0-62-120 | 10 | 0.143938 | 10.0.62.120 | 1395 | True | 51.9574 | 4.0961 | 51.9574 | 1677703285 | 10 | 4c393_00000 |
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 1-train Loss: 0.5130 Acc: 0.8197
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 1-val Loss: 0.3553 Acc: 0.9737
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 2-train Loss: 0.4676 Acc: 0.7705
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 2-val Loss: 0.2600 Acc: 0.9737
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 3-train Loss: 0.3940 Acc: 0.8525
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 3-val Loss: 0.2136 Acc: 0.9737
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 4-train Loss: 0.3602 Acc: 0.8852
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 4-val Loss: 0.1854 Acc: 1.0000
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 5-train Loss: 0.2871 Acc: 0.8689
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 5-val Loss: 0.1691 Acc: 1.0000
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 6-train Loss: 0.2858 Acc: 0.9344
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 6-val Loss: 0.1459 Acc: 1.0000
2023-03-01 12:41:13,026 WARNING util.py:244 -- The `process_trial_save` operation took 2.925 s, which may be a performance bottleneck.
2023-03-01 12:41:13,027 WARNING trial_runner.py:678 -- Consider turning off forced head-worker trial checkpoint syncs by setting sync_on_checkpoint=False. Note that this may result in faulty trial restoration if a failure occurs while the checkpoint is being synced from the worker to the head node.
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 7-train Loss: 0.1965 Acc: 0.9344
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 7-val Loss: 0.1387 Acc: 1.0000
2023-03-01 12:41:17,101 WARNING util.py:244 -- The `process_trial_save` operation took 2.925 s, which may be a performance bottleneck.
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 8-train Loss: 0.2277 Acc: 0.9344
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 8-val Loss: 0.1500 Acc: 0.9737
2023-03-01 12:41:21,195 WARNING util.py:244 -- The `process_trial_save` operation took 2.936 s, which may be a performance bottleneck.
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 9-train Loss: 0.1884 Acc: 0.9344
(RayTrainWorker pid=1478, ip=10.0.62.120) Epoch 9-val Loss: 0.1439 Acc: 0.9737
2023-03-01 12:41:25,360 WARNING util.py:244 -- The `process_trial_save` operation took 2.947 s, which may be a performance bottleneck.
2023-03-01 12:41:29,205 WARNING util.py:244 -- The `process_trial_save` operation took 2.696 s, which may be a performance bottleneck.
2023-03-01 12:41:33,757 INFO tune.py:825 -- Total run time: 66.96 seconds (65.01 seconds for the tuning loop).
Result(
metrics={'loss': 0.14393797124686994, 'acc': 0.9736842105263158, 'should_checkpoint': True, 'done': True, 'trial_id': '4c393_00000', 'experiment_tag': '0'},
log_dir=PosixPath('/tmp/ray_results/finetune-resnet/TorchTrainer_4c393_00000_0_2023-03-01_12-40-31'),
checkpoint=TorchCheckpoint(local_path=/tmp/ray_results/finetune-resnet/TorchTrainer_4c393_00000_0_2023-03-01_12-40-31/checkpoint_000009)
)
Load the checkpoint for prediction:#
The metadata and checkpoints have already been saved in the storage_path
specified in TorchTrainer:
We now need to load the trained model and evaluate it on test data. The best model parameters have been saved in log_dir
. We can load the resulting checkpoint from our fine-tuning run using the previously defined initialize_model_from_uri()
function.
model = initialize_model_from_uri(result.checkpoint.uri)
device = torch.device("cuda")
/home/ray/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
warnings.warn(
/home/ray/anaconda3/lib/python3.9/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.
warnings.warn(msg)
Finally, define a simple evaluation loop and check the performance of the checkpoint model.
model = model.to(device)
model.eval()
torch_datasets = build_datasets()
dataloader = DataLoader(torch_datasets["val"], batch_size=32, num_workers=4)
corrects = 0
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
corrects += evaluate(preds, labels)
print("Accuracy: ", corrects / len(dataloader.dataset))
Accuracy: 0.934640522875817