#!/usr/bin/env python
import argparse
import json
import os
import numpy as np
import ray
from ray import air, tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import HyperBandScheduler
def train(config, checkpoint_dir=None):
step = 0
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint")) as f:
step = json.loads(f.read())["timestep"]
for timestep in range(step, 100):
v = np.tanh(float(timestep) / config.get("width", 1))
v *= config.get("height", 1)
# Checkpoint the state of the training every 3 steps
# Note that this is only required for certain schedulers
checkpoint = None
if timestep % 3 == 0:
checkpoint = Checkpoint.from_dict({"timestep": timestep})
# Here we use `episode_reward_mean`, but you can also report other
# objectives such as loss or accuracy.
session.report({"episode_reward_mean": v}, checkpoint=checkpoint)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
parser.add_argument(
"--server-address",
type=str,
default=None,
required=False,
help="The address of server to connect to if using Ray Client.",
)
args, _ = parser.parse_known_args()
if args.server_address is not None:
ray.init(f"ray://{args.server_address}")
else:
ray.init(num_cpus=4 if args.smoke_test else None)
# Hyperband early stopping, configured with `episode_reward_mean` as the
# objective and `training_iteration` as the time unit,
# which is automatically filled by Tune.
hyperband = HyperBandScheduler(max_t=200)
tuner = tune.Tuner(
train,
run_config=air.RunConfig(
name="hyperband_test",
stop={"training_iteration": 10 if args.smoke_test else 99999},
failure_config=air.FailureConfig(
fail_fast=True,
),
),
tune_config=tune.TuneConfig(
num_samples=20,
metric="episode_reward_mean",
mode="max",
scheduler=hyperband,
),
param_space={"height": tune.uniform(0, 100)},
)
results = tuner.fit()
print("Best hyperparameters found were: ", results.get_best_result().config)