Using Keras & TensorFlow with Tune¶

Example¶
import argparse
import os
from filelock import FileLock
from tensorflow.keras.datasets import mnist
import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.integration.keras import TuneReportCallback
def train_mnist(config):
# https://github.com/tensorflow/tensorflow/issues/32159
import tensorflow as tf
batch_size = 128
num_classes = 10
epochs = 12
with FileLock(os.path.expanduser("~/.data.lock")):
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(config["hidden"], activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(num_classes, activation="softmax"),
]
)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.SGD(lr=config["lr"], momentum=config["momentum"]),
metrics=["accuracy"],
)
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=0,
validation_data=(x_test, y_test),
callbacks=[TuneReportCallback({"mean_accuracy": "accuracy"})],
)
def tune_mnist(num_training_iterations):
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", max_t=400, grace_period=20
)
analysis = tune.run(
train_mnist,
name="exp",
scheduler=sched,
metric="mean_accuracy",
mode="max",
stop={"mean_accuracy": 0.99, "training_iteration": num_training_iterations},
num_samples=10,
resources_per_trial={"cpu": 2, "gpu": 0},
config={
"threads": 2,
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
"hidden": tune.randint(32, 512),
},
)
print("Best hyperparameters found were: ", analysis.best_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing"
)
parser.add_argument(
"--server-address",
type=str,
default=None,
required=False,
help="The address of server to connect to if using " "Ray Client.",
)
args, _ = parser.parse_known_args()
if args.smoke_test:
ray.init(num_cpus=4)
elif args.server_address:
ray.init(f"ray://{args.server_address}")
tune_mnist(num_training_iterations=5 if args.smoke_test else 300)
More Keras and TensorFlow Examples¶
Memory NN Example: Example of training a Memory NN on bAbI with Keras using PBT.
TensorFlow MNIST Example: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses
tf.function
. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advancedKeras Cifar10 Example: A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.