Source code for ray.train.v2.api.data_parallel_trainer

import logging
from typing import Any, Callable, Dict, List, Optional, Union

import ray
from ray._private.ray_constants import env_bool
from ray._private.usage import usage_lib
from ray.air._internal.usage import tag_train_v2_trainer
from ray.train import (
    BackendConfig,
    Checkpoint,
    DataConfig,
    ScalingConfig,
    RunConfig,
    Result,
)
from ray.train.base_trainer import (
    _RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING,
    _TRAINER_RESTORE_DEPRECATION_WARNING,
)
from ray.train.constants import RAY_CHDIR_TO_TRIAL_DIR, RAY_TRAIN_ENABLE_STATE_TRACKING
from ray.train.context import _GET_METADATA_DEPRECATION_MESSAGE
from ray.train.v2._internal.callbacks import (
    AcceleratorSetupCallback,
    BackendSetupCallback,
    DatasetsSetupCallback,
    WorkingDirectorySetupCallback,
)
from ray.train.v2._internal.callbacks.datasets import GenDataset
from ray.train.v2._internal.callbacks.metrics import (
    ControllerMetricsCallback,
    WorkerMetricsCallback,
)
from ray.train.v2._internal.callbacks.state_manager import StateManagerCallback
from ray.train.v2._internal.callbacks.user_callback import UserCallbackHandler
from ray.train.v2._internal.constants import (
    DEFAULT_RUN_CONTROLLER_AS_ACTOR,
    METRICS_ENABLED_ENV_VAR,
    RUN_CONTROLLER_AS_ACTOR_ENV_VAR,
    get_env_vars_to_propagate,
)
from ray.train.v2._internal.execution.callback import RayTrainCallback
from ray.train.v2._internal.execution.context import TrainRunContext
from ray.train.v2._internal.execution.controller import TrainController
from ray.train.v2._internal.execution.failure_handling import create_failure_policy
from ray.train.v2._internal.execution.scaling_policy import create_scaling_policy
from ray.train.v2._internal.util import construct_train_func
from ray.train.v2.api.callback import UserCallback
from ray.util.annotations import Deprecated, DeveloperAPI
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

logger = logging.getLogger(__name__)


[docs] @DeveloperAPI class DataParallelTrainer: """Base class for distributed data parallel training on Ray. This class supports the SPMD parallelization pattern, where a single training function is executed in parallel across multiple workers, and different shards of data are processed by each worker. """ def __init__( self, train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], *, train_loop_config: Optional[Dict] = None, backend_config: Optional[BackendConfig] = None, scaling_config: Optional[ScalingConfig] = None, run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, dataset_config: Optional[DataConfig] = None, # TODO: [Deprecated] Remove in future release resume_from_checkpoint: Optional[Checkpoint] = None, metadata: Optional[Dict[str, Any]] = None, ): self.run_config = run_config or RunConfig() self.train_loop_per_worker = train_loop_per_worker self.train_loop_config = train_loop_config self.scaling_config = scaling_config or ScalingConfig() self.backend_config = backend_config or BackendConfig() self.datasets = datasets or {} self.data_config = dataset_config or DataConfig() self.train_run_context = TrainRunContext(self.run_config) if resume_from_checkpoint is not None: raise DeprecationWarning(_RESUME_FROM_CHECKPOINT_DEPRECATION_WARNING) if metadata is not None: raise DeprecationWarning(_GET_METADATA_DEPRECATION_MESSAGE) usage_lib.record_library_usage("train") tag_train_v2_trainer(self)
[docs] def fit(self) -> Result: """Launches the Ray Train controller to run training on workers. Returns: A Result object containing the training result. Raises: ray.train.v2.api.exceptions.TrainingFailedError: If any failures occur during training and the number of retries configured in `FailureConfig` is exhausted. """ train_fn = construct_train_func( self.train_loop_per_worker, config=self.train_loop_config, train_func_context=self.backend_config.train_func_context, fn_arg_name="train_loop_per_worker", ) result = self._initialize_and_run_controller( train_fn=train_fn, scaling_policy=create_scaling_policy(self.scaling_config), failure_policy=create_failure_policy(self.run_config.failure_config), train_run_context=self.train_run_context, callbacks=self._create_default_callbacks(), ) if result.error: # NOTE: If the training run errored out, raise an error back to the # user's driver script. # For example, if the Train `FailurePolicy` runs out of retries, # and one of the workers errors. The controller will exit, and # the error will be raised here. raise result.error return result
def _create_default_callbacks(self) -> List[RayTrainCallback]: accelerator_setup_callback = AcceleratorSetupCallback( self.backend_config, self.scaling_config ) backend_setup_callback = BackendSetupCallback(self.backend_config) datasets_setup_callback = DatasetsSetupCallback( datasets=self.datasets, data_config=self.data_config, scaling_config=self.scaling_config, ) callbacks = [ accelerator_setup_callback, backend_setup_callback, datasets_setup_callback, ] if env_bool(RAY_CHDIR_TO_TRIAL_DIR, True): working_directory_setup_callback = WorkingDirectorySetupCallback() callbacks.append(working_directory_setup_callback) if env_bool(METRICS_ENABLED_ENV_VAR, True): callbacks.append(ControllerMetricsCallback(self.train_run_context)) callbacks.append(WorkerMetricsCallback(self.train_run_context)) if env_bool(RAY_TRAIN_ENABLE_STATE_TRACKING, False): callbacks.append(StateManagerCallback(self.train_run_context)) # Add internal callback that invokes all user-defined callbacks. user_callbacks = [ cb for cb in self.run_config.callbacks if isinstance(cb, UserCallback) ] callbacks.append( UserCallbackHandler( user_callbacks=user_callbacks, train_run_context=self.train_run_context ) ) # Append all other callbacks to the full list. This allows custom workarounds # built on top of internal callbacks to work. callbacks.extend( [cb for cb in self.run_config.callbacks if not isinstance(cb, UserCallback)] ) return callbacks def _initialize_and_run_controller(self, **controller_init_kwargs) -> Result: run_controller_as_actor = env_bool( RUN_CONTROLLER_AS_ACTOR_ENV_VAR, DEFAULT_RUN_CONTROLLER_AS_ACTOR ) if run_controller_as_actor: # Attach the controller to the node running the driver script. controller_actor_cls = ray.remote( num_cpus=0, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=ray.get_runtime_context().get_node_id(), soft=False ), runtime_env={"env_vars": get_env_vars_to_propagate()}, )(TrainController) controller = controller_actor_cls.remote(**controller_init_kwargs) ray.get(controller.run.remote()) return ray.get(controller.get_result.remote()) else: controller = TrainController(**controller_init_kwargs) controller.run() return controller.get_result()
[docs] @classmethod @Deprecated def restore(cls, *args, **kwargs): """[Deprecated] Restores a Train experiment from a previously interrupted/failed run. This method is deprecated and will be removed in a future release. """ raise DeprecationWarning(_TRAINER_RESTORE_DEPRECATION_WARNING)
[docs] @classmethod @Deprecated def can_restore(cls, *args, **kwargs): """[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run. This method is deprecated and will be removed in a future release. """ raise DeprecationWarning(_TRAINER_RESTORE_DEPRECATION_WARNING)