Learner.postprocess_gradients_for_module(*, module_id: str, config: AlgorithmConfig | None = None, module_gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable][source]#

Applies postprocessing operations on the gradients of the given module.

  • module_id – The module ID for which we will postprocess computed gradients. Note that module_gradients_dict already only carries those gradient tensors that belong to this module_id. Other module_id’s gradients are not available in this call.

  • config – The AlgorithmConfig specific to the given module_id.

  • module_gradients_dict – A dictionary of gradients in the same (flat) format as self._params, mapping gradient refs to gradient tensors, which are to be postprocessed. You may alter these tensors in place or create new ones and return these in a new dict.


A dictionary with the updated gradients and the exact same (flat) structure as the incoming module_gradients_dict arg.