ray.rllib.policy.torch_policy_v2.TorchPolicyV2
ray.rllib.policy.torch_policy_v2.TorchPolicyV2#
- class ray.rllib.policy.torch_policy_v2.TorchPolicyV2(observation_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, config: dict, *, max_seq_len: int = 20)[source]#
Bases:
ray.rllib.policy.policy.Policy
PyTorch specific Policy class to use with RLlib.
Methods
__init__
(observation_space, action_space, ...)Initializes a TorchPolicy instance.
action_distribution_fn
(model, *, obs_batch, ...)Action distribution function for this Policy.
action_sampler_fn
(model, *, obs_batch, ...)Custom function for sampling new actions given policy.
apply
(func, *args, **kwargs)Calls the given function with this Policy instance.
compute_single_action
([obs, state, ...])Computes and returns a single (B=1) action value.
export_checkpoint
(export_dir[, ...])Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
export_model
(export_dir[, onnx])Exports the Policy's Model to local directory for serving.
extra_action_out
(input_dict, state_batches, ...)Returns dict of extra info to include in experience batch.
Extra values to fetch and return from compute_gradients().
extra_grad_process
(optimizer, loss)Called after each optimizer.zero_grad() + loss.backward() call.
from_checkpoint
(checkpoint[, policy_ids])Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.
from_state
(state)Recovers a Policy from a state object.
Get batch divisibility request.
Get metrics on timing from connectors.
Returns the state of this Policy's exploration component.
get_host
()Returns the computer's network name.
Returns tf.Session object to use for computing actions or None.
get_tower_stats
(stats_name)Returns list of per-tower stats, copied to this Policy's device.
import_model_from_h5
(import_file)Imports weights into torch model.
Maximal view requirements dict for
learn_on_batch()
andcompute_actions
calls.Samples a batch from given replay actor and performs an update.
loss
(model, dist_class, train_batch)Constructs the loss function.
Create model.
Create model and action distribution function.
Returns the RL Module (only for when RLModule API is enabled.)
maybe_add_time_dimension
(input_dict, seq_lens)Adds a time dimension for recurrent RLModules.
on_global_var_update
(global_vars)Called on an update to global vars.
Custom the local PyTorch optimizer(s) to use.
postprocess_trajectory
(sample_batch[, ...])Postprocesses a trajectory and returns the processed trajectory.
reset_connectors
(env_id)Reset action- and agent-connectors for this policy.
restore_connectors
(state)Restore agent and action connectors if configs available.
stats_fn
(train_batch)Stats function.