Source code for ray.train.xgboost.xgboost_checkpoint

import os
import tempfile
from typing import TYPE_CHECKING, Optional

import xgboost

from ray.air._internal.checkpointing import save_preprocessor_to_dir
from ray.air.checkpoint import Checkpoint
from ray.air.constants import MODEL_KEY
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.data.preprocessor import Preprocessor


[docs]@PublicAPI(stability="beta") class XGBoostCheckpoint(Checkpoint): """A :py:class:`~ray.air.checkpoint.Checkpoint` with XGBoost-specific functionality. Create this from a generic :py:class:`~ray.air.checkpoint.Checkpoint` by calling ``XGBoostCheckpoint.from_checkpoint(ckpt)``. """
[docs] @classmethod def from_model( cls, booster: xgboost.Booster, *, preprocessor: Optional["Preprocessor"] = None, ) -> "XGBoostCheckpoint": """Create a :py:class:`~ray.air.checkpoint.Checkpoint` that stores an XGBoost model. Args: booster: The XGBoost model to store in the checkpoint. preprocessor: A fitted preprocessor to be applied before inference. Returns: An :py:class:`XGBoostCheckpoint` containing the specified ``Estimator``. Examples: >>> import numpy as np >>> import ray >>> from ray.train.xgboost import XGBoostCheckpoint >>> import xgboost >>> >>> train_X = np.array([[1, 2], [3, 4]]) >>> train_y = np.array([0, 1]) >>> >>> model = xgboost.XGBClassifier().fit(train_X, train_y) >>> checkpoint = XGBoostCheckpoint.from_model(model.get_booster()) You can use a :py:class:`XGBoostCheckpoint` to create an :py:class:`~ray.train.xgboost.XGBoostPredictor` and preform inference. >>> from ray.train.xgboost import XGBoostPredictor >>> >>> predictor = XGBoostPredictor.from_checkpoint(checkpoint) """ with tempfile.TemporaryDirectory() as tmpdirname: booster.save_model(os.path.join(tmpdirname, MODEL_KEY)) if preprocessor: save_preprocessor_to_dir(preprocessor, tmpdirname) checkpoint = cls.from_directory(tmpdirname) ckpt_dict = checkpoint.to_dict() return cls.from_dict(ckpt_dict)
[docs] def get_model(self) -> xgboost.Booster: """Retrieve the XGBoost model stored in this checkpoint.""" with self.as_directory() as checkpoint_path: booster = xgboost.Booster() booster.load_model(os.path.join(checkpoint_path, MODEL_KEY)) return booster