Source code for ray.data.datasource.file_based_datasource

import itertools
import logging
import pathlib
import posixpath
import sys
import urllib.parse
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
    TypeVar,
)

import numpy as np

from ray.air._internal.remote_storage import _is_local_windows_path
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.output_buffer import BlockOutputBuffer
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import _check_pyarrow_version, _resolve_custom_scheme
from ray.data.block import Block, BlockAccessor
from ray.data.context import DataContext
from ray.data.datasource.datasource import Datasource, Reader, ReadTask, WriteResult
from ray.data.datasource.file_meta_provider import (
    BaseFileMetadataProvider,
    DefaultFileMetadataProvider,
)
from ray.data.datasource.partitioning import (
    Partitioning,
    PathPartitionFilter,
    PathPartitionParser,
)

from ray.types import ObjectRef
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray._private.utils import _add_creatable_buckets_param_if_s3_uri

if TYPE_CHECKING:
    import pandas as pd
    import pyarrow


logger = logging.getLogger(__name__)


# We should parallelize file size fetch operations beyond this threshold.
FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD = 16

# 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem.
PATHS_PER_FILE_SIZE_FETCH_TASK = 16


@DeveloperAPI
class BlockWritePathProvider:
    """Abstract callable that provides concrete output paths when writing
    dataset blocks.

    Current subclasses:
        DefaultBlockWritePathProvider
    """

    def _get_write_path_for_block(
        self,
        base_path: str,
        *,
        filesystem: Optional["pyarrow.fs.FileSystem"] = None,
        dataset_uuid: Optional[str] = None,
        block: Optional[Block] = None,
        block_index: Optional[int] = None,
        file_format: Optional[str] = None,
    ) -> str:
        """
        Resolves and returns the write path for the given dataset block. When
        implementing this method, care should be taken to ensure that a unique
        path is provided for every dataset block.

        Args:
            base_path: The base path to write the dataset block out to. This is
                expected to be the same for all blocks in the dataset, and may
                point to either a directory or file prefix.
            filesystem: The filesystem implementation that will be used to
                write a file out to the write path returned.
            dataset_uuid: Unique identifier for the dataset that this block
                belongs to.
            block: The block to write.
            block_index: Ordered index of the block to write within its parent
                dataset.
            file_format: File format string for the block that can be used as
                the file extension in the write path returned.

        Returns:
            The dataset block write path.
        """
        raise NotImplementedError

    def __call__(
        self,
        base_path: str,
        *,
        filesystem: Optional["pyarrow.fs.FileSystem"] = None,
        dataset_uuid: Optional[str] = None,
        block: Optional[Block] = None,
        block_index: Optional[int] = None,
        file_format: Optional[str] = None,
    ) -> str:
        return self._get_write_path_for_block(
            base_path,
            filesystem=filesystem,
            dataset_uuid=dataset_uuid,
            block=block,
            block_index=block_index,
            file_format=file_format,
        )


@DeveloperAPI
class DefaultBlockWritePathProvider(BlockWritePathProvider):
    """Default block write path provider implementation that writes each
    dataset block out to a file of the form:
    {base_path}/{dataset_uuid}_{block_index}.{file_format}
    """

    def _get_write_path_for_block(
        self,
        base_path: str,
        *,
        filesystem: Optional["pyarrow.fs.FileSystem"] = None,
        dataset_uuid: Optional[str] = None,
        block: Optional[ObjectRef[Block]] = None,
        block_index: Optional[int] = None,
        file_format: Optional[str] = None,
    ) -> str:
        suffix = f"{dataset_uuid}_{block_index:06}.{file_format}"
        # Uses POSIX path for cross-filesystem compatibility, since PyArrow
        # FileSystem paths are always forward slash separated, see:
        # https://arrow.apache.org/docs/python/filesystems.html
        return posixpath.join(base_path, suffix)


