Training a Torch Classifier

This tutorial demonstrates how to train an image classifier using the Ray AI Runtime (AIR).

You should be familiar with PyTorch before starting the tutorial. If you need a refresher, read PyTorch’s training a classifier tutorial.

Before you begin

  • Install the Ray AI Runtime. You’ll need Ray 1.13 later to run this example.

!pip install 'ray[air]'
  • Install requests, torch, and torchvision

!pip install requests torch torchvision

Load and normalize CIFAR-10

We’ll train our classifier on a popular image dataset called CIFAR-10.

First, let’s load CIFAR-10 into a Ray Dataset.

import ray
from ray.data.datasource import SimpleTorchDatasource
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def train_dataset_factory():
    return torchvision.datasets.CIFAR10(root="./data", download=True, train=True, transform=transform)

def test_dataset_factory():
    return torchvision.datasets.CIFAR10(root="./data", download=True, train=False, transform=transform)

train_dataset: ray.data.Dataset = ray.data.read_datasource(SimpleTorchDatasource(), dataset_factory=train_dataset_factory)
test_dataset: ray.data.Dataset = ray.data.read_datasource(SimpleTorchDatasource(), dataset_factory=test_dataset_factory)
2022-05-26 14:49:27,034	INFO services.py:1477 -- View the Ray dashboard at http://127.0.0.1:8265
2022-05-26 14:49:29,044	WARNING read_api.py:253 -- The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.
(_prepare_read pid=13653) 2022-05-26 14:49:29,041	WARNING torch_datasource.py:55 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.
(_execute_read_task pid=13653) Files already downloaded and verified
2022-05-26 14:49:46,308	WARNING read_api.py:253 -- The number of blocks in this dataset (1) limits its parallelism to 1 concurrent tasks. This is much less than the number of available CPU slots in the cluster. Use `.repartition(n)` to increase the number of dataset blocks.
(_prepare_read pid=13653) 2022-05-26 14:49:46,305	WARNING torch_datasource.py:55 -- `SimpleTorchDatasource` doesn't support parallel reads. The `parallelism` argument will be ignored.
(_execute_read_task pid=13653) Files already downloaded and verified
train_dataset
Dataset(num_blocks=1, num_rows=50000, schema=<class 'tuple'>)

Note that SimpleTorchDatasource loads all data into memory, so you shouldn’t use it with larger datasets.

Next, let’s represent our data using pandas dataframes instead of tuples. This lets us call methods like Dataset.to_torch later in the tutorial.

from typing import Tuple
import pandas as pd
from ray.data.extensions import TensorArray
import torch


def convert_batch_to_pandas(batch: Tuple[torch.Tensor, int]) -> pd.DataFrame:
    images = [TensorArray(image.numpy()) for image, _ in batch]
    labels = [label for _, label in batch]

    df = pd.DataFrame({"image": images, "label": labels})

    return df


train_dataset = train_dataset.map_batches(convert_batch_to_pandas)
test_dataset = test_dataset.map_batches(convert_batch_to_pandas)
[dataset]: Run `pip install tqdm` to enable progress reporting.
(_map_block_nosplit pid=13653) Files already downloaded and verified
(_map_block_nosplit pid=13653) Files already downloaded and verified
train_dataset
Dataset(num_blocks=1, num_rows=50000, schema={image: object, label: int64})

Train a convolutional neural network

Now that we’ve created our datasets, let’s define the training logic.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

We define our training logic in a function called train_loop_per_worker.

train_loop_per_worker contains regular PyTorch code with a few notable exceptions:

from ray import train
import torch.optim as optim


def train_loop_per_worker(config):
    model = train.torch.prepare_model(Net())

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    train_dataset_shard: torch.utils.data.Dataset = train.get_dataset_shard("train").to_torch(
        feature_columns=["image"],
        label_column="label",
        batch_size=config["batch_size"],
        unsqueeze_feature_tensors=False,
        unsqueeze_label_tensor=False
    )

    for epoch in range(2):
        running_loss = 0.0
        for i, data in enumerate(train_dataset_shard):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
                running_loss = 0.0

        train.save_checkpoint(model=model.module.state_dict())

Finally, we can train our model. This should take a few minutes to run.

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"batch_size": 2},
    datasets={"train": train_dataset},
    scaling_config={"num_workers": 2}
)
result = trainer.fit()
latest_checkpoint = result.checkpoint
== Status ==
Current time: 2022-05-26 14:52:09 (running for 00:02:01.90)
Memory usage on this node: 16.6/64.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/44.98 GiB heap, 0.0/2.0 GiB objects
Result logdir: /ray_results/TorchTrainer_2022-05-26_14-50-07
Number of trials: 1/1 (1 TERMINATED)
Trial name status loc
TorchTrainer_cf234_00000TERMINATED127.0.0.1:13741


