Source code for ray.rllib.connectors.connector_pipeline_v2

import logging
from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union

import gymnasium as gym

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.metrics import TIMERS, CONNECTOR_PIPELINE_TIMER, CONNECTOR_TIMERS
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.metrics.utils import to_snake_case
from ray.rllib.utils.typing import EpisodeType, StateDict
from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)


[docs] @PublicAPI(stability="alpha") class ConnectorPipelineV2(ConnectorV2): """Utility class for quick manipulation of a connector pipeline."""
[docs] @override(ConnectorV2) def recompute_output_observation_space( self, input_observation_space: gym.Space, input_action_space: gym.Space, ) -> gym.Space: self._fix_spaces(input_observation_space, input_action_space) return self.observation_space
[docs] @override(ConnectorV2) def recompute_output_action_space( self, input_observation_space: gym.Space, input_action_space: gym.Space, ) -> gym.Space: self._fix_spaces(input_observation_space, input_action_space) return self.action_space
def __init__( self, input_observation_space: Optional[gym.Space] = None, input_action_space: Optional[gym.Space] = None, *, connectors: Optional[List[ConnectorV2]] = None, **kwargs, ): """Initializes a ConnectorPipelineV2 instance. Args: input_observation_space: The (optional) input observation space for this connector piece. This is the space coming from a previous connector piece in the (env-to-module or learner) pipeline or is directly defined within the gym.Env. input_action_space: The (optional) input action space for this connector piece. This is the space coming from a previous connector piece in the (module-to-env) pipeline or is directly defined within the gym.Env. connectors: A list of individual ConnectorV2 pieces to be added to this pipeline during construction. Note that you can always add (or remove) more ConnectorV2 pieces later on the fly. """ self.connectors = [] for conn in connectors: # If we have a `ConnectorV2` instance just append. if isinstance(conn, ConnectorV2): self.connectors.append(conn) # If, we have a class with `args` and `kwargs`, build the instance. # Note that this way of constructing a pipeline should only be # used internally when restoring the pipeline state from a # checkpoint. elif isinstance(conn, tuple) and len(conn) == 3: self.connectors.append(conn[0](*conn[1], **conn[2])) super().__init__(input_observation_space, input_action_space, **kwargs) def __len__(self): return len(self.connectors) @override(ConnectorV2) def __call__( self, *, rl_module: RLModule, batch: Dict[str, Any], episodes: List[EpisodeType], explore: Optional[bool] = None, shared_data: Optional[dict] = None, metrics: Optional[MetricsLogger] = None, **kwargs, ) -> Any: """In a pipeline, we simply call each of our connector pieces after each other. Each connector piece receives as input the output of the previous connector piece in the pipeline. """ shared_data = shared_data if shared_data is not None else {} full_stats = None if metrics: full_stats = metrics.log_time( kwargs.get("metrics_prefix_key", ()) + (CONNECTOR_PIPELINE_TIMER,) ) full_stats.__enter__() # Loop through connector pieces and call each one with the output of the # previous one. Thereby, time each connector piece's call. for connector in self.connectors: # TODO (sven): Add MetricsLogger to non-Learner components that have a # LearnerConnector pipeline. stats = None if metrics: stats = metrics.log_time( kwargs.get("metrics_prefix_key", ()) + ( TIMERS, CONNECTOR_TIMERS, to_snake_case(connector.__class__.__name__), ) ) stats.__enter__() batch = connector( rl_module=rl_module, batch=batch, episodes=episodes, explore=explore, shared_data=shared_data, metrics=metrics, # Deprecated arg. data=batch, **kwargs, ) if metrics: stats.__exit__(None, None, None) if not isinstance(batch, dict): raise ValueError( f"`data` returned by ConnectorV2 {connector} must be a dict! " f"You returned {batch}. Check your (custom) connectors' " f"`__call__()` method's return value and make sure you return " f"the `batch` arg passed in (either altered or unchanged)." ) if metrics: full_stats.__exit__(None, None, None) return batch
[docs] def remove(self, name_or_class: Union[str, Type]): """Remove a single connector piece in this pipeline by its name or class. Args: name_or_class: The name of the connector piece to be removed from the pipeline. """ idx = -1 for i, c in enumerate(self.connectors): if (isinstance(name_or_class, type) and c.__class__ is name_or_class) or ( isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class ): idx = i break if idx >= 0: del self.connectors[idx] self._fix_spaces(self.input_observation_space, self.input_action_space) logger.info( f"Removed connector {name_or_class} from {self.__class__.__name__}." ) else: logger.warning( f"Trying to remove a non-existent connector {name_or_class}." )
[docs] def insert_before( self, name_or_class: Union[str, type], connector: ConnectorV2, ) -> ConnectorV2: """Insert a new connector piece before an existing piece (by name or class). Args: name_or_class: Name or class of the connector piece before which `connector` will get inserted. connector: The new connector piece to be inserted. Returns: The ConnectorV2 before which `connector` has been inserted. """ idx = -1 for idx, c in enumerate(self.connectors): if ( isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class): break if idx < 0: raise ValueError( f"Can not find connector with name or type '{name_or_class}'!" ) next_connector = self.connectors[idx] self.connectors.insert(idx, connector) self._fix_spaces(self.input_observation_space, self.input_action_space) logger.info( f"Inserted {connector.__class__.__name__} before {name_or_class} " f"to {self.__class__.__name__}." ) return next_connector
[docs] def insert_after( self, name_or_class: Union[str, Type], connector: ConnectorV2, ) -> ConnectorV2: """Insert a new connector piece after an existing piece (by name or class). Args: name_or_class: Name or class of the connector piece after which `connector` will get inserted. connector: The new connector piece to be inserted. Returns: The ConnectorV2 after which `connector` has been inserted. """ idx = -1 for idx, c in enumerate(self.connectors): if ( isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class): break if idx < 0: raise ValueError( f"Can not find connector with name or type '{name_or_class}'!" ) prev_connector = self.connectors[idx] self.connectors.insert(idx + 1, connector) self._fix_spaces(self.input_observation_space, self.input_action_space) logger.info( f"Inserted {connector.__class__.__name__} after {name_or_class} " f"to {self.__class__.__name__}." ) return prev_connector
[docs] def prepend(self, connector: ConnectorV2) -> None: """Prepend a new connector at the beginning of a connector pipeline. Args: connector: The new connector piece to be prepended to this pipeline. """ self.connectors.insert(0, connector) self._fix_spaces(self.input_observation_space, self.input_action_space) logger.info( f"Added {connector.__class__.__name__} to the beginning of " f"{self.__class__.__name__}." )
[docs] def append(self, connector: ConnectorV2) -> None: """Append a new connector at the end of a connector pipeline. Args: connector: The new connector piece to be appended to this pipeline. """ self.connectors.append(connector) self._fix_spaces(self.input_observation_space, self.input_action_space) logger.info( f"Added {connector.__class__.__name__} to the end of " f"{self.__class__.__name__}." )
[docs] @override(ConnectorV2) def get_state( self, components: Optional[Union[str, Collection[str]]] = None, *, not_components: Optional[Union[str, Collection[str]]] = None, **kwargs, ) -> StateDict: state = {} for conn in self.connectors: conn_name = type(conn).__name__ if self._check_component(conn_name, components, not_components): sts = conn.get_state( components=self._get_subcomponents(conn_name, components), not_components=self._get_subcomponents(conn_name, not_components), **kwargs, ) # Ignore empty dicts. if sts: state[conn_name] = sts return state
[docs] @override(ConnectorV2) def set_state(self, state: Dict[str, Any]) -> None: for conn in self.connectors: conn_name = type(conn).__name__ if conn_name in state: conn.set_state(state[conn_name])
[docs] @override(Checkpointable) def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]: return [(type(conn).__name__, conn) for conn in self.connectors]
# Note that we don't have to override Checkpointable.get_ctor_args_and_kwargs and # don't have to return the `connectors` c'tor kwarg from there. This is b/c all # connector pieces in this pipeline are themselves Checkpointable components, # so they will be properly written into this pipeline's checkpoint.
[docs] @override(Checkpointable) def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: return ( (self.input_observation_space, self.input_action_space), # *args { "connectors": [ (type(conn), *conn.get_ctor_args_and_kwargs()) for conn in self.connectors ] }, )
[docs] @override(ConnectorV2) def reset_state(self) -> None: for conn in self.connectors: conn.reset_state()
[docs] @override(ConnectorV2) def merge_states(self, states: List[Dict[str, Any]]) -> Dict[str, Any]: merged_states = {} if not states: return merged_states for i, (key, item) in enumerate(states[0].items()): state_list = [state[key] for state in states] conn = self.connectors[i] merged_states[key] = conn.merge_states(state_list) return merged_states
def __repr__(self, indentation: int = 0): return "\n".join( [" " * indentation + self.__class__.__name__] + [c.__str__(indentation + 4) for c in self.connectors] ) def __getitem__( self, key: Union[str, int, Type], ) -> Union[ConnectorV2, List[ConnectorV2]]: """Returns a single ConnectorV2 or list of ConnectorV2s that fit `key`. If key is an int, we return a single ConnectorV2 at that index in this pipeline. If key is a ConnectorV2 type or a string matching the class name of a ConnectorV2 in this pipeline, we return a list of all ConnectorV2s in this pipeline matching the specified class. Args: key: The key to find or to index by. Returns: A single ConnectorV2 or a list of ConnectorV2s matching `key`. """ # Key is an int -> Index into pipeline and return. if isinstance(key, int): return self.connectors[key] # Key is a class. elif isinstance(key, type): results = [] for c in self.connectors: if issubclass(c.__class__, key): results.append(c) return results # Key is a string -> Find connector(s) by name. elif isinstance(key, str): results = [] for c in self.connectors: if c.name == key: results.append(c) return results # Slicing not supported (yet). elif isinstance(key, slice): raise NotImplementedError( "Slicing of ConnectorPipelineV2 is currently not supported!" ) else: raise NotImplementedError( f"Indexing ConnectorPipelineV2 by {type(key)} is currently not " f"supported!" ) @property def observation_space(self): if len(self) > 0: return self.connectors[-1].observation_space return self._observation_space @property def action_space(self): if len(self) > 0: return self.connectors[-1].action_space return self._action_space def _fix_spaces(self, input_observation_space, input_action_space): if len(self) > 0: # Fix each connector's input_observation- and input_action space in # the pipeline. obs_space = input_observation_space act_space = input_action_space for con in self.connectors: con.input_action_space = act_space con.input_observation_space = obs_space obs_space = con.observation_space act_space = con.action_space