Using PyTorch Lightning with Tune#

PyTorch Lightning is a framework which brings structure into training PyTorch models. It aims to avoid boilerplate code, so you don’t have to write the same training loops all over again when building a new model.

../../_images/pytorch_lightning_full.png

The main abstraction of PyTorch Lightning is the LightningModule class, which should be extended by your application. There is a great post on how to transfer your models from vanilla PyTorch to Lightning.

The class structure of PyTorch Lightning makes it very easy to define and tune model parameters. This tutorial will show you how to use Tune to find the best set of parameters for your application on the example of training a MNIST classifier. Notably, the LightningModule does not have to be altered at all for this - so you can use it plug and play for your existing models, assuming their parameters are configurable!

Note

To run this example, you will need to install the following:

$ pip install "ray[tune]" torch torchvision pytorch-lightning

PyTorch Lightning classifier for MNIST#

Let’s first start with the basic PyTorch Lightning implementation of an MNIST classifier. This classifier does not include any tuning code at this point.

Our example builds on the MNIST example from the blog post we talked about earlier.

First, we run some imports:

import math

import torch
import pytorch_lightning as pl
from filelock import FileLock
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
import os

And then there is the Lightning model adapted from the blog post. Note that we left out the test set validation and made the model parameters configurable through a config dict that is passed on initialization. Also, we specify a data_dir where the MNIST data will be stored. Note that we use a FileLock for downloading data so that the dataset is only downloaded once per node. Lastly, we added a new metric, the validation accuracy, to the logs.

class LightningMNISTClassifier(pl.LightningModule):
    """
    This has been adapted from
    https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09
    """

    def __init__(self, config, data_dir=None):
        super(LightningMNISTClassifier, self).__init__()

        self.data_dir = data_dir or os.getcwd()

        self.layer_1_size = config["layer_1_size"]
        self.layer_2_size = config["layer_2_size"]
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, self.layer_1_size)
        self.layer_2 = torch.nn.Linear(self.layer_1_size, self.layer_2_size)
        self.layer_3 = torch.nn.Linear(self.layer_2_size, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)

        x = self.layer_1(x)
        x = torch.relu(x)

        x = self.layer_2(x)
        x = torch.relu(x)

        x = self.layer_3(x)
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, predicted = torch.max(logits.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / len(labels)
        return torch.tensor(accuracy)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)

        self.log("ptl/train_loss", loss)
        self.log("ptl/train_accuracy", accuracy)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        accuracy = self.accuracy(logits, y)
        return {"val_loss": loss, "val_accuracy": accuracy}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
        self.log("ptl/val_loss", avg_loss)
        self.log("ptl/val_accuracy", avg_acc)

    @staticmethod
    def download_data(data_dir):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        with FileLock(os.path.expanduser("~/.data.lock")):
            return MNIST(data_dir, train=True, download=True, transform=transform)

    def prepare_data(self):
        mnist_train = self.download_data(self.data_dir)

        self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=int(self.batch_size))

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=int(self.batch_size))

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


def train_mnist(config):
    model = LightningMNISTClassifier(config)
    trainer = pl.Trainer(max_epochs=10, enable_progress_bar=False)

    trainer.fit(model)

And that’s it! You can now run train_mnist(config) to train the classifier, e.g. like so:

def train_mnist_no_tune():
    config = {"layer_1_size": 128, "layer_2_size": 256, "lr": 1e-3, "batch_size": 64}
    train_mnist(config)

Tuning the model parameters#

The parameters above should give you a good accuracy of over 90% already. However, we might improve on this simply by changing some of the hyperparameters. For instance, maybe we get an even higher accuracy if we used a larger batch size.

Instead of guessing the parameter values, let’s use Tune to systematically try out parameter combinations and find the best performing set.

First, we need some additional imports:

from pytorch_lightning.loggers import TensorBoardLogger
from ray import train, tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.tune.integration.pytorch_lightning import (
    TuneReportCallback,
    TuneReportCheckpointCallback,
)

Talking to Tune with a PyTorch Lightning callback#

PyTorch Lightning introduced Callbacks that can be used to plug custom functions into the training loop. This way the original LightningModule does not have to be altered at all. Also, we could use the same callback for multiple modules.

