ray.rllib.core.learner.learner.Learner.compute_losses#

Learner.compute_losses(*, fwd_out: Dict[str, Any], batch: Dict[str, Any]) Dict[str, Any][source]#

Computes the loss(es) for the module being optimized.

This method must be overridden by MultiRLModule-specific Learners in order to define the specific loss computation logic. If the algorithm is single-agent, only compute_loss_for_module() should be overridden instead. If the algorithm uses independent multi-agent learning (default behavior for RLlib’s multi-agent setups), also only compute_loss_for_module() should be overridden, but it will be called for each individual RLModule inside the MultiRLModule. It is recommended to not compute any forward passes within this method, and to use the forward_train() outputs of the RLModule(s) to compute the required loss tensors. See here for a custom loss function example script: ray-project/ray # noqa

Parameters:
  • fwd_out – Output from a call to the forward_train() method of the underlying MultiRLModule (self.module) during training (self.update()).

  • batch – The training batch that was used to compute fwd_out.

Returns:

A dictionary mapping module IDs to individual loss terms.