Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The team is currently transitioning algorithms, example scripts, and documentation to the new code base throughout the subsequent minor releases leading up to Ray 3.0.
See here for more details on how to activate and use the new API stack.
Policy API#
The Policy
class contains functionality to compute
actions for decision making in an environment, as well as computing loss(es) and gradients,
updating a neural network model as well as postprocessing a collected environment trajectory.
One or more Policy
objects sit inside a
RolloutWorker
’s PolicyMap
and
are - if more than one - are selected based on a multi-agent policy_mapping_fn
,
which maps agent IDs to a policy ID.
Note
Ray 2.10.0 introduces the alpha stage of RLlib’s “new API stack”. The team is currently transitioning algorithms, example scripts, and documentation to the new code base throughout the subsequent minor releases leading up to Ray 3.0.
See here for more details on how to activate and use the new API stack.
Building Custom Policy Classes#
Warning
As of Ray >= 1.9, it is no longer recommended to use the build_policy_class()
or
build_tf_policy()
utility functions for creating custom Policy sub-classes.
Instead, follow the simple guidelines here for directly sub-classing from
either one of the built-in types:
EagerTFPolicyV2
or
TorchPolicyV2
In order to create a custom Policy, sub-class Policy
(for a generic,
framework-agnostic policy),
TorchPolicyV2
(for a PyTorch specific policy), or
EagerTFPolicyV2
(for a TensorFlow specific policy) and override one or more of their methods. Those are in particular:
Base Policy classes#
RLlib's base class for all Policy implementations. |
|
A TF-eager / TF2 based tensorflow policy. |
|
PyTorch specific Policy class to use with RLlib. |
Making models#
Torch Policy#
Create model. |
|
Create model and action distribution function. |
Tensorflow Policy#
Build underlying model for this Policy. |
Inference#
Base Policy#
Computes actions for the current policy. |
|
Computes actions from collected samples (across multiple-agents). |
|
Computes and returns a single (B=1) action value. |
Torch Policy#
Custom function for sampling new actions given policy. |
|
Action distribution function for this Policy. |
|
Returns dict of extra info to include in experience batch. |
Tensorflow Policy#
Custom function for sampling new actions given policy. |
|
Action distribution function for this Policy. |
|
Extra values to fetch and return from compute_actions(). |
Computing, processing, and applying gradients#
Base Policy#
Computes gradients given a batch of experiences. |
|
Applies the (previously) computed gradients. |
Torch Policy#
Extra values to fetch and return from compute_gradients(). |
|
Called after each optimizer.zero_grad() + loss.backward() call. |
Tensorflow Policy#
Gradient stats function. |
|
Gradients computing function (from loss tensor, using local optimizer). |
|
Gradients computing function (from loss tensor, using local optimizer). |
|
Extra stats to be reported after gradient computation. |
Updating the Policy’s model#
Base Policy#
Perform one learning update, given |
|
Bulk-loads the given SampleBatch into the devices' memories. |
|
Runs a single step of SGD on an already loaded data in a buffer. |
|
Samples a batch from given replay actor and performs an update. |
|
Returns the number of currently loaded samples in the given buffer. |
Loss, Logging, optimizers, and trajectory processing#
Base Policy#
Loss function for this Policy. |
|
Computes the log-prob/likelihood for a given action and observation. |
|
Called on an update to global vars. |
|
Implements algorithm-specific trajectory postprocessing. |
Torch Policy#
Custom the local PyTorch optimizer(s) to use. |
|
Returns list of per-tower stats, copied to this Policy's device. |
Tensorflow Policy#
TF optimizer to use for policy optimization. |
|
Stats function. |
Saving and restoring#
Base Policy#
Creates new Policy instance(s) from a given Policy or Algorithm checkpoint. |
|
Exports Policy checkpoint to a local directory and returns an AIR Checkpoint. |
|
Exports the Policy's Model to local directory for serving. |
|
Recovers a Policy from a state object. |
|
Returns model weights. |
|
Sets this Policy's model's weights. |
|
Returns the entire current state of this Policy. |
|
Restores the entire current state of this Policy from |
|
Imports Policy from local file. |
Recurrent Policies#
Base Policy#
Returns initial RNN state for the current policy. |
|
The number of internal states needed by the RNN-Model of the Policy. |
|
Whether this Policy holds a recurrent Model. |
Miscellaneous#
Base Policy#
Calls the given function with this Policy instance. |
|
Returns tf.Session object to use for computing actions or None. |
|
Maximal view requirements dict for |
|
Returns the computer's network name. |
|
Returns the state of this Policy's exploration component. |
Torch Policy#
Get batch divisibility request. |
Tensorflow Policy#
Return the list of all savable variables for this policy. |
|
Get batch divisibility request. |