Note

Ray 2.40 uses RLlib’s new API stack by default. The Ray team has mostly completed transitioning algorithms, example scripts, and documentation to the new code base.

If you’re still using the old API stack, see New API stack migration guide for details on how to migrate.

Getting Started with RLlib#

All RLlib experiments are run using an Algorithm class which holds a policy for environment interaction. Through the algorithm’s interface, you can train the policy, compute actions, or store your algorithm’s state (checkpointing). In multi-agent training, the algorithm manages the querying and optimization of multiple policies at once.

../_images/rllib-api.svg

In this guide, we will explain in detail RLlib’s Python API for running learning experiments.

Using the Python API#

The Python API provides all the flexibility required for applying RLlib to any type of problem.

Let’s start with an example of the API’s basic usage. We first create a PPOConfig instance and set some properties through the config class’ various methods. For example, we can set the RL environment we want to use by calling the config’s environment method. To scale our algorithm and define, how many environment workers (EnvRunners) we want to leverage, we can call the env_runners method. After we build the PPO Algorithm from its configuration, we can train it for a number of iterations (here 10) and save the resulting policy periodically (here every 5 iterations).

from pprint import pprint

from ray.rllib.algorithms.ppo import PPOConfig

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("CartPole-v1")
    .env_runners(num_env_runners=1)
)

algo = config.build()

for i in range(10):
    result = algo.train()
    result.pop("config")
    pprint(result)

    if i % 5 == 0:
        checkpoint_dir = algo.save_to_path()
        print(f"Checkpoint saved in directory {checkpoint_dir}")

All RLlib algorithms are compatible with the Tune API. This enables them to be easily used in experiments with Ray Tune. For example, the following code performs a simple hyper-parameter sweep of PPO.

from ray import train, tune

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("CartPole-v1")
    .training(
        lr=tune.grid_search([0.01, 0.001, 0.0001]),
    )
)

tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=train.RunConfig(
        stop={"env_runners/episode_return_mean": 150.0},
    ),
)

tuner.fit()

Tune will schedule the trials to run in parallel on your Ray cluster:

== Status ==
Using FIFO scheduling algorithm.
Resources requested: 4/4 CPUs, 0/0 GPUs
Result logdir: ~/ray_results/my_experiment
PENDING trials:
 - PPO_CartPole-v1_2_lr=0.0001:     PENDING
RUNNING trials:
 - PPO_CartPole-v1_0_lr=0.01:       RUNNING [pid=21940], 16 s, 4013 ts, 22 rew
 - PPO_CartPole-v1_1_lr=0.001:      RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew

Tuner.fit() returns an ResultGrid object that allows further analysis of the training results and retrieving the checkpoint(s) of the trained agent.

from ray import train, tune

# Tuner.fit() allows setting a custom log directory (other than ~/ray-results).
tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=train.RunConfig(
        stop={"num_env_steps_sampled_lifetime": 20000},
        checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
    ),
)

results = tuner.fit()

# Get the best result based on a particular metric.
best_result = results.get_best_result(
    metric="env_runners/episode_return_mean", mode="max"
)

# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint

Note

You can find your checkpoint’s version by looking into the rllib_checkpoint.json file inside your checkpoint directory.

Loading and restoring a trained algorithm from a checkpoint is simple. Let’s assume you have a local checkpoint directory called checkpoint_path. To load newer RLlib checkpoints (version >= 1.0), use the following code:

from ray.rllib.algorithms.algorithm import Algorithm
algo = Algorithm.from_checkpoint(checkpoint_path)

For older RLlib checkpoint versions (version < 1.0), you can restore an algorithm through:

from ray.rllib.algorithms.ppo import PPO
algo = PPO(config=config, env=env_class)
algo.restore(checkpoint_path)

Computing Actions#

The simplest way to programmatically compute actions from a trained agent is to use Algorithm.compute_single_action(). This method preprocesses and filters the observation before passing it to the agent policy. Here is a simple example of testing a trained agent for one episode:

import pathlib
import gymnasium as gym
import numpy as np
import torch
from ray.rllib.core.rl_module import RLModule

