{ "cells": [ { "cell_type": "markdown", "id": "ac6a47d6", "metadata": {}, "source": [ "# Parameter Server\n", "\n", "```{tip}\n", "For a production-grade implementation of distributed\n", "training, use [Ray Train](https://docs.ray.io/en/master/train/train.html).\n", "```\n", "\n", "The parameter server is a framework for distributed machine learning training.\n", "\n", "In the parameter server framework, a centralized server (or group of server\n", "nodes) maintains global shared parameters of a machine-learning model\n", "(e.g., a neural network) while the data and computation of calculating\n", "updates (i.e., gradient descent updates) are distributed over worker nodes.\n", "\n", "```{image} /ray-core/images/param_actor.png\n", ":align: center\n", "```\n", "\n", "Parameter servers are a core part of many machine learning applications. This\n", "document walks through how to implement simple synchronous and asynchronous\n", "parameter servers using Ray actors.\n", "\n", "To run the application, first install some dependencies.\n", "\n", "```bash\n", "pip install torch torchvision filelock\n", "```\n", "\n", "Let's first define some helper functions and import some dependencies." ] }, { "cell_type": "code", "execution_count": null, "id": "84677721", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torchvision import datasets, transforms\n", "from filelock import FileLock\n", "import numpy as np\n", "\n", "import ray\n", "\n", "\n", "def get_data_loader():\n", " \"\"\"Safely downloads data. Returns training/validation set dataloader.\"\"\"\n", " mnist_transforms = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", " )\n", "\n", " # We add FileLock here because multiple workers will want to\n", " # download data, and this may cause overwrites since\n", " # DataLoader is not threadsafe.\n", " with FileLock(os.path.expanduser(\"~/data.lock\")):\n", " train_loader = torch.utils.data.DataLoader(\n", " datasets.MNIST(\n", " \"~/data\", train=True, download=True, transform=mnist_transforms\n", " ),\n", " batch_size=128,\n", " shuffle=True,\n", " )\n", " test_loader = torch.utils.data.DataLoader(\n", " datasets.MNIST(\"~/data\", train=False, transform=mnist_transforms),\n", " batch_size=128,\n", " shuffle=True,\n", " )\n", " return train_loader, test_loader\n", "\n", "\n", "def evaluate(model, test_loader):\n", " \"\"\"Evaluates the accuracy of the model on a validation dataset.\"\"\"\n", " model.eval()\n", " correct = 0\n", " total = 0\n", " with torch.no_grad():\n", " for batch_idx, (data, target) in enumerate(test_loader):\n", " # This is only set to finish evaluation faster.\n", " if batch_idx * len(data) > 1024:\n", " break\n", " outputs = model(data)\n", " _, predicted = torch.max(outputs.data, 1)\n", " total += target.size(0)\n", " correct += (predicted == target).sum().item()\n", " return 100.0 * correct / total" ] }, { "cell_type": "markdown", "id": "fa5c9aea", "metadata": {}, "source": [ "## Setup: Defining the Neural Network\n", "\n", "We define a small neural network to use in training. We provide\n", "some helper functions for obtaining data, including getter/setter\n", "methods for gradients and weights." ] }, { "cell_type": "code", "execution_count": null, "id": "9b064f87", "metadata": {}, "outputs": [], "source": [ "class ConvNet(nn.Module):\n", " \"\"\"Small ConvNet for MNIST.\"\"\"\n", "\n", " def __init__(self):\n", " super(ConvNet, self).__init__()\n", " self.conv1 = nn.Conv2d(1, 3, kernel_size=3)\n", " self.fc = nn.Linear(192, 10)\n", "\n", " def forward(self, x):\n", " x = F.relu(F.max_pool2d(self.conv1(x), 3))\n", " x = x.view(-1, 192)\n", " x = self.fc(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def get_weights(self):\n", " return {k: v.cpu() for k, v in self.state_dict().items()}\n", "\n", " def set_weights(self, weights):\n", " self.load_state_dict(weights)\n", "\n", " def get_gradients(self):\n", " grads = []\n", " for p in self.parameters():\n", " grad = None if p.grad is None else p.grad.data.cpu().numpy()\n", " grads.append(grad)\n", " return grads\n", "\n", " def set_gradients(self, gradients):\n", " for g, p in zip(gradients, self.parameters()):\n", " if g is not None:\n", " p.grad = torch.from_numpy(g)" ] }, { "cell_type": "markdown", "id": "81e2bf73", "metadata": {}, "source": [ "## Defining the Parameter Server\n", "\n", "\n", "The parameter server will hold a copy of the model.\n", "During training, it will:\n", "\n", "1. Receive gradients and apply them to its model.\n", "\n", "2. Send the updated model back to the workers.\n", "\n", "The ``@ray.remote`` decorator defines a remote process. It wraps the\n", "ParameterServer class and allows users to instantiate it as a\n", "remote actor." ] }, { "cell_type": "code", "execution_count": null, "id": "77b756a4", "metadata": {}, "outputs": [], "source": [ "@ray.remote\n", "class ParameterServer(object):\n", " def __init__(self, lr):\n", " self.model = ConvNet()\n", " self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)\n", "\n", " def apply_gradients(self, *gradients):\n", " summed_gradients = [\n", " np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients)\n", " ]\n", " self.optimizer.zero_grad()\n", " self.model.set_gradients(summed_gradients)\n", " self.optimizer.step()\n", " return self.model.get_weights()\n", "\n", " def get_weights(self):\n", " return self.model.get_weights()" ] }, { "cell_type": "markdown", "id": "ccf86e35", "metadata": {}, "source": [ "## Defining the Worker\n", "\n", "The worker will also hold a copy of the model. During training. it will\n", "continuously evaluate data and send gradients\n", "to the parameter server. The worker will synchronize its model with the\n", "Parameter Server model weights." ] }, { "cell_type": "code", "execution_count": null, "id": "18737d3a", "metadata": {}, "outputs": [], "source": [ "@ray.remote\n", "class DataWorker(object):\n", " def __init__(self):\n", " self.model = ConvNet()\n", " self.data_iterator = iter(get_data_loader()[0])\n", "\n", " def compute_gradients(self, weights):\n", " self.model.set_weights(weights)\n", " try:\n", " data, target = next(self.data_iterator)\n", " except StopIteration: # When the epoch ends, start a new epoch.\n", " self.data_iterator = iter(get_data_loader()[0])\n", " data, target = next(self.data_iterator)\n", " self.model.zero_grad()\n", " output = self.model(data)\n", " loss = F.nll_loss(output, target)\n", " loss.backward()\n", " return self.model.get_gradients()" ] }, { "cell_type": "markdown", "id": "763fc48e", "metadata": {}, "source": [ "## Synchronous Parameter Server Training\n", "\n", "We'll now create a synchronous parameter server training scheme. We'll first\n", "instantiate a process for the parameter server, along with multiple\n", "workers." ] }, { "cell_type": "code", "execution_count": null, "id": "e3b47dc6", "metadata": {}, "outputs": [], "source": [ "iterations = 200\n", "num_workers = 2\n", "\n", "ray.init(ignore_reinit_error=True)\n", "ps = ParameterServer.remote(1e-2)\n", "workers = [DataWorker.remote() for i in range(num_workers)]" ] }, { "cell_type": "markdown", "id": "3545f369", "metadata": {}, "source": [ "We'll also instantiate a model on the driver process to evaluate the test\n", "accuracy during training." ] }, { "cell_type": "code", "execution_count": null, "id": "f4db4f83", "metadata": {}, "outputs": [], "source": [ "model = ConvNet()\n", "test_loader = get_data_loader()[1]" ] }, { "cell_type": "markdown", "id": "a7a721b7", "metadata": {}, "source": [ "Training alternates between:\n", "\n", "1. Computing the gradients given the current weights from the server\n", "2. Updating the parameter server's weights with the gradients." ] }, { "cell_type": "code", "execution_count": null, "id": "75da1bf4", "metadata": {}, "outputs": [], "source": [ "print(\"Running synchronous parameter server training.\")\n", "current_weights = ps.get_weights.remote()\n", "for i in range(iterations):\n", " gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]\n", " # Calculate update after all gradients are available.\n", " current_weights = ps.apply_gradients.remote(*gradients)\n", "\n", " if i % 10 == 0:\n", " # Evaluate the current model.\n", " model.set_weights(ray.get(current_weights))\n", " accuracy = evaluate(model, test_loader)\n", " print(\"Iter {}: \\taccuracy is {:.1f}\".format(i, accuracy))\n", "\n", "print(\"Final accuracy is {:.1f}.\".format(accuracy))\n", "# Clean up Ray resources and processes before the next example.\n", "ray.shutdown()" ] }, { "cell_type": "markdown", "id": "11c7f855", "metadata": {}, "source": [ "## Asynchronous Parameter Server Training\n", "\n", "We'll now create a synchronous parameter server training scheme. We'll first\n", "instantiate a process for the parameter server, along with multiple\n", "workers." ] }, { "cell_type": "code", "execution_count": null, "id": "dcc5407d", "metadata": {}, "outputs": [], "source": [ "print(\"Running Asynchronous Parameter Server Training.\")\n", "\n", "ray.init(ignore_reinit_error=True)\n", "ps = ParameterServer.remote(1e-2)\n", "workers = [DataWorker.remote() for i in range(num_workers)]" ] }, { "cell_type": "markdown", "id": "1dc42aa6", "metadata": {}, "source": [ "Here, workers will asynchronously compute the gradients given its\n", "current weights and send these gradients to the parameter server as\n", "soon as they are ready. When the Parameter server finishes applying the\n", "new gradient, the server will send back a copy of the current weights to the\n", "worker. The worker will then update the weights and repeat." ] }, { "cell_type": "code", "execution_count": null, "id": "4013539b", "metadata": {}, "outputs": [], "source": [ "current_weights = ps.get_weights.remote()\n", "\n", "gradients = {}\n", "for worker in workers:\n", " gradients[worker.compute_gradients.remote(current_weights)] = worker\n", "\n", "for i in range(iterations * num_workers):\n", " ready_gradient_list, _ = ray.wait(list(gradients))\n", " ready_gradient_id = ready_gradient_list[0]\n", " worker = gradients.pop(ready_gradient_id)\n", "\n", " # Compute and apply gradients.\n", " current_weights = ps.apply_gradients.remote(*[ready_gradient_id])\n", " gradients[worker.compute_gradients.remote(current_weights)] = worker\n", "\n", " if i % 10 == 0:\n", " # Evaluate the current model after every 10 updates.\n", " model.set_weights(ray.get(current_weights))\n", " accuracy = evaluate(model, test_loader)\n", " print(\"Iter {}: \\taccuracy is {:.1f}\".format(i, accuracy))\n", "\n", "print(\"Final accuracy is {:.1f}.\".format(accuracy))" ] }, { "cell_type": "markdown", "id": "dfb8532b", "metadata": {}, "source": [ "## Final Thoughts\n", "\n", "This approach is powerful because it enables you to implement a parameter\n", "server with a few lines of code as part of a Python application.\n", "As a result, this simplifies the deployment of applications that use\n", "parameter servers and to modify the behavior of the parameter server.\n", "\n", "For example, sharding the parameter server, changing the update rule,\n", "switching between asynchronous and synchronous updates, ignoring\n", "straggler workers, or any number of other customizations,\n", "will only require a few extra lines of code." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }