Source code for ray.serve.batching

import asyncio
import time
from dataclasses import dataclass
from functools import wraps
from inspect import isasyncgenfunction, iscoroutinefunction
from typing import (
    Any,
    AsyncGenerator,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    overload,
)

from ray._private.signature import extract_signature, flatten_args, recover_args
from ray._private.utils import get_or_create_event_loop
from ray.serve._private.utils import extract_self_if_method_call
from ray.serve.exceptions import RayServeException
from ray.util.annotations import PublicAPI


@dataclass
class _SingleRequest:
    self_arg: Any
    flattened_args: List[Any]
    future: asyncio.Future


@dataclass
class _GeneratorResult:
    result: Any
    next_future: asyncio.Future


def _batch_args_kwargs(
    list_of_flattened_args: List[List[Any]],
) -> Tuple[Tuple[Any], Dict[Any, Any]]:
    """Batch a list of flatten args and returns regular args and kwargs"""
    # Ray's flatten arg format is a list with alternating key and values
    # e.g. args=(1, 2), kwargs={"key": "val"} got turned into
    #      [None, 1, None, 2, "key", "val"]
    arg_lengths = {len(args) for args in list_of_flattened_args}
    assert (
        len(arg_lengths) == 1
    ), "All batch requests should have the same number of parameters."
    arg_length = arg_lengths.pop()

    batched_flattened_args = []
    for idx in range(arg_length):
        if idx % 2 == 0:
            batched_flattened_args.append(list_of_flattened_args[0][idx])
        else:
            batched_flattened_args.append(
                [item[idx] for item in list_of_flattened_args]
            )

    return recover_args(batched_flattened_args)


