ray.rllib.policy.torch_policy_v2.TorchPolicyV2#

class ray.rllib.policy.torch_policy_v2.TorchPolicyV2(observation_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, config: dict, *, max_seq_len: int = 20)[source]#

Bases: Policy

PyTorch specific Policy class to use with RLlib.

Methods

__init__

Initializes a TorchPolicy instance.

action_distribution_fn

Action distribution function for this Policy.

action_sampler_fn

Custom function for sampling new actions given policy.

apply

Calls the given function with this Policy instance.

compute_single_action

Computes and returns a single (B=1) action value.

export_checkpoint

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

export_model

Exports the Policy's Model to local directory for serving.

extra_action_out

Returns dict of extra info to include in experience batch.

extra_compute_grad_fetches

Extra values to fetch and return from compute_gradients().

extra_grad_process

Called after each optimizer.zero_grad() + loss.backward() call.

from_checkpoint

Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.

from_state

Recovers a Policy from a state object.

get_batch_divisibility_req

Get batch divisibility request.

get_connector_metrics

Get metrics on timing from connectors.

get_exploration_state

Returns the state of this Policy's exploration component.

get_host

Returns the computer's network name.

get_session

Returns tf.Session object to use for computing actions or None.

get_tower_stats

Returns list of per-tower stats, copied to this Policy's device.

import_model_from_h5

Imports weights into torch model.

init_view_requirements

Maximal view requirements dict for learn_on_batch() and compute_actions calls.

learn_on_batch_from_replay_buffer

Samples a batch from given replay actor and performs an update.

loss

Constructs the loss function.

make_model

Create model.

make_model_and_action_dist

Create model and action distribution function.

maybe_remove_time_dimension

Removes a time dimension for recurrent RLModules.

on_global_var_update

Called on an update to global vars.

optimizer

Custom the local PyTorch optimizer(s) to use.

postprocess_trajectory

Postprocesses a trajectory and returns the processed trajectory.

reset_connectors

Reset action- and agent-connectors for this policy.

restore_connectors

Restore agent and action connectors if configs available.

stats_fn

Stats function.