Source code for ray.tune.schedulers.trial_scheduler
from typing import TYPE_CHECKING, Dict, Optional
from ray.air._internal.usage import tag_scheduler
from ray.tune.experiment import Trial
from ray.tune.result import DEFAULT_METRIC
from ray.util.annotations import DeveloperAPI, PublicAPI
if TYPE_CHECKING:
    from ray.tune.execution.tune_controller import TuneController
[docs]
@DeveloperAPI
class TrialScheduler:
    """Interface for implementing a Trial Scheduler class.
    Note to Tune developers: If a new scheduler is added, please update
    `air/_internal/usage.py`.
    """
    CONTINUE = "CONTINUE"  #: Status for continuing trial execution
    PAUSE = "PAUSE"  #: Status for pausing trial execution
    STOP = "STOP"  #: Status for stopping trial execution
    # Caution: Temporary and anti-pattern! This means Scheduler calls
    # into Executor directly without going through TrialRunner.
    # TODO(xwjiang): Deprecate this after we control the interaction
    #  between schedulers and executor.
    NOOP = "NOOP"
    _metric = None
    _supports_buffered_results = True
    def __init__(self):
        tag_scheduler(self)
    @property
    def metric(self):
        return self._metric
    @property
    def supports_buffered_results(self):
        return self._supports_buffered_results
[docs]
    def set_search_properties(
        self, metric: Optional[str], mode: Optional[str], **spec
    ) -> bool:
        """Pass search properties to scheduler.
        This method acts as an alternative to instantiating schedulers
        that react to metrics with their own `metric` and `mode` parameters.
        Args:
            metric: Metric to optimize
            mode: One of ["min", "max"]. Direction to optimize.
            **spec: Any kwargs for forward compatiblity.
                Info like Experiment.PUBLIC_KEYS is provided through here.
        """
        if self._metric and metric:
            return False
        if metric:
            self._metric = metric
        if self._metric is None:
            # Per default, use anonymous metric
            self._metric = DEFAULT_METRIC
        return True 
[docs]
    def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
        """Called when a new trial is added to the trial runner."""
        raise NotImplementedError 
[docs]
    def on_trial_error(self, tune_controller: "TuneController", trial: Trial):
        """Notification for the error of trial.
        This will only be called when the trial is in the RUNNING state."""
        raise NotImplementedError 
[docs]
    def on_trial_result(
        self, tune_controller: "TuneController", trial: Trial, result: Dict
    ) -> str:
        """Called on each intermediate result returned by a trial.
        At this point, the trial scheduler can make a decision by returning
        one of CONTINUE, PAUSE, and STOP. This will only be called when the
        trial is in the RUNNING state."""
        raise NotImplementedError 
[docs]
    def on_trial_complete(
        self, tune_controller: "TuneController", trial: Trial, result: Dict
    ):
        """Notification for the completion of trial.
        This will only be called when the trial is in the RUNNING state and
        either completes naturally or by manual termination."""
        raise NotImplementedError 
[docs]
    def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
        """Called to remove trial.
        This is called when the trial is in PAUSED or PENDING state. Otherwise,
        call `on_trial_complete`."""
        raise NotImplementedError 
[docs]
    def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
        """Called to choose a new trial to run.
        This should return one of the trials in tune_controller that is in
        the PENDING or PAUSED state. This function must be idempotent.
        If no trial is ready, return None."""
        raise NotImplementedError 
[docs]
    def debug_string(self) -> str:
        """Returns a human readable message for printing to the console."""
        raise NotImplementedError 
[docs]
    def save(self, checkpoint_path: str):
        """Save trial scheduler to a checkpoint"""
        raise NotImplementedError 
[docs]
    def restore(self, checkpoint_path: str):
        """Restore trial scheduler from checkpoint."""
        raise NotImplementedError 
 
[docs]
@PublicAPI
class FIFOScheduler(TrialScheduler):
    """Simple scheduler that just runs trials in submission order."""
    def __init__(self):
        super().__init__()
    def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
        pass
    def on_trial_error(self, tune_controller: "TuneController", trial: Trial):
        pass
    def on_trial_result(
        self, tune_controller: "TuneController", trial: Trial, result: Dict
    ) -> str:
        return TrialScheduler.CONTINUE
    def on_trial_complete(
        self, tune_controller: "TuneController", trial: Trial, result: Dict
    ):
        pass
    def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
        pass
    def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
        for trial in tune_controller.get_trials():
            if trial.status == Trial.PENDING:
                return trial
        for trial in tune_controller.get_trials():
            if trial.status == Trial.PAUSED:
                return trial
        return None
    def debug_string(self) -> str:
        return "Using FIFO scheduling algorithm."