Training (tune.Trainable, tune.report)

Training can be done with either a Class API (tune.Trainable) or function API (tune.report).

For the sake of example, let’s maximize this objective function:

def objective(x, a, b):
    return a * (x ** 0.5) + b

Function API

Here is a simple example of using the function API. You can report intermediate metrics by simply calling tune.report within the provided function.

def trainable(config):
    # config (dict): A dict of hyperparameters.

    for x in range(20):
        intermediate_score = objective(x, config["a"], config["b"])

        tune.report(value=intermediate_score)  # This sends the score to Tune.

analysis = tune.run(
    trainable,
    config={"a": 2, "b": 4}
)

print("best config: ", analysis.get_best_config(metric="score", mode="max"))

Tip

Do not use tune.report within a Trainable class.

Tune will run this function on a separate thread in a Ray actor process.

Tip

If you want to leverage multi-node data parallel training with PyTorch while using parallel hyperparameter tuning, check out our :ref:PyTorch user guide and Tune’s :ref:distributed pytorch integrations.

Function API Checkpointing

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. To use Tune’s checkpointing features, you must expose a checkpoint_dir argument in the function signature, and call tune.checkpoint_dir :

import time
from ray import tune

def train_func(config, checkpoint_dir=None):
    start = 0
    if checkpoint_dir:
        with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
            state = json.loads(f.read())
            start = state["step"] + 1

    for iter in range(start, 100):
        time.sleep(1)

        with tune.checkpoint_dir(step=step):
            path = os.path.join(checkpoint_dir, "checkpoint")
            with open(path, "w") as f:
                f.write(json.dumps({"step": start}))

        tune.report(hello="world", ray="tune")

tune.run(train_func)

Note

checkpoint_freq and checkpoint_at_end will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to local_dir/exp_name/trial_name/checkpoint_<step>. You can restore a single trial checkpoint by using tune.run(restore=<checkpoint_dir>):

analysis = tune.run(
    train,
    config={
        "max_iter": 5
    },
).trials
last_ckpt = trial.checkpoint.value
analysis = tune.run(train, config={"max_iter": 10}, restore=last_ckpt)

Tune also may copy or move checkpoints during the course of tuning. For this purpose, it is important not to depend on absolute paths in the implementation of save.

Trainable Class API

Caution

Do not use tune.report within a Trainable class.

The Trainable class API will require users to subclass ray.tune.Trainable. Here’s a naive example of this API:

from ray import tune

class Trainable(tune.Trainable):
    def setup(self, config):
        # config (dict): A dict of hyperparameters
        self.x = 0
        self.a = config["a"]
        self.b = config["b"]

    def step(self):  # This is called iteratively.
        score = objective(self.x, self.a, self.b)
        self.x += 1
        return {"score": score}

analysis = tune.run(
    Trainable,
    stop={"training_iteration": 20},
    config={
        "a": 2,
        "b": 4
    })

print('best config: ', analysis.get_best_config(metric="score", mode="max"))

As a subclass of tune.Trainable, Tune will create a Trainable object on a separate process (using the Ray Actor API).

  1. setup function is invoked once training starts.

  2. step is invoked multiple times. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.

  3. cleanup is invoked when training is finished.

Tip

As a rule of thumb, the execution time of step should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

Class API Checkpointing

You can also implement checkpoint/restore using the Trainable Class API:

class MyTrainableClass(Trainable):
    def save_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return tmp_checkpoint_dir

    def load_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")
        self.model.load_state_dict(torch.load(checkpoint_path))

tune.run(MyTrainableClass, checkpoint_freq=2)

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

Manual Checkpointing: A custom Trainable can manually trigger checkpointing by returning should_checkpoint: True (or tune.result.SHOULD_CHECKPOINT: True) in the result dictionary of step. This can be especially helpful in spot instances:

def step(self):
    # training code
    result = {"mean_accuracy": accuracy}
    if detect_instance_preemption():
        result.update(should_checkpoint=True)
    return result

