Source code for ray.data.stats

import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import pandas as pd
import pyarrow as pa

from ray.air.util.tensor_extensions.arrow import convert_to_pyarrow_array
from ray.data.aggregate import (
    AggregateFnV2,
    ApproximateQuantile,
    ApproximateTopK,
    Count,
    Max,
    Mean,
    Min,
    MissingValuePercentage,
    Std,
    ZeroPercentage,
)
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    from ray.data.dataset import Schema
    from ray.data.datatype import DataType, TypeCategory


logger = logging.getLogger(__name__)


[docs] @PublicAPI(stability="alpha") @dataclass class DatasetSummary: """Wrapper for dataset summary statistics. Provides methods to access computed statistics. Attributes: dataset_schema: PyArrow schema of the original dataset """ STATISTIC_COLUMN = "statistic" # PyArrow requires tables whereby each column's value conforms to the column's dtype as defined by the schema. # However, aggregation results might produce statistics with types different from # the original column (e.g., 'count' is int64 even for string columns). # To handle this, we split statistics into two tables: # 1. _stats_matching_column_dtype: Statistics that share the same type as the # original column (e.g., min/max for numerical columns). These preserve # the original column's dtype. # 2. _stats_mismatching_column_dtype: Statistics with different types (e.g., count, # missing_pct). These use inferred types (e.g., float64 for count). _stats_matching_column_dtype: pa.Table _stats_mismatching_column_dtype: pa.Table dataset_schema: pa.Schema columns: list[str] def _safe_convert_table(self, table: pa.Table): """Safely convert a PyArrow table to pandas, handling problematic extension types. Args: table: PyArrow table to convert Returns: pandas DataFrame with converted data """ from ray.data.block import BlockAccessor try: return BlockAccessor.for_block(table).to_pandas() except (TypeError, ValueError, pa.ArrowInvalid) as e: logger.warning( f"Direct conversion to pandas failed ({e}), " "attempting column-by-column conversion" ) result_data = {} for col_name in table.schema.names: col = table.column(col_name) try: result_data[col_name] = col.to_pandas() except (TypeError, ValueError, pa.ArrowInvalid): # Cast problematic columns to null type null_col = pa.nulls(len(col), type=pa.null()) result_data[col_name] = null_col.to_pandas() return pd.DataFrame(result_data) def _set_statistic_index(self, df: pd.DataFrame) -> pd.DataFrame: """Set the statistic column as index if it exists, else return empty DataFrame. Args: df: DataFrame to set index on Returns: DataFrame with statistic column as index, or empty DataFrame if column missing """ if self.STATISTIC_COLUMN in df.columns: return df.set_index(self.STATISTIC_COLUMN) return pd.DataFrame()
[docs] def to_pandas(self): """Convert summary to a single pandas DataFrame. Combines statistics from both schema-matching and schema-changing tables. Note: Some PyArrow extension types (like TensorExtensionType) may fail to convert to pandas when all values in a column are None. In such cases, this method attempts to convert column-by-column, casting problematic columns to null type. Returns: DataFrame with all statistics, where rows are unique statistics from both tables """ df_matching = self._set_statistic_index( self._safe_convert_table(self._stats_matching_column_dtype) ) df_changing = self._set_statistic_index( self._safe_convert_table(self._stats_mismatching_column_dtype) ) # Handle case where both are empty if df_matching.empty and df_changing.empty: return pd.DataFrame(columns=[self.STATISTIC_COLUMN]) # Combine tables: prefer schema_matching values, fill with schema_changing result = df_matching.combine_first(df_changing) return ( result.reset_index() .sort_values(self.STATISTIC_COLUMN) .reset_index(drop=True) )
def _extract_column_from_table( self, table: pa.Table, column: str ) -> Optional[dict]: """Extract a column from a PyArrow table if it exists. Args: table: PyArrow table to extract from column: Column name to extract Returns: DataFrame with 'statistic' and 'value' columns, or None if column doesn't exist """ if column not in table.schema.names: return None df = self._safe_convert_table(table)[[self.STATISTIC_COLUMN, column]] return df.rename(columns={column: "value"})
[docs] def get_column_stats(self, column: str): """Get all statistics for a specific column, merging from both tables. Args: column: Column name to get statistics for Returns: DataFrame with all statistics for the column """ dfs = [ df for table in [ self._stats_matching_column_dtype, self._stats_mismatching_column_dtype, ] if (df := self._extract_column_from_table(table, column)) is not None ] if not dfs: raise ValueError(f"Column '{column}' not found in summary tables") # Concatenate and merge duplicate statistics (prefer non-null values) combined = pd.concat(dfs, ignore_index=True) # Group by statistic and take first non-null value for each group def first_non_null(series): non_null = series.dropna() return non_null.iloc[0] if len(non_null) > 0 else None result = ( combined.groupby(self.STATISTIC_COLUMN, sort=False)["value"] .apply(first_non_null) .reset_index() .sort_values(self.STATISTIC_COLUMN) .reset_index(drop=True) ) return result
@dataclass class _DtypeAggregators: """Container for columns and their aggregators. Attributes: column_to_dtype: Mapping from column name to dtype string representation aggregators: List of all aggregators to apply """ column_to_dtype: Dict[str, str] aggregators: List[AggregateFnV2] def _numerical_aggregators(column: str) -> List[AggregateFnV2]: """Generate default metrics for numerical columns. This function returns a list of aggregators that compute the following metrics: - count - mean - min - max - std - approximate_quantile (median) - missing_value_percentage - zero_percentage Args: column: The name of the numerical column to compute metrics for. Returns: A list of AggregateFnV2 instances that can be used with Dataset.aggregate() """ return [ Count(on=column, ignore_nulls=False), Mean(on=column, ignore_nulls=True), Min(on=column, ignore_nulls=True), Max(on=column, ignore_nulls=True), Std(on=column, ignore_nulls=True, ddof=0), ApproximateQuantile(on=column, quantiles=[0.5]), MissingValuePercentage(on=column), ZeroPercentage(on=column, ignore_nulls=True), ] def _temporal_aggregators(column: str) -> List[AggregateFnV2]: """Generate default metrics for temporal columns. This function returns a list of aggregators that compute the following metrics: - count - min - max - missing_value_percentage Args: column: The name of the temporal column to compute metrics for. Returns: A list of AggregateFnV2 instances that can be used with Dataset.aggregate() """ return [ Count(on=column, ignore_nulls=False), Min(on=column, ignore_nulls=True), Max(on=column, ignore_nulls=True), MissingValuePercentage(on=column), ] def _basic_aggregators(column: str) -> List[AggregateFnV2]: """Generate default metrics for all columns. This function returns a list of aggregators that compute the following metrics: - count - missing_value_percentage - approximate_top_k (top 10 most frequent values) Args: column: The name of the column to compute metrics for. Returns: A list of AggregateFnV2 instances that can be used with Dataset.aggregate() """ return [ Count(on=column, ignore_nulls=False), MissingValuePercentage(on=column), ApproximateTopK(on=column, k=10), ] def _default_dtype_aggregators() -> Dict[ Union["DataType", "TypeCategory"], Callable[[str], List[AggregateFnV2]] ]: """Get default mapping from Ray Data DataType to aggregator factory functions. This function returns factory functions that create aggregators for specific columns. Returns: Dict mapping DataType or TypeCategory to factory functions that take a column name and return a list of aggregators for that column. Examples: >>> from ray.data.datatype import DataType >>> from ray.data.stats import _default_dtype_aggregators >>> mapping = _default_dtype_aggregators() >>> factory = mapping.get(DataType.int32()) >>> aggs = factory("my_column") # Creates aggregators for "my_column" """ from ray.data.datatype import DataType, TypeCategory # Use pattern-matching types for cleaner mapping return { # Numerical types DataType.int8(): _numerical_aggregators, DataType.int16(): _numerical_aggregators, DataType.int32(): _numerical_aggregators, DataType.int64(): _numerical_aggregators, DataType.uint8(): _numerical_aggregators, DataType.uint16(): _numerical_aggregators, DataType.uint32(): _numerical_aggregators, DataType.uint64(): _numerical_aggregators, DataType.float32(): _numerical_aggregators, DataType.float64(): _numerical_aggregators, DataType.bool(): _numerical_aggregators, # String and binary types DataType.string(): _basic_aggregators, DataType.binary(): _basic_aggregators, # Temporal types - pattern matches all temporal types (timestamp, date, time, duration) TypeCategory.TEMPORAL: _temporal_aggregators, # Note: Complex types like lists, structs, maps use fallback logic # in _get_aggregators_for_dtype since they can't be easily enumerated } def _get_fallback_aggregators(column: str, dtype: "DataType") -> List[AggregateFnV2]: """Get aggregators using heuristic-based type detection. This is a fallback when no explicit mapping is found for the dtype. Args: column: Column name dtype: Ray Data DataType for the column Returns: List of aggregators suitable for the column type """ try: # Check for null type first if dtype.is_arrow_type() and pa.types.is_null(dtype._physical_dtype): return [Count(on=column, ignore_nulls=False)] elif dtype.is_numerical_type(): return _numerical_aggregators(column) elif dtype.is_temporal_type(): return _temporal_aggregators(column) else: # Default for strings, binary, lists, nested types, etc. return _basic_aggregators(column) except Exception as e: logger.warning( f"Could not determine aggregators for column '{column}' with dtype {dtype}: {e}. " f"Using basic aggregators." ) return _basic_aggregators(column) def _get_aggregators_for_dtype( column: str, dtype: "DataType", dtype_agg_mapping: Dict[ Union["DataType", "TypeCategory"], Callable[[str], List[AggregateFnV2]] ], ) -> List[AggregateFnV2]: """Get aggregators for a specific column based on its DataType. Attempts to match the dtype against the provided mapping first, then falls back to heuristic-based selection if no match is found. Args: column: Column name dtype: Ray Data DataType for the column dtype_agg_mapping: Mapping from DataType to factory functions Returns: List of aggregators with the column name properly set """ from ray.data.datatype import DataType, TypeCategory # Try to find a match in the mapping for mapping_key, factory in dtype_agg_mapping.items(): if isinstance(mapping_key, DataType) and dtype == mapping_key: return factory(column) elif isinstance(mapping_key, (TypeCategory, str)) and dtype.is_of(mapping_key): return factory(column) # Fallback: Use heuristic-based selection return _get_fallback_aggregators(column, dtype) def _dtype_aggregators_for_dataset( schema: Optional["Schema"], columns: Optional[List[str]] = None, dtype_agg_mapping: Optional[ Dict[Union["DataType", "TypeCategory"], Callable[[str], List[AggregateFnV2]]] ] = None, ) -> _DtypeAggregators: """Generate aggregators for columns in a dataset based on their DataTypes. Args: schema: A Ray Schema instance columns: List of columns to include. If None, all columns will be included. dtype_agg_mapping: Optional user-provided mapping from DataType to aggregator factories. Each value should be a callable that takes a column name and returns aggregators. This will be merged with the default mapping (user mapping takes precedence). Returns: _DtypeAggregators containing column-to-dtype mapping and aggregators Raises: ValueError: If schema is None or if specified columns don't exist in schema """ from ray.data.datatype import DataType if not schema: raise ValueError("Dataset must have a schema to determine column types") if columns is None: columns = schema.names # Validate columns exist in schema missing_cols = set(columns) - set(schema.names) if missing_cols: raise ValueError(f"Columns {missing_cols} not found in dataset schema") # Build final mapping: default + user overrides defaults = _default_dtype_aggregators() if dtype_agg_mapping: # Put user overrides first so they are checked before default patterns final_mapping = dtype_agg_mapping.copy() for k, v in defaults.items(): if k not in final_mapping: final_mapping[k] = v else: final_mapping = defaults # Generate aggregators for each column column_to_dtype = {} all_aggs = [] name_to_type = dict(zip(schema.names, schema.types)) for name in columns: pa_type = name_to_type[name] if pa_type is None or pa_type is object: logger.warning(f"Skipping field '{name}': type is None or unsupported") continue ray_dtype = DataType.from_arrow(pa_type) column_to_dtype[name] = str(ray_dtype) all_aggs.extend(_get_aggregators_for_dtype(name, ray_dtype, final_mapping)) return _DtypeAggregators( column_to_dtype=column_to_dtype, aggregators=all_aggs, ) def _format_stats( agg: AggregateFnV2, value: Any, agg_type: pa.DataType ) -> Dict[str, Tuple[Any, pa.DataType]]: """Format aggregation result into stat entries. Takes the raw aggregation result and formats it into one or more stat entries. For scalar results, returns a single entry. For list results, expands into multiple indexed entries. Args: agg: The aggregator instance value: The aggregation result value agg_type: PyArrow type of the aggregation result Returns: Dictionary mapping stat names to (value, type) tuples """ from ray.data.datatype import DataType agg_name = agg.get_agg_name() # Handle list results: expand into separate indexed stats # If the value is None but the type is list, it means we got a null result # for a list-type aggregator (e.g., ignore_nulls=True and all nulls). is_list_type = ( pa.types.is_list(agg_type) or DataType.from_arrow(agg_type).is_list_type() ) if isinstance(value, list) or (value is None and is_list_type): scalar_type = ( agg_type.value_type if DataType.from_arrow(agg_type).is_list_type() else agg_type ) if value is None: # Can't expand None without knowing the size, return as-is pass else: labels = [str(idx) for idx in range(len(value))] return { f"{agg_name}[{label}]": (list_val, scalar_type) for label, list_val in zip(labels, value) } # Fallback to scalar result for non-list values or unexpandable Nones return {agg_name: (value, agg_type)} def _parse_summary_stats( agg_result: Dict[str, any], original_schema: pa.Schema, agg_schema: pa.Schema, aggregators: List[AggregateFnV2], ) -> tuple: """Parse aggregation results into schema-matching and schema-changing stats. Args: agg_result: Dictionary of aggregation results with keys like "count(col)" original_schema: Original dataset schema agg_schema: Schema of aggregation results aggregators: List of aggregators used to generate the results Returns: Tuple of (schema_matching_stats, schema_changing_stats, column_names) """ schema_matching = {} schema_changing = {} columns = set() # Build a lookup map from "stat_name(col_name)" to aggregator agg_lookup = {agg.name: agg for agg in aggregators} for key, value in agg_result.items(): if "(" not in key or not key.endswith(")"): continue # Get aggregator and extract info agg = agg_lookup.get(key) if not agg: continue col_name = agg.get_target_column() if not col_name: # Skip aggregations without a target column (e.g., Count()) continue # Format the aggregation results agg_type = agg_schema.field(key).type original_type = original_schema.field(col_name).type formatted_stats = _format_stats(agg, value, agg_type) for stat_name, (stat_value, stat_type) in formatted_stats.items(): # Add formatted stats to appropriate dict based on schema matching stats_dict = ( schema_matching if stat_type == original_type else schema_changing ) stats_dict.setdefault(stat_name, {})[col_name] = (stat_value, stat_type) columns.add(col_name) return schema_matching, schema_changing, columns def _create_pyarrow_array( col_data: List, col_type: Optional[pa.DataType] = None, col_name: str = "" ) -> pa.Array: """Create a PyArrow array with fallback strategies. Uses convert_to_pyarrow_array from arrow_block.py for type inference and error handling when no specific type is provided. Args: col_data: List of column values col_type: Optional PyArrow type to use col_name: Column name for error messages (optional) Returns: PyArrow array """ if col_type is not None: try: return pa.array(col_data, type=col_type) except (pa.ArrowTypeError, pa.ArrowInvalid): # Type mismatch - fall through to type inference pass # Use convert_to_pyarrow_array for type inference and error handling # This handles tensors, extension types, and fallback to ArrowPythonObjectArray return convert_to_pyarrow_array(col_data, col_name or "column") def _build_summary_table( stats_dict: Dict[str, Dict[str, tuple]], all_columns: set, original_schema: pa.Schema, preserve_types: bool, ) -> pa.Table: """Build a PyArrow table from parsed statistics. Args: stats_dict: Nested dict of {stat_name: {col_name: (value, type)}} all_columns: Set of all column names across both tables original_schema: Original dataset schema preserve_types: If True, use original schema types for columns Returns: PyArrow table with statistics """ if not stats_dict: return pa.table({}) stat_names = sorted(stats_dict.keys()) table_data = {DatasetSummary.STATISTIC_COLUMN: stat_names} for col_name in sorted(all_columns): # Collect values and infer type col_data = [] first_type = None for stat_name in stat_names: if col_name in stats_dict[stat_name]: value, agg_type = stats_dict[stat_name][col_name] col_data.append(value) if first_type is None: first_type = agg_type else: col_data.append(None) # Determine column type: prefer original schema, then first aggregation type, then infer if preserve_types and col_name in original_schema.names: col_type = original_schema.field(col_name).type else: col_type = first_type table_data[col_name] = _create_pyarrow_array(col_data, col_type, col_name) return pa.table(table_data)