Training a model with distributed LightGBM#

In this example we will train a model in Ray Train using distributed LightGBM.

Let’s start with installing our dependencies:

!pip install -qU "ray[data,train]"
[notice] A new release of pip available: 22.3.1 -> 23.1.2
[notice] To update, run: pip install --upgrade pip

Then we need some imports:

from typing import Tuple

import ray
from import Dataset, Preprocessor
from import Categorizer, StandardScaler
from ray.train.lightgbm import LightGBMTrainer
from ray.train import Result, ScalingConfig
/Users/balaji/Documents/GitHub/ray/.venv/lib/python3.11/site-packages/tqdm/ TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See
  from .autonotebook import tqdm as notebook_tqdm
2023-07-07 14:34:14,951	INFO -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2023-07-07 14:34:15,892	INFO -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.

Next we define a function to load our train, validation, and test datasets.

def prepare_data() -> Tuple[Dataset, Dataset, Dataset]:
    dataset ="s3://anonymous@air-example-data/breast_cancer_with_categorical.csv")
    train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)
    test_dataset = valid_dataset.drop_columns(cols=["target"])
    return train_dataset, valid_dataset, test_dataset

The following function will create a LightGBM trainer, train it, and return the result.

def train_lightgbm(num_workers: int, use_gpu: bool = False) -> Result:
    train_dataset, valid_dataset, _ = prepare_data()

    # Scale some random columns, and categorify the categorical_column,
    # allowing LightGBM to use its built-in categorical feature support
    scaler = StandardScaler(columns=["mean radius", "mean texture"])
    categorizer = Categorizer(["categorical_column"])

    train_dataset = categorizer.fit_transform(scaler.fit_transform(train_dataset))
    valid_dataset = categorizer.transform(scaler.transform(valid_dataset))

    # LightGBM specific params
    params = {
        "objective": "binary",
        "metric": ["binary_logloss", "binary_error"],

    trainer = LightGBMTrainer(
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
        datasets={"train": train_dataset, "valid": valid_dataset},
        metadata = {"scaler_pkl": scaler.serialize(), "categorizer_pkl": categorizer.serialize()}
    result =

    return result

Once we have the result, we can do batch inference on the obtained model. Let’s define a utility function for this.

import pandas as pd
from ray.train import Checkpoint

class Predict:

    def __init__(self, checkpoint: Checkpoint):
        self.model = LightGBMTrainer.get_model(checkpoint)
        self.scaler = Preprocessor.deserialize(checkpoint.get_metadata()["scaler_pkl"])
        self.categorizer = Preprocessor.deserialize(checkpoint.get_metadata()["categorizer_pkl"])

    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        preprocessed_batch = self.categorizer.transform_batch(self.scaler.transform_batch(batch))
        return {"predictions": self.model.predict(preprocessed_batch)}

def predict_lightgbm(result: Result):
    _, _, test_dataset = prepare_data()

    scores = test_dataset.map_batches(
    predicted_labels = scores.map_batches(lambda df: (df > 0.5).astype(int), batch_format="pandas")
    print(f"PREDICTED LABELS")

Now we can run the training:

result = train_lightgbm(num_workers=2, use_gpu=False)

Tune Status

Current time:2023-07-07 14:34:34
Running for: 00:00:06.06
Memory: 12.2/64.0 GiB

System Info

Using FIFO scheduling algorithm.
Logical resource usage: 4.0/10 CPUs, 0/0 GPUs

Trial Status

Trial name status loc iter total time (s) train-binary_logloss train-binary_error valid-binary_logloss
LightGBMTrainer_0c5ae_00000TERMINATED127.0.0.1:10027 101 4.5829 0.000202293 0 0.130232
(LightGBMTrainer pid=10027) The `preprocessor` arg to Trainer is deprecated. Apply preprocessor transformations ahead of time by calling `preprocessor.transform(ds)`. Support for the preprocessor arg will be dropped in a future release.
(LightGBMTrainer pid=10027) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(get_pd_value_counts)]
(LightGBMTrainer pid=10027) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(LightGBMTrainer pid=10027) Tip: For detailed progress reporting, run ` = True`
(LightGBMTrainer pid=10027) Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
(LightGBMTrainer pid=10027) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(Categorizer._transform_pandas)] -> AllToAllOperator[Aggregate]
(LightGBMTrainer pid=10027) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(LightGBMTrainer pid=10027) Tip: For detailed progress reporting, run ` = True`

(pid=10027) Running: 0.0/10.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory:   0%|          | 0/14 [00:00<?, ?it/s] 


(LightGBMTrainer pid=10027) Warning: The Ray cluster currently does not have any available CPUs. The Dataset job will hang unless more CPUs are freed up. A common reason is that cluster resources are used by Actors or Tune trials; see the following link for more details:

(pid=10027) Running: 0.0/10.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory:   7%|▋         | 1/14 [00:00<00:01,  9.53it/s]


(LightGBMTrainer pid=10027) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(Categorizer._transform_pandas)->MapBatches(StandardScaler._transform_pandas)]

