ray.train.tensorflow.TensorflowTrainer#

class ray.train.tensorflow.TensorflowTrainer(*args, **kwargs)[source]#

Bases: DataParallelTrainer

A Trainer for data parallel Tensorflow training.

This Trainer runs the function train_loop_per_worker on multiple Ray Actors. These actors already have the necessary TensorFlow process group already configured for distributed TensorFlow training.

The train_loop_per_worker function is expected to take in either 0 or 1 arguments:

def train_loop_per_worker():
    ...
def train_loop_per_worker(config: Dict):
    ...

If train_loop_per_worker accepts an argument, then train_loop_config will be passed in as the argument. This is useful if you want to tune the values in train_loop_config as hyperparameters.

If the datasets dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed by ray.train.get_dataset_shard("train") inside train_loop_per_worker. All the other datasets will not be split and ray.train.get_dataset_shard(...) will return the the entire Dataset.

Inside the train_loop_per_worker function, you can use any of the Ray Train loop methods.

Warning

Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the tf.config.threading module (eg. tf.config.threading.set_inter_op_parallelism_threads(num_cpus)) at the beginning of your train_loop_per_worker function.

from ray import train

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.
    train.report(...)

    # Returns dict of last saved checkpoint.
    train.get_checkpoint()

    # Returns the Dataset shard for the given key.
    train.get_dataset_shard("my_dataset")

    # Returns the total number of workers executing training.
    train.get_context().get_world_size()

    # Returns the rank of this worker.
    train.get_context().get_world_rank()

    # Returns the rank of the worker on the current node.
    train.get_context().get_local_rank()

Any returns from the train_loop_per_worker will be discarded and not used or persisted anywhere.

To save a model to use for the TensorflowPredictor, you must save it under the “model” kwarg in Checkpoint passed to train.report().

Example:

import os
import tempfile
import tensorflow as tf

import ray
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

def build_model():
    # toy neural network : 1-layer
    return tf.keras.Sequential(
        [tf.keras.layers.Dense(
            1, activation="linear", input_shape=(1,))]
    )

def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(
            optimizer="Adam", loss="mean_squared_error", metrics=["mse"])

    tf_dataset = dataset_shard.to_tf(
        feature_columns="x",
        label_columns="y",
        batch_size=1
    )
    for epoch in range(config["num_epochs"]):
        model.fit(tf_dataset)

        # Create checkpoint.
        checkpoint_dir = tempfile.mkdtemp()
        model.save_weights(
            os.path.join(checkpoint_dir, "my_checkpoint")
        )
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        train.report(
            {},
            checkpoint=checkpoint,
        )

train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=3, use_gpu=True),
    datasets={"train": train_dataset},
    train_loop_config={"num_epochs": 2},
)
result = trainer.fit()
Parameters:
  • train_loop_per_worker – The training function to execute. This can either take in no arguments or a config dict.

  • train_loop_config – Configurations to pass into train_loop_per_worker if it accepts an argument.

  • tensorflow_config – Configuration for setting up the TensorFlow backend. If set to None, use the default configuration. This replaces the backend_config arg of DataParallelTrainer.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest.

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.

  • resume_from_checkpoint – A checkpoint to resume training from.

  • metadata – Dict that should be made available via ray.train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Methods

as_trainable

Converts self to a tune.Trainable class.

can_restore

Checks whether a given directory contains a restorable Train experiment.

fit

Runs training.

get_dataset_config

Returns a copy of this Trainer's final dataset configs.

restore

Restores a DataParallelTrainer from a previously interrupted/failed run.

setup

Called during fit() to perform initial setup on the Trainer.