Periodic Checkpointing: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting checkpoint_freq=<int> and max_failures=<int> to checkpoint trials every N iterations and recover from up to M crashes per trial, e.g.:

tune.run(
    my_trainable,
    checkpoint_freq=10,
    max_failures=5,
)

Checkpointing at Termination: The checkpoint_freq may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end of a trial, you can additionally set the checkpoint_at_end=True:

 tune.run(
     my_trainable,
     checkpoint_freq=10,
     checkpoint_at_end=True,
     max_failures=5,
 )

Use validate_save_restore to catch save_checkpoint/load_checkpoint errors before execution.

from ray.tune.utils import validate_save_restore

# both of these should return
validate_save_restore(MyTrainableClass)
validate_save_restore(MyTrainableClass, use_object_store=True)

Advanced: Reusing Actors

Note

This feature is only for the Trainable Class API.

Your Trainable can often take a long time to start. To avoid this, you can do tune.run(reuse_actors=True) to reuse the same Trainable Python process and object for multiple hyperparameters.

This requires you to implement Trainable.reset_config, which provides a new set of hyperparameters. It is up to the user to correctly update the hyperparameters of your trainable.

class PytorchTrainble(tune.Trainable):
    """Train a Pytorch ConvNet."""

    def setup(self, config):
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config.get("lr", 0.01),
            momentum=config.get("momentum", 0.9))

    def reset_config(self, new_config):
        for param_group in self.optimizer.param_groups:
            if "lr" in new_config:
                param_group["lr"] = new_config["lr"]
            if "momentum" in new_config:
                param_group["momentum"] = new_config["momentum"]

        self.model = ConvNet()
        self.config = new_config
        return True

Advanced Resource Allocation

Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to set extra_cpu or extra_gpu inside tune.run to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should set "gpu": 1, "extra_gpu": 4.

 tune.run(
     my_trainable,
     name="my_trainable",
     resources_per_trial={
         "cpu": 1,
         "gpu": 1,
         "extra_gpu": 4
     }
 )

The Trainable also provides the default_resource_requests interface to automatically declare the resources_per_trial based on the given configuration.

tune.report / tune.checkpoint (Function API)

ray.tune.report(**kwargs)[source]

Logs all keyword arguments.

import time
from ray import tune

def run_me(config):
    for iter in range(100):
        time.sleep(1)
        tune.report(hello="world", ray="tune")

analysis = tune.run(run_me)
Parameters

**kwargs – Any key value pair to be logged by Tune. Any of these metrics can be used for early stopping or optimization.

ray.tune.checkpoint_dir(step)[source]

Returns a checkpoint dir inside a context.

Store any files related to restoring state within the provided checkpoint dir.

Parameters

step (int) – Index for the checkpoint. Expected to be a monotonically increasing quantity.

import os
import json
import time
from ray import tune

def func(config, checkpoint_dir=None):
    start = 0
    if checkpoint_dir:
        with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
            state = json.loads(f.read())
            accuracy = state["acc"]
            start = state["step"] + 1

    for iter in range(start, 10):
        time.sleep(1)

        with tune.checkpoint_dir(step=iter) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            with open(path, "w") as f:
                f.write(json.dumps({"step": start}))

        tune.report(hello="world", ray="tune")
Yields

checkpoint_dir (str) – Directory for checkpointing.

New in version 0.8.7.

ray.tune.get_trial_dir()[source]

Returns the directory where trial results are saved.

For function API use only.

ray.tune.get_trial_name()[source]

Trial name for the corresponding trial.

For function API use only.

ray.tune.get_trial_id()[source]

Trial id for the corresponding trial.

For function API use only.

tune.Trainable (Class API)

class ray.tune.Trainable(config=None, logger_creator=None)[source]

Abstract class for trainable models, functions, etc.

A call to train() on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

Calling save() should save the training state of a trainable to disk, and restore(path) should restore a trainable to the given state.

Generally you only need to implement setup, step, save_checkpoint, and load_checkpoint when subclassing Trainable.

