
class ray.train.xgboost.xgboost_trainer.XGBoostTrainer(*args, **kwargs)[source]#

A Trainer for distributed data-parallel XGBoost training.


import xgboost

import ray.data
import ray.train
from ray.train.xgboost import RayTrainReportCallback, XGBoostTrainer

def train_fn_per_worker(config: dict):
    # (Optional) Add logic to resume training state from a checkpoint.
    # ray.train.get_checkpoint()

    # 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix`
    train_ds_iter, eval_ds_iter = (
    train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()

    train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
    train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
    eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]

    dtrain = xgboost.DMatrix(train_X, label=train_y)
    deval = xgboost.DMatrix(eval_X, label=eval_y)

    params = {
        "tree_method": "approx",
        "objective": "reg:squarederror",
        "eta": 1e-4,
        "subsample": 0.5,
        "max_depth": 2,

    # 2. Do distributed data-parallel training.
    # Ray Train sets up the necessary coordinator processes and
    # environment variables for your workers to communicate with each other.
    bst = xgboost.train(
        evals=[(deval, "validation")],

train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
trainer = XGBoostTrainer(
    datasets={"train": train_ds, "validation": eval_ds},
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)
  • train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single Dict argument which is set by defining train_loop_config. Within this function you can use any of the Ray Train Loop utilities.

  • train_loop_config – A configuration Dict to pass in as an argument to train_loop_per_worker. This is typically used for specifying hyperparameters.

  • xgboost_config – The configuration for setting up the distributed xgboost backend. Defaults to using the “rabit” backend. See XGBoostConfig for more info.

  • datasets – The Ray Datasets to use for training and validation.

  • dataset_config – The configuration for ingesting the input datasets. By default, all the Ray Datasets are split equally across workers. See DataConfig for more details.

  • scaling_config – The configuration for how to scale data parallel training. num_workers determines how many Python processes are used for training, and use_gpu determines whether or not each process should use GPUs. See ScalingConfig for more info.

  • run_config – The configuration for the execution of the training run. See RunConfig for more info.

  • resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within train_loop_per_worker by calling ray.train.get_checkpoint().

  • metadata – Dict that should be made available via ray.train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

  • label_column – [Deprecated] Name of the label column. A column with this name must be present in the training dataset.

  • params – [Deprecated] XGBoost training parameters. Refer to XGBoost documentation for a list of possible parameters.

  • num_boost_round – [Deprecated] Target number of boosting iterations (trees in the model). Note that unlike in xgboost.train, this is the target number of trees, meaning that if you set num_boost_round=10 and pass a model that has already been trained for 5 iterations, it will be trained for 5 iterations more, instead of 10 more.

  • **train_kwargs – [Deprecated] Additional kwargs passed to xgboost.train() function.

PublicAPI (beta): This API is in beta and may change before becoming stable.



Converts self to a tune.Trainable class.


Checks whether a given directory contains a restorable Train experiment.


Runs training.


Returns a copy of this Trainer's final dataset configs.


Retrieve the XGBoost model stored in this checkpoint.




Restores a DataParallelTrainer from a previously interrupted/failed run.


Called during fit() to perform initial setup on the Trainer.