HyperBand Function Example#

#!/usr/bin/env python

import argparse
import json
import os
import tempfile

import numpy as np

import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import HyperBandScheduler


def train_func(config):
    step = 0
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            with open(os.path.join(checkpoint_dir, "checkpoint.json")) as f:
                step = json.load(f)["timestep"] + 1

    for timestep in range(step, 100):
        v = np.tanh(float(timestep) / config.get("width", 1))
        v *= config.get("height", 1)

        # Checkpoint the state of the training every 3 steps
        # Note that this is only required for certain schedulers
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None
            if timestep % 3 == 0:
                with open(
                    os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w"
                ) as f:
                    json.dump({"timestep": timestep}, f)
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            # Here we use `episode_reward_mean`, but you can also report other
            # objectives such as loss or accuracy.
            train.report({"episode_reward_mean": v}, checkpoint=checkpoint)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args, _ = parser.parse_known_args()

    ray.init(num_cpus=4 if args.smoke_test else None)

    # Hyperband early stopping, configured with `episode_reward_mean` as the
    # objective and `training_iteration` as the time unit,
    # which is automatically filled by Tune.
    hyperband = HyperBandScheduler(max_t=200)

    tuner = tune.Tuner(
        train_func,
        run_config=train.RunConfig(
            name="hyperband_test",
            stop={"training_iteration": 10 if args.smoke_test else 99999},
            failure_config=train.FailureConfig(
                fail_fast=True,
            ),
        ),
        tune_config=tune.TuneConfig(
            num_samples=20,
            metric="episode_reward_mean",
            mode="max",
            scheduler=hyperband,
        ),
        param_space={"height": tune.uniform(0, 100)},
    )
    results = tuner.fit()
    print("Best hyperparameters found were: ", results.get_best_result().config)