Source code for ray.train.lightgbm.config
import logging
import threading
from dataclasses import dataclass
from typing import Any, Dict, Optional
import ray
from ray.train._internal.utils import get_address_and_port
from ray.train._internal.worker_group import WorkerGroup
from ray.train.backend import Backend, BackendConfig
logger = logging.getLogger(__name__)
# Global LightGBM distributed network configuration for each worker process.
_lightgbm_network_params: Optional[Dict[str, Any]] = None
_lightgbm_network_params_lock = threading.Lock()
[docs]
def get_network_params() -> Dict[str, Any]:
"""Returns the network parameters to enable LightGBM distributed training."""
global _lightgbm_network_params
with _lightgbm_network_params_lock:
if not _lightgbm_network_params:
logger.warning(
"`ray.train.lightgbm.get_network_params` was called outside "
"the context of a `ray.train.lightgbm.LightGBMTrainer`. "
"The current process has no knowledge of the distributed training "
"worker group, so this method will return an empty dict. "
"Please call this within the training loop of a "
"`ray.train.lightgbm.LightGBMTrainer`. "
"If you are in fact calling this within a `LightGBMTrainer`, "
"this is unexpected: please file a bug report to the Ray Team."
)
return {}
return _lightgbm_network_params.copy()
def _set_network_params(
num_machines: int,
local_listen_port: int,
machines: str,
):
global _lightgbm_network_params
with _lightgbm_network_params_lock:
assert (
_lightgbm_network_params is None
), "LightGBM network params are already initialized."
_lightgbm_network_params = dict(
num_machines=num_machines,
local_listen_port=local_listen_port,
machines=machines,
)
@dataclass
class LightGBMConfig(BackendConfig):
"""Configuration for LightGBM distributed data-parallel training setup.
See the LightGBM docs for more information on the "network parameters"
that Ray Train sets up for you:
https://lightgbm.readthedocs.io/en/latest/Parameters.html#network-parameters
"""
@property
def backend_cls(self):
return _LightGBMBackend
class _LightGBMBackend(Backend):
def on_training_start(
self, worker_group: WorkerGroup, backend_config: LightGBMConfig
):
node_ips_and_ports = worker_group.execute(get_address_and_port)
ports = [port for _, port in node_ips_and_ports]
machines = ",".join(
[f"{node_ip}:{port}" for node_ip, port in node_ips_and_ports]
)
num_machines = len(worker_group)
ray.get(
[
worker_group.execute_single_async(
rank, _set_network_params, num_machines, ports[rank], machines
)
for rank in range(len(worker_group))
]
)