Convert existing PyTorch code to Ray AIRΒΆ

If you already have working PyTorch code, you don’t have to start from scratch to utilize the benefits of Ray AIR. Instead, you can continue to use your existing code and incrementally add Ray AIR components as needed.

Some of the benefits you’ll get by using Ray AIR with your existing PyTorch training code:

  • Easy distributed data-parallel training on a cluster

  • Automatic checkpointing/fault tolerance and result tracking

  • Parallel data preprocessing

  • Seamless integration with hyperparameter tuning

  • Scalable batch prediction

  • Scalable model serving

This tutorial will show you how to start with Ray AIR from your existing PyTorch training code. We will learn how to distribute your training and do scalable batch prediction.

The example codeΒΆ

The example code we’ll be using is that of the PyTorch quickstart tutorial. This code trains a neural network classifier on the FashionMNIST dataset.

You can find the code we used for this tutorial here on GitHub.

UnmodifiedΒΆ

Let’s start with the unmodified code from the example. A thorough explanation of the parts is given in the full tutorial - we’ll just focus on the code here.

We start with some imports:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

Then we download the data:

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

We can now define the dataloaders:

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

We can then define and instantiate the neural network:

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)
Using cpu device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

Define our optimizer and loss:

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

And finally our training loop. Note that we renamed the function from train to train_epoch to avoid conflicts with the Ray Train module later (which is also called train):

def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

And while we’re at it, here is our validation loop (note that we sneaked in a return test_loss statement and also renamed the function):

def test_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss

