Source code for ray.rllib.offline.offline_prelearner

import gymnasium as gym
import logging
import numpy as np
import uuid

from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING

from ray.actor import ActorHandle
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner import Learner
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
from ray.rllib.utils.annotations import (
    OverrideToImplementCustomLogic,
    OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.spaces.space_utils import from_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType, ModuleID
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.rllib.algorithms.algorithm_config import AlgorithmConfig

#: This is the default schema used if no `input_read_schema` is set in
#: the config. If a user passes in a schema into `input_read_schema`
#: this user-defined schema has to comply with the keys of `SCHEMA`,
#: while values correspond to the columns in the user's dataset. Note
#: that only the user-defined values will be overridden while all
#: other values from SCHEMA remain as defined here.
SCHEMA = {
    Columns.EPS_ID: Columns.EPS_ID,
    Columns.AGENT_ID: Columns.AGENT_ID,
    Columns.MODULE_ID: Columns.MODULE_ID,
    Columns.OBS: Columns.OBS,
    Columns.ACTIONS: Columns.ACTIONS,
    Columns.REWARDS: Columns.REWARDS,
    Columns.INFOS: Columns.INFOS,
    Columns.NEXT_OBS: Columns.NEXT_OBS,
    Columns.TERMINATEDS: Columns.TERMINATEDS,
    Columns.TRUNCATEDS: Columns.TRUNCATEDS,
    Columns.T: Columns.T,
    # TODO (simon): Add remove as soon as we are new stack only.
    "agent_index": "agent_index",
    "dones": "dones",
    "unroll_id": "unroll_id",
}

logger = logging.getLogger(__name__)


[docs] @PublicAPI(stability="alpha") class OfflinePreLearner: """Class that coordinates data transformation from dataset to learner. This class is an essential part of the new `Offline RL API` of `RLlib`. It is a callable class that is run in `ray.data.Dataset.map_batches` when iterating over batches for training. It's basic function is to convert data in batch from rows to episodes (`SingleAGentEpisode`s for now) and to then run the learner connector pipeline to convert further to trainable batches. These batches are used directly in the `Learner`'s `update` method. The main reason to run these transformations inside of `map_batches` is for better performance. Batches can be pre-fetched in `ray.data` and therefore batch trransformation can be run highly parallelized to the `Learner''s `update`. This class can be overridden to implement custom logic for transforming batches and make them 'Learner'-ready. When deriving from this class the `__call__` method and `_map_to_episodes` can be overridden to induce custom logic for the complete transformation pipeline (`__call__`) or for converting to episodes only ('_map_to_episodes`). For an example how this class can be used to also compute values and advantages see `rllib.algorithm.marwil.marwil_prelearner.MAWRILOfflinePreLearner`. Custom `OfflinePreLearner` classes can be passed into `AlgorithmConfig.offline`'s `prelearner_class`. The `OfflineData` class will then use the custom class in its data pipeline. """
[docs] @OverrideToImplementCustomLogic_CallToSuperRecommended def __init__( self, *, config: "AlgorithmConfig", learner: Union[Learner, list[ActorHandle]], spaces: Optional[Tuple[gym.Space, gym.Space]] = None, module_spec: Optional[MultiRLModuleSpec] = None, module_state: Optional[Dict[ModuleID, Any]] = None, **kwargs: Dict[str, Any], ): self.config = config self.input_read_episodes = self.config.input_read_episodes self.input_read_sample_batches = self.config.input_read_sample_batches # We need this learner to run the learner connector pipeline. # If it is a `Learner` instance, the `Learner` is local. if isinstance(learner, Learner): self._learner = learner self.learner_is_remote = False self._module = self._learner._module # Otherwise we have remote `Learner`s. else: self.learner_is_remote = True # Build the module from spec. Note, this will be a MultiRLModule. self._module = module_spec.build() self._module.set_state(module_state) # Store the observation and action space if defined, otherwise we # set them to `None`. Note, if `None` the `convert_from_jsonable` # will not convert the input space samples. self.observation_space, self.action_space = spaces or (None, None) # Build the learner connector pipeline. self._learner_connector = self.config.build_learner_connector( input_observation_space=self.observation_space, input_action_space=self.action_space, ) # Cache the policies to be trained to update weights only for these. self._policies_to_train = self.config.policies_to_train self._is_multi_agent = config.is_multi_agent # Set the counter to zero. self.iter_since_last_module_update = 0 # self._future = None # Set up an episode buffer, if the module is stateful or we sample from # `SampleBatch` types. if ( self.input_read_sample_batches or self._module.is_stateful() or self.input_read_episodes ): # Either the user defined a buffer class or we fall back to the default. prelearner_buffer_class = ( self.config.prelearner_buffer_class or self.default_prelearner_buffer_class ) prelearner_buffer_kwargs = ( self.default_prelearner_buffer_kwargs | self.config.prelearner_buffer_kwargs ) # Initialize the buffer. self.episode_buffer = prelearner_buffer_class( **prelearner_buffer_kwargs, )
[docs] @OverrideToImplementCustomLogic def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]: """Prepares plain data batches for training with `Learner`'s. Args: batch: A dictionary of numpy arrays containing either column data with `self.config.input_read_schema`, `EpisodeType` data, or `BatchType` data. Returns: A `MultiAgentBatch` that can be passed to `Learner.update` methods. """ # If we directly read in episodes we just convert to list. if self.input_read_episodes: # Import `msgpack` for decoding. import msgpack import msgpack_numpy as mnp # Read the episodes and decode them. episodes = [ SingleAgentEpisode.from_state( msgpack.unpackb(state, object_hook=mnp.decode) ) for state in batch["item"] ] # Ensure that all episodes are done and no duplicates are in the batch. episodes = self._validate_episodes(episodes) # Add the episodes to the buffer. self.episode_buffer.add(episodes) # TODO (simon): Refactor into a single code block for both cases. episodes = self.episode_buffer.sample( num_items=self.config.train_batch_size_per_learner, batch_length_T=self.config.model_config.get("max_seq_len", 0) if self._module.is_stateful() else None, n_step=self.config.get("n_step", 1) or 1, # TODO (simon): This can be removed as soon as DreamerV3 has been # cleaned up, i.e. can use episode samples for training. sample_episodes=True, finalize=True, ) # Else, if we have old stack `SampleBatch`es. elif self.input_read_sample_batches: episodes = OfflinePreLearner._map_sample_batch_to_episode( self._is_multi_agent, batch, finalize=True, schema=SCHEMA | self.config.input_read_schema, input_compress_columns=self.config.input_compress_columns, )["episodes"] # Ensure that all episodes are done and no duplicates are in the batch. episodes = self._validate_episodes(episodes) # Add the episodes to the buffer. self.episode_buffer.add(episodes) # Sample steps from the buffer. episodes = self.episode_buffer.sample( num_items=self.config.train_batch_size_per_learner, batch_length_T=self.config.model_config.get("max_seq_len", 0) if self._module.is_stateful() else None, n_step=self.config.get("n_step", 1) or 1, # TODO (simon): This can be removed as soon as DreamerV3 has been # cleaned up, i.e. can use episode samples for training. sample_episodes=True, finalize=True, ) # Otherwise we map the batch to episodes. else: episodes = self._map_to_episodes( self._is_multi_agent, batch, schema=SCHEMA | self.config.input_read_schema, finalize=True, input_compress_columns=self.config.input_compress_columns, observation_space=self.observation_space, action_space=self.action_space, )["episodes"] # TODO (simon): Make synching work. Right now this becomes blocking or never # receives weights. Learners appear to be non accessable via other actors. # Increase the counter for updating the module. # self.iter_since_last_module_update += 1 # if self._future: # refs, _ = ray.wait([self._future], timeout=0) # print(f"refs: {refs}") # if refs: # module_state = ray.get(self._future) # # self._module.set_state(module_state) # self._future = None # # Synch the learner module, if necessary. Note, in case of a local learner # # we have a reference to the module and therefore an up-to-date module. # if self.learner_is_remote and self.iter_since_last_module_update # > self.config.prelearner_module_synch_period: # # Reset the iteration counter. # self.iter_since_last_module_update = 0 # # Request the module weights from the remote learner. # self._future = # self._learner.get_module_state.remote(inference_only=False) # # module_state = # ray.get(self._learner.get_module_state.remote(inference_only=False)) # # self._module.set_state(module_state) # Run the `Learner`'s connector pipeline. batch = self._learner_connector( rl_module=self._module, batch={}, episodes=episodes, shared_data={}, ) # Convert to `MultiAgentBatch`. batch = MultiAgentBatch( { module_id: SampleBatch(module_data) for module_id, module_data in batch.items() }, # TODO (simon): This can be run once for the batch and the # metrics, but we run it twice: here and later in the learner. env_steps=sum(e.env_steps() for e in episodes), ) # Remove all data from modules that should not be trained. We do # not want to pass around more data than necessaty. for module_id in list(batch.policy_batches.keys()): if not self._should_module_be_updated(module_id, batch): del batch.policy_batches[module_id] # TODO (simon): Log steps trained for metrics (how?). At best in learner # and not here. But we could precompute metrics here and pass it to the learner # for logging. Like this we do not have to pass around episode lists. # TODO (simon): episodes are only needed for logging here. return {"batch": [batch]}
@property def default_prelearner_buffer_class(self) -> ReplayBuffer: """Sets the default replay buffer.""" from ray.rllib.utils.replay_buffers.episode_replay_buffer import ( EpisodeReplayBuffer, ) # Return the buffer. return EpisodeReplayBuffer @property def default_prelearner_buffer_kwargs(self) -> Dict[str, Any]: """Sets the default arguments for the replay buffer. Note, the `capacity` might vary with the size of the episodes or sample batches in the offline dataset. """ return { "capacity": self.config.train_batch_size_per_learner * 10, "batch_size_B": self.config.train_batch_size_per_learner, } def _validate_episodes( self, episodes: List[SingleAgentEpisode] ) -> Set[SingleAgentEpisode]: """Validate episodes sampled from the dataset. Note, our episode buffers cannot handle either duplicates nor non-ordered fragmentations, i.e. fragments from episodes that do not arrive in timestep order. Args: episodes: A list of `SingleAgentEpisode` instances sampled from a dataset. Returns: A set of `SingleAgentEpisode` instances. Raises: ValueError: If not all episodes are `done`. """ # Ensure that episodes are all done. if not all(eps.is_done for eps in episodes): raise ValueError( "When sampling from episodes (`input_read_episodes=True`) all " "recorded episodes must be done (i.e. either `terminated=True`) " "or `truncated=True`)." ) # Ensure that episodes do not contain duplicates. Note, this can happen # if the dataset is small and pulled batches contain multiple episodes. unique_episode_ids = set() episodes = { eps for eps in episodes if eps.id_ not in unique_episode_ids and not unique_episode_ids.add(eps.id_) and eps.id_ not in self.episode_buffer.episode_id_to_index.keys() } return episodes
[docs] def _should_module_be_updated(self, module_id, multi_agent_batch=None) -> bool: """Checks which modules in a MultiRLModule should be updated.""" if not self._policies_to_train: # In case of no update information, the module is updated. return True elif not callable(self._policies_to_train): return module_id in set(self._policies_to_train) else: return self._policies_to_train(module_id, multi_agent_batch)
[docs] @OverrideToImplementCustomLogic @staticmethod def _map_to_episodes( is_multi_agent: bool, batch: Dict[str, Union[list, np.ndarray]], schema: Dict[str, str] = SCHEMA, finalize: bool = False, input_compress_columns: Optional[List[str]] = None, observation_space: gym.Space = None, action_space: gym.Space = None, **kwargs: Dict[str, Any], ) -> Dict[str, List[EpisodeType]]: """Maps a batch of data to episodes.""" # Set to empty list, if `None`. input_compress_columns = input_compress_columns or [] # If spaces are given, we can use the space-specific # conversion method to convert space samples. if observation_space and action_space: convert = from_jsonable_if_needed # Otherwise we use an identity function. else: def convert(sample, space): return sample episodes = [] for i, obs in enumerate(batch[schema[Columns.OBS]]): # If multi-agent we need to extract the agent ID. # TODO (simon): Check, what happens with the module ID. if is_multi_agent: agent_id = ( batch[schema[Columns.AGENT_ID]][i] if Columns.AGENT_ID in batch # The old stack uses "agent_index" instead of "agent_id". # TODO (simon): Remove this as soon as we are new stack only. else ( batch[schema["agent_index"]][i] if schema["agent_index"] in batch else None ) ) else: agent_id = None if is_multi_agent: # TODO (simon): Add support for multi-agent episodes. NotImplementedError else: # Build a single-agent episode with a single row of the batch. episode = SingleAgentEpisode( id_=str(batch[schema[Columns.EPS_ID]][i]), agent_id=agent_id, # Observations might be (a) serialized and/or (b) converted # to a JSONable (when a composite space was used). We unserialize # and then reconvert from JSONable to space sample. observations=[ convert(unpack_if_needed(obs), observation_space) if Columns.OBS in input_compress_columns else convert(obs, observation_space), convert( unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]), observation_space, ) if Columns.OBS in input_compress_columns else convert( batch[schema[Columns.NEXT_OBS]][i], observation_space ), ], infos=[ {}, batch[schema[Columns.INFOS]][i] if schema[Columns.INFOS] in batch else {}, ], # Actions might be (a) serialized and/or (b) converted to a JSONable # (when a composite space was used). We unserializer and then # reconvert from JSONable to space sample. actions=[ convert( unpack_if_needed(batch[schema[Columns.ACTIONS]][i]), action_space, ) if Columns.ACTIONS in input_compress_columns else convert(batch[schema[Columns.ACTIONS]][i], action_space) ], rewards=[batch[schema[Columns.REWARDS]][i]], terminated=batch[ schema[Columns.TERMINATEDS] if schema[Columns.TERMINATEDS] in batch else "dones" ][i], truncated=batch[schema[Columns.TRUNCATEDS]][i] if schema[Columns.TRUNCATEDS] in batch else False, # TODO (simon): Results in zero-length episodes in connector. # t_started=batch[Columns.T if Columns.T in batch else # "unroll_id"][i][0], # TODO (simon): Single-dimensional columns are not supported. # Extra model outputs might be serialized. We unserialize them here # if needed. # TODO (simon): Check, if we need here also reconversion from # JSONable in case of composite spaces. extra_model_outputs={ k: [ unpack_if_needed(v[i]) if k in input_compress_columns else v[i] ] for k, v in batch.items() if ( k not in schema and k not in schema.values() and k not in ["dones", "agent_index", "type"] ) }, len_lookback_buffer=0, ) if finalize: episode.finalize() episodes.append(episode) # Note, `map_batches` expects a `Dict` as return value. return {"episodes": episodes}
[docs] @OverrideToImplementCustomLogic @staticmethod def _map_sample_batch_to_episode( is_multi_agent: bool, batch: Dict[str, Union[list, np.ndarray]], schema: Dict[str, str] = SCHEMA, finalize: bool = False, input_compress_columns: Optional[List[str]] = None, ) -> Dict[str, List[EpisodeType]]: """Maps an old stack `SampleBatch` to new stack episodes.""" # Set `input_compress_columns` to an empty `list` if `None`. input_compress_columns = input_compress_columns or [] # TODO (simon): CHeck, if needed. It could possibly happen that a batch contains # data from different episodes. Merging and resplitting the batch would then # be the solution. # Check, if batch comes actually from multiple episodes. # episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1 # Define a container to collect episodes. episodes = [] # Loop over `SampleBatch`es in the `ray.data` batch (a dict). for i, obs in enumerate(batch[schema[Columns.OBS]]): # If multi-agent we need to extract the agent ID. # TODO (simon): Check, what happens with the module ID. if is_multi_agent: agent_id = ( # The old stack uses "agent_index" instead of "agent_id". batch[schema["agent_index"]][i][0] if schema["agent_index"] in batch else None ) else: agent_id = None if is_multi_agent: # TODO (simon): Add support for multi-agent episodes. NotImplementedError else: # Unpack observations, if needed. Note, observations could # be either compressed by their entirety (the complete batch # column) or individually (each column entry). if isinstance(obs, str): # Decompress the observations if we have a string, i.e. # observations are compressed in their entirety. obs = unpack_if_needed(obs) # Convert to a list of arrays. This is needed as input by # the `SingleAgentEpisode`. obs = [obs[i, ...] for i in range(obs.shape[0])] # Otherwise observations are only compressed inside of the # batch column (if at all). elif isinstance(obs, np.ndarray): # Unpack observations, if they are compressed otherwise we # simply convert to a list, which is needed by the # `SingleAgentEpisode`. obs = ( unpack_if_needed(obs.tolist()) if schema[Columns.OBS] in input_compress_columns else obs.tolist() ) else: raise TypeError( f"Unknown observation type: {type(obs)}. When mapping " "from old recorded `SampleBatches` batched " "observations should be either of type `np.array` " "or - if the column is compressed - of `str` type." ) if schema[Columns.NEXT_OBS] in batch: # Append the last `new_obs` to get the correct length of # observations. obs.append( unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1]) if schema[Columns.OBS] in input_compress_columns else batch[schema[Columns.NEXT_OBS]][i][-1] ) else: # Otherwise we duplicate the last observation. obs.append(obs[-1]) # Check, if we have `done`, `truncated`, or `terminated`s in # the batch. if ( schema[Columns.TRUNCATEDS] in batch and schema[Columns.TERMINATEDS] in batch ): truncated = batch[schema[Columns.TRUNCATEDS]][i][-1] terminated = batch[schema[Columns.TERMINATEDS]][i][-1] elif ( schema[Columns.TRUNCATEDS] in batch and schema[Columns.TERMINATEDS] not in batch ): truncated = batch[schema[Columns.TRUNCATEDS]][i][-1] terminated = False elif ( schema[Columns.TRUNCATEDS] not in batch and schema[Columns.TERMINATEDS] in batch ): terminated = batch[schema[Columns.TERMINATEDS]][i][-1] truncated = False elif "done" in batch: terminated = batch["done"][i][-1] truncated = False # Otherwise, if no `terminated`, nor `truncated` nor `done` # is given, we consider the episode as terminated. else: terminated = True truncated = False # Create a `SingleAgentEpisode`. episode = SingleAgentEpisode( # If the recorded episode has an ID we use this ID, # otherwise we generate a new one. id_=str(batch[schema[Columns.EPS_ID]][i][0]) if schema[Columns.EPS_ID] in batch else uuid.uuid4().hex, agent_id=agent_id, observations=obs, infos=( batch[schema[Columns.INFOS]][i] if schema[Columns.INFOS] in batch else [{}] * len(obs) ), # Actions might be (a) serialized. We unserialize them here. actions=( unpack_if_needed(batch[schema[Columns.ACTIONS]][i]) if Columns.ACTIONS in input_compress_columns else batch[schema[Columns.ACTIONS]][i] ), rewards=batch[schema[Columns.REWARDS]][i], terminated=terminated, truncated=truncated, # TODO (simon): Results in zero-length episodes in connector. # t_started=batch[Columns.T if Columns.T in batch else # "unroll_id"][i][0], # TODO (simon): Single-dimensional columns are not supported. # Extra model outputs might be serialized. We unserialize them here # if needed. # TODO (simon): Check, if we need here also reconversion from # JSONable in case of composite spaces. extra_model_outputs={ k: unpack_if_needed(v[i]) if k in input_compress_columns else v[i] for k, v in batch.items() if ( k not in schema and k not in schema.values() and k not in ["dones", "agent_index", "type"] ) }, len_lookback_buffer=0, ) # Finalize, if necessary. # TODO (simon, sven): Check, if we should convert all data to lists # before. Right now only obs are lists. if finalize: episode.finalize() episodes.append(episode) # Note, `map_batches` expects a `Dict` as return value. return {"episodes": episodes}