import io
import logging
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Optional,
Union,
)
import numpy as np
import ray
from ray.data._internal.util import (
RetryingContextManager,
RetryingPyFileSystem,
_check_pyarrow_version,
_is_local_scheme,
infer_compression,
iterate_with_retry,
make_async_gen,
)
from ray.data.block import Block, BlockAccessor
from ray.data.context import DataContext
from ray.data.datasource.datasource import Datasource, ReadTask
from ray.data.datasource.file_meta_provider import (
BaseFileMetadataProvider,
DefaultFileMetadataProvider,
)
from ray.data.datasource.partitioning import (
Partitioning,
PathPartitionFilter,
PathPartitionParser,
)
from ray.data.datasource.path_util import (
_has_file_extension,
_resolve_paths_and_filesystem,
)
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
import pandas as pd
import pyarrow
logger = logging.getLogger(__name__)
# We should parallelize file size fetch operations beyond this threshold.
FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD = 16
# 16 file size fetches from S3 takes ~1.5 seconds with Arrow's S3FileSystem.
PATHS_PER_FILE_SIZE_FETCH_TASK = 16
[docs]
@DeveloperAPI
@dataclass
class FileShuffleConfig:
"""Configuration for file shuffling.
This configuration object controls how files are shuffled while reading file-based
datasets.
.. note::
Even if you provided a seed, you might still observe a non-deterministic row
order. This is because tasks are executed in parallel and their completion
order might vary. If you need to preserve the order of rows, set
`DataContext.get_current().execution_options.preserve_order`.
Args:
seed: An optional integer seed for the file shuffler. If provided, Ray Data
shuffles files deterministically based on this seed.
Example:
>>> import ray
>>> from ray.data import FileShuffleConfig
>>> shuffle = FileShuffleConfig(seed=42)
>>> ds = ray.data.read_images("s3://anonymous@ray-example-data/batoidea", shuffle=shuffle)
""" # noqa: E501
seed: Optional[int] = None
def __post_init__(self):
"""Ensure that the seed is either None or an integer."""
if self.seed is not None and not isinstance(self.seed, int):
raise ValueError("Seed must be an integer or None.")
[docs]
@DeveloperAPI
class FileBasedDatasource(Datasource):
"""File-based datasource for reading files.
Don't use this class directly. Instead, subclass it and implement `_read_stream()`.
"""
# If `_WRITE_FILE_PER_ROW` is `True`, this datasource calls `_write_row` and writes
# each row to a file. Otherwise, this datasource calls `_write_block` and writes
# each block to a file.
_WRITE_FILE_PER_ROW = False
_FILE_EXTENSIONS: Optional[Union[str, List[str]]] = None
# Number of threads for concurrent reading within each read task.
# If zero or negative, reading will be performed in the main thread.
_NUM_THREADS_PER_TASK = 0
def __init__(
self,
paths: Union[str, List[str]],
*,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
schema: Optional[Union[type, "pyarrow.lib.Schema"]] = None,
open_stream_args: Optional[Dict[str, Any]] = None,
meta_provider: BaseFileMetadataProvider = DefaultFileMetadataProvider(),
partition_filter: PathPartitionFilter = None,
partitioning: Partitioning = None,
ignore_missing_paths: bool = False,
shuffle: Optional[Union[Literal["files"], FileShuffleConfig]] = None,
include_paths: bool = False,
file_extensions: Optional[List[str]] = None,
):
super().__init__()
_check_pyarrow_version()
self._supports_distributed_reads = not _is_local_scheme(paths)
if not self._supports_distributed_reads and ray.util.client.ray.is_connected():
raise ValueError(
"Because you're using Ray Client, read tasks scheduled on the Ray "
"cluster can't access your local files. To fix this issue, store "
"files in cloud storage or a distributed filesystem like NFS."
)
self._schema = schema
self._data_context = DataContext.get_current()
self._open_stream_args = open_stream_args
self._meta_provider = meta_provider
self._partition_filter = partition_filter
self._partitioning = partitioning
self._ignore_missing_paths = ignore_missing_paths
self._include_paths = include_paths
# Need this property for lineage tracking
self._source_paths = paths
paths, self._filesystem = _resolve_paths_and_filesystem(paths, filesystem)
self._filesystem = RetryingPyFileSystem.wrap(
self._filesystem, retryable_errors=self._data_context.retried_io_errors
)
paths, file_sizes = map(
list,
zip(
*meta_provider.expand_paths(
paths,
self._filesystem,
partitioning,
ignore_missing_paths=ignore_missing_paths,
)
),
)
if ignore_missing_paths and len(paths) == 0:
raise ValueError(
"None of the provided paths exist. "
"The 'ignore_missing_paths' field is set to True."
)
if self._partition_filter is not None:
# Use partition filter to skip files which are not needed.
path_to_size = dict(zip(paths, file_sizes))
paths = self._partition_filter(paths)
file_sizes = [path_to_size[p] for p in paths]
if len(paths) == 0:
raise ValueError(
"No input files found to read. Please double check that "
"'partition_filter' field is set properly."
)
if file_extensions is not None:
path_to_size = dict(zip(paths, file_sizes))
paths = [p for p in paths if _has_file_extension(p, file_extensions)]
file_sizes = [path_to_size[p] for p in paths]
if len(paths) == 0:
raise ValueError(
"No input files found to read with the following file extensions: "
f"{file_extensions}. Please double check that "
"'file_extensions' field is set properly."
)
_validate_shuffle_arg(shuffle)
self._file_metadata_shuffler = None
if shuffle == "files":
self._file_metadata_shuffler = np.random.default_rng()
elif isinstance(shuffle, FileShuffleConfig):
# Create a NumPy random generator with a fixed seed if provided
self._file_metadata_shuffler = np.random.default_rng(shuffle.seed)
# Read tasks serialize `FileBasedDatasource` instances, and the list of paths
# can be large. To avoid slow serialization speeds, we store a reference to
# the paths rather than the paths themselves.
self._paths_ref = ray.put(paths)
self._file_sizes_ref = ray.put(file_sizes)
def _paths(self) -> List[str]:
return ray.get(self._paths_ref)
def _file_sizes(self) -> List[float]:
return ray.get(self._file_sizes_ref)
def estimate_inmemory_data_size(self) -> Optional[int]:
total_size = 0
for sz in self._file_sizes():
if sz is not None:
total_size += sz
return total_size
def get_read_tasks(
self, parallelism: int, per_task_row_limit: Optional[int] = None
) -> List[ReadTask]:
import numpy as np
open_stream_args = self._open_stream_args
partitioning = self._partitioning
paths = self._paths()
file_sizes = self._file_sizes()
if self._file_metadata_shuffler is not None:
files_metadata = list(zip(paths, file_sizes))
shuffled_files_metadata = [
files_metadata[i]
for i in self._file_metadata_shuffler.permutation(len(files_metadata))
]
paths, file_sizes = list(map(list, zip(*shuffled_files_metadata)))
filesystem = _wrap_s3_serialization_workaround(self._filesystem)
if open_stream_args is None:
open_stream_args = {}
def read_files(
read_paths: Iterable[str],
) -> Iterable[Block]:
nonlocal filesystem, open_stream_args, partitioning
fs = _unwrap_s3_serialization_workaround(filesystem)
for read_path in read_paths:
partitions: Dict[str, str] = {}
if partitioning is not None:
parse = PathPartitionParser(partitioning)
partitions = parse(read_path)
with RetryingContextManager(
self._open_input_source(fs, read_path, **open_stream_args),
context=self._data_context,
) as f:
for block in iterate_with_retry(
lambda: self._read_stream(f, read_path),
description="read stream iteratively",
match=self._data_context.retried_io_errors,
):
if partitions:
block = _add_partitions(block, partitions)
if self._include_paths:
block_accessor = BlockAccessor.for_block(block)
block = block_accessor.fill_column("path", read_path)
yield block
def create_read_task_fn(read_paths, num_threads):
def read_task_fn():
nonlocal num_threads, read_paths
# TODO: We should refactor the code so that we can get the results in
# order even when using multiple threads.
if self._data_context.execution_options.preserve_order:
num_threads = 0
if num_threads > 0:
num_threads = min(num_threads, len(read_paths))
logger.debug(
f"Reading {len(read_paths)} files with {num_threads} threads."
)
yield from make_async_gen(
iter(read_paths),
read_files,
num_workers=num_threads,
preserve_ordering=True,
)
else:
logger.debug(f"Reading {len(read_paths)} files.")
yield from read_files(read_paths)
return read_task_fn
# fix https://github.com/ray-project/ray/issues/24296
parallelism = min(parallelism, len(paths))
read_tasks = []
split_paths = np.array_split(paths, parallelism)
split_file_sizes = np.array_split(file_sizes, parallelism)
for read_paths, file_sizes in zip(split_paths, split_file_sizes):
if len(read_paths) <= 0:
continue
meta = self._meta_provider(
read_paths,
rows_per_file=self._rows_per_file(),
file_sizes=file_sizes,
)
read_task_fn = create_read_task_fn(read_paths, self._NUM_THREADS_PER_TASK)
read_task = ReadTask(
read_task_fn, meta, per_task_row_limit=per_task_row_limit
)
read_tasks.append(read_task)
return read_tasks
[docs]
def resolve_compression(
self, path: str, open_args: Dict[str, Any]
) -> Optional[str]:
"""Resolves the compression format for a stream.
Args:
path: The file path to resolve compression for.
open_args: kwargs passed to
`pyarrow.fs.FileSystem.open_input_stream <https://arrow.apache.org/docs/python/generated/pyarrow.fs.FileSystem.html#pyarrow.fs.FileSystem.open_input_stream>`_
when opening input files to read.
Returns:
The compression format (e.g., "gzip", "snappy", "bz2") or None if
no compression is detected or specified.
"""
compression = open_args.get("compression", None)
if compression is None:
compression = infer_compression(path)
return compression
def _resolve_buffer_size(self, open_args: Dict[str, Any]) -> Optional[int]:
buffer_size = open_args.pop("buffer_size", None)
if buffer_size is None:
buffer_size = self._data_context.streaming_read_buffer_size
return buffer_size
def _file_to_snappy_stream(
self,
file: "pyarrow.NativeFile",
filesystem: "RetryingPyFileSystem",
) -> "pyarrow.PythonFile":
import pyarrow as pa
import snappy
from pyarrow.fs import HadoopFileSystem
stream = io.BytesIO()
if isinstance(filesystem.unwrap(), HadoopFileSystem):
snappy.hadoop_snappy.stream_decompress(src=file, dst=stream)
else:
snappy.stream_decompress(src=file, dst=stream)
stream.seek(0)
return pa.PythonFile(stream, mode="r")
def _open_input_source(
self,
filesystem: "RetryingPyFileSystem",
path: str,
**open_args,
) -> "pyarrow.NativeFile":
"""Opens a source path for reading and returns the associated Arrow NativeFile.
The default implementation opens the source path as a sequential input stream,
using self._data_context.streaming_read_buffer_size as the buffer size if none
is given by the caller.
Implementations that do not support streaming reads (e.g. that require random
access) should override this method.
"""
compression = self.resolve_compression(path, open_args)
buffer_size = self._resolve_buffer_size(open_args)
if compression == "snappy":
# Arrow doesn't support streaming Snappy decompression since the canonical
# C++ Snappy library doesn't natively support streaming decompression. We
# works around this by manually decompressing the file with python-snappy.
open_args["compression"] = None
file = filesystem.open_input_stream(
path, buffer_size=buffer_size, **open_args
)
return self._file_to_snappy_stream(file, filesystem)
open_args["compression"] = compression
return filesystem.open_input_stream(path, buffer_size=buffer_size, **open_args)
def _rows_per_file(self):
"""Returns the number of rows per file, or None if unknown."""
return None
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
"""Streaming read a single file.
This method should be implemented by subclasses.
"""
raise NotImplementedError(
"Subclasses of FileBasedDatasource must implement _read_stream()."
)
@property
def supports_distributed_reads(self) -> bool:
return self._supports_distributed_reads
def _add_partitions(
data: Union["pyarrow.Table", "pd.DataFrame"], partitions: Dict[str, Any]
) -> Union["pyarrow.Table", "pd.DataFrame"]:
import pandas as pd
import pyarrow as pa
assert isinstance(data, (pa.Table, pd.DataFrame))
if isinstance(data, pa.Table):
return _add_partitions_to_table(data, partitions)
if isinstance(data, pd.DataFrame):
return _add_partitions_to_dataframe(data, partitions)
def _add_partitions_to_table(
table: "pyarrow.Table", partitions: Dict[str, Any]
) -> "pyarrow.Table":
import pyarrow as pa
import pyarrow.compute as pc
column_names = set(table.column_names)
for field, value in partitions.items():
column = pa.array([value] * len(table))
if field in column_names:
# TODO: Handle cast error.
column_type = table.schema.field(field).type
column = column.cast(column_type)
values_are_equal = pc.all(pc.equal(column, table[field]))
values_are_equal = values_are_equal.as_py()
if not values_are_equal:
raise ValueError(
f"Partition column {field} exists in table data, but partition "
f"value '{value}' is different from in-data values: "
f"{table[field].unique().to_pylist()}."
)
i = table.schema.get_field_index(field)
table = table.set_column(i, field, column)
else:
table = table.append_column(field, column)
return table
def _add_partitions_to_dataframe(
df: "pd.DataFrame", partitions: Dict[str, Any]
) -> "pd.DataFrame":
import pandas as pd
for field, value in partitions.items():
column = pd.Series(data=[value] * len(df), name=field)
if field in df:
column = column.astype(df[field].dtype)
mask = df[field].notna()
if not df[field][mask].equals(column[mask]):
raise ValueError(
f"Partition column {field} exists in table data, but partition "
f"value '{value}' is different from in-data values: "
f"{list(df[field].unique())}."
)
df[field] = column
return df
def _wrap_s3_serialization_workaround(filesystem: "pyarrow.fs.FileSystem"):
# This is needed because pa.fs.S3FileSystem assumes pa.fs is already
# imported before deserialization. See #17085.
import pyarrow as pa
import pyarrow.fs
base_fs = filesystem
if isinstance(filesystem, RetryingPyFileSystem):
base_fs = filesystem.unwrap()
if isinstance(base_fs, pa.fs.S3FileSystem):
return _S3FileSystemWrapper(filesystem)
return filesystem
def _unwrap_s3_serialization_workaround(
filesystem: Union["pyarrow.fs.FileSystem", "_S3FileSystemWrapper"],
):
if isinstance(filesystem, _S3FileSystemWrapper):
filesystem = filesystem.unwrap()
return filesystem
class _S3FileSystemWrapper:
"""pyarrow.fs.S3FileSystem wrapper that can be deserialized safely.
Importing pyarrow.fs during reconstruction triggers the pyarrow
S3 subsystem initialization.
NOTE: This is only needed for pyarrow<14.0.0 and should be removed
once the minimum supported pyarrow version exceeds that.
See https://github.com/apache/arrow/pull/38375 for context.
"""
def __init__(self, fs: "pyarrow.fs.FileSystem"):
self._fs = fs
def unwrap(self):
return self._fs
@classmethod
def _reconstruct(cls, fs_reconstruct, fs_args):
# Implicitly trigger S3 subsystem initialization by importing
# pyarrow.fs.
import pyarrow.fs # noqa: F401
return cls(fs_reconstruct(*fs_args))
def __reduce__(self):
return _S3FileSystemWrapper._reconstruct, self._fs.__reduce__()
def _resolve_kwargs(
kwargs_fn: Callable[[], Dict[str, Any]], **kwargs
) -> Dict[str, Any]:
if kwargs_fn:
kwarg_overrides = kwargs_fn()
kwargs.update(kwarg_overrides)
return kwargs
def _validate_shuffle_arg(
shuffle: Union[Literal["files"], FileShuffleConfig, None],
) -> None:
if not (
shuffle is None or shuffle == "files" or isinstance(shuffle, FileShuffleConfig)
):
raise ValueError(
f"Invalid value for 'shuffle': {shuffle}. "
"Valid values are None, 'files', `FileShuffleConfig`."
)