Training in Tune (tune.Trainable, session.report)#

Training can be done with either a Function API (session.report) or Class API (tune.Trainable).

For the sake of example, let’s maximize this objective function:

def objective(x, a, b):
    return a * (x ** 0.5) + b

Tune’s Function API#

The Function API allows you to define a custom training function that Tune will run in parallel Ray actor processes, one for each Tune trial.

The config argument in the function is a dictionary populated automatically by Ray Tune and corresponding to the hyperparameters selected for the trial from the search space.

With the Function API, you can report intermediate metrics by simply calling session.report within the function.

from ray import tune
from ray.air import session


def trainable(config: dict):
    intermediate_score = 0
    for x in range(20):
        intermediate_score = objective(x, config["a"], config["b"])
        session.report({"score": intermediate_score})  # This sends the score to Tune.


tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()

Tip

Do not use session.report within a Trainable class.

In the previous example, we reported on every step, but this metric reporting frequency is configurable. For example, we could also report only a single time at the end with the final score:

from ray import tune
from ray.air import session


def trainable(config: dict):
    final_score = 0
    for x in range(20):
        final_score = objective(x, config["a"], config["b"])

    session.report({"score": final_score})  # This sends the score to Tune.


tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results = tuner.fit()

It’s also possible to return a final set of metrics to Tune by returning them from your function:

def trainable(config: dict):
    final_score = 0
    for x in range(20):
        final_score = objective(x, config["a"], config["b"])

    return {"score": final_score}  # This sends the score to Tune.

You’ll notice that Ray Tune will output extra values in addition to the user reported metrics, such as iterations_since_restore. See How to use log metrics in Tune? for an explanation/glossary of these values.

Function API Checkpointing#

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. You can save and load checkpoints in Ray Tune in the following manner:

from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint


def train_func(config):
    epochs = config.get("epochs", 2)
    start = 0
    loaded_checkpoint = session.get_checkpoint()
    if loaded_checkpoint:
        last_step = loaded_checkpoint.to_dict()["step"]
        start = last_step + 1

    for step in range(start, epochs):
        # Model training here
        # ...

        # Report metrics and save a checkpoint
        metrics = {"metric": "my_metric"}
        checkpoint = Checkpoint.from_dict({"step": step})
        session.report(metrics, checkpoint=checkpoint)


tuner = tune.Tuner(train_func)
results = tuner.fit()

Note

checkpoint_frequency and checkpoint_at_end will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to <local_dir>/<exp_name>/trial_name/checkpoint_<step>.

Tune also may copy or move checkpoints during the course of tuning. For this purpose, it is important not to depend on absolute paths in the implementation of save.

See here for more information on creating checkpoints. If using framework-specific trainers from Ray AIR, see here for references to framework-specific checkpoints such as TensorflowCheckpoint.

Tune’s Trainable Class API#

Caution

Do not use session.report within a Trainable class.

The Trainable class API will require users to subclass ray.tune.Trainable. Here’s a naive example of this API:

from ray import air, tune


class Trainable(tune.Trainable):
    def setup(self, config: dict):
        # config (dict): A dict of hyperparameters
        self.x = 0
        self.a = config["a"]
        self.b = config["b"]

    def step(self):  # This is called iteratively.
        score = objective(self.x, self.a, self.b)
        self.x += 1
        return {"score": score}


tuner = tune.Tuner(
    Trainable,
    run_config=air.RunConfig(
        # Train for 20 steps
        stop={"training_iteration": 20},
        checkpoint_config=air.CheckpointConfig(
            # We haven't implemented checkpointing yet. See below!
            checkpoint_at_end=False
        ),
    ),
    param_space={"a": 2, "b": 4},
)
results = tuner.fit()

As a subclass of tune.Trainable, Tune will create a Trainable object on a separate process (using the Ray Actor API).

  1. setup function is invoked once training starts.

  2. step is invoked multiple times. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.

  3. cleanup is invoked when training is finished.

The config argument in the setup method is a dictionary populated automatically by Tune and corresponding to the hyperparameters selected for the trial from the search space.

Tip

As a rule of thumb, the execution time of step should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

You’ll notice that Ray Tune will output extra values in addition to the user reported metrics, such as iterations_since_restore. See How to use log metrics in Tune? for an explanation/glossary of these values.

Class API Checkpointing#

You can also implement checkpoint/restore using the Trainable Class API:

import os
import torch
from torch import nn

from ray import air, tune


class MyTrainableClass(tune.Trainable):
    def setup(self, config):
        self.model = nn.Sequential(
            nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)
        )

    def step(self):
        return {}

    def save_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return tmp_checkpoint_dir

    def load_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))


tuner = tune.Tuner(
    MyTrainableClass,
    param_space={"input_size": 64},
    run_config=air.RunConfig(
        stop={"training_iteration": 2},
        checkpoint_config=air.CheckpointConfig(checkpoint_frequency=2),
    ),
)
tuner.fit()

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

Manual Checkpointing: A custom Trainable can manually trigger checkpointing by returning should_checkpoint: True (or tune.result.SHOULD_CHECKPOINT: True) in the result dictionary of step. This can be especially helpful in spot instances:

