Anti-pattern: Calling ray.get unnecessarily harms performance
Anti-pattern: Calling ray.get unnecessarily harms performance#
TLDR: Avoid calling ray.get() unnecessarily for intermediate steps. Work with object references directly, and only call
ray.get() at the end to get the final result.
ray.get() is called, objects must be transferred to the worker/node that calls
ray.get(). If you don’t need to manipulate the object, you probably don’t need to call
ray.get() on it!
Typically, it’s best practice to wait as long as possible before calling
ray.get(), or even design your program to avoid having to call
ray.get() at all.
import ray import numpy as np ray.init() @ray.remote def generate_rollout(): return np.ones((10000, 10000)) @ray.remote def reduce(rollout): return np.sum(rollout) # `ray.get()` downloads the result here. rollout = ray.get(generate_rollout.remote()) # Now we have to reupload `rollout` reduced = ray.get(reduce.remote(rollout))
# Don't need ray.get here. rollout_obj_ref = generate_rollout.remote() # Rollout object is passed by reference. reduced = ray.get(reduce.remote(rollout_obj_ref))
Notice in the anti-pattern example, we call
ray.get() which forces us to transfer the large rollout to the driver, then again to the reduce worker.
In the fixed version, we only pass the reference to the object to the reduce task.
reduce worker will implicitly call
ray.get() to fetch the actual rollout data directly from the
generate_rollout worker, avoiding the extra copy to the driver.
ray.get() related anti-patterns are: