ray.rllib.policy.policy.Policy#

class ray.rllib.policy.policy.Policy(observation_space: gymnasium.Space, action_space: gymnasium.Space, config: dict)[source]#

RLlib’s base class for all Policy implementations.

Policy is the abstract superclass for all DL-framework specific sub-classes (e.g. TFPolicy or TorchPolicy). It exposes APIs to

  1. Compute actions from observation (and possibly other) inputs.

  2. Manage the Policy’s NN model(s), like exporting and loading their weights.

  3. Postprocess a given trajectory from the environment or other input via the

    postprocess_trajectory method.

  4. Compute losses from a train batch.

  5. Perform updates from a train batch on the NN-models (this normally includes loss

    calculations) either:

    1. in one monolithic step (learn_on_batch)

    2. via batch pre-loading, then n steps of actual loss computations and updates

      (load_batch_into_buffer + learn_on_loaded_batch).

Methods

__init__

Initializes a Policy instance.

apply

Calls the given function with this Policy instance.

apply_gradients

Applies the (previously) computed gradients.

compute_actions

Computes actions for the current policy.

compute_actions_from_input_dict

Computes actions from collected samples (across multiple-agents).

compute_gradients

Computes gradients given a batch of experiences.

compute_log_likelihoods

Computes the log-prob/likelihood for a given action and observation.

compute_single_action

Computes and returns a single (B=1) action value.

export_checkpoint

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

export_model

Exports the Policy's Model to local directory for serving.

from_checkpoint

Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.

from_state

Recovers a Policy from a state object.

get_connector_metrics

Get metrics on timing from connectors.

get_exploration_state

Returns the state of this Policy's exploration component.

get_host

Returns the computer's network name.

get_initial_state

Returns initial RNN state for the current policy.

get_num_samples_loaded_into_buffer

Returns the number of currently loaded samples in the given buffer.

get_session

Returns tf.Session object to use for computing actions or None.

get_state

Returns the entire current state of this Policy.

get_weights

Returns model weights.

import_model_from_h5

Imports Policy from local file.

init_view_requirements

Maximal view requirements dict for learn_on_batch() and compute_actions calls.

is_recurrent

Whether this Policy holds a recurrent Model.

learn_on_batch

Perform one learning update, given samples.

learn_on_batch_from_replay_buffer

Samples a batch from given replay actor and performs an update.

learn_on_loaded_batch

Runs a single step of SGD on an already loaded data in a buffer.

load_batch_into_buffer

Bulk-loads the given SampleBatch into the devices' memories.

loss

Loss function for this Policy.

make_rl_module

Returns the RL Module (only for when RLModule API is enabled.)

maybe_add_time_dimension

Adds a time dimension for recurrent RLModules.

maybe_remove_time_dimension

Removes a time dimension for recurrent RLModules.

num_state_tensors

The number of internal states needed by the RNN-Model of the Policy.

on_global_var_update

Called on an update to global vars.

postprocess_trajectory

Implements algorithm-specific trajectory postprocessing.

reset_connectors

Reset action- and agent-connectors for this policy.

restore_connectors

Restore agent and action connectors if configs available.

set_state

Restores the entire current state of this Policy from state.

set_weights

Sets this Policy's model's weights.