ray.rllib.core.learner.learner.Learner._update#

abstract Learner._update(batch: Dict[str, Any], **kwargs) Tuple[Any, Any, Any][source]#

Contains all logic for an in-graph/traceable update step.

Framework specific subclasses must implement this method. This should include calls to the RLModule’s forward_train, compute_loss, compute_gradients`, postprocess_gradients, and apply_gradients methods and return a tuple with all the individual results.

Parameters:
  • batch – The train batch already converted to a Dict mapping str to (possibly nested) tensors.

  • kwargs – Forward compatibility kwargs.

Returns:

  1. The forward_train() output of the RLModule,

  2. the loss_per_module dictionary mapping module IDs to individual loss

    tensors

  3. a metrics dict mapping module IDs to metrics key/value pairs.

Return type:

A tuple consisting of