{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "16a75f79", "metadata": {}, "source": [ "(convert-torch-to-train)=\n", "\n", "# Convert existing PyTorch code to Ray Train\n", "\n", "If you already have working PyTorch code, you don't have to start from scratch to utilize the benefits of Ray Train. Instead, you can continue to use your existing code and incrementally add Ray Train components as needed.\n", "\n", "Some of the benefits you'll get by using Ray Train with your existing PyTorch training code:\n", "\n", "- Easy distributed data-parallel training on a cluster\n", "- Automatic checkpointing/fault tolerance and result tracking\n", "- Parallel data preprocessing\n", "- Seamless integration with hyperparameter tuning\n", "\n", "This tutorial will show you how to start with Ray Train from your existing PyTorch training code and learn how to **distribute your training**.\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9a4855cf", "metadata": {}, "source": [ "## The example code\n", "\n", "The example code we'll be using is that of the [PyTorch quickstart tutorial](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html). This code trains a neural network classifier on the FashionMNIST dataset.\n", "\n", "You can find the code we used for this tutorial [here on GitHub](https://github.com/pytorch/tutorials/blob/8dddccc4c69116ca724aa82bd5f4596ef7ad119c/beginner_source/basics/quickstart_tutorial.py)." ] }, { "attachments": {}, "cell_type": "markdown", "id": "a42faedb", "metadata": {}, "source": [ "## Unmodified\n", "Let's start with the unmodified code from the example. A thorough explanation of the parts is given in the full tutorial - we'll just focus on the code here.\n", "\n", "We start with some imports:" ] }, { "cell_type": "code", "execution_count": 1, "id": "01af2222", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision.transforms import ToTensor\n", "\n", "import os\n", "from tempfile import TemporaryDirectory" ] }, { "attachments": {}, "cell_type": "markdown", "id": "db36ae56", "metadata": {}, "source": [ "Then we download the data: \n", "\n", "This tutorial assumes that your existing code is using the `torch.utils.data.Dataset` native to PyTorch. It continues to use `torch.utils.data.Dataset` to allow you to make as few code changes as possible. **This tutorial also runs with Ray Data, which gives you the benefits of efficient parallel preprocessing.** For more details on using Ray Data for for images, see the {doc}`Working with Images ` Ray Data user guide." ] }, { "cell_type": "code", "execution_count": null, "id": "28126be5", "metadata": {}, "outputs": [], "source": [ "# Download training data from open datasets.\n", "training_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=True,\n", " download=True,\n", " transform=ToTensor(),\n", ")\n", "\n", "# Download test data from open datasets.\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor(),\n", ")\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "9795c146", "metadata": {}, "source": [ "We can now define the dataloaders:" ] }, { "cell_type": "code", "execution_count": 3, "id": "b99cac23", "metadata": {}, "outputs": [], "source": [ "batch_size = 64\n", "\n", "# Create data loaders.\n", "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", "test_dataloader = DataLoader(test_data, batch_size=batch_size)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ae11399e", "metadata": {}, "source": [ "We can then define and instantiate the neural network:" ] }, { "cell_type": "code", "execution_count": 4, "id": "3b027562", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cpu device\n", "NeuralNetwork(\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=784, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " (4): Linear(in_features=512, out_features=10, bias=True)\n", " )\n", ")\n" ] } ], "source": [ "# Get cpu or gpu device for training.\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Using {device} device\")\n", "\n", "# Define model\n", "class NeuralNetwork(nn.Module):\n", " def __init__(self):\n", " super(NeuralNetwork, self).__init__()\n", " self.flatten = nn.Flatten()\n", " self.linear_relu_stack = nn.Sequential(\n", " nn.Linear(28*28, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 512),\n", " nn.ReLU(),\n", " nn.Linear(512, 10)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.flatten(x)\n", " logits = self.linear_relu_stack(x)\n", " return logits\n", "\n", "model = NeuralNetwork().to(device)\n", "print(model)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b692d06a", "metadata": {}, "source": [ "Define our optimizer and loss:" ] }, { "cell_type": "code", "execution_count": 5, "id": "efe92797", "metadata": {}, "outputs": [], "source": [ "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "681d5798", "metadata": {}, "source": [ "And finally our training loop. Note that we renamed the function from `train` to `train_epoch` to avoid conflicts with the Ray Train module later (which is also called `train`):" ] }, { "cell_type": "code", "execution_count": 6, "id": "2ce258ed", "metadata": {}, "outputs": [], "source": [ "def train_epoch(dataloader, model, loss_fn, optimizer):\n", " size = len(dataloader.dataset)\n", " model.train()\n", " for batch, (X, y) in enumerate(dataloader):\n", " X, y = X.to(device), y.to(device)\n", "\n", " # Compute prediction error\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", "\n", " # Backpropagation\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if batch % 100 == 0:\n", " loss, current = loss.item(), batch * len(X)\n", " print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6621cffa", "metadata": {}, "source": [ "And while we're at it, here is our validation loop (note that we sneaked in a `return test_loss` statement and also renamed the function):" ] }, { "cell_type": "code", "execution_count": 7, "id": "bbefec77", "metadata": {}, "outputs": [], "source": [ "def test_epoch(dataloader, model, loss_fn):\n", " size = len(dataloader.dataset)\n", " num_batches = len(dataloader)\n", " model.eval()\n", " test_loss, correct = 0, 0\n", " with torch.no_grad():\n", " for X, y in dataloader:\n", " X, y = X.to(device), y.to(device)\n", " pred = model(X)\n", " test_loss += loss_fn(pred, y).item()\n", " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", " test_loss /= num_batches\n", " correct /= size\n", " print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n", " return test_loss" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d915d788", "metadata": {}, "source": [ "Now we can trigger training and save a model:" ] }, { "cell_type": "code", "execution_count": 8, "id": "27f80fc7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1\n", "-------------------------------\n", "loss: 2.295566 [ 0/60000]\n", "loss: 2.291762 [ 6400/60000]\n", "loss: 2.268867 [12800/60000]\n", "loss: 2.262820 [19200/60000]\n", "loss: 2.256001 [25600/60000]\n", "loss: 2.204572 [32000/60000]\n", "loss: 2.225075 [38400/60000]\n", "loss: 2.184233 [44800/60000]\n", "loss: 2.182663 [51200/60000]\n", "loss: 2.154192 [57600/60000]\n", "Test Error: \n", " Accuracy: 36.5%, Avg loss: 2.146461 \n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 2.150961 [ 0/60000]\n", "loss: 2.147769 [ 6400/60000]\n", "loss: 2.085719 [12800/60000]\n", "loss: 2.107859 [19200/60000]\n", "loss: 2.066872 [25600/60000]\n", "loss: 1.978430 [32000/60000]\n", "loss: 2.029306 [38400/60000]\n", "loss: 1.939256 [44800/60000]\n", "loss: 1.951516 [51200/60000]\n", "loss: 1.881199 [57600/60000]\n", "Test Error: \n", " Accuracy: 55.0%, Avg loss: 1.879711 \n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 1.907144 [ 0/60000]\n", "loss: 1.879325 [ 6400/60000]\n", "loss: 1.765395 [12800/60000]\n", "loss: 1.815291 [19200/60000]\n", "loss: 1.708041 [25600/60000]\n", "loss: 1.641765 [32000/60000]\n", "loss: 1.687605 [38400/60000]\n", "loss: 1.581743 [44800/60000]\n", "loss: 1.615951 [51200/60000]\n", "loss: 1.507691 [57600/60000]\n", "Test Error: \n", " Accuracy: 62.3%, Avg loss: 1.523205 \n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 1.589735 [ 0/60000]\n", "loss: 1.549950 [ 6400/60000]\n", "loss: 1.404985 [12800/60000]\n", "loss: 1.479113 [19200/60000]\n", "loss: 1.362190 [25600/60000]\n", "loss: 1.348071 [32000/60000]\n", "loss: 1.376365 [38400/60000]\n", "loss: 1.297325 [44800/60000]\n", "loss: 1.336892 [51200/60000]\n", "loss: 1.234042 [57600/60000]\n", "Test Error: \n", " Accuracy: 63.8%, Avg loss: 1.255606 \n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 1.334560 [ 0/60000]\n", "loss: 1.311746 [ 6400/60000]\n", "loss: 1.151140 [12800/60000]\n", "loss: 1.254679 [19200/60000]\n", "loss: 1.132061 [25600/60000]\n", "loss: 1.149663 [32000/60000]\n", "loss: 1.179779 [38400/60000]\n", "loss: 1.117024 [44800/60000]\n", "loss: 1.159811 [51200/60000]\n", "loss: 1.072276 [57600/60000]\n", "Test Error: \n", " Accuracy: 65.0%, Avg loss: 1.088372 \n", "\n", "Done!\n" ] } ], "source": [ "epochs = 5\n", "for t in range(epochs):\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n", " train_epoch(train_dataloader, model, loss_fn, optimizer)\n", " test_epoch(test_dataloader, model, loss_fn)\n", "print(\"Done!\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "e62fc82b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Saved PyTorch Model State to model.pth\n" ] } ], "source": [ "torch.save(model.state_dict(), \"model.pth\")\n", "print(\"Saved PyTorch Model State to model.pth\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6655d903", "metadata": {}, "source": [ "We'll cover the rest of the tutorial (loading the model and doing batch prediction) later!" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d0b98b1c", "metadata": {}, "source": [ "## Introducing a wrapper function (no Ray Train, yet!)\n", "The notebook-style from the tutorial is great for tutorials, but in your production code you probably wrapped the actual training logic in a function. So let's do this here, too.\n", "\n", "Note that we do not add or alter any code here (apart from variable definitions) - we just take the loose bits of code in the current tutorial and put them into one function." ] }, { "cell_type": "code", "execution_count": 10, "id": "aacdf4a6", "metadata": {}, "outputs": [], "source": [ "def train_func():\n", " batch_size = 64\n", " lr = 1e-3\n", " epochs = 5\n", " \n", " # Create data loaders.\n", " train_dataloader = DataLoader(training_data, batch_size=batch_size)\n", " test_dataloader = DataLoader(test_data, batch_size=batch_size)\n", " \n", " # Get cpu or gpu device for training.\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", " print(f\"Using {device} device\")\n", " \n", " model = NeuralNetwork().to(device)\n", " print(model)\n", " \n", " loss_fn = nn.CrossEntropyLoss()\n", " optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", " \n", " for t in range(epochs):\n", " print(f\"Epoch {t+1}\\n-------------------------------\")\n", " train_epoch(train_dataloader, model, loss_fn, optimizer)\n", " test_epoch(test_dataloader, model, loss_fn)\n", "\n", " print(\"Done!\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "60f7341a", "metadata": {}, "source": [ "Let's see it in action again:" ] }, { "cell_type": "code", "execution_count": 11, "id": "7130c361", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cpu device\n", "NeuralNetwork(\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (linear_relu_stack): Sequential(\n", " (0): Linear(in_features=784, out_features=512, bias=True)\n", " (1): ReLU()\n", " (2): Linear(in_features=512, out_features=512, bias=True)\n", " (3): ReLU()\n", " (4): Linear(in_features=512, out_features=10, bias=True)\n", " )\n", ")\n", "Epoch 1\n", "-------------------------------\n", "loss: 2.311088 [ 0/60000]\n", "loss: 2.295296 [ 6400/60000]\n", "loss: 2.271576 [12800/60000]\n", "loss: 2.258537 [19200/60000]\n", "loss: 2.250895 [25600/60000]\n", "loss: 2.216462 [32000/60000]\n", "loss: 2.222296 [38400/60000]\n", "loss: 2.189997 [44800/60000]\n", "loss: 2.188647 [51200/60000]\n", "loss: 2.145895 [57600/60000]\n", "Test Error: \n", " Accuracy: 44.8%, Avg loss: 2.144711 \n", "\n", "Epoch 2\n", "-------------------------------\n", "loss: 2.164661 [ 0/60000]\n", "loss: 2.150512 [ 6400/60000]\n", "loss: 2.085597 [12800/60000]\n", "loss: 2.099732 [19200/60000]\n", "loss: 2.047274 [25600/60000]\n", "loss: 1.980986 [32000/60000]\n", "loss: 2.014364 [38400/60000]\n", "loss: 1.930184 [44800/60000]\n", "loss: 1.941903 [51200/60000]\n", "loss: 1.856329 [57600/60000]\n", "Test Error: \n", " Accuracy: 56.2%, Avg loss: 1.857978 \n", "\n", "Epoch 3\n", "-------------------------------\n", "loss: 1.901466 [ 0/60000]\n", "loss: 1.867397 [ 6400/60000]\n", "loss: 1.739829 [12800/60000]\n", "loss: 1.784509 [19200/60000]\n", "loss: 1.677714 [25600/60000]\n", "loss: 1.621924 [32000/60000]\n", "loss: 1.652736 [38400/60000]\n", "loss: 1.549752 [44800/60000]\n", "loss: 1.583215 [51200/60000]\n", "loss: 1.469457 [57600/60000]\n", "Test Error: \n", " Accuracy: 62.0%, Avg loss: 1.491323 \n", "\n", "Epoch 4\n", "-------------------------------\n", "loss: 1.564052 [ 0/60000]\n", "loss: 1.533092 [ 6400/60000]\n", "loss: 1.374619 [12800/60000]\n", "loss: 1.450151 [19200/60000]\n", "loss: 1.340597 [25600/60000]\n", "loss: 1.326336 [32000/60000]\n", "loss: 1.345804 [38400/60000]\n", "loss: 1.269192 [44800/60000]\n", "loss: 1.307673 [51200/60000]\n", "loss: 1.200916 [57600/60000]\n", "Test Error: \n", " Accuracy: 63.8%, Avg loss: 1.232803 \n", "\n", "Epoch 5\n", "-------------------------------\n", "loss: 1.311137 [ 0/60000]\n", "loss: 1.301159 [ 6400/60000]\n", "loss: 1.127901 [12800/60000]\n", "loss: 1.233908 [19200/60000]\n", "loss: 1.118969 [25600/60000]\n", "loss: 1.134692 [32000/60000]\n", "loss: 1.157277 [38400/60000]\n", "loss: 1.094546 [44800/60000]\n", "loss: 1.135308 [51200/60000]\n", "loss: 1.043909 [57600/60000]\n", "Test Error: \n", " Accuracy: 65.0%, Avg loss: 1.072193 \n", "\n", "Done!\n" ] } ], "source": [ "train_func()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b3df2581", "metadata": {}, "source": [ "The output should look very similar to the previous ouput." ] }, { "attachments": {}, "cell_type": "markdown", "id": "abe8e708", "metadata": {}, "source": [ "## Starting with Ray Train: Distribute the training\n", "\n", "As a first step, we want to distribute the training across multiple workers. For this we want to\n", "\n", "1. Use data-parallel training by sharding the training data\n", "2. Setup the model to communicate gradient updates across machines\n", "3. Report the results back to Ray Train.\n", "\n", "\n", "To facilitate this, we only need a few changes to the code:\n", "\n", "1. We import Ray Train:\n", "\n", " ```python\n", " import ray.train as train\n", " ```\n", "\n", "\n", "2. We use a `config` dict to configure some hyperparameters (this is not strictly needed but good practice, especially if you want to o hyperparameter tuning later):\n", "\n", " ```python\n", " def train_func(config: dict):\n", " batch_size = config[\"batch_size\"]\n", " lr = config[\"lr\"]\n", " epochs = config[\"epochs\"]\n", " ```\n", "\n", "3. We dynamically adjust the worker batch size according to the number of workers:\n", "\n", " ```python\n", " batch_size_per_worker = batch_size // train.get_context().get_world_size()\n", " ```\n", "\n", "4. We prepare the data loader for distributed data sharding:\n", "\n", " ```python\n", " train_dataloader = train.torch.prepare_data_loader(train_dataloader)\n", " test_dataloader = train.torch.prepare_data_loader(test_dataloader)\n", " ```\n", "\n", "5. We prepare the model for distributed gradient updates:\n", "\n", " ```python\n", " model = train.torch.prepare_model(model)\n", " ```\n", " :::{note}\n", " Note that `train.torch.prepare_model()` also automatically takes care of setting up devices (e.g. GPU training) - so we can get rid of those lines in our current code!\n", " :::\n", "\n", "6. We capture the validation loss and report it to Ray train:\n", "\n", " ```python\n", " test_loss = test(test_dataloader, model, loss_fn)\n", " train.report(dict(loss=test_loss))\n", " ```\n", "\n", "7. In the `train_epoch()` and `test_epoch()` functions we divide the `size` by the world size:\n", "\n", " ```python\n", " # Divide by word size\n", " size = len(dataloader.dataset) // train.get_context().get_world_size()\n", " ```\n", "\n", "8. In the `train_epoch()` function we can get rid of the device mapping. Ray Train does this for us:\n", "\n", " ```python\n", " # We don't need this anymore! Ray Train does this automatically:\n", " # X, y = X.to(device), y.to(device) \n", " ```\n", "\n", "That's it - you need less than 10 lines of Ray Train-specific code and can otherwise continue to use your original code.\n", "\n", "Let's take a look at the resulting code. First the `train_epoch()` function (2 lines changed, and we also commented out the print statement):" ] }, { "cell_type": "code", "execution_count": 12, "id": "50b2f602", "metadata": {}, "outputs": [], "source": [ "import ray.train\n", "\n", "def train_epoch(epoch, dataloader, model, loss_fn, optimizer):\n", " if ray.train.get_context().get_world_size() > 1:\n", " dataloader.sampler.set_epoch(epoch)\n", "\n", " model.train()\n", " for batch, (X, y) in enumerate(dataloader):\n", " # We don't need this anymore! Ray Train does this automatically:\n", " # X, y = X.to(device), y.to(device)\n", "\n", " # Compute prediction error\n", " pred = model(X)\n", " loss = loss_fn(pred, y)\n", "\n", " # Backpropagation\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if batch % 100 == 0:\n", " loss, current = loss.item(), batch * len(X)\n", " # print(f\"loss: {loss:>7f} [{current:>5d}/{size:>5d}]\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6e260f44", "metadata": {}, "source": [ "Then the `test_epoch()` function (1 line changed, and we also commented out the print statement):" ] }, { "cell_type": "code", "execution_count": 13, "id": "72aa3e48", "metadata": {}, "outputs": [], "source": [ "def test_epoch(dataloader, model, loss_fn):\n", " # Divide the dataset size by the world size to get the per-worker dataset size.\n", " size = len(dataloader.dataset) // ray.train.get_context().get_world_size()\n", " num_batches = len(dataloader)\n", " model.eval()\n", " test_loss, correct = 0, 0\n", " with torch.no_grad():\n", " for X, y in dataloader:\n", " X, y = X.to(device), y.to(device)\n", " pred = model(X)\n", " test_loss += loss_fn(pred, y).item()\n", " correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n", " test_loss /= num_batches\n", " correct /= size\n", " # print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")\n", " return test_loss" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cf280e6a", "metadata": {}, "source": [ "And lastly, the wrapping `train_func()` where we added 4 lines and modified 2 (apart from the config dict):" ] }, { "cell_type": "code", "execution_count": 14, "id": "3f79c731", "metadata": {}, "outputs": [], "source": [ "import ray.train as train\n", "from ray.train import Checkpoint\n", "\n", "def train_func(config: dict):\n", " batch_size = config[\"batch_size\"]\n", " lr = config[\"lr\"]\n", " epochs = config[\"epochs\"]\n", "\n", " batch_size_per_worker = batch_size // train.get_context().get_world_size()\n", "\n", " # Create data loaders.\n", " train_dataloader = DataLoader(\n", " training_data, batch_size=batch_size_per_worker, shuffle=True\n", " )\n", " test_dataloader = DataLoader(test_data, batch_size=batch_size_per_worker)\n", "\n", " train_dataloader = train.torch.prepare_data_loader(train_dataloader)\n", " test_dataloader = train.torch.prepare_data_loader(test_dataloader)\n", "\n", " model = NeuralNetwork()\n", " model = train.torch.prepare_model(model)\n", "\n", " loss_fn = nn.CrossEntropyLoss()\n", " optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", "\n", " for epoch in range(epochs):\n", " train_epoch(epoch, train_dataloader, model, loss_fn, optimizer)\n", " test_loss = test_epoch(test_dataloader, model, loss_fn)\n", "\n", " with TemporaryDirectory() as tmpdir:\n", " if train.get_context().get_world_rank() == 0:\n", " state_dict = dict(epoch=epoch, model=model.state_dict())\n", " torch.save(state_dict, os.path.join(tmpdir, \"checkpoint.bin\"))\n", " checkpoint = Checkpoint.from_directory(tmpdir)\n", " else:\n", " checkpoint = None\n", " train.report(dict(loss=test_loss), checkpoint=checkpoint)\n", "\n", " print(\"Done!\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "0fc52cc7", "metadata": {}, "source": [ "Now we'll use Ray Train's TorchTrainer to kick off the training. Note that we can set the hyperparameters here! In the `scaling_config` we can also configure how many parallel workers to use and if we want to enable GPU training or not." ] }, { "cell_type": "code", "execution_count": null, "id": "939e767f", "metadata": {}, "outputs": [], "source": [ "from ray.train.torch import TorchTrainer\n", "from ray.train import ScalingConfig\n", "\n", "\n", "trainer = TorchTrainer(\n", " train_loop_per_worker=train_func,\n", " train_loop_config={\"lr\": 1e-3, \"batch_size\": 64, \"epochs\": 4},\n", " scaling_config=ScalingConfig(num_workers=2, use_gpu=False),\n", ")\n", "result = trainer.fit()\n", "print(f\"Last result: {result.metrics}\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "341f4fd8", "metadata": {}, "source": [ "Great, this works! You're now training your model in parallel. You could now scale this up to more nodes and workers on your Ray cluster.\n", "\n", "But there are a few improvements we can make to the code in order to get the most of the system. For one, we should enable **checkpointing** to get access to the trained model afterwards. Additionally, we should optimize the **data loading** to take place within the workers." ] }, { "attachments": {}, "cell_type": "markdown", "id": "3bbe06f3", "metadata": {}, "source": [ "### Enabling checkpointing to retrieve the model\n", "Enabling checkpointing is pretty easy - we just need to pass a `Checkpoint` object with the model state to the `ray.train.report()` API.\n", "\n", "```python\n", " from ray import train\n", " from ray.train import Checkpoint\n", "\n", " with TemporaryDirectory() as tmpdir:\n", " torch.save(\n", " {\n", " \"epoch\": epoch,\n", " \"model\": model.module.state_dict()\n", " },\n", " os.path.join(tmpdir, \"checkpoint.pt\")\n", " )\n", " train.report(dict(loss=test_loss), checkpoint=Checkpoint.from_directory(tmpdir))\n", "```\n", "\n", "### Move the data loader to the training function\n", "\n", "You may have noticed a warning: `Warning: The actor TrainTrainable is very large (52 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.`.\n", "\n", "This is because we load the data outside the training function. Ray then serializes it to make it accessible to the remote tasks (that may get executed on a remote node!). This is not too bad with just 52 MB of data, but imagine this were a full image dataset - you wouldn't want to ship this around the cluster unnecessarily. Instead, you should move the dataset loading part into the `train_func()`. This will then download the data *to disk* once per machine and result in much more efficient data loading.\n", "\n", "Make sure to wrap the data downloading logic in a `FileLock` -- otherwise you may run into data corruption issues if multiple workers are downloading the data to the same location simultaneously.\n", "\n", "The result looks like this:" ] }, { "cell_type": "code", "execution_count": 16, "id": "059953f8", "metadata": {}, "outputs": [], "source": [ "from filelock import FileLock\n", "\n", "from ray.train import Checkpoint\n", "\n", "def load_data():\n", " # Download training data from open datasets.\n", " with FileLock(\"./data.lock\"):\n", " training_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=True,\n", " download=True,\n", " transform=ToTensor(),\n", " )\n", "\n", " # Download test data from open datasets.\n", " test_data = datasets.FashionMNIST(\n", " root=\"data\",\n", " train=False,\n", " download=True,\n", " transform=ToTensor(),\n", " )\n", " return training_data, test_data\n", "\n", "\n", "def train_func(config: dict):\n", " batch_size = config[\"batch_size\"]\n", " lr = config[\"lr\"]\n", " epochs = config[\"epochs\"]\n", "\n", " batch_size_per_worker = batch_size // train.get_context().get_world_size()\n", "\n", " training_data, test_data = load_data() # <- this is new!\n", "\n", " # Create data loaders.\n", " train_dataloader = DataLoader(training_data, batch_size=batch_size_per_worker)\n", " test_dataloader = DataLoader(test_data, batch_size=batch_size_per_worker)\n", "\n", " train_dataloader = train.torch.prepare_data_loader(train_dataloader)\n", " test_dataloader = train.torch.prepare_data_loader(test_dataloader)\n", "\n", " model = NeuralNetwork()\n", " model = train.torch.prepare_model(model)\n", "\n", " loss_fn = nn.CrossEntropyLoss()\n", " optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n", "\n", " for epoch in range(epochs):\n", " train_epoch(epoch, train_dataloader, model, loss_fn, optimizer)\n", " test_loss = test_epoch(test_dataloader, model, loss_fn)\n", " with TemporaryDirectory() as tmpdir:\n", " torch.save(\n", " {\n", " \"epoch\": epoch,\n", " \"model\": model.module.state_dict()\n", " },\n", " os.path.join(tmpdir, \"checkpoint.pt\")\n", " )\n", " train.report(dict(loss=test_loss), checkpoint=Checkpoint.from_directory(tmpdir))\n", "\n", " print(\"Done!\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "d2af219d", "metadata": {}, "source": [ "Let's train again:" ] }, { "cell_type": "code", "execution_count": null, "id": "de249d61", "metadata": {}, "outputs": [], "source": [ "trainer = TorchTrainer(\n", " train_loop_per_worker=train_func,\n", " train_loop_config={\"lr\": 1e-3, \"batch_size\": 64, \"epochs\": 4},\n", " scaling_config=ScalingConfig(num_workers=2, use_gpu=False),\n", ")\n", "result = trainer.fit()\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "534ed4df", "metadata": {}, "source": [ "We can see our results here:" ] }, { "cell_type": "code", "execution_count": 18, "id": "b81ca48f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Last result: {'loss': 1.215654496934004, '_timestamp': 1657734050, '_time_this_iter_s': 10.695234060287476, '_training_iteration': 4, 'time_this_iter_s': 10.697366952896118, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 4, 'trial_id': 'b43fc_00000', 'experiment_id': '3b3c6e36d57a4e7993aacdbe6cd4c8ed', 'date': '2022-07-13_10-40-50', 'timestamp': 1657734050, 'time_total_s': 96.68163204193115, 'pid': 65706, 'hostname': 'Jiaos-MacBook-Pro-16-inch-2019', 'node_ip': '127.0.0.1', 'config': {}, 'time_since_restore': 96.68163204193115, 'timesteps_since_restore': 0, 'iterations_since_restore': 4, 'warmup_time': 0.0036132335662841797, 'experiment_tag': '0'}\n", "Checkpoint: \n" ] } ], "source": [ "print(f\"Last result: {result.metrics}\")\n", "print(f\"Checkpoint: {result.checkpoint}\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2963e1f7", "metadata": {}, "source": [ "## Summary\n", "\n", "This tutorial demonstrated how to turn your existing PyTorch code into code you can use with Ray Train.\n", "\n", "We learned how to\n", "- enable distributed training using Ray Train abstractions\n", "- save and retrieve model checkpoints via Ray Train\n", "\n", "In our {ref}`other examples ` you can learn how to do more things with the Ray libraries, such as **serving your model with Ray Serve** or **tune your hyperparameters with Ray Tune.** You can also learn how to perform {ref}`offline batch inference ` with Ray Data.\n", "\n", "We hope this tutorial gave you a good starting point to leverage Ray Train. If you have any questions, suggestions, or run into any problems please reach out on [Discuss](https://discuss.ray.io/) or [GitHub](https://github.com/ray-project/ray)!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "orphan": true, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } } }, "nbformat": 4, "nbformat_minor": 5 }