Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The team is currently transitioning algorithms, example scripts, and documentation to the new code base throughout the subsequent minor releases leading up to Ray 3.0.
See here for more details on how to activate and use the new API stack.
Getting Started with RLlib#
All RLlib experiments are run using an Algorithm
class which holds a policy for environment interaction.
Through the algorithm’s interface, you can train the policy, compute actions, or store your algorithm’s state (checkpointing).
In multi-agent training, the algorithm manages the querying and optimization of multiple policies at once.
In this guide, we will explain in detail RLlib’s Python API for running learning experiments.
Using the Python API#
The Python API provides all the flexibility required for applying RLlib to any type of problem.
Let’s start with an example of the API’s basic usage.
We first create a PPOConfig
instance and set some properties through the config class’ various methods.
For example, we can set the RL environment we want to use by calling the config’s environment
method.
To scale our algorithm and define, how many environment workers (EnvRunners) we want to leverage, we can call
the env_runners
method.
After we build
the PPO
Algorithm from its configuration, we can train
it for a number of
iterations (here 10
) and save
the resulting policy periodically (here every 5
iterations).
from pprint import pprint
from ray.rllib.algorithms.ppo import PPOConfig
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.env_runners(num_env_runners=1)
)
algo = config.build()
for i in range(10):
result = algo.train()
result.pop("config")
pprint(result)
if i % 5 == 0:
checkpoint_dir = algo.save_to_path()
print(f"Checkpoint saved in directory {checkpoint_dir}")
All RLlib algorithms are compatible with the Tune API. This enables them to be easily used in experiments with Ray Tune. For example, the following code performs a simple hyper-parameter sweep of PPO.
from ray import train, tune
config = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.training(
lr=tune.grid_search([0.01, 0.001, 0.0001]),
)
)
tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=train.RunConfig(
stop={"env_runners/episode_return_mean": 150.0},
),
)
tuner.fit()
Tune will schedule the trials to run in parallel on your Ray cluster:
== Status ==
Using FIFO scheduling algorithm.
Resources requested: 4/4 CPUs, 0/0 GPUs
Result logdir: ~/ray_results/my_experiment
PENDING trials:
- PPO_CartPole-v1_2_lr=0.0001: PENDING
RUNNING trials:
- PPO_CartPole-v1_0_lr=0.01: RUNNING [pid=21940], 16 s, 4013 ts, 22 rew
- PPO_CartPole-v1_1_lr=0.001: RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew
Tuner.fit()
returns an ResultGrid
object that allows further analysis
of the training results and retrieving the checkpoint(s) of the trained agent.
from ray import train, tune
# Tuner.fit() allows setting a custom log directory (other than ~/ray-results).
tuner = tune.Tuner(
"PPO",
param_space=config,
run_config=train.RunConfig(
stop={"num_env_steps_sampled_lifetime": 20000},
checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
),
)
results = tuner.fit()
# Get the best result based on a particular metric.
best_result = results.get_best_result(
metric="env_runners/episode_return_mean", mode="max"
)
# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint
Note
You can find your checkpoint’s version by
looking into the rllib_checkpoint.json
file inside your checkpoint directory.
Loading and restoring a trained algorithm from a checkpoint is simple.
Let’s assume you have a local checkpoint directory called checkpoint_path
.
To load newer RLlib checkpoints (version >= 1.0), use the following code:
from ray.rllib.algorithms.algorithm import Algorithm
algo = Algorithm.from_checkpoint(checkpoint_path)
For older RLlib checkpoint versions (version < 1.0), you can restore an algorithm through:
from ray.rllib.algorithms.ppo import PPO
algo = PPO(config=config, env=env_class)
algo.restore(checkpoint_path)
Computing Actions#
The simplest way to programmatically compute actions from a trained agent is to
use Algorithm.compute_single_action()
.
This method preprocesses and filters the observation before passing it to the agent
policy.
Here is a simple example of testing a trained agent for one episode:
import pathlib
import gymnasium as gym
import numpy as np
import torch
from ray.rllib.core.rl_module import RLModule
env = gym.make("CartPole-v1")
# Create only the neural network (RLModule) from our checkpoint.
rl_module = RLModule.from_checkpoint(
pathlib.Path(best_checkpoint.path) / "learner_group" / "learner" / "rl_module"
)["default_policy"]
episode_return = 0
terminated = truncated = False
obs, info = env.reset()
while not terminated and not truncated:
# Compute the next action from a batch (B=1) of observations.
torch_obs_batch = torch.from_numpy(np.array([obs]))
action_logits = rl_module.forward_inference({"obs": torch_obs_batch})[
"action_dist_inputs"
]
# The default RLModule used here produces action logits (from which
# we'll have to sample an action or use the max-likelihood one).
action = torch.argmax(action_logits[0]).numpy()
obs, reward, terminated, truncated, info = env.step(action)
episode_return += reward
print(f"Reached episode return of {episode_return}.")
For more advanced usage on computing actions and other functionality, you can consult the RLlib Algorithm API documentation.
Accessing Policy State#
It is common to need to access a algorithm’s internal state, for instance to set or get model weights.
In RLlib algorithm state is replicated across multiple rollout workers (Ray actors)
in the cluster.
However, you can easily get and update this state between calls to train()
via Algorithm.env_runner_group.foreach_worker()
or Algorithm.env_runner_group.foreach_worker_with_index()
.
These functions take a lambda function that is applied with the worker as an argument.
These functions return values for each worker as a list.
You can also access just the “master” copy of the algorithm state through
Algorithm.get_policy()
or Algorithm.env_runner
,
but note that updates here may not be immediately reflected in
your rollout workers (if you have configured num_env_runners > 0
).
Here’s a quick example of how to access state of a model:
from ray.rllib.algorithms.ppo import PPOConfig
algo = (
PPOConfig()
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.environment("CartPole-v1")
.env_runners(num_env_runners=2)
).build()
# Get weights of the algo's RLModule.
algo.get_module().get_state()
# Same as above
algo.env_runner.module.get_state()
# Get list of weights of each EnvRunner, including remote replicas.
algo.env_runner_group.foreach_worker(lambda env_runner: env_runner.module.get_state())
# Same as above, but with index.
algo.env_runner_group.foreach_worker_with_id(
lambda _id, env_runner: env_runner.module.get_state()
)
Accessing Model State#
Similar to accessing policy state, you may want to get a reference to the
underlying neural network model being trained. For example, you may want to
pre-train it separately, or otherwise update its weights outside of RLlib.
This can be done by accessing the model
of the policy.
Note
To run these examples, you need to install a few extra dependencies, namely
pip install "gym[atari]" "gym[accept-rom-license]" atari_py
.
Below you find three explicit examples showing how to access the model state of an algorithm.
Example: Preprocessing observations for feeding into a model
Then for the code:
try:
import gymnasium as gym
env = gym.make("ale_py:ALE/Pong-v5")
obs, infos = env.reset()
except Exception:
import gym
env = gym.make("PongNoFrameskip-v4")
obs = env.reset()
# RLlib uses preprocessors to implement transforms such as one-hot encoding
# and flattening of tuple and dict observations.
from ray.rllib.models.preprocessors import get_preprocessor
prep = get_preprocessor(env.observation_space)(env.observation_space)
# <ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>
# Observations should be preprocessed prior to feeding into a model
obs.shape
# (210, 160, 3)
prep.transform(obs).shape
# (84, 84, 3)
Example: Querying a policy’s action distribution
# Get a reference to the policy
import numpy as np
import torch
from ray.rllib.algorithms.dqn import DQNConfig
algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.framework("torch")
.environment("CartPole-v1")
.env_runners(num_env_runners=0)
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
# <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>
policy = algo.get_policy()
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>
# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
logits, _ = policy.model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])
# Compute action distribution given logits
policy.dist_class
# <class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
dist = policy.dist_class(logits, policy.model)
# <ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>
# Query the distribution for samples, sample logps
dist.sample()
# <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
dist.logp(torch.tensor([1]))
# <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>
# Get the estimated values for the most recent forward pass
policy.model.value_function()
# <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>
print(policy.model)
"""
Model: "model"
_____________________________________________________________________
Layer (type) Output Shape Param # Connected to
=====================================================================
observations (InputLayer) [(None, 4)] 0
_____________________________________________________________________
fc_1 (Dense) (None, 256) 1280 observations[0][0]
_____________________________________________________________________
fc_value_1 (Dense) (None, 256) 1280 observations[0][0]
_____________________________________________________________________
fc_2 (Dense) (None, 256) 65792 fc_1[0][0]
_____________________________________________________________________
fc_value_2 (Dense) (None, 256) 65792 fc_value_1[0][0]
_____________________________________________________________________
fc_out (Dense) (None, 2) 514 fc_2[0][0]
_____________________________________________________________________
value_out (Dense) (None, 1) 257 fc_value_2[0][0]
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0
_____________________________________________________________________
"""
Example: Getting Q values from a DQN model
# Get a reference to the model through the policy
import numpy as np
import torch
from ray.rllib.algorithms.dqn import DQNConfig
algo = (
DQNConfig()
.api_stack(
enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False
)
.framework("torch")
.environment("CartPole-v1")
.training(
replay_buffer_config={
"type": "MultiAgentPrioritizedReplayBuffer",
}
)
).build()
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>
# List of all model variables
list(model.parameters())
# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is
# examples/offline_rl/saving_experiences.py
model_out = model({"obs": torch.from_numpy(np.array([[0.1, 0.2, 0.3, 0.4]]))})
# (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)
# Access the base Keras models (all default models have a base)
print(model)
"""
Model: "model"
_______________________________________________________________________
Layer (type) Output Shape Param # Connected to
=======================================================================
observations (InputLayer) [(None, 4)] 0
_______________________________________________________________________
fc_1 (Dense) (None, 256) 1280 observations[0][0]
_______________________________________________________________________
fc_out (Dense) (None, 256) 65792 fc_1[0][0]
_______________________________________________________________________
value_out (Dense) (None, 1) 257 fc_1[0][0]
=======================================================================
Total params: 67,329
Trainable params: 67,329
Non-trainable params: 0
______________________________________________________________________________
"""
# Access the Q value model (specific to DQN)
print(model.get_q_value_distributions(model_out[0])[0])
# tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
# ^ exact numbers may differ due to randomness
print(model.advantage_module)
# Access the state value model (specific to DQN)
print(model.get_state_value(model_out[0]))
# tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
# ^ exact number may differ due to randomness
print(model.value_module)
This is especially useful when used with custom model classes.
Configuring RLlib Algorithms#
You can configure RLlib algorithms in a modular fashion by working with so-called
AlgorithmConfig
objects.
In essence, you first create a config = AlgorithmConfig()
object and then call methods
on it to set the desired configuration options.
Each RLlib algorithm has its own config class that inherits from AlgorithmConfig
.
For instance, to create a PPO
algorithm, you start with a PPOConfig
object, to work
with a DQN
algorithm, you start with a DQNConfig
object, etc.
Note
Each algorithm has its specific settings, but most configuration options are shared.
We discuss the common options below, and refer to
the RLlib algorithms guide for algorithm-specific
properties.
Algorithms differ mostly in their training
settings.
Below you find the basic signature of the AlgorithmConfig
class, as well as some
advanced usage examples:
- class ray.rllib.algorithms.algorithm_config.AlgorithmConfig(algo_class: type | None = None)[source]
A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration.
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks # Construct a generic config object, specifying values within different # sub-categories, e.g. "training". config = ( PPOConfig() .training(gamma=0.9, lr=0.01) .environment(env="CartPole-v1") .env_runners(num_env_runners=0) .callbacks(MemoryTrackingCallbacks) ) # A config object can be used to construct the respective Algorithm. rllib_algo = config.build()
from ray.rllib.algorithms.ppo import PPOConfig from ray import tune # In combination with a tune.grid_search: config = PPOConfig() config.training(lr=tune.grid_search([0.01, 0.001])) # Use `to_dict()` method to get the legacy plain python config dict # for usage with `tune.Tuner().fit()`. tune.Tuner("PPO", param_space=config.to_dict())
As RLlib algorithms are fairly complex, they come with many configuration options. To make things easier, the common properties of algorithms are naturally grouped into the following categories:
Let’s discuss each category one by one, starting with training options.
Specifying Training Options#
Note
For instance, a DQNConfig
takes a double_q
training
argument to specify whether
to use a double-Q DQN, whereas in a PPOConfig
this does not make sense.
For individual algorithms, this is probably the most relevant configuration group,
as this is where all the algorithm-specific options go.
But the base configuration for training
of an AlgorithmConfig
is actually quite small:
- AlgorithmConfig.training(*, gamma: float | None = <ray.rllib.utils.from_config._NotProvided object>, lr: float | ~typing.List[~typing.List[int | float]] | ~typing.List[~typing.Tuple[int, int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip: float | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip_by: str | None = <ray.rllib.utils.from_config._NotProvided object>, train_batch_size: int | None = <ray.rllib.utils.from_config._NotProvided object>, train_batch_size_per_learner: int | None = <ray.rllib.utils.from_config._NotProvided object>, num_epochs: int | None = <ray.rllib.utils.from_config._NotProvided object>, minibatch_size: int | None = <ray.rllib.utils.from_config._NotProvided object>, shuffle_batch_per_epoch: bool | None = <ray.rllib.utils.from_config._NotProvided object>, model: dict | None = <ray.rllib.utils.from_config._NotProvided object>, optimizer: dict | None = <ray.rllib.utils.from_config._NotProvided object>, learner_class: ~typing.Type[Learner] | None = <ray.rllib.utils.from_config._NotProvided object>, learner_connector: ~typing.Callable[[RLModule], ConnectorV2 | ~typing.List[ConnectorV2]] | None = <ray.rllib.utils.from_config._NotProvided object>, add_default_connectors_to_learner_pipeline: bool | None = <ray.rllib.utils.from_config._NotProvided object>, learner_config_dict: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, num_sgd_iter=-1, max_requests_in_flight_per_sampler_worker=-1) AlgorithmConfig [source]
Sets the training related configuration.
- Parameters:
gamma – Float specifying the discount factor of the Markov Decision process.
lr – The learning rate (float) or learning rate schedule in the format of [[timestep, lr-value], [timestep, lr-value], …] In case of a schedule, intermediary timesteps are assigned to linearly interpolated learning rate values. A schedule config’s first entry must start with timestep 0, i.e.: [[0, initial_value], […]]. Note: If you require a) more than one optimizer (per RLModule), b) optimizer types that are not Adam, c) a learning rate schedule that is not a linearly interpolated, piecewise schedule as described above, or d) specifying c’tor arguments of the optimizer that are not the learning rate (e.g. Adam’s epsilon), then you must override your Learner’s
configure_optimizer_for_module()
method and handle lr-scheduling yourself.grad_clip – If None, no gradient clipping is applied. Otherwise, depending on the setting of
grad_clip_by
, the (float) value ofgrad_clip
has the following effect: Ifgrad_clip_by=value
: Clips all computed gradients individually inside the interval [-grad_clip
, +`grad_clip`]. Ifgrad_clip_by=norm
, computes the L2-norm of each weight/bias gradient tensor individually and then clip all gradients such that these L2-norms do not exceedgrad_clip
. The L2-norm of a tensor is computed via:sqrt(SUM(w0^2, w1^2, ..., wn^2))
where w[i] are the elements of the tensor (no matter what the shape of this tensor is). Ifgrad_clip_by=global_norm
, computes the square of the L2-norm of each weight/bias gradient tensor individually, sum up all these squared L2-norms across all given gradient tensors (e.g. the entire module to be updated), square root that overall sum, and then clip all gradients such that this global L2-norm does not exceed the given value. The global L2-norm over a list of tensors (e.g. W and V) is computed via:sqrt[SUM(w0^2, w1^2, ..., wn^2) + SUM(v0^2, v1^2, ..., vm^2)]
, where w[i] and v[j] are the elements of the tensors W and V (no matter what the shapes of these tensors are).grad_clip_by – See
grad_clip
for the effect of this setting on gradient clipping. Allowed values arevalue
,norm
, andglobal_norm
.train_batch_size_per_learner – Train batch size per individual Learner worker. This setting only applies to the new API stack. The number of Learner workers can be set via
config.resources( num_learners=...)
. The total effective batch size is thennum_learners
xtrain_batch_size_per_learner
and you can access it with the propertyAlgorithmConfig.total_train_batch_size
.train_batch_size – Training batch size, if applicable. When on the new API stack, this setting should no longer be used. Instead, use
train_batch_size_per_learner
(in combination withnum_learners
).num_epochs – The number of complete passes over the entire train batch (per Learner). Each pass might be further split into n minibatches (if
minibatch_size
provided).minibatch_size – The size of minibatches to use to further split the train batch into.
shuffle_batch_per_epoch – Whether to shuffle the train batch once per epoch. If the train batch has a time rank (axis=1), shuffling only takes place along the batch axis to not disturb any intact (episode) trajectories.
model – Arguments passed into the policy model. See models/catalog.py for a full list of the available model options. TODO: Provide ModelConfig objects instead of dicts.
optimizer – Arguments to pass to the policy optimizer. This setting is not used when
enable_rl_module_and_learner=True
.learner_class – The
Learner
class to use for (distributed) updating of the RLModule. Only used whenenable_rl_module_and_learner=True
.learner_connector – A callable taking an env observation space and an env action space as inputs and returning a learner ConnectorV2 (might be a pipeline) object.
add_default_connectors_to_learner_pipeline – If True (default), RLlib’s Learners automatically add the default Learner ConnectorV2 pieces to the LearnerPipeline. These automatically perform: a) adding observations from episodes to the train batch, if this has not already been done by a user-provided connector piece b) if RLModule is stateful, add a time rank to the train batch, zero-pad the data, and add the correct state inputs, if this has not already been done by a user-provided connector piece. c) add all other information (actions, rewards, terminateds, etc..) to the train batch, if this has not already been done by a user-provided connector piece. Only if you know exactly what you are doing, you should set this setting to False. Note that this setting is only relevant if the new API stack is used (including the new EnvRunner classes).
learner_config_dict – A dict to insert any settings accessible from within the Learner instance. This should only be used in connection with custom Learner subclasses and in case the user doesn’t want to write an extra
AlgorithmConfig
subclass just to add a few settings to the base Algo’s own config class.
- Returns:
This updated AlgorithmConfig object.
Specifying Environments#
- AlgorithmConfig.environment(env: str | ~typing.Any | gymnasium.Env | None = <ray.rllib.utils.from_config._NotProvided object>, *, env_config: dict | None = <ray.rllib.utils.from_config._NotProvided object>, observation_space: gymnasium.spaces.Space | None = <ray.rllib.utils.from_config._NotProvided object>, action_space: gymnasium.spaces.Space | None = <ray.rllib.utils.from_config._NotProvided object>, env_task_fn: ~typing.Callable[[~typing.Dict, ~typing.Any | gymnasium.Env, ~ray.rllib.env.env_context.EnvContext], ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, render_env: bool | None = <ray.rllib.utils.from_config._NotProvided object>, clip_rewards: bool | float | None = <ray.rllib.utils.from_config._NotProvided object>, normalize_actions: bool | None = <ray.rllib.utils.from_config._NotProvided object>, clip_actions: bool | None = <ray.rllib.utils.from_config._NotProvided object>, disable_env_checking: bool | None = <ray.rllib.utils.from_config._NotProvided object>, is_atari: bool | None = <ray.rllib.utils.from_config._NotProvided object>, action_mask_key: str | None = <ray.rllib.utils.from_config._NotProvided object>, auto_wrap_old_gym_envs=-1) AlgorithmConfig [source]
Sets the config’s RL-environment settings.
- Parameters:
env – The environment specifier. This can either be a tune-registered env, via
tune.register_env([name], lambda env_ctx: [env object])
, or a string specifier of an RLlib supported type. In the latter case, RLlib tries to interpret the specifier as either an Farama-Foundation gymnasium env, a PyBullet env, or a fully qualified classpath to an Env class, e.g. “ray.rllib.examples.envs.classes.random_env.RandomEnv”.env_config – Arguments dict passed to the env creator as an EnvContext object (which is a dict plus the properties:
num_env_runners
,worker_index
,vector_index
, andremote
).observation_space – The observation space for the Policies of this Algorithm.
action_space – The action space for the Policies of this Algorithm.
env_task_fn – A callable taking the last train results, the base env and the env context as args and returning a new task to set the env to. The env must be a
TaskSettableEnv
sub-class for this to work. Seeexamples/curriculum_learning.py
for an example.render_env – If True, try to render the environment on the local worker or on worker 1 (if num_env_runners > 0). For vectorized envs, this usually means that only the first sub-environment is rendered. In order for this to work, your env has to implement the
render()
method which either: a) handles window generation and rendering itself (returning True) or b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].clip_rewards – Whether to clip rewards during Policy’s postprocessing. None (default): Clip for Atari only (r=sign(r)). True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0. False: Never clip. [float value]: Clip at -value and + value. Tuple[value1, value2]: Clip at value1 and value2.
normalize_actions – If True, RLlib learns entirely inside a normalized action space (0.0 centered with small stddev; only affecting Box components). RLlib unsquashes actions (and clip, just in case) to the bounds of the env’s action space before sending actions back to the env.
clip_actions – If True, the RLlib default ModuleToEnv connector clips actions according to the env’s bounds (before sending them into the
env.step()
call).disable_env_checking – Disable RLlib’s env checks after a gymnasium.Env instance has been constructed in an EnvRunner. Note that the checks include an
env.reset()
andenv.step()
(with a random action), which might tinker with your env’s logic and behavior and thus negatively influence sample collection- and/or learning behavior.is_atari – This config can be used to explicitly specify whether the env is an Atari env or not. If not specified, RLlib tries to auto-detect this.
action_mask_key – If observation is a dictionary, expect the value by the key
action_mask_key
to contain a valid actions mask (numpy.int8
array of zeros and ones). Defaults to “action_mask”.
- Returns:
This updated AlgorithmConfig object.
Specifying Framework Options#
- AlgorithmConfig.framework(framework: str | None = <ray.rllib.utils.from_config._NotProvided object>, *, eager_tracing: bool | None = <ray.rllib.utils.from_config._NotProvided object>, eager_max_retraces: int | None = <ray.rllib.utils.from_config._NotProvided object>, tf_session_args: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, local_tf_session_args: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner: bool | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_what_to_compile: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_dynamo_mode: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_dynamo_backend: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker: bool | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker_dynamo_backend: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker_dynamo_mode: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_ddp_kwargs: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, torch_skip_nan_gradients: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig [source]
Sets the config’s DL framework settings.
- Parameters:
framework – torch: PyTorch; tf2: TensorFlow 2.x (eager execution or traced if eager_tracing=True); tf: TensorFlow (static-graph);
eager_tracing – Enable tracing in eager mode. This greatly improves performance (speedup ~2x), but makes it slightly harder to debug since Python code won’t be evaluated after the initial eager pass. Only possible if framework=tf2.
eager_max_retraces – Maximum number of tf.function re-traces before a runtime error is raised. This is to prevent unnoticed retraces of methods inside the
..._eager_traced
Policy, which could slow down execution by a factor of 4, without the user noticing what the root cause for this slowdown could be. Only necessary for framework=tf2. Set to None to ignore the re-trace count and never throw an error.tf_session_args – Configures TF for single-process operation by default.
local_tf_session_args – Override the following tf session args on the local worker
torch_compile_learner – If True, forward_train methods on TorchRLModules on the learner are compiled. If not specified, the default is to compile forward train on the learner.
torch_compile_learner_what_to_compile – A TorchCompileWhatToCompile mode specifying what to compile on the learner side if torch_compile_learner is True. See TorchCompileWhatToCompile for details and advice on its usage.
torch_compile_learner_dynamo_backend – The torch dynamo backend to use on the learner.
torch_compile_learner_dynamo_mode – The torch dynamo mode to use on the learner.
torch_compile_worker – If True, forward exploration and inference methods on TorchRLModules on the workers are compiled. If not specified, the default is to not compile forward methods on the workers because retracing can be expensive.
torch_compile_worker_dynamo_backend – The torch dynamo backend to use on the workers.
torch_compile_worker_dynamo_mode – The torch dynamo mode to use on the workers.
torch_ddp_kwargs – The kwargs to pass into
torch.nn.parallel.DistributedDataParallel
when usingnum_learners > 1
. This is specifically helpful when searching for unused parameters that are not used in the backward pass. This can give hints for errors in custom models where some parameters do not get touched in the backward pass although they should.torch_skip_nan_gradients – If updates with
nan
gradients should be entirely skipped. This skips updates in the optimizer entirely if they contain anynan
gradient. This can help to avoid biasing moving-average based optimizers - like Adam. This can help in training phases where policy updates can be highly unstable such as during the early stages of training or with highly exploratory policies. In such phases many gradients might turnnan
and setting them to zero could corrupt the optimizer’s internal state. The default isFalse
and turnsnan
gradients to zero. If manynan
gradients are encountered consider (a) monitoring gradients by settinglog_gradients
inAlgorithmConfig
toTrue
, (b) use proper weight initialization (e.g. Xavier, Kaiming) via themodel_config_dict
inAlgorithmConfig.rl_module
and/or (c) gradient clipping viagrad_clip
inAlgorithmConfig.training
.
- Returns:
This updated AlgorithmConfig object.
Specifying Rollout Workers#
- AlgorithmConfig.rollouts(**kwargs)
Specifying Evaluation Options#
- AlgorithmConfig.evaluation(*, evaluation_interval: int | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_duration: int | str | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_duration_unit: str | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_sample_timeout_s: float | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_parallel_to_training: bool | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_force_reset_envs_before_iteration: bool | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_config: ~ray.rllib.algorithms.algorithm_config.AlgorithmConfig | dict | None = <ray.rllib.utils.from_config._NotProvided object>, off_policy_estimation_methods: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, ope_split_batch_by_episode: bool | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_num_env_runners: int | None = <ray.rllib.utils.from_config._NotProvided object>, custom_evaluation_function: ~typing.Callable | None = <ray.rllib.utils.from_config._NotProvided object>, always_attach_evaluation_results=-1, evaluation_num_workers=-1) AlgorithmConfig [source]
Sets the config’s evaluation settings.
- Parameters:
evaluation_interval – Evaluate with every
evaluation_interval
training iterations. The evaluation stats are reported under the “evaluation” metric key. Set to None (or 0) for no evaluation.evaluation_duration – Duration for which to run evaluation each
evaluation_interval
. The unit for the duration can be set viaevaluation_duration_unit
to either “episodes” (default) or “timesteps”. If using multiple evaluation workers (EnvRunners) in theevaluation_num_env_runners > 1
setting, the amount of episodes/timesteps to run are split amongst these. A special value of “auto” can be used in caseevaluation_parallel_to_training=True
. This is the recommended way when trying to save as much time on evaluation as possible. The Algorithm then runs as many timesteps via the evaluation workers as possible, while not taking longer than the parallely running training step and thus, never wasting any idle time on either training- or evaluation workers. When using this setting (evaluation_duration="auto"
), it is strongly advised to setevaluation_interval=1
andevaluation_force_reset_envs_before_iteration=True
at the same time.evaluation_duration_unit – The unit, with which to count the evaluation duration. Either “episodes” (default) or “timesteps”. Note that this setting is ignored if
evaluation_duration="auto"
.evaluation_sample_timeout_s – The timeout (in seconds) for evaluation workers to sample a complete episode in the case your config settings are:
evaluation_duration != auto
andevaluation_duration_unit=episode
. After this time, the user receives a warning and instructions on how to fix the issue.evaluation_parallel_to_training – Whether to run evaluation in parallel to the
Algorithm.training_step()
call, using threading. Default=False. E.g. for evaluation_interval=1 -> In every call toAlgorithm.train()
, theAlgorithm.training_step()
andAlgorithm.evaluate()
calls run in parallel. Note that this setting - albeit extremely efficient b/c it wastes no extra time for evaluation - causes the evaluation results to lag one iteration behind the rest of the training results. This is important when picking a good checkpoint. For example, if iteration 42 reports a good evaluationepisode_return_mean
, be aware that these results were achieved on the weights trained in iteration 41, so you should probably pick the iteration 41 checkpoint instead.evaluation_force_reset_envs_before_iteration – Whether all environments should be force-reset (even if they are not done yet) right before the evaluation step of the iteration begins. Setting this to True (default) makes sure that the evaluation results aren’t polluted with episode statistics that were actually (at least partially) achieved with an earlier set of weights. Note that this setting is only supported on the new API stack w/ EnvRunners and ConnectorV2 (
config.enable_rl_module_and_learner=True
ANDconfig.enable_env_runner_and_connector_v2=True
).evaluation_config – Typical usage is to pass extra args to evaluation env creator and to disable exploration by computing deterministic actions. IMPORTANT NOTE: Policy gradient algorithms are able to find the optimal policy, even if this is a stochastic one. Setting “explore=False” here results in the evaluation workers not using this optimal policy!
off_policy_estimation_methods – Specify how to evaluate the current policy, along with any optional config parameters. This only has an effect when reading offline experiences (“input” is not “sampler”). Available keys: {ope_method_name: {“type”: ope_type, …}} where
ope_method_name
is a user-defined string to save the OPE results under, andope_type
can be any subclass of OffPolicyEstimator, e.g. ray.rllib.offline.estimators.is::ImportanceSampling or your own custom subclass, or the full class path to the subclass. You can also add additional config arguments to be passed to the OffPolicyEstimator in the dict, e.g. {“qreg_dr”: {“type”: DoublyRobust, “q_model_type”: “qreg”, “k”: 5}}ope_split_batch_by_episode – Whether to use SampleBatch.split_by_episode() to split the input batch to episodes before estimating the ope metrics. In case of bandits you should make this False to see improvements in ope evaluation speed. In case of bandits, it is ok to not split by episode, since each record is one timestep already. The default is True.
evaluation_num_env_runners – Number of parallel EnvRunners to use for evaluation. Note that this is set to zero by default, which means evaluation is run in the algorithm process (only if
evaluation_interval
is not 0 or None). If you increase this, also increases the Ray resource usage of the algorithm since evaluation workers are created separately from those EnvRunners used to sample data for training.custom_evaluation_function – Customize the evaluation method. This must be a function of signature (algo: Algorithm, eval_workers: EnvRunnerGroup) -> (metrics: dict, env_steps: int, agent_steps: int) (metrics: dict if
enable_env_runner_and_connector_v2=True
), whereenv_steps
andagent_steps
define the number of sampled steps during the evaluation iteration. See the Algorithm.evaluate() method to see the default implementation. The Algorithm guarantees all eval workers have the latest policy state before this function is called.
- Returns:
This updated AlgorithmConfig object.
Specifying Offline Data Options#
- AlgorithmConfig.offline_data(*, input_: str | ~typing.Callable[[~ray.rllib.offline.io_context.IOContext], ~ray.rllib.offline.input_reader.InputReader] | None = <ray.rllib.utils.from_config._NotProvided object>, input_read_method: str | ~typing.Callable | None = <ray.rllib.utils.from_config._NotProvided object>, input_read_method_kwargs: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, input_read_schema: ~typing.Dict[str, str] | None = <ray.rllib.utils.from_config._NotProvided object>, input_read_episodes: bool | None = <ray.rllib.utils.from_config._NotProvided object>, input_read_sample_batches: bool | None = <ray.rllib.utils.from_config._NotProvided object>, input_read_batch_size: int | None = <ray.rllib.utils.from_config._NotProvided object>, input_filesystem: str | None = <ray.rllib.utils.from_config._NotProvided object>, input_filesystem_kwargs: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, input_compress_columns: ~typing.List[str] | None = <ray.rllib.utils.from_config._NotProvided object>, materialize_data: bool | None = <ray.rllib.utils.from_config._NotProvided object>, materialize_mapped_data: bool | None = <ray.rllib.utils.from_config._NotProvided object>, map_batches_kwargs: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, iter_batches_kwargs: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, prelearner_class: ~typing.Type | None = <ray.rllib.utils.from_config._NotProvided object>, prelearner_buffer_class: ~typing.Type | None = <ray.rllib.utils.from_config._NotProvided object>, prelearner_buffer_kwargs: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, prelearner_module_synch_period: int | None = <ray.rllib.utils.from_config._NotProvided object>, dataset_num_iters_per_learner: int | None = <ray.rllib.utils.from_config._NotProvided object>, input_config: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, actions_in_input_normalized: bool | None = <ray.rllib.utils.from_config._NotProvided object>, postprocess_inputs: bool | None = <ray.rllib.utils.from_config._NotProvided object>, shuffle_buffer_size: int | None = <ray.rllib.utils.from_config._NotProvided object>, output: str | None = <ray.rllib.utils.from_config._NotProvided object>, output_config: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, output_compress_columns: ~typing.List[str] | None = <ray.rllib.utils.from_config._NotProvided object>, output_max_file_size: float | None = <ray.rllib.utils.from_config._NotProvided object>, output_max_rows_per_file: int | None = <ray.rllib.utils.from_config._NotProvided object>, output_write_method: str | None = <ray.rllib.utils.from_config._NotProvided object>, output_write_method_kwargs: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, output_filesystem: str | None = <ray.rllib.utils.from_config._NotProvided object>, output_filesystem_kwargs: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, output_write_episodes: bool | None = <ray.rllib.utils.from_config._NotProvided object>, offline_sampling: str | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig [source]
Sets the config’s offline data settings.
- Parameters:
input – Specify how to generate experiences: - “sampler”: Generate experiences via online (env) simulation (default). - A local directory or file glob expression (e.g., “/tmp/.json”). - A list of individual file paths/URIs (e.g., [“/tmp/1.json”, “s3://bucket/2.json”]). - A dict with string keys and sampling probabilities as values (e.g., {“sampler”: 0.4, “/tmp/.json”: 0.4, “s3://bucket/expert.json”: 0.2}). - A callable that takes an
IOContext
object as only arg and returns aray.rllib.offline.InputReader
. - A string key that indexes a callable withtune.registry.register_input
input_read_method – Read method for the
ray.data.Dataset
to read in the offline data frominput_
. The default isread_parquet
for Parquet files. See https://docs.ray.io/en/latest/data/api/input_output.html for more info about available read methods inray.data
.input_read_method_kwargs – Keyword args for
input_read_method
. These are passed into the read method without checking. If no arguments are passed in the default argument{'override_num_blocks': max(num_learners * 2, 2)}
is used. Use these keyword args together withmap_batches_kwargs
anditer_batches_kwargs
to tune the performance of the data pipeline.input_read_schema – Table schema for converting offline data to episodes. This schema maps the offline data columns to ray.rllib.core.columns.Columns:
{Columns.OBS: 'o_t', Columns.ACTIONS: 'a_t', ...}
. Columns in the data set that are not mapped via this schema are sorted into episodes’extra_model_outputs
. If no schema is passed in the default schema used isray.rllib.offline.offline_data.SCHEMA
. If your data set contains already the names in this schema, noinput_read_schema
is needed.input_read_episodes – Whether offline data is already stored in RLlib’s
EpisodeType
format, i.e.ray.rllib.env.SingleAgentEpisode
(multi -agent is planned but not supported, yet). Reading episodes directly avoids additional transform steps and is usually faster and therefore the recommended format when your application remains fully inside of RLlib’s schema. The other format is a columnar format and is agnostic to the RL framework used. Use the latter format, if you are unsure when to use the data or in which RL framework. The default is to read column data, i.e. False.input_read_episodes
andinput_read_sample_batches
cannot be True at the same time. See alsooutput_write_episodes
to define the output data format when recording.input_read_sample_batches – Whether offline data is stored in RLlib’s old stack
SampleBatch
type. This is usually the case for older data recorded with RLlib in JSON line format. Reading inSampleBatch
data needs extra transforms and might not concatenate episode chunks contained in differentSampleBatch`es in the data. If possible avoid to read `SampleBatch`es and convert them in a controlled form into RLlib's `EpisodeType
(i.e.SingleAgentEpisode
orMultiAgentEpisode
). The default is False.input_read_episodes
andinput_read_sample_batches
cannot be True at the same time.input_read_batch_size – Batch size to pull from the data set. This could differ from the
train_batch_size_per_learner
, if a dataset holdsEpisodeType
(i.e.SingleAgentEpisode
orMultiAgentEpisode
) orBatchType
(i.e.SampleBatch
orMultiAgentBatch
) or any other data type that contains multiple timesteps in a single row of the dataset. In such cases a single batch of sizetrain_batch_size_per_learner
will potentially pull a multiple oftrain_batch_size_per_learner
timesteps from the offline dataset. The default isNone
in which thetrain_batch_size_per_learner
is pulled.input_filesystem – A cloud filesystem to handle access to cloud storage when reading experiences. Should be either “gcs” for Google Cloud Storage, “s3” for AWS S3 buckets, or “abs” for Azure Blob Storage.
input_filesystem_kwargs – A dictionary holding the kwargs for the filesystem given by
input_filesystem
. Seegcsfs.GCSFilesystem
for GCS,pyarrow.fs.S3FileSystem
, for S3, andablfs.AzureBlobFilesystem
for ABS filesystem arguments.input_compress_columns – What input columns are compressed with LZ4 in the input data. If data is stored in RLlib’s
SingleAgentEpisode
(MultiAgentEpisode
not supported, yet). Note the providingrllib.core.columns.Columns.OBS
also tries to decompressrllib.core.columns.Columns.NEXT_OBS
.materialize_data – Whether the raw data should be materialized in memory. This boosts performance, but requires enough memory to avoid an OOM, so make sure that your cluster has the resources available. For very large data you might want to switch to streaming mode by setting this to
False
(default). If your algorithm does not need the RLModule in the Learner connector pipeline or all (learner) connectors are stateless you should consider settingmaterialize_mapped_data
toTrue
instead (and setmaterialize_data
toFalse
). If your data does not fit into memory and your Learner connector pipeline requires an RLModule or is stateful, set bothmaterialize_data
andmaterialize_mapped_data
toFalse
.materialize_mapped_data – Whether the data should be materialized after running it through the Learner connector pipeline (i.e. after running the
OfflinePreLearner
). This improves performance, but should only be used in case the (learner) connector pipeline does not require an RLModule and the (learner) connector pipeline is stateless. For example, MARWIL’s Learner connector pipeline requires the RLModule for value function predictions and training batches would become stale after some iterations causing learning degradation or divergence. Also ensure that your cluster has enough memory available to avoid an OOM. If set toTrue
(True), make sure thatmaterialize_data
is set toFalse
to avoid materialization of two datasets. If your data does not fit into memory and your Learner connector pipeline requires an RLModule or is stateful, set bothmaterialize_data
andmaterialize_mapped_data
toFalse
.map_batches_kwargs – Keyword args for the
map_batches
method. These are passed into theray.data.Dataset.map_batches
method when sampling without checking. If no arguments passed in the default arguments{'concurrency': max(2, num_learners), 'zero_copy_batch': True}
is used. Use these keyword args together withinput_read_method_kwargs
anditer_batches_kwargs
to tune the performance of the data pipeline.iter_batches_kwargs – Keyword args for the
iter_batches
method. These are passed into theray.data.Dataset.iter_batches
method when sampling without checking. If no arguments are passed in, the default argument{'prefetch_batches': 2, 'local_buffer_shuffle_size': train_batch_size_per_learner x 4}
is used. Use these keyword args together withinput_read_method_kwargs
andmap_batches_kwargs
to tune the performance of the data pipeline.prelearner_class – An optional
OfflinePreLearner
class that is used to transform data batches inray.data.map_batches
used in theOfflineData
class to transform data from columns to batches that can be used in theLearner.update...()
methods. Override theOfflinePreLearner
class and pass your derived class in here, if you need to make some further transformations specific for your data or loss. The default is None which uses the baseOfflinePreLearner
defined inray.rllib.offline.offline_prelearner
.prelearner_module_synch_period – The period (number of batches converted) after which the
RLModule
held by thePreLearner
should sync weights. ThePreLearner
is used to preprocess batches for the learners. The higher this value, the more off-policy thePreLearner
’s module is. Values too small force thePreLearner
to sync more frequently and thus might slow down the data pipeline. The default value chosen by theOfflinePreLearner
is 10.dataset_num_iters_per_learner – Number of updates to run in each learner during a single training iteration. If None, each learner runs a complete epoch over its data block (the dataset is partitioned into at least as many blocks as there are learners). The default is
None
.input_config – Arguments that describe the settings for reading the input. If input is “sample”, this is the environment configuration, e.g.
env_name
andenv_config
, etc. SeeEnvContext
for more info. If the input is “dataset”, this contains e.g.format
,path
.actions_in_input_normalized – True, if the actions in a given offline “input” are already normalized (between -1.0 and 1.0). This is usually the case when the offline file has been generated by another RLlib algorithm (e.g. PPO or SAC), while “normalize_actions” was set to True.
postprocess_inputs – Whether to run postprocess_trajectory() on the trajectory fragments from offline inputs. Note that postprocessing is done using the current policy, not the behavior policy, which is typically undesirable for on-policy algorithms.
shuffle_buffer_size – If positive, input batches are shuffled via a sliding window buffer of this number of batches. Use this if the input data is not in random enough order. Input is delayed until the shuffle buffer is filled.
output – Specify where experiences should be saved: - None: don’t save any experiences - “logdir” to save to the agent log dir - a path/URI to save to a custom output directory (e.g., “s3://bckt/”) - a function that returns a rllib.offline.OutputWriter
output_config – Arguments accessible from the IOContext for configuring custom output.
output_compress_columns – What sample batch columns to LZ4 compress in the output data. Note that providing
rllib.core.columns.Columns.OBS
also compressesrllib.core.columns.Columns.NEXT_OBS
.output_max_file_size – Max output file size (in bytes) before rolling over to a new file.
output_max_rows_per_file – Max output row numbers before rolling over to a new file.
output_write_method – Write method for the
ray.data.Dataset
to write the offline data tooutput
. The default isread_parquet
for Parquet files. See https://docs.ray.io/en/latest/data/api/input_output.html for more info about available read methods inray.data
.output_write_method_kwargs –
kwargs
for theoutput_write_method
. These are passed into the write method without checking.output_filesystem – A cloud filesystem to handle access to cloud storage when writing experiences. Should be either “gcs” for Google Cloud Storage, “s3” for AWS S3 buckets, or “abs” for Azure Blob Storage.
output_filesystem_kwargs – A dictionary holding the kwargs for the filesystem given by
output_filesystem
. Seegcsfs.GCSFilesystem
for GCS,pyarrow.fs.S3FileSystem
, for S3, andablfs.AzureBlobFilesystem
for ABS filesystem arguments.offline_sampling – Whether sampling for the Algorithm happens via reading from offline data. If True, EnvRunners don’t limit the number of collected batches within the same
sample()
call based on the number of sub-environments within the worker (no sub-environments present).
- Returns:
This updated AlgorithmConfig object.
Specifying Multi-Agent Options#
- AlgorithmConfig.multi_agent(*, policies: ~typing.Dict[str, PolicySpec] | ~typing.Collection[str] | None = <ray.rllib.utils.from_config._NotProvided object>, policy_map_capacity: int | None = <ray.rllib.utils.from_config._NotProvided object>, policy_mapping_fn: ~typing.Callable[[~typing.Any, EpisodeType], str] | None = <ray.rllib.utils.from_config._NotProvided object>, policies_to_train: ~typing.Collection[str] | ~typing.Callable[[str, SampleBatch | MultiAgentBatch | ~typing.Dict[str, ~typing.Any]], bool] | None = <ray.rllib.utils.from_config._NotProvided object>, policy_states_are_swappable: bool | None = <ray.rllib.utils.from_config._NotProvided object>, observation_fn: ~typing.Callable | None = <ray.rllib.utils.from_config._NotProvided object>, count_steps_by: str | None = <ray.rllib.utils.from_config._NotProvided object>, algorithm_config_overrides_per_module=-1, replay_mode=-1, policy_map_cache=-1) AlgorithmConfig [source]
Sets the config’s multi-agent settings.
Validates the new multi-agent settings and translates everything into a unified multi-agent setup format. For example a
policies
list or set of IDs is properly converted into a dict mapping these IDs to PolicySpecs.- Parameters:
policies – Map of type MultiAgentPolicyConfigDict from policy ids to either 4-tuples of (policy_cls, obs_space, act_space, config) or PolicySpecs. These tuples or PolicySpecs define the class of the policy, the observation- and action spaces of the policies, and any extra config.
policy_map_capacity – Keep this many policies in the “policy_map” (before writing least-recently used ones to disk/S3).
policy_mapping_fn – Function mapping agent ids to policy ids. The signature is:
(agent_id, episode, worker, **kwargs) -> PolicyID
.policies_to_train – Determines those policies that should be updated. Options are: - None, for training all policies. - An iterable of PolicyIDs that should be trained. - A callable, taking a PolicyID and a SampleBatch or MultiAgentBatch and returning a bool (indicating whether the given policy is trainable or not, given the particular batch). This allows you to have a policy trained only on certain data (e.g. when playing against a certain opponent).
policy_states_are_swappable – Whether all Policy objects in this map can be “swapped out” via a simple
state = A.get_state(); B.set_state(state)
, whereA
andB
are policy instances in this map. You should set this to True for significantly speeding up the PolicyMap’s cache lookup times, iff your policies all share the same neural network architecture and optimizer types. If True, the PolicyMap doesn’t have to garbage collect old, least recently used policies, but instead keeps them in memory and simply override their state with the state of the most recently accessed one. For example, in a league-based training setup, you might have 100s of the same policies in your map (playing against each other in various combinations), but all of them share the same state structure (are “swappable”).observation_fn – Optional function that can be used to enhance the local agent observations to include more state. See rllib/evaluation/observation_function.py for more info.
count_steps_by – Which metric to use as the “batch size” when building a MultiAgentBatch. The two supported values are: “env_steps”: Count each time the env is “stepped” (no matter how many multi-agent actions are passed/how many multi-agent observations have been returned in the previous step). “agent_steps”: Count each individual agent step as one step.
- Returns:
This updated AlgorithmConfig object.
Specifying Reporting Options#
- AlgorithmConfig.reporting(*, keep_per_episode_custom_metrics: bool | None = <ray.rllib.utils.from_config._NotProvided object>, metrics_episode_collection_timeout_s: float | None = <ray.rllib.utils.from_config._NotProvided object>, metrics_num_episodes_for_smoothing: int | None = <ray.rllib.utils.from_config._NotProvided object>, min_time_s_per_iteration: float | None = <ray.rllib.utils.from_config._NotProvided object>, min_train_timesteps_per_iteration: int | None = <ray.rllib.utils.from_config._NotProvided object>, min_sample_timesteps_per_iteration: int | None = <ray.rllib.utils.from_config._NotProvided object>, log_gradients: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig [source]
Sets the config’s reporting settings.
- Parameters:
keep_per_episode_custom_metrics – Store raw custom metrics without calculating max, min, mean
metrics_episode_collection_timeout_s – Wait for metric batches for at most this many seconds. Those that have not returned in time are collected in the next train iteration.
metrics_num_episodes_for_smoothing – Smooth rollout metrics over this many episodes, if possible. In case rollouts (sample collection) just started, there may be fewer than this many episodes in the buffer and we’ll compute metrics over this smaller number of available episodes. In case there are more than this many episodes collected in a single training iteration, use all of these episodes for metrics computation, meaning don’t ever cut any “excess” episodes. Set this to 1 to disable smoothing and to always report only the most recently collected episode’s return.
min_time_s_per_iteration – Minimum time (in sec) to accumulate within a single
Algorithm.train()
call. This value does not affect learning, only the number of timesAlgorithm.training_step()
is called byAlgorithm.train()
. If - after one such step attempt, the time taken has not reachedmin_time_s_per_iteration
, performs n moreAlgorithm.training_step()
calls until the minimum time has been consumed. Set to 0 or None for no minimum time.min_train_timesteps_per_iteration – Minimum training timesteps to accumulate within a single
train()
call. This value does not affect learning, only the number of timesAlgorithm.training_step()
is called byAlgorithm.train()
. If - after one such step attempt, the training timestep count has not been reached, performs n moretraining_step()
calls until the minimum timesteps have been executed. Set to 0 or None for no minimum timesteps.min_sample_timesteps_per_iteration – Minimum env sampling timesteps to accumulate within a single
train()
call. This value does not affect learning, only the number of timesAlgorithm.training_step()
is called byAlgorithm.train()
. If - after one such step attempt, the env sampling timestep count has not been reached, performs n moretraining_step()
calls until the minimum timesteps have been executed. Set to 0 or None for no minimum timesteps.log_gradients – Log gradients to results. If this is
True
the global norm of the gradients dictionariy for each optimizer is logged to results. The default isTrue
.
- Returns:
This updated AlgorithmConfig object.
Specifying Checkpointing Options#
- AlgorithmConfig.checkpointing(export_native_model_files: bool | None = <ray.rllib.utils.from_config._NotProvided object>, checkpoint_trainable_policies_only: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig [source]
Sets the config’s checkpointing settings.
- Parameters:
export_native_model_files – Whether an individual Policy- or the Algorithm’s checkpoints also contain (tf or torch) native model files. These could be used to restore just the NN models from these files w/o requiring RLlib. These files are generated by calling the tf- or torch- built-in saving utility methods on the actual models.
checkpoint_trainable_policies_only – Whether to only add Policies to the Algorithm checkpoint (in sub-directory “policies/”) that are trainable according to the
is_trainable_policy
callable of the local worker.
- Returns:
This updated AlgorithmConfig object.
Specifying Debugging Options#
- AlgorithmConfig.debugging(*, logger_creator: ~typing.Callable[[], ~ray.tune.logger.logger.Logger] | None = <ray.rllib.utils.from_config._NotProvided object>, logger_config: dict | None = <ray.rllib.utils.from_config._NotProvided object>, log_level: str | None = <ray.rllib.utils.from_config._NotProvided object>, log_sys_usage: bool | None = <ray.rllib.utils.from_config._NotProvided object>, fake_sampler: bool | None = <ray.rllib.utils.from_config._NotProvided object>, seed: int | None = <ray.rllib.utils.from_config._NotProvided object>, _run_training_always_in_thread: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _evaluation_parallel_to_training_wo_thread: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig [source]
Sets the config’s debugging settings.
- Parameters:
logger_creator – Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created.
logger_config – Define logger-specific configuration to be used inside Logger Default value None allows overwriting with nested dicts.
log_level – Set the ray.rllib.* log level for the agent process and its workers. Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level also periodically prints out summaries of relevant internal dataflow (this is also printed out once at startup at the INFO level).
log_sys_usage – Log system resource metrics to results. This requires
psutil
to be installed for sys stats, andgputil
for GPU metrics.fake_sampler – Use fake (infinite speed) sampler. For testing only.
seed – This argument, in conjunction with worker_index, sets the random seed of each worker, so that identically configured trials have identical results. This makes experiments reproducible.
_run_training_always_in_thread – Runs the n
training_step()
calls per iteration always in a separate thread (just as we would do withevaluation_parallel_to_training=True
, but even without evaluation going on and even without evaluation workers being created in the Algorithm)._evaluation_parallel_to_training_wo_thread – Only relevant if
evaluation_parallel_to_training
is True. Then, in order to achieve parallelism, RLlib doesn’t use a thread pool (as it usually does in this situation).
- Returns:
This updated AlgorithmConfig object.
Specifying Callback Options#
- AlgorithmConfig.callbacks(callbacks_class) AlgorithmConfig [source]
Sets the callbacks configuration.
- Parameters:
callbacks_class – Callbacks class, whose methods are called during various phases of training and environment sample collection. See the
DefaultCallbacks
class andexamples/metrics/custom_metrics_and_callbacks.py
for more usage information.- Returns:
This updated AlgorithmConfig object.
Specifying Resources#
You can control the degree of parallelism used by setting the num_env_runners
hyperparameter for most algorithms. The Algorithm will construct that many
“remote worker” instances (see RolloutWorker class)
that are constructed as ray.remote actors, plus exactly one “local worker”, an EnvRunner
object that isn’t a
ray actor, but lives directly inside the Algorithm.
For most algorithms, learning updates are performed on the local worker and sample collection from
one or more environments is performed by the remote workers (in parallel).
For example, setting num_env_runners=0
will only create the local worker, in which case both
sample collection and training will be done by the local worker.
On the other hand, setting num_env_runners=5
will create the local worker (responsible for training updates)
and 5 remote workers (responsible for sample collection).
Since learning is most of the time done on the local worker, it may help to provide one or more GPUs
to that worker via the num_gpus
setting.
Similarly, you can control the resource allocation to remote workers with num_cpus_per_env_runner
, num_gpus_per_env_runner
, and custom_resources_per_env_runner
.
The number of GPUs can be fractional quantities (for example, 0.5) to allocate only a fraction
of a GPU. For example, with DQN you can pack five algorithms onto one GPU by setting
num_gpus: 0.2
. See this fractional GPU example here
as well that also demonstrates how environments (running on the remote workers) that
require a GPU can benefit from the num_gpus_per_env_runner
setting.
For synchronous algorithms like PPO and A2C, the driver and workers can make use of
the same GPU. To do this for an amount of n
GPUS:
gpu_count = n
num_gpus = 0.0001 # Driver GPU
num_gpus_per_env_runner = (gpu_count - num_gpus) / num_env_runners
If you specify num_gpus
and your machine does not have the required number of GPUs
available, a RuntimeError will be thrown by the respective worker. On the other hand,
if you set num_gpus=0
, your policies will be built solely on the CPU, even if
GPUs are available on the machine.
- AlgorithmConfig.resources(*, num_cpus_for_main_process: int | None = <ray.rllib.utils.from_config._NotProvided object>, num_gpus: int | float | None = <ray.rllib.utils.from_config._NotProvided object>, _fake_gpus: bool | None = <ray.rllib.utils.from_config._NotProvided object>, placement_strategy: str | None = <ray.rllib.utils.from_config._NotProvided object>, num_cpus_per_worker=-1, num_gpus_per_worker=-1, custom_resources_per_worker=-1, num_learner_workers=-1, num_cpus_per_learner_worker=-1, num_gpus_per_learner_worker=-1, local_gpu_idx=-1, num_cpus_for_local_worker=-1) AlgorithmConfig [source]
Specifies resources allocated for an Algorithm and its ray actors/workers.
- Parameters:
num_cpus_for_main_process – Number of CPUs to allocate for the main algorithm process that runs
Algorithm.training_step()
. Note: This is only relevant when running RLlib through Tune. Otherwise,Algorithm.training_step()
runs in the main program (driver).num_gpus – Number of GPUs to allocate to the algorithm process. Note that not all algorithms can take advantage of GPUs. Support for multi-GPU is currently only available for tf-[PPO/IMPALA/DQN/PG]. This can be fractional (e.g., 0.3 GPUs).
_fake_gpus – Set to True for debugging (multi-)?GPU funcitonality on a CPU machine. GPU towers are simulated by graphs located on CPUs in this case. Use
num_gpus
to test for different numbers of fake GPUs.placement_strategy – The strategy for the placement group factory returned by
Algorithm.default_resource_request()
. A PlacementGroup defines, which devices (resources) should always be co-located on the same node. For example, an Algorithm with 2 EnvRunners and 1 Learner (with 1 GPU) requests a placement group with the bundles: [{“cpu”: 1}, {“gpu”: 1, “cpu”: 1}, {“cpu”: 1}, {“cpu”: 1}], where the first bundle is for the local (main Algorithm) process, the second one for the 1 Learner worker and the last 2 bundles are for the two EnvRunners. These bundles can now be “placed” on the same or different nodes depending on the value ofplacement_strategy
: “PACK”: Packs bundles into as few nodes as possible. “SPREAD”: Places bundles across distinct nodes as even as possible. “STRICT_PACK”: Packs bundles into one node. The group is not allowed to span multiple nodes. “STRICT_SPREAD”: Packs bundles across distinct nodes.
- Returns:
This updated AlgorithmConfig object.
Specifying Experimental Features#
- AlgorithmConfig.experimental(*, _torch_grad_scaler_class: ~typing.Type | None = <ray.rllib.utils.from_config._NotProvided object>, _torch_lr_scheduler_classes: ~typing.List[~typing.Type] | ~typing.Dict[str, ~typing.List[~typing.Type]] | None = <ray.rllib.utils.from_config._NotProvided object>, _tf_policy_handles_more_than_one_loss: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _disable_preprocessor_api: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _disable_action_flattening: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _disable_initialize_loss_from_dummy_batch: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _enable_new_api_stack=-1) AlgorithmConfig [source]
Sets the config’s experimental settings.
- Parameters:
_torch_grad_scaler_class – Class to use for torch loss scaling (and gradient unscaling). The class must implement the following methods to be compatible with a
TorchLearner
. These methods/APIs match exactly those of torch’s owntorch.amp.GradScaler
(see here for more details https://pytorch.org/docs/stable/amp.html#gradient-scaling):scale([loss])
to scale the loss by some factor.get_scale()
to get the current scale factor value.step([optimizer])
to unscale the grads (divide by the scale factor) and step the given optimizer.update()
to update the scaler after an optimizer step (for example to adjust the scale factor)._torch_lr_scheduler_classes – A list of
torch.lr_scheduler.LRScheduler
(see here for more details https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) classes or a dictionary mapping module IDs to such a list of respective scheduler classes. Multiple scheduler classes can be applied in sequence and are stepped in the same sequence as defined here. Note, most learning rate schedulers need arguments to be configured, that is, you might have to partially initialize the schedulers in the list(s) usingfunctools.partial
._tf_policy_handles_more_than_one_loss – Experimental flag. If True, TFPolicy handles more than one loss or optimizer. Set this to True, if you would like to return more than one loss term from your
loss_fn
and an equal number of optimizers from youroptimizer_fn
._disable_preprocessor_api – Experimental flag. If True, no (observation) preprocessor is created and observations arrive in model as they are returned by the env.
_disable_action_flattening – Experimental flag. If True, RLlib doesn’t flatten the policy-computed actions into a single tensor (for storage in SampleCollectors/output files/etc..), but leave (possibly nested) actions as-is. Disabling flattening affects: - SampleCollectors: Have to store possibly nested action structs. - Models that have the previous action(s) as part of their input. - Algorithms reading from offline files (incl. action information).
- Returns:
This updated AlgorithmConfig object.
RLlib Scaling Guide#
Here are some rules of thumb for scaling training with RLlib.
If the environment is slow and cannot be replicated (e.g., since it requires interaction with physical systems), then you should use a sample-efficient off-policy algorithm such as DQN or SAC. These algorithms default to
num_env_runners: 0
for single-process operation. Make sure to setnum_gpus: 1
if you want to use a GPU. Consider also batch RL training with the offline data API.
2. If the environment is fast and the model is small (most models for RL are), use time-efficient algorithms such as PPO, or IMPALA.
These can be scaled by increasing num_env_runners
to add rollout workers. It may also make sense to enable vectorization for
inference. Make sure to set num_gpus: 1
if you want to use a GPU. If the learner becomes a bottleneck, you can use multiple GPUs for learning by setting
num_gpus > 1
.
If the model is compute intensive (e.g., a large deep residual network) and inference is the bottleneck, consider allocating GPUs to workers by setting
num_gpus_per_env_runner: 1
. If you only have a single GPU, considernum_env_runners: 0
to use the learner GPU for inference. For efficient use of GPU time, use a small number of GPU workers and a large number of envs per worker.Finally, if both model and environment are compute intensive, then enable remote worker envs with async batching by setting
remote_worker_envs: True
and optionallyremote_env_batch_wait_ms
. This batches inference on GPUs in the rollout workers while letting envs run asynchronously in separate actors, similar to the SEED architecture. The number of workers and number of envs per worker should be tuned to maximize GPU utilization.
In case you are using lots of workers (num_env_runners >> 10
) and you observe worker failures for whatever reasons, which normally interrupt your RLlib training runs, consider using
the config settings ignore_env_runner_failures=True
, restart_failed_env_runners=True
, or restart_failed_sub_environments=True
:
restart_failed_env_runners
: When set to True (default), your Algorithm will attempt to restart any failed EnvRunner and replace it with a newly created one. This way, your number of workers will never decrease, even if some of them fail from time to time.
ignore_env_runner_failures
: When set to True, your Algorithm will not crash due to an EnvRunner error, but continue for as long as there is at least one functional worker remaining. This setting is ignored when restart_failed_env_runners=True
.
restart_failed_sub_environments
: When set to True and there is a failure in one of the vectorized sub-environments in one of your EnvRunners, RLlib tries to recreate only the failed sub-environment and re-integrate the newly created one into your vectorized env stack on that EnvRunner.
Note that only one of ignore_env_runner_failures
or restart_failed_env_runners
should be set to True (they are mutually exclusive settings). However,
you can combine each of these with the restart_failed_sub_environments=True
setting.
Using these options will make your training runs much more stable and more robust against occasional OOM or other similar “once in a while” errors on the EnvRunners
themselves or inside your custom environments.
Debugging RLlib Experiments#
Eager Mode#
Policies built with build_tf_policy
(most of the reference algorithms are)
can be run in eager mode by setting the
"framework": "tf2"
/ "eager_tracing": true
config options.
This will tell RLlib to execute the model forward pass, action distribution,
loss, and stats functions in eager mode.
Eager mode makes debugging much easier, since you can now use line-by-line
debugging with breakpoints or Python print()
to inspect
intermediate tensor values.
However, eager can be slower than graph mode unless tracing is enabled.
Episode Traces#
You can use the data output API to save episode traces
for debugging. For example, the following command will run PPO while saving episode
traces to /tmp/debug
.
cd rllib/tuned_examples/ppo
python cartpole_ppo.py --output /tmp/debug
# episode traces will be saved in /tmp/debug, for example
output-2019-02-23_12-02-03_worker-2_0.json
output-2019-02-23_12-02-04_worker-1_0.json
Log Verbosity#
You can control the log level via the "log_level"
flag. Valid values are “DEBUG”,
“INFO”, “WARN” (default), and “ERROR”. This can be used to increase or decrease the
verbosity of internal logging.
For example:
cd rllib/tuned_examples/ppo
python atari_ppo.py --env ALE/Pong-v5 --log-level INFO
python atari_ppo.py --env ALE/Pong-v5 --log-level DEBUG
The default log level is WARN
. We strongly recommend using at least INFO
level logging for development.
Stack Traces#
You can use the ray stack
command to dump the stack traces of all the
Python workers on a single node. This can be useful for debugging unexpected
hangs or performance issues.
Next Steps#
To check how your application is doing, you can use the Ray dashboard.