- class ray.train.tensorflow.tensorflow_trainer.TensorflowTrainer(*args, **kwargs)[source]#
A Trainer for data parallel Tensorflow training.
This Trainer runs the function
on multiple Ray Actors. These actors already have the necessary TensorFlow process group already configured for distributed TensorFlow training.The
function is expected to take in either 0 or 1 arguments:def train_loop_per_worker(): ...
def train_loop_per_worker(config: Dict): ...
accepts an argument, thentrain_loop_config
will be passed in as the argument. This is useful if you want to tune the values intrain_loop_config
as hyperparameters.If the
dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed byray.train.get_dataset_shard("train")
. All the other datasets will not be split andray.train.get_dataset_shard(...)
will return the the entire Dataset.Inside the
function, you can use any of the Ray Train loop methods.Warning
Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the
module (eg.tf.config.threading.set_inter_op_parallelism_threads(num_cpus)
) at the beginning of yourtrain_loop_per_worker
function.from ray import train def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. train.report(...) # Returns dict of last saved checkpoint. train.get_checkpoint() # Returns the Dataset shard for the given key. train.get_dataset_shard("my_dataset") # Returns the total number of workers executing training. train.get_context().get_world_size() # Returns the rank of this worker. train.get_context().get_world_rank() # Returns the rank of the worker on the current node. train.get_context().get_local_rank()
Any returns from the
will be discarded and not used or persisted anywhere.To save a model to use for the
, you must save it under the “model” kwarg inCheckpoint
passed totrain.report()
import os import tempfile import tensorflow as tf import ray from ray import train from ray.train import Checkpoint, ScalingConfig from ray.train.tensorflow import TensorflowTrainer def build_model(): # toy neural network : 1-layer return tf.keras.Sequential( [tf.keras.layers.Dense( 1, activation="linear", input_shape=(1,))] ) def train_loop_per_worker(config): dataset_shard = train.get_dataset_shard("train") strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() model.compile( optimizer="Adam", loss="mean_squared_error", metrics=["mse"]) tf_dataset = dataset_shard.to_tf( feature_columns="x", label_columns="y", batch_size=1 ) for epoch in range(config["num_epochs"]): model.fit(tf_dataset) # Create checkpoint. checkpoint_dir = tempfile.mkdtemp() model.save_weights( os.path.join(checkpoint_dir, "my_checkpoint") ) checkpoint = Checkpoint.from_directory(checkpoint_dir) train.report( {}, checkpoint=checkpoint, ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) trainer = TensorflowTrainer( train_loop_per_worker=train_loop_per_worker, scaling_config=ScalingConfig(num_workers=3, use_gpu=True), datasets={"train": train_dataset}, train_loop_config={"num_epochs": 2}, ) result = trainer.fit()
- Parameters:
train_loop_per_worker – The training function to execute. This can either take in no arguments or a
dict.train_loop_config – Configurations to pass into
if it accepts an argument.tensorflow_config – Configuration for setting up the TensorFlow backend. If set to None, use the default configuration. This replaces the
arg ofDataParallelTrainer
.scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest.
run_config – Configuration for the execution of the training run.
datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.
resume_from_checkpoint – A checkpoint to resume training from.
metadata – Dict that should be made available via
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.
PublicAPI (beta): This API is in beta and may change before becoming stable.
Converts self to a
class.Checks whether a given directory contains a restorable Train experiment.
Runs training.
Returns a copy of this Trainer's final dataset configs.
Restores a DataParallelTrainer from a previously interrupted/failed run.
Called during fit() to perform initial setup on the Trainer.