ray.train.lightgbm.LightGBMTrainer#

class ray.train.lightgbm.LightGBMTrainer(train_loop_per_worker: Callable[[], None] | Callable[[Dict], None], *, train_loop_config: Dict | None = None, lightgbm_config: LightGBMConfig | None = None, scaling_config: ScalingConfig | None = None, run_config: RunConfig | None = None, datasets: Dict[str, Dataset | Callable[[], Dataset]] | None = None, dataset_config: DataConfig | None = None, metadata: Dict[str, Any] | None = None, resume_from_checkpoint: Checkpoint | None = None)[source]#

Bases: DataParallelTrainer

A Trainer for distributed data-parallel LightGBM training.

Example

import lightgbm as lgb

import ray.data
import ray.train
from ray.train.lightgbm import RayTrainReportCallback
from ray.train.lightgbm.v2 import LightGBMTrainer

def train_fn_per_worker(config: dict):
    # (Optional) Add logic to resume training state from a checkpoint.
    # ray.train.get_checkpoint()

    # 1. Get the dataset shard for the worker and convert to a `lgb.Dataset`
    train_ds_iter, eval_ds_iter = (
        ray.train.get_dataset_shard("train"),
        ray.train.get_dataset_shard("validation"),
    )
    train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
    train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
    train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
    eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]

    train_set = lgb.Dataset(train_X, label=train_y)
    eval_set = lgb.Dataset(eval_X, label=eval_y)

    # 2. Run distributed data-parallel training.
    # `get_network_params` sets up the necessary configurations for LightGBM
    # to set up the data parallel training worker group on your Ray cluster.
    params = {
        "objective": "regression",
        # Adding the line below is the only change needed
        # for your `lgb.train` call!
        **ray.train.lightgbm.v2.get_network_params(),
    }
    lgb.train(
        params,
        train_set,
        valid_sets=[eval_set],
        valid_names=["eval"],
        # To access the checkpoint from trainer, you need this callback.
        callbacks=[RayTrainReportCallback()],
    )

train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
eval_ds = ray.data.from_items(
    [{"x": x, "y": x + 1} for x in range(32, 32 + 16)]
)
trainer = LightGBMTrainer(
    train_fn_per_worker,
    datasets={"train": train_ds, "validation": eval_ds},
    scaling_config=ray.train.ScalingConfig(num_workers=4),
)
result = trainer.fit()
booster = RayTrainReportCallback.get_model(result.checkpoint)
Parameters:
  • train_loop_per_worker – The training function to execute on each worker. This function can either take in zero arguments or a single Dict argument which is set by defining train_loop_config. Within this function you can use any of the Ray Train Loop utilities.

  • train_loop_config – A configuration Dict to pass in as an argument to train_loop_per_worker. This is typically used for specifying hyperparameters.

  • lightgbm_config – The configuration for setting up the distributed lightgbm backend. See LightGBMConfig for more info.

  • scaling_config – The configuration for how to scale data parallel training. num_workers determines how many Python processes are used for training, and use_gpu determines whether or not each process should use GPUs. See ScalingConfig for more info.

  • run_config – The configuration for the execution of the training run. See RunConfig for more info.

  • datasets – The Ray Datasets to ingest for training. Datasets are keyed by name ({name: dataset}). Each dataset can be accessed from within the train_loop_per_worker by calling ray.train.get_dataset_shard(name). Sharding and additional configuration can be done by passing in a dataset_config.

  • dataset_config – The configuration for ingesting the input datasets. By default, all the Ray Dataset are split equally across workers. See DataConfig for more details.

  • resume_from_checkpoint – A checkpoint to resume training from. This checkpoint can be accessed from within train_loop_per_worker by calling ray.train.get_checkpoint().

  • metadata – Dict that should be made available via ray.train.get_context().get_metadata() and in checkpoint.get_metadata() for checkpoints saved from this Trainer. Must be JSON-serializable.

Methods

can_restore

[Deprecated] Checks if a Train experiment can be restored from a previously interrupted/failed run.

fit

Launches the Ray Train controller to run training on workers.

get_model

Retrieve the LightGBM model stored in this checkpoint.

restore

[Deprecated] Restores a Train experiment from a previously interrupted/failed run.