Note

From Ray 2.6.0 onwards, RLlib is adopting a new stack for training and model customization, gradually replacing the ModelV2 API and some convoluted parts of Policy API with the RLModule API. Click here for details.

Getting Started with RLlib#

At a high level, RLlib provides you with an Algorithm class which holds a policy for environment interaction. Through the algorithm’s interface, you can train the policy compute actions, or store your algorithms. In multi-agent training, the algorithm manages the querying and optimization of multiple policies at once.

../_images/rllib-api.svg

In this guide, we will first walk you through running your first experiments with the RLlib CLI, and then discuss our Python API in more detail.

Using the RLlib CLI#

The quickest way to run your first RLlib algorithm is to use the command line interface. You can train DQN with the following commands:

pip install "ray[rllib]" tensorflow rllib train --algo DQN --env CartPole-v1 --stop '{"training_iteration": 30}'

Note

The rllib train command (same as the train.py script in the repo) has a number of options you can show by running rllib train --help.

Note that you choose any supported RLlib algorithm (--algo) and environment (--env). RLlib supports any Farama-Foundation Gymnasium environment, as well as a number of other environments (see Environments). It also supports a large number of algorithms (see Algorithms) to choose from.

Running the above will return one of the checkpoints that get generated during training after 30 training iterations, as well as a command that you can use to evaluate the trained algorithm. You can evaluate the trained algorithm with the following command (assuming the checkpoint path is called checkpoint):

rllib evaluate checkpoint --algo DQN --env CartPole-v1

Note

By default, the results will be logged to a subdirectory of ~/ray_results. This subdirectory will contain a file params.json which contains the hyper-parameters, a file result.json which contains a training summary for each episode and a TensorBoard file that can be used to visualize training process with TensorBoard by running

tensorboard --logdir=~/ray_results

For more advanced evaluation functionality, refer to Customized Evaluation During Training.

Note

Each algorithm has specific hyperparameters that can be set with --config, see the algorithms documentation for more information. For instance, you can train the A2C algorithm on 8 workers by specifying num_workers: 8 in a JSON string passed to --config:

rllib train --env=PongDeterministic-v4 --run=A2C --config '{"num_workers": 8}'

Running Tuned Examples#

Some good hyperparameters and settings are available in the RLlib repository (some of them are tuned to run on GPUs).

Note

If you find better settings or tune an algorithm on a different domain, consider submitting a Pull Request!

You can run these with the rllib train file command as follows:

rllib train file /path/to/tuned/example.yaml

Note that this works with any local YAML file in the correct format, or with remote URLs pointing to such files. If you want to learn more about the RLlib CLI, please check out the RLlib CLI user guide.

Using the Python API#

The Python API provides the needed flexibility for applying RLlib to new problems. For instance, you will need to use this API if you wish to use custom environments, preprocessors, or models with RLlib.

Here is an example of the basic usage. We first create a PPOConfig and add properties to it, like the environment we want to use, or the resources we want to leverage for training. After we build the algo from its configuration, we can train it for a number of episodes (here 10) and save the resulting policy periodically (here every 5 episodes).

from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print


algo = (
    PPOConfig()
    .rollouts(num_rollout_workers=1)
    .resources(num_gpus=0)
    .environment(env="CartPole-v1")
    .build()
)

for i in range(10):
    result = algo.train()
    print(pretty_print(result))

    if i % 5 == 0:
        checkpoint_dir = algo.save().checkpoint.path
        print(f"Checkpoint saved in directory {checkpoint_dir}")

All RLlib algorithms are compatible with the Tune API. This enables them to be easily used in experiments with Ray Tune. For example, the following code performs a simple hyper-parameter sweep of PPO.

import ray
from ray import train, tune

ray.init()

config = PPOConfig().training(lr=tune.grid_search([0.01, 0.001, 0.0001]))

tuner = tune.Tuner(
    "PPO",
    run_config=train.RunConfig(
        stop={"episode_reward_mean": 150},
    ),
    param_space=config,
)

tuner.fit()

Tune will schedule the trials to run in parallel on your Ray cluster:

== Status ==
Using FIFO scheduling algorithm.
Resources requested: 4/4 CPUs, 0/0 GPUs
Result logdir: ~/ray_results/my_experiment
PENDING trials:
 - PPO_CartPole-v1_2_lr=0.0001:     PENDING
RUNNING trials:
 - PPO_CartPole-v1_0_lr=0.01:       RUNNING [pid=21940], 16 s, 4013 ts, 22 rew
 - PPO_CartPole-v1_1_lr=0.001:      RUNNING [pid=21942], 27 s, 8111 ts, 54.7 rew

Tuner.fit() returns an ResultGrid object that allows further analysis of the training results and retrieving the checkpoint(s) of the trained agent.

# ``Tuner.fit()`` allows setting a custom log directory (other than ``~/ray-results``)
tuner = ray.tune.Tuner(
    "PPO",
    param_space=config,
    run_config=train.RunConfig(
        stop={"episode_reward_mean": 150},
        checkpoint_config=train.CheckpointConfig(checkpoint_at_end=True),
    ),
)

results = tuner.fit()

# Get the best result based on a particular metric.
best_result = results.get_best_result(metric="episode_reward_mean", mode="max")

# Get the best checkpoint corresponding to the best result.
best_checkpoint = best_result.checkpoint

Note

You can find your checkpoint’s version by looking into the rllib_checkpoint.json file inside your checkpoint directory.

Loading and restoring a trained algorithm from a checkpoint is simple. Let’s assume you have a local checkpoint directory called checkpoint_path. To load newer RLlib checkpoints (version >= 1.0), use the following code:

from ray.rllib.algorithms.algorithm import Algorithm
algo = Algorithm.from_checkpoint(checkpoint_path)

For older RLlib checkpoint versions (version < 1.0), you can restore an algorithm via:

from ray.rllib.algorithms.ppo import PPO
algo = PPO(config=config, env=env_class)
algo.restore(checkpoint_path)

Computing Actions#

The simplest way to programmatically compute actions from a trained agent is to use Algorithm.compute_single_action(). This method preprocesses and filters the observation before passing it to the agent policy. Here is a simple example of testing a trained agent for one episode:

# Note: `gymnasium` (not `gym`) will be **the** API supported by RLlib from Ray 2.3 on.
try:
    import gymnasium as gym

    gymnasium = True
except Exception:
    import gym

    gymnasium = False

from ray.rllib.algorithms.ppo import PPOConfig

env_name = "CartPole-v1"
env = gym.make(env_name)
algo = PPOConfig().environment(env_name).build()

episode_reward = 0
terminated = truncated = False

if gymnasium:
    obs, info = env.reset()
else:
    obs = env.reset()

while not terminated and not truncated:
    action = algo.compute_single_action(obs)
    if gymnasium:
        obs, reward, terminated, truncated, info = env.step(action)
    else:
        obs, reward, terminated, info = env.step(action)
    episode_reward += reward

For more advanced usage on computing actions and other functionality, you can consult the RLlib Algorithm API documentation.

Accessing Policy State#

It is common to need to access a algorithm’s internal state, for instance to set or get model weights.

In RLlib algorithm state is replicated across multiple rollout workers (Ray actors) in the cluster. However, you can easily get and update this state between calls to train() via Algorithm.workers.foreach_worker() or Algorithm.workers.foreach_worker_with_index(). These functions take a lambda function that is applied with the worker as an argument. These functions return values for each worker as a list.

You can also access just the “master” copy of the algorithm state through Algorithm.get_policy() or Algorithm.workers.local_worker(), but note that updates here may not be immediately reflected in your rollout workers (if you have configured num_rollout_workers > 0). Here’s a quick example of how to access state of a model:

from ray.rllib.algorithms.dqn import DQNConfig

algo = DQNConfig().environment(env="CartPole-v1").build()

# Get weights of the default local policy
algo.get_policy().get_weights()

# Same as above
algo.workers.local_worker().policy_map["default_policy"].get_weights()

# Get list of weights of each worker, including remote replicas
algo.workers.foreach_worker(lambda worker: worker.get_policy().get_weights())

# Same as above, but with index.
algo.workers.foreach_worker_with_id(
    lambda _id, worker: worker.get_policy().get_weights()
)

Accessing Model State#

Similar to accessing policy state, you may want to get a reference to the underlying neural network model being trained. For example, you may want to pre-train it separately, or otherwise update its weights outside of RLlib. This can be done by accessing the model of the policy.

Note

To run these examples, you need to install a few extra dependencies, namely pip install "gym[atari]" "gym[accept-rom-license]" atari_py.

Below you find three explicit examples showing how to access the model state of an algorithm.

Example: Preprocessing observations for feeding into a model

Then for the code:

try:
    import gymnasium as gym

    env = gym.make("ALE/Pong-v5")
    obs, infos = env.reset()
except Exception:
    import gym

    env = gym.make("PongNoFrameskip-v4")
    obs = env.reset()

# RLlib uses preprocessors to implement transforms such as one-hot encoding
# and flattening of tuple and dict observations.
from ray.rllib.models.preprocessors import get_preprocessor

prep = get_preprocessor(env.observation_space)(env.observation_space)
# <ray.rllib.models.preprocessors.GenericPixelPreprocessor object at 0x7fc4d049de80>

# Observations should be preprocessed prior to feeding into a model
obs.shape
# (210, 160, 3)
prep.transform(obs).shape
# (84, 84, 3)
Example: Querying a policy’s action distribution
# Get a reference to the policy
import numpy as np
from ray.rllib.algorithms.dqn import DQNConfig

algo = (
    DQNConfig()
    .environment("CartPole-v1")
    .framework("tf2")
    .rollouts(num_rollout_workers=0)
    .build()
)
# <ray.rllib.algorithms.ppo.PPO object at 0x7fd020186384>

policy = algo.get_policy()
# <ray.rllib.policy.eager_tf_policy.PPOTFPolicy_eager object at 0x7fd020165470>

# Run a forward pass to get model output logits. Note that complex observations
# must be preprocessed as in the above code block.
logits, _ = policy.model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
# (<tf.Tensor: id=1274, shape=(1, 2), dtype=float32, numpy=...>, [])

# Compute action distribution given logits
policy.dist_class
# <class_object 'ray.rllib.models.tf.tf_action_dist.Categorical'>
dist = policy.dist_class(logits, policy.model)
# <ray.rllib.models.tf.tf_action_dist.Categorical object at 0x7fd02301d710>

# Query the distribution for samples, sample logps
dist.sample()
# <tf.Tensor: id=661, shape=(1,), dtype=int64, numpy=..>
dist.logp([1])
# <tf.Tensor: id=1298, shape=(1,), dtype=float32, numpy=...>

# Get the estimated values for the most recent forward pass
policy.model.value_function()
# <tf.Tensor: id=670, shape=(1,), dtype=float32, numpy=...>

policy.model.base_model.summary()
"""
Model: "model"
_____________________________________________________________________
Layer (type)               Output Shape  Param #  Connected to
=====================================================================
observations (InputLayer)  [(None, 4)]   0
_____________________________________________________________________
fc_1 (Dense)               (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_value_1 (Dense)         (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_2 (Dense)               (None, 256)   65792    fc_1[0][0]
_____________________________________________________________________
fc_value_2 (Dense)         (None, 256)   65792    fc_value_1[0][0]
_____________________________________________________________________
fc_out (Dense)             (None, 2)     514      fc_2[0][0]
_____________________________________________________________________
value_out (Dense)          (None, 1)     257      fc_value_2[0][0]
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0
_____________________________________________________________________
"""
Example: Getting Q values from a DQN model
# Get a reference to the model through the policy
import numpy as np
from ray.rllib.algorithms.dqn import DQNConfig

algo = DQNConfig().environment("CartPole-v1").framework("tf2").build()
model = algo.get_policy().model
# <ray.rllib.models.catalog.FullyConnectedNetwork_as_DistributionalQModel ...>

# List of all model variables
model.variables()

# Run a forward pass to get base model output. Note that complex observations
# must be preprocessed. An example of preprocessing is examples/saving_experiences.py
model_out = model({"obs": np.array([[0.1, 0.2, 0.3, 0.4]])})
# (<tf.Tensor: id=832, shape=(1, 256), dtype=float32, numpy=...)

# Access the base Keras models (all default models have a base)
model.base_model.summary()
"""
Model: "model"
_______________________________________________________________________
Layer (type)                Output Shape    Param #  Connected to
=======================================================================
observations (InputLayer)   [(None, 4)]     0
_______________________________________________________________________
fc_1 (Dense)                (None, 256)     1280     observations[0][0]
_______________________________________________________________________
fc_out (Dense)              (None, 256)     65792    fc_1[0][0]
_______________________________________________________________________
value_out (Dense)           (None, 1)       257      fc_1[0][0]
=======================================================================
Total params: 67,329
Trainable params: 67,329
Non-trainable params: 0
______________________________________________________________________________
"""

# Access the Q value model (specific to DQN)
print(model.get_q_value_distributions(model_out)[0])
# tf.Tensor([[ 0.13023682 -0.36805138]], shape=(1, 2), dtype=float32)
# ^ exact numbers may differ due to randomness

model.q_value_head.summary()

# Access the state value model (specific to DQN)
print(model.get_state_value(model_out))
# tf.Tensor([[0.09381643]], shape=(1, 1), dtype=float32)
# ^ exact number may differ due to randomness

model.state_value_head.summary()

This is especially useful when used with custom model classes.

Configuring RLlib Algorithms#

You can configure RLlib algorithms in a modular fashion by working with so-called AlgorithmConfig objects. In essence, you first create a config = AlgorithmConfig() object and then call methods on it to set the desired configuration options. Each RLlib algorithm has its own config class that inherits from AlgorithmConfig. For instance, to create a PPO algorithm, you start with a PPOConfig object, to work with a DQN algorithm, you start with a DQNConfig object, etc.

Note

Each algorithm has its specific settings, but most configuration options are shared. We discuss the common options below, and refer to the RLlib algorithms guide for algorithm-specific properties. Algorithms differ mostly in their training settings.

Below you find the basic signature of the AlgorithmConfig class, as well as some advanced usage examples:

class ray.rllib.algorithms.algorithm_config.AlgorithmConfig(algo_class=None)[source]

A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration.

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.callbacks import MemoryTrackingCallbacks
# Construct a generic config object, specifying values within different
# sub-categories, e.g. "training".
config = (PPOConfig().training(gamma=0.9, lr=0.01)
        .environment(env="CartPole-v1")
        .resources(num_gpus=0)
        .rollouts(num_rollout_workers=0)
        .callbacks(MemoryTrackingCallbacks)
    )
# A config object can be used to construct the respective Algorithm.
rllib_algo = config.build()
from ray.rllib.algorithms.ppo import PPOConfig
from ray import tune
# In combination with a tune.grid_search:
config = PPOConfig()
config.training(lr=tune.grid_search([0.01, 0.001]))
# Use `to_dict()` method to get the legacy plain python config dict
# for usage with `tune.Tuner().fit()`.
tune.Tuner("PPO", param_space=config.to_dict())

As RLlib algorithms are fairly complex, they come with many configuration options. To make things easier, the common properties of algorithms are naturally grouped into the following categories:

Let’s discuss each category one by one, starting with training options.

Specifying Training Options#

Note

For instance, a DQNConfig takes a double_q training argument to specify whether to use a double-Q DQN, whereas in a PPOConfig this does not make sense.

For individual algorithms, this is probably the most relevant configuration group, as this is where all the algorithm-specific options go. But the base configuration for training of an AlgorithmConfig is actually quite small:

