from dataclasses import dataclass
import logging
import os
from datetime import timedelta
from typing import Optional
import ray
from ray.air.checkpoint import Checkpoint
from ray.train.backend import BackendConfig, Backend, _warn_about_bad_checkpoint_type
from ray.train.constants import DEFAULT_NCCL_SOCKET_IFNAME
from ray.train._internal.worker_group import WorkerGroup
from ray.train._internal.utils import get_address_and_port
from ray.train.torch.torch_checkpoint import TorchCheckpoint
from ray.util import PublicAPI
import torch
import torch.distributed as dist
try:
from torch.profiler import profile
except ImportError:
profile = None
logger = logging.getLogger(__name__)
[docs]@PublicAPI(stability="beta")
@dataclass
class TorchConfig(BackendConfig):
"""Configuration for torch process group setup.
See https://pytorch.org/docs/stable/distributed.html for more info.
Args:
backend: The backend to use for training.
See ``torch.distributed.init_process_group`` for more info and
valid values.
If set to None, nccl will be used if GPUs are requested, else gloo
will be used.
init_method: The initialization method to use. Either "env"
for environment variable initialization or "tcp" for TCP
initialization. Defaults to "env".
timeout_s: Seconds for process group operations to timeout.
"""
backend: Optional[str] = None
init_method: str = "env"
timeout_s: int = 1800
@property
def backend_cls(self):
return _TorchBackend
def _set_nccl_network_interface():
"""Set the appropriate NCCL network interface to use."""
if "NCCL_SOCKET_IFNAME" not in os.environ:
logger.debug(
f"Setting NCCL_SOCKET_IFNAME to {DEFAULT_NCCL_SOCKET_IFNAME} "
f"to prioritize ethernet connection. To override this behavior, set the "
f"`NCCL_SOCKET_IFNAME` environment variable in your Ray runtime "
"environment: "
"`ray.init(runtime_env={{'env_vars': {'NCCL_SOCKET_IFNAME': 'ens5'}}}`"
)
os.environ["NCCL_SOCKET_IFNAME"] = DEFAULT_NCCL_SOCKET_IFNAME
def _setup_torch_process_group(
backend: str,
world_rank: int,
world_size: int,
init_method: str,
timeout_s: int = 1800,
):
"""Connects the distributed PyTorch backend.
Args:
backend: The backend (nccl, gloo, etc.) to use for training.
world_rank: Rank of the current worker.
world_size: Number of workers participating in the job.
init_method: URL specifying how to initialize the process group.
timeout_s: Seconds for process group operations to timeout.
"""
if world_rank == 0:
logger.info(
f"Setting up process group for: {init_method} [rank={world_rank}, "
f"world_size={world_size}]"
)
else:
logger.debug(
f"Setting up process group for: {init_method} [rank={world_rank}, "
f"world_size={world_size}]"
)
logger.debug(f"using {backend}")
# See the `timeout` arg in https://pytorch.org/docs/master/
# distributed.html#torch.distributed.init_process_group for description of
# NCCL_ASYNC_ERROR_HANDLING. We do not use NCCL_BLOCKING_WAIT due to performance
# overhead.
if (
backend == "nccl"
and "NCCL_ASYNC_ERROR_HANDLING" not in os.environ
and "NCCL_BLOCKING_WAIT" not in os.environ
):
logger.debug(
"Setting NCCL_ASYNC_ERROR_HANDLING to fail if NCCL collective "
"communication operations are timing out. "
"To override this behavior, you can set NCCL_ASYNC_ERROR_HANDLING=0."
)
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
dist.init_process_group(
backend=backend,
init_method=init_method,
rank=world_rank,
world_size=world_size,
timeout=timedelta(seconds=timeout_s),
)
def _shutdown_torch(destroy_process_group=False):
from ray.train.torch.train_loop_utils import get_device
devices = get_device()
if not isinstance(devices, list):
devices = [devices]
if destroy_process_group:
dist.destroy_process_group()
if torch.cuda.is_available():
for device in devices:
with torch.cuda.device(device):
torch.cuda.empty_cache()
def _set_torch_distributed_env_vars():
# Same env vars as in
# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
from ray.air import session
from ray.train.torch.train_loop_utils import get_device
os.environ["LOCAL_RANK"] = str(session.get_local_rank())
os.environ["RANK"] = str(session.get_world_rank())
os.environ["LOCAL_WORLD_SIZE"] = str(session.get_local_world_size())
os.environ["WORLD_SIZE"] = str(session.get_world_size())
os.environ["NODE_RANK"] = str(session.get_node_rank())
# Makes sure Hugging Face Accelerate uses the correct device
device = get_device()
if isinstance(device, list):
device = device[0]
os.environ["ACCELERATE_TORCH_DEVICE"] = str(device)
class _TorchBackend(Backend):
share_cuda_visible_devices: bool = True
def on_start(self, worker_group: WorkerGroup, backend_config: TorchConfig):
if dist.is_available():
# Set the appropriate training backend.
if backend_config.backend is None:
if worker_group.num_gpus_per_worker > 0:
backend = "nccl"
else:
backend = "gloo"
else:
backend = backend_config.backend
if backend == "nccl":
worker_group.execute(_set_nccl_network_interface)
master_addr, master_port = worker_group.execute_single(
0, get_address_and_port
)
if backend_config.init_method == "env":
def set_env_vars(addr, port):
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = str(port)
worker_group.execute(set_env_vars, addr=master_addr, port=master_port)
url = "env://"
elif backend_config.init_method == "tcp":
url = f"tcp://{master_addr}:{master_port}"
else:
raise ValueError(
f"The provided init_method ("
f"{backend_config.init_method}) is not supported. Must "
f"be either 'env' or 'tcp'."
)
setup_futures = []
for i in range(len(worker_group)):
setup_futures.append(
worker_group.execute_single_async(
i,
_setup_torch_process_group,
backend=backend,
world_rank=i,
world_size=len(worker_group),
init_method=url,
timeout_s=backend_config.timeout_s,
)
)
ray.get(setup_futures)
else:
raise RuntimeError("Distributed torch is not available.")
def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchConfig):
worker_group.execute(
_shutdown_torch,
destroy_process_group=len(worker_group) > 1,
)
def on_training_start(
self, worker_group: WorkerGroup, backend_config: BackendConfig
):
worker_group.execute(_set_torch_distributed_env_vars)
@classmethod
def _encode_data(cls, checkpoint: Checkpoint):
checkpoint = super()._encode_data(checkpoint)
if type(checkpoint) is Checkpoint:
_warn_about_bad_checkpoint_type(TorchCheckpoint)
checkpoint = TorchCheckpoint.from_checkpoint(checkpoint)
return checkpoint