{
"cells": [
{
"cell_type": "markdown",
"id": "aa1c2614",
"metadata": {},
"source": [
"(tune-rllib-example)=\n",
"\n",
"# Using RLlib with Tune\n",
"\n",
"\n",
"
\n",
"\n",
"
\n",
"\n",
"```{image} /rllib/images/rllib-logo.png\n",
":align: center\n",
":alt: RLlib Logo\n",
":height: 120px\n",
":target: https://docs.ray.io\n",
"```\n",
"\n",
"```{contents}\n",
":backlinks: none\n",
":local: true\n",
"```\n",
"\n",
"## Example\n",
"\n",
"Example of using a Tune scheduler ([Population Based Training](tune-scheduler-pbt)) with RLlib.\n",
"\n",
"This example specifies `num_workers=4`, `num_cpus=1`, and `num_gpus=0`, which means that each\n",
"PPO trial will use 5 CPUs: 1 (for training) + 4 (for sample collection).\n",
"This example runs 2 trials, so at least 10 CPUs must be available in the cluster resources\n",
"in order to run both trials concurrently. Otherwise, the PBT scheduler will round-robin\n",
"between training each trial, which is less efficient.\n",
"\n",
"If you want to run this example with GPUs, you can set `num_gpus` accordingly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4621a1a",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"from ray import tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig\n",
"from ray.tune.schedulers import PopulationBasedTraining\n",
"\n",
"if __name__ == \"__main__\":\n",
" import argparse\n",
"\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument(\n",
" \"--smoke-test\", action=\"store_true\", help=\"Finish quickly for testing\"\n",
" )\n",
" args, _ = parser.parse_known_args()\n",
"\n",
" # Postprocess the perturbed config to ensure it's still valid\n",
" def explore(config):\n",
" # ensure we collect enough timesteps to do sgd\n",
" if config[\"train_batch_size\"] < config[\"sgd_minibatch_size\"] * 2:\n",
" config[\"train_batch_size\"] = config[\"sgd_minibatch_size\"] * 2\n",
" # ensure we run at least one sgd iter\n",
" if config[\"num_sgd_iter\"] < 1:\n",
" config[\"num_sgd_iter\"] = 1\n",
" return config\n",
"\n",
" hyperparam_mutations = {\n",
" \"clip_param\": lambda: random.uniform(0.01, 0.5),\n",
" \"lr\": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],\n",
" \"num_epochs\": lambda: random.randint(1, 30),\n",
" \"minibatch_size\": lambda: random.randint(128, 16384),\n",
" \"train_batch_size_per_learner\": lambda: random.randint(2000, 160000),\n",
" }\n",
"\n",
" pbt = PopulationBasedTraining(\n",
" time_attr=\"time_total_s\",\n",
" perturbation_interval=120,\n",
" resample_probability=0.25,\n",
" # Specifies the mutations of these hyperparams\n",
" hyperparam_mutations=hyperparam_mutations,\n",
" custom_explore_fn=explore,\n",
" )\n",
"\n",
" # Stop when we've either reached 100 training iterations or reward=300\n",
" stopping_criteria = {\"training_iteration\": 100, \"episode_reward_mean\": 300}\n",
"\n",
" config = (\n",
" PPOConfig()\n",
" .environment(\"Humanoid-v2\")\n",
" .env_runners(num_env_runners=4)\n",
" .training(\n",
" # These params are tuned from a fixed starting value.\n",
" kl_coeff=1.0,\n",
" lambda_=0.95,\n",
" clip_param=0.2,\n",
" lr=1e-4,\n",
" # These params start off randomly drawn from a set.\n",
" num_epochs=tune.choice([10, 20, 30]),\n",
" minibatch_size=tune.choice([128, 512, 2048]),\n",
" train_batch_size_per_learner=tune.choice([10000, 20000, 40000]),\n",
" )\n",
" .rl_module(\n",
" model_config=DefaultModelConfig(free_log_std=True),\n",
" )\n",
" )\n",
"\n",
" tuner = tune.Tuner(\n",
" \"PPO\",\n",
" tune_config=tune.TuneConfig(\n",
" metric=\"env_runners/episode_return_mean\",\n",
" mode=\"max\",\n",
" scheduler=pbt,\n",
" num_samples=1 if args.smoke_test else 2,\n",
" ),\n",
" param_space=config,\n",
" run_config=tune.RunConfig(stop=stopping_criteria),\n",
" )\n",
" results = tuner.fit()\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "8cd3cc70",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best performing trial's final set of hyperparameters:\n",
"\n",
"{'clip_param': 0.2,\n",
" 'lambda': 0.95,\n",
" 'lr': 0.0001,\n",
" 'num_sgd_iter': 30,\n",
" 'sgd_minibatch_size': 2048,\n",
" 'train_batch_size': 20000}\n",
"\n",
"Best performing trial's final reported metrics:\n",
"\n",
"{'episode_len_mean': 61.09146341463415,\n",
" 'episode_reward_max': 567.4424113245353,\n",
" 'episode_reward_mean': 310.36948184391935,\n",
" 'episode_reward_min': 87.74736189944105}\n"
]
}
],
"source": [
"import pprint\n",
"\n",
"best_result = results.get_best_result()\n",
"\n",
"print(\"Best performing trial's final set of hyperparameters:\\n\")\n",
"pprint.pprint(\n",
" {k: v for k, v in best_result.config.items() if k in hyperparam_mutations}\n",
")\n",
"\n",
"print(\"\\nBest performing trial's final reported metrics:\\n\")\n",
"\n",
"metrics_to_print = [\n",
" \"episode_reward_mean\",\n",
" \"episode_reward_max\",\n",
" \"episode_reward_min\",\n",
" \"episode_len_mean\",\n",
"]\n",
"pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4cc4685",
"metadata": {},
"outputs": [],
"source": [
"from ray.rllib.algorithms.algorithm import Algorithm\n",
"\n",
"loaded_ppo = Algorithm.from_checkpoint(best_result.checkpoint)\n",
"loaded_policy = loaded_ppo.get_policy()\n",
"\n",
"# See your trained policy in action\n",
"# loaded_policy.compute_single_action(...)\n"
]
},
{
"cell_type": "markdown",
"id": "db534c4e",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## More RLlib Examples\n",
"\n",
"- {doc}`/tune/examples/includes/pb2_ppo_example`:\n",
" Example of optimizing a distributed RLlib algorithm (PPO) with the PB2 scheduler.\n",
" Uses a small population size of 4, so can train on a laptop."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3d4fb61",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"orphan": true
},
"nbformat": 4,
"nbformat_minor": 5
}