ray.rllib.policy.policy.Policy#

class ray.rllib.policy.policy.Policy(observation_space: gymnasium.Space, action_space: gymnasium.Space, config: dict)[source]#

Bases: object

RLlib’s base class for all Policy implementations.

Policy is the abstract superclass for all DL-framework specific sub-classes (e.g. TFPolicy or TorchPolicy). It exposes APIs to

  1. Compute actions from observation (and possibly other) inputs.

  2. Manage the Policy’s NN model(s), like exporting and loading their weights.

  3. Postprocess a given trajectory from the environment or other input via the

    postprocess_trajectory method.

  4. Compute losses from a train batch.

  5. Perform updates from a train batch on the NN-models (this normally includes loss

    calculations) either:

    1. in one monolithic step (learn_on_batch)

    2. via batch pre-loading, then n steps of actual loss computations and updates

      (load_batch_into_buffer + learn_on_loaded_batch).

Methods

__init__(observation_space, action_space, config)

Initializes a Policy instance.

apply(func, *args, **kwargs)

Calls the given function with this Policy instance.

apply_gradients(gradients)

Applies the (previously) computed gradients.

compute_actions(obs_batch[, state_batches, ...])

Computes actions for the current policy.

compute_actions_from_input_dict(input_dict)

Computes actions from collected samples (across multiple-agents).

compute_gradients(postprocessed_batch)

Computes gradients given a batch of experiences.

compute_log_likelihoods(actions, obs_batch)

Computes the log-prob/likelihood for a given action and observation.

compute_single_action([obs, state, ...])

Computes and returns a single (B=1) action value.

export_checkpoint(export_dir[, ...])

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

export_model(export_dir[, onnx])

Exports the Policy's Model to local directory for serving.

from_checkpoint(checkpoint[, policy_ids])

Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.

from_state(state)

Recovers a Policy from a state object.

get_connector_metrics()

Get metrics on timing from connectors.

get_exploration_state()

Returns the state of this Policy's exploration component.

get_host()

Returns the computer's network name.

get_initial_state()

Returns initial RNN state for the current policy.

get_num_samples_loaded_into_buffer([...])

Returns the number of currently loaded samples in the given buffer.

get_session()

Returns tf.Session object to use for computing actions or None.

get_state()

Returns the entire current state of this Policy.

get_weights()

Returns model weights.

import_model_from_h5(import_file)

Imports Policy from local file.

init_view_requirements()

Maximal view requirements dict for learn_on_batch() and compute_actions calls.

is_recurrent()

Whether this Policy holds a recurrent Model.

learn_on_batch(samples)

Perform one learning update, given samples.

learn_on_batch_from_replay_buffer(...)

Samples a batch from given replay actor and performs an update.

learn_on_loaded_batch([offset, buffer_index])

Runs a single step of SGD on an already loaded data in a buffer.

load_batch_into_buffer(batch[, buffer_index])

Bulk-loads the given SampleBatch into the devices' memories.

loss(model, dist_class, train_batch)

Loss function for this Policy.

make_rl_module()

Returns the RL Module (only for when RLModule API is enabled.)

maybe_add_time_dimension(input_dict, seq_lens)

Adds a time dimension for recurrent RLModules.

maybe_remove_time_dimension(input_dict)

Removes a time dimension for recurrent RLModules.

num_state_tensors()

The number of internal states needed by the RNN-Model of the Policy.

on_global_var_update(global_vars)

Called on an update to global vars.

postprocess_trajectory(sample_batch[, ...])

Implements algorithm-specific trajectory postprocessing.

reset_connectors(env_id)

Reset action- and agent-connectors for this policy.

restore_connectors(state)

Restore agent and action connectors if configs available.

set_state(state)

Restores the entire current state of this Policy from state.

set_weights(weights)

Sets this Policy's model's weights.