Training a Torch Image Classifier#

This tutorial shows you how to train an image classifier using the Ray AI Runtime (AIR).

You should be familiar with PyTorch before starting the tutorial. If you need a refresher, read PyTorch’s training a classifier tutorial.

Before you begin#

  • Install the Ray AI Runtime. You need Ray 2.0 or later to run this example.

!pip install 'ray[air]'
  • Install requests, torch, and torchvision.

!pip install requests torch torchvision

Load and normalize CIFAR-10#

We’ll train our classifier on a popular image dataset called CIFAR-10.

First, let’s load CIFAR-10 into a Ray Dataset.

import ray
import torchvision
import torchvision.transforms as transforms

train_dataset = torchvision.datasets.CIFAR10("data", download=True, train=True)
test_dataset = torchvision.datasets.CIFAR10("data", download=True, train=False)

train_dataset: ray.data.Dataset = ray.data.from_torch(train_dataset)
test_dataset: ray.data.Dataset = ray.data.from_torch(test_dataset)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 170498071/170498071 [00:21<00:00, 7792736.24it/s]
Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
2022-10-23 10:33:48,403	INFO worker.py:1518 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
train_dataset

from_torch doesn’t parallelize reads, so you shouldn’t use it with larger datasets.

Next, let’s represent our data using a dictionary of ndarrays instead of tuples. This lets us call Dataset.iter_torch_batches later in the tutorial.

from typing import Dict, Tuple
import numpy as np
from PIL.Image import Image
import torch


def convert_batch_to_numpy(batch: Tuple[Image, int]) -> Dict[str, np.ndarray]:
    images = np.stack([np.array(image) for image, _ in batch])
    labels = np.array([label for _, label in batch])
    return {"image": images, "label": labels}


train_dataset = train_dataset.map_batches(convert_batch_to_numpy).cache()
test_dataset = test_dataset.map_batches(convert_batch_to_numpy).cache()
Read->Map_Batches:   0%|          | 0/1 [00:00<?, ?it/s]
(_map_block_nosplit pid=3958) Files already downloaded and verified
Read->Map_Batches: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:04<00:00,  4.27s/it]
Read->Map_Batches:   0%|          | 0/1 [00:00<?, ?it/s]
(_map_block_nosplit pid=3958) Files already downloaded and verified
Read->Map_Batches: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:01<00:00,  1.40s/it]
train_dataset

Train a convolutional neural network#

Now that we’ve created our datasets, let’s define the training logic.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

We define our training logic in a function called train_loop_per_worker. This function contains regular PyTorch code with a few notable exceptions:

from ray import train
from ray.air import session, Checkpoint
from ray.train.torch import TorchCheckpoint
import torch.nn as nn
import torch.optim as optim
import torchvision


def train_loop_per_worker(config):
    model = train.torch.prepare_model(Net())

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_dataset_shard = session.get_dataset_shard("train")

    for epoch in range(2):
        running_loss = 0.0
        train_dataset_batches = train_dataset_shard.iter_torch_batches(
            batch_size=config["batch_size"],
        )
        for i, batch in enumerate(train_dataset_batches):
            # get the inputs and labels
            inputs, labels = batch["image"], batch["label"]

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
                running_loss = 0.0

        metrics = dict(running_loss=running_loss)
        checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
        session.report(metrics, checkpoint=checkpoint)

To improve our model’s accuracy, we’ll also define a Preprocessor to normalize the images.

from ray.data.preprocessors import TorchVisionPreprocessor

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
preprocessor = TorchVisionPreprocessor(columns=["image"], transform=transform)

Finally, we can train our model. This should take a few minutes to run.

from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config=ScalingConfig(num_workers=2),
    preprocessor=preprocessor
)
result = trainer.fit()
latest_checkpoint = result.checkpoint
== Status ==
Current time: 2022-08-30 15:31:37 (running for 00:00:45.17)
Memory usage on this node: 16.9/32.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/10 CPUs, 0/0 GPUs, 0.0/14.83 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/bveeramani/ray_results/TorchTrainer_2022-08-30_15-30-52
Number of trials: 1/1 (1 TERMINATED)
Trial name status loc iter total time (s) running_loss _timestamp _time_this_iter_s
TorchTrainer_6799a_00000TERMINATED127.0.0.1:3978 2 43.7121 595.445 1661898697 20.8503


