ray.train.gbdt_trainer.GBDTTrainer
ray.train.gbdt_trainer.GBDTTrainer#
- class ray.train.gbdt_trainer.GBDTTrainer(*args, **kwargs)[source]#
Bases:
ray.train.base_trainer.BaseTrainer
Abstract class for scaling gradient-boosting decision tree (GBDT) frameworks.
Inherited by XGBoostTrainer and LightGBMTrainer.
- Parameters
datasets – Datasets to use for training and validation. Must include a “train” key denoting the training dataset. All non-training datasets will be used as separate validation sets, each reporting a separate metric.
label_column – Name of the label column. A column with this name must be present in the training dataset.
params – Framework specific training parameters.
dmatrix_params – Dict of
dataset name:dict of kwargs
passed to respectivexgboost_ray.RayDMatrix
initializations.num_boost_round – Target number of boosting iterations (trees in the model).
scaling_config – Configuration for how to scale data parallel training.
run_config – Configuration for the execution of the training run.
resume_from_checkpoint – A checkpoint to resume training from.
metadata – Dict that should be made available in
checkpoint.get_metadata()
for checkpoints saved from this Trainer. Must be JSON-serializable.**train_kwargs – Additional kwargs passed to framework
train()
function.
DeveloperAPI: This API may change across minor Ray releases.
Methods
Converts self to a
tune.Trainable
class.can_restore
(path[, storage_filesystem])Checks whether a given directory contains a restorable Train experiment.
fit
()Runs training.
Called during fit() to preprocess dataset attributes with preprocessor.
restore
(path[, storage_filesystem, ...])Restores a Train experiment from a previously interrupted/failed run.
setup
()Called during fit() to perform initial setup on the Trainer.