import copy
import gymnasium as gym
import logging
import numpy as np
import tree
import uuid
from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec, MultiRLModule
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils import flatten_dict
from ray.rllib.utils.annotations import (
    OverrideToImplementCustomLogic,
    OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.compression import unpack_if_needed
from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
from ray.rllib.utils.spaces.space_utils import from_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType, ModuleID
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
    from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
#: This is the default schema used if no `input_read_schema` is set in
#: the config. If a user passes in a schema into `input_read_schema`
#: this user-defined schema has to comply with the keys of `SCHEMA`,
#: while values correspond to the columns in the user's dataset. Note
#: that only the user-defined values will be overridden while all
#: other values from SCHEMA remain as defined here.
SCHEMA = {
    Columns.EPS_ID: Columns.EPS_ID,
    Columns.AGENT_ID: Columns.AGENT_ID,
    Columns.MODULE_ID: Columns.MODULE_ID,
    Columns.OBS: Columns.OBS,
    Columns.ACTIONS: Columns.ACTIONS,
    Columns.REWARDS: Columns.REWARDS,
    Columns.INFOS: Columns.INFOS,
    Columns.NEXT_OBS: Columns.NEXT_OBS,
    Columns.TERMINATEDS: Columns.TERMINATEDS,
    Columns.TRUNCATEDS: Columns.TRUNCATEDS,
    Columns.T: Columns.T,
    # TODO (simon): Add remove as soon as we are new stack only.
    "agent_index": "agent_index",
    "dones": "dones",
    "unroll_id": "unroll_id",
}
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="alpha")
class OfflinePreLearner:
    """Class that coordinates data transformation from dataset to learner.
    This class is an essential part of the new `Offline RL API` of `RLlib`.
    It is a callable class that is run in `ray.data.Dataset.map_batches`
    when iterating over batches for training. It's basic function is to
    convert data in batch from rows to episodes (`SingleAGentEpisode`s
    for now) and to then run the learner connector pipeline to convert
    further to trainable batches. These batches are used directly in the
    `Learner`'s `update` method.
    The main reason to run these transformations inside of `map_batches`
    is for better performance. Batches can be pre-fetched in `ray.data`
    and therefore batch trransformation can be run highly parallelized to
    the `Learner''s `update`.
    This class can be overridden to implement custom logic for transforming
    batches and make them 'Learner'-ready. When deriving from this class
    the `__call__` method and `_map_to_episodes` can be overridden to induce
    custom logic for the complete transformation pipeline (`__call__`) or
    for converting to episodes only ('_map_to_episodes`).
    Custom `OfflinePreLearner` classes can be passed into
    `AlgorithmConfig.offline`'s `prelearner_class`. The `OfflineData` class
    will then use the custom class in its data pipeline.
    """
[docs]
    @OverrideToImplementCustomLogic_CallToSuperRecommended
    def __init__(
        self,
        *,
        config: "AlgorithmConfig",
        spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
        module_spec: Optional[MultiRLModuleSpec] = None,
        module_state: Optional[Dict[ModuleID, Any]] = None,
        **kwargs: Dict[str, Any],
    ):
        self.config: AlgorithmConfig = config
        self.input_read_episodes: bool = self.config.input_read_episodes
        self.input_read_sample_batches: bool = self.config.input_read_sample_batches
        # Build the module from spec.
        self._module: MultiRLModule = module_spec.build()
        self._module.set_state(module_state)
        # Map the module to the device, if necessary.
        # TODO (simon): Check here if we already have a list.
        # self._set_device(device_strings)
        # Store the observation and action space if defined, otherwise we
        # set them to `None`. Note, if `None` the `convert_from_jsonable`
        # will not convert the input space samples.
        self.observation_space, self.action_space = spaces or (None, None)
        # Build the learner connector pipeline.
        self._learner_connector = self.config.build_learner_connector(
            input_observation_space=self.observation_space,
            input_action_space=self.action_space,
        )
        # Cache the policies to be trained to update weights only for these.
        self._policies_to_train = self.config.policies_to_train
        self._is_multi_agent: bool = config.is_multi_agent
        # Set the counter to zero.
        self.iter_since_last_module_update: int = 0
        # self._future = None
        # Set up an episode buffer, if the module is stateful or we sample from
        # `SampleBatch` types.
        if (
            self.input_read_sample_batches
            or self._module.is_stateful()
            or self.input_read_episodes
        ):
            # Either the user defined a buffer class or we fall back to the default.
            prelearner_buffer_class = (
                self.config.prelearner_buffer_class
                or self.default_prelearner_buffer_class
            )
            prelearner_buffer_kwargs = (
                self.default_prelearner_buffer_kwargs
                | self.config.prelearner_buffer_kwargs
            )
            # Initialize the buffer.
            self.episode_buffer = prelearner_buffer_class(
                **prelearner_buffer_kwargs,
            ) 
