Using Keras & TensorFlow with Tune#

try-anyscale-quickstart

Keras & TensorFlow Logo

Prerequisites#

  • pip install "ray[tune]" tensorflow==2.18.0 filelock

Example#

import os

from filelock import FileLock
from tensorflow.keras.datasets import mnist

from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.air.integrations.keras import ReportCheckpointCallback


def train_mnist(config):
    # https://github.com/tensorflow/tensorflow/issues/32159
    import tensorflow as tf

    batch_size = 128
    num_classes = 10
    epochs = 12

    with FileLock(os.path.expanduser("~/.data.lock")):
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(config["hidden"], activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(num_classes, activation="softmax"),
        ]
    )

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=tf.keras.optimizers.SGD(learning_rate=config["learning_rate"], momentum=config["momentum"]),
        metrics=["accuracy"],
    )

    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=epochs,
        verbose=0,
        validation_data=(x_test, y_test),
        callbacks=[ReportCheckpointCallback(metrics={"accuracy": "accuracy"})],
    )


def tune_mnist():
    sched = AsyncHyperBandScheduler(
        time_attr="training_iteration", max_t=400, grace_period=20
    )

    tuner = tune.Tuner(
        tune.with_resources(train_mnist, resources={"cpu": 2, "gpu": 0}),
        tune_config=tune.TuneConfig(
            metric="accuracy",
            mode="max",
            scheduler=sched,
            num_samples=10,
        ),
        run_config=tune.RunConfig(
            name="exp",
            stop={"accuracy": 0.99},
        ),
        param_space={
            "threads": 2,
            "learning_rate": tune.uniform(0.001, 0.1),
            "momentum": tune.uniform(0.1, 0.9),
            "hidden": tune.randint(32, 512),
        },
    )
    results = tuner.fit()
    return results

    

results = tune_mnist()
print(f"Best hyperparameters found were: {results.get_best_result().config} | Accuracy: {results.get_best_result().metrics['accuracy']}")
Hide code cell output

Tune Status

Current time:2025-02-13 15:22:41
Running for: 00:00:41.76
Memory: 21.4/36.0 GiB

System Info

Using AsyncHyperBand: num_stopped=0
Bracket: Iter 320.000: None | Iter 80.000: None | Iter 20.000: None
Logical resource usage: 2.0/12 CPUs, 0/0 GPUs

Trial Status

Trial name status loc hidden learning_rate momentum iter total time (s) accuracy
train_mnist_533a2_00000TERMINATED127.0.0.1:36365 371 0.0799367 0.588387 12 20.8515 0.984583
train_mnist_533a2_00001TERMINATED127.0.0.1:36364 266 0.0457424 0.22303 12 19.5277 0.96495
train_mnist_533a2_00002TERMINATED127.0.0.1:36368 157 0.0190286 0.537132 12 16.6606 0.95385
train_mnist_533a2_00003TERMINATED127.0.0.1:36363 451 0.0433488 0.18925 12 22.0514 0.966283
train_mnist_533a2_00004TERMINATED127.0.0.1:36367 276 0.0336728 0.430171 12 20.0884 0.964767
train_mnist_533a2_00005TERMINATED127.0.0.1:36366 208 0.071015 0.419166 12 17.933 0.976083
train_mnist_533a2_00006TERMINATED127.0.0.1:36475 312 0.00692959 0.714595 12 13.058 0.944017
train_mnist_533a2_00007TERMINATED127.0.0.1:36479 169 0.0694114 0.664904 12 10.7991 0.9803
train_mnist_533a2_00008TERMINATED127.0.0.1:36486 389 0.0370836 0.665592 12 14.018 0.977833
train_mnist_533a2_00009TERMINATED127.0.0.1:36487 389 0.0676138 0.52372 12 14.0043 0.981833
2025-02-13 15:22:41,843	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/rdecal/ray_results/exp' in 0.0048s.
2025-02-13 15:22:41,846	INFO tune.py:1041 -- Total run time: 41.77 seconds (41.75 seconds for the tuning loop).
Best hyperparameters found were: {'threads': 2, 'learning_rate': 0.07993666231835218, 'momentum': 0.5883866709655042, 'hidden': 371} | Accuracy: 0.98458331823349

This should output something like:

Best hyperparameters found were:  {'threads': 2, 'learning_rate': 0.07607440973606909, 'momentum': 0.7715363277240616, 'hidden': 452} | Accuracy: 0.98458331823349

More Keras and TensorFlow Examples#

  • Memory NN Example: Example of training a Memory NN on bAbI with Keras using PBT.

  • TensorFlow MNIST Example: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses tf.function. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced

  • Keras Cifar10 Example: A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.