Training (tune.Trainable,

Training can be done with either a Function API ( or Class API (tune.Trainable).

For the sake of example, let’s maximize this objective function:

def objective(x, a, b):
    return a * (x ** 0.5) + b

Function API#

The Function API allows you to define a custom training function that Tune will run in parallel Ray actor processes, one for each Tune trial.

With the Function API, you can report intermediate metrics by simply calling within the function.

from ray import tune
from ray.air import session

def trainable(config: dict):
    intermediate_score = 0
    for x in range(20):
        intermediate_score = objective(x, config["a"], config["b"]){"score": intermediate_score})  # This sends the score to Tune.

tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results =


Do not use within a Trainable class.

In the previous example, we reported on every step, but this metric reporting frequency is configurable. For example, we could also report only a single time at the end with the final score:

from ray import tune
from ray.air import session

def trainable(config: dict):
    final_score = 0
    for x in range(20):
        final_score = objective(x, config["a"], config["b"]){"score": final_score})  # This sends the score to Tune.

tuner = tune.Tuner(trainable, param_space={"a": 2, "b": 4})
results =

It’s also possible to return a final set of metrics to Tune by returning them from your function:

def trainable(config: dict):
    final_score = 0
    for x in range(20):
        final_score = objective(x, config["a"], config["b"])

    return {"score": final_score}  # This sends the score to Tune.

You’ll notice that Ray Tune will output extra values in addition to the user reported metrics, such as iterations_since_restore. See How to use log metrics in Tune? for an explanation/glossary of these values.

Function API Checkpointing#

Many Tune features rely on checkpointing, including the usage of certain Trial Schedulers and fault tolerance. You can save and load checkpoints in Ray Tune in the following manner:

from ray import tune
from ray.air import session
from ray.air.checkpoint import Checkpoint

def train_func(config):
    epochs = config.get("epochs", 2)
    start = 0
    loaded_checkpoint = session.get_checkpoint()
    if loaded_checkpoint:
        last_step = loaded_checkpoint.to_dict()["step"]
        start = last_step + 1

    for step in range(start, epochs):
        # Model training here
        # ...

        # Report metrics and save a checkpoint
        metrics = {"metric": "my_metric"}
        checkpoint = Checkpoint.from_dict({"step": step}), checkpoint=checkpoint)

tuner = tune.Tuner(train_func)
results =


checkpoint_frequency and checkpoint_at_end will not work with Function API checkpointing.

In this example, checkpoints will be saved by training iteration to <local_dir>/<exp_name>/trial_name/checkpoint_<step>.

Tune also may copy or move checkpoints during the course of tuning. For this purpose, it is important not to depend on absolute paths in the implementation of save.

See here for more information on creating checkpoints. If using framework-specific trainers from Ray AIR, see here for references to framework-specific checkpoints such as TensorflowCheckpoint.

Trainable Class API#


Do not use within a Trainable class.

The Trainable class API will require users to subclass ray.tune.Trainable. Here’s a naive example of this API:

from ray import air, tune

class Trainable(tune.Trainable):
    def setup(self, config: dict):
        # config (dict): A dict of hyperparameters
        self.x = 0
        self.a = config["a"]
        self.b = config["b"]

    def step(self):  # This is called iteratively.
        score = objective(self.x, self.a, self.b)
        self.x += 1
        return {"score": score}

