Source code for ray.tune.stopper.function_stopper
from typing import Callable, Dict
from ray.tune.stopper.stopper import Stopper
from ray.util.annotations import PublicAPI
[docs]
@PublicAPI
class FunctionStopper(Stopper):
"""Provide a custom function to check if trial should be stopped.
The passed function will be called after each iteration. If it returns
True, the trial will be stopped.
Args:
function: Function that checks if a trial
should be stopped. Must accept the `trial_id` string and `result`
dictionary as arguments. Must return a boolean.
"""
def __init__(self, function: Callable[[str, Dict], bool]):
self._fn = function
def __call__(self, trial_id, result):
return self._fn(trial_id, result)
def stop_all(self):
return False
@classmethod
def is_valid_function(cls, fn):
is_function = callable(fn) and not issubclass(type(fn), Stopper)
if is_function and hasattr(fn, "stop_all"):
raise ValueError(
"Stop object must be ray.tune.Stopper subclass to be detected "
"correctly."
)
return is_function