Source code for ray.tune.integration.lightgbm
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Optional
from lightgbm import Booster
import ray.tune
from ray.train.lightgbm._lightgbm_utils import RayReportCallback
from ray.tune import Checkpoint
from ray.util.annotations import Deprecated, PublicAPI
[docs]
@PublicAPI(stability="beta")
class TuneReportCheckpointCallback(RayReportCallback):
"""Creates a callback that reports metrics and checkpoints model.
Args:
metrics: Metrics to report. If this is a list,
each item should be a metric key reported by LightGBM,
and it will be reported to Ray Train/Tune under the same name.
This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
which can be used to rename LightGBM default metrics.
filename: Customize the saved checkpoint file type by passing
a filename. Defaults to "model.txt".
frequency: How often to save checkpoints, in terms of iterations.
Defaults to 0 (no checkpoints are saved during training).
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
results_postprocessing_fn: An optional Callable that takes in
the metrics dict that will be reported (after it has been flattened)
and returns a modified dict.
Examples
--------
Reporting checkpoints and metrics to Ray Tune when running many
independent LightGBM trials (without data parallelism within a trial).
.. testcode::
:skipif: True
import lightgbm
from ray.tune.integration.lightgbm import TuneReportCheckpointCallback
config = {
# ...
"metric": ["binary_logloss", "binary_error"],
}
# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
...,
callbacks=[
TuneReportCheckpointCallback(
metrics={"loss": "eval-binary_logloss"}, frequency=1
)
],
)
"""
@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint.from_directory(temp_checkpoint_dir)
def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
with self._get_checkpoint(model=model) as checkpoint:
ray.tune.report(report_dict, checkpoint=checkpoint)
def _report_metrics(self, report_dict: Dict):
ray.tune.report(report_dict)
@Deprecated
class TuneReportCallback:
def __new__(cls: type, *args, **kwargs):
# TODO(justinvyu): [code_removal] Remove in 2.11.
raise DeprecationWarning(
"`TuneReportCallback` is deprecated. "
"Use `ray.tune.integration.lightgbm.TuneReportCheckpointCallback` instead."
)