ray.train.gbdt_trainer.GBDTTrainer#

class ray.train.gbdt_trainer.GBDTTrainer(*args, **kwargs)[source]#

Bases: ray.train.base_trainer.BaseTrainer

Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.

Inherited by XGBoostTrainer and LightGBMTrainer.

Parameters
  • datasets – Ray Datasets to use for training and validation. Must include a “train” key denoting the training dataset. If a preprocessor is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the preprocessor if one is provided. All non-training datasets will be used as separate validation sets, each reporting a separate metric.

  • label_column – Name of the label column. A column with this name must be present in the training dataset.

  • params – Framework specific training parameters.

  • dmatrix_params – Dict of dataset name:dict of kwargs passed to respective xgboost_ray.RayDMatrix initializations.

  • scaling_config – Configuration for how to scale data parallel training.

  • run_config – Configuration for the execution of the training run.

  • preprocessor – A ray.data.Preprocessor to preprocess the provided datasets.

  • resume_from_checkpoint – A checkpoint to resume training from.

  • **train_kwargs – Additional kwargs passed to framework train() function.

DeveloperAPI: This API may change across minor Ray releases.

Methods

as_trainable()

Convert self to a tune.Trainable class.

fit()

Runs training.

setup()

Called during fit() to perform initial setup on the Trainer.