Source code for ray.train.lightgbm._lightgbm_utils

import tempfile
from abc import abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

from lightgbm.basic import Booster
from lightgbm.callback import CallbackEnv

import ray.train
from ray.train import Checkpoint
from ray.tune.utils import flatten_dict
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    import pandas as pd


[docs] @PublicAPI(stability="alpha") def normalize_pandas_for_lightgbm(df: "pd.DataFrame") -> "pd.DataFrame": """Map Arrow-backed pandas dtypes to NumPy-nullable equivalents. LightGBM's pandas input validation rejects Arrow-backed dtypes like ``int64[pyarrow]``. Since Ray Data 2.56, ``Dataset.to_pandas()`` preserves Arrow-backed dtypes when the source was Arrow, so callers passing the resulting frame to ``lightgbm.Dataset`` must normalize first. This helper is a faster alternative to ``df.convert_dtypes(dtype_backend="numpy_nullable")``: - It maps dtypes mechanically rather than scanning every value. - It only touches ``pd.ArrowDtype`` columns. NumPy-backed columns (e.g. from ``ray.data.from_pandas`` shards) keep their original buffers. Only numeric and boolean Arrow dtypes are remapped. Other Arrow dtypes (string, decimal, timestamp) are left as-is; LightGBM doesn't accept them as features anyway. Args: df: The pandas DataFrame to normalize. Returns: A DataFrame with Arrow-backed numeric/boolean columns replaced by NumPy-nullable equivalents. Other columns are returned unchanged. """ import pandas as pd import pyarrow as pa dtype_mapping = {} for column, dtype in df.dtypes.items(): if not isinstance(dtype, pd.ArrowDtype): continue arrow_dtype = dtype.pyarrow_dtype if pa.types.is_signed_integer(arrow_dtype): dtype_mapping[column] = f"Int{arrow_dtype.bit_width}" elif pa.types.is_unsigned_integer(arrow_dtype): dtype_mapping[column] = f"UInt{arrow_dtype.bit_width}" elif pa.types.is_floating(arrow_dtype): dtype_mapping[column] = f"Float{arrow_dtype.bit_width}" elif pa.types.is_boolean(arrow_dtype): dtype_mapping[column] = "boolean" if dtype_mapping: df = df.astype(dtype_mapping, copy=False) return df
class RayReportCallback: CHECKPOINT_NAME = "model.txt" def __init__( self, metrics: Optional[Union[str, List[str], Dict[str, str]]] = None, filename: str = CHECKPOINT_NAME, frequency: int = 0, checkpoint_at_end: bool = True, results_postprocessing_fn: Optional[ Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]] ] = None, ): if isinstance(metrics, str): metrics = [metrics] self._metrics = metrics self._filename = filename self._frequency = frequency self._checkpoint_at_end = checkpoint_at_end self._results_postprocessing_fn = results_postprocessing_fn @classmethod def get_model( cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME ) -> Booster: """Retrieve the model stored in a checkpoint reported by this callback. Args: checkpoint: The checkpoint object returned by a training run. The checkpoint should be saved by an instance of this callback. filename: The filename to load the model from, which should match the filename used when creating the callback. Returns: The model loaded from the checkpoint. """ with checkpoint.as_directory() as checkpoint_path: return Booster(model_file=Path(checkpoint_path, filename).as_posix()) def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict: result_dict = flatten_dict(evals_log, delimiter="-") if not self._metrics: report_dict = result_dict else: report_dict = {} for key in self._metrics: if isinstance(self._metrics, dict): metric = self._metrics[key] else: metric = key report_dict[key] = result_dict[metric] if self._results_postprocessing_fn: report_dict = self._results_postprocessing_fn(report_dict) return report_dict def _get_eval_result(self, env: CallbackEnv) -> dict: eval_result = {} for entry in env.evaluation_result_list: data_name, eval_name, result = entry[0:3] if len(entry) > 4: stdv = entry[4] suffix = "-mean" else: stdv = None suffix = "" if data_name not in eval_result: eval_result[data_name] = {} eval_result[data_name][eval_name + suffix] = result if stdv is not None: eval_result[data_name][eval_name + "-stdv"] = stdv return eval_result @abstractmethod def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]: """Get checkpoint from model. This method needs to be implemented by subclasses. """ raise NotImplementedError @abstractmethod def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster): """Save checkpoint and report metrics corresonding to this checkpoint. This method needs to be implemented by subclasses. """ raise NotImplementedError @abstractmethod def _report_metrics(self, report_dict: Dict): """Report Metrics. This method needs to be implemented by subclasses. """ raise NotImplementedError def __call__(self, env: CallbackEnv) -> None: eval_result = self._get_eval_result(env) report_dict = self._get_report_dict(eval_result) # Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=11, # you will checkpoint at iterations 1, 3, 5, ..., 9, and 10 (checkpoint_at_end) # (iterations count from 0) on_last_iter = env.iteration == env.end_iteration - 1 should_checkpoint_at_end = on_last_iter and self._checkpoint_at_end should_checkpoint_with_frequency = ( self._frequency != 0 and (env.iteration + 1) % self._frequency == 0 ) should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency if should_checkpoint: self._save_and_report_checkpoint(report_dict, env.model) else: self._report_metrics(report_dict)
[docs] @PublicAPI(stability="beta") class RayTrainReportCallback(RayReportCallback): """Creates a callback that reports metrics and checkpoints model. Args: metrics: Metrics to report. If this is a list, each item should be a metric key reported by LightGBM, and it will be reported to Ray Train/Tune under the same name. This can also be a dict of {<key-to-report>: <lightgbm-metric-key>}, which can be used to rename LightGBM default metrics. filename: Customize the saved checkpoint file type by passing a filename. Defaults to "model.txt". frequency: How often to save checkpoints, in terms of iterations. Defaults to 0 (no checkpoints are saved during training). checkpoint_at_end: Whether or not to save a checkpoint at the end of training. results_postprocessing_fn: An optional Callable that takes in the metrics dict that will be reported (after it has been flattened) and returns a modified dict. Examples -------- Reporting checkpoints and metrics to Ray Tune when running many independent LightGBM trials (without data parallelism within a trial). .. testcode:: :skipif: True import lightgbm from ray.train.lightgbm import RayTrainReportCallback config = { # ... "metric": ["binary_logloss", "binary_error"], } # Report only log loss to Tune after each validation epoch. bst = lightgbm.train( ..., callbacks=[ RayTrainReportCallback( metrics={"loss": "eval-binary_logloss"}, frequency=1 ) ], ) Loading a model from a checkpoint reported by this callback. .. testcode:: :skipif: True from ray.train.lightgbm import RayTrainReportCallback # Get a `Checkpoint` object that is saved by the callback during training. result = trainer.fit() booster = RayTrainReportCallback.get_model(result.checkpoint) """ @contextmanager def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]: if ray.train.get_context().get_world_rank() in (0, None): with tempfile.TemporaryDirectory() as temp_checkpoint_dir: model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix()) yield Checkpoint.from_directory(temp_checkpoint_dir) else: yield None def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster): with self._get_checkpoint(model=model) as checkpoint: ray.train.report(report_dict, checkpoint=checkpoint) def _report_metrics(self, report_dict: Dict): ray.train.report(report_dict)