class _BatchQueue:
    def __init__(
        self,
        max_batch_size: int,
        batch_wait_timeout_s: float,
        handle_batch_func: Optional[Callable] = None,
    ) -> None:
        """Async queue that accepts individual items and returns batches.

        Respects max_batch_size and timeout_s; a batch will be returned when
        max_batch_size elements are available or the timeout has passed since
        the previous get.

        If handle_batch_func is passed in, a background coroutine will run to
        poll from the queue and call handle_batch_func on the results.

        Cannot be pickled.

        Arguments:
            max_batch_size: max number of elements to return in a batch.
            timeout_s: time to wait before returning an incomplete
                batch.
            handle_batch_func(Optional[Callable]): callback to run in the
                background to handle batches if provided.
        """
        self.queue: asyncio.Queue[_SingleRequest] = asyncio.Queue()
        self.max_batch_size = max_batch_size
        self.batch_wait_timeout_s = batch_wait_timeout_s
        self.queue_put_event = asyncio.Event()

        self._handle_batch_task = None
        if handle_batch_func is not None:
            self._handle_batch_task = get_or_create_event_loop().create_task(
                self._process_batches(handle_batch_func)
            )

    def put(self, request: Tuple[_SingleRequest, asyncio.Future]) -> None:
        self.queue.put_nowait(request)
        self.queue_put_event.set()

    async def wait_for_batch(self) -> List[Any]:
        """Wait for batch respecting self.max_batch_size and self.timeout_s.

        Returns a batch of up to self.max_batch_size items. Waits for up to
        to self.timeout_s after receiving the first request that will be in
        the next batch. After the timeout, returns as many items as are ready.

        Always returns a batch with at least one item - will block
        indefinitely until an item comes in.
        """

        batch = []
        batch.append(await self.queue.get())

        # Cache current max_batch_size and batch_wait_timeout_s for this batch.
        max_batch_size = self.max_batch_size
        batch_wait_timeout_s = self.batch_wait_timeout_s

        # Wait self.timeout_s seconds for new queue arrivals.
        batch_start_time = time.time()
        while True:
            remaining_batch_time_s = max(
                batch_wait_timeout_s - (time.time() - batch_start_time), 0
            )
            try:
                # Wait for new arrivals.
                await asyncio.wait_for(
                    self.queue_put_event.wait(), remaining_batch_time_s
                )
            except asyncio.TimeoutError:
                pass

            # Add all new arrivals to the batch.
            while len(batch) < max_batch_size and not self.queue.empty():
                batch.append(self.queue.get_nowait())
            self.queue_put_event.clear()

            if (
                time.time() - batch_start_time >= batch_wait_timeout_s
                or len(batch) >= max_batch_size
            ):
                break

        return batch

    def _validate_results(
        self, results: Iterable[Any], input_batch_length: int
    ) -> None:
        if len(results) != input_batch_length:
            raise RayServeException(
                "Batched function doesn't preserve batch size. "
                f"The input list has length {input_batch_length} but the "
                f"returned list has length {len(results)}."
            )

    async def _consume_func_generator(
        self,
        func_generator: AsyncGenerator,
        initial_futures: List[asyncio.Future],
        input_batch_length: int,
    ) -> None:
        """Consumes batch function generator.

        This function only runs if the function decorated with @serve.batch
        is a generator.
        """

        FINISHED_TOKEN = None

        try:
            futures = initial_futures
            async for results in func_generator:
                self._validate_results(results, input_batch_length)
                next_futures = []
                for result, future in zip(results, futures):
                    if future is FINISHED_TOKEN:
                        # This caller has already terminated.
                        next_futures.append(FINISHED_TOKEN)
                    elif result in [StopIteration, StopAsyncIteration]:
                        # User's code returned sentinel. No values left
                        # for caller. Terminate iteration for caller.
                        future.set_exception(StopAsyncIteration)
                        next_futures.append(FINISHED_TOKEN)
                    else:
                        next_future = get_or_create_event_loop().create_future()
                        future.set_result(_GeneratorResult(result, next_future))
                        next_futures.append(next_future)
                futures = next_futures

            for future in futures:
                if future is not FINISHED_TOKEN:
                    future.set_exception(StopAsyncIteration)
        except Exception as e:
            for future in futures:
                if future is not FINISHED_TOKEN:
                    future.set_exception(e)

    async def _process_batches(self, func: Callable) -> None:
        """Loops infinitely and processes queued request batches."""

        while True:
            batch: List[_SingleRequest] = await self.wait_for_batch()
            assert len(batch) > 0
            self_arg = batch[0].self_arg
            args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch])
            futures = [item.future for item in batch]

            # Method call.
            if self_arg is not None:
                func_future_or_generator = func(self_arg, *args, **kwargs)
            # Normal function call.
            else:
                func_future_or_generator = func(*args, **kwargs)

            if isasyncgenfunction(func):
                func_generator = func_future_or_generator
                await self._consume_func_generator(func_generator, futures, len(batch))
            else:
                try:
                    func_future = func_future_or_generator
                    results = await func_future
                    self._validate_results(results, len(batch))
                    for result, future in zip(results, futures):
                        future.set_result(result)
                except Exception as e:
                    for future in futures:
                        future.set_exception(e)

    def __del__(self):
        if (
            self._handle_batch_task is None
            or not get_or_create_event_loop().is_running()
        ):
            return

        # TODO(edoakes): although we try to gracefully shutdown here, it still
        # causes some errors when the process exits due to the asyncio loop
        # already being destroyed.
        self._handle_batch_task.cancel()