(RayTrainWorker pid=3979) 2022-08-30 15:30:54,566	INFO config.py:71 -- Setting up process group for: env:// [rank=0, world_size=2]
(RayTrainWorker pid=3979) 2022-08-30 15:30:55,727	INFO train_loop_utils.py:300 -- Moving model to device: cpu
(RayTrainWorker pid=3979) 2022-08-30 15:30:55,728	INFO train_loop_utils.py:347 -- Wrapping provided model in DDP.
(RayTrainWorker pid=3980) [1,  2000] loss: 2.276
(RayTrainWorker pid=3979) [1,  2000] loss: 2.270
(RayTrainWorker pid=3980) [1,  4000] loss: 1.964
(RayTrainWorker pid=3979) [1,  4000] loss: 1.936
(RayTrainWorker pid=3980) [1,  6000] loss: 1.753
(RayTrainWorker pid=3979) [1,  6000] loss: 1.754
(RayTrainWorker pid=3980) [1,  8000] loss: 1.638
(RayTrainWorker pid=3979) [1,  8000] loss: 1.661
(RayTrainWorker pid=3980) [1, 10000] loss: 1.586
(RayTrainWorker pid=3979) [1, 10000] loss: 1.547
(RayTrainWorker pid=3980) [1, 12000] loss: 1.489
(RayTrainWorker pid=3979) [1, 12000] loss: 1.476
Result for TorchTrainer_6799a_00000:
  _time_this_iter_s: 20.542800188064575
  _timestamp: 1661898676
  _training_iteration: 1
  date: 2022-08-30_15-31-16
  done: false
  experiment_id: c25700542bc348dbbeaf54e46f1fc84c
  hostname: MBP.local.meter
  iterations_since_restore: 1
  node_ip: 127.0.0.1
  pid: 3978
  running_loss: 687.5853321105242
  should_checkpoint: true
  time_since_restore: 22.880314111709595
  time_this_iter_s: 22.880314111709595
  time_total_s: 22.880314111709595
  timestamp: 1661898676
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 6799a_00000
  warmup_time: 0.0025300979614257812
  
(RayTrainWorker pid=3980) [2,  2000] loss: 1.417
(RayTrainWorker pid=3979) [2,  2000] loss: 1.431
(RayTrainWorker pid=3980) [2,  4000] loss: 1.403
(RayTrainWorker pid=3979) [2,  4000] loss: 1.404
(RayTrainWorker pid=3980) [2,  6000] loss: 1.394
(RayTrainWorker pid=3979) [2,  6000] loss: 1.368
(RayTrainWorker pid=3980) [2,  8000] loss: 1.343
(RayTrainWorker pid=3979) [2,  8000] loss: 1.363
(RayTrainWorker pid=3980) [2, 10000] loss: 1.340
(RayTrainWorker pid=3979) [2, 10000] loss: 1.297
(RayTrainWorker pid=3980) [2, 12000] loss: 1.253
(RayTrainWorker pid=3979) [2, 12000] loss: 1.276
Result for TorchTrainer_6799a_00000:
  _time_this_iter_s: 20.850306034088135
  _timestamp: 1661898697
  _training_iteration: 2
  date: 2022-08-30_15-31-37
  done: false
  experiment_id: c25700542bc348dbbeaf54e46f1fc84c
  hostname: MBP.local.meter
  iterations_since_restore: 2
  node_ip: 127.0.0.1
  pid: 3978
  running_loss: 595.4451928250492
  should_checkpoint: true
  time_since_restore: 43.71214985847473
  time_this_iter_s: 20.831835746765137
  time_total_s: 43.71214985847473
  timestamp: 1661898697
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: 6799a_00000
  warmup_time: 0.0025300979614257812
  
Result for TorchTrainer_6799a_00000:
  _time_this_iter_s: 20.850306034088135
  _timestamp: 1661898697
  _training_iteration: 2
  date: 2022-08-30_15-31-37
  done: true
  experiment_id: c25700542bc348dbbeaf54e46f1fc84c
  experiment_tag: '0'
  hostname: MBP.local.meter
  iterations_since_restore: 2
  node_ip: 127.0.0.1
  pid: 3978
  running_loss: 595.4451928250492
  should_checkpoint: true
  time_since_restore: 43.71214985847473
  time_this_iter_s: 20.831835746765137
  time_total_s: 43.71214985847473
  timestamp: 1661898697
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: 6799a_00000
  warmup_time: 0.0025300979614257812
  
