Handling Failures and Node Preemption#

Automatically Recover from Train Worker Failures#

Ray Train has built-in fault tolerance to recover from worker failures (i.e. RayActorErrors). When a failure is detected, the workers will be shut down and new workers will be added in.

The training function will be restarted, but progress from the previous execution can be resumed through checkpointing.

Tip

In order to retain progress when recovery, your training function must implement logic for both saving and loading checkpoints.

Each instance of recovery from a worker failure is considered a retry. The number of retries is configurable through the max_failures attribute of the FailureConfig argument set in the RunConfig passed to the Trainer:

from ray.train import RunConfig, FailureConfig


# Tries to recover a run up to this many times.
run_config = RunConfig(failure_config=FailureConfig(max_failures=2))

# No limit on the number of retries.
run_config = RunConfig(failure_config=FailureConfig(max_failures=-1))

Which checkpoint will be restored?#

Ray Train will automatically resume training from the latest available checkpoint reported to Ray Train.

This will be the last checkpoint passed to train.report().

Restore a Ray Train Experiment#

At the experiment level, Trainer restoration allows you to resume a previously interrupted experiment from where it left off.

A Train experiment may be interrupted due to one of the following reasons:

  • The experiment was manually interrupted (e.g., Ctrl+C, or pre-empted head node instance).

  • The head node crashed (e.g., OOM or some other runtime error).

  • The entire cluster went down (e.g., network error affecting all nodes).

Trainer restoration is possible for all of Ray Train’s built-in trainers, but we use TorchTrainer in the examples for demonstration. We also use <Framework>Trainer to refer to methods that are shared across all built-in trainers.

Let’s say your initial Train experiment is configured as follows. The actual training loop is just for demonstration purposes: the important detail is that saving and loading checkpoints has been implemented.

import os
import tempfile
from typing import Dict, Optional

import torch

import ray
from ray import train
from ray.train import Checkpoint
from ray.train.torch import TorchTrainer


def get_datasets() -> Dict[str, ray.data.Dataset]:
    return {"train": ray.data.from_items([{"x": i, "y": 2 * i} for i in range(10)])}


def train_loop_per_worker(config: dict):
    from torchvision.models import resnet18

    model = resnet18()

    # Checkpoint loading
    checkpoint: Optional[Checkpoint] = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
            model.load_state_dict(model_state_dict)

    model = train.torch.prepare_model(model)

    train_ds = train.get_dataset_shard("train")

    for epoch in range(5):
        # Do some training...

        # Checkpoint saving
        with tempfile.TemporaryDirectory() as tmpdir:
            torch.save(model.module.state_dict(), os.path.join(tmpdir, "model.pt"))
            train.report({"epoch": epoch}, checkpoint=Checkpoint.from_directory(tmpdir))


trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    datasets=get_datasets(),
    scaling_config=train.ScalingConfig(num_workers=2),
    run_config=train.RunConfig(
        name="dl_trainer_restore", storage_path=os.path.expanduser("~/ray_results")
    ),
)
result = trainer.fit()

The results and checkpoints of the experiment are saved to the path configured by RunConfig. If the experiment has been interrupted due to one of the reasons listed above, use this path to resume:

from ray.train.torch import TorchTrainer

restored_trainer = TorchTrainer.restore(
    path=os.path.expanduser("~/ray_results/dl_trainer_restore"),
    datasets=get_datasets(),
)

Tip

You can also restore from a remote path (e.g., from an experiment directory stored in a s3 bucket).

original_trainer = TorchTrainer(
    # ...
    run_config=train.RunConfig(
        # Configure cloud storage
        storage_path="s3://results-bucket",
        name="dl_trainer_restore",
    ),
)
result = trainer.fit()
restored_trainer = TorchTrainer.restore(
    "s3://results-bucket/dl_trainer_restore",
    datasets=get_datasets(),
)

Note

Different trainers may allow more parameters to be optionally re-specified on restore. Only datasets are required to be re-specified on restore, if they were supplied originally.

TorchTrainer.restore, TensorflowTrainer.restore, and HorovodTrainer.restore can take in the same parameters as their parent class’s DataParallelTrainer.restore.

Unless otherwise specified, other trainers will accept the same parameters as BaseTrainer.restore.

Auto-resume#

Adding the branching logic below will allow you to run the same script after the interrupt, picking up training from where you left on the previous run. Notice that we use the <Framework>Trainer.can_restore utility method to determine the existence and validity of the given experiment directory.

experiment_path = os.path.expanduser("~/ray_results/dl_restore_autoresume")
if TorchTrainer.can_restore(experiment_path):
    trainer = TorchTrainer.restore(experiment_path, datasets=get_datasets())
    result = trainer.fit()
else:
    trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        datasets=get_datasets(),
        scaling_config=train.ScalingConfig(num_workers=2),
        run_config=train.RunConfig(
            storage_path=os.path.expanduser("~/ray_results"),
            name="dl_restore_autoresume",
        ),
    )
result = trainer.fit()

See also

See the BaseTrainer.restore docstring for a full example.

Note

<Framework>Trainer.restore is different from <Framework>Trainer(..., resume_from_checkpoint=...). resume_from_checkpoint is meant to be used to start a new Train experiment, which writes results to a new directory and starts over from iteration 0.

<Framework>Trainer.restore is used to continue an existing experiment, where new results will continue to be appended to existing logs.