class _LazyBatchQueueWrapper:
    """Stores a _BatchQueue and updates its settings.

    _BatchQueue cannot be pickled, you must construct it lazily
    at runtime inside a replica. This class initializes a queue only upon
    first access.
    """

    def __init__(
        self,
        max_batch_size: int = 10,
        batch_wait_timeout_s: float = 0.0,
        handle_batch_func: Optional[Callable] = None,
        batch_queue_cls: Type[_BatchQueue] = _BatchQueue,
    ):
        self._queue: Type[_BatchQueue] = None
        self.max_batch_size = max_batch_size
        self.batch_wait_timeout_s = batch_wait_timeout_s
        self.handle_batch_func = handle_batch_func
        self.batch_queue_cls = batch_queue_cls

    @property
    def queue(self) -> Type[_BatchQueue]:
        """Returns _BatchQueue.

        Initializes queue when called for the first time.
        """
        if self._queue is None:
            self._queue = self.batch_queue_cls(
                self.max_batch_size,
                self.batch_wait_timeout_s,
                self.handle_batch_func,
            )
        return self._queue

    def set_max_batch_size(self, new_max_batch_size: int) -> None:
        """Updates queue's max_batch_size."""

        self.max_batch_size = new_max_batch_size

        if self._queue is not None:
            self._queue.max_batch_size = new_max_batch_size

    def set_batch_wait_timeout_s(self, new_batch_wait_timeout_s: float) -> None:
        self.batch_wait_timeout_s = new_batch_wait_timeout_s

        if self._queue is not None:
            self._queue.batch_wait_timeout_s = new_batch_wait_timeout_s

    def get_max_batch_size(self) -> int:
        return self.max_batch_size

    def get_batch_wait_timeout_s(self) -> float:
        return self.batch_wait_timeout_s


def _validate_max_batch_size(max_batch_size):
    if not isinstance(max_batch_size, int):
        if isinstance(max_batch_size, float) and max_batch_size.is_integer():
            max_batch_size = int(max_batch_size)
        else:
            raise TypeError(
                f"max_batch_size must be integer >= 1, got {max_batch_size}"
            )

    if max_batch_size < 1:
        raise ValueError(
            f"max_batch_size must be an integer >= 1, got {max_batch_size}"
        )


def _validate_batch_wait_timeout_s(batch_wait_timeout_s):
    if not isinstance(batch_wait_timeout_s, (float, int)):
        raise TypeError(
            "batch_wait_timeout_s must be a float >= 0, " f"got {batch_wait_timeout_s}"
        )

    if batch_wait_timeout_s < 0:
        raise ValueError(
            "batch_wait_timeout_s must be a float >= 0, " f"got {batch_wait_timeout_s}"
        )


T = TypeVar("T")
R = TypeVar("R")
F = TypeVar("F", bound=Callable[[List[T]], List[R]])
G = TypeVar("G", bound=Callable[[T], R])


# Normal decorator use case (called with no arguments).
@overload
def batch(func: F) -> G:
    pass


# "Decorator factory" use case (called with arguments).
@overload
def batch(
    max_batch_size: int = 10,
    batch_wait_timeout_s: float = 0.0,
) -> Callable[[F], G]:
    pass