Now we can trigger training and save a model:

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_epoch(train_dataloader, model, loss_fn, optimizer)
    test_epoch(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.310824  [    0/60000]
loss: 2.294673  [ 6400/60000]
loss: 2.279552  [12800/60000]
loss: 2.271212  [19200/60000]
loss: 2.253037  [25600/60000]
loss: 2.237454  [32000/60000]
loss: 2.229266  [38400/60000]
loss: 2.201499  [44800/60000]
loss: 2.204318  [51200/60000]
loss: 2.165509  [57600/60000]
Test Error: 
 Accuracy: 53.3%, Avg loss: 2.165069 

Epoch 2
-------------------------------
loss: 2.172770  [    0/60000]
loss: 2.162628  [ 6400/60000]
loss: 2.114871  [12800/60000]
loss: 2.129640  [19200/60000]
loss: 2.072733  [25600/60000]
loss: 2.029744  [32000/60000]
loss: 2.044677  [38400/60000]
loss: 1.968758  [44800/60000]
loss: 1.982601  [51200/60000]
loss: 1.903552  [57600/60000]
Test Error: 
 Accuracy: 56.8%, Avg loss: 1.906084 

Epoch 3
-------------------------------
loss: 1.929975  [    0/60000]
loss: 1.905118  [ 6400/60000]
loss: 1.797361  [12800/60000]
loss: 1.840994  [19200/60000]
loss: 1.721110  [25600/60000]
loss: 1.678175  [32000/60000]
loss: 1.691375  [38400/60000]
loss: 1.584185  [44800/60000]
loss: 1.619714  [51200/60000]
loss: 1.506852  [57600/60000]
Test Error: 
 Accuracy: 60.5%, Avg loss: 1.530285 

Epoch 4
-------------------------------
loss: 1.583245  [    0/60000]
loss: 1.556023  [ 6400/60000]
loss: 1.411425  [12800/60000]
loss: 1.488727  [19200/60000]
loss: 1.359579  [25600/60000]
loss: 1.360133  [32000/60000]
loss: 1.366381  [38400/60000]
loss: 1.279213  [44800/60000]
loss: 1.328040  [51200/60000]
loss: 1.223219  [57600/60000]
Test Error: 
 Accuracy: 62.9%, Avg loss: 1.254554 

Epoch 5
-------------------------------
loss: 1.319601  [    0/60000]
loss: 1.307626  [ 6400/60000]
loss: 1.148442  [12800/60000]
loss: 1.258462  [19200/60000]
loss: 1.127337  [25600/60000]
loss: 1.160929  [32000/60000]
loss: 1.174176  [38400/60000]
loss: 1.098028  [44800/60000]
loss: 1.149889  [51200/60000]
loss: 1.065084  [57600/60000]
Test Error: 
 Accuracy: 64.1%, Avg loss: 1.088535 

Done!
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")
Saved PyTorch Model State to model.pth

We’ll cover the rest of the tutorial (loading the model and doing batch prediction) later!

Introducing a wrapper function (no Ray AIR, yet!)ΒΆ

The notebook-style from the tutorial is great for tutorials, but in your production code you probably wrapped the actual training logic in a function. So let’s do this here, too.

Note that we do not add or alter any code here (apart from variable definitions) - we just take the loose bits of code in the current tutorial and put them into one function.

def train_func():
    batch_size = 64
    lr = 1e-3
    epochs = 5
    
    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    
    # Get cpu or gpu device for training.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    
    model = NeuralNetwork().to(device)
    print(model)
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        test_epoch(test_dataloader, model, loss_fn)

    print("Done!")

Let’s see it in action again:

train_func()
Using cpu device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
Epoch 1
-------------------------------
loss: 2.300413  [    0/60000]
loss: 2.286883  [ 6400/60000]
loss: 2.273655  [12800/60000]
loss: 2.268864  [19200/60000]
loss: 2.244969  [25600/60000]
loss: 2.216548  [32000/60000]
loss: 2.222040  [38400/60000]
loss: 2.182495  [44800/60000]
loss: 2.182766  [51200/60000]
loss: 2.156811  [57600/60000]
Test Error: 
 Accuracy: 50.6%, Avg loss: 2.145070 

Epoch 2
-------------------------------
loss: 2.150665  [    0/60000]
loss: 2.138080  [ 6400/60000]
loss: 2.086293  [12800/60000]
loss: 2.103044  [19200/60000]
loss: 2.041723  [25600/60000]
loss: 1.980555  [32000/60000]
loss: 2.000344  [38400/60000]
loss: 1.915977  [44800/60000]
loss: 1.930699  [51200/60000]
loss: 1.849836  [57600/60000]
Test Error: 
 Accuracy: 57.1%, Avg loss: 1.851586 

Epoch 3
-------------------------------
loss: 1.884133  [    0/60000]
loss: 1.847551  [ 6400/60000]
loss: 1.739998  [12800/60000]
loss: 1.781018  [19200/60000]
loss: 1.660957  [25600/60000]
loss: 1.619163  [32000/60000]
loss: 1.629720  [38400/60000]
loss: 1.535548  [44800/60000]
loss: 1.571056  [51200/60000]
loss: 1.458136  [57600/60000]
Test Error: 
 Accuracy: 60.9%, Avg loss: 1.483232 

Epoch 4
-------------------------------
loss: 1.549974  [    0/60000]
loss: 1.511189  [ 6400/60000]
loss: 1.374695  [12800/60000]
loss: 1.445348  [19200/60000]
loss: 1.323308  [25600/60000]
loss: 1.324354  [32000/60000]
loss: 1.328822  [38400/60000]
loss: 1.257330  [44800/60000]
loss: 1.298783  [51200/60000]
loss: 1.197863  [57600/60000]
Test Error: 
 Accuracy: 63.4%, Avg loss: 1.226258 

Epoch 5
-------------------------------
loss: 1.301627  [    0/60000]
loss: 1.279378  [ 6400/60000]
loss: 1.124873  [12800/60000]
loss: 1.230890  [19200/60000]
loss: 1.104782  [25600/60000]
loss: 1.130063  [32000/60000]
loss: 1.147206  [38400/60000]
loss: 1.083874  [44800/60000]
loss: 1.127730  [51200/60000]
loss: 1.044750  [57600/60000]
Test Error: 
 Accuracy: 64.6%, Avg loss: 1.066819 

Done!

The output should look very similar to the previous ouput.

Starting with Ray AIR: Distribute the trainingΒΆ

As a first step, we want to distribute the training across multiple workers. For this we want to

  1. Use data-parallel training by sharding the training data

  2. Setup the model to communicate gradient updates across machines

  3. Report the results back to Ray Train.

To facilitate this, we only need a few changes to the code:

  1. We import Ray Train:

import ray.train as train
  1. We use a config dict to configure some hyperparameters (this is not strictly needed but good practice, especially if you want to o hyperparameter tuning later):

def train_func(config: dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]
  1. We dynamically adjust the worker batch size according to the number of workers:

    batch_size_per_worker = batch_size // train.world_size()
  1. We prepare the data loader for distributed data sharding:

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)
  1. We prepare the model for distributed gradient updates:

    model = train.torch.prepare_model(model)

Note that train.torch.prepare_model() also automatically takes care of setting up devices (e.g. GPU training) - so we can get rid of those lines in our current code!

  1. We capture the validation loss and report it to Ray train:

        test_loss = test(test_dataloader, model, loss_fn)
        train.report(loss=test_loss)
  1. In the train_epoch() and test_epoch() functions we divide the size by the world size:

    size = len(dataloader.dataset) // train.world_size()  # Divide by word size
  1. In the train_epoch() function we can get rid of the device mapping. Ray Train does this for us:

        # We don't need this anymore! Ray Train does this automatically:
        # X, y = X.to(device), y.to(device) 

That’s it - you need less than 10 lines of Ray Train-specific code and can otherwise continue to use your original code.