@PublicAPI(stability="beta")
class FileExtensionFilter(PathPartitionFilter):
    """A file-extension-based path filter that filters files that don't end
    with the provided extension(s).

    Attributes:
        file_extensions: File extension(s) of files to be included in reading.
        allow_if_no_extension: If this is True, files without any extensions
            will be included in reading.

    """

    def __init__(
        self,
        file_extensions: Union[str, List[str]],
        allow_if_no_extension: bool = False,
    ):
        if isinstance(file_extensions, str):
            file_extensions = [file_extensions]

        self.extensions = [f".{ext.lower()}" for ext in file_extensions]
        self.allow_if_no_extension = allow_if_no_extension

    def _file_has_extension(self, path: str):
        suffixes = [suffix.lower() for suffix in pathlib.Path(path).suffixes]
        if not suffixes:
            return self.allow_if_no_extension
        return any(ext in suffixes for ext in self.extensions)

    def __call__(self, paths: List[str]) -> List[str]:
        return [path for path in paths if self._file_has_extension(path)]

    def __str__(self):
        return (
            f"{type(self).__name__}(extensions={self.extensions}, "
            f"allow_if_no_extensions={self.allow_if_no_extension})"
        )

    def __repr__(self):
        return str(self)


[docs]@DeveloperAPI class FileBasedDatasource(Datasource): """File-based datasource, for reading and writing files. This class should not be used directly, and should instead be subclassed and tailored to particular file formats. Classes deriving from this class must implement _read_file(). If the _FILE_EXTENSION is defined, per default only files with this extension will be read. If None, no default filter is used. Current subclasses: JSONDatasource, CSVDatasource, NumpyDatasource, BinaryDatasource """ _FILE_EXTENSION: Optional[Union[str, List[str]]] = None def _open_input_source( self, filesystem: "pyarrow.fs.FileSystem", path: str, **open_args, ) -> "pyarrow.NativeFile": """Opens a source path for reading and returns the associated Arrow NativeFile. The default implementation opens the source path as a sequential input stream, using ctx.streaming_read_buffer_size as the buffer size if none is given by the caller. Implementations that do not support streaming reads (e.g. that require random access) should override this method. """ buffer_size = open_args.pop("buffer_size", None) if buffer_size is None: ctx = DataContext.get_current() buffer_size = ctx.streaming_read_buffer_size return filesystem.open_input_stream(path, buffer_size=buffer_size, **open_args) def create_reader(self, **kwargs): return _FileBasedDatasourceReader(self, **kwargs) def _rows_per_file(self): """Returns the number of rows per file, or None if unknown.""" return None def _read_stream( self, f: "pyarrow.NativeFile", path: str, **reader_args ) -> Iterator[Block]: """Streaming read a single file, passing all kwargs to the reader. By default, delegates to self._read_file(). """ yield self._read_file(f, path, **reader_args) def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args) -> Block: """Reads a single file, passing all kwargs to the reader. This method should be implemented by subclasses. """ raise NotImplementedError( "Subclasses of FileBasedDatasource must implement _read_file()." ) def _convert_block_to_tabular_block( self, block: Block, column_name: Optional[str] = None ) -> Union["pyarrow.Table", "pd.DataFrame"]: """Convert block returned by `_read_file` or `_read_stream` to a tabular block. If your `_read_file` or `_read_stream` implementation returns a list, then you need to implement this method. Otherwise, `FileBasedDatasource` won't be able to include partition data. """ import pandas as pd import pyarrow as pa if isinstance(block, (pd.DataFrame, pa.Table)): return block raise NotImplementedError( "If your `_read_file` or `_read_stream` implementation returns a list, " "then you need to implement `_convert_block_to_tabular_block." )
[docs] def write( self, blocks: Iterable[Block], ctx: TaskContext, path: str, dataset_uuid: str, filesystem: Optional["pyarrow.fs.FileSystem"] = None, try_create_dir: bool = True, open_stream_args: Optional[Dict[str, Any]] = None, block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(), write_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, _block_udf: Optional[Callable[[Block], Block]] = None, **write_args, ) -> WriteResult: """Write blocks for a file-based datasource.""" path, filesystem = _resolve_paths_and_filesystem(path, filesystem) path = path[0] if try_create_dir: # Arrow's S3FileSystem doesn't allow creating buckets by default, so we add # a query arg enabling bucket creation if an S3 URI is provided. tmp = _add_creatable_buckets_param_if_s3_uri(path) filesystem.create_dir(tmp, recursive=True) filesystem = _wrap_s3_serialization_workaround(filesystem) _write_block_to_file = self._write_block if open_stream_args is None: open_stream_args = {} def write_block(write_path: str, block: Block): logger.debug(f"Writing {write_path} file.") fs = _unwrap_s3_serialization_workaround(filesystem) if _block_udf is not None: block = _block_udf(block) with fs.open_output_stream(write_path, **open_stream_args) as f: _write_block_to_file( f, BlockAccessor.for_block(block), writer_args_fn=write_args_fn, **write_args, ) # TODO: decide if we want to return richer object when the task # succeeds. return "ok" file_format = self._FILE_EXTENSION if isinstance(file_format, list): file_format = file_format[0] builder = DelegatingBlockBuilder() for block in blocks: builder.add_block(block) block = builder.build() if not block_path_provider: block_path_provider = DefaultBlockWritePathProvider() write_path = block_path_provider( path, filesystem=filesystem, dataset_uuid=dataset_uuid, block=block, block_index=ctx.task_idx, file_format=file_format, ) return write_block(write_path, block)
def _write_block( self, f: "pyarrow.NativeFile", block: BlockAccessor, writer_args_fn: Callable[[], Dict[str, Any]] = lambda: {}, **writer_args, ): """Writes a block to a single file, passing all kwargs to the writer. This method should be implemented by subclasses. """ raise NotImplementedError( "Subclasses of FileBasedDatasource must implement _write_files()." ) @classmethod def file_extension_filter(cls) -> Optional[PathPartitionFilter]: if cls._FILE_EXTENSION is None: return None return FileExtensionFilter(cls._FILE_EXTENSION)
class _FileBasedDatasourceReader(Reader): def __init__( self, delegate: FileBasedDatasource, paths: Union[str, List[str]], filesystem: Optional["pyarrow.fs.FileSystem"] = None, schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None, open_stream_args: Optional[Dict[str, Any]] = None, meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(), partition_filter: PathPartitionFilter = None, partitioning: Partitioning = None, # TODO(ekl) deprecate this once read fusion is available. _block_udf: Optional[Callable[[Block], Block]] = None, ignore_missing_paths: bool = False, **reader_args, ): _check_pyarrow_version() self._delegate = delegate self._schema = schema self._open_stream_args = open_stream_args self._meta_provider = meta_provider self._partition_filter = partition_filter self._partitioning = partitioning self._block_udf = _block_udf self._ignore_missing_paths = ignore_missing_paths self._reader_args = reader_args paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem) self._paths, self._file_sizes = map( list, zip( *meta_provider.expand_paths( paths, self._filesystem, partitioning, ignore_missing_paths=ignore_missing_paths, ) ), ) if ignore_missing_paths and len(self._paths) == 0: raise ValueError( "None of the provided paths exist. " "The 'ignore_missing_paths' field is set to True." ) if self._partition_filter is not None: # Use partition filter to skip files which are not needed. path_to_size = dict(zip(self._paths, self._file_sizes)) self._paths = self._partition_filter(self._paths) self._file_sizes = [path_to_size[p] for p in self._paths] if len(self._paths) == 0: raise ValueError( "No input files found to read. Please double check that " "'partition_filter' field is set properly." ) def estimate_inmemory_data_size(self) -> Optional[int]: total_size = 0 for sz in self._file_sizes: if sz is not None: total_size += sz return total_size def get_read_tasks(self, parallelism: int) -> List[ReadTask]: import numpy as np ctx = DataContext.get_current() open_stream_args = self._open_stream_args reader_args = self._reader_args partitioning = self._partitioning _block_udf = self._block_udf paths, file_sizes = self._paths, self._file_sizes read_stream = self._delegate._read_stream convert_block_to_tabular_block = self._delegate._convert_block_to_tabular_block column_name = reader_args.get("column_name", None) filesystem = _wrap_s3_serialization_workaround(self._filesystem) if open_stream_args is None: open_stream_args = {} open_input_source = self._delegate._open_input_source def read_files( read_paths: List[str], fs: Union["pyarrow.fs.FileSystem", _S3FileSystemWrapper], ) -> Iterable[Block]: DataContext._set_current(ctx) logger.debug(f"Reading {len(read_paths)} files.") fs = _unwrap_s3_serialization_workaround(filesystem) output_buffer = BlockOutputBuffer( block_udf=_block_udf, target_max_block_size=ctx.target_max_block_size ) for read_path in read_paths: compression = open_stream_args.pop("compression", None) if compression is None: import pyarrow as pa try: # If no compression manually given, try to detect # compression codec from path. compression = pa.Codec.detect(read_path).name except (ValueError, TypeError): # Arrow's compression inference on the file path # doesn't work for Snappy, so we double-check ourselves. import pathlib suffix = pathlib.Path(read_path).suffix if suffix and suffix[1:] == "snappy": compression = "snappy" else: compression = None if compression == "snappy": # Pass Snappy compression as a reader arg, so datasource subclasses # can manually handle streaming decompression in # self._delegate._read_stream(). reader_args["compression"] = compression reader_args["filesystem"] = fs elif compression is not None: # Non-Snappy compression, pass as open_input_stream() arg so Arrow # can take care of streaming decompression for us. open_stream_args["compression"] = compression partitions: Dict[str, str] = {} if partitioning is not None: parse = PathPartitionParser(partitioning) partitions = parse(read_path) with open_input_source(fs, read_path, **open_stream_args) as f: for data in read_stream(f, read_path, **reader_args): if partitions: data = convert_block_to_tabular_block(data, column_name) data = _add_partitions(data, partitions) output_buffer.add_block(data) if output_buffer.has_next(): yield output_buffer.next() output_buffer.finalize() if output_buffer.has_next(): yield output_buffer.next() # fix https://github.com/ray-project/ray/issues/24296 parallelism = min(parallelism, len(paths)) read_tasks = [] for read_paths, file_sizes in zip( np.array_split(paths, parallelism), np.array_split(file_sizes, parallelism) ): if len(read_paths) <= 0: continue meta = self._meta_provider( read_paths, self._schema, rows_per_file=self._delegate._rows_per_file(), file_sizes=file_sizes, ) read_task = ReadTask( lambda read_paths=read_paths: read_files(read_paths, filesystem), meta ) read_tasks.append(read_task) return read_tasks def _add_partitions( data: Union["pyarrow.Table", "pd.DataFrame"], partitions: Dict[str, Any] ) -> Union["pyarrow.Table", "pd.DataFrame"]: import pandas as pd import pyarrow as pa assert isinstance(data, (pa.Table, pd.DataFrame)) if isinstance(data, pa.Table): return _add_partitions_to_table(data, partitions) if isinstance(data, pd.DataFrame): return _add_partitions_to_dataframe(data, partitions) def _add_partitions_to_table( table: "pyarrow.Table", partitions: Dict[str, Any] ) -> "pyarrow.Table": import pyarrow as pa import pyarrow.compute as pc column_names = set(table.column_names) for field, value in partitions.items(): column = pa.array([value] * len(table)) if field in column_names: # TODO: Handle cast error. column_type = table.schema.field(field).type column = column.cast(column_type) values_are_equal = pc.all(pc.equal(column, table[field])) values_are_equal = values_are_equal.as_py() if not values_are_equal: raise ValueError( f"Partition column {field} exists in table data, but partition " f"value '{value}' is different from in-data values: " f"{table[field].unique().to_pylist()}." ) i = table.schema.get_field_index(field) table = table.set_column(i, field, column) else: table = table.append_column(field, column) return table def _add_partitions_to_dataframe( df: "pd.DataFrame", partitions: Dict[str, Any] ) -> "pd.DataFrame": import pandas as pd for field, value in partitions.items(): column = pd.Series(data=[value] * len(df), name=field) if field in df: column = column.astype(df[field].dtype) mask = df[field].notna() if not df[field][mask].equals(column[mask]): raise ValueError( f"Partition column {field} exists in table data, but partition " f"value '{value}' is different from in-data values: " f"{list(df[field].unique())}." ) df[field] = column return df # TODO(Clark): Add unit test coverage of _resolve_paths_and_filesystem and # _expand_paths. def _resolve_paths_and_filesystem( paths: Union[str, List[str]], filesystem: "pyarrow.fs.FileSystem" = None, ) -> Tuple[List[str], "pyarrow.fs.FileSystem"]: """ Resolves and normalizes all provided paths, infers a filesystem from the paths and ensures that all paths use the same filesystem. Args: paths: A single file/directory path or a list of file/directory paths. A list of paths can contain both files and directories. filesystem: The filesystem implementation that should be used for reading these files. If None, a filesystem will be inferred. If not None, the provided filesystem will still be validated against all filesystems inferred from the provided paths to ensure compatibility. """ import pyarrow as pa from pyarrow.fs import ( FileSystem, FSSpecHandler, PyFileSystem, _resolve_filesystem_and_path, ) if isinstance(paths, str): paths = [paths] if isinstance(paths, pathlib.Path): paths = [str(paths)] elif not isinstance(paths, list) or any(not isinstance(p, str) for p in paths): raise ValueError("paths must be a path string or a list of path strings.") elif len(paths) == 0: raise ValueError("Must provide at least one path.") need_unwrap_path_protocol = True if filesystem and not isinstance(filesystem, FileSystem): err_msg = ( f"The filesystem passed must either conform to " f"pyarrow.fs.FileSystem, or " f"fsspec.spec.AbstractFileSystem. The provided " f"filesystem was: {filesystem}" ) try: import fsspec from fsspec.implementations.http import HTTPFileSystem except ModuleNotFoundError: # If filesystem is not a pyarrow filesystem and fsspec isn't # installed, then filesystem is neither a pyarrow filesystem nor # an fsspec filesystem, so we raise a TypeError. raise TypeError(err_msg) from None if not isinstance(filesystem, fsspec.spec.AbstractFileSystem): raise TypeError(err_msg) from None if isinstance(filesystem, HTTPFileSystem): # If filesystem is fsspec HTTPFileSystem, the protocol/scheme of paths # should not be unwrapped/removed, because HTTPFileSystem expects full file # paths including protocol/scheme. This is different behavior compared to # file systems implementation in pyarrow.fs.FileSystem. need_unwrap_path_protocol = False filesystem = PyFileSystem(FSSpecHandler(filesystem)) resolved_paths = [] for path in paths: path = _resolve_custom_scheme(path) try: resolved_filesystem, resolved_path = _resolve_filesystem_and_path( path, filesystem ) except pa.lib.ArrowInvalid as e: if "Cannot parse URI" in str(e): resolved_filesystem, resolved_path = _resolve_filesystem_and_path( _encode_url(path), filesystem ) resolved_path = _decode_url(resolved_path) elif "Unrecognized filesystem type in URI" in str(e): scheme = urllib.parse.urlparse(path, allow_fragments=False).scheme if scheme in ["http", "https"]: # If scheme of path is HTTP and filesystem is not resolved, # try to use fsspec HTTPFileSystem. This expects fsspec is # installed. try: from fsspec.implementations.http import HTTPFileSystem except ModuleNotFoundError: raise ImportError( "Please install fsspec to read files from HTTP." ) from None resolved_filesystem = PyFileSystem(FSSpecHandler(HTTPFileSystem())) resolved_path = path need_unwrap_path_protocol = False else: raise else: raise if filesystem is None: filesystem = resolved_filesystem elif need_unwrap_path_protocol: resolved_path = _unwrap_protocol(resolved_path) resolved_path = filesystem.normalize_path(resolved_path) resolved_paths.append(resolved_path) return resolved_paths, filesystem def _is_url(path) -> bool: return urllib.parse.urlparse(path).scheme != "" def _encode_url(path): return urllib.parse.quote(path, safe="/:") def _decode_url(path): return urllib.parse.unquote(path) def _unwrap_protocol(path): """ Slice off any protocol prefixes on path. """ if sys.platform == "win32" and _is_local_windows_path(path): # Represent as posix path such that downstream functions properly handle it. # This is executed when 'file://' is NOT included in the path. return pathlib.Path(path).as_posix() parsed = urllib.parse.urlparse(path, allow_fragments=False) # support '#' in path query = "?" + parsed.query if parsed.query else "" # support '?' in path netloc = parsed.netloc if parsed.scheme == "s3" and "@" in parsed.netloc: # If the path contains an @, it is assumed to be an anonymous # credentialed path, and we need to strip off the credentials. netloc = parsed.netloc.split("@")[-1] parsed_path = parsed.path # urlparse prepends the path with a '/'. This does not work on Windows # so if this is the case strip the leading slash. if ( sys.platform == "win32" and not netloc and len(parsed_path) >= 3 and parsed_path[0] == "/" # The problematic leading slash and parsed_path[1].isalpha() # Ensure it is a drive letter. and parsed_path[2:4] in (":", ":/") ): parsed_path = parsed_path[1:] return netloc + parsed_path + query def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"): # This is needed because pa.fs.S3FileSystem assumes pa.fs is already # imported before deserialization. See #17085. import pyarrow as pa import pyarrow.fs if isinstance(filesystem, pa.fs.S3FileSystem): return _S3FileSystemWrapper(filesystem) return filesystem def _unwrap_s3_serialization_workaround( filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"] ): if isinstance(filesystem, _S3FileSystemWrapper): return filesystem.unwrap() else: return filesystem class _S3FileSystemWrapper: def __init__(self, fs: "pyarrow.fs.S3FileSystem"): self._fs = fs def unwrap(self): return self._fs @classmethod def _reconstruct(cls, fs_reconstruct, fs_args): # Implicitly trigger S3 subsystem initialization by importing # pyarrow.fs. import pyarrow.fs # noqa: F401 return cls(fs_reconstruct(*fs_args)) def __reduce__(self): return _S3FileSystemWrapper._reconstruct, self._fs.__reduce__() def _wrap_arrow_serialization_workaround(kwargs: dict) -> dict: if "filesystem" in kwargs: kwargs["filesystem"] = _wrap_s3_serialization_workaround(kwargs["filesystem"]) return kwargs def _unwrap_arrow_serialization_workaround(kwargs: dict) -> dict: if isinstance(kwargs.get("filesystem"), _S3FileSystemWrapper): kwargs["filesystem"] = kwargs["filesystem"].unwrap() return kwargs def _resolve_kwargs( kwargs_fn: Callable[[], Dict[str, Any]], **kwargs ) -> Dict[str, Any]: if kwargs_fn: kwarg_overrides = kwargs_fn() kwargs.update(kwarg_overrides) return kwargs Uri = TypeVar("Uri") Meta = TypeVar("Meta") def _fetch_metadata_parallel( uris: List[Uri], fetch_func: Callable[[List[Uri]], List[Meta]], desired_uris_per_task: int, **ray_remote_args, ) -> Iterator[Meta]: """Fetch file metadata in parallel using Ray tasks.""" remote_fetch_func = cached_remote_fn(fetch_func, num_cpus=0.5) if ray_remote_args: remote_fetch_func = remote_fetch_func.options(**ray_remote_args) # Choose a parallelism that results in a # of metadata fetches per task that # dominates the Ray task overhead while ensuring good parallelism. # Always launch at least 2 parallel fetch tasks. parallelism = max(len(uris) // desired_uris_per_task, 2) metadata_fetch_bar = ProgressBar("Metadata Fetch Progress", total=parallelism) fetch_tasks = [] for uri_chunk in np.array_split(uris, parallelism): if len(uri_chunk) == 0: continue fetch_tasks.append(remote_fetch_func.remote(uri_chunk)) results = metadata_fetch_bar.fetch_until_complete(fetch_tasks) yield from itertools.chain.from_iterable(results)