Source code for ray.rllib.core.rl_module.apis.self_supervised_loss_api

import abc
from typing import Any, Dict, TYPE_CHECKING

from ray.rllib.utils.typing import ModuleID, TensorType
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
    from ray.rllib.core.learner.learner import Learner


[docs] @PublicAPI(stability="alpha") class SelfSupervisedLossAPI(abc.ABC): """An API to be implemented by RLModules that bring their own self-supervised loss. Learners will call these model's `compute_self_supervised_loss()` method instead of the Learner's own `compute_loss_for_module()` method. The call signature is identical to the Learner's `compute_loss_for_module()` method except of an additional mandatory `learner` kwarg. """
[docs] @abc.abstractmethod def compute_self_supervised_loss( self, *, learner: "Learner", module_id: ModuleID, config: "AlgorithmConfig", batch: Dict[str, Any], fwd_out: Dict[str, TensorType], **kwargs, ) -> TensorType: """Computes the loss for a single module. Think of this as computing loss for a single agent. For multi-agent use-cases that require more complicated computation for loss, consider overriding the `compute_losses` method instead. Args: learner: The Learner calling this loss method on the RLModule. module_id: The ID of the RLModule (within a MultiRLModule). config: The AlgorithmConfig specific to the given `module_id`. batch: The sample batch for this particular RLModule. fwd_out: The output of the forward pass for this particular RLModule. Returns: A single total loss tensor. If you have more than one optimizer on the provided `module_id` and would like to compute gradients separately using these different optimizers, simply add up the individual loss terms for each optimizer and return the sum. Also, for recording/logging any individual loss terms, you can use the `Learner.metrics.log_value( key=..., value=...)` or `Learner.metrics.log_dict()` APIs. See: :py:class:`~ray.rllib.utils.metrics.metrics_logger.MetricsLogger` for more information. """