Learner.compile_results(*, batch: MultiAgentBatch, fwd_out: Dict[str, Any], loss_per_module: Dict[str, numpy.array | jnp.ndarray | tf.Tensor | torch.Tensor], metrics_per_module: DefaultDict[str, Dict[str, Any]]) Dict[str, Any][source]#

Compile results from the update in a numpy-friendly format.

  • batch – The batch that was used for the update.

  • fwd_out – The output of the forward train pass.

  • loss_per_module – A dict mapping module IDs (including ALL_MODULES) to the individual loss tensors as returned by calls to compute_loss_for_module(module_id=...).

  • metrics_per_module – The collected metrics defaultdict mapping ModuleIDs to metrics dicts. These metrics are collected during loss- and gradient computation, gradient postprocessing, and gradient application.


A dictionary of results sub-dicts per module (including ALL_MODULES).