{ "cells": [ { "cell_type": "markdown", "id": "436ead19", "metadata": {}, "source": [ "# Simple Parallel Model Selection\n", "\n", "```{tip}\n", "For a production-grade implementation of distributed\n", "hyperparameter tuning, use [Ray Tune](https://docs.ray.io/en/master/tune.html), a scalable hyperparameter\n", "tuning library built using Ray's Actor API.\n", "```\n", "\n", "In this example, we'll demonstrate how to quickly write a hyperparameter\n", "tuning script that evaluates a set of hyperparameters in parallel.\n", "\n", "This script will demonstrate how to use two important parts of the Ray API:\n", "using ``ray.remote`` to define remote functions and ``ray.wait`` to wait for\n", "their results to be ready.\n", "\n", "```{image} /ray-core/images/hyperparameter.png\n", ":align: center\n", "```\n", "\n", "## Setup: Dependencies\n", "\n", "First, import some dependencies and define functions to generate\n", "random hyperparameters and retrieve data." ] }, { "cell_type": "code", "execution_count": null, "id": "8e992dc3", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "from filelock import FileLock\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "\n", "import ray\n", "\n", "ray.init()\n", "\n", "# The number of sets of random hyperparameters to try.\n", "num_evaluations = 10\n", "\n", "\n", "# A function for generating random hyperparameters.\n", "def generate_hyperparameters():\n", " return {\n", " \"learning_rate\": 10 ** np.random.uniform(-5, 1),\n", " \"batch_size\": np.random.randint(1, 100),\n", " \"momentum\": np.random.uniform(0, 1),\n", " }\n", "\n", "\n", "def get_data_loaders(batch_size):\n", " mnist_transforms = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", " )\n", "\n", " # We add FileLock here because multiple workers will want to\n", " # download data, and this may cause overwrites since\n", " # DataLoader is not threadsafe.\n", " with FileLock(os.path.expanduser(\"~/data.lock\")):\n", " train_loader = torch.utils.data.DataLoader(\n", " datasets.MNIST(\n", " \"~/data\", train=True, download=True, transform=mnist_transforms\n", " ),\n", " batch_size=batch_size,\n", " shuffle=True,\n", " )\n", " test_loader = torch.utils.data.DataLoader(\n", " datasets.MNIST(\"~/data\", train=False, transform=mnist_transforms),\n", " batch_size=batch_size,\n", " shuffle=True,\n", " )\n", " return train_loader, test_loader" ] }, { "cell_type": "markdown", "id": "a3f0d421", "metadata": {}, "source": [ "## Setup: Defining the Neural Network\n", "\n", "We define a small neural network to use in training. In addition,\n", "we created methods to train and test this neural network." ] }, { "cell_type": "code", "execution_count": null, "id": "c02ed1db", "metadata": {}, "outputs": [], "source": [ "class ConvNet(nn.Module):\n", " \"\"\"Simple two layer Convolutional Neural Network.\"\"\"\n", "\n", " def __init__(self):\n", " super(ConvNet, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 3, kernel_size=3)\n", " self.fc = nn.Linear(192, 10)\n", "\n", " def forward(self, x):\n", " x = F.relu(F.max_pool2d(self.conv1(x), 3))\n", " x = x.view(-1, 192)\n", " x = self.fc(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", "\n", "def train(model, optimizer, train_loader, device=torch.device(\"cpu\")):\n", " \"\"\"Optimize the model with one pass over the data.\n", "\n", " Cuts off at 1024 samples to simplify training.\n", " \"\"\"\n", " model.train()\n", " for batch_idx, (data, target) in enumerate(train_loader):\n", " if batch_idx * len(data) > 1024:\n", " return\n", " data, target = data.to(device), target.to(device)\n", " optimizer.zero_grad()\n", " output = model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " optimizer.step()\n", "\n", "\n", "def test(model, test_loader, device=torch.device(\"cpu\")):\n", " \"\"\"Checks the validation accuracy of the model.\n", "\n", " Cuts off at 512 samples for simplicity.\n", " \"\"\"\n", " model.eval()\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for batch_idx, (data, target) in enumerate(test_loader):\n", " if batch_idx * len(data) > 512:\n", " break\n", " data, target = data.to(device), target.to(device)\n", " outputs = model(data)\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += target.size(0)\n", " correct += (predicted == target).sum().item()\n", "\n", " return correct / total" ] }, { "cell_type": "markdown", "id": "f3a9ed6f", "metadata": {}, "source": [ "## Evaluating the Hyperparameters\n", "\n", "For a given configuration, the neural network created previously\n", "will be trained and return the accuracy of the model. These trained\n", "networks will then be tested for accuracy to find the best set of\n", "hyperparameters.\n", "\n", "The ``@ray.remote`` decorator defines a remote process." ] }, { "cell_type": "code", "execution_count": null, "id": "2471f2db", "metadata": {}, "outputs": [], "source": [ "@ray.remote\n", "def evaluate_hyperparameters(config):\n", " model = ConvNet()\n", " train_loader, test_loader = get_data_loaders(config[\"batch_size\"])\n", " optimizer = optim.SGD(\n", " model.parameters(), lr=config[\"learning_rate\"], momentum=config[\"momentum\"]\n", " )\n", " train(model, optimizer, train_loader)\n", " return test(model, test_loader)" ] }, { "cell_type": "markdown", "id": "62f180e6", "metadata": {}, "source": [ "## Synchronous Evaluation of Randomly Generated Hyperparameters\n", "\n", "We will create multiple sets of random hyperparameters for our neural\n", "network that will be evaluated in parallel." ] }, { "cell_type": "code", "execution_count": null, "id": "c4075165", "metadata": {}, "outputs": [], "source": [ "# Keep track of the best hyperparameters and the best accuracy.\n", "best_hyperparameters = None\n", "best_accuracy = 0\n", "# A list holding the object refs for all of the experiments that we have\n", "# launched but have not yet been processed.\n", "remaining_ids = []\n", "# A dictionary mapping an experiment's object ref to its hyperparameters.\n", "# hyerparameters used for that experiment.\n", "hyperparameters_mapping = {}" ] }, { "cell_type": "markdown", "id": "8dd73456", "metadata": {}, "source": [ "Launch asynchronous parallel tasks for evaluating different\n", "hyperparameters. ``accuracy_id`` is an ObjectRef that acts as a handle to\n", "the remote task. It is used later to fetch the result of the task\n", "when the task finishes." ] }, { "cell_type": "code", "execution_count": null, "id": "eda2226a", "metadata": {}, "outputs": [], "source": [ "# Randomly generate sets of hyperparameters and launch a task to evaluate it.\n", "for i in range(num_evaluations):\n", " hyperparameters = generate_hyperparameters()\n", " accuracy_id = evaluate_hyperparameters.remote(hyperparameters)\n", " remaining_ids.append(accuracy_id)\n", " hyperparameters_mapping[accuracy_id] = hyperparameters" ] }, { "cell_type": "markdown", "id": "bd0e53ec", "metadata": {}, "source": [ "Process each hyperparameter and corresponding accuracy in the order that\n", "they finish to store the hyperparameters with the best accuracy." ] }, { "cell_type": "code", "execution_count": null, "id": "d95ca22b", "metadata": {}, "outputs": [], "source": [ "# Fetch and print the results of the tasks in the order that they complete.\n", "while remaining_ids:\n", " # Use ray.wait to get the object ref of the first task that completes.\n", " done_ids, remaining_ids = ray.wait(remaining_ids)\n", " # There is only one return result by default.\n", " result_id = done_ids[0]\n", "\n", " hyperparameters = hyperparameters_mapping[result_id]\n", " accuracy = ray.get(result_id)\n", " print(\n", " \"\"\"We achieve accuracy {:.3}% with\n", " learning_rate: {:.2}\n", " batch_size: {}\n", " momentum: {:.2}\n", " \"\"\".format(\n", " 100 * accuracy,\n", " hyperparameters[\"learning_rate\"],\n", " hyperparameters[\"batch_size\"],\n", " hyperparameters[\"momentum\"],\n", " )\n", " )\n", " if accuracy > best_accuracy:\n", " best_hyperparameters = hyperparameters\n", " best_accuracy = accuracy\n", "\n", "# Record the best performing set of hyperparameters.\n", "print(\n", " \"\"\"Best accuracy over {} trials was {:.3} with\n", " learning_rate: {:.2}\n", " batch_size: {}\n", " momentum: {:.2}\n", " \"\"\".format(\n", " num_evaluations,\n", " 100 * best_accuracy,\n", " best_hyperparameters[\"learning_rate\"],\n", " best_hyperparameters[\"batch_size\"],\n", " best_hyperparameters[\"momentum\"],\n", " )\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }