Working with Tensors#

N-dimensional arrays (in other words, tensors) are ubiquitous in ML workloads. This guide describes the limitations and best practices of working with such data.

Tensor data representation#

Ray Data represents tensors as NumPy ndarrays.

import ray

ds ="s3://anonymous@air-example-data/digits")
   schema={image: numpy.ndarray(shape=(28, 28), dtype=uint8)}

Batches of fixed-shape tensors#

If your tensors have a fixed shape, Ray Data represents batches as regular ndarrays.

>>> import ray
>>> ds ="s3://anonymous@air-example-data/digits")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
(32, 28, 28)
>>> batch["image"].dtype

Batches of variable-shape tensors#

If your tensors vary in shape, Ray Data represents batches as arrays of object dtype.

>>> import ray
>>> ds ="s3://anonymous@air-example-data/AnimalDetection")
>>> batch = ds.take_batch(batch_size=32)
>>> batch["image"].shape
>>> batch["image"].dtype

The individual elements of these object arrays are regular ndarrays.

>>> batch["image"][0].dtype
>>> batch["image"][0].shape  
(375, 500, 3)
>>> batch["image"][3].shape  
(333, 465, 3)

Transforming tensor data#

Call map() or map_batches() to transform tensor data.

from typing import Any, Dict

import ray
import numpy as np

ds ="s3://anonymous@air-example-data/AnimalDetection")

def increase_brightness(row: Dict[str, Any]) -> Dict[str, Any]:
    row["image"] = np.clip(row["image"] + 4, 0, 255)
    return row

# Increase the brightness, record at a time.

def batch_increase_brightness(batch: Dict[str, np.ndarray]) -> Dict:
    batch["image"] = np.clip(batch["image"] + 4, 0, 255)
    return batch

# Increase the brightness, batch at a time.

In addition to NumPy ndarrays, Ray Data also treats returned lists of NumPy ndarrays and objects implementing __array__ (for example, torch.Tensor) as tensor data.

For more information on transforming data, read Transforming data.

Saving tensor data#

Save tensor data with formats like Parquet, NumPy, and JSON. To view all supported formats, see the Input/Output reference.

Call write_parquet() to save data in Parquet files.

import ray

ds ="s3://anonymous@ray-example-data/image-datasets/simple")

Call write_numpy() to save an ndarray column in NumPy files.

import ray

ds ="s3://anonymous@ray-example-data/image-datasets/simple")
ds.write_numpy("/tmp/simple", column="image")

To save images in a JSON file, call write_json().

import ray

ds ="s3://anonymous@ray-example-data/image-datasets/simple")

For more information on saving data, read Saving data.