ray.rllib.policy.torch_policy_v2.TorchPolicyV2#

class ray.rllib.policy.torch_policy_v2.TorchPolicyV2(observation_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, config: dict, *, max_seq_len: int = 20)[source]#

Bases: ray.rllib.policy.policy.Policy

PyTorch specific Policy class to use with RLlib.

Methods

__init__(observation_space, action_space, ...)

Initializes a TorchPolicy instance.

action_distribution_fn(model, *, obs_batch, ...)

Action distribution function for this Policy.

action_sampler_fn(model, *, obs_batch, ...)

Custom function for sampling new actions given policy.

apply(func, *args, **kwargs)

Calls the given function with this Policy instance.

compute_single_action([obs, state, ...])

Computes and returns a single (B=1) action value.

export_checkpoint(export_dir[, ...])

Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.

export_model(export_dir[, onnx])

Exports the Policy's Model to local directory for serving.

extra_action_out(input_dict, state_batches, ...)

Returns dict of extra info to include in experience batch.

extra_compute_grad_fetches()

Extra values to fetch and return from compute_gradients().

extra_grad_process(optimizer, loss)

Called after each optimizer.zero_grad() + loss.backward() call.

from_checkpoint(checkpoint[, policy_ids])

Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.

from_state(state)

Recovers a Policy from a state object.

get_batch_divisibility_req()

Get batch divisibility request.

get_connector_metrics()

Get metrics on timing from connectors.

get_exploration_state()

Returns the state of this Policy's exploration component.

get_host()

Returns the computer's network name.

get_session()

Returns tf.Session object to use for computing actions or None.

get_tower_stats(stats_name)

Returns list of per-tower stats, copied to this Policy's device.

import_model_from_h5(import_file)

Imports weights into torch model.

init_view_requirements()

Maximal view requirements dict for learn_on_batch() and compute_actions calls.

learn_on_batch_from_replay_buffer(...)

Samples a batch from given replay actor and performs an update.

loss(model, dist_class, train_batch)

Constructs the loss function.

make_model()

Create model.

make_model_and_action_dist()

Create model and action distribution function.

make_rl_module()

Returns the RL Module (only for when RLModule API is enabled.)

maybe_add_time_dimension(input_dict, seq_lens)

Adds a time dimension for recurrent RLModules.

on_global_var_update(global_vars)

Called on an update to global vars.

optimizer()

Custom the local PyTorch optimizer(s) to use.

postprocess_trajectory(sample_batch[, ...])

Postprocesses a trajectory and returns the processed trajectory.

reset_connectors(env_id)

Reset action- and agent-connectors for this policy.

restore_connectors(state)

Restore agent and action connectors if configs available.

stats_fn(train_batch)

Stats function.