{ "cells": [ { "cell_type": "markdown", "id": "57fe8246", "metadata": {}, "source": [ "# Offline reinforcement learning with Ray AIR\n", "In this example, we'll train a reinforcement learning agent using offline training.\n", "\n", "Offline training means that the data from the environment (and the actions performed by the agent) have been stored on disk. In contrast, online training samples experiences live by interacting with the environment." ] }, { "cell_type": "markdown", "id": "edc8d8ac", "metadata": {}, "source": [ "Let's start with installing our dependencies:" ] }, { "cell_type": "code", "execution_count": 1, "id": "0ef2e884", "metadata": {}, "outputs": [], "source": [ "# !pip install -qU \"ray[rllib]\" gymnasium" ] }, { "cell_type": "markdown", "id": "503b1b55", "metadata": {}, "source": [ "Now we can run some imports:" ] }, { "cell_type": "code", "execution_count": 2, "id": "db0a45ff", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import gymnasium as gym\n", "import os\n", "\n", "import numpy as np\n", "import ray\n", "from ray.air import Checkpoint\n", "from ray.air.config import RunConfig\n", "from ray.train.rl.rl_predictor import RLPredictor\n", "from ray.train.rl.rl_trainer import RLTrainer\n", "from ray.air.config import ScalingConfig\n", "from ray.air.result import Result\n", "from ray.rllib.algorithms.bc import BC\n", "from ray.tune.tuner import Tuner" ] }, { "cell_type": "markdown", "id": "184fe936", "metadata": {}, "source": [ "We will be training on offline data - this means we have full agent trajectories stored somewhere on disk and want to train on these past experiences.\n", "\n", "Usually this data could come from external systems, or a database of historical data. But for this example, we'll generate some offline data ourselves and store it using RLlibs `output_config`." ] }, { "cell_type": "code", "execution_count": 3, "id": "5aeed761", "metadata": {}, "outputs": [], "source": [ "def generate_offline_data(path: str):\n", " print(f\"Generating offline data for training at {path}\")\n", " trainer = RLTrainer(\n", " algorithm=\"PPO\",\n", " run_config=RunConfig(stop={\"timesteps_total\": 5000}),\n", " config={\n", " \"env\": \"CartPole-v1\",\n", " \"output\": \"dataset\",\n", " \"output_config\": {\n", " \"format\": \"json\",\n", " \"path\": path,\n", " \"max_num_samples_per_file\": 1,\n", " },\n", " \"batch_mode\": \"complete_episodes\",\n", " },\n", " )\n", " trainer.fit()" ] }, { "cell_type": "markdown", "id": "8bca906c", "metadata": {}, "source": [ "Here we define the training function. It will create an `RLTrainer` using the `PPO` algorithm and kick off training on the `CartPole-v1` environment. It will use the offline data provided in `path` for this." ] }, { "cell_type": "code", "execution_count": 4, "id": "f5071ce0", "metadata": {}, "outputs": [], "source": [ "def train_rl_bc_offline(path: str, num_workers: int, use_gpu: bool = False) -> Result:\n", " print(\"Starting offline training\")\n", " dataset = ray.data.read_json(\n", " path, parallelism=num_workers, ray_remote_args={\"num_cpus\": 1}\n", " )\n", "\n", " trainer = RLTrainer(\n", " run_config=RunConfig(stop={\"training_iteration\": 5}),\n", " scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),\n", " datasets={\"train\": dataset},\n", " algorithm=BC,\n", " config={\n", " \"env\": \"CartPole-v1\",\n", " \"framework\": \"tf\",\n", " \"evaluation_num_workers\": 1,\n", " \"evaluation_interval\": 1,\n", " \"evaluation_config\": {\"input\": \"sampler\"},\n", " },\n", " )\n", "\n", " # Todo (krfricke/xwjiang): Enable checkpoint config in RunConfig\n", " # result = trainer.fit()\n", " tuner = Tuner(\n", " trainer,\n", " _tuner_kwargs={\"checkpoint_at_end\": True},\n", " )\n", " result = tuner.fit()[0]\n", " return result" ] }, { "cell_type": "markdown", "id": "d935cdee", "metadata": {}, "source": [ "Once we trained our RL policy, we want to evaluate it on a fresh environment. For this, we will also define a utility function:" ] }, { "cell_type": "code", "execution_count": 5, "id": "2628f3b0", "metadata": {}, "outputs": [], "source": [ "def evaluate_using_checkpoint(checkpoint: Checkpoint, num_episodes) -> list:\n", " predictor = RLPredictor.from_checkpoint(checkpoint)\n", "\n", " env = gym.make(\"CartPole-v1\")\n", "\n", " rewards = []\n", " for i in range(num_episodes):\n", " obs, _ = env.reset()\n", " reward = 0.0\n", " terminated = truncated = False\n", " while not terminated and not truncated:\n", " action = predictor.predict(np.array([obs]))\n", " obs, r, terminated, truncated, _ = env.step(action[0])\n", " reward += r\n", " rewards.append(reward)\n", "\n", " return rewards" ] }, { "cell_type": "markdown", "id": "84f4bebe", "metadata": {}, "source": [ "Let's put it all together. First, we initialize Ray and create the offline data:" ] }, { "cell_type": "code", "execution_count": 6, "id": "cae1337e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-09-26 18:22:15,032\tINFO worker.py:1509 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generating offline data for training at /tmp/out\n" ] }, { "data": { "text/html": [ "
Current time: | 2022-09-26 18:22:31 |
Running for: | 00:00:15.61 |
Memory: | 10.4/62.7 GiB |
Trial name | status | loc | iter | total time (s) | ts | reward | num_recreated_worker\n", "s | episode_reward_max | episode_reward_min |
---|---|---|---|---|---|---|---|---|---|
AIRPPO_d229c_00000 | TERMINATED | 192.168.1.241:3893828 | 2 | 8.77525 | 8528 | 45.76 | 0 | 137 | 10 |
Trial name | agent_timesteps_total | counters | custom_metrics | date | done | episode_len_mean | episode_media | episode_reward_max | episode_reward_mean | episode_reward_min | episodes_this_iter | episodes_total | experiment_id | hostname | info | iterations_since_restore | node_ip | num_agent_steps_sampled | num_agent_steps_trained | num_env_steps_sampled | num_env_steps_sampled_this_iter | num_env_steps_trained | num_env_steps_trained_this_iter | num_faulty_episodes | num_healthy_workers | num_recreated_workers | num_steps_trained_this_iter | perf | pid | policy_reward_max | policy_reward_mean | policy_reward_min | sampler_perf | sampler_results | time_since_restore | time_this_iter_s | time_total_s | timers | timestamp | timesteps_since_restore | timesteps_total | training_iteration | trial_id | warmup_time |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
AIRPPO_d229c_00000 | 8528 | {'num_env_steps_sampled': 8528, 'num_env_steps_trained': 8528, 'num_agent_steps_sampled': 8528, 'num_agent_steps_trained': 8528} | {} | 2022-09-26_18-22-31 | True | 45.76 | {} | 137 | 45.76 | 10 | 84 | 284 | eadfde34443046629ed77655da6915c9 | corvus | {'learner': {'default_policy': {'learner_stats': {'cur_kl_coeff': 0.30000001192092896, 'cur_lr': 4.999999873689376e-05, 'total_loss': 9.522293, 'policy_loss': -0.03154374, 'vf_loss': 9.54884, 'vf_explained_var': -0.011132962, 'kl': 0.016653905, 'entropy': 0.6111665, 'entropy_coeff': 0.0, 'model': {}}, 'custom_metrics': {}, 'num_agent_steps_trained': 128.0}}, 'num_env_steps_sampled': 8528, 'num_env_steps_trained': 8528, 'num_agent_steps_sampled': 8528, 'num_agent_steps_trained': 8528} | 2 | 192.168.1.241 | 8528 | 8528 | 8528 | 4238 | 8528 | 4238 | 0 | 2 | 0 | 4238 | {'cpu_util_percent': 16.94, 'ram_util_percent': 16.5} | 3893828 | {} | {} | {} | {'mean_raw_obs_processing_ms': 0.2064514408664396, 'mean_inference_ms': 0.31645616264795123, 'mean_action_processing_ms': 0.032597069914330125, 'mean_env_wait_ms': 0.027492389739157415, 'mean_env_render_ms': 0.0} | {'episode_reward_max': 137.0, 'episode_reward_min': 10.0, 'episode_reward_mean': 45.76, 'episode_len_mean': 45.76, 'episode_media': {}, 'episodes_this_iter': 84, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [22.0, 17.0, 27.0, 14.0, 13.0, 10.0, 16.0, 12.0, 25.0, 29.0, 22.0, 27.0, 18.0, 26.0, 35.0, 25.0, 41.0, 23.0, 69.0, 56.0, 53.0, 30.0, 120.0, 40.0, 38.0, 86.0, 10.0, 19.0, 137.0, 43.0, 72.0, 119.0, 21.0, 53.0, 45.0, 36.0, 14.0, 35.0, 69.0, 100.0, 118.0, 48.0, 12.0, 21.0, 12.0, 30.0, 59.0, 34.0, 72.0, 63.0, 50.0, 42.0, 32.0, 28.0, 44.0, 59.0, 19.0, 86.0, 32.0, 69.0, 47.0, 62.0, 73.0, 13.0, 72.0, 36.0, 12.0, 49.0, 17.0, 117.0, 19.0, 13.0, 24.0, 12.0, 17.0, 23.0, 49.0, 22.0, 86.0, 79.0, 92.0, 21.0, 101.0, 30.0, 12.0, 62.0, 80.0, 32.0, 18.0, 95.0, 18.0, 35.0, 80.0, 69.0, 72.0, 116.0, 67.0, 83.0, 35.0, 19.0], 'episode_lengths': [22, 17, 27, 14, 13, 10, 16, 12, 25, 29, 22, 27, 18, 26, 35, 25, 41, 23, 69, 56, 53, 30, 120, 40, 38, 86, 10, 19, 137, 43, 72, 119, 21, 53, 45, 36, 14, 35, 69, 100, 118, 48, 12, 21, 12, 30, 59, 34, 72, 63, 50, 42, 32, 28, 44, 59, 19, 86, 32, 69, 47, 62, 73, 13, 72, 36, 12, 49, 17, 117, 19, 13, 24, 12, 17, 23, 49, 22, 86, 79, 92, 21, 101, 30, 12, 62, 80, 32, 18, 95, 18, 35, 80, 69, 72, 116, 67, 83, 35, 19]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.2064514408664396, 'mean_inference_ms': 0.31645616264795123, 'mean_action_processing_ms': 0.032597069914330125, 'mean_env_wait_ms': 0.027492389739157415, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0} | 8.77525 | 3.47446 | 8.77525 | {'training_iteration_time_ms': 4384.142, 'load_time_ms': 0.287, 'load_throughput': 14835762.966, 'learn_time_ms': 2129.317, 'learn_throughput': 2002.52, 'synch_weights_time_ms': 1.33} | 1664241751 | 0 | 8528 | 2 | d229c_00000 | 3.78836 |
Current time: | 2022-09-26 18:22:55 |
Running for: | 00:00:10.97 |
Memory: | 10.9/62.7 GiB |
Trial name | status | loc | iter | total time (s) | ts | reward | num_recreated_worker\n", "s | episode_reward_max | episode_reward_min |
---|---|---|---|---|---|---|---|---|---|
AIRBC_e3afc_00000 | TERMINATED | 192.168.1.241:3894380 | 5 | 0.996612 | 11084 | nan | 0 | nan | nan |
Trial name | agent_timesteps_total | counters | custom_metrics | date | done | episode_len_mean | episode_media | episode_reward_max | episode_reward_mean | episode_reward_min | episodes_this_iter | episodes_total | evaluation | experiment_id | hostname | info | iterations_since_restore | node_ip | num_agent_steps_sampled | num_agent_steps_trained | num_env_steps_sampled | num_env_steps_sampled_this_iter | num_env_steps_trained | num_env_steps_trained_this_iter | num_faulty_episodes | num_healthy_workers | num_recreated_workers | num_steps_trained_this_iter | perf | pid | policy_reward_max | policy_reward_mean | policy_reward_min | sampler_perf | sampler_results | time_since_restore | time_this_iter_s | time_total_s | timers | timestamp | timesteps_since_restore | timesteps_total | training_iteration | trial_id | warmup_time |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
AIRBC_e3afc_00000 | 11084 | {'num_env_steps_sampled': 11084, 'num_env_steps_trained': 11084, 'num_agent_steps_sampled': 11084, 'num_agent_steps_trained': 11084} | {} | 2022-09-26_18-22-55 | True | nan | {} | nan | nan | nan | 0 | 0 | {'episode_reward_max': 24.0, 'episode_reward_min': 10.0, 'episode_reward_mean': 16.9, 'episode_len_mean': 16.9, 'episode_media': {}, 'episodes_this_iter': 10, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [22.0, 12.0, 10.0, 21.0, 24.0, 19.0, 15.0, 17.0, 17.0, 12.0], 'episode_lengths': [22, 12, 10, 21, 24, 19, 15, 17, 17, 12]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.1389556124250126, 'mean_inference_ms': 0.3053837110354277, 'mean_action_processing_ms': 0.031036834474401493, 'mean_env_wait_ms': 0.0260694335622386, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'num_agent_steps_sampled_this_iter': 169, 'num_env_steps_sampled_this_iter': 169, 'timesteps_this_iter': 169, 'num_healthy_workers': 1, 'num_recreated_workers': 0} | 21b4e50f0a544d479bf6794c0eedc65a | corvus | {'learner': {'default_policy': {'learner_stats': {'policy_loss': 0.69113123, 'total_loss': 0.69113123, 'model': {}}, 'custom_metrics': {}, 'num_agent_steps_trained': 2000.0}}, 'num_env_steps_sampled': 11084, 'num_env_steps_trained': 11084, 'num_agent_steps_sampled': 11084, 'num_agent_steps_trained': 11084} | 5 | 192.168.1.241 | 11084 | 11084 | 11084 | 2270 | 11084 | 2270 | 0 | 2 | 0 | 2270 | {} | 3894380 | {} | {} | {} | {} | {'episode_reward_max': nan, 'episode_reward_min': nan, 'episode_reward_mean': nan, 'episode_len_mean': nan, 'episode_media': {}, 'episodes_this_iter': 0, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [], 'episode_lengths': []}, 'sampler_perf': {}, 'num_faulty_episodes': 0} | 0.996612 | 0.116935 | 0.996612 | {'training_iteration_time_ms': 55.867, 'sample_time_ms': 32.326, 'load_time_ms': 0.227, 'load_throughput': 9744218.306, 'learn_time_ms': 21.78, 'learn_throughput': 101779.38, 'synch_weights_time_ms': 1.468} | 1664241775 | 0 | 11084 | 5 | e3afc_00000 | 6.92411 |