[docs]
    @OverrideToImplementCustomLogic
    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """Prepares plain data batches for training with `Learner`'s.
        Args:
            batch: A dictionary of numpy arrays containing either column data
                with `self.config.input_read_schema`, `EpisodeType` data, or
                `BatchType` data.
        Returns:
            A `MultiAgentBatch` that can be passed to `Learner.update` methods.
        """
        # If we directly read in episodes we just convert to list.
        if self.input_read_episodes:
            # Import `msgpack` for decoding.
            import msgpack
            import msgpack_numpy as mnp
            # Read the episodes and decode them.
            episodes: List[SingleAgentEpisode] = [
                SingleAgentEpisode.from_state(
                    msgpack.unpackb(state, object_hook=mnp.decode)
                )
                for state in batch["item"]
            ]
            # Ensure that all episodes are done and no duplicates are in the batch.
            episodes = self._validate_episodes(episodes)
            # Add the episodes to the buffer.
            self.episode_buffer.add(episodes)
            # TODO (simon): Refactor into a single code block for both cases.
            episodes = self.episode_buffer.sample(
                num_items=self.config.train_batch_size_per_learner,
                batch_length_T=self.config.model_config.get("max_seq_len", 0)
                if self._module.is_stateful()
                else None,
                n_step=self.config.get("n_step", 1) or 1,
                # TODO (simon): This can be removed as soon as DreamerV3 has been
                # cleaned up, i.e. can use episode samples for training.
                sample_episodes=True,
                to_numpy=True,
            )
        # Else, if we have old stack `SampleBatch`es.
        elif self.input_read_sample_batches:
            episodes: List[
                SingleAgentEpisode
            ] = OfflinePreLearner._map_sample_batch_to_episode(
                self._is_multi_agent,
                batch,
                to_numpy=True,
                schema=SCHEMA | self.config.input_read_schema,
                input_compress_columns=self.config.input_compress_columns,
            )[
                "episodes"
            ]
            # Ensure that all episodes are done and no duplicates are in the batch.
            episodes = self._validate_episodes(episodes)
            # Add the episodes to the buffer.
            self.episode_buffer.add(episodes)
            # Sample steps from the buffer.
            episodes = self.episode_buffer.sample(
                num_items=self.config.train_batch_size_per_learner,
                batch_length_T=self.config.model_config.get("max_seq_len", 0)
                if self._module.is_stateful()
                else None,
                n_step=self.config.get("n_step", 1) or 1,
                # TODO (simon): This can be removed as soon as DreamerV3 has been
                # cleaned up, i.e. can use episode samples for training.
                sample_episodes=True,
                to_numpy=True,
            )
        # Otherwise we map the batch to episodes.
        else:
            episodes: List[SingleAgentEpisode] = self._map_to_episodes(
                self._is_multi_agent,
                batch,
                schema=SCHEMA | self.config.input_read_schema,
                to_numpy=False,
                input_compress_columns=self.config.input_compress_columns,
                observation_space=self.observation_space,
                action_space=self.action_space,
            )["episodes"]
        # TODO (simon): Make synching work. Right now this becomes blocking or never
        #  receives weights. Learners appear to be non accessi ble via other actors.
        # Increase the counter for updating the module.
        # self.iter_since_last_module_update += 1
        # if self._future:
        #     refs, _ = ray.wait([self._future], timeout=0)
        #     print(f"refs: {refs}")
        #     if refs:
        #         module_state = ray.get(self._future)
        #
        #         self._module.set_state(module_state)
        #         self._future = None
        # # Synch the learner module, if necessary. Note, in case of a local learner
        # # we have a reference to the module and therefore an up-to-date module.
        # if self.learner_is_remote and self.iter_since_last_module_update
        # > self.config.prelearner_module_synch_period:
        #     # Reset the iteration counter.
        #     self.iter_since_last_module_update = 0
        #     # Request the module weights from the remote learner.
        #     self._future =
        # self._learner.get_module_state.remote(inference_only=False)
        #     # module_state =
        # ray.get(self._learner.get_module_state.remote(inference_only=False))
        #     # self._module.set_state(module_state)
        # Run the `Learner`'s connector pipeline.
        batch = self._learner_connector(
            rl_module=self._module,
            batch={},
            episodes=episodes,
            shared_data={},
            # TODO (sven): Add MetricsLogger to non-Learner components that have a
            #  LearnerConnector pipeline.
            metrics=None,
        )
        # Remove all data from modules that should not be trained. We do
        # not want to pass around more data than necessary.
        for module_id in batch:
            if not self._should_module_be_updated(module_id, batch):
                del batch[module_id]
        # Flatten the dictionary to increase serialization performance.
        return flatten_dict(batch) 
    @property
    def default_prelearner_buffer_class(self) -> ReplayBuffer:
        """Sets the default replay buffer."""
        from ray.rllib.utils.replay_buffers.episode_replay_buffer import (
            EpisodeReplayBuffer,
        )
        # Return the buffer.
        return EpisodeReplayBuffer
    @property
    def default_prelearner_buffer_kwargs(self) -> Dict[str, Any]:
        """Sets the default arguments for the replay buffer.
        Note, the `capacity` might vary with the size of the episodes or
        sample batches in the offline dataset.
        """
        return {
            "capacity": self.config.train_batch_size_per_learner * 10,
            "batch_size_B": self.config.train_batch_size_per_learner,
        }
    def _validate_episodes(
        self, episodes: List[SingleAgentEpisode]
    ) -> Set[SingleAgentEpisode]:
        """Validate episodes sampled from the dataset.
        Note, our episode buffers cannot handle either duplicates nor
        non-ordered fragmentations, i.e. fragments from episodes that do
        not arrive in timestep order.
        Args:
            episodes: A list of `SingleAgentEpisode` instances sampled
                from a dataset.
        Returns:
            A set of `SingleAgentEpisode` instances.
        Raises:
            ValueError: If not all episodes are `done`.
        """
        # Ensure that episodes are all done.
        if not all(eps.is_done for eps in episodes):
            raise ValueError(
                "When sampling from episodes (`input_read_episodes=True`) all "
                "recorded episodes must be done (i.e. either `terminated=True`) "
                "or `truncated=True`)."
            )
        # Ensure that episodes do not contain duplicates. Note, this can happen
        # if the dataset is small and pulled batches contain multiple episodes.
        unique_episode_ids = set()
        episodes = {
            eps
            for eps in episodes
            if eps.id_ not in unique_episode_ids
            and not unique_episode_ids.add(eps.id_)
            and eps.id_ not in self.episode_buffer.episode_id_to_index.keys()
        }
        return episodes
