"""An example showing how to use Pytorch Lightning training, Ray Tune
HPO, and MLflow autologging all together."""
import os
import tempfile
import lightning.pytorch as pl
import mlflow
from ray import tune
from ray.air.integrations.mlflow import setup_mlflow
from ray.tune.examples.mnist_ptl_mini import LightningMNISTClassifier, MNISTDataModule
from ray.tune.integration.pytorch_lightning import TuneReportCallback
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
setup_mlflow(
config,
experiment_name=config.get("experiment_name", None),
tracking_uri=config.get("tracking_uri", None),
)
model = LightningMNISTClassifier(config, data_dir)
dm = MNISTDataModule(
data_dir=data_dir, num_workers=1, batch_size=config["batch_size"]
)
metrics = {"loss": "ptl/val_loss", "acc": "ptl/val_accuracy"}
mlflow.pytorch.autolog()
trainer = pl.Trainer(
max_epochs=num_epochs,
gpus=num_gpus,
progress_bar_refresh_rate=0,
callbacks=[TuneReportCallback(metrics, on="validation_end")],
)
trainer.fit(model, dm)
def tune_mnist(
num_samples=10,
num_epochs=10,
gpus_per_trial=0,
tracking_uri=None,
experiment_name="ptl_autologging_example",
):
data_dir = os.path.join(tempfile.gettempdir(), "mnist_data_")
# Download data
MNISTDataModule(data_dir=data_dir, batch_size=32).prepare_data()
# Set the MLflow experiment, or create it if it does not exist.
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)
config = {
"layer_1": tune.choice([32, 64, 128]),
"layer_2": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"experiment_name": experiment_name,
"tracking_uri": mlflow.get_tracking_uri(),
"data_dir": os.path.join(tempfile.gettempdir(), "mnist_data_"),
"num_epochs": num_epochs,
}
trainable = tune.with_parameters(
train_mnist_tune,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial,
)
tuner = tune.Tuner(
tune.with_resources(trainable, resources={"cpu": 1, "gpu": gpus_per_trial}),
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
num_samples=num_samples,
),
run_config=tune.RunConfig(
name="tune_mnist",
),
param_space=config,
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
args, _ = parser.parse_known_args()
if args.smoke_test:
tune_mnist(
num_samples=1,
num_epochs=1,
gpus_per_trial=0,
tracking_uri=os.path.join(tempfile.gettempdir(), "mlruns"),
)
else:
tune_mnist(num_samples=10, num_epochs=10, gpus_per_trial=0)