Consuming Datasets

The data underlying a Dataset can be consumed in several ways:

  • Retrieving a limited prefix of rows.

  • Iterating over rows and batches.

  • Saving to files.

Retrieving a limited set of rows

A limited set of rows can be retried from a Dataset via the ds.take() API, along with its sibling helper APIs ds.take_all(), for retrieving all rows, and ds.show(), for printing a limited set of rows. These methods are convenient for quickly inspecting a subset (prefix) of rows. They have the benefit that, if used right after reading, they will only trigger more files to be read if needed to retrieve rows from that file; if inspecting a small prefix of rows, often only the first file will need to be read.

import ray

ds = ray.data.range(10000)

print(ds.take(5))
# -> [0, 1, 2, 3, 4]

# Warning: This will print all of the rows!
print(ds.take_all())

ds.show(5)
# -> 0
#    1
#    2
#    3
#    4

Iterating over Datasets

Datasets can be consumed a row at a time using the ds.iter_rows() API

import ray

ds = ray.data.range(10000)
num_rows = 0

# Consume all rows in the Dataset.
for row in ds.iter_rows():
    assert isinstance(row, int)
    num_rows += 1

print(num_rows)
# -> 10000

or a batch at a time using the ds.iter_batches() API, where you can specify batch size as well as the desired batch format. By default, the batch format is "default". For tabular data, the default format is a Pandas DataFrame; for Python objects, it’s a list.

import ray
import pandas as pd

ds = ray.data.range(10000)
num_batches = 0

# Consume all batches in the Dataset.
for batch in ds.iter_batches(batch_size=2):
    assert isinstance(batch, list)
    num_batches += 1

print(num_batches)
# -> 5000

# Consume data as Pandas DataFrame batches.
cum_sum = 0
for batch in ds.iter_batches(batch_size=2, batch_format="pandas"):
    assert isinstance(batch, pd.DataFrame)
    # Simple integer Dataset is converted to a single-column Pandas DataFrame.
    cum_sum += batch["value"]
print(cum_sum)
# -> 49995000

Datasets can be passed to Ray tasks or actors and accessed by these iteration methods. This does not incur a copy, since the blocks of the Dataset are passed by reference as Ray objects:

import ray

@ray.remote
def consume(data: ray.data.Dataset[int]) -> int:
    num_batches = 0
    # Consume data in 2-record batches.
    for batch in data.iter_batches(batch_size=2):
        assert len(batch) == 2
        num_batches += 1
    return num_batches

ds = ray.data.range(10000)
ray.get(consume.remote(ds))
# -> 5000

Splitting Into and Consuming Shards

Datasets can be split up into disjoint sub-datasets, or shards. Locality-aware splitting is supported if you pass in a list of actor handles to the ds.split() function along with the number of desired splits. This is a common pattern useful for loading and sharding data between distributed training actors:

Note

If using Ray Train for distributed training, you do not need to split the dataset; Ray Train will automatically do locality-aware splitting into per-trainer shards for you!

# @ray.remote(num_gpus=1)  # Uncomment this to run on GPUs.
@ray.remote
class Worker:
    def __init__(self, rank: int):
        pass

    def train(self, shard: ray.data.Dataset[int]) -> int:
        for batch in shard.iter_torch_batches(batch_size=256):
            pass
        return shard.count()

workers = [Worker.remote(i) for i in range(4)]
# -> [Actor(Worker, ...), Actor(Worker, ...), ...]

ds = ray.data.range(10000)
# -> Dataset(num_blocks=200, num_rows=10000, schema=<class 'int'>)

shards = ds.split(n=4, locality_hints=workers)
# -> [Dataset(num_blocks=13, num_rows=2500, schema=<class 'int'>),
#     Dataset(num_blocks=13, num_rows=2500, schema=<class 'int'>), ...]

ray.get([w.train.remote(s) for w, s in zip(workers, shards)])
# -> [2500, 2500, 2500, 2500]

Saving Datasets

Datasets can be written to local or remote storage in the desired data format. The supported formats include Parquet, CSV, JSON, NumPy. To control the number of output files, you may use ds.repartition() to repartition the Dataset before writing out.

import ray