[docs]
    def _should_module_be_updated(self, module_id, multi_agent_batch=None) -> bool:
        """Checks which modules in a MultiRLModule should be updated."""
        if not self._policies_to_train:
            # In case of no update information, the module is updated.
            return True
        elif not callable(self._policies_to_train):
            return module_id in set(self._policies_to_train)
        else:
            return self._policies_to_train(module_id, multi_agent_batch) 
[docs]
    @OverrideToImplementCustomLogic
    @staticmethod
    def _map_to_episodes(
        is_multi_agent: bool,
        batch: Dict[str, Union[list, np.ndarray]],
        schema: Dict[str, str] = SCHEMA,
        to_numpy: bool = False,
        input_compress_columns: Optional[List[str]] = None,
        ignore_final_observation: Optional[bool] = False,
        observation_space: gym.Space = None,
        action_space: gym.Space = None,
        **kwargs: Dict[str, Any],
    ) -> Dict[str, List[EpisodeType]]:
        """Maps a batch of data to episodes."""
        # Set to empty list, if `None`.
        input_compress_columns = input_compress_columns or []
        # If spaces are given, we can use the space-specific
        # conversion method to convert space samples.
        if observation_space and action_space:
            convert = from_jsonable_if_needed
        # Otherwise we use an identity function.
        else:
            def convert(sample, space):
                return sample
        episodes = []
        for i, obs in enumerate(batch[schema[Columns.OBS]]):
            # If multi-agent we need to extract the agent ID.
            # TODO (simon): Check, what happens with the module ID.
            if is_multi_agent:
                agent_id = (
                    batch[schema[Columns.AGENT_ID]][i]
                    if Columns.AGENT_ID in batch
                    # The old stack uses "agent_index" instead of "agent_id".
                    # TODO (simon): Remove this as soon as we are new stack only.
                    else (
                        batch[schema["agent_index"]][i]
                        if schema["agent_index"] in batch
                        else None
                    )
                )
            else:
                agent_id = None
            if is_multi_agent:
                # TODO (simon): Add support for multi-agent episodes.
                pass
            else:
                # Unpack observations, if needed.
                unpacked_obs = (
                    convert(unpack_if_needed(obs), observation_space)
                    if Columns.OBS in input_compress_columns
                    else convert(obs, observation_space)
                )
                # Set the next observation.
                if ignore_final_observation:
                    unpacked_next_obs = tree.map_structure(
                        lambda x: 0 * x, copy.deepcopy(unpacked_obs)
                    )
                else:
                    unpacked_next_obs = (
                        convert(
                            unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]),
                            observation_space,
                        )
                        if Columns.OBS in input_compress_columns
                        else convert(
                            batch[schema[Columns.NEXT_OBS]][i], observation_space
                        )
                    )
                # Build a single-agent episode with a single row of the batch.
                episode = SingleAgentEpisode(
                    id_=str(batch[schema[Columns.EPS_ID]][i])
                    if schema[Columns.EPS_ID] in batch
                    else uuid.uuid4().hex,
                    agent_id=agent_id,
                    # Observations might be (a) serialized and/or (b) converted
                    # to a JSONable (when a composite space was used). We unserialize
                    # and then reconvert from JSONable to space sample.
                    observations=[unpacked_obs, unpacked_next_obs],
                    infos=[
                        {},
                        batch[schema[Columns.INFOS]][i]
                        if schema[Columns.INFOS] in batch
                        else {},
                    ],
                    # Actions might be (a) serialized and/or (b) converted to a JSONable
                    # (when a composite space was used). We unserializer and then
                    # reconvert from JSONable to space sample.
                    actions=[
                        convert(
                            unpack_if_needed(batch[schema[Columns.ACTIONS]][i]),
                            action_space,
                        )
                        if Columns.ACTIONS in input_compress_columns
                        else convert(batch[schema[Columns.ACTIONS]][i], action_space)
                    ],
                    rewards=[batch[schema[Columns.REWARDS]][i]],
                    terminated=batch[
                        schema[Columns.TERMINATEDS]
                        if schema[Columns.TERMINATEDS] in batch
                        else "dones"
                    ][i],
                    truncated=batch[schema[Columns.TRUNCATEDS]][i]
                    if schema[Columns.TRUNCATEDS] in batch
                    else False,
                    # TODO (simon): Results in zero-length episodes in connector.
                    # t_started=batch[Columns.T if Columns.T in batch else
                    # "unroll_id"][i][0],
                    # TODO (simon): Single-dimensional columns are not supported.
                    # Extra model outputs might be serialized. We unserialize them here
                    # if needed.
                    # TODO (simon): Check, if we need here also reconversion from
                    # JSONable in case of composite spaces.
                    extra_model_outputs={
                        k: [
                            unpack_if_needed(v[i])
                            if k in input_compress_columns
                            else v[i]
                        ]
                        for k, v in batch.items()
                        if (
                            k not in schema
                            and k not in schema.values()
                            and k not in ["dones", "agent_index", "type"]
                        )
                    },
                    len_lookback_buffer=0,
                )
                if to_numpy:
                    episode.to_numpy()
                episodes.append(episode)
        # Note, `map_batches` expects a `Dict` as return value.
        return {"episodes": episodes} 
