import glob
import warnings
from abc import ABCMeta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint
from ray.util.annotations import DeveloperAPI, PublicAPI
if TYPE_CHECKING:
from ray.train import Checkpoint
from ray.tune.experiment import Trial
from ray.tune.stopper import Stopper
class _CallbackMeta(ABCMeta):
"""A helper metaclass to ensure container classes (e.g. CallbackList) have
implemented all the callback methods (e.g. `on_*`).
"""
def __new__(mcs, name: str, bases: Tuple[type], attrs: Dict[str, Any]) -> type:
cls = super().__new__(mcs, name, bases, attrs)
if mcs.need_check(cls, name, bases, attrs):
mcs.check(cls, name, bases, attrs)
return cls
@classmethod
def need_check(
mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any]
) -> bool:
return attrs.get("IS_CALLBACK_CONTAINER", False)
@classmethod
def check(
mcs, cls: type, name: str, bases: Tuple[type], attrs: Dict[str, Any]
) -> None:
methods = set()
for base in bases:
methods.update(
attr_name
for attr_name, attr in vars(base).items()
if mcs.need_override_by_subclass(attr_name, attr)
)
overridden = {
attr_name
for attr_name, attr in attrs.items()
if mcs.need_override_by_subclass(attr_name, attr)
}
missing = methods.difference(overridden)
if missing:
raise TypeError(
f"Found missing callback method: {missing} "
f"in class {cls.__module__}.{cls.__qualname__}."
)
@classmethod
def need_override_by_subclass(mcs, attr_name: str, attr: Any) -> bool:
return (
(
attr_name.startswith("on_")
and not attr_name.startswith("on_trainer_init")
)
or attr_name == "setup"
) and callable(attr)
[docs]
@PublicAPI(stability="beta")
class Callback(metaclass=_CallbackMeta):
"""Tune base callback that can be extended and passed to a ``TrialRunner``
Tune callbacks are called from within the ``TrialRunner`` class. There are
several hooks that can be used, all of which are found in the submethod
definitions of this base class.
The parameters passed to the ``**info`` dict vary between hooks. The
parameters passed are described in the docstrings of the methods.
This example will print a metric each time a result is received:
.. testcode::
from ray import train, tune
from ray.tune import Callback
class MyCallback(Callback):
def on_trial_result(self, iteration, trials, trial, result,
**info):
print(f"Got result: {result['metric']}")
def train_func(config):
for i in range(10):
tune.report(metric=i)
tuner = tune.Tuner(
train_func,
run_config=train.RunConfig(
callbacks=[MyCallback()]
)
)
tuner.fit()
.. testoutput::
:hide:
...
"""
# File templates for any artifacts written by this callback
# These files should live in the `trial.local_path` for each trial.
# TODO(ml-team): Make this more visible to users to override. Internal use for now.
_SAVED_FILE_TEMPLATES = []
# arguments here match Experiment.public_spec
[docs]
def setup(
self,
stop: Optional["Stopper"] = None,
num_samples: Optional[int] = None,
total_num_samples: Optional[int] = None,
**info,
):
"""Called once at the very beginning of training.
Any Callback setup should be added here (setting environment
variables, etc.)
Arguments:
stop: Stopping criteria.
If ``time_budget_s`` was passed to ``train.RunConfig``, a
``TimeoutStopper`` will be passed here, either by itself
or as a part of a ``CombinedStopper``.
num_samples: Number of times to sample from the
hyperparameter space. Defaults to 1. If `grid_search` is
provided as an argument, the grid will be repeated
`num_samples` of times. If this is -1, (virtually) infinite
samples are generated until a stopping condition is met.
total_num_samples: Total number of samples factoring
in grid search samplers.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_step_begin(self, iteration: int, trials: List["Trial"], **info):
"""Called at the start of each tuning loop step.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_step_end(self, iteration: int, trials: List["Trial"], **info):
"""Called at the end of each tuning loop step.
The iteration counter is increased before this hook is called.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_trial_start(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after starting a trial instance.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has been started.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_trial_restore(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after restoring a trial instance.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has been restored.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_trial_save(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after receiving a checkpoint from a trial.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just saved a checkpoint.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_trial_result(
self,
iteration: int,
trials: List["Trial"],
trial: "Trial",
result: Dict,
**info,
):
"""Called after receiving a result from a trial.
The search algorithm and scheduler are notified before this
hook is called.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just sent a result.
result: Result that the trial sent.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_trial_complete(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after a trial instance completed.
The search algorithm and scheduler are notified before this
hook is called.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has been completed.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_trial_recover(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after a trial instance failed (errored) but the trial is scheduled
for retry.
The search algorithm and scheduler are not notified.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has errored.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_trial_error(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
"""Called after a trial instance failed (errored).
The search algorithm and scheduler are notified before this
hook is called.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has errored.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_checkpoint(
self,
iteration: int,
trials: List["Trial"],
trial: "Trial",
checkpoint: "Checkpoint",
**info,
):
"""Called after a trial saved a checkpoint with Tune.
Arguments:
iteration: Number of iterations of the tuning loop.
trials: List of trials.
trial: Trial that just has errored.
checkpoint: Checkpoint object that has been saved
by the trial.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def on_experiment_end(self, trials: List["Trial"], **info):
"""Called after experiment is over and all trials have concluded.
Arguments:
trials: List of trials.
**info: Kwargs dict for forward compatibility.
"""
pass
[docs]
def get_state(self) -> Optional[Dict]:
"""Get the state of the callback.
This method should be implemented by subclasses to return a dictionary
representation of the object's current state.
This is called automatically by Tune to periodically checkpoint callback state.
Upon :ref:`Tune experiment restoration <tune-experiment-level-fault-tolerance>`,
callback state will be restored via :meth:`~ray.tune.Callback.set_state`.
.. testcode::
from typing import Dict, List, Optional
from ray.tune import Callback
from ray.tune.experiment import Trial
class MyCallback(Callback):
def __init__(self):
self._trial_ids = set()
def on_trial_start(
self, iteration: int, trials: List["Trial"], trial: "Trial", **info
):
self._trial_ids.add(trial.trial_id)
def get_state(self) -> Optional[Dict]:
return {"trial_ids": self._trial_ids.copy()}
def set_state(self, state: Dict) -> Optional[Dict]:
self._trial_ids = state["trial_ids"]
Returns:
dict: State of the callback. Should be `None` if the callback does not
have any state to save (this is the default).
"""
return None
[docs]
def set_state(self, state: Dict):
"""Set the state of the callback.
This method should be implemented by subclasses to restore the callback's
state based on the given dict state.
This is used automatically by Tune to restore checkpoint callback state
on :ref:`Tune experiment restoration <tune-experiment-level-fault-tolerance>`.
See :meth:`~ray.tune.Callback.get_state` for an example implementation.
Args:
state: State of the callback.
"""
pass
@DeveloperAPI
class CallbackList(Callback):
"""Call multiple callbacks at once."""
IS_CALLBACK_CONTAINER = True
CKPT_FILE_TMPL = "callback-states-{}.pkl"
def __init__(self, callbacks: List[Callback]):
self._callbacks = callbacks
def setup(self, **info):
for callback in self._callbacks:
try:
callback.setup(**info)
except TypeError as e:
if "argument" in str(e):
warnings.warn(
"Please update `setup` method in callback "
f"`{callback.__class__}` to match the method signature"
" in `ray.tune.callback.Callback`.",
FutureWarning,
)
callback.setup()
else:
raise e
def on_step_begin(self, **info):
for callback in self._callbacks:
callback.on_step_begin(**info)
def on_step_end(self, **info):
for callback in self._callbacks:
callback.on_step_end(**info)
def on_trial_start(self, **info):
for callback in self._callbacks:
callback.on_trial_start(**info)
def on_trial_restore(self, **info):
for callback in self._callbacks:
callback.on_trial_restore(**info)
def on_trial_save(self, **info):
for callback in self._callbacks:
callback.on_trial_save(**info)
def on_trial_result(self, **info):
for callback in self._callbacks:
callback.on_trial_result(**info)
def on_trial_complete(self, **info):
for callback in self._callbacks:
callback.on_trial_complete(**info)
def on_trial_recover(self, **info):
for callback in self._callbacks:
callback.on_trial_recover(**info)
def on_trial_error(self, **info):
for callback in self._callbacks:
callback.on_trial_error(**info)
def on_checkpoint(self, **info):
for callback in self._callbacks:
callback.on_checkpoint(**info)
def on_experiment_end(self, **info):
for callback in self._callbacks:
callback.on_experiment_end(**info)
def get_state(self) -> Optional[Dict]:
"""Gets the state of all callbacks contained within this list.
If there are no stateful callbacks, then None will be returned in order
to avoid saving an unnecessary callback checkpoint file."""
state = {}
any_stateful_callbacks = False
for i, callback in enumerate(self._callbacks):
callback_state = callback.get_state()
if callback_state:
any_stateful_callbacks = True
state[i] = callback_state
if not any_stateful_callbacks:
return None
return state
def set_state(self, state: Dict):
"""Sets the state for all callbacks contained within this list.
Skips setting state for all stateless callbacks where `get_state`
returned None."""
for i, callback in enumerate(self._callbacks):
callback_state = state.get(i, None)
if callback_state:
callback.set_state(callback_state)
def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
"""Save the state of the callback list to the checkpoint_dir.
Args:
checkpoint_dir: directory where the checkpoint is stored.
session_str: Unique identifier of the current run session (ex: timestamp).
"""
state_dict = self.get_state()
if state_dict:
file_name = self.CKPT_FILE_TMPL.format(session_str)
tmp_file_name = f".tmp-{file_name}"
_atomic_save(
state=state_dict,
checkpoint_dir=checkpoint_dir,
file_name=file_name,
tmp_file_name=tmp_file_name,
)
def restore_from_dir(self, checkpoint_dir: str):
"""Restore the state of the list of callbacks from the checkpoint_dir.
You should check if it's possible to restore with `can_restore`
before calling this method.
Args:
checkpoint_dir: directory where the checkpoint is stored.
Raises:
RuntimeError: if unable to find checkpoint.
NotImplementedError: if the `set_state` method is not implemented.
"""
state_dict = _load_newest_checkpoint(
checkpoint_dir, self.CKPT_FILE_TMPL.format("*")
)
if not state_dict:
raise RuntimeError(
"Unable to find checkpoint in {}.".format(checkpoint_dir)
)
self.set_state(state_dict)
def can_restore(self, checkpoint_dir: str) -> bool:
"""Check if the checkpoint_dir contains the saved state for this callback list.
Returns:
can_restore: True if the checkpoint_dir contains a file of the
format `CKPT_FILE_TMPL`. False otherwise.
"""
return any(
glob.iglob(Path(checkpoint_dir, self.CKPT_FILE_TMPL.format("*")).as_posix())
)
def __len__(self) -> int:
return len(self._callbacks)
def __getitem__(self, i: int) -> "Callback":
return self._callbacks[i]