Trial Schedulers (tune.schedulers)

In Tune, some hyperparameter optimization algorithms are written as “scheduling algorithms”. These Trial Schedulers can early terminate bad trials, pause trials, clone trials, and alter hyperparameters of a running trial.

All Trial Schedulers take in a metric, which is a value returned in the result dict of your Trainable and is maximized or minimized according to mode.

tune.run( ... , scheduler=Scheduler(metric="accuracy", mode="max"))

Summary

Tune includes distributed implementations of early stopping algorithms such as Median Stopping Rule, HyperBand, and ASHA. Tune also includes a distributed implementation of Population Based Training (PBT).

Tip

The easiest scheduler to start with is the ASHAScheduler which will aggressively terminate low-performing trials.

When using schedulers, you may face compatibility issues, as shown in the below compatibility matrix. Certain schedulers cannot be used with Search Algorithms, and certain schedulers are require checkpointing to be implemented.

TrialScheduler Feature Compatibility Matrix

Scheduler

Need Checkpointing?

SearchAlg Compatible?

Example

ASHA

No

Yes

Link

Median Stopping Rule

No

Yes

Link

HyperBand

Yes

Yes

Link

BOHB

Yes

Only TuneBOHB

Link

Population Based Training

Yes

Not Compatible

Link

ASHA (tune.schedulers.ASHAScheduler)

The ASHA scheduler can be used by setting the scheduler parameter of tune.run, e.g.

asha_scheduler = ASHAScheduler(
    time_attr='training_iteration',
    metric='episode_reward_mean',
    mode='max',
    max_t=100,
    grace_period=10,
    reduction_factor=3,
    brackets=1)
tune.run( ... , scheduler=asha_scheduler)

Compared to the original version of HyperBand, this implementation provides better parallelism and avoids straggler issues during eliminations. We recommend using this over the standard HyperBand scheduler. An example of this can be found here: async_hyperband_example.

Even though the original paper mentions a bracket count of 3, discussions with the authors concluded that the value should be left to 1 bracket. This is the default used if no value is provided for the brackets argument.

class ray.tune.schedulers.AsyncHyperBandScheduler(time_attr: str = 'training_iteration', reward_attr: Optional[str] = None, metric: Optional[str] = None, mode: Optional[str] = None, max_t: int = 100, grace_period: int = 1, reduction_factor: float = 4, brackets: int = 1)[source]

Implements the Async Successive Halving.

This should provide similar theoretical performance as HyperBand but avoid straggler issues that HyperBand faces. One implementation detail is when using multiple brackets, trial allocation to bracket is done randomly with over a softmax probability.

See https://arxiv.org/abs/1810.05934

Parameters
  • time_attr (str) – A training result attr to use for comparing time. Note that you can pass in something non-temporal such as training_iteration as a measure of progress, the only requirement is that the attribute should increase monotonically.

  • metric (str) – The training result objective value attribute. Stopping procedures will use this attribute.

  • mode (str) – One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute.

  • max_t (float) – max time units per trial. Trials will be stopped after max_t time units (determined by time_attr) have passed.

  • grace_period (float) – Only stop trials at least this old in time. The units are the same as the attribute named by time_attr.

  • reduction_factor (float) – Used to set halving rate and amount. This is simply a unit-less scalar.

  • brackets (int) – Number of brackets. Each bracket has a different halving rate, specified by the reduction factor.

ray.tune.schedulers.ASHAScheduler

alias of ray.tune.schedulers.async_hyperband.AsyncHyperBandScheduler

HyperBand (tune.schedulers.HyperBandScheduler)

Tune implements the standard version of HyperBand. We recommend using the ASHA Scheduler over the standard HyperBand scheduler.

class ray.tune.schedulers.HyperBandScheduler(time_attr: str = 'training_iteration', reward_attr: Optional[str] = None, metric: Optional[str] = None, mode: Optional[str] = None, max_t: int = 81, reduction_factor: float = 3)[source]

