Note

From Ray 2.6.0 onwards, RLlib is adopting a new stack for training and model customization, gradually replacing the ModelV2 API and some convoluted parts of Policy API with the RLModule API. Click here for details.

Policy API#

The Policy class contains functionality to compute actions for decision making in an environment, as well as computing loss(es) and gradients, updating a neural network model as well as postprocessing a collected environment trajectory. One or more Policy objects sit inside a RolloutWorker’s PolicyMap and are - if more than one - are selected based on a multi-agent policy_mapping_fn, which maps agent IDs to a policy ID.

../../_images/policy_classes_overview.svg

RLlib’s Policy class hierarchy: Policies are deep-learning framework specific as they hold functionality to handle a computation graph (e.g. a TensorFlow 1.x graph in a session). You can define custom policy behavior by sub-classing either of the available, built-in classes, depending on your needs.#

Building Custom Policy Classes#

Warning

As of Ray >= 1.9, it is no longer recommended to use the build_policy_class() or build_tf_policy() utility functions for creating custom Policy sub-classes. Instead, follow the simple guidelines here for directly sub-classing from either one of the built-in types: EagerTFPolicyV2 or TorchPolicyV2

In order to create a custom Policy, sub-class Policy (for a generic, framework-agnostic policy), TorchPolicyV2 (for a PyTorch specific policy), or EagerTFPolicyV2 (for a TensorFlow specific policy) and override one or more of their methods. Those are in particular:

See here for an example on how to override TorchPolicy.

Base Policy classes#

Policy

RLlib's base class for all Policy implementations.

EagerTFPolicyV2

A TF-eager / TF2 based tensorflow policy.

TorchPolicyV2

PyTorch specific Policy class to use with RLlib.

Making models#

Base Policy#

make_rl_module

Returns the RL Module (only for when RLModule API is enabled.)

Torch Policy#

make_model

Create model.

make_model_and_action_dist

Create model and action distribution function.

Tensorflow Policy#

make_model

Build underlying model for this Policy.

Inference#

Base Policy#

compute_actions

Computes actions for the current policy.

compute_actions_from_input_dict

Computes actions from collected samples (across multiple-agents).

compute_single_action

Computes and returns a single (B=1) action value.

Torch Policy#

action_sampler_fn

Custom function for sampling new actions given policy.

action_distribution_fn

Action distribution function for this Policy.

extra_action_out

Returns dict of extra info to include in experience batch.

Tensorflow Policy#

action_sampler_fn

Custom function for sampling new actions given policy.

action_distribution_fn

Action distribution function for this Policy.

extra_action_out_fn

Extra values to fetch and return from compute_actions().

Computing, processing, and applying gradients#

Base Policy#

compute_gradients

Computes gradients given a batch of experiences.

apply_gradients

Applies the (previously) computed gradients.

Torch Policy#

extra_compute_grad_fetches

Extra values to fetch and return from compute_gradients().

extra_grad_process

Called after each optimizer.zero_grad() + loss.backward() call.

Tensorflow Policy#

grad_stats_fn

Gradient stats function.

compute_gradients_fn

Gradients computing function (from loss tensor, using local optimizer).

apply_gradients_fn

Gradients computing function (from loss tensor, using local optimizer).

extra_learn_fetches_fn

Extra stats to be reported after gradient computation.

Updating the Policy’s model#

Base Policy#

learn_on_batch

Perform one learning update, given samples.

load_batch_into_buffer

Bulk-loads the given SampleBatch into the devices' memories.

learn_on_loaded_batch

Runs a single step of SGD on an already loaded data in a buffer.

learn_on_batch_from_replay_buffer

Samples a batch from given replay actor and performs an update.

get_num_samples_loaded_into_buffer

Returns the number of currently loaded samples in the given buffer.

Loss, Logging, optimizers, and trajectory processing#

Base Policy#

loss

Loss function for this Policy.

compute_log_likelihoods

Computes the log-prob/likelihood for a given action and observation.

on_global_var_update

Called on an update to global vars.

postprocess_trajectory

Implements algorithm-specific trajectory postprocessing.

Torch Policy#

optimizer

Custom the local PyTorch optimizer(s) to use.

get_tower_stats

Returns list of per-tower stats, copied to this Policy's device.

Tensorflow Policy#

optimizer

TF optimizer to use for policy optimization.

stats_fn

Stats function.

Saving and restoring#

Base Policy#

from_checkpoint

Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.

export_checkpoint

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

export_model

Exports the Policy's Model to local directory for serving.

from_state

Recovers a Policy from a state object.

get_weights

Returns model weights.

set_weights

Sets this Policy's model's weights.

get_state

Returns the entire current state of this Policy.

set_state

Restores the entire current state of this Policy from state.

import_model_from_h5

Imports Policy from local file.

Connectors#

Base Policy#

reset_connectors

Reset action- and agent-connectors for this policy.

restore_connectors

Restore agent and action connectors if configs available.

get_connector_metrics

Get metrics on timing from connectors.

Recurrent Policies#

Base Policy#

Policy.get_initial_state

Returns initial RNN state for the current policy.

Policy.num_state_tensors

The number of internal states needed by the RNN-Model of the Policy.

Policy.is_recurrent

Whether this Policy holds a recurrent Model.

Miscellaneous#

Base Policy#

apply

Calls the given function with this Policy instance.

get_session

Returns tf.Session object to use for computing actions or None.

init_view_requirements

Maximal view requirements dict for learn_on_batch() and compute_actions calls.

get_host

Returns the computer's network name.

get_exploration_state

Returns the state of this Policy's exploration component.

Torch Policy#

get_batch_divisibility_req

Get batch divisibility request.

Tensorflow Policy#

variables

Return the list of all savable variables for this policy.

get_batch_divisibility_req

Get batch divisibility request.