Using LightGBM with Tune

LightGBM Logo

Example

import lightgbm as lgb
import numpy as np
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split

from ray import tune
from ray.air import session
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.lightgbm import TuneReportCheckpointCallback


def train_breast_cancer(config):

    data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
    train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)
    train_set = lgb.Dataset(train_x, label=train_y)
    test_set = lgb.Dataset(test_x, label=test_y)
    gbm = lgb.train(
        config,
        train_set,
        valid_sets=[test_set],
        valid_names=["eval"],
        verbose_eval=False,
        callbacks=[
            TuneReportCheckpointCallback(
                {
                    "binary_error": "eval-binary_error",
                    "binary_logloss": "eval-binary_logloss",
                }
            )
        ],
    )
    preds = gbm.predict(test_x)
    pred_labels = np.rint(preds)
    session.report({
        "mean_accuracy": sklearn.metrics.accuracy_score(test_y, pred_labels), "done": True
    })


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--server-address",
        type=str,
        default=None,
        required=False,
        help="The address of server to connect to if using " "Ray Client.",
    )
    args, _ = parser.parse_known_args()

    if args.server_address:
        import ray

        ray.init(f"ray://{args.server_address}")

    config = {
        "objective": "binary",
        "metric": ["binary_error", "binary_logloss"],
        "verbose": -1,
        "boosting_type": tune.grid_search(["gbdt", "dart"]),
        "num_leaves": tune.randint(10, 1000),
        "learning_rate": tune.loguniform(1e-8, 1e-1),
    }
    
    tuner = tune.Tuner(
        train_breast_cancer,
        tune_config=tune.TuneConfig(
            metric="binary_error",
            mode="min",
            scheduler=ASHAScheduler(),
            num_samples=2,
        ),
        param_space=config,
    )
    results = tuner.fit()

    print("Best hyperparameters found were: ", results.get_best_result().config)
2022-07-22 15:30:02,623	INFO services.py:1483 -- View the Ray dashboard at http://127.0.0.1:8265
2022-07-22 15:30:05,042	WARNING function_trainable.py:619 -- Function checkpointing is disabled. This may result in unexpected behavior when using checkpointing features or certain schedulers. To enable, set the train function arguments to be `func(config, checkpoint_dir=None)`.
== Status ==
Current time: 2022-07-22 15:30:18 (running for 00:00:12.88)
Memory usage on this node: 10.1/16.0 GiB
Using AsyncHyperBand: num_stopped=4 Bracket: Iter 64.000: -0.32867132867132864 | Iter 16.000: -0.32867132867132864 | Iter 4.000: -0.32867132867132864 | Iter 1.000: -0.35664335664335667
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/5.3 GiB heap, 0.0/2.0 GiB objects
Current best trial: c7534_00003 with binary_error=0.3146853146853147 and parameters={'objective': 'binary', 'metric': ['binary_error', 'binary_logloss'], 'verbose': -1, 'boosting_type': 'dart', 'num_leaves': 702, 'learning_rate': 4.858514533326432e-08}
Result logdir: /Users/kai/ray_results/train_breast_cancer_2022-07-22_15-29-59
Number of trials: 4/4 (4 TERMINATED)
Trial name status loc boosting_type learning_rate num_leaves iter total time (s) binary_error binary_logloss
train_breast_cancer_c7534_00000TERMINATED127.0.0.1:46947gbdt 1.09528e-08 926 100 4.04621 0.370629 0.659303
train_breast_cancer_c7534_00001TERMINATED127.0.0.1:46965dart 9.07058e-05 512 1 0.0379331 0.391608 0.670769
train_breast_cancer_c7534_00002TERMINATED127.0.0.1:46987gbdt 0.00110605 186 1 0.0196211 0.405594 0.678443
train_breast_cancer_c7534_00003TERMINATED127.0.0.1:46988dart 4.85851e-08 702 100 0.417179 0.314685 0.655626


