Source code for ray.tune.schedulers.median_stopping_rule
import collections
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
import numpy as np
from ray.tune.experiment import Trial
from ray.tune.result import DEFAULT_METRIC
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
from ray.tune.execution.tune_controller import TuneController
logger = logging.getLogger(__name__)
[docs]
@PublicAPI
class MedianStoppingRule(FIFOScheduler):
"""Implements the median stopping rule as described in the Vizier paper:
https://research.google.com/pubs/pub46180.html
Args:
time_attr: The training result attr to use for comparing time.
Note that you can pass in something non-temporal such as
`training_iteration` as a measure of progress, the only requirement
is that the attribute should increase monotonically.
metric: The training result objective value attribute. Stopping
procedures will use this attribute. If None but a mode was passed,
the `ray.tune.result.DEFAULT_METRIC` will be used per default.
mode: One of {min, max}. Determines whether objective is
minimizing or maximizing the metric attribute.
grace_period: Only stop trials at least this old in time.
The mean will only be computed from this time onwards. The units
are the same as the attribute named by `time_attr`.
min_samples_required: Minimum number of trials to compute median
over.
min_time_slice: Each trial runs at least this long before
yielding (assuming it isn't stopped). Note: trials ONLY yield if
there are not enough samples to evaluate performance for the
current result AND there are other trials waiting to run.
The units are the same as the attribute named by `time_attr`.
hard_stop: If False, pauses trials instead of stopping
them. When all other trials are complete, paused trials will be
resumed and allowed to run FIFO.
"""
def __init__(
self,
time_attr: str = "time_total_s",
metric: Optional[str] = None,
mode: Optional[str] = None,
grace_period: float = 60.0,
min_samples_required: int = 3,
min_time_slice: int = 0,
hard_stop: bool = True,
):
super().__init__()
self._stopped_trials = set()
self._grace_period = grace_period
self._min_samples_required = min_samples_required
self._min_time_slice = min_time_slice
self._metric = metric
self._worst = None
self._compare_op = None
self._mode = mode
if mode:
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
self._worst = float("-inf") if self._mode == "max" else float("inf")
self._compare_op = max if self._mode == "max" else min
self._time_attr = time_attr
self._hard_stop = hard_stop
self._trial_state = {}
self._last_pause = collections.defaultdict(lambda: float("-inf"))
self._results = collections.defaultdict(list)
def set_search_properties(
self, metric: Optional[str], mode: Optional[str], **spec
) -> bool:
if self._metric and metric:
return False
if self._mode and mode:
return False
if metric:
self._metric = metric
if mode:
self._mode = mode
self._worst = float("-inf") if self._mode == "max" else float("inf")
self._compare_op = max if self._mode == "max" else min
if self._metric is None and self._mode:
# If only a mode was passed, use anonymous metric
self._metric = DEFAULT_METRIC
return True
def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
if not self._metric or not self._worst or not self._compare_op:
raise ValueError(
"{} has been instantiated without a valid `metric` ({}) or "
"`mode` ({}) parameter. Either pass these parameters when "
"instantiating the scheduler, or pass them as parameters "
"to `tune.TuneConfig()`".format(
self.__class__.__name__, self._metric, self._mode
)
)
super(MedianStoppingRule, self).on_trial_add(tune_controller, trial)
[docs]
def on_trial_result(
self, tune_controller: "TuneController", trial: Trial, result: Dict
) -> str:
"""Callback for early stopping.
This stopping rule stops a running trial if the trial's best objective
value by step `t` is strictly worse than the median of the running
averages of all completed trials' objectives reported up to step `t`.
"""
if self._time_attr not in result or self._metric not in result:
return TrialScheduler.CONTINUE
if trial in self._stopped_trials:
assert not self._hard_stop
# Fall back to FIFO
return TrialScheduler.CONTINUE
time = result[self._time_attr]
self._results[trial].append(result)
if time < self._grace_period:
return TrialScheduler.CONTINUE
trials = self._trials_beyond_time(time)
trials.remove(trial)
if len(trials) < self._min_samples_required:
action = self._on_insufficient_samples(tune_controller, trial, time)
if action == TrialScheduler.PAUSE:
self._last_pause[trial] = time
action_str = "Yielding time to other trials."
else:
action_str = "Continuing anyways."
logger.debug(
"MedianStoppingRule: insufficient samples={} to evaluate "
"trial {} at t={}. {}".format(
len(trials), trial.trial_id, time, action_str
)
)
return action
median_result = self._median_result(trials, time)
best_result = self._best_result(trial)
logger.debug(
"Trial {} best res={} vs median res={} at t={}".format(
trial, best_result, median_result, time
)
)
if self._compare_op(median_result, best_result) != best_result:
logger.debug("MedianStoppingRule: early stopping {}".format(trial))
self._stopped_trials.add(trial)
if self._hard_stop:
return TrialScheduler.STOP
else:
return TrialScheduler.PAUSE
else:
return TrialScheduler.CONTINUE
def on_trial_complete(
self, tune_controller: "TuneController", trial: Trial, result: Dict
):
self._results[trial].append(result)
def debug_string(self) -> str:
return "Using MedianStoppingRule: num_stopped={}.".format(
len(self._stopped_trials)
)
def _on_insufficient_samples(
self, tune_controller: "TuneController", trial: Trial, time: float
) -> str:
pause = time - self._last_pause[trial] > self._min_time_slice
pause = pause and [
t
for t in tune_controller.get_live_trials()
if t.status in (Trial.PENDING, Trial.PAUSED)
]
return TrialScheduler.PAUSE if pause else TrialScheduler.CONTINUE
def _trials_beyond_time(self, time: float) -> List[Trial]:
trials = [
trial
for trial in self._results
if self._results[trial][-1][self._time_attr] >= time
]
return trials
def _median_result(self, trials: List[Trial], time: float):
return np.median([self._running_mean(trial, time) for trial in trials])
def _running_mean(self, trial: Trial, time: float) -> np.ndarray:
results = self._results[trial]
# TODO(ekl) we could do interpolation to be more precise, but for now
# assume len(results) is large and the time diffs are roughly equal
scoped_results = [
r for r in results if self._grace_period <= r[self._time_attr] <= time
]
return np.mean([r[self._metric] for r in scoped_results])
def _best_result(self, trial):
results = self._results[trial]
return self._compare_op([r[self._metric] for r in results])