Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.
Note, however, that so far only PPO (single- and multi-agent) and SAC (single-agent only) support the “new API stack” and continue to run by default with the old APIs. You can continue to use the existing custom (old stack) classes.
See here for more details on how to use the new API stack.
Algorithms#
Overview#
The following table is an overview of all available algorithms in RLlib. Note that all of them support multi-GPU training on a single (GPU) node in Ray (open-source) () as well as multi-GPU training on multi-node (GPU) clusters when using the Anyscale platform ().
Algorithm |
Single- and Multi-agent |
Multi-GPU (multi-node) |
Action Spaces |
On-Policy |
|||
Off-Policy |
|||
High-throughput on- and off policy |
|||
Model-based RL |
|||
Offline RL and Imitation Learning |
|||
Algorithm Extensions and -Plugins |
|||
On-policy#
Proximal Policy Optimization (PPO)#
Tuned examples: Pong-v5, CartPole-v1. Pendulum-v1.
PPO-specific configs (see also common configs):
- class ray.rllib.algorithms.ppo.ppo.PPOConfig(algo_class=None)[source]#
Defines a configuration class from which a PPO Algorithm can be built.
from ray.rllib.algorithms.ppo import PPOConfig config = PPOConfig() # Activate new API stack. config.api_stack( enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) config.environment("CartPole-v1") config.env_runners(num_env_runners=1) config.training( gamma=0.9, lr=0.01, kl_coeff=0.3, train_batch_size_per_learner=256 ) # Build a Algorithm object from the config and run 1 training iteration. algo = config.build() algo.train()
from ray.rllib.algorithms.ppo import PPOConfig from ray import air from ray import tune config = ( PPOConfig() # Activate new API stack. .api_stack( enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) # Set the config object's env. .environment(env="CartPole-v1") # Update the config object's training parameters. .training( lr=0.001, clip_param=0.2 ) ) tune.Tuner( "PPO", run_config=air.RunConfig(stop={"training_iteration": 1}), param_space=config, ).fit()
- training(*, use_critic: bool | None = <ray.rllib.utils.from_config._NotProvided object>, use_gae: bool | None = <ray.rllib.utils.from_config._NotProvided object>, lambda_: float | None = <ray.rllib.utils.from_config._NotProvided object>, use_kl_loss: bool | None = <ray.rllib.utils.from_config._NotProvided object>, kl_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, kl_target: float | None = <ray.rllib.utils.from_config._NotProvided object>, mini_batch_size_per_learner: int | None = <ray.rllib.utils.from_config._NotProvided object>, sgd_minibatch_size: int | None = <ray.rllib.utils.from_config._NotProvided object>, num_sgd_iter: int | None = <ray.rllib.utils.from_config._NotProvided object>, shuffle_sequences: bool | None = <ray.rllib.utils.from_config._NotProvided object>, vf_loss_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, entropy_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, entropy_coeff_schedule: ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, clip_param: float | None = <ray.rllib.utils.from_config._NotProvided object>, vf_clip_param: float | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip: float | None = <ray.rllib.utils.from_config._NotProvided object>, lr_schedule: ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, vf_share_layers=-1, **kwargs) PPOConfig [source]#
Sets the training related configuration.
- Parameters:
use_critic – Should use a critic as a baseline (otherwise don’t use value baseline; required for using GAE).
use_gae – If true, use the Generalized Advantage Estimator (GAE) with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
lambda – The lambda parameter for General Advantage Estimation (GAE). Defines the exponential weight used between actually measured rewards vs value function estimates over multiple time steps. Specifically,
lambda_
balances short-term, low-variance estimates with longer-term, high-variance returns. Alambda_
of 0.0 makes the GAE rely only on immediate rewards (and vf predictions from there on, reducing variance, but increasing bias), while alambda_
of 1.0 only incorporates vf predictions at the truncation points of the given episodes or episode chunks (reducing bias but increasing variance).use_kl_loss – Whether to use the KL-term in the loss function.
kl_coeff – Initial coefficient for KL divergence.
kl_target – Target value for KL divergence.
mini_batch_size_per_learner – Only use if new API stack is enabled. The mini batch size per Learner worker. This is the batch size that each Learner worker’s training batch (whose size is
s`elf.train_batch_size_per_learner
) will be split into. For example, if the train batch size per Learner worker is 4000 and the mini batch size per Learner worker is 400, the train batch will be split into 10 equal sized chunks (or “mini batches”). Each such mini batch will be used for one SGD update. Overall, the train batch on each Learner worker will be traversedself.num_sgd_iter
times. In the above example, ifself.num_sgd_iter
is 5, we will altogether perform 50 (10x5) SGD updates per Learner update step.sgd_minibatch_size – Total SGD batch size across all devices for SGD. This defines the minibatch size within each epoch. Deprecated on the new API stack (use
mini_batch_size_per_learner
instead).num_sgd_iter – Number of SGD iterations in each outer loop (i.e., number of epochs to execute per train batch).
shuffle_sequences – Whether to shuffle sequences in the batch when training (recommended).
vf_loss_coeff – Coefficient of the value function loss. IMPORTANT: you must tune this if you set vf_share_layers=True inside your model’s config.
entropy_coeff – The entropy coefficient (float) or entropy coefficient schedule in the format of [[timestep, coeff-value], [timestep, coeff-value], …] In case of a schedule, intermediary timesteps will be assigned to linearly interpolated coefficient values. A schedule config’s first entry must start with timestep 0, i.e.: [[0, initial_value], […]].
clip_param – The PPO clip parameter.
vf_clip_param – Clip param for the value function. Note that this is sensitive to the scale of the rewards. If your expected V is large, increase this.
grad_clip – If specified, clip the global norm of gradients by this amount.
- Returns:
This updated AlgorithmConfig object.
Off-Policy#
Deep Q Networks (DQN, Rainbow, Parametric DQN)#
All of the DQN improvements evaluated in Rainbow are available, though not all are enabled by default. See also how to use parametric-actions in DQN.
Tuned examples: PongDeterministic-v4, Rainbow configuration, {BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4, with Dueling and Double-Q, with Distributional DQN.
Hint
For a complete rainbow setup,
make the following changes to the default DQN config:
"n_step": [between 1 and 10],
"noisy": True,
"num_atoms": [more than 1],
"v_min": -10.0,
"v_max": 10.0
(set v_min
and v_max
according to your expected range of returns).
DQN-specific configs (see also common configs):
- class ray.rllib.algorithms.dqn.dqn.DQNConfig(algo_class=None)[source]#
Defines a configuration class from which a DQN Algorithm can be built.
from ray.rllib.algorithms.dqn.dqn import DQNConfig config = DQNConfig() replay_config = { "type": "MultiAgentPrioritizedReplayBuffer", "capacity": 60000, "prioritized_replay_alpha": 0.5, "prioritized_replay_beta": 0.5, "prioritized_replay_eps": 3e-6, } config = config.training(replay_buffer_config=replay_config) config = config.resources(num_gpus=0) config = config.env_runners(num_env_runners=1) config = config.environment("CartPole-v1") algo = DQN(config=config) algo.train() del algo
from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray import air from ray import tune config = DQNConfig() config = config.training( num_atoms=tune.grid_search([1,])) config = config.environment(env="CartPole-v1") tune.Tuner( "DQN", run_config=air.RunConfig(stop={"training_iteration":1}), param_space=config.to_dict() ).fit()
- training(*, target_network_update_freq: int | None = <ray.rllib.utils.from_config._NotProvided object>, replay_buffer_config: dict | None = <ray.rllib.utils.from_config._NotProvided object>, store_buffer_in_checkpoints: bool | None = <ray.rllib.utils.from_config._NotProvided object>, lr_schedule: ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, epsilon: float | ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, adam_epsilon: float | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip: int | None = <ray.rllib.utils.from_config._NotProvided object>, num_steps_sampled_before_learning_starts: int | None = <ray.rllib.utils.from_config._NotProvided object>, tau: float | None = <ray.rllib.utils.from_config._NotProvided object>, num_atoms: int | None = <ray.rllib.utils.from_config._NotProvided object>, v_min: float | None = <ray.rllib.utils.from_config._NotProvided object>, v_max: float | None = <ray.rllib.utils.from_config._NotProvided object>, noisy: bool | None = <ray.rllib.utils.from_config._NotProvided object>, sigma0: float | None = <ray.rllib.utils.from_config._NotProvided object>, dueling: bool | None = <ray.rllib.utils.from_config._NotProvided object>, hiddens: int | None = <ray.rllib.utils.from_config._NotProvided object>, double_q: bool | None = <ray.rllib.utils.from_config._NotProvided object>, n_step: int | ~typing.Tuple[int, int] | None = <ray.rllib.utils.from_config._NotProvided object>, before_learn_on_batch: ~typing.Callable[[~typing.Type[~ray.rllib.policy.sample_batch.MultiAgentBatch], ~typing.List[~typing.Type[~ray.rllib.policy.policy.Policy]], ~typing.Type[int]], ~typing.Type[~ray.rllib.policy.sample_batch.MultiAgentBatch]] = <ray.rllib.utils.from_config._NotProvided object>, training_intensity: float | None = <ray.rllib.utils.from_config._NotProvided object>, td_error_loss_fn: str | None = <ray.rllib.utils.from_config._NotProvided object>, categorical_distribution_temperature: float | None = <ray.rllib.utils.from_config._NotProvided object>, **kwargs) DQNConfig [source]#
Sets the training related configuration.
- Parameters:
target_network_update_freq – Update the target network every
target_network_update_freq
sample steps.replay_buffer_config – Replay buffer config. Examples: { “_enable_replay_buffer_api”: True, “type”: “MultiAgentReplayBuffer”, “capacity”: 50000, “replay_sequence_length”: 1, } - OR - { “_enable_replay_buffer_api”: True, “type”: “MultiAgentPrioritizedReplayBuffer”, “capacity”: 50000, “prioritized_replay_alpha”: 0.6, “prioritized_replay_beta”: 0.4, “prioritized_replay_eps”: 1e-6, “replay_sequence_length”: 1, } - Where - prioritized_replay_alpha: Alpha parameter controls the degree of prioritization in the buffer. In other words, when a buffer sample has a higher temporal-difference error, with how much more probability should it drawn to use to update the parametrized Q-network. 0.0 corresponds to uniform probability. Setting much above 1.0 may quickly result as the sampling distribution could become heavily “pointy” with low entropy. prioritized_replay_beta: Beta parameter controls the degree of importance sampling which suppresses the influence of gradient updates from samples that have higher probability of being sampled via alpha parameter and the temporal-difference error. prioritized_replay_eps: Epsilon parameter sets the baseline probability for sampling so that when the temporal-difference error of a sample is zero, there is still a chance of drawing the sample.
store_buffer_in_checkpoints – Set this to True, if you want the contents of your buffer(s) to be stored in any saved checkpoints as well. Warnings will be created if: - This is True AND restoring from a checkpoint that contains no buffer data. - This is False AND restoring from a checkpoint that does contain buffer data.
epsilon – Epsilon exploration schedule. In the format of [[timestep, value], [timestep, value], …]. A schedule must start from timestep 0.
adam_epsilon – Adam optimizer’s epsilon hyper parameter.
grad_clip – If not None, clip gradients during optimization at this value.
num_steps_sampled_before_learning_starts – Number of timesteps to collect from rollout workers before we start sampling from replay buffers for learning. Whether we count this in agent steps or environment steps depends on config.multi_agent(count_steps_by=..).
tau – Update the target by au * policy + (1- au) * target_policy.
num_atoms – Number of atoms for representing the distribution of return. When this is greater than 1, distributional Q-learning is used.
v_min – Minimum value estimation
v_max – Maximum value estimation
noisy – Whether to use noisy network to aid exploration. This adds parametric noise to the model weights.
sigma0 – Control the initial parameter noise for noisy nets.
dueling – Whether to use dueling DQN.
hiddens – Dense-layer setup for each the advantage branch and the value branch
double_q – Whether to use double DQN.
n_step – N-step target updates. If >1, sars’ tuples in trajectories will be postprocessed to become sa[discounted sum of R][s t+n] tuples. An integer will be interpreted as a fixed n-step value. If a tuple of 2 ints is provided here, the n-step value will be drawn for each sample(!) in the train batch from a uniform distribution over the closed interval defined by
[n_step[0], n_step[1]]
.before_learn_on_batch – Callback to run before learning on a multi-agent batch of experiences.
training_intensity – The intensity with which to update the model (vs collecting samples from the env). If None, uses “natural” values of:
train_batch_size
/ (rollout_fragment_length
xnum_env_runners
xnum_envs_per_env_runner
). If not None, will make sure that the ratio between timesteps inserted into and sampled from the buffer matches the given values. Example: training_intensity=1000.0 train_batch_size=250 rollout_fragment_length=1 num_env_runners=1 (or 0) num_envs_per_env_runner=1 -> natural value = 250 / 1 = 250.0 -> will make sure that replay+train op will be executed 4x asoften as rollout+insert op (4 * 250 = 1000). See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further details.td_error_loss_fn – “huber” or “mse”. loss function for calculating TD error when num_atoms is 1. Note that if num_atoms is > 1, this parameter is simply ignored, and softmax cross entropy loss will be used.
categorical_distribution_temperature – Set the temperature parameter used by Categorical action distribution. A valid temperature is in the range of [0, 1]. Note that this mostly affects evaluation since TD error uses argmax for return calculation.
- Returns:
This updated AlgorithmConfig object.
Soft Actor Critic (SAC)#
[original paper], [follow up paper], [implementation].
Tuned examples: Pendulum-v1, HalfCheetah-v3,
SAC-specific configs (see also common configs):
- class ray.rllib.algorithms.sac.sac.SACConfig(algo_class=None)[source]#
Defines a configuration class from which an SAC Algorithm can be built.
config = SACConfig().training(gamma=0.9, lr=0.01, train_batch_size=32) config = config.resources(num_gpus=0) config = config.env_runners(num_env_runners=1) # Build a Algorithm object from the config and run 1 training iteration. algo = config.build(env="CartPole-v1") algo.train()
- training(*, twin_q: bool | None = <ray.rllib.utils.from_config._NotProvided object>, q_model_config: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, policy_model_config: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, tau: float | None = <ray.rllib.utils.from_config._NotProvided object>, initial_alpha: float | None = <ray.rllib.utils.from_config._NotProvided object>, target_entropy: str | float | None = <ray.rllib.utils.from_config._NotProvided object>, n_step: int | ~typing.Tuple[int, int] | None = <ray.rllib.utils.from_config._NotProvided object>, store_buffer_in_checkpoints: bool | None = <ray.rllib.utils.from_config._NotProvided object>, replay_buffer_config: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, training_intensity: float | None = <ray.rllib.utils.from_config._NotProvided object>, clip_actions: bool | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip: float | None = <ray.rllib.utils.from_config._NotProvided object>, optimization_config: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, actor_lr: float | ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, critic_lr: float | ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, alpha_lr: float | ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, target_network_update_freq: int | None = <ray.rllib.utils.from_config._NotProvided object>, _deterministic_loss: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _use_beta_distribution: bool | None = <ray.rllib.utils.from_config._NotProvided object>, num_steps_sampled_before_learning_starts: int | None = <ray.rllib.utils.from_config._NotProvided object>, **kwargs) SACConfig [source]#
Sets the training related configuration.
- Parameters:
twin_q – Use two Q-networks (instead of one) for action-value estimation. Note: Each Q-network will have its own target network.
q_model_config – Model configs for the Q network(s). These will override MODEL_DEFAULTS. This is treated just as the top-level
model
dict in setting up the Q-network(s) (2 if twin_q=True). That means, you can do for different observation spaces:obs=Box(1D)
->Tuple(Box(1D) + Action)
->concat
->post_fcnet
obs=Box(3D) -> Tuple(Box(3D) + Action) -> vision-net -> concat w/ action -> post_fcnet obs=Tuple(Box(1D), Box(3D)) -> Tuple(Box(1D), Box(3D), Action) -> vision-net -> concat w/ Box(1D) and action -> post_fcnet You can also have SAC use your custom_model as Q-model(s), by simply specifying thecustom_model
sub-key in below dict (just like you would do in the top-levelmodel
dict.policy_model_config – Model options for the policy function (see
q_model_config
above for details). The difference toq_model_config
above is that no action concat’ing is performed before the post_fcnet stack.tau – Update the target by au * policy + (1- au) * target_policy.
initial_alpha – Initial value to use for the entropy weight alpha.
target_entropy – Target entropy lower bound. If “auto”, will be set to
-|A|
(e.g. -2.0 for Discrete(2), -3.0 for Box(shape=(3,))). This is the inverse of reward scale, and will be optimized automatically.n_step – N-step target updates. If >1, sars’ tuples in trajectories will be postprocessed to become sa[discounted sum of R][s t+n] tuples. An integer will be interpreted as a fixed n-step value. If a tuple of 2 ints is provided here, the n-step value will be drawn for each sample(!) in the train batch from a uniform distribution over the closed interval defined by
[n_step[0], n_step[1]]
.store_buffer_in_checkpoints – Set this to True, if you want the contents of your buffer(s) to be stored in any saved checkpoints as well. Warnings will be created if: - This is True AND restoring from a checkpoint that contains no buffer data. - This is False AND restoring from a checkpoint that does contain buffer data.
replay_buffer_config – Replay buffer config. Examples: { “_enable_replay_buffer_api”: True, “type”: “MultiAgentReplayBuffer”, “capacity”: 50000, “replay_batch_size”: 32, “replay_sequence_length”: 1, } - OR - { “_enable_replay_buffer_api”: True, “type”: “MultiAgentPrioritizedReplayBuffer”, “capacity”: 50000, “prioritized_replay_alpha”: 0.6, “prioritized_replay_beta”: 0.4, “prioritized_replay_eps”: 1e-6, “replay_sequence_length”: 1, } - Where - prioritized_replay_alpha: Alpha parameter controls the degree of prioritization in the buffer. In other words, when a buffer sample has a higher temporal-difference error, with how much more probability should it drawn to use to update the parametrized Q-network. 0.0 corresponds to uniform probability. Setting much above 1.0 may quickly result as the sampling distribution could become heavily “pointy” with low entropy. prioritized_replay_beta: Beta parameter controls the degree of importance sampling which suppresses the influence of gradient updates from samples that have higher probability of being sampled via alpha parameter and the temporal-difference error. prioritized_replay_eps: Epsilon parameter sets the baseline probability for sampling so that when the temporal-difference error of a sample is zero, there is still a chance of drawing the sample.
training_intensity – The intensity with which to update the model (vs collecting samples from the env). If None, uses “natural” values of:
train_batch_size
/ (rollout_fragment_length
xnum_env_runners
xnum_envs_per_env_runner
). If not None, will make sure that the ratio between timesteps inserted into and sampled from th buffer matches the given values. Example: training_intensity=1000.0 train_batch_size=250 rollout_fragment_length=1 num_env_runners=1 (or 0) num_envs_per_env_runner=1 -> natural value = 250 / 1 = 250.0 -> will make sure that replay+train op will be executed 4x asoften as rollout+insert op (4 * 250 = 1000). See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further details.clip_actions – Whether to clip actions. If actions are already normalized, this should be set to False.
grad_clip – If not None, clip gradients during optimization at this value.
optimization_config – Config dict for optimization. Set the supported keys
actor_learning_rate
,critic_learning_rate
, andentropy_learning_rate
in here.actor_lr – The learning rate (float) or learning rate schedule for the policy in the format of [[timestep, lr-value], [timestep, lr-value], …] In case of a schedule, intermediary timesteps will be assigned to linearly interpolated learning rate values. A schedule config’s first entry must start with timestep 0, i.e.: [[0, initial_value], […]]. Note: It is common practice (two-timescale approach) to use a smaller learning rate for the policy than for the critic to ensure that the critic gives adequate values for improving the policy. Note: If you require a) more than one optimizer (per RLModule), b) optimizer types that are not Adam, c) a learning rate schedule that is not a linearly interpolated, piecewise schedule as described above, or d) specifying c’tor arguments of the optimizer that are not the learning rate (e.g. Adam’s epsilon), then you must override your Learner’s
configure_optimizer_for_module()
method and handle lr-scheduling yourself. The default value is 3e-5, one decimal less than the respective learning rate of the critic (seecritic_lr
).critic_lr – The learning rate (float) or learning rate schedule for the critic in the format of [[timestep, lr-value], [timestep, lr-value], …] In case of a schedule, intermediary timesteps will be assigned to linearly interpolated learning rate values. A schedule config’s first entry must start with timestep 0, i.e.: [[0, initial_value], […]]. Note: It is common practice (two-timescale approach) to use a smaller learning rate for the policy than for the critic to ensure that the critic gives adequate values for improving the policy. Note: If you require a) more than one optimizer (per RLModule), b) optimizer types that are not Adam, c) a learning rate schedule that is not a linearly interpolated, piecewise schedule as described above, or d) specifying c’tor arguments of the optimizer that are not the learning rate (e.g. Adam’s epsilon), then you must override your Learner’s
configure_optimizer_for_module()
method and handle lr-scheduling yourself. The default value is 3e-4, one decimal higher than the respective learning rate of the actor (policy) (seeactor_lr
).alpha_lr – The learning rate (float) or learning rate schedule for the hyperparameter alpha in the format of [[timestep, lr-value], [timestep, lr-value], …] In case of a schedule, intermediary timesteps will be assigned to linearly interpolated learning rate values. A schedule config’s first entry must start with timestep 0, i.e.: [[0, initial_value], […]]. Note: If you require a) more than one optimizer (per RLModule), b) optimizer types that are not Adam, c) a learning rate schedule that is not a linearly interpolated, piecewise schedule as described above, or d) specifying c’tor arguments of the optimizer that are not the learning rate (e.g. Adam’s epsilon), then you must override your Learner’s
configure_optimizer_for_module()
method and handle lr-scheduling yourself. The default value is 3e-4, identical to the critic learning rate (lr
).target_network_update_freq – Update the target network every
target_network_update_freq
steps._deterministic_loss – Whether the loss should be calculated deterministically (w/o the stochastic action sampling step). True only useful for continuous actions and for debugging.
_use_beta_distribution – Use a Beta-distribution instead of a
SquashedGaussian
for bounded, continuous action spaces (not recommended; for debugging only).
- Returns:
This updated AlgorithmConfig object.
High-Throughput On- and Off-Policy#
Importance Weighted Actor-Learner Architecture (IMPALA)#
Tuned examples: PongNoFrameskip-v4, vectorized configuration, multi-gpu configuration, {BeamRider,Breakout,Qbert,SpaceInvaders}NoFrameskip-v4.
IMPALA-specific configs (see also common configs):
- ray.rllib.algorithms.impala.impala.ImpalaConfig#
alias of
IMPALAConfig
Asynchronous Proximal Policy Optimization (APPO)#
Tip
APPO isn’t always more efficient; it’s often better to use standard PPO or IMPALA.
Tuned examples: PongNoFrameskip-v4
APPO-specific configs (see also common configs):
- class ray.rllib.algorithms.appo.appo.APPOConfig(algo_class=None)[source]#
Defines a configuration class from which an APPO Algorithm can be built.
from ray.rllib.algorithms.appo import APPOConfig config = APPOConfig().training(lr=0.01, grad_clip=30.0, train_batch_size=50) config = config.resources(num_gpus=0) config = config.env_runners(num_env_runners=1) config = config.environment("CartPole-v1") # Build an Algorithm object from the config and run 1 training iteration. algo = config.build() algo.train() del algo
from ray.rllib.algorithms.appo import APPOConfig from ray import air from ray import tune config = APPOConfig() # Update the config object. config = config.training(lr=tune.grid_search([0.001,])) # Set the config object's env. config = config.environment(env="CartPole-v1") # Use to_dict() to get the old-style python config dict # when running with tune. tune.Tuner( "APPO", run_config=air.RunConfig(stop={"training_iteration": 1}, verbose=0), param_space=config.to_dict(), ).fit()
- training(*, vtrace: bool | None = <ray.rllib.utils.from_config._NotProvided object>, use_critic: bool | None = <ray.rllib.utils.from_config._NotProvided object>, use_gae: bool | None = <ray.rllib.utils.from_config._NotProvided object>, lambda_: float | None = <ray.rllib.utils.from_config._NotProvided object>, clip_param: float | None = <ray.rllib.utils.from_config._NotProvided object>, use_kl_loss: bool | None = <ray.rllib.utils.from_config._NotProvided object>, kl_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, kl_target: float | None = <ray.rllib.utils.from_config._NotProvided object>, tau: float | None = <ray.rllib.utils.from_config._NotProvided object>, target_network_update_freq: int | None = <ray.rllib.utils.from_config._NotProvided object>, target_update_frequency=-1, **kwargs) APPOConfig [source]#
Sets the training related configuration.
- Parameters:
vtrace – Whether to use V-trace weighted advantages. If false, PPO GAE advantages will be used instead.
use_critic – Should use a critic as a baseline (otherwise don’t use value baseline; required for using GAE). Only applies if vtrace=False.
use_gae – If true, use the Generalized Advantage Estimator (GAE) with a value function, see https://arxiv.org/pdf/1506.02438.pdf. Only applies if vtrace=False.
lambda – GAE (lambda) parameter.
clip_param – PPO surrogate slipping parameter.
use_kl_loss – Whether to use the KL-term in the loss function.
kl_coeff – Coefficient for weighting the KL-loss term.
kl_target – Target term for the KL-term to reach (via adjusting the
kl_coeff
automatically).tau – The factor by which to update the target policy network towards the current policy network. Can range between 0 and 1. e.g. updated_param = tau * current_param + (1 - tau) * target_param
target_network_update_freq – The frequency to update the target policy and tune the kl loss coefficients that are used during training. After setting this parameter, the algorithm waits for at least
target_network_update_freq * minibatch_size * num_sgd_iter
number of samples to be trained on by the learner group before updating the target networks and tuned the kl loss coefficients that are used during training. NOTE: This parameter is only applicable when using the Learner API (enable_rl_module_and_learner=True).
- Returns:
This updated AlgorithmConfig object.
Model-based RL#
DreamerV3#
Tuned examples: Atari 100k, Atari 200M, DeepMind Control Suite
Pong-v5 results (1, 2, and 4 GPUs):
Atari 100k results (1 vs 4 GPUs):
DeepMind Control Suite (vision) results (1 vs 4 GPUs):
Offline RL and Imitation Learning#
Behavior Cloning (BC)#
Tuned examples: CartPole-v1 Pendulum-v1
BC-specific configs (see also common configs):
- class ray.rllib.algorithms.bc.bc.BCConfig(algo_class=None)[source]#
Defines a configuration class from which a new BC Algorithm can be built
from ray.rllib.algorithms.bc import BCConfig # Run this from the ray directory root. config = BCConfig().training(lr=0.00001, gamma=0.99) config = config.offline_data( input_="./rllib/tests/data/cartpole/large.json") # Build an Algorithm object from the config and run 1 training iteration. algo = config.build() algo.train()
from ray.rllib.algorithms.bc import BCConfig from ray import tune config = BCConfig() # Print out some default values. print(config.beta) # Update the config object. config.training( lr=tune.grid_search([0.001, 0.0001]), beta=0.75 ) # Set the config object's data path. # Run this from the ray directory root. config.offline_data( input_="./rllib/tests/data/cartpole/large.json" ) # Set the config object's env, used for evaluation. config.environment(env="CartPole-v1") # Use to_dict() to get the old-style python config dict # when running with tune. tune.Tuner( "BC", param_space=config.to_dict(), ).fit()
- training(*, beta: float | None = <ray.rllib.utils.from_config._NotProvided object>, bc_logstd_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, moving_average_sqd_adv_norm_update_rate: float | None = <ray.rllib.utils.from_config._NotProvided object>, moving_average_sqd_adv_norm_start: float | None = <ray.rllib.utils.from_config._NotProvided object>, vf_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip: float | None = <ray.rllib.utils.from_config._NotProvided object>, **kwargs) MARWILConfig #
Sets the training related configuration.
- Parameters:
beta – Scaling of advantages in exponential terms. When beta is 0.0, MARWIL is reduced to behavior cloning (imitation learning); see bc.py algorithm in this same directory.
bc_logstd_coeff – A coefficient to encourage higher action distribution entropy for exploration.
moving_average_sqd_adv_norm_start – Starting value for the squared moving average advantage norm (c^2).
vf_coeff – Balancing value estimation loss and policy optimization loss. moving_average_sqd_adv_norm_update_rate: Update rate for the squared moving average advantage norm (c^2).
grad_clip – If specified, clip the global norm of gradients by this amount.
- Returns:
This updated AlgorithmConfig object.
Monotonic Advantage Re-Weighted Imitation Learning (MARWIL)#
Tuned examples: CartPole-v1
MARWIL-specific configs (see also common configs):
- class ray.rllib.algorithms.marwil.marwil.MARWILConfig(algo_class=None)[source]#
Defines a configuration class from which a MARWIL Algorithm can be built.
Example
>>> from ray.rllib.algorithms.marwil import MARWILConfig >>> # Run this from the ray directory root. >>> config = MARWILConfig() >>> config = config.training(beta=1.0, lr=0.00001, gamma=0.99) >>> config = config.offline_data( ... input_=["./rllib/tests/data/cartpole/large.json"]) >>> print(config.to_dict()) ... >>> # Build an Algorithm object from the config and run 1 training iteration. >>> algo = config.build() >>> algo.train()
Example
>>> from ray.rllib.algorithms.marwil import MARWILConfig >>> from ray import tune >>> config = MARWILConfig() >>> # Print out some default values. >>> print(config.beta) >>> # Update the config object. >>> config.training(lr=tune.grid_search( ... [0.001, 0.0001]), beta=0.75) >>> # Set the config object's data path. >>> # Run this from the ray directory root. >>> config.offline_data( ... input_=["./rllib/tests/data/cartpole/large.json"]) >>> # Set the config object's env, used for evaluation. >>> config.environment(env="CartPole-v1") >>> # Use to_dict() to get the old-style python config dict >>> # when running with tune. >>> tune.Tuner( ... "MARWIL", ... param_space=config.to_dict(), ... ).fit()
- training(*, beta: float | None = <ray.rllib.utils.from_config._NotProvided object>, bc_logstd_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, moving_average_sqd_adv_norm_update_rate: float | None = <ray.rllib.utils.from_config._NotProvided object>, moving_average_sqd_adv_norm_start: float | None = <ray.rllib.utils.from_config._NotProvided object>, vf_coeff: float | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip: float | None = <ray.rllib.utils.from_config._NotProvided object>, **kwargs) MARWILConfig [source]#
Sets the training related configuration.
- Parameters:
beta – Scaling of advantages in exponential terms. When beta is 0.0, MARWIL is reduced to behavior cloning (imitation learning); see bc.py algorithm in this same directory.
bc_logstd_coeff – A coefficient to encourage higher action distribution entropy for exploration.
moving_average_sqd_adv_norm_start – Starting value for the squared moving average advantage norm (c^2).
vf_coeff – Balancing value estimation loss and policy optimization loss. moving_average_sqd_adv_norm_update_rate: Update rate for the squared moving average advantage norm (c^2).
grad_clip – If specified, clip the global norm of gradients by this amount.
- Returns:
This updated AlgorithmConfig object.
Algorithm Extensions- and Plugins#
Curiosity-driven Exploration by Self-supervised Prediction#
Tuned examples: 12x12 FrozenLake-v1