Source code for ray.data.namespace_expressions.string_namespace

"""String namespace for expression operations on string-typed columns."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Literal

import pyarrow
import pyarrow.compute as pc

from ray.data.datatype import DataType
from ray.data.expressions import pyarrow_udf

if TYPE_CHECKING:
    from ray.data.expressions import Expr, UDFExpr


def _create_str_udf(
    pc_func: Callable[..., pyarrow.Array], return_dtype: DataType
) -> Callable[..., "UDFExpr"]:
    """Helper to create a string UDF that wraps a PyArrow compute function.

    This helper handles all types of PyArrow compute operations:
    - Unary operations (no args): upper(), lower(), reverse()
    - Pattern operations (pattern + args): starts_with(), contains()
    - Multi-argument operations: replace(), replace_slice()

    Args:
        pc_func: PyArrow compute function that takes (array, *positional, **kwargs)
        return_dtype: The return data type

    Returns:
        A callable that creates UDFExpr instances
    """

    def wrapper(expr: Expr, *positional: Any, **kwargs: Any) -> "UDFExpr":
        @pyarrow_udf(return_dtype=return_dtype)
        def udf(arr: pyarrow.Array) -> pyarrow.Array:
            return pc_func(arr, *positional, **kwargs)

        return udf(expr)

    return wrapper


[docs] @dataclass class _StringNamespace: """Namespace for string operations on expression columns. This namespace provides methods for operating on string-typed columns using PyArrow compute functions. Example: >>> from ray.data.expressions import col >>> # Convert to uppercase >>> expr = col("name").str.upper() >>> # Get string length >>> expr = col("name").str.len() >>> # Check if string starts with a prefix >>> expr = col("name").str.starts_with("A") """ _expr: Expr # Length methods
[docs] def len(self) -> "UDFExpr": """Get the length of each string in characters.""" return _create_str_udf(pc.utf8_length, DataType.int32())(self._expr)
[docs] def byte_len(self) -> "UDFExpr": """Get the length of each string in bytes.""" return _create_str_udf(pc.binary_length, DataType.int32())(self._expr)
# Case methods
[docs] def upper(self) -> "UDFExpr": """Convert strings to uppercase.""" return _create_str_udf(pc.utf8_upper, DataType.string())(self._expr)
[docs] def lower(self) -> "UDFExpr": """Convert strings to lowercase.""" return _create_str_udf(pc.utf8_lower, DataType.string())(self._expr)
[docs] def capitalize(self) -> "UDFExpr": """Capitalize the first character of each string.""" return _create_str_udf(pc.utf8_capitalize, DataType.string())(self._expr)
[docs] def title(self) -> "UDFExpr": """Convert strings to title case.""" return _create_str_udf(pc.utf8_title, DataType.string())(self._expr)
[docs] def swapcase(self) -> "UDFExpr": """Swap the case of each character.""" return _create_str_udf(pc.utf8_swapcase, DataType.string())(self._expr)
# Predicate methods
[docs] def is_alpha(self) -> "UDFExpr": """Check if strings contain only alphabetic characters.""" return _create_str_udf(pc.utf8_is_alpha, DataType.bool())(self._expr)
[docs] def is_alnum(self) -> "UDFExpr": """Check if strings contain only alphanumeric characters.""" return _create_str_udf(pc.utf8_is_alnum, DataType.bool())(self._expr)
[docs] def is_digit(self) -> "UDFExpr": """Check if strings contain only digits.""" return _create_str_udf(pc.utf8_is_digit, DataType.bool())(self._expr)
[docs] def is_decimal(self) -> "UDFExpr": """Check if strings contain only decimal characters.""" return _create_str_udf(pc.utf8_is_decimal, DataType.bool())(self._expr)
[docs] def is_numeric(self) -> "UDFExpr": """Check if strings contain only numeric characters.""" return _create_str_udf(pc.utf8_is_numeric, DataType.bool())(self._expr)
[docs] def is_space(self) -> "UDFExpr": """Check if strings contain only whitespace.""" return _create_str_udf(pc.utf8_is_space, DataType.bool())(self._expr)
[docs] def is_lower(self) -> "UDFExpr": """Check if strings are lowercase.""" return _create_str_udf(pc.utf8_is_lower, DataType.bool())(self._expr)
[docs] def is_upper(self) -> "UDFExpr": """Check if strings are uppercase.""" return _create_str_udf(pc.utf8_is_upper, DataType.bool())(self._expr)
[docs] def is_title(self) -> "UDFExpr": """Check if strings are title-cased.""" return _create_str_udf(pc.utf8_is_title, DataType.bool())(self._expr)
[docs] def is_printable(self) -> "UDFExpr": """Check if strings contain only printable characters.""" return _create_str_udf(pc.utf8_is_printable, DataType.bool())(self._expr)
[docs] def is_ascii(self) -> "UDFExpr": """Check if strings contain only ASCII characters.""" return _create_str_udf(pc.string_is_ascii, DataType.bool())(self._expr)
# Searching methods
[docs] def starts_with(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Check if strings start with a pattern.""" return _create_str_udf(pc.starts_with, DataType.bool())( self._expr, pattern, *args, **kwargs )
[docs] def ends_with(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Check if strings end with a pattern.""" return _create_str_udf(pc.ends_with, DataType.bool())( self._expr, pattern, *args, **kwargs )
[docs] def contains(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Check if strings contain a substring.""" return _create_str_udf(pc.match_substring, DataType.bool())( self._expr, pattern, *args, **kwargs )
[docs] def match(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Match strings against a SQL LIKE pattern.""" return _create_str_udf(pc.match_like, DataType.bool())( self._expr, pattern, *args, **kwargs )
[docs] def find(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Find the first occurrence of a substring.""" return _create_str_udf(pc.find_substring, DataType.int32())( self._expr, pattern, *args, **kwargs )
[docs] def count(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Count occurrences of a substring.""" return _create_str_udf(pc.count_substring, DataType.int32())( self._expr, pattern, *args, **kwargs )
[docs] def find_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Find the first occurrence matching a regex pattern.""" return _create_str_udf(pc.find_substring_regex, DataType.int32())( self._expr, pattern, *args, **kwargs )
[docs] def count_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Count occurrences matching a regex pattern.""" return _create_str_udf(pc.count_substring_regex, DataType.int32())( self._expr, pattern, *args, **kwargs )
[docs] def match_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Check if strings match a regex pattern.""" return _create_str_udf(pc.match_substring_regex, DataType.bool())( self._expr, pattern, *args, **kwargs )
# Transformation methods
[docs] def reverse(self) -> "UDFExpr": """Reverse each string.""" return _create_str_udf(pc.utf8_reverse, DataType.string())(self._expr)
[docs] def slice(self, *args: Any, **kwargs: Any) -> "UDFExpr": """Slice strings by codeunit indices.""" return _create_str_udf(pc.utf8_slice_codeunits, DataType.string())( self._expr, *args, **kwargs )
[docs] def replace( self, pattern: str, replacement: str, *args: Any, **kwargs: Any ) -> "UDFExpr": """Replace occurrences of a substring.""" return _create_str_udf(pc.replace_substring, DataType.string())( self._expr, pattern, replacement, *args, **kwargs )
[docs] def replace_regex( self, pattern: str, replacement: str, *args: Any, **kwargs: Any ) -> "UDFExpr": """Replace occurrences matching a regex pattern.""" return _create_str_udf(pc.replace_substring_regex, DataType.string())( self._expr, pattern, replacement, *args, **kwargs )
[docs] def replace_slice( self, start: int, stop: int, replacement: str, *args: Any, **kwargs: Any ) -> "UDFExpr": """Replace a slice with a string.""" return _create_str_udf(pc.binary_replace_slice, DataType.string())( self._expr, start, stop, replacement, *args, **kwargs )
[docs] def split(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Split strings by a pattern.""" return _create_str_udf(pc.split_pattern, DataType(object))( self._expr, pattern, *args, **kwargs )
[docs] def split_regex(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Split strings by a regex pattern.""" return _create_str_udf(pc.split_pattern_regex, DataType(object))( self._expr, pattern, *args, **kwargs )
[docs] def split_whitespace(self, *args: Any, **kwargs: Any) -> "UDFExpr": """Split strings on whitespace.""" return _create_str_udf(pc.utf8_split_whitespace, DataType(object))( self._expr, *args, **kwargs )
[docs] def extract(self, pattern: str, *args: Any, **kwargs: Any) -> "UDFExpr": """Extract a substring matching a regex pattern.""" return _create_str_udf(pc.extract_regex, DataType.string())( self._expr, pattern, *args, **kwargs )
[docs] def repeat(self, n: int, *args: Any, **kwargs: Any) -> "UDFExpr": """Repeat each string n times.""" return _create_str_udf(pc.binary_repeat, DataType.string())( self._expr, n, *args, **kwargs )
[docs] def center( self, width: int, padding: str = " ", *args: Any, **kwargs: Any ) -> "UDFExpr": """Center strings in a field of given width.""" return _create_str_udf(pc.utf8_center, DataType.string())( self._expr, width, padding, *args, **kwargs )
# Custom methods that need special logic beyond simple PyArrow function calls
[docs] def strip(self, characters: str | None = None) -> "UDFExpr": """Remove leading and trailing whitespace or specified characters. Args: characters: Characters to remove. If None, removes whitespace. Returns: UDFExpr that strips characters from both ends. """ @pyarrow_udf(return_dtype=DataType.string()) def _str_strip(arr: pyarrow.Array) -> pyarrow.Array: if characters is None: return pc.utf8_trim_whitespace(arr) else: return pc.utf8_trim(arr, characters=characters) return _str_strip(self._expr)
[docs] def lstrip(self, characters: str | None = None) -> "UDFExpr": """Remove leading whitespace or specified characters. Args: characters: Characters to remove. If None, removes whitespace. Returns: UDFExpr that strips characters from the left. """ @pyarrow_udf(return_dtype=DataType.string()) def _str_lstrip(arr: pyarrow.Array) -> pyarrow.Array: if characters is None: return pc.utf8_ltrim_whitespace(arr) else: return pc.utf8_ltrim(arr, characters=characters) return _str_lstrip(self._expr)
[docs] def rstrip(self, characters: str | None = None) -> "UDFExpr": """Remove trailing whitespace or specified characters. Args: characters: Characters to remove. If None, removes whitespace. Returns: UDFExpr that strips characters from the right. """ @pyarrow_udf(return_dtype=DataType.string()) def _str_rstrip(arr: pyarrow.Array) -> pyarrow.Array: if characters is None: return pc.utf8_rtrim_whitespace(arr) else: return pc.utf8_rtrim(arr, characters=characters) return _str_rstrip(self._expr)
# Padding
[docs] def pad( self, width: int, fillchar: str = " ", side: Literal["left", "right", "both"] = "right", ) -> "UDFExpr": """Pad strings to a specified width. Args: width: Target width. fillchar: Character to use for padding. side: "left", "right", or "both" for padding side. Returns: UDFExpr that pads strings. """ @pyarrow_udf(return_dtype=DataType.string()) def _str_pad(arr: pyarrow.Array) -> pyarrow.Array: if side == "right": return pc.utf8_rpad(arr, width=width, padding=fillchar) elif side == "left": return pc.utf8_lpad(arr, width=width, padding=fillchar) elif side == "both": return pc.utf8_center(arr, width=width, padding=fillchar) else: raise ValueError("side must be 'left', 'right', or 'both'") return _str_pad(self._expr)