Source code for ray.tune.schedulers
import inspect
from ray._private.utils import get_function_args
from ray.tune.schedulers.async_hyperband import ASHAScheduler, AsyncHyperBandScheduler
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.schedulers.hyperband import HyperBandScheduler
from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
from ray.tune.schedulers.pbt import (
PopulationBasedTraining,
PopulationBasedTrainingReplay,
)
from ray.tune.schedulers.resource_changing_scheduler import ResourceChangingScheduler
from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
from ray.util import PublicAPI
def _pb2_importer():
# PB2 introduces a GPy dependency which can be expensive, so we import
# lazily.
from ray.tune.schedulers.pb2 import PB2
return PB2
# Values in this dictionary will be one two kinds:
# class of the scheduler object to create
# wrapper function to support a lazy import of the scheduler class
SCHEDULER_IMPORT = {
"fifo": FIFOScheduler,
"async_hyperband": AsyncHyperBandScheduler,
"asynchyperband": AsyncHyperBandScheduler,
"median_stopping_rule": MedianStoppingRule,
"medianstopping": MedianStoppingRule,
"hyperband": HyperBandScheduler,
"hb_bohb": HyperBandForBOHB,
"pbt": PopulationBasedTraining,
"pbt_replay": PopulationBasedTrainingReplay,
"pb2": _pb2_importer,
"resource_changing": ResourceChangingScheduler,
}
[docs]
@PublicAPI(stability="beta")
def create_scheduler(
scheduler,
**kwargs,
):
"""Instantiate a scheduler based on the given string.
This is useful for swapping between different schedulers.
Args:
scheduler: The scheduler to use.
**kwargs: Scheduler parameters.
These keyword arguments will be passed to the initialization
function of the chosen scheduler.
Returns:
ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
Example:
>>> from ray import tune
>>> pbt_kwargs = {}
>>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs) # doctest: +SKIP
"""
scheduler = scheduler.lower()
if scheduler not in SCHEDULER_IMPORT:
raise ValueError(
f"The `scheduler` argument must be one of "
f"{list(SCHEDULER_IMPORT)}. "
f"Got: {scheduler}"
)
SchedulerClass = SCHEDULER_IMPORT[scheduler]
if inspect.isfunction(SchedulerClass):
# invoke the wrapper function to retrieve class
SchedulerClass = SchedulerClass()
scheduler_args = get_function_args(SchedulerClass)
trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args}
return SchedulerClass(**trimmed_kwargs)
__all__ = [
"TrialScheduler",
"HyperBandScheduler",
"AsyncHyperBandScheduler",
"ASHAScheduler",
"MedianStoppingRule",
"FIFOScheduler",
"PopulationBasedTraining",
"PopulationBasedTrainingReplay",
"HyperBandForBOHB",
"ResourceChangingScheduler",
]