(BaseWorkerMixin pid=13750) 2022-05-26 14:50:12,654	INFO torch.py:346 -- Setting up process group for: env:// [rank=1, world_size=2]
(BaseWorkerMixin pid=13750) [W ProcessGroupGloo.cpp:715] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
(BaseWorkerMixin pid=13749) 2022-05-26 14:50:12,652	INFO torch.py:346 -- Setting up process group for: env:// [rank=0, world_size=2]
(BaseWorkerMixin pid=13749) [W ProcessGroupGloo.cpp:715] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
(BaseWorkerMixin pid=13750) 2022-05-26 14:50:16,045	INFO torch.py:98 -- Moving model to device: cpu
(BaseWorkerMixin pid=13750) 2022-05-26 14:50:16,045	INFO torch.py:132 -- Wrapping provided model in DDP.
(BaseWorkerMixin pid=13749) 2022-05-26 14:50:16,045	INFO torch.py:98 -- Moving model to device: cpu
(BaseWorkerMixin pid=13749) 2022-05-26 14:50:16,045	INFO torch.py:132 -- Wrapping provided model in DDP.
(BaseWorkerMixin pid=13750) /GitHub/ray/python/ray/ml/utils/torch_utils.py:64: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /Users/distiller/project/pytorch/torch/csrc/utils/tensor_numpy.cpp:178.)
(BaseWorkerMixin pid=13750)   return torch.as_tensor(vals, dtype=dtype)
(BaseWorkerMixin pid=13749) /GitHub/ray/python/ray/ml/utils/torch_utils.py:64: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /Users/distiller/project/pytorch/torch/csrc/utils/tensor_numpy.cpp:178.)
(BaseWorkerMixin pid=13749)   return torch.as_tensor(vals, dtype=dtype)
(BaseWorkerMixin pid=13750) [1,  2000] loss: 2.208
(BaseWorkerMixin pid=13749) [1,  2000] loss: 2.198
(BaseWorkerMixin pid=13750) [1,  4000] loss: 1.906
(BaseWorkerMixin pid=13749) [1,  4000] loss: 1.876
(BaseWorkerMixin pid=13750) [1,  6000] loss: 1.718
(BaseWorkerMixin pid=13749) [1,  6000] loss: 1.736
(BaseWorkerMixin pid=13750) [1,  8000] loss: 1.641
(BaseWorkerMixin pid=13749) [1,  8000] loss: 1.658
(BaseWorkerMixin pid=13750) [1, 10000] loss: 1.586
(BaseWorkerMixin pid=13749) [1, 10000] loss: 1.547
(BaseWorkerMixin pid=13750) [1, 12000] loss: 1.488
(BaseWorkerMixin pid=13749) [1, 12000] loss: 1.494
(BaseWorkerMixin pid=13750) [2,  2000] loss: 1.417
(BaseWorkerMixin pid=13749) [2,  2000] loss: 1.452
(BaseWorkerMixin pid=13750) [2,  4000] loss: 1.413
(BaseWorkerMixin pid=13749) [2,  4000] loss: 1.409
(BaseWorkerMixin pid=13750) [2,  6000] loss: 1.397
(BaseWorkerMixin pid=13749) [2,  6000] loss: 1.372
(BaseWorkerMixin pid=13750) [2,  8000] loss: 1.361
(BaseWorkerMixin pid=13749) [2,  8000] loss: 1.382
(BaseWorkerMixin pid=13750) [2, 10000] loss: 1.339
(BaseWorkerMixin pid=13749) [2, 10000] loss: 1.309
(BaseWorkerMixin pid=13750) [2, 12000] loss: 1.276
(BaseWorkerMixin pid=13749) [2, 12000] loss: 1.285
2022-05-26 14:52:09,873	ERROR checkpoint_manager.py:189 -- Result dict has no key: training_iteration. checkpoint_score_attr must be set to a key of the result dict. Valid keys are ['trial_id', 'experiment_id', 'date', 'timestamp', 'pid', 'hostname', 'node_ip', 'config', 'done']
Trial TorchTrainer_cf234_00000 completed. Last result: 
2022-05-26 14:52:09,986	INFO tune.py:752 -- Total run time: 122.04 seconds (121.90 seconds for the tuning loop).

To scale your training script, create a Ray Cluster and increase the number of workers. If your cluster contains GPUs, add "use_gpu": True to your scaling config.

scaling_config={"num_workers": 8, "use_gpu": True}

Test the network on the test data

Let’s see how our model performs.

