Pattern: Using ray.wait to limit the number of pending tasks
Pattern: Using ray.wait to limit the number of pending tasks#
In this pattern, we use
ray.wait() to limit the number of pending tasks.
If we continuously submit tasks faster than their process time, we will accumulate tasks in the pending task queue, which can eventually cause OOM.
ray.wait(), we can apply backpressure and limit the number of pending tasks so that the pending task queue won’t grow indefinitely and cause OOM.
If we submit a finite number of tasks, it’s unlikely that we will hit the issue mentioned above since each task only uses a small amount of memory for bookkeeping in the queue. It’s more likely to happen when we have an infinite stream of tasks to run.
This method is meant primarily to limit how many tasks should be in flight at the same time. It can also be used to limit how many tasks can run concurrently, but it is not recommended, as it can hurt scheduling performance. Ray automatically decides task parallelism based on resource availability, so the recommended method for adjusting how many tasks can run concurrently is to modify each task’s resource requirements instead.
Example use case#
You have a worker actor that process tasks at a rate of X tasks per second and you want to submit tasks to it at a rate lower than X to avoid OOM.
For example, Ray Serve uses this pattern to limit the number of pending queries for each worker.
import ray ray.init() @ray.remote class Actor: async def heavy_compute(self): # taking a long time... # await asyncio.sleep(5) return actor = Actor.remote() NUM_TASKS = 1000 result_refs =  # When NUM_TASKS is large enough, this will eventually OOM. for _ in range(NUM_TASKS): result_refs.append(actor.heavy_compute.remote()) ray.get(result_refs)
MAX_NUM_PENDING_TASKS = 100 result_refs =  for _ in range(NUM_TASKS): if len(result_refs) > MAX_NUM_PENDING_TASKS: # update result_refs to only # track the remaining tasks. ready_refs, result_refs = ray.wait(result_refs, num_returns=1) ray.get(ready_refs) result_refs.append(actor.heavy_compute.remote()) ray.get(result_refs)