{ "cells": [ { "cell_type": "markdown", "id": "596e2463", "metadata": {}, "source": [ "# Simple AutoML for time series with Ray AIR" ] }, { "cell_type": "markdown", "id": "676f0ada", "metadata": {}, "source": [ "AutoML (Automatic Machine Learning) boils down to picking the best model for a given task and dataset. In {doc}`this Ray Core example `, we showed how to build an AutoML system which will chooses the best `statsforecast` model and its corresponding hyperparameters for a time series regression task on the [M5 dataset](https://www.kaggle.com/c/m5-forecasting-accuracy).\n", "\n", "The basic steps were:\n", "\n", "1. Define a set of autoregressive forecasting models to search over. For each model type, we also define a set of model parameters to search over.\n", "2. Perform temporal cross-validation on each model configuration in parallel.\n", "3. Pick the best performing model as the output of the AutoML system.\n", "\n", "We see that these steps fit into the framework of a hyperparameter optimization problem that can be tackled with the [Ray AIR Tuner](air-tuner)!\n", "\n", "In this notebook, we will show how to:\n", "1. **Create an AutoML system with Ray AIR** for time series forecasting.\n", "2. Leverage the higher-level Tuner API to **define the model and hyperparameter search space**, as well as **parallelize cross-validation** of different models.\n", "3. Analyze results to **identify the best-performing model type and model parameters** for the time-series dataset.\n", "\n", "Similar to {doc}`the Ray Core example `, we will be using only one partition of the [M5 dataset](https://www.kaggle.com/c/m5-forecasting-accuracy) for this example." ] }, { "cell_type": "markdown", "id": "9e01bef8", "metadata": {}, "source": [ "## Setup\n", "\n", "Let's first start by installing the `statsforecast` and `ray[air]` packages." ] }, { "cell_type": "code", "execution_count": null, "id": "5e6da0ab", "metadata": {}, "outputs": [], "source": [ "!pip install statsforecast\n", "!pip install ray[air]" ] }, { "cell_type": "markdown", "id": "aff15de3", "metadata": {}, "source": [ "Next, we'll make the necessary imports, then initialize and connect to our Ray cluster!" ] }, { "cell_type": "code", "execution_count": 2, "id": "555f90df", "metadata": {}, "outputs": [], "source": [ "import time\n", "import itertools\n", "import pandas as pd\n", "import numpy as np\n", "from collections import defaultdict\n", "from statsforecast import StatsForecast\n", "from statsforecast.models import ETS, AutoARIMA, _TS\n", "from pyarrow import parquet as pq\n", "from sklearn.metrics import mean_squared_error, mean_absolute_error\n", "\n", "import ray\n", "from ray import air, tune" ] }, { "cell_type": "code", "execution_count": null, "id": "83e21bfe", "metadata": {}, "outputs": [], "source": [ "if ray.is_initialized():\n", " ray.shutdown()\n", "ray.init(runtime_env={\"pip\": [\"statsforecast\"]})" ] }, { "cell_type": "markdown", "id": "d38cba4c-3984-4c2a-9a01-aefb610b92ce", "metadata": {}, "source": [ "```{note}\n", "We may want to run on multiple nodes, and setting the `runtime_env` to include the `statsforecast` module will guarantee that we can access it on each worker, regardless of which node it lives on.\n", "```" ] }, { "cell_type": "markdown", "id": "01ee6bbe", "metadata": {}, "source": [ "## Read a partition of the M5 dataset from S3\n", "\n", "We first obtain the data from an S3 bucket and preprocess it to the format that `statsforecast` expects. As the dataset is quite large, we use PyArrow’s push-down predicate as a filter to obtain just the rows we care about without having to load them all into memory." ] }, { "cell_type": "code", "execution_count": 4, "id": "4324f7f8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | unique_id | \n", "ds | \n", "y | \n", "
---|---|---|---|
0 | \n", "FOODS_1_001_CA_1 | \n", "2011-01-29 | \n", "13.0 | \n", "
1 | \n", "FOODS_1_001_CA_1 | \n", "2011-01-30 | \n", "10.0 | \n", "
2 | \n", "FOODS_1_001_CA_1 | \n", "2011-01-31 | \n", "10.0 | \n", "
3 | \n", "FOODS_1_001_CA_1 | \n", "2011-02-01 | \n", "11.0 | \n", "
4 | \n", "FOODS_1_001_CA_1 | \n", "2011-02-02 | \n", "14.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
1936 | \n", "FOODS_1_001_CA_1 | \n", "2016-05-18 | \n", "10.0 | \n", "
1937 | \n", "FOODS_1_001_CA_1 | \n", "2016-05-19 | \n", "11.0 | \n", "
1938 | \n", "FOODS_1_001_CA_1 | \n", "2016-05-20 | \n", "10.0 | \n", "
1939 | \n", "FOODS_1_001_CA_1 | \n", "2016-05-21 | \n", "10.0 | \n", "
1940 | \n", "FOODS_1_001_CA_1 | \n", "2016-05-22 | \n", "10.0 | \n", "
1941 rows × 3 columns
\n", "