ray.rllib.core.learner.learner.Learner.compute_gradients#
- abstract Learner.compute_gradients(loss_per_module: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], **kwargs) Dict[Hashable, torch.Tensor | tf.Variable] [source]#
Computes the gradients based on the given losses.
- Parameters:
loss_per_module – Dict mapping module IDs to their individual total loss terms, computed by the individual
compute_loss_for_module()
calls. The overall total loss (sum of loss terms over all modules) is stored underloss_per_module[ALL_MODULES]
.**kwargs – Forward compatibility kwargs.
- Returns:
The gradients in the same (flat) format as self._params. Note that all top-level structures, such as module IDs, will not be present anymore in the returned dict. It will merely map parameter tensor references to their respective gradient tensors.