Source code for ray.tune.stopper.stopper

import abc
from typing import Any, Dict

from ray.util.annotations import PublicAPI


[docs]@PublicAPI class Stopper(abc.ABC): """Base class for implementing a Tune experiment stopper. Allows users to implement experiment-level stopping via ``stop_all``. By default, this class does not stop any trials. Subclasses need to implement ``__call__`` and ``stop_all``. Examples: >>> import time >>> from ray import train, tune >>> from ray.tune import Stopper >>> >>> class TimeStopper(Stopper): ... def __init__(self): ... self._start = time.time() ... self._deadline = 2 # Stop all trials after 2 seconds ... ... def __call__(self, trial_id, result): ... return False ... ... def stop_all(self): ... return time.time() - self._start > self._deadline ... >>> def train_fn(config): ... for i in range(100): ... time.sleep(1) ... train.report({"iter": i}) ... >>> tuner = tune.Tuner( ... train_fn, ... tune_config=tune.TuneConfig(num_samples=2), ... run_config=train.RunConfig(stop=TimeStopper()), ... ) >>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS [ignore]... """
[docs] def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool: """Returns true if the trial should be terminated given the result.""" raise NotImplementedError
[docs] def stop_all(self) -> bool: """Returns true if the experiment should be terminated.""" raise NotImplementedError
[docs]@PublicAPI class CombinedStopper(Stopper): """Combine several stoppers via 'OR'. Args: *stoppers: Stoppers to be combined. Examples: >>> import numpy as np >>> from ray import train, tune >>> from ray.tune.stopper import ( ... CombinedStopper, ... MaximumIterationStopper, ... TrialPlateauStopper, ... ) >>> >>> stopper = CombinedStopper( ... MaximumIterationStopper(max_iter=10), ... TrialPlateauStopper(metric="my_metric"), ... ) >>> def train_fn(config): ... for i in range(15): ... train.report({"my_metric": np.random.normal(0, 1 - i / 15)}) ... >>> tuner = tune.Tuner( ... train_fn, ... run_config=train.RunConfig(stop=stopper), ... ) >>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS [ignore]... >>> all(result.metrics["training_iteration"] <= 20 for result in result_grid) True """ def __init__(self, *stoppers: Stopper): self._stoppers = stoppers def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool: return any(s(trial_id, result) for s in self._stoppers) def stop_all(self) -> bool: return any(s.stop_all() for s in self._stoppers)