def step(self):
    # training code
    result = {"mean_accuracy": accuracy}
    if detect_instance_preemption():
        result.update(should_checkpoint=True)
    return result

Periodic Checkpointing: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting checkpoint_frequency=<int> and max_failures=<int> to checkpoint trials every N iterations and recover from up to M crashes per trial, e.g.:

tuner = tune.Tuner(
    my_trainable,
    run_config=air.RunConfig(
        checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10),
        failure_config=air.FailureConfig(max_failures=5))
)
results = tuner.fit()

Checkpointing at Termination: The checkpoint_frequency may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end of a trial, you can additionally set the checkpoint_at_end=True:

 tuner = tune.Tuner(
     my_trainable,
     run_config=air.RunConfig(
         checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10, checkpoint_at_end=True),
         failure_config=air.FailureConfig(max_failures=5))
 )
 results = tuner.fit()

Use validate_save_restore to catch save_checkpoint/load_checkpoint errors before execution.

from ray.tune.utils import validate_save_restore

# both of these should return
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)

Advanced: Reusing Actors in Tune#

Note

This feature is only for the Trainable Class API.

Your Trainable can often take a long time to start. To avoid this, you can do tune.TuneConfig(reuse_actors=True) (which is taken in by Tuner) to reuse the same Trainable Python process and object for multiple hyperparameters.

This requires you to implement Trainable.reset_config, which provides a new set of hyperparameters. It is up to the user to correctly update the hyperparameters of your trainable.

class PytorchTrainble(tune.Trainable):
    """Train a Pytorch ConvNet."""

    def setup(self, config):
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config.get("lr", 0.01),
            momentum=config.get("momentum", 0.9))

    def reset_config(self, new_config):
        for param_group in self.optimizer.param_groups:
            if "lr" in new_config:
                param_group["lr"] = new_config["lr"]
            if "momentum" in new_config:
                param_group["momentum"] = new_config["momentum"]

        self.model = ConvNet()
        self.config = new_config
        return True

Comparing Tune’s Function API and Class API#

Here are a few key concepts and what they look like for the Function and Class API’s.

Concept

Function API

Class API

Training Iteration

Increments on each session.report call

Increments on each Trainable.step call

Report metrics

session.report(metrics)

Return metrics from Trainable.step

Saving a checkpoint

session.report(..., checkpoint=checkpoint)

Trainable.save_checkpoint

Loading a checkpoint

session.get_checkpoint()

Trainable.load_checkpoint

Accessing config

Passed as an argument def train_func(config):

Passed through Trainable.setup

Advanced Resource Allocation#

Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to add more bundles to the PlacementGroupFactory to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should use tune.with_resources like this:

 tuner = tune.Tuner(
     tune.with_resources(my_trainable, tune.PlacementGroupFactory([
         {"CPU": 1, "GPU": 1},
         {"GPU": 1},
         {"GPU": 1},
         {"GPU": 1},
         {"GPU": 1}
     ])),
     run_config=air.RunConfig(name="my_trainable")
 )

The Trainable also provides the default_resource_requests interface to automatically declare the resources per trial based on the given configuration.

It is also possible to specify memory ("memory", in bytes) and custom resource requirements.

session (Function API)#

air.session.report(metrics, *[, checkpoint])

Report metrics and optionally save a checkpoint.

air.session.get_checkpoint()

Access the session's last checkpoint to resume from if applicable.

air.session.get_trial_name()

Trial name for the corresponding trial.

air.session.get_trial_id()

Trial id for the corresponding trial.

air.session.get_trial_resources()

Trial resources for the corresponding trial.

air.session.get_trial_dir()

Log directory corresponding to the trial directory for a Tune session.

Trainable (Class API)#

Constructor#

Trainable([config, logger_creator, ...])

Abstract class for trainable models, functions, etc.

Trainable Methods to Implement#

setup(config)

Subclasses should override this for custom initialization.

save_checkpoint(checkpoint_dir)

Subclasses should override this to implement save().

load_checkpoint(checkpoint)

Subclasses should override this to implement restore().

step()

Subclasses should override this to implement train().

reset_config(new_config)

Resets configuration without restarting the trial.

cleanup()

Subclasses should override this for any cleanup on stop.

default_resource_request(config)

Provides a static resource requirement for the given configuration.

Tune Trainable Utilities#

Tune Data Ingestion Utilities#

tune.with_parameters(trainable, **kwargs)

Wrapper for trainables to pass arbitrary large data objects.

Tune Resource Assignment Utilities#

tune.with_resources(trainable, resources)

Wrapper for trainables to specify resource requests.

PlacementGroupFactory(bundles[, strategy])

Wrapper class that creates placement groups for trials.

tune.utils.wait_for_gpu([gpu_id, ...])

Checks if a given GPU has freed memory.

Tune Trainable Debugging Utilities#

tune.utils.diagnose_serialization(trainable)

Utility for detecting why your trainable function isn't serializing.

tune.utils.validate_save_restore(trainable_cls)

Helper method to check if your Trainable class will resume correctly.