Source code for ray.util.dask.callbacks

import contextlib

from ray import ObjectRef
from collections import namedtuple, defaultdict
from datetime import datetime
from typing import Any, List, Optional

from dask.callbacks import Callback

# The names of the Ray-specific callbacks. These are the kwarg names that
# RayDaskCallback will accept on construction, and is considered the
# source-of-truth for what Ray-specific callbacks exist.
CBS = (
    "ray_presubmit",
    "ray_postsubmit",
    "ray_pretask",
    "ray_posttask",
    "ray_postsubmit_all",
    "ray_finish",
)
# The Ray-specific callback method names for RayDaskCallback.
CB_FIELDS = tuple("_" + field for field in CBS)
# The Ray-specific callbacks that we do _not_ wish to drop from RayCallbacks
# if not given on a RayDaskCallback instance (will be filled with None
# instead).
CBS_DONT_DROP = {"ray_pretask", "ray_posttask"}

# The Ray-specific callbacks for a single RayDaskCallback.
RayCallback = namedtuple("RayCallback", " ".join(CBS))

# The Ray-specific callbacks for one or more RayDaskCallbacks.
RayCallbacks = namedtuple("RayCallbacks", " ".join([field + "_cbs" for field in CBS]))


[docs]class RayDaskCallback(Callback): """ Extends Dask's `Callback` class with Ray-specific hooks. When instantiating or subclassing this class, both the normal Dask hooks (e.g. pretask, posttask, etc.) and the Ray-specific hooks can be provided. See `dask.callbacks.Callback` for usage. Caveats: Any Dask-Ray scheduler must bring the Ray-specific callbacks into context using the `local_ray_callbacks` context manager, since the built-in `local_callbacks` context manager provided by Dask isn't aware of this class. """ # Set of active Ray-specific callbacks. ray_active = set() def __init__(self, **kwargs): for cb in CBS: cb_func = kwargs.pop(cb, None) if cb_func is not None: setattr(self, "_" + cb, cb_func) super().__init__(**kwargs) @property def _ray_callback(self): return RayCallback(*[getattr(self, field, None) for field in CB_FIELDS]) def __enter__(self): self._ray_cm = add_ray_callbacks(self) self._ray_cm.__enter__() super().__enter__() return self def __exit__(self, *args): super().__exit__(*args) self._ray_cm.__exit__(*args) def register(self): type(self).ray_active.add(self._ray_callback) super().register() def unregister(self): type(self).ray_active.remove(self._ray_callback) super().unregister()
[docs] def _ray_presubmit(self, task, key, deps) -> Optional[Any]: """Run before submitting a Ray task. If this callback returns a non-`None` value, Ray does _not_ create a task and uses this value as the would-be task's result value. Args: task: A Dask task, where the first tuple item is the task function, and the remaining tuple items are the task arguments, which are either the actual argument values, or Dask keys into the deps dictionary whose corresponding values are the argument values. key: The Dask graph key for the given task. deps: The dependencies of this task. Returns: Either None, in which case Ray submits a task, or a non-None value, in which case Ray task doesn't submit a task and uses this return value as the would-be task result value. """ pass
[docs] def _ray_postsubmit(self, task, key, deps, object_ref: ObjectRef): """Run after submitting a Ray task. Args: task: A Dask task, where the first tuple item is the task function, and the remaining tuple items are the task arguments, which are either the actual argument values, or Dask keys into the deps dictionary whose corresponding values are the argument values. key: The Dask graph key for the given task. deps: The dependencies of this task. object_ref: The object reference for the return value of the Ray task. """ pass
[docs] def _ray_pretask(self, key, object_refs: List[ObjectRef]): """Run before executing a Dask task within a Ray task. This method executes after Ray submits the task within a Ray worker. Ray passes the return value of this task to the _ray_posttask callback, if provided. Args: key: The Dask graph key for the Dask task. object_refs: The object references for the arguments of the Ray task. Returns: A value that Ray passes to the corresponding _ray_posttask callback, if the callback is defined. """ pass
[docs] def _ray_posttask(self, key, result, pre_state): """Run after executing a Dask task within a Ray task. This method executes within a Ray worker. This callback receives the return value of the _ray_pretask callback, if provided. Args: key: The Dask graph key for the Dask task. result: The task result value. pre_state: The return value of the corresponding _ray_pretask callback, if said callback is defined. """ pass
[docs] def _ray_postsubmit_all(self, object_refs: List[ObjectRef], dsk): """Run after Ray submits all tasks. Args: object_refs: The object references for the output (leaf) Ray tasks of the task graph. dsk: The Dask graph. """ pass
[docs] def _ray_finish(self, result): """Run after Ray finishes executing all Ray tasks and returns the final result. Args: result: The final result (output) of the Dask computation, before any repackaging is done by Dask collection-specific post-compute callbacks. """ pass
class add_ray_callbacks: def __init__(self, *callbacks): self.callbacks = [normalize_ray_callback(c) for c in callbacks] RayDaskCallback.ray_active.update(self.callbacks) def __enter__(self): return self def __exit__(self, *args): for c in self.callbacks: RayDaskCallback.ray_active.discard(c) def normalize_ray_callback(cb): if isinstance(cb, RayDaskCallback): return cb._ray_callback elif isinstance(cb, RayCallback): return cb else: raise TypeError( "Callbacks must be either 'RayDaskCallback' or 'RayCallback' namedtuple" ) def unpack_ray_callbacks(cbs): """Take an iterable of callbacks, return a list of each callback.""" if cbs: # Only drop callback methods that aren't in CBS_DONT_DROP. return RayCallbacks( *( [cb for cb in cbs_ if cb or CBS[idx] in CBS_DONT_DROP] or None for idx, cbs_ in enumerate(zip(*cbs)) ) ) else: return RayCallbacks(*([()] * len(CBS))) @contextlib.contextmanager def local_ray_callbacks(callbacks=None): """ Allows Dask-Ray callbacks to work with nested schedulers. Callbacks will only be used by the first started scheduler they encounter. This means that only the outermost scheduler will use global callbacks. """ global_callbacks = callbacks is None if global_callbacks: callbacks, RayDaskCallback.ray_active = (RayDaskCallback.ray_active, set()) try: yield callbacks or () finally: if global_callbacks: RayDaskCallback.ray_active = callbacks class ProgressBarCallback(RayDaskCallback): def __init__(self): import ray @ray.remote class ProgressBarActor: def __init__(self): self._init() def submit(self, key, deps, now): for dep in deps.keys(): self.deps[key].add(dep) self.submitted[key] = now self.submission_queue.append((key, now)) def task_scheduled(self, key, now): self.scheduled[key] = now def finish(self, key, now): self.finished[key] = now def result(self): return len(self.submitted), len(self.finished) def report(self): result = defaultdict(dict) for key, finished in self.finished.items(): submitted = self.submitted[key] scheduled = self.scheduled[key] # deps = self.deps[key] result[key]["execution_time"] = ( finished - scheduled ).total_seconds() # Calculate the scheduling time. # This is inaccurate. # We should subtract scheduled - (last dep completed). # But currently it is not easy because # of how getitem is implemented in dask on ray sort. result[key]["scheduling_time"] = ( scheduled - submitted ).total_seconds() result["submission_order"] = self.submission_queue return result def ready(self): pass def reset(self): self._init() def _init(self): self.submission_queue = [] self.submitted = defaultdict(None) self.scheduled = defaultdict(None) self.finished = defaultdict(None) self.deps = defaultdict(set) try: self.pb = ray.get_actor("_dask_on_ray_pb") ray.get(self.pb.reset.remote()) except ValueError: self.pb = ProgressBarActor.options(name="_dask_on_ray_pb").remote() ray.get(self.pb.ready.remote()) def _ray_postsubmit(self, task, key, deps, object_ref): # Indicate the dask task is submitted. self.pb.submit.remote(key, deps, datetime.now()) def _ray_pretask(self, key, object_refs): self.pb.task_scheduled.remote(key, datetime.now()) def _ray_posttask(self, key, result, pre_state): # Indicate the dask task is finished. self.pb.finish.remote(key, datetime.now()) def _ray_finish(self, result): print("All tasks are completed.")