ray.train.lightgbm.RayTrainReportCallback#

class ray.train.lightgbm.RayTrainReportCallback(metrics: str | List[str] | Dict[str, str] | None = None, filename: str = 'model.txt', frequency: int = 0, checkpoint_at_end: bool = True, results_postprocessing_fn: Callable[[Dict[str, float | List[float]]], Dict[str, float]] | None = None)[source]#

Creates a callback that reports metrics and checkpoints model.

Parameters:
  • metrics – Metrics to report. If this is a list, each item should be a metric key reported by LightGBM, and it will be reported to Ray Train/Tune under the same name. This can also be a dict of {<key-to-report>: <lightgbm-metric-key>}, which can be used to rename LightGBM default metrics.

  • filename – Customize the saved checkpoint file type by passing a filename. Defaults to “model.txt”.

  • frequency – How often to save checkpoints, in terms of iterations. Defaults to 0 (no checkpoints are saved during training).

  • checkpoint_at_end – Whether or not to save a checkpoint at the end of training.

  • results_postprocessing_fn – An optional Callable that takes in the metrics dict that will be reported (after it has been flattened) and returns a modified dict.

Examples

Reporting checkpoints and metrics to Ray Tune when running many independent xgboost trials (without data parallelism within a trial).

import lightgbm

from ray.train.lightgbm import RayTrainReportCallback

config = {
    # ...
    "metric": ["binary_logloss", "binary_error"],
}

# Report only log loss to Tune after each validation epoch.
bst = lightgbm.train(
    ...,
    callbacks=[
        RayTrainReportCallback(
            metrics={"loss": "eval-binary_logloss"}, frequency=1
        )
    ],
)

Loading a model from a checkpoint reported by this callback.

from ray.train.lightgbm import RayTrainReportCallback

# Get a `Checkpoint` object that is saved by the callback during training.
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)

PublicAPI (beta): This API is in beta and may change before becoming stable.

Methods

get_model

Retrieve the model stored in a checkpoint reported by this callback.

Attributes

CHECKPOINT_NAME