Dynamic Request Batching#

Serve offers a request batching feature that can improve your service throughput without sacrificing latency. This improvement is possible because ML models can utilize efficient vectorized computation to process a batch of requests at a time. Batching is also necessary when your model is expensive to use and you want to maximize the utilization of hardware.

Machine Learning (ML) frameworks such as Tensorflow, PyTorch, and Scikit-Learn support evaluating multiple samples at the same time. Ray Serve allows you to take advantage of this feature with dynamic request batching. When a request arrives, Serve puts the request in a queue. This queue buffers the requests to form a batch. The deployment picks up the batch and evaluates it. After the evaluation, Ray Serve splits up the resulting batch, and returns each response individually.

Enable batching for your deployment#

You can enable batching by using the ray.serve.batch decorator. The following simple example modifies the Model class to accept a batch:

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class Model:
    def __call__(self, single_sample: int) -> int:
        return single_sample * 2


handle: DeploymentHandle = serve.run(Model.bind())
assert handle.remote(1).result() == 2

The batching decorators expect you to make the following changes in your method signature:

  • Declare the method as an async method because the decorator batches in asyncio event loop.

  • Modify the method to accept a list of its original input types as input. For example, arg1: int, arg2: str should be changed to arg1: List[int], arg2: List[str].

  • Modify the method to return a list. The length of the return list and the input list must be of equal lengths for the decorator to split the output evenly and return a corresponding response back to its respective request.

from typing import List

import numpy as np

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class Model:
    @serve.batch(max_batch_size=8, batch_wait_timeout_s=0.1)
    async def __call__(self, multiple_samples: List[int]) -> List[int]:
        # Use numpy's vectorized computation to efficiently process a batch.
        return np.array(multiple_samples) * 2


handle: DeploymentHandle = serve.run(Model.bind())
responses = [handle.remote(i) for i in range(8)]
assert list(r.result() for r in responses) == [i * 2 for i in range(8)]

You can supply 4 optional parameters to the decorators:

  • max_batch_size controls the size of the batch. The default value is 10.

  • batch_wait_timeout_s controls how long Serve should wait for a batch once the first request arrives. The default value is 0.01 (10 milliseconds).

  • max_concurrent_batches maximum number of batches that can run concurrently. The default value is 1.

  • batch_size_fn optional function to compute the effective batch size. If provided, this function takes a list of items and returns an integer representing the batch size. This is useful for batching based on custom metrics such as total nodes in graphs or total tokens in sequences. If None (the default), the batch size is computed as len(batch).

Once the first request arrives, the batching decorator waits for a full batch (up to max_batch_size) until batch_wait_timeout_s is reached. If the timeout is reached, Serve sends the batch to the model regardless of the batch size.

Tip

You can reconfigure your batch_wait_timeout_s and max_batch_size parameters using the set_batch_wait_timeout_s and set_max_batch_size methods:

from typing import Dict


@serve.deployment(
    # These values can be overridden in the Serve config.
    user_config={
        "max_batch_size": 10,
        "batch_wait_timeout_s": 0.5,
    }
)
class Model:
    @serve.batch(max_batch_size=8, batch_wait_timeout_s=0.1)
    async def __call__(self, multiple_samples: List[int]) -> List[int]:
        # Use numpy's vectorized computation to efficiently process a batch.
        return np.array(multiple_samples) * 2

    def reconfigure(self, user_config: Dict):
        self.__call__.set_max_batch_size(user_config["max_batch_size"])
        self.__call__.set_batch_wait_timeout_s(user_config["batch_wait_timeout_s"])


Use these methods in the constructor or the reconfigure method to control the @serve.batch parameters through your Serve configuration file.

Custom batch size functions#

By default, Ray Serve measures batch size as the number of items in the batch (len(batch)). However, in many workloads, the computational cost depends on properties of the items themselves rather than just the count. For example:

  • Graph Neural Networks (GNNs): The cost depends on the total number of nodes across all graphs, not the number of graphs

  • Natural Language Processing (NLP): Transformer models batch by total token count, not the number of sequences

  • Variable-resolution images: Memory usage depends on total pixels, not the number of images

Use the batch_size_fn parameter to define a custom metric for batch size:

Graph Neural Network example#

The following example shows how to batch graph data by total node count:

from typing import List

from ray import serve
from ray.serve.handle import DeploymentHandle


class Graph:
    """Simple graph data structure for GNN workloads."""

    def __init__(self, num_nodes: int, node_features: list):
        self.num_nodes = num_nodes
        self.node_features = node_features


@serve.deployment
class GraphNeuralNetwork:
    @serve.batch(
        max_batch_size=10000,  # Maximum total nodes per batch
        batch_wait_timeout_s=0.1,
        batch_size_fn=lambda graphs: sum(g.num_nodes for g in graphs),
    )
    async def predict(self, graphs: List[Graph]) -> List[float]:
        """Process a batch of graphs, batching by total node count."""
        # The batch_size_fn ensures that the total number of nodes
        # across all graphs in the batch doesn't exceed max_batch_size.
        # This prevents GPU memory overflow.
        results = []
        for graph in graphs:
            # Your GNN model inference logic here
            # For this example, just return a simple score
            score = float(graph.num_nodes * 0.1)
            results.append(score)
        return results

    async def __call__(self, graph: Graph) -> float:
        return await self.predict(graph)


