A Basic Tune Tutorial


This tutorial will walk you through the following process to setup a Tune experiment using Pytorch. Specifically, we’ll leverage ASHA and Bayesian Optimization (via HyperOpt) via the following steps:

  1. Integrating Tune into your workflow

  2. Specifying a TrialScheduler

  3. Adding a SearchAlgorithm

  4. Getting the best model and analyzing results


To run this example, you will need to install the following:

$ pip install ray torch torchvision

We first run some imports:

import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets

from ray import tune
from ray.tune import track
from ray.tune.schedulers import ASHAScheduler
from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test

Below, we have some boiler plate code for a PyTorch training function.

def train_mnist(config):
    model = ConvNet()
    train_loader, test_loader = get_data_loaders()
    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        if i % 5 == 0:
            # This saves the model to the trial directory
            torch.save(model, "./model.pth")

Notice that there’s a couple helper functions in the above training script. You can take a look at these functions in the imported module examples/mnist_pytorch; there’s no black magic happening. For example, train is simply a for loop over the data loader.


def train(model, optimizer, train_loader):
    for batch_idx, (data, target) in enumerate(train_loader):
        if batch_idx * len(data) > EPOCH_SIZE:
        output = model(data)
        loss = F.nll_loss(output, target)

Let’s run 1 trial, randomly sampling from a uniform distribution for learning rate and momentum.

search_space = {
    "lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())),
    "momentum": tune.uniform(0.1, 0.9)

# Uncomment this to enable distributed execution
# `ray.init(address=...)`

analysis = tune.run(train_mnist, config=search_space)

We can then plot the performance of this trial.

dfs = analysis.trial_dataframes
[d.mean_accuracy.plot() for d in dfs.values()]


Tune will automatically run parallel trials across all available cores/GPUs on your machine or cluster. To limit the number of cores that Tune uses, you can call ray.init(num_cpus=<int>, num_gpus=<int>) before tune.run.

Early Stopping with ASHA

Let’s integrate a Trial Scheduler to our search - ASHA, a scalable algorithm for principled early stopping.

How does it work? On a high level, it terminates trials that are less promising and allocates more time and resources to more promising trials. See this blog post for more details.

We can afford to increase the search space by 5x, by adjusting the parameter num_samples. See Tune Trial Schedulers for more details of available schedulers and library integrations.

analysis = tune.run(
    scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),

# Obtain a trial dataframe from all run trials of this `tune.run` call.
dfs = analysis.trial_dataframes

You can run the below in a Jupyter notebook to visualize trial progress.

# Plot by epoch
ax = None  # This plots everything on the same plot
for d in dfs.values():
    ax = d.mean_accuracy.plot(ax=ax, legend=False)

You can also use Tensorboard for visualizing results.

$ tensorboard --logdir {logdir}

Search Algorithms in Tune

With Tune you can combine powerful hyperparameter search libraries such as HyperOpt and Ax with state-of-the-art algorithms such as HyperBand without modifying any model training code. Tune allows you to use different search algorithms in combination with different trial schedulers. See Tune Search Algorithms for more details of available algorithms and library integrations.

from hyperopt import hp
from ray.tune.suggest.hyperopt import HyperOptSearch

space = {
    "lr": hp.loguniform("lr", 1e-10, 0.1),
    "momentum": hp.uniform("momentum", 0.1, 0.9),

hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max")

analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search)

Evaluate your model

You can evaluate best trained model using the Analysis object to retrieve the best model:

import os

df = analysis.dataframe()
logdir = analysis.get_best_logdir("mean_accuracy", mode="max")
model = torch.load(os.path.join(logdir, "model.pth"))

Next Steps

Take a look at the Tune User Guide for a more comprehensive overview of Tune’s features.