Source code for ray.tune.stopper.trial_plateau

from collections import defaultdict, deque
from typing import Dict, Optional

import numpy as np

from ray.tune.stopper.stopper import Stopper
from ray.util.annotations import PublicAPI


[docs] @PublicAPI class TrialPlateauStopper(Stopper): """Early stop single trials when they reached a plateau. When the standard deviation of the `metric` result of a trial is below a threshold `std`, the trial plateaued and will be stopped early. Args: metric: Metric to check for convergence. std: Maximum metric standard deviation to decide if a trial plateaued. Defaults to 0.01. num_results: Number of results to consider for stdev calculation. grace_period: Minimum number of timesteps before a trial can be early stopped metric_threshold (Optional[float]): Minimum or maximum value the result has to exceed before it can be stopped early. mode: If a `metric_threshold` argument has been passed, this must be one of [min, max]. Specifies if we optimize for a large metric (max) or a small metric (min). If max, the `metric_threshold` has to be exceeded, if min the value has to be lower than `metric_threshold` in order to early stop. """ def __init__( self, metric: str, std: float = 0.01, num_results: int = 4, grace_period: int = 4, metric_threshold: Optional[float] = None, mode: Optional[str] = None, ): self._metric = metric self._mode = mode self._std = std self._num_results = num_results self._grace_period = grace_period self._metric_threshold = metric_threshold if self._metric_threshold: if mode not in ["min", "max"]: raise ValueError( f"When specifying a `metric_threshold`, the `mode` " f"argument has to be one of [min, max]. " f"Got: {mode}" ) self._iter = defaultdict(lambda: 0) self._trial_results = defaultdict(lambda: deque(maxlen=self._num_results)) def __call__(self, trial_id: str, result: Dict): metric_result = result.get(self._metric) self._trial_results[trial_id].append(metric_result) self._iter[trial_id] += 1 # If still in grace period, do not stop yet if self._iter[trial_id] < self._grace_period: return False # If not enough results yet, do not stop yet if len(self._trial_results[trial_id]) < self._num_results: return False # If metric threshold value not reached, do not stop yet if self._metric_threshold is not None: if self._mode == "min" and metric_result > self._metric_threshold: return False elif self._mode == "max" and metric_result < self._metric_threshold: return False # Calculate stdev of last `num_results` results try: current_std = np.std(self._trial_results[trial_id]) except Exception: current_std = float("inf") # If stdev is lower than threshold, stop early. return current_std < self._std def stop_all(self): return False