Other implementation methods that may be helpful to override are log_result, reset_config, cleanup, and _export_model.

When using Tune, Tune will convert this class into a Ray actor, which runs on a separate process. Tune will also change the current working directory of this process to self.logdir.

_export_model(export_formats, export_dir)[source]

Subclasses should override this to export model.

Parameters
  • export_formats (list) – List of formats that should be exported.

  • export_dir (str) – Directory to place exported models.

Returns

A dict that maps ExportFormats to successfully exported models.

_log_result(result)[source]

This method is deprecated. Override ‘log_result’ instead.

Changed in version 0.8.7.

_restore(checkpoint)[source]

This method is deprecated. Override ‘load_checkpoint’ instead.

Changed in version 0.8.7.

_save(tmp_checkpoint_dir)[source]

This method is deprecated. Override ‘save_checkpoint’ instead.

Changed in version 0.8.7.

_setup(config)[source]

This method is deprecated. Override ‘setup’ instead.

Changed in version 0.8.7.

_stop()[source]

This method is deprecated. Override ‘cleanup’ instead.

Changed in version 0.8.7.

_train()[source]

This method is deprecated. Override ‘Trainable.step’ instead.

Changed in version 0.8.7.

cleanup()[source]

Subclasses should override this for any cleanup on stop.

If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here.

You can kill a Ray actor by calling actor.__ray_terminate__.remote() on the actor.

New in version 0.8.7.

classmethod default_resource_request(config)[source]

Provides a static resource requirement for the given configuration.

This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to.

@classmethod
def default_resource_request(cls, config):
    return Resources(
        cpu=0,
        gpu=0,
        extra_cpu=config["workers"],
        extra_gpu=int(config["use_gpu"]) * config["workers"])
Returns

A Resources object consumed by Tune for queueing.

Return type

Resources

delete_checkpoint(checkpoint_path)[source]

Deletes local copy of checkpoint.

Parameters

checkpoint_path (str) – Path to checkpoint.

export_model(export_formats, export_dir=None)[source]

Exports model based on export_formats.

Subclasses should override _export_model() to actually export model to local directory.

Parameters
  • export_formats (Union[list,str]) – Format or list of (str) formats that should be exported.

  • export_dir (str) – Optional dir to place the exported model. Defaults to self.logdir.

Returns

A dict that maps ExportFormats to successfully exported models.

get_config()[source]

Returns configuration passed in by Tune.

load_checkpoint(checkpoint)[source]

Subclasses should override this to implement restore().

Warning

In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in Trainable.save_checkpoint may be changed.

If Trainable.save_checkpoint returned a prefixed string, the prefix of the checkpoint string returned by Trainable.save_checkpoint may be changed. This is because trial pausing depends on temporary directories.

The directory structure under the checkpoint_dir provided to Trainable.save_checkpoint is preserved.

See the example below.

class Example(Trainable):
    def save_checkpoint(self, checkpoint_path):
        print(checkpoint_path)
        return os.path.join(checkpoint_path, "my/check/point")

    def load_checkpoint(self, checkpoint):
        print(checkpoint)

>>> trainer = Example()
>>> obj = trainer.save_to_object()  # This is used when PAUSED.
<logdir>/tmpc8k_c_6hsave_to_object/checkpoint_0/my/check/point
>>> trainer.restore_from_object(obj)  # Note the different prefix.
<logdir>/tmpb87b5axfrestore_from_object/checkpoint_0/my/check/point

New in version 0.8.7.

Parameters

checkpoint (str|dict) – If dict, the return value is as returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir save_checkpoint is preserved.

log_result(result)[source]

Subclasses can optionally override this to customize logging.

The logging here is done on the worker process rather than the driver. You may want to turn off driver logging via the loggers parameter in tune.run when overriding this function.

New in version 0.8.7.

Parameters

result (dict) – Training result returned by step().

reset_config(new_config)[source]

Resets configuration without restarting the trial.

