{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "586737af",
"metadata": {},
"source": [
"# How to use Tune with PyTorch\n",
"\n",
"\n",
"
\n",
"\n",
"
\n",
"\n",
"(tune-pytorch-cifar-ref)=\n",
"\n",
"In this walkthrough, we will show you how to integrate Tune into your PyTorch\n",
"training workflow. We will follow [this tutorial from the PyTorch documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)\n",
"for training a CIFAR10 image classifier.\n",
"\n",
"```{image} /images/pytorch_logo.png\n",
":align: center\n",
"```\n",
"\n",
"Hyperparameter tuning can make the difference between an average model and a highly\n",
"accurate one. Often simple things like choosing a different learning rate or changing\n",
"a network layer size can have a dramatic impact on your model performance. Fortunately,\n",
"Tune makes exploring these optimal parameter combinations easy - and works nicely\n",
"together with PyTorch.\n",
"\n",
"As you will see, we only need to add some slight modifications. In particular, we\n",
"need to\n",
"\n",
"1. wrap data loading and training in functions,\n",
"2. make some network parameters configurable,\n",
"3. add checkpointing (optional),\n",
"4. and define the search space for the model tuning\n",
"\n",
"```{contents}\n",
":backlinks: none\n",
":local: true\n",
"```"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7e8650d1",
"metadata": {},
"source": [
"## Setup / Imports\n",
"\n",
"First, the requirements (uncomment to install):"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "89728f69",
"metadata": {},
"outputs": [],
"source": [
"#!pip install -Uq \"ray[tune]\" torch torchvision pandas ipywidgets"
]
},
{
"cell_type": "markdown",
"id": "0cda8ce7",
"metadata": {},
"source": [
"Next, the imports:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55529285",
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"import os\n",
"import tempfile\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from filelock import FileLock\n",
"from torch.utils.data import random_split\n",
"\n",
"from ray import train, tune\n",
"from ray.tune.schedulers import ASHAScheduler"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f59e551d",
"metadata": {},
"source": [
"Most of the imports are needed for building the PyTorch model. Only the last three\n",
"imports are for Ray Tune.\n",
"\n",
"## Data loaders\n",
"\n",
"We wrap the data loaders in their own function and pass a global data directory.\n",
"This way we can share a data directory between different trials."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "01471556",
"metadata": {},
"outputs": [],
"source": [
"def load_data(data_dir=\"./data\"):\n",
" \"\"\"Create dataloaders for normalized CIFAR10 training/test subsets.\"\"\"\n",
" transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
" ])\n",
"\n",
" # We add FileLock here because multiple workers will want to\n",
" # download data, and this may cause overwrites since\n",
" # DataLoader is not threadsafe.\n",
" with FileLock(os.path.expanduser(\"~/.data.lock\")):\n",
" trainset = torchvision.datasets.CIFAR10(\n",
" root=data_dir, train=True, download=True, transform=transform)\n",
"\n",
" testset = torchvision.datasets.CIFAR10(\n",
" root=data_dir, train=False, download=True, transform=transform)\n",
"\n",
" return trainset, testset\n",
"\n",
"def create_dataloaders(trainset, batch_size, num_workers=8):\n",
" \"\"\"Create train/val splits and dataloaders.\"\"\"\n",
" train_size = int(len(trainset) * 0.8)\n",
" train_subset, val_subset = random_split(\n",
" trainset, [train_size, len(trainset) - train_size])\n",
"\n",
" train_loader = torch.utils.data.DataLoader(\n",
" train_subset,\n",
" batch_size=batch_size, \n",
" shuffle=True,\n",
" num_workers=num_workers\n",
" )\n",
" val_loader = torch.utils.data.DataLoader(\n",
" val_subset,\n",
" batch_size=batch_size,\n",
" shuffle=False, \n",
" num_workers=num_workers\n",
" )\n",
" return train_loader, val_loader"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def load_test_data():\n",
" # Load fake data for running a quick smoke-test.\n",
" trainset = torchvision.datasets.FakeData(\n",
" 128, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()\n",
" )\n",
" testset = torchvision.datasets.FakeData(\n",
" 16, (3, 32, 32), num_classes=10, transform=transforms.ToTensor()\n",
" )\n",
" return trainset, testset"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "80958cf3",
"metadata": {},
"source": [
"## Configurable neural network\n",
"\n",
"We can only tune those parameters that are configurable. In this example, we can specify\n",
"the layer sizes of the fully connected layers:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fff6bd0d",
"metadata": {},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self, l1=120, l2=84):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, l1)\n",
" self.fc2 = nn.Linear(l1, l2)\n",
" self.fc3 = nn.Linear(l2, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 16 * 5 * 5)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "fb619875",
"metadata": {},
"source": [
"## The train function\n",
"\n",
"Now it gets interesting, because we introduce some changes to the example [from the PyTorch\n",
"documentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).\n",
"\n",
"(communicating-with-ray-tune)=\n",
"\n",
"The full code example looks like this:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fa0bdae0",
"metadata": {},
"outputs": [],
"source": [
"def train_cifar(config):\n",
" net = Net(config[\"l1\"], config[\"l2\"])\n",
" device = config[\"device\"]\n",
" if device == \"cuda\":\n",
" net = nn.DataParallel(net)\n",
" net.to(device)\n",
"\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = optim.SGD(net.parameters(), lr=config[\"lr\"], momentum=0.9, weight_decay=5e-5)\n",
"\n",
" # Load existing checkpoint through `get_checkpoint()` API.\n",
" if tune.get_checkpoint():\n",
" loaded_checkpoint = tune.get_checkpoint()\n",
" with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:\n",
" model_state, optimizer_state = torch.load(\n",
" os.path.join(loaded_checkpoint_dir, \"checkpoint.pt\")\n",
" )\n",
" net.load_state_dict(model_state)\n",
" optimizer.load_state_dict(optimizer_state)\n",
"\n",
" # Data setup\n",
" if config[\"smoke_test\"]:\n",
" trainset, _ = load_test_data()\n",
" else:\n",
" trainset, _ = load_data()\n",
" train_loader, val_loader = create_dataloaders(\n",
" trainset, \n",
" config[\"batch_size\"],\n",
" num_workers=0 if config[\"smoke_test\"] else 8\n",
" )\n",
"\n",
" for epoch in range(config[\"max_num_epochs\"]): # loop over the dataset multiple times\n",
" net.train()\n",
" running_loss = 0.0\n",
" for inputs, labels in train_loader:\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
"\n",
" # forward + backward + optimize\n",
" optimizer.zero_grad() # reset gradients\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" running_loss += loss.item()\n",
"\n",
" # Validation\n",
" net.eval()\n",
" val_loss = 0.0\n",
" correct = total = 0\n",
" with torch.no_grad():\n",
" for inputs, labels in val_loader:\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
" outputs = net(inputs)\n",
" val_loss += criterion(outputs, labels).item()\n",
" _, predicted = outputs.max(1)\n",
" total += labels.size(0)\n",
" correct += predicted.eq(labels).sum().item()\n",
"\n",
" # Report metrics\n",
" metrics = {\n",
" \"loss\": val_loss / len(val_loader),\n",
" \"accuracy\": correct / total,\n",
" }\n",
"\n",
" # Here we save a checkpoint. It is automatically registered with\n",
" # Ray Tune and will potentially be accessed through in ``get_checkpoint()``\n",
" # in future iterations.\n",
" # Note to save a file-like checkpoint, you still need to put it under a directory\n",
" # to construct a checkpoint.\n",
" with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n",
" path = os.path.join(temp_checkpoint_dir, \"checkpoint.pt\")\n",
" torch.save(\n",
" (net.state_dict(), optimizer.state_dict()), path\n",
" )\n",
" checkpoint = tune.Checkpoint.from_directory(temp_checkpoint_dir)\n",
" tune.report(metrics, checkpoint=checkpoint)\n",
" print(\"Finished Training!\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "918d8baf",
"metadata": {},
"source": [
"As you can see, most of the code is adapted directly from the example.\n",
"\n",
"## Test set accuracy\n",
"\n",
"Commonly the performance of a machine learning model is tested on a hold-out test\n",
"set with data that has not been used for training the model. We also wrap this in a\n",
"function:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "93b5b4af",
"metadata": {},
"outputs": [],
"source": [
"def test_best_model(best_result, smoke_test=False):\n",
" best_trained_model = Net(best_result.config[\"l1\"], best_result.config[\"l2\"])\n",
" device = best_result.config[\"device\"]\n",
" if device == \"cuda\":\n",
" best_trained_model = nn.DataParallel(best_trained_model)\n",
" best_trained_model.to(device)\n",
"\n",
" checkpoint_path = os.path.join(best_result.checkpoint.to_directory(), \"checkpoint.pt\")\n",
"\n",
" model_state, _optimizer_state = torch.load(checkpoint_path)\n",
" best_trained_model.load_state_dict(model_state)\n",
"\n",
" if smoke_test:\n",
" _trainset, testset = load_test_data()\n",
" else:\n",
" _trainset, testset = load_data()\n",
"\n",
" testloader = torch.utils.data.DataLoader(\n",
" testset, batch_size=4, shuffle=False, num_workers=2\n",
" )\n",
"\n",
" correct = 0\n",
" total = 0\n",
" with torch.no_grad():\n",
" for data in testloader:\n",
" images, labels = data\n",
" images, labels = images.to(device), labels.to(device)\n",
" outputs = best_trained_model(images)\n",
" _, predicted = outputs.max(1)\n",
" total += labels.size(0)\n",
" correct += predicted.eq(labels).sum().item()\n",
"\n",
" print(f\"Best trial test set accuracy: {correct / total}\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "85f8230e",
"metadata": {},
"source": [
"As you can see, the function also expects a `device` parameter, so we can do the\n",
"test set validation on a GPU."
]
},
{
"cell_type": "markdown",
"id": "980720e3",
"metadata": {},
"source": [
"\n",
"## Configuring the search space\n",
"\n",
"Lastly, we need to define Tune's search space. Here is an example:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "564eab65",
"metadata": {},
"outputs": [],
"source": [
"# Set this to True for a smoke test that runs with a small synthetic dataset.\n",
"SMOKE_TEST = False"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "18d94d85",
"metadata": {
"tags": [
"hide-cell"
]
},
"outputs": [],
"source": [
"# For CI testing:\n",
"SMOKE_TEST = True"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "5416cece",
"metadata": {},
"outputs": [],
"source": [
"config = {\n",
" \"l1\": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),\n",
" \"l2\": tune.sample_from(lambda _: 2**np.random.randint(2, 9)),\n",
" \"lr\": tune.loguniform(1e-4, 1e-1),\n",
" \"batch_size\": tune.choice([2, 4, 8, 16]),\n",
" \"smoke_test\": SMOKE_TEST,\n",
" \"num_trials\": 10 if not SMOKE_TEST else 2,\n",
" \"max_num_epochs\": 10 if not SMOKE_TEST else 2,\n",
" \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
"}"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "20af95cc",
"metadata": {},
"source": [
"The `tune.sample_from()` function makes it possible to define your own sample\n",
"methods to obtain hyperparameters. In this example, the layer sizes `l1` and `l2` \n",
"should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.\n",
"The `lr` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,\n",
"the batch size is a choice between 2, 4, 8, and 16.\n",
"\n",
"At each trial, Tune will now randomly sample a combination of parameters from these\n",
"search spaces. It will then train a number of models in parallel and find the best\n",
"performing one among these. We also use the `ASHAScheduler` which will terminate badly\n",
"performing trials early.\n",
"\n",
"You can specify the number of CPUs, which are then available e.g.\n",
"to increase the `num_workers` of the PyTorch `DataLoader` instances. The selected\n",
"number of GPUs are made visible to PyTorch in each trial. Trials do not have access to\n",
"GPUs that haven't been requested for them - so you don't have to care about two trials\n",
"using the same set of resources.\n",
"\n",
"Here we can also specify fractional GPUs, so something like `gpus_per_trial=0.5` is\n",
"completely valid. The trials will then share GPUs among each other.\n",
"You just have to make sure that the models still fit in the GPU memory.\n",
"\n",
"After training the models, we will find the best performing one and load the trained\n",
"network from the checkpoint file. We then obtain the test set accuracy and report\n",
"everything by printing.\n",
"\n",
"The full main function looks like this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91d83380",
"metadata": {
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"def main(config, gpus_per_trial=1):\n",
" scheduler = ASHAScheduler(\n",
" time_attr=\"training_iteration\",\n",
" max_t=config[\"max_num_epochs\"],\n",
" grace_period=1,\n",
" reduction_factor=2)\n",
" \n",
" tuner = tune.Tuner(\n",
" tune.with_resources(\n",
" tune.with_parameters(train_cifar),\n",
" resources={\"cpu\": 2, \"gpu\": gpus_per_trial}\n",
" ),\n",
" tune_config=tune.TuneConfig(\n",
" metric=\"loss\",\n",
" mode=\"min\",\n",
" scheduler=scheduler,\n",
" num_samples=config[\"num_trials\"],\n",
" ),\n",
" param_space=config,\n",
" )\n",
" results = tuner.fit()\n",
" \n",
" best_result = results.get_best_result(\"loss\", \"min\")\n",
"\n",
" print(f\"Best trial config: {best_result.config}\")\n",
" print(f\"Best trial final validation loss: {best_result.metrics['loss']}\")\n",
" print(f\"Best trial final validation accuracy: {best_result.metrics['accuracy']}\")\n",
"\n",
" test_best_model(best_result, smoke_test=config[\"smoke_test\"])\n",
"\n",
"main(config, gpus_per_trial=1 if torch.cuda.is_available() else 0)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "b702b4ce",
"metadata": {},
"source": [
"If you run the code, an example output could look like this:\n",
"\n",
"```{code-block} bash\n",
":emphasize-lines: 7\n",
"\n",
" Number of trials: 10 (10 TERMINATED)\n",
" +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+\n",
" | Trial name | status | loc | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration |\n",
" |-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------|\n",
" | train_cifar_87d1f_00000 | TERMINATED | | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 |\n",
" | train_cifar_87d1f_00001 | TERMINATED | | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 |\n",
" | train_cifar_87d1f_00002 | TERMINATED | | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 |\n",
" | train_cifar_87d1f_00003 | TERMINATED | | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 |\n",
" | train_cifar_87d1f_00004 | TERMINATED | | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 |\n",
" | train_cifar_87d1f_00005 | TERMINATED | | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 |\n",
" | train_cifar_87d1f_00006 | TERMINATED | | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 |\n",
" | train_cifar_87d1f_00007 | TERMINATED | | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 |\n",
" | train_cifar_87d1f_00008 | TERMINATED | | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 |\n",
" | train_cifar_87d1f_00009 | TERMINATED | | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 |\n",
" +-------------------------+------------+-------+------+------+-------------+--------------+---------+------------+----------------------+\n",
"\n",
"\n",
" Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.0027624906698231976, 'batch_size': 16, 'data_dir': '...'}\n",
" Best trial final validation loss: 1.1815014744281769\n",
" Best trial final validation accuracy: 0.5836\n",
" Best trial test set accuracy: 0.5806\n",
"```\n",
"\n",
"As you can see, most trials have been stopped early in order to avoid wasting resources.\n",
"The best performing trial achieved a validation accuracy of about 58%, which could\n",
"be confirmed on the test set.\n",
"\n",
"So that's it! You can now tune the parameters of your PyTorch models.\n",
"\n",
"## See More PyTorch Examples\n",
"\n",
"- {doc}`/tune/examples/includes/mnist_pytorch`: Converts the PyTorch MNIST example to use Tune with the function-based API.\n",
" Also shows how to easily convert something relying on argparse to use Tune.\n",
"- {doc}`/tune/examples/includes/pbt_convnet_function_example`: Example training a ConvNet with checkpointing in function API.\n",
"- {doc}`/tune/examples/includes/mnist_pytorch_trainable`: Converts the PyTorch MNIST example to use Tune with Trainable API.\n",
" Also uses the HyperBandScheduler and checkpoints the model at the end."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "tune-pytorch-cifar",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
},
"orphan": true
},
"nbformat": 4,
"nbformat_minor": 5
}