abstract Learner.compute_loss_for_module(*, module_id: str, hps: LearnerHyperparameters, batch: NestedDict, fwd_out: Mapping[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor]) numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor[source]#

Computes the loss for a single module.

Think of this as computing loss for a single agent. For multi-agent use-cases that require more complicated computation for loss, consider overriding the compute_loss method instead.

  • module_id – The id of the module.

  • hps – The LearnerHyperparameters specific to the given module_id.

  • batch – The sample batch for this particular module.

  • fwd_out – The output of the forward pass for this particular module.


A single total loss tensor. If you have more than one optimizer on the provided module_id and would like to compute gradients separately using these different optimizers, simply add up the individual loss terms for each optimizer and return the sum. Also, for tracking the individual loss terms, you can use the Learner.register_metric(s) APIs.