Source code for ray.rllib.offline.offline_data

import logging
from pathlib import Path
import pyarrow.fs
import ray
import time
import types

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.core import COMPONENT_RL_MODULE
from ray.rllib.env import INPUT_ENV_SPACES
from ray.rllib.offline.offline_prelearner import OfflinePreLearner
from ray.rllib.utils.annotations import (
    OverrideToImplementCustomLogic,
    OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.util.annotations import PublicAPI

logger = logging.getLogger(__name__)


[docs] @PublicAPI(stability="alpha") class OfflineData:
[docs] @OverrideToImplementCustomLogic_CallToSuperRecommended def __init__(self, config: AlgorithmConfig): self.config = config self.is_multi_agent = self.config.is_multi_agent self.path = ( self.config.input_ if isinstance(config.input_, list) else Path(config.input_) ) # Use `read_parquet` as default data read method. self.data_read_method = self.config.input_read_method # Override default arguments for the data read method. self.data_read_method_kwargs = self.config.input_read_method_kwargs # In case `EpisodeType` or `BatchType` batches are read the size # could differ from the final `train_batch_size_per_learner`. self.data_read_batch_size = self.config.input_read_batch_size # If data should be materialized. self.materialize_data = config.materialize_data # If mapped data should be materialized. self.materialize_mapped_data = config.materialize_mapped_data # Flag to identify, if data has already been mapped with the # `OfflinePreLearner`. self.data_is_mapped = False # Set the filesystem. self.filesystem = self.config.input_filesystem self.filesystem_kwargs = self.config.input_filesystem_kwargs self.filesystem_object = None # If a specific filesystem is given, set it up. Note, this could # be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage. # this filesystem is specifically needed, if a session has to be created # with the cloud provider. if self.filesystem == "gcs": import gcsfs self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs) elif self.filesystem == "s3": self.filesystem_object = pyarrow.fs.S3FileSystem(**self.filesystem_kwargs) elif self.filesystem == "abs": import adlfs self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs) elif isinstance(self.filesystem, pyarrow.fs.FileSystem): self.filesystem_object = self.filesystem elif self.filesystem is not None: raise ValueError( f"Unknown `config.input_filesystem` {self.filesystem}! Filesystems " "can be None for local, any instance of `pyarrow.fs.FileSystem`, " "'gcs' for GCS, 's3' for S3, or 'abs' for adlfs.AzureBlobFileSystem." ) # Add the filesystem object to the write method kwargs. if self.filesystem_object: self.data_read_method_kwargs.update( { "filesystem": self.filesystem_object, } ) try: # Load the dataset. start_time = time.perf_counter() self.data = getattr(ray.data, self.data_read_method)( self.path, **self.data_read_method_kwargs ) if self.materialize_data: self.data = self.data.materialize() stop_time = time.perf_counter() logger.debug( "===> [OfflineData] - Time for loading dataset: " f"{stop_time - start_time}s." ) logger.info("Reading data from {}".format(self.path)) except Exception as e: logger.error(e) # Avoids reinstantiating the batch iterator each time we sample. self.batch_iterators = None self.map_batches_kwargs = ( self.default_map_batches_kwargs | self.config.map_batches_kwargs ) self.iter_batches_kwargs = ( self.default_iter_batches_kwargs | self.config.iter_batches_kwargs ) self.returned_streaming_split = False # Defines the prelearner class. Note, this could be user-defined. self.prelearner_class = self.config.prelearner_class or OfflinePreLearner # For remote learner setups. self.locality_hints = None self.learner_handles = None self.module_spec = None
[docs] @OverrideToImplementCustomLogic def sample( self, num_samples: int, return_iterator: bool = False, num_shards: int = 1, ): # Materialize the mapped data, if necessary. This runs for all the # data the `OfflinePreLearner` logic and maps them to `MultiAgentBatch`es. # TODO (simon, sven): This would never update the module nor the # the connectors. If this is needed we have to check, if we give # (a) only an iterator and let the learner and OfflinePreLearner # communicate through the object storage. This only works when # not materializing. # (b) Rematerialize the data every couple of iterations. This is # is costly. if not self.data_is_mapped: # Constructor `kwargs` for the `OfflinePreLearner`. fn_constructor_kwargs = { "config": self.config, "learner": self.learner_handles[0], "spaces": self.spaces[INPUT_ENV_SPACES], } # If we have multiple learners, add to the constructor `kwargs`. if num_shards > 1: # Call here the learner to get an up-to-date module state. # TODO (simon): This is a workaround as along as learners cannot # receive any calls from another actor. module_state = ray.get( self.learner_handles[0].get_state.remote( component=COMPONENT_RL_MODULE ) ) # Add constructor `kwargs` when using remote learners. fn_constructor_kwargs.update( { "learner": None, "module_spec": self.module_spec, "module_state": module_state, } ) self.data = self.data.map_batches( self.prelearner_class, fn_constructor_kwargs=fn_constructor_kwargs, batch_size=self.data_read_batch_size or num_samples, **self.map_batches_kwargs, ) # Set the flag to `True`. self.data_is_mapped = True # If the user wants to materialize the data in memory. if self.materialize_mapped_data: self.data = self.data.materialize() # Build an iterator, if necessary. Note, in case that an iterator should be # returned now and we have already generated from the iterator, i.e. # `isinstance(self.batch_iterators, types.GeneratorType) == True`, we need # to create here a new iterator. if not self.batch_iterators or ( return_iterator and isinstance(self.batch_iterators, types.GeneratorType) ): # If we have more than one learner create an iterator for each of them # by splitting the data stream. if num_shards > 1: logger.debug("===> [OfflineData]: Return streaming_split ... ") # In case of multiple shards, we return multiple # `StreamingSplitIterator` instances. self.batch_iterators = self.data.streaming_split( n=num_shards, # Note, `equal` must be `True`, i.e. the batch size must # be the same for all batches b/c otherwise remote learners # could block each others. equal=True, locality_hints=self.locality_hints, ) # Otherwise we create a simple iterator and - if necessary - initialize # it here. else: # If no iterator should be returned, or if we want to return a single # batch iterator, we instantiate the batch iterator once, here. self.batch_iterators = self.data.iter_batches( # This is important. The batch size is now 1, because the data # is already run through the `OfflinePreLearner` and a single # instance is a single `MultiAgentBatch` of size `num_samples`. batch_size=1, **self.iter_batches_kwargs, ) # If there should be batches if not return_iterator: self.batch_iterators = iter(self.batch_iterators) # Do we want to return an iterator or a single batch? if return_iterator: return self.batch_iterators else: # Return a single batch from the iterator. try: return next(self.batch_iterators)["batch"][0] except StopIteration: # If the batch iterator is exhausted, reinitiate a new one. logger.debug( "===> [OfflineData]: Batch iterator exhausted. Reinitiating ..." ) self.batch_iterators = None return self.sample( num_samples=num_samples, return_iterator=return_iterator, num_shards=num_shards, )
@property def default_map_batches_kwargs(self): return { "concurrency": max(2, self.config.num_learners), "zero_copy_batch": True, } @property def default_iter_batches_kwargs(self): return { "prefetch_batches": 2, }