(pid=10027) Running: 0.0/10.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory:   7%|▋         | 1/14 [00:00<00:01,  7.59it/s]


(LightGBMTrainer pid=10027) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)

(pid=10027) Running: 0.0/10.0 CPU, 0.0/0.0 GPU, 0.0 MiB/512.0 MiB object_store_memory:   7%|▋         | 1/14 [00:00<00:01,  6.59it/s]


(LightGBMTrainer pid=10027) Tip: For detailed progress reporting, run ` = True`


(LightGBMTrainer pid=10027) Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[MapBatches(Categorizer._transform_pandas)->MapBatches(StandardScaler._transform_pandas)]
(LightGBMTrainer pid=10027) Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
(LightGBMTrainer pid=10027) Tip: For detailed progress reporting, run ` = True`
(_RemoteRayLightGBMActor pid=10063) [LightGBM] [Info] Trying to bind port 51134...
(_RemoteRayLightGBMActor pid=10063) [LightGBM] [Info] Binding port 51134 succeeded
(_RemoteRayLightGBMActor pid=10063) [LightGBM] [Info] Listening...
(_RemoteRayLightGBMActor pid=10062) [LightGBM] [Warning] Connecting to rank 1 failed, waiting for 200 milliseconds
(_RemoteRayLightGBMActor pid=10063) [LightGBM] [Info] Connected to rank 0
(_RemoteRayLightGBMActor pid=10063) [LightGBM] [Info] Local rank: 1, total number of machines: 2
(_RemoteRayLightGBMActor pid=10063) [LightGBM] [Warning] num_threads is set=2, n_jobs=-1 will be ignored. Current value: num_threads=2
(_RemoteRayLightGBMActor pid=10062) /Users/balaji/Documents/GitHub/ray/.venv/lib/python3.11/site-packages/lightgbm/ UserWarning: Overriding the parameters from Reference Dataset.
(_RemoteRayLightGBMActor pid=10062)   _log_warning('Overriding the parameters from Reference Dataset.')
(_RemoteRayLightGBMActor pid=10062) /Users/balaji/Documents/GitHub/ray/.venv/lib/python3.11/site-packages/lightgbm/ UserWarning: categorical_column in param dict is overridden.
(_RemoteRayLightGBMActor pid=10062)   _log_warning(f'{cat_alias} in param dict is overridden.')
2023-07-07 14:34:34,087	INFO -- Total run time: 7.18 seconds (6.05 seconds for the tuning loop).
{'train-binary_logloss': 0.00020229312743896637, 'train-binary_error': 0.0, 'valid-binary_logloss': 0.13023245107091222, 'valid-binary_error': 0.023529411764705882, 'time_this_iter_s': 0.021785974502563477, 'should_checkpoint': True, 'done': True, 'training_iteration': 101, 'trial_id': '0c5ae_00000', 'date': '2023-07-07_14-34-34', 'timestamp': 1688765674, 'time_total_s': 4.582904100418091, 'pid': 10027, 'hostname': 'Balajis-MacBook-Pro-16', 'node_ip': '', 'config': {}, 'time_since_restore': 4.582904100418091, 'iterations_since_restore': 101, 'experiment_tag': '0'}

And perform inference on the obtained model:

2023-07-07 14:34:36,769	INFO -- To satisfy the requested parallelism of 20, each read task output will be split into 20 smaller blocks.
2023-07-07 14:34:38,655	WARNING -- Warning: The Ray cluster currently does not have any available CPUs. The Dataset job will hang unless more CPUs are freed up. A common reason is that cluster resources are used by Actors or Tune trials; see the following link for more details:
2023-07-07 14:34:38,668	INFO -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2023-07-07 14:34:38,674	INFO -- Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[MapBatches(<lambda>)->MapBatches(Predict)] -> TaskPoolMapOperator[MapBatches(<lambda>)]
2023-07-07 14:34:38,674	INFO -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
2023-07-07 14:34:38,676	INFO -- Tip: For detailed progress reporting, run ` = True`
2023-07-07 14:34:38,701	INFO -- MapBatches(<lambda>)->MapBatches(Predict): Waiting for 1 pool actors to start...
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}
{'predictions': 1}
{'predictions': 1}
{'predictions': 1}
{'predictions': 0}