ray.tune.schedulers.PopulationBasedTraining#
- class ray.tune.schedulers.PopulationBasedTraining(time_attr: str = 'time_total_s', metric: str | None = None, mode: str | None = None, perturbation_interval: float = 60.0, burn_in_period: float = 0.0, hyperparam_mutations: Dict[str, dict | list | tuple | Callable | Domain] = None, quantile_fraction: float = 0.25, resample_probability: float = 0.25, perturbation_factors: Tuple[float, float] = (1.2, 0.8), custom_explore_fn: Callable | None = None, log_config: bool = True, require_attrs: bool = True, synch: bool = False)[source]#
Bases:
FIFOScheduler
Implements the Population Based Training (PBT) algorithm.
https://www.deepmind.com/blog/population-based-training-of-neural-networks
PBT trains a group of models (or agents) in parallel. Periodically, poorly performing models clone the state of the top performers, and a random mutation is applied to their hyperparameters in the hopes of outperforming the current top models.
Unlike other hyperparameter search algorithms, PBT mutates hyperparameters during training time. This enables very fast hyperparameter discovery and also automatically discovers good annealing schedules.
This Tune PBT implementation considers all trials added as part of the PBT population. If the number of trials exceeds the cluster capacity, they will be time-multiplexed as to balance training progress across the population. To run multiple trials, use
tune.TuneConfig(num_samples=<int>)
.In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in
pbt_global.txt
and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config] on each perturbation step.- Parameters:
time_attr – The training result attr to use for comparing time. Note that you can pass in something non-temporal such as
training_iteration
as a measure of progress, the only requirement is that the attribute should increase monotonically.metric – The training result objective value attribute. Stopping procedures will use this attribute. If None but a mode was passed, the
ray.tune.result.DEFAULT_METRIC
will be used per default.mode – One of {min, max}. Determines whether objective is minimizing or maximizing the metric attribute.
perturbation_interval – Models will be considered for perturbation at this interval of
time_attr
. Note that perturbation incurs checkpoint overhead, so you shouldn’t set this to be too frequent.burn_in_period – Models will not be considered for perturbation before this interval of
time_attr
has passed. This guarantees that models are trained for at least a certain amount of time or timesteps before being perturbed.hyperparam_mutations – Hyperparams to mutate. The format is as follows: for each key, either a list, function, or a tune search space object (tune.loguniform, tune.uniform, etc.) can be provided. A list specifies an allowed set of categorical values. A function or tune search space object specifies the distribution of a continuous parameter. You must use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary tune.sample_from objects are not supported. A key can also hold a dict for nested hyperparameters. You must specify at least one of
hyperparam_mutations
orcustom_explore_fn
. Tune will sample the search space provided byhyperparam_mutations
for the initial hyperparameter values if the corresponding hyperparameters are not present in a trial’s initialconfig
.quantile_fraction – Parameters are transferred from the top
quantile_fraction
fraction of trials to the bottomquantile_fraction
fraction. Needs to be between 0 and 0.5. Setting it to 0 essentially implies doing no exploitation at all.resample_probability – The probability of resampling from the original distribution when applying
hyperparam_mutations
. If not resampled, the value will be perturbed by a factor chosen fromperturbation_factors
if continuous, or changed to an adjacent value if discrete.perturbation_factors – Scaling factors to choose between when mutating a continuous hyperparameter.
custom_explore_fn – You can also specify a custom exploration function. This function is invoked as
f(config)
after built-in perturbations fromhyperparam_mutations
are applied, and should returnconfig
updated as needed. You must specify at least one ofhyperparam_mutations
orcustom_explore_fn
.log_config – Whether to log the ray config of each model to local_dir at each exploit. Allows config schedule to be reconstructed.
require_attrs – Whether to require time_attr and metric to appear in result for every iteration. If True, error will be raised if these values are not present in trial result.
synch – If False, will use asynchronous implementation of PBT. Trial perturbations occur every perturbation_interval for each trial independently. If True, will use synchronous implementation of PBT. Perturbations will occur only after all trials are synced at the same time_attr every perturbation_interval. Defaults to False. See Appendix A.1 here https://arxiv.org/pdf/1711.09846.pdf.
import random from ray import tune from ray.tune.schedulers import PopulationBasedTraining pbt = PopulationBasedTraining( time_attr="training_iteration", metric="episode_reward_mean", mode="max", perturbation_interval=10, # every 10 `time_attr` units # (training_iterations in this case) hyperparam_mutations={ # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling # resets it to a value sampled from the lambda function. "factor_1": lambda: random.uniform(0.0, 20.0), # Alternatively, use tune search space primitives. # The search space for factor_1 is equivalent to factor_2. "factor_2": tune.uniform(0.0, 20.0), # Perturb factor3 by changing it to an adjacent value, e.g. # 10 -> 1 or 10 -> 100. Resampling will choose at random. "factor_3": [1, 10, 100, 1000, 10000], # Using tune.choice is NOT equivalent to the above. # factor_4 is treated as a continuous hyperparameter. "factor_4": tune.choice([1, 10, 100, 1000, 10000]), }) tuner = tune.Tuner( trainable, tune_config=tune.TuneConfig( scheduler=pbt, num_samples=8, ), ) tuner.fit()
Methods
Ensures all trials get fair share of time (as defined by time_attr).
Restore trial scheduler from checkpoint.
Save trial scheduler to a checkpoint
Attributes
Status for continuing trial execution
Status for pausing trial execution
Status for stopping trial execution