import abc
import math
from typing import TYPE_CHECKING, Any, Callable, List, Optional
import numpy as np
from ray.data._internal.util import is_null
from ray.data.block import AggType, Block, BlockAccessor, KeyType, T, U
from ray.util.annotations import Deprecated, PublicAPI
if TYPE_CHECKING:
from ray.data import Schema
[docs]
@Deprecated(message="AggregateFn is deprecated, please use AggregateFnV2")
@PublicAPI
class AggregateFn:
"""NOTE: THIS IS DEPRECATED, PLEASE USE :class:`AggregateFnV2` INSTEAD
Defines how to perform a custom aggregation in Ray Data.
`AggregateFn` instances are passed to a Dataset's ``.aggregate(...)`` method to
specify the steps required to transform and combine rows sharing the same key.
This enables implementing custom aggregators beyond the standard
built-in options like Sum, Min, Max, Mean, etc.
Args:
init: Function that creates an initial aggregator for each group. Receives a key
(the group key) and returns the initial accumulator state (commonly 0,
an empty list, or an empty dictionary).
merge: Function that merges two accumulators generated by different workers
into one accumulator.
name: An optional display name for the aggregator. Useful for debugging.
accumulate_row: Function that processes an individual row. It receives the current
accumulator and a row, then returns an updated accumulator. Cannot be
used if `accumulate_block` is provided.
accumulate_block: Function that processes an entire block of rows at once. It receives the
current accumulator and a block of rows, then returns an updated accumulator.
This allows for vectorized operations. Cannot be used if `accumulate_row`
is provided.
finalize: Function that finishes the aggregation by transforming the final
accumulator state into the desired output. For example, if your
accumulator is a list of items, you may want to compute a statistic
from the list. If not provided, the final accumulator state is returned
as-is.
Example:
.. testcode::
import ray
from ray.data.aggregate import AggregateFn
# A simple aggregator that counts how many rows there are per group
count_agg = AggregateFn(
init=lambda k: 0,
accumulate_row=lambda counter, row: counter + 1,
merge=lambda c1, c2: c1 + c2,
name="custom_count"
)
ds = ray.data.from_items([{"group": "A"}, {"group": "B"}, {"group": "A"}])
result = ds.groupby("group").aggregate(count_agg).take_all()
# result: [{'group': 'A', 'custom_count': 2}, {'group': 'B', 'custom_count': 1}]
"""
def __init__(
self,
init: Callable[[KeyType], AggType],
merge: Callable[[AggType, AggType], AggType],
name: str,
accumulate_row: Callable[[AggType, T], AggType] = None,
accumulate_block: Callable[[AggType, Block], AggType] = None,
finalize: Optional[Callable[[AggType], U]] = None,
):
if (accumulate_row is None and accumulate_block is None) or (
accumulate_row is not None and accumulate_block is not None
):
raise ValueError(
"Exactly one of accumulate_row or accumulate_block must be provided."
)
if accumulate_block is None:
def accumulate_block(a: AggType, block: Block) -> AggType:
block_acc = BlockAccessor.for_block(block)
for r in block_acc.iter_rows(public_row_format=False):
a = accumulate_row(a, r)
return a
if not isinstance(name, str):
raise TypeError("`name` must be provided.")
if finalize is None:
finalize = lambda a: a # noqa: E731
self.name = name
self.init = init
self.merge = merge
self.accumulate_block = accumulate_block
self.finalize = finalize
def _validate(self, schema: Optional["Schema"]) -> None:
"""Raise an error if this cannot be applied to the given schema."""
pass
[docs]
@PublicAPI(stability="alpha")
class AggregateFnV2(AggregateFn, abc.ABC):
"""Provides an interface to implement efficient aggregations to be applied
to the dataset.
`AggregateFnV2` instances are passed to a Dataset's ``.aggregate(...)`` method to
perform distributed aggregations. To create a custom aggregation, you should subclass
`AggregateFnV2` and implement the `aggregate_block` and `combine` methods.
The `_finalize` method can also be overridden if the final accumulated state
needs further transformation.
Aggregation follows these steps:
1. **Initialization**: For each group (if grouping) or for the entire dataset,
an initial accumulator is created using `zero_factory`.
2. **Block Aggregation**: The `aggregate_block` method is applied to
each block independently, producing a partial aggregation result for that block.
3. **Combination**: The `combine` method is used to merge these partial
results (or an existing accumulated result with a new partial result)
into a single, combined accumulator.
4. **Finalization**: Optionally, the `_finalize` method transforms the
final combined accumulator into the desired output format.
Args:
name: The name of the aggregation. This will be used as the column name
in the output, e.g., "sum(my_col)".
zero_factory: A callable that returns the initial "zero" value for the
accumulator. For example, for a sum, this would be `lambda: 0`; for
finding a minimum, `lambda: float("inf")`, for finding a maximum,
`lambda: float("-inf")`.
on: The name of the column to perform the aggregation on. If `None`,
the aggregation is performed over the entire row (e.g., for `Count()`).
ignore_nulls: Whether to ignore null values during aggregation.
If `True`, nulls are skipped.
If `False`, the presence of a null value might result in a null output,
depending on the aggregation logic.
"""
def __init__(
self,
name: str,
zero_factory: Callable[[], AggType],
*,
on: Optional[str],
ignore_nulls: bool,
):
if not name:
raise ValueError(
f"Non-empty string has to be provided as name (got {name})"
)
self._target_col_name = on
self._ignore_nulls = ignore_nulls
_safe_combine = _null_safe_combine(self.combine, ignore_nulls)
_safe_aggregate = _null_safe_aggregate(self.aggregate_block, ignore_nulls)
_safe_finalize = _null_safe_finalize(self.finalize)
_safe_zero_factory = _null_safe_zero_factory(zero_factory, ignore_nulls)
super().__init__(
name=name,
init=_safe_zero_factory,
merge=_safe_combine,
accumulate_block=lambda _, block: _safe_aggregate(block),
finalize=_safe_finalize,
)
def get_target_column(self) -> Optional[str]:
return self._target_col_name
[docs]
@abc.abstractmethod
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
"""Combines a new partial aggregation result with the current accumulator.
This method defines how two intermediate aggregation states are merged.
For example, if `aggregate_block` produces partial sums `s1` and `s2` from
two different blocks, `combine(s1, s2)` should return `s1 + s2`.
Args:
current_accumulator: The current accumulated state (e.g., the result of
previous `combine` calls or an initial value from `zero_factory`).
new: A new partially aggregated value, typically the output of
`aggregate_block` from a new block of data, or another accumulator
from a parallel task.
Returns:
The updated accumulator after combining it with the new value.
"""
...
[docs]
@abc.abstractmethod
def aggregate_block(self, block: Block) -> AggType:
"""Aggregates data within a single block.
This method processes all rows in a given `Block` and returns a partial
aggregation result for that block. For instance, if implementing a sum,
this method would sum all relevant values within the block.
Args:
block: A `Block` of data to be aggregated.
Returns:
A partial aggregation result for the input block. The type of this
result (`AggType`) should be consistent with the `current_accumulator`
and `new` arguments of the `combine` method, and the `accumulator`
argument of the `_finalize` method.
"""
...
[docs]
def finalize(self, accumulator: AggType) -> Optional[U]:
"""Transforms the final accumulated state into the desired output.
This method is called once per group after all blocks have been processed
and all partial results have been combined. It provides an opportunity
to perform a final transformation on the accumulated data.
For many aggregations (e.g., Sum, Count, Min, Max), the accumulated state
is already the final result, so this method can simply return the
accumulator as is (which is the default behavior).
For other aggregations, like Mean, this method is crucial.
A Mean aggregation might accumulate `[sum, count]`. The `_finalize`
method would then compute `sum / count` to get the final mean.
Args:
accumulator: The final accumulated state for a group, after all
`aggregate_block` and `combine` operations.
Returns:
The final result of the aggregation for the group.
"""
return accumulator
def _validate(self, schema: Optional["Schema"]) -> None:
if self._target_col_name:
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
SortKey(self._target_col_name).validate_schema(schema)
[docs]
@PublicAPI
class Count(AggregateFnV2):
"""Defines count aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import Count
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Counting all rows:
result = ds.aggregate(Count())
# result: {'count()': 100}
# Counting all rows per group:
result = ds.groupby("group_key").aggregate(Count(on="id")).take_all()
# result: [{'group_key': 0, 'count(id)': 34},
# {'group_key': 1, 'count(id)': 33},
# {'group_key': 2, 'count(id)': 33}]
Args:
on: Optional name of the column to count values on. If None, counts rows.
ignore_nulls: Whether to ignore null values when counting. Only applies if
`on` is specified. Default is `False` which means `Count()` on a column
will count nulls by default. To match pandas default behavior of not counting nulls,
set `ignore_nulls=True`.
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ignore_nulls: bool = False,
alias_name: Optional[str] = None,
):
super().__init__(
alias_name if alias_name else f"count({on or ''})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=lambda: 0,
)
def aggregate_block(self, block: Block) -> AggType:
block_accessor = BlockAccessor.for_block(block)
if self._target_col_name is None:
# In case of global count, simply fetch number of rows
return block_accessor.num_rows()
return block_accessor.count(
self._target_col_name, ignore_nulls=self._ignore_nulls
)
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
return current_accumulator + new
[docs]
@PublicAPI
class Sum(AggregateFnV2):
"""Defines sum aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import Sum
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Summing all rows per group:
result = ds.aggregate(Sum(on="id"))
# result: {'sum(id)': 4950}
Args:
on: The name of the numerical column to sum. Must be provided.
ignore_nulls: Whether to ignore null values during summation. If `True` (default),
nulls are skipped. If `False`, the sum will be null if any
value in the group is null.
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
super().__init__(
alias_name if alias_name else f"sum({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=lambda: 0,
)
def aggregate_block(self, block: Block) -> AggType:
return BlockAccessor.for_block(block).sum(
self._target_col_name, self._ignore_nulls
)
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
return current_accumulator + new
[docs]
@PublicAPI
class Min(AggregateFnV2):
"""Defines min aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import Min
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Finding the minimum value per group:
result = ds.groupby("group_key").aggregate(Min(on="id")).take_all()
# result: [{'group_key': 0, 'min(id)': 0},
# {'group_key': 1, 'min(id)': 1},
# {'group_key': 2, 'min(id)': 2}]
Args:
on: The name of the column to find the minimum value from. Must be provided.
ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
skipped. If `False`, the minimum will be null if any value in
the group is null (for most data types, or follow type-specific
comparison rules with nulls).
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
super().__init__(
alias_name if alias_name else f"min({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=lambda: float("+inf"),
)
def aggregate_block(self, block: Block) -> AggType:
return BlockAccessor.for_block(block).min(
self._target_col_name, self._ignore_nulls
)
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
return min(current_accumulator, new)
[docs]
@PublicAPI
class Max(AggregateFnV2):
"""Defines max aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import Max
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Finding the maximum value per group:
result = ds.groupby("group_key").aggregate(Max(on="id")).take_all()
# result: [{'group_key': 0, 'max(id)': ...},
# {'group_key': 1, 'max(id)': ...},
# {'group_key': 2, 'max(id)': ...}]
Args:
on: The name of the column to find the maximum value from. Must be provided.
ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
skipped. If `False`, the maximum will be null if any value in
the group is null (for most data types, or follow type-specific
comparison rules with nulls).
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
super().__init__(
alias_name if alias_name else f"max({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=lambda: float("-inf"),
)
def aggregate_block(self, block: Block) -> AggType:
return BlockAccessor.for_block(block).max(
self._target_col_name, self._ignore_nulls
)
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
return max(current_accumulator, new)
[docs]
@PublicAPI
class Mean(AggregateFnV2):
"""Defines mean (average) aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import Mean
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Calculating the mean value per group:
result = ds.groupby("group_key").aggregate(Mean(on="id")).take_all()
# result: [{'group_key': 0, 'mean(id)': ...},
# {'group_key': 1, 'mean(id)': ...},
# {'group_key': 2, 'mean(id)': ...}]
Args:
on: The name of the numerical column to calculate the mean on. Must be provided.
ignore_nulls: Whether to ignore null values. If `True` (default), nulls are
skipped. If `False`, the mean will be null if any value in the
group is null.
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
super().__init__(
alias_name if alias_name else f"mean({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
# The accumulator is: [current_sum, current_count].
# NOTE: We copy the returned list `list([0,0])` as some internal mechanisms
# might modify accumulators in-place.
zero_factory=lambda: list([0, 0]), # noqa: C410
)
def aggregate_block(self, block: Block) -> AggType:
block_acc = BlockAccessor.for_block(block)
count = block_acc.count(self._target_col_name, self._ignore_nulls)
if count == 0 or count is None:
# Empty or all null.
return None
sum_ = block_acc.sum(self._target_col_name, self._ignore_nulls)
if is_null(sum_):
# In case of ignore_nulls=False and column containing 'null'
# return as is (to prevent unnecessary type conversions, when, for ex,
# using Pandas and returning None)
return sum_
return [sum_, count]
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
return [current_accumulator[0] + new[0], current_accumulator[1] + new[1]]
def finalize(self, accumulator: AggType) -> Optional[U]:
# The final accumulator for a group is [total_sum, total_count].
if accumulator[1] == 0:
# If total_count is 0 (e.g., group was empty or all nulls ignored),
# the mean is undefined. Return NaN
return np.nan
return accumulator[0] / accumulator[1]
[docs]
@PublicAPI
class Std(AggregateFnV2):
"""Defines standard deviation aggregation.
Uses Welford's online algorithm for numerical stability. This method computes
the standard deviation in a single pass. Results may differ slightly from
libraries like NumPy or Pandas that use a two-pass algorithm but are generally
more accurate.
See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
Example:
.. testcode::
import ray
from ray.data.aggregate import Std
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Calculating the standard deviation per group:
result = ds.groupby("group_key").aggregate(Std(on="id")).take_all()
# result: [{'group_key': 0, 'std(id)': ...},
# {'group_key': 1, 'std(id)': ...},
# {'group_key': 2, 'std(id)': ...}]
Args:
on: The name of the column to calculate standard deviation on.
ddof: Delta Degrees of Freedom. The divisor used in calculations is `N - ddof`,
where `N` is the number of elements. Default is 1.
ignore_nulls: Whether to ignore null values. Default is True.
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ddof: int = 1,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
super().__init__(
alias_name if alias_name else f"std({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
# Accumulator: [M2, mean, count]
# M2: sum of squares of differences from the current mean
# mean: current mean
# count: current count of non-null elements
# We need to copy the list as it might be modified in-place by some aggregations.
zero_factory=lambda: list([0, 0, 0]), # noqa: C410
)
self._ddof = ddof
def aggregate_block(self, block: Block) -> AggType:
block_acc = BlockAccessor.for_block(block)
count = block_acc.count(self._target_col_name, ignore_nulls=self._ignore_nulls)
if count == 0 or count is None:
# Empty or all null.
return None
sum_ = block_acc.sum(self._target_col_name, self._ignore_nulls)
if is_null(sum_):
# If sum is null (e.g., ignore_nulls=False and a null was encountered),
# return as is to prevent type conversions.
return sum_
mean = sum_ / count
M2 = block_acc.sum_of_squared_diffs_from_mean(
self._target_col_name, self._ignore_nulls, mean
)
return [M2, mean, count]
def combine(self, current_accumulator: List[float], new: List[float]) -> AggType:
# Merges two accumulators [M2, mean, count] using a parallel algorithm.
# See: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
M2_a, mean_a, count_a = current_accumulator
M2_b, mean_b, count_b = new
delta = mean_b - mean_a
count = count_a + count_b
# NOTE: We use this mean calculation since it's more numerically
# stable than mean_a + delta * count_b / count, which actually
# deviates from Pandas in the ~15th decimal place and causes our
# exact comparison tests to fail.
mean = (mean_a * count_a + mean_b * count_b) / count
# Update the sum of squared differences.
M2 = M2_a + M2_b + (delta**2) * count_a * count_b / count
return [M2, mean, count]
def finalize(self, accumulator: List[float]) -> Optional[U]:
# Compute the final standard deviation from the accumulated
# sum of squared differences from current mean and the count.
# Final accumulator: [M2, mean, count]
M2, mean, count = accumulator
# Denominator for variance calculation is count - ddof
if count - self._ddof <= 0:
# If count - ddof is not positive, variance/std is undefined (or zero).
# Return NaN, consistent with pandas/numpy.
return np.nan
# Standard deviation is the square root of variance (M2 / (count - ddof))
return math.sqrt(M2 / (count - self._ddof))
[docs]
@PublicAPI
class AbsMax(AggregateFnV2):
"""Defines absolute max aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import AbsMax
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Calculating the absolute maximum value per group:
result = ds.groupby("group_key").aggregate(AbsMax(on="id")).take_all()
# result: [{'group_key': 0, 'abs_max(id)': ...},
# {'group_key': 1, 'abs_max(id)': ...},
# {'group_key': 2, 'abs_max(id)': ...}]
Args:
on: The name of the column to calculate absolute maximum on. Must be provided.
ignore_nulls: Whether to ignore null values. Default is True.
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
if on is None or not isinstance(on, str):
raise ValueError(f"Column to aggregate on has to be provided (got {on})")
super().__init__(
alias_name if alias_name else f"abs_max({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=lambda: 0,
)
def aggregate_block(self, block: Block) -> AggType:
block_accessor = BlockAccessor.for_block(block)
max_ = block_accessor.max(self._target_col_name, self._ignore_nulls)
min_ = block_accessor.min(self._target_col_name, self._ignore_nulls)
if is_null(max_) or is_null(min_):
return None
return max(
abs(max_),
abs(min_),
)
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
return max(current_accumulator, new)
[docs]
@PublicAPI
class Quantile(AggregateFnV2):
"""Defines Quantile aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import Quantile
ds = ray.data.range(100)
# Schema: {'id': int64}
ds = ds.add_column("group_key", lambda x: x % 3)
# Schema: {'id': int64, 'group_key': int64}
# Calculating the 50th percentile (median) per group:
result = ds.groupby("group_key").aggregate(Quantile(q=0.5, on="id")).take_all()
# result: [{'group_key': 0, 'quantile(id)': ...},
# {'group_key': 1, 'quantile(id)': ...},
# {'group_key': 2, 'quantile(id)': ...}]
Args:
on: The name of the column to calculate the quantile on. Must be provided.
q: The quantile to compute, which must be between 0 and 1 inclusive.
For example, q=0.5 computes the median.
ignore_nulls: Whether to ignore null values. Default is True.
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
q: float = 0.5,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
self._q = q
super().__init__(
alias_name if alias_name else f"quantile({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=list,
)
def combine(self, current_accumulator: List[Any], new: List[Any]) -> List[Any]:
if isinstance(current_accumulator, List) and isinstance(new, List):
current_accumulator.extend(new)
return current_accumulator
if isinstance(current_accumulator, List) and (not isinstance(new, List)):
if new is not None and new != "":
current_accumulator.append(new)
return current_accumulator
if isinstance(new, List) and (not isinstance(current_accumulator, List)):
if current_accumulator is not None and current_accumulator != "":
new.append(current_accumulator)
return new
ls = []
if current_accumulator is not None and current_accumulator != "":
ls.append(current_accumulator)
if new is not None and new != "":
ls.append(new)
return ls
def aggregate_block(self, block: Block) -> AggType:
block_acc = BlockAccessor.for_block(block)
ls = []
for row in block_acc.iter_rows(public_row_format=False):
ls.append(row.get(self._target_col_name))
return ls
def finalize(self, accumulator: List[Any]) -> Optional[U]:
if self._ignore_nulls:
accumulator = [v for v in accumulator if not is_null(v)]
else:
nulls = [v for v in accumulator if is_null(v)]
if len(nulls) > 0:
# If nulls are present and not ignored, the quantile is undefined.
# Return the first null encountered to preserve column type.
return nulls[0]
if not accumulator:
# If the list is empty (e.g., all values were null and ignored, or no values),
# quantile is undefined.
return None
key = lambda x: x # noqa: E731
input_values = sorted(accumulator)
k = (len(input_values) - 1) * self._q
f = math.floor(k)
c = math.ceil(k)
if f == c:
return key(input_values[int(k)])
# Interpolate between the elements at floor and ceil indices.
d0 = key(input_values[int(f)]) * (c - k)
d1 = key(input_values[int(c)]) * (k - f)
return round(d0 + d1, 5)
[docs]
@PublicAPI
class Unique(AggregateFnV2):
"""Defines unique aggregation.
Example:
.. testcode::
import ray
from ray.data.aggregate import Unique
ds = ray.data.range(100)
ds = ds.add_column("group_key", lambda x: x % 3)
# Calculating the unique values per group:
result = ds.groupby("group_key").aggregate(Unique(on="id")).take_all()
# result: [{'group_key': 0, 'unique(id)': ...},
# {'group_key': 1, 'unique(id)': ...},
# {'group_key': 2, 'unique(id)': ...}]
Args:
on: The name of the column from which to collect unique values.
ignore_nulls: Whether to ignore null values when collecting unique items.
Default is True (nulls are excluded).
alias_name: Optional name for the resulting column.
"""
def __init__(
self,
on: Optional[str] = None,
ignore_nulls: bool = True,
alias_name: Optional[str] = None,
):
super().__init__(
alias_name if alias_name else f"unique({str(on)})",
on=on,
ignore_nulls=ignore_nulls,
zero_factory=set,
)
def combine(self, current_accumulator: AggType, new: AggType) -> AggType:
return self._to_set(current_accumulator) | self._to_set(new)
def aggregate_block(self, block: Block) -> AggType:
import pyarrow.compute as pac
col = BlockAccessor.for_block(block).to_arrow().column(self._target_col_name)
return pac.unique(col).to_pylist()
@staticmethod
def _to_set(x):
if isinstance(x, set):
return x
elif isinstance(x, list):
return set(x)
else:
return {x}
def _null_safe_zero_factory(zero_factory, ignore_nulls: bool):
"""NOTE: PLEASE READ CAREFULLY BEFORE CHANGING
Null-safe zero factory is crucial for implementing proper aggregation
protocol (monoid) w/o the need for additional containers.
Main hurdle for implementing proper aggregation semantic is to be able to encode
semantic of an "empty accumulator" and be able to tell it from the case when
accumulator is actually holding null value:
- Empty container can be overridden with any value
- Container holding null can't be overridden if ignore_nulls=False
However, it's possible for us to exploit asymmetry in cases of ignore_nulls being
True or False:
- Case of ignore_nulls=False entails that if there's any "null" in the sequence,
aggregation is undefined and correspondingly expected to return null
- Case of ignore_nulls=True in turn, entails that if aggregation returns "null"
if and only if the sequence does NOT have any non-null value
Therefore, we apply this difference in semantic to zero-factory to make sure that
our aggregation protocol is adherent to that definition:
- If ignore_nulls=True, zero-factory returns null, therefore encoding empty
container
- If ignore_nulls=False, couldn't return null as aggregation will incorrectly
prioritize it, and instead it returns true zero value for the aggregation
(ie 0 for count/sum, -inf for max, etc).
"""
if ignore_nulls:
def _safe_zero_factory(_):
return None
else:
def _safe_zero_factory(_):
return zero_factory()
return _safe_zero_factory
def _null_safe_aggregate(
aggregate: Callable[[Block], AggType],
ignore_nulls: bool,
) -> Callable[[Block], Optional[AggType]]:
def _safe_aggregate(block: Block) -> Optional[AggType]:
result = aggregate(block)
# NOTE: If `ignore_nulls=True`, aggregation will only be returning
# null if the block does NOT contain any non-null elements
if is_null(result) and ignore_nulls:
return None
return result
return _safe_aggregate
def _null_safe_finalize(
finalize: Callable[[AggType], AggType]
) -> Callable[[Optional[AggType]], AggType]:
def _safe_finalize(acc: Optional[AggType]) -> AggType:
# If accumulator container is not null, finalize.
# Otherwise, return as is.
return acc if is_null(acc) else finalize(acc)
return _safe_finalize
def _null_safe_combine(
combine: Callable[[AggType, AggType], AggType], ignore_nulls: bool
) -> Callable[[Optional[AggType], Optional[AggType]], Optional[AggType]]:
"""Null-safe combination have to be an associative operation
with an identity element (zero) or in other words implement a monoid.
To achieve that in the presence of null values following semantic is
established:
- Case of ignore_nulls=True:
- If current accumulator is null (ie empty), return new accumulator
- If new accumulator is null (ie empty), return cur
- Otherwise combine (current and new)
- Case of ignore_nulls=False:
- If new accumulator is null (ie has null in the sequence, b/c we're
NOT ignoring nulls), return it
- If current accumulator is null (ie had null in the prior sequence,
b/c we're NOT ignoring nulls), return it
- Otherwise combine (current and new)
"""
if ignore_nulls:
def _safe_combine(
cur: Optional[AggType], new: Optional[AggType]
) -> Optional[AggType]:
if is_null(cur):
return new
elif is_null(new):
return cur
else:
return combine(cur, new)
else:
def _safe_combine(
cur: Optional[AggType], new: Optional[AggType]
) -> Optional[AggType]:
if is_null(new):
return new
elif is_null(cur):
return cur
else:
return combine(cur, new)
return _safe_combine