{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Finetuning a Pytorch Image Classifier with Ray Train\n", "This example fine tunes a pre-trained ResNet model with Ray Train. \n", "\n", "For this example, the network architecture consists of the intermediate layer output of a pre-trained ResNet model, which feeds into a randomly initialized linear layer that outputs classification logits for our new task.\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load and preprocess finetuning dataset\n", "This example is adapted from Pytorch's [Finetuning Torchvision Models](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html) tutorial.\n", "We will use *hymenoptera_data* as the finetuning dataset, which contains two classes (bees and ants) and 397 total images (across training and validation). This is a quite small dataset and used only for demonstration purposes. Use `torchvision.datasets.ImageFolder()` to load the images and their corresponding labels." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, models, transforms\n", "import numpy as np\n", "\n", "# Data augmentation and normalization for training\n", "# Just normalization for validation\n", "data_transforms = {\n", " \"train\": transforms.Compose(\n", " [\n", " transforms.RandomResizedCrop(224),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n", " ]\n", " ),\n", " \"val\": transforms.Compose(\n", " [\n", " transforms.Resize(224),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n", " ]\n", " ),\n", "}\n", "\n", "def download_datasets():\n", " os.system(\n", " \"wget https://download.pytorch.org/tutorial/hymenoptera_data.zip >/dev/null 2>&1\"\n", " )\n", " os.system(\"unzip hymenoptera_data.zip >/dev/null 2>&1\")\n", "\n", "# Download and build torch datasets\n", "def build_datasets():\n", " torch_datasets = {}\n", " for split in [\"train\", \"val\"]:\n", " torch_datasets[split] = datasets.ImageFolder(\n", " os.path.join(\"./hymenoptera_data\", split), data_transforms[split]\n", " )\n", " return torch_datasets\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "if SMOKE_TEST:\n", " from torch.utils.data import Subset\n", "\n", " def build_datasets():\n", " torch_datasets = {}\n", " for split in [\"train\", \"val\"]:\n", " torch_datasets[split] = datasets.ImageFolder(\n", " os.path.join(\"./hymenoptera_data\", split), data_transforms[split]\n", " )\n", " \n", " # Only take a subset for smoke test\n", " for split in [\"train\", \"val\"]:\n", " indices = list(range(100))\n", " torch_datasets[split] = Subset(torch_datasets[split], indices)\n", " return torch_datasets\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialize Model and Fine-tuning configs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's define the training configuration that will be passed into the training loop function later." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "train_loop_config = {\n", " \"input_size\": 224, # Input image size (224 x 224)\n", " \"batch_size\": 32, # Batch size for training\n", " \"num_epochs\": 10, # Number of epochs to train for\n", " \"lr\": 0.001, # Learning Rate\n", " \"momentum\": 0.9, # SGD optimizer momentum\n", "}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let's define our model. You can either create a model from pre-trained weights or reload the model checkpoint from a previous run." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "from ray.train import Checkpoint\n", "\n", "# Option 1: Initialize model with pretrained weights\n", "def initialize_model():\n", " # Load pretrained model params\n", " model = models.resnet50(pretrained=True)\n", "\n", " # Replace the original classifier with a new Linear layer\n", " num_features = model.fc.in_features\n", " model.fc = nn.Linear(num_features, 2)\n", "\n", " # Ensure all params get updated during finetuning\n", " for param in model.parameters():\n", " param.requires_grad = True\n", " return model\n", "\n", "\n", "# Option 2: Initialize model with an Train checkpoint\n", "# Replace this with your own uri\n", "CHECKPOINT_FROM_S3 = Checkpoint(\n", " path=\"s3://air-example-data/finetune-resnet-checkpoint/TorchTrainer_4f69f_00000_0_2023-02-14_14-04-09/checkpoint_000001/\"\n", ")\n", "\n", "\n", "def initialize_model_from_checkpoint(checkpoint: Checkpoint):\n", " with checkpoint.as_directory() as tmpdir:\n", " state_dict = torch.load(os.path.join(tmpdir, \"checkpoint.pt\"))\n", " resnet50 = initialize_model()\n", " resnet50.load_state_dict(state_dict[\"model\"])\n", " return resnet50\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define the Training Loop\n", "\n", "The `train_loop_per_worker` function defines the fine-tuning procedure for each worker.\n", "\n", "**1. Prepare dataloaders for each worker**:\n", "- This tutorial assumes you are using PyTorch's native `torch.utils.data.Dataset` for data input. {meth}`train.torch.prepare_data_loader() ` prepares your dataLoader for distributed execution. You can also use Ray Data for more efficient preprocessing. For more details on using Ray Data for for images, see the {doc}`Working with Images ` Ray Data user guide.\n", "\n", "**2. Prepare your model**:\n", "- {meth}`train.torch.prepare_model() ` prepares the model for distributed training. Under the hood, it converts your torch model to `DistributedDataParallel` model, which synchronize its weights across all workers.\n", "\n", "**3. Report metrics and checkpoint**:\n", "- {meth}`train.report() ` will report metrics and checkpoints to Ray Train.\n", "- Saving checkpoints through {meth}`train.report(metrics, checkpoint=...) ` will automatically [upload checkpoints to cloud storage](tune-cloud-checkpointing) (if configured), and allow you to easily enable Ray Train worker fault tolerance in the future." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import os\n", "from tempfile import TemporaryDirectory\n", "\n", "import ray.train as train\n", "from ray.train import Checkpoint\n", "\n", "\n", "\n", "def evaluate(logits, labels):\n", " _, preds = torch.max(logits, 1)\n", " corrects = torch.sum(preds == labels).item()\n", " return corrects\n", "\n", "\n", "def train_loop_per_worker(configs):\n", " import warnings\n", "\n", " warnings.filterwarnings(\"ignore\")\n", "\n", " # Calculate the batch size for a single worker\n", " worker_batch_size = configs[\"batch_size\"] // train.get_context().get_world_size()\n", "\n", " # Download dataset once on local rank 0 worker\n", " if train.get_context().get_local_rank() == 0:\n", " download_datasets()\n", " torch.distributed.barrier()\n", "\n", " # Build datasets on each worker\n", " torch_datasets = build_datasets()\n", "\n", " # Prepare dataloader for each worker\n", " dataloaders = dict()\n", " dataloaders[\"train\"] = DataLoader(\n", " torch_datasets[\"train\"], batch_size=worker_batch_size, shuffle=True\n", " )\n", " dataloaders[\"val\"] = DataLoader(\n", " torch_datasets[\"val\"], batch_size=worker_batch_size, shuffle=False\n", " )\n", "\n", " # Distribute\n", " dataloaders[\"train\"] = train.torch.prepare_data_loader(dataloaders[\"train\"])\n", " dataloaders[\"val\"] = train.torch.prepare_data_loader(dataloaders[\"val\"])\n", "\n", " device = train.torch.get_device()\n", "\n", " # Prepare DDP Model, optimizer, and loss function\n", " model = initialize_model()\n", " model = train.torch.prepare_model(model)\n", "\n", " optimizer = optim.SGD(\n", " model.parameters(), lr=configs[\"lr\"], momentum=configs[\"momentum\"]\n", " )\n", " criterion = nn.CrossEntropyLoss()\n", "\n", " # Start training loops\n", " for epoch in range(configs[\"num_epochs\"]):\n", " # Each epoch has a training and validation phase\n", " for phase in [\"train\", \"val\"]:\n", " if phase == \"train\":\n", " model.train() # Set model to training mode\n", " else:\n", " model.eval() # Set model to evaluate mode\n", "\n", " running_loss = 0.0\n", " running_corrects = 0\n", "\n", " if train.get_context().get_world_size() > 1:\n", " dataloaders[phase].sampler.set_epoch(epoch)\n", "\n", " for inputs, labels in dataloaders[phase]:\n", " inputs = inputs.to(device)\n", " labels = labels.to(device)\n", "\n", " # zero the parameter gradients\n", " optimizer.zero_grad()\n", "\n", " # forward\n", " with torch.set_grad_enabled(phase == \"train\"):\n", " # Get model outputs and calculate loss\n", " outputs = model(inputs)\n", " loss = criterion(outputs, labels)\n", "\n", " # backward + optimize only if in training phase\n", " if phase == \"train\":\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # calculate statistics\n", " running_loss += loss.item() * inputs.size(0)\n", " running_corrects += evaluate(outputs, labels)\n", "\n", " size = len(torch_datasets[phase]) // train.get_context().get_world_size()\n", " epoch_loss = running_loss / size\n", " epoch_acc = running_corrects / size\n", "\n", " if train.get_context().get_world_rank() == 0:\n", " print(\n", " \"Epoch {}-{} Loss: {:.4f} Acc: {:.4f}\".format(\n", " epoch, phase, epoch_loss, epoch_acc\n", " )\n", " )\n", "\n", " # Report metrics and checkpoint every epoch\n", " if phase == \"val\":\n", " with TemporaryDirectory() as tmpdir:\n", " state_dict = {\n", " \"epoch\": epoch,\n", " \"model\": model.module.state_dict(),\n", " \"optimizer_state_dict\": optimizer.state_dict(),\n", " }\n", " torch.save(state_dict, os.path.join(tmpdir, \"checkpoint.pt\"))\n", " train.report(\n", " metrics={\"loss\": epoch_loss, \"acc\": epoch_acc},\n", " checkpoint=Checkpoint.from_directory(tmpdir),\n", " )\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, setup the TorchTrainer:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from ray.train.torch import TorchTrainer\n", "from ray.train import ScalingConfig, RunConfig, CheckpointConfig\n", "\n", "# Scale out model training across 4 GPUs.\n", "scaling_config = ScalingConfig(\n", " num_workers=4, use_gpu=True, resources_per_worker={\"CPU\": 1, \"GPU\": 1}\n", ")\n", "\n", "# Save the latest checkpoint\n", "checkpoint_config = CheckpointConfig(num_to_keep=1)\n", "\n", "# Set experiment name and checkpoint configs\n", "run_config = RunConfig(\n", " name=\"finetune-resnet\",\n", " storage_path=\"/tmp/ray_results\",\n", " checkpoint_config=checkpoint_config,\n", ")\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [ "remove-cell" ] }, "outputs": [], "source": [ "if SMOKE_TEST:\n", " scaling_config = ScalingConfig(\n", " num_workers=8, use_gpu=False, resources_per_worker={\"CPU\": 1}\n", " )\n", " train_loop_config[\"num_epochs\"] = 1\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer = TorchTrainer(\n", " train_loop_per_worker=train_loop_per_worker,\n", " train_loop_config=train_loop_config,\n", " scaling_config=scaling_config,\n", " run_config=run_config,\n", ")\n", "\n", "result = trainer.fit()\n", "print(result)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the checkpoint for prediction:\n", "\n", " \n", " The metadata and checkpoints have already been saved in the `storage_path` specified in TorchTrainer:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now need to load the trained model and evaluate it on test data. The best model parameters have been saved in `log_dir`. We can load the resulting checkpoint from our fine-tuning run using the previously defined `initialize_model_from_checkpoint()` function.