Pattern: Using ray.wait to limit the number of pending tasks#
In this pattern, we use ray.wait()
to limit the number of pending tasks.
If we continuously submit tasks faster than their process time, we will accumulate tasks in the pending task queue, which can eventually cause OOM.
With ray.wait()
, we can apply backpressure and limit the number of pending tasks so that the pending task queue won’t grow indefinitely and cause OOM.
Note
If we submit a finite number of tasks, it’s unlikely that we will hit the issue mentioned above since each task only uses a small amount of memory for bookkeeping in the queue. It’s more likely to happen when we have an infinite stream of tasks to run.
Note
This method is meant primarily to limit how many tasks should be in flight at the same time. It can also be used to limit how many tasks can run concurrently, but it is not recommended, as it can hurt scheduling performance. Ray automatically decides task parallelism based on resource availability, so the recommended method for adjusting how many tasks can run concurrently is to modify each task’s resource requirements instead.
Example use case#
You have a worker actor that process tasks at a rate of X tasks per second and you want to submit tasks to it at a rate lower than X to avoid OOM.
For example, Ray Serve uses this pattern to limit the number of pending queries for each worker.
Code example#
Without backpressure:
import ray
ray.init()
@ray.remote
class Actor:
async def heavy_compute(self):
# taking a long time...
# await asyncio.sleep(5)
return
actor = Actor.remote()
NUM_TASKS = 1000
result_refs = []
# When NUM_TASKS is large enough, this will eventually OOM.
for _ in range(NUM_TASKS):
result_refs.append(actor.heavy_compute.remote())
ray.get(result_refs)
With backpressure:
MAX_NUM_PENDING_TASKS = 100
result_refs = []
for _ in range(NUM_TASKS):
if len(result_refs) > MAX_NUM_PENDING_TASKS:
# update result_refs to only
# track the remaining tasks.
ready_refs, result_refs = ray.wait(result_refs, num_returns=1)
ray.get(ready_refs)
result_refs.append(actor.heavy_compute.remote())
ray.get(result_refs)