import logging
import ray
from pathlib import Path
from typing import List
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core.columns import Columns
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.utils.annotations import (
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
OverrideToImplementCustomLogic,
)
from ray.rllib.utils.compression import pack_if_needed
from ray.rllib.utils.spaces.space_utils import to_jsonable_if_needed
from ray.rllib.utils.typing import EpisodeType
from ray.util.debug import log_once
from ray.util.annotations import PublicAPI
logger = logging.Logger(__file__)
# TODO (simon): This class can be agnostic to the episode type as it
# calls only get_state.
[docs]
@PublicAPI(stability="alpha")
class OfflineSingleAgentEnvRunner(SingleAgentEnvRunner):
"""The environment runner to record the single agent case."""
@override(SingleAgentEnvRunner)
@OverrideToImplementCustomLogic_CallToSuperRecommended
def __init__(self, *, config: AlgorithmConfig, **kwargs):
# Initialize the parent.
super().__init__(config=config, **kwargs)
# Get the data context for this `EnvRunner`.
data_context = ray.data.DataContext.get_current()
# Limit the resources for Ray Data to the CPUs given to this `EnvRunner`.
data_context.execution_options.resource_limits.cpu = (
config.num_cpus_per_env_runner
)
# Set the output write method.
self.output_write_method = self.config.output_write_method
self.output_write_method_kwargs = self.config.output_write_method_kwargs
# Set the filesystem.
self.filesystem = self.config.output_filesystem
self.filesystem_kwargs = self.config.output_filesystem_kwargs
self.filesystem_object = None
# Set the output base path.
self.output_path = self.config.output
# Set the subdir (environment specific).
self.subdir_path = self.config.env.lower()
# Set the worker-specific path name. Note, this is
# specifically to enable multi-threaded writing into
# the same directory.
self.worker_path = "run-" + f"{self.worker_index}".zfill(6)
# If a specific filesystem is given, set it up. Note, this could
# be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage.
# this filesystem is specifically needed, if a session has to be created
# with the cloud provider.
if self.filesystem == "gcs":
import gcsfs
self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs)
elif self.filesystem == "s3":
from pyarrow import fs
self.filesystem_object = fs.S3FileSystem(**self.filesystem_kwargs)
elif self.filesystem == "abs":
import adlfs
self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs)
elif self.filesystem is not None:
raise ValueError(
f"Unknown filesystem: {self.filesystem}. Filesystems can be "
"'gcs' for GCS, 's3' for S3, or 'abs'"
)
# Add the filesystem object to the write method kwargs.
self.output_write_method_kwargs.update(
{
"filesystem": self.filesystem_object,
}
)
# If we should store `SingleAgentEpisodes` or column data.
self.output_write_episodes = self.config.output_write_episodes
# Which columns should be compressed in the output data.
self.output_compress_columns = self.config.output_compress_columns
# Buffer these many rows before writing to file.
self.output_max_rows_per_file = self.config.output_max_rows_per_file
# If the user defines a maximum number of rows per file, set the
# event to `False` and check during sampling.
if self.output_max_rows_per_file:
self.write_data_this_iter = False
# Otherwise the event is always `True` and we write always sampled
# data immediately to disk.
else:
self.write_data_this_iter = True
# If the remaining data should be stored. Note, this is only
# relevant in case `output_max_rows_per_file` is defined.
self.write_remaining_data = self.config.output_write_remaining_data
# Counts how often `sample` is called to define the output path for
# each file.
self._sample_counter = 0
# Define the buffer for experiences stored until written to disk.
self._samples = []
[docs]
@override(SingleAgentEnvRunner)
@OverrideToImplementCustomLogic
def sample(
self,
*,
num_timesteps: int = None,
num_episodes: int = None,
explore: bool = None,
random_actions: bool = False,
force_reset: bool = False,
) -> List[SingleAgentEpisode]:
"""Samples from environments and writes data to disk."""
# Call the super sample method.
samples = super().sample(
num_timesteps=num_timesteps,
num_episodes=num_episodes,
explore=explore,
random_actions=random_actions,
force_reset=force_reset,
)
self._sample_counter += 1
# Add data to the buffers.
if self.output_write_episodes:
import msgpack
import msgpack_numpy as mnp
if log_once("msgpack"):
logger.info(
"Packing episodes with `msgpack` and encode array with "
"`msgpack_numpy` for serialization. This is needed for "
"recording episodes."
)
# Note, we serialize episodes with `msgpack` and `msgpack_numpy` to
# ensure version compatibility.
self._samples.extend(
[msgpack.packb(eps.get_state(), default=mnp.encode) for eps in samples]
)
else:
self._map_episodes_to_data(samples)
# If the user defined the maximum number of rows to write.
if self.output_max_rows_per_file:
# Check, if this number is reached.
if len(self._samples) >= self.output_max_rows_per_file:
# Start the recording of data.
self.write_data_this_iter = True
if self.write_data_this_iter:
# If the user wants a maximum number of experiences per file,
# cut the samples to write to disk from the buffer.
if self.output_max_rows_per_file:
# Reset the event.
self.write_data_this_iter = False
# Ensure that all data ready to be written is released from
# the buffer. Note, this is important in case we have many
# episodes sampled and a relatively small `output_max_rows_per_file`.
while len(self._samples) >= self.output_max_rows_per_file:
# Extract the number of samples to be written to disk this
# iteration.
samples_to_write = self._samples[: self.output_max_rows_per_file]
# Reset the buffer to the remaining data. This only makes sense, if
# `rollout_fragment_length` is smaller `output_max_rows_per_file` or
# a 2 x `output_max_rows_per_file`.
self._samples = self._samples[self.output_max_rows_per_file :]
samples_ds = ray.data.from_items(samples_to_write)
# Otherwise, write the complete data.
else:
samples_ds = ray.data.from_items(self._samples)
try:
# Setup the path for writing data. Each run will be written to
# its own file. A run is a writing event. The path will look
# like. 'base_path/env-name/00000<WorkerID>-00000<RunID>'.
path = (
Path(self.output_path)
.joinpath(self.subdir_path)
.joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6))
)
getattr(samples_ds, self.output_write_method)(
path.as_posix(), **self.output_write_method_kwargs
)
logger.info(f"Wrote samples to storage at {path}.")
except Exception as e:
logger.error(e)
self.metrics.log_value(
key="recording_buffer_size",
value=len(self._samples),
)
# Finally return the samples as usual.
return samples
[docs]
@override(EnvRunner)
@OverrideToImplementCustomLogic
def stop(self) -> None:
"""Writes the reamining samples to disk
Note, if the user defined `max_rows_per_file` the
number of rows for the remaining samples could be
less than the defined maximum row number by the user.
"""
# If there are samples left over we have to write htem to disk. them
# to a dataset.
if self._samples and self.write_remaining_data:
# Convert them to a `ray.data.Dataset`.
samples_ds = ray.data.from_items(self._samples)
# Increase the sample counter for the folder/file name.
self._sample_counter += 1
# Try to write the dataset to disk/cloud storage.
try:
# Setup the path for writing data. Each run will be written to
# its own file. A run is a writing event. The path will look
# like. 'base_path/env-name/00000<WorkerID>-00000<RunID>'.
path = (
Path(self.output_path)
.joinpath(self.subdir_path)
.joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6))
)
getattr(samples_ds, self.output_write_method)(
path.as_posix(), **self.output_write_method_kwargs
)
logger.info(
f"Wrote final samples to storage at {path}. Note "
"Note, final samples could be smaller in size than "
f"`max_rows_per_file`, if defined."
)
except Exception as e:
logger.error(e)
logger.debug(f"Experience buffer length: {len(self._samples)}")
@OverrideToImplementCustomLogic
def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None:
"""Converts list of episodes to list of single dict experiences.
Note, this method also appends all sampled experiences to the
buffer.
Args:
samples: List of episodes to be converted.
"""
# Loop through all sampled episodes.
obs_space = self.env.observation_space
action_space = self.env.action_space
for sample in samples:
# Loop through all items of the episode.
for i in range(len(sample)):
sample_data = {
Columns.EPS_ID: sample.id_,
Columns.AGENT_ID: sample.agent_id,
Columns.MODULE_ID: sample.module_id,
# Compress observations, if requested.
Columns.OBS: pack_if_needed(
to_jsonable_if_needed(sample.get_observations(i), obs_space)
)
if Columns.OBS in self.output_compress_columns
else to_jsonable_if_needed(sample.get_observations(i), obs_space),
# Compress actions, if requested.
Columns.ACTIONS: pack_if_needed(
to_jsonable_if_needed(sample.get_actions(i), action_space)
)
if Columns.ACTIONS in self.output_compress_columns
else to_jsonable_if_needed(sample.get_actions(i), action_space),
Columns.REWARDS: sample.get_rewards(i),
# Compress next observations, if requested.
Columns.NEXT_OBS: pack_if_needed(
to_jsonable_if_needed(sample.get_observations(i + 1), obs_space)
)
if Columns.OBS in self.output_compress_columns
else to_jsonable_if_needed(
sample.get_observations(i + 1), obs_space
),
Columns.TERMINATEDS: False
if i < len(sample) - 1
else sample.is_terminated,
Columns.TRUNCATEDS: False
if i < len(sample) - 1
else sample.is_truncated,
**{
# Compress any extra model output, if requested.
k: pack_if_needed(sample.get_extra_model_outputs(k, i))
if k in self.output_compress_columns
else sample.get_extra_model_outputs(k, i)
for k in sample.extra_model_outputs.keys()
},
}
# Finally append to the data buffer.
self._samples.append(sample_data)