ray.train.tensorflow.TensorflowTrainer#

class ray.train.tensorflow.TensorflowTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, tensorflow_config: TensorflowConfig | None = None, scaling_config: ScalingConfig | None = None, dataset_config: DataConfig | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, metadata: Dict[str, Any] | None = None, resume_from_checkpoint: Checkpoint | None = None)[source]#

Bases: DataParallelTrainer

A Trainer for data parallel Tensorflow training.

At a high level, this Trainer does the following:

  1. Launches multiple workers as defined by the scaling_config.

  2. Sets up a distributed Tensorflow environment on these workers as defined by the tensorflow_config.

  3. Ingests the input datasets based on the dataset_config.

  4. Runs the input train_loop_per_worker(train_loop_config) on all workers.

For more details, see:

Inside the train_loop_per_worker function, you can use any of the Ray Train loop methods.

Warning

Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the tf.config.threading module (eg. tf.config.threading.set_inter_op_parallelism_threads(num_cpus)) at the beginning of your train_loop_per_worker function.

from ray import train

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.
    train.report(...)

    # Returns dict of last saved checkpoint.
    train.get_checkpoint()

    # Returns the Dataset shard for the given key.
    train.get_dataset_shard("my_dataset")

    # Returns the total number of workers executing training.
    train.get_context().get_world_size()

    # Returns the rank of this worker.
    train.get_context().get_world_rank()

    # Returns the rank of the worker on the current node.
    train.get_context().get_local_rank()

Any returns from the train_loop_per_worker will be discarded and not used or persisted anywhere.

Example:

import os
import tempfile
import tensorflow as tf

import ray
from ray import train
from ray.train import Checkpoint, ScalingConfig
from ray.train.tensorflow import TensorflowTrainer

def build_model():
    # toy neural network : 1-layer
    return tf.keras.Sequential(
        [tf.keras.layers.Dense(
            1, activation="linear", input_shape=(1,))]
    )

def train_loop_per_worker(config):
    dataset_shard = train.get_dataset_shard("train")
    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
    with strategy.scope():
        model = build_model()
        model.compile(
            optimizer="Adam", loss="mean_squared_error", metrics=["mse"])

    tf_dataset = dataset_shard.to_tf(
        feature_columns="x",
        label_columns="y",
        batch_size=1
    )
    for epoch in range(config["num_epochs"]):
        model.fit(tf_dataset)

        # Create checkpoint.
        checkpoint_dir = tempfile.mkdtemp()
        model.save_weights(
            os.path.join(checkpoint_dir, "my_checkpoint")
        )
        checkpoint = Checkpoint.from_directory(checkpoint_dir)

        train.report(
            {},
            checkpoint=checkpoint,
        )

train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
trainer = TensorflowTrainer(
    train_loop_per_worker=train_loop_per_worker,
    scaling_config=ScalingConfig(num_workers=3, use_gpu=True),
    datasets={"train": train_dataset},
    train_loop_config={"num_epochs": 2},
)
result = trainer.fit()
Parameters:
  • train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single Dict argument which is set by defining train_loop_config. Within this function you can use any of the Ray Train Loop utilities.

  • train_loop_config – A configuration Dict to pass in as an argument to train_loop_per_worker. This is typically used for specifying hyperparameters. Passing large datasets via train_loop_config is not recommended and may introduce large overhead and unknown issues with serialization and deserialization.

  • tensorflow_config – The configuration for setting up the Tensorflow Distributed backend. If set to None, a default configuration will be used in which GPU training uses NCCL and CPU training uses Gloo.

  • scaling_config – The configuration for how to scale data parallel training. num_workers determines how many Python processes are used for training, and use_gpu determines whether or not each process should use GPUs. See ScalingConfig for more info.

  • run_config – The configuration for the execution of the training run. See RunConfig for more info.

  • datasets – The Ray Datasets to ingest for training. Datasets are keyed by name ({name: dataset}). Each dataset can be accessed from within the train_loop_per_worker by calling ray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in a dataset_config.

  • resume_from_checkpoint – A checkpoint to resume training from.

  • metadata – Dict that should be made available via ray.train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Methods

can_restore

[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run.

fit

Launches the Ray Train controller to run training on workers.

restore

[Deprecated] Restores a Train experiment from a previously interrupted/failed run.