ray.rllib.policy.torch_policy_v2.TorchPolicyV2#
- class ray.rllib.policy.torch_policy_v2.TorchPolicyV2(observation_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, config: dict, *, max_seq_len: int = 20)[source]#
Bases:
Policy
PyTorch specific Policy class to use with RLlib.
Methods
Initializes a TorchPolicy instance.
Action distribution function for this Policy.
Custom function for sampling new actions given policy.
Calls the given function with this Policy instance.
Computes and returns a single (B=1) action value.
Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
Exports the Policy's Model to local directory for serving.
Returns dict of extra info to include in experience batch.
Extra values to fetch and return from compute_gradients().
Called after each optimizer.zero_grad() + loss.backward() call.
Creates new Policy instance(s) from a given Policy or Algorithm checkpoint.
Recovers a Policy from a state object.
Get batch divisibility request.
Get metrics on timing from connectors.
Returns the state of this Policy's exploration component.
Returns the computer's network name.
Returns tf.Session object to use for computing actions or None.
Returns list of per-tower stats, copied to this Policy's device.
Imports weights into torch model.
Maximal view requirements dict for
learn_on_batch()
andcompute_actions
calls.Samples a batch from given replay actor and performs an update.
Constructs the loss function.
Create model.
Create model and action distribution function.
Removes a time dimension for recurrent RLModules.
Called on an update to global vars.
Custom the local PyTorch optimizer(s) to use.
Postprocesses a trajectory and returns the processed trajectory.
Reset action- and agent-connectors for this policy.
Restore agent and action connectors if configs available.
Stats function.