Learner.compute_loss(*, fwd_out: Dict[str, Any], batch: Dict[str, Any]) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor | Dict[str, Any][source]#

Computes the loss for the module being optimized.

This method must be overridden by multiagent-specific algorithm learners to specify the specific loss computation logic. If the algorithm is single agent compute_loss_for_module() should be overridden instead. fwd_out is the output of the forward_train() method of the underlying MultiAgentRLModule. batch is the data that was used to compute fwd_out. The returned dictionary must contain a key called ALL_MODULES, which will be used to compute gradients. It is recommended to not compute any forward passes within this method, and to use the forward_train() outputs of the RLModule(s) to compute the required tensors for loss calculations.

  • fwd_out – Output from a call to the forward_train() method of self.module during training (self.update()).

  • batch – The training batch that was used to compute fwd_out.


A dictionary mapping module IDs to individual loss terms. The dictionary must contain one protected key ALL_MODULES which will be used for computing gradients through.