ray.rllib.core.learner.learner.Learner.compute_losses#
- Learner.compute_losses(*, fwd_out: Dict[str, Any], batch: Dict[str, Any]) Dict[str, Any] [source]#
Computes the loss(es) for the module being optimized.
This method must be overridden by MultiRLModule-specific Learners in order to define the specific loss computation logic. If the algorithm is single-agent, only
compute_loss_for_module()
should be overridden instead. If the algorithm uses independent multi-agent learning (default behavior for RLlib’s multi-agent setups), also onlycompute_loss_for_module()
should be overridden, but it will be called for each individual RLModule inside the MultiRLModule. It is recommended to not compute any forward passes within this method, and to use theforward_train()
outputs of the RLModule(s) to compute the required loss tensors. See here for a custom loss function example script: ray-project/ray # noqa- Parameters:
fwd_out – Output from a call to the
forward_train()
method of the underlying MultiRLModule (self.module
) during training (self.update()
).batch – The train batch that was used to compute
fwd_out
.
- Returns:
A dictionary mapping module IDs to individual loss terms.