Implements the HyperBand early stopping algorithm.

HyperBandScheduler early stops trials using the HyperBand optimization algorithm. It divides trials into brackets of varying sizes, and periodically early stops low-performing trials within each bracket.

To use this implementation of HyperBand with Tune, all you need to do is specify the max length of time a trial can run max_t, the time units time_attr, the name of the reported objective value metric, and if metric is to be maximized or minimized (mode). We automatically determine reasonable values for the other HyperBand parameters based on the given values.

For example, to limit trials to 10 minutes and early stop based on the episode_mean_reward attr, construct:

HyperBand('time_total_s', 'episode_reward_mean', max_t=600)

Note that Tune’s stopping criteria will be applied in conjunction with HyperBand’s early stopping mechanisms.

See also: https://people.eecs.berkeley.edu/~kjamieson/hyperband.html

Parameters
  • time_attr (str) – The training result attr to use for comparing time. Note that you can pass in something non-temporal such as training_iteration as a measure of progress, the only requirement is that the attribute should increase monotonically.

  • metric (str) – The training result objective value attribute. Stopping procedures will use this attribute.

  • mode (str) – One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute.

  • max_t (int) – max time units per trial. Trials will be stopped after max_t time units (determined by time_attr) have passed. The scheduler will terminate trials after this time has passed. Note that this is different from the semantics of max_t as mentioned in the original HyperBand paper.

  • reduction_factor (float) – Same as eta. Determines how sharp the difference is between bracket space-time allocation ratios.

HyperBand Implementation Details

Implementation details may deviate slightly from theory but are focused on increasing usability. Note: R, s_max, and eta are parameters of HyperBand given by the paper. See this post for context.

  1. Both s_max (representing the number of brackets - 1) and eta, representing the downsampling rate, are fixed. In many practical settings, R, which represents some resource unit and often the number of training iterations, can be set reasonably large, like R >= 200. For simplicity, assume eta = 3. Varying R between R = 200 and R = 1000 creates a huge range of the number of trials needed to fill up all brackets.

../../_images/hyperband_bracket.png

On the other hand, holding R constant at R = 300 and varying eta also leads to HyperBand configurations that are not very intuitive:

../../_images/hyperband_eta.png

The implementation takes the same configuration as the example given in the paper and exposes max_t, which is not a parameter in the paper.

  1. The example in the post to calculate n_0 is actually a little different than the algorithm given in the paper. In this implementation, we implement n_0 according to the paper (which is n in the below example):

../../_images/hyperband_allocation.png
  1. There are also implementation specific details like how trials are placed into brackets which are not covered in the paper. This implementation places trials within brackets according to smaller bracket first - meaning that with low number of trials, there will be less early stopping.

Median Stopping Rule (tune.schedulers.MedianStoppingRule)

The Median Stopping Rule implements the simple strategy of stopping a trial if its performance falls below the median of other trials at similar points in time.

class ray.tune.schedulers.MedianStoppingRule(time_attr: str = 'time_total_s', reward_attr: Optional[str] = None, metric: Optional[str] = None, mode: Optional[str] = None, grace_period: float = 60.0, min_samples_required: int = 3, min_time_slice: int = 0, hard_stop: bool = True)[source]

Implements the median stopping rule as described in the Vizier paper:

https://research.google.com/pubs/pub46180.html

Parameters
  • time_attr (str) – The training result attr to use for comparing time. Note that you can pass in something non-temporal such as training_iteration as a measure of progress, the only requirement is that the attribute should increase monotonically.

  • metric (str) – The training result objective value attribute. Stopping procedures will use this attribute.

  • mode (str) – One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute.

  • grace_period (float) – Only stop trials at least this old in time. The mean will only be computed from this time onwards. The units are the same as the attribute named by time_attr.

  • min_samples_required (int) – Minimum number of trials to compute median over.

  • min_time_slice (float) – Each trial runs at least this long before yielding (assuming it isn’t stopped). Note: trials ONLY yield if there are not enough samples to evaluate performance for the current result AND there are other trials waiting to run. The units are the same as the attribute named by time_attr.

  • hard_stop (bool) – If False, pauses trials instead of stopping them. When all other trials are complete, paused trials will be resumed and allowed to run FIFO.

