Source code for ray.rllib.connectors.env_to_module.observation_preprocessor

import abc
from typing import Any, Dict, List, Optional

import gymnasium as gym

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import EpisodeType
from ray.util.annotations import PublicAPI


[docs] @PublicAPI(stability="alpha") class SingleAgentObservationPreprocessor(ConnectorV2, abc.ABC): """Env-to-module connector preprocessing the most recent single-agent observation. This is a convenience class that simplifies the writing of few-step preprocessor connectors. Note that this class also works in a multi-agent setup, in which case RLlib separately calls this connector piece with each agents' observation and `SingleAgentEpisode` object. Users must implement the `preprocess()` method, which simplifies the usual procedure of extracting some data from a list of episodes and adding it to the batch to a mere "old-observation --transform--> return new-observation" step. """
[docs] @override(ConnectorV2) def recompute_output_observation_space( self, input_observation_space: gym.Space, input_action_space: gym.Space, ) -> gym.Space: # Users should override this method only in case the # `SingleAgentObservationPreprocessor` changes the observation space of the # pipeline. In this case, return the new observation space based on the # incoming one (`input_observation_space`). return super().recompute_output_observation_space( input_observation_space, input_action_space )
[docs] @abc.abstractmethod def preprocess(self, observation, episode: SingleAgentEpisode): """Override to implement the preprocessing logic. Args: observation: A single (non-batched) observation item for a single agent to be preprocessed by this connector. episode: The `SingleAgentEpisode` instance, from which `observation` was taken. You can extract information on the particular AgentID and the ModuleID through `episode.agent_id` and `episode.module_id`. Returns: The new observation for the agent after `observation` has been preprocessed. """
@override(ConnectorV2) def __call__( self, *, rl_module: RLModule, batch: Dict[str, Any], episodes: List[EpisodeType], explore: Optional[bool] = None, persistent_data: Optional[dict] = None, **kwargs, ) -> Any: # We process and then replace observations inside the episodes directly. # Thus, all following connectors will only see and operate on the already # processed observation (w/o having access anymore to the original # observations). for sa_episode in self.single_agent_episode_iterator(episodes): observation = sa_episode.get_observations(-1) # Process the observation and write the new observation back into the # episode. new_observation = self.preprocess( observation=observation, episode=sa_episode, ) sa_episode.set_observations(at_indices=-1, new_data=new_observation) # We set the Episode's observation space to ours so that we can safely # set the last obs to the new value (without causing a space mismatch # error). sa_episode.observation_space = self.observation_space # Leave `batch` as is. RLlib's default connector automatically populates # the OBS column therein from the episodes' now transformed observations. return batch
[docs] @PublicAPI(stability="alpha") class MultiAgentObservationPreprocessor(ConnectorV2, abc.ABC): """Env-to-module connector preprocessing the most recent multi-agent observation. The observation is always a dict of individual agents' observations. This is a convenience class that simplifies the writing of few-step preprocessor connectors. Users must implement the `preprocess()` method, which simplifies the usual procedure of extracting some data from a list of episodes and adding it to the batch to a mere "old-observation --transform--> return new-observation" step. """
[docs] @override(ConnectorV2) def recompute_output_observation_space( self, input_observation_space: gym.Space, input_action_space: gym.Space, ) -> gym.Space: # Users should override this method only in case the # `MultiAgentObservationPreprocessor` changes the observation space of the # pipeline. In this case, return the new observation space based on the # incoming one (`input_observation_space`). return super().recompute_output_observation_space( input_observation_space, input_action_space )
[docs] @abc.abstractmethod def preprocess(self, observations, episode: MultiAgentEpisode): """Override to implement the preprocessing logic. Args: observations: An observation dict containing each stepping agents' (non-batched) observation to be preprocessed by this connector. episode: The MultiAgentEpisode instance, where the `observation` dict originated from. Returns: The new multi-agent observation dict after `observations` has been preprocessed. """
@override(ConnectorV2) def __call__( self, *, rl_module: RLModule, batch: Dict[str, Any], episodes: List[EpisodeType], explore: Optional[bool] = None, persistent_data: Optional[dict] = None, **kwargs, ) -> Any: # We process and then replace observations inside the episodes directly. # Thus, all following connectors will only see and operate on the already # processed observation (w/o having access anymore to the original # observations). for ma_episode in episodes: observations = ma_episode.get_observations(-1) # Process the observation and write the new observation back into the # episode. new_observation = self.preprocess( observations=observations, episode=ma_episode, ) # TODO (sven): Implement set_observations API for multi-agent episodes. # For now, we'll hack it through the single agent APIs. # ma_episode.set_observations(at_indices=-1, new_data=new_observation) for agent_id, obs in new_observation.items(): ma_episode.agent_episodes[agent_id].set_observations( at_indices=-1, new_data=obs, ) # We set the Episode's observation space to ours so that we can safely # set the last obs to the new value (without causing a space mismatch # error). ma_episode.observation_space = self.observation_space # Leave `batch` as is. RLlib's default connector automatically populates # the OBS column therein from the episodes' now transformed observations. return batch
# Backward compatibility ObservationPreprocessor = SingleAgentObservationPreprocessor