ray.train.gbdt_trainer.GBDTTrainer#

class ray.train.gbdt_trainer.GBDTTrainer(*args, **kwargs)[source]#

Bases: ray.train.base_trainer.BaseTrainer

Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.

Inherited by XGBoostTrainer and LightGBMTrainer.

Parameters
  • datasets – Datasets to use for training and validation. Must include a “train” key denoting the training dataset. All non-training datasets will be used as separate validation sets, each reporting a separate metric.

  • label_column – Name of the label column. A column with this name must be present in the training dataset.

  • params – Framework specific training parameters.

  • dmatrix_params – Dict of dataset name:dict of kwargs passed to respective xgboost_ray.RayDMatrix initializations.

  • num_boost_round – Target number of boosting iterations (trees in the model).

  • scaling_config – Configuration for how to scale data parallel training.

  • run_config – Configuration for the execution of the training run.

  • resume_from_checkpoint – A checkpoint to resume training from.

  • metadata – Dict that should be made available in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

  • **train_kwargs – Additional kwargs passed to framework train() function.

DeveloperAPI: This API may change across minor Ray releases.

Methods

as_trainable

Converts self to a tune.Trainable class.

can_restore

Checks whether a given directory contains a restorable Train experiment.

fit

Runs training.

preprocess_datasets

Called during fit() to preprocess dataset attributes with preprocessor.

restore

Restores a Train experiment from a previously interrupted/failed run.

setup

Called during fit() to perform initial setup on the Trainer.