{ "cells": [ { "cell_type": "markdown", "id": "c6962854", "metadata": {}, "source": [ "# Training a Torch Image Classifier\n", "\n", "This tutorial shows you how to train an image classifier using the [Ray AI Runtime](air) (AIR).\n", "\n", "You should be familiar with [PyTorch](https://pytorch.org/) before starting the tutorial. If you need a refresher, read PyTorch's [training a classifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) tutorial.\n", "\n", "## Before you begin\n", "\n", "* Install the [Ray AI Runtime](air). You need Ray 2.0 or later to run this example." ] }, { "cell_type": "code", "execution_count": 1, "id": "d806ba6b", "metadata": {}, "outputs": [], "source": [ "!pip install 'ray[air]'" ] }, { "cell_type": "markdown", "id": "6d588ce2", "metadata": {}, "source": [ "* Install `requests`, `torch`, and `torchvision`." ] }, { "cell_type": "code", "execution_count": 2, "id": "77a70a7a", "metadata": {}, "outputs": [], "source": [ "!pip install requests torch torchvision" ] }, { "cell_type": "markdown", "id": "f18ec14f", "metadata": {}, "source": [ "## Load and normalize CIFAR-10\n", "\n", "We'll train our classifier on a popular image dataset called [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html).\n", "\n", "First, let's load CIFAR-10 into a Ray Dataset." ] }, { "cell_type": "code", "execution_count": 3, "id": "d3f2e890", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 170498071/170498071 [00:21<00:00, 7792736.24it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Extracting data/cifar-10-python.tar.gz to data\n", "Files already downloaded and verified\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2022-10-23 10:33:48,403\tINFO worker.py:1518 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n" ] } ], "source": [ "import ray\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "\n", "train_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=True)\n", "test_dataset = torchvision.datasets.CIFAR10(\"data\", download=True, train=False)\n", "\n", "train_dataset: ray.data.Dataset = ray.data.from_torch(train_dataset)\n", "test_dataset: ray.data.Dataset = ray.data.from_torch(test_dataset)" ] }, { "cell_type": "code", "execution_count": 4, "id": "a2e7db56", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5d97a30cd75b40208a984ffa63cfecff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HTML(value='
Trial name | status | loc | iter | total time (s) | running_loss | _timestamp | _time_this_iter_s |
---|---|---|---|---|---|---|---|
TorchTrainer_6799a_00000 | TERMINATED | 127.0.0.1:3978 | 2 | 43.7121 | 595.445 | 1661898697 | 20.8503 |