ray.train.lightgbm.LightGBMTrainer#
- class ray.train.lightgbm.LightGBMTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, lightgbm_config: LightGBMConfig | None = None, scaling_config: ScalingConfig | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, dataset_config: DataConfig | None = None, metadata: Dict[str, Any] | None = None, resume_from_checkpoint: Checkpoint | None = None, label_column: str | None = None, params: Dict[str, Any] | None = None, num_boost_round: int | None = None)[source]#
- Bases: - DataParallelTrainer- A Trainer for distributed data-parallel LightGBM training. - Example - import lightgbm as lgb import ray.data import ray.train from ray.train.lightgbm import RayTrainReportCallback from ray.train.lightgbm import LightGBMTrainer def train_fn_per_worker(config: dict): # (Optional) Add logic to resume training state from a checkpoint. # ray.train.get_checkpoint() # 1. Get the dataset shard for the worker and convert to a `lgb.Dataset` train_ds_iter, eval_ds_iter = ( ray.train.get_dataset_shard("train"), ray.train.get_dataset_shard("validation"), ) train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize() train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas() train_X, train_y = train_df.drop("y", axis=1), train_df["y"] eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"] train_set = lgb.Dataset(train_X, label=train_y) eval_set = lgb.Dataset(eval_X, label=eval_y) # 2. Run distributed data-parallel training. # `get_network_params` sets up the necessary configurations for LightGBM # to set up the data parallel training worker group on your Ray cluster. params = { "objective": "regression", # Adding the line below is the only change needed # for your `lgb.train` call! **ray.train.lightgbm.get_network_params(), } lgb.train( params, train_set, valid_sets=[eval_set], valid_names=["eval"], num_boost_round=1, # To access the checkpoint from trainer, you need this callback. callbacks=[RayTrainReportCallback()], ) train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)]) eval_ds = ray.data.from_items( [{"x": x, "y": x + 1} for x in range(32, 32 + 16)] ) trainer = LightGBMTrainer( train_fn_per_worker, datasets={"train": train_ds, "validation": eval_ds}, scaling_config=ray.train.ScalingConfig(num_workers=2), ) result = trainer.fit() booster = RayTrainReportCallback.get_model(result.checkpoint) - Parameters:
- train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single - Dictargument which is set by defining- train_loop_config. Within this function you can use any of the Ray Train Loop utilities.
- train_loop_config – A configuration - Dictto pass in as an argument to- train_loop_per_worker. This is typically used for specifying hyperparameters.
- lightgbm_config – The configuration for setting up the distributed lightgbm backend. See - LightGBMConfigfor more info.
- scaling_config – The configuration for how to scale data parallel training. - num_workersdetermines how many Python processes are used for training, and- use_gpudetermines whether or not each process should use GPUs. See- ScalingConfigfor more info.
- run_config – The configuration for the execution of the training run. See - RunConfigfor more info.
- datasets – The Ray Datasets to ingest for training. Datasets are keyed by name ( - {name: dataset}). Each dataset can be accessed from within the- train_loop_per_workerby calling- ray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in a- dataset_config.
- dataset_config – The configuration for ingesting the input - datasets. By default, all the Ray Dataset are split equally across workers. See- DataConfigfor more details.
- resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within - train_loop_per_workerby calling- ray.train.get_checkpoint().
- metadata – Dict that should be made available via - ray.train.get_context().get_metadata()and in- checkpoint.get_metadata()for checkpoints saved from this Trainer. Must be JSON-serializable.
 
 - Methods - [Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run. - Launches the Ray Train controller to run training on workers. - Retrieve the LightGBM model stored in this checkpoint. - [Deprecated] Restores a Train experiment from a previously interrupted/failed run.