{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "47de02e1",
"metadata": {},
"source": [
"# Running Tune experiments with AxSearch\n",
"In this tutorial we introduce Ax, while running a simple Ray Tune experiment. Tune’s Search Algorithms integrate with Ax and, as a result, allow you to seamlessly scale up a Ax optimization process - without sacrificing performance.\n",
"\n",
"Ax is a platform for optimizing any kind of experiment, including machine learning experiments, A/B tests, and simulations. Ax can optimize discrete configurations (e.g., variants of an A/B test) using multi-armed bandit optimization, and continuous/ordered configurations (e.g. float/int parameters) using Bayesian optimization. Results of A/B tests and simulations with reinforcement learning agents often exhibit high amounts of noise. Ax supports state-of-the-art algorithms which work better than traditional Bayesian optimization in high-noise settings. Ax also supports multi-objective and constrained optimization which are common to real-world problems (e.g. improving load time without increasing data use). Ax belongs to the domain of \"derivative-free\" and \"black-box\" optimization.\n",
"\n",
"In this example we minimize a simple objective to briefly demonstrate the usage of AxSearch with Ray Tune via `AxSearch`. It's useful to keep in mind that despite the emphasis on machine learning experiments, Ray Tune optimizes any implicit or explicit objective. Here we assume `ax-platform==0.2.4` library is installed withe python version >= 3.7. To learn more, please refer to the [Ax website](https://ax.dev/)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "297d8b18",
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: ax-platform==0.2.4 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (0.2.4)\n",
"Requirement already satisfied: botorch==0.6.2 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from ax-platform==0.2.4) (0.6.2)\n",
"Requirement already satisfied: jinja2 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from ax-platform==0.2.4) (3.0.3)\n",
"Requirement already satisfied: pandas in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from ax-platform==0.2.4) (1.3.5)\n",
"Requirement already satisfied: scipy in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from ax-platform==0.2.4) (1.4.1)\n",
"Requirement already satisfied: plotly in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from ax-platform==0.2.4) (5.6.0)\n",
"Requirement already satisfied: scikit-learn in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from ax-platform==0.2.4) (0.24.2)\n",
"Requirement already satisfied: typeguard in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from ax-platform==0.2.4) (2.13.3)\n",
"Requirement already satisfied: gpytorch>=1.6 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from botorch==0.6.2->ax-platform==0.2.4) (1.6.0)\n",
"Requirement already satisfied: torch>=1.9 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from botorch==0.6.2->ax-platform==0.2.4) (1.9.0)\n",
"Requirement already satisfied: multipledispatch in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from botorch==0.6.2->ax-platform==0.2.4) (0.6.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from jinja2->ax-platform==0.2.4) (2.0.1)\n",
"Requirement already satisfied: pytz>=2017.3 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from pandas->ax-platform==0.2.4) (2022.1)\n",
"Requirement already satisfied: numpy>=1.17.3 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from pandas->ax-platform==0.2.4) (1.21.6)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from pandas->ax-platform==0.2.4) (2.8.2)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from plotly->ax-platform==0.2.4) (8.0.1)\n",
"Requirement already satisfied: six in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from plotly->ax-platform==0.2.4) (1.16.0)\n",
"Requirement already satisfied: joblib>=0.11 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from scikit-learn->ax-platform==0.2.4) (1.1.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from scikit-learn->ax-platform==0.2.4) (3.0.0)\n",
"Requirement already satisfied: typing-extensions in /Users/kai/.pyenv/versions/3.7.7/lib/python3.7/site-packages (from torch>=1.9->botorch==0.6.2->ax-platform==0.2.4) (4.1.1)\n",
"\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\n",
"\u001b[0m"
]
}
],
"source": [
"# !pip install ray[tune]\n",
"!pip install ax-platform==0.2.4"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "59b1e0d1",
"metadata": {},
"source": [
"Click below to see all the imports we need for this example.\n",
"You can also launch directly into a Binder instance to run this notebook yourself.\n",
"Just click on the rocket symbol at the top of the navigation."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cbae6dbe",
"metadata": {
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"import numpy as np\n",
"import time\n",
"\n",
"import ray\n",
"from ray import train, tune\n",
"from ray.tune.search.ax import AxSearch"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7b2b6af7",
"metadata": {},
"source": [
"Let's start by defining a classic benchmark for global optimization.\n",
"The form here is explicit for demonstration, yet it is typically a black-box.\n",
"We artificially sleep for a bit (`0.02` seconds) to simulate a long-running ML experiment.\n",
"This setup assumes that we're running multiple `step`s of an experiment and try to tune 6-dimensions of the `x` hyperparameter."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0f7fbe0f",
"metadata": {},
"outputs": [],
"source": [
"def landscape(x):\n",
" \"\"\"\n",
" Hartmann 6D function containing 6 local minima.\n",
" It is a classic benchmark for developing global optimization algorithms.\n",
" \"\"\"\n",
" alpha = np.array([1.0, 1.2, 3.0, 3.2])\n",
" A = np.array(\n",
" [\n",
" [10, 3, 17, 3.5, 1.7, 8],\n",
" [0.05, 10, 17, 0.1, 8, 14],\n",
" [3, 3.5, 1.7, 10, 17, 8],\n",
" [17, 8, 0.05, 10, 0.1, 14],\n",
" ]\n",
" )\n",
" P = 10 ** (-4) * np.array(\n",
" [\n",
" [1312, 1696, 5569, 124, 8283, 5886],\n",
" [2329, 4135, 8307, 3736, 1004, 9991],\n",
" [2348, 1451, 3522, 2883, 3047, 6650],\n",
" [4047, 8828, 8732, 5743, 1091, 381],\n",
" ]\n",
" )\n",
" y = 0.0\n",
" for j, alpha_j in enumerate(alpha):\n",
" t = 0\n",
" for k in range(6):\n",
" t += A[j, k] * ((x[k] - P[j, k]) ** 2)\n",
" y -= alpha_j * np.exp(-t)\n",
" return y"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0b1ae9df",
"metadata": {},
"source": [
"Next, our `objective` function takes a Tune `config`, evaluates the `landscape` of our experiment in a training loop,\n",
"and uses `train.report` to report the `landscape` back to Tune."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8c3f252e",
"metadata": {},
"outputs": [],
"source": [
"def objective(config):\n",
" for i in range(config[\"iterations\"]):\n",
" x = np.array([config.get(\"x{}\".format(i + 1)) for i in range(6)])\n",
" train.report(\n",
" {\"timesteps_total\": i, \"landscape\": landscape(x), \"l2norm\": np.sqrt((x ** 2).sum())}\n",
" )\n",
" time.sleep(0.02)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "d9982d95",
"metadata": {},
"source": [
"Next we define a search space. The critical assumption is that the optimal hyperparamters live within this space. Yet, if the space is very large, then those hyperparamters may be difficult to find in a short amount of time."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "30f75f5a",
"metadata": {},
"outputs": [],
"source": [
"search_space = {\n",
" \"iterations\":100,\n",
" \"x1\": tune.uniform(0.0, 1.0),\n",
" \"x2\": tune.uniform(0.0, 1.0),\n",
" \"x3\": tune.uniform(0.0, 1.0),\n",
" \"x4\": tune.uniform(0.0, 1.0),\n",
" \"x5\": tune.uniform(0.0, 1.0),\n",
" \"x6\": tune.uniform(0.0, 1.0)\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "106d8578",
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"ray.init(configure_logging=False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "932f74e6",
"metadata": {},
"source": [
"Now we define the search algorithm from `AxSearch`. If you want to constrain your parameters or even the space of outcomes, that can be easily done by passing the argumentsas below."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "34dd5c95",
"metadata": {},
"outputs": [],
"source": [
"algo = AxSearch(\n",
" parameter_constraints=[\"x1 + x2 <= 2.0\"],\n",
" outcome_constraints=[\"l2norm <= 1.25\"],\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "f6d18a99",
"metadata": {},
"source": [
"We also use `ConcurrencyLimiter` to constrain to 4 concurrent trials. "
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dcd905ef",
"metadata": {},
"outputs": [],
"source": [
"algo = tune.search.ConcurrencyLimiter(algo, max_concurrent=4)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "10fd5427",
"metadata": {},
"source": [
"The number of samples is the number of hyperparameter combinations that will be tried out. This Tune run is set to `1000` samples.\n",
"You can decrease this if it takes too long on your machine, or you can set a time limit easily through `stop` argument in the `train.RunConfig()` as we will show here."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c53349a5",
"metadata": {},
"outputs": [],
"source": [
"num_samples = 100\n",
"stop_timesteps = 200"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6c661045",
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [],
"source": [
"# Reducing samples for smoke tests\n",
"num_samples = 10"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "91076c5a",
"metadata": {},
"source": [
"Finally, we run the experiment to find the global minimum of the provided landscape (which contains 5 false minima). The argument to metric, `\"landscape\"`, is provided via the `objective` function's `session.report`. The experiment `\"min\"`imizes the \"mean_loss\" of the `landscape` by searching within `search_space` via `algo`, `num_samples` times or when `\"timesteps_total\": stop_timesteps`. This previous sentence is fully characterizes the search problem we aim to solve. With this in mind, notice how efficient it is to execute `tuner.fit()`."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2f519d63",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO 07-22 15:04:18] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.\n",
"[INFO 07-22 15:04:18] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[FixedParameter(name='iterations', parameter_type=INT, value=100), RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[ParameterConstraint(1.0*x1 + 1.0*x2 <= 2.0)]).\n",
"[INFO 07-22 15:04:18] ax.modelbridge.dispatch_utils: Using Bayesian optimization since there are more ordered parameters than there are categories for the unordered categorical parameters.\n",
"[INFO 07-22 15:04:18] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 12 trials, GPEI for subsequent trials]). Iterations after 12 will take longer to generate due to model-fitting.\n",
"Detected sequential enforcement. Be sure to use a ConcurrencyLimiter.\n"
]
},
{
"data": {
"text/html": [
"== Status ==
Current time: 2022-07-22 15:04:35 (running for 00:00:16.56)
Memory usage on this node: 9.9/16.0 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/16 CPUs, 0/0 GPUs, 0.0/5.13 GiB heap, 0.0/2.0 GiB objects
Current best trial: 34b7abda with landscape=-1.6624439263544026 and parameters={'iterations': 100, 'x1': 0.26526361983269453, 'x2': 0.9248840995132923, 'x3': 0.15171580761671066, 'x4': 0.43602637108415365, 'x5': 0.8573104059323668, 'x6': 0.08981018699705601}
Result logdir: /Users/kai/ray_results/ax
Number of trials: 10/10 (10 TERMINATED)
Trial name | status | loc | iterations | x1 | x2 | x3 | x4 | x5 | x6 | iter | total time (s) | ts | landscape | l2norm |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
objective_2dfbe86a | TERMINATED | 127.0.0.1:44721 | 100 | 0.0558336 | 0.0896192 | 0.958956 | 0.234474 | 0.174516 | 0.970311 | 100 | 2.57372 | 99 | -0.805233 | 1.39917 |
objective_2fa776c0 | TERMINATED | 127.0.0.1:44726 | 100 | 0.744772 | 0.754537 | 0.0950125 | 0.273877 | 0.0966829 | 0.368943 | 100 | 2.6361 | 99 | -0.11286 | 1.16341 |
objective_2fabaa1a | TERMINATED | 127.0.0.1:44727 | 100 | 0.405704 | 0.374626 | 0.935628 | 0.222185 | 0.787212 | 0.00812439 | 100 | 2.62393 | 99 | -0.11348 | 1.35995 |
objective_2faee7c0 | TERMINATED | 127.0.0.1:44728 | 100 | 0.664728 | 0.207519 | 0.359514 | 0.704578 | 0.755882 | 0.812402 | 100 | 2.62069 | 99 | -0.0119837 | 1.53035 |
objective_313d3d3a | TERMINATED | 127.0.0.1:44747 | 100 | 0.0418746 | 0.992783 | 0.906027 | 0.594429 | 0.825393 | 0.646362 | 100 | 3.16233 | 99 | -0.00677976 | 1.80573 |
objective_32c9acd8 | TERMINATED | 127.0.0.1:44726 | 100 | 0.126064 | 0.703408 | 0.344681 | 0.337363 | 0.401396 | 0.679202 | 100 | 3.12119 | 99 | -0.904622 | 1.16864 |
objective_32cf8ca2 | TERMINATED | 127.0.0.1:44756 | 100 | 0.0910936 | 0.304138 | 0.869848 | 0.405435 | 0.567922 | 0.228608 | 100 | 2.70791 | 99 | -0.146532 | 1.18178 |
objective_32d8dd20 | TERMINATED | 127.0.0.1:44758 | 100 | 0.603178 | 0.409057 | 0.729056 | 0.0825984 | 0.572948 | 0.508304 | 100 | 2.64158 | 99 | -0.247223 | 1.28691 |
objective_34adf04a | TERMINATED | 127.0.0.1:44768 | 100 | 0.454189 | 0.271772 | 0.530871 | 0.991841 | 0.691843 | 0.472366 | 100 | 2.70327 | 99 | -0.0132915 | 1.49917 |
objective_34b7abda | TERMINATED | 127.0.0.1:44771 | 100 | 0.265264 | 0.924884 | 0.151716 | 0.436026 | 0.85731 | 0.0898102 | 100 | 2.68521 | 99 | -1.66244 | 1.37185 |