Source code for ray.rllib.core.rl_module.apis.value_function_api

import abc
from typing import Any, Dict, Optional

from ray.rllib.utils.typing import TensorType
from ray.util.annotations import PublicAPI


[docs] @PublicAPI(stability="alpha") class ValueFunctionAPI(abc.ABC): """An API to be implemented by RLModules for handling value function-based learning. RLModules implementing this API must override the `compute_values` method. """
[docs] @abc.abstractmethod def compute_values( self, batch: Dict[str, Any], embeddings: Optional[Any] = None, ) -> TensorType: """Computes the value estimates given `batch`. Args: batch: The batch to compute value function estimates for. embeddings: Optional embeddings already computed from the `batch` (by another forward pass through the model's encoder (or other subcomponent that computes an embedding). For example, the caller of thie method should provide `embeddings` - if available - to avoid duplicate passes through a shared encoder. Returns: A tensor of shape (B,) or (B, T) (in case the input `batch` has a time dimension. Note that the last value dimension should already be squeezed out (not 1!). """