class ray.train.data_parallel_trainer.DataParallelTrainer(*args, **kwargs)[source]#

Bases: BaseTrainer

A Trainer for data parallel training.

You should subclass this Trainer if your Trainer follows SPMD (single program, multiple data) programming paradigm - you want multiple processes to run the same function, but on different data.

This Trainer runs the function train_loop_per_worker on multiple Ray Actors.

The train_loop_per_worker function is expected to take in either 0 or 1 arguments:

def train_loop_per_worker():
def train_loop_per_worker(config: Dict):

If train_loop_per_worker accepts an argument, then train_loop_config will be passed in as the argument. This is useful if you want to tune the values in train_loop_config as hyperparameters.

If the datasets dict contains a training dataset (denoted by the “train” key), then it will be split into multiple dataset shards that can then be accessed by train.get_dataset_shard("train") inside train_loop_per_worker. All the other datasets will not be split and train.get_dataset_shard(...) will return the the entire Dataset.

Inside the train_loop_per_worker function, you can use any of the Ray Train loop methods.

from ray import train

def train_loop_per_worker():
    # Report intermediate results for callbacks or logging and
    # checkpoint data.

    # Returns dict of last saved checkpoint.

    # Returns the Dataset shard for the given key.

    # Returns the total number of workers executing training.

    # Returns the rank of this worker.

    # Returns the rank of the worker on the current node.

Any returns from the train_loop_per_worker will be discarded and not used or persisted anywhere.

How do I use DataParallelTrainer or any of its subclasses?


import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.data_parallel_trainer import DataParallelTrainer

def train_loop_for_worker():
    dataset_shard_for_this_worker = train.get_dataset_shard("train")

    # 3 items for 3 workers, each worker gets 1 item
    batches = list(dataset_shard_for_this_worker.iter_batches(batch_size=1))
    assert len(batches) == 1

train_dataset = ray.data.from_items([1, 2, 3])
assert train_dataset.count() == 3
trainer = DataParallelTrainer(
    datasets={"train": train_dataset},
result = trainer.fit()

How do I develop on top of DataParallelTrainer?

In many cases, using DataParallelTrainer directly is sufficient to execute functions on multiple actors.

However, you may want to subclass DataParallelTrainer and create a custom Trainer for the following 2 use cases:

  • Use Case 1: You want to do data parallel training, but want to have a predefined training_loop_per_worker.

  • Use Case 2: You want to implement a custom Backend that automatically handles additional setup or teardown logic on each actor, so that the users of this new trainer do not have to implement this logic. For example, a TensorflowTrainer can be built on top of DataParallelTrainer that automatically handles setting the proper environment variables for distributed Tensorflow on each actor.

For 1, you can set a predefined training loop in __init__

from ray.train.data_parallel_trainer import DataParallelTrainer

class MyDataParallelTrainer(DataParallelTrainer):
    def __init__(self, *args, **kwargs):
        predefined_train_loop_per_worker = lambda: 1
        super().__init__(predefined_train_loop_per_worker, *args, **kwargs)

For 2, you can implement the ray.train.Backend and ray.train.BackendConfig interfaces.

from dataclasses import dataclass
from ray.train.backend import Backend, BackendConfig

class MyBackend(Backend):
    def on_start(self, worker_group, backend_config):
        def set_env_var(env_var_value):
            import os
            os.environ["MY_ENV_VAR"] = env_var_value

        worker_group.execute(set_env_var, backend_config.env_var)

class MyBackendConfig(BackendConfig):
    env_var: str = "default_value"

    def backend_cls(self):
        return MyBackend

class MyTrainer(DataParallelTrainer):
    def __init__(self, train_loop_per_worker, my_backend_config:
        MyBackendConfig, **kwargs):

            backend_config=my_backend_config, **kwargs)
  • train_loop_per_worker – The training function to execute. This can either take in no arguments or a config dict.

  • train_loop_config – Configurations to pass into train_loop_per_worker if it accepts an argument.

  • backend_config – Configuration for setting up a Backend (e.g. Torch, Tensorflow, Horovod) on each worker to enable distributed communication. If no Backend should be set up, then set this to None.

  • scaling_config – Configuration for how to scale data parallel training.

  • dataset_config – Configuration for dataset ingest. This is merged with the default dataset config for the given trainer (cls._dataset_config).

  • run_config – Configuration for the execution of the training run.

  • datasets – Any Datasets to use for training. Use the key “train” to denote which dataset is the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided.

  • metadata – Dict that should be made available via train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

  • resume_from_checkpoint – A checkpoint to resume training from.

DeveloperAPI: This API may change across minor Ray releases.



Converts self to a tune.Trainable class.


Checks whether a given directory contains a restorable Train experiment.


Runs training.


Returns a copy of this Trainer's final dataset configs.


Restores a DataParallelTrainer from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.