A Basic Tune Tutorial

This tutorial will walk you through the process of setting up Tune. Specifically, we’ll leverage early stopping and Bayesian Optimization (via HyperOpt) to optimize your PyTorch model.


If you have suggestions as to how to improve this tutorial, please let us know!

To run this example, you will need to install the following:

$ pip install ray torch torchvision

Pytorch Model Setup

To start off, let’s first import some dependencies:

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

from ray import tune
from ray.tune.schedulers import ASHAScheduler

Then, let’s define the PyTorch model that we’ll be training.

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # In this example, we don't change the model architecture
        # due to simplicity.
        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
        self.fc = nn.Linear(192, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 3))
        x = x.view(-1, 192)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

Below, we have some boiler plate code for training and evaluating your model in Pytorch. Skip ahead to the Tune usage.

# Change these values if you want the training to run quicker or slower.

def train(model, optimizer, train_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    for batch_idx, (data, target) in enumerate(train_loader):
        # We set this just for the example to run quickly.
        if batch_idx * len(data) > EPOCH_SIZE:
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.nll_loss(output, target)

def test(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(data_loader):
            # We set this just for the example to run quickly.
            if batch_idx * len(data) > TEST_SIZE:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    return correct / total

Setting up Tune

Below, we define a function that trains the Pytorch model for multiple epochs. This function will be executed on a separate Ray Actor (process) underneath the hood, so we need to communicate the performance of the model back to Tune (which is on the main Python process).

To do this, we call tune.report in our training function, which sends the performance value back to Tune.


Since the function is executed on the separate process, make sure that the function is serializable by Ray.

def train_mnist(config):
    # Data Setup
    mnist_transforms = transforms.Compose(
         transforms.Normalize((0.1307, ), (0.3081, ))])

    train_loader = DataLoader(
        datasets.MNIST("~/data", train=True, download=True, transform=mnist_transforms),
    test_loader = DataLoader(
        datasets.MNIST("~/data", train=False, transform=mnist_transforms),

    model = ConvNet()
    optimizer = optim.SGD(
        model.parameters(), lr=config["lr"], momentum=config["momentum"])
    for i in range(10):
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)

        # Send the current training result back to Tune

        if i % 5 == 0:
            # This saves the model to the trial directory
            torch.save(model.state_dict(), "./model.pth")

Let’s run 1 trial by calling tune.run and randomly sample from a uniform distribution for learning rate and momentum.

search_space = {
    "lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())),
    "momentum": tune.uniform(0.1, 0.9)

# Uncomment this to enable distributed execution
# `ray.init(address="auto")`

# Download the dataset first
datasets.MNIST("~/data", train=True, download=True)

analysis = tune.run(train_mnist, config=search_space)

tune.run returns an Analysis object. You can use this to plot the performance of this trial.

dfs = analysis.trial_dataframes
[d.mean_accuracy.plot() for d in dfs.values()]


Tune will automatically run parallel trials across all available cores/GPUs on your machine or cluster. To limit the number of cores that Tune uses, you can call ray.init(num_cpus=<int>, num_gpus=<int>) before tune.run. If you’re using a Search Algorithm like Bayesian Optimization, you’ll want to use the ConcurrencyLimiter.

Early Stopping with ASHA

Let’s integrate early stopping into our optimization process. Let’s use ASHA, a scalable algorithm for principled early stopping.

On a high level, ASHA terminates trials that are less promising and allocates more time and resources to more promising trials. As our optimization process becomes more efficient, we can afford to increase the search space by 5x, by adjusting the parameter num_samples.

ASHA is implemented in Tune as a “Trial Scheduler”. These Trial Schedulers can early terminate bad trials, pause trials, clone trials, and alter hyperparameters of a running trial. See the TrialScheduler documentation for more details of available schedulers and library integrations.

analysis = tune.run(
    scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),

# Obtain a trial dataframe from all run trials of this `tune.run` call.
dfs = analysis.trial_dataframes

You can run the below in a Jupyter notebook to visualize trial progress.

# Plot by epoch
ax = None  # This plots everything on the same plot
for d in dfs.values():
    ax = d.mean_accuracy.plot(ax=ax, legend=False)

You can also use Tensorboard for visualizing results.

$ tensorboard --logdir {logdir}

Search Algorithms in Tune

In addition to TrialSchedulers, you can further optimize your hyperparameters by using an intelligent search technique like Bayesian Optimization. To do this, you can use a Tune Search Algorithm. Search Algorithms leverage optimization algorithms to intelligently navigate the given hyperparameter space.

Note that each library has a specific way of defining the search space.

from hyperopt import hp
from ray.tune.suggest.hyperopt import HyperOptSearch

space = {
    "lr": hp.loguniform("lr", 1e-10, 0.1),
    "momentum": hp.uniform("momentum", 0.1, 0.9),

hyperopt_search = HyperOptSearch(space, metric="mean_accuracy", mode="max")

analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search)


Tune allows you to use some search algorithms in combination with different trial schedulers. See this page for more details.

Evaluate your model

You can evaluate best trained model using the Analysis object to retrieve the best model:

import os

df = analysis.dataframe()
logdir = analysis.get_best_logdir("mean_accuracy", mode="max")
state_dict = torch.load(os.path.join(logdir, "model.pth"))

model = ConvNet()

Next Steps

  • Take a look at the User Guide & Configuring Tune for a more comprehensive overview of Tune’s features.

  • Check out the Tune tutorials for guides on using Tune with your preferred machine learning library.

  • Browse our gallery of examples to see how to use Tune with PyTorch, XGBoost, Tensorflow, etc.

  • Let us know if you ran into issues or have any questions by opening an issue on our Github.