ds = ray.data.range(1000)
# -> Dataset(num_blocks=200, num_rows=1000, schema=<class 'int'>)
ds.take(5)
# -> [0, 1, 2, 3, 4]

# Write out just one file.
ds.repartition(1).write_parquet("/tmp/one_parquet")
# -> /tmp/one_parquet/d757569dfb2845589b0ccbcb263e8cc3_000000.parquet

# Write out multiple files.
ds.repartition(3).write_parquet("/tmp/multi_parquet")
# -> /tmp/multi_parquet/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000000.parquet
# -> /tmp/multi_parquet/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000001.parquet
# -> /tmp/multi_parquet/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000002.parquet
import ray

ds = ray.data.range(1000)
# -> Dataset(num_blocks=200, num_rows=1000, schema=<class 'int'>)
ds.take(5)
# -> [0, 1, 2, 3, 4]

# Write out just one file.
ds.repartition(1).write_csv("/tmp/one_csv")
# -> /tmp/one_csv/d757569dfb2845589b0ccbcb263e8cc3_000000.csv

# Write out multiple files.
ds.repartition(3).write_csv("/tmp/multi_csv")
# -> /tmp/multi_csv/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000000.csv
# -> /tmp/multi_csv/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000001.csv
# -> /tmp/multi_csv/2b529dc5d8eb45e5ad03e69fb7ad8bc0_000002.csv
import ray

ds = ray.data.range(1000)
# -> Dataset(num_blocks=200, num_rows=1000, schema=<class 'int'>)
ds.take(5)
# -> [0, 1, 2, 3, 4]

# Write out just one file.
ds.repartition(1).write_json("/tmp/one_json")
# -> /tmp/one_json/ab693fde13634f4c8cdaef1db9595ac1_000000.json

# Write out multiple files.
ds.repartition(3).write_json("/tmp/multi_json")
# -> /tmp/multi_json/f467636b3c41420bb109505ab56c6eae_000000.json
# -> /tmp/multi_json/f467636b3c41420bb109505ab56c6eae_000001.json
# -> /tmp/multi_json/f467636b3c41420bb109505ab56c6eae_000002.json
import ray
import numpy as np

ds = ray.data.from_numpy(np.arange(1000))
# -> Dataset(
#        num_blocks=1,
#        num_rows=1000,
#        schema={value: <ArrowTensorType: shape=(), dtype=int64>},
#    )
ds.show(2)
# -> {'value': array(0)}
# -> {'value': array(1)}

# Write out just one file.
ds.repartition(1).write_numpy("/tmp/one_numpy")
# -> /tmp/one_numpy/78c91652e2364a7481cf171bed6d96e4_000000.npy

# Write out multiple files.
ds.repartition(3).write_numpy("/tmp/multi_numpy")
# -> /tmp/multi_numpy/b837e5b5a18448bfa3f8388f5d99d033_000000.npy
# -> /tmp/multi_numpy/b837e5b5a18448bfa3f8388f5d99d033_000001.npy
# -> /tmp/multi_numpy/b837e5b5a18448bfa3f8388f5d99d033_000002.npy
import ray

ds = ray.data.from_items(
    [
        {"some_int": 1, "some_float": 1.0, "some_bytestring": b"abc"},
        {"some_int": 2, "some_float": 2.0, "some_bytestring": b"def"},
    ]
)
# -> Dataset(
#        num_blocks=2, 
#        num_rows=2, 
#        schema={some_int: int64, some_float: double, some_bytestring: binary}
#    )

ds.show(2)
# -> {'some_int': 1, 'some_float': 1.0, 'some_bytestring': b'abc'}
# -> {'some_int': 2, 'some_float': 2.0, 'some_bytestring': b'def'}

# Write out just one file.
ds.repartition(1).write_tfrecords("/tmp/one_tfrecord")
# -> /tmp/one_tfrecord/6d41f90ee8ac4d7db0aa4f43efd3070c_000000.tfrecords

# Write out multiple files.
ds.repartition(2).write_tfrecords("/tmp/multi_tfrecords")
# -> /tmp/multi_tfrecords/1ba614cf75de47e184b7d8f4a1cdfc80_000000.tfrecords
# -> /tmp/multi_tfrecords/1ba614cf75de47e184b7d8f4a1cdfc80_000001.tfrecords