Source code for ray.tune.schedulers.trial_scheduler
from typing import TYPE_CHECKING, Dict, Optional
from ray.air._internal.usage import tag_scheduler
from ray.tune.experiment import Trial
from ray.tune.result import DEFAULT_METRIC
from ray.util.annotations import DeveloperAPI, PublicAPI
if TYPE_CHECKING:
from ray.tune.execution.tune_controller import TuneController
[docs]
@DeveloperAPI
class TrialScheduler:
"""Interface for implementing a Trial Scheduler class.
Note to Tune developers: If a new scheduler is added, please update
`air/_internal/usage.py`.
"""
CONTINUE = "CONTINUE" #: Status for continuing trial execution
PAUSE = "PAUSE" #: Status for pausing trial execution
STOP = "STOP" #: Status for stopping trial execution
# Caution: Temporary and anti-pattern! This means Scheduler calls
# into Executor directly without going through TrialRunner.
# TODO(xwjiang): Deprecate this after we control the interaction
# between schedulers and executor.
NOOP = "NOOP"
_metric = None
_supports_buffered_results = True
def __init__(self):
tag_scheduler(self)
@property
def metric(self):
return self._metric
@property
def supports_buffered_results(self):
return self._supports_buffered_results
[docs]
def set_search_properties(
self, metric: Optional[str], mode: Optional[str], **spec
) -> bool:
"""Pass search properties to scheduler.
This method acts as an alternative to instantiating schedulers
that react to metrics with their own `metric` and `mode` parameters.
Args:
metric: Metric to optimize
mode: One of ["min", "max"]. Direction to optimize.
**spec: Any kwargs for forward compatiblity.
Info like Experiment.PUBLIC_KEYS is provided through here.
"""
if self._metric and metric:
return False
if metric:
self._metric = metric
if self._metric is None:
# Per default, use anonymous metric
self._metric = DEFAULT_METRIC
return True
[docs]
def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
"""Called when a new trial is added to the trial runner."""
raise NotImplementedError
[docs]
def on_trial_error(self, tune_controller: "TuneController", trial: Trial):
"""Notification for the error of trial.
This will only be called when the trial is in the RUNNING state."""
raise NotImplementedError
[docs]
def on_trial_result(
self, tune_controller: "TuneController", trial: Trial, result: Dict
) -> str:
"""Called on each intermediate result returned by a trial.
At this point, the trial scheduler can make a decision by returning
one of CONTINUE, PAUSE, and STOP. This will only be called when the
trial is in the RUNNING state."""
raise NotImplementedError
[docs]
def on_trial_complete(
self, tune_controller: "TuneController", trial: Trial, result: Dict
):
"""Notification for the completion of trial.
This will only be called when the trial is in the RUNNING state and
either completes naturally or by manual termination."""
raise NotImplementedError
[docs]
def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
"""Called to remove trial.
This is called when the trial is in PAUSED or PENDING state. Otherwise,
call `on_trial_complete`."""
raise NotImplementedError
[docs]
def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
"""Called to choose a new trial to run.
This should return one of the trials in tune_controller that is in
the PENDING or PAUSED state. This function must be idempotent.
If no trial is ready, return None."""
raise NotImplementedError
[docs]
def debug_string(self) -> str:
"""Returns a human readable message for printing to the console."""
raise NotImplementedError
[docs]
def save(self, checkpoint_path: str):
"""Save trial scheduler to a checkpoint"""
raise NotImplementedError
[docs]
def restore(self, checkpoint_path: str):
"""Restore trial scheduler from checkpoint."""
raise NotImplementedError
[docs]
@PublicAPI
class FIFOScheduler(TrialScheduler):
"""Simple scheduler that just runs trials in submission order."""
def __init__(self):
super().__init__()
def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
pass
def on_trial_error(self, tune_controller: "TuneController", trial: Trial):
pass
def on_trial_result(
self, tune_controller: "TuneController", trial: Trial, result: Dict
) -> str:
return TrialScheduler.CONTINUE
def on_trial_complete(
self, tune_controller: "TuneController", trial: Trial, result: Dict
):
pass
def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
pass
def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
for trial in tune_controller.get_trials():
if trial.status == Trial.PENDING:
return trial
for trial in tune_controller.get_trials():
if trial.status == Trial.PAUSED:
return trial
return None
def debug_string(self) -> str:
return "Using FIFO scheduling algorithm."