2022-07-22 15:30:06,224	INFO plugin_schema_manager.py:52 -- Loading the default runtime env schemas: ['/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/working_dir_schema.json', '/Users/kai/coding/ray/python/ray/_private/runtime_env/../../runtime_env/schemas/pip_schema.json'].
(train_breast_cancer pid=46947) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.
(train_breast_cancer pid=46947)   _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
Result for train_breast_cancer_c7534_00000:
  binary_error: 0.3706293706293706
  binary_logloss: 0.6593043583564255
  date: 2022-07-22_15-30-11
  done: false
  experiment_id: 9fbbf2cd94b24a14aa5ef2d552e78b70
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 1
  node_ip: 127.0.0.1
  pid: 46947
  time_since_restore: 0.10576009750366211
  time_this_iter_s: 0.10576009750366211
  time_total_s: 0.10576009750366211
  timestamp: 1658500211
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: c7534_00000
  warmup_time: 0.0033888816833496094
  
Result for train_breast_cancer_c7534_00001:
  binary_error: 0.3916083916083916
  binary_logloss: 0.670769405026208
  date: 2022-07-22_15-30-14
  done: true
  experiment_id: 10df796f3d2e4627ba7526014b21f426
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 1
  node_ip: 127.0.0.1
  pid: 46965
  time_since_restore: 0.0379331111907959
  time_this_iter_s: 0.0379331111907959
  time_total_s: 0.0379331111907959
  timestamp: 1658500214
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: c7534_00001
  warmup_time: 0.0033578872680664062
  
(train_breast_cancer pid=46965) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.
(train_breast_cancer pid=46965)   _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
Result for train_breast_cancer_c7534_00000:
  binary_error: 0.3706293706293706
  binary_logloss: 0.6593034612409915
  date: 2022-07-22_15-30-15
  done: true
  experiment_id: 9fbbf2cd94b24a14aa5ef2d552e78b70
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 100
  node_ip: 127.0.0.1
  pid: 46947
  time_since_restore: 4.046205043792725
  time_this_iter_s: 0.002338886260986328
  time_total_s: 4.046205043792725
  timestamp: 1658500215
  timesteps_since_restore: 0
  training_iteration: 100
  trial_id: c7534_00000
  warmup_time: 0.0033888816833496094
  
Result for train_breast_cancer_c7534_00003:
  binary_error: 0.3146853146853147
  binary_logloss: 0.635705942279978
  date: 2022-07-22_15-30-18
  done: false
  experiment_id: d370b87343ea4a8e994bcf99a4f6f28d
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 1
  node_ip: 127.0.0.1
  pid: 46988
  time_since_restore: 0.04007911682128906
  time_this_iter_s: 0.04007911682128906
  time_total_s: 0.04007911682128906
  timestamp: 1658500218
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: c7534_00003
  warmup_time: 0.0032351016998291016
  
Result for train_breast_cancer_c7534_00002:
  binary_error: 0.40559440559440557
  binary_logloss: 0.6784426899984863
  date: 2022-07-22_15-30-18
  done: true
  experiment_id: 96e95ab236aa40aea3e9a1218293b562
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 1
  node_ip: 127.0.0.1
  pid: 46987
  time_since_restore: 0.01962113380432129
  time_this_iter_s: 0.01962113380432129
  time_total_s: 0.01962113380432129
  timestamp: 1658500218
  timesteps_since_restore: 0
  training_iteration: 1
  trial_id: c7534_00002
  warmup_time: 0.0026988983154296875
  
(train_breast_cancer pid=46987) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.
(train_breast_cancer pid=46987)   _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
(train_breast_cancer pid=46988) /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.
(train_breast_cancer pid=46988)   _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
Result for train_breast_cancer_c7534_00003:
  binary_error: 0.3146853146853147
  binary_logloss: 0.6556262981958247
  date: 2022-07-22_15-30-18
  done: true
  experiment_id: d370b87343ea4a8e994bcf99a4f6f28d
  hostname: Kais-MacBook-Pro.local
  iterations_since_restore: 100
  node_ip: 127.0.0.1
  pid: 46988
  time_since_restore: 0.4171791076660156
  time_this_iter_s: 0.0024061203002929688
  time_total_s: 0.4171791076660156
  timestamp: 1658500218
  timesteps_since_restore: 0
  training_iteration: 100
  trial_id: c7534_00003
  warmup_time: 0.0032351016998291016
  
2022-07-22 15:30:18,873	INFO tune.py:738 -- Total run time: 13.83 seconds (12.87 seconds for the tuning loop).
Best hyperparameters found were:  {'objective': 'binary', 'metric': ['binary_error', 'binary_logloss'], 'verbose': -1, 'boosting_type': 'dart', 'num_leaves': 702, 'learning_rate': 4.858514533326432e-08}