Source code for ray.data.namespace_expressions.list_namespace
"""List namespace for expression operations on list-typed columns."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
import pyarrow
import pyarrow.compute as pc
from ray.data.datatype import DataType
from ray.data.expressions import pyarrow_udf
if TYPE_CHECKING:
from ray.data.expressions import Expr, UDFExpr
[docs]
@dataclass
class _ListNamespace:
"""Namespace for list operations on expression columns.
This namespace provides methods for operating on list-typed columns using
PyArrow compute functions.
Example:
>>> from ray.data.expressions import col
>>> # Get length of list column
>>> expr = col("items").list.len()
>>> # Get first item using method
>>> expr = col("items").list.get(0)
>>> # Get first item using indexing
>>> expr = col("items").list[0]
>>> # Slice list
>>> expr = col("items").list[1:3]
"""
_expr: Expr
[docs]
def len(self) -> "UDFExpr":
"""Get the length of each list."""
@pyarrow_udf(return_dtype=DataType.int32())
def _list_len(arr: pyarrow.Array) -> pyarrow.Array:
return pc.list_value_length(arr)
return _list_len(self._expr)
def __getitem__(self, key: Union[int, slice]) -> "UDFExpr":
"""Get element or slice using bracket notation.
Args:
key: An integer for element access or slice for list slicing.
Returns:
UDFExpr that extracts the element or slice.
Example:
>>> col("items").list[0] # Get first item # doctest: +SKIP
>>> col("items").list[1:3] # Get slice [1, 3) # doctest: +SKIP
>>> col("items").list[-1] # Get last item # doctest: +SKIP
"""
if isinstance(key, int):
return self.get(key)
elif isinstance(key, slice):
return self.slice(key.start, key.stop, key.step)
else:
raise TypeError(
f"List indices must be integers or slices, not {type(key).__name__}"
)
[docs]
def get(self, index: int) -> "UDFExpr":
"""Get element at the specified index from each list.
Args:
index: The index of the element to retrieve. Negative indices are supported.
Returns:
UDFExpr that extracts the element at the given index.
"""
# Infer return type from the list's value type
return_dtype = DataType(object) # fallback
if self._expr.data_type.is_arrow_type():
arrow_type = self._expr.data_type.to_arrow_dtype()
if pyarrow.types.is_list(arrow_type) or pyarrow.types.is_large_list(
arrow_type
):
return_dtype = DataType.from_arrow(arrow_type.value_type)
elif pyarrow.types.is_fixed_size_list(arrow_type):
return_dtype = DataType.from_arrow(arrow_type.value_type)
@pyarrow_udf(return_dtype=return_dtype)
def _list_get(arr: pyarrow.Array) -> pyarrow.Array:
return pc.list_element(arr, index)
return _list_get(self._expr)
[docs]
def slice(
self, start: int | None = None, stop: int | None = None, step: int | None = None
) -> "UDFExpr":
"""Slice each list.
Args:
start: Start index (inclusive). Defaults to 0.
stop: Stop index (exclusive). Defaults to list length.
step: Step size. Defaults to 1.
Returns:
UDFExpr that extracts a slice from each list.
"""
# Return type is the same as the input list type
return_dtype = self._expr.data_type
@pyarrow_udf(return_dtype=return_dtype)
def _list_slice(arr: pyarrow.Array) -> pyarrow.Array:
return pc.list_slice(
arr,
start=0 if start is None else start,
stop=stop,
step=1 if step is None else step,
)
return _list_slice(self._expr)