class ray.tune.Trainable(config: Dict[str, Any] = None, logger_creator: Callable[[Dict[str, Any]], Logger] = None, remote_checkpoint_dir: Optional[str] = None, sync_config: Optional[ray.train.SyncConfig] = None, storage: Optional[ray.train._internal.storage.StorageContext] = None)[source]#

Bases: object

Abstract class for trainable models, functions, etc.

A call to train() on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

Calling save() should save the training state of a trainable to disk, and restore(path) should restore a trainable to the given state.

Generally you only need to implement setup, step, save_checkpoint, and load_checkpoint when subclassing Trainable.

Other implementation methods that may be helpful to override are log_result, reset_config, cleanup, and _export_model.

Tune will convert this class into a Ray actor, which runs on a separate process. By default, Tune will also change the current working directory of this process to its corresponding trial-level log directory self.logdir. This is designed so that different trials that run on the same physical node won’t accidentally write to the same location and overstep each other.

The behavior of changing the working directory can be disabled by setting the RAY_CHDIR_TO_TRIAL_DIR=0 environment variable. This allows access to files in the original working directory, but relative paths should be used for read only purposes, and you must make sure that the directory is synced on all nodes if running on multiple machines.

The TUNE_ORIG_WORKING_DIR environment variable was the original workaround for accessing paths relative to the original working directory. This environment variable is deprecated, and the RAY_CHDIR_TO_TRIAL_DIR environment variable described above should be used instead.

This class supports checkpointing to and restoring from remote storage.

PublicAPI: This API is stable across Ray releases.


__init__([config, logger_creator, ...])

Initialize a Trainable.


Subclasses should override this for any cleanup on stop.


Provides a static resource requirement for the given configuration.


Deletes local copy of checkpoint.

export_model(export_formats[, export_dir])

Exports model based on export_formats.

get_auto_filled_metrics([now, ...])

Return a dict with metrics auto-filled by the trainable.


Returns configuration passed in by Tune.


Subclasses should override this to implement restore().


Subclasses can optionally override this to customize logging.

reset(new_config[, logger_creator, ...])

Resets trial for use with new config.


Resets configuration without restarting the trial.


Returns a help string for configuring this trainable's resources.

restore(checkpoint_path[, ...])

Restores training state from a given model checkpoint.


save([checkpoint_dir, prevent_upload])

Saves the current model state to a checkpoint.


Subclasses should override this to implement save().



Subclasses should override this for custom initialization.


Subclasses should override this to implement train().


Releases all resources used by this trainable.


Runs one logical iteration of training.

train_buffered(buffer_time_s[, ...])

Runs multiple iterations of training.



Current training iteration.


Directory of the results and checkpoints for this Trainable.


Current training iteration (same as self.iteration).


Trial ID for the corresponding trial of this Trainable.


Trial name for the corresponding trial of this Trainable.


Resources currently assigned to the trial of this Trainable.