import itertools
import logging
import os
import pathlib
import re
from typing import (
TYPE_CHECKING,
Callable,
Iterator,
List,
Optional,
Tuple,
TypeVar,
Union,
)
import numpy as np
import ray
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.util import call_with_retry
from ray.data.block import BlockMetadata
from ray.data.datasource.partitioning import Partitioning
from ray.util.annotations import DeveloperAPI
if TYPE_CHECKING:
import pyarrow
logger = logging.getLogger(__name__)
def _handle_read_os_error(error: OSError, paths: Union[str, List[str]]) -> str:
# NOTE: this is not comprehensive yet, and should be extended as more errors arise.
# NOTE: The latter patterns are raised in Arrow 10+, while the former is raised in
# Arrow < 10.
aws_error_pattern = (
r"^(?:(.*)AWS Error \[code \d+\]: No response body\.(.*))|"
r"(?:(.*)AWS Error UNKNOWN \(HTTP status 400\) during HeadObject operation: "
r"No response body\.(.*))|"
r"(?:(.*)AWS Error ACCESS_DENIED during HeadObject operation: No response "
r"body\.(.*))$"
)
if re.match(aws_error_pattern, str(error)):
# Specially handle AWS error when reading files, to give a clearer error
# message to avoid confusing users. The real issue is most likely that the AWS
# S3 file credentials have not been properly configured yet.
if isinstance(paths, str):
# Quote to highlight single file path in error message for better
# readability. List of file paths will be shown up as ['foo', 'boo'],
# so only quote single file path here.
paths = f'"{paths}"'
raise OSError(
(
f"Failing to read AWS S3 file(s): {paths}. "
"Please check that file exists and has properly configured access. "
"You can also run AWS CLI command to get more detailed error message "
"(e.g., aws s3 ls <file-name>). "
"See https://awscli.amazonaws.com/v2/documentation/api/latest/reference/s3/index.html " # noqa
"and https://docs.ray.io/en/latest/data/creating-datasets.html#reading-from-remote-storage " # noqa
"for more information."
)
)
else:
raise error
def _expand_paths(
paths: List[str],
filesystem: "pyarrow.fs.FileSystem",
partitioning: Optional[Partitioning],
ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
"""Get the file sizes for all provided file paths."""
from pyarrow.fs import LocalFileSystem
from ray.data.datasource.file_based_datasource import (
FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD,
)
from ray.data.datasource.path_util import _unwrap_protocol
# We break down our processing paths into a few key cases:
# 1. If len(paths) < threshold, fetch the file info for the individual files/paths
# serially.
# 2. If all paths are contained under the same parent directory (or base directory,
# if using partitioning), fetch all file infos at this prefix and filter to the
# provided paths on the client; this should be a single file info request.
# 3. If more than threshold requests required, parallelize them via Ray tasks.
# 1. Small # of paths case.
if (
len(paths) < FILE_SIZE_FETCH_PARALLELIZATION_THRESHOLD
# Local file systems are very fast to hit.
or isinstance(filesystem, LocalFileSystem)
):
yield from _get_file_infos_serial(paths, filesystem, ignore_missing_paths)
else:
# 2. Common path prefix case.
# Get longest common path of all paths.
common_path = os.path.commonpath(paths)
# If parent directory (or base directory, if using partitioning) is common to
# all paths, fetch all file infos at that prefix and filter the response to the
# provided paths.
if (
partitioning is not None
and common_path == _unwrap_protocol(partitioning.base_dir)
) or all(str(pathlib.Path(path).parent) == common_path for path in paths):
yield from _get_file_infos_common_path_prefix(
paths, common_path, filesystem, ignore_missing_paths
)
# 3. Parallelization case.
else:
# Parallelize requests via Ray tasks.
yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths)
def _get_file_infos_serial(
paths: List[str],
filesystem: "pyarrow.fs.FileSystem",
ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
for path in paths:
yield from _get_file_infos(path, filesystem, ignore_missing_paths)
def _get_file_infos_common_path_prefix(
paths: List[str],
common_path: str,
filesystem: "pyarrow.fs.FileSystem",
ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
path_to_size = {path: None for path in paths}
for path, file_size in _get_file_infos(
common_path, filesystem, ignore_missing_paths
):
if path in path_to_size:
path_to_size[path] = file_size
# Check if all `paths` have file size metadata.
# If any of paths has no file size, fall back to get files metadata in parallel.
# This can happen when path is a directory, but not a file.
have_missing_path = False
for path in paths:
if path_to_size[path] is None:
logger.debug(
f"Finding path {path} not have file size metadata. "
"Fall back to get files metadata in parallel for all paths."
)
have_missing_path = True
break
if have_missing_path:
# Parallelize requests via Ray tasks.
yield from _get_file_infos_parallel(paths, filesystem, ignore_missing_paths)
else:
# Iterate over `paths` to yield each path in original order.
# NOTE: do not iterate over `path_to_size` because the dictionary skips
# duplicated path, while `paths` might contain duplicated path if one wants
# to read same file multiple times.
for path in paths:
yield path, path_to_size[path]
def _get_file_infos_parallel(
paths: List[str],
filesystem: "pyarrow.fs.FileSystem",
ignore_missing_paths: bool = False,
) -> Iterator[Tuple[str, int]]:
from ray.data.datasource.file_based_datasource import (
PATHS_PER_FILE_SIZE_FETCH_TASK,
_unwrap_s3_serialization_workaround,
_wrap_s3_serialization_workaround,
)
logger.warning(
f"Expanding {len(paths)} path(s). This may be a HIGH LATENCY "
f"operation on some cloud storage services. Moving all the "
"paths to a common parent directory will lead to faster "
"metadata fetching."
)
# Capture the filesystem in the fetcher func closure, but wrap it in our
# serialization workaround to make sure that the pickle roundtrip works as expected.
filesystem = _wrap_s3_serialization_workaround(filesystem)
def _file_infos_fetcher(paths: List[str]) -> List[Tuple[str, int]]:
fs = _unwrap_s3_serialization_workaround(filesystem)
return list(
itertools.chain.from_iterable(
_get_file_infos(path, fs, ignore_missing_paths) for path in paths
)
)
yield from _fetch_metadata_parallel(
paths, _file_infos_fetcher, PATHS_PER_FILE_SIZE_FETCH_TASK
)
Uri = TypeVar("Uri")
Meta = TypeVar("Meta")
def _fetch_metadata_parallel(
uris: List[Uri],
fetch_func: Callable[[List[Uri]], List[Meta]],
desired_uris_per_task: int,
**ray_remote_args,
) -> Iterator[Meta]:
"""Fetch file metadata in parallel using Ray tasks."""
remote_fetch_func = cached_remote_fn(fetch_func)
if ray_remote_args:
remote_fetch_func = remote_fetch_func.options(**ray_remote_args)
# Choose a parallelism that results in a # of metadata fetches per task that
# dominates the Ray task overhead while ensuring good parallelism.
# Always launch at least 2 parallel fetch tasks.
parallelism = max(len(uris) // desired_uris_per_task, 2)
metadata_fetch_bar = ProgressBar(
"Metadata Fetch Progress", total=parallelism, unit="task"
)
fetch_tasks = []
for uri_chunk in np.array_split(uris, parallelism):
if len(uri_chunk) == 0:
continue
fetch_tasks.append(remote_fetch_func.remote(uri_chunk))
results = metadata_fetch_bar.fetch_until_complete(fetch_tasks)
yield from itertools.chain.from_iterable(results)
def _get_file_infos(
path: str, filesystem: "pyarrow.fs.FileSystem", ignore_missing_path: bool = False
) -> List[Tuple[str, int]]:
"""Get the file info for all files at or under the provided path."""
from pyarrow.fs import FileType
file_infos = []
try:
ctx = ray.data.DataContext.get_current()
file_info = call_with_retry(
lambda: filesystem.get_file_info(path),
description="get file info",
match=ctx.retried_io_errors,
)
except OSError as e:
_handle_read_os_error(e, path)
if file_info.type == FileType.Directory:
for file_path, file_size in _expand_directory(path, filesystem):
file_infos.append((file_path, file_size))
elif file_info.type == FileType.File:
file_infos.append((path, file_info.size))
elif file_info.type == FileType.NotFound and ignore_missing_path:
pass
else:
raise FileNotFoundError(path)
return file_infos
def _expand_directory(
path: str,
filesystem: "pyarrow.fs.FileSystem",
exclude_prefixes: Optional[List[str]] = None,
ignore_missing_path: bool = False,
) -> List[Tuple[str, int]]:
"""
Expand the provided directory path to a list of file paths.
Args:
path: The directory path to expand.
filesystem: The filesystem implementation that should be used for
reading these files.
exclude_prefixes: The file relative path prefixes that should be
excluded from the returned file set. Default excluded prefixes are
"." and "_".
Returns:
An iterator of (file_path, file_size) tuples.
"""
if exclude_prefixes is None:
exclude_prefixes = [".", "_"]
from pyarrow.fs import FileSelector
selector = FileSelector(path, recursive=True, allow_not_found=ignore_missing_path)
files = filesystem.get_file_info(selector)
base_path = selector.base_dir
out = []
for file_ in files:
if not file_.is_file:
continue
file_path = file_.path
if not file_path.startswith(base_path):
continue
relative = file_path[len(base_path) :]
if any(relative.startswith(prefix) for prefix in exclude_prefixes):
continue
out.append((file_path, file_.size))
# We sort the paths to guarantee a stable order.
return sorted(out)