ray.train.data_parallel_trainer.DataParallelTrainer#
- class ray.train.data_parallel_trainer.DataParallelTrainer(*args, **kwargs)[source]#
Bases:
BaseTrainer
A Trainer for data parallel training.
You should subclass this Trainer if your Trainer follows SPMD (single program, multiple data) programming paradigm - you want multiple processes to run the same function, but on different data.
This Trainer runs the function
train_loop_per_worker
on multiple Ray Actors.The
train_loop_per_worker
function is expected to take in either 0 or 1 arguments:def train_loop_per_worker(): ...
def train_loop_per_worker(config: Dict): ...
If
train_loop_per_worker
accepts an argument, thentrain_loop_config
will be passed in as the argument. This is useful if you want to tune the values intrain_loop_config
as hyperparameters.If the
datasets
dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed bytrain.get_dataset_shard("train")
insidetrain_loop_per_worker
. All the other datasets will not be split andtrain.get_dataset_shard(...)
will return the the entire Dataset.Inside the
train_loop_per_worker
function, you can use any of the Ray Train loop methods.from ray import train def train_loop_per_worker(): # Report intermediate results for callbacks or logging and # checkpoint data. train.report(...) # Returns dict of last saved checkpoint. train.get_checkpoint() # Returns the Dataset shard for the given key. train.get_dataset_shard("my_dataset") # Returns the total number of workers executing training. train.get_context().get_world_size() # Returns the rank of this worker. train.get_context().get_world_rank() # Returns the rank of the worker on the current node. train.get_context().get_local_rank()
Any returns from the
train_loop_per_worker
will be discarded and not used or persisted anywhere.How do I use DataParallelTrainer or any of its subclasses?
Example:
import ray from ray import train from ray.train import ScalingConfig from ray.train.data_parallel_trainer import DataParallelTrainer def train_loop_for_worker(): dataset_shard_for_this_worker = train.get_dataset_shard("train") # 3 items for 3 workers, each worker gets 1 item batches = list(dataset_shard_for_this_worker.iter_batches(batch_size=1)) assert len(batches) == 1 train_dataset = ray.data.from_items([1, 2, 3]) assert train_dataset.count() == 3 trainer = DataParallelTrainer( train_loop_for_worker, scaling_config=ScalingConfig(num_workers=3), datasets={"train": train_dataset}, ) result = trainer.fit()
How do I develop on top of DataParallelTrainer?
In many cases, using DataParallelTrainer directly is sufficient to execute functions on multiple actors.
However, you may want to subclass
DataParallelTrainer
and create a custom Trainer for the following 2 use cases:Use Case 1: You want to do data parallel training, but want to have a predefined
training_loop_per_worker
.Use Case 2: You want to implement a custom
Backend
that automatically handles additional setup or teardown logic on each actor, so that the users of this new trainer do not have to implement this logic. For example, aTensorflowTrainer
can be built on top ofDataParallelTrainer
that automatically handles setting the proper environment variables for distributed Tensorflow on each actor.
For 1, you can set a predefined training loop in __init__
from ray.train.data_parallel_trainer import DataParallelTrainer class MyDataParallelTrainer(DataParallelTrainer): def __init__(self, *args, **kwargs): predefined_train_loop_per_worker = lambda: 1 super().__init__(predefined_train_loop_per_worker, *args, **kwargs)
For 2, you can implement the
ray.train.Backend
andray.train.BackendConfig
interfaces.from dataclasses import dataclass from ray.train.backend import Backend, BackendConfig class MyBackend(Backend): def on_start(self, worker_group, backend_config): def set_env_var(env_var_value): import os os.environ["MY_ENV_VAR"] = env_var_value worker_group.execute(set_env_var, backend_config.env_var) @dataclass class MyBackendConfig(BackendConfig): env_var: str = "default_value" def backend_cls(self): return MyBackend class MyTrainer(DataParallelTrainer): def __init__(self, train_loop_per_worker, my_backend_config: MyBackendConfig, **kwargs): super().__init__( train_loop_per_worker, backend_config=my_backend_config, **kwargs)
- Parameters:
train_loop_per_worker – The training function to execute. This can either take in no arguments or a
config
dict.train_loop_config – Configurations to pass into
train_loop_per_worker
if it accepts an argument.backend_config – Configuration for setting up a Backend (e.g. Torch, Tensorflow, Horovod) on each worker to enable distributed communication. If no Backend should be set up, then set this to None.
scaling_config – Configuration for how to scale data parallel training.
dataset_config – Configuration for dataset ingest. This is merged with the default dataset config for the given trainer (
cls._dataset_config
).run_config – Configuration for the execution of the training run.
datasets – Ray Datasets to use for training and evaluation. This is a dict where the key is the name of the dataset, which can be accessed from within the
train_loop_per_worker
by callingtrain.get_dataset_shard(dataset_key)
. By default, all datasets are sharded equally across workers. This can be configured viadataset_config
.metadata – Dict that should be made available via
train.get_context().get_metadata()
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.resume_from_checkpoint – A checkpoint to resume training from.
DeveloperAPI: This API may change across minor Ray releases.
Methods
Converts self to a
tune.Trainable
class.Checks whether a given directory contains a restorable Train experiment.
Runs training.
Returns a copy of this Trainer's final dataset configs.
Deprecated.
Restores a DataParallelTrainer from a previously interrupted/failed run.
Called during fit() to perform initial setup on the Trainer.