Source code for ray.rllib.core.rl_module.apis.q_net_api

import abc
from typing import Dict

from ray.rllib.utils.typing import TensorType
from ray.util.annotations import PublicAPI


[docs] @PublicAPI(stability="alpha") class QNetAPI(abc.ABC): """An API to be implemented by RLModules used for (distributional) Q-learning. RLModules implementing this API must override the `compute_q_values` and the `compute_advantage_distribution` methods. """
[docs] @abc.abstractmethod def compute_q_values( self, batch: Dict[str, TensorType], ) -> Dict[str, TensorType]: """Computes Q-values, given encoder, q-net and (optionally), advantage net. Note, these can be accompanied by logits and probabilities in case of distributional Q-learning, i.e. `self.num_atoms > 1`. Args: batch: The batch received in the forward pass. Results: A dictionary containing the Q-value predictions ("qf_preds") and in case of distributional Q-learning - in addition to the Q-value predictions ("qf_preds") - the support atoms ("atoms"), the Q-logits ("qf_logits"), and the probabilities ("qf_probs"). """
[docs] def compute_advantage_distribution( self, batch: Dict[str, TensorType], ) -> Dict[str, TensorType]: """Computes the advantage distribution. Note this distribution is identical to the Q-distribution in case no dueling architecture is used. Args: batch: A dictionary containing a tensor with the outputs of the forward pass of the Q-head or advantage stream head. Returns: A `dict` containing the support of the discrete distribution for either Q-values or advantages (in case of a dueling architecture), ("atoms"), the logits per action and atom and the probabilities of the discrete distribution (per action and atom of the support). """ # Return the Q-distribution by default. return self.compute_q_values(batch)