ray.data.extensions.tensor_extension.TensorDtype#

class ray.data.extensions.tensor_extension.TensorDtype(shape: Tuple[Optional[int], ...], dtype: numpy.dtype)[source]#

Bases: pandas.core.dtypes.base.ExtensionDtype

Pandas extension type for a column of homogeneous-typed tensors.

This extension supports tensors in which the elements have different shapes. However, each tensor element must be non-ragged, i.e. each tensor element must have a well-defined, non-ragged shape.

See: https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/base.py for up-to-date interface documentation and the subclassing contract. The docstrings of the below properties and methods were copied from the base ExtensionDtype.

Examples

>>> # Create a DataFrame with a list of ndarrays as a column.
>>> import pandas as pd
>>> import numpy as np
>>> import ray
>>> df = pd.DataFrame({
...     "one": [1, 2, 3],
...     "two": list(np.arange(24).reshape((3, 2, 2, 2)))})
>>> # Note the opaque np.object dtype for this column.
>>> df.dtypes 
one     int64
two    object
dtype: object
>>> # Cast column to our TensorDtype extension type.
>>> from ray.data.extensions import TensorDtype
>>> df["two"] = df["two"].astype(TensorDtype(np.int64, (3, 2, 2, 2)))
>>> # Note that the column dtype is now TensorDtype instead of
>>> # np.object.
>>> df.dtypes 
one          int64
two    TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
dtype: object
>>> # Pandas is now aware of this tensor column, and we can do the
>>> # typical DataFrame operations on this column.
>>> col = 2 * (df["two"] + 10)
>>> # The ndarrays underlying the tensor column will be manipulated,
>>> # but the column itself will continue to be a Pandas type.
>>> type(col) 
pandas.core.series.Series
>>> col 
0   [[[ 2  4]
      [ 6  8]]
     [[10 12]
       [14 16]]]
1   [[[18 20]
      [22 24]]
     [[26 28]
      [30 32]]]
2   [[[34 36]
      [38 40]]
     [[42 44]
      [46 48]]]
Name: two, dtype: TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
>>> # Once you do an aggregation on that column that returns a single
>>> # row's value, you get back our TensorArrayElement type.
>>> tensor = col.mean()
>>> type(tensor) 
ray.data.extensions.tensor_extension.TensorArrayElement
>>> tensor 
array([[[18., 20.],
        [22., 24.]],
       [[26., 28.],
        [30., 32.]]])
>>> # This is a light wrapper around a NumPy ndarray, and can easily
>>> # be converted to an ndarray.
>>> type(tensor.to_numpy()) 
numpy.ndarray
>>> # In addition to doing Pandas operations on the tensor column,
>>> # you can now put the DataFrame into a Dataset.
>>> ds = ray.data.from_pandas(df) 
>>> # Internally, this column is represented the corresponding
>>> # Arrow tensor extension type.
>>> ds.schema() 
one: int64
two: extension<arrow.py_extension_type<ArrowTensorType>>
>>> # You can write the dataset to Parquet.
>>> ds.write_parquet("/some/path") 
>>> # And you can read it back.
>>> read_ds = ray.data.read_parquet("/some/path") 
>>> read_ds.schema() 
one: int64
two: extension<arrow.py_extension_type<ArrowTensorType>>
>>> read_df = ray.get(read_ds.to_pandas_refs())[0] 
>>> read_df.dtypes 
one          int64
two    TensorDtype(shape=(3, 2, 2, 2), dtype=int64)
dtype: object
>>> # The tensor extension type is preserved along the
>>> # Pandas --> Arrow --> Parquet --> Arrow --> Pandas
>>> # conversion chain.
>>> read_df.equals(df) 
True

PublicAPI (beta): This API is in beta and may change before becoming stable.

property type#

The scalar type for the array, e.g. int It’s expected ExtensionArray[item] returns an instance of ExtensionDtype.type for scalar item, assuming that value is valid (not NA). NA values do not need to be instances of type.

property element_dtype#

The dtype of the underlying tensor elements.

property element_shape#

The shape of the underlying tensor elements. This will be a tuple of Nones if the corresponding TensorArray for this TensorDtype holds variable-shaped tensor elements.

property is_variable_shaped#

Whether the corresponding TensorArray for this TensorDtype holds variable-shaped tensor elements.

property name: str#

A string identifying the data type. Will be used for display in, e.g. Series.dtype

classmethod construct_from_string(string: str)[source]#

Construct this type from a string.

This is useful mainly for data types that accept parameters. For example, a period dtype accepts a frequency parameter that can be set as period[H] (where H means hourly frequency).

By default, in the abstract class, just the name of the type is expected. But subclasses can overwrite this method to accept parameters.

Parameters

string (str) – The name of the type, for example category.

Returns

Instance of the dtype.

Return type

ExtensionDtype

Raises

TypeError – If a class cannot be constructed from this ‘string’.

Examples

For extension dtypes with arguments the following may be an adequate implementation.

>>> import re
>>> @classmethod
... def construct_from_string(cls, string):
...     pattern = re.compile(r"^my_type\[(?P<arg_name>.+)\]$")
...     match = pattern.match(string)
...     if match:
...         return cls(**match.groupdict())
...     else:
...         raise TypeError(
...             f"Cannot construct a '{cls.__name__}' from '{string}'"
...         )
classmethod construct_array_type()[source]#

Return the array type associated with this dtype.

Returns

Return type

type