Ray Tune comes with ready-to-use PyTorch Lightning callbacks. To report metrics back to Tune after each validation epoch, we will use the TuneReportCallback:

TuneReportCallback(
    {"loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy"}, on="validation_end"
)
<ray.tune.integration.pytorch_lightning.TuneReportCallback at 0x17b305710>

This callback will take the val_loss and val_accuracy values from the PyTorch Lightning trainer and report them to Tune as the loss and mean_accuracy, respectively.

Adding the Tune training function#

Then we specify our training function. Note that we added the data_dir as a parameter here to avoid that each training run downloads the full MNIST dataset. Instead, we want to access a shared data location.

We are also able to specify the number of epochs to train each model, and the number of GPUs we want to use for training. We also create a TensorBoard logger that writes logfiles directly into Tune’s root trial directory - if we didn’t do that PyTorch Lightning would create subdirectories, and each trial would thus be shown twice in TensorBoard, one time for Tune’s logs, and another time for PyTorch Lightning’s logs.

def train_mnist_tune(config, num_epochs=10, num_gpus=0, data_dir="~/data"):
    data_dir = os.path.expanduser(data_dir)
    model = LightningMNISTClassifier(config, data_dir)
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        # If fractional GPUs passed in, convert to int.
        gpus=math.ceil(num_gpus),
        logger=TensorBoardLogger(save_dir=os.getcwd(), name="", version="."),
        enable_progress_bar=False,
        callbacks=[
            TuneReportCallback(
                {"loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy"},
                on="validation_end",
            )
        ],
    )
    trainer.fit(model)

Configuring the search space#

Now we configure the parameter search space. We would like to choose between three different layer and batch sizes. The learning rate should be sampled uniformly between 0.0001 and 0.1. The tune.loguniform() function is syntactic sugar to make sampling between these different orders of magnitude easier, specifically we are able to also sample small values.

config = {
    "layer_1_size": tune.choice([32, 64, 128]),
    "layer_2_size": tune.choice([64, 128, 256]),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([32, 64, 128]),
}

Selecting a scheduler#

In this example, we use an Asynchronous Hyperband scheduler. This scheduler decides at each iteration which trials are likely to perform badly, and stops these trials. This way we don’t waste any resources on bad hyperparameter configurations.

num_epochs = 10

scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

Changing the CLI output#

We instantiate a CLIReporter to specify which metrics we would like to see in our output tables in the command line. This is optional, but can be used to make sure our output tables only include information we would like to see.

reporter = CLIReporter(
    parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
    metric_columns=["loss", "mean_accuracy", "training_iteration"],
)

Passing constants to the train function#

The data_dir, num_epochs and num_gpus we pass to the training function are constants. To avoid including them as non-configurable parameters in the config specification, we can use tune.with_parameters to wrap around the training function.

gpus_per_trial = 0
data_dir = "~/data"

train_fn_with_parameters = tune.with_parameters(
    train_mnist_tune, num_epochs=num_epochs, num_gpus=gpus_per_trial, data_dir=data_dir
)

Training with GPUs#

We can specify how many resources Tune should request for each trial. This also includes GPUs.

PyTorch Lightning takes care of moving the training to the GPUs. We already made sure that our code is compatible with that, so there’s nothing more to do here other than to specify the number of GPUs we would like to use:

resources_per_trial = {"cpu": 1, "gpu": gpus_per_trial}

You can also specify fractional GPUs for Tune, allowing multiple trials to share GPUs and thus increase concurrency under resource constraints. While the gpus_per_trial passed into Tune is a decimal value, the gpus passed into the pl.Trainer should still be an integer. Please note that if using fractional GPUs, it is the user’s responsibility to make sure multiple trials can share GPUs and there is enough memory to do so. Ray does not automatically handle this for you.

If you want to use multiple GPUs per trial, you should check out Getting Start with Lightning and Ray TorchTrainer.

Putting it together#

Lastly, we need to create a Tuner() object and start Ray Tune with tuner.fit().

The full code looks like this:

