#!/usr/bin/env python
import argparse
import json
import os
import tempfile
import numpy as np
import ray
from ray import train, tune
from ray.train import Checkpoint
from ray.tune.schedulers import HyperBandScheduler
def train_func(config):
step = 0
checkpoint = train.get_checkpoint()
if checkpoint:
with checkpoint.as_directory() as checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint.json")) as f:
step = json.load(f)["timestep"] + 1
for timestep in range(step, 100):
v = np.tanh(float(timestep) / config.get("width", 1))
v *= config.get("height", 1)
# Checkpoint the state of the training every 3 steps
# Note that this is only required for certain schedulers
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
checkpoint = None
if timestep % 3 == 0:
with open(
os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w"
) as f:
json.dump({"timestep": timestep}, f)
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
train.report({"episode_reward_mean": v}, checkpoint=checkpoint)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
args, _ = parser.parse_known_args()
ray.init(num_cpus=4 if args.smoke_test else None)
# Hyperband early stopping, configured with `episode_reward_mean` as the
# objective and `training_iteration` as the time unit,
# which is automatically filled by Tune.
hyperband = HyperBandScheduler(max_t=200)
tuner = tune.Tuner(
train_func,
run_config=train.RunConfig(
name="hyperband_test",
stop={"training_iteration": 10 if args.smoke_test else 99999},
failure_config=train.FailureConfig(
fail_fast=True,
),
),
tune_config=tune.TuneConfig(
num_samples=20,
metric="episode_reward_mean",
mode="max",
scheduler=hyperband,
),
param_space={"height": tune.uniform(0, 100)},
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)