A Guide To Callbacks & Metrics in Tune#
How to work with Callbacks in Ray Tune?#
Ray Tune supports callbacks that are called during various times of the training process.
Callbacks can be passed as a parameter to RunConfig
, taken in by Tuner
, and the sub-method you provide will be invoked automatically.
This simple callback just prints a metric each time a result is received:
from ray import train, tune
from ray.train import RunConfig
from ray.tune import Callback
class MyCallback(Callback):
def on_trial_result(self, iteration, trials, trial, result, **info):
print(f"Got result: {result['metric']}")
def train_fn(config):
for i in range(10):
train.report({"metric": i})
tuner = tune.Tuner(
train_fn,
run_config=RunConfig(callbacks=[MyCallback()]))
tuner.fit()
For more details and available hooks, please see the API docs for Ray Tune callbacks.
How to use log metrics in Tune?#
You can log arbitrary values and metrics in both Function and Class training APIs:
def trainable(config):
for i in range(num_epochs):
...
train.report({"acc": accuracy, "metric_foo": random_metric_1, "bar": metric_2})
class Trainable(tune.Trainable):
def step(self):
...
# don't call report here!
return dict(acc=accuracy, metric_foo=random_metric_1, bar=metric_2)
Tip
Note that train.report()
is not meant to transfer large amounts of data, like models or datasets.
Doing so can incur large overheads and slow down your Tune run significantly.
Which Tune metrics get automatically filled in?#
Tune has the concept of auto-filled metrics. During training, Tune will automatically log the below metrics in addition to any user-provided values. All of these can be used as stopping conditions or passed as a parameter to Trial Schedulers/Search Algorithms.
config
: The hyperparameter configurationdate
: String-formatted date and time when the result was processeddone
: True if the trial has been finished, False otherwiseepisodes_total
: Total number of episodes (for RLlib trainables)experiment_id
: Unique experiment IDexperiment_tag
: Unique experiment tag (includes parameter values)hostname
: Hostname of the workeriterations_since_restore
: The number of timestrain.report
has been called after restoring the worker from a checkpointnode_ip
: Host IP of the workerpid
: Process ID (PID) of the worker processtime_since_restore
: Time in seconds since restoring from a checkpoint.time_this_iter_s
: Runtime of the current training iteration in seconds (i.e. one call to the trainable function or to_train()
in the class API.time_total_s
: Total runtime in seconds.timestamp
: Timestamp when the result was processedtimesteps_since_restore
: Number of timesteps since restoring from a checkpointtimesteps_total
: Total number of timestepstraining_iteration
: The number of timestrain.report()
has been calledtrial_id
: Unique trial ID
All of these metrics can be seen in the Trial.last_result
dictionary.