Source code for ray.train.v2.xgboost.xgboost_trainer

import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

import ray.train
from ray.train import Checkpoint
from ray.train.trainer import GenDataset
from ray.train.v2.api.config import RunConfig, ScalingConfig
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
from ray.util.annotations import Deprecated

if TYPE_CHECKING:
    from ray.train.xgboost import XGBoostConfig

logger = logging.getLogger(__name__)


[docs] class XGBoostTrainer(DataParallelTrainer): """A Trainer for distributed data-parallel XGBoost training. Example ------- .. testcode:: import xgboost import ray.data import ray.train from ray.train.xgboost import RayTrainReportCallback from ray.train.xgboost import XGBoostTrainer def train_fn_per_worker(config: dict): # (Optional) Add logic to resume training state from a checkpoint. # ray.train.get_checkpoint() # 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix` train_ds_iter, eval_ds_iter = ( ray.train.get_dataset_shard("train"), ray.train.get_dataset_shard("validation"), ) train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize() train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas() train_X, train_y = train_df.drop("y", axis=1), train_df["y"] eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"] dtrain = xgboost.DMatrix(train_X, label=train_y) deval = xgboost.DMatrix(eval_X, label=eval_y) params = { "tree_method": "approx", "objective": "reg:squarederror", "eta": 1e-4, "subsample": 0.5, "max_depth": 2, } # 2. Do distributed data-parallel training. # Ray Train sets up the necessary coordinator processes and # environment variables for your workers to communicate with each other. bst = xgboost.train( params, dtrain=dtrain, evals=[(deval, "validation")], num_boost_round=10, callbacks=[RayTrainReportCallback()], ) train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)]) trainer = XGBoostTrainer( train_fn_per_worker, datasets={"train": train_ds, "validation": eval_ds}, scaling_config=ray.train.ScalingConfig(num_workers=4), ) result = trainer.fit() booster = RayTrainReportCallback.get_model(result.checkpoint) .. testoutput:: :hide: ... Args: train_loop_per_worker: The training function to execute on each worker. This function can either take in zero arguments or a single ``Dict`` argument which is set by defining ``train_loop_config``. Within this function you can use any of the :ref:`Ray Train Loop utilities <train-loop-api>`. train_loop_config: A configuration ``Dict`` to pass in as an argument to ``train_loop_per_worker``. This is typically used for specifying hyperparameters. xgboost_config: The configuration for setting up the distributed xgboost backend. Defaults to using the "rabit" backend. See :class:`~ray.train.xgboost.XGBoostConfig` for more info. scaling_config: The configuration for how to scale data parallel training. ``num_workers`` determines how many Python processes are used for training, and ``use_gpu`` determines whether or not each process should use GPUs. See :class:`~ray.train.ScalingConfig` for more info. run_config: The configuration for the execution of the training run. See :class:`~ray.train.RunConfig` for more info. datasets: The Ray Datasets to ingest for training. Datasets are keyed by name (``{name: dataset}``). Each dataset can be accessed from within the ``train_loop_per_worker`` by calling ``ray.train.get_dataset_shard(name)``. Sharding and additional configuration can be done by passing in a ``dataset_config``. dataset_config: The configuration for ingesting the input ``datasets``. By default, all the Ray Dataset are split equally across workers. See :class:`~ray.train.DataConfig` for more details. resume_from_checkpoint: A checkpoint to resume training from. This checkpoint can be accessed from within ``train_loop_per_worker`` by calling ``ray.train.get_checkpoint()``. metadata: Dict that should be made available via `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()` for checkpoints saved from this Trainer. Must be JSON-serializable. """ def __init__( self, train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]], *, train_loop_config: Optional[Dict] = None, xgboost_config: Optional["XGBoostConfig"] = None, scaling_config: Optional[ScalingConfig] = None, run_config: Optional[RunConfig] = None, datasets: Optional[Dict[str, GenDataset]] = None, dataset_config: Optional[ray.train.DataConfig] = None, # TODO: [Deprecated] metadata: Optional[Dict[str, Any]] = None, resume_from_checkpoint: Optional[Checkpoint] = None, # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API label_column: Optional[str] = None, params: Optional[Dict[str, Any]] = None, num_boost_round: Optional[int] = None, ): if ( label_column is not None or params is not None or num_boost_round is not None ): raise DeprecationWarning( "The legacy XGBoostTrainer API is deprecated. " "Please switch to passing in a custom `train_loop_per_worker` " "function instead. " "See this issue for more context: " "https://github.com/ray-project/ray/issues/50042" ) from ray.train.xgboost import XGBoostConfig super(XGBoostTrainer, self).__init__( train_loop_per_worker=train_loop_per_worker, train_loop_config=train_loop_config, backend_config=xgboost_config or XGBoostConfig(), scaling_config=scaling_config, dataset_config=dataset_config, run_config=run_config, datasets=datasets, resume_from_checkpoint=resume_from_checkpoint, metadata=metadata, )
[docs] @classmethod @Deprecated def get_model(cls, checkpoint: Checkpoint): """[Deprecated] Retrieve the XGBoost model stored in this checkpoint.""" raise DeprecationWarning( "`XGBoostTrainer.get_model` is deprecated. " "Use `RayTrainReportCallback.get_model` instead." )