ray.tune.schedulers.pb2.PB2#
- class ray.tune.schedulers.pb2.PB2(time_attr: str = 'time_total_s', metric: str | None = None, mode: str | None = None, perturbation_interval: float = 60.0, hyperparam_bounds: Dict[str, dict | list | tuple] = None, quantile_fraction: float = 0.25, log_config: bool = True, require_attrs: bool = True, synch: bool = False, custom_explore_fn: Callable[[dict], dict] | None = None)[source]#
Bases:
PopulationBasedTraining
Implements the Population Based Bandit (PB2) algorithm.
PB2 trains a group of models (or agents) in parallel. Periodically, poorly performing models clone the state of the top performers, and the hyper- parameters are re-selected using GP-bandit optimization. The GP model is trained to predict the improvement in the next training period.
Like PBT, PB2 adapts hyperparameters during training time. This enables very fast hyperparameter discovery and also automatically discovers schedules.
This Tune PB2 implementation is built on top of Tune’s PBT implementation. It considers all trials added as part of the PB2 population. If the number of trials exceeds the cluster capacity, they will be time-multiplexed as to balance training progress across the population. To run multiple trials, use
tune.TuneConfig(num_samples=<int>)
.In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in
pb2_global.txt
and individual policy perturbations are recorded in pb2_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config] on each perturbation step.- Parameters:
time_attr – The training result attr to use for comparing time. Note that you can pass in something non-temporal such as
training_iteration
as a measure of progress, the only requirement is that the attribute should increase monotonically.metric – The training result objective value attribute. Stopping procedures will use this attribute.
mode – One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute.
perturbation_interval – Models will be considered for perturbation at this interval of
time_attr
. Note that perturbation incurs checkpoint overhead, so you shouldn’t set this to be too frequent.hyperparam_bounds – Hyperparameters to mutate. The format is as follows: for each key, enter a list of the form [min, max] representing the minimum and maximum possible hyperparam values. A key can also hold a dict for nested hyperparameters. Tune will sample uniformly between the bounds provided by
hyperparam_bounds
for the initial hyperparameter values if the corresponding hyperparameters are not present in a trial’s initialconfig
.quantile_fraction – Parameters are transferred from the top
quantile_fraction
fraction of trials to the bottomquantile_fraction
fraction. Needs to be between 0 and 0.5. Setting it to 0 essentially implies doing no exploitation at all.custom_explore_fn – You can also specify a custom exploration function. This function is invoked as
f(config)
, where the input is the new config generated by Bayesian Optimization. This function should return theconfig
updated as needed.log_config – Whether to log the ray config of each model to local_dir at each exploit. Allows config schedule to be reconstructed.
require_attrs – Whether to require time_attr and metric to appear in result for every iteration. If True, error will be raised if these values are not present in trial result.
synch – If False, will use asynchronous implementation of PBT. Trial perturbations occur every perturbation_interval for each trial independently. If True, will use synchronous implementation of PBT. Perturbations will occur only after all trials are synced at the same time_attr every perturbation_interval. Defaults to False. See Appendix A.1 here https://arxiv.org/pdf/1711.09846.pdf.
Example
from ray import tune from ray.tune.schedulers.pb2 import PB2 from ray.tune.examples.pbt_function import pbt_function # run "pip install gpy" to use PB2 pb2 = PB2( metric="mean_accuracy", mode="max", perturbation_interval=20, hyperparam_bounds={"lr": [0.0001, 0.1]}, ) tuner = tune.Tuner( pbt_function, tune_config=tune.TuneConfig( scheduler=pb2, num_samples=8, ), param_space={"lr": 0.0001}, ) tuner.fit()
Methods
Ensures all trials get fair share of time (as defined by time_attr).
Restore trial scheduler from checkpoint.
Save trial scheduler to a checkpoint
Attributes
Status for continuing trial execution
Status for pausing trial execution
Status for stopping trial execution