import inspect
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union
import ray
from ray._private.thirdparty.tabulate.tabulate import tabulate
from ray.air.config import RunConfig, ScalingConfig
from ray.train import BackendConfig, Checkpoint, TrainingIterator
from ray.train._internal import session
from ray.train._internal.backend_executor import BackendExecutor, TrialInfo
from ray.train._internal.data_config import DataConfig
from ray.train._internal.session import _TrainingResult, get_session
from ray.train._internal.utils import construct_train_func
from ray.train.trainer import BaseTrainer, GenDataset
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.widgets import Template
from ray.widgets.util import repr_with_fallback
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
logger = logging.getLogger(__name__)
[docs]@DeveloperAPI
class DataParallelTrainer(BaseTrainer):
"""A Trainer for data parallel training.
You should subclass this Trainer if your Trainer follows SPMD (single program,
multiple data) programming paradigm - you want multiple processes to run the same
function, but on different data.
This Trainer runs the function ``train_loop_per_worker`` on multiple Ray
Actors.
The ``train_loop_per_worker`` function is expected to take in either 0 or 1
arguments:
.. testcode::
def train_loop_per_worker():
...
.. testcode::
def train_loop_per_worker(config: Dict):
...
If ``train_loop_per_worker`` accepts an argument, then
``train_loop_config`` will be passed in as the argument. This is useful if you
want to tune the values in ``train_loop_config`` as hyperparameters.
If the ``datasets`` dict contains a training dataset (denoted by
the "train" key), then it will be split into multiple dataset
shards that can then be accessed by ``train.get_dataset_shard("train")`` inside
``train_loop_per_worker``. All the other datasets will not be split and
``train.get_dataset_shard(...)`` will return the the entire Dataset.
Inside the ``train_loop_per_worker`` function, you can use any of the
:ref:`Ray Train loop methods <train-loop-api>`.
.. testcode::
from ray import train
def train_loop_per_worker():
# Report intermediate results for callbacks or logging and
# checkpoint data.
train.report(...)
# Returns dict of last saved checkpoint.
train.get_checkpoint()
# Returns the Dataset shard for the given key.
train.get_dataset_shard("my_dataset")
# Returns the total number of workers executing training.
train.get_context().get_world_size()
# Returns the rank of this worker.
train.get_context().get_world_rank()
# Returns the rank of the worker on the current node.
train.get_context().get_local_rank()
Any returns from the ``train_loop_per_worker`` will be discarded and not
used or persisted anywhere.
**How do I use DataParallelTrainer or any of its subclasses?**
Example:
.. testcode::
import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.data_parallel_trainer import DataParallelTrainer
def train_loop_for_worker():
dataset_shard_for_this_worker = train.get_dataset_shard("train")
# 3 items for 3 workers, each worker gets 1 item
batches = list(dataset_shard_for_this_worker.iter_batches(batch_size=1))
assert len(batches) == 1
train_dataset = ray.data.from_items([1, 2, 3])
assert train_dataset.count() == 3
trainer = DataParallelTrainer(
train_loop_for_worker,
scaling_config=ScalingConfig(num_workers=3),
datasets={"train": train_dataset},
)
result = trainer.fit()
.. testoutput::
:hide:
...
**How do I develop on top of DataParallelTrainer?**
In many cases, using DataParallelTrainer directly is sufficient to execute
functions on multiple actors.
However, you may want to subclass ``DataParallelTrainer`` and create a custom
Trainer for the following 2 use cases:
- **Use Case 1:** You want to do data parallel training, but want to have
a predefined ``training_loop_per_worker``.
- **Use Case 2:** You want to implement a custom
:py:class:`~ray.train.backend.Backend` that automatically handles
additional setup or teardown logic on each actor, so that the users of this
new trainer do not have to implement this logic. For example, a
``TensorflowTrainer`` can be built on top of ``DataParallelTrainer``
that automatically handles setting the proper environment variables for
distributed Tensorflow on each actor.
For 1, you can set a predefined training loop in __init__
.. testcode::
from ray.train.data_parallel_trainer import DataParallelTrainer
class MyDataParallelTrainer(DataParallelTrainer):
def __init__(self, *args, **kwargs):
predefined_train_loop_per_worker = lambda: 1
super().__init__(predefined_train_loop_per_worker, *args, **kwargs)
For 2, you can implement the ``ray.train.Backend`` and ``ray.train.BackendConfig``
interfaces.
.. testcode::
from dataclasses import dataclass
from ray.train.backend import Backend, BackendConfig
class MyBackend(Backend):
def on_start(self, worker_group, backend_config):
def set_env_var(env_var_value):
import os
os.environ["MY_ENV_VAR"] = env_var_value
worker_group.execute(set_env_var, backend_config.env_var)
@dataclass
class MyBackendConfig(BackendConfig):
env_var: str = "default_value"
def backend_cls(self):
return MyBackend
class MyTrainer(DataParallelTrainer):
def __init__(self, train_loop_per_worker, my_backend_config:
MyBackendConfig, **kwargs):
super().__init__(
train_loop_per_worker,
backend_config=my_backend_config, **kwargs)
Args:
train_loop_per_worker: The training function to execute.
This can either take in no arguments or a ``config`` dict.
train_loop_config: Configurations to pass into
``train_loop_per_worker`` if it accepts an argument.
backend_config: Configuration for setting up a Backend (e.g. Torch,
Tensorflow, Horovod) on each worker to enable distributed
communication. If no Backend should be set up, then set this to None.
scaling_config: Configuration for how to scale data parallel training.
dataset_config: Configuration for dataset ingest. This is merged with the
default dataset config for the given trainer (`cls._dataset_config`).
run_config: Configuration for the execution of the training run.
datasets: Any Datasets to use for training. Use
the key "train" to denote which dataset is the training
dataset. If a ``preprocessor`` is provided and has not already been fit,
it will be fit on the training dataset. All datasets will be transformed
by the ``preprocessor`` if one is provided.
metadata: Dict that should be made available via
`train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
for checkpoints saved from this Trainer. Must be JSON-serializable.
resume_from_checkpoint: A checkpoint to resume training from.
"""
# Exposed here for testing purposes. Should never need
# to be overriden.
_backend_executor_cls: Type[BackendExecutor] = BackendExecutor
_training_iterator_cls: Type[TrainingIterator] = TrainingIterator
_scaling_config_allowed_keys = BaseTrainer._scaling_config_allowed_keys + [
"num_workers",
"resources_per_worker",
"use_gpu",
"placement_strategy",
]
# For backwards compatibility with the legacy dataset config API.
_dataset_config = None
_fields_for_tuner_param_space = BaseTrainer._fields_for_tuner_param_space + [
"train_loop_config"
]
def __init__(
self,
train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
*,
train_loop_config: Optional[Dict] = None,
backend_config: Optional[BackendConfig] = None,
scaling_config: Optional[ScalingConfig] = None,
dataset_config: Optional[DataConfig] = None,
run_config: Optional[RunConfig] = None,
datasets: Optional[Dict[str, GenDataset]] = None,
metadata: Optional[Dict[str, Any]] = None,
resume_from_checkpoint: Optional[Checkpoint] = None,
# Deprecated.
preprocessor: Optional["Preprocessor"] = None,
):
self._train_loop_per_worker = train_loop_per_worker
self._train_loop_config = train_loop_config
if dataset_config is None:
dataset_config = DataConfig()
if not isinstance(dataset_config, DataConfig):
raise ValueError(
"`dataset_config` must be an instance of ray.train.DataConfig, "
f"was: {dataset_config}"
)
self._data_config = dataset_config
backend_config = (
backend_config if backend_config is not None else BackendConfig()
)
self._backend_config = backend_config
super(DataParallelTrainer, self).__init__(
scaling_config=scaling_config,
run_config=run_config,
datasets=datasets,
metadata=metadata,
preprocessor=preprocessor,
resume_from_checkpoint=resume_from_checkpoint,
)
[docs] @PublicAPI(stability="beta")
@classmethod
def restore(
cls: Type["DataParallelTrainer"],
path: str,
train_loop_per_worker: Optional[
Union[Callable[[], None], Callable[[Dict], None]]
] = None,
train_loop_config: Optional[Dict] = None,
**kwargs,
) -> "DataParallelTrainer":
"""Restores a DataParallelTrainer from a previously interrupted/failed run.
Args:
train_loop_per_worker: Optionally re-specified train loop function.
This should be used to re-specify a function that is not
restorable in a new Ray cluster (e.g., it holds onto outdated
object references). This should be the same training loop
that was passed to the original trainer constructor.
train_loop_config: Optionally re-specified train config.
This should similarly be used if the original `train_loop_config`
contained outdated object references, and it should not be modified
from what was originally passed in.
See :meth:`BaseTrainer.restore() <ray.train.trainer.BaseTrainer.restore>`
for descriptions of the other arguments.
Returns:
DataParallelTrainer: A restored instance of the `DataParallelTrainer`
subclass that is calling this method.
"""
return super(DataParallelTrainer, cls).restore(
path=path,
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_loop_config,
**kwargs,
)
def _validate_attributes(self):
super()._validate_attributes()
self._validate_train_loop_per_worker(
self._train_loop_per_worker, "train_loop_per_worker"
)
def preprocess_datasets(self) -> None:
# Evaluate all datasets.
self.datasets = {k: d() if callable(d) else d for k, d in self.datasets.items()}
self.datasets = self._data_config._legacy_preprocessing(
self.datasets, self.preprocessor
)
def _validate_train_loop_per_worker(
self, train_loop_per_worker: Callable, fn_name: str
) -> None:
num_params = len(inspect.signature(train_loop_per_worker).parameters)
if num_params > 1:
raise ValueError(
f"{fn_name} should take in 0 or 1 arguments, "
f"but it accepts {num_params} arguments instead."
)
@classmethod
def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfig:
scaling_config = super(DataParallelTrainer, cls)._validate_scaling_config(
scaling_config
)
# This validation happens after the scaling config is updated from
# its specification in the Tuner `param_space`
if not scaling_config.use_gpu and "GPU" in ray.available_resources():
logger.info(
"GPUs are detected in your Ray cluster, but GPU "
"training is not enabled for this trainer. To enable "
"GPU training, make sure to set `use_gpu` to True "
"in your scaling config."
)
if scaling_config.num_workers is None:
raise ValueError(
"You must specify the 'num_workers' in `scaling_config` as either an "
f"argument of `{cls.__name__}` or through the `param_space` of a "
"`Tuner` (if performing hyperparameter tuning)."
)
if scaling_config.num_workers <= 0:
raise ValueError(
"'num_workers' in `scaling_config` must be a positive "
f"integer. Received {scaling_config.num_workers}"
)
return scaling_config
def _report(self, training_iterator: TrainingIterator) -> None:
for results in training_iterator:
# TODO(ml-team): add ability to report results from multiple workers.
first_worker_result = results[0]
assert all(isinstance(result, _TrainingResult) for result in results)
tune_session = get_session()
# Check if any workers reported a checkpoint.
# If so, report a checkpoint pointing to the persisted location
# to Tune for book-keeping.
# NOTE: This removes the restriction for any individual worker
# (ex: global rank 0 worker) from needing to report a checkpoint.
# All workers reported a checkpoint to the same fs path, so there's
# no need to report multiple checkpoints to Tune.
worker_checkpoints = [
result.checkpoint for result in results if result.checkpoint is not None
]
at_least_one_reported_checkpoint = len(worker_checkpoints) > 0
if at_least_one_reported_checkpoint:
# Update the coordinator's checkpoint index to the latest.
# This is what keeps the checkpoint index in line with the workers.
tune_session.storage._update_checkpoint_index(
first_worker_result.metrics
)
# Make sure that all workers uploaded to the same location.
assert all(
checkpoint.path == tune_session.storage.checkpoint_fs_path
for checkpoint in worker_checkpoints
)
checkpoint = (
Checkpoint(
filesystem=tune_session.storage.storage_filesystem,
path=tune_session.storage.checkpoint_fs_path,
)
if at_least_one_reported_checkpoint
else None
)
tracked_training_result = _TrainingResult(
checkpoint=checkpoint,
metrics=first_worker_result.metrics,
)
logger.debug(
"Report (metrics, checkpoint) to the Tune session:\n"
f" metrics={tracked_training_result.metrics}\n"
f" checkpoint={tracked_training_result.checkpoint}"
)
# Report the metrics and checkpoint to Tune.
tune_session._report_training_result(tracked_training_result)
def training_loop(self) -> None:
scaling_config = self._validate_scaling_config(self.scaling_config)
train_loop_per_worker = construct_train_func(
self._train_loop_per_worker,
self._train_loop_config,
fn_arg_name="train_loop_per_worker",
discard_returns=True,
)
additional_resources_per_worker = scaling_config.additional_resources_per_worker
trial_info = TrialInfo(
name=session.get_trial_name(),
id=session.get_trial_id(),
resources=session.get_trial_resources(),
logdir=session.get_trial_dir(),
driver_ip=ray.util.get_node_ip_address(),
experiment_name=session.get_experiment_name(),
)
backend_executor = self._backend_executor_cls(
backend_config=self._backend_config,
trial_info=trial_info,
num_workers=scaling_config.num_workers,
num_cpus_per_worker=scaling_config.num_cpus_per_worker,
num_gpus_per_worker=scaling_config.num_gpus_per_worker,
additional_resources_per_worker=additional_resources_per_worker,
max_retries=0,
)
# Start the remote actors.
backend_executor.start()
training_iterator = self._training_iterator_cls(
backend_executor=backend_executor,
backend_config=self._backend_config,
train_func=train_loop_per_worker,
datasets=self.datasets,
metadata=self.metadata,
data_config=self._data_config,
checkpoint=self.starting_checkpoint,
)
self._report(training_iterator)
# Shutdown workers.
backend_executor.shutdown()
[docs] def get_dataset_config(self) -> DataConfig:
"""Returns a copy of this Trainer's final dataset configs.
Returns:
The merged default + user-supplied dataset config.
"""
return self._data_config
@repr_with_fallback(["ipywidgets", "8"])
def _repr_mimebundle_(self, **kwargs):
"""Returns a mimebundle with an ipywidget repr and a simple text repr.
Depending on the frontend where the data is being displayed,
different mimetypes will be used from this bundle.
See https://ipython.readthedocs.io/en/stable/config/integrating.html
for information about this method, and
https://ipywidgets.readthedocs.io/en/latest/embedding.html
for more information about the jupyter widget mimetype.
Returns:
A mimebundle containing an ipywidget repr and a simple text repr.
"""
from ipywidgets import HTML, Layout, Tab, VBox
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
children = []
titles = []
if self.datasets:
children.append(self._datasets_repr_())
titles.append("Datasets")
children.append(HTML(self._data_config_repr_html_()))
titles.append("Data Config")
if self._train_loop_config:
children.append(HTML(self._train_loop_config_repr_html_()))
titles.append("Train Loop Config")
if self.scaling_config:
children.append(HTML(self.scaling_config._repr_html_()))
titles.append("Scaling Config")
if self.run_config:
children.append(HTML(self.run_config._repr_html_()))
titles.append("Run Config")
if self._backend_config:
children.append(HTML(self._backend_config._repr_html_()))
titles.append("Backend Config")
tab = Tab(children, titles=titles)
widget = VBox([title, tab], layout=Layout(width="100%"))
bundle = widget._repr_mimebundle_(**kwargs)
bundle.update(
{
"text/plain": repr(self),
}
)
return bundle
def _train_loop_config_repr_html_(self) -> str:
if self._train_loop_config:
table_data = {}
for k, v in self._train_loop_config.items():
if isinstance(v, str) or str(v).isnumeric():
table_data[k] = v
elif hasattr(v, "_repr_html_"):
table_data[k] = v._repr_html_()
else:
table_data[k] = str(v)
return Template("title_data.html.j2").render(
title="Train Loop Config",
data=Template("scrollableTable.html.j2").render(
table=tabulate(
table_data.items(),
headers=["Setting", "Value"],
showindex=False,
tablefmt="unsafehtml",
),
max_height="none",
),
)
else:
return ""
def _data_config_repr_html_(self) -> str:
# TODO make this rendering nicer.
content = [str(self._data_config)]
return Template("rendered_html_common.html.j2").render(content=content)
def _datasets_repr_(self) -> str:
from ipywidgets import HTML, Layout, VBox
content = []
if self.datasets:
for name, config in self.datasets.items():
tab = config._tab_repr_()
if tab:
content.append(
HTML(
Template("title_data.html.j2").render(
title=f"Dataset - <code>{name}</code>", data=None
)
)
)
content.append(config._tab_repr_())
return VBox(content, layout=Layout(width="100%"))