ray.train.xgboost.xgboost_trainer.XGBoostTrainer#
- class ray.train.xgboost.xgboost_trainer.XGBoostTrainer(*args, **kwargs)[source]#
Bases:
XGBoostTrainer
A Trainer for distributed data-parallel XGBoost training.
Example
import xgboost import ray.data import ray.train from ray.train.xgboost import RayTrainReportCallback, XGBoostTrainer def train_fn_per_worker(config: dict): # (Optional) Add logic to resume training state from a checkpoint. # ray.train.get_checkpoint() # 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix` train_ds_iter, eval_ds_iter = ( ray.train.get_dataset_shard("train"), ray.train.get_dataset_shard("validation"), ) train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize() train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas() train_X, train_y = train_df.drop("y", axis=1), train_df["y"] eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"] dtrain = xgboost.DMatrix(train_X, label=train_y) deval = xgboost.DMatrix(eval_X, label=eval_y) params = { "tree_method": "approx", "objective": "reg:squarederror", "eta": 1e-4, "subsample": 0.5, "max_depth": 2, } # 2. Do distributed data-parallel training. # Ray Train sets up the necessary coordinator processes and # environment variables for your workers to communicate with each other. bst = xgboost.train( params, dtrain=dtrain, evals=[(deval, "validation")], num_boost_round=10, callbacks=[RayTrainReportCallback()], ) train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)]) trainer = XGBoostTrainer( train_fn_per_worker, datasets={"train": train_ds, "validation": eval_ds}, scaling_config=ray.train.ScalingConfig(num_workers=4), ) result = trainer.fit() booster = RayTrainReportCallback.get_model(result.checkpoint)
- Parameters:
train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single
Dict
argument which is set by definingtrain_loop_config
. Within this function you can use any of the Ray Train Loop utilities.train_loop_config – A configuration
Dict
to pass in as an argument totrain_loop_per_worker
. This is typically used for specifying hyperparameters.xgboost_config – The configuration for setting up the distributed xgboost backend. Defaults to using the “rabit” backend. See
XGBoostConfig
for more info.datasets – The Ray Datasets to use for training and validation.
dataset_config – The configuration for ingesting the input
datasets
. By default, all the Ray Datasets are split equally across workers. SeeDataConfig
for more details.scaling_config – The configuration for how to scale data parallel training.
num_workers
determines how many Python processes are used for training, anduse_gpu
determines whether or not each process should use GPUs. SeeScalingConfig
for more info.run_config – The configuration for the execution of the training run. See
RunConfig
for more info.resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within
train_loop_per_worker
by callingray.train.get_checkpoint()
.metadata – Dict that should be made available via
ray.train.get_context().get_metadata()
and incheckpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.label_column – [Deprecated] Name of the label column. A column with this name must be present in the training dataset.
params – [Deprecated] XGBoost training parameters. Refer to XGBoost documentation for a list of possible parameters.
num_boost_round – [Deprecated] Target number of boosting iterations (trees in the model). Note that unlike in
xgboost.train
, this is the target number of trees, meaning that if you setnum_boost_round=10
and pass a model that has already been trained for 5 iterations, it will be trained for 5 iterations more, instead of 10 more.**train_kwargs – [Deprecated] Additional kwargs passed to
xgboost.train()
function.
PublicAPI (beta): This API is in beta and may change before becoming stable.
Methods
Converts self to a
tune.Trainable
class.Checks whether a given directory contains a restorable Train experiment.
Runs training.
Returns a copy of this Trainer's final dataset configs.
Retrieve the XGBoost model stored in this checkpoint.
Deprecated.
Restores a DataParallelTrainer from a previously interrupted/failed run.
Called during fit() to perform initial setup on the Trainer.