{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Finetuning a Pytorch Image Classifier with Ray Train\n",
"\n",
"\n",
"
\n",
"\n",
"
\n",
"\n",
"This example fine tunes a pre-trained ResNet model with Ray Train. \n",
"\n",
"For this example, the network architecture consists of the intermediate layer output of a pre-trained ResNet model, which feeds into a randomly initialized linear layer that outputs classification logits for our new task.\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load and preprocess finetuning dataset\n",
"This example is adapted from Pytorch's [Finetuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.\n",
"We will use *hymenoptera_data* as the finetuning dataset, which contains two classes (bees and ants) and 397 total images (across training and validation). This is a quite small dataset and used only for demonstration purposes. "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# To run full example, set SMOKE_TEST as False\n",
"SMOKE_TEST = True\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset is publicly available [here](https://www.kaggle.com/datasets/ajayrana/hymenoptera-data). Note that it is structured with directory names as the labels. Use `torchvision.datasets.ImageFolder()` to load the images and their corresponding labels."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"from torchvision import datasets, models, transforms\n",
"import numpy as np\n",
"\n",
"# Data augmentation and normalization for training\n",
"# Just normalization for validation\n",
"data_transforms = {\n",
" \"train\": transforms.Compose(\n",
" [\n",
" transforms.RandomResizedCrop(224),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
" ]\n",
" ),\n",
" \"val\": transforms.Compose(\n",
" [\n",
" transforms.Resize(224),\n",
" transforms.CenterCrop(224),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n",
" ]\n",
" ),\n",
"}\n",
"\n",
"def download_datasets():\n",
" os.system(\n",
" \"wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1\"\n",
" )\n",
" os.system(\"unzip hymenoptera_data.zip >/dev/null 2>&1\")\n",
"\n",
"# Download and build torch datasets\n",
"def build_datasets():\n",
" torch_datasets = {}\n",
" for split in [\"train\", \"val\"]:\n",
" torch_datasets[split] = datasets.ImageFolder(\n",
" os.path.join(\"./hymenoptera_data\", split), data_transforms[split]\n",
" )\n",
" return torch_datasets\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"if SMOKE_TEST:\n",
" from torch.utils.data import Subset\n",
"\n",
" def build_datasets():\n",
" torch_datasets = {}\n",
" for split in [\"train\", \"val\"]:\n",
" torch_datasets[split] = datasets.ImageFolder(\n",
" os.path.join(\"./hymenoptera_data\", split), data_transforms[split]\n",
" )\n",
" \n",
" # Only take a subset for smoke test\n",
" for split in [\"train\", \"val\"]:\n",
" indices = list(range(100))\n",
" torch_datasets[split] = Subset(torch_datasets[split], indices)\n",
" return torch_datasets\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialize Model and Fine-tuning configs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's define the training configuration that will be passed into the training loop function later."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_loop_config = {\n",
" \"input_size\": 224, # Input image size (224 x 224)\n",
" \"batch_size\": 32, # Batch size for training\n",
" \"num_epochs\": 10, # Number of epochs to train for\n",
" \"lr\": 0.001, # Learning Rate\n",
" \"momentum\": 0.9, # SGD optimizer momentum\n",
"}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's define our model. You can either create a model from pre-trained weights or reload the model checkpoint from a previous run."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"from ray.train import Checkpoint\n",
"\n",
"# Option 1: Initialize model with pretrained weights\n",
"def initialize_model():\n",
" # Load pretrained model params\n",
" model = models.resnet50(pretrained=True)\n",
"\n",
" # Replace the original classifier with a new Linear layer\n",
" num_features = model.fc.in_features\n",
" model.fc = nn.Linear(num_features, 2)\n",
"\n",
" # Ensure all params get updated during finetuning\n",
" for param in model.parameters():\n",
" param.requires_grad = True\n",
" return model\n",
"\n",
"\n",
"# Option 2: Initialize model with an Train checkpoint\n",
"# Replace this with your own uri\n",
"CHECKPOINT_FROM_S3 = Checkpoint(\n",
" path=\"s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000001/\"\n",
")\n",
"\n",
"\n",
"def initialize_model_from_checkpoint(checkpoint: Checkpoint):\n",
" with checkpoint.as_directory() as tmpdir:\n",
" state_dict = torch.load(os.path.join(tmpdir, \"checkpoint.pt\"))\n",
" resnet50 = initialize_model()\n",
" resnet50.load_state_dict(state_dict[\"model\"])\n",
" return resnet50\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define the Training Loop\n",
"\n",
"The `train_loop_per_worker` function defines the fine-tuning procedure for each worker.\n",
"\n",
"**1. Prepare dataloaders for each worker**:\n",
"- This tutorial assumes you are using PyTorch's native `torch.utils.data.Dataset` for data input. {meth}`train.torch.prepare_data_loader() ` prepares your dataLoader for distributed execution. You can also use Ray Data for more efficient preprocessing. For more details on using Ray Data for for images, see the {doc}`Working with Images ` Ray Data user guide.\n",
"\n",
"**2. Prepare your model**:\n",
"- {meth}`train.torch.prepare_model() ` prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, which synchronize its weights across all workers.\n",
"\n",
"**3. Report metrics and checkpoint**:\n",
"- {meth}`train.report() ` will report metrics and checkpoints to Ray Train.\n",
"- Saving checkpoints through {meth}`train.report(metrics, checkpoint=...) ` will automatically [upload checkpoints to cloud storage](tune-cloud-checkpointing) (if configured), and allow you to easily enable Ray Train worker fault tolerance in the future."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from tempfile import TemporaryDirectory\n",
"\n",
"import ray.train as train\n",
"from ray.train import Checkpoint\n",
"\n",
"\n",
"\n",
"def evaluate(logits, labels):\n",
" _, preds = torch.max(logits, 1)\n",
" corrects = torch.sum(preds == labels).item()\n",
" return corrects\n",
"\n",
"\n",
"def train_loop_per_worker(configs):\n",
" import warnings\n",
"\n",
" warnings.filterwarnings(\"ignore\")\n",
"\n",
" # Calculate the batch size for a single worker\n",
" worker_batch_size = configs[\"batch_size\"] // train.get_context().get_world_size()\n",
"\n",
" # Download dataset once on local rank 0 worker\n",
" if train.get_context().get_local_rank() == 0:\n",
" download_datasets()\n",
" torch.distributed.barrier()\n",
"\n",
" # Build datasets on each worker\n",
" torch_datasets = build_datasets()\n",
"\n",
" # Prepare dataloader for each worker\n",
" dataloaders = dict()\n",
" dataloaders[\"train\"] = DataLoader(\n",
" torch_datasets[\"train\"], batch_size=worker_batch_size, shuffle=True\n",
" )\n",
" dataloaders[\"val\"] = DataLoader(\n",
" torch_datasets[\"val\"], batch_size=worker_batch_size, shuffle=False\n",
" )\n",
"\n",
" # Distribute\n",
" dataloaders[\"train\"] = train.torch.prepare_data_loader(dataloaders[\"train\"])\n",
" dataloaders[\"val\"] = train.torch.prepare_data_loader(dataloaders[\"val\"])\n",
"\n",
" device = train.torch.get_device()\n",
"\n",
" # Prepare DDP Model, optimizer, and loss function\n",
" model = initialize_model()\n",
" model = train.torch.prepare_model(model)\n",
"\n",
" optimizer = optim.SGD(\n",
" model.parameters(), lr=configs[\"lr\"], momentum=configs[\"momentum\"]\n",
" )\n",
" criterion = nn.CrossEntropyLoss()\n",
"\n",
" # Start training loops\n",
" for epoch in range(configs[\"num_epochs\"]):\n",
" # Each epoch has a training and validation phase\n",
" for phase in [\"train\", \"val\"]:\n",
" if phase == \"train\":\n",
" model.train() # Set model to training mode\n",
" else:\n",
" model.eval() # Set model to evaluate mode\n",
"\n",
" running_loss = 0.0\n",
" running_corrects = 0\n",
"\n",
" if train.get_context().get_world_size() > 1:\n",
" dataloaders[phase].sampler.set_epoch(epoch)\n",
"\n",
" for inputs, labels in dataloaders[phase]:\n",
" inputs = inputs.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward\n",
" with torch.set_grad_enabled(phase == \"train\"):\n",
" # Get model outputs and calculate loss\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # backward + optimize only if in training phase\n",
" if phase == \"train\":\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # calculate statistics\n",
" running_loss += loss.item() * inputs.size(0)\n",
" running_corrects += evaluate(outputs, labels)\n",
"\n",
" size = len(torch_datasets[phase]) // train.get_context().get_world_size()\n",
" epoch_loss = running_loss / size\n",
" epoch_acc = running_corrects / size\n",
"\n",
" if train.get_context().get_world_rank() == 0:\n",
" print(\n",
" \"Epoch {}-{} Loss: {:.4f} Acc: {:.4f}\".format(\n",
" epoch, phase, epoch_loss, epoch_acc\n",
" )\n",
" )\n",
"\n",
" # Report metrics and checkpoint every epoch\n",
" if phase == \"val\":\n",
" with TemporaryDirectory() as tmpdir:\n",
" state_dict = {\n",
" \"epoch\": epoch,\n",
" \"model\": model.module.state_dict(),\n",
" \"optimizer_state_dict\": optimizer.state_dict(),\n",
" }\n",
" torch.save(state_dict, os.path.join(tmpdir, \"checkpoint.pt\"))\n",
" train.report(\n",
" metrics={\"loss\": epoch_loss, \"acc\": epoch_acc},\n",
" checkpoint=Checkpoint.from_directory(tmpdir),\n",
" )\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, setup the TorchTrainer:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from ray.train.torch import TorchTrainer\n",
"from ray.train import ScalingConfig, RunConfig, CheckpointConfig\n",
"\n",
"# Scale out model training across 4 GPUs.\n",
"scaling_config = ScalingConfig(\n",
" num_workers=4, use_gpu=True, resources_per_worker={\"CPU\": 1, \"GPU\": 1}\n",
")\n",
"\n",
"# Save the latest checkpoint\n",
"checkpoint_config = CheckpointConfig(num_to_keep=1)\n",
"\n",
"# Set experiment name and checkpoint configs\n",
"run_config = RunConfig(\n",
" name=\"finetune-resnet\",\n",
" storage_path=\"/tmp/ray_results\",\n",
" checkpoint_config=checkpoint_config,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"if SMOKE_TEST:\n",
" scaling_config = ScalingConfig(\n",
" num_workers=8, use_gpu=False, resources_per_worker={\"CPU\": 1}\n",
" )\n",
" train_loop_config[\"num_epochs\"] = 1\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = TorchTrainer(\n",
" train_loop_per_worker=train_loop_per_worker,\n",
" train_loop_config=train_loop_config,\n",
" scaling_config=scaling_config,\n",
" run_config=run_config,\n",
")\n",
"\n",
"result = trainer.fit()\n",
"print(result)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the checkpoint for prediction:\n",
"\n",
" \n",
" The metadata and checkpoints have already been saved in the `storage_path` specified in TorchTrainer:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now need to load the trained model and evaluate it on test data. The best model parameters have been saved in `log_dir`. We can load the resulting checkpoint from our fine-tuning run using the previously defined `initialize_model_from_checkpoint()` function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = initialize_model_from_checkpoint(result.checkpoint)\n",
"device = torch.device(\"cuda\")\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"if SMOKE_TEST:\n",
" device = torch.device(\"cpu\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, define a simple evaluation loop and check the performance of the checkpoint model."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.934640522875817\n"
]
}
],
"source": [
"model = model.to(device)\n",
"model.eval()\n",
"\n",
"download_datasets()\n",
"torch_datasets = build_datasets()\n",
"dataloader = DataLoader(torch_datasets[\"val\"], batch_size=32, num_workers=4)\n",
"corrects = 0\n",
"for inputs, labels in dataloader:\n",
" inputs = inputs.to(device)\n",
" labels = labels.to(device)\n",
" preds = model(inputs)\n",
" corrects += evaluate(preds, labels)\n",
"\n",
"print(\"Accuracy: \", corrects / len(dataloader.dataset))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orphan": true,
"vscode": {
"interpreter": {
"hash": "a8c1140d108077f4faeb76b2438f85e4ed675f93d004359552883616a1acd54c"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}