AlgorithmConfig.training(*, gamma: float | None = <ray.rllib.utils.from_config._NotProvided object>, lr: float | ~typing.List[~typing.List[int | float]] | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip: float | None = <ray.rllib.utils.from_config._NotProvided object>, grad_clip_by: str | None = <ray.rllib.utils.from_config._NotProvided object>, train_batch_size: int | None = <ray.rllib.utils.from_config._NotProvided object>, model: dict | None = <ray.rllib.utils.from_config._NotProvided object>, optimizer: dict | None = <ray.rllib.utils.from_config._NotProvided object>, max_requests_in_flight_per_sampler_worker: int | None = <ray.rllib.utils.from_config._NotProvided object>, learner_class: ~typing.Type[Learner] | None = <ray.rllib.utils.from_config._NotProvided object>, _enable_learner_api: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the training related configuration.

Parameters:
  • gamma – Float specifying the discount factor of the Markov Decision process.

  • lr – The learning rate (float) or learning rate schedule in the format of [[timestep, lr-value], [timestep, lr-value], …] In case of a schedule, intermediary timesteps will be assigned to linearly interpolated learning rate values. A schedule config’s first entry must start with timestep 0, i.e.: [[0, initial_value], […]]. Note: If you require a) more than one optimizer (per RLModule), b) optimizer types that are not Adam, c) a learning rate schedule that is not a linearly interpolated, piecewise schedule as described above, or d) specifying c’tor arguments of the optimizer that are not the learning rate (e.g. Adam’s epsilon), then you must override your Learner’s configure_optimizer_for_module() method and handle lr-scheduling yourself.

  • grad_clip – If None, no gradient clipping will be applied. Otherwise, depending on the setting of grad_clip_by, the (float) value of grad_clip will have the following effect: If grad_clip_by=value: Will clip all computed gradients individually inside the interval [-grad_clip, +`grad_clip`]. If grad_clip_by=norm, will compute the L2-norm of each weight/bias gradient tensor individually and then clip all gradients such that these L2-norms do not exceed grad_clip. The L2-norm of a tensor is computed via: sqrt(SUM(w0^2, w1^2, ..., wn^2)) where w[i] are the elements of the tensor (no matter what the shape of this tensor is). If grad_clip_by=global_norm, will compute the square of the L2-norm of each weight/bias gradient tensor individually, sum up all these squared L2-norms across all given gradient tensors (e.g. the entire module to be updated), square root that overall sum, and then clip all gradients such that this global L2-norm does not exceed the given value. The global L2-norm over a list of tensors (e.g. W and V) is computed via: sqrt[SUM(w0^2, w1^2, ..., wn^2) + SUM(v0^2, v1^2, ..., vm^2)], where w[i] and v[j] are the elements of the tensors W and V (no matter what the shapes of these tensors are).

  • grad_clip_by – See grad_clip for the effect of this setting on gradient clipping. Allowed values are value, norm, and global_norm.

  • train_batch_size – Training batch size, if applicable.

  • model – Arguments passed into the policy model. See models/catalog.py for a full list of the available model options. TODO: Provide ModelConfig objects instead of dicts.

  • optimizer – Arguments to pass to the policy optimizer. This setting is not used when _enable_new_api_stack=True.

  • max_requests_in_flight_per_sampler_worker – Max number of inflight requests to each sampling worker. See the FaultTolerantActorManager class for more details. Tuning these values is important when running experimens with large sample batches, where there is the risk that the object store may fill up, causing spilling of objects to disk. This can cause any asynchronous requests to become very slow, making your experiment run slow as well. You can inspect the object store during your experiment via a call to ray memory on your headnode, and by using the ray dashboard. If you’re seeing that the object store is filling up, turn down the number of remote requests in flight, or enable compression in your experiment of timesteps.

  • learner_class – The Learner class to use for (distributed) updating of the RLModule. Only used when _enable_new_api_stack=True.

Returns:

This updated AlgorithmConfig object.

Specifying Environments#

AlgorithmConfig.environment(env: str | ~typing.Any | gymnasium.Env | None = <ray.rllib.utils.from_config._NotProvided object>, *, env_config: dict | None = <ray.rllib.utils.from_config._NotProvided object>, observation_space: gymnasium.spaces.Space | None = <ray.rllib.utils.from_config._NotProvided object>, action_space: gymnasium.spaces.Space | None = <ray.rllib.utils.from_config._NotProvided object>, env_task_fn: ~typing.Callable[[dict, ~typing.Any | gymnasium.Env, ~ray.rllib.env.env_context.EnvContext], ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, render_env: bool | None = <ray.rllib.utils.from_config._NotProvided object>, clip_rewards: bool | float | None = <ray.rllib.utils.from_config._NotProvided object>, normalize_actions: bool | None = <ray.rllib.utils.from_config._NotProvided object>, clip_actions: bool | None = <ray.rllib.utils.from_config._NotProvided object>, disable_env_checking: bool | None = <ray.rllib.utils.from_config._NotProvided object>, is_atari: bool | None = <ray.rllib.utils.from_config._NotProvided object>, auto_wrap_old_gym_envs: bool | None = <ray.rllib.utils.from_config._NotProvided object>, action_mask_key: str | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s RL-environment settings.

Parameters:
  • env – The environment specifier. This can either be a tune-registered env, via tune.register_env([name], lambda env_ctx: [env object]), or a string specifier of an RLlib supported type. In the latter case, RLlib will try to interpret the specifier as either an Farama-Foundation gymnasium env, a PyBullet env, a ViZDoomGym env, or a fully qualified classpath to an Env class, e.g. “ray.rllib.examples.env.random_env.RandomEnv”.

  • env_config – Arguments dict passed to the env creator as an EnvContext object (which is a dict plus the properties: num_rollout_workers, worker_index, vector_index, and remote).

  • observation_space – The observation space for the Policies of this Algorithm.

  • action_space – The action space for the Policies of this Algorithm.

  • env_task_fn – A callable taking the last train results, the base env and the env context as args and returning a new task to set the env to. The env must be a TaskSettableEnv sub-class for this to work. See examples/curriculum_learning.py for an example.

  • render_env – If True, try to render the environment on the local worker or on worker 1 (if num_rollout_workers > 0). For vectorized envs, this usually means that only the first sub-environment will be rendered. In order for this to work, your env will have to implement the render() method which either: a) handles window generation and rendering itself (returning True) or b) returns a numpy uint8 image of shape [height x width x 3 (RGB)].

  • clip_rewards – Whether to clip rewards during Policy’s postprocessing. None (default): Clip for Atari only (r=sign(r)). True: r=sign(r): Fixed rewards -1.0, 1.0, or 0.0. False: Never clip. [float value]: Clip at -value and + value. Tuple[value1, value2]: Clip at value1 and value2.

  • normalize_actions – If True, RLlib will learn entirely inside a normalized action space (0.0 centered with small stddev; only affecting Box components). We will unsquash actions (and clip, just in case) to the bounds of the env’s action space before sending actions back to the env.

  • clip_actions – If True, RLlib will clip actions according to the env’s bounds before sending them back to the env. TODO: (sven) This option should be deprecated and always be False.

  • disable_env_checking – If True, disable the environment pre-checking module.

  • is_atari – This config can be used to explicitly specify whether the env is an Atari env or not. If not specified, RLlib will try to auto-detect this.

  • auto_wrap_old_gym_envs

    Whether to auto-wrap old gym environments (using

    the pre 0.24 gym APIs, e.g. reset() returning single obs and no info dict). If True, RLlib will automatically wrap the given gym env class with the gym-provided compatibility wrapper (gym.wrappers.EnvCompatibility). If False, RLlib will produce a descriptive error on which steps to perform to upgrade to gymnasium (or to switch this flag to True).

    action_mask_key: If observation is a dictionary, expect the value by

    the key action_mask_key to contain a valid actions mask (numpy.int8 array of zeros and ones). Defaults to “action_mask”.

