Note
Ray 2.40 uses RLlib’s new API stack by default. The Ray team has mostly completed transitioning algorithms, example scripts, and documentation to the new code base.
If you’re still using the old API stack, see New API stack migration guide for details on how to migrate.
RLlib’s callback APIs#
Overview#
RLlib’s callbacks are the easiest way for you to inject code into your experiments.
In a nutshell, you define the code you want to execute at certain events and pass it to your
AlgorithmConfig
.
Here is an example of defining a simple lambda that prints out an episode’s return after the episode has terminated:
from ray.rllib.algorithms.ppo import PPOConfig
ppo = config = (
PPOConfig()
.environment("CartPole-v1")
.callbacks(
on_episode_end=(
lambda episode, **kw: print(f"Episode done. R={episode.get_return()}")
)
)
.build()
)
ppo.train()
Callback lambdas vs stateful RLlibCallback#
There are two ways to define custom code and have it executed during the various callback events.
Callback lambdas#
If your injected code is rather simple and doesn’t need to store temporary information
for reuse in succeeding event calls, you can use a lambda
and pass it to the callbacks()
method as previously shown.
See here for a complete list of supported callback events.
The names of the events always match the argument names for the
callbacks()
method.
Stateful RLlibCallback#
If you need your code to be stateful and be able to temporarily store results for reuse
in succeeding calls triggered by the same or a different event, you
need to subclass the RLlibCallback
API and then implement
one or more methods, for example on_algorithm_init()
:
Here is the same example, printing out a terminated episode’s return, but using
a subclass of RLlibCallback
.
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.callbacks.callbacks import RLlibCallback
class EpisodeReturn(RLlibCallback):
def __init__(self):
super().__init__()
# Keep some global state in between individual callback events.
self.overall_sum_of_rewards = 0.0
def on_episode_end(self, *, episode, **kwargs):
self.overall_sum_of_rewards += episode.get_return()
print(f"Episode done. R={episode.get_return()} Global SUM={self.overall_sum_of_rewards}")
ppo = (
PPOConfig()
.environment("CartPole-v1")
.callbacks(EpisodeReturn)
.build()
)
ppo.train()
Overview of all callback events#
During a training iteration, the Algorithm normally walks through the following event tree. Note
that some of the events in the tree happen simultaneously, on different processes through Ray actors.
For example an EnvRunner actor may trigger its on_episode_start
event while at the same time another
EnvRunner actor may trigger its on_sample_end
event and the main Algorithm process triggers
on_train_result
.
Note
Currently, RLlib only invokes callbacks in Algorithm
and EnvRunner
actors.
The Ray team is considering expanding callbacks onto Learner
actors and possibly RLModule
instances as well.
Here is a high-level overview of all supported events in RLlib’s callbacks system.
Algorithm
.__init__()
`on_algorithm_init` - After algorithm construction and setup.
.train()
`on_train_result` - After a training iteration.
.evaluate()
`on_evaluate_start` - Before evaluation starts using the eval ``EnvRunnerGroup``.
`on_evaluate_end` - After evaluation is finished.
.restore_from_path()
`on_checkpoint_loaded` - After a checkpoint's new state has been loaded.
EnvRunner
.__init__()
`on_environment_created` - After the RL environment has been created.
.sample()
`on_episode_created` - After a new episode object has been created.
`on_episode_start` - After an episode object has started (after ``env.reset()``).
`on_episode_step` - After an episode object has stepped (after ``env.step()``).
`on_episode_end` - After an episode object has terminated (or truncated).
`on_sample_end` - At the end of the ``EnvRunner.sample()`` call.
Click here to see all Algorithm-bound methods of RLlibCallback
|
Callback run when a new Algorithm instance has finished setup. |
|
Callback before evaluation starts. |
|
Runs when the evaluation is done. |
Callback run after one or more EnvRunner actors have been recreated. |
|
Callback run when an Algorithm has loaded a new state from a checkpoint. |
Click here to see all EnvRunner-bound methods of RLlibCallback
Callback run when a new environment object has been created. |
|
|
Callback run when a new episode is created (but has not started yet!). |
|
Callback run right after an Episode has been started. |
|
Called on each episode step (after the action(s) has/have been logged). |
|
Called when an episode is done (after terminated/truncated have been logged). |
|
Called at the end of |
Chaining callbacks#
You can define more than one RLlibCallback
class and send them in a list to the
callbacks()
method.
You can also send lists of callables, instead of a single callable, to the different
arguments of that method.
For example, assume you already have a subclass of RLlibCallback
written and would like to reuse it in different experiments. However, one of your experiments
requires some debug callback code you would like to inject only temporarily for a couple of runs.
Resolution order of chained callbacks#
RLlib resolves all available callback methods and callables for a given event as follows:
Subclasses of RLlibCallback
take precedence
over individual or lists of callables provided through the various arguments of
the callbacks()
method.
For example, assume the callback event is on_train_result
, which fires at the end of
a training iteration and inside the algorithm’s process.
RLlib loops through the list of all given
RLlibCallback
subclasses and calls theiron_train_result
method. Thereby, it keeps the exact order the user provided in the list.RLlib then loops through the list of all defined
on_train_result
callables. The user configured these by calling thecallbacks()
method and defining theon_train_result
argument in this call.
class MyCallbacks(RLlibCallback):
def on_train_result(self, *, algorithm, metrics_logger, result, **kwargs):
print("RLlibCallback subclass")
class MyDebugCallbacks(RLlibCallback):
def on_train_result(self, *, algorithm, metrics_logger, result, **kwargs):
print("debug subclass")
# Define the callbacks order through the config.
# Subclasses first, then individual `on_train_result` (or other events) callables:
config.callbacks(
callbacks_class=[MyDebugCallbacks, MyCallbacks], # <- note: debug class first
on_train_result=[
lambda algorithm, **kw: print('lambda 1'),
lambda algorithm, **kw: print('lambda 2'),
],
)
# When training the algorithm, after each training iteration, you should see
# something like:
# > debug subclass
# > RLlibCallback subclass
# > lambda 1
# > lambda 2
Examples#
Here are two examples showing you how to setup custom callbacks on the Algorithm process as well as on the EnvRunner processes.
Example 1: on_train_result
#
The following example demonstrates how to implement a simple custom function writing the replay buffer contents to disk from time to time.
You normally don’t want to write the contents of buffers along with your Algorithm checkpoints, so doing this less often, in a more controlled fashion through a custom callback could be a good compromise.
import ormsgpack
from ray.rllib.algorithms.dqn import DQNConfig
def _write_buffer_if_necessary(algorithm, metrics_logger, result):
# Write the buffer contents only every ith iteration.
if algorithm.training_iteration % 2 == 0:
# python dict
buffer_contents = algorithm.local_replay_buffer.get_state()
# binary
msgpacked = ormsgpack.packb(
buffer_contents,
option=ormsgpack.OPT_SERIALIZE_NUMPY,
)
# Open some file and write the buffer contents into it using `ormsgpack`.
with open("replay_buffer_contents.msgpack", "wb") as f:
f.write(msgpacked)
config = (
DQNConfig()
.environment("CartPole-v1")
.callbacks(
on_train_result=_write_buffer_if_necessary,
)
)
dqn = config.build()
# Train n times. Expect buffer to be written every ith iteration.
for _ in range(4):
print(dqn.train())
Tip
See here for the exact call signatures and expected argument types of all available callbacks.
Example 2: on_episode_step
and on_episode_end
#
The following example demonstrates how to implement a custom RLlibCallback
class
computing the average “first-joint angle” of the
Acrobot-v1 RL environment.
The example utilizes RLlib’s MetricsLogger
API to log the custom computations happening in the injected code your Algorithm’s main results system.
Also take a look at this more complex example on how to generate and log a PacMan heatmap (image) to WandB.
import math
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.callbacks.callbacks import RLlibCallback
class LogAcrobotAngle(RLlibCallback):
def on_episode_step(self, *, episode, env, **kwargs):
# First get the angle from the env (note that `env` is a VectorEnv).
# See https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/acrobot.py
# for the env's source code.
cos_theta1, sin_theta1 = env.envs[0].unwrapped.state[0], env.envs[0].unwrapped.state[1]
# Convert cos/sin/tan into degree.
deg_theta1 = math.degrees(math.atan2(sin_theta1, cos_theta1))
# Log the theta1 degree value in the episode object, temporarily.
episode.add_temporary_timestep_data("theta1", deg_theta1)
def on_episode_end(self, *, episode, metrics_logger, **kwargs):
# Get all the logged theta1 degree values and average them.
theta1s = episode.get_temporary_timestep_data("theta1")
avg_theta1 = np.mean(theta1s)
# Log the final result - per episode - to the MetricsLogger.
# Report with a sliding/smoothing window of 50.
metrics_logger.log_value("theta1_mean", avg_theta1, reduce="mean", window=50)
config = (
PPOConfig()
.environment("Acrobot-v1")
.callbacks(
callbacks_class=LogAcrobotAngle,
)
)
ppo = config.build()
# Train n times. Expect `theta1_mean` to be found in the results under:
# `env_runners/theta1_mean`
for i in range(10):
results = ppo.train()
print(
f"iter={i} "
f"theta1_mean={results['env_runners']['theta1_mean']} "
f"R={results['env_runners']['episode_return_mean']}"
)
Tip
You can base your custom logic on whether the calling EnvRunner is a regular “training”
EnvRunner, used to collect training samples, or an evaluation EnvRunner, used to play
through episodes for evaluation only.
Access the env_runner.config.in_evaluation
boolean flag, which is True on
evaluation EnvRunner
actors and False on EnvRunner
actors used to collect
training data.
Tip
See here for the exact call signatures and expected argument types of all available callbacks.