Learner.postprocess_gradients(gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable][source]#

Applies potential postprocessing operations on the gradients.

This method is called after gradients have been computed and modifies them before they are applied to the respective module(s) by the optimizer(s). This might include grad clipping by value, norm, or global-norm, or other algorithm specific gradient postprocessing steps.

This default implementation calls self.postprocess_gradients_for_module() on each of the sub-modules in our MultiAgentRLModule: self.module and returns the accumulated gradients dicts.


gradients_dict – A dictionary of gradients in the same (flat) format as self._params. Note that top-level structures, such as module IDs, will not be present anymore in this dict. It will merely map gradient tensor references to gradient tensors.


A dictionary with the updated gradients and the exact same (flat) structure as the incoming gradients_dict arg.