def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0, data_dir="~/data"):
    config = {
        "layer_1_size": tune.choice([32, 64, 128]),
        "layer_2_size": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
    }

    scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

    reporter = CLIReporter(
        parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
        metric_columns=["loss", "mean_accuracy", "training_iteration"],
    )

    train_fn_with_parameters = tune.with_parameters(
        train_mnist_tune,
        num_epochs=num_epochs,
        num_gpus=gpus_per_trial,
        data_dir=data_dir,
    )
    resources_per_trial = {"cpu": 1, "gpu": gpus_per_trial}

    tuner = tune.Tuner(
        tune.with_resources(train_fn_with_parameters, resources=resources_per_trial),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        run_config=train.RunConfig(
            name="tune_mnist_asha",
            progress_reporter=reporter,
        ),
        param_space=config,
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)

In the example above, Tune runs 10 trials with different hyperparameter configurations. An example output could look like so:

  +------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+
  | Trial name                   | status     | loc   |   layer_1_size |   layer_2_size |          lr |   batch_size |     loss |   mean_accuracy |   training_iteration |
  |------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------|
  | train_mnist_tune_63ecc_00000 | TERMINATED |       |            128 |             64 | 0.00121197  |          128 | 0.120173 |       0.972461  |                   10 |
  | train_mnist_tune_63ecc_00001 | TERMINATED |       |             64 |            128 | 0.0301395   |          128 | 0.454836 |       0.868164  |                    4 |
  | train_mnist_tune_63ecc_00002 | TERMINATED |       |             64 |            128 | 0.0432097   |          128 | 0.718396 |       0.718359  |                    1 |
  | train_mnist_tune_63ecc_00003 | TERMINATED |       |             32 |            128 | 0.000294669 |           32 | 0.111475 |       0.965764  |                   10 |
  | train_mnist_tune_63ecc_00004 | TERMINATED |       |             32 |            256 | 0.000386664 |           64 | 0.133538 |       0.960839  |                    8 |
  | train_mnist_tune_63ecc_00005 | TERMINATED |       |            128 |            128 | 0.0837395   |           32 | 2.32628  |       0.0991242 |                    1 |
  | train_mnist_tune_63ecc_00006 | TERMINATED |       |             64 |            128 | 0.000158761 |          128 | 0.134595 |       0.959766  |                   10 |
  | train_mnist_tune_63ecc_00007 | TERMINATED |       |             64 |             64 | 0.000672126 |           64 | 0.118182 |       0.972903  |                   10 |
  | train_mnist_tune_63ecc_00008 | TERMINATED |       |            128 |             64 | 0.000502428 |           32 | 0.11082  |       0.975518  |                   10 |
  | train_mnist_tune_63ecc_00009 | TERMINATED |       |             64 |            256 | 0.00112894  |           32 | 0.13472  |       0.971935  |                    8 |
  +------------------------------+------------+-------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+

As you can see in the training_iteration column, trials with a high loss (and low accuracy) have been terminated early. The best performing trial used layer_1_size=128, layer_2_size=64, lr=0.000502428 and batch_size=32.

Using Population Based Training to find the best parameters#

The ASHAScheduler terminates those trials early that show bad performance. Sometimes, this stops trials that would get better after more training steps, and which might eventually even show better performance than other configurations.

Another popular method for hyperparameter tuning, called Population Based Training, instead perturbs hyperparameters during the training run. Tune implements PBT, and we only need to make some slight adjustments to our code.

Adding checkpoints to the PyTorch Lightning module#

First, we need to introduce another callback to save model checkpoints. Since Tune requires a call to train.report() after creating a new checkpoint to register it, we will use a combined reporting and checkpointing callback:

TuneReportCheckpointCallback(
    metrics={"loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy"},
    filename="checkpoint",
    on="validation_end",
)
<ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback at 0x17a626090>

The checkpoint value is the name of the checkpoint file within the checkpoint directory.

We also include checkpoint loading in our training function:

