Guide to Population Based Training (PBT)

Tune includes a distributed implementation of Population Based Training (PBT) as a scheduler.


PBT starts by training many neural networks in parallel with random hyperparameters, using information from the rest of the population to refine these hyperparameters and allocate resources to promising models. Let’s walk through how to use this algorithm.

Function API with Population Based Training

PBT takes its inspiration from genetic algorithms where each member of the population can exploit information from the remainder of the population. For example, a worker might copy the model parameters from a better performing worker. It can also explore new hyperparameters by changing the current values randomly.

As the training of the population of neural networks progresses, this process of exploiting and exploring is performed periodically, ensuring that all the workers in the population have a good base level of performance and also that new hyperparameters are consistently explored.

This means that PBT can quickly exploit good hyperparameters, can dedicate more training time to promising models and, crucially, can adapt the hyperparameter values throughout training, leading to automatic learning of the best configurations.

First we define a training function that trains a ConvNet model using SGD.

def train_convnet(config, checkpoint_dir=None):
    # Create our data loaders, model, and optmizer.
    step = 0
    train_loader, test_loader = get_data_loaders()
    model = ConvNet()
    optimizer = optim.SGD(
        lr=config.get("lr", 0.01),
        momentum=config.get("momentum", 0.9))

    # If checkpoint_dir is not None, then we are resuming from a checkpoint.
    # Load model state and iteration step from checkpoint.
    if checkpoint_dir:
        print("Loading from checkpoint.")
        path = os.path.join(checkpoint_dir, "checkpoint")
        checkpoint = torch.load(path)
        step = checkpoint["step"]

    while True:
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        if step % 5 == 0:
            # Every 5 steps, checkpoint our current state.
            # First get the checkpoint directory from tune.
            with tune.checkpoint_dir(step=step) as checkpoint_dir:
                # Then create a checkpoint file in this directory.
                path = os.path.join(checkpoint_dir, "checkpoint")
                # Save state to checkpoint file.
                # No need to save optimizer for SGD.
                    "step": step,
                    "model_state_dict": model.state_dict(),
                    "mean_accuracy": acc
                }, path)
        step += 1

The example reuses some of the functions in ray/tune/examples/, and is also a good demo for how to decouple the tuning logic and original training code.

Here, we also need to take in a checkpoint_dir arg since checkpointing is required for the exploitation process in PBT. We have to both load in the checkpoint if one is provided, and periodically save our model state in a checkpoint- in this case every 5 iterations. With SGD, there’s no need to checkpoint the optimizer since it does not depend on previous states, but this is necessary with other optimizers like Adam.

Then, we define a PBT scheduler:

    scheduler = PopulationBasedTraining(
            # distribution for resampling
            "lr": lambda: np.random.uniform(0.0001, 1),
            # allow perturbations within this set of categorical values
            "momentum": [0.8, 0.9, 0.99],

Some of the most important parameters are:

  • hyperparam_mutations and custom_explore_fn are used to mutate the hyperparameters. hyperparam_mutations is a dictionary where each key/value pair specifies the candidates or function for a hyperparameter. custom_explore_fn is applied after built-in perturbations from hyperparam_mutations are applied, and should return config updated as needed.

  • resample_probability: The probability of resampling from the original distribution when applying hyperparam_mutations. If not resampled, the value will be perturbed by a factor of 1.2 or 0.8 if continuous, or changed to an adjacent value if discrete. Note that resample_probability by default is 0.25, thus hyperparameter with a distribution may go out of the specific range.

Now we can kick off the tuning process by invoking

    class CustomStopper(tune.Stopper):
        def __init__(self):
            self.should_stop = False

        def __call__(self, trial_id, result):
            max_iter = 5 if args.smoke_test else 100
            if not self.should_stop and result["mean_accuracy"] > 0.96:
                self.should_stop = True
            return self.should_stop or result["training_iteration"] >= max_iter

        def stop_all(self):
            return self.should_stop

    stopper = CustomStopper()

    analysis =
            "lr": tune.uniform(0.001, 1),
            "momentum": tune.uniform(0.001, 1),

During the training, we can constantly check the status of the models from console log:

== Status ==
Memory usage on this node: 11.2/16.0 GiB
PopulationBasedTraining: 12 checkpoints, 5 perturbs
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.83 GiB heap, 0.0/1.66 GiB objects
Result logdir: /Users/foo/ray_results/pbt_test
Number of trials: 4 (4 TERMINATED)
| Trial name                | status     | loc   |        lr |   momentum |      acc |   iter |   total time (s) |
| train_convnet_b2732_00000 | TERMINATED |       | 0.221776  |   0.608416 | 0.95625  |     59 |          13.0862 |
| train_convnet_b2732_00001 | TERMINATED |       | 0.0734679 |   0.1484   | 0.934375 |     59 |          13.1084 |
| train_convnet_b2732_00002 | TERMINATED |       | 0.0376862 |   0.8      | 0.971875 |     46 |          10.2909 |
| train_convnet_b2732_00003 | TERMINATED |       | 0.0471078 |   0.8      | 0.95     |     51 |          11.3355 |

In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in pbt_global.txt and individual policy perturbations are recorded in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag, target trial iteration, clone trial iteration, old config, new config] on each perturbation step.

Checking the accuracy:

# Plot by wall-clock time
dfs = analysis.fetch_trial_dataframes()
# This plots everything on the same plot
ax = None
for d in dfs.values():
    ax = d.plot("training_iteration", "mean_accuracy", ax=ax, legend=False)

plt.ylabel("Test Accuracy")