Returns:

This updated AlgorithmConfig object.

Specifying Framework Options#

AlgorithmConfig.framework(framework: str | None = <ray.rllib.utils.from_config._NotProvided object>, *, eager_tracing: bool | None = <ray.rllib.utils.from_config._NotProvided object>, eager_max_retraces: int | None = <ray.rllib.utils.from_config._NotProvided object>, tf_session_args: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, local_tf_session_args: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner: bool | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_what_to_compile: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_dynamo_mode: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_dynamo_backend: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker: bool | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker_dynamo_backend: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker_dynamo_mode: str | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s DL framework settings.

Parameters:
  • framework – torch: PyTorch; tf2: TensorFlow 2.x (eager execution or traced if eager_tracing=True); tf: TensorFlow (static-graph);

  • eager_tracing – Enable tracing in eager mode. This greatly improves performance (speedup ~2x), but makes it slightly harder to debug since Python code won’t be evaluated after the initial eager pass. Only possible if framework=tf2.

  • eager_max_retraces – Maximum number of tf.function re-traces before a runtime error is raised. This is to prevent unnoticed retraces of methods inside the _eager_traced Policy, which could slow down execution by a factor of 4, without the user noticing what the root cause for this slowdown could be. Only necessary for framework=tf2. Set to None to ignore the re-trace count and never throw an error.

  • tf_session_args – Configures TF for single-process operation by default.

  • local_tf_session_args – Override the following tf session args on the local worker

  • torch_compile_learner – If True, forward_train methods on TorchRLModules on the learner are compiled. If not specified, the default is to compile forward train on the learner.

  • torch_compile_learner_what_to_compile – A TorchCompileWhatToCompile mode specifying what to compile on the learner side if torch_compile_learner is True. See TorchCompileWhatToCompile for details and advice on its usage.

  • torch_compile_learner_dynamo_backend – The torch dynamo backend to use on the learner.

  • torch_compile_learner_dynamo_mode – The torch dynamo mode to use on the learner.

  • torch_compile_worker – If True, forward exploration and inference methods on TorchRLModules on the workers are compiled. If not specified, the default is to not compile forward methods on the workers because retracing can be expensive.

  • torch_compile_worker_dynamo_backend – The torch dynamo backend to use on the workers.

  • torch_compile_worker_dynamo_mode – The torch dynamo mode to use on the workers.

Returns:

This updated AlgorithmConfig object.

Specifying Rollout Workers#

AlgorithmConfig.rollouts(*, env_runner_cls: type | None = <ray.rllib.utils.from_config._NotProvided object>, num_rollout_workers: int | None = <ray.rllib.utils.from_config._NotProvided object>, num_envs_per_worker: int | None = <ray.rllib.utils.from_config._NotProvided object>, create_env_on_local_worker: bool | None = <ray.rllib.utils.from_config._NotProvided object>, sample_collector: ~typing.Type[~ray.rllib.evaluation.collectors.sample_collector.SampleCollector] | None = <ray.rllib.utils.from_config._NotProvided object>, enable_connectors: bool | None = <ray.rllib.utils.from_config._NotProvided object>, use_worker_filter_stats: bool | None = <ray.rllib.utils.from_config._NotProvided object>, update_worker_filter_stats: bool | None = <ray.rllib.utils.from_config._NotProvided object>, rollout_fragment_length: int | str | None = <ray.rllib.utils.from_config._NotProvided object>, batch_mode: str | None = <ray.rllib.utils.from_config._NotProvided object>, remote_worker_envs: bool | None = <ray.rllib.utils.from_config._NotProvided object>, remote_env_batch_wait_ms: float | None = <ray.rllib.utils.from_config._NotProvided object>, validate_workers_after_construction: bool | None = <ray.rllib.utils.from_config._NotProvided object>, preprocessor_pref: str | None = <ray.rllib.utils.from_config._NotProvided object>, observation_filter: str | None = <ray.rllib.utils.from_config._NotProvided object>, compress_observations: bool | None = <ray.rllib.utils.from_config._NotProvided object>, enable_tf1_exec_eagerly: bool | None = <ray.rllib.utils.from_config._NotProvided object>, sampler_perf_stats_ema_coef: float | None = <ray.rllib.utils.from_config._NotProvided object>, ignore_worker_failures=-1, recreate_failed_workers=-1, restart_failed_sub_environments=-1, num_consecutive_worker_failures_tolerance=-1, worker_health_probe_timeout_s=-1, worker_restore_timeout_s=-1, synchronize_filter=-1, sample_async=-1) AlgorithmConfig[source]

Sets the rollout worker configuration.

Parameters:
  • env_runner_cls – The EnvRunner class to use for environment rollouts (data collection).

  • num_rollout_workers – Number of rollout worker actors to create for parallel sampling. Setting this to 0 will force rollouts to be done in the local worker (driver process or the Algorithm’s actor when using Tune).

  • num_envs_per_worker – Number of environments to evaluate vector-wise per worker. This enables model inference batching, which can improve performance for inference bottlenecked workloads.

  • sample_collector – The SampleCollector class to be used to collect and retrieve environment-, model-, and sampler data. Override the SampleCollector base class to implement your own collection/buffering/retrieval logic.

  • create_env_on_local_worker – When num_rollout_workers > 0, the driver (local_worker; worker-idx=0) does not need an environment. This is because it doesn’t have to sample (done by remote_workers; worker_indices > 0) nor evaluate (done by evaluation workers; see below).

  • enable_connectors – Use connector based environment runner, so that all preprocessing of obs and postprocessing of actions are done in agent and action connectors.

  • use_worker_filter_stats – Whether to use the workers in the WorkerSet to update the central filters (held by the local worker). If False, stats from the workers will not be used and discarded.

  • update_worker_filter_stats – Whether to push filter updates from the central filters (held by the local worker) to the remote workers’ filters. Setting this to True might be useful within the evaluation config in order to disable the usage of evaluation trajectories for synching the central filter (used for training).

  • rollout_fragment_length – Divide episodes into fragments of this many steps each during rollouts. Trajectories of this size are collected from rollout workers and combined into a larger batch of train_batch_size for learning. For example, given rollout_fragment_length=100 and train_batch_size=1000: 1. RLlib collects 10 fragments of 100 steps each from rollout workers. 2. These fragments are concatenated and we perform an epoch of SGD. When using multiple envs per worker, the fragment size is multiplied by num_envs_per_worker. This is since we are collecting steps from multiple envs in parallel. For example, if num_envs_per_worker=5, then rollout workers will return experiences in chunks of 5*100 = 500 steps. The dataflow here can vary per algorithm. For example, PPO further divides the train batch into minibatches for multi-epoch SGD. Set to “auto” to have RLlib compute an exact rollout_fragment_length to match the given batch size.

  • batch_mode – How to build individual batches with the EnvRunner(s). Batches coming from distributed EnvRunners are usually concat’d to form the train batch. Note that “steps” below can mean different things (either env- or agent-steps) and depends on the count_steps_by setting, adjustable via AlgorithmConfig.multi_agent(count_steps_by=..): 1) “truncate_episodes”: Each call to EnvRunner.sample() will return a batch of at most rollout_fragment_length * num_envs_per_worker in size. The batch will be exactly rollout_fragment_length * num_envs in size if postprocessing does not change batch sizes. Episodes may be truncated in order to meet this size requirement. This mode guarantees evenly sized batches, but increases variance as the future return must now be estimated at truncation boundaries. 2) “complete_episodes”: Each call to EnvRunner.sample() will return a batch of at least rollout_fragment_length * num_envs_per_worker in size. Episodes will not be truncated, but multiple episodes may be packed within one batch to meet the (minimum) batch size. Note that when num_envs_per_worker > 1, episode steps will be buffered until the episode completes, and hence batches may contain significant amounts of off-policy data.

  • remote_worker_envs – If using num_envs_per_worker > 1, whether to create those new envs in remote processes instead of in the same worker. This adds overheads, but can make sense if your envs can take much time to step / reset (e.g., for StarCraft). Use this cautiously; overheads are significant.

  • remote_env_batch_wait_ms – Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal value could be obtained by measuring your environment step / reset and model inference perf.

  • validate_workers_after_construction – Whether to validate that each created remote worker is healthy after its construction process.

  • preprocessor_pref – Whether to use “rllib” or “deepmind” preprocessors by default. Set to None for using no preprocessor. In this case, the model will have to handle possibly complex observations from the environment.

  • observation_filter – Element-wise observation filter, either “NoFilter” or “MeanStdFilter”.

  • compress_observations – Whether to LZ4 compress individual observations in the SampleBatches collected during rollouts.

  • enable_tf1_exec_eagerly – Explicitly tells the rollout worker to enable TF eager execution. This is useful for example when framework is “torch”, but a TF2 policy needs to be restored for evaluation or league-based purposes.

  • sampler_perf_stats_ema_coef – If specified, perf stats are in EMAs. This is the coeff of how much new data points contribute to the averages. Default is None, which uses simple global average instead. The EMA update rule is: updated = (1 - ema_coef) * old + ema_coef * new