handle: DeploymentHandle = serve.run(GraphNeuralNetwork.bind())

# Create test graphs with varying node counts
graphs = [
    Graph(num_nodes=100, node_features=[1.0] * 100),
    Graph(num_nodes=5000, node_features=[2.0] * 5000),
    Graph(num_nodes=3000, node_features=[3.0] * 3000),
]

# Send requests - they'll be batched by total node count
results = [handle.remote(g).result() for g in graphs]
print(f"Results: {results}")

In this example, batch_size_fn=lambda graphs: sum(g.num_nodes for g in graphs) ensures that the batch contains at most 10,000 total nodes, preventing GPU memory overflow regardless of how many individual graphs are in the batch.

NLP token batching example#

The following example shows how to batch text sequences by total token count:

from typing import List

from ray import serve
from ray.serve.handle import DeploymentHandle


@serve.deployment
class TokenBatcher:
    @serve.batch(
        max_batch_size=512,  # Maximum total tokens per batch
        batch_wait_timeout_s=0.1,
        batch_size_fn=lambda sequences: sum(len(s.split()) for s in sequences),
    )
    async def process(self, sequences: List[str]) -> List[int]:
        """Process text sequences, batching by total token count."""
        # The batch_size_fn ensures total tokens don't exceed max_batch_size.
        # This is useful for transformer models with fixed context windows.
        return [len(seq.split()) for seq in sequences]

    async def __call__(self, sequence: str) -> int:
        return await self.process(sequence)


handle: DeploymentHandle = serve.run(TokenBatcher.bind())

# Create sequences with different lengths
sequences = [
    "This is a short sentence",
    "This is a much longer sentence with many more words to process",
    "Short",
]

# Send requests - they'll be batched by total token count
results = [handle.remote(seq).result() for seq in sequences]
print(f"Token counts: {results}")

This pattern ensures that the total number of tokens doesn’t exceed the model’s context window or memory limits.

Streaming batched requests#

Use an async generator to stream the outputs from your batched requests. The following example converts the StreamingResponder class to accept a batch.

import asyncio
from typing import AsyncGenerator
from starlette.requests import Request
from starlette.responses import StreamingResponse

from ray import serve


@serve.deployment
class StreamingResponder:
    async def generate_numbers(self, max: str) -> AsyncGenerator[str, None]:
        for i in range(max):
            yield str(i)
            await asyncio.sleep(0.1)

    def __call__(self, request: Request) -> StreamingResponse:
        max = int(request.query_params.get("max", "25"))
        gen = self.generate_numbers(max)
        return StreamingResponse(gen, status_code=200, media_type="text/plain")


Decorate async generator functions with the ray.serve.batch decorator. Similar to non-streaming methods, the function takes in a List of inputs and in each iteration it yields an iterable of outputs with the same length as the input batch size.

import asyncio
from typing import List, AsyncGenerator, Union
from starlette.requests import Request
from starlette.responses import StreamingResponse

from ray import serve


@serve.deployment
class StreamingResponder:
    @serve.batch(max_batch_size=5, batch_wait_timeout_s=0.1)
    async def generate_numbers(
        self, max_list: List[str]
    ) -> AsyncGenerator[List[Union[int, StopIteration]], None]:
        for i in range(max(max_list)):
            next_numbers = []
            for requested_max in max_list:
                if requested_max > i:
                    next_numbers.append(str(i))
                else:
                    next_numbers.append(StopIteration)
            yield next_numbers
            await asyncio.sleep(0.1)

    async def __call__(self, request: Request) -> StreamingResponse:
        max = int(request.query_params.get("max", "25"))
        gen = self.generate_numbers(max)
        return StreamingResponse(gen, status_code=200, media_type="text/plain")


Calling the serve.batch-decorated function returns an async generator that you can await to receive results.

Some inputs within a batch may generate fewer outputs than others. When a particular input has nothing left to yield, pass a StopIteration object into the output iterable. This action terminates the generator that Serve returns when it calls the serve.batch function with that input. When serve.batch-decorated functions return streaming generators over HTTP, this action allows the end client’s connection to terminate once its call is done, instead of waiting until the entire batch is done.

Tips for fine-tuning batching parameters#

max_batch_size ideally should be a power of 2 (2, 4, 8, 16, …) because CPUs and GPUs are both optimized for data of these shapes. Large batch sizes incur a high memory cost as well as latency penalty for the first few requests.

When using batch_size_fn, set max_batch_size based on your custom metric rather than item count. For example, if batching by total nodes in graphs, set max_batch_size to your GPU’s maximum node capacity (such as 10,000 nodes) rather than a count of graphs.

Set batch_wait_timeout_s considering the end-to-end latency SLO (Service Level Objective). For example, if your latency target is 150ms, and the model takes 100ms to evaluate the batch, set the batch_wait_timeout_s to a value much lower than 150ms - 100ms = 50ms.

When using batching in a Serve Deployment Graph, the relationship between an upstream node and a downstream node might affect the performance as well. Consider a chain of two models where first model sets max_batch_size=8 and second model sets max_batch_size=6. In this scenario, when the first model finishes a full batch of 8, the second model finishes one batch of 6 and then to fill the next batch, which Serve initially only partially fills with 8 - 6 = 2 requests, leads to incurring latency costs. The batch size of downstream models should ideally be multiples or divisors of the upstream models to ensure the batches work optimally together.