ray.train.tensorflow.TensorflowTrainer#
- class ray.train.tensorflow.TensorflowTrainer(*args, **kwargs)[source]#
Bases:
DataParallelTrainer
A Trainer for data parallel Tensorflow training.
This Trainer runs the function
train_loop_per_worker
on multiple Ray Actors. These actors already have the necessary TensorFlow process group already configured for distributed TensorFlow training.The
train_loop_per_worker
function is expected to take in either 0 or 1 arguments:def train_loop_per_worker(): ...
def train_loop_per_worker(config: Dict): ...
If
train_loop_per_worker
accepts an argument, thentrain_loop_config
will be passed in as the argument. This is useful if you want to tune the values intrain_loop_config
as hyperparameters.If the
datasets
dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed byray.train.get_dataset_shard("train")
insidetrain_loop_per_worker
. All the other datasets will not be split andray.train.get_dataset_shard(...)
will return the the entire Dataset.Inside the
train_loop_per_worker
function, you can use any of the Ray Train loop methods.Warning
Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the
tf.config.threading
module (eg.tf.config.threading.set_inter_op_parallelism_threads(num_cpus)
) at the beginning of yourtrain_loop_per_worker
function.from ray import train def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. train.report(...) # Returns dict of last saved checkpoint. train.get_checkpoint() # Returns the Dataset shard for the given key. train.get_dataset_shard("my_dataset") # Returns the total number of workers executing training. train.get_context().get_world_size() # Returns the rank of this worker. train.get_context().get_world_rank() # Returns the rank of the worker on the current node. train.get_context().get_local_rank()
Any returns from the
train_loop_per_worker
will be discarded and not used or persisted anywhere.To save a model to use for the
TensorflowPredictor
, you must save it under the “model” kwarg inCheckpoint
passed totrain.report()
.Example:
import os import tempfile import tensorflow as tf import ray from ray import train from ray.train import Checkpoint, ScalingConfig from ray.train.tensorflow import TensorflowTrainer def build_model(): # toy neural network : 1-layer return tf.keras.Sequential( [tf.keras.layers.Dense( 1, activation="linear", input_shape=(1,))] ) def train_loop_per_worker(config): dataset_shard = train.get_dataset_shard("train") strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() model.compile( optimizer="Adam", loss="mean_squared_error", metrics=["mse"]) tf_dataset = dataset_shard.to_tf( feature_columns="x", label_columns="y", batch_size=1 ) for epoch in range(config["num_epochs"]): model.fit(tf_dataset) # Create checkpoint. checkpoint_dir = tempfile.mkdtemp() model.save_weights( os.path.join(checkpoint_dir, "my_checkpoint") ) checkpoint = Checkpoint.from_directory(checkpoint_dir) train.report( {}, checkpoint=checkpoint, ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) trainer = TensorflowTrainer( train_loop_per_worker=train_loop_per_worker, scaling_config=ScalingConfig(num_workers=3, use_gpu=True), datasets={"train": train_dataset}, train_loop_config={"num_epochs": 2}, ) result = trainer.fit()
- Parameters:
train_loop_per_worker – The training function to execute. This can either take in no arguments or a
config
dict.train_loop_config – Configurations to pass into
train_loop_per_worker
if it accepts an argument.tensorflow_config – Configuration for setting up the TensorFlow backend. If set to None, use the default configuration. This replaces the
backend_config
arg ofDataParallelTrainer
.scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest.
run_config – Configuration for the execution of the training run.
datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset.
resume_from_checkpoint – A checkpoint to resume training from.
metadata – Dict that should be made available via
ray.train.get_context().get_metadata()
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.
PublicAPI (beta): This API is in beta and may change before becoming stable.
Methods
Converts self to a
tune.Trainable
class.Checks whether a given directory contains a restorable Train experiment.
Runs training.
Returns a copy of this Trainer's final dataset configs.
Deprecated.
Restores a DataParallelTrainer from a previously interrupted/failed run.
Called during fit() to perform initial setup on the Trainer.