Note

Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.

Note, however, that so far only PPO (single- and multi-agent) and SAC (single-agent only) support the “new API stack” and continue to run by default with the old APIs. You can continue to use the existing custom (old stack) classes.

See here for more details on how to use the new API stack.

Policy API#

The Policy class contains functionality to compute actions for decision making in an environment, as well as computing loss(es) and gradients, updating a neural network model as well as postprocessing a collected environment trajectory. One or more Policy objects sit inside a RolloutWorker’s PolicyMap and are - if more than one - are selected based on a multi-agent policy_mapping_fn, which maps agent IDs to a policy ID.

../../_images/policy_classes_overview.svg

RLlib’s Policy class hierarchy: Policies are deep-learning framework specific as they hold functionality to handle a computation graph (e.g. a TensorFlow 1.x graph in a session). You can define custom policy behavior by sub-classing either of the available, built-in classes, depending on your needs.#

Note

Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The Ray Team plans to transition algorithms, example scripts, and documentation to the new code base thereby incrementally replacing the “old API stack” (e.g., ModelV2, Policy, RolloutWorker) throughout the subsequent minor releases leading up to Ray 3.0.

Note, however, that so far only PPO (single- and multi-agent) and SAC (single-agent only) support the “new API stack” and continue to run by default with the old APIs. You can continue to use the existing custom (old stack) classes.

See here for more details on how to use the new API stack.

Building Custom Policy Classes#

Warning

As of Ray >= 1.9, it is no longer recommended to use the build_policy_class() or build_tf_policy() utility functions for creating custom Policy sub-classes. Instead, follow the simple guidelines here for directly sub-classing from either one of the built-in types: EagerTFPolicyV2 or TorchPolicyV2

In order to create a custom Policy, sub-class Policy (for a generic, framework-agnostic policy), TorchPolicyV2 (for a PyTorch specific policy), or EagerTFPolicyV2 (for a TensorFlow specific policy) and override one or more of their methods. Those are in particular:

See here for an example on how to override TorchPolicy.

Base Policy classes#

Policy

RLlib's base class for all Policy implementations.

EagerTFPolicyV2

A TF-eager / TF2 based tensorflow policy.

TorchPolicyV2

PyTorch specific Policy class to use with RLlib.

Making models#

Base Policy#

make_rl_module

Returns the RL Module (only for when RLModule API is enabled.)

Torch Policy#

make_model

Create model.

make_model_and_action_dist

Create model and action distribution function.

Tensorflow Policy#

make_model

Build underlying model for this Policy.

Inference#

Base Policy#

compute_actions

Computes actions for the current policy.

compute_actions_from_input_dict

Computes actions from collected samples (across multiple-agents).

compute_single_action

Computes and returns a single (B=1) action value.

Torch Policy#

action_sampler_fn

Custom function for sampling new actions given policy.

action_distribution_fn

Action distribution function for this Policy.

extra_action_out

Returns dict of extra info to include in experience batch.

Tensorflow Policy#

action_sampler_fn

Custom function for sampling new actions given policy.

action_distribution_fn

Action distribution function for this Policy.

extra_action_out_fn

Extra values to fetch and return from compute_actions().

Computing, processing, and applying gradients#

Base Policy#

compute_gradients

Computes gradients given a batch of experiences.

apply_gradients

Applies the (previously) computed gradients.

Torch Policy#

extra_compute_grad_fetches

Extra values to fetch and return from compute_gradients().

extra_grad_process

Called after each optimizer.zero_grad() + loss.backward() call.

Tensorflow Policy#

grad_stats_fn

Gradient stats function.

compute_gradients_fn

Gradients computing function (from loss tensor, using local optimizer).

apply_gradients_fn

Gradients computing function (from loss tensor, using local optimizer).

extra_learn_fetches_fn

Extra stats to be reported after gradient computation.

Updating the Policy’s model#

Base Policy#

learn_on_batch

Perform one learning update, given samples.

load_batch_into_buffer

Bulk-loads the given SampleBatch into the devices' memories.

learn_on_loaded_batch

Runs a single step of SGD on an already loaded data in a buffer.

learn_on_batch_from_replay_buffer

Samples a batch from given replay actor and performs an update.

get_num_samples_loaded_into_buffer

Returns the number of currently loaded samples in the given buffer.

Loss, Logging, optimizers, and trajectory processing#

Base Policy#

loss

Loss function for this Policy.

compute_log_likelihoods

Computes the log-prob/likelihood for a given action and observation.

on_global_var_update

Called on an update to global vars.

postprocess_trajectory

Implements algorithm-specific trajectory postprocessing.

Torch Policy#

optimizer

Custom the local PyTorch optimizer(s) to use.

get_tower_stats

Returns list of per-tower stats, copied to this Policy's device.

Tensorflow Policy#

optimizer

TF optimizer to use for policy optimization.

stats_fn

Stats function.

Saving and restoring#

Base Policy#

from_checkpoint

Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.

export_checkpoint

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

export_model

Exports the Policy's Model to local directory for serving.

from_state

Recovers a Policy from a state object.

get_weights

Returns model weights.

set_weights

Sets this Policy's model's weights.

get_state

Returns the entire current state of this Policy.

set_state

Restores the entire current state of this Policy from state.

import_model_from_h5

Imports Policy from local file.

Connectors#

Base Policy#

reset_connectors

Reset action- and agent-connectors for this policy.

restore_connectors

Restore agent and action connectors if configs available.

get_connector_metrics

Get metrics on timing from connectors.

Recurrent Policies#

Base Policy#

Policy.get_initial_state

Returns initial RNN state for the current policy.

Policy.num_state_tensors

The number of internal states needed by the RNN-Model of the Policy.

Policy.is_recurrent

Whether this Policy holds a recurrent Model.

Miscellaneous#

Base Policy#

apply

Calls the given function with this Policy instance.

get_session

Returns tf.Session object to use for computing actions or None.

init_view_requirements

Maximal view requirements dict for learn_on_batch() and compute_actions calls.

get_host

Returns the computer's network name.

get_exploration_state

Returns the state of this Policy's exploration component.

Torch Policy#

get_batch_divisibility_req

Get batch divisibility request.

Tensorflow Policy#

variables

Return the list of all savable variables for this policy.

get_batch_divisibility_req

Get batch divisibility request.