PBT Function Example#
The following script produces the following results. For a population of 8 trials, the PBT learning rate schedule roughly matches the optimal learning rate schedule.
#!/usr/bin/env python
import argparse
import json
import os
import random
import tempfile
import numpy as np
import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import PopulationBasedTraining
def pbt_function(config):
"""Toy PBT problem for benchmarking adaptive learning rate.
The goal is to optimize this trainable's accuracy. The accuracy increases
fastest at the optimal lr, which is a function of the current accuracy.
The optimal lr schedule for this problem is the triangle wave as follows.
Note that many lr schedules for real models also follow this shape:
best lr
^
| /\
| / \
| / \
| / \
------------> accuracy
In this problem, using PBT with a population of 2-4 is sufficient to
roughly approximate this lr schedule. Higher population sizes will yield
faster convergence. Training will not converge without PBT.
"""
lr = config["lr"]
checkpoint_interval = config.get("checkpoint_interval", 1)
accuracy = 0.0 # end = 1000
# NOTE: See below why step is initialized to 1
step = 1
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint.json"), "r") as f:
checkpoint_dict = json.load(f)
accuracy = checkpoint_dict["acc"]
last_step = checkpoint_dict["step"]
# Current step should be 1 more than the last checkpoint step
step = last_step + 1
# triangle wave:
# - start at 0.001 @ t=0,
# - peak at 0.01 @ t=midpoint,
# - end at 0.001 @ t=midpoint * 2,
midpoint = 100 # lr starts decreasing after acc > midpoint
q_tolerance = 3 # penalize exceeding lr by more than this multiple
noise_level = 2 # add gaussian noise to the acc increase
# Let `stop={"done": True}` in the configs below handle trial stopping
while True:
if accuracy < midpoint:
optimal_lr = 0.01 * accuracy / midpoint
else:
optimal_lr = 0.01 - 0.01 * (accuracy - midpoint) / midpoint
optimal_lr = min(0.01, max(0.001, optimal_lr))
# compute accuracy increase
q_err = max(lr, optimal_lr) / min(lr, optimal_lr)
if q_err < q_tolerance:
accuracy += (1.0 / q_err) * random.random()
elif lr > optimal_lr:
accuracy -= (q_err - q_tolerance) * random.random()
accuracy += noise_level * np.random.normal()
accuracy = max(0, accuracy)
metrics = {
"mean_accuracy": accuracy,
"cur_lr": lr,
"optimal_lr": optimal_lr, # for debugging
"q_err": q_err, # for debugging
"done": accuracy > midpoint * 2, # this stops the training process
}
if step % checkpoint_interval == 0:
# Checkpoint every `checkpoint_interval` steps
# NOTE: if we initialized `step=0` above, our checkpointing and perturbing
# would be out of sync by 1 step.
# Ex: if `checkpoint_interval` = `perturbation_interval` = 3
# step: 0 (checkpoint) 1 2 3 (checkpoint)
# training_iteration: 1 2 3 (perturb) 4
with tempfile.TemporaryDirectory() as tempdir:
with open(os.path.join(tempdir, "checkpoint.json"), "w") as f:
checkpoint_dict = {"acc": accuracy, "step": step}
json.dump(checkpoint_dict, f)
train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
else:
train.report(metrics)
step += 1
def run_tune_pbt(smoke_test=False):
perturbation_interval = 5
pbt = PopulationBasedTraining(
time_attr="training_iteration",
perturbation_interval=perturbation_interval,
hyperparam_mutations={
# distribution for resampling
"lr": tune.uniform(0.0001, 0.02),
# allow perturbations within this set of categorical values
"some_other_factor": [1, 2],
},
)
tuner = tune.Tuner(
pbt_function,
run_config=train.RunConfig(
name="pbt_function_api_example",
verbose=False,
stop={
# Stop when done = True or at some # of train steps
# (whichever comes first)
"done": True,
"training_iteration": 10 if smoke_test else 1000,
},
failure_config=train.FailureConfig(
fail_fast=True,
),
checkpoint_config=train.CheckpointConfig(
checkpoint_score_attribute="mean_accuracy",
num_to_keep=2,
),
),
tune_config=tune.TuneConfig(
scheduler=pbt,
metric="mean_accuracy",
mode="max",
num_samples=8,
reuse_actors=True,
),
param_space={
"lr": 0.0001,
# Note: `some_other_factor` is perturbed because it is specified under
# the PBT scheduler's `hyperparam_mutations` argument, but has no effect on
# the model training in this example
"some_other_factor": 1,
# Note: `checkpoint_interval` will not be perturbed (since it's not
# included above), and it will be used to determine how many steps to take
# between each checkpoint.
# We recommend matching `perturbation_interval` and `checkpoint_interval`
# (e.g. checkpoint every 4 steps, and perturb on those same steps)
# or making `perturbation_interval` a multiple of `checkpoint_interval`
# (e.g. checkpoint every 2 steps, and perturb every 4 steps).
# This is to ensure that the lastest checkpoints are being used by PBT
# when trials decide to exploit. If checkpointing and perturbing are not
# aligned, then PBT may use a stale checkpoint to resume from.
"checkpoint_interval": perturbation_interval,
},
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing",
)
args, _ = parser.parse_known_args()
if args.smoke_test:
ray.init(num_cpus=2) # force pausing to happen for test
run_tune_pbt(smoke_test=args.smoke_test)