Pattern: Using ray.wait to limit the number of in-flight tasks

In this pattern, we use ray.wait() to limit the number of in-flight tasks.

If we submit tasks faster than their process time, we will have tasks accumulated in the pending task queue, which will eventually cause OOM. With ray.wait(), we can apply backpressure and limit the number of in-flight tasks so that the pending task queue won’t grow indefinitely and cause OOM.

Example use case

You have a worker actor that process tasks at a rate of X tasks per second and you want to submit tasks to it at a rate lower than X to avoid OOM.

For example, Ray Serve uses this pattern to limit the number of in-flight queries for each worker.

../../_images/limit-tasks.svg

Limit number of in-flight tasks

Code example

Without backpressure:

import ray

ray.init()


@ray.remote
class Actor:
    async def heavy_compute(self):
        # taking a long time...
        # await asyncio.sleep(5)
        return


actor = Actor.remote()

NUM_TASKS = 1000
result_refs = []
# When NUM_TASKS is large enough, this will eventually OOM.
for _ in range(NUM_TASKS):
    result_refs.append(actor.heavy_compute.remote())
ray.get(result_refs)

With backpressure:

MAX_NUM_IN_FLIGHT_TASKS = 100
result_refs = []
for _ in range(NUM_TASKS):
    if len(result_refs) > MAX_NUM_IN_FLIGHT_TASKS:
        # update result_refs to only
        # track the remaining tasks.
        ready_refs, result_refs = ray.wait(result_refs, num_returns=1)
        ray.get(ready_refs)

    result_refs.append(actor.heavy_compute.remote())

ray.get(result_refs)