ray.data.read_tfrecords#

ray.data.read_tfrecords(paths: Union[str, List[str]], *, filesystem: Optional[pyarrow.fs.FileSystem] = None, parallelism: int = -1, arrow_open_stream_args: Optional[Dict[str, Any]] = None, meta_provider: ray.data.datasource.file_meta_provider.BaseFileMetadataProvider = <ray.data.datasource.file_meta_provider.DefaultFileMetadataProvider object>, partition_filter: Optional[ray.data.datasource.partitioning.PathPartitionFilter] = None, ignore_missing_paths: bool = False, tf_schema: Optional[schema_pb2.Schema] = None) ray.data.dataset.Dataset[source]#

Create a dataset from TFRecord files that contain tf.train.Example messages.

Warning

This function exclusively supports tf.train.Example messages. If a file contains a message that isn’t of type tf.train.Example, then this function errors.

Examples

>>> import os
>>> import tempfile
>>> import tensorflow as tf
>>> features = tf.train.Features(
...     feature={
...         "length": tf.train.Feature(float_list=tf.train.FloatList(value=[5.1])),
...         "width": tf.train.Feature(float_list=tf.train.FloatList(value=[3.5])),
...         "species": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"setosa"])),
...     }
... )
>>> example = tf.train.Example(features=features)
>>> path = os.path.join(tempfile.gettempdir(), "data.tfrecords")
>>> with tf.io.TFRecordWriter(path=path) as writer:
...     writer.write(example.SerializeToString())

This function reads tf.train.Example messages into a tabular Dataset.

>>> import ray
>>> ds = ray.data.read_tfrecords(path)
>>> ds.to_pandas()  
   length  width    species
0     5.1    3.5  b'setosa'

We can also read compressed TFRecord files which uses one of the compression type supported by Arrow:

>>> compressed_path = os.path.join(tempfile.gettempdir(), "data_compressed.tfrecords")
>>> options = tf.io.TFRecordOptions(compression_type="GZIP") # "ZLIB" also supported by TensorFlow
>>> with tf.io.TFRecordWriter(path=compressed_path, options=options) as writer:
...     writer.write(example.SerializeToString())
>>> ds = ray.data.read_tfrecords(
...     [compressed_path],
...     arrow_open_stream_args={"compression": "gzip"},
... )
>>> ds.to_pandas()  
   length  width    species
0     5.1    3.5  b'setosa'
Parameters
  • paths – A single file/directory path or a list of file/directory paths. A list of paths can contain both files and directories.

  • filesystem – The filesystem implementation to read from.

  • parallelism – The requested parallelism of the read. Parallelism may be limited by the number of files in the dataset.

  • arrow_open_stream_args – Key-word arguments passed to pyarrow.fs.FileSystem.open_input_stream. To read a compressed TFRecord file, pass the corresponding compression type (e.g. for GZIP or ZLIB, use arrow_open_stream_args={'compression_type': 'gzip'}).

  • meta_provider – File metadata provider. Custom metadata providers may be able to resolve file metadata more quickly and/or accurately.

  • partition_filter – Path-based partition filter, if any. Can be used with a custom callback to read only selected partitions of a dataset. By default, this filters out any file paths whose file extension does not match "*.tfrecords*".

  • ignore_missing_paths – If True, ignores any file paths in paths that are not found. Defaults to False.

  • tf_schema – Optional TensorFlow Schema which is used to explicitly set the schema of the underlying Dataset.

Returns

A Dataset that contains the example features.

Raises

ValueError – If a file contains a message that isn’t a tf.train.Example.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.