[docs]@PublicAPI(stability="stable") def batch( _func: Optional[Callable] = None, max_batch_size: int = 10, batch_wait_timeout_s: float = 0.0, *, batch_queue_cls: Type[_BatchQueue] = _BatchQueue, ): """Converts a function to asynchronously handle batches. The function can be a standalone function or a class method. In both cases, the function must be `async def` and take a list of objects as its sole argument and return a list of the same length as a result. When invoked, the caller passes a single object. These will be batched and executed asynchronously once there is a batch of `max_batch_size` or `batch_wait_timeout_s` has elapsed, whichever occurs first. `max_batch_size` and `batch_wait_timeout_s` can be updated using setter methods from the batch_handler (`set_max_batch_size` and `set_batch_wait_timeout_s`). Example: .. code-block:: python from ray import serve from starlette.requests import Request @serve.deployment class BatchedDeployment: @serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1) async def batch_handler(self, requests: List[Request]) -> List[str]: response_batch = [] for r in requests: name = (await requests.json())["name"] response_batch.append(f"Hello {name}!") return response_batch def update_batch_params(self, max_batch_size, batch_wait_timeout_s): self.batch_handler.set_max_batch_size(max_batch_size) self.batch_handler.set_batch_wait_timeout_s(batch_wait_timeout_s) async def __call__(self, request: Request): return await self.batch_handler(request) app = BatchedDeployment.bind() Arguments: max_batch_size: the maximum batch size that will be executed in one call to the underlying function. batch_wait_timeout_s: the maximum duration to wait for `max_batch_size` elements before running the current batch. batch_queue_cls: the class to use for the underlying batch queue. """ # `_func` will be None in the case when the decorator is parametrized. # See the comment at the end of this function for a detailed explanation. if _func is not None: if not callable(_func): raise TypeError( "@serve.batch can only be used to decorate functions or methods." ) if not iscoroutinefunction(_func): raise TypeError("Functions decorated with @serve.batch must be 'async def'") _validate_max_batch_size(max_batch_size) _validate_batch_wait_timeout_s(batch_wait_timeout_s) def _batch_decorator(_func): lazy_batch_queue_wrapper = _LazyBatchQueueWrapper( max_batch_size, batch_wait_timeout_s, _func, batch_queue_cls, ) async def batch_handler_generator( first_future: asyncio.Future, ) -> AsyncGenerator: """Generator that handles generator batch functions.""" future = first_future while True: try: async_response: _GeneratorResult = await future future = async_response.next_future yield async_response.result except StopAsyncIteration: break def enqueue_request(args, kwargs) -> asyncio.Future: self = extract_self_if_method_call(args, _func) flattened_args: List = flatten_args(extract_signature(_func), args, kwargs) if self is None: # For functions, inject the batch queue as an # attribute of the function. batch_queue_object = _func else: # For methods, inject the batch queue as an # attribute of the object. batch_queue_object = self # Trim the self argument from methods flattened_args = flattened_args[2:] batch_queue = lazy_batch_queue_wrapper.queue # Magic batch_queue_object attributes that can be used to change the # batch queue attributes on the fly. # This is purposefully undocumented for now while we figure out # the best API. if hasattr(batch_queue_object, "_ray_serve_max_batch_size"): new_max_batch_size = getattr( batch_queue_object, "_ray_serve_max_batch_size" ) _validate_max_batch_size(new_max_batch_size) batch_queue.max_batch_size = new_max_batch_size if hasattr(batch_queue_object, "_ray_serve_batch_wait_timeout_s"): new_batch_wait_timeout_s = getattr( batch_queue_object, "_ray_serve_batch_wait_timeout_s" ) _validate_batch_wait_timeout_s(new_batch_wait_timeout_s) batch_queue.batch_wait_timeout_s = new_batch_wait_timeout_s future = get_or_create_event_loop().create_future() batch_queue.put(_SingleRequest(self, flattened_args, future)) return future # TODO (shrekris-anyscale): deprecate batch_queue_cls argument and # convert batch_wrapper into a class once `self` argument is no # longer needed in `enqueue_request`. @wraps(_func) def generator_batch_wrapper(*args, **kwargs): first_future = enqueue_request(args, kwargs) return batch_handler_generator(first_future) @wraps(_func) async def batch_wrapper(*args, **kwargs): # This will raise if the underlying call raised an exception. return await enqueue_request(args, kwargs) if isasyncgenfunction(_func): wrapper = generator_batch_wrapper else: wrapper = batch_wrapper # We store the lazy_batch_queue_wrapper's getters and setters as # batch_wrapper attributes, so they can be accessed in user code. wrapper._get_max_batch_size = lazy_batch_queue_wrapper.get_max_batch_size wrapper._get_batch_wait_timeout_s = ( lazy_batch_queue_wrapper.get_batch_wait_timeout_s ) wrapper.set_max_batch_size = lazy_batch_queue_wrapper.set_max_batch_size wrapper.set_batch_wait_timeout_s = ( lazy_batch_queue_wrapper.set_batch_wait_timeout_s ) return wrapper # Unfortunately, this is required to handle both non-parametrized # (@serve.batch) and parametrized (@serve.batch(**kwargs)) usage. # In the former case, `serve.batch` will be called with the underlying # function as the sole argument. In the latter case, it will first be # called with **kwargs, then the result of that call will be called # with the underlying function as the sole argument (i.e., it must be a # "decorator factory."). return _batch_decorator(_func) if callable(_func) else _batch_decorator