class ray.train.tensorflow.TensorflowTrainer(*args, **kwargs)[source]#

Bases: ray.train.data_parallel_trainer.DataParallelTrainer

A Trainer for data parallel Tensorflow training.

This Trainer runs the function train_loop_per_worker on multiple Ray Actors. These actors already have the necessary TensorFlow process group already configured for distributed TensorFlow training.

The train_loop_per_worker function is expected to take in either 0 or 1 arguments:

def train_loop_per_worker():
def train_loop_per_worker(config: Dict):

If train_loop_per_worker accepts an argument, then train_loop_config will be passed in as the argument. This is useful if you want to tune the values in train_loop_config as hyperparameters.

If the datasets dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed by ray.train.get_dataset_shard("train") inside train_loop_per_worker. All the other datasets will not be split and ray.train.get_dataset_shard(...) will return the the entire Dataset.

Inside the train_loop_per_worker function, you can use any of the Ray Train loop methods.


Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the tf.config.threading module (eg. tf.config.threading.set_inter_op_parallelism_threads(num_cpus)) at the beginning of your train_loop_per_worker function.

from ray import train

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.

    # Returns dict of last saved checkpoint.

    # Returns the Dataset shard for the given key.

    # Returns the total number of workers executing training.

    # Returns the rank of this worker.

    # Returns the rank of the worker on the current node.

Any returns from the train_loop_per_worker will be discarded and not used or persisted anywhere.

To save a model to use for the TensorflowPredictor, you must save it under the “model” kwarg in Checkpoint passed to train.report().


import os
import tempfile
import tensorflow as tf

import ray
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

def build_model():
    # toy neural network : 1-layer
    return tf.keras.Sequential(
            1, activation="linear", input_shape=(1,))]

def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
            optimizer="Adam", loss="mean_squared_error", metrics=["mse"])

    tf_dataset = dataset_shard.to_tf(
    for epoch in range(config["num_epochs"]):

        # Create checkpoint.
        checkpoint_dir = tempfile.mkdtemp()
            os.path.join(checkpoint_dir, "my_checkpoint")
        checkpoint = Checkpoint.from_directory(checkpoint_dir)


train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(
    scaling_config=ScalingConfig(num_workers=3, use_gpu=True),
    datasets={"train": train_dataset},
    train_loop_config={"num_epochs": 2},
result = trainer.fit()
  • train_loop_per_worker – The training function to execute. This can either take in no arguments or a config dict.

  • train_loop_config – Configurations to pass into train_loop_per_worker if it accepts an argument.

  • tensorflow_config – Configuration for setting up the TensorFlow backend. If set to None, use the default configuration. This replaces the backend_config arg of DataParallelTrainer.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.

  • resume_from_checkpoint – A checkpoint to resume training from.

  • metadata – Dict that should be made available via ray.train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

PublicAPI (beta): This API is in beta and may change before becoming stable.



Converts self to a tune.Trainable class.


Checks whether a given directory contains a restorable Train experiment.


Runs training.


Returns a copy of this Trainer's final dataset configs.


Restores a DataParallelTrainer from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.