env = gym.make("CartPole-v1")

# Create only the neural network (RLModule) from our checkpoint.
rl_module = RLModule.from_checkpoint(
    pathlib.Path(best_checkpoint.path) / "learner_group" / "learner" / "rl_module"
)["default_policy"]

episode_return = 0
terminated = truncated = False

obs, info = env.reset()

while not terminated and not truncated:
    # Compute the next action from a batch (B=1) of observations.
    torch_obs_batch = torch.from_numpy(np.array([obs]))
    action_logits = rl_module.forward_inference({"obs": torch_obs_batch})[
        "action_dist_inputs"
    ]
    # The default RLModule used here produces action logits (from which
    # we'll have to sample an action or use the max-likelihood one).
    action = torch.argmax(action_logits[0]).numpy()
    obs, reward, terminated, truncated, info = env.step(action)
    episode_return += reward

print(f"Reached episode return of {episode_return}.")

For more advanced usage on computing actions and other functionality, you can consult the RLlib Algorithm API documentation.

Accessing Policy State#

It is common to need to access a algorithm’s internal state, for instance to set or get model weights.

In RLlib algorithm state is replicated across multiple rollout workers (Ray actors) in the cluster. However, you can easily get and update this state between calls to train() via Algorithm.env_runner_group.foreach_worker() or Algorithm.env_runner_group.foreach_worker_with_index(). These functions take a lambda function that is applied with the worker as an argument. These functions return values for each worker as a list.

You can also access just the “master” copy of the algorithm state through Algorithm.get_policy() or Algorithm.env_runner, but note that updates here may not be immediately reflected in your rollout workers (if you have configured num_env_runners > 0). Here’s a quick example of how to access state of a model:

from ray.rllib.algorithms.ppo import PPOConfig

algo = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("CartPole-v1")
    .env_runners(num_env_runners=2)
).build()

# Get weights of the algo's RLModule.
algo.get_module().get_state()

# Same as above
algo.env_runner.module.get_state()

# Get list of weights of each EnvRunner, including remote replicas.
algo.env_runner_group.foreach_worker(lambda env_runner: env_runner.module.get_state())

# Same as above, but with index.
algo.env_runner_group.foreach_worker_with_id(
    lambda _id, env_runner: env_runner.module.get_state()
)

Accessing Model State#

Similar to accessing policy state, you may want to get a reference to the underlying neural network model being trained. For example, you may want to pre-train it separately, or otherwise update its weights outside of RLlib. This can be done by accessing the model of the policy.

Note

To run these examples, you need to install a few extra dependencies, namely pip install "gym[atari]" "gym[accept-rom-license]" atari_py.

Below you find three explicit examples showing how to access the model state of an algorithm.

Example: Preprocessing observations for feeding into a model

Then for the code:

try:
    import gymnasium as gym

    env = gym.make("ale_py:ALE/Pong-v5")
    obs, infos = env.reset()
except Exception:
    import gym

    env = gym.make("PongNoFrameskip-v4")
    obs = env.reset()

# RLlib uses preprocessors to implement transforms such as one-hot encoding
# and flattening of tuple and dict observations.
from ray.rllib.models.preprocessors import get_preprocessor

prep = get_preprocessor(env.observation_space)(env.observation_space)
# <ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>

# Observations should be preprocessed prior to feeding into a model
obs.shape
# (210, 160, 3)
prep.transform(obs).shape
# (84, 84, 3)
Example: Querying a policy’s action distribution
# Get a reference to the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = (
    DQNConfig()
    .api_stack(
        enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
    )
    .framework("torch")
    .environment("CartPole-v1")
    .env_runners(num_env_runners=0)
    .training(
        replay_buffer_config={
            "type": "MultiAgentPrioritizedReplayBuffer",
        }
    )
).build()
# <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>

policy = algo.get_policy()
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>

# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
logits, _ = policy.model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])

# Compute action distribution given logits
policy.dist_class
# <class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
dist = policy.dist_class(logits, policy.model)
# <ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>

# Query the distribution for samples, sample logps
dist.sample()
# <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
dist.logp(torch.tensor([1]))
# <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>

# Get the estimated values for the most recent forward pass
policy.model.value_function()
# <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>

print(policy.model)
"""
Model: "model"
_____________________________________________________________________
Layer (type)               Output Shape  Param #  Connected to
=====================================================================
observations (InputLayer)  [(None, 4)]   0
_____________________________________________________________________
fc_1 (Dense)               (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_value_1 (Dense)         (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_2 (Dense)               (None, 256)   65792    fc_1[0][0]
_____________________________________________________________________
fc_value_2 (Dense)         (None, 256)   65792    fc_value_1[0][0]
_____________________________________________________________________
fc_out (Dense)             (None, 2)     514      fc_2[0][0]
_____________________________________________________________________
value_out (Dense)          (None, 1)     257      fc_value_2[0][0]
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0
_____________________________________________________________________
"""
Example: Getting Q values from a DQN model
# Get a reference to the model through the policy
import numpy as np
import torch

from ray.rllib.algorithms.dqn import DQNConfig

algo = (
    DQNConfig()
    .api_stack(
        enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
    )
    .framework("torch")
    .environment("CartPole-v1")
    .training(
        replay_buffer_config={
            "type": "MultiAgentPrioritizedReplayBuffer",
        }
    )
).build()
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>

# List of all model variables
list(model.parameters())

# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is
# examples/offline_rl/saving_experiences.py
model_out = model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)

# Access the base Keras models (all default models have a base)
print(model)
"""
Model: "model"
_______________________________________________________________________
Layer (type)                Output Shape    Param #  Connected to
=======================================================================
observations (InputLayer)   [(None, 4)]     0
_______________________________________________________________________
fc_1 (Dense)                (None, 256)     1280     observations[0][0]
_______________________________________________________________________
fc_out (Dense)              (None, 256)     65792    fc_1[0][0]
_______________________________________________________________________
value_out (Dense)           (None, 1)       257      fc_1[0][0]
=======================================================================
Total params: 67,329
Trainable params: 67,329
Non-trainable params: 0
______________________________________________________________________________
"""

# Access the Q value model (specific to DQN)
print(model.get_q_value_distributions(model_out[0])[0])
# tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
# ^ exact numbers may differ due to randomness

print(model.advantage_module)

# Access the state value model (specific to DQN)
print(model.get_state_value(model_out[0]))
# tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
# ^ exact number may differ due to randomness

print(model.value_module)

This is especially useful when used with custom model classes.

Debugging RLlib Experiments#

Eager Mode#

Policies built with build_tf_policy (most of the reference algorithms are) can be run in eager mode by setting the "framework": "tf2" / "eager_tracing": true config options. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.

Eager mode makes debugging much easier, since you can now use line-by-line debugging with breakpoints or Python print() to inspect intermediate tensor values. However, eager can be slower than graph mode unless tracing is enabled.

Episode Traces#

You can use the data output API to save episode traces for debugging. For example, the following command will run PPO while saving episode traces to /tmp/debug.

cd rllib/tuned_examples/ppo
python cartpole_ppo.py --output /tmp/debug

# episode traces will be saved in /tmp/debug, for example
output-2019-02-23_12-02-03_worker-2_0.json
output-2019-02-23_12-02-04_worker-1_0.json

Log Verbosity#

You can control the log level via the "log_level" flag. Valid values are “DEBUG”, “INFO”, “WARN” (default), and “ERROR”. This can be used to increase or decrease the verbosity of internal logging. For example:

cd rllib/tuned_examples/ppo

python atari_ppo.py --env ALE/Pong-v5 --log-level INFO
python atari_ppo.py --env ALE/Pong-v5 --log-level DEBUG

The default log level is WARN. We strongly recommend using at least INFO level logging for development.

Stack Traces#

You can use the ray stack command to dump the stack traces of all the Python workers on a single node. This can be useful for debugging unexpected hangs or performance issues.

Next Steps#

  • To check how your application is doing, you can use the Ray dashboard.