abstract Learner._update(batch: Dict[str, Any], **kwargs) Tuple[Any, Any, Any][source]#

Contains all logic for an in-graph/traceable update step.

Framework specific subclasses must implement this method. This should include calls to the RLModule’s forward_train, compute_loss, compute_gradients`, postprocess_gradients, and apply_gradients methods and return a tuple with all the individual results.

  • batch – The train batch already converted to a Dict mapping str to (possibly nested) tensors.

  • kwargs – Forward compatibility kwargs.


  1. The forward_train() output of the RLModule,

  2. the loss_per_module dictionary mapping module IDs to individual loss


  3. a metrics dict mapping module IDs to metrics key/value pairs.

Return type:

A tuple consisting of