ray.rllib.core.learner.learner.Learner.compute_loss#
- Learner.compute_loss(*, fwd_out: Dict[str, Any], batch: Dict[str, Any]) Dict[str, Any] [source]#
Computes the loss for the module being optimized.
This method must be overridden by multiagent-specific algorithm learners to specify the specific loss computation logic. If the algorithm is single agent
compute_loss_for_module()
should be overridden instead.fwd_out
is the output of theforward_train()
method of the underlying MultiRLModule.batch
is the data that was used to computefwd_out
. The returned dictionary must contain a key called ALL_MODULES, which will be used to compute gradients. It is recommended to not compute any forward passes within this method, and to use theforward_train()
outputs of the RLModule(s) to compute the required tensors for loss calculations.- Parameters:
fwd_out – Output from a call to the
forward_train()
method of self.module during training (self.update()
).batch – The training batch that was used to compute
fwd_out
.
- Returns:
A dictionary mapping module IDs to individual loss terms. The dictionary must contain one protected key ALL_MODULES which will be used for computing gradients through.