class ray.train.xgboost.XGBoostTrainer(*args, **kwargs)[source]#

Bases: ray.train.gbdt_trainer.GBDTTrainer

A Trainer for data parallel XGBoost training.

This Trainer runs the XGBoost training loop in a distributed manner using multiple Ray Actors.


XGBoostTrainer does not modify or otherwise alter the working of the XGBoost distributed training algorithm. Ray only provides orchestration, data ingest and fault tolerance. For more information on XGBoost distributed training, refer to XGBoost documentation.


import ray

from ray.train.xgboost import XGBoostTrainer
from ray.air.config import ScalingConfig

train_dataset = ray.data.from_items(
    [{"x": x, "y": x + 1} for x in range(32)])
trainer = XGBoostTrainer(
    params={"objective": "reg:squarederror"},
    datasets={"train": train_dataset}
result = trainer.fit()
  • datasets – Ray Datasets to use for training and validation. Must include a β€œtrain” key denoting the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided. All non-training datasets will be used as separate validation sets, each reporting a separate metric.

  • label_column – Name of the label column. A column with this name must be present in the training dataset.

  • params – XGBoost training parameters. Refer to XGBoost documentation for a list of possible parameters.

  • dmatrix_params – Dict of dataset name:dict of kwargs passed to respective xgboost_ray.RayDMatrix initializations, which in turn are passed to xgboost.DMatrix objects created on each worker. For example, this can be used to add sample weights with the weights parameter.

  • scaling_config – Configuration for how to scale data parallel training.

  • run_config – Configuration for the execution of the training run.

  • preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

  • **train_kwargs – Additional kwargs passed to xgboost.train() function.

PublicAPI (beta): This API is in beta and may change before becoming stable.