ray.rllib.core.learner.learner.Learner.postprocess_gradients_for_module#

Learner.postprocess_gradients_for_module(*, module_id: str, hps: LearnerHyperparameters, module_gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable][source]#

Applies postprocessing operations on the gradients of the given module.

Parameters:
  • module_id – The module ID for which we will postprocess computed gradients. Note that module_gradients_dict already only carries those gradient tensors that belong to this module_id. Other module_id’s gradients are not available in this call.

  • hps – The LearnerHyperparameters specific to the given module_id.

  • module_gradients_dict – A dictionary of gradients in the same (flat) format as self._params, mapping gradient refs to gradient tensors, which are to be postprocessed. You may alter these tensors in place or create new ones and return these in a new dict.

Returns:

A dictionary with the updated gradients and the exact same (flat) structure as the incoming module_gradients_dict arg.