MLflow PyTorch Lightning Example#

"""An example showing how to use Pytorch Lightning training, Ray Tune
HPO, and MLflow autologging all together."""

import os
import tempfile

import lightning.pytorch as pl
import mlflow

from ray import tune
from ray.air.integrations.mlflow import setup_mlflow
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier, MNISTDataModule
from ray.tune.integration.pytorch_lightning import TuneReportCallback


def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
    setup_mlflow(
        config,
        experiment_name=config.get("experiment_name", None),
        tracking_uri=config.get("tracking_uri", None),
    )

    model = LightningMNISTClassifier(config, data_dir)
    dm = MNISTDataModule(
        data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]
    )
    metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
    mlflow.pytorch.autolog()
    trainer = pl.Trainer(
        max_epochs=num_epochs,
        gpus=num_gpus,
        progress_bar_refresh_rate=0,
        callbacks=[TuneReportCallback(metrics, on="validation_end")],
    )
    trainer.fit(model, dm)


def tune_mnist(
    num_samples=10,
    num_epochs=10,
    gpus_per_trial=0,
    tracking_uri=None,
    experiment_name="ptl_autologging_example",
):
    data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
    # Download data
    MNISTDataModule(data_dir=data_dir, batch_size=32).prepare_data()

    # Set the MLflow experiment, or create it if it does not exist.
    mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_experiment(experiment_name)

    config = {
        "layer_1": tune.choice([32, 64, 128]),
        "layer_2": tune.choice([64, 128, 256]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([32, 64, 128]),
        "experiment_name": experiment_name,
        "tracking_uri": mlflow.get_tracking_uri(),
        "data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"),
        "num_epochs": num_epochs,
    }

    trainable = tune.with_parameters(
        train_mnist_tune,
        data_dir=data_dir,
        num_epochs=num_epochs,
        num_gpus=gpus_per_trial,
    )

    tuner = tune.Tuner(
        tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            num_samples=num_samples,
        ),
        run_config=tune.RunConfig(
            name="tune_mnist",
        ),
        param_space=config,
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args, _ = parser.parse_known_args()

    if args.smoke_test:
        tune_mnist(
            num_samples=1,
            num_epochs=1,
            gpus_per_trial=0,
            tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns"),
        )
    else:
        tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)