Computer vision pattern#
This notebook is an end-to-end, real-world computer-vision workflow that runs seamlessly on an Anyscale cluster using Ray Train. You start by pulling a slice of the Food-101 dataset, push it through a lightweight preprocessing pipeline, store it efficiently in Parquet, and then fine-tune a ResNet-18 model in a fault-tolerant, distributed manner. Along the way, you lean on Ray’s helpers to prepare data loaders, coordinate workers, checkpoint automatically, and resume after failure. Afterwards, you perform inference with Ray Data, all without writing a single line of low-level distributed code.
Learning objectives#
Launch distributed training with Ray Train’s
TorchTrainerand configure it for multi-GPU, multi-node execution.Use Ray Train’s built-in utilities (
prepare_model,prepare_data_loader,get_checkpoint,train.report) to wrap your existing PyTorch code without modifying your modeling logic.Save and resume from automatic, fault-tolerant checkpoints across epochs.
Offload batch inference using Ray Data. This allows you to treat inference as a scalable workload.
Run end-to-end training and evaluation without needing to understand the low-level mechanics of distributed systems.
By the end of the tutorial, you produce a working model, clear loss curves, and a hands-on experience of how Ray Train simplifies distributed computer-vision workloads.
What problem are you solving? (image classification with Food-101-Lite)#
This notebook trains a neural network to classify food photos into one of 10 categories
using the Food-101-Lite dataset—a compact, 10-class subset of the original Food-101 benchmark.
Inputs#
Every sample is a 3-channel Red-Green-Blue (RGB) image, resized to \(224 \times 224\):
You apply standard vision transforms (normalization, random crop/flip) and batch the data with a plain PyTorch DataLoader (wrapped by ray.train.torch.prepare_data_loader for distributed training).
Labels#
Each image belongs to one of ten classes:
[‘pizza’, ‘hamburger’, ‘sushi’, ‘ramen’, ‘fried rice’, ‘steak’, ‘hot dog’, ‘pancake’, ‘burrito’, ‘caesar salad’]
The label is an integer \(y \in \{0, \dots, 9\}\) used for supervision.
What does the model learn?#
You train a compact CNN (For example, ResNet-18) to map an image (x) to class probabilities:
Training minimizes the cross-entropy loss,
so the network assigns high likelihood to the correct class.
How to migrate this computer vision workload to a distributed setup using Ray on Anyscale#
In this tutorial, you start with a small PyTorch-based image classification task—training a ResNet-18 model on a 10% slice of the Food-101 dataset, and progressively migrate it into a fully distributed, fault-tolerant training job using Ray Train on Anyscale. The goal is to show you exactly how to scale your existing workflow without rewriting it from scratch.
Use the following steps to migrate:
Preprocess data and persist it in a distributed-friendly format
You take raw images from Hugging Face’sfood101dataset, applytorchvisionresizing and center-cropping, and serialize them to Parquet usingpyarrow. The system writes these Parquet files to the Anyscale cluster’s shared storage volume (/mnt/cluster_storage), so any node can access them, on any worker, without duplication or sync issues.Create a lightweight PyTorch
Datasetfor Parquet ingestion
You implement a customFood101Datasetthat reads directly from the Parquet files. This provides control over the way the system reads rows and row groups. While this isn’t yet fully distributed, it allows you to simulate a real-world scenario where a developer starts with something simple before optimizing. Note: you use PyTorch style data loading in this tutorial to demonstrate (1) low level control in a PyTorch native environment and (2) how to move pre-existing PyTorch code into a distributed Anyscale environment. Other tutorials in this module incorporate Ray Data, so you can see how the two approaches differ.Integrate Ray Train into the training loop
You encapsulate your existing PyTorch training logic in atrain_loop_per_worker()function, which Ray Train executes on each worker, typically one per GPU. Inside this loop, you:Wrap the model with
prepare_model()to make it compatible with distributed data parallelism.Wrap the
DataLoaderwithprepare_data_loader()to enable device placement and Ray worker context handling.Use Ray’s
CheckpointAPI to save and resume from checkpoints as needed.Report training and validation metrics with
train.report()after each epoch.
Launch training with
TorchTraineron an Anyscale cluster
You instantiate aTorchTrainerthat runs:With
num_workers=8anduse_gpu=True. For example, across eight A10 or A100 GPUs on Anyscale. Please note that this amount of compute is not necessary for the example in this tutorial, as the excess resources are for educational purposes only.With
RunConfigthat sets checkpoint retention and auto-resume (withmax_failures=3).On infrastructure that’s provisioned and scheduled by Anyscale with no manual Ray cluster setup required.
Once launched, Ray automatically handles:
Multi-node orchestration
Worker assignment and device pinning
Failure recovery and retry logic
Checkpointing and logging
Validate fault tolerance
You runtrainer.fit()a second time. If manual intervention or failure interrupts the previous training, Ray picks up from the latest checkpoint. This shows real-world robustness without any manual checkpoint management or scripting.Launch distributed GPU inference tasks
At the end, you provide inference with a Ray Data pipeline that loads the best checkpoint and runs inference on a single image from the validation set. You run this task on one GPU from the cluster.
All of this runs inside a managed Anyscale workspace. You don’t need to start or SSH into clusters, worry about node IP, or configure NCCL. The entire setup is declarative and self-contained in this notebook, and you can re-run it or scale it up by changing a single parameter (num_workers).
This tutorial mirrors how many ML teams operate in practice: starting with a working PyTorch training loop and migrating it to the cloud without rewriting core logic. With Ray Train on Anyscale, the migration is clean, incremental, and production-ready.
1. Imports#
Before you start, gather every library you use throughout this notebook. Pull in core Python utilities for file handling and plotting, PyTorch and TorchVision for deep-learning components, Ray Train for distributed orchestration, Hugging Face Datasets for quick data access, and PyArrow plus Pandas for fast Parquet IO. Importing everything up-front keeps the rest of the tutorial clean and predictable.
# 00. Runtime setup
import os
import sys
import subprocess
# Non-secret env var
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
# Install Python dependencies
subprocess.check_call([
sys.executable, "-m", "pip", "install", "--no-cache-dir",
"torch==2.8.0",
"torchvision==0.23.0",
"matplotlib==3.10.6",
"pyarrow==14.0.2",
"datasets==2.19.2",
])
# 01. Imports
# ————————————————————————
# Standard Library Utilities
# ————————————————————————
import os, io, tempfile, shutil # file I/O and temp dirs
import json # reading/writing configs
import random, uuid # randomness and unique IDs
# ————————————————————————
# Core Data & Storage Libraries
# ————————————————————————
import pandas as pd # tabular data handling
import numpy as np # numerical ops
import pyarrow as pa # in-memory columnar format
import pyarrow.parquet as pq # reading/writing Parquet files
from tqdm import tqdm # progress bars
# ————————————————————————
# Image Handling & Visualization
# ————————————————————————
from PIL import Image
import matplotlib.pyplot as plt # plotting loss curves, images
# ————————————————————————
# PyTorch + TorchVision Core
# ————————————————————————
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torchvision.models import resnet18
from torchvision.transforms import Compose, Resize, CenterCrop
# ————————————————————————
# Ray Train: Distributed Training Primitives
# ————————————————————————
import ray
import ray.train as train
from ray.train.torch import (
prepare_model,
prepare_data_loader,
TorchTrainer,
)
from ray.train import (
ScalingConfig,
RunConfig,
FailureConfig,
CheckpointConfig,
Checkpoint,
get_checkpoint,
get_context,
)
# ————————————————————————
# Dataset Access
# ————————————————————————
from datasets import load_dataset # Hugging Face Datasets
from ray.data import DataContext
DataContext.get_current().use_streaming_executor = False
2. Load 10 % of Food-101#
Next, get roughly 7,500 images, exactly 10% of Food-101—using a single call to load_dataset. This trimmed subset trains quickly while still being large enough to demonstrate Ray’s scaling behavior.
# 02. Load 10% of food101 (~7,500 images)
ds = load_dataset("food101", split="train[:10%]")
3. Resize and encode images#
Preprocess each image: resize to 256 pixels, center-crop to 224 pixels (the size expected by most ImageNet models), and then convert the result to raw Joint Photographic Experts Group (JPEG) bytes. By storing bytes instead of full Python Imaging Library (PIL) objects, you keep the dataset compact and Parquet-friendly.
# 03. Resize and encode as JPEG bytes
transform = Compose([Resize(256), CenterCrop(224)])
records = []
for example in tqdm(ds, desc="Preprocessing images", unit="img"):
try:
img = transform(example["image"])
buf = io.BytesIO()
img.save(buf, format="JPEG")
records.append({
"image_bytes": buf.getvalue(),
"label": example["label"]
})
except Exception as e:
continue
4. Visual sanity check#
Before committing to hours of training, take nine random samples and plot them with their class names. This quick inspection lets you properly align labels and confirm that images are correctly resized.
# 04. Visualize the dataset
label_names = ds.features["label"].names # maps int → string
samples = random.sample(records, 9)
fig, axs = plt.subplots(3, 3, figsize=(8, 8))
fig.suptitle("Sample Resized Images from food101-lite", fontsize=16)
for ax, rec in zip(axs.flatten(), samples):
img = Image.open(io.BytesIO(rec["image_bytes"]))
label_name = label_names[rec["label"]]
ax.imshow(img)
ax.set_title(label_name)
ax.axis("off")
plt.tight_layout()
plt.show()
5. Persist to Parquet#
Write the images and labels to a Parquet file. Because Parquet is columnar, you can read just the columns you need during training, which speeds up IO—especially when multiple workers are reading in parallel under Ray.
# 05. Write Dataset to Parquet
output_dir = "/mnt/cluster_storage/food101_lite/parquet_256"
os.makedirs(output_dir, exist_ok=True)
table = pa.Table.from_pydict({
"image_bytes": [r["image_bytes"] for r in records],
"label": [r["label"] for r in records]
})
pq.write_table(table, os.path.join(output_dir, "shard_0.parquet"))
print(f"Wrote {len(records)} records to {output_dir}")
6. Custom Food101Dataset for Parquet#
To feed data into PyTorch, define a custom Dataset. You cache Parquet metadata, map global indices to specific row groups, and pull only the row you need. Each __getitem__ returns an (image, label) pair that’s immediately ready for further transforms.
# 06. Define PyTorch Dataset that loads from Parquet
class Food101Dataset(Dataset):
def __init__(self, parquet_path: str, transform=None):
self.parquet_file = pq.ParquetFile(parquet_path)
self.transform = transform
# Precompute a global row index to (row_group_idx, local_idx) map
self.row_group_map = []
for rg_idx in range(self.parquet_file.num_row_groups):
rg_meta = self.parquet_file.metadata.row_group(rg_idx)
num_rows = rg_meta.num_rows
self.row_group_map.extend([(rg_idx, i) for i in range(num_rows)])
def __len__(self):
return len(self.row_group_map)
def __getitem__(self, idx):
row_group_idx, local_idx = self.row_group_map[idx]
# Read only the relevant row group (in memory-efficient batch---for scalability)
table = self.parquet_file.read_row_group(row_group_idx, columns=["image_bytes", "label"])
row = table.to_pandas().iloc[local_idx]
img = Image.open(io.BytesIO(row["image_bytes"])).convert("RGB")
if self.transform:
img = self.transform(img)
return img, row["label"]
7. Image transform#
Create a transform pipeline: ToTensor() followed by ImageNet mean and standard-deviation normalisation. By applying the transform inside the dataset, you make sure every worker, no matter where it runs, processes images in exactly the same way.
# 07. Define data preprocessing transform
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
8. Train/validation split#
Shuffle the full Parquet table once (seeded for reproducibility) and then slice off the last 500 rows to construct the validation set. Write the train and validation partitions to their own Parquet files so you can load them independently later.
# 08. Create train/val Parquet splits
full_path = "/mnt/cluster_storage/food101_lite/parquet_256/shard_0.parquet"
df = (
pq.read_table(full_path)
.to_pandas()
.sample(frac=1.0, random_state=42) # shuffle for reproducibility
)
df[:-500].to_parquet("/mnt/cluster_storage/food101_lite/train.parquet") # training
df[-500:].to_parquet("/mnt/cluster_storage/food101_lite/val.parquet") # validation
9. Inspect a DataLoader batch#
Before you scale out, build a regular single-process DataLoader, pull one batch, and print its shape. This tiny test reassures you that batching, multiprocessing, and transforms work correctly.
# 09. Observe data shape
loader = DataLoader(
Food101Dataset("/mnt/cluster_storage/food101_lite/train.parquet", transform=transform),
batch_size=16,
shuffle=True,
num_workers=4,
)
for images, labels in loader:
print(images.shape, labels.shape)
break
10. Helper: Ray-prepared DataLoaders#
Wrap the DataLoader with prepare_data_loader.
Ray automatically injects a DistributedSampler, pins the loader to the correct GPU, and manages all worker-rank bookkeeping. This means you don’t need to manually construct a DistributedSampler, as it’s handled internally by Ray.
# 10. Define helper to create prepared DataLoader
def build_dataloader(parquet_path: str, batch_size: int, shuffle=True):
dataset = Food101Dataset(parquet_path, transform=transform)
# Let Ray handle DistributedSampler and device placement via prepare_data_loader.
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=2,
)
return prepare_data_loader(loader)
11. train_loop_per_worker#
This function defines the per-worker training logic that Ray Train executes on each distributed worker.
Each worker builds its own model, optimizer, and dataloaders; resumes automatically from the most recent Ray-managed checkpoint (if available); and then trains and validates the model across epochs.
Key behaviors to note:
Checkpoints are first written to a fast temporary local directory on each worker, then safely persisted to the run’s configured
storage_pathbytrain.report()—ensuring reliability and retry support even under transient node failures.Metrics (train and validation loss) are automatically collected and stored by Ray Train—no need for manual file writes or JSON logging.
Fault tolerance is fully handled by Ray Train’s checkpointing and retry mechanism via
RunConfigandFailureConfig.Final accuracy is computed using
torchmetrics.MulticlassAccuracy, which performs synchronized, distributed accuracy aggregation across all workers, ensuring a correct global metric instead of rank-0-only evaluation.
This design keeps the training loop clean, fault-tolerant, and fully aligned with Ray Train’s built-in distributed orchestration.
# 11. Define Ray Train train_loop_per_worker (tempdir checkpoints + Ray-managed metrics)
def train_loop_per_worker(config):
import tempfile
rank = get_context().get_world_rank()
# === Model ===
net = resnet18(num_classes=101)
model = prepare_model(net)
# === Optimizer / Loss ===
optimizer = optim.Adam(model.parameters(), lr=config["lr"])
criterion = nn.CrossEntropyLoss()
# === Resume from Checkpoint ===
start_epoch = 0
ckpt = get_checkpoint()
if ckpt:
with ckpt.as_directory() as ckpt_dir:
# Map to CPU is fine; prepare_model will handle device placement.
model.load_state_dict(torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu"))
opt_path = os.path.join(ckpt_dir, "optimizer.pt")
if os.path.exists(opt_path):
optimizer.load_state_dict(torch.load(opt_path, map_location="cpu"))
meta_path = os.path.join(ckpt_dir, "meta.pt")
if os.path.exists(meta_path):
# Continue from the next epoch after the saved one
start_epoch = int(torch.load(meta_path).get("epoch", -1)) + 1
if rank == 0:
print(f"[Rank {rank}] Resumed from checkpoint at epoch {start_epoch}")
# === DataLoaders ===
train_loader = build_dataloader(
"/mnt/cluster_storage/food101_lite/train.parquet", config["batch_size"], shuffle=True
)
val_loader = build_dataloader(
"/mnt/cluster_storage/food101_lite/val.parquet", config["batch_size"], shuffle=False
)
# === Training Loop ===
for epoch in range(start_epoch, config["epochs"]):
# Required when using DistributedSampler
if hasattr(train_loader, "sampler") and hasattr(train_loader.sampler, "set_epoch"):
train_loader.sampler.set_epoch(epoch)
model.train()
train_loss_total, train_batches = 0.0, 0
for xb, yb in train_loader:
optimizer.zero_grad()
loss = criterion(model(xb), yb)
loss.backward()
optimizer.step()
train_loss_total += loss.item()
train_batches += 1
train_loss = train_loss_total / max(train_batches, 1)
# === Validation Loop ===
model.eval()
val_loss_total, val_batches = 0.0, 0
with torch.no_grad():
for val_xb, val_yb in val_loader:
val_loss_total += criterion(model(val_xb), val_yb).item()
val_batches += 1
val_loss = val_loss_total / max(val_batches, 1)
metrics = {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss}
if rank == 0:
print(metrics)
# ---- Save checkpoint to fast local temp dir; Ray persists it via report() ----
if rank == 0:
with tempfile.TemporaryDirectory() as tmpdir:
torch.save(model.state_dict(), os.path.join(tmpdir, "model.pt"))
torch.save(optimizer.state_dict(), os.path.join(tmpdir, "optimizer.pt"))
torch.save({"epoch": epoch}, os.path.join(tmpdir, "meta.pt"))
ckpt_out = Checkpoint.from_directory(tmpdir)
train.report(metrics, checkpoint=ckpt_out)
else:
# Non-zero ranks report metrics only (no checkpoint attachment)
train.report(metrics)
# === Final validation accuracy (distributed via TorchMetrics) ===
from torchmetrics.classification import MulticlassAccuracy
model.eval()
device = next(model.parameters()).device
# Sync across DDP workers when computing the final value
acc_metric = MulticlassAccuracy(
num_classes=101, average="micro", sync_on_compute=True
).to(device)
with torch.no_grad():
for xb, yb in val_loader:
logits = model(xb)
preds = torch.argmax(logits, dim=1)
acc_metric.update(preds, yb)
dist_val_acc = acc_metric.compute().item()
if rank == 0:
print(f"Val Accuracy (distributed): {dist_val_acc:.2%}")
12. Launch distributed training with TorchTrainer#
Instantiate a TorchTrainer. Ask for eight GPU workers, enable up to three automatic retries, and tell Ray to keep the five checkpoints with the lowest validation loss. One call to trainer.fit() kicks off a fault-tolerant job on your Anyscale cluster.
# 12. Run Training with Ray Train
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 5},
scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
run_config=RunConfig(
name="food101_ft_resume",
storage_path="/mnt/cluster_storage/food101_lite/results",
checkpoint_config=CheckpointConfig(
num_to_keep=5,
checkpoint_score_attribute="val_loss",
checkpoint_score_order="min"
),
failure_config=FailureConfig(max_failures=3),
),
)
result = trainer.fit()
print("Final metrics:", result.metrics)
best_ckpt = result.checkpoint # this is the one with lowest val_loss
13. Plot training and validation loss curves#
After training completes, visualize the recorded metrics directly from Ray Train’s results object. No manual CSV handling required.
result.metrics_dataframe automatically contains every metric reported during training, including per-epoch loss values from all workers.
This plot extracts the training and validation losses, groups them by epoch, and displays the most recent report for each.
By comparing these two curves, you quickly assess convergence behavior and detect overfitting (for example, when training loss continues to decrease while validation loss rises).
Because Ray Train automatically stores all metrics and checkpoints, this visualization reflects the same information used to select the best checkpoint based on validation loss in your RunConfig.
# 13. Plot training / validation loss curves
# Pull the full metrics history Ray stored for this run
df = result.metrics_dataframe.copy()
# Keep only the columns we need (guard against extra columns)
cols = [c for c in ["epoch", "train_loss", "val_loss"] if c in df.columns]
df = df[cols].dropna()
# If multiple rows per epoch exist, keep the last report per epoch
if "epoch" in df.columns:
df = df.sort_index().groupby("epoch", as_index=False).last()
# Plot
plt.figure(figsize=(8, 5))
if "train_loss" in df.columns:
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train Loss")
if "val_loss" in df.columns:
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train/Val Loss across Epochs")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
14. Demonstrate fault-tolerant resumption#
To prove that checkpointing works, run trainer.fit() a second time without changing anything. If the earlier run crashed mid-epoch, Ray automatically picks up the latest checkpoint and continues. If it already finished, Ray starts a clean new experiment.
# 14. Run the trainer again to demonstrate resuming from latest checkpoint
result = trainer.fit()
print("Final metrics:", result.metrics)
15. Batch inference with Ray Data#
Define a stateful, GPU-backed batch inference pipeline using Ray Data.
Each actor loads the model once per GPU, keeps it in memory, and performs inference on incoming batches in parallel.
This pattern scales efficiently across multiple GPUs and avoids redundant model loading for every prediction.
# 15. Batch inference with Ray Data (force GPU actors if available on the cluster)
import ray.data as rdata
class ImageBatchPredictor:
"""Stateful per-actor batch predictor that keeps the model in memory."""
def __init__(self, checkpoint_path: str):
# Pick the best available device on the ACTOR (worker), not the driver.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# === Load model & weights once per actor ===
model = resnet18(num_classes=101)
checkpoint = Checkpoint.from_directory(checkpoint_path)
with checkpoint.as_directory() as ckpt_dir:
state_dict = torch.load(
os.path.join(ckpt_dir, "model.pt"),
map_location=self.device,
)
# Strip DDP "module." prefix if present
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
self.model = model.eval().to(self.device)
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
torch.set_grad_enabled(False)
def __call__(self, batch):
"""batch: Pandas DataFrame with columns ['image_bytes', 'label']"""
imgs = []
for b in batch["image_bytes"]:
img = Image.open(io.BytesIO(b)).convert("RGB")
imgs.append(self.transform(img).numpy()) # (C,H,W) as numpy
x = torch.from_numpy(np.stack(imgs, axis=0)).to(self.device) # (N,C,H,W)
logits = self.model(x)
preds = torch.argmax(logits, dim=1).cpu().numpy()
out = batch.copy()
out["predicted_label"] = preds.astype(int)
return out[["predicted_label", "label"]]
def build_inference_dataset(
checkpoint_path: str,
parquet_path: str,
*,
num_actors: int = 1,
batch_size: int = 64,
use_gpu_actors: bool = True, # <— default to GPU actors on the cluster
):
"""
Create a Ray Dataset pipeline that performs batch inference using
stateful per-actor model loading. By default, requests 1 GPU per actor
so each actor runs on a GPU worker (driver may have no GPU).
"""
ds = rdata.read_parquet(parquet_path, columns=["image_bytes", "label"])
pred_ds = ds.map_batches(
ImageBatchPredictor, # pass the CLASS (stateful actors)
fn_constructor_args=(checkpoint_path,), # ctor args for each actor
batch_size=batch_size,
batch_format="pandas",
concurrency=num_actors, # number of actor workers
num_gpus=1 if use_gpu_actors else 0, # <— force GPU placement on workers
)
return pred_ds
16. Run and visualize Ray Data inference#
Use the best checkpoint to run Ray Data Inference on a validation sample.
The model loads once per GPU actor, batches and parallelizes predictions, and visualizes the result alongside the ground-truth label for quick qualitative evaluation.
# 16. Perform inference with Ray Data using the best checkpoint
checkpoint_root = "/mnt/cluster_storage/food101_lite/results/food101_ft_resume"
checkpoint_dirs = sorted(
[
d for d in os.listdir(checkpoint_root)
if d.startswith("checkpoint_") and os.path.isdir(os.path.join(checkpoint_root, d))
],
reverse=True,
)
if not checkpoint_dirs:
raise FileNotFoundError("No checkpoint directories found.")
# Use the best checkpoint from the training result
with result.checkpoint.as_directory() as ckpt_dir:
print("Best checkpoint contents:", os.listdir(ckpt_dir))
best_ckpt_path = ckpt_dir
parquet_path = "/mnt/cluster_storage/food101_lite/val.parquet"
# Which item to visualize
idx = 2
# Build a Ray Data inference pipeline (model is loaded once per GPU actor)
pred_ds = build_inference_dataset(
checkpoint_path=best_ckpt_path,
parquet_path=parquet_path,
num_actors=1, # adjust to scale out
batch_size=64, # adjust for throughput
)
# Materialize predictions up to the desired index and grab the row
import itertools
row_iter = pred_ds.iter_rows()
inference_row = next(itertools.islice(row_iter, idx, idx + 1)) # {"predicted_label": ..., "label": ...}
print(inference_row)
# Load label map from Hugging Face (for pretty titles)
ds_tmp = load_dataset("food101", split="train[:1%]") # just to get label names
label_names = ds_tmp.features["label"].names
# Load the raw image locally for visualization
dataset = Food101Dataset(parquet_path, transform=None)
img, _ = dataset[idx]
# Plot the image with predicted and true labels
plt.imshow(img)
plt.axis("off")
plt.title(
f"Pred: {label_names[int(inference_row['predicted_label'])]}\n"
f"True: {label_names[int(inference_row['label'])]}"
)
plt.show()
17. Clean up#
Finally, tidy up by deleting temporary checkpoint folders, the metrics CSV, and any intermediate result directories. Clearing out old artifacts frees disk space and leaves your workspace clean for whatever comes next.
# 17. Cleanup---delete checkpoints and metrics from model training
# Base directory
BASE_DIR = "/mnt/cluster_storage/food101_lite"
# Paths to clean
paths_to_delete = [
os.path.join(BASE_DIR, "tmp_checkpoints"), # custom checkpoints
os.path.join(BASE_DIR, "results", "history.csv"), # metrics history file
os.path.join(BASE_DIR, "results", "food101_ft_resume"), # ray trainer run dir
os.path.join(BASE_DIR, "results", "food101_ft_run"),
os.path.join(BASE_DIR, "results", "food101_single_run"),
]
# Delete each path if it exists
for path in paths_to_delete:
if os.path.exists(path):
if os.path.isfile(path):
os.remove(path)
print(f"Deleted file: {path}")
else:
shutil.rmtree(path)
print(f"Deleted directory: {path}")
else:
print(f"Not found (skipped): {path}")
Wrap up and next steps#
You’ve taken a realistic computer-vision workload, from raw images all the way to distributed training and GPU inference, and run it on Ray Train with zero boilerplate around GPUs, data parallelism, or fault-tolerance. This tutorial demonstrates:
Using Ray Train’s TorchTrainer to scale PyTorch training across multiple GPUs and nodes with minimal code changes.
Wrapping models and data loaders with
prepare_model()andprepare_data_loader()to enable Ray-managed device placement and distributed execution.Sharding data across workers and coordinating training epochs across Ray workers.
Configuring automatic checkpointing and failure recovery using Ray Train’s built-in
Checkpoint,RunConfig, andFailureConfigAPIs.Running Ray Data-based inference for distributed inference, showing how to serve and scale model predictions across a Ray cluster.
Next steps#
Below are a few directions you might explore to adapt or extend the pattern:
Larger or custom datasets
Swap in the full 75 k-image Food-101 split—or your own dataset in any storage backend (S3, GCS, Azure Blob).
Add multi-file Parquet sharding and let each worker read a different shard.
Model architectures
Drop in Vision Transformers (
vit_b_16,vit_l_32) or ConvNeXt; the prepare helpers work exactly the same.Experiment with transfer learning versus training from scratch.
Mixed precision and performance tuning
Enable automatic mixed precision (
torch.cuda.amp) or bfloat16 to speed up training and save memory.Profile data-loading throughput and play with
num_workers, prefetching, and caching.
Hyperparameter sweeps
Wrap the training loop in Ray Tune to search over learning rates, augmentations, or optimizers.
Use Ray’s integrated reporting to schedule early stopping.
Data augmentation pipelines
Integrate additional transforms inside the dataset class for image augmentation.
Compare CPU versus GPU-side augmentations for throughput.
Distributed validation and metrics
Replace your simple accuracy printout with more advanced metrics (F1, top-5 accuracy, confusion matrices).
Model serving
Convert the remote inference helper into a Ray Serve deployment for low-latency online predictions.
Auto-scale replicas based on request volume.
End-to-end MLOps
Register checkpoints in a model registry (for example, MLflow, Weights & Biases, or Ray’s built-in MLflow integration).
Schedule the notebook as a Ray Job or CI/CD pipeline for regular retraining runs.