Population Based Training (tune.schedulers.PopulationBasedTraining)

Tune includes a distributed implementation of Population Based Training (PBT). This can be enabled by setting the scheduler parameter of tune.run, e.g.

pbt_scheduler = PopulationBasedTraining(
        time_attr='time_total_s',
        metric='mean_accuracy',
        mode='max',
        perturbation_interval=600.0,
        hyperparam_mutations={
            "lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
            "alpha": lambda: random.uniform(0.0, 1.0),
        ...
        })
tune.run( ... , scheduler=pbt_scheduler)

When the PBT scheduler is enabled, each trial variant is treated as a member of the population. Periodically, top-performing trials are checkpointed (this requires your Trainable to support save and restore). Low-performing trials clone the checkpoints of top performers and perturb the configurations in the hope of discovering an even better variation.

You can run this toy PBT example to get an idea of how how PBT operates. When training in PBT mode, a single trial may see many different hyperparameters over its lifetime, which is recorded in its result.json file. The following figure generated by the example shows PBT with optimizing a LR schedule over the course of a single experiment:

../../_images/pbt.png
class ray.tune.schedulers.PopulationBasedTraining(time_attr: str = 'time_total_s', reward_attr: Optional[str] = None, metric: Optional[str] = None, mode: Optional[str] = None, perturbation_interval: float = 60.0, hyperparam_mutations: Dict = None, quantile_fraction: float = 0.25, resample_probability: float = 0.25, custom_explore_fn: Optional[Callable] = None, log_config: bool = True, require_attrs: bool = True, synch: bool = False)[source]

Implements the Population Based Training (PBT) algorithm.

https://deepmind.com/blog/population-based-training-neural-networks

PBT trains a group of models (or agents) in parallel. Periodically, poorly performing models clone the state of the top performers, and a random mutation is applied to their hyperparameters in the hopes of outperforming the current top models.

Unlike other hyperparameter search algorithms, PBT mutates hyperparameters during training time. This enables very fast hyperparameter discovery and also automatically discovers good annealing schedules.

This Tune PBT implementation considers all trials added as part of the PBT population. If the number of trials exceeds the cluster capacity, they will be time-multiplexed as to balance training progress across the population. To run multiple trials, use tune.run(num_samples=<int>).

In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config] on each perturbation step.

Parameters
  • time_attr (str) – The training result attr to use for comparing time. Note that you can pass in something non-temporal such as training_iteration as a measure of progress, the only requirement is that the attribute should increase monotonically.

  • metric (str) – The training result objective value attribute. Stopping procedures will use this attribute.

  • mode (str) – One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute.

  • perturbation_interval (float) – Models will be considered for perturbation at this interval of time_attr. Note that perturbation incurs checkpoint overhead, so you shouldn’t set this to be too frequent.

  • hyperparam_mutations (dict) – Hyperparams to mutate. The format is as follows: for each key, either a list, function, or a tune search space object (tune.loguniform, tune.uniform, etc.) can be provided. A list specifies an allowed set of categorical values. A function or tune search space object specifies the distribution of a continuous parameter. You must use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary tune.sample_from objects are not supported. You must specify at least one of hyperparam_mutations or custom_explore_fn. Tune will use the search space provided by hyperparam_mutations for the initial samples if the corresponding attributes are not present in config.

  • quantile_fraction (float) – Parameters are transferred from the top quantile_fraction fraction of trials to the bottom quantile_fraction fraction. Needs to be between 0 and 0.5. Setting it to 0 essentially implies doing no exploitation at all.

  • resample_probability (float) – The probability of resampling from the original distribution when applying hyperparam_mutations. If not resampled, the value will be perturbed by a factor of 1.2 or 0.8 if continuous, or changed to an adjacent value if discrete.

  • custom_explore_fn (func) – You can also specify a custom exploration function. This function is invoked as f(config) after built-in perturbations from hyperparam_mutations are applied, and should return config updated as needed. You must specify at least one of hyperparam_mutations or custom_explore_fn.

  • log_config (bool) – Whether to log the ray config of each model to local_dir at each exploit. Allows config schedule to be reconstructed.

  • require_attrs (bool) – Whether to require time_attr and metric to appear in result for every iteration. If True, error will be raised if these values are not present in trial result.

  • synch (bool) – If False, will use asynchronous implementation of PBT. Trial perturbations occur every perturbation_interval for each trial independently. If True, will use synchronous implementation of PBT. Perturbations will occur only after all trials are synced at the same time_attr every perturbation_interval. Defaults to False. See Appendix A.1 here https://arxiv.org/pdf/1711.09846.pdf.

