Source code for ray.train.xgboost.xgboost_trainer

import logging
from functools import partial
from typing import Any, Callable, Dict, Optional, Union

import xgboost
from packaging.version import Version

import ray.train
from ray.train import Checkpoint
from ray.train.constants import TRAIN_DATASET_KEY
from ray.train.trainer import GenDataset
from ray.train.xgboost import RayTrainReportCallback, XGBoostConfig
from ray.train.xgboost.v2 import XGBoostTrainer as SimpleXGBoostTrainer
from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)


LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE = (
    "Passing in `xgboost.train` kwargs such as `params`, `num_boost_round`, "
    "`label_column`, etc. to `XGBoostTrainer` is deprecated "
    "in favor of the new API which accepts a ``train_loop_per_worker`` argument, "
    "similar to the other DataParallelTrainer APIs (ex: TorchTrainer). "
    "See this issue for more context: "
    "https://github.com/ray-project/ray/issues/50042"
)


def _xgboost_train_fn_per_worker(
    config: dict,
    label_column: str,
    num_boost_round: int,
    dataset_keys: set,
    xgboost_train_kwargs: dict,
):
    checkpoint = ray.train.get_checkpoint()
    starting_model = None
    remaining_iters = num_boost_round
    if checkpoint:
        starting_model = RayTrainReportCallback.get_model(checkpoint)
        starting_iter = starting_model.num_boosted_rounds()
        remaining_iters = num_boost_round - starting_iter
        logger.info(
            f"Model loaded from checkpoint will train for "
            f"additional {remaining_iters} iterations (trees) in order "
            "to achieve the target number of iterations "
            f"({num_boost_round=})."
        )

    train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
    train_df = train_ds_iter.materialize().to_pandas()

    eval_ds_iters = {
        k: ray.train.get_dataset_shard(k)
        for k in dataset_keys
        if k != TRAIN_DATASET_KEY
    }
    eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}

    train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
    dtrain = xgboost.DMatrix(train_X, label=train_y)

    # NOTE: Include the training dataset in the evaluation datasets.
    # This allows `train-*` metrics to be calculated and reported.
    evals = [(dtrain, TRAIN_DATASET_KEY)]

    for eval_name, eval_df in eval_dfs.items():
        eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
        evals.append((xgboost.DMatrix(eval_X, label=eval_y), eval_name))

    evals_result = {}
    xgboost.train(
        config,
        dtrain=dtrain,
        evals=evals,
        evals_result=evals_result,
        num_boost_round=remaining_iters,
        xgb_model=starting_model,
        **xgboost_train_kwargs,
    )


