ray.tune.Trainable#

class ray.tune.Trainable(config: Dict[str, Any] = None, logger_creator: Callable[[Dict[str, Any]], Logger] = None, storage: StorageContext | None = None)[source]#

Abstract class for trainable models, functions, etc.

A call to train() on a trainable will execute one logical iteration of training. As a rule of thumb, the execution time of one train call should be large enough to avoid overheads (i.e. more than a few seconds), but short enough to report progress periodically (i.e. at most a few minutes).

Calling save() should save the training state of a trainable to disk, and restore(path) should restore a trainable to the given state.

Generally you only need to implement setup, step, save_checkpoint, and load_checkpoint when subclassing Trainable.

Other implementation methods that may be helpful to override are log_result, reset_config, cleanup, and _export_model.

Tune will convert this class into a Ray actor, which runs on a separate process. By default, Tune will also change the current working directory of this process to its corresponding trial-level log directory self.logdir. This is designed so that different trials that run on the same physical node won’t accidentally write to the same location and overstep each other.

The behavior of changing the working directory can be disabled by setting the RAY_CHDIR_TO_TRIAL_DIR=0 environment variable. This allows access to files in the original working directory, but relative paths should be used for read only purposes, and you must make sure that the directory is synced on all nodes if running on multiple machines.

The TUNE_ORIG_WORKING_DIR environment variable was the original workaround for accessing paths relative to the original working directory. This environment variable is deprecated, and the RAY_CHDIR_TO_TRIAL_DIR environment variable described above should be used instead.

This class supports checkpointing to and restoring from remote storage.

Methods

__init__

Initialize a Trainable.

cleanup

Subclasses should override this for any cleanup on stop.

default_resource_request

Provides a static resource requirement for the given configuration.

export_model

Exports model based on export_formats.

get_auto_filled_metrics

Return a dict with metrics auto-filled by the trainable.

get_config

Returns configuration passed in by Tune.

load_checkpoint

Subclasses should override this to implement restore().

log_result

Subclasses can optionally override this to customize logging.

reset

Resets trial for use with new config.

reset_config

Resets configuration without restarting the trial.

resource_help

Returns a help string for configuring this trainable's resources.

restore

Restores training state from a given model checkpoint.

save

Saves the current model state to a checkpoint.

save_checkpoint

Subclasses should override this to implement save().

setup

Subclasses should override this for custom initialization.

step

Subclasses should override this to implement train().

stop

Releases all resources used by this trainable.

train

Runs one logical iteration of training.

train_buffered

Runs multiple iterations of training.

Attributes

iteration

Current training iteration.

logdir

Directory of the results and checkpoints for this Trainable.

training_iteration

Current training iteration (same as self.iteration).

trial_id

Trial ID for the corresponding trial of this Trainable.

trial_name

Trial name for the corresponding trial of this Trainable.

trial_resources

Resources currently assigned to the trial of this Trainable.