Let’s take a look at the resulting code. First the train_epoch() function (2 lines changed, and we also commented out the print statement):

def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) // train.world_size()  # Divide by word size
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # We don't need this anymore! Ray Train does this automatically:
        # X, y = X.to(device), y.to(device)  

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

Then the test_epoch() function (1 line changed, and we also commented out the print statement):

def test_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // train.world_size()  # Divide by word size
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    # print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return test_loss

And lastly, the wrapping train_func() where we added 4 lines and modified 2 (apart from the config dict):

import ray.train as train


def train_func(config: dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]
    
    batch_size_per_worker = batch_size // train.world_size()
    
    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size_per_worker)
    test_dataloader = DataLoader(test_data, batch_size=batch_size_per_worker)
    
    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)
    
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    for t in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        test_loss = test_epoch(test_dataloader, model, loss_fn)
        train.report(loss=test_loss)

    print("Done!")
Package pickle5 becomes unnecessary in Python 3.8 and above. Its presence may confuse libraries including Ray. Please uninstall the package.

Now we’ll use Ray Train’s TorchTrainer to kick off the training. Note that we can set the hyperparmameters here! In the scaling_config we can also configure how many parallel workers to use and if we want to enable GPU training or not.

from ray.train.torch import TorchTrainer


trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 4},
    scaling_config={"num_workers": 2, "use_gpu": False},
)
result = trainer.fit()
print(f"Last result: {result.metrics}")
2022-06-22 16:28:31,525	INFO services.py:1477 -- View the Ray dashboard at http://127.0.0.1:8265
== Status ==
Current time: 2022-06-22 16:29:30 (running for 00:00:56.32)
Memory usage on this node: 7.5/31.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/15.32 GiB heap, 0.0/7.66 GiB objects
Result logdir: /home/ubuntu/ray_results/TorchTrainer_2022-06-22_16-28-33
Number of trials: 1/1 (1 TERMINATED)
Trial name status loc iter total time (s) loss _timestamp _time_this_iter_s
TorchTrainer_5c84a_00000TERMINATED172.31.43.110:1481731 4 47.56351.2631 1655915369 11.0948


