Source code for ray.data.namespace_expressions.struct_namespace

"""Struct namespace for expression operations on struct-typed columns."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Union

import pyarrow
import pyarrow.compute as pc

from ray.data.datatype import DataType
from ray.data.expressions import _create_pyarrow_compute_udf

if TYPE_CHECKING:
    from ray.data.expressions import Expr, PyArrowComputeUDFExpr


[docs] @dataclass class _StructNamespace: """Namespace for struct operations on expression columns. This namespace provides methods for operating on struct-typed columns using PyArrow compute functions. Example: >>> from ray.data.expressions import col >>> # Access a field using method >>> expr = col("user_record").struct.field("age") >>> # Access a field using bracket notation >>> expr = col("user_record").struct["age"] >>> # Access nested field >>> expr = col("user_record").struct["address"].struct["city"] """ _expr: Expr def __getitem__(self, key: Union[str, int]) -> "PyArrowComputeUDFExpr": """Extract a field using bracket notation. Args: key: The field name or index to extract. Returns: PyArrowComputeUDFExpr that extracts the specified field from each struct. Example: >>> from ray.data.expressions import col >>> expr = col("user").struct["age"] # Get age field by name >>> expr = col("user").struct[1] # Get second field by index >>> expr = col("user").struct["address"].struct["city"] # Get nested city field """ if isinstance(key, str): return self.field(key) if isinstance(key, int) and not isinstance(key, bool): return self.field_by_index(key) raise TypeError( f"Struct indices must be strings or integers, not {type(key).__name__}" )
[docs] def field(self, field_name: str) -> "PyArrowComputeUDFExpr": """Extract a field from a struct. Args: field_name: The name of the field to extract. Returns: UDFExpr that extracts the specified field from each struct. """ return_dtype = DataType(object) if self._expr.data_type.is_arrow_type(): arrow_type = self._expr.data_type.to_arrow_dtype() if pyarrow.types.is_struct(arrow_type): try: field_type = arrow_type.field(field_name).type return_dtype = DataType.from_arrow(field_type) except KeyError: pass return _create_pyarrow_compute_udf(pc.struct_field, return_dtype)( self._expr, field_name )
[docs] def field_by_index(self, index: int) -> "PyArrowComputeUDFExpr": """Extract a field from a struct by index. Args: index: The index of the field to extract. Returns: UDFExpr that extracts the specified field from each struct. """ if not isinstance(index, int) or isinstance(index, bool): raise TypeError( f"Struct field index must be an integer, not {type(index).__name__}" ) if index < 0: raise ValueError(f"Struct field index must be non-negative, got {index}") return_dtype = DataType(object) if self._expr.data_type.is_arrow_type(): arrow_type = self._expr.data_type.to_arrow_dtype() if pyarrow.types.is_struct(arrow_type): try: field_type = arrow_type[index].type return_dtype = DataType.from_arrow(field_type) except IndexError: pass return _create_pyarrow_compute_udf(pc.struct_field, return_dtype)( self._expr, index )