Source code for ray.train.torch.xla.config
import logging
import os
import re
import shutil
import uuid
from dataclasses import dataclass
import ray
from ray.train._internal.base_worker_group import BaseWorkerGroup
from ray.train._internal.utils import get_address_and_port
from ray.train.backend import Backend
from ray.train.torch import TorchConfig
from ray.util import PublicAPI
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="alpha")
@dataclass
class TorchXLAConfig(TorchConfig):
    """
    Configuration for torch XLA setup.
    See https://pytorch.org/xla/release/1.13/index.html for more info.
    Currently, only "neuron_cores" accelerator (AwsNeuronXLABackend)
    is supported with xrt runtime.
    """
    neuron_parallel_compile: bool = False
    @property
    def backend_cls(self):
        return _TorchAwsNeuronXLABackend 
def _kill_xrt_server():
    import subprocess
    subprocess.call(["pkill", "-f", "xrt_run_server"])
def _set_xla_env_vars():
    # https://pytorch.org/docs/1.13/elastic/run.html#environment-variables
    context = ray.train.get_context()
    os.environ["LOCAL_RANK"] = str(context.get_local_rank())
    os.environ["RANK"] = str(context.get_world_rank())
    os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size())
    os.environ["WORLD_SIZE"] = str(context.get_world_size())
    os.environ["GROUP_RANK"] = str(context.get_node_rank())
    os.environ["GROUP_WORLD_SIZE"] = str(
        context.get_world_size() / context.get_local_world_size()
    )
    os.environ["ROLE_RANK"] = str(context.get_world_rank())
    os.environ["ROLE_WORLD_RANK"] = str(context.get_world_rank())
    os.environ["ROLE_WORLD_SIZE"] = str(context.get_world_size())
    # EFA and XLA setup
    # https://github.com/aws/libfabric/blob/master/prov/efa/src/rxr/rxr_init.c
    # https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh # noqa
    os.environ["FI_PROVIDER"] = "efa"
    os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
    os.environ["FI_EFA_FORK_SAFE"] = "1"
    os.environ["XLA_TRANSFER_SEED_ASYNC"] = "1"
    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
def _setup_xla_torch_process_group():
    try:
        import torch.distributed as dist
        import torch_xla.core.xla_model as xm  # noqa F401
        import torch_xla.distributed.xla_backend  # noqa F401
        dist.init_process_group("xla")
    except ImportError:
        raise ImportError("torch_xla must be installed to use torch_xla backend.")
# The following env vars enable Neuron graph extraction for parallel compilation
#   Note: model outputs are invalid and should be ignored while these env vars are set
def _set_neuron_parallel_compile_env_vars():
    os.environ["NEURON_PARALLEL_COMPILE"] = "1"
    os.environ["NEURON_EXTRACT_GRAPHS_ONLY"] = "1"
    os.environ["NEURON_FALL_BACK_TO_NULL_NEFF"] = "1"
# Compile previously extracted Neuron graphs
def _neuron_compile_extracted_graphs():
    try:
        from libneuronxla.neuron_cc_cache import CacheUrl
        from libneuronxla.neuron_parallel_compile import parallel_compile
    except ImportError:
        raise ImportError(
            "libneuronxla must be installed to use Neuron parallel compilation."
        )
    # Only 1 worker per node should run parallel_compile()
    if os.environ.get("LOCAL_RANK") == "0":
        logger.info("Compiling extracted graphs on local rank0 worker")
        parallel_compile_workdir = (
            f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/"
        )
        if os.path.exists(parallel_compile_workdir):
            shutil.rmtree(parallel_compile_workdir)
        os.makedirs(parallel_compile_workdir, exist_ok=True)
        # Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by
        # using NEURON_COMPILE_CACHE_URL. --cache_dir takes precedence.
        explicit_cache_dir = None
        if neuron_cc_flags := os.environ.get("NEURON_CC_FLAGS"):
            if s := re.search(r"--cache_dir[= ](\S+)", neuron_cc_flags):
                explicit_cache_dir = s.group(1)
        parallel_compile(
            parallel_compile_workdir,
            CacheUrl.get_cache_url(explicit_cache_dir),
        )
class _TorchAwsNeuronXLABackend(Backend):
    unique_run_id: str = str(uuid.uuid4())
    def on_start(self, worker_group: BaseWorkerGroup, backend_config: TorchXLAConfig):
        """Logic ran right before training is started."""
        # On previous worker failure, we don't run graceful shutdown on workers.
        # This would leak any running xrt server.
        worker_group.execute(_kill_xrt_server)
        # Get master address and port from the first worker.
        master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
        def set_env_vars(addr, port):
            os.environ["MASTER_ADDR"] = addr
            os.environ["MASTER_PORT"] = str(port)
            # To trigger the xrt server
            os.environ["TORCHELASTIC_RUN_ID"] = self.unique_run_id
        # Set the env vars on all workers.
        worker_group.execute(set_env_vars, addr=master_addr, port=master_port)
        # Set up env vars for neuron parallel compilation graph extraction
        if backend_config.neuron_parallel_compile:
            logger.info("Extracting graphs for Neuron parallel compilation")
            worker_group.execute(_set_neuron_parallel_compile_env_vars)
    def on_training_start(
        self, worker_group: BaseWorkerGroup, backend_config: TorchXLAConfig
    ):
        """
        Configure the environment variables for the worker group.
        And initialize the xla distributed process group.
        TODO: Current setup only supports homogenous cluster with
         neuron_cores accelerator and xrt runtime.
        """
        worker_group.execute(_set_xla_env_vars)
        worker_group.execute(_setup_xla_torch_process_group)
    def on_shutdown(
        self, worker_group: BaseWorkerGroup, backend_config: TorchXLAConfig
    ):
        """
        Logic ran right after training is finished.
        This is a sanity cleanup to kill xrt server, and to optionally
        run neuron parallel graph compilation
        """
        worker_group.execute(_kill_xrt_server)
        # Compile the extracted graphs. This must run at end of training.
        if backend_config.neuron_parallel_compile:
            worker_group.execute(_neuron_compile_extracted_graphs)