Source code for ray.train.backend
import logging
from contextlib import nullcontext
from typing import TypeVar
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI
from ray.widgets import make_table_html_repr
EncodedData = TypeVar("EncodedData")
logger = logging.getLogger(__name__)
@DeveloperAPI
class BackendConfig:
"""Parent class for configurations of training backend."""
@property
def backend_cls(self):
return Backend
@property
def train_func_context(self):
return nullcontext
def _repr_html_(self) -> str:
return make_table_html_repr(obj=self, title=type(self).__name__)
[docs]
@DeveloperAPI
class Backend(metaclass=Singleton):
"""Singleton for distributed communication backend.
Attributes:
share_cuda_visible_devices: If True, each worker
process will have CUDA_VISIBLE_DEVICES set as the visible device
IDs of all workers on the same node for this training instance.
If False, each worker will have CUDA_VISIBLE_DEVICES set to the
device IDs allocated by Ray for that worker.
"""
share_cuda_visible_devices: bool = False
[docs]
def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for starting this backend."""
pass
[docs]
def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass
[docs]
def on_training_start(
self, worker_group: WorkerGroup, backend_config: BackendConfig
):
"""Logic ran right before training is started.
Session API is available at this point."""
pass