def train_mnist_tune_checkpoint(config, num_epochs=10, num_gpus=0, data_dir="~/data"):
    data_dir = os.path.expanduser(data_dir)
    kwargs = {
        "max_epochs": num_epochs,
        # If fractional GPUs passed in, convert to int.
        "gpus": math.ceil(num_gpus),
        "logger": TensorBoardLogger(save_dir=os.getcwd(), name="", version="."),
        "enable_progress_bar": False,
        "callbacks": [
            TuneReportCheckpointCallback(
                metrics={"loss": "ptl/val_loss", "mean_accuracy": "ptl/val_accuracy"},
                filename="checkpoint",
                on="validation_end",
            )
        ],
    }

    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            kwargs["resume_from_checkpoint"] = os.path.join(checkpoint_dir, "checkpoint")

    model = LightningMNISTClassifier(config=config, data_dir=data_dir)
    trainer = pl.Trainer(**kwargs)

    trainer.fit(model)

Configuring and running Population Based Training#

We need to call Tune slightly differently:

def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0, data_dir="~/data"):
    config = {
        "layer_1_size": tune.choice([32, 64, 128]),
        "layer_2_size": tune.choice([64, 128, 256]),
        "lr": 1e-3,
        "batch_size": 64,
    }

    scheduler = PopulationBasedTraining(
        perturbation_interval=4,
        hyperparam_mutations={
            "lr": tune.loguniform(1e-4, 1e-1),
            "batch_size": [32, 64, 128],
        },
    )

    reporter = CLIReporter(
        parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
        metric_columns=["loss", "mean_accuracy", "training_iteration"],
    )

    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(
                train_mnist_tune_checkpoint,
                num_epochs=num_epochs,
                num_gpus=gpus_per_trial,
                data_dir=data_dir,
            ),
            resources={"cpu": 1, "gpu": gpus_per_trial},
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            scheduler=scheduler,
            num_samples=num_samples,
        ),
        run_config=train.RunConfig(
            name="tune_mnist_asha",
            progress_reporter=reporter,
        ),
        param_space=config,
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)

Instead of passing tune parameters to the config dict, we start with fixed values, though we are also able to sample some of them, like the layer sizes. Additionally, we have to tell PBT how to perturb the hyperparameters. Note that the layer sizes are not tuned right here. This is because we cannot simply change layer sizes during a training run - which is what would happen in PBT.

To test running both of our main scripts (tune_mnist_asha and tune_mnist_pbt), all you have to do is specify a data_dir folder and run the scripts with reasonable parameters:

data_dir = "~/data/"

tune_mnist_asha(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)
tune_mnist_pbt(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)
== Status ==
Current time: 2022-07-22 16:24:58 (running for 00:00:00.16)
Memory usage on this node: 11.2/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+
(train_mnist_tune pid=52355) GPU available: False, used: False
(train_mnist_tune pid=52355) TPU available: False, using: 0 TPU cores
(train_mnist_tune pid=52355) IPU available: False, using: 0 IPUs
(train_mnist_tune pid=52355) HPU available: False, using: 0 HPUs
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:336: LightningDeprecationWarning: The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7. Please use the `on_exception` callback hook instead.
(train_mnist_tune pid=52355)   "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:348: LightningDeprecationWarning: The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8.
(train_mnist_tune pid=52355)   "The `on_init_start` callback hook was deprecated in v1.6 and will be removed in v1.8."
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:351: LightningDeprecationWarning: The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.
(train_mnist_tune pid=52355)   rank_zero_deprecation("The `on_init_end` callback hook was deprecated in v1.6 and will be removed in v1.8.")
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:377: LightningDeprecationWarning: The `Callback.on_batch_start` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_train_batch_start` instead.
(train_mnist_tune pid=52355)   f"The `Callback.{hook}` hook was deprecated in v1.6 and"
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:377: LightningDeprecationWarning: The `Callback.on_batch_end` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_train_batch_end` instead.
(train_mnist_tune pid=52355)   f"The `Callback.{hook}` hook was deprecated in v1.6 and"
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:386: LightningDeprecationWarning: The `Callback.on_epoch_start` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_<train/validation/test>_epoch_start` instead.
(train_mnist_tune pid=52355)   f"The `Callback.{hook}` hook was deprecated in v1.6 and"
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/configuration_validator.py:386: LightningDeprecationWarning: The `Callback.on_epoch_end` hook was deprecated in v1.6 and will be removed in v1.8. Please use `Callback.on_<train/validation/test>_epoch_end` instead.
(train_mnist_tune pid=52355)   f"The `Callback.{hook}` hook was deprecated in v1.6 and"
(train_mnist_tune pid=52355) 
(train_mnist_tune pid=52355)   | Name    | Type   | Params
(train_mnist_tune pid=52355) -----------------------------------
(train_mnist_tune pid=52355) 0 | layer_1 | Linear | 100 K 
(train_mnist_tune pid=52355) 1 | layer_2 | Linear | 16.5 K
(train_mnist_tune pid=52355) 2 | layer_3 | Linear | 1.3 K 
(train_mnist_tune pid=52355) -----------------------------------
(train_mnist_tune pid=52355) 118 K     Trainable params
(train_mnist_tune pid=52355) 0         Non-trainable params
(train_mnist_tune pid=52355) 118 K     Total params
(train_mnist_tune pid=52355) 0.473     Total estimated model params size (MB)
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:245: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
(train_mnist_tune pid=52355)   category=PossibleUserWarning,
(train_mnist_tune pid=52355) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:245: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
(train_mnist_tune pid=52355)   category=PossibleUserWarning,
== Status ==
Current time: 2022-07-22 16:25:13 (running for 00:00:15.04)
Memory usage on this node: 11.7/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+


== Status ==
Current time: 2022-07-22 16:25:18 (running for 00:00:20.06)
Memory usage on this node: 11.7/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+


Result for train_mnist_tune_727f7_00000:
  date: 2022-07-22_16-25-19
  done: false
  experiment_id: d137534eb136478c9e9c4514a538f9da
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 1
  loss: 0.12323953211307526
  mean_accuracy: 0.9600474834442139
  node_ip: 127.0.0.1
  pid: 52355
  time_since_restore: 11.140714168548584
  time_this_iter_s: 11.140714168548584
  time_total_s: 11.140714168548584
  timestamp: 1658503519
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: 727f7_00000
  warmup_time: 0.0038437843322753906
  
== Status ==
Current time: 2022-07-22 16:25:24 (running for 00:00:26.19)
Memory usage on this node: 12.0/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: None | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.12323953211307526 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+---------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |    loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+---------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.12324 |        0.960047 |                    1 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+---------+-----------------+----------------------+


== Status ==
Current time: 2022-07-22 16:25:29 (running for 00:00:31.21)
Memory usage on this node: 11.9/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: None | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.12323953211307526 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+---------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |    loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+---------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.12324 |        0.960047 |                    1 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+---------+-----------------+----------------------+


Result for train_mnist_tune_727f7_00000:
  date: 2022-07-22_16-25-31
  done: false
  experiment_id: d137534eb136478c9e9c4514a538f9da
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 2
  loss: 0.09032993763685226
  mean_accuracy: 0.9731012582778931
  node_ip: 127.0.0.1
  pid: 52355
  time_since_restore: 22.593159914016724
  time_this_iter_s: 11.45244574546814
  time_total_s: 22.593159914016724
  timestamp: 1658503531
  timesteps_since_restore: 0
  training_iteration: 2
  trial_id: 727f7_00000
  warmup_time: 0.0038437843322753906
  
== Status ==
Current time: 2022-07-22 16:25:36 (running for 00:00:37.64)
Memory usage on this node: 12.1/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.09032993763685226 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |      loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.0903299 |        0.973101 |                    2 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+


== Status ==
Current time: 2022-07-22 16:25:41 (running for 00:00:42.66)
Memory usage on this node: 12.1/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.09032993763685226 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |      loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.0903299 |        0.973101 |                    2 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+


Result for train_mnist_tune_727f7_00000:
  date: 2022-07-22_16-25-42
  done: false
  experiment_id: d137534eb136478c9e9c4514a538f9da
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 3
  loss: 0.09614239633083344
  mean_accuracy: 0.9754746556282043
  node_ip: 127.0.0.1
  pid: 52355
  time_since_restore: 33.65132713317871
  time_this_iter_s: 11.058167219161987
  time_total_s: 33.65132713317871
  timestamp: 1658503542
  timesteps_since_restore: 0
  training_iteration: 3
  trial_id: 727f7_00000
  warmup_time: 0.0038437843322753906
  
== Status ==
Current time: 2022-07-22 16:25:47 (running for 00:00:48.70)
Memory usage on this node: 12.1/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.09614239633083344 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |      loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.0961424 |        0.975475 |                    3 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+


== Status ==
Current time: 2022-07-22 16:25:52 (running for 00:00:53.72)
Memory usage on this node: 12.1/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: None | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.09614239633083344 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |      loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.0961424 |        0.975475 |                    3 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+


Result for train_mnist_tune_727f7_00000:
  date: 2022-07-22_16-25-53
  done: false
  experiment_id: d137534eb136478c9e9c4514a538f9da
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 4
  loss: 0.09530177712440491
  mean_accuracy: 0.9760680198669434
  node_ip: 127.0.0.1
  pid: 52355
  time_since_restore: 44.58630990982056
  time_this_iter_s: 10.934982776641846
  time_total_s: 44.58630990982056
  timestamp: 1658503553
  timesteps_since_restore: 0
  training_iteration: 4
  trial_id: 727f7_00000
  warmup_time: 0.0038437843322753906
  
== Status ==
Current time: 2022-07-22 16:25:58 (running for 00:00:59.63)
Memory usage on this node: 11.9/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: -0.09530177712440491 | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.09530177712440491 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |      loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.0953018 |        0.976068 |                    4 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+


== Status ==
Current time: 2022-07-22 16:26:03 (running for 00:01:04.65)
Memory usage on this node: 11.9/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: -0.09530177712440491 | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.09530177712440491 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |      loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.0953018 |        0.976068 |                    4 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+-----------+-----------------+----------------------+


Result for train_mnist_tune_727f7_00000:
  date: 2022-07-22_16-26-04
  done: false
  experiment_id: d137534eb136478c9e9c4514a538f9da
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 5
  loss: 0.10016436874866486
  mean_accuracy: 0.9750791192054749
  node_ip: 127.0.0.1
  pid: 52355
  time_since_restore: 55.61101007461548
  time_this_iter_s: 11.024700164794922
  time_total_s: 55.61101007461548
  timestamp: 1658503564
  timesteps_since_restore: 0
  training_iteration: 5
  trial_id: 727f7_00000
  warmup_time: 0.0038437843322753906
  
== Status ==
Current time: 2022-07-22 16:26:09 (running for 00:01:10.66)
Memory usage on this node: 12.0/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: -0.09530177712440491 | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.10016436874866486 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |     loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.100164 |        0.975079 |                    5 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------+


== Status ==
Current time: 2022-07-22 16:26:14 (running for 00:01:15.67)
Memory usage on this node: 12.0/16.0 GiB
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 4.000: -0.09530177712440491 | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 1.0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.10016436874866486 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 RUNNING)
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------+
| Trial name                   | status   | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |     loss |   mean_accuracy |   training_iteration |
|------------------------------+----------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | RUNNING  | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.100164 |        0.975079 |                    5 |
+------------------------------+----------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------+
2022-07-22 16:26:15,433	INFO tune.py:738 -- Total run time: 76.74 seconds (76.61 seconds for the tuning loop).
Result for train_mnist_tune_727f7_00000:
  date: 2022-07-22_16-26-15
  done: true
  experiment_id: d137534eb136478c9e9c4514a538f9da
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 6
  loss: 0.10947871953248978
  mean_accuracy: 0.9756724834442139
  node_ip: 127.0.0.1
  pid: 52355
  time_since_restore: 66.58598804473877
  time_this_iter_s: 10.974977970123291
  time_total_s: 66.58598804473877
  timestamp: 1658503575
  timesteps_since_restore: 0
  training_iteration: 6
  trial_id: 727f7_00000
  warmup_time: 0.0038437843322753906
  
== Status ==
Current time: 2022-07-22 16:26:15 (running for 00:01:16.62)
Memory usage on this node: 12.1/16.0 GiB
Using AsyncHyperBand: num_stopped=1
Bracket: Iter 4.000: -0.09530177712440491 | Iter 2.000: -0.09032993763685226 | Iter 1.000: -0.12323953211307526
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/4.72 GiB heap, 0.0/2.0 GiB objects
Current best trial: 727f7_00000 with loss=0.10947871953248978 and parameters={'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
Result logdir: /Users/kai/ray_results/tune_mnist_asha
Number of trials: 1/1 (1 TERMINATED)
+------------------------------+------------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------+
| Trial name                   | status     | loc             |   layer_1_size |   layer_2_size |         lr |   batch_size |     loss |   mean_accuracy |   training_iteration |
|------------------------------+------------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------|
| train_mnist_tune_727f7_00000 | TERMINATED | 127.0.0.1:52355 |            128 |            128 | 0.00165008 |           64 | 0.109479 |        0.975672 |                    6 |
+------------------------------+------------+-----------------+----------------+----------------+------------+--------------+----------+-----------------+----------------------+


Best hyperparameters found were:  {'layer_1_size': 128, 'layer_2_size': 128, 'lr': 0.001650077499050015, 'batch_size': 64}
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/var/folders/b2/0_91bd757rz02lrmr920v0gw0000gn/T/ipykernel_52122/1146224506.py in <module>
      2 
      3 tune_mnist_asha(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)
----> 4 tune_mnist_pbt(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)

/var/folders/b2/0_91bd757rz02lrmr920v0gw0000gn/T/ipykernel_52122/328169407.py in tune_mnist_pbt(num_samples, num_epochs, gpus_per_trial, data_dir)
     38         run_config=air.RunConfig(
     39             name="tune_mnist_asha",
---> 40             tune_mnist_pbt=reporter,
     41         ),
     42         param_space=config,

TypeError: __init__() got an unexpected keyword argument 'tune_mnist_pbt'

If you have more resources available (e.g. a GPU), you can modify the above parameters accordingly.

An example output of a run could look like this:

+-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------+
| Trial name                              | status     | loc   |   layer_1_size |   layer_2_size |        lr |   batch_size |      loss |   mean_accuracy |   training_iteration |
|-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------|
| train_mnist_tune_checkpoint_85489_00000 | TERMINATED |       |            128 |            128 | 0.001     |           64 | 0.108734  |        0.973101 |                   10 |
| train_mnist_tune_checkpoint_85489_00001 | TERMINATED |       |            128 |            128 | 0.001     |           64 | 0.093577  |        0.978639 |                   10 |
| train_mnist_tune_checkpoint_85489_00002 | TERMINATED |       |            128 |            256 | 0.0008    |           32 | 0.0922348 |        0.979299 |                   10 |
| train_mnist_tune_checkpoint_85489_00003 | TERMINATED |       |             64 |            256 | 0.001     |           64 | 0.124648  |        0.973892 |                   10 |
| train_mnist_tune_checkpoint_85489_00004 | TERMINATED |       |            128 |             64 | 0.001     |           64 | 0.101717  |        0.975079 |                   10 |
| train_mnist_tune_checkpoint_85489_00005 | TERMINATED |       |             64 |             64 | 0.001     |           64 | 0.121467  |        0.969146 |                   10 |
| train_mnist_tune_checkpoint_85489_00006 | TERMINATED |       |            128 |            256 | 0.00064   |           32 | 0.053446  |        0.987062 |                   10 |
| train_mnist_tune_checkpoint_85489_00007 | TERMINATED |       |            128 |            256 | 0.001     |           64 | 0.129804  |        0.973497 |                   10 |
| train_mnist_tune_checkpoint_85489_00008 | TERMINATED |       |             64 |            256 | 0.0285125 |          128 | 0.363236  |        0.913867 |                   10 |
| train_mnist_tune_checkpoint_85489_00009 | TERMINATED |       |             32 |            256 | 0.001     |           64 | 0.150946  |        0.964201 |                   10 |
+-----------------------------------------+------------+-------+----------------+----------------+-----------+--------------+-----------+-----------------+----------------------+

As you can see, each sample ran the full number of 10 iterations. All trials ended with quite good parameter combinations and showed relatively good performances. In some runs, the parameters have been perturbed. And the best configuration even reached a mean validation accuracy of 0.987062!

In summary, PyTorch Lightning Modules are easy to extend to use with Tune. It just took us importing one or two callbacks and a small wrapper function to get great performing parameter configurations.

More PyTorch Lightning Examples#