ray.rllib.core.learner.learner.Learner._update#
- abstract Learner._update(batch: Dict[str, Any], **kwargs) Tuple[Any, Any, Any] [source]#
Contains all logic for an in-graph/traceable update step.
Framework specific subclasses must implement this method. This should include calls to the RLModule’s
forward_train
,compute_loss
, compute_gradients`,postprocess_gradients
, andapply_gradients
methods and return a tuple with all the individual results.- Parameters:
batch – The train batch already converted to a Dict mapping str to (possibly nested) tensors.
kwargs – Forward compatibility kwargs.
- Returns:
The
forward_train()
output of the RLModule,- the loss_per_module dictionary mapping module IDs to individual loss
tensors
a metrics dict mapping module IDs to metrics key/value pairs.
- Return type:
A tuple consisting of