This method is optional, but can be implemented to speed up algorithms such as PBT, and to allow performance optimizations such as running experiments with reuse_actors=True. Note that self.config need to be updated to reflect the latest parameter information in Ray logs.

Parameters

new_config (dict) – Updated hyperparameter configuration for the trainable.

Returns

True if reset was successful else False.

classmethod resource_help(config)[source]

Returns a help string for configuring this trainable’s resources.

Parameters

config (dict) – The Trainer’s config dict.

restore(checkpoint_path)[source]

Restores training state from a given model checkpoint.

These checkpoints are returned from calls to save().

Subclasses should override _restore() instead to restore state. This method restores additional metadata saved with the checkpoint.

restore_from_object(obj)[source]

Restores training state from a checkpoint object.

These checkpoints are returned from calls to save_to_object().

save(checkpoint_dir=None)[source]

Saves the current model state to a checkpoint.

Subclasses should override _save() instead to save state. This method dumps additional metadata alongside the saved path.

Parameters

checkpoint_dir (str) – Optional dir to place the checkpoint.

Returns

Checkpoint path or prefix that may be passed to restore().

Return type

str

save_checkpoint(tmp_checkpoint_dir)[source]

Subclasses should override this to implement save().

Warning

Do not rely on absolute paths in the implementation of Trainable.save_checkpoint and Trainable.load_checkpoint.

Use validate_save_restore to catch Trainable.save_checkpoint/ Trainable.load_checkpoint errors before execution.

>>> from ray.tune.utils import validate_save_restore
>>> validate_save_restore(MyTrainableClass)
>>> validate_save_restore(MyTrainableClass, use_object_store=True)

New in version 0.8.7.

Parameters

tmp_checkpoint_dir (str) – The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.

Returns

A dict or string. If string, the return value is expected to be prefixed by tmp_checkpoint_dir. If dict, the return value will be automatically serialized by Tune and passed to Trainable.load_checkpoint().

Examples

>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1"))
"/tmp/checkpoint_1/my_checkpoint_file"
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2"))
{"some": "data"}
>>> trainable.save_checkpoint("/tmp/bad_example")
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.
save_to_object()[source]

Saves the current model state to a Python object.

It also saves to disk but does not return the checkpoint path.

Returns

Object holding checkpoint data.

setup(config)[source]

Subclasses should override this for custom initialization.

New in version 0.8.7.

Parameters

config (dict) – Hyperparameters and other configs given. Copy of self.config.

step()[source]

Subclasses should override this to implement train().

The return value will be automatically passed to the loggers. Users can also return tune.result.DONE or tune.result.SHOULD_CHECKPOINT as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables.

New in version 0.8.7.

Returns

A dict that describes training progress.

stop()[source]

Releases all resources used by this trainable.

Calls Trainable.cleanup internally. Subclasses should override Trainable.cleanup for custom cleanup procedures.

train()[source]

Runs one logical iteration of training.

Calls step() internally. Subclasses should override step() instead to return results. This method automatically fills the following fields in the result:

done (bool): training is terminated. Filled only if not provided.

time_this_iter_s (float): Time in seconds this iteration took to run. This may be overriden in order to override the system-computed time difference.

time_total_s (float): Accumulated time in seconds for this entire experiment.

experiment_id (str): Unique string identifier for this experiment. This id is preserved across checkpoint / restore calls.

training_iteration (int): The index of this training iteration, e.g. call to train(). This is incremented after step() is called.

pid (str): The pid of the training process.

date (str): A formatted date of when the result was processed.

timestamp (str): A UNIX timestamp of when the result was processed.

hostname (str): Hostname of the machine hosting the training process.

node_ip (str): Node ip of the machine hosting the training process.

Returns

A dict that describes training progress.

property iteration

Current training iteration.

This value is automatically incremented every time train() is called and is automatically inserted into the training result dict.

property logdir

Directory of the results and checkpoints for this Trainable.

Tune will automatically sync this folder with the driver if execution is distributed.

Note that the current working directory will also be changed to this.

property training_iteration

Current training iteration (same as self.iteration).

