import logging
from pathlib import Path
import pyarrow.fs
import ray
import time
import types
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core import COMPONENT_RL_MODULE
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.annotations import (
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.util.annotations import PublicAPI
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="alpha")
class OfflineData:
[docs]
@OverrideToImplementCustomLogic_CallToSuperRecommended
def __init__(self, config: AlgorithmConfig):
self.config = config
self.is_multi_agent = self.config.is_multi_agent
self.path = (
self.config.input_
if isinstance(config.input_, list)
else Path(config.input_)
)
# Use `read_parquet` as default data read method.
self.data_read_method = self.config.input_read_method
# Override default arguments for the data read method.
self.data_read_method_kwargs = self.config.input_read_method_kwargs
# In case `EpisodeType` or `BatchType` batches are read the size
# could differ from the final `train_batch_size_per_learner`.
self.data_read_batch_size = self.config.input_read_batch_size
# If data should be materialized.
self.materialize_data = config.materialize_data
# If mapped data should be materialized.
self.materialize_mapped_data = config.materialize_mapped_data
# Flag to identify, if data has already been mapped with the
# `OfflinePreLearner`.
self.data_is_mapped = False
# Set the filesystem.
self.filesystem = self.config.input_filesystem
self.filesystem_kwargs = self.config.input_filesystem_kwargs
self.filesystem_object = None
# If a specific filesystem is given, set it up. Note, this could
# be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage.
# this filesystem is specifically needed, if a session has to be created
# with the cloud provider.
if self.filesystem == "gcs":
import gcsfs
self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs)
elif self.filesystem == "s3":
self.filesystem_object = pyarrow.fs.S3FileSystem(**self.filesystem_kwargs)
elif self.filesystem == "abs":
import adlfs
self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs)
elif isinstance(self.filesystem, pyarrow.fs.FileSystem):
self.filesystem_object = self.filesystem
elif self.filesystem is not None:
raise ValueError(
f"Unknown `config.input_filesystem` {self.filesystem}! Filesystems "
"can be None for local, any instance of `pyarrow.fs.FileSystem`, "
"'gcs' for GCS, 's3' for S3, or 'abs' for adlfs.AzureBlobFileSystem."
)
# Add the filesystem object to the write method kwargs.
if self.filesystem_object:
self.data_read_method_kwargs.update(
{
"filesystem": self.filesystem_object,
}
)
try:
# Load the dataset.
start_time = time.perf_counter()
self.data = getattr(ray.data, self.data_read_method)(
self.path, **self.data_read_method_kwargs
)
if self.materialize_data:
self.data = self.data.materialize()
stop_time = time.perf_counter()
logger.debug(
"===> [OfflineData] - Time for loading dataset: "
f"{stop_time - start_time}s."
)
logger.info("Reading data from {}".format(self.path))
except Exception as e:
logger.error(e)
# Avoids reinstantiating the batch iterator each time we sample.
self.batch_iterators = None
self.map_batches_kwargs = (
self.default_map_batches_kwargs | self.config.map_batches_kwargs
)
self.iter_batches_kwargs = (
self.default_iter_batches_kwargs | self.config.iter_batches_kwargs
)
self.returned_streaming_split = False
# Defines the prelearner class. Note, this could be user-defined.
self.prelearner_class = self.config.prelearner_class or OfflinePreLearner
# For remote learner setups.
self.locality_hints = None
self.learner_handles = None
self.module_spec = None
[docs]
@OverrideToImplementCustomLogic
def sample(
self,
num_samples: int,
return_iterator: bool = False,
num_shards: int = 1,
):
# Materialize the mapped data, if necessary. This runs for all the
# data the `OfflinePreLearner` logic and maps them to `MultiAgentBatch`es.
# TODO (simon, sven): This would never update the module nor the
# the connectors. If this is needed we have to check, if we give
# (a) only an iterator and let the learner and OfflinePreLearner
# communicate through the object storage. This only works when
# not materializing.
# (b) Rematerialize the data every couple of iterations. This is
# is costly.
if not self.data_is_mapped:
# Constructor `kwargs` for the `OfflinePreLearner`.
fn_constructor_kwargs = {
"config": self.config,
"learner": self.learner_handles[0],
"spaces": self.spaces[INPUT_ENV_SPACES],
}
# If we have multiple learners, add to the constructor `kwargs`.
if num_shards > 1:
# Call here the learner to get an up-to-date module state.
# TODO (simon): This is a workaround as along as learners cannot
# receive any calls from another actor.
module_state = ray.get(
self.learner_handles[0].get_state.remote(
component=COMPONENT_RL_MODULE
)
)
# Add constructor `kwargs` when using remote learners.
fn_constructor_kwargs.update(
{
"learner": None,
"module_spec": self.module_spec,
"module_state": module_state,
}
)
self.data = self.data.map_batches(
self.prelearner_class,
fn_constructor_kwargs=fn_constructor_kwargs,
batch_size=self.data_read_batch_size or num_samples,
**self.map_batches_kwargs,
)
# Set the flag to `True`.
self.data_is_mapped = True
# If the user wants to materialize the data in memory.
if self.materialize_mapped_data:
self.data = self.data.materialize()
# Build an iterator, if necessary. Note, in case that an iterator should be
# returned now and we have already generated from the iterator, i.e.
# `isinstance(self.batch_iterators, types.GeneratorType) == True`, we need
# to create here a new iterator.
if not self.batch_iterators or (
return_iterator and isinstance(self.batch_iterators, types.GeneratorType)
):
# If we have more than one learner create an iterator for each of them
# by splitting the data stream.
if num_shards > 1:
logger.debug("===> [OfflineData]: Return streaming_split ... ")
# In case of multiple shards, we return multiple
# `StreamingSplitIterator` instances.
self.batch_iterators = self.data.streaming_split(
n=num_shards,
# Note, `equal` must be `True`, i.e. the batch size must
# be the same for all batches b/c otherwise remote learners
# could block each others.
equal=True,
locality_hints=self.locality_hints,
)
# Otherwise we create a simple iterator and - if necessary - initialize
# it here.
else:
# If no iterator should be returned, or if we want to return a single
# batch iterator, we instantiate the batch iterator once, here.
self.batch_iterators = self.data.iter_batches(
# This is important. The batch size is now 1, because the data
# is already run through the `OfflinePreLearner` and a single
# instance is a single `MultiAgentBatch` of size `num_samples`.
batch_size=1,
**self.iter_batches_kwargs,
)
# If there should be batches
if not return_iterator:
self.batch_iterators = iter(self.batch_iterators)
# Do we want to return an iterator or a single batch?
if return_iterator:
return self.batch_iterators
else:
# Return a single batch from the iterator.
try:
return next(self.batch_iterators)["batch"][0]
except StopIteration:
# If the batch iterator is exhausted, reinitiate a new one.
logger.debug(
"===> [OfflineData]: Batch iterator exhausted. Reinitiating ..."
)
self.batch_iterators = None
return self.sample(
num_samples=num_samples,
return_iterator=return_iterator,
num_shards=num_shards,
)
@property
def default_map_batches_kwargs(self):
return {
"concurrency": max(2, self.config.num_learners),
"zero_copy_batch": True,
}
@property
def default_iter_batches_kwargs(self):
return {
"prefetch_batches": 2,
}