Source code for ray.tune.stopper.experiment_plateau

import numpy as np

from ray.util.annotations import PublicAPI
from ray.tune.stopper.stopper import Stopper


[docs]@PublicAPI class ExperimentPlateauStopper(Stopper): """Early stop the experiment when a metric plateaued across trials. Stops the entire experiment when the metric has plateaued for more than the given amount of iterations specified in the patience parameter. Args: metric: The metric to be monitored. std: The minimal standard deviation after which the tuning process has to stop. top: The number of best models to consider. mode: The mode to select the top results. Can either be "min" or "max". patience: Number of epochs to wait for a change in the top models. Raises: ValueError: If the mode parameter is not "min" nor "max". ValueError: If the top parameter is not an integer greater than 1. ValueError: If the standard deviation parameter is not a strictly positive float. ValueError: If the patience parameter is not a strictly positive integer. """ def __init__( self, metric: str, std: float = 0.001, top: int = 10, mode: str = "min", patience: int = 0, ): if mode not in ("min", "max"): raise ValueError("The mode parameter can only be either min or max.") if not isinstance(top, int) or top <= 1: raise ValueError( "Top results to consider must be" " a positive integer greater than one." ) if not isinstance(patience, int) or patience < 0: raise ValueError("Patience must be a strictly positive integer.") if not isinstance(std, float) or std <= 0: raise ValueError( "The standard deviation must be a strictly positive float number." ) self._mode = mode self._metric = metric self._patience = patience self._iterations = 0 self._std = std self._top = top self._top_values = [] def __call__(self, trial_id, result): """Return a boolean representing if the tuning has to stop.""" self._top_values.append(result[self._metric]) if self._mode == "min": self._top_values = sorted(self._top_values)[: self._top] else: self._top_values = sorted(self._top_values)[-self._top :] # If the current iteration has to stop if self.has_plateaued(): # we increment the total counter of iterations self._iterations += 1 else: # otherwise we reset the counter self._iterations = 0 # and then call the method that re-executes # the checks, including the iterations. return self.stop_all() def has_plateaued(self): return ( len(self._top_values) == self._top and np.std(self._top_values) <= self._std )
[docs] def stop_all(self): """Return whether to stop and prevent trials from starting.""" return self.has_plateaued() and self._iterations >= self._patience