import random
from ray import tune
from ray.tune.schedulers import PopulationBasedTraining

pbt = PopulationBasedTraining(
    time_attr="training_iteration",
    metric="episode_reward_mean",
    mode="max",
    perturbation_interval=10,  # every 10 `time_attr` units
                               # (training_iterations in this case)
    hyperparam_mutations={
        # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
        # resets it to a value sampled from the lambda function.
        "factor_1": lambda: random.uniform(0.0, 20.0),
        # Alternatively, use tune search space primitives.
        # The search space for factor_1 is equivalent to factor_2.
        "factor_2": tune.uniform(0.0, 20.0),
        # Perturb factor3 by changing it to an adjacent value, e.g.
        # 10 -> 1 or 10 -> 100. Resampling will choose at random.
        "factor_3": [1, 10, 100, 1000, 10000],
        # Using tune.choice is NOT equivalent to the above.
        # factor_4 is treated as a continuous hyperparameter.
        "factor_4": tune.choice([1, 10, 100, 1000, 10000]),
    })
tune.run({...}, num_samples=8, scheduler=pbt)

Population Based Training Replay (tune.schedulers.PopulationBasedTrainingReplay)

Tune includes a utility to replay hyperparameter schedules of Population Based Training runs. You just specify an existing experiment directory and the ID of the trial you would like to replay. The scheduler accepts only one trial, and it will update its config according to the obtained schedule.

replay = PopulationBasedTrainingReplay(
    experiment_dir="~/ray_results/pbt_experiment/",
    trial_id="XXXXX_00001")
tune.run(
    ...,
    scheduler=replay)

See here for an example on how to use the replay utility in practice.

class ray.tune.schedulers.PopulationBasedTrainingReplay(policy_file: str)[source]

Replays a Population Based Training run.

Population Based Training does not return a single hyperparameter configuration, but rather a schedule of configurations. For instance, PBT might discover that a larger learning rate leads to good results in the first training iterations, but that a smaller learning rate is preferable later.

This scheduler enables replaying these parameter schedules from a finished PBT run. This requires that population based training has been run with log_config=True, which is the default setting.

The scheduler will only accept and train a single trial. It will start with the initial config of the existing trial and update the config according to the schedule.

Parameters

policy_file (str) – The PBT policy file. Usually this is stored in ~/ray_results/experiment_name/pbt_policy_xxx.txt where xxx is the trial ID.

Example:

# Replaying a result from ray.tune.examples.pbt_convnet_example
from ray import tune

from ray.tune.examples.pbt_convnet_example import PytorchTrainable
from ray.tune.schedulers import PopulationBasedTrainingReplay

replay = PopulationBasedTrainingReplay(
    "~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt")

tune.run(
    PytorchTrainable,
    scheduler=replay,
    stop={"training_iteration": 100})

BOHB (tune.schedulers.HyperBandForBOHB)