Returns:

This updated AlgorithmConfig object.

Specifying Evaluation Options#

AlgorithmConfig.evaluation(*, evaluation_interval: int | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_duration: int | str | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_duration_unit: str | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_sample_timeout_s: float | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_parallel_to_training: bool | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_config: ~ray.rllib.algorithms.algorithm_config.AlgorithmConfig | dict | None = <ray.rllib.utils.from_config._NotProvided object>, off_policy_estimation_methods: ~typing.Dict | None = <ray.rllib.utils.from_config._NotProvided object>, ope_split_batch_by_episode: bool | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_num_workers: int | None = <ray.rllib.utils.from_config._NotProvided object>, custom_evaluation_function: ~typing.Callable | None = <ray.rllib.utils.from_config._NotProvided object>, always_attach_evaluation_results: bool | None = <ray.rllib.utils.from_config._NotProvided object>, enable_async_evaluation: bool | None = <ray.rllib.utils.from_config._NotProvided object>, evaluation_num_episodes=-1) AlgorithmConfig[source]

Sets the config’s evaluation settings.

Parameters:
  • evaluation_interval – Evaluate with every evaluation_interval training iterations. The evaluation stats will be reported under the “evaluation” metric key. Note that for Ape-X metrics are already only reported for the lowest epsilon workers (least random workers). Set to None (or 0) for no evaluation.

  • evaluation_duration – Duration for which to run evaluation each evaluation_interval. The unit for the duration can be set via evaluation_duration_unit to either “episodes” (default) or “timesteps”. If using multiple evaluation workers (evaluation_num_workers > 1), the load to run will be split amongst these. If the value is “auto”: - For evaluation_parallel_to_training=True: Will run as many episodes/timesteps that fit into the (parallel) training step. - For evaluation_parallel_to_training=False: Error.

  • evaluation_duration_unit – The unit, with which to count the evaluation duration. Either “episodes” (default) or “timesteps”.

  • evaluation_sample_timeout_s – The timeout (in seconds) for the ray.get call to the remote evaluation worker(s) sample() method. After this time, the user will receive a warning and instructions on how to fix the issue. This could be either to make sure the episode ends, increasing the timeout, or switching to evaluation_duration_unit=timesteps.

  • evaluation_parallel_to_training – Whether to run evaluation in parallel to a Algorithm.train() call using threading. Default=False. E.g. evaluation_interval=2 -> For every other training iteration, the Algorithm.train() and Algorithm.evaluate() calls run in parallel. Note: This is experimental. Possible pitfalls could be race conditions for weight synching at the beginning of the evaluation loop.

  • evaluation_config – Typical usage is to pass extra args to evaluation env creator and to disable exploration by computing deterministic actions. IMPORTANT NOTE: Policy gradient algorithms are able to find the optimal policy, even if this is a stochastic one. Setting “explore=False” here will result in the evaluation workers not using this optimal policy!

  • off_policy_estimation_methods – Specify how to evaluate the current policy, along with any optional config parameters. This only has an effect when reading offline experiences (“input” is not “sampler”). Available keys: {ope_method_name: {“type”: ope_type, …}} where ope_method_name is a user-defined string to save the OPE results under, and ope_type can be any subclass of OffPolicyEstimator, e.g. ray.rllib.offline.estimators.is::ImportanceSampling or your own custom subclass, or the full class path to the subclass. You can also add additional config arguments to be passed to the OffPolicyEstimator in the dict, e.g. {“qreg_dr”: {“type”: DoublyRobust, “q_model_type”: “qreg”, “k”: 5}}

  • ope_split_batch_by_episode – Whether to use SampleBatch.split_by_episode() to split the input batch to episodes before estimating the ope metrics. In case of bandits you should make this False to see improvements in ope evaluation speed. In case of bandits, it is ok to not split by episode, since each record is one timestep already. The default is True.

  • evaluation_num_workers – Number of parallel workers to use for evaluation. Note that this is set to zero by default, which means evaluation will be run in the algorithm process (only if evaluation_interval is not None). If you increase this, it will increase the Ray resource usage of the algorithm since evaluation workers are created separately from rollout workers (used to sample data for training).

  • custom_evaluation_function – Customize the evaluation method. This must be a function of signature (algo: Algorithm, eval_workers: WorkerSet) -> metrics: dict. See the Algorithm.evaluate() method to see the default implementation. The Algorithm guarantees all eval workers have the latest policy state before this function is called.

  • always_attach_evaluation_results – Make sure the latest available evaluation results are always attached to a step result dict. This may be useful if Tune or some other meta controller needs access to evaluation metrics all the time.

  • enable_async_evaluation – If True, use an AsyncRequestsManager for the evaluation workers and use this manager to send sample() requests to the evaluation workers. This way, the Algorithm becomes more robust against long running episodes and/or failing (and restarting) workers.

Returns:

This updated AlgorithmConfig object.

Specifying Exploration Options#

