Batching Tutorial

In this guide, we will deploy a simple vectorized adder that takes a batch of queries and add them at once. In particular, we show:

  • How to implement and deploy Ray Serve model that accepts batches.

  • How to configure the batch size.

  • How to query the model in Python.

This tutorial should help the following use cases:

  • You want to perform offline batch inference on a cluster of machines.

  • You want to serve online queries and your model can take advantage of batching. For example, linear regressions and neural networks use CPU and GPU’s vectorized instructions to perform computation in parallel. Performing inference with batching can increase the throughput of the model as well as utilization of the hardware.

Let’s import Ray Serve and some other helpers.

import ray
from ray import serve

from typing import List
import time

import numpy as np
import requests

You can use the @serve.accept_batch decorator to annotate a function or a class. This annotation is needed because batched backends have different APIs compared to single request backends. In a batched backend, the inputs are a list of values.

For single query backend, the input type is a single Flask request or ServeRequest:

def single_request(
    request: Union[Flask.Request, ServeRequest],
):
    pass

For batched backends, the input types are converted to list of their original types:

@serve.accept_batch
def batched_request(
    request: List[Union[Flask.Request, ServeRequest]],
):
    pass

Let’s define the backend function. We will take in a list of requests, extract the input value, convert them into an array, and use NumPy to add 1 to each element.

@serve.accept_batch
def batch_adder_v0(flask_requests: List):
    numbers = [int(request.args["number"]) for request in flask_requests]

    input_array = np.array(numbers)
    print("Our input array has shape:", input_array.shape)
    # Sleep for 200ms, this could be performing CPU intensive computation
    # in real models
    time.sleep(0.2)
    output_array = input_array + 1
    return output_array.astype(int).tolist()


Let’s deploy it. Note that in the config section of create_backend, we are specifying the maximum batch size via config={"max_batch_size": 4}. This configuration option limits the maximum possible batch size send to the backend.

Note

Ray Serve performs opportunistic batching. When a worker is free to evaluate the next batch, Ray Serve will look at the pending queries and take max(number_of_pending_queries, max_batch_size) queries to form a batch. You can provide batch_wait_timeout to override this behavior to wait for a full batch to arrive before executing (under a timeout).

client = serve.start()
client.create_backend("adder:v0", batch_adder_v0, config={"max_batch_size": 4})
client.create_endpoint(
    "adder", backend="adder:v0", route="/adder", methods=["GET"])

Let’s define a Ray remote task to send queries in parallel. As you can see, the first batch has a batch size of 1, and the subsequent queries have a batch size of 4. Even though each query is issued independently, Ray Serve was able to evaluate them in batches.

@ray.remote
def send_query(number):
    resp = requests.get("http://localhost:8000/adder?number={}".format(number))
    return int(resp.text)


# Let's use Ray to send all queries in parallel
results = ray.get([send_query.remote(i) for i in range(9)])
print("Result returned:", results)
# Output
# (pid=...) Our input array has shape: (1,)
# (pid=...) Our input array has shape: (4,)
# (pid=...) Our input array has shape: (4,)
# Result returned: [1, 2, 3, 4, 5, 6, 7, 8, 9]

What if you want to evaluate a whole batch in Python? Ray Serve allows you to send queries via the Python API. A batch of queries can either come from the web server or the Python API. Requests coming from the Python API will have the similar API as Flask.Request. See more on the API here.

@serve.accept_batch
def batch_adder_v1(requests: List):
    numbers = [int(request.args["number"]) for request in requests]
    input_array = np.array(numbers)
    print("Our input array has shape:", input_array.shape)
    # Sleep for 200ms, this could be performing CPU intensive computation
    # in real models
    time.sleep(0.2)
    output_array = input_array + 1
    return output_array.astype(int).tolist()


Let’s deploy the new version to the same endpoint. Don’t forget to set max_batch_size!

client.create_backend("adder:v1", batch_adder_v1, config={"max_batch_size": 4})
client.set_traffic("adder", {"adder:v1": 1})

To query the backend via Python API, we can use serve.get_handle to receive a handle to the corresponding “endpoint”. To enqueue a query, you can call handle.remote(data, argument_name=argument_value). This call returns immediately with a Ray ObjectRef. You can call ray.get to retrieve the result.

handle = client.get_handle("adder")
print(handle)
# Output
# RayServeHandle(
#    Endpoint="adder",
#    Traffic={'adder:v1': 1}
# )

input_batch = list(range(9))
print("Input batch is", input_batch)
# Input batch is [0, 1, 2, 3, 4, 5, 6, 7, 8]

result_batch = ray.get([handle.remote(number=i) for i in input_batch])
# Output
# (pid=...) Current context is python
# (pid=...) Our input array has shape: (1,)
# (pid=...) Current context is python
# (pid=...) Our input array has shape: (4,)
# (pid=...) Current context is python
# (pid=...) Our input array has shape: (4,)

print("Result batch is", result_batch)
# Result batch is [1, 2, 3, 4, 5, 6, 7, 8, 9]