Source code for ray.data.namespace_expressions.list_namespace

"""List namespace for expression operations on list-typed columns."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Union

import pyarrow
import pyarrow.compute as pc

from ray.data.datatype import DataType
from ray.data.expressions import pyarrow_udf

if TYPE_CHECKING:
    from ray.data.expressions import Expr, UDFExpr


[docs] @dataclass class _ListNamespace: """Namespace for list operations on expression columns. This namespace provides methods for operating on list-typed columns using PyArrow compute functions. Example: >>> from ray.data.expressions import col >>> # Get length of list column >>> expr = col("items").list.len() >>> # Get first item using method >>> expr = col("items").list.get(0) >>> # Get first item using indexing >>> expr = col("items").list[0] >>> # Slice list >>> expr = col("items").list[1:3] """ _expr: Expr
[docs] def len(self) -> "UDFExpr": """Get the length of each list.""" @pyarrow_udf(return_dtype=DataType.int32()) def _list_len(arr: pyarrow.Array) -> pyarrow.Array: return pc.list_value_length(arr) return _list_len(self._expr)
def __getitem__(self, key: Union[int, slice]) -> "UDFExpr": """Get element or slice using bracket notation. Args: key: An integer for element access or slice for list slicing. Returns: UDFExpr that extracts the element or slice. Example: >>> col("items").list[0] # Get first item # doctest: +SKIP >>> col("items").list[1:3] # Get slice [1, 3) # doctest: +SKIP >>> col("items").list[-1] # Get last item # doctest: +SKIP """ if isinstance(key, int): return self.get(key) elif isinstance(key, slice): return self.slice(key.start, key.stop, key.step) else: raise TypeError( f"List indices must be integers or slices, not {type(key).__name__}" )
[docs] def get(self, index: int) -> "UDFExpr": """Get element at the specified index from each list. Args: index: The index of the element to retrieve. Negative indices are supported. Returns: UDFExpr that extracts the element at the given index. """ # Infer return type from the list's value type return_dtype = DataType(object) # fallback if self._expr.data_type.is_arrow_type(): arrow_type = self._expr.data_type.to_arrow_dtype() if pyarrow.types.is_list(arrow_type) or pyarrow.types.is_large_list( arrow_type ): return_dtype = DataType.from_arrow(arrow_type.value_type) elif pyarrow.types.is_fixed_size_list(arrow_type): return_dtype = DataType.from_arrow(arrow_type.value_type) @pyarrow_udf(return_dtype=return_dtype) def _list_get(arr: pyarrow.Array) -> pyarrow.Array: return pc.list_element(arr, index) return _list_get(self._expr)
[docs] def slice( self, start: int | None = None, stop: int | None = None, step: int | None = None ) -> "UDFExpr": """Slice each list. Args: start: Start index (inclusive). Defaults to 0. stop: Stop index (exclusive). Defaults to list length. step: Step size. Defaults to 1. Returns: UDFExpr that extracts a slice from each list. """ # Return type is the same as the input list type return_dtype = self._expr.data_type @pyarrow_udf(return_dtype=return_dtype) def _list_slice(arr: pyarrow.Array) -> pyarrow.Array: return pc.list_slice( arr, start=0 if start is None else start, stop=stop, step=1 if step is None else step, ) return _list_slice(self._expr)