{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "edce67b9", "metadata": {}, "source": [ "# Tuning XGBoost hyperparameters with Ray Tune\n", "\n", "\n", " \"try-anyscale-quickstart\"\n", "\n", "

\n", "\n", "(tune-xgboost-ref)=\n", "\n", "This tutorial demonstrates how to optimize XGBoost models using Ray Tune. You'll learn:\n", "- The basics of XGBoost and its key hyperparameters\n", "- How to train a simple XGBoost classifier (without hyperparameter tuning)\n", "- How to use Ray Tune to find optimal hyperparameters\n", "- Advanced techniques like early stopping and GPU acceleration\n", "\n", "XGBoost is currently one of the most popular machine learning algorithms. It performs\n", "very well on a large selection of tasks, and was the key to success in many Kaggle\n", "competitions.\n", "\n", "```{image} /images/xgboost_logo.png\n", ":align: center\n", ":alt: XGBoost\n", ":target: https://xgboost.readthedocs.io/en/latest/\n", ":width: 200px\n", "```\n", "\n", "```{contents}\n", ":depth: 2\n", "```\n", "\n", ":::{note}\n", "To run this tutorial, you will need to install the following:\n", "\n", "```bash\n", "$ pip install -q \"ray[tune]\" scikit-learn xgboost\n", "```\n", ":::\n", "\n", "## What is XGBoost\n", "\n", "\n", "XGBoost (e**X**treme **G**radient **Boost**ing) is a powerful and efficient implementation of gradient boosted [decision trees](https://en.wikipedia.org/wiki/Decision_tree). It has become one of the most popular machine learning algorithms due to its:\n", "\n", "1. Performance: Consistently strong results across many types of problems\n", "2. Speed: Highly optimized implementation that can leverage GPU acceleration \n", "3. Flexibility: Works with many types of prediction problems (classification, regression, ranking)\n", "\n", "Key Concepts:\n", "- Uses an ensemble of simple decision trees\n", "- Trees are built sequentially, with each tree correcting errors from previous trees\n", "- Employs gradient descent to minimize a loss function\n", "- Even though single trees can have high bias, using a boosted ensemble can result in better predictions and reduced bias\n", "\n", "\n", ":::{figure} /images/tune-xgboost-ensemble.svg\n", ":alt: Single vs. ensemble learning\n", "\n", "A single decision tree (left) might be able to get to an accuracy of 70%\n", "for a binary classification task. By combining the output of several small\n", "decision trees, an ensemble learner (right) might end up with a higher accuracy\n", "of 90%.\n", ":::\n", "\n", "Boosting algorithms start with a single small decision tree and evaluate how well\n", "it predicts the given examples. When building the next tree, those samples that have\n", "been misclassified before have a higher chance of being used to generate the tree.\n", "This is useful because it avoids overfitting to samples that can be easily classified\n", "and instead tries to come up with models that are able to classify hard examples, too.\n", "Please see [here for a more thorough introduction to bagging and boosting algorithms](https://towardsdatascience.com/ensemble-methods-bagging-boosting-and-stacking-c9214a10a205).\n", "\n", "There are many boosting algorithms. In their core, they are all very similar. XGBoost\n", "uses second-level derivatives to find splits that maximize the *gain* (the inverse of\n", "the *loss*) - hence the name. In practice, XGBoost usually shows the best performance\n", "against other boosting algorithms, although LightGBM tends to be [faster and more\n", "memory efficient](https://xgboosting.com/xgboost-vs-lightgbm/), especially for large datasets.\n", "\n", "## Training a simple XGBoost classifier\n", "\n", "Let's first see how a simple XGBoost classifier can be trained. We'll use the\n", "`breast_cancer`-Dataset included in the `sklearn` dataset collection. This is\n", "a binary classification dataset. Given 30 different input features, our task is to\n", "learn to identify subjects with breast cancer and those without.\n", "\n", "Here is the full code to train a simple XGBoost model:" ] }, { "cell_type": "code", "execution_count": 1, "id": "63611b7f", "metadata": {}, "outputs": [], "source": [ "SMOKE_TEST = False" ] }, { "cell_type": "code", "execution_count": 2, "id": "be0b8321", "metadata": { "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "SMOKE_TEST = True" ] }, { "cell_type": "code", "execution_count": 3, "id": "77b3c71c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.9650\n" ] } ], "source": [ "import sklearn.datasets\n", "import sklearn.metrics\n", "from sklearn.model_selection import train_test_split\n", "import xgboost as xgb\n", "\n", "\n", "def train_breast_cancer(config):\n", " # Load dataset\n", " data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)\n", " # Split into train and test set\n", " train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25)\n", " # Build input matrices for XGBoost\n", " train_set = xgb.DMatrix(train_x, label=train_y)\n", " test_set = xgb.DMatrix(test_x, label=test_y)\n", " # Train the classifier\n", " results = {}\n", " bst = xgb.train(\n", " config,\n", " train_set,\n", " evals=[(test_set, \"eval\")],\n", " evals_result=results,\n", " verbose_eval=False,\n", " )\n", " return results\n", "\n", "\n", "results = train_breast_cancer(\n", " {\"objective\": \"binary:logistic\", \"eval_metric\": [\"logloss\", \"error\"]}\n", ")\n", "accuracy = 1.0 - results[\"eval\"][\"error\"][-1]\n", "print(f\"Accuracy: {accuracy:.4f}\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ec2a13f8", "metadata": {}, "source": [ "As you can see, the code is quite simple. First, the dataset is loaded and split\n", "into a `test` and `train` set. The XGBoost model is trained with `xgb.train()`.\n", "XGBoost automatically evaluates metrics we specified on the test set. In our case\n", "it calculates the *logloss* and the prediction *error*, which is the percentage of\n", "misclassified examples. To calculate the accuracy, we just have to subtract the error\n", "from `1.0`. Even in this simple example, most runs result\n", "in a good accuracy of over `0.90`.\n", "\n", "Maybe you have noticed the `config` parameter we pass to the XGBoost algorithm. This\n", "is a {class}`dict` in which you can specify parameters for the XGBoost algorithm. In this\n", "simple example, the only parameters we passed are the `objective` and `eval_metric` parameters.\n", "The value `binary:logistic` tells XGBoost that we aim to train a logistic regression model for\n", "a binary classification task. You can find an overview over all valid objectives\n", "[here in the XGBoost documentation](https://xgboost.readthedocs.io/en/latest/parameter.html#learning-task-parameters).\n", "\n", "## Scaling XGBoost Training with Ray Train\n", "\n", "In {doc}`/train/examples/xgboost/distributed-xgboost-lightgbm`, we covered how to scale XGBoost single-model training with *Ray Train*.\n", "For the rest of this tutorial, we will focus on how to optimize the hyperparameters of the XGBoost model using *Ray Tune*.\n", "\n", "## XGBoost Hyperparameters\n", "\n", "Even with the default settings, XGBoost was able to get to a good accuracy on the\n", "breast cancer dataset. However, as in many machine learning algorithms, there are\n", "many knobs to tune which might lead to even better performance. Let's explore some of\n", "them below.\n", "\n", "### Maximum tree depth\n", "\n", "Remember that XGBoost internally uses many decision tree models to come up with\n", "predictions. When training a decision tree, we need to tell the algorithm how\n", "large the tree may get. The parameter for this is called the tree *depth*.\n", "\n", ":::{figure} /images/tune-xgboost-depth.svg\n", ":align: center\n", ":alt: Decision tree depth\n", "\n", "In this image, the left tree has a depth of 2, and the right tree a depth of 3.\n", "Note that with each level, $2^{(d-1)}$ splits are added, where *d* is the depth\n", "of the tree.\n", ":::\n", "\n", "Tree depth is a property that concerns the model complexity. If you only allow short\n", "trees, the models are likely not very precise - they underfit the data. If you allow\n", "very large trees, the single models are likely to overfit to the data. In practice,\n", "a number between `2` and `6` is often a good starting point for this parameter.\n", "\n", "XGBoost's default value is `3`.\n", "\n", "### Minimum child weight\n", "\n", "When a decision tree creates new leaves, it splits up the remaining data at one node\n", "into two groups. If there are only few samples in one of these groups, it often\n", "doesn't make sense to split it further. One of the reasons for this is that the\n", "model is harder to train when we have fewer samples.\n", "\n", ":::{figure} /images/tune-xgboost-weight.svg\n", ":align: center\n", ":alt: Minimum child weight\n", "\n", "In this example, we start with 100 examples. At the first node, they are split\n", "into 4 and 96 samples, respectively. In the next step, our model might find\n", "that it doesn't make sense to split the 4 examples more. It thus only continues\n", "to add leaves on the right side.\n", ":::\n", "\n", "The parameter used by the model to decide if it makes sense to split a node is called\n", "the *minimum child weight*. In the case of linear regression, this is just the absolute\n", "number of nodes requried in each child. In other objectives, this value is determined\n", "using the weights of the examples, hence the name.\n", "\n", "The larger the value, the more constrained the trees are and the less deep they will be.\n", "This parameter thus also affects the model complexity. Thus, for noisy or small datasets, \n", "smaller values are preferred. Values can range between 0 and infinity and are dependent on\n", "the sample size. For our case with only 500 examples in the breast cancer dataset, values \n", "between `0` and `10` should be sensible.\n", "\n", "XGBoost's default value is `1`.\n", "\n", "### Subsample size\n", "\n", "Each decision tree we add is trained on a subsample of the total training dataset.\n", "The probabilities for the samples are weighted according to the XGBoost algorithm,\n", "but we can decide on which fraction of the samples we want to train each decision\n", "tree on.\n", "\n", "Setting this value to `0.7` would mean that we randomly sample `70%` of the\n", "training dataset before each training iteration. Lower values lead to more\n", "diverse trees and higher values to more similar trees. Lower values help\n", "prevent overfitting.\n", "\n", "XGBoost's default value is `1`.\n", "\n", "### Learning rate / Eta\n", "\n", "Remember that XGBoost sequentially trains many decision trees, and that later trees\n", "are more likely trained on data that has been misclassified by prior trees. In effect\n", "this means that earlier trees make decisions for easy samples (i.e. those samples that\n", "can easily be classified) and later trees make decisions for harder samples. It is then\n", "sensible to assume that the later trees are less accurate than earlier trees.\n", "\n", "To address this fact, XGBoost uses a parameter called *Eta*, which is sometimes called\n", "the *learning rate*. Don't confuse this with learning rates from gradient descent!\n", "The original [paper on stochastic gradient boosting](https://www.researchgate.net/publication/222573328_Stochastic_Gradient_Boosting)\n", "introduces this parameter like so:\n", "\n", "$$\n", "F_m(x) = F_{m-1}(x) + \\eta \\cdot \\gamma_{lm} \\textbf{1}(x \\in R_{lm})\n", "$$\n", "\n", "This is just a complicated way to say that when we train we new decision tree,\n", "represented by $\\gamma_{lm} \\textbf{1}(x \\in R_{lm})$, we want to dampen\n", "its effect on the previous prediction $F_{m-1}(x)$ with a factor\n", "$\\eta$.\n", "\n", "Typical values for this parameter are between `0.01` and `` 0.3` ``.\n", "\n", "XGBoost's default value is `0.3`.\n", "\n", "### Number of boost rounds\n", "\n", "Lastly, we can decide on how many boosting rounds we perform, which means how\n", "many decision trees we ultimately train. When we do heavy subsampling or use small\n", "learning rate, it might make sense to increase the number of boosting rounds.\n", "\n", "XGBoost's default value is `10`.\n", "\n", "### Putting it together\n", "\n", "Let's see how this looks like in code! We just need to adjust our `config` dict:" ] }, { "cell_type": "code", "execution_count": 4, "id": "35073e88", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.9231\n" ] } ], "source": [ "config = {\n", " \"objective\": \"binary:logistic\",\n", " \"eval_metric\": [\"logloss\", \"error\"],\n", " \"max_depth\": 2,\n", " \"min_child_weight\": 0,\n", " \"subsample\": 0.8,\n", " \"eta\": 0.2,\n", "}\n", "results = train_breast_cancer(config)\n", "accuracy = 1.0 - results[\"eval\"][\"error\"][-1]\n", "print(f\"Accuracy: {accuracy:.4f}\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "69cf0c13", "metadata": {}, "source": [ "The rest stays the same. Please note that we do not adjust the `num_boost_rounds` here.\n", "The result should also show a high accuracy of over 90%.\n", "\n", "## Tuning the configuration parameters\n", "\n", "XGBoosts default parameters already lead to a good accuracy, and even our guesses in the\n", "last section should result in accuracies well above 90%. However, our guesses were\n", "just that: guesses. Often we do not know what combination of parameters would actually\n", "lead to the best results on a machine learning task.\n", "\n", "Unfortunately, there are infinitely many combinations of hyperparameters we could try\n", "out. Should we combine `max_depth=3` with `subsample=0.8` or with `subsample=0.9`?\n", "What about the other parameters?\n", "\n", "This is where hyperparameter tuning comes into play. By using tuning libraries such as\n", "Ray Tune we can try out combinations of hyperparameters. Using sophisticated search\n", "strategies, these parameters can be selected so that they are likely to lead to good\n", "results (avoiding an expensive *exhaustive search*). Also, trials that do not perform\n", "well can be preemptively stopped to reduce waste of computing resources. Lastly, Ray Tune\n", "also takes care of training these runs in parallel, greatly increasing search speed.\n", "\n", "Let's start with a basic example on how to use Tune for this. We just need to make\n", "a few changes to our code-block:" ] }, { "cell_type": "code", "execution_count": 5, "id": "ff856a82", "metadata": { "tags": [ "hide-output" ] }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2025-02-11 16:13:34
Running for: 00:00:01.87
Memory: 22.5/36.0 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Logical resource usage: 1.0/12 CPUs, 0/0 GPUs\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc eta max_depth min_child_weight subsample acc iter total time (s)
train_breast_cancer_31c9f_00000TERMINATED127.0.0.1:897350.0434196 8 1 0.5303510.909091 1 0.0114911
train_breast_cancer_31c9f_00001TERMINATED127.0.0.1:897340.0115669 6 2 0.9965190.615385 1 0.01138
train_breast_cancer_31c9f_00002TERMINATED127.0.0.1:897400.00124339 7 3 0.5360780.629371 1 0.0096581
train_breast_cancer_31c9f_00003TERMINATED127.0.0.1:897420.000400434 6 3 0.90014 0.601399 1 0.0103199
train_breast_cancer_31c9f_00004TERMINATED127.0.0.1:897380.0121308 6 3 0.8431560.629371 1 0.00843
train_breast_cancer_31c9f_00005TERMINATED127.0.0.1:897330.0344144 2 3 0.5130710.895105 1 0.00800109
train_breast_cancer_31c9f_00006TERMINATED127.0.0.1:897370.0530037 7 2 0.9208010.965035 1 0.0117419
train_breast_cancer_31c9f_00007TERMINATED127.0.0.1:897410.000230442 3 3 0.9468520.608392 1 0.00917387
train_breast_cancer_31c9f_00008TERMINATED127.0.0.1:897390.00166323 4 1 0.5888790.636364 1 0.011095
train_breast_cancer_31c9f_00009TERMINATED127.0.0.1:897360.0753618 3 3 0.55103 0.909091 1 0.00776482
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-11 16:13:34,649\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-13-31' in 0.0057s.\n", "2025-02-11 16:13:34,652\tINFO tune.py:1041 -- Total run time: 1.88 seconds (1.86 seconds for the tuning loop).\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000000)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000001)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000002)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000003)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000004)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000005)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000006)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000007)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000008)\n", "\u001b[36m(train_breast_cancer pid=90413)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-17-11/train_breast_cancer_b412c_00000_0_eta=0.0200,max_depth=4,min_child_weight=2,subsample=0.7395_2025-02-11_16-17-11/checkpoint_000009)\n" ] } ], "source": [ "import sklearn.datasets\n", "import sklearn.metrics\n", "\n", "from ray import tune\n", "\n", "\n", "def train_breast_cancer(config):\n", " # Load dataset\n", " data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)\n", " # Split into train and test set\n", " train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25)\n", " # Build input matrices for XGBoost\n", " train_set = xgb.DMatrix(train_x, label=train_y)\n", " test_set = xgb.DMatrix(test_x, label=test_y)\n", " # Train the classifier\n", " results = {}\n", " xgb.train(\n", " config,\n", " train_set,\n", " evals=[(test_set, \"eval\")],\n", " evals_result=results,\n", " verbose_eval=False,\n", " )\n", " # Return prediction accuracy\n", " accuracy = 1.0 - results[\"eval\"][\"error\"][-1]\n", " tune.report({\"mean_accuracy\": accuracy, \"done\": True})\n", "\n", "\n", "config = {\n", " \"objective\": \"binary:logistic\",\n", " \"eval_metric\": [\"logloss\", \"error\"],\n", " \"max_depth\": tune.randint(1, 9),\n", " \"min_child_weight\": tune.choice([1, 2, 3]),\n", " \"subsample\": tune.uniform(0.5, 1.0),\n", " \"eta\": tune.loguniform(1e-4, 1e-1),\n", "}\n", "tuner = tune.Tuner(\n", " train_breast_cancer,\n", " tune_config=tune.TuneConfig(num_samples=10),\n", " param_space=config,\n", ")\n", "results = tuner.fit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4999e858", "metadata": {}, "source": [ "As you can see, the changes in the actual training function are minimal. Instead of\n", "returning the accuracy value, we report it back to Tune using `session.report()`.\n", "Our `config` dictionary only changed slightly. Instead of passing hard-coded\n", "parameters, we tell Tune to choose values from a range of valid options. There are\n", "a number of options we have here, all of which are explained in\n", "{ref}`the Tune docs `.\n", "\n", "For a brief explanation, this is what they do:\n", "\n", "- `tune.randint(min, max)` chooses a random integer value between *min* and *max*.\n", " Note that *max* is exclusive, so it will not be sampled.\n", "- `tune.choice([a, b, c])` chooses one of the items of the list at random. Each item\n", " has the same chance to be sampled.\n", "- `tune.uniform(min, max)` samples a floating point number between *min* and *max*.\n", " Note that *max* is exclusive here, too.\n", "- `tune.loguniform(min, max)` samples a floating point number between *min* and *max*,\n", " but applies a logarithmic transformation to these boundaries first. Thus, this makes\n", " it easy to sample values from different orders of magnitude.\n", "\n", "The `num_samples=10` option we pass to the `TuneConfig()` means that we sample 10 different\n", "hyperparameter configurations from this search space.\n", "\n", "The output of our training run coud look like this:\n", "\n", "```{code-block} bash\n", ":emphasize-lines: 14\n", "\n", " Number of trials: 10/10 (10 TERMINATED)\n", " +---------------------------------+------------+-------+-------------+-------------+--------------------+-------------+----------+--------+------------------+\n", " | Trial name | status | loc | eta | max_depth | min_child_weight | subsample | acc | iter | total time (s) |\n", " |---------------------------------+------------+-------+-------------+-------------+--------------------+-------------+----------+--------+------------------|\n", " | train_breast_cancer_b63aa_00000 | TERMINATED | | 0.000117625 | 2 | 2 | 0.616347 | 0.916084 | 1 | 0.0306492 |\n", " | train_breast_cancer_b63aa_00001 | TERMINATED | | 0.0382954 | 8 | 2 | 0.581549 | 0.937063 | 1 | 0.0357082 |\n", " | train_breast_cancer_b63aa_00002 | TERMINATED | | 0.000217926 | 1 | 3 | 0.528428 | 0.874126 | 1 | 0.0264609 |\n", " | train_breast_cancer_b63aa_00003 | TERMINATED | | 0.000120929 | 8 | 1 | 0.634508 | 0.958042 | 1 | 0.036406 |\n", " | train_breast_cancer_b63aa_00004 | TERMINATED | | 0.00839715 | 5 | 1 | 0.730624 | 0.958042 | 1 | 0.0389378 |\n", " | train_breast_cancer_b63aa_00005 | TERMINATED | | 0.000732948 | 8 | 2 | 0.915863 | 0.958042 | 1 | 0.0382841 |\n", " | train_breast_cancer_b63aa_00006 | TERMINATED | | 0.000856226 | 4 | 1 | 0.645209 | 0.916084 | 1 | 0.0357089 |\n", " | train_breast_cancer_b63aa_00007 | TERMINATED | | 0.00769908 | 7 | 1 | 0.729443 | 0.909091 | 1 | 0.0390737 |\n", " | train_breast_cancer_b63aa_00008 | TERMINATED | | 0.00186339 | 5 | 3 | 0.595744 | 0.944056 | 1 | 0.0343912 |\n", " | train_breast_cancer_b63aa_00009 | TERMINATED | | 0.000950272 | 3 | 2 | 0.835504 | 0.965035 | 1 | 0.0348201 |\n", " +---------------------------------+------------+-------+-------------+-------------+--------------------+-------------+----------+--------+------------------+\n", "```\n", "\n", "The best configuration we found used `eta=0.000950272`, `max_depth=3`,\n", "`min_child_weight=2`, `subsample=0.835504` and reached an accuracy of\n", "`0.965035`.\n", "\n", "## Early stopping\n", "\n", "Currently, Tune samples 10 different hyperparameter configurations and trains a full\n", "XGBoost on all of them. In our small example, training is very fast. However,\n", "if training takes longer, a significant amount of computer resources is spent on trials\n", "that will eventually show a bad performance, e.g. a low accuracy. It would be good\n", "if we could identify these trials early and stop them, so we don't waste any resources.\n", "\n", "This is where Tune's *Schedulers* shine. A Tune `TrialScheduler` is responsible\n", "for starting and stopping trials. Tune implements a number of different schedulers, each\n", "described {ref}`in the Tune documentation `.\n", "For our example, we will use the `AsyncHyperBandScheduler` or `ASHAScheduler`.\n", "\n", "The basic idea of this scheduler: We sample a number of hyperparameter configurations.\n", "Each of these configurations is trained for a specific number of iterations.\n", "After these iterations, only the best performing hyperparameters are retained. These\n", "are selected according to some loss metric, usually an evaluation loss. This cycle is\n", "repeated until we end up with the best configuration.\n", "\n", "The `ASHAScheduler` needs to know three things:\n", "\n", "1. Which metric should be used to identify badly performing trials?\n", "2. Should this metric be maximized or minimized?\n", "3. How many iterations does each trial train for?\n", "\n", "There are more parameters, which are explained in the\n", "{ref}`documentation `.\n", "\n", "Lastly, we have to report the loss metric to Tune. We do this with a `Callback` that\n", "XGBoost accepts and calls after each evaluation round. Ray Tune comes\n", "with {ref}`two XGBoost callbacks `\n", "we can use for this. The `TuneReportCallback` just reports the evaluation\n", "metrics back to Tune. The `TuneReportCheckpointCallback` also saves\n", "checkpoints after each evaluation round. We will just use the latter in this\n", "example so that we can retrieve the saved model later.\n", "\n", "These parameters from the `eval_metrics` configuration setting are then automatically\n", "reported to Tune via the callback. Here, the raw error will be reported, not the accuracy.\n", "To display the best reached accuracy, we will inverse it later.\n", "\n", "We will also load the best checkpointed model so that we can use it for predictions.\n", "The best model is selected with respect to the `metric` and `mode` parameters we\n", "pass to the `TunerConfig()`." ] }, { "cell_type": "code", "execution_count": 6, "id": "d08b5b0a", "metadata": { "tags": [ "hide-output" ] }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2025-02-11 16:13:35
Running for: 00:00:01.05
Memory: 22.5/36.0 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using AsyncHyperBand: num_stopped=1
Bracket: Iter 8.000: -0.6414526407118444 | Iter 4.000: -0.6439705872452343 | Iter 2.000: -0.6452721030145259 | Iter 1.000: -0.6459394399519567
Logical resource usage: 1.0/12 CPUs, 0/0 GPUs\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc eta max_depth min_child_weight subsample iter total time (s) eval-logloss eval-error
train_breast_cancer_32eb5_00000TERMINATED127.0.0.1:897630.000830475 5 1 0.675899 10 0.0169384 0.640195 0.342657
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2025-02-11 16:13:35,717\tINFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/rdecal/ray_results/train_breast_cancer_2025-02-11_16-13-34' in 0.0018s.\n", "2025-02-11 16:13:35,719\tINFO tune.py:1041 -- Total run time: 1.05 seconds (1.04 seconds for the tuning loop).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Best model parameters: {'objective': 'binary:logistic', 'eval_metric': ['logloss', 'error'], 'max_depth': 5, 'min_child_weight': 1, 'subsample': 0.675899175238225, 'eta': 0.0008304750981897656}\n", "Best model total accuracy: 0.6573\n" ] } ], "source": [ "import sklearn.datasets\n", "import sklearn.metrics\n", "from ray.tune.schedulers import ASHAScheduler\n", "from sklearn.model_selection import train_test_split\n", "import xgboost as xgb\n", "\n", "from ray import tune\n", "from ray.tune.integration.xgboost import TuneReportCheckpointCallback\n", "\n", "\n", "def train_breast_cancer(config: dict):\n", " # This is a simple training function to be passed into Tune\n", " # Load dataset\n", " data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)\n", " # Split into train and test set\n", " train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25)\n", " # Build input matrices for XGBoost\n", " train_set = xgb.DMatrix(train_x, label=train_y)\n", " test_set = xgb.DMatrix(test_x, label=test_y)\n", " # Train the classifier, using the Tune callback\n", " xgb.train(\n", " config,\n", " train_set,\n", " evals=[(test_set, \"eval\")],\n", " verbose_eval=False,\n", " # `TuneReportCheckpointCallback` defines the checkpointing frequency and format.\n", " callbacks=[TuneReportCheckpointCallback(frequency=1)],\n", " )\n", "\n", "\n", "def get_best_model_checkpoint(results):\n", " best_result = results.get_best_result()\n", "\n", " # `TuneReportCheckpointCallback` provides a helper method to retrieve the\n", " # model from a checkpoint.\n", " best_bst = TuneReportCheckpointCallback.get_model(best_result.checkpoint)\n", "\n", " accuracy = 1.0 - best_result.metrics[\"eval-error\"]\n", " print(f\"Best model parameters: {best_result.config}\")\n", " print(f\"Best model total accuracy: {accuracy:.4f}\")\n", " return best_bst\n", "\n", "\n", "def tune_xgboost(smoke_test=False):\n", " search_space = {\n", " # You can mix constants with search space objects.\n", " \"objective\": \"binary:logistic\",\n", " \"eval_metric\": [\"logloss\", \"error\"],\n", " \"max_depth\": tune.randint(1, 9),\n", " \"min_child_weight\": tune.choice([1, 2, 3]),\n", " \"subsample\": tune.uniform(0.5, 1.0),\n", " \"eta\": tune.loguniform(1e-4, 1e-1),\n", " }\n", " # This will enable aggressive early stopping of bad trials.\n", " scheduler = ASHAScheduler(\n", " max_t=10, grace_period=1, reduction_factor=2 # 10 training iterations\n", " )\n", "\n", " tuner = tune.Tuner(\n", " train_breast_cancer,\n", " tune_config=tune.TuneConfig(\n", " metric=\"eval-logloss\",\n", " mode=\"min\",\n", " scheduler=scheduler,\n", " num_samples=1 if smoke_test else 10,\n", " ),\n", " param_space=search_space,\n", " )\n", " results = tuner.fit()\n", " return results\n", "\n", "\n", "results = tune_xgboost(smoke_test=SMOKE_TEST)\n", "\n", "# Load the best model checkpoint.\n", "best_bst = get_best_model_checkpoint(results)\n", "\n", "# You could now do further predictions with\n", "# best_bst.predict(...)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "20732fe4", "metadata": {}, "source": [ "The output of our run could look like this:\n", "\n", "```{code-block} bash\n", ":emphasize-lines: 7\n", "\n", " Number of trials: 10/10 (10 TERMINATED)\n", " +---------------------------------+------------+-------+-------------+-------------+--------------------+-------------+--------+------------------+----------------+--------------+\n", " | Trial name | status | loc | eta | max_depth | min_child_weight | subsample | iter | total time (s) | eval-logloss | eval-error |\n", " |---------------------------------+------------+-------+-------------+-------------+--------------------+-------------+--------+------------------+----------------+--------------|\n", " | train_breast_cancer_ba275_00000 | TERMINATED | | 0.00205087 | 2 | 1 | 0.898391 | 10 | 0.380619 | 0.678039 | 0.090909 |\n", " | train_breast_cancer_ba275_00001 | TERMINATED | | 0.000183834 | 4 | 3 | 0.924939 | 1 | 0.0228798 | 0.693009 | 0.111888 |\n", " | train_breast_cancer_ba275_00002 | TERMINATED | | 0.0242721 | 7 | 2 | 0.501551 | 10 | 0.376154 | 0.54472 | 0.06993 |\n", " | train_breast_cancer_ba275_00003 | TERMINATED | | 0.000449692 | 5 | 3 | 0.890212 | 1 | 0.0234981 | 0.692811 | 0.090909 |\n", " | train_breast_cancer_ba275_00004 | TERMINATED | | 0.000376393 | 7 | 2 | 0.883609 | 1 | 0.0231569 | 0.692847 | 0.062937 |\n", " | train_breast_cancer_ba275_00005 | TERMINATED | | 0.00231942 | 3 | 3 | 0.877464 | 2 | 0.104867 | 0.689541 | 0.083916 |\n", " | train_breast_cancer_ba275_00006 | TERMINATED | | 0.000542326 | 1 | 2 | 0.578584 | 1 | 0.0213971 | 0.692765 | 0.083916 |\n", " | train_breast_cancer_ba275_00007 | TERMINATED | | 0.0016801 | 1 | 2 | 0.975302 | 1 | 0.02226 | 0.691999 | 0.083916 |\n", " | train_breast_cancer_ba275_00008 | TERMINATED | | 0.000595756 | 8 | 3 | 0.58429 | 1 | 0.0221152 | 0.692657 | 0.06993 |\n", " | train_breast_cancer_ba275_00009 | TERMINATED | | 0.000357845 | 8 | 1 | 0.637776 | 1 | 0.022635 | 0.692859 | 0.090909 |\n", " +---------------------------------+------------+-------+-------------+-------------+--------------------+-------------+--------+------------------+----------------+--------------+\n", "\n", "\n", " Best model parameters: {'objective': 'binary:logistic', 'eval_metric': ['logloss', 'error'], 'max_depth': 7, 'min_child_weight': 2, 'subsample': 0.5015513240240503, 'eta': 0.024272050872920895}\n", " Best model total accuracy: 0.9301\n", "```\n", "\n", "As you can see, most trials have been stopped only after a few iterations. Only the\n", "two most promising trials were run for the full 10 iterations.\n", "\n", "You can also ensure that all available resources are being used as the scheduler\n", "terminates trials, freeing them up. This can be done through the\n", "`ResourceChangingScheduler`. An example of this can be found here:\n", "{doc}`/tune/examples/includes/xgboost_dynamic_resources_example`.\n", "\n", "## Using fractional GPUs\n", "\n", "You can often accelerate your training by using GPUs in addition to CPUs. However,\n", "you usually don't have as many GPUs as you have trials to run. For instance, if you\n", "run 10 Tune trials in parallel, you usually don't have access to 10 separate GPUs.\n", "\n", "Tune supports *fractional GPUs*. This means that each task is assigned a fraction\n", "of the GPU memory for training. For 10 tasks, this could look like this:" ] }, { "cell_type": "code", "execution_count": null, "id": "7d1b20a3", "metadata": { "tags": [ "hide-output" ] }, "outputs": [], "source": [ "config = {\n", " \"objective\": \"binary:logistic\",\n", " \"eval_metric\": [\"logloss\", \"error\"],\n", " \"tree_method\": \"gpu_hist\",\n", " \"max_depth\": tune.randint(1, 9),\n", " \"min_child_weight\": tune.choice([1, 2, 3]),\n", " \"subsample\": tune.uniform(0.5, 1.0),\n", " \"eta\": tune.loguniform(1e-4, 1e-1),\n", "}\n", "\n", "tuner = tune.Tuner(\n", " tune.with_resources(train_breast_cancer, resources={\"cpu\": 1, \"gpu\": 0.1}),\n", " tune_config=tune.TuneConfig(num_samples=1 if SMOKE_TEST else 10),\n", " param_space=config,\n", ")\n", "results = tuner.fit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ee131861", "metadata": {}, "source": [ "Each task thus works with 10% of the available GPU memory. You also have to tell\n", "XGBoost to use the `gpu_hist` tree method, so it knows it should use the GPU.\n", "\n", "## Conclusion\n", "\n", "You should now have a basic understanding on how to train XGBoost models and on how\n", "to tune the hyperparameters to yield the best results. In our simple example,\n", "Tuning the parameters didn't make a huge difference for the accuracy.\n", "But in larger applications, intelligent hyperparameter tuning can make the\n", "difference between a model that doesn't seem to learn at all, and a model\n", "that outperforms all the other ones.\n", "\n", "## More XGBoost Examples\n", "\n", "- {doc}`/tune/examples/includes/xgboost_dynamic_resources_example`:\n", " Trains a basic XGBoost model with Tune with the class-based API and a ResourceChangingScheduler, ensuring all resources are being used at all time.\n", "- {doc}`/train/examples/xgboost/distributed-xgboost-lightgbm`: Shows how to scale XGBoost single-model training with *Ray Train* (as opposed to hyperparameter tuning with Ray Tune).\n", "\n", "## Learn More\n", "\n", "- [XGBoost Hyperparameter Tuning - A Visual Guide](https://kevinvecmanis.io/machine%20learning/hyperparameter%20tuning/dataviz/python/2019/05/11/XGBoost-Tuning-Visual-Guide.html)\n", "- [Notes on XGBoost Parameter Tuning](https://xgboost.readthedocs.io/en/latest/tutorials/param_tuning.html)\n", "- [Doing XGBoost Hyperparameter Tuning the smart way](https://towardsdatascience.com/doing-xgboost-hyper-parameter-tuning-the-smart-way-part-1-of-2-f6d255a45dde)" ] } ], "metadata": { "kernelspec": { "display_name": "xgboost-tune", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" }, "orphan": true }, "nbformat": 4, "nbformat_minor": 5 }