Source code for ray.train.torch.xla.config
import logging
import os
import re
import shutil
import uuid
from dataclasses import dataclass
import ray
from ray.train._internal.utils import get_address_and_port
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend
from ray.train.torch import TorchConfig
from ray.util import PublicAPI
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="alpha")
@dataclass
class TorchXLAConfig(TorchConfig):
"""
Configuration for torch XLA setup.
See https://pytorch.org/xla/release/1.13/index.html for more info.
Currently, only "neuron_cores" accelerator (AwsNeuronXLABackend)
is supported with xrt runtime.
"""
neuron_parallel_compile: bool = False
@property
def backend_cls(self):
return _TorchAwsNeuronXLABackend
def _kill_xrt_server():
import subprocess
subprocess.call(["pkill", "-f", "xrt_run_server"])
def _set_xla_env_vars():
# https://pytorch.org/docs/1.13/elastic/run.html#environment-variables
context = ray.train.get_context()
os.environ["LOCAL_RANK"] = str(context.get_local_rank())
os.environ["RANK"] = str(context.get_world_rank())
os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size())
os.environ["WORLD_SIZE"] = str(context.get_world_size())
os.environ["GROUP_RANK"] = str(context.get_node_rank())
os.environ["GROUP_WORLD_SIZE"] = str(
context.get_world_size() / context.get_local_world_size()
)
os.environ["ROLE_RANK"] = str(context.get_world_rank())
os.environ["ROLE_WORLD_RANK"] = str(context.get_world_rank())
os.environ["ROLE_WORLD_SIZE"] = str(context.get_world_size())
# EFA and XLA setup
# https://github.com/aws/libfabric/blob/master/prov/efa/src/rxr/rxr_init.c
# https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh # noqa
os.environ["FI_PROVIDER"] = "efa"
os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
os.environ["FI_EFA_FORK_SAFE"] = "1"
os.environ["XLA_TRANSFER_SEED_ASYNC"] = "1"
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
def _setup_xla_torch_process_group():
try:
import torch.distributed as dist
import torch_xla.core.xla_model as xm # noqa F401
import torch_xla.distributed.xla_backend # noqa F401
dist.init_process_group("xla")
except ImportError:
raise ImportError("torch_xla must be installed to use torch_xla backend.")
# The following env vars enable Neuron graph extraction for parallel compilation
# Note: model outputs are invalid and should be ignored while these env vars are set
def _set_neuron_parallel_compile_env_vars():
os.environ["NEURON_PARALLEL_COMPILE"] = "1"
os.environ["NEURON_EXTRACT_GRAPHS_ONLY"] = "1"
os.environ["NEURON_FALL_BACK_TO_NULL_NEFF"] = "1"
# Compile previously extracted Neuron graphs
def _neuron_compile_extracted_graphs():
try:
from libneuronxla.neuron_cc_cache import CacheUrl
from libneuronxla.neuron_parallel_compile import parallel_compile
except ImportError:
raise ImportError(
"libneuronxla must be installed to use Neuron parallel compilation."
)
# Only 1 worker per node should run parallel_compile()
if os.environ.get("LOCAL_RANK") == "0":
logger.info("Compiling extracted graphs on local rank0 worker")
parallel_compile_workdir = (
f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/"
)
if os.path.exists(parallel_compile_workdir):
shutil.rmtree(parallel_compile_workdir)
os.makedirs(parallel_compile_workdir, exist_ok=True)
# Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by
# using NEURON_COMPILE_CACHE_URL. --cache_dir takes precedence.
explicit_cache_dir = None
if neuron_cc_flags := os.environ.get("NEURON_CC_FLAGS"):
if s := re.search(r"--cache_dir[= ](\S+)", neuron_cc_flags):
explicit_cache_dir = s.group(1)
parallel_compile(
parallel_compile_workdir,
CacheUrl.get_cache_url(explicit_cache_dir),
)
class _TorchAwsNeuronXLABackend(Backend):
unique_run_id: str = str(uuid.uuid4())
def on_start(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig):
"""Logic ran right before training is started."""
# On previous worker failure, we don't run graceful shutdown on workers.
# This would leak any running xrt server.
worker_group.execute(_kill_xrt_server)
# Get master address and port from the first worker.
master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
def set_env_vars(addr, port):
os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = str(port)
# To trigger the xrt server
os.environ["TORCHELASTIC_RUN_ID"] = self.unique_run_id
# Set the env vars on all workers.
worker_group.execute(set_env_vars, addr=master_addr, port=master_port)
# Set up env vars for neuron parallel compilation graph extraction
if backend_config.neuron_parallel_compile:
logger.info("Extracting graphs for Neuron parallel compilation")
worker_group.execute(_set_neuron_parallel_compile_env_vars)
def on_training_start(
self, worker_group: WorkerGroup, backend_config: TorchXLAConfig
):
"""
Configure the environment variables for the worker group.
And initialize the xla distributed process group.
TODO: Current setup only supports homogenous cluster with
neuron_cores accelerator and xrt runtime.
"""
worker_group.execute(_set_xla_env_vars)
worker_group.execute(_setup_xla_torch_process_group)
def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchXLAConfig):
"""
Logic ran right after training is finished.
This is a sanity cleanup to kill xrt server, and to optionally
run neuron parallel graph compilation
"""
worker_group.execute(_kill_xrt_server)
# Compile the extracted graphs. This must run at end of training.
if backend_config.neuron_parallel_compile:
worker_group.execute(_neuron_compile_extracted_graphs)