2022-08-30 15:31:37,386	INFO tune.py:758 -- Total run time: 45.32 seconds (45.16 seconds for the tuning loop).

To scale your training script, create a Ray Cluster and increase the number of workers. If your cluster contains GPUs, add "use_gpu": True to your scaling config.

scaling_config=ScalingConfig(num_workers=8, use_gpu=True)

Test the network on the test data#

Let’s see how our model performs.

To classify images in the test dataset, we’ll need to create a Predictor.

Predictors load data from checkpoints and efficiently perform inference. In contrast to TorchPredictor, which performs inference on a single batch, BatchPredictor performs inference on an entire dataset. Because we want to classify all of the images in the test dataset, we’ll use a BatchPredictor.

from ray.train.torch import TorchPredictor
from ray.train.batch_predictor import BatchPredictor

batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TorchPredictor,
    model=Net(),
)

outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset,
    dtype=torch.float,
    feature_columns=["image"],
    keep_columns=["label"],
)
Map Progress (1 actors 1 pending): 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:01<00:00,  1.59s/it]

Our model outputs a list of energies for each class. To classify an image, we choose the class that has the highest energy.

import numpy as np


def convert_logits_to_classes(df):
    best_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = best_class
    return df[["prediction", "label"]]


predictions = outputs.map_batches(convert_logits_to_classes)

predictions.show(1)
Map_Batches: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 59.42it/s]
{'prediction': 3, 'label': 3}

Now that we’ve classified all of the images, let’s figure out which images were classified correctly. The predictions dataset contains predicted labels and the test_dataset contains the true labels. To determine whether an image was classified correctly, we join the two datasets and check if the predicted labels are the same as the actual labels.

def calculate_prediction_scores(df):
    df["correct"] = df["prediction"] == df["label"]
    return df


scores = predictions.map_batches(calculate_prediction_scores)

scores.show(1)
Map_Batches: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 132.06it/s]
{'prediction': 3, 'label': 3, 'correct': True}

To compute our test accuracy, we’ll count how many images the model classified correctly and divide that number by the total number of test images.

scores.sum(on="correct") / scores.count()
Shuffle Map: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 152.00it/s]
Shuffle Reduce: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [00:00<00:00, 219.54it/s]
0.557

Deploy the network and make a prediction#

Our model seems to perform decently, so let’s deploy the model to an endpoint. This allows us to make predictions over the Internet.

from ray import serve
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import json_to_ndarray


serve.run(
    PredictorDeployment.bind(
        TorchPredictor,
        latest_checkpoint,
        model=Net(),
        http_adapter=json_to_ndarray,
    )
)
(ServeController pid=3987) INFO 2022-08-30 15:31:39,948 controller 3987 http_state.py:129 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-4b114e48c80d3549aa5da89fa16707e0334a0bafde984fd8b8618e47' on node '4b114e48c80d3549aa5da89fa16707e0334a0bafde984fd8b8618e47' listening on '127.0.0.1:8000'
(HTTPProxyActor pid=3988) INFO:     Started server process [3988]
(ServeController pid=3987) INFO 2022-08-30 15:31:40,567 controller 3987 deployment_state.py:1232 - Adding 1 replica to deployment 'PredictorDeployment'.
RayServeSyncHandle(deployment='PredictorDeployment')

Let’s classify a test image.

image = test_dataset.take(1)[0]["image"]

You can perform inference against a deployed model by posting a dictionary with an "array" key. To learn more about the default input schema, read the NdArray documentation.

import requests

payload = {"array": image.tolist(), "dtype": "float32"}
response = requests.post("http://localhost:8000/", json=payload)
response.json()
[-1.1342155933380127,
 -1.854529857635498,
 1.2062205076217651,
 2.6219608783721924,
 0.5199968218803406,
 2.2016565799713135,
 0.9447429180145264,
 -0.5387609004974365,
 -1.9515650272369385,
 -1.676588773727417]
(HTTPProxyActor pid=3988) INFO 2022-08-30 15:31:41,713 http_proxy 127.0.0.1 http_proxy.py:315 - POST / 200 12.9ms
(ServeReplica:PredictorDeployment pid=3995) INFO 2022-08-30 15:31:41,712 PredictorDeployment PredictorDeployment#pTPSPE replica.py:482 - HANDLE __call__ OK 9.9ms