tuner = tune.Tuner(
        # Train for 20 steps
        stop={"training_iteration": 20},
            # We haven't implemented checkpointing yet. See below!
    param_space={"a": 2, "b": 4},
results =

As a subclass of tune.Trainable, Tune will create a Trainable object on a separate process (using the Ray Actor API).

  1. setup function is invoked once training starts.

  2. step is invoked multiple times. Each time, the Trainable object executes one logical iteration of training in the tuning process, which may include one or more iterations of actual training.

  3. cleanup is invoked when training is finished.


As a rule of thumb, the execution time of step should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

You’ll notice that Ray Tune will output extra values in addition to the user reported metrics, such as iterations_since_restore. See How to use log metrics in Tune? for an explanation/glossary of these values.

Class API Checkpointing#

You can also implement checkpoint/restore using the Trainable Class API:

import os
import torch
from torch import nn

from ray import air, tune

class MyTrainableClass(tune.Trainable):
    def setup(self, config):
        self.model = nn.Sequential(
            nn.Linear(config.get("input_size", 32), 32), nn.ReLU(), nn.Linear(32, 10)

    def step(self):
        return {}

    def save_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth"), checkpoint_path)
        return tmp_checkpoint_dir

    def load_checkpoint(self, tmp_checkpoint_dir):
        checkpoint_path = os.path.join(tmp_checkpoint_dir, "model.pth")

tuner = tune.Tuner(
    param_space={"input_size": 64},
        stop={"training_iteration": 2},

You can checkpoint with three different mechanisms: manually, periodically, and at termination.

Manual Checkpointing: A custom Trainable can manually trigger checkpointing by returning should_checkpoint: True (or tune.result.SHOULD_CHECKPOINT: True) in the result dictionary of step. This can be especially helpful in spot instances:

def step(self):
    # training code
    result = {"mean_accuracy": accuracy}
    if detect_instance_preemption():
    return result

Periodic Checkpointing: periodic checkpointing can be used to provide fault-tolerance for experiments. This can be enabled by setting checkpoint_frequency=<int> and max_failures=<int> to checkpoint trials every N iterations and recover from up to M crashes per trial, e.g.:

tuner = tune.Tuner(
results =

Checkpointing at Termination: The checkpoint_frequency may not coincide with the exact end of an experiment. If you want a checkpoint to be created at the end of a trial, you can additionally set the checkpoint_at_end=True:

 tuner = tune.Tuner(
         checkpoint_config=air.CheckpointConfig(checkpoint_frequency=10, checkpoint_at_end=True),
 results =

Use validate_save_restore to catch save_checkpoint/load_checkpoint errors before execution.

from ray.tune.utils import validate_save_restore

# both of these should return
validate_save_restore(MyTrainableClass, use_object_store=True)

Advanced: Reusing Actors#


This feature is only for the Trainable Class API.

Your Trainable can often take a long time to start. To avoid this, you can do tune.TuneConfig(reuse_actors=True) (which is taken in by Tuner) to reuse the same Trainable Python process and object for multiple hyperparameters.

This requires you to implement Trainable.reset_config, which provides a new set of hyperparameters. It is up to the user to correctly update the hyperparameters of your trainable.

class PytorchTrainble(tune.Trainable):
    """Train a Pytorch ConvNet."""

    def setup(self, config):
        self.train_loader, self.test_loader = get_data_loaders()
        self.model = ConvNet()
        self.optimizer = optim.SGD(
            lr=config.get("lr", 0.01),
            momentum=config.get("momentum", 0.9))

    def reset_config(self, new_config):
        for param_group in self.optimizer.param_groups:
            if "lr" in new_config:
                param_group["lr"] = new_config["lr"]
            if "momentum" in new_config:
                param_group["momentum"] = new_config["momentum"]

        self.model = ConvNet()
        self.config = new_config
        return True

Comparing the Function API and Class API#

Here are a few key concepts and what they look like for the Function and Class API’s.


Function API

Class API

Training Iteration

Increments on each call

Increments on each Trainable.step call

Report metrics

Return metrics from Trainable.step

Saving a checkpoint, checkpoint=checkpoint)


Loading a checkpoint



Accessing config

Passed as an argument def train_func(config):

Passed through Trainable.setup

Advanced Resource Allocation#

Trainables can themselves be distributed. If your trainable function / class creates further Ray actors or tasks that also consume CPU / GPU resources, you will want to add more bundles to the PlacementGroupFactory to reserve extra resource slots. For example, if a trainable class requires 1 GPU itself, but also launches 4 actors, each using another GPU, then you should use this:

 tuner = tune.Tuner(
     tune.with_resources(my_trainable, tune.PlacementGroupFactory([
         {"CPU": 1, "GPU": 1},
         {"GPU": 1},
         {"GPU": 1},
         {"GPU": 1},
         {"GPU": 1}

The Trainable also provides the default_resource_requests interface to automatically declare the resources per trial based on the given configuration.

It is also possible to specify memory ("memory", in bytes) and custom resource requirements.

session (Function API)# Dict, *, checkpoint: Optional[ray.air.checkpoint.Checkpoint] = None) None[source]

Report metrics and optionally save a checkpoint.

Each invocation of this method will automatically increment the underlying iteration number. The physical meaning of this “iteration” is defined by user (or more specifically the way they call report). It does not necessarily map to one epoch.

This API is the canonical way to report metrics from Tune and Train, and replaces the legacy, with tune.checkpoint_dir, and train.save_checkpoint calls.

Note on directory checkpoints: AIR will take ownership of checkpoints passed to report() by moving them to a new path. The original directory will no longer be accessible to the caller after the report call.


  • metrics – The metrics you want to report.

  • checkpoint – The optional checkpoint you want to report.

ray.air.session.get_checkpoint() Optional[ray.air.checkpoint.Checkpoint][source]

Access the session’s last checkpoint to resume from if applicable.


Checkpoint object if the session is currently being resumed.

Otherwise, return None.

######## Using it in the *per worker* train loop (TrainSession) ######
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.air.config import ScalingConfig
def train_func():
    ckpt = session.get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as loaded_checkpoint_dir:
            import tensorflow as tf

            model = tf.keras.models.load_model(loaded_checkpoint_dir)
        model = build_model()"my_model", overwrite=True)
        metrics={"iter": 1},

scaling_config = ScalingConfig(num_workers=2)
trainer = TensorflowTrainer(
    train_loop_per_worker=train_func, scaling_config=scaling_config
result =

# trainer2 will pick up from the checkpoint saved by trainer1.
trainer2 = TensorflowTrainer(
    # this is ultimately what is accessed through
    # ``Session.get_checkpoint()``
result2 =
ray.air.session.get_trial_name() str[source]

Trial name for the corresponding trial.

ray.air.session.get_trial_id() str[source]

Trial id for the corresponding trial.

ray.air.session.get_trial_resources() PlacementGroupFactory[source]

Trial resources for the corresponding trial.

ray.air.session.get_trial_dir() str[source]

Log directory corresponding to the trial directory for a Tune session. If calling from a Train session, this will give the trial directory of its parent Tune session.

from ray import tune
from ray.air import session

def train_func():
    # Example:
    # >>> session.get_trial_dir()
    # ~/ray_results/<exp-name>/<trial-dir>

tuner = tune.Tuner(train_func)

tune.Trainable (Class API)#

class ray.tune.Trainable(config: Dict[str, Any] = None, logger_creator: Callable[[Dict[str, Any]], Logger] = None, remote_checkpoint_dir: Optional[str] = None, custom_syncer: Optional[ray.tune.syncer.Syncer] = None, sync_timeout: Optional[int] = None)[source]#

Abstract class for trainable models, functions, etc.

A call to train() on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

Calling save() should save the training state of a trainable to disk, and restore(path) should restore a trainable to the given state.

Generally you only need to implement setup, step, save_checkpoint, and load_checkpoint when subclassing Trainable.

Other implementation methods that may be helpful to override are log_result, reset_config, cleanup, and _export_model.

Tune will convert this class into a Ray actor, which runs on a separate process. By default, Tune will also change the current working directory of this process to its corresponding trial-level log directory self.logdir. This is designed so that different trials that run on the same physical node won’t accidently write to the same location and overstep each other.

The behavior of changing the working directory can be disabled by setting the flag chdir_to_trial_dir=False in tune.TuneConfig. This allows access to files in the original working directory, but relative paths should be used for read only purposes, and you must make sure that the directory is synced on all nodes if running on multiple machines.

The TUNE_ORIG_WORKING_DIR environment variable was the original workaround for accessing paths relative to the original working directory. This environment variable is deprecated, and the chdir_to_trial_dir flag described above should be used instead.

This class supports checkpointing to and restoring from remote storage.

PublicAPI: This API is stable across Ray releases.


alias of ray.air.checkpoint.Checkpoint


Close stdout and stderr logfiles.

_create_logger(config: Dict[str, Any], logger_creator: Callable[[Dict[str, Any]], Logger] = None)[source]#

Create logger from logger creator.

Sets _logdir and _result_logger.

_logdir is the per trial directory for the Trainable.

_export_model(export_formats: List[str], export_dir: str)[source]#

Subclasses should override this to export model.

  • export_formats – List of formats that should be exported.

  • export_dir – Directory to place exported models.


A dict that maps ExportFormats to successfully exported models.

_open_logfiles(stdout_file, stderr_file)[source]#

Create loggers. Open stdout and stderr logfiles.


Converts a local_path to be based off of self.remote_checkpoint_dir.


Subclasses should override this for any cleanup on stop.

If any Ray actors are launched in the Trainable (i.e., with a RLlib trainer), be sure to kill the Ray actor process here.

This process should be lightweight. Per default,

You can kill a Ray actor by calling ray.kill(actor) on the actor or removing all references to it and waiting for garbage collection

New in version 0.8.7.

classmethod default_resource_request(config: Dict[str, Any]) Optional[Union[ray.tune.resources.Resources, ray.tune.execution.placement_groups.PlacementGroupFactory]][source]#

Provides a static resource requirement for the given configuration.

This can be overridden by sub-classes to set the correct trial resource allocation, so the user does not need to.

def default_resource_request(cls, config):
    return PlacementGroupFactory([{"CPU": 1}, {"CPU": 1}]])
  • config[Dict[str – The Trainable’s config dict.

  • Any]] – The Trainable’s config dict.


A Resources object or

PlacementGroupFactory consumed by Tune for queueing.

Return type

Union[Resources, PlacementGroupFactory]

delete_checkpoint(checkpoint_path: Union[str, ray.air.checkpoint.Checkpoint])[source]#

Deletes local copy of checkpoint.


checkpoint_path – Path to checkpoint.

export_model(export_formats: Union[List[str], str], export_dir: Optional[str] = None)[source]#

Exports model based on export_formats.

Subclasses should override _export_model() to actually export model to local directory.

  • export_formats – Format or list of (str) formats that should be exported.

  • export_dir – Optional dir to place the exported model. Defaults to self.logdir.


A dict that maps ExportFormats to successfully exported models.

get_auto_filled_metrics(now: Optional[datetime.datetime] = None, time_this_iter: Optional[float] = None, debug_metrics_only: bool = False) dict[source]#

Return a dict with metrics auto-filled by the trainable.

If debug_metrics_only is True, only metrics that don’t require at least one iteration will be returned (ray.tune.result.DEBUG_METRICS).


Returns configuration passed in by Tune.

load_checkpoint(checkpoint: Union[Dict, str])[source]#

Subclasses should override this to implement restore().


In this method, do not rely on absolute paths. The absolute path of the checkpoint_dir used in Trainable.save_checkpoint may be changed.

If Trainable.save_checkpoint returned a prefixed string, the prefix of the checkpoint string returned by Trainable.save_checkpoint may be changed. This is because trial pausing depends on temporary directories.

The directory structure under the checkpoint_dir provided to Trainable.save_checkpoint is preserved.

See the examples below.


>>> import os
>>> from ray.tune.trainable import Trainable
>>> class Example(Trainable):
...    def save_checkpoint(self, checkpoint_path):
...        my_checkpoint_path = os.path.join(checkpoint_path, "my/path")
...        return my_checkpoint_path
...    def load_checkpoint(self, my_checkpoint_path):
...        print(my_checkpoint_path)
>>> trainer = Example()
>>> # This is used when PAUSED.
>>> obj = trainer.save_to_object() 
>>> # Note the different prefix.
>>> trainer.restore_from_object(obj) 

If Trainable.save_checkpoint returned a dict, then Tune will directly pass the dict data as the argument to this method.


>>> from ray.tune.trainable import Trainable
>>> class Example(Trainable):
...    def save_checkpoint(self, checkpoint_path):
...        return {"my_data": 1}
...    def load_checkpoint(self, checkpoint_dict):
...        print(checkpoint_dict["my_data"])

New in version 0.8.7.


checkpoint – If dict, the return value is as returned by save_checkpoint. If a string, then it is a checkpoint path that may have a different prefix than that returned by save_checkpoint. The directory structure underneath the checkpoint_dir from save_checkpoint is preserved.

log_result(result: Dict)[source]#

Subclasses can optionally override this to customize logging.

The logging here is done on the worker process rather than the driver.

New in version 0.8.7.


result – Training result returned by step().

reset(new_config, logger_creator=None)[source]#

Resets trial for use with new config.

Subclasses should override reset_config() to actually reset actor behavior for the new config.

reset_config(new_config: Dict)[source]#

Resets configuration without restarting the trial.

This method is optional, but can be implemented to speed up algorithms such as PBT, and to allow performance optimizations such as running experiments with reuse_actors=True.


new_config – Updated hyperparameter configuration for the trainable.


True if reset was successful else False.

classmethod resource_help(config: Dict)[source]#

Returns a help string for configuring this trainable’s resources.


config – The Trainer’s config dict.

restore(checkpoint_path: Union[str, ray.air.checkpoint.Checkpoint], checkpoint_node_ip: Optional[str] = None, fallback_to_latest: bool = False)[source]#

Restores training state from a given model checkpoint.

These checkpoints are returned from calls to save().

Subclasses should override load_checkpoint() instead to restore state. This method restores additional metadata saved with the checkpoint.

checkpoint_path should match with the return from save().

checkpoint_path can be /ray_results/exp/MyTrainable_abc/ checkpoint_00000/checkpoint. Or, /ray_results/exp/MyTrainable_abc/checkpoint_00000.

self.logdir should generally be corresponding to checkpoint_path, for example, /ray_results/exp/MyTrainable_abc.

self.remote_checkpoint_dir in this case, is something like, REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc

  • checkpoint_path – Path to restore checkpoint from. If this path does not exist on the local node, it will be fetched from external (cloud) storage if available, or restored from a remote node.

  • checkpoint_node_ip – If given, try to restore checkpoint from this node if it doesn’t exist locally or on cloud storage.

  • fallback_to_latest – If True, will try to recover the latest available checkpoint if the given checkpoint_path could not be found.


Restores training state from a checkpoint object.

These checkpoints are returned from calls to save_to_object().

save(checkpoint_dir: Optional[str] = None, prevent_upload: bool = False) str[source]#

Saves the current model state to a checkpoint.

Subclasses should override save_checkpoint() instead to save state. This method dumps additional metadata alongside the saved path.

If a remote checkpoint dir is given, this will also sync up to remote storage.

  • checkpoint_dir – Optional dir to place the checkpoint.

  • prevent_upload – If True, will not upload the saved checkpoint to cloud.


The given or created checkpoint directory.

Note the return path should match up with what is expected of restore().

save_checkpoint(checkpoint_dir: str) Optional[Union[str, Dict]][source]#

Subclasses should override this to implement save().


Do not rely on absolute paths in the implementation of Trainable.save_checkpoint and Trainable.load_checkpoint.

Use validate_save_restore to catch Trainable.save_checkpoint/ Trainable.load_checkpoint errors before execution.

>>> from ray.tune.utils import validate_save_restore
>>> MyTrainableClass = ... 
>>> validate_save_restore(MyTrainableClass) 
>>> validate_save_restore( 
...     MyTrainableClass, use_object_store=True)

New in version 0.8.7.


checkpoint_dir – The directory where the checkpoint file must be stored. In a Tune run, if the trial is paused, the provided path may be temporary and moved.


A dict or string. If string, the return value is expected to be prefixed by checkpoint_dir. If dict, the return value will be automatically serialized by Tune. In both cases, the return value is exactly what will be passed to Trainable.load_checkpoint() upon restore.


>>> trainable, trainable1, trainable2 = ... 
>>> print(trainable1.save_checkpoint("/tmp/checkpoint_1")) 
>>> print(trainable2.save_checkpoint("/tmp/checkpoint_2")) 
{"some": "data"}
>>> trainable.save_checkpoint("/tmp/bad_example") 
"/tmp/NEW_CHECKPOINT_PATH/my_checkpoint_file" # This will error.

Saves the current model state to a Python object.

It also saves to disk but does not return the checkpoint path. It does not save the checkpoint to cloud storage.


Object holding checkpoint data.

setup(config: Dict)[source]#

Subclasses should override this for custom initialization.

New in version 0.8.7.


config – Hyperparameters and other configs given. Copy of self.config.


Subclasses should override this to implement train().

The return value will be automatically passed to the loggers. Users can also return tune.result.DONE or tune.result.SHOULD_CHECKPOINT as a key to manually trigger termination or checkpointing of this trial. Note that manual checkpointing only works when subclassing Trainables.

New in version 0.8.7.


A dict that describes training progress.


Releases all resources used by this trainable.

Calls Trainable.cleanup internally. Subclasses should override Trainable.cleanup for custom cleanup procedures.


Runs one logical iteration of training.

Calls step() internally. Subclasses should override step() instead to return results. This method automatically fills the following fields in the result:

done (bool): training is terminated. Filled only if not provided.

time_this_iter_s (float): Time in seconds this iteration took to run. This may be overridden in order to override the system-computed time difference.

time_total_s (float): Accumulated time in seconds for this entire experiment.

experiment_id (str): Unique string identifier for this experiment. This id is preserved across checkpoint / restore calls.

training_iteration (int): The index of this training iteration, e.g. call to train(). This is incremented after step() is called.

pid (str): The pid of the training process.

date (str): A formatted date of when the result was processed.

timestamp (str): A UNIX timestamp of when the result was processed.

hostname (str): Hostname of the machine hosting the training process.

node_ip (str): Node ip of the machine hosting the training process.


A dict that describes training progress.

train_buffered(buffer_time_s: float, max_buffer_length: int = 1000)[source]#

Runs multiple iterations of training.

Calls train() internally. Collects and combines multiple results. This function will run self.train() repeatedly until one of the following conditions is met: 1) the maximum buffer length is reached, 2) the maximum buffer time is reached, or 3) a checkpoint was created. Even if the maximum time is reached, it will always block until at least one result is received.

  • buffer_time_s – Maximum time to buffer. The next result received after this amount of time has passed will return the whole buffer.

  • max_buffer_length – Maximum number of results to buffer.

property iteration#

Current training iteration.

This value is automatically incremented every time train() is called and is automatically inserted into the training result dict.

property logdir#

Directory of the results and checkpoints for this Trainable.

Tune will automatically sync this folder with the driver if execution is distributed.

Note that the current working directory will also be changed to this.

property training_iteration#

Current training iteration (same as self.iteration).

This value is automatically incremented every time train() is called and is automatically inserted into the training result dict.

property trial_id#

Trial ID for the corresponding trial of this Trainable.

This is not set if not using Tune.

trial_id = self.trial_id
property trial_name#

Trial name for the corresponding trial of this Trainable.

This is not set if not using Tune.

name = self.trial_name
property trial_resources: Union[ray.tune.resources.Resources, ray.tune.execution.placement_groups.PlacementGroupFactory]#

Resources currently assigned to the trial of this Trainable.

This is not set if not using Tune.

trial_resources = self.trial_resources


ray.tune.utils.wait_for_gpu(gpu_id: Optional[Union[int, str]] = None, target_util: float = 0.01, retry: int = 20, delay_s: int = 5, gpu_memory_limit: Optional[float] = None)[source]#

Checks if a given GPU has freed memory.

Requires gputil to be installed: pip install gputil.

  • gpu_id – GPU id or uuid to check. Must be found within GPUtil.getGPUs(). If none, resorts to the first item returned from ray.get_gpu_ids().

  • target_util – The utilization threshold to reach to unblock. Set this to 0 to block until the GPU is completely free.

  • retry – Number of times to check GPU limit. Sleeps delay_s seconds between checks.

  • delay_s – Seconds to wait before check.


True if free.

Return type



RuntimeError – If GPUtil is not found, if no GPUs are detected or if the check fails.


def tune_func(config):

tuner = tune.Tuner(
        resources={"gpu": 1}

PublicAPI (beta): This API is in beta and may change before becoming stable.

ray.tune.utils.diagnose_serialization(trainable: Callable)[source]#

Utility for detecting why your trainable function isn’t serializing.


trainable – The trainable object passed to tune.Tuner(trainable). Currently only supports Function API.


bool | set of unserializable objects.


import threading
# this is not serializable
e = threading.Event()

def test():

# should help identify that 'e' should be moved into
# the `test` scope.

# correct implementation
def test():
    e = threading.Event()

assert diagnose_serialization(test) is True

DeveloperAPI: This API may change across minor Ray releases.

ray.tune.utils.validate_save_restore(trainable_cls: Type, config: Optional[Dict] = None, num_gpus: int = 0, use_object_store: bool = False)[source]#

Helper method to check if your Trainable class will resume correctly.

  • trainable_cls – Trainable class for evaluation.

  • config – Config to pass to Trainable when testing.

  • num_gpus – GPU resources to allocate when testing.

  • use_object_store – Whether to save and restore to Ray’s object store. Recommended to set this to True if planning to use algorithms that pause training (i.e., PBT, HyperBand).

DeveloperAPI: This API may change across minor Ray releases.


ray.tune.with_parameters(trainable: Union[Type[Trainable], Callable], **kwargs)[source]#

Wrapper for trainables to pass arbitrary large data objects.

This wrapper function will store all passed parameters in the Ray object store and retrieve them when calling the function. It can thus be used to pass arbitrary data, even datasets, to Tune trainables.

This can also be used as an alternative to functools.partial to pass default arguments to trainables.

When used with the function API, the trainable function is called with the passed parameters as keyword arguments. When used with the class API, the Trainable.setup() method is called with the respective kwargs.

If the data already exists in the object store (are instances of ObjectRef), using tune.with_parameters() is not necessary. You can instead pass the object refs to the training function via the config or use Python partials.

  • trainable – Trainable to wrap.

  • **kwargs – parameters to store in object store.

Function API example:

from ray import tune
from ray.air import session

def train(config, data=None):
    for sample in data:
        loss = update_model(sample)

data = HugeDataset(download=True)

tuner = Tuner(
    tune.with_parameters(train, data=data),
    # ...

Class API example:

from ray import tune

class MyTrainable(tune.Trainable):
    def setup(self, config, data=None): = data
        self.iter = iter(
        self.next_sample = next(self.iter)

    def step(self):
        loss = update_model(self.next_sample)
            self.next_sample = next(self.iter)
        except StopIteration:
            return {"loss": loss, done: True}
        return {"loss": loss}

data = HugeDataset(download=True)

tuner = Tuner(
    tune.with_parameters(MyTrainable, data=data),
    # ...


When restoring a Tune experiment, you need to re-specify the trainable wrapped with tune.with_parameters. The reasoning behind this is as follows:

1. tune.with_parameters stores parameters in the object store and attaches object references to the trainable, but the objects they point to may not exist anymore upon restore.

2. The attached objects could be arbitrarily large, so Tune does not save the object data along with the trainable.

To restore, Tune allows the trainable to be re-specified in Tuner.restore(overwrite_trainable=...). Continuing from the previous examples, here’s an example of restoration:

from ray.tune import Tuner

data = HugeDataset(download=True)

tuner = Tuner.restore(
    overwrite_trainable=tune.with_parameters(MyTrainable, data=data)

PublicAPI (beta): This API is in beta and may change before becoming stable.