ray.train.xgboost.RayTrainReportCallback#

class ray.train.xgboost.RayTrainReportCallback(*args: Any, **kwargs: Any)[source]#

Bases: TuneCallback

XGBoost callback to save checkpoints and report metrics.

Parameters:
  • metrics – Metrics to report. If this is a list, each item describes the metric key reported to XGBoost, and it will be reported under the same name. This can also be a dict of {<key-to-report>: <xgboost-metric-key>}, which can be used to rename xgboost default metrics.

  • filename – Customize the saved checkpoint file type by passing a filename. Defaults to “model.ubj”.

  • frequency – How often to save checkpoints, in terms of iterations. Defaults to 0 (no checkpoints are saved during training).

  • checkpoint_at_end – Whether or not to save a checkpoint at the end of training.

  • results_postprocessing_fn – An optional Callable that takes in the metrics dict that will be reported (after it has been flattened) and returns a modified dict. For example, this can be used to average results across CV fold when using xgboost.cv.

Examples

Reporting checkpoints and metrics to Ray Tune when running many independent xgboost trials (without data parallelism within a trial).

import xgboost

from ray.tune import Tuner
from ray.train.xgboost import RayTrainReportCallback

def train_fn(config):
    # Report log loss to Ray Tune after each validation epoch.
    bst = xgboost.train(
        ...,
        callbacks=[
            RayTrainReportCallback(
                metrics={"loss": "eval-logloss"}, frequency=1
            )
        ],
    )

tuner = Tuner(train_fn)
results = tuner.fit()

Loading a model from a checkpoint reported by this callback.

from ray.train.xgboost import RayTrainReportCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)

PublicAPI (beta): This API is in beta and may change before becoming stable.

Methods

get_model

Retrieve the model stored in a checkpoint reported by this callback.

Attributes

CHECKPOINT_NAME