Source code for ray.train.torch.config

import logging
import os
from dataclasses import dataclass
from datetime import timedelta
from typing import Optional

import torch
import torch.distributed as dist

import ray
from ray._private.accelerators.hpu import HPU_PACKAGE_AVAILABLE
from ray.train._internal.utils import get_address_and_port
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend, BackendConfig
from ray.util import PublicAPI

logger = logging.getLogger(__name__)

class TorchConfigContextManager:
    def __enter__(self):
        # Set default cuda device
        if torch.cuda.is_available():
            device = ray.train.torch.get_device()
            if device.type == "cuda":

    def __exit__(self, type, value, traceback):
        # Propagate exceptions if any
        return False

[docs]@PublicAPI(stability="stable") @dataclass class TorchConfig(BackendConfig): """Configuration for torch process group setup. See for more info. Args: backend: The backend to use for training. See ``torch.distributed.init_process_group`` for more info and valid values. If set to None, nccl will be used if GPUs are requested, else gloo will be used. init_method: The initialization method to use. Either "env" for environment variable initialization or "tcp" for TCP initialization. Defaults to "env". timeout_s: Seconds for process group operations to timeout. """ backend: Optional[str] = None init_method: str = "env" timeout_s: int = 1800 @property def backend_cls(self): return _TorchBackend @property def train_func_context(self): return TorchConfigContextManager
def _setup_torch_process_group( backend: str, world_rank: int, world_size: int, init_method: str, timeout_s: int = 1800, ): """Connects the distributed PyTorch backend. Args: backend: The backend (nccl, gloo, etc.) to use for training. world_rank: Rank of the current worker. world_size: Number of workers participating in the job. init_method: URL specifying how to initialize the process group. timeout_s: Seconds for process group operations to timeout. """ if world_rank == 0: f"Setting up process group for: {init_method} [rank={world_rank}, " f"world_size={world_size}]" ) else: logger.debug( f"Setting up process group for: {init_method} [rank={world_rank}, " f"world_size={world_size}]" ) logger.debug(f"using {backend}") # See the `timeout` arg in # distributed.html#torch.distributed.init_process_group for description of # NCCL_ASYNC_ERROR_HANDLING. We do not use NCCL_BLOCKING_WAIT due to performance # overhead. if ( backend == "nccl" and "NCCL_ASYNC_ERROR_HANDLING" not in os.environ and "NCCL_BLOCKING_WAIT" not in os.environ ): logger.debug( "Setting NCCL_ASYNC_ERROR_HANDLING to fail if NCCL collective " "communication operations are timing out. " "To override this behavior, you can set NCCL_ASYNC_ERROR_HANDLING=0." ) os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" elif backend == "hccl" and HPU_PACKAGE_AVAILABLE: import habana_frameworks.torch.core as htcore # noqa: F401 import habana_frameworks.torch.distributed.hccl as hpu_dist # noqa: F401 dist.init_process_group( backend=backend, init_method=init_method, rank=world_rank, world_size=world_size, timeout=timedelta(seconds=timeout_s), ) def _shutdown_torch(destroy_process_group=False): from ray.air._internal.torch_utils import get_devices devices = get_devices() if destroy_process_group: dist.destroy_process_group() if torch.cuda.is_available(): for device in devices: with torch.cuda.device(device): torch.cuda.empty_cache() def _set_torch_distributed_env_vars(): # Same env vars as in # from ray.train.torch import get_device context = ray.train.get_context() os.environ["LOCAL_RANK"] = str(context.get_local_rank()) os.environ["RANK"] = str(context.get_world_rank()) os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size()) os.environ["WORLD_SIZE"] = str(context.get_world_size()) os.environ["NODE_RANK"] = str(context.get_node_rank()) # Makes sure Hugging Face Accelerate uses the correct device device = get_device() os.environ["ACCELERATE_TORCH_DEVICE"] = str(device) class _TorchBackend(Backend): share_cuda_visible_devices: bool = True def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig): if dist.is_available(): # Set the appropriate training backend. if backend_config.backend is None: if worker_group.num_gpus_per_worker > 0: backend = "nccl" else: backend = "gloo" else: backend = backend_config.backend master_addr, master_port = worker_group.execute_single( 0, get_address_and_port ) if backend_config.init_method == "env": def set_env_vars(addr, port): os.environ["MASTER_ADDR"] = addr os.environ["MASTER_PORT"] = str(port) worker_group.execute(set_env_vars, addr=master_addr, port=master_port) url = "env://" elif backend_config.init_method == "tcp": url = f"tcp://{master_addr}:{master_port}" else: raise ValueError( f"The provided init_method (" f"{backend_config.init_method}) is not supported. Must " f"be either 'env' or 'tcp'." ) setup_futures = [] for i in range(len(worker_group)): setup_futures.append( worker_group.execute_single_async( i, _setup_torch_process_group, backend=backend, world_rank=i, world_size=len(worker_group), init_method=url, timeout_s=backend_config.timeout_s, ) ) ray.get(setup_futures) else: raise RuntimeError("Distributed torch is not available.") def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchConfig): worker_group.execute( _shutdown_torch, destroy_process_group=len(worker_group) > 1, ) def on_training_start( self, worker_group: WorkerGroup, backend_config: BackendConfig ): worker_group.execute(_set_torch_distributed_env_vars)