2022-06-22 16:28:38,581	WARNING worker.py:1726 -- Warning: The actor TrainTrainable is very large (52 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store.
(BaseWorkerMixin pid=1481763) 2022-06-22 16:28:44,894	INFO config.py:70 -- Setting up process group for: env:// [rank=0, world_size=2]
(BaseWorkerMixin pid=1481764) 2022-06-22 16:28:44,891	INFO config.py:70 -- Setting up process group for: env:// [rank=1, world_size=2]
(BaseWorkerMixin pid=1481763) 2022-06-22 16:28:46,425	INFO train_loop_utils.py:293 -- Moving model to device: cpu
(BaseWorkerMixin pid=1481763) 2022-06-22 16:28:46,425	INFO train_loop_utils.py:331 -- Wrapping provided model in DDP.
(BaseWorkerMixin pid=1481764) 2022-06-22 16:28:46,425	INFO train_loop_utils.py:293 -- Moving model to device: cpu
(BaseWorkerMixin pid=1481764) 2022-06-22 16:28:46,425	INFO train_loop_utils.py:331 -- Wrapping provided model in DDP.
Result for TorchTrainer_5c84a_00000:
  _time_this_iter_s: 10.956670761108398
  _timestamp: 1655915337
  _training_iteration: 1
  date: 2022-06-22_16-28-57
  done: false
  experiment_id: dd2810ff95f74b1a8390f918b6c122fe
  hostname: ip-172-31-43-110
  iterations_since_restore: 1
  loss: 2.1705087840936748
  node_ip: 172.31.43.110
  pid: 1481731
  time_since_restore: 14.807097911834717
  time_this_iter_s: 14.807097911834717
  time_total_s: 14.807097911834717
  timestamp: 1655915337
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 5c84a_00000
  warmup_time: 0.0042934417724609375
  
Result for TorchTrainer_5c84a_00000:
  _time_this_iter_s: 10.683637142181396
  _timestamp: 1655915347
  _training_iteration: 2
  date: 2022-06-22_16-29-07
  done: false
  experiment_id: dd2810ff95f74b1a8390f918b6c122fe
  hostname: ip-172-31-43-110
  iterations_since_restore: 2
  loss: 1.918477459318319
  node_ip: 172.31.43.110
  pid: 1481731
  time_since_restore: 25.498638153076172
  time_this_iter_s: 10.691540241241455
  time_total_s: 25.498638153076172
  timestamp: 1655915347
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: 5c84a_00000
  warmup_time: 0.0042934417724609375
  
Result for TorchTrainer_5c84a_00000:
  _time_this_iter_s: 10.996578216552734
  _timestamp: 1655915358
  _training_iteration: 3
  date: 2022-06-22_16-29-18
  done: false
  experiment_id: dd2810ff95f74b1a8390f918b6c122fe
  hostname: ip-172-31-43-110
  iterations_since_restore: 3
  loss: 1.54556822397147
  node_ip: 172.31.43.110
  pid: 1481731
  time_since_restore: 36.48866558074951
  time_this_iter_s: 10.99002742767334
  time_total_s: 36.48866558074951
  timestamp: 1655915358
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: 5c84a_00000
  warmup_time: 0.0042934417724609375
  
Result for TorchTrainer_5c84a_00000:
  _time_this_iter_s: 11.09483027458191
  _timestamp: 1655915369
  _training_iteration: 4
  date: 2022-06-22_16-29-29
  done: false
  experiment_id: dd2810ff95f74b1a8390f918b6c122fe
  hostname: ip-172-31-43-110
  iterations_since_restore: 4
  loss: 1.263096342800529
  node_ip: 172.31.43.110
  pid: 1481731
  time_since_restore: 47.56349587440491
  time_this_iter_s: 11.074830293655396
  time_total_s: 47.56349587440491
  timestamp: 1655915369
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: 5c84a_00000
  warmup_time: 0.0042934417724609375
  
(BaseWorkerMixin pid=1481763) Done!
(BaseWorkerMixin pid=1481764) Done!
Result for TorchTrainer_5c84a_00000:
  _time_this_iter_s: 11.09483027458191
  _timestamp: 1655915369
  _training_iteration: 4
  date: 2022-06-22_16-29-29
  done: true
  experiment_id: dd2810ff95f74b1a8390f918b6c122fe
  experiment_tag: '0'
  hostname: ip-172-31-43-110
  iterations_since_restore: 4
  loss: 1.263096342800529
  node_ip: 172.31.43.110
  pid: 1481731
  time_since_restore: 47.56349587440491
  time_this_iter_s: 11.074830293655396
  time_total_s: 47.56349587440491
  timestamp: 1655915369
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: 5c84a_00000
  warmup_time: 0.0042934417724609375
  
2022-06-22 16:29:31,024	INFO tune.py:734 -- Total run time: 57.58 seconds (56.31 seconds for the tuning loop).
Last result: {'loss': 1.263096342800529, '_timestamp': 1655915369, '_time_this_iter_s': 11.09483027458191, '_training_iteration': 4, 'time_this_iter_s': 11.074830293655396, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 4, 'trial_id': '5c84a_00000', 'experiment_id': 'dd2810ff95f74b1a8390f918b6c122fe', 'date': '2022-06-22_16-29-29', 'timestamp': 1655915369, 'time_total_s': 47.56349587440491, 'pid': 1481731, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 47.56349587440491, 'timesteps_since_restore': 0, 'iterations_since_restore': 4, 'warmup_time': 0.0042934417724609375, 'experiment_tag': '0'}

Great, this works! You’re now training your model in parallel. You could now scale this up to more nodes and workers on your Ray cluster.

But there are a few improvements we can make to the code in order to get the most of the system. For one, we should enable checkpointing to get access to the trained model afterwards. Additionally, we should optimize the data loading to take place within the workers.

Enabling checkpointing to retrieve the modelΒΆ

Enabling checkpointing is pretty easy - we just need to call the train.save_checkpoint() API and pass the model state to it:

    train.save_checkpoint(epoch=t, model=model.module.state_dict())

Note that the model.module part is needed because the model gets wrapped in torch.nn.DistributedDataParallel by train.torch.prepare_model.

Move the data loader to the training functionΒΆ

You may have noticed a warning: Warning: The actor TrainTrainable is very large (52 MiB). Check that its definition is not implicitly capturing a large array or other object in scope. Tip: use ray.put() to put large objects in the Ray object store..

This is because we load the data outside the training function. Ray then serializes it to make it accessible to the remote tasks (that may get executed on a remote node!). This is not too bad with just 52 MB of data, but imagine this were a full image dataset - you wouldn’t want to ship this around the cluster unnecessarily. Instead, you should move the dataset loading part into the train_func(). This will then download the data to disk once per machine and result in much more efficient data loading.

The result looks like this:

def load_data():
    # Download training data from open datasets.
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor(),
    )

    # Download test data from open datasets.
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor(),
    )
    return training_data, test_data


def train_func(config: dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]
    
    batch_size_per_worker = batch_size // train.world_size()
    
    training_data, test_data = load_data()  # <- this is new!
    
    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size_per_worker)
    test_dataloader = DataLoader(test_data, batch_size=batch_size_per_worker)
    
    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)
    
    model = NeuralNetwork()
    model = train.torch.prepare_model(model)
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    for t in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        test_loss = test_epoch(test_dataloader, model, loss_fn)
        train.save_checkpoint(epoch=t, model=model.module.state_dict())  # <- this is new!
        train.report(loss=test_loss)

    print("Done!")

