How to Save and Load Trial Checkpoints#
Trial checkpoints are one of the three types of data stored by Tune. These are user-defined and are meant to snapshot your training progress!
Trial-level checkpoints are saved via the Tune Trainable API: this is how you define your custom training logic, and it’s also where you’ll define which trial state to checkpoint. In this guide, we will show how to save and load checkpoints for Tune’s Function Trainable and Class Trainable APIs, as well as walk you through configuration options.
Function API Checkpointing#
If using Ray Tune’s Function API, one can save and load checkpoints in the following manner.
To create a checkpoint, use the from_directory()
APIs.
import os
import tempfile
from ray import train, tune
from ray.train import Checkpoint
def train_func(config):
start = 1
my_model = MyModel()
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
checkpoint_dict = torch.load(os.path.join(checkpoint_dir, "checkpoint.pt"))
start = checkpoint_dict["epoch"] + 1
my_model.load_state_dict(checkpoint_dict["model_state"])
for epoch in range(start, config["epochs"] + 1):
# Model training here
# ...
metrics = {"metric": 1}
with tempfile.TemporaryDirectory() as tempdir:
torch.save(
{"epoch": epoch, "model_state": my_model.state_dict()},
os.path.join(tempdir, "checkpoint.pt"),
)
train.report(metrics=metrics, checkpoint=Checkpoint.from_directory(tempdir))
tuner = tune.Tuner(train_func, param_space={"epochs": 5})
result_grid = tuner.fit()
In the above code snippet:
We implement checkpoint saving with
train.report(..., checkpoint=checkpoint)
. Note that every checkpoint must be reported alongside a set of metrics – this way, checkpoints can be ordered with respect to a specified metric.The saved checkpoint during training iteration
epoch
is saved to the path<storage_path>/<exp_name>/<trial_name>/checkpoint_<epoch>
on the node on which training happens and can be further synced to a consolidated storage location depending on the storage configuration.We implement checkpoint loading with
train.get_checkpoint()
. This will be populated with a trial’s latest checkpoint whenever Tune restores a trial. This happens when (1) a trial is configured to retry after encountering a failure, (2) the experiment is being restored, and (3) the trial is being resumed after a pause (ex: PBT).
Note
checkpoint_frequency
and checkpoint_at_end
will not work with Function API checkpointing.
These are configured manually with Function Trainable. For example, if you want to checkpoint every three
epochs, you can do so through:
NUM_EPOCHS = 12
# checkpoint every three epochs.
CHECKPOINT_FREQ = 3
def train_func(config):
for epoch in range(1, config["epochs"] + 1):
# Model training here
# ...
# Report metrics and save a checkpoint
metrics = {"metric": "my_metric"}
if epoch % CHECKPOINT_FREQ == 0:
with tempfile.TemporaryDirectory() as tempdir:
# Save a checkpoint in tempdir.
train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
else:
train.report(metrics)
tuner = tune.Tuner(train_func, param_space={"epochs": NUM_EPOCHS})
result_grid = tuner.fit()
Class API Checkpointing#
You can also implement checkpoint/restore using the Trainable Class API:
import os
import torch
from torch import nn
from ray import train, tune
class MyTrainableClass(tune.Trainable):
def setup(self, config):
self.model = nn.Sequential(
nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
)
def step(self):
return {}
def save_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
torch.save(self.model.state_dict(), checkpoint_path)
return tmp_checkpoint_dir
def load_checkpoint(self, tmp_checkpoint_dir):
checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
self.model.load_state_dict(torch.load(checkpoint_path))
tuner = tune.Tuner(
MyTrainableClass,
param_space={"input_size": 64},
run_config=train.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=train.CheckpointConfig(checkpoint_frequency=2),
),
)
tuner.fit()
You can checkpoint with three different mechanisms: manually, periodically, and at termination.
Manual Checkpointing#
A custom Trainable can manually trigger checkpointing by returning should_checkpoint: True
(or tune.result.SHOULD_CHECKPOINT: True
) in the result dictionary of step
.
This can be especially helpful in spot instances:
import random
# to be implemented by user.
def detect_instance_preemption():
choice = random.randint(1, 100)
# simulating a 1% chance of preemption.
return choice <= 1
def train_func(self):
# training code
result = {"mean_accuracy": "my_accuracy"}
if detect_instance_preemption():
result.update(should_checkpoint=True)
return result
In the above example, if detect_instance_preemption
returns True, manual checkpointing can be triggered.
Periodic Checkpointing#
This can be enabled by setting checkpoint_frequency=N
to checkpoint trials every N iterations, e.g.:
tuner = tune.Tuner(
MyTrainableClass,
run_config=train.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=train.CheckpointConfig(checkpoint_frequency=10),
),
)
tuner.fit()
Checkpointing at Termination#
The checkpoint_frequency may not coincide with the exact end of an experiment.
If you want a checkpoint to be created at the end of a trial, you can additionally set the checkpoint_at_end=True
:
tuner = tune.Tuner(
MyTrainableClass,
run_config=train.RunConfig(
stop={"training_iteration": 2},
checkpoint_config=train.CheckpointConfig(
checkpoint_frequency=10, checkpoint_at_end=True
),
),
)
tuner.fit()
Configurations#
Checkpointing can be configured through CheckpointConfig
.
Some of the configurations do not apply to Function Trainable API, since checkpointing frequency
is determined manually within the user-defined training loop. See the compatibility matrix below.
Class API |
Function API |
|
---|---|---|
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
✅ |
|
✅ |
❌ |
|
✅ |
❌ |
Summary#
In this user guide, we covered how to save and load trial checkpoints in Tune. Once checkpointing is enabled, move onto one of the following guides to find out how to:
Appendix: Types of data stored by Tune#
Experiment Checkpoints#
Experiment-level checkpoints save the experiment state. This includes the state of the searcher, the list of trials and their statuses (e.g., PENDING, RUNNING, TERMINATED, ERROR), and metadata pertaining to each trial (e.g., hyperparameter configuration, some derived trial results (min, max, last), etc).
The experiment-level checkpoint is periodically saved by the driver on the head node. By default, the frequency at which it is saved is automatically adjusted so that at most 5% of the time is spent saving experiment checkpoints, and the remaining time is used for handling training results and scheduling. This time can also be adjusted with the TUNE_GLOBAL_CHECKPOINT_S environment variable.
Trial Checkpoints#
Trial-level checkpoints capture the per-trial state. This often includes the model and optimizer states. Following are a few uses of trial checkpoints:
If the trial is interrupted for some reason (e.g., on spot instances), it can be resumed from the last state. No training time is lost.
Some searchers or schedulers pause trials to free up resources for other trials to train in the meantime. This only makes sense if the trials can then continue training from the latest state.
The checkpoint can be later used for other downstream tasks like batch inference.
Learn how to save and load trial checkpoints here.
Trial Results#
Metrics reported by trials are saved and logged to their respective trial directories. This is the data stored in CSV, JSON or Tensorboard (events.out.tfevents.*) formats. that can be inspected by Tensorboard and used for post-experiment analysis.