AlgorithmConfig.exploration(*, explore: bool | None = <ray.rllib.utils.from_config._NotProvided object>, exploration_config: dict | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s exploration settings.

Parameters:
  • explore – Default exploration behavior, iff explore=None is passed into compute_action(s). Set to False for no exploration behavior (e.g., for evaluation).

  • exploration_config – A dict specifying the Exploration object’s config.

Returns:

This updated AlgorithmConfig object.

Specifying Offline Data Options#

AlgorithmConfig.offline_data(*, input_=<ray.rllib.utils.from_config._NotProvided object>, input_config=<ray.rllib.utils.from_config._NotProvided object>, actions_in_input_normalized=<ray.rllib.utils.from_config._NotProvided object>, input_evaluation=<ray.rllib.utils.from_config._NotProvided object>, postprocess_inputs=<ray.rllib.utils.from_config._NotProvided object>, shuffle_buffer_size=<ray.rllib.utils.from_config._NotProvided object>, output=<ray.rllib.utils.from_config._NotProvided object>, output_config=<ray.rllib.utils.from_config._NotProvided object>, output_compress_columns=<ray.rllib.utils.from_config._NotProvided object>, output_max_file_size=<ray.rllib.utils.from_config._NotProvided object>, offline_sampling=<ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s offline data settings.

Parameters:
  • input – Specify how to generate experiences: - “sampler”: Generate experiences via online (env) simulation (default). - A local directory or file glob expression (e.g., “/tmp/.json”). - A list of individual file paths/URIs (e.g., [“/tmp/1.json”, “s3://bucket/2.json”]). - A dict with string keys and sampling probabilities as values (e.g., {“sampler”: 0.4, “/tmp/.json”: 0.4, “s3://bucket/expert.json”: 0.2}). - A callable that takes an IOContext object as only arg and returns a ray.rllib.offline.InputReader. - A string key that indexes a callable with tune.registry.register_input

  • input_config – Arguments that describe the settings for reading the input. If input is sample, this will be environment configuation, e.g. env_name and env_config, etc. See EnvContext for more info. If the input is dataset, this will be e.g. format, path.

  • actions_in_input_normalized – True, if the actions in a given offline “input” are already normalized (between -1.0 and 1.0). This is usually the case when the offline file has been generated by another RLlib algorithm (e.g. PPO or SAC), while “normalize_actions” was set to True.

  • postprocess_inputs – Whether to run postprocess_trajectory() on the trajectory fragments from offline inputs. Note that postprocessing will be done using the current policy, not the behavior policy, which is typically undesirable for on-policy algorithms.

  • shuffle_buffer_size – If positive, input batches will be shuffled via a sliding window buffer of this number of batches. Use this if the input data is not in random enough order. Input is delayed until the shuffle buffer is filled.

  • output – Specify where experiences should be saved: - None: don’t save any experiences - “logdir” to save to the agent log dir - a path/URI to save to a custom output directory (e.g., “s3://bckt/”) - a function that returns a rllib.offline.OutputWriter

  • output_config – Arguments accessible from the IOContext for configuring custom output.

  • output_compress_columns – What sample batch columns to LZ4 compress in the output data.

  • output_max_file_size – Max output file size (in bytes) before rolling over to a new file.

  • offline_sampling – Whether sampling for the Algorithm happens via reading from offline data. If True, EnvRunners will NOT limit the number of collected batches within the same sample() call based on the number of sub-environments within the worker (no sub-environments present).

Returns:

This updated AlgorithmConfig object.

Specifying Multi-Agent Options#

AlgorithmConfig.multi_agent(*, policies=<ray.rllib.utils.from_config._NotProvided object>, algorithm_config_overrides_per_module: ~typing.Dict[str, dict] | None = <ray.rllib.utils.from_config._NotProvided object>, policy_map_capacity: int | None = <ray.rllib.utils.from_config._NotProvided object>, policy_mapping_fn: ~typing.Callable[[~typing.Any, OldEpisode], str] | None = <ray.rllib.utils.from_config._NotProvided object>, policies_to_train: ~typing.Container[str] | ~typing.Callable[[str, SampleBatch | MultiAgentBatch], bool] | None = <ray.rllib.utils.from_config._NotProvided object>, policy_states_are_swappable: bool | None = <ray.rllib.utils.from_config._NotProvided object>, observation_fn: ~typing.Callable | None = <ray.rllib.utils.from_config._NotProvided object>, count_steps_by: str | None = <ray.rllib.utils.from_config._NotProvided object>, replay_mode=-1, policy_map_cache=-1) AlgorithmConfig[source]

Sets the config’s multi-agent settings.

Validates the new multi-agent settings and translates everything into a unified multi-agent setup format. For example a policies list or set of IDs is properly converted into a dict mapping these IDs to PolicySpecs.

Parameters:
  • policies – Map of type MultiAgentPolicyConfigDict from policy ids to either 4-tuples of (policy_cls, obs_space, act_space, config) or PolicySpecs. These tuples or PolicySpecs define the class of the policy, the observation- and action spaces of the policies, and any extra config.

  • algorithm_config_overrides_per_module – Only used if _enable_new_api_stack=True. A mapping from ModuleIDs to per-module AlgorithmConfig override dicts, which apply certain settings, e.g. the learning rate, from the main AlgorithmConfig only to this particular module (within a MultiAgentRLModule). You can create override dicts by using the AlgorithmConfig.overrides utility. For example, to override your learning rate and (PPO) lambda setting just for a single RLModule with your MultiAgentRLModule, do: config.multi_agent(algorithm_config_overrides_per_module={ “module_1”: PPOConfig.overrides(lr=0.0002, lambda_=0.75), })

  • policy_map_capacity – Keep this many policies in the “policy_map” (before writing least-recently used ones to disk/S3).

  • policy_mapping_fn – Function mapping agent ids to policy ids. The signature is: (agent_id, episode, worker, **kwargs) -> PolicyID.

  • policies_to_train – Determines those policies that should be updated. Options are: - None, for training all policies. - An iterable of PolicyIDs that should be trained. - A callable, taking a PolicyID and a SampleBatch or MultiAgentBatch and returning a bool (indicating whether the given policy is trainable or not, given the particular batch). This allows you to have a policy trained only on certain data (e.g. when playing against a certain opponent).

  • policy_states_are_swappable – Whether all Policy objects in this map can be “swapped out” via a simple state = A.get_state(); B.set_state(state), where A and B are policy instances in this map. You should set this to True for significantly speeding up the PolicyMap’s cache lookup times, iff your policies all share the same neural network architecture and optimizer types. If True, the PolicyMap will not have to garbage collect old, least recently used policies, but instead keep them in memory and simply override their state with the state of the most recently accessed one. For example, in a league-based training setup, you might have 100s of the same policies in your map (playing against each other in various combinations), but all of them share the same state structure (are “swappable”).

  • observation_fn – Optional function that can be used to enhance the local agent observations to include more state. See rllib/evaluation/observation_function.py for more info.

  • count_steps_by – Which metric to use as the “batch size” when building a MultiAgentBatch. The two supported values are: “env_steps”: Count each time the env is “stepped” (no matter how many multi-agent actions are passed/how many multi-agent observations have been returned in the previous step). “agent_steps”: Count each individual agent step as one step.

Returns:

This updated AlgorithmConfig object.

Specifying Reporting Options#

AlgorithmConfig.reporting(*, keep_per_episode_custom_metrics: bool | None = <ray.rllib.utils.from_config._NotProvided object>, metrics_episode_collection_timeout_s: float | None = <ray.rllib.utils.from_config._NotProvided object>, metrics_num_episodes_for_smoothing: int | None = <ray.rllib.utils.from_config._NotProvided object>, min_time_s_per_iteration: int | None = <ray.rllib.utils.from_config._NotProvided object>, min_train_timesteps_per_iteration: int | None = <ray.rllib.utils.from_config._NotProvided object>, min_sample_timesteps_per_iteration: int | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s reporting settings.

Parameters:
  • keep_per_episode_custom_metrics – Store raw custom metrics without calculating max, min, mean

  • metrics_episode_collection_timeout_s – Wait for metric batches for at most this many seconds. Those that have not returned in time will be collected in the next train iteration.

  • metrics_num_episodes_for_smoothing – Smooth rollout metrics over this many episodes, if possible. In case rollouts (sample collection) just started, there may be fewer than this many episodes in the buffer and we’ll compute metrics over this smaller number of available episodes. In case there are more than this many episodes collected in a single training iteration, use all of these episodes for metrics computation, meaning don’t ever cut any “excess” episodes. Set this to 1 to disable smoothing and to always report only the most recently collected episode’s return.

  • min_time_s_per_iteration – Minimum time to accumulate within a single train() call. This value does not affect learning, only the number of times Algorithm.training_step() is called by Algorithm.train(). If - after one such step attempt, the time taken has not reached min_time_s_per_iteration, will perform n more training_step() calls until the minimum time has been consumed. Set to 0 or None for no minimum time.

  • min_train_timesteps_per_iteration – Minimum training timesteps to accumulate within a single train() call. This value does not affect learning, only the number of times Algorithm.training_step() is called by Algorithm.train(). If - after one such step attempt, the training timestep count has not been reached, will perform n more training_step() calls until the minimum timesteps have been executed. Set to 0 or None for no minimum timesteps.

  • min_sample_timesteps_per_iteration – Minimum env sampling timesteps to accumulate within a single train() call. This value does not affect learning, only the number of times Algorithm.training_step() is called by Algorithm.train(). If - after one such step attempt, the env sampling timestep count has not been reached, will perform n more training_step() calls until the minimum timesteps have been executed. Set to 0 or None for no minimum timesteps.

Returns:

This updated AlgorithmConfig object.

Specifying Checkpointing Options#

AlgorithmConfig.checkpointing(export_native_model_files: bool | None = <ray.rllib.utils.from_config._NotProvided object>, checkpoint_trainable_policies_only: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s checkpointing settings.

Parameters:
  • export_native_model_files – Whether an individual Policy- or the Algorithm’s checkpoints also contain (tf or torch) native model files. These could be used to restore just the NN models from these files w/o requiring RLlib. These files are generated by calling the tf- or torch- built-in saving utility methods on the actual models.

  • checkpoint_trainable_policies_only – Whether to only add Policies to the Algorithm checkpoint (in sub-directory “policies/”) that are trainable according to the is_trainable_policy callable of the local worker.

Returns:

This updated AlgorithmConfig object.

Specifying Debugging Options#

AlgorithmConfig.debugging(*, logger_creator: ~typing.Callable[[], ~ray.tune.logger.logger.Logger] | None = <ray.rllib.utils.from_config._NotProvided object>, logger_config: dict | None = <ray.rllib.utils.from_config._NotProvided object>, log_level: str | None = <ray.rllib.utils.from_config._NotProvided object>, log_sys_usage: bool | None = <ray.rllib.utils.from_config._NotProvided object>, fake_sampler: bool | None = <ray.rllib.utils.from_config._NotProvided object>, seed: int | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s debugging settings.

Parameters:
  • logger_creator – Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created.

  • logger_config – Define logger-specific configuration to be used inside Logger Default value None allows overwriting with nested dicts.

  • log_level – Set the ray.rllib.* log level for the agent process and its workers. Should be one of DEBUG, INFO, WARN, or ERROR. The DEBUG level will also periodically print out summaries of relevant internal dataflow (this is also printed out once at startup at the INFO level). When using the rllib train command, you can also use the -v and -vv flags as shorthand for INFO and DEBUG.

  • log_sys_usage – Log system resource metrics to results. This requires psutil to be installed for sys stats, and gputil for GPU metrics.

  • fake_sampler – Use fake (infinite speed) sampler. For testing only.

  • seed – This argument, in conjunction with worker_index, sets the random seed of each worker, so that identically configured trials will have identical results. This makes experiments reproducible.

Returns:

This updated AlgorithmConfig object.

Specifying Callback Options#

AlgorithmConfig.callbacks(callbacks_class) AlgorithmConfig[source]

Sets the callbacks configuration.

Parameters:

callbacks_class – Callbacks class, whose methods will be run during various phases of training and environment sample collection. See the DefaultCallbacks class and examples/custom_metrics_and_callbacks.py for more usage information.

Returns:

This updated AlgorithmConfig object.

Specifying Resources#

You can control the degree of parallelism used by setting the num_workers hyperparameter for most algorithms. The Algorithm will construct that many “remote worker” instances (see RolloutWorker class) that are constructed as ray.remote actors, plus exactly one “local worker”, a RolloutWorker object that is not a ray actor, but lives directly inside the Algorithm. For most algorithms, learning updates are performed on the local worker and sample collection from one or more environments is performed by the remote workers (in parallel). For example, setting num_workers=0 will only create the local worker, in which case both sample collection and training will be done by the local worker. On the other hand, setting num_workers=5 will create the local worker (responsible for training updates) and 5 remote workers (responsible for sample collection).

Since learning is most of the time done on the local worker, it may help to provide one or more GPUs to that worker via the num_gpus setting. Similarly, the resource allocation to remote workers can be controlled via num_cpus_per_worker, num_gpus_per_worker, and custom_resources_per_worker.

The number of GPUs can be fractional quantities (e.g. 0.5) to allocate only a fraction of a GPU. For example, with DQN you can pack five algorithms onto one GPU by setting num_gpus: 0.2. Check out this fractional GPU example here as well that also demonstrates how environments (running on the remote workers) that require a GPU can benefit from the num_gpus_per_worker setting.

For synchronous algorithms like PPO and A2C, the driver and workers can make use of the same GPU. To do this for an amount of n GPUS:

gpu_count = n
num_gpus = 0.0001 # Driver GPU
num_gpus_per_worker = (gpu_count - num_gpus) / num_workers
../_images/rllib-config.svg

If you specify num_gpus and your machine does not have the required number of GPUs available, a RuntimeError will be thrown by the respective worker. On the other hand, if you set num_gpus=0, your policies will be built solely on the CPU, even if GPUs are available on the machine.

AlgorithmConfig.resources(*, num_gpus: int | float | None = <ray.rllib.utils.from_config._NotProvided object>, _fake_gpus: bool | None = <ray.rllib.utils.from_config._NotProvided object>, num_cpus_per_worker: int | float | None = <ray.rllib.utils.from_config._NotProvided object>, num_gpus_per_worker: int | float | None = <ray.rllib.utils.from_config._NotProvided object>, num_cpus_for_local_worker: int | None = <ray.rllib.utils.from_config._NotProvided object>, num_learner_workers: int | None = <ray.rllib.utils.from_config._NotProvided object>, num_cpus_per_learner_worker: int | float | None = <ray.rllib.utils.from_config._NotProvided object>, num_gpus_per_learner_worker: int | float | None = <ray.rllib.utils.from_config._NotProvided object>, local_gpu_idx: int | None = <ray.rllib.utils.from_config._NotProvided object>, custom_resources_per_worker: dict | None = <ray.rllib.utils.from_config._NotProvided object>, placement_strategy: str | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Specifies resources allocated for an Algorithm and its ray actors/workers.

Parameters:
  • num_gpus – Number of GPUs to allocate to the algorithm process. Note that not all algorithms can take advantage of GPUs. Support for multi-GPU is currently only available for tf-[PPO/IMPALA/DQN/PG]. This can be fractional (e.g., 0.3 GPUs).

  • _fake_gpus – Set to True for debugging (multi-)?GPU funcitonality on a CPU machine. GPU towers will be simulated by graphs located on CPUs in this case. Use num_gpus to test for different numbers of fake GPUs.

  • num_cpus_per_worker – Number of CPUs to allocate per worker.

  • num_gpus_per_worker – Number of GPUs to allocate per worker. This can be fractional. This is usually needed only if your env itself requires a GPU (i.e., it is a GPU-intensive video game), or model inference is unusually expensive.

  • num_learner_workers – Number of workers used for training. A value of 0 means training will take place on a local worker on head node CPUs or 1 GPU (determined by num_gpus_per_learner_worker). For multi-gpu training, set number of workers greater than 1 and set num_gpus_per_learner_worker accordingly (e.g. 4 GPUs total, and model needs 2 GPUs: num_learner_workers = 2 and num_gpus_per_learner_worker = 2)

  • num_cpus_per_learner_worker – Number of CPUs allocated per Learner worker. Only necessary for custom processing pipeline inside each Learner requiring multiple CPU cores. Ignored if num_learner_workers = 0.

  • num_gpus_per_learner_worker – Number of GPUs allocated per worker. If num_learner_workers = 0, any value greater than 0 will run the training on a single GPU on the head node, while a value of 0 will run the training on head node CPU cores. If num_gpus_per_learner_worker is set, then num_cpus_per_learner_worker cannot be set.

  • local_gpu_idx – if num_gpus_per_worker > 0, and num_workers<2, then this gpu index will be used for training. This is an index into the available cuda devices. For example if os.environ[“CUDA_VISIBLE_DEVICES”] = “1” then a local_gpu_idx of 0 will use the gpu with id 1 on the node.

  • custom_resources_per_worker – Any custom Ray resources to allocate per worker.

  • num_cpus_for_local_worker – Number of CPUs to allocate for the algorithm. Note: this only takes effect when running in Tune. Otherwise, the algorithm runs in the main program (driver).

  • custom_resources_per_worker – Any custom Ray resources to allocate per worker.

  • placement_strategy – The strategy for the placement group factory returned by Algorithm.default_resource_request(). A PlacementGroup defines, which devices (resources) should always be co-located on the same node. For example, an Algorithm with 2 rollout workers, running with num_gpus=1 will request a placement group with the bundles: [{“gpu”: 1, “cpu”: 1}, {“cpu”: 1}, {“cpu”: 1}], where the first bundle is for the driver and the other 2 bundles are for the two workers. These bundles can now be “placed” on the same or different nodes depending on the value of placement_strategy: “PACK”: Packs bundles into as few nodes as possible. “SPREAD”: Places bundles across distinct nodes as even as possible. “STRICT_PACK”: Packs bundles into one node. The group is not allowed to span multiple nodes. “STRICT_SPREAD”: Packs bundles across distinct nodes.

Returns:

This updated AlgorithmConfig object.

Specifying Experimental Features#

AlgorithmConfig.experimental(*, _enable_new_api_stack: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _tf_policy_handles_more_than_one_loss: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _disable_preprocessor_api: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _disable_action_flattening: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _disable_execution_plan_api: bool | None = <ray.rllib.utils.from_config._NotProvided object>, _disable_initialize_loss_from_dummy_batch: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[source]

Sets the config’s experimental settings.

Parameters:
  • _enable_new_api_stack – Enables the new API stack, which will use RLModule (instead of ModelV2) as well as the multi-GPU capable Learner API (instead of using Policy to compute loss and update the model).

  • _tf_policy_handles_more_than_one_loss – Experimental flag. If True, TFPolicy will handle more than one loss/optimizer. Set this to True, if you would like to return more than one loss term from your loss_fn and an equal number of optimizers from your optimizer_fn. In the future, the default for this will be True.

  • _disable_preprocessor_api – Experimental flag. If True, no (observation) preprocessor will be created and observations will arrive in model as they are returned by the env. In the future, the default for this will be True.

  • _disable_action_flattening – Experimental flag. If True, RLlib will no longer flatten the policy-computed actions into a single tensor (for storage in SampleCollectors/output files/etc..), but leave (possibly nested) actions as-is. Disabling flattening affects: - SampleCollectors: Have to store possibly nested action structs. - Models that have the previous action(s) as part of their input. - Algorithms reading from offline files (incl. action information).

  • _disable_execution_plan_api – Experimental flag. If True, the execution plan API will not be used. Instead, a Algorithm’s training_iteration method will be called as-is each training iteration.

Returns:

This updated AlgorithmConfig object.

RLlib Scaling Guide#

Here are some rules of thumb for scaling training with RLlib.

  1. If the environment is slow and cannot be replicated (e.g., since it requires interaction with physical systems), then you should use a sample-efficient off-policy algorithm such as DQN or SAC. These algorithms default to num_workers: 0 for single-process operation. Make sure to set num_gpus: 1 if you want to use a GPU. Consider also batch RL training with the offline data API.

2. If the environment is fast and the model is small (most models for RL are), use time-efficient algorithms such as PPO, or IMPALA. These can be scaled by increasing num_workers to add rollout workers. It may also make sense to enable vectorization for inference. Make sure to set num_gpus: 1 if you want to use a GPU. If the learner becomes a bottleneck, multiple GPUs can be used for learning by setting num_gpus > 1.

  1. If the model is compute intensive (e.g., a large deep residual network) and inference is the bottleneck, consider allocating GPUs to workers by setting num_gpus_per_worker: 1. If you only have a single GPU, consider num_workers: 0 to use the learner GPU for inference. For efficient use of GPU time, use a small number of GPU workers and a large number of envs per worker.

  2. Finally, if both model and environment are compute intensive, then enable remote worker envs with async batching by setting remote_worker_envs: True and optionally remote_env_batch_wait_ms. This batches inference on GPUs in the rollout workers while letting envs run asynchronously in separate actors, similar to the SEED architecture. The number of workers and number of envs per worker should be tuned to maximize GPU utilization.

In case you are using lots of workers (num_workers >> 10) and you observe worker failures for whatever reasons, which normally interrupt your RLlib training runs, consider using the config settings ignore_worker_failures=True, recreate_failed_workers=True, or restart_failed_sub_environments=True:

ignore_worker_failures: When set to True, your Algorithm will not crash due to a single worker error but continue for as long as there is at least one functional worker remaining. recreate_failed_workers: When set to True, your Algorithm will attempt to replace/recreate any failed worker(s) with newly created one(s). This way, your number of workers will never decrease, even if some of them fail from time to time. restart_failed_sub_environments: When set to True and there is a failure in one of the vectorized sub-environments in one of your workers, the worker will try to recreate only the failed sub-environment and re-integrate the newly created one into your vectorized env stack on that worker.

Note that only one of ignore_worker_failures or recreate_failed_workers may be set to True (they are mutually exclusive settings). However, you can combine each of these with the restart_failed_sub_environments=True setting. Using these options will make your training runs much more stable and more robust against occasional OOM or other similar “once in a while” errors on your workers themselves or inside your environments.

Debugging RLlib Experiments#

Gym Monitor#

The "monitor": true config can be used to save Gym episode videos to the result dir. For example:

rllib train --env=PongDeterministic-v4 \
    --run=A2C --config '{"num_workers": 2, "monitor": true}'

# videos will be saved in the ~/ray_results/<experiment> dir, for example
openaigym.video.0.31401.video000000.meta.json
openaigym.video.0.31401.video000000.mp4
openaigym.video.0.31403.video000000.meta.json
openaigym.video.0.31403.video000000.mp4

Eager Mode#

Policies built with build_tf_policy (most of the reference algorithms are) can be run in eager mode by setting the "framework": "tf2" / "eager_tracing": true config options or using rllib train --config '{"framework": "tf2"}' [--trace]. This will tell RLlib to execute the model forward pass, action distribution, loss, and stats functions in eager mode.

Eager mode makes debugging much easier, since you can now use line-by-line debugging with breakpoints or Python print() to inspect intermediate tensor values. However, eager can be slower than graph mode unless tracing is enabled.

Using PyTorch#

Algorithms that have an implemented TorchPolicy, will allow you to run rllib train using the command line --framework=torch flag. Algorithms that do not have a torch version yet will complain with an error in this case.

Episode Traces#

You can use the data output API to save episode traces for debugging. For example, the following command will run PPO while saving episode traces to /tmp/debug.

rllib train --run=PPO --env=CartPole-v1 \
    --config='{"output": "/tmp/debug", "output_compress_columns": []}'

# episode traces will be saved in /tmp/debug, for example
output-2019-02-23_12-02-03_worker-2_0.json
output-2019-02-23_12-02-04_worker-1_0.json

Log Verbosity#

You can control the log level via the "log_level" flag. Valid values are “DEBUG”, “INFO”, “WARN” (default), and “ERROR”. This can be used to increase or decrease the verbosity of internal logging. You can also use the -v and -vv flags. For example, the following two commands are about equivalent:

rllib train --env=PongDeterministic-v4 \
    --run=A2C --config '{"num_workers": 2, "log_level": "DEBUG"}'

rllib train --env=PongDeterministic-v4 \
    --run=A2C --config '{"num_workers": 2}' -vv

The default log level is WARN. We strongly recommend using at least INFO level logging for development.

Stack Traces#

You can use the ray stack command to dump the stack traces of all the Python workers on a single node. This can be useful for debugging unexpected hangs or performance issues.

Next Steps#

  • To check how your application is doing, you can use the Ray dashboard.