Get Started with Distributed Training using TensorFlow/Keras#

Ray Train’s TensorFlow integration enables you to scale your TensorFlow and Keras training functions to many machines and GPUs.

On a technical level, Ray Train schedules your training workers and configures TF_CONFIG for you, allowing you to run your MultiWorkerMirroredStrategy training script. See Distributed training with TensorFlow for more information.

Most of the examples in this guide use TensorFlow with Keras, but Ray Train also works with vanilla TensorFlow.

Quickstart#

import ray
import tensorflow as tf

from ray import train
from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
from ray.train.tensorflow.keras import ReportCheckpointCallback


# If using GPUs, set this to True.
use_gpu = False

a = 5
b = 10
size = 100


def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=()),
            # Add feature dimension, expanding (batch_size,) to (batch_size, 1).
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(10),
            tf.keras.layers.Dense(1),
        ]
    )
    return model


def train_func(config: dict):
    batch_size = config.get("batch_size", 64)
    epochs = config.get("epochs", 3)

    strategy = tf.distribute.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # Model building/compiling need to be within `strategy.scope()`.
        multi_worker_model = build_model()
        multi_worker_model.compile(
            optimizer=tf.keras.optimizers.SGD(learning_rate=config.get("lr", 1e-3)),
            loss=tf.keras.losses.mean_squared_error,
            metrics=[tf.keras.metrics.mean_squared_error],
        )

    dataset = train.get_dataset_shard("train")

    results = []
    for _ in range(epochs):
        tf_dataset = dataset.to_tf(
            feature_columns="x", label_columns="y", batch_size=batch_size
        )
        history = multi_worker_model.fit(
            tf_dataset, callbacks=[ReportCheckpointCallback()]
        )
        results.append(history.history)
    return results


config = {"lr": 1e-3, "batch_size": 32, "epochs": 4}

train_dataset = ray.data.from_items(
    [{"x": x / 200, "y": 2 * x / 200} for x in range(200)]
)
scaling_config = ScalingConfig(num_workers=2, use_gpu=use_gpu)
trainer = TensorflowTrainer(
    train_loop_per_worker=train_func,
    train_loop_config=config,
    scaling_config=scaling_config,
    datasets={"train": train_dataset},
)
result = trainer.fit()
print(result.metrics)

Update your training function#

First, update your training function to support distributed training.

Note

The current TensorFlow implementation supports MultiWorkerMirroredStrategy (and MirroredStrategy). If there are other strategies you wish to see supported by Ray Train, submit a feature request on GitHub.

These instructions closely follow TensorFlow’s Multi-worker training with Keras tutorial. One key difference is that Ray Train handles the environment variable set up for you.

Step 1: Wrap your model in MultiWorkerMirroredStrategy.

The MultiWorkerMirroredStrategy enables synchronous distributed training. You must build and compile the Model within the scope of the strategy.

with tf.distribute.MultiWorkerMirroredStrategy().scope():
    model = ... # build model
    model.compile()

Step 2: Update your Dataset batch size to the global batch size.

Set batch_size appropriately because batch splits evenly across worker processes.

-batch_size = worker_batch_size
+batch_size = worker_batch_size * train.get_context().get_world_size()

Warning

Ray doesn’t automatically set any environment variables or configuration related to local parallelism or threading aside from “OMP_NUM_THREADS”. If you want greater control over TensorFlow threading, use the tf.config.threading module (eg. tf.config.threading.set_inter_op_parallelism_threads(num_cpus)) at the beginning of your train_loop_per_worker function.

Create a TensorflowTrainer#

Trainers are the primary Ray Train classes for managing state and execute training. For distributed Tensorflow, use a TensorflowTrainer that you can setup like this:

from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer
# For GPU Training, set `use_gpu` to True.
use_gpu = False
trainer = TensorflowTrainer(
    train_func,
    scaling_config=ScalingConfig(use_gpu=use_gpu, num_workers=2)
)

To customize the backend setup, you can pass a TensorflowConfig:

from ray.train import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer, TensorflowConfig

trainer = TensorflowTrainer(
    train_func,
    tensorflow_backend=TensorflowConfig(...),
    scaling_config=ScalingConfig(num_workers=2),
)

For more configurability, see the DataParallelTrainer API.

Run a training function#

With a distributed training function and a Ray Train Trainer, you are now ready to start training.

trainer.fit()

Load and preprocess data#

TensorFlow by default uses its own internal dataset sharding policy, as described in the guide. If your TensorFlow dataset is compatible with distributed loading, you don’t need to change anything.

If you require more advanced preprocessing, you may want to consider using Ray Data for distributed data ingest. See Ray Data with Ray Train.

The main difference is that you may want to convert your Ray Data dataset shard to a TensorFlow dataset in your training function so that you can use the Keras API for model training.

See this example for distributed data loading. The relevant parts are:

import tensorflow as tf
from ray import train
from ray.train.tensorflow import prepare_dataset_shard

