ray.rllib.policy.torch_policy_v2.TorchPolicyV2.extra_grad_process#

TorchPolicyV2.extra_grad_process(optimizer: torch.optim.Optimizer, loss: numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor) Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor][source]#

Called after each optimizer.zero_grad() + loss.backward() call.

Called for each self._optimizers/loss-value pair. Allows for gradient processing before optimizer.step() is called. E.g. for gradient clipping.

Parameters:
  • optimizer – A torch optimizer object.

  • loss – The loss tensor associated with the optimizer.

Returns:

An dict with information on the gradient processing step.