ray.data.extensions.tensor_extension.TensorDtype#

class ray.data.extensions.tensor_extension.TensorDtype(shape: Tuple[Optional[int], ...], dtype: numpy.dtype)[source]#

Pandas extension type for a column of homogeneous-typed tensors.

This extension supports tensors in which the elements have different shapes. However, each tensor element must be non-ragged, i.e. each tensor element must have a well-defined, non-ragged shape.

See: https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/base.py for up-to-date interface documentation and the subclassing contract. The docstrings of the below properties and methods were copied from the base ExtensionDtype.

Examples

>>> # Create a DataFrame with a list of ndarrays as a column.
>>> import pandas as pd
>>> import numpy as np
>>> import ray
>>> df = pd.DataFrame({
...     "one": [1, 2, 3],
...     "two": list(np.arange(24).reshape((3, 2, 2, 2)))})
>>> # Note the opaque np.object dtype for this column.
>>> df.dtypes 
one     int64
two    object
dtype: object
>>> # Cast column to our TensorDtype extension type.
>>> from ray.data.extensions import TensorDtype
>>> df["two"] = df["two"].astype(TensorDtype(np.int64, (3, 2, 2, 2)))
>>> # Note that the column dtype is now TensorDtype instead of
>>> # np.object.
>>> df.dtypes 
one          int64
two    TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
dtype: object
>>> # Pandas is now aware of this tensor column, and we can do the
>>> # typical DataFrame operations on this column.
>>> col = 2 * (df["two"] + 10)
>>> # The ndarrays underlying the tensor column will be manipulated,
>>> # but the column itself will continue to be a Pandas type.
>>> type(col) 
pandas.core.series.Series
>>> col 
0   [[[ 2  4]
      [ 6  8]]
     [[10 12]
       [14 16]]]
1   [[[18 20]
      [22 24]]
     [[26 28]
      [30 32]]]
2   [[[34 36]
      [38 40]]
     [[42 44]
      [46 48]]]
Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
>>> # Once you do an aggregation on that column that returns a single
>>> # row's value, you get back our TensorArrayElement type.
>>> tensor = col.mean()
>>> type(tensor) 
ray.data.extensions.tensor_extension.TensorArrayElement
>>> tensor 
array([[[18., 20.],
        [22., 24.]],
       [[26., 28.],
        [30., 32.]]])
>>> # This is a light wrapper around a NumPy ndarray, and can easily
>>> # be converted to an ndarray.
>>> type(tensor.to_numpy()) 
numpy.ndarray
>>> # In addition to doing Pandas operations on the tensor column,
>>> # you can now put the DataFrame into a Dataset.
>>> ds = ray.data.from_pandas(df) 
>>> # Internally, this column is represented the corresponding
>>> # Arrow tensor extension type.
>>> ds.schema() 
one: int64
two: extension<arrow.py_extension_type<ArrowTensorType>>
>>> # You can write the dataset to Parquet.
>>> ds.write_parquet("/some/path") 
>>> # And you can read it back.
>>> read_ds = ray.data.read_parquet("/some/path") 
>>> read_ds.schema() 
one: int64
two: extension<arrow.py_extension_type<ArrowTensorType>>
>>> read_df = ray.get(read_ds.to_pandas_refs())[0] 
>>> read_df.dtypes 
one          int64
two    TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
dtype: object
>>> # The tensor extension type is preserved along the
>>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
>>> # conversion chain.
>>> read_df.equals(df) 
True

PublicAPI (beta): This API is in beta and may change before becoming stable.

__init__(shape: Tuple[Optional[int], ...], dtype: numpy.dtype)[source]#

Methods

__init__(shape, dtype)

construct_array_type()

Return the array type associated with this dtype.

construct_from_string(string)

Construct this type from a string.

is_dtype(dtype)

Check if we match 'dtype'.

Attributes

base

element_dtype

The dtype of the underlying tensor elements.

element_shape

The shape of the underlying tensor elements.

is_variable_shaped

Whether the corresponding TensorArray for this TensorDtype holds variable-shaped tensor elements.

kind

A character code (one of 'biufcmMOSUV'), default 'O'

na_value

Default NA value to use for this type.

name

A string identifying the data type.

names

Ordered list of field names, or None if there are no fields.

type

The scalar type for the array, e.g.