ray.train.tensorflow.TensorflowTrainer#
- class ray.train.tensorflow.TensorflowTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, tensorflow_config: TensorflowConfig | None = None, scaling_config: ScalingConfig | None = None, dataset_config: DataConfig | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, metadata: Dict[str, Any] | None = None, resume_from_checkpoint: Checkpoint | None = None)[source]#
Bases:
DataParallelTrainer
A Trainer for data parallel Tensorflow training.
At a high level, this Trainer does the following:
Launches multiple workers as defined by the
scaling_config
.Sets up a distributed Tensorflow environment on these workers as defined by the
tensorflow_config
.Ingests the input
datasets
based on thedataset_config
.Runs the input
train_loop_per_worker(train_loop_config)
on all workers.
For more details, see:
Inside the
train_loop_per_worker
function, you can use any of the Ray Train loop methods.Warning
Ray will not automatically set any environment variables or configuration related to local parallelism / threading aside from “OMP_NUM_THREADS”. If you desire greater control over TensorFlow threading, use the
tf.config.threading
module (eg.tf.config.threading.set_inter_op_parallelism_threads(num_cpus)
) at the beginning of yourtrain_loop_per_worker
function.from ray import train def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. train.report(...) # Returns dict of last saved checkpoint. train.get_checkpoint() # Returns the Dataset shard for the given key. train.get_dataset_shard("my_dataset") # Returns the total number of workers executing training. train.get_context().get_world_size() # Returns the rank of this worker. train.get_context().get_world_rank() # Returns the rank of the worker on the current node. train.get_context().get_local_rank()
Any returns from the
train_loop_per_worker
will be discarded and not used or persisted anywhere.Example:
import os import tempfile import tensorflow as tf import ray from ray import train from ray.train import Checkpoint, ScalingConfig from ray.train.tensorflow import TensorflowTrainer def build_model(): # toy neural network : 1-layer return tf.keras.Sequential( [tf.keras.layers.Dense( 1, activation="linear", input_shape=(1,))] ) def train_loop_per_worker(config): dataset_shard = train.get_dataset_shard("train") strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() with strategy.scope(): model = build_model() model.compile( optimizer="Adam", loss="mean_squared_error", metrics=["mse"]) tf_dataset = dataset_shard.to_tf( feature_columns="x", label_columns="y", batch_size=1 ) for epoch in range(config["num_epochs"]): model.fit(tf_dataset) # Create checkpoint. checkpoint_dir = tempfile.mkdtemp() model.save_weights( os.path.join(checkpoint_dir, "my_checkpoint") ) checkpoint = Checkpoint.from_directory(checkpoint_dir) train.report( {}, checkpoint=checkpoint, ) train_dataset = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) trainer = TensorflowTrainer( train_loop_per_worker=train_loop_per_worker, scaling_config=ScalingConfig(num_workers=3, use_gpu=True), datasets={"train": train_dataset}, train_loop_config={"num_epochs": 2}, ) result = trainer.fit()
- Parameters:
train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single
Dict
argument which is set by definingtrain_loop_config
. Within this function you can use any of the Ray Train Loop utilities.train_loop_config – A configuration
Dict
to pass in as an argument totrain_loop_per_worker
. This is typically used for specifying hyperparameters. Passing large datasets viatrain_loop_config
is not recommended and may introduce large overhead and unknown issues with serialization and deserialization.tensorflow_config – The configuration for setting up the Tensorflow Distributed backend. If set to None, a default configuration will be used in which GPU training uses NCCL and CPU training uses Gloo.
scaling_config – The configuration for how to scale data parallel training.
num_workers
determines how many Python processes are used for training, anduse_gpu
determines whether or not each process should use GPUs. SeeScalingConfig
for more info.run_config – The configuration for the execution of the training run. See
RunConfig
for more info.datasets – The Ray Datasets to ingest for training. Datasets are keyed by name (
{name: dataset}
). Each dataset can be accessed from within thetrain_loop_per_worker
by callingray.train.get_dataset_shard(name)
. Sharding and additional configuration can be done by passing in adataset_config
.resume_from_checkpoint – A checkpoint to resume training from.
metadata – Dict that should be made available via
ray.train.get_context().get_metadata()
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.
PublicAPI (beta): This API is in beta and may change before becoming stable.
Methods
[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run.
Launches the Ray Train controller to run training on workers.
[Deprecated] Restores a Train experiment from a previously interrupted/failed run.