import argparse
import sys
import ray
from ray import tune
from ray.train import ScalingConfig
from ray.tune.tune_config import TuneConfig
from ray.tune.tuner import Tuner
if sys.version_info >= (3, 12):
# Skip this test in Python 3.12+ because TensorFlow is not supported.
exit(0)
else:
from ray.train.examples.tf.tensorflow_mnist_example import train_func
from ray.train.tensorflow import TensorflowTrainer
def tune_tensorflow_mnist(
num_workers: int = 2, num_samples: int = 2, use_gpu: bool = False
):
trainer = TensorflowTrainer(
train_loop_per_worker=train_func,
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
)
tuner = Tuner(
trainer,
tune_config=TuneConfig(num_samples=num_samples, metric="accuracy", mode="max"),
param_space={
"train_loop_config": {
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"epochs": 3,
}
},
)
best_accuracy = tuner.fit().get_best_result().metrics["accuracy"]
print(f"Best accuracy config: {best_accuracy}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="Sets number of workers for training.",
)
parser.add_argument(
"--num-samples",
type=int,
default=2,
help="Sets number of samples for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
args = parser.parse_args()
if args.smoke_test:
num_gpus = args.num_workers if args.use_gpu else 0
ray.init(num_cpus=8, num_gpus=num_gpus)
tune_tensorflow_mnist(num_workers=2, num_samples=2, use_gpu=args.use_gpu)
else:
ray.init(address=args.address)
tune_tensorflow_mnist(
num_workers=args.num_workers,
num_samples=args.num_samples,
use_gpu=args.use_gpu,
)