Source code for ray.data.aggregate
from typing import TYPE_CHECKING, Callable, Optional, Union
from ray.data.block import AggType, Block, BlockAccessor, KeyType, T, U
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
import pyarrow as pa
[docs]
@PublicAPI
class AggregateFn:
"""Defines an aggregate function in the accumulator style.
Aggregates a collection of inputs of type T into
a single output value of type U.
See https://www.sigops.org/s/conferences/sosp/2009/papers/yu-sosp09.pdf
for more details about accumulator-based aggregation.
Args:
init: This is called once for each group to return the empty accumulator.
For example, an empty accumulator for a sum would be 0.
merge: This may be called multiple times, each time to merge
two accumulators into one.
name: The name of the aggregation. This will be used as the column name
in the output Dataset.
accumulate_row: This is called once per row of the same group.
This combines the accumulator and the row, returns the updated
accumulator. Exactly one of accumulate_row and accumulate_block must
be provided.
accumulate_block: This is used to calculate the aggregation for a
single block, and is vectorized alternative to accumulate_row. This will
be given a base accumulator and the entire block, allowing for
vectorized accumulation of the block. Exactly one of accumulate_row and
accumulate_block must be provided.
finalize: This is called once to compute the final aggregation
result from the fully merged accumulator.
"""
def __init__(
self,
init: Callable[[KeyType], AggType],
merge: Callable[[AggType, AggType], AggType],
name: str,
accumulate_row: Callable[[AggType, T], AggType] = None,
accumulate_block: Callable[[AggType, Block], AggType] = None,
finalize: Optional[Callable[[AggType], U]] = None,
):
if (accumulate_row is None and accumulate_block is None) or (
accumulate_row is not None and accumulate_block is not None
):
raise ValueError(
"Exactly one of accumulate_row or accumulate_block must be provided."
)
if accumulate_block is None:
def accumulate_block(a: AggType, block: Block) -> AggType:
block_acc = BlockAccessor.for_block(block)
for r in block_acc.iter_rows(public_row_format=False):
a = accumulate_row(a, r)
return a
if not isinstance(name, str):
raise TypeError("`name` must be provided.")
if finalize is None:
finalize = lambda a: a # noqa: E731
self.init = init
self.merge = merge
self.name = name
self.accumulate_block = accumulate_block
self.finalize = finalize
def _validate(self, schema: Optional[Union[type, "pa.lib.Schema"]]) -> None:
"""Raise an error if this cannot be applied to the given schema."""
pass