Source code for ray.tune.stopper.experiment_plateau
import numpy as np
from ray.tune.stopper.stopper import Stopper
from ray.util.annotations import PublicAPI
[docs]
@PublicAPI
class ExperimentPlateauStopper(Stopper):
"""Early stop the experiment when a metric plateaued across trials.
Stops the entire experiment when the metric has plateaued
for more than the given amount of iterations specified in
the patience parameter.
Args:
metric: The metric to be monitored.
std: The minimal standard deviation after which
the tuning process has to stop.
top: The number of best models to consider.
mode: The mode to select the top results.
Can either be "min" or "max".
patience: Number of epochs to wait for
a change in the top models.
Raises:
ValueError: If the mode parameter is not "min" nor "max".
ValueError: If the top parameter is not an integer
greater than 1.
ValueError: If the standard deviation parameter is not
a strictly positive float.
ValueError: If the patience parameter is not
a strictly positive integer.
"""
def __init__(
self,
metric: str,
std: float = 0.001,
top: int = 10,
mode: str = "min",
patience: int = 0,
):
if mode not in ("min", "max"):
raise ValueError("The mode parameter can only be either min or max.")
if not isinstance(top, int) or top <= 1:
raise ValueError(
"Top results to consider must be"
" a positive integer greater than one."
)
if not isinstance(patience, int) or patience < 0:
raise ValueError("Patience must be a strictly positive integer.")
if not isinstance(std, float) or std <= 0:
raise ValueError(
"The standard deviation must be a strictly positive float number."
)
self._mode = mode
self._metric = metric
self._patience = patience
self._iterations = 0
self._std = std
self._top = top
self._top_values = []
def __call__(self, trial_id, result):
"""Return a boolean representing if the tuning has to stop."""
self._top_values.append(result[self._metric])
if self._mode == "min":
self._top_values = sorted(self._top_values)[: self._top]
else:
self._top_values = sorted(self._top_values)[-self._top :]
# If the current iteration has to stop
if self.has_plateaued():
# we increment the total counter of iterations
self._iterations += 1
else:
# otherwise we reset the counter
self._iterations = 0
# and then call the method that re-executes
# the checks, including the iterations.
return self.stop_all()
def has_plateaued(self):
return (
len(self._top_values) == self._top and np.std(self._top_values) <= self._std
)
[docs]
def stop_all(self):
"""Return whether to stop and prevent trials from starting."""
return self.has_plateaued() and self._iterations >= self._patience