import logging
import struct
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union
import pyarrow
from ray.air.util.tensor_extensions.arrow import pyarrow_table_from_pydict
from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
import pandas as pd
import tensorflow as tf
from tensorflow_metadata.proto.v0 import schema_pb2
from ray.data.dataset import Dataset
logger = logging.getLogger(__name__)
[docs]
@PublicAPI(stability="alpha")
@dataclass
class TFXReadOptions:
"""
Specifies read options when reading TFRecord files with TFX.
"""
# An int representing the number of consecutive elements of
# this dataset to combine in a single batch when tfx-bsl is used to read
# the tfrecord files.
batch_size: int = 2048
# Toggles the schema inference applied; applicable
# only if tfx-bsl is used and tf_schema argument is missing.
# Defaults to True.
auto_infer_schema: bool = True
class TFRecordDatasource(FileBasedDatasource):
"""TFRecord datasource, for reading and writing TFRecord files."""
_FILE_EXTENSIONS = ["tfrecords"]
def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
tfx_read_options: Optional["TFXReadOptions"] = None,
**file_based_datasource_kwargs,
):
"""
Args:
tf_schema: Optional TensorFlow Schema which is used to explicitly set
the schema of the underlying Dataset.
tfx_read_options: Optional options for enabling reading tfrecords
using tfx-bsl.
"""
super().__init__(paths, **file_based_datasource_kwargs)
self._tf_schema = tf_schema
self._tfx_read_options = tfx_read_options
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._tfx_read_options:
yield from self._tfx_read_stream(f, path)
else:
yield from self._default_read_stream(f, path)
def _default_read_stream(
self, f: "pyarrow.NativeFile", path: str
) -> Iterator[Block]:
import tensorflow as tf
from google.protobuf.message import DecodeError
for record in _read_records(f, path):
example = tf.train.Example()
try:
example.ParseFromString(record)
except DecodeError as e:
raise ValueError(
"`TFRecordDatasource` failed to parse `tf.train.Example` "
f"record in '{path}'. This error can occur if your TFRecord "
f"file contains a message type other than `tf.train.Example`: {e}"
)
yield pyarrow_table_from_pydict(
_convert_example_to_dict(example, self._tf_schema)
)
def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder
full_path = self._resolve_full_path(path)
compression = (self._open_stream_args or {}).get("compression", None)
if compression:
compression = compression.upper()
tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)
decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._tfx_read_options.batch_size):
yield _cast_large_list_to_list(
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
)
except Exception as error:
logger.exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error
# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")
def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"
return relative_path
def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
) -> Dict[str, pyarrow.Array]:
record = {}
schema_dict = {}
# Convert user-specified schema into dict for convenient mapping
if tf_schema is not None:
for schema_feature in tf_schema.feature:
schema_dict[schema_feature.name] = schema_feature.type
for feature_name, feature in example.features.feature.items():
if tf_schema is not None and feature_name not in schema_dict:
raise ValueError(
f"Found extra unexpected feature {feature_name} "
f"not in specified schema: {tf_schema}"
)
schema_feature_type = schema_dict.get(feature_name)
record[feature_name] = _get_feature_value(feature, schema_feature_type)
return record
def _get_single_true_type(dct) -> str:
"""Utility function for getting the single key which has a `True` value in
a dict. Used to filter a dict of `{field_type: is_valid}` to get
the field type from a schema or data source."""
filtered_types = iter([_type for _type in dct if dct[_type]])
# In the case where there are no keys with a `True` value, return `None`
return next(filtered_types, None)
def _get_feature_value(
feature: "tf.train.Feature",
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
) -> pyarrow.Array:
import pyarrow as pa
underlying_feature_type = {
"bytes": feature.HasField("bytes_list"),
"float": feature.HasField("float_list"),
"int": feature.HasField("int64_list"),
}
# At most one of `bytes_list`, `float_list`, and `int64_list`
# should contain values. If none contain data, this indicates
# an empty feature value.
assert sum(bool(value) for value in underlying_feature_type.values()) <= 1
if schema_feature_type is not None:
try:
from tensorflow_metadata.proto.v0 import schema_pb2
except ModuleNotFoundError:
raise ModuleNotFoundError(
"To use TensorFlow schemas, please install "
"the tensorflow-metadata package."
)
# If a schema is specified, compare to the underlying type
specified_feature_type = {
"bytes": schema_feature_type == schema_pb2.FeatureType.BYTES,
"float": schema_feature_type == schema_pb2.FeatureType.FLOAT,
"int": schema_feature_type == schema_pb2.FeatureType.INT,
}
und_type = _get_single_true_type(underlying_feature_type)
spec_type = _get_single_true_type(specified_feature_type)
if und_type is not None and und_type != spec_type:
raise ValueError(
"Schema field type mismatch during read: specified type is "
f"{spec_type}, but underlying type is {und_type}",
)
# Override the underlying value type with the type in the user-specified schema.
underlying_feature_type = specified_feature_type
if underlying_feature_type["bytes"]:
value = feature.bytes_list.value
type_ = pa.binary()
elif underlying_feature_type["float"]:
value = feature.float_list.value
type_ = pa.float32()
elif underlying_feature_type["int"]:
value = feature.int64_list.value
type_ = pa.int64()
else:
value = []
type_ = pa.null()
value = list(value)
if len(value) == 1 and schema_feature_type is None:
# Use the value itself if the features contains a single value.
# This is to give better user experience when writing preprocessing UDF on
# these single-value lists.
value = value[0]
else:
# If the feature value is empty and no type is specified in the user-provided
# schema, set the type to null for now to allow pyarrow to construct a valid
# Array; later, infer the type from other records which have non-empty values
# for the feature.
if len(value) == 0:
type_ = pa.null()
type_ = pa.list_(type_)
return pa.array([value], type=type_)
# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/reader.py#L16-L96 # noqa: E501
#
# MIT License
#
# Copyright (c) 2020 Vahid Kazemi
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
def _read_records(
file: "pyarrow.NativeFile",
path: str,
) -> Iterable[memoryview]:
"""
Read records from TFRecord file.
A TFRecord file contains a sequence of records. The file can only be read
sequentially. Each record is stored in the following formats:
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
See https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecords_format_details
for more details.
"""
length_bytes = bytearray(8)
crc_bytes = bytearray(4)
datum_bytes = bytearray(1024 * 1024)
row_count = 0
while True:
try:
# Read "length" field.
num_length_bytes_read = file.readinto(length_bytes)
if num_length_bytes_read == 0:
break
elif num_length_bytes_read != 8:
raise ValueError(
"Failed to read the length of record data. Expected 8 bytes but "
"got {num_length_bytes_read} bytes."
)
# Read "masked_crc32_of_length" field.
num_length_crc_bytes_read = file.readinto(crc_bytes)
if num_length_crc_bytes_read != 4:
raise ValueError(
"Failed to read the length of CRC-32C hashes. Expected 4 bytes "
"but got {num_length_crc_bytes_read} bytes."
)
# Read "data[length]" field.
(data_length,) = struct.unpack("<Q", length_bytes)
if data_length > len(datum_bytes):
datum_bytes = datum_bytes.zfill(int(data_length * 1.5))
datum_bytes_view = memoryview(datum_bytes)[:data_length]
num_datum_bytes_read = file.readinto(datum_bytes_view)
if num_datum_bytes_read != data_length:
raise ValueError(
f"Failed to read the record. Exepcted {data_length} bytes but got "
f"{num_datum_bytes_read} bytes."
)
# Read "masked_crc32_of_data" field.
# TODO(chengsu): ideally we should check CRC-32C against the actual data.
num_crc_bytes_read = file.readinto(crc_bytes)
if num_crc_bytes_read != 4:
raise ValueError(
"Failed to read the CRC-32C hashes. Expected 4 bytes but got "
f"{num_crc_bytes_read} bytes."
)
# Return the data.
yield datum_bytes_view
row_count += 1
data_length = None
except Exception as e:
error_message = (
f"Failed to read TFRecord file {path}. Please ensure that the "
f"TFRecord file has correct format. Already read {row_count} rows."
)
if data_length is not None:
error_message += f" Byte size of current record data is {data_length}."
raise RuntimeError(error_message) from e
def _cast_large_list_to_list(batch: pyarrow.Table):
"""
This function transform pyarrow.large_list into list and pyarrow.large_binary into
pyarrow.binary so that all types resulting from the tfrecord_datasource are usable
with dataset.to_tf().
"""
old_schema = batch.schema
fields = {}
for column_name in old_schema.names:
field_type = old_schema.field(column_name).type
if type(field_type) == pyarrow.lib.LargeListType:
value_type = field_type.value_type
if value_type == pyarrow.large_binary():
value_type = pyarrow.binary()
fields[column_name] = pyarrow.list_(value_type)
elif field_type == pyarrow.large_binary():
fields[column_name] = pyarrow.binary()
else:
fields[column_name] = old_schema.field(column_name)
new_schema = pyarrow.schema(fields)
return batch.cast(new_schema)
def _infer_schema_and_transform(dataset: "Dataset"):
list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names))
return dataset.map_batches(
_unwrap_single_value_lists,
fn_kwargs={"col_lengths": list_sizes["max_list_size"]},
batch_format="pyarrow",
)
def _unwrap_single_value_lists(batch: pyarrow.Table, col_lengths: Dict[str, int]):
"""
This function will transfrom the dataset converting list types that always
contain single values to thery underlying data type
(i.e. pyarrow.int64() and pyarrow.float64())
"""
columns = {}
for col in col_lengths:
value_type = batch[col].type.value_type
if col_lengths[col] == 1:
if batch[col]:
columns[col] = pyarrow.array(
[x.as_py()[0] if x.as_py() else None for x in batch[col]],
type=value_type,
)
else:
columns[col] = batch[col]
return pyarrow.table(columns)
class _MaxListSize(AggregateFn):
def __init__(self, columns: List[str]):
self._columns = columns
super().__init__(
init=self._init,
merge=self._merge,
accumulate_row=self._accumulate_row,
finalize=lambda a: a,
name="max_list_size",
)
def _init(self, k: str):
return {col: 0 for col in self._columns}
def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]):
merged = {}
for col in self._columns:
merged[col] = max(acc1[col], acc2[col])
return merged
def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"):
for k in row:
value = row[k]
if value:
acc[k] = max(len(value), acc[k])
return acc