{ "cells": [ { "cell_type": "markdown", "id": "aa1c2614", "metadata": {}, "source": [ "(tune-rllib-example)=\n", "\n", "# Using RLlib with Tune\n", "\n", "```{image} /rllib/images/rllib-logo.png\n", ":align: center\n", ":alt: RLlib Logo\n", ":height: 120px\n", ":target: https://docs.ray.io\n", "```\n", "\n", "```{contents}\n", ":backlinks: none\n", ":local: true\n", "```\n", "\n", "## Example\n", "\n", "Example of using a Tune scheduler ([Population Based Training](tune-scheduler-pbt)) with RLlib.\n", "\n", "This example specifies `num_workers=4`, `num_cpus=1`, and `num_gpus=0`, which means that each\n", "PPO trial will use 5 CPUs: 1 (for training) + 4 (for sample collection).\n", "This example runs 2 trials, so at least 10 CPUs must be available in the cluster resources\n", "in order to run both trials concurrently. Otherwise, the PBT scheduler will round-robin\n", "between training each trial, which is less efficient.\n", "\n", "If you want to run this example with GPUs, you can set `num_gpus` accordingly." ] }, { "cell_type": "code", "execution_count": null, "id": "f4621a1a", "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "import ray\n", "from ray import train, tune\n", "from ray.tune.schedulers import PopulationBasedTraining\n", "\n", "if __name__ == \"__main__\":\n", " import argparse\n", "\n", " parser = argparse.ArgumentParser()\n", " parser.add_argument(\n", " \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n", " )\n", " args, _ = parser.parse_known_args()\n", "\n", " # Postprocess the perturbed config to ensure it's still valid\n", " def explore(config):\n", " # ensure we collect enough timesteps to do sgd\n", " if config[\"train_batch_size\"] < config[\"sgd_minibatch_size\"] * 2:\n", " config[\"train_batch_size\"] = config[\"sgd_minibatch_size\"] * 2\n", " # ensure we run at least one sgd iter\n", " if config[\"num_sgd_iter\"] < 1:\n", " config[\"num_sgd_iter\"] = 1\n", " return config\n", "\n", " hyperparam_mutations = {\n", " \"lambda\": lambda: random.uniform(0.9, 1.0),\n", " \"clip_param\": lambda: random.uniform(0.01, 0.5),\n", " \"lr\": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],\n", " \"num_sgd_iter\": lambda: random.randint(1, 30),\n", " \"sgd_minibatch_size\": lambda: random.randint(128, 16384),\n", " \"train_batch_size\": lambda: random.randint(2000, 160000),\n", " }\n", "\n", " pbt = PopulationBasedTraining(\n", " time_attr=\"time_total_s\",\n", " perturbation_interval=120,\n", " resample_probability=0.25,\n", " # Specifies the mutations of these hyperparams\n", " hyperparam_mutations=hyperparam_mutations,\n", " custom_explore_fn=explore,\n", " )\n", "\n", " # Stop when we've either reached 100 training iterations or reward=300\n", " stopping_criteria = {\"training_iteration\": 100, \"episode_reward_mean\": 300}\n", "\n", " tuner = tune.Tuner(\n", " \"PPO\",\n", " tune_config=tune.TuneConfig(\n", " metric=\"episode_reward_mean\",\n", " mode=\"max\",\n", " scheduler=pbt,\n", " num_samples=1 if args.smoke_test else 2,\n", " ),\n", " param_space={\n", " \"env\": \"Humanoid-v2\",\n", " \"kl_coeff\": 1.0,\n", " \"num_workers\": 4,\n", " \"num_cpus\": 1, # number of CPUs to use per trial\n", " \"num_gpus\": 0, # number of GPUs to use per trial\n", " \"model\": {\"free_log_std\": True},\n", " # These params are tuned from a fixed starting value.\n", " \"lambda\": 0.95,\n", " \"clip_param\": 0.2,\n", " \"lr\": 1e-4,\n", " # These params start off randomly drawn from a set.\n", " \"num_sgd_iter\": tune.choice([10, 20, 30]),\n", " \"sgd_minibatch_size\": tune.choice([128, 512, 2048]),\n", " \"train_batch_size\": tune.choice([10000, 20000, 40000]),\n", " },\n", " run_config=train.RunConfig(stop=stopping_criteria),\n", " )\n", " results = tuner.fit()\n" ] }, { "cell_type": "code", "execution_count": 35, "id": "8cd3cc70", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best performing trial's final set of hyperparameters:\n", "\n", "{'clip_param': 0.2,\n", " 'lambda': 0.95,\n", " 'lr': 0.0001,\n", " 'num_sgd_iter': 30,\n", " 'sgd_minibatch_size': 2048,\n", " 'train_batch_size': 20000}\n", "\n", "Best performing trial's final reported metrics:\n", "\n", "{'episode_len_mean': 61.09146341463415,\n", " 'episode_reward_max': 567.4424113245353,\n", " 'episode_reward_mean': 310.36948184391935,\n", " 'episode_reward_min': 87.74736189944105}\n" ] } ], "source": [ "import pprint\n", "\n", "best_result = results.get_best_result()\n", "\n", "print(\"Best performing trial's final set of hyperparameters:\\n\")\n", "pprint.pprint(\n", " {k: v for k, v in best_result.config.items() if k in hyperparam_mutations}\n", ")\n", "\n", "print(\"\\nBest performing trial's final reported metrics:\\n\")\n", "\n", "metrics_to_print = [\n", " \"episode_reward_mean\",\n", " \"episode_reward_max\",\n", " \"episode_reward_min\",\n", " \"episode_len_mean\",\n", "]\n", "pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e4cc4685", "metadata": {}, "outputs": [], "source": [ "from ray.rllib.algorithms.algorithm import Algorithm\n", "\n", "loaded_ppo = Algorithm.from_checkpoint(best_result.checkpoint)\n", "loaded_policy = loaded_ppo.get_policy()\n", "\n", "# See your trained policy in action\n", "# loaded_policy.compute_single_action(...)\n" ] }, { "cell_type": "markdown", "id": "db534c4e", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## More RLlib Examples\n", "\n", "- {doc}`/tune/examples/includes/pb2_ppo_example`:\n", " Example of optimizing a distributed RLlib algorithm (PPO) with the PB2 scheduler.\n", " Uses a small population size of 4, so can train on a laptop." ] }, { "cell_type": "code", "execution_count": null, "id": "a3d4fb61", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "orphan": true }, "nbformat": 4, "nbformat_minor": 5 }