Source code for ray.tune.stopper.trial_plateau
from collections import defaultdict, deque
from typing import Dict, Optional
import numpy as np
from ray.tune.stopper.stopper import Stopper
from ray.util.annotations import PublicAPI
[docs]
@PublicAPI
class TrialPlateauStopper(Stopper):
"""Early stop single trials when they reached a plateau.
When the standard deviation of the `metric` result of a trial is
below a threshold `std`, the trial plateaued and will be stopped
early.
Args:
metric: Metric to check for convergence.
std: Maximum metric standard deviation to decide if a
trial plateaued. Defaults to 0.01.
num_results: Number of results to consider for stdev
calculation.
grace_period: Minimum number of timesteps before a trial
can be early stopped
metric_threshold (Optional[float]):
Minimum or maximum value the result has to exceed before it can
be stopped early.
mode: If a `metric_threshold` argument has been
passed, this must be one of [min, max]. Specifies if we optimize
for a large metric (max) or a small metric (min). If max, the
`metric_threshold` has to be exceeded, if min the value has to
be lower than `metric_threshold` in order to early stop.
"""
def __init__(
self,
metric: str,
std: float = 0.01,
num_results: int = 4,
grace_period: int = 4,
metric_threshold: Optional[float] = None,
mode: Optional[str] = None,
):
self._metric = metric
self._mode = mode
self._std = std
self._num_results = num_results
self._grace_period = grace_period
self._metric_threshold = metric_threshold
if self._metric_threshold:
if mode not in ["min", "max"]:
raise ValueError(
f"When specifying a `metric_threshold`, the `mode` "
f"argument has to be one of [min, max]. "
f"Got: {mode}"
)
self._iter = defaultdict(lambda: 0)
self._trial_results = defaultdict(lambda: deque(maxlen=self._num_results))
def __call__(self, trial_id: str, result: Dict):
metric_result = result.get(self._metric)
self._trial_results[trial_id].append(metric_result)
self._iter[trial_id] += 1
# If still in grace period, do not stop yet
if self._iter[trial_id] < self._grace_period:
return False
# If not enough results yet, do not stop yet
if len(self._trial_results[trial_id]) < self._num_results:
return False
# If metric threshold value not reached, do not stop yet
if self._metric_threshold is not None:
if self._mode == "min" and metric_result > self._metric_threshold:
return False
elif self._mode == "max" and metric_result < self._metric_threshold:
return False
# Calculate stdev of last `num_results` results
try:
current_std = np.std(self._trial_results[trial_id])
except Exception:
current_std = float("inf")
# If stdev is lower than threshold, stop early.
return current_std < self._std
def stop_all(self):
return False