abstract Learner.compute_gradients(loss_per_module: Mapping[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], **kwargs) Dict[Hashable, torch.Tensor | tf.Variable][source]#

Computes the gradients based on the given losses.

  • loss_per_module – Dict mapping module IDs to their individual total loss terms, computed by the individual compute_loss_for_module() calls. The overall total loss (sum of loss terms over all modules) is stored under loss_per_module[ALL_MODULES].

  • **kwargs – Forward compatibility kwargs.


The gradients in the same (flat) format as self._params. Note that all top-level structures, such as module IDs, will not be present anymore in the returned dict. It will merely map parameter tensor references to their respective gradient tensors.