This class is a variant of HyperBand that enables the BOHB Algorithm. This implementation is true to the original HyperBand implementation and does not implement pipelining nor straggler mitigation.

This is to be used in conjunction with the Tune BOHB search algorithm. See TuneBOHB for package requirements, examples, and details.

An example of this in use can be found here: bohb_example.

class ray.tune.schedulers.HyperBandForBOHB(time_attr: str = 'training_iteration', reward_attr: Optional[str] = None, metric: Optional[str] = None, mode: Optional[str] = None, max_t: int = 81, reduction_factor: float = 3)[source]

Extends HyperBand early stopping algorithm for BOHB.

This implementation removes the HyperBandScheduler pipelining. This class introduces key changes:

1. Trials are now placed so that the bracket with the largest size is filled first.

2. Trials will be paused even if the bracket is not filled. This allows BOHB to insert new trials into the training.

See ray.tune.schedulers.HyperBandScheduler for parameter docstring.

FIFOScheduler

class ray.tune.schedulers.FIFOScheduler[source]

Simple scheduler that just runs trials in submission order.

TrialScheduler

class ray.tune.schedulers.TrialScheduler[source]

Interface for implementing a Trial Scheduler class.

CONTINUE = 'CONTINUE'

Status for continuing trial execution

PAUSE = 'PAUSE'

Status for pausing trial execution

STOP = 'STOP'

Status for stopping trial execution

set_search_properties(metric: Optional[str], mode: Optional[str]) → bool[source]

Pass search properties to scheduler.

This method acts as an alternative to instantiating schedulers that react to metrics with their own metric and mode parameters.

Parameters
  • metric (str) – Metric to optimize

  • mode (str) – One of [“min”, “max”]. Direction to optimize.

on_trial_add(trial_runner: ray.tune.trial_runner.TrialRunner, trial: ray.tune.trial.Trial)[source]

Called when a new trial is added to the trial runner.

on_trial_error(trial_runner: ray.tune.trial_runner.TrialRunner, trial: ray.tune.trial.Trial)[source]

Notification for the error of trial.

This will only be called when the trial is in the RUNNING state.

on_trial_result(trial_runner: ray.tune.trial_runner.TrialRunner, trial: ray.tune.trial.Trial, result: Dict) → str[source]

Called on each intermediate result returned by a trial.

At this point, the trial scheduler can make a decision by returning one of CONTINUE, PAUSE, and STOP. This will only be called when the trial is in the RUNNING state.

on_trial_complete(trial_runner: ray.tune.trial_runner.TrialRunner, trial: ray.tune.trial.Trial, result: Dict)[source]

Notification for the completion of trial.

This will only be called when the trial is in the RUNNING state and either completes naturally or by manual termination.

on_trial_remove(trial_runner: ray.tune.trial_runner.TrialRunner, trial: ray.tune.trial.Trial)[source]

Called to remove trial.

This is called when the trial is in PAUSED or PENDING state. Otherwise, call on_trial_complete.

choose_trial_to_run(trial_runner: ray.tune.trial_runner.TrialRunner) → Optional[ray.tune.trial.Trial][source]

Called to choose a new trial to run.

This should return one of the trials in trial_runner that is in the PENDING or PAUSED state. This function must be idempotent.

If no trial is ready, return None.

debug_string() → str[source]

Returns a human readable message for printing to the console.

Shim Instantiation (tune.create_scheduler)

There is also a shim function that constructs the scheduler based on the provided string. This can be useful if the scheduler you want to use changes often (e.g., specifying the scheduler via a CLI option or config file).

tune.create_scheduler(**kwargs)

Instantiate a scheduler based on the given string.

This is useful for swapping between different schedulers.

Parameters
  • scheduler (str) – The scheduler to use.

  • **kwargs – Scheduler parameters. These keyword arguments will be passed to the initialization function of the chosen scheduler.

Returns

The scheduler.

Return type

ray.tune.schedulers.trial_scheduler.TrialScheduler

Example

>>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs)