Source code for ray.rllib.core.rl_module.apis.target_network_api
import abc
from typing import Any, Dict, List, Tuple
from ray.rllib.utils.typing import NetworkType
from ray.util.annotations import PublicAPI
TARGET_NETWORK_ACTION_DIST_INPUTS = "target_network_action_dist_inputs"
[docs]
@PublicAPI(stability="alpha")
class TargetNetworkAPI(abc.ABC):
"""An API to be implemented by RLModules for handling target networks.
RLModules implementing this API must override the `make_target_networks`,
`get_target_network_pairs`, and the `forward_target` methods.
Note that the respective Learner that owns the implementing RLModule handles all
target syncing logic.
"""
[docs]
@abc.abstractmethod
def make_target_networks(self) -> None:
"""Creates the required target nets for this RLModule.
Use the convenience `ray.rllib.core.learner.utils.make_target_network()` utility
when implementing this method. Pass in an already existing, corresponding "main"
net (for which you need a target net).
This function already takes care of initialization (from the "main" net).
"""
[docs]
@abc.abstractmethod
def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]:
"""Returns a list of 2-tuples of (main_net, target_net).
For example, if your RLModule has a property: `self.q_net` and this network
has a corresponding target net `self.target_q_net`, return from this
(overridden) method: [(self.q_net, self.target_q_net)].
Note that you need to create all target nets in your overridden
`make_target_networks` method and store the target nets in any properly of your
choice.
Returns:
A list of 2-tuples of (main_net, target_net)
"""
[docs]
@abc.abstractmethod
def forward_target(self, batch: Dict[str, Any]) -> Dict[str, Any]:
"""Performs the forward pass through the target net(s).
Args:
batch: The batch to use for the forward pass.
Returns:
The results from the forward pass(es) through the target net(s).
"""