[docs] @PublicAPI(stability="beta") class XGBoostTrainer(SimpleXGBoostTrainer): """A Trainer for distributed data-parallel XGBoost training. Example ------- .. testcode:: import xgboost import ray.data import ray.train from ray.train.xgboost import RayTrainReportCallback, XGBoostTrainer def train_fn_per_worker(config: dict): # (Optional) Add logic to resume training state from a checkpoint. # ray.train.get_checkpoint() # 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix` train_ds_iter, eval_ds_iter = ( ray.train.get_dataset_shard("train"), ray.train.get_dataset_shard("validation"), ) train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize() train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas() train_X, train_y = train_df.drop("y", axis=1), train_df["y"] eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"] dtrain = xgboost.DMatrix(train_X, label=train_y) deval = xgboost.DMatrix(eval_X, label=eval_y) params = { "tree_method": "approx", "objective": "reg:squarederror", "eta": 1e-4, "subsample": 0.5, "max_depth": 2, } # 2. Do distributed data-parallel training. # Ray Train sets up the necessary coordinator processes and # environment variables for your workers to communicate with each other. bst = xgboost.train( params, dtrain=dtrain, evals=[(deval, "validation")], num_boost_round=10, callbacks=[RayTrainReportCallback()], ) train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)]) trainer = XGBoostTrainer( train_fn_per_worker, datasets={"train": train_ds, "validation": eval_ds}, scaling_config=ray.train.ScalingConfig(num_workers=4), ) result = trainer.fit() booster = RayTrainReportCallback.get_model(result.checkpoint) .. testoutput:: :hide: ... Args: train_loop_per_worker: The training function to execute on each worker. This function can either take in zero arguments or a single ``Dict`` argument which is set by defining ``train_loop_config``. Within this function you can use any of the :ref:`Ray Train Loop utilities <train-loop-api>`. train_loop_config: A configuration ``Dict`` to pass in as an argument to ``train_loop_per_worker``. This is typically used for specifying hyperparameters. xgboost_config: The configuration for setting up the distributed xgboost backend. Defaults to using the "rabit" backend. See :class:`~ray.train.xgboost.XGBoostConfig` for more info. datasets: The Ray Datasets to use for training and validation. dataset_config: The configuration for ingesting the input ``datasets``. By default, all the Ray Datasets are split equally across workers. See :class:`~ray.train.DataConfig` for more details. scaling_config: The configuration for how to scale data parallel training. ``num_workers`` determines how many Python processes are used for training, and ``use_gpu`` determines whether or not each process should use GPUs. See :class:`~ray.train.ScalingConfig` for more info. run_config: The configuration for the execution of the training run. See :class:`~ray.train.RunConfig` for more info. resume_from_checkpoint: A checkpoint to resume training from. This checkpoint can be accessed from within ``train_loop_per_worker`` by calling ``ray.train.get_checkpoint()``. metadata: Dict that should be made available via `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` for checkpoints saved from this Trainer. Must be JSON-serializable. label_column: [Deprecated] Name of the label column. A column with this name must be present in the training dataset. params: [Deprecated] XGBoost training parameters. Refer to `XGBoost documentation <https://xgboost.readthedocs.io/>`_ for a list of possible parameters. num_boost_round: [Deprecated] Target number of boosting iterations (trees in the model). Note that unlike in ``xgboost.train``, this is the target number of trees, meaning that if you set ``num_boost_round=10`` and pass a model that has already been trained for 5 iterations, it will be trained for 5 iterations more, instead of 10 more. **train_kwargs: [Deprecated] Additional kwargs passed to ``xgboost.train()`` function. """ _handles_checkpoint_freq = True _handles_checkpoint_at_end = True def __init__( self, train_loop_per_worker: Optional[ Union[Callable[[], None], Callable[[Dict], None]] ] = None, *, train_loop_config: Optional[Dict] = None, xgboost_config: Optional[XGBoostConfig] = None, scaling_config: Optional[ray.train.ScalingConfig] = None, run_config: Optional[ray.train.RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, dataset_config: Optional[ray.train.DataConfig] = None, resume_from_checkpoint: Optional[Checkpoint] = None, metadata: Optional[Dict[str, Any]] = None, # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API label_column: Optional[str] = None, params: Optional[Dict[str, Any]] = None, num_boost_round: Optional[int] = None, **train_kwargs, ): if Version(xgboost.__version__) < Version("1.7.0"): raise ImportError( "`XGBoostTrainer` requires the `xgboost` version to be >= 1.7.0. " 'Upgrade with: `pip install -U "xgboost>=1.7"`' ) # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API legacy_api = train_loop_per_worker is None if legacy_api: train_loop_per_worker = self._get_legacy_train_fn_per_worker( xgboost_train_kwargs=train_kwargs, run_config=run_config, label_column=label_column, num_boost_round=num_boost_round, datasets=datasets, ) train_loop_config = params or {} # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API # elif train_kwargs: # _log_deprecation_warning( # "Passing `xgboost.train` kwargs to `XGBoostTrainer` is deprecated. " # "Please pass in a `train_loop_per_worker` function instead, " # "which has full flexibility on the call to `xgboost.train(**kwargs)`. " # f"{LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE}" # ) super(XGBoostTrainer, self).__init__( train_loop_per_worker=train_loop_per_worker, train_loop_config=train_loop_config, xgboost_config=xgboost_config, scaling_config=scaling_config, run_config=run_config, datasets=datasets, dataset_config=dataset_config, resume_from_checkpoint=resume_from_checkpoint, metadata=metadata, ) def _get_legacy_train_fn_per_worker( self, xgboost_train_kwargs: Dict, run_config: Optional[ray.train.RunConfig], datasets: Optional[Dict[str, GenDataset]], label_column: Optional[str], num_boost_round: Optional[int], ) -> Callable[[Dict], None]: """Get the training function for the legacy XGBoostTrainer API.""" datasets = datasets or {} if not datasets.get(TRAIN_DATASET_KEY): raise ValueError( "`datasets` must be provided for the XGBoostTrainer API " "if `train_loop_per_worker` is not provided. " "This dict must contain the training dataset under the " f"key: '{TRAIN_DATASET_KEY}'. " f"Got keys: {list(datasets.keys())}" ) if not label_column: raise ValueError( "`label_column` must be provided for the XGBoostTrainer API " "if `train_loop_per_worker` is not provided. " "This is the column name of the label in the dataset." ) num_boost_round = num_boost_round or 10 # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API # _log_deprecation_warning(LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE) # Initialize a default Ray Train metrics/checkpoint reporting callback if needed callbacks = xgboost_train_kwargs.get("callbacks", []) user_supplied_callback = any( isinstance(callback, RayTrainReportCallback) for callback in callbacks ) callback_kwargs = {} if run_config: checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end callback_kwargs["frequency"] = checkpoint_frequency # Default `checkpoint_at_end=True` unless the user explicitly sets it. callback_kwargs["checkpoint_at_end"] = ( checkpoint_at_end if checkpoint_at_end is not None else True ) if not user_supplied_callback: callbacks.append(RayTrainReportCallback(**callback_kwargs)) xgboost_train_kwargs["callbacks"] = callbacks train_fn_per_worker = partial( _xgboost_train_fn_per_worker, label_column=label_column, num_boost_round=num_boost_round, dataset_keys=set(datasets), xgboost_train_kwargs=xgboost_train_kwargs, ) return train_fn_per_worker
[docs] @classmethod def get_model( cls, checkpoint: Checkpoint, ) -> xgboost.Booster: """Retrieve the XGBoost model stored in this checkpoint.""" return RayTrainReportCallback.get_model(checkpoint)