Let’s train again:

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 4},
    scaling_config={"num_workers": 2, "use_gpu": False},
)
result = trainer.fit()
== Status ==
Current time: 2022-06-22 16:30:41 (running for 00:00:56.46)
Memory usage on this node: 7.2/31.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/15.32 GiB heap, 0.0/7.66 GiB objects
Result logdir: /home/ubuntu/ray_results/TorchTrainer_2022-06-22_16-29-44
Number of trials: 1/1 (1 TERMINATED)
Trial name status loc iter total time (s) loss _timestamp _time_this_iter_s
TorchTrainer_86514_00000TERMINATED172.31.43.110:1481879 4 53.10381.24844 1655915440 11.4238


(BaseWorkerMixin pid=1481912) 2022-06-22 16:29:50,060	INFO config.py:70 -- Setting up process group for: env:// [rank=1, world_size=2]
(BaseWorkerMixin pid=1481911) 2022-06-22 16:29:50,039	INFO config.py:70 -- Setting up process group for: env:// [rank=0, world_size=2]
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 41984/26421880 [00:00<01:12, 361943.13it/s]
  0%|          | 41984/26421880 [00:00<01:10, 375108.71it/s]
  0%|          | 96256/26421880 [00:00<00:57, 459571.31it/s]
  0%|          | 79872/26421880 [00:00<01:10, 376211.09it/s]
  1%|          | 205824/26421880 [00:00<00:35, 736411.92it/s]
  0%|          | 130048/26421880 [00:00<01:00, 431147.91it/s]
  2%|▏         | 427008/26421880 [00:00<00:20, 1299733.59it/s]
  1%|          | 274432/26421880 [00:00<00:31, 822829.39it/s]
  3%|β–Ž         | 870400/26421880 [00:00<00:10, 2402951.16it/s]
  2%|▏         | 565248/26421880 [00:00<00:16, 1558525.69it/s]
  7%|β–‹         | 1754112/26421880 [00:00<00:05, 4545943.78it/s]
  4%|▍         | 1145856/26421880 [00:00<00:08, 2984121.04it/s]
 13%|β–ˆβ–Ž        | 3316736/26421880 [00:00<00:02, 8105757.73it/s]
  9%|β–Š         | 2288640/26421880 [00:00<00:04, 5702264.83it/s]
 19%|β–ˆβ–Š        | 4894720/26421880 [00:00<00:02, 10492520.90it/s]
 15%|β–ˆβ–        | 3875840/26421880 [00:00<00:02, 8912112.28it/s]
 24%|β–ˆβ–ˆβ–       | 6440960/26421880 [00:00<00:01, 11998939.91it/s]
 20%|β–ˆβ–ˆ        | 5402624/26421880 [00:00<00:01, 10824900.56it/s]
 30%|β–ˆβ–ˆβ–ˆ       | 7987200/26421880 [00:01<00:01, 13030704.03it/s]
 27%|β–ˆβ–ˆβ–‹       | 7032832/26421880 [00:01<00:01, 12501970.20it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 9534464/26421880 [00:01<00:01, 13735819.55it/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 8593408/26421880 [00:01<00:01, 13437913.13it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 11080704/26421880 [00:01<00:01, 14207667.46it/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 10128384/26421880 [00:01<00:01, 14011409.41it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 12626944/26421880 [00:01<00:00, 14542620.11it/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 11660288/26421880 [00:01<00:01, 14405849.73it/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 14174208/26421880 [00:01<00:00, 14759412.47it/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 15720448/26421880 [00:01<00:00, 14933524.60it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 13135872/26421880 [00:01<00:00, 14499791.65it/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 17266688/26421880 [00:01<00:00, 15051136.83it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 14682112/26421880 [00:01<00:00, 14772346.00it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 18813952/26421880 [00:01<00:00, 15130807.12it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 16228352/26421880 [00:01<00:00, 14967937.04it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 20360192/26421880 [00:01<00:00, 15195261.31it/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 17775616/26421880 [00:01<00:00, 15112781.93it/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 21906432/26421880 [00:01<00:00, 15236573.97it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 19321856/26421880 [00:01<00:00, 15212408.07it/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 23453696/26421880 [00:02<00:00, 15266838.91it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 20868096/26421880 [00:01<00:00, 15269608.06it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 24999936/26421880 [00:02<00:00, 15289978.61it/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 23961600/26421880 [00:02<00:00, 15340254.66it/s]
26422272it [00:02, 11870579.44it/s]                              
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 25507840/26421880 [00:02<00:00, 15364219.70it/s]
26422272it [00:02, 11859017.59it/s]                              
(BaseWorkerMixin pid=1481912) Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481911) Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481912) 
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
(BaseWorkerMixin pid=1481911) 
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/29515 [00:00<?, ?it/s]m 
  0%|          | 0/29515 [00:00<?, ?it/s]m 
29696it [00:00, 295064.38it/s]                           
29696it [00:00, 296748.03it/s]           m 
(BaseWorkerMixin pid=1481912) Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481912) 
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
(BaseWorkerMixin pid=1481911) Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481911) 
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/4422102 [00:00<?, ?it/s]
  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|          | 34816/4422102 [00:00<00:12, 348148.71it/s]
  1%|          | 22528/4422102 [00:00<00:19, 225237.25it/s]
  2%|▏         | 69632/4422102 [00:00<00:12, 346462.21it/s]
  1%|          | 45056/4422102 [00:00<00:19, 224942.66it/s]
  2%|▏         | 104448/4422102 [00:00<00:12, 347157.14it/s]
  2%|▏         | 99328/4422102 [00:00<00:11, 368964.83it/s]
  4%|▍         | 184320/4422102 [00:00<00:08, 524405.18it/s]
  4%|▍         | 189440/4422102 [00:00<00:07, 577214.69it/s]
  9%|β–Š         | 384000/4422102 [00:00<00:03, 1052895.14it/s]
  9%|β–‰         | 396288/4422102 [00:00<00:03, 1111874.46it/s]
 18%|β–ˆβ–Š        | 779264/4422102 [00:00<00:01, 2033939.42it/s]
 18%|β–ˆβ–Š        | 806912/4422102 [00:00<00:01, 2121674.81it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 1573888/4422102 [00:00<00:00, 3955289.29it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 1630208/4422102 [00:00<00:00, 4106422.09it/s]
4422656it [00:00, 5316997.96it/s]                             
4422656it [00:00, 5349419.81it/s]                             
(BaseWorkerMixin pid=1481912) Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481911) Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481912) 
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
(BaseWorkerMixin pid=1481911) 
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
(BaseWorkerMixin pid=1481912) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
(BaseWorkerMixin pid=1481911) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
6144it [00:00, 49085340.53it/s]         0m 
6144it [00:00, 43975774.36it/s]         0m 
(BaseWorkerMixin pid=1481912) 2022-06-22 16:29:55,912	INFO train_loop_utils.py:293 -- Moving model to device: cpu
(BaseWorkerMixin pid=1481912) 2022-06-22 16:29:55,913	INFO train_loop_utils.py:331 -- Wrapping provided model in DDP.
(BaseWorkerMixin pid=1481911) 2022-06-22 16:29:55,899	INFO train_loop_utils.py:293 -- Moving model to device: cpu
(BaseWorkerMixin pid=1481911) 2022-06-22 16:29:55,900	INFO train_loop_utils.py:331 -- Wrapping provided model in DDP.
(BaseWorkerMixin pid=1481912) Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481912) 
(BaseWorkerMixin pid=1481911) Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
(BaseWorkerMixin pid=1481911) 
Result for TorchTrainer_86514_00000:
  _time_this_iter_s: 17.002915143966675
  _timestamp: 1655915407
  _training_iteration: 1
  date: 2022-06-22_16-30-07
  done: false
  experiment_id: 1e7954bef1c6432785374780fb0da29e
  hostname: ip-172-31-43-110
  iterations_since_restore: 1
  loss: 2.1645877740945028
  node_ip: 172.31.43.110
  pid: 1481879
  should_checkpoint: true
  time_since_restore: 19.680341005325317
  time_this_iter_s: 19.680341005325317
  time_total_s: 19.680341005325317
  timestamp: 1655915407
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: '86514_00000'
  warmup_time: 0.004637956619262695
  
Result for TorchTrainer_86514_00000:
  _time_this_iter_s: 10.904694557189941
  _timestamp: 1655915418
  _training_iteration: 2
  date: 2022-06-22_16-30-18
  done: false
  experiment_id: 1e7954bef1c6432785374780fb0da29e
  hostname: ip-172-31-43-110
  iterations_since_restore: 2
  loss: 1.905545388057733
  node_ip: 172.31.43.110
  pid: 1481879
  should_checkpoint: true
  time_since_restore: 30.562681436538696
  time_this_iter_s: 10.882340431213379
  time_total_s: 30.562681436538696
  timestamp: 1655915418
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: '86514_00000'
  warmup_time: 0.004637956619262695
  
Result for TorchTrainer_86514_00000:
  _time_this_iter_s: 11.091916799545288
  _timestamp: 1655915429
  _training_iteration: 3
  date: 2022-06-22_16-30-29
  done: false
  experiment_id: 1e7954bef1c6432785374780fb0da29e
  hostname: ip-172-31-43-110
  iterations_since_restore: 3
  loss: 1.531144731363673
  node_ip: 172.31.43.110
  pid: 1481879
  should_checkpoint: true
  time_since_restore: 41.66515636444092
  time_this_iter_s: 11.102474927902222
  time_total_s: 41.66515636444092
  timestamp: 1655915429
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: '86514_00000'
  warmup_time: 0.004637956619262695
  
Result for TorchTrainer_86514_00000:
  _time_this_iter_s: 11.423810482025146
  _timestamp: 1655915440
  _training_iteration: 4
  date: 2022-06-22_16-30-40
  done: false
  experiment_id: 1e7954bef1c6432785374780fb0da29e
  hostname: ip-172-31-43-110
  iterations_since_restore: 4
  loss: 1.2484390530616614
  node_ip: 172.31.43.110
  pid: 1481879
  should_checkpoint: true
  time_since_restore: 53.103771924972534
  time_this_iter_s: 11.438615560531616
  time_total_s: 53.103771924972534
  timestamp: 1655915440
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: '86514_00000'
  warmup_time: 0.004637956619262695
  
(BaseWorkerMixin pid=1481912) Done!
(BaseWorkerMixin pid=1481911) Done!
Result for TorchTrainer_86514_00000:
  _time_this_iter_s: 11.423810482025146
  _timestamp: 1655915440
  _training_iteration: 4
  date: 2022-06-22_16-30-40
  done: true
  experiment_id: 1e7954bef1c6432785374780fb0da29e
  experiment_tag: '0'
  hostname: ip-172-31-43-110
  iterations_since_restore: 4
  loss: 1.2484390530616614
  node_ip: 172.31.43.110
  pid: 1481879
  should_checkpoint: true
  time_since_restore: 53.103771924972534
  time_this_iter_s: 11.438615560531616
  time_total_s: 53.103771924972534
  timestamp: 1655915440
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: '86514_00000'
  warmup_time: 0.004637956619262695
  
2022-06-22 16:30:41,236	INFO tune.py:734 -- Total run time: 56.58 seconds (56.46 seconds for the tuning loop).

We can see our results here:

print(f"Last result: {result.metrics}")
print(f"Checkpoint: {result.checkpoint}")
Last result: {'loss': 1.2484390530616614, '_timestamp': 1655915440, '_time_this_iter_s': 11.423810482025146, '_training_iteration': 4, 'time_this_iter_s': 11.438615560531616, 'should_checkpoint': True, 'done': True, 'timesteps_total': None, 'episodes_total': None, 'training_iteration': 4, 'trial_id': '86514_00000', 'experiment_id': '1e7954bef1c6432785374780fb0da29e', 'date': '2022-06-22_16-30-40', 'timestamp': 1655915440, 'time_total_s': 53.103771924972534, 'pid': 1481879, 'hostname': 'ip-172-31-43-110', 'node_ip': '172.31.43.110', 'config': {}, 'time_since_restore': 53.103771924972534, 'timesteps_since_restore': 0, 'iterations_since_restore': 4, 'warmup_time': 0.004637956619262695, 'experiment_tag': '0'}
Checkpoint: <ray.air.checkpoint.Checkpoint object at 0x7f15317fc160>

Loading the model for predictionΒΆ

You may have noticed that we skipped one part of the original tutorial - loading the model and using it for inference. The original code looks like this (we’ve wrapped it in a function):

def predict_from_model(model):
    classes = [
        "T-shirt/top",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]

    model.eval()
    x, y = test_data[0][0], test_data[0][1]
    with torch.no_grad():
        pred = model(x)
        predicted, actual = classes[pred[0].argmax(0)], classes[y]
        print(f'Predicted: "{predicted}", Actual: "{actual}"')

We can use our saved model with the existing code to do prediction:

from ray.train.torch import load_checkpoint

model, _ = load_checkpoint(result.checkpoint, NeuralNetwork())

predict_from_model(model)
Predicted: "Ankle boot", Actual: "Ankle boot"

To predict more than one example, we can use a loop:

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

def predict_from_model(model, data):
    model.eval()
    with torch.no_grad():
        for x, y in data:
            pred = model(x)
            predicted, actual = classes[pred[0].argmax(0)], classes[y]
            print(f'Predicted: "{predicted}", Actual: "{actual}"')
predict_from_model(model, [test_data[i] for i in range(10)])
Predicted: "Ankle boot", Actual: "Ankle boot"
Predicted: "Pullover", Actual: "Pullover"
Predicted: "Trouser", Actual: "Trouser"
Predicted: "Trouser", Actual: "Trouser"
Predicted: "Pullover", Actual: "Shirt"
Predicted: "Trouser", Actual: "Trouser"
Predicted: "Coat", Actual: "Coat"
Predicted: "Coat", Actual: "Shirt"
Predicted: "Sneaker", Actual: "Sandal"
Predicted: "Sneaker", Actual: "Sneaker"

Using Ray AIR for scalable batch predictionΒΆ

However, we can also use Ray AIRs BatchPredictor class to do scalable prediction.

from ray.air import BatchPredictor
from ray.air.predictors.integrations.torch import TorchPredictor

batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, TorchPredictor, model=NeuralNetwork())
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
/home/ubuntu/ray/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb Cell 49' in <cell line: 1>()
----> <a href='vscode-notebook-cell://ssh-remote%2Bec2-18-118-29-213.us-east-2.compute.amazonaws.com/home/ubuntu/ray/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb#ch0000048vscode-remote?line=0'>1</a> from ray.air import BatchPredictor
      <a href='vscode-notebook-cell://ssh-remote%2Bec2-18-118-29-213.us-east-2.compute.amazonaws.com/home/ubuntu/ray/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb#ch0000048vscode-remote?line=1'>2</a> from ray.air.predictors.integrations.torch import TorchPredictor
      <a href='vscode-notebook-cell://ssh-remote%2Bec2-18-118-29-213.us-east-2.compute.amazonaws.com/home/ubuntu/ray/doc/source/ray-air/examples/convert_existing_pytorch_code_to_ray_air.ipynb#ch0000048vscode-remote?line=3'>4</a> batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, TorchPredictor, model=NeuralNetwork())

ImportError: cannot import name 'BatchPredictor' from 'ray.air' (/home/ubuntu/ray/python/ray/air/__init__.py)

Batch predictors work with Ray Datasets. Here we convert our test dataset into a Ray Dataset - note that this is not very efficient, and you can look at our other tutorials to see more efficient ways to generate a Ray Dataset.

import ray.data

ds = ray.data.from_items([x for x, y in test_data])

We can then trigger prediction with two workers:

results = batch_predictor.predict(ds, min_scoring_workers=2)
Map Progress (2 actors 1 pending): 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 200/200 [00:04<00:00, 44.83it/s]

results is another Ray Dataset. We can use results.show() to see our prediction results:

results.show()
predictions
0 [-1.3625717, -1.7147198, -0.7944063, -1.516942...
1 [0.7415036, -2.1898255, 2.9233487, -0.8718336,...
2 [1.9445083, 4.0967875, -0.07387225, 2.8944397,...
3 [1.460784, 3.2333734, -0.15551251, 2.3267126, ...
4 [0.758382, -0.8887838, 1.1806433, -0.04382074,...
... ...
9995 [-1.6255455, -2.7318435, -0.8888813, -2.205097...
9996 [0.90374756, 2.051165, -0.046540327, 1.4930309...
9997 [0.89324296, -0.14099044, -0.08300409, 0.74801...
9998 [1.4642937, 3.2236817, -0.23001938, 2.5179548,...
9999 [-1.0059572, -1.2176754, -0.36726016, -1.10825...

10000 rows Γ— 1 columns

If we want to convert these predictions into class names (as in the original example), we can use a map function to do this:

predicted_classes = results.map_batches(
    lambda batch: [classes[pred.argmax(0)] for pred in batch["predictions"]], 
    batch_format="pandas")
Map_Batches: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 200/200 [00:01<00:00, 117.68it/s]

To compare this with the actual labels, let’s create a Ray dataset for these and zip it together with the predicted classes:

real_classes = ray.data.from_items([classes[y] for x, y in test_data])
merged = predicted_classes.zip(real_classes)

Let’s examine our results:

merged.show()
value
0 (Ankle boot, Ankle boot)
1 (Pullover, Pullover)
2 (Trouser, Trouser)
3 (Trouser, Trouser)
4 (Pullover, Shirt)
... ...
9995 (Ankle boot, Ankle boot)
9996 (Trouser, Trouser)
9997 (T-shirt/top, Bag)
9998 (Trouser, Trouser)
9999 (Sneaker, Sandal)

10000 rows Γ— 1 columns

SummaryΒΆ

This tutorial demonstrated how to turn your existing PyTorch code into code you can use with Ray AIR.

We learned how to

  • enable distributed training using Ray Train abstractions

  • save and retrieve model checkpoints via Ray AIR

  • load a model for batch prediction

In our other examples you can learn how to do more things with the Ray AIR API, such as serving your model with Ray Serve or tune your hyperparameters with Ray Tune. You can also learn how to construct Ray Datasets to leverage Ray AIR’s preprocessing API.

We hope this tutorial gave you a good starting point to leverage Ray AIR. If you have any questions, suggestions, or run into any problems pelase reach out on Discuss or GitHub!