
Algorithms¶
Tip
Check out the environments page to learn more about different environment types.
Available Algorithms - Overview¶
Algorithm |
Frameworks |
Discrete Actions |
Continuous Actions |
Multi-Agent |
Model Support |
Multi-GPU |
---|---|---|---|---|---|---|
tf + torch |
Yes +parametric |
Yes |
Yes |
A2C: tf + torch |
||
tf + torch |
Yes +parametric |
Yes |
Yes |
No |
||
torch |
Yes +parametric |
No |
No |
No |
||
tf + torch |
Yes +parametric |
Yes |
Yes |
tf + torch |
||
tf + torch |
Yes |
Yes |
No |
No |
||
torch |
Yes +parametric |
No |
Yes |
No |
||
tf + torch |
Yes +parametric |
Yes |
Yes |
torch |
||
tf + torch |
No |
Yes |
No |
tf + torch |
||
torch |
Yes +parametric |
Yes |
Yes |
torch |
||
tf + torch |
No |
Yes |
Yes |
torch |
||
tf + torch |
No |
Yes |
Yes |
torch |
||
tf + torch |
Yes |
Yes |
No |
No |
||
torch |
No |
Yes |
No |
torch |
||
tf + torch |
Yes +parametric |
No |
Yes |
tf + torch |
||
tf + torch |
Yes +parametric |
No |
Yes |
torch |
||
tf + torch |
Yes +parametric |
Yes |
Yes |
tf + torch |
||
tf + torch |
No |
Yes |
No |
torch |
||
tf + torch |
Yes +parametric |
Yes |
Yes |
torch |
||
torch |
No |
Yes |
No |
torch |
||
tf + torch |
Yes +parametric |
Yes |
Yes |
tf + torch |
||
tf + torch |
Yes +parametric |
Yes |
Yes |
tf + torch |
||
tf + torch |
Yes +parametric |
No |
Yes |
torch |
||
tf + torch |
Yes |
Yes |
Yes |
torch |
||
tf + torch |
Yes (multi-discr. slates) |
No |
No |
torch |
||
tf + torch |
No |
Yes |
Yes |
torch |
Multi-Agent only Methods
Algorithm |
Frameworks |
Discrete Actions |
Continuous Actions |
Multi-Agent |
Model Support |
---|---|---|---|---|---|
torch |
Yes +parametric |
No |
Yes |
||
tf |
Yes |
Partial |
Yes |
||
Depends on bootstrapped algorithm |
|||||
Depends on bootstrapped algorithm |
|||||
Depends on bootstrapped algorithm |
Exploration-based plug-ins (can be combined with any algo)
Algorithm |
Frameworks |
Discrete Actions |
Continuous Actions |
Multi-Agent |
Model Support |
---|---|---|---|---|---|
tf + torch |
Yes +parametric |
No |
Yes |
High-throughput architectures¶
Distributed Prioritized Experience Replay (Ape-X)¶
[paper]
[implementation]
Ape-X variations of DQN and DDPG (APEX_DQN, APEX_DDPG) use a single GPU learner and many CPU workers for experience collection. Experience collection can scale to hundreds of CPU workers due to the distributed prioritization of experience prior to storage in replay buffers.
Ape-X architecture¶
Tuned examples: PongNoFrameskip-v4, Pendulum-v1, MountainCarContinuous-v0, {BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4.
Atari results @10M steps: more details
Atari env |
RLlib Ape-X 8-workers |
Mnih et al Async DQN 16-workers |
---|---|---|
BeamRider |
6134 |
~6000 |
Breakout |
123 |
~50 |
Qbert |
15302 |
~1200 |
SpaceInvaders |
686 |
~600 |
Scalability:
Atari env |
RLlib Ape-X 8-workers @1 hour |
Mnih et al Async DQN 16-workers @1 hour |
---|---|---|
BeamRider |
4873 |
~1000 |
Breakout |
77 |
~10 |
Qbert |
4083 |
~500 |
SpaceInvaders |
646 |
~300 |

Ape-X using 32 workers in RLlib vs vanilla DQN (orange) and A3C (blue) on PongNoFrameskip-v4.¶
Ape-X specific configs (see also common configs):
# APEX-DQN settings overriding DQN ones:
# .training()
self.optimizer = merge_dicts(
DQNConfig().optimizer, {
"max_weight_sync_delay": 400,
"num_replay_buffer_shards": 4,
"debug": False
})
self.n_step = 3
self.train_batch_size = 512
self.target_network_update_freq = 500000
self.training_intensity = 1
# max number of inflight requests to each sampling worker
# see the AsyncRequestsManager class for more details
# Tuning these values is important when running experimens with large sample
# batches. If the sample batches are large in size, then there is the risk that
# the object store may fill up, causing the store to spill objects to disk.
# This can cause any asynchronous requests to become very slow, making your
# experiment run slowly. You can inspect the object store during your
# experiment via a call to ray memory on your headnode, and by using the ray
# dashboard. If you're seeing that the object store is filling up, turn down
# the number of remote requests in flight, or enable compression in your
# experiment of timesteps.
self.max_requests_in_flight_per_sampler_worker = 2
self.max_requests_in_flight_per_replay_worker = float("inf")
self.timeout_s_sampler_manager = 0.0
self.timeout_s_replay_manager = 0.0
# APEX-DQN is using a distributed (non local) replay buffer.
self.replay_buffer_config = {
"no_local_replay_buffer": True,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 2000000,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
"learning_starts": 50000,
# Whether all shards of the replay buffer must be co-located
# with the learner process (running the execution plan).
# This is preferred b/c the learner process should have quick
# access to the data from the buffer shards, avoiding network
# traffic each time samples from the buffer(s) are drawn.
# Set this to False for relaxing this constraint and allowing
# replay shards to be created on node(s) other than the one
# on which the learner is located.
"replay_buffer_shards_colocated_with_driver": True,
"worker_side_prioritization": True,
# Deprecated key.
"prioritized_replay": DEPRECATED_VALUE,
}
# .rollouts()
self.num_workers = 32
self.rollout_fragment_length = 50
self.exploration_config = {
"type": "PerWorkerEpsilonGreedy",
}
# .resources()
self.num_gpus = 1
# .reporting()
self.min_time_s_per_iteration = 30
self.min_sample_timesteps_per_iteration = 25000
# fmt: on
Importance Weighted Actor-Learner Architecture (IMPALA)¶
[paper]
[implementation]
In IMPALA, a central learner runs SGD in a tight loop while asynchronously pulling sample batches from many actor processes. RLlib’s IMPALA implementation uses DeepMind’s reference V-trace code. Note that we do not provide a deep residual network out of the box, but one can be plugged in as a custom model. Multiple learner GPUs and experience replay are also supported.
IMPALA architecture¶
Tuned examples: PongNoFrameskip-v4, vectorized configuration, multi-gpu configuration, {BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4
Atari results @10M steps: more details
Atari env |
RLlib IMPALA 32-workers |
Mnih et al A3C 16-workers |
---|---|---|
BeamRider |
2071 |
~3000 |
Breakout |
385 |
~150 |
Qbert |
4068 |
~1000 |
SpaceInvaders |
719 |
~600 |
Scalability:
Atari env |
RLlib IMPALA 32-workers @1 hour |
Mnih et al A3C 16-workers @1 hour |
---|---|---|
BeamRider |
3181 |
~1000 |
Breakout |
538 |
~10 |
Qbert |
10850 |
~500 |
SpaceInvaders |
843 |
~300 |

Multi-GPU IMPALA scales up to solve PongNoFrameskip-v4 in ~3 minutes using a pair of V100 GPUs and 128 CPU workers. The maximum training throughput reached is ~30k transitions per second (~120k environment frames per second).¶
IMPALA-specific configs (see also common configs):
# IMPALA specific settings:
self.vtrace = True
self.vtrace_clip_rho_threshold = 1.0
self.vtrace_clip_pg_rho_threshold = 1.0
self.vtrace_drop_last_ts = True
self.num_multi_gpu_tower_stacks = 1
self.minibatch_buffer_size = 1
self.num_sgd_iter = 1
self.replay_proportion = 0.0
self.replay_ratio = ((1 / self.replay_proportion)
if self.replay_proportion > 0 else 0.0)
self.replay_buffer_num_slots = 0
self.learner_queue_size = 16
self.learner_queue_timeout = 300
self.max_requests_in_flight_per_sampler_worker = 2
self.max_requests_in_flight_per_aggregator_worker = 2
self.timeout_s_sampler_manager = 0.0
self.timeout_s_aggregator_manager = 0.0
self.broadcast_interval = 1
self.num_aggregation_workers = 0
self.grad_clip = 40.0
self.opt_type = "adam"
self.lr_schedule = None
self.decay = 0.99
self.momentum = 0.0
self.epsilon = 0.1
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
self.entropy_coeff_schedule = None
self._separate_vf_optimizer = False
self._lr_vf = 0.0005
self.after_train_step = None
# Override some of AlgorithmConfig's default values with ARS-specific values.
self.rollout_fragment_length = 50
self.train_batch_size = 500
self.num_workers = 2
self.num_gpus = 1
self.lr = 0.0005
self.min_time_s_per_iteration = 10
Asynchronous Proximal Policy Optimization (APPO)¶
[paper]
[implementation]
We include an asynchronous variant of Proximal Policy Optimization (PPO) based on the IMPALA architecture. This is similar to IMPALA but using a surrogate policy loss with clipping. Compared to synchronous PPO, APPO is more efficient in wall-clock time due to its use of asynchronous sampling. Using a clipped loss also allows for multiple SGD passes, and therefore the potential for better sample efficiency compared to IMPALA. V-trace can also be enabled to correct for off-policy samples.
Tip
APPO is not always more efficient; it is often better to use standard PPO or IMPALA.
APPO architecture (same as IMPALA)¶
Tuned examples: PongNoFrameskip-v4
APPO-specific configs (see also common configs):
# APPO specific settings:
self.vtrace = True
self.use_critic = True
self.use_gae = True
self.lambda_ = 1.0
self.clip_param = 0.4
self.use_kl_loss = False
self.kl_coeff = 1.0
self.kl_target = 0.01
# Override some of ImpalaConfig's default values with APPO-specific values.
self.rollout_fragment_length = 50
self.train_batch_size = 500
self.min_time_s_per_iteration = 10
self.num_workers = 2
self.num_gpus = 0
self.num_multi_gpu_tower_stacks = 1
self.minibatch_buffer_size = 1
self.num_sgd_iter = 1
self.replay_proportion = 0.0
self.replay_buffer_num_slots = 100
self.learner_queue_size = 16
self.learner_queue_timeout = 300
self.max_sample_requests_in_flight_per_worker = 2
self.broadcast_interval = 1
self.grad_clip = 40.0
self.opt_type = "adam"
self.lr = 0.0005
self.lr_schedule = None
self.decay = 0.99
self.momentum = 0.0
self.epsilon = 0.1
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
self.entropy_coeff_schedule = None
Decentralized Distributed Proximal Policy Optimization (DD-PPO)¶
[paper]
[implementation]
Unlike APPO or PPO, with DD-PPO policy improvement is no longer done centralized in the algorithm process. Instead, gradients are computed remotely on each rollout worker and all-reduced at each mini-batch using torch distributed. This allows each worker’s GPU to be used both for sampling and for training.
Tip
DD-PPO is best for envs that require GPUs to function, or if you need to scale out SGD to multiple nodes. If you don’t meet these requirements, standard PPO will be more efficient.
DD-PPO architecture (both sampling and learning are done on worker GPUs)¶
Tuned examples: CartPole-v0, BreakoutNoFrameskip-v4
DDPPO-specific configs (see also common configs):
# DD-PPO specific settings:
self.keep_local_weights_in_sync = True
self.torch_distributed_backend = "gloo"
# Override some of PPO/Algorithm's default values with DDPPO-specific values.
# During the sampling phase, each rollout worker will collect a batch
# `rollout_fragment_length * num_envs_per_worker` steps in size.
self.rollout_fragment_length = 100
# Vectorize the env (should enable by default since each worker has
# a GPU).
self.num_envs_per_worker = 5
# During the SGD phase, workers iterate over minibatches of this size.
# The effective minibatch size will be:
# `sgd_minibatch_size * num_workers`.
self.sgd_minibatch_size = 50
# Number of SGD epochs per optimization round.
self.num_sgd_iter = 10
# *** WARNING: configs below are DDPPO overrides over PPO; you
# shouldn't need to adjust them. ***
# DDPPO requires PyTorch distributed.
self.framework_str = "torch"
# Learning is no longer done on the driver process, so
# giving GPUs to the driver does not make sense!
self.num_gpus = 0
# Each rollout worker gets a GPU.
self.num_gpus_per_worker = 1
# This is auto set based on sample batch size.
self.train_batch_size = -1
# Kl divergence penalty should be fixed to 0 in DDPPO because in order
# for it to be used as a penalty, we would have to un-decentralize
# DDPPO
self.kl_coeff = 0.0
self.kl_target = 0.0
Gradient-based¶
Advantage Actor-Critic (A2C)¶
[paper] [implementation]
A2C scales to 16-32+ worker processes depending on the environment and supports microbatching
(i.e., gradient accumulation), which can be enabled by setting the
microbatch_size
config.
Microbatching allows for training with a train_batch_size
much larger than GPU memory.
A2C architecture¶
Tuned examples: Atari environments
Tip
Consider using IMPALA for faster training with similar timestep efficiency.
Atari results @10M steps: more details
Atari env |
RLlib A2C 5-workers |
Mnih et al A3C 16-workers |
---|---|---|
BeamRider |
1401 |
~3000 |
Breakout |
374 |
~150 |
Qbert |
3620 |
~1000 |
SpaceInvaders |
692 |
~600 |
A2C-specific configs (see also common configs):
# A2C specific settings:
self.microbatch_size = None
# Override some of A3CConfig's default values with A2C-specific values.
self.rollout_fragment_length = 20
self.sample_async = False
self.min_time_s_per_iteration = 10
Asynchronous Advantage Actor-Critic (A3C)¶
[paper] [implementation]
A3C is the asynchronous version of A2C, where gradients are computed on the workers directly after trajectory rollouts,
and only then shipped to a central learner to accumulate these gradients on the central model. After the central model update, parameters are broadcast back to
all workers.
Similar to A2C, A3C scales to 16-32+ worker processes depending on the environment.
Tuned examples: PongDeterministic-v4
Tip
Consider using IMPALA for faster training with similar timestep efficiency.
A3C-specific configs (see also common configs):
#
# A3C specific settings.
self.use_critic = True
self.use_gae = True
self.lambda_ = 1.0
self.grad_clip = 40.0
self.lr_schedule = None
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.01
self.entropy_coeff_schedule = None
self.sample_async = True
# Override some of AlgorithmConfig's default values with PPO-specific values.
self.rollout_fragment_length = 10
self.lr = 0.0001
# Min time (in seconds) per reporting.
# This causes not every call to `training_iteration` to be reported,
# but to wait until n seconds have passed and then to summarize the
# thus far collected results.
self.min_time_s_per_iteration = 5
Deep Deterministic Policy Gradients (DDPG)¶
[paper]
[implementation]
DDPG is implemented similarly to DQN (below). The algorithm can be scaled by increasing the number of workers or using Ape-X.
The improvements from TD3 are available as
TD3
.
DDPG architecture (same as DQN)¶
Tuned examples: Pendulum-v1, MountainCarContinuous-v0, HalfCheetah-v2.
DDPG-specific configs (see also common configs):
# DDPG-specific settings.
self.twin_q = False
self.policy_delay = 1
self.smooth_target_policy = False
self.target_noise = 0.2
self.target_noise_clip = 0.5
self.use_state_preprocessor = False
self.actor_hiddens = [400, 300]
self.actor_hidden_activation = "relu"
self.critic_hiddens = [400, 300]
self.critic_hidden_activation = "relu"
self.n_step = 1
self.training_intensity = None
self.critic_lr = 1e-3
self.actor_lr = 1e-3
self.tau = 0.002
self.use_huber = False
self.huber_threshold = 1.0
self.l2_reg = 1e-6
# Override some of SimpleQ's default values with DDPG-specific values.
# .exploration()
self.exploration_config = {
# DDPG uses OrnsteinUhlenbeck (stateful) noise to be added to NN-output
# actions (after a possible pure random phase of n timesteps).
"type": "OrnsteinUhlenbeckNoise",
# For how many timesteps should we return completely random actions,
# before we start adding (scaled) noise?
"random_timesteps": 1000,
# The OU-base scaling factor to always apply to action-added noise.
"ou_base_scale": 0.1,
# The OU theta param.
"ou_theta": 0.15,
# The OU sigma param.
"ou_sigma": 0.2,
# The initial noise scaling factor.
"initial_scale": 1.0,
# The final noise scaling factor.
"final_scale": 0.02,
# Timesteps over which to anneal scale (from initial to final values).
"scale_timesteps": 10000,
}
# Common DDPG buffer parameters.
self.replay_buffer_config = {
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 50000,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
# Alpha parameter for prioritized replay buffer.
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
}
# .training()
self.grad_clip = None
self.train_batch_size = 256
self.target_network_update_freq = 0
# .rollouts()
self.rollout_fragment_length = 1
self.compress_observations = False
Twin Delayed DDPG (TD3)¶
[paper]
[implementation]
TD3 represents an improvement over DDPG. Its implementation is available in RLlib as TD3.
Tuned examples: TD3 Pendulum-v1, TD3 InvertedPendulum-v2, TD3 Mujoco suite (Ant-v2, HalfCheetah-v2, Hopper-v2, Walker2d-v2).
TD3-specific configs (see also common configs):
# Override some of DDPG/SimpleQ/Algorithm's default values with TD3-specific
# values.
# .training()
# largest changes: twin Q functions, delayed policy updates, target
# smoothing, no l2-regularization.
self.twin_q = True
self.policy_delay = 2
self.smooth_target_policy = True,
self.l2_reg = 0.0
# Different tau (affecting target network update).
self.tau = 5e-3
# Different batch size.
self.train_batch_size = 100
# No prioritized replay by default (we may want to change this at some
# point).
self.replay_buffer_config = {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": 1000000,
"learning_starts": 10000,
"worker_side_prioritization": False,
}
# .exploration()
# TD3 uses Gaussian Noise by default.
self.exploration_config = {
# TD3 uses simple Gaussian noise on top of deterministic NN-output
# actions (after a possible pure random phase of n timesteps).
"type": "GaussianNoise",
# For how many timesteps should we return completely random
# actions, before we start adding (scaled) noise?
"random_timesteps": 10000,
# Gaussian stddev of action noise for exploration.
"stddev": 0.1,
# Scaling settings by which the Gaussian noise is scaled before
# being added to the actions. NOTE: The scale timesteps start only
# after(!) any random steps have been finished.
# By default, do not anneal over time (fixed 1.0).
"initial_scale": 1.0,
"final_scale": 1.0,
"scale_timesteps": 1,
}
Deep Q Networks (DQN, Rainbow, Parametric DQN)¶
[paper] [implementation]
DQN can be scaled by increasing the number of workers or using Ape-X. Memory usage is reduced by compressing samples in the replay buffer with LZ4. All of the DQN improvements evaluated in Rainbow are available, though not all are enabled by default. See also how to use parametric-actions in DQN.
DQN architecture¶
Tuned examples: PongDeterministic-v4, Rainbow configuration, {BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4, with Dueling and Double-Q, with Distributional DQN.
Tip
Consider using Ape-X for faster training with similar timestep efficiency.
Hint
For a complete rainbow setup,
make the following changes to the default DQN config:
"n_step": [between 1 and 10],
"noisy": True,
"num_atoms": [more than 1],
"v_min": -10.0,
"v_max": 10.0
(set v_min
and v_max
according to your expected range of returns).
Atari results @10M steps: more details
Atari env |
RLlib DQN |
RLlib Dueling DDQN |
RLlib Dist. DQN |
Hessel et al. DQN |
---|---|---|---|---|
BeamRider |
2869 |
1910 |
4447 |
~2000 |
Breakout |
287 |
312 |
410 |
~150 |
Qbert |
3921 |
7968 |
15780 |
~4000 |
SpaceInvaders |
650 |
1001 |
1025 |
~500 |
DQN-specific configs (see also common configs):
self.num_atoms = 1
self.v_min = -10.0
self.v_max = 10.0
self.noisy = False
self.sigma0 = 0.5
self.dueling = True
self.hiddens = [256]
self.double_q = True
self.n_step = 1
self.before_learn_on_batch = None
self.training_intensity = None
# Changes to SimpleQConfig's default:
self.replay_buffer_config = {
"type": "MultiAgentPrioritizedReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
# Size of the replay buffer. Note that if async_updates is set,
# then each worker will have a replay buffer of this size.
"capacity": 50000,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
}
# fmt: on
Recurrent Replay Distributed DQN (R2D2)¶
[paper] [implementation]
R2D2 can be scaled by increasing the number of workers. All of the DQN improvements evaluated in Rainbow are available, though not all are enabled by default.
Tuned examples: CartPole-v0
Policy Gradients¶
[paper]
[implementation]
We include a vanilla policy gradients implementation as an example algorithm.
Policy gradients architecture (same as A2C)¶
Tuned examples: CartPole-v0
PG-specific configs (see also common configs):
# Override some of AlgorithmConfig's default values with PG-specific values.
self.num_workers = 0
self.lr = 0.0004
self._disable_preprocessor_api = True
Proximal Policy Optimization (PPO)¶
[paper]
[implementation]
PPO’s clipped objective supports multiple SGD passes over the same batch of experiences. RLlib’s multi-GPU optimizer pins that data in GPU memory to avoid unnecessary transfers from host memory, substantially improving performance over a naive implementation. PPO scales out using multiple workers for experience collection, and also to multiple GPUs for SGD.
Tip
If you need to scale out with GPUs on multiple nodes, consider using decentralized PPO.
PPO architecture¶
Tuned examples: Unity3D Soccer (multi-agent: Strikers vs Goalie), Humanoid-v1, Hopper-v1, Pendulum-v1, PongDeterministic-v4, Walker2d-v1, HalfCheetah-v2, {BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4
Atari results: more details
Atari env |
RLlib PPO @10M |
RLlib PPO @25M |
Baselines PPO @10M |
---|---|---|---|
BeamRider |
2807 |
4480 |
~1800 |
Breakout |
104 |
201 |
~250 |
Qbert |
11085 |
14247 |
~14000 |
SpaceInvaders |
671 |
944 |
~800 |
Scalability: more details
MuJoCo env |
RLlib PPO 16-workers @ 1h |
Fan et al PPO 16-workers @ 1h |
---|---|---|
HalfCheetah |
9664 |
~7700 |

RLlib’s multi-GPU PPO scales to multiple GPUs and hundreds of CPUs on solving the Humanoid-v1 task. Here we compare against a reference MPI-based implementation.¶
PPO-specific configs (see also common configs):
# PPO specific settings:
self.lr_schedule = None
self.use_critic = True
self.use_gae = True
self.lambda_ = 1.0
self.kl_coeff = 0.2
self.sgd_minibatch_size = 128
self.num_sgd_iter = 30
self.shuffle_sequences = True
self.vf_loss_coeff = 1.0
self.entropy_coeff = 0.0
self.entropy_coeff_schedule = None
self.clip_param = 0.3
self.vf_clip_param = 10.0
self.grad_clip = None
self.kl_target = 0.01
# Override some of AlgorithmConfig's default values with PPO-specific values.
self.rollout_fragment_length = 200
self.train_batch_size = 4000
self.lr = 5e-5
self.model["vf_share_layers"] = False
Soft Actor Critic (SAC)¶
[original paper], [follow up paper], [discrete actions paper]
[implementation]
SAC architecture (same as DQN)¶
RLlib’s soft-actor critic implementation is ported from the official SAC repo to better integrate with RLlib APIs.
Note that SAC has two fields to configure for custom models: policy_model_config
and q_model_config
, the model
field of the config will be ignored.
Tuned examples (continuous actions): Pendulum-v1, HalfCheetah-v3, Tuned examples (discrete actions): CartPole-v0
MuJoCo results @3M steps: more details
MuJoCo env |
RLlib SAC |
Haarnoja et al SAC |
---|---|---|
HalfCheetah |
13000 |
~15000 |
SAC-specific configs (see also common configs):
# SAC-specific config settings.
self.twin_q = True
self.q_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define custom Q-model(s).
"custom_model_config": {},
}
self.policy_model_config = {
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"custom_model": None, # Use this to define a custom policy model.
"custom_model_config": {},
}
self.clip_actions = False
self.tau = 5e-3
self.initial_alpha = 1.0
self.target_entropy = "auto"
self.n_step = 1
self.replay_buffer_config = {
"_enable_replay_buffer_api": True,
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": int(1e6),
# How many steps of the model to sample before learning starts.
"learning_starts": 1500,
# If True prioritized replay buffer will be used.
"prioritized_replay": False,
"prioritized_replay_alpha": 0.6,
"prioritized_replay_beta": 0.4,
"prioritized_replay_eps": 1e-6,
# Whether to compute priorities already on the remote worker side.
"worker_side_prioritization": False,
}
self.store_buffer_in_checkpoints = False
self.training_intensity = None
self.optimization = {
"actor_learning_rate": 3e-4,
"critic_learning_rate": 3e-4,
"entropy_learning_rate": 3e-4,
}
self.grad_clip = None
self.target_network_update_freq = 0
# .rollout()
self.rollout_fragment_length = 1
self.compress_observations = False
# .training()
self.train_batch_size = 256
# .reporting()
self.min_time_s_per_iteration = 1
self.min_sample_timesteps_per_iteration = 100
Model-Agnostic Meta-Learning (MAML)¶
RLlib’s MAML implementation is a meta-learning method for learning and quick adaptation across different tasks for continuous control. Code here is adapted from https://github.com/jonasrothfuss, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail here.
MAML uses additional metrics to measure performance; episode_reward_mean
measures the agent’s returns before adaptation, episode_reward_mean_adapt_N
measures the agent’s returns after N gradient steps of inner adaptation, and adaptation_delta
measures the difference in performance before and after adaptation. Examples can be seen here.
Tuned examples: HalfCheetahRandDirecEnv (Env, Config), AntRandGoalEnv (Env, Config), PendulumMassEnv (Env, Config)
MAML-specific configs (see also common configs):
# MAML-specific config settings.
self.use_gae = True
self.lambda_ = 1.0
self.kl_coeff = 0.0005
self.vf_loss_coeff = 0.5
self.entropy_coeff = 0.0
self.clip_param = 0.3
self.vf_clip_param = 10.0
self.grad_clip = None
self.kl_target = 0.01
self.inner_adaptation_steps = 1
self.maml_optimizer_steps = 5
self.inner_lr = 0.1
self.use_meta_env = True
# Override some of AlgorithmConfig's default values with MAML-specific values.
self.rollout_fragment_length = 200
self.create_env_on_local_worker = True
self.lr = 1e-3
# Share layers for value function.
self.model.update({
"vf_share_layers": False,
})
self.batch_mode = "complete_episodes"
self._disable_execution_plan_api = False
Model-Based Meta-Policy-Optimization (MB-MPO)¶
RLlib’s MBMPO implementation is a Dyna-styled model-based RL method that learns based on the predictions of an ensemble of transition-dynamics models. Similar to MAML, MBMPO metalearns an optimal policy by treating each dynamics model as a different task. Code here is adapted from https://github.com/jonasrothfuss/model_ensemble_meta_learning. Similar to the original paper, MBMPO is evaluated on MuJoCo, with the horizon set to 200 instead of the default 1000.
Additional statistics are logged in MBMPO. Each MBMPO iteration corresponds to multiple MAML iterations, and MAMLIter$i$_DynaTrajInner_$j$_episode_reward_mean
measures the agent’s returns across the dynamics models at iteration i
of MAML and step j
of inner adaptation. Examples can be seen here.
Tuned examples (continuous actions): Pendulum-v1, HalfCheetah, Hopper, Tuned examples (discrete actions): CartPole-v0
MuJoCo results @100K steps: more details
MuJoCo env |
RLlib MBMPO |
Clavera et al MBMPO |
---|---|---|
HalfCheetah |
520 |
~550 |
Hopper |
620 |
~650 |
MBMPO-specific configs (see also common configs):
# MBMPO specific config settings:
# If true, use the Generalized Advantage Estimator (GAE)
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
self.use_gae = True
# GAE(lambda) parameter.
self.lambda_ = 1.0
# Initial coefficient for KL divergence.
self.kl_coeff = 0.0005
# Coefficient of the value function loss.
self.vf_loss_coeff = 0.5
# Coefficient of the entropy regularizer.
self.entropy_coeff = 0.0
# PPO clip parameter.
self.clip_param = 0.5
# Clip param for the value function. Note that this is sensitive to the
# scale of the rewards. If your expected V is large, increase this.
self.vf_clip_param = 10.0
# If specified, clip the global norm of gradients by this amount.
self.grad_clip = None
# Target value for KL divergence.
self.kl_target = 0.01
# Number of Inner adaptation steps for the MAML algorithm.
self.inner_adaptation_steps = 1
# Number of MAML steps per meta-update iteration (PPO steps).
self.maml_optimizer_steps = 8
# Inner adaptation step size.
self.inner_lr = 1e-3
# Horizon of the environment (200 in MB-MPO paper).
self.horizon = 200
# Dynamics ensemble hyperparameters.
self.dynamics_model = {
"custom_model": DynamicsEnsembleCustomModel,
# Number of Transition-Dynamics (TD) models in the ensemble.
"ensemble_size": 5,
# Hidden layers for each model in the TD-model ensemble.
"fcnet_hiddens": [512, 512, 512],
# Model learning rate.
"lr": 1e-3,
# Max number of training epochs per MBMPO iter.
"train_epochs": 500,
# Model batch size.
"batch_size": 500,
# Training/validation split.
"valid_split_ratio": 0.2,
# Normalize data (obs, action, and deltas).
"normalize_data": True,
}
# Workers sample from dynamics models, not from actual envs.
self.custom_vector_env = model_vector_env
# How many iterations through MAML per MBMPO iteration.
self.num_maml_steps = 10
# Override some of AlgorithmConfig's default values with MBMPO-specific
# values.
self.batch_mode = "complete_episodes"
# Size of batches collected from each worker.
self.rollout_fragment_length = 200
# Do create an actual env on the local worker (worker-idx=0).
self.create_env_on_local_worker = True
# Step size of SGD.
self.lr = 1e-3
# Exploration for MB-MPO is based on StochasticSampling, but uses 8000
# random timesteps up-front for worker=0.
self.exploration_config = {
"type": MBMPOExploration,
"random_timesteps": 8000,
}
Dreamer¶
Dreamer is an image-only model-based RL method that learns by imagining trajectories in the future and is evaluated on the DeepMind Control Suite environments. RLlib’s Dreamer is adapted from the official Google research repo.
To visualize learning, RLlib Dreamer’s imagined trajectories are logged as gifs in TensorBoard. Examples of such can be seen here.
Tuned examples: Deepmind Control Environments
Deepmind Control results @1M steps: more details
DMC env |
RLlib Dreamer |
Danijar et al Dreamer |
---|---|---|
Walker-Walk |
920 |
~930 |
Cheetah-Run |
640 |
~800 |
Dreamer-specific configs (see also common configs):
# Dreamer specific settings:
self.td_model_lr = 6e-4
self.actor_lr = 8e-5
self.critic_lr = 8e-5
self.grad_clip = 100.0
self.lambda_ = 0.95
self.dreamer_train_iters = 100
self.batch_size = 50
self.batch_length = 50
self.imagine_horizon = 15
self.free_nats = 3.0
self.kl_coeff = 1.0
self.prefill_timesteps = 5000
self.explore_noise = 0.3
self.dreamer_model = {
"custom_model": DreamerModel,
# RSSM/PlaNET parameters
"deter_size": 200,
"stoch_size": 30,
# CNN Decoder Encoder
"depth_size": 32,
# General Network Parameters
"hidden_size": 400,
# Action STD
"action_init_std": 5.0,
}
# Override some of AlgorithmConfig's default values with PPO-specific values.
# .rollouts()
self.num_workers = 0
self.num_envs_per_worker = 1
self.horizon = 1000
self.batch_mode = "complete_episodes"
self.clip_actions = False
# .training()
self.gamma = 0.99
# .environment()
self.env_config = {
# Repeats action send by policy for frame_skip times in env
"frame_skip": 2,
}
SlateQ¶
SlateQ is a model-free RL method that builds on top of DQN and generates recommendation slates for recommender system environments. Since these types of environments come with large combinatorial action spaces, SlateQ mitigates this by decomposing the Q-value into single-item Q-values and solves the decomposed objective via mixing integer programming and deep learning optimization. SlateQ can be evaluated on Google’s RecSim environment. An RLlib wrapper for RecSim can be found here <.
RecSim environment wrapper: Google RecSim
SlateQ-specific configs (see also common configs):
# SlateQ specific settings:
self.fcnet_hiddens_per_candidate = [256, 32]
self.target_network_update_freq = 3200
self.tau = 1.0
self.use_huber = False
self.huber_threshold = 1.0
self.training_intensity = None
self.lr_schedule = None
self.lr_choice_model = 1e-3
self.rmsprop_epsilon = 1e-5
self.grad_clip = None
self.n_step = 1
self.replay_buffer_config = {
"type": "MultiAgentPrioritizedReplayBuffer",
"capacity": 100000,
"prioritized_replay_alpha": 0.6,
# Beta parameter for sampling from prioritized replay buffer.
"prioritized_replay_beta": 0.4,
# Epsilon to add to the TD errors when updating priorities.
"prioritized_replay_eps": 1e-6,
# The number of continuous environment steps to replay at once. This may
# be set to greater than 1 to support recurrent models.
"replay_sequence_length": 1,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# How many steps of the model to sample before learning starts.
"learning_starts": 20000,
}
# Override some of AlgorithmConfig's default values with SlateQ-specific values.
self.exploration_config = {
# The Exploration class to use.
# Must be SlateEpsilonGreedy or SlateSoftQ to handle the problem that
# the action space of the policy is different from the space used inside
# the exploration component.
# E.g.: action_space=MultiDiscrete([5, 5]) <- slate-size=2, num-docs=5,
# but action distribution is Categorical(5*4) -> all possible unique slates.
"type": "SlateEpsilonGreedy",
"warmup_timesteps": 20000,
"epsilon_timesteps": 250000,
"final_epsilon": 0.01,
}
# Switch to greedy actions in evaluation workers.
self.evaluation_config = {"explore": False}
self.num_workers = 0
self.rollout_fragment_length = 4
self.train_batch_size = 32
self.lr = 0.00025
self.min_sample_timesteps_per_iteration = 1000
self.min_time_s_per_iteration = 1
self.compress_observations = False
self._disable_preprocessor_api = True
Conservative Q-Learning (CQL)¶
In offline RL, the algorithm has no access to an environment, but can only sample from a fixed dataset of pre-collected state-action-reward tuples. In particular, CQL (Conservative Q-Learning) is an offline RL algorithm that mitigates the overestimation of Q-values outside the dataset distribution via conservative critic estimates. It does so by adding a simple Q regularizer loss to the standard Bellman update loss. This ensures that the critic does not output overly-optimistic Q-values. This conservative correction term can be added on top of any off-policy Q-learning algorithm (here, we provide this for SAC).
RLlib’s CQL is evaluated against the Behavior Cloning (BC) benchmark at 500K gradient steps over the dataset. The only difference between the BC- and CQL configs is the bc_iters
parameter in CQL, indicating how many gradient steps we perform over the BC loss. CQL is evaluated on the D4RL benchmark, which has pre-collected offline datasets for many types of environments.
Tuned examples: HalfCheetah Random, Hopper Random
CQL-specific configs (see also common configs):
# CQL-specific config settings:
self.bc_iters = 20000
self.temperature = 1.0
self.num_actions = 10
self.lagrangian = False
self.lagrangian_thresh = 5.0
self.min_q_weight = 5.0
# Changes to Trainer's/SACConfig's default:
# .offline_data()
self.off_policy_estimation_methods = {}
# .reporting()
self.min_sample_timesteps_per_iteration = 0
self.min_train_timesteps_per_iteration = 100
# fmt: on
Critic Regularized Regression (CRR)¶
CRR is another offline RL algorithm based on Q-learning that can learn from an offline experience replay. The challenge in applying existing Q-learning algorithms to offline RL lies in the overestimation of the Q-function, as well as, the lack of exploration beyond the observed data. The latter becomes increasingly important during bootstrapping in the bellman equation, where the Q-function queried for the next state’s Q-value(s) does not have support in the observed data. To mitigate these issues, CRR implements a simple and yet powerful idea of “value-filtered regression”. The key idea is to use a learned critic to filter-out the non-promising transitions from the replay dataset. For more details, please refer to the paper (see link above).
Tuned examples: CartPole-v0, Pendulum-v1
# CRR-specific settings.
self.weight_type = "bin"
self.temperature = 1.0
self.max_weight = 20.0
self.advantage_type = "mean"
self.n_action_sample = 4
self.twin_q = True
self.target_update_grad_intervals = 100
Derivative-free¶
Augmented Random Search (ARS)¶
[paper] [implementation]
ARS is a random search method for training linear policies for continuous control problems. Code here is adapted from https://github.com/modestyachts/ARS to integrate with RLlib APIs.
Tuned examples: CartPole-v0, Swimmer-v2
ARS-specific configs (see also common configs):
# ARS specific settings:
self.action_noise_std = 0.0
self.noise_stdev = 0.02
self.num_rollouts = 32
self.rollouts_used = 32
self.sgd_stepsize = 0.01
self.noise_size = 250000000
self.eval_prob = 0.03
self.report_length = 10
self.offset = 0
# Override some of AlgorithmConfig's default values with ARS-specific values.
self.num_workers = 2
self.observation_filter = "MeanStdFilter"
# ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0).
# Therefore, we must be careful not to use more than 1 env per eval worker
# (would break ARSPolicy's compute_single_action method) and to not do
# obs-filtering.
self.evaluation_config["num_envs_per_worker"] = 1
self.evaluation_config["observation_filter"] = "NoFilter"
Evolution Strategies¶
[paper] [implementation]
Code here is adapted from https://github.com/openai/evolution-strategies-starter to execute in the distributed setting with Ray.
Tuned examples: Humanoid-v1
Scalability:

RLlib’s ES implementation scales further and is faster than a reference Redis implementation on solving the Humanoid-v1 task.¶
ES-specific configs (see also common configs):
# ES specific settings:
self.action_noise_std = 0.01
self.l2_coeff = 0.005
self.noise_stdev = 0.02
self.episodes_per_batch = 1000
self.eval_prob = 0.03
# self.return_proc_mode = "centered_rank" # only supported return_proc_mode
self.stepsize = 0.01
self.noise_size = 250000000
self.report_length = 10
# Override some of AlgorithmConfig's default values with ES-specific values.
self.train_batch_size = 10000
self.num_workers = 10
self.observation_filter = "MeanStdFilter"
# ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0).
# Therefore, we must be careful not to use more than 1 env per eval worker
# (would break ARSPolicy's compute_single_action method) and to not do
# obs-filtering.
self.evaluation_config["num_envs_per_worker"] = 1
self.evaluation_config["observation_filter"] = "NoFilter"
Monotonic Advantage Re-Weighted Imitation Learning (MARWIL)¶
MARWIL is a hybrid imitation learning and policy gradient algorithm suitable for training on batched historical data.
When the beta
hyperparameter is set to zero, the MARWIL objective reduces to vanilla imitation learning (see BC).
MARWIL requires the offline datasets API to be used.
Tuned examples: CartPole-v0
MARWIL-specific configs (see also common configs):
# MARWIL specific settings:
self.beta = 1.0
self.bc_logstd_coeff = 0.0
self.moving_average_sqd_adv_norm_update_rate = 1e-8
self.moving_average_sqd_adv_norm_start = 100.0
self.replay_buffer_config = {
"type": "MultiAgentPrioritizedReplayBuffer",
# Size of the replay buffer in (single and independent) timesteps.
# The buffer gets filled by reading from the input files line-by-line
# and adding all timesteps on one line at once. We then sample
# uniformly from the buffer (`train_batch_size` samples) for
# each training step.
"capacity": 10000,
# Specify prioritized replay by supplying a buffer type that supports
# prioritization
"prioritized_replay": DEPRECATED_VALUE,
# Number of steps to read before learning starts.
"learning_starts": 0,
"replay_sequence_length": 1
}
self.use_gae = True
self.vf_coeff = 1.0
self.grad_clip = None
# Override some of AlgorithmConfig's default values with MARWIL-specific values.
# You should override input_ to point to an offline dataset
# (see trainer.py and trainer_config.py).
# The dataset may have an arbitrary number of timesteps
# (and even episodes) per line.
# However, each line must only contain consecutive timesteps in
# order for MARWIL to be able to calculate accumulated
# discounted returns. It is ok, though, to have multiple episodes in
# the same line.
self.input_ = "sampler"
# Use importance sampling estimators for reward.
self.off_policy_estimation_methods = {
"is": {"type": ImportanceSampling},
"wis": {"type": WeightedImportanceSampling},
}
self.postprocess_inputs = True
self.lr = 1e-4
self.train_batch_size = 2000
self.num_workers = 0
Behavior Cloning (BC; derived from MARWIL implementation)¶
Our behavioral cloning implementation is directly derived from our MARWIL implementation,
with the only difference being the beta
parameter force-set to 0.0. This makes
BC try to match the behavior policy, which generated the offline data, disregarding any resulting rewards.
BC requires the offline datasets API to be used.
Tuned examples: CartPole-v0
BC-specific configs (see also common configs):
# No need to calculate advantages (or do anything else with the rewards).
self.beta = 0.0
# Advantages (calculated during postprocessing)
# not important for behavioral cloning.
self.postprocess_inputs = False
# No reward estimation.
self.off_policy_estimation_methods = {}
Contextual Bandits¶
The Multi-armed bandit (MAB) problem provides a simplified RL setting that involves learning to act under one situation only, i.e. the context (observation/state) and arms (actions/items-to-select) are both fixed. Contextual bandit is an extension of the MAB problem, where at each round the agent has access not only to a set of bandit arms/actions but also to a context (state) associated with this iteration. The context changes with each iteration, but, is not affected by the action that the agent takes. The objective of the agent is to maximize the cumulative rewards, by collecting enough information about how the context and the rewards of the arms are related to each other. The agent does this by balancing the trade-off between exploration and exploitation.
Contextual bandit algorithms typically consist of an action-value model (Q model) and an exploration strategy (epsilon-greedy, LinUCB, Thompson Sampling etc.)
RLlib supports the following online contextual bandit algorithms, named after the exploration strategies that they employ:
Linear Upper Confidence Bound (BanditLinUCB)¶
[paper] [implementation]
LinUCB assumes a linear dependency between the expected reward of an action and
its context. It estimates the Q value of each action using ridge regression.
It constructs a confidence region around the weights of the linear
regression model and uses this confidence ellipsoid to estimate the
uncertainty of action values.
Tuned examples: SimpleContextualBandit, UCB Bandit on RecSim. ParametricItemRecoEnv.
LinUCB-specific configs (see also common configs):
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.framework_str = "torch"
self.num_workers = 0
self.rollout_fragment_length = 1
self.train_batch_size = 1
# Make sure, a `train()` call performs at least 100 env sampling
# timesteps, before reporting results. Not setting this (default is 0)
# would significantly slow down the Bandit Algorithm.
self.min_sample_timesteps_per_iteration = 100
Linear Thompson Sampling (BanditLinTS)¶
[paper]
[implementation]
Like LinUCB, LinTS also assumes a linear dependency between the expected
reward of an action and its context and uses online ridge regression to
estimate the Q values of actions given the context. It assumes a Gaussian
prior on the weights and a Gaussian likelihood function. For deciding which
action to take, the agent samples weights for each arm, using
the posterior distributions, and plays the arm that produces the highest reward.
Tuned examples: SimpleContextualBandit, WheelBandit.
LinTS-specific configs (see also common configs):
# Override some of AlgorithmConfig's default values with bandit-specific values.
self.framework_str = "torch"
self.num_workers = 0
self.rollout_fragment_length = 1
self.train_batch_size = 1
# Make sure, a `train()` call performs at least 100 env sampling
# timesteps, before reporting results. Not setting this (default is 0)
# would significantly slow down the Bandit Algorithm.
self.min_sample_timesteps_per_iteration = 100
Single-Player Alpha Zero (AlphaZero)¶
[paper] [implementation] AlphaZero is an RL agent originally designed for two-player games. This version adapts it to handle single player games. The code can be scaled to any number of workers. It also implements the ranked rewards (R2) strategy to enable self-play even in the one-player setting. The code is mainly purposed to be used for combinatorial optimization.
Tuned examples: Sparse reward CartPole
AlphaZero-specific configs (see also common configs):
# AlphaZero specific config settings:
self.sgd_minibatch_size = 128
self.shuffle_sequences = True
self.num_sgd_iter = 30
self.learning_starts = 1000
self.replay_buffer_config = {
"type": "ReplayBuffer",
# Size of the replay buffer in batches (not timesteps!).
"capacity": 1000,
# When to start returning samples (in batches, not timesteps!).
"learning_starts": 500,
# Choosing `fragments` here makes it so that the buffer stores entire
# batches, instead of sequences, episodes or timesteps.
"storage_unit": "fragments",
}
self.lr_schedule = None
self.vf_share_layers = False
self.mcts_config = {
"puct_coefficient": 1.0,
"num_simulations": 30,
"temperature": 1.5,
"dirichlet_epsilon": 0.25,
"dirichlet_noise": 0.03,
"argmax_tree_policy": False,
"add_dirichlet_noise": True,
}
self.ranked_rewards = {
"enable": True,
"percentile": 75,
"buffer_max_length": 1000,
# add rewards obtained from random policy to
# "warm start" the buffer
"initialize_buffer": True,
"num_init_rewards": 100,
}
# Override some of AlgorithmConfig's default values with AlphaZero-specific
# values.
self.framework_str = "torch"
self.callbacks_class = AlphaZeroDefaultCallbacks
self.lr = 5e-5
self.rollout_fragment_length = 200
self.train_batch_size = 4000
self.batch_mode = "complete_episodes"
# Extra configuration that disables exploration.
self.evaluation_config = {
"mcts_config": {
"argmax_tree_policy": True,
"add_dirichlet_noise": False,
},
}
Multi-Agent Methods¶
QMIX Monotonic Value Factorisation (QMIX, VDN, IQN)¶
[paper] [implementation] Q-Mix is a specialized multi-agent algorithm. Code here is adapted from https://github.com/oxwhirl/pymarl_alpha to integrate with RLlib multi-agent APIs. To use Q-Mix, you must specify an agent grouping in the environment (see the two-step game example). Currently, all agents in the group must be homogeneous. The algorithm can be scaled by increasing the number of workers or using Ape-X.
Tuned examples: Two-step game
QMIX-specific configs (see also common configs):
# QMix specific settings:
self.mixer = "qmix"
self.mixing_embed_dim = 32
self.double_q = True
self.optim_alpha = 0.99
self.optim_eps = 0.00001
self.grad_clip = 10
# Override some of AlgorithmConfig's default values with QMix-specific values.
# .training()
self.lr = 0.0005
self.train_batch_size = 32
self.target_network_update_freq = 500
self.replay_buffer_config = {
"type": "ReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
# Size of the replay buffer in batches
"capacity": 1000,
# Choosing `fragments` here makes it so that the buffer stores entire
# batches, instead of sequences, episodes or timesteps.
"storage_unit": "fragments",
"learning_starts": 1000,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
}
self.model = {
"lstm_cell_size": 64,
"max_seq_len": 999999,
}
# .framework()
self.framework_str = "torch"
# .rollouts()
self.num_workers = 0
self.rollout_fragment_length = 4
self.batch_mode = "complete_episodes"
# .reporting()
self.min_time_s_per_iteration = 1
self.min_sample_timesteps_per_iteration = 1000
# .exploration()
self.exploration_config = {
# The Exploration class to use.
"type": "EpsilonGreedy",
# Config for the Exploration class' constructor:
"initial_epsilon": 1.0,
"final_epsilon": 0.01,
# Timesteps over which to anneal epsilon.
"epsilon_timesteps": 40000,
# For soft_q, use:
# "exploration_config" = {
# "type": "SoftQ"
# "temperature": [float, e.g. 1.0]
# }
}
# .evaluation()
# Evaluate with epsilon=0 every `evaluation_interval` training iterations.
# The evaluation stats will be reported under the "evaluation" metric key.
# Note that evaluation is currently not parallelized, and that for Ape-X
# metrics are already only reported for the lowest epsilon workers.
self.evaluation_interval = None
self.evaluation_duration = 10
self.evaluation_config = {
"explore": False,
}
Multi-Agent Deep Deterministic Policy Gradient (MADDPG)¶
[paper] [implementation] MADDPG is a DDPG centralized/shared critic algorithm. Code here is adapted from https://github.com/openai/maddpg to integrate with RLlib multi-agent APIs. Please check justinkterry/maddpg-rllib for examples and more information. Note that the implementation here is based on OpenAI’s, and is intended for use with the discrete MPE environments. Please also note that people typically find this method difficult to get to work, even with all applicable optimizations for their environment applied. This method should be viewed as for research purposes, and for reproducing the results of the paper introducing it.
MADDPG-specific configs (see also common configs):
Tuned examples: Multi-Agent Particle Environment, Two-step game
# MADDPG specific config settings:
self.agent_id = None
self.use_local_critic = False
self.use_state_preprocessor = False
self.actor_hiddens = [64, 64]
self.actor_hidden_activation = "relu"
self.critic_hiddens = [64, 64]
self.critic_hidden_activation = "relu"
self.n_step = 1
self.good_policy = "maddpg"
self.adv_policy = "maddpg"
self.replay_buffer_config = {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": int(1e6),
# How many steps of the model to sample before learning starts.
"learning_starts": 1024 * 25,
# Force lockstep replay mode for MADDPG.
"replay_mode": "lockstep",
}
self.training_intensity = None
self.critic_lr = 1e-2
self.actor_lr = 1e-2
self.target_network_update_freq = 0
self.tau = 0.01
self.actor_feature_reg = 0.001
self.grad_norm_clipping = 0.5
# Changes to Algorithm's default:
self.rollout_fragment_length = 100
self.train_batch_size = 1024
self.num_workers = 1
self.min_time_s_per_iteration = 0
# fmt: on
Parameter Sharing¶
[paper], [paper] and [instructions]. Parameter sharing refers to a class of methods that take a base single agent method, and use it to learn a single policy for all agents. This simple approach has been shown to achieve state of the art performance in cooperative games, and is usually how you should start trying to learn a multi-agent problem.
Tuned examples: PettingZoo, waterworld, rock-paper-scissors, multi-agent cartpole
Fully Independent Learning¶
[instructions] Fully independent learning involves a collection of agents learning independently of each other via single agent methods. This typically works, but can be less effective than dedicated multi-agent RL methods, since they do not account for the non-stationarity of the multi-agent environment.
Tuned examples: waterworld, multiagent-cartpole
Exploration-based plug-ins (can be combined with any algo)¶
Curiosity (ICM: Intrinsic Curiosity Module)¶
Tuned examples:
Pyramids (Unity3D) (use --env Pyramids
command line option)
Test case with MiniGrid example (UnitTest case: test_curiosity_on_partially_observable_domain
)
Activating Curiosity The curiosity plugin can be easily activated by specifying it as the Exploration class to-be-used in the main Algorithm config. Most of its parameters usually do not have to be specified as the module uses the values from the paper by default. For example:
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 0
config["exploration_config"] = {
"type": "Curiosity", # <- Use the Curiosity module for exploring.
"eta": 1.0, # Weight for intrinsic rewards before being added to extrinsic ones.
"lr": 0.001, # Learning rate of the curiosity (ICM) module.
"feature_dim": 288, # Dimensionality of the generated feature vectors.
# Setup of the feature net (used to encode observations into feature (latent) vectors).
"feature_net_config": {
"fcnet_hiddens": [],
"fcnet_activation": "relu",
},
"inverse_net_hiddens": [256], # Hidden layers of the "inverse" model.
"inverse_net_activation": "relu", # Activation of the "inverse" model.
"forward_net_hiddens": [256], # Hidden layers of the "forward" model.
"forward_net_activation": "relu", # Activation of the "forward" model.
"beta": 0.2, # Weight for the "forward" loss (beta) over the "inverse" loss (1.0 - beta).
# Specify, which exploration sub-type to use (usually, the algo's "default"
# exploration, e.g. EpsilonGreedy for DQN, StochasticSampling for PG/SAC).
"sub_exploration": {
"type": "StochasticSampling",
}
}
Functionality
RLlib’s Curiosity is based on “ICM” (intrinsic curiosity module) described in this paper here.
It allows agents to learn in sparse-reward- or even no-reward environments by
calculating so-called “intrinsic rewards”, purely based on the information content that is incoming via the observation channel.
Sparse-reward environments are envs where almost all reward signals are 0.0, such as these [MiniGrid env examples here].
In such environments, agents have to navigate (and change the underlying state of the environment) over long periods of time, without receiving much (or any) feedback.
For example, the task could be to find a key in some room, pick it up, find a matching door (matching the color of the key), and eventually unlock this door with the key to reach a goal state,
all the while not seeing any rewards.
Such problems are impossible to solve with standard RL exploration methods like epsilon-greedy or stochastic sampling.
The Curiosity module - when configured as the Exploration class to use via the Algorithm’s config (see above on how to do this) - automatically adds three simple models to the Policy’s self.model
:
a) a latent space learning (“feature”) model, taking an environment observation and outputting a latent vector, which represents this observation and
b) a “forward” model, predicting the next latent vector, given the current observation vector and an action to take next.
c) a so-called “inverse” net, only used to train the “feature” net. The inverse net tries to predict the action taken between two latent vectors (obs and next obs).
All the above extra Models are trained inside the modified Exploration.postprocess_trajectory()
call.
Using the (ever changing) “forward” model, our Curiosity module calculates an artificial (intrinsic) reward signal, weights it via the eta
parameter, and then adds it to the environment’s (extrinsic) reward.
Intrinsic rewards for each env-step are calculated by taking the euclidian distance between the latent-space encoded next observation (“feature” model) and the predicted latent-space encoding for the next observation
(“forward” model).
This allows the agent to explore areas of the environment, where the “forward” model still performs poorly (are not “understood” yet), whereas exploration to these areas will taper down after the agent has visited them
often: The “forward” model will eventually get better at predicting these next latent vectors, which in turn will diminish the intrinsic rewards (decrease the euclidian distance between predicted and actual vectors).
RE3 (Random Encoders for Efficient Exploration)¶
Examples:
LunarLanderContinuous-v2 (use --env LunarLanderContinuous-v2
command line option)
Test case with Pendulum-v1 example
Activating RE3
The RE3 plugin can be easily activated by specifying it as the Exploration class to-be-used
in the main Algorithm config and inheriting the RE3UpdateCallbacks
as shown in this example. Most of its parameters usually do not have to be specified as the module uses the values from the paper by default. For example:
config = sac.DEFAULT_CONFIG.copy()
config["env"] = "Pendulum-v1"
config["seed"] = 12345
config["callbacks"] = RE3Callbacks
config["exploration_config"] = {
"type": "RE3",
# the dimensionality of the observation embedding vectors in latent space.
"embeds_dim": 128,
"rho": 0.1, # Beta decay factor, used for on-policy algorithm.
"k_nn": 50, # Number of neighbours to set for K-NN entropy estimation.
# Configuration for the encoder network, producing embedding vectors from observations.
# This can be used to configure fcnet- or conv_net setups to properly process any
# observation space. By default uses the Policy model configuration.
"encoder_net_config": {
"fcnet_hiddens": [],
"fcnet_activation": "relu",
},
# Hyperparameter to choose between exploration and exploitation. A higher value of beta adds
# more importance to the intrinsic reward, as per the following equation
# `reward = r + beta * intrinsic_reward`
"beta": 0.2,
# Schedule to use for beta decay, one of constant" or "linear_decay".
"beta_schedule": 'constant',
# Specify, which exploration sub-type to use (usually, the algo's "default"
# exploration, e.g. EpsilonGreedy for DQN, StochasticSampling for PG/SAC).
"sub_exploration": {
"type": "StochasticSampling",
}
}
Functionality RLlib’s RE3 is based on “Random Encoders for Efficient Exploration” described in this paper here. RE3 quantifies exploration based on state entropy. The entropy of a state is calculated based on its distance from K nearest neighbor states present in the replay buffer in the latent space (With this implementation, KNN is implemented using training samples from the same batch). The state entropy is considered as an intrinsic reward and for policy optimization added to the extrinsic reward when available. If the extrinsic reward is not available then the state entropy is used as “intrinsic reward” for unsupervised pre-training of the RL agent. RE3 further allows agents to learn in sparse-reward or even no-reward environments by using the state entropy as “intrinsic rewards”.
This exploration objective can be used with both model-free and model-based RL algorithms. RE3 uses a randomly initialized encoder to get the state’s latent representation, thus taking away the complexity of training the representation learning method. The encoder weights are fixed during the entire duration of the training process.