This value is automatically incremented every time train() is called and is automatically inserted into the training result dict.

property trial_id

Trial ID for the corresponding trial of this Trainable.

This is not set if not using Tune.

trial_id = self.trial_id
property trial_name

Trial name for the corresponding trial of this Trainable.

This is not set if not using Tune.

name = self.trial_name

Distributed Torch

Ray also offers lightweight integrations to distribute your model training on Ray Tune.

ray.tune.integration.torch.DistributedTrainableCreator(func, use_gpu=False, num_workers=1, num_cpus_per_worker=1, backend='gloo', timeout_s=10)[source]

Creates a class that executes distributed training.

Similar to running torch.distributed.launch.

Note that you typically should not instantiate the object created.

Parameters
  • func (callable) – This function is a Tune trainable function. This function must have 2 args in the signature, and the latter arg must contain checkpoint_dir. For example: func(config, checkpoint_dir=None).

  • use_gpu (bool) – Sets resource allocation for workers to 1 GPU if true. Also automatically sets CUDA_VISIBLE_DEVICES for each training worker.

  • num_workers (int) – Number of training workers to include in world.

  • num_cpus_per_worker (int) – Number of CPU resources to reserve per training worker.

  • backend (str) – One of “gloo”, “nccl”.

  • timeout_s (float) – Seconds before the torch process group times out. Useful when machines are unreliable. Defaults to 60 seconds.

Returns

A trainable class object that can be passed to Tune. Resources

are automatically set within the object, so users do not need to set resources_per_trainable.

Example:

trainable_cls = DistributedTrainableCreator(
    train_func, num_workers=2)
analysis = tune.run(trainable_cls)
ray.tune.integration.torch.distributed_checkpoint_dir(step, disable=False)[source]

ContextManager for creating a distributed checkpoint.

Only checkpoints a file on the “main” training actor, avoiding redundant work.

Parameters
  • step (int) – Used to label the checkpoint

  • disable (bool) – Disable for prototyping.

Yields

path (str)

A path to a directory. This path will be used

again when invoking the training_function.

Example:

def train_func(config, checkpoint_dir):
    if checkpoint_dir:
        path = os.path.join(checkpoint_dir, "checkpoint")
        model_state_dict = torch.load(path)

    if epoch % 3 == 0:
        with distributed_checkpoint_dir(step=epoch) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
            torch.save(model.state_dict(), path)
ray.tune.integration.torch.is_distributed_trainable()[source]

Returns True if executing within a DistributedTrainable.

tune.DurableTrainable

class ray.tune.DurableTrainable(remote_checkpoint_dir, *args, **kwargs)[source]

Abstract class for a remote-storage backed fault-tolerant Trainable.

Supports checkpointing to and restoring from remote storage. To use this class, implement the same private methods as ray.tune.Trainable.

Warning

This class is currently experimental and may be subject to change.

Run this with Tune as follows. Setting sync_to_driver=False disables syncing to the driver to avoid keeping redundant checkpoints around, as well as preventing the driver from syncing up the same checkpoint.

See tune/trainable.py.

remote_checkpoint_dir

Upload directory (S3 or GS path).

Type

str

storage_client

Tune-internal interface for interacting with external storage.

>>> tune.run(MyDurableTrainable, sync_to_driver=False)

StatusReporter

class ray.tune.function_runner.StatusReporter(result_queue, continue_semaphore, trial_name=None, trial_id=None, logdir=None)[source]

Object passed into your function that you can report status through.

Example

>>> def trainable_function(config, reporter):
>>>     assert isinstance(reporter, StatusReporter)
>>>     reporter(timesteps_this_iter=1)
__call__(**kwargs)[source]

Report updated training status.

Pass in done=True when the training job is completed.

Parameters

kwargs – Latest training result status.

Example

>>> reporter(mean_accuracy=1, training_iteration=4)
>>> reporter(mean_accuracy=1, training_iteration=4, done=True)
Raises

StopIteration – A StopIteration exception is raised if the trial has been signaled to stop.