def train_func(config: dict):
    # ...

    # Get dataset shard from Ray Train
    dataset_shard = train.get_context().get_dataset_shard("train")

    # Define a helper function to build a TensorFlow dataset
    def to_tf_dataset(dataset, batch_size):
        def to_tensor_iterator():
            for batch in dataset.iter_tf_batches(
                batch_size=batch_size, dtypes=tf.float32
            ):
                yield batch["image"], batch["label"]

        output_signature = (
            tf.TensorSpec(shape=(None, 784), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 784), dtype=tf.float32),
        )
        tf_dataset = tf.data.Dataset.from_generator(
            to_tensor_iterator, output_signature=output_signature
        )
        # Call prepare_dataset_shard to disable automatic sharding
        # (since the dataset is already sharded)
        return prepare_dataset_shard(tf_dataset)

    for epoch in range(epochs):
        # Call our helper function to build the dataset
        tf_dataset = to_tf_dataset(
            dataset=dataset_shard,
            batch_size=64,
        )
        history = multi_worker_model.fit(tf_dataset)

Report results#

During training, the training loop should report intermediate results and checkpoints to Ray Train. This reporting logs the results to the console output and appends them to local log files. The logging also triggers checkpoint bookkeeping.

The easiest way to report your results with Keras is by using the ReportCheckpointCallback:

from ray.train.tensorflow.keras import ReportCheckpointCallback

def train_func(config: dict):
    # ...
    for epoch in range(epochs):
        model.fit(dataset, callbacks=[ReportCheckpointCallback()])

This callback automatically forwards all results and checkpoints from the Keras training function to Ray Train.

Aggregate results#

TensorFlow Keras automatically aggregates metrics from all workers. If you wish to have more control over that, consider implementing a custom training loop.

Save and load checkpoints#

You can save Checkpoints by calling train.report(metrics, checkpoint=Checkpoint(...)) in the training function. This call saves the checkpoint state from the distributed workers on the Trainer, where you executed your python script.

You can access the latest saved checkpoint through the checkpoint attribute of the Result, and access the best saved checkpoints with the best_checkpoints attribute.

These concrete examples demonstrate how Ray Train appropriately saves checkpoints, model weights but not models, in distributed training.

import json
import os
import tempfile

from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

import numpy as np

def train_func(config):
    import tensorflow as tf
    n = 100
    # create a toy dataset
    # data   : X - dim = (n, 4)
    # target : Y - dim = (n, 1)
    X = np.random.normal(0, 1, size=(n, 4))
    Y = np.random.uniform(0, 1, size=(n, 1))

    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # toy neural network : 1-layer
        model = tf.keras.Sequential([tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))])
        model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])

    for epoch in range(config["num_epochs"]):
        history = model.fit(X, Y, batch_size=20)

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            model.save(os.path.join(temp_checkpoint_dir, "model.keras"))
            checkpoint_dict = os.path.join(temp_checkpoint_dir, "checkpoint.json")
            with open(checkpoint_dict, "w") as f:
                json.dump({"epoch": epoch}, f)
            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            train.report({"loss": history.history["loss"][0]}, checkpoint=checkpoint)

trainer = TensorflowTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
print(result.checkpoint)

By default, checkpoints persist to local disk in the log directory of each run.

Load checkpoints#

import os
import tempfile

from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

import numpy as np

def train_func(config):
    import tensorflow as tf
    n = 100
    # create a toy dataset
    # data   : X - dim = (n, 4)
    # target : Y - dim = (n, 1)
    X = np.random.normal(0, 1, size=(n, 4))
    Y = np.random.uniform(0, 1, size=(n, 1))

    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        # toy neural network : 1-layer
        checkpoint = train.get_checkpoint()
        if checkpoint:
            with checkpoint.as_directory() as checkpoint_dir:
                model = tf.keras.models.load_model(
                    os.path.join(checkpoint_dir, "model.keras")
                )
        else:
            model = tf.keras.Sequential(
                [tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))]
            )
        model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])

    for epoch in range(config["num_epochs"]):
        history = model.fit(X, Y, batch_size=20)

        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            model.save(os.path.join(temp_checkpoint_dir, "model.keras"))
            extra_json = os.path.join(temp_checkpoint_dir, "checkpoint.json")
            with open(extra_json, "w") as f:
                json.dump({"epoch": epoch}, f)
            checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)

            train.report({"loss": history.history["loss"][0]}, checkpoint=checkpoint)

trainer = TensorflowTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
)
result = trainer.fit()
print(result.checkpoint)

# Start a new run from a loaded checkpoint
trainer = TensorflowTrainer(
    train_func,
    train_loop_config={"num_epochs": 5},
    scaling_config=ScalingConfig(num_workers=2),
    resume_from_checkpoint=result.checkpoint,
)
result = trainer.fit()

Further reading#

See User Guides to explore more topics: