ray.rllib.core.learner.learner.Learner.postprocess_gradients_for_module#
- Learner.postprocess_gradients_for_module(*, module_id: str, config: AlgorithmConfig | None = None, module_gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable] [source]#
Applies postprocessing operations on the gradients of the given module.
- Parameters:
module_id – The module ID for which we will postprocess computed gradients. Note that
module_gradients_dict
already only carries those gradient tensors that belong to thismodule_id
. Othermodule_id
’s gradients are not available in this call.config – The AlgorithmConfig specific to the given
module_id
.module_gradients_dict – A dictionary of gradients in the same (flat) format as self._params, mapping gradient refs to gradient tensors, which are to be postprocessed. You may alter these tensors in place or create new ones and return these in a new dict.
- Returns:
A dictionary with the updated gradients and the exact same (flat) structure as the incoming
module_gradients_dict
arg.