ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2#

class ray.rllib.policy.eager_tf_policy_v2.EagerTFPolicyV2(observation_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, config: dict, **kwargs)[source]#

Bases: Policy

A TF-eager / TF2 based tensorflow policy.

This class is intended to be used and extended by sub-classing.

Methods

action_distribution_fn

Action distribution function for this Policy.

action_sampler_fn

Custom function for sampling new actions given policy.

apply

Calls the given function with this Policy instance.

apply_gradients_fn

Gradients computing function (from loss tensor, using local optimizer).

compute_gradients_fn

Gradients computing function (from loss tensor, using local optimizer).

compute_single_action

Computes and returns a single (B=1) action value.

export_checkpoint

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

extra_action_out_fn

Extra values to fetch and return from compute_actions().

extra_learn_fetches_fn

Extra stats to be reported after gradient computation.

from_checkpoint

Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.

from_state

Recovers a Policy from a state object.

get_batch_divisibility_req

Get batch divisibility request.

get_connector_metrics

Get metrics on timing from connectors.

get_host

Returns the computer's network name.

get_num_samples_loaded_into_buffer

Returns the number of currently loaded samples in the given buffer.

get_session

Returns tf.Session object to use for computing actions or None.

grad_stats_fn

Gradient stats function.

import_model_from_h5

Imports Policy from local file.

init_view_requirements

Maximal view requirements dict for learn_on_batch() and compute_actions calls.

learn_on_batch_from_replay_buffer

Samples a batch from given replay actor and performs an update.

learn_on_loaded_batch

Runs a single step of SGD on an already loaded data in a buffer.

load_batch_into_buffer

Bulk-loads the given SampleBatch into the devices' memories.

loss

Compute loss for this policy using model, dist_class and a train_batch.

make_model

Build underlying model for this Policy.

maybe_remove_time_dimension

Removes a time dimension for recurrent RLModules.

on_global_var_update

Called on an update to global vars.

optimizer

TF optimizer to use for policy optimization.

postprocess_trajectory

Post process trajectory in the format of a SampleBatch.

reset_connectors

Reset action- and agent-connectors for this policy.

restore_connectors

Restore agent and action connectors if configs available.

stats_fn

Stats function.

variables

Return the list of all savable variables for this policy.