Validating checkpoints asynchronously#
During training, you may want to validate the model periodically to monitor training progress. The standard way to do this is to periodically switch between training and validation within the training loop. Instead, Ray Train allows you to asynchronously validate the model in a separate Ray task, which has following benefits:
Running validation in parallel without blocking the training loop
Running validation on different hardware than training
Leveraging autoscaling to launch user-specified machines only for the duration of the validation
Letting training continue immediately after saving a checkpoint with partial metrics (for example, loss) and then receiving validation metrics (for example, accuracy) as soon as they are available. If the initial and validated metrics share the same key, the validated metrics overwrite the initial metrics.
Tutorial#
First, define a validate_fn
that takes a ray.train.Checkpoint
to validate
and an optional validate_config
dictionary. This dictionary can contain arguments needed
for validation, such as the validation dataset. Your function should return a dictionary of metrics
from that validation. The following is a simple example for teaching purposes only. It is impractical
because the validation task always runs on cpu; for a more realistic example, see
Write a distributed validation function.
import os
import torch
import ray.train
def validate_fn(checkpoint: ray.train.Checkpoint, config: dict) -> dict:
# Load the checkpoint
model = ...
with checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
model.load_state_dict(model_state_dict)
model.eval()
# Perform validation on the data
total_accuracy = 0
dataset = config["dataset"]
with torch.no_grad():
for batch in dataset.iter_torch_batches(batch_size=128):
images, labels = batch["image"], batch["label"]
outputs = model(images)
total_accuracy += (outputs.argmax(1) == labels).sum().item()
return {"score": total_accuracy / len(dataset)}
Warning
Don’t pass large objects to the validate_fn
because Ray Train runs it as a Ray task and
serializes all captured variables. Instead, package large objects in the Checkpoint
and
access them from shared storage later as explained in Saving and Loading Checkpoints.
Next, within your training loop, call ray.train.report()
with validate_fn
and
validate_config
as arguments from the rank 0 worker like the following:
import tempfile
import ray.data
def train_func(config: dict) -> None:
...
epochs = ...
model = ...
rank = ray.train.get_context().get_world_rank()
for epoch in epochs:
... # training step
if rank == 0:
training_metrics = {"loss": ..., "epoch": epoch}
local_checkpoint_dir = tempfile.mkdtemp()
torch.save(
model.module.state_dict(),
os.path.join(local_checkpoint_dir, "model.pt"),
)
ray.train.report(
training_metrics,
checkpoint=ray.train.Checkpoint.from_directory(local_checkpoint_dir),
checkpoint_upload_mode=ray.train.CheckpointUploadMode.ASYNC,
validate_fn=validate_fn,
validate_config={
"dataset": config["validation_dataset"],
"train_run_name": ray.train.get_context().get_experiment_name(),
"epoch": epoch,
},
)
else:
ray.train.report({}, None)
def run_trainer() -> ray.train.Result:
train_dataset = ray.data.read_parquet(...)
validation_dataset = ray.data.read_parquet(...)
trainer = ray.train.torch.TorchTrainer(
train_func,
# Pass training dataset in datasets arg to split it across training workers
datasets={"train": train_dataset},
# Pass validation dataset in train_loop_config so validate_fn can choose how to use it later
train_loop_config={"validation_dataset": validation_dataset},
scaling_config=ray.train.ScalingConfig(
num_workers=2,
use_gpu=True,
# Use powerful GPUs for training
accelerator_type="A100",
),
)
return trainer.fit()
Finally, after training is done, you can access your checkpoints and their associated metrics with the
ray.train.Result
object. See Inspecting Training Results for more details.
Write a distributed validation function#
The validate_fn
above runs in a single Ray task, but you can improve its performance by spawning
even more Ray tasks or actors. The Ray team recommends doing this with one of the following approaches:
Creating a
ray.train.torch.TorchTrainer
that only does validation, not training.Using
ray.data.Dataset.map_batches()
to calculate metrics on a validation set.
Choose an approach#
You should use TorchTrainer
if:
You want to keep your existing validation logic and avoid migrating to Ray Data. The training function API lets you fully customize the validation loop to match your current setup.
Your validation code depends on running within a Torch process group — for example, your metric aggregation logic uses collective communication calls, or your model parallelism setup requires cross-GPU communication during the forward pass.
You should use map_batches
if:
You care about validation performance. Preliminary benchmarks show that
map_batches
is faster.You prefer Ray Data’s native metric aggregation APIs over PyTorch, where you must implement aggregation manually using low-level collective operations or rely on third-party libraries such as torchmetrics.
Example: validation with Ray Train TorchTrainer#
Here is a validate_fn
that uses a TorchTrainer
to calculate average cross entropy
loss on a validation set. Note the following about this example:
It
report
s a dummy checkpoint so that theTorchTrainer
keeps the metrics.While you typically use the
TorchTrainer
for training, you can use it solely for validation like in this example.Because training generally has a higher GPU memory requirement than inference, you can set different resource requirements for training and validation, for example, A100 for training and A10G for validation.
import torchmetrics
from torch.nn import CrossEntropyLoss
import ray.train.torch
def eval_only_train_fn(config_dict: dict) -> None:
# Load the checkpoint
model = ...
with config_dict["checkpoint"].as_directory() as checkpoint_dir:
model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
model.load_state_dict(model_state_dict)
model.cuda().eval()
# Set up metrics and data loaders
criterion = CrossEntropyLoss()
mean_valid_loss = torchmetrics.MeanMetric().cuda()
test_data_shard = ray.train.get_dataset_shard("validation")
test_dataloader = test_data_shard.iter_torch_batches(batch_size=128)
# Compute and report metric
with torch.no_grad():
for batch in test_dataloader:
images, labels = batch["image"], batch["label"]
outputs = model(images)
loss = criterion(outputs, labels)
mean_valid_loss(loss)
ray.train.report(
metrics={"score": mean_valid_loss.compute().item()},
checkpoint=ray.train.Checkpoint(
ray.train.get_context()
.get_storage()
.build_checkpoint_path_from_name("placeholder")
),
checkpoint_upload_mode=ray.train.CheckpointUploadMode.NO_UPLOAD,
)
def validate_fn(checkpoint: ray.train.Checkpoint, config: dict) -> dict:
trainer = ray.train.torch.TorchTrainer(
eval_only_train_fn,
train_loop_config={"checkpoint": checkpoint},
scaling_config=ray.train.ScalingConfig(
num_workers=2, use_gpu=True, accelerator_type="A10G"
),
# Name validation run to easily associate it with training run
run_config=ray.train.RunConfig(
name=f"{config['train_run_name']}_validation_epoch_{config['epoch']}"
),
# User weaker GPUs for validation
datasets={"validation": config["dataset"]},
)
result = trainer.fit()
return result.metrics
Example: validation with Ray Data map_batches#
The following is a validate_fn
that uses ray.data.Dataset.map_batches()
to
calculate average accuracy on a validation set. To learn more about how to use
map_batches
for batch inference, see End-to-end: Offline Batch Inference.
class Predictor:
def __init__(self, checkpoint: ray.train.Checkpoint):
self.model = ...
with checkpoint.as_directory() as checkpoint_dir:
model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
self.model.load_state_dict(model_state_dict)
self.model.cuda().eval()
def __call__(self, batch: dict) -> dict:
image = torch.as_tensor(batch["image"], dtype=torch.float32, device="cuda")
label = torch.as_tensor(batch["label"], dtype=torch.float32, device="cuda")
pred = self.model(image)
return {"res": (pred.argmax(1) == label).cpu().numpy()}
def validate_fn(checkpoint: ray.train.Checkpoint, config: dict) -> dict:
# Set name to avoid confusion; default name is "Dataset"
config["dataset"].set_name("validation")
eval_res = config["dataset"].map_batches(
Predictor,
batch_size=128,
num_gpus=1,
fn_constructor_kwargs={"checkpoint": checkpoint},
concurrency=2,
)
mean = eval_res.mean(["res"])
return {
"score": mean,
}
Checkpoint metrics lifecycle#
During the training loop the following happens to your checkpoints and metrics :
You report a checkpoint with some initial metrics, such as training loss, as well as a
validate_fn
andvalidate_config
.Ray Train asynchronously runs your
validate_fn
with that checkpoint andvalidate_config
in a new Ray task.When that validation task completes, Ray Train associates the metrics returned by your
validate_fn
with that checkpoint.After training is done, you can access your checkpoints and their associated metrics with the
ray.train.Result
object. See Inspecting Training Results for more details.

How Ray Train populates checkpoint metrics during training and how you access them after training.#