Source code for ray.tune.integration.xgboost
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from xgboost.core import Booster
import ray.tune
from ray.train.xgboost._xgboost_utils import RayReportCallback
from ray.tune import Checkpoint
from ray.util.annotations import Deprecated, PublicAPI
[docs]
@PublicAPI(stability="beta")
class TuneReportCheckpointCallback(RayReportCallback):
"""XGBoost callback to save checkpoints and report metrics for Ray Tune.
Args:
metrics: Metrics to report. If this is a list,
each item describes the metric key reported to XGBoost,
and it will be reported under the same name.
This can also be a dict of {<key-to-report>: <xgboost-metric-key>},
which can be used to rename xgboost default metrics.
filename: Customize the saved checkpoint file type by passing
a filename. Defaults to "model.ubj".
frequency: How often to save checkpoints, in terms of iterations.
Defaults to 0 (no checkpoints are saved during training).
checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
results_postprocessing_fn: An optional Callable that takes in
the metrics dict that will be reported (after it has been flattened)
and returns a modified dict. For example, this can be used to
average results across CV fold when using ``xgboost.cv``.
Examples
--------
Reporting checkpoints and metrics to Ray Tune when running many
independent xgboost trials (without data parallelism within a trial).
.. testcode::
:skipif: True
import xgboost
from ray.tune import Tuner
from ray.tune.integration.xgboost import TuneReportCheckpointCallback
def train_fn(config):
# Report log loss to Ray Tune after each validation epoch.
bst = xgboost.train(
...,
callbacks=[
TuneReportCheckpointCallback(
metrics={"loss": "eval-logloss"}, frequency=1
)
],
)
tuner = Tuner(train_fn)
results = tuner.fit()
"""
def __init__(
self,
metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
filename: str = RayReportCallback.CHECKPOINT_NAME,
frequency: int = 0,
checkpoint_at_end: bool = True,
results_postprocessing_fn: Optional[
Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
] = None,
):
super().__init__(
metrics=metrics,
filename=filename,
frequency=frequency,
checkpoint_at_end=checkpoint_at_end,
results_postprocessing_fn=results_postprocessing_fn,
)
@contextmanager
def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
yield Checkpoint(temp_checkpoint_dir)
def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
with self._get_checkpoint(model=model) as checkpoint:
ray.tune.report(report_dict, checkpoint=checkpoint)
def _report_metrics(self, report_dict: Dict):
ray.tune.report(report_dict)
@Deprecated
class TuneReportCallback:
def __new__(cls: type, *args, **kwargs):
# TODO(justinvyu): [code_removal] Remove in 2.11.
raise DeprecationWarning(
"`TuneReportCallback` is deprecated. "
"Use `ray.tune.integration.xgboost.TuneReportCheckpointCallback` instead."
)