class ray.train.tensorflow.TensorflowTrainer(*args, **kwargs)[source]#

Bases: ray.train.data_parallel_trainer.DataParallelTrainer

A Trainer for data parallel Tensorflow training.

This Trainer runs the function train_loop_per_worker on multiple Ray Actors. These actors already have the necessary TensorFlow process group already configured for distributed TensorFlow training.

The train_loop_per_worker function is expected to take in either 0 or 1 arguments:

def train_loop_per_worker():
def train_loop_per_worker(config: Dict):

If train_loop_per_worker accepts an argument, then train_loop_config will be passed in as the argument. This is useful if you want to tune the values in train_loop_config as hyperparameters.

If the datasets dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed by session.get_dataset_shard("train") inside train_loop_per_worker. All the other datasets will not be split and session.get_dataset_shard(...) will return the the entire Dataset.

Inside the train_loop_per_worker function, you can use any of the Ray AIR session methods.


Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the tf.config.threading module (eg. tf.config.threading.set_inter_op_parallelism_threads(num_cpus)) at the beginning of your train_loop_per_worker function.

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.

    # Returns dict of last saved checkpoint.

    # Returns the Ray Dataset shard for the given key.

    # Returns the total number of workers executing training.

    # Returns the rank of this worker.

    # Returns the rank of the worker on the current node.

Any returns from the train_loop_per_worker will be discarded and not used or persisted anywhere.

To save a model to use for the TensorflowPredictor, you must save it under the “model” kwarg in Checkpoint passed to session.report().


import tensorflow as tf

import ray
from ray.air import session, Checkpoint
from ray.air.config import ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

# If using GPUs, set this to True.
use_gpu = False

def build_model():
    # toy neural network : 1-layer
    return tf.keras.Sequential(
            1, activation="linear", input_shape=(1,))]

def train_loop_per_worker(config):
    dataset_shard = session.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
            optimizer="Adam", loss="mean_squared_error", metrics=["mse"])

    tf_dataset = dataset_shard.to_tf(
    for epoch in range(config["num_epochs"]):
        # You can also use ray.air.integrations.keras.Callback
        # for reporting and checkpointing instead of reporting manually.
                dict(epoch=epoch, model=model.get_weights())

train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(
    scaling_config=ScalingConfig(num_workers=3, use_gpu=use_gpu),
    datasets={"train": train_dataset},
    train_loop_config={"num_epochs": 2},
result = trainer.fit()
  • train_loop_per_worker – The training function to execute. This can either take in no arguments or a config dict.

  • train_loop_config – Configurations to pass into train_loop_per_worker if it accepts an argument.

  • tensorflow_config – Configuration for setting up the TensorFlow backend. If set to None, use the default configuration. This replaces the backend_config arg of DataParallelTrainer.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Ray Datasets to use for training. Use the key “train” to denote which dataset is the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided.

  • preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

PublicAPI (beta): This API is in beta and may change before becoming stable.