Note

Ray 2.40 uses RLlib’s new API stack by default. The Ray team has mostly completed transitioning algorithms, example scripts, and documentation to the new code base.

If you’re still using the old API stack, see New API stack migration guide for details on how to migrate.

Replay Buffers#

Quick Intro to Replay Buffers in RL#

When we talk about replay buffers in reinforcement learning, we generally mean a buffer that stores and replays experiences collected from interactions of our agent(s) with the environment. In python, a simple buffer can be implemented by a list to which elements are added and later sampled from. Such buffers are used mostly in off-policy learning algorithms. This makes sense intuitively because these algorithms can learn from experiences that are stored in the buffer, but where produced by a previous version of the policy (or even a completely different “behavior policy”).

Sampling Strategy#

When sampling from a replay buffer, we choose which experiences to train our agent with. A straightforward strategy that has proven effective for many algorithms is to pick these samples uniformly at random. A more advanced strategy (proven better in many cases) is Prioritized Experiences Replay (PER). In PER, single items in the buffer are assigned a (scalar) priority value, which denotes their significance, or in simpler terms, how much we expect to learn from these items. Experiences with a higher priority are more likely to be sampled.

Eviction Strategy#

A buffer is naturally limited in its capacity to hold experiences. In the course of running an algorith, a buffer will eventually reach its capacity and in order to make room for new experiences, we need to delete (evict) older ones. This is generally done on a first-in-first-out basis. For your algorithms this means that buffers with a high capacity give the opportunity to learn from older samples, while smaller buffers make the learning process more on-policy. An exception from this strategy is made in buffers that implement reservoir sampling.

Replay Buffers in RLlib#

RLlib comes with a set of extendable replay buffers built in. All the of them support the two basic methods add() and sample(). We provide a base ReplayBuffer class from which you can build your own buffer. In most algorithms, we require MultiAgentReplayBuffers. This is because we want them to generalize to the multi-agent case. Therefore, these buffer’s add() and sample() methods require a policy_id to handle experiences per policy. Have a look at the MultiAgentReplayBuffer to get a sense of how it extends our base class. You can find buffer types and arguments to modify their behaviour as part of RLlib’s default parameters. They are part of the replay_buffer_config.

Basic Usage#

You will rarely have to define your own replay buffer sub-class, when running an experiment, but rather configure existing buffers. The following is from RLlib’s examples section: and runs the R2D2 algorithm with PER (which by default it doesn’t). The highlighted lines focus on the PER configuration.

Executable example script
"""Simple example of how to modify replay buffer behaviour.

We modify DQN to utilize prioritized replay but supplying it with the
PrioritizedMultiAgentReplayBuffer instead of the standard MultiAgentReplayBuffer.
This is possible because DQN uses the DQN training iteration function,
which includes and a priority update, given that a fitting buffer is provided.
"""

import argparse

import ray
from ray import air, tune
from ray.air.constants import TRAINING_ITERATION
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.metrics import NUM_ENV_STEPS_SAMPLED_LIFETIME
from ray.rllib.utils.replay_buffers.replay_buffer import StorageUnit

tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()

parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "torch"],
    default="torch",
    help="The DL framework specifier.",
)
parser.add_argument(
    "--stop-iters", type=int, default=50, help="Number of iterations to train."
)
parser.add_argument(
    "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)

if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(num_cpus=args.num_cpus or None)

    # This is where we add prioritized experiences replay
    # The training iteration function that is used by DQN already includes a priority
    # update step.
    replay_buffer_config = {
        "type": "MultiAgentPrioritizedReplayBuffer",
        # Although not necessary, we can modify the default constructor args of
        # the replay buffer here
        "prioritized_replay_alpha": 0.5,
        "storage_unit": StorageUnit.SEQUENCES,
        "replay_burn_in": 20,
        "zero_init_states": True,
    }

    config = (
        DQNConfig()
        .environment("CartPole-v1")
        .framework(framework=args.framework)
        .env_runners(num_env_runners=4)
        .training(
            model=dict(use_lstm=True, lstm_cell_size=64, max_seq_len=20),
            replay_buffer_config=replay_buffer_config,
        )
    )

    stop_config = {
        NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps,
        TRAINING_ITERATION: args.stop_iters,
    }

    results = tune.Tuner(
        config.algo_class,
        param_space=config,
        run_config=air.RunConfig(stop=stop_config),
    ).fit()

    ray.shutdown()

Tip

Because of its prevalence, most Q-learning algorithms support PER. The priority update step that is needed is embedded into their training iteration functions.

Warning

If your custom buffer requires extra interaction, you will have to change the training iteration function, too!

Specifying a buffer type works the same way as specifying an exploration type. Here are three ways of specifying a type:

Changing a replay buffer configuration
config = (
    DQNConfig()
    .api_stack(
        enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
    )
    .training(replay_buffer_config={"type": ReplayBuffer})
)

another_config = (
    DQNConfig()
    .api_stack(
        enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
    )
    .training(replay_buffer_config={"type": "ReplayBuffer"})
)


yet_another_config = (
    DQNConfig()
    .api_stack(
        enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
    )
    .training(
        replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"}
    )
)

validate_buffer_config(config)
validate_buffer_config(another_config)
validate_buffer_config(yet_another_config)

# After validation, all three configs yield the same effective config
assert (
    config.replay_buffer_config
    == another_config.replay_buffer_config
    == yet_another_config.replay_buffer_config
)

Apart from the type, you can also specify the capacity and other parameters. These parameters are mostly constructor arguments for the buffer. The following categories exist:

  1. Parameters that define how algorithms interact with replay buffers.

    e.g. worker_side_prioritization to decide where to compute priorities

  2. Constructor arguments to instantiate the replay buffer.

    e.g. capacity to limit the buffer’s size

  3. Call arguments for underlying replay buffer methods.

    e.g. prioritized_replay_beta is used by the MultiAgentPrioritizedReplayBuffer to call the sample() method of every underlying PrioritizedReplayBuffer

Tip

Most of the time, only 1. and 2. are of interest. 3. is an advanced feature that supports use cases where a MultiAgentReplayBuffer instantiates underlying buffers that need constructor or default call arguments.

ReplayBuffer Base Class#

The base ReplayBuffer class only supports storing and replaying experiences in different StorageUnits. You can add data to the buffer’s storage with the add() method and replay it with the sample() method. Advanced buffer types add functionality while trying to retain compatibility through inheritance. The following is an example of the most basic scheme of interaction with a ReplayBuffer.

# We choose fragments because it does not impose restrictions on our batch to be added
buffer = ReplayBuffer(capacity=2, storage_unit=StorageUnit.FRAGMENTS)
dummy_batch = SampleBatch({"a": [1], "b": [2]})
buffer.add(dummy_batch)
buffer.sample(2)
# Because elements can be sampled multiple times, we receive a concatenated version
# of dummy_batch `{a: [1, 1], b: [2, 2,]}`.

Building your own ReplayBuffer#

Here is an example of how to implement your own toy example of a ReplayBuffer class and make SimpleQ use it:

class LessSampledReplayBuffer(ReplayBuffer):
    @override(ReplayBuffer)
    def sample(
        self, num_items: int, evict_sampled_more_then: int = 30, **kwargs
    ) -> Optional[SampleBatchType]:
        """Evicts experiences that have been sampled > evict_sampled_more_then times."""
        idxes = [random.randint(0, len(self) - 1) for _ in range(num_items)]
        often_sampled_idxes = list(
            filter(lambda x: self._hit_count[x] >= evict_sampled_more_then, set(idxes))
        )

        sample = self._encode_sample(idxes)
        self._num_timesteps_sampled += sample.count

        for idx in often_sampled_idxes:
            del self._storage[idx]
            self._hit_count = np.append(
                self._hit_count[:idx], self._hit_count[idx + 1 :]
            )

        return sample


config = (
    DQNConfig()
    .api_stack(
        enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
    )
    .environment(env="CartPole-v1")
    .training(replay_buffer_config={"type": LessSampledReplayBuffer})
)

tune.Tuner(
    "DQN",
    param_space=config,
    run_config=air.RunConfig(
        stop={"training_iteration": 1},
    ),
).fit()

For a full implementation, you should consider other methods like get_state() and set_state(). A more extensive example is our implementation of reservoir sampling, the ReservoirReplayBuffer.

Advanced Usage#

In RLlib, all replay buffers implement the ReplayBuffer interface. Therefore, they support, whenever possible, different StorageUnits. The storage_unit constructor argument of a replay buffer defines how experiences are stored, and therefore the unit in which they are sampled. When later calling the sample() method, num_items will relate to said storage_unit.

Here is a full example of how to modify the storage_unit and interact with a custom buffer:

# This line will make our buffer store only complete episodes found in a batch
config.training(replay_buffer_config={"storage_unit": StorageUnit.EPISODES})

less_sampled_buffer = LessSampledReplayBuffer(**config.replay_buffer_config)

# Gather some random experiences
env = RandomEnv()
terminated = truncated = False
batch = SampleBatch({})
t = 0
while not terminated and not truncated:
    obs, reward, terminated, truncated, info = env.step([0, 0])
    # Note that in order for RLlib to find out about start and end of an episode,
    # "t" and "terminateds" have to properly mark an episode's trajectory
    one_step_batch = SampleBatch(
        {
            "obs": [obs],
            "t": [t],
            "reward": [reward],
            "terminateds": [terminated],
            "truncateds": [truncated],
        }
    )
    batch = concat_samples([batch, one_step_batch])
    t += 1

less_sampled_buffer.add(batch)
for i in range(10):
    assert len(less_sampled_buffer._storage) == 1
    less_sampled_buffer.sample(num_items=1, evict_sampled_more_then=9)

assert len(less_sampled_buffer._storage) == 0

As noted above, RLlib’s MultiAgentReplayBuffers support modification of underlying replay buffers. Under the hood, the MultiAgentReplayBuffer stores experiences per policy in separate underlying replay buffers. You can modify their behaviour by specifying an underlying replay_buffer_config that works the same way as the parent’s config.

Here is an example of how to create an MultiAgentReplayBuffer with an alternative underlying ReplayBuffer. The MultiAgentReplayBuffer can stay the same. We only need to specify our own buffer along with a default call argument:

config = (
    DQNConfig()
    .api_stack(
        enable_env_runner_and_connector_v2=False, enable_rl_module_and_learner=False
    )
    .training(
        replay_buffer_config={
            "type": "MultiAgentReplayBuffer",
            "underlying_replay_buffer_config": {
                "type": LessSampledReplayBuffer,
                # We can specify the default call argument
                # for the sample method of the underlying buffer method here.
                "evict_sampled_more_then": 20,
            },
        }
    )
    .environment(env="CartPole-v1")
)

tune.Tuner(
    "DQN",
    param_space=config.to_dict(),
    run_config=air.RunConfig(
        stop={"env_runners/episode_return_mean": 40, "training_iteration": 7},
    ),
).fit()