Source code for ray.train.backend
import logging
from typing import Type, TypeVar, Dict
from ray.air.checkpoint import Checkpoint
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import Deprecated, DeveloperAPI
from ray.widgets import make_table_html_repr
EncodedData = TypeVar("EncodedData")
logger = logging.getLogger(__name__)
# This is used in several places to print a warning.
_encode_decode_deprecation_message = (
"``encode_data`` and ``decode_data`` are deprecated in favor of "
"framework-specific ``ray.air.Checkpoint`` subclasses (reported "
"using ``ray.air.session.report()``) which can implement "
"encoding and decoding logic."
)
def _warn_about_bad_checkpoint_type(expected_checkpoint_cls: Type[Checkpoint]):
return
# Do not print warnings in 2.1 yet.
# TODO(ml-team): Change this once we have full API parity with framework
# checkpoints. Also turn on test_torch_trainer::test_torch_bad_checkpoint_warning
# warnings.warn(
# f"You have reported a checkpoint with the `{Checkpoint}` "
# "type, but the intended checkpoint type for the Trainer "
# f"you are using is `{expected_checkpoint_cls}`. "
# "Not using the intended checkpoint type may cause "
# "exceptions or other issues, especially during "
# "serialization and deserialization. The checkpoint "
# "type will be changed automatically. "
# "This behavior may change in the future."
# )
[docs]@DeveloperAPI
class BackendConfig:
"""Parent class for configurations of training backend."""
@property
def backend_cls(self):
return Backend
def _repr_html_(self) -> str:
return make_table_html_repr(obj=self, title=type(self).__name__)
[docs]@DeveloperAPI
class Backend(metaclass=Singleton):
"""Singleton for distributed communication backend.
Attributes:
share_cuda_visible_devices: If True, each worker
process will have CUDA_VISIBLE_DEVICES set as the visible device
IDs of all workers on the same node for this training instance.
If False, each worker will have CUDA_VISIBLE_DEVICES set to the
device IDs allocated by Ray for that worker.
"""
share_cuda_visible_devices: bool = False
[docs] def on_start(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for starting this backend."""
pass
[docs] def on_shutdown(self, worker_group: WorkerGroup, backend_config: BackendConfig):
"""Logic for shutting down the backend."""
pass
[docs] def on_training_start(
self, worker_group: WorkerGroup, backend_config: BackendConfig
):
"""Logic ran right before training is started.
Session API is available at this point."""
pass
@classmethod
def _encode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``encode_data`` is deprecated."""
if cls.encode_data != Backend.encode_data:
raise DeprecationWarning(_encode_decode_deprecation_message)
return checkpoint
@classmethod
def _decode_data(cls, checkpoint: Checkpoint) -> Checkpoint:
"""Temporary method until ``decode_data`` is deprecated."""
if cls.decode_data != Backend.decode_data:
raise DeprecationWarning(_encode_decode_deprecation_message)
return checkpoint
[docs] @Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def encode_data(data_dict: Dict) -> EncodedData:
"""Logic to encode a data dict before sending to the driver.
This function will be called on the workers for any data that is
sent to the driver via ``session.report()``.
"""
return data_dict
[docs] @Deprecated(message=_encode_decode_deprecation_message)
@staticmethod
def decode_data(encoded_data: EncodedData) -> Dict:
"""Logic to decode an encoded data dict.
This function will be called on the driver after receiving the
encoded data dict from the worker.
"""
return encoded_data