print('best config:', analysis.get_best_config("mean_accuracy"))

Replaying a PBT run

A run of Population Based Training ends with fully trained models. However, sometimes you might like to train the model from scratch, but use the same hyperparameter schedule as obtained from PBT. Ray Tune offers a replay utility for this.

All you need to do is pass the policy log file for the trial you want to replay. This is usually stored in the experiment directory, for instance ~/ray_results/pbt_test/pbt_policy_ba982_00000.txt.

The replay utility reads the original configuration for the trial and updates it each time when it was originally perturbed. You can (and should) thus just use the same Trainable for the replay run.

from ray import tune

from ray.tune.examples.pbt_convnet_example import PytorchTrainable
from ray.tune.schedulers import PopulationBasedTrainingReplay

replay = PopulationBasedTrainingReplay(
    stop={"training_iteration": 100})


The Generative Adversarial Networks (GAN) (Goodfellow et al., 2014) framework learns generative models via a training paradigm consisting of two competing modules – a generator and a discriminator. GAN training can be remarkably brittle and unstable in the face of suboptimal hyperparameter selection with generators often collapsing to a single mode or diverging entirely.

As presented in Population Based Training (PBT), PBT can help with the DCGAN training. We will now walk through how to do this in Tune. Complete code example at github

We define the Generator and Discriminator with standard Pytorch API:

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(, 1.0, 0.02)
        nn.init.constant_(, 0)

# Generator Code
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False), nn.Sigmoid())

    def forward(self, input):
        return self.main(input)

To train the model with PBT, we need to define a metric for the scheduler to evaluate the model candidates. For a GAN network, inception score is arguably the most commonly used metric. We trained a mnist classification model (LeNet) and use it to inference the generated images and evaluate the image quality.

class Net(nn.Module):
    LeNet for MNist classification, used for inception_score

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x,
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def inception_score(imgs, mnist_model_ref, batch_size=32, splits=1):
    N = len(imgs)
    dtype = torch.FloatTensor
    dataloader =, batch_size=batch_size)
    cm = ray.get(mnist_model_ref)  # Get the mnist model from Ray object store.
    up = nn.Upsample(size=(28, 28), mode="bilinear").type(dtype)

    def get_pred(x):
        x = up(x)
        x = cm(x)
        return F.softmax(x).data.cpu().numpy()

    preds = np.zeros((N, 10))
    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]
        preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []
    for k in range(splits):
        part = preds[k * (N // splits):(k + 1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))

    return np.mean(split_scores), np.std(split_scores)

We define a training function that includes a Generator and a Discriminator, each with an independent learning rate and optimizer. We make sure to implement checkpointing for our training.

def dcgan_train(config, checkpoint_dir=None):
    step = 0
    use_cuda = config.get("use_gpu") and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    netD = Discriminator().to(device)
    netG = Generator().to(device)
    criterion = nn.BCELoss()
    optimizerD = optim.Adam(
        netD.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999))
    optimizerG = optim.Adam(
        netG.parameters(), lr=config.get("lr", 0.01), betas=(beta1, 0.999))
    with FileLock(os.path.expanduser("~/.data.lock")):
        dataloader = get_data_loader()

    if checkpoint_dir is not None:
        path = os.path.join(checkpoint_dir, "checkpoint")
        checkpoint = torch.load(path)
        step = checkpoint["step"]

        if "netD_lr" in config:
            for param_group in optimizerD.param_groups:
                param_group["lr"] = config["netD_lr"]
        if "netG_lr" in config:
            for param_group in optimizerG.param_groups:
                param_group["lr"] = config["netG_lr"]

    while True:
        lossG, lossD, is_score = train(netD, netG, optimizerG, optimizerD,
                                       criterion, dataloader, step, device,
        step += 1
        with tune.checkpoint_dir(step=step) as checkpoint_dir:
            path = os.path.join(checkpoint_dir, "checkpoint")
                "netDmodel": netD.state_dict(),
                "netGmodel": netG.state_dict(),
                "optimD": optimizerD.state_dict(),
                "optimG": optimizerG.state_dict(),
                "step": step,
            }, path), lossd=lossD, is_score=is_score)

We specify inception score as the metric and start the tuning:

    # load the pretrained mnist classification model for inception_score
    mnist_cnn = Net()
    # Put the model in Ray object store.
    mnist_model_ref = ray.put(mnist_cnn)

    scheduler = PopulationBasedTraining(
            # distribution for resampling
            "netG_lr": lambda: np.random.uniform(1e-2, 1e-5),
            "netD_lr": lambda: np.random.uniform(1e-2, 1e-5),

    tune_iter = 5 if args.smoke_test else 300
    analysis =
            "training_iteration": tune_iter,
            "netG_lr": tune.choice([0.0001, 0.0002, 0.0005]),
            "netD_lr": tune.choice([0.0001, 0.0002, 0.0005]),
            "mnist_model_ref": mnist_model_ref

The trained Generator models can be loaded from log directory, and generate images from noise signals.


Below, we visualize the increasing inception score from the training logs.

lossG = [df['is_score'].tolist() for df in list(analysis.trial_dataframes.values())]

plt.title("Inception Score During Training")
for i, lossg in enumerate(lossG):


And the Generator loss:

lossG = [df['lossg'].tolist() for df in list(analysis.trial_dataframes.values())]

plt.title("Generator Loss During Training")
for i, lossg in enumerate(lossG):


Training of the MNist Generator takes a couple of minutes. The example can be easily altered to generate images for other datasets, e.g. cifar10 or LSUN.