To classify images in the test dataset, we’ll need to create a Predictor.

Predictors load data from checkpoints and efficiently perform inference. In contrast to TorchPredictor, which performs inference on a single batch, BatchPredictor performs inference on an entire dataset. Because we want to classify all of the images in the test dataset, we’ll use a BatchPredictor.

from ray.train.torch import TorchPredictor
from ray.train.batch_predictor import BatchPredictor

batch_predictor = BatchPredictor.from_checkpoint(
    checkpoint=latest_checkpoint,
    predictor_cls=TorchPredictor,
    model=Net(),
)

outputs: ray.data.Dataset = batch_predictor.predict(
    data=test_dataset, feature_columns=["image"], unsqueeze=False
)
(BlockWorker pid=13962) /GitHub/ray/python/ray/ml/utils/torch_utils.py:64: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /Users/distiller/project/pytorch/torch/csrc/utils/tensor_numpy.cpp:178.)
(BlockWorker pid=13962)   return torch.as_tensor(vals, dtype=dtype)

Our model outputs a list of energies for each class. To classify an image, we choose the class that has the highest energy.

import numpy as np

def convert_logits_to_classes(df):
    best_class = df["predictions"].map(lambda x: x.argmax())
    df["prediction"] = best_class
    return df[["prediction"]]

predictions = outputs.map_batches(
    convert_logits_to_classes, batch_format="pandas"
)

predictions.show(1)
{'prediction': 3}

Now that we’ve classified all of the images, let’s figure out which images were classified correctly. The predictions dataset contains predicted labels and the test_dataset contains the true labels. To determine whether an image was classified correctly, we join the two datasets and check if the predicted labels are the same as the actual labels.

def calculate_prediction_scores(df):
    df["correct"] = df["prediction"] == df["label"]
    return df[["prediction", "label", "correct"]]

scores = test_dataset.zip(predictions).map_batches(calculate_prediction_scores)

scores.show(1)
{'prediction': 3, 'label': 3, 'correct': True}

To compute our test accuracy, we’ll count how many images the model classified correctly and divide that number by the total number of test images.

scores.sum(on="correct") / scores.count()
0.5531

Deploy the network and make a prediction

Our model seems to perform decently, so let’s deploy the model to an endpoint. This’ll allow us to make predictions over the Internet.

from ray import serve
from ray.serve.model_wrappers import ModelWrapperDeployment

serve.start(detached=True)
deployment = ModelWrapperDeployment.options(name="my-deployment")
deployment.deploy(TorchPredictor, latest_checkpoint, batching_params=False, model=Net())
(ServeController pid=13967) INFO 2022-05-26 14:52:14,630 controller 13967 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.
(ServeController pid=13967) INFO 2022-05-26 14:52:14,633 controller 13967 http_state.py:112 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-node:127.0.0.1-0' on node 'node:127.0.0.1-0' listening on '127.0.0.1:8000'
(HTTPProxyActor pid=13969) INFO:     Started server process [13969]
(ServeController pid=13967) INFO 2022-05-26 14:52:16,241 controller 13967 deployment_state.py:1218 - Adding 1 replicas to deployment 'my-deployment'.

Let’s classify a test image.

batch = test_dataset.take(1)
array = np.expand_dims(np.array(batch[0]["image"]), axis=0)
array.shape
(1, 3, 32, 32)

You can perform inference against a deployed model by posting a dictionary with an "array" key. To learn more about the default input schema, read the NdArray documentation.

import requests

payload = {"array": array.tolist()}
response = requests.post(deployment.url, json=payload)
response.json()
{'predictions': {'0': [-1.1721627712249756,
   -1.2344744205474854,
   -0.0395149365067482,
   2.5982346534729004,
   -0.7517635822296143,
   1.6971060037612915,
   -0.27467942237854004,
   -0.8857517242431641,
   1.4102720022201538,
   -1.8619050979614258]}}
(HTTPProxyActor pid=13969) INFO 2022-05-26 14:52:18,593 http_proxy 127.0.0.1 http_proxy.py:315 - POST /my-deployment 307 4.9ms
(HTTPProxyActor pid=13969) INFO 2022-05-26 14:52:18,616 http_proxy 127.0.0.1 http_proxy.py:315 - POST /my-deployment 200 20.6ms
(my-deployment pid=13971) INFO 2022-05-26 14:52:18,591 my-deployment my-deployment#HdSekn replica.py:478 - HANDLE __call__ OK 0.3ms
(my-deployment pid=13971) INFO 2022-05-26 14:52:18,615 my-deployment my-deployment#HdSekn replica.py:478 - HANDLE __call__ OK 17.5ms