{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "1dd265b7", "metadata": {}, "source": [ "(tune-analysis-guide)=\n", "\n", "# Analyzing Tune Experiment Results\n", "\n", "In this guide, we'll walk through some common workflows of what analysis you might want to perform after running your Tune experiment with `tuner.fit()`.\n", "\n", "1. Loading Tune experiment results from a directory\n", "2. Basic *experiment-level* analysis: get a quick overview of how trials performed\n", "3. Basic *trial-level* analysis: access individual trial hyperparameter configs and last reported metrics\n", "4. Plotting the entire history of reported metrics for a trial\n", "5. Accessing saved checkpoints (assuming that you have enabled checkpointing) and loading into a model for test inference\n", "\n", "```python\n", "result_grid: ResultGrid = tuner.fit()\n", "best_result: Result = result_grid.get_best_result()\n", "```\n", "\n", "The output of `tuner.fit()` is a [`ResultGrid`](result-grid-docstring), which is a collection of [`Result`](result-docstring) objects. See the linked documentation references for [`ResultGrid`](result-grid-docstring) and [`Result`](result-docstring) for more details on what attributes are available.\n", "\n", "Let's start by performing a hyperparameter search with the MNIST PyTorch example. The training function is defined {doc}`here `, and we pass it into a `Tuner` to start running the trials in parallel." ] }, { "cell_type": "code", "execution_count": 1, "id": "8479d7d2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Tune Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "
Current time:2023-08-25 17:42:39
Running for: 00:00:12.43
Memory: 27.0/64.0 GiB
\n", "
\n", "
\n", "
\n", "

System Info

\n", " Using FIFO scheduling algorithm.
Logical resource usage: 1.0/10 CPUs, 0/0 GPUs\n", "
\n", " \n", "
\n", "
\n", "
\n", "

Trial Status

\n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc lr momentum acc iter total time (s)
train_mnist_6e465_00000TERMINATED127.0.0.1:949030.0188636 0.8 0.925 100 8.81282
train_mnist_6e465_00001TERMINATED127.0.0.1:949040.0104137 0.9 0.9625 100 8.6819
train_mnist_6e465_00002TERMINATED127.0.0.1:949050.00102317 0.990.953125 100 8.67491
train_mnist_6e465_00003TERMINATED127.0.0.1:949060.0103929 0.8 0.94375 100 8.92996
train_mnist_6e465_00004TERMINATED127.0.0.1:949070.00808686 0.9 0.95625 100 8.75311
train_mnist_6e465_00005TERMINATED127.0.0.1:949080.00172525 0.990.95625 100 8.76523
train_mnist_6e465_00006TERMINATED127.0.0.1:949090.0507692 0.8 0.946875 100 8.94565
train_mnist_6e465_00007TERMINATED127.0.0.1:949100.00978134 0.9 0.965625 100 8.77776
train_mnist_6e465_00008TERMINATED127.0.0.1:949110.00368709 0.990.934375 100 8.8495
\n", "
\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "2023-08-25 17:42:27,603\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m StorageContext on SESSION (rank=None):\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m StorageContext<\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m storage_path=/tmp/ray_results\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m storage_local_path=/Users/justin/ray_results\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m storage_filesystem=\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m storage_fs_path=/tmp/ray_results\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m experiment_dir_name=tune_analyzing_results\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m trial_dir_name=train_mnist_6e465_00003_3_lr=0.0104,momentum=0.8000_2023-08-25_17-42-27\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m current_checkpoint_index=0\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94906)\u001b[0m >\n", "\u001b[2m\u001b[36m(train_mnist pid=94907)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray_results/tune_analyzing_results/train_mnist_6e465_00004_4_lr=0.0081,momentum=0.9000_2023-08-25_17-42-27/checkpoint_000000)\n", "2023-08-25 17:42:30,460\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:30,868\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:31,252\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:31,684\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:32,050\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:32,422\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:32,836\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:33,238\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:33,599\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:33,987\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:34,358\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:34,768\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m StorageContext on SESSION (rank=None):\u001b[32m [repeated 8x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m StorageContext<\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m storage_path=/tmp/ray_results\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m storage_local_path=/Users/justin/ray_results\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m storage_filesystem=\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m storage_fs_path=/tmp/ray_results\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m experiment_dir_name=tune_analyzing_results\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m current_checkpoint_index=0\u001b[32m [repeated 16x across cluster]\u001b[0m\n", "\u001b[2m\u001b[36m(ImplicitFunc pid=94905)\u001b[0m >\u001b[32m [repeated 8x across cluster]\u001b[0m\n", "2023-08-25 17:42:35,127\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "\u001b[2m\u001b[36m(train_mnist pid=94906)\u001b[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray_results/tune_analyzing_results/train_mnist_6e465_00003_3_lr=0.0104,momentum=0.8000_2023-08-25_17-42-27/checkpoint_000050)\u001b[32m [repeated 455x across cluster]\u001b[0m\n", "2023-08-25 17:42:35,508\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:35,899\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:36,277\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:36,662\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:37,065\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:37,455\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:37,857\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:38,237\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:38,639\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:39,019\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:39,400\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:39,773\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:39,879\tWARNING experiment_state.py:371 -- Experiment checkpoint syncing has been triggered multiple times in the last 30.0 seconds. A sync will be triggered whenever a trial has checkpointed more than `num_to_keep` times since last sync or if 300 seconds have passed since last sync. If you have set `num_to_keep` in your `CheckpointConfig`, consider increasing the checkpoint frequency or keeping more checkpoints. You can supress this warning by changing the `TUNE_WARN_EXCESSIVE_EXPERIMENT_CHECKPOINT_SYNC_THRESHOLD_S` environment variable.\n", "2023-08-25 17:42:39,882\tINFO tune.py:1147 -- Total run time: 12.52 seconds (12.42 seconds for the tuning loop).\n" ] } ], "source": [ "import os\n", "\n", "from ray import train, tune\n", "from ray.tune.examples.mnist_pytorch import train_mnist\n", "from ray.tune import ResultGrid\n", "\n", "storage_path = \"/tmp/ray_results\"\n", "exp_name = \"tune_analyzing_results\"\n", "tuner = tune.Tuner(\n", " train_mnist,\n", " param_space={\n", " \"lr\": tune.loguniform(0.001, 0.1),\n", " \"momentum\": tune.grid_search([0.8, 0.9, 0.99]),\n", " \"should_checkpoint\": True,\n", " },\n", " run_config=train.RunConfig(\n", " name=exp_name,\n", " stop={\"training_iteration\": 100},\n", " checkpoint_config=train.CheckpointConfig(\n", " checkpoint_score_attribute=\"mean_accuracy\",\n", " num_to_keep=5,\n", " ),\n", " storage_path=storage_path,\n", " ),\n", " tune_config=tune.TuneConfig(mode=\"max\", metric=\"mean_accuracy\", num_samples=3),\n", ")\n", "result_grid: ResultGrid = tuner.fit()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a18a988c", "metadata": {}, "source": [ "## Loading experiment results from an directory\n", "\n", "Although we have the `result_grid` object in memory because we just ran the Tune experiment above, we might be performing this analysis after our initial training script has exited. We can retrieve the `ResultGrid` from a [restored `Tuner`](tune-stopping-guide), passing in the experiment directory, which should look something like `~/ray_results/{exp_name}`. If you don't specify an experiment `name` in the `RunConfig`, the experiment name will be auto-generated and can be found in the logs of your experiment." ] }, { "cell_type": "code", "execution_count": 2, "id": "92ded070", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading results from /tmp/ray_results/tune_analyzing_results...\n" ] } ], "source": [ "experiment_path = os.path.join(storage_path, exp_name)\n", "print(f\"Loading results from {experiment_path}...\")\n", "\n", "restored_tuner = tune.Tuner.restore(experiment_path, trainable=train_mnist)\n", "result_grid = restored_tuner.get_results()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "ea5085c8", "metadata": {}, "source": [ "## Experiment-level Analysis: Working with `ResultGrid`" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a7182cd1", "metadata": {}, "source": [ "The first thing we might want to check is if there were any erroring trials." ] }, { "cell_type": "code", "execution_count": 3, "id": "008a8df7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No errors!\n" ] } ], "source": [ "# Check if there have been errors\n", "if result_grid.errors:\n", " print(\"One of the trials failed!\")\n", "else:\n", " print(\"No errors!\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "c95f6cef", "metadata": {}, "source": [ "Note that `ResultGrid` is an iterable, and we can access its length and index into it to access individual `Result` objects.\n", "\n", "We should have **9** results in this example, since we have 3 samples for each of the 3 grid search values." ] }, { "cell_type": "code", "execution_count": 4, "id": "4ccecf9c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of results: 9\n" ] } ], "source": [ "num_results = len(result_grid)\n", "print(\"Number of results:\", num_results)" ] }, { "cell_type": "code", "execution_count": 5, "id": "5cff1c8d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trial #0 finished successfully with a mean accuracy metric of: 0.953125\n", "Trial #1 finished successfully with a mean accuracy metric of: 0.9625\n", "Trial #2 finished successfully with a mean accuracy metric of: 0.95625\n", "Trial #3 finished successfully with a mean accuracy metric of: 0.946875\n", "Trial #4 finished successfully with a mean accuracy metric of: 0.925\n", "Trial #5 finished successfully with a mean accuracy metric of: 0.934375\n", "Trial #6 finished successfully with a mean accuracy metric of: 0.965625\n", "Trial #7 finished successfully with a mean accuracy metric of: 0.95625\n", "Trial #8 finished successfully with a mean accuracy metric of: 0.94375\n" ] } ], "source": [ "# Iterate over results\n", "for i, result in enumerate(result_grid):\n", " if result.error:\n", " print(f\"Trial #{i} had an error:\", result.error)\n", " continue\n", "\n", " print(\n", " f\"Trial #{i} finished successfully with a mean accuracy metric of:\",\n", " result.metrics[\"mean_accuracy\"]\n", " )" ] }, { "attachments": {}, "cell_type": "markdown", "id": "66c7ccc4", "metadata": {}, "source": [ "Above, we printed the **last reported** `mean_accuracy` metric for all trials by looping through the `result_grid`.\n", "We can access the same metrics for all trials in a pandas DataFrame." ] }, { "cell_type": "code", "execution_count": 6, "id": "c3541ea8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
training_iterationmean_accuracy
01000.953125
11000.962500
21000.956250
31000.946875
41000.925000
51000.934375
61000.965625
71000.956250
81000.943750
\n", "
" ], "text/plain": [ " training_iteration mean_accuracy\n", "0 100 0.953125\n", "1 100 0.962500\n", "2 100 0.956250\n", "3 100 0.946875\n", "4 100 0.925000\n", "5 100 0.934375\n", "6 100 0.965625\n", "7 100 0.956250\n", "8 100 0.943750" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_df = result_grid.get_dataframe()\n", "results_df[[\"training_iteration\", \"mean_accuracy\"]]" ] }, { "cell_type": "code", "execution_count": 7, "id": "0117b332", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shortest training time: 8.674914598464966\n", "Longest training time: 8.945653676986694\n" ] } ], "source": [ "print(\"Shortest training time:\", results_df[\"time_total_s\"].min())\n", "print(\"Longest training time:\", results_df[\"time_total_s\"].max())" ] }, { "attachments": {}, "cell_type": "markdown", "id": "184bd3ee", "metadata": {}, "source": [ "The last reported metrics might not contain the best accuracy each trial achieved. If we want to get maximum accuracy that each trial reported throughout its training, we can do so by using {meth}`~ray.tune.ResultGrid.get_dataframe` specifying a metric and mode used to filter each trial's training history." ] }, { "cell_type": "code", "execution_count": 8, "id": "54f2d019", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
training_iterationmean_accuracy
0500.968750
1550.975000
2950.975000
3710.978125
4650.959375
5770.965625
6820.975000
7800.968750
8920.975000
\n", "
" ], "text/plain": [ " training_iteration mean_accuracy\n", "0 50 0.968750\n", "1 55 0.975000\n", "2 95 0.975000\n", "3 71 0.978125\n", "4 65 0.959375\n", "5 77 0.965625\n", "6 82 0.975000\n", "7 80 0.968750\n", "8 92 0.975000" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_result_df = result_grid.get_dataframe(\n", " filter_metric=\"mean_accuracy\", filter_mode=\"max\"\n", ")\n", "best_result_df[[\"training_iteration\", \"mean_accuracy\"]]" ] }, { "attachments": {}, "cell_type": "markdown", "id": "a016e288", "metadata": {}, "source": [ "## Trial-level Analysis: Working with an individual `Result`" ] }, { "attachments": {}, "cell_type": "markdown", "id": "59d52e62", "metadata": {}, "source": [ "Let's take a look at the result that ended with the best `mean_accuracy` metric. By default, `get_best_result` will use the same metric and mode as defined in the `TuneConfig` above. However, it's also possible to specify a new metric/order in which results should be ranked." ] }, { "cell_type": "code", "execution_count": 9, "id": "1b59ac25", "metadata": {}, "outputs": [], "source": [ "from ray.train import Result\n", "\n", "# Get the result with the maximum test set `mean_accuracy`\n", "best_result: Result = result_grid.get_best_result()\n", "\n", "# Get the result with the minimum `mean_accuracy`\n", "worst_performing_result: Result = result_grid.get_best_result(\n", " metric=\"mean_accuracy\", mode=\"min\"\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "19d25389", "metadata": {}, "source": [ "We can examine a few of the properties of the best `Result`. See the [API reference](result-docstring) for a list of all accessible properties.\n", "\n", "First, we can access the best result's hyperparameter configuration with `Result.config`." ] }, { "cell_type": "code", "execution_count": 10, "id": "7ffc3edc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'lr': 0.009781335971854077, 'momentum': 0.9, 'should_checkpoint': True}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_result.config" ] }, { "attachments": {}, "cell_type": "markdown", "id": "403111f9", "metadata": {}, "source": [ "Next, we can access the trial directory via `Result.path`. The result `path` gives the trial level directory that contains checkpoints (if you reported any) and logged metrics to load manually or inspect using a tool like Tensorboard (see `result.json`, `progress.csv`)." ] }, { "cell_type": "code", "execution_count": 11, "id": "c90dcc28", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/tmp/ray_results/tune_analyzing_results/train_mnist_6e465_00007_7_lr=0.0098,momentum=0.9000_2023-08-25_17-42-27'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "best_result.path" ] }, { "attachments": {}, "cell_type": "markdown", "id": "44d4080e", "metadata": {}, "source": [ "You can also directly get the latest checkpoint for a specific trial via `Result.checkpoint`." ] }, { "cell_type": "code", "execution_count": 12, "id": "fa4018f1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Checkpoint(filesystem=local, path=/tmp/ray_results/tune_analyzing_results/train_mnist_6e465_00007_7_lr=0.0098,momentum=0.9000_2023-08-25_17-42-27/checkpoint_000099)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the last Checkpoint associated with the best-performing trial\n", "best_result.checkpoint" ] }, { "attachments": {}, "cell_type": "markdown", "id": "79661a56", "metadata": {}, "source": [ "You can also get the last-reported metrics associated with a specific trial via `Result.metrics`." ] }, { "cell_type": "code", "execution_count": 15, "id": "52d4b99c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'mean_accuracy': 0.965625,\n", " 'timestamp': 1693010559,\n", " 'should_checkpoint': True,\n", " 'done': True,\n", " 'training_iteration': 100,\n", " 'trial_id': '6e465_00007',\n", " 'date': '2023-08-25_17-42-39',\n", " 'time_this_iter_s': 0.08028697967529297,\n", " 'time_total_s': 8.77775764465332,\n", " 'pid': 94910,\n", " 'node_ip': '127.0.0.1',\n", " 'config': {'lr': 0.009781335971854077,\n", " 'momentum': 0.9,\n", " 'should_checkpoint': True},\n", " 'time_since_restore': 8.77775764465332,\n", " 'iterations_since_restore': 100,\n", " 'checkpoint_dir_name': 'checkpoint_000099',\n", " 'experiment_tag': '7_lr=0.0098,momentum=0.9000'}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the last reported set of metrics\n", "best_result.metrics" ] }, { "attachments": {}, "cell_type": "markdown", "id": "00705f44", "metadata": {}, "source": [ "Access the entire history of reported metrics from a `Result` as a pandas DataFrame:" ] }, { "cell_type": "code", "execution_count": 16, "id": "ca87204f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
training_iterationmean_accuracytime_total_s
010.1687500.111393
120.6093750.195086
230.8000000.283543
340.8406250.388538
450.8406250.479402
............
95960.9468758.415694
96970.9437508.524299
97980.9562508.606126
98990.9343758.697471
991000.9656258.777758
\n", "

100 rows × 3 columns

\n", "
" ], "text/plain": [ " training_iteration mean_accuracy time_total_s\n", "0 1 0.168750 0.111393\n", "1 2 0.609375 0.195086\n", "2 3 0.800000 0.283543\n", "3 4 0.840625 0.388538\n", "4 5 0.840625 0.479402\n", ".. ... ... ...\n", "95 96 0.946875 8.415694\n", "96 97 0.943750 8.524299\n", "97 98 0.956250 8.606126\n", "98 99 0.934375 8.697471\n", "99 100 0.965625 8.777758\n", "\n", "[100 rows x 3 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result_df = best_result.metrics_dataframe\n", "result_df[[\"training_iteration\", \"mean_accuracy\", \"time_total_s\"]]" ] }, { "attachments": {}, "cell_type": "markdown", "id": "20bc50e9", "metadata": {}, "source": [ "## Plotting metrics\n", "\n", "We can use the metrics DataFrame to quickly visualize learning curves. First, let's plot the mean accuracy vs. training iterations for the best result." ] }, { "cell_type": "code", "execution_count": 17, "id": "1ff489ec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "best_result.metrics_dataframe.plot(\"training_iteration\", \"mean_accuracy\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "4fd2f85b", "metadata": {}, "source": [ "We can also iterate through the entire set of results and create a combined plot of all trials with the hyperparameters as labels." ] }, { "cell_type": "code", "execution_count": 18, "id": "54b78da6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Mean Test Accuracy')" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax = None\n", "for result in result_grid:\n", " label = f\"lr={result.config['lr']:.3f}, momentum={result.config['momentum']}\"\n", " if ax is None:\n", " ax = result.metrics_dataframe.plot(\"training_iteration\", \"mean_accuracy\", label=label)\n", " else:\n", " result.metrics_dataframe.plot(\"training_iteration\", \"mean_accuracy\", ax=ax, label=label)\n", "ax.set_title(\"Mean Accuracy vs. Training Iteration for All Trials\")\n", "ax.set_ylabel(\"Mean Test Accuracy\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "be02fc7a", "metadata": {}, "source": [ "## Accessing checkpoints and loading for test inference\n", "\n", "We saw earlier that `Result` contains the last checkpoint associated with a trial. Let's see how we can use this checkpoint to load a model for performing inference on some sample MNIST images." ] }, { "cell_type": "code", "execution_count": 19, "id": "50d3acff", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from ray.tune.examples.mnist_pytorch import ConvNet, get_data_loaders\n", "\n", "model = ConvNet()\n", "\n", "with best_result.checkpoint.as_directory() as checkpoint_dir:\n", " # The model state dict was saved under `model.pt` by the training function\n", " # imported from `ray.tune.examples.mnist_pytorch`\n", " model.load_state_dict(torch.load(os.path.join(checkpoint_dir, \"model.pt\")))" ] }, { "attachments": {}, "cell_type": "markdown", "id": "2813c45d", "metadata": {}, "source": [ "Refer to the training loop definition {doc}`here ` to see how we are saving the checkpoint in the first place.\n", "\n", "Next, let's test our model with a sample data point and print out the predicted class." ] }, { "cell_type": "code", "execution_count": 21, "id": "eb8f6942", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predicted Class = 9\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADICAYAAABCmsWgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAO/UlEQVR4nO3db2xTV5oG8McOsROIYzcwsfGQCO+KGVgxCqM0ST2wFaUWGXYGkSY7C9Jsl/5RUVsHCbKjbtOFICG0ZqEqLDRtP2ybtBqlqaIRYUqrSJUDycAm6ZDSaYE2A9ps8UywgZ2NbQJJnPjshwxeee8NJ07s+Jo+P+l+8Ovj63OAh+N7fH2vTgghQETT0qe7A0Rax5AQSTAkRBIMCZEEQ0IkwZAQSTAkRBIMCZEEQ0IkwZAQSSxI1Y4bGxtx+PBh+P1+lJSU4Pjx4ygvL5e+LhqNYmhoCCaTCTqdLlXdo285IQTC4TDsdjv0eslcIVKgtbVVGAwG8c4774hLly6J5557TlgsFhEIBKSv9fl8AgA3bvOy+Xw+6b9JnRDJP8GxoqICZWVleP311wFMzQ5FRUXYuXMnXn755fu+NhgMwmKxYB3+BguQneyuEQEAJhDBWXyM4eFhmM3m+7ZN+set8fFx9Pf3o76+PlbT6/VwuVzo6elRtB8bG8PY2FjscTgc/nPHsrFAx5BQivx5apjJR/qkH7jfunULk5OTsFqtcXWr1Qq/369o7/F4YDabY1tRUVGyu0Q0J2lf3aqvr0cwGIxtPp8v3V0iipP0j1tLlixBVlYWAoFAXD0QCMBmsynaG41GGI3GZHeDKGmSPpMYDAaUlpbC6/XGatFoFF6vF06nM9lvR5RyKfmepK6uDtu3b8fDDz+M8vJyHD16FCMjI3j66adT8XZEKZWSkGzduhU3b95EQ0MD/H4/1qxZg46ODsXBPFEmSMn3JHMRCoVgNpuxHlu4BEwpMyEiOIOTCAaDyM/Pv2/btK9uEWkdQ0IkwZAQSTAkRBIMCZEEQ0IkwZAQSTAkRBIMCZEEQ0IkwZAQSTAkRBIMCZEEQ0IkwZAQSTAkRBIMCZEEQ0IkwZAQSTAkRBIMCZEEQ0IkwZAQSaTsTlc0v7Isynts3PnR91TbDv3DmKL2cJH6hcqDf5erqE38cSjB3mU2ziREEgwJkQRDQiTBkBBJ8MD9AVHVc0VRezb/9Jz3+9GZPEXtX195UrVtXlvfnN9PiziTEEkwJEQSDAmRBENCJMGQEElwdUvDslb8haJ2/VWDatvOPxUoak/nq59qcm5UeZu9v86ZUG37k4W3FbUPf3FRte21NtVyxuNMQiTBkBBJMCREEgwJkQQP3OdZ1ne+o6hdPWZXbXuk7ANF7ce5d1TbXoqMK2o/+Pc61bam0luKWu8PW1Xbqtm2pFe1fgg/mPE+MglnEiIJhoRIgiEhkmBIiCQSDkl3dzc2b94Mu90OnU6H9vb2uOeFEGhoaMDSpUuRm5sLl8uFK1eUv3UgyhQJr26NjIygpKQEzzzzDKqrqxXPHzp0CMeOHcO7774Lh8OBvXv3orKyEpcvX0ZOTk5SOp3Jrm9doaj98pF/U21bashS1J78r8dV2/peVV4ZZeynyhUvALiUwEoWzSIkmzZtwqZNm1SfE0Lg6NGj2LNnD7Zs2QIAeO+992C1WtHe3o5t27bNrbdEaZDUY5LBwUH4/X64XK5YzWw2o6KiAj09PaqvGRsbQygUituItCSpIfH7/QAAq9UaV7darbHn/j+PxwOz2RzbioqKktklojlL++pWfX09gsFgbPP51E/vJkqXpJ6WYrPZAACBQABLly6N1QOBANasWaP6GqPRCKPRmMxuaFrVc2cUtRULItO0Vh649/arX7oUTyj38eWGN6bZr/pvUkhdUmcSh8MBm80Gr9cbq4VCIfT19cHpdCbzrYjmTcIzye3bt3H16tXY48HBQXz++ecoKChAcXExdu3ahQMHDmDFihWxJWC73Y6qqqpk9pto3iQckvPnz+Oxxx6LPa6rmzrTdPv27WhubsZLL72EkZER7NixA8PDw1i3bh06Ojr4HQllrIRDsn79egghpn1ep9Nh//792L9//5w6RqQVaV/dItI6/uhqnr175lFFbc/fql99RM2VmjcTeLfUrGK9eP7nqvXl+CIl75dunEmIJBgSIgmGhEiCISGS4IH7PFvkU/6/tPKXbtW2X/99o6L2+8ioatvOO99X1LaavlZt+5BeeUfdROgGFs3p9ZmGMwmRBENCJMGQEEkwJEQSDAmRBFe35pn91f9Q1BYs+65q2w2/eV5Ry70+otpW9F9S1G78rly1bcOSL+/XxTjfb1GuvK04+DvVttEZ7zWzcCYhkmBIiCQYEiIJhoRIggfuGjDxhz+q1nNU6tP9JvTqkUcUtY+XTHe1FKXW28qbCwHAX/7Tp4paNDo54/0+CDiTEEkwJEQSDAmRBENCJMGQEElwdSvT6JXXBwaAL392TKU687/ePadrVOvfiypXt75tOJMQSTAkRBIMCZEEQ0IkwQP3DDP4L+q/ETHqfjvjfTSH7Iraqr2Dqm2/XSegqONMQiTBkBBJMCREEgwJkQRDQiTB1S0NW/Bd5SrUgeqWOe/31ferFbXim8qruNAUziREEgwJkQRDQiTBkBBJ8MBdw4Z/VKSo1Sz6nxm/vn3Eolp3/Oq/FTWefjI9ziREEgwJkQRDQiTBkBBJJBQSj8eDsrIymEwmFBYWoqqqCgMDA3FtRkdH4Xa7sXjxYuTl5aGmpgaBQCCpnSaaTwmtbnV1dcHtdqOsrAwTExN45ZVXsHHjRly+fBmLFk3dtnj37t346KOP0NbWBrPZjNraWlRXV+PcuXMpGcCDYPhJp2q95cBhRS1Ll6fadlIob6Fz7B+3qbbNucQroCQioZB0dHTEPW5ubkZhYSH6+/vx6KOPIhgM4u2330ZLSws2bNgAAGhqasKqVavQ29uLRx5RXtSZSOvmdEwSDAYBAAUFBQCA/v5+RCIRuFyuWJuVK1eiuLgYPT09qvsYGxtDKBSK24i0ZNYhiUaj2LVrF9auXYvVq1cDAPx+PwwGAywWS1xbq9UKv9+vuh+PxwOz2RzbioqUX6ARpdOsQ+J2u3Hx4kW0trbOqQP19fUIBoOxzefzzWl/RMk2q9NSamtrcerUKXR3d2PZsmWxus1mw/j4OIaHh+Nmk0AgAJvNprovo9EIo9E4m248MIZ/qn5H3eULFipqagfoAPD7yKiitvAb9Y+uD+pdclMloZlECIHa2lqcOHECnZ2dcDgccc+XlpYiOzsbXq83VhsYGMC1a9fgdKqv4BBpXUIzidvtRktLC06ePAmTyRQ7zjCbzcjNzYXZbMazzz6Luro6FBQUID8/Hzt37oTT6eTKFmWshELy5ptvAgDWr18fV29qasJTTz0FADhy5Aj0ej1qamowNjaGyspKvPHGzO/dR6Q1CYVEiOlua/l/cnJy0NjYiMbGxll3ikhLeO4WkQR/dDXP/vOgcgHj/NrXpmmtXPXL0qn/v/aLdT9T1KJ/+DqhvpE6ziREEgwJkQRDQiTBkBBJ8MB9npn+6k+KWp5u5qflTHdayuTNW7PuE90fZxIiCYaESIIhIZJgSIgkGBIiCa5upYhwlqjWf71G7cRP5Y+rprPlyk/U3y9yY8b7oMRwJiGSYEiIJBgSIgmGhEiCB+4pIrJ0qvUN59wz3kckqDxdZdU/X1VvHOVteFKFMwmRBENCJMGQEEkwJEQSDAmRBFe3UkR/9nPVuuPs3PbLNaz5x5mESIIhIZJgSIgkGBIiCYaESIIhIZJgSIgkGBIiCYaESEJz37jfu5vWBCKA/MZaRLMygQiAmd29TXMhCYfDAICz+DjNPaFvg3A4DLPZfN82OjGTKM2jaDSKoaEhmEwmhMNhFBUVwefzIT8/P91dS6pQKMSxpZEQAuFwGHa7HXr9/Y86NDeT6PV6LFu2DACg0039BDY/P1+zf9hzxbGlj2wGuYcH7kQSDAmRhKZDYjQasW/fPhiNM7/JTabg2DKH5g7cibRG0zMJkRYwJEQSDAmRBENCJKHpkDQ2NmL58uXIyclBRUUFPv3003R3KWHd3d3YvHkz7HY7dDod2tvb454XQqChoQFLly5Fbm4uXC4Xrly5kp7OJsDj8aCsrAwmkwmFhYWoqqrCwMBAXJvR0VG43W4sXrwYeXl5qKmpQSAQSFOPZ0+zIfnggw9QV1eHffv24bPPPkNJSQkqKytx40Zm3dFpZGQEJSUlaGxUu8MVcOjQIRw7dgxvvfUW+vr6sGjRIlRWVmJ0dHSee5qYrq4uuN1u9Pb24pNPPkEkEsHGjRsxMjISa7N79258+OGHaGtrQ1dXF4aGhlBdXZ3GXs+S0Kjy8nLhdrtjjycnJ4XdbhcejyeNvZobAOLEiROxx9FoVNhsNnH48OFYbXh4WBiNRvH++++noYezd+PGDQFAdHV1CSGmxpGdnS3a2tpibb766isBQPT09KSrm7OiyZlkfHwc/f39cLlcsZper4fL5UJPT08ae5Zcg4OD8Pv9ceM0m82oqKjIuHEGg0EAQEFBAQCgv78fkUgkbmwrV65EcXFxxo1NkyG5desWJicnYbVa4+pWqxV+vz9NvUq+e2PJ9HFGo1Hs2rULa9euxerVqwFMjc1gMMBiscS1zbSxARo8C5gyj9vtxsWLF3H27Byv4apRmpxJlixZgqysLMVKSCAQgM1mS1Ovku/eWDJ5nLW1tTh16hROnz4d+4kDMDW28fFxDA8Px7XPpLHdo8mQGAwGlJaWwuv1xmrRaBRerxdOpzONPUsuh8MBm80WN85QKIS+vj7Nj1MIgdraWpw4cQKdnZ1wOBxxz5eWliI7OztubAMDA7h27Zrmx6aQ7pWD6bS2tgqj0Siam5vF5cuXxY4dO4TFYhF+vz/dXUtIOBwWFy5cEBcuXBAAxGuvvSYuXLggvvnmGyGEEAcPHhQWi0WcPHlSfPHFF2LLli3C4XCIu3fvprnn9/fCCy8Is9kszpw5I65fvx7b7ty5E2vz/PPPi+LiYtHZ2SnOnz8vnE6ncDqdaez17Gg2JEIIcfz4cVFcXCwMBoMoLy8Xvb296e5Swk6fPi0wdUmLuG379u1CiKll4L179wqr1SqMRqN4/PHHxcDAQHo7PQNqYwIgmpqaYm3u3r0rXnzxRfHQQw+JhQsXiieeeEJcv349fZ2eJZ4qTyShyWMSIi1hSIgkGBIiCYaESIIhIZJgSIgkGBIiCYaESIIhIZJgSIgkGBIiCYaESOJ/ARenxDNLcYJgAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "_, test_loader = get_data_loaders()\n", "test_img = next(iter(test_loader))[0][0]\n", "\n", "predicted_class = torch.argmax(model(test_img)).item()\n", "print(\"Predicted Class =\", predicted_class)\n", "\n", "# Need to reshape to (batch_size, channels, width, height)\n", "test_img = test_img.numpy().reshape((1, 1, 28, 28))\n", "plt.figure(figsize=(2, 2))\n", "plt.imshow(test_img.reshape((28, 28)))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fce0ae4f", "metadata": {}, "outputs": [], "source": [] }, { "attachments": {}, "cell_type": "markdown", "id": "1699bab7", "metadata": {}, "source": [ "Consider using Ray Data if you want to use a checkpointed model for large scale inference!" ] }, { "attachments": {}, "cell_type": "markdown", "id": "16c25683", "metadata": {}, "source": [ "## Summary\n", "\n", "In this guide, we looked at some common analysis workflows you can perform using the `ResultGrid` output returned by `Tuner.fit`. These included: **loading results from an experiment directory, exploring experiment-level and trial-level results, plotting logged metrics, and accessing trial checkpoints for inference.**\n", "\n", "Take a look at [Tune's experiment tracking integrations](./experiment-tracking) for more analysis tools that you can build into your Tune experiment with a few callbacks!" ] } ], "metadata": { "kernelspec": { "display_name": "ray_dev_py38", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "265d195fda5292fe8f69c6e37c435a5634a1ed3b6799724e66a975f68fa21517" } } }, "nbformat": 4, "nbformat_minor": 5 }