Batching Tutorial

In this guide, we will deploy a simple vectorized adder that takes a batch of queries and adds them at once. In particular, we show:

  • How to implement and deploy a Ray Serve backend that accepts batches.

  • How to configure the batch size.

  • How to query the model in Python.

This tutorial should help the following use cases:

  • You want to perform offline batch inference on a cluster of machines.

  • You want to serve online queries and your model can take advantage of batching. For example, linear regressions and neural networks use CPU and GPU’s vectorized instructions to perform computation in parallel. Performing inference with batching can increase the throughput of the model as well as utilization of the hardware.

Let’s import Ray Serve and some other helpers.

from typing import List, Union
import time

import numpy as np
import requests
from starlette.requests import Request

import ray
from ray import serve
from ray.serve import ServeRequest

You can use the @serve.batch decorator to annotate a function or a method. This annotation will automatically cause calls to the function to be batched together. The function must handle a list of objects and will be called with a single object. This function must also be async def so that you can handle multiple queries concurrently:

@serve.batch
async def my_batch_handler(self, requests: List):
    pass

This batch handler can then be called from another async def method in your backend. These calls will be batched and executed together, but return an individual result as if they were a normal function call:

class MyBackend:
    @serve.batch
    async def my_batch_handler(self, requests: List):
        results = []
        for request in requests:
            results.append(request.json())
        return results

    async def __call__(self, request):
        await self.my_batch_handler(request)

Note

By default, Ray Serve performs opportunistic batching. This means that as soon as the batch handler is called, the method will be executed without waiting for a full batch. If there are more queries available after this call finishes, a larger batch may be executed. This behavior can be tuned using the batch_wait_timeout_s option to @serve.batch (defaults to 0). Increasing this timeout may improve throughput at the cost of latency under low load.

Let’s define a backend that takes in a list of requests, extracts the input value, converts them into an array, and uses NumPy to add 1 to each element.

class BatchAdder:
    @serve.batch(max_batch_size=4)
    async def handle_batch(self, requests: List[Union[Request, ServeRequest]]):
        numbers = [int(request.query_params["number"]) for request in requests]

        input_array = np.array(numbers)
        print("Our input array has shape:", input_array.shape)
        # Sleep for 200ms, this could be performing CPU intensive computation
        # in real models
        time.sleep(0.2)
        output_array = input_array + 1
        return output_array.astype(int).tolist()

    async def __call__(self, request: Union[Request, ServeRequest]):
        return await self.handle_batch(request)


Let’s deploy it. Note that in the @serve.batch decorator, we are specifying specifying the maximum batch size via max_batch_size=4. This option limits the maximum possible batch size that will be executed at once.

ray.init(num_cpus=8)
serve.start()
serve.create_backend("adder:v0", BatchAdder)
serve.create_endpoint(
    "adder", backend="adder:v0", route="/adder", methods=["GET"])

Let’s define a Ray remote task to send queries in parallel. As you can see, the first batch has a batch size of 1, and the subsequent queries have a batch size of 4. Even though each query is issued independently, Ray Serve was able to evaluate them in batches.

@ray.remote
def send_query(number):
    resp = requests.get("http://localhost:8000/adder?number={}".format(number))
    return int(resp.text)


# Let's use Ray to send all queries in parallel
results = ray.get([send_query.remote(i) for i in range(9)])
print("Result returned:", results)
# Output
# (pid=...) Our input array has shape: (1,)
# (pid=...) Our input array has shape: (4,)
# (pid=...) Our input array has shape: (4,)
# Result returned: [1, 2, 3, 4, 5, 6, 7, 8, 9]

What if you want to evaluate a whole batch in Python? Ray Serve allows you to send queries via the Python API. A batch of queries can either come from the web server or the Python API. Requests coming from the Python API will have a similar API to Starlette Request. See more on the API here.

To query the backend via Python API, we can use serve.get_handle to receive a handle to the corresponding “endpoint”. To enqueue a query, you can call handle.remote(data). This call returns immediately with a Ray ObjectRef. You can call ray.get to retrieve the result.

handle = serve.get_handle("adder")
print(handle)
# Output
# RayServeHandle(
#    Endpoint="adder",
#    Traffic={'adder:v0': 1}
# )

input_batch = list(range(9))
print("Input batch is", input_batch)
# Input batch is [0, 1, 2, 3, 4, 5, 6, 7, 8]

result_batch = ray.get([handle.remote(number=i) for i in input_batch])
# Output
# (pid=...) Current context is python
# (pid=...) Our input array has shape: (1,)
# (pid=...) Current context is python
# (pid=...) Our input array has shape: (4,)
# (pid=...) Current context is python
# (pid=...) Our input array has shape: (4,)

print("Result batch is", result_batch)
# Result batch is [1, 2, 3, 4, 5, 6, 7, 8, 9]