[docs]
    @OverrideToImplementCustomLogic
    @staticmethod
    def _map_sample_batch_to_episode(
        is_multi_agent: bool,
        batch: Dict[str, Union[list, np.ndarray]],
        schema: Dict[str, str] = SCHEMA,
        to_numpy: bool = False,
        input_compress_columns: Optional[List[str]] = None,
    ) -> Dict[str, List[EpisodeType]]:
        """Maps an old stack `SampleBatch` to new stack episodes."""
        # Set `input_compress_columns` to an empty `list` if `None`.
        input_compress_columns = input_compress_columns or []
        # TODO (simon): CHeck, if needed. It could possibly happen that a batch contains
        #   data from different episodes. Merging and resplitting the batch would then
        # be the solution.
        # Check, if batch comes actually from multiple episodes.
        # episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1
        # Define a container to collect episodes.
        episodes = []
        # Loop over `SampleBatch`es in the `ray.data` batch (a dict).
        for i, obs in enumerate(batch[schema[Columns.OBS]]):
            # If multi-agent we need to extract the agent ID.
            # TODO (simon): Check, what happens with the module ID.
            if is_multi_agent:
                agent_id = (
                    # The old stack uses "agent_index" instead of "agent_id".
                    batch[schema["agent_index"]][i][0]
                    if schema["agent_index"] in batch
                    else None
                )
            else:
                agent_id = None
            if is_multi_agent:
                # TODO (simon): Add support for multi-agent episodes.
                pass
            else:
                # Unpack observations, if needed. Note, observations could
                # be either compressed by their entirety (the complete batch
                # column) or individually (each column entry).
                if isinstance(obs, str):
                    # Decompress the observations if we have a string, i.e.
                    # observations are compressed in their entirety.
                    obs = unpack_if_needed(obs)
                    # Convert to a list of arrays. This is needed as input by
                    # the `SingleAgentEpisode`.
                    obs = [obs[i, ...] for i in range(obs.shape[0])]
                # Otherwise observations are only compressed inside of the
                # batch column (if at all).
                elif isinstance(obs, np.ndarray):
                    # Unpack observations, if they are compressed otherwise we
                    # simply convert to a list, which is needed by the
                    # `SingleAgentEpisode`.
                    obs = (
                        unpack_if_needed(obs.tolist())
                        if schema[Columns.OBS] in input_compress_columns
                        else obs.tolist()
                    )
                else:
                    raise TypeError(
                        f"Unknown observation type: {type(obs)}. When mapping "
                        "from old recorded `SampleBatches` batched "
                        "observations should be either of type `np.array` "
                        "or - if the column is compressed - of `str` type."
                    )
                if schema[Columns.NEXT_OBS] in batch:
                    # Append the last `new_obs` to get the correct length of
                    # observations.
                    obs.append(
                        unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
                        if schema[Columns.OBS] in input_compress_columns
                        else batch[schema[Columns.NEXT_OBS]][i][-1]
                    )
                else:
                    # Otherwise we duplicate the last observation.
                    obs.append(obs[-1])
                # Check, if we have `done`, `truncated`, or `terminated`s in
                # the batch.
                if (
                    schema[Columns.TRUNCATEDS] in batch
                    and schema[Columns.TERMINATEDS] in batch
                ):
                    truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
                    terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
                elif (
                    schema[Columns.TRUNCATEDS] in batch
                    and schema[Columns.TERMINATEDS] not in batch
                ):
                    truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
                    terminated = False
                elif (
                    schema[Columns.TRUNCATEDS] not in batch
                    and schema[Columns.TERMINATEDS] in batch
                ):
                    terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
                    truncated = False
                elif "done" in batch:
                    terminated = batch["done"][i][-1]
                    truncated = False
                # Otherwise, if no `terminated`, nor `truncated` nor `done`
                # is given, we consider the episode as terminated.
                else:
                    terminated = True
                    truncated = False
                # Create a `SingleAgentEpisode`.
                episode = SingleAgentEpisode(
                    # If the recorded episode has an ID we use this ID,
                    # otherwise we generate a new one.
                    id_=str(batch[schema[Columns.EPS_ID]][i][0])
                    if schema[Columns.EPS_ID] in batch
                    else uuid.uuid4().hex,
                    agent_id=agent_id,
                    observations=obs,
                    infos=(
                        batch[schema[Columns.INFOS]][i]
                        if schema[Columns.INFOS] in batch
                        else [{}] * len(obs)
                    ),
                    # Actions might be (a) serialized. We unserialize them here.
                    actions=(
                        unpack_if_needed(batch[schema[Columns.ACTIONS]][i])
                        if Columns.ACTIONS in input_compress_columns
                        else batch[schema[Columns.ACTIONS]][i]
                    ),
                    rewards=batch[schema[Columns.REWARDS]][i],
                    terminated=terminated,
                    truncated=truncated,
                    # TODO (simon): Results in zero-length episodes in connector.
                    # t_started=batch[Columns.T if Columns.T in batch else
                    # "unroll_id"][i][0],
                    # TODO (simon): Single-dimensional columns are not supported.
                    # Extra model outputs might be serialized. We unserialize them here
                    # if needed.
                    # TODO (simon): Check, if we need here also reconversion from
                    # JSONable in case of composite spaces.
                    extra_model_outputs={
                        k: unpack_if_needed(v[i])
                        if k in input_compress_columns
                        else v[i]
                        for k, v in batch.items()
                        if (
                            k not in schema
                            and k not in schema.values()
                            and k not in ["dones", "agent_index", "type"]
                        )
                    },
                    len_lookback_buffer=0,
                )
                # Numpy'ized, if necessary.
                # TODO (simon, sven): Check, if we should convert all data to lists
                # before. Right now only obs are lists.
                if to_numpy:
                    episode.to_numpy()
                episodes.append(episode)
        # Note, `map_batches` expects a `Dict` as return value.
        return {"episodes": episodes}