MLflow PyTorch Lightning Example#

"""An example showing how to use Pytorch Lightning training, Ray Tune
HPO, and MLflow autologging all together."""
import os
import tempfile

import pytorch_lightning as pl

import mlflow

from ray import train, tune
from ray.air.integrations.mlflow import setup_mlflow
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier, MNISTDataModule


def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
    setup_mlflow(
        config,
        experiment_name=config.get("experiment_name", None),
        tracking_uri=config.get("tracking_uri", None),
    )

    model = LightningMNISTClassifier(config, data_dir)
    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]
    )
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    mlflow.pytorch.autolog()
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
        callbacks=[TuneReportCallback(metrics, on="validation_end")],
    )
    trainer.fit(model, dm)


def tune_mnist(
    num_samples=10,
    num_epochs=10,
    gpus_per_trial=0,
    tracking_uri=None,
    experiment_name="ptl_autologging_example",
):
    data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
    # Download data
    MNISTDataModule(data_dir=data_dir, batch_size=32).prepare_data()

    # Set the MLflow experiment, or create it if it does not exist.
    mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_experiment(experiment_name)

    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
        "experiment_name": experiment_name,
        "tracking_uri": mlflow.get_tracking_uri(),
        "data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"),
        "num_epochs": num_epochs,
    }

    trainable = tune.with_parameters(
        train_mnist_tune,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_gpus=gpus_per_trial,
    )

    tuner = tune.Tuner(
        tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            num_samples=num_samples,
        ),
        run_config=train.RunConfig(
            name="tune_mnist",
        ),
        param_space=config,
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_mnist(
            num_samples=1,
            num_epochs=1,
            gpus_per_trial=0,
            tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns"),
        )
    else:
        tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)