ray.rllib.policy.torch_policy_v2.TorchPolicyV2.extra_grad_process
ray.rllib.policy.torch_policy_v2.TorchPolicyV2.extra_grad_process#
- TorchPolicyV2.extra_grad_process(optimizer: torch.optim.Optimizer, loss: Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]) Dict[str, Union[numpy.array, jnp.ndarray, tf.Tensor, torch.Tensor]] [source]#
Called after each optimizer.zero_grad() + loss.backward() call.
Called for each self._optimizers/loss-value pair. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping.
- Parameters
optimizer – A torch optimizer object.
loss – The loss tensor associated with the optimizer.
- Returns
An dict with information on the gradient processing step.