import functools
import logging
from typing import Dict, Set, List, Tuple, Union, Optional, Any
import os
import time
import uuid
import ray
from ray.dag import DAGNode
from ray.dag.input_node import DAGInputData
from ray.remote_function import RemoteFunction
# avoid collision with arguments & APIs
from ray.workflow.common import (
WorkflowStatus,
Event,
asyncio_run,
validate_user_metadata,
)
from ray.workflow import serialization, workflow_access, workflow_context
from ray.workflow.event_listener import EventListener, EventListenerType, TimerListener
from ray.workflow.workflow_storage import WorkflowStorage
from ray.workflow.workflow_state_from_dag import workflow_state_from_dag
from ray.util.annotations import PublicAPI
from ray._private.usage import usage_lib
logger = logging.getLogger(__name__)
@PublicAPI(stability="alpha")
def init(
*,
max_running_workflows: Optional[int] = None,
max_pending_workflows: Optional[int] = None,
) -> None:
"""Initialize workflow.
If Ray is not initialized, we will initialize Ray and
use ``/tmp/ray/workflow_data`` as the default storage.
Args:
max_running_workflows: The maximum number of concurrently running workflows.
Use -1 as infinity. 'None' means preserving previous setting or initialize
the setting with infinity.
max_pending_workflows: The maximum number of queued workflows.
Use -1 as infinity. 'None' means preserving previous setting or initialize
the setting with infinity.
"""
usage_lib.record_library_usage("workflow")
if max_running_workflows is not None:
if not isinstance(max_running_workflows, int):
raise TypeError("'max_running_workflows' must be None or an integer.")
if max_running_workflows < -1 or max_running_workflows == 0:
raise ValueError(
"'max_running_workflows' must be a positive integer "
"or use -1 as infinity."
)
if max_pending_workflows is not None:
if not isinstance(max_pending_workflows, int):
raise TypeError("'max_pending_workflows' must be None or an integer.")
if max_pending_workflows < -1:
raise ValueError(
"'max_pending_workflows' must be a non-negative integer "
"or use -1 as infinity."
)
if not ray.is_initialized():
# We should use get_temp_dir_path, but for ray client, we don't
# have this one. We need a flag to tell whether it's a client
# or a driver to use the right dir.
# For now, just use /tmp/ray/workflow_data
ray.init(storage="file:///tmp/ray/workflow_data")
workflow_access.init_management_actor(max_running_workflows, max_pending_workflows)
serialization.init_manager()
def _ensure_workflow_initialized() -> None:
# NOTE: Trying to get the actor has a side effect: it initializes Ray with
# default arguments. This is different in "init()": it assigns a temporary
# storage. This is why we need to check "ray.is_initialized()" first.
if not ray.is_initialized():
init()
else:
try:
workflow_access.get_management_actor()
except ValueError:
init()
def client_mode_wrap(func):
"""Wraps a function called during client mode for execution as a remote task.
Adopted from "ray._private.client_mode_hook.client_mode_wrap". Some changes are made
(e.g., init the workflow instead of init Ray; the latter does not specify a storage
during Ray init and will result in workflow failures).
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
from ray._private.client_mode_hook import client_mode_should_convert
if os.environ.get("RAY_ENABLE_AUTO_CONNECT", "") != "0":
_ensure_workflow_initialized()
# `is_client_mode_enabled_by_default` is used for testing with
# `RAY_CLIENT_MODE=1`. This flag means all tests run with client mode.
if client_mode_should_convert():
f = ray.remote(num_cpus=0)(func)
ref = f.remote(*args, **kwargs)
return ray.get(ref)
return func(*args, **kwargs)
return wrapper
[docs]@PublicAPI(stability="alpha")
def run(
dag: DAGNode,
*args,
workflow_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Run a workflow.
If the workflow with the given id already exists, it will be resumed.
Examples:
.. testcode::
import ray
from ray import workflow
@ray.remote
def book_flight(origin: str, dest: str):
return f"Flight: {origin}->{dest}"
@ray.remote
def book_hotel(location: str):
return f"Hotel: {location}"
@ray.remote
def finalize_trip(bookings: List[Any]):
return ' | '.join(ray.get(bookings))
flight1 = book_flight.bind("OAK", "SAN")
flight2 = book_flight.bind("SAN", "OAK")
hotel = book_hotel.bind("SAN")
trip = finalize_trip.bind([flight1, flight2, hotel])
print(workflow.run(trip))
.. testoutput::
Flight: OAK->SAN | Flight: SAN->OAK | Hotel: SAN
Args:
workflow_id: A unique identifier that can be used to resume the
workflow. If not specified, a random id will be generated.
metadata: The metadata to add to the workflow. It has to be able
to serialize to json.
Returns:
The running result.
"""
return ray.get(
run_async(dag, *args, workflow_id=workflow_id, metadata=metadata, **kwargs)
)
[docs]@PublicAPI(stability="alpha")
def run_async(
dag: DAGNode,
*args,
workflow_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
) -> ray.ObjectRef:
"""Run a workflow asynchronously.
If the workflow with the given id already exists, it will be resumed.
Args:
workflow_id: A unique identifier that can be used to resume the
workflow. If not specified, a random id will be generated.
metadata: The metadata to add to the workflow. It has to be able
to serialize to json.
Returns:
The running result as ray.ObjectRef.
"""
_ensure_workflow_initialized()
if not isinstance(dag, DAGNode):
raise TypeError("Input should be a DAG.")
input_data = DAGInputData(*args, **kwargs)
validate_user_metadata(metadata)
metadata = metadata or {}
if workflow_id is None:
# Workflow ID format: {Entry workflow UUID}.{Unix time to nanoseconds}
workflow_id = f"{str(uuid.uuid4())}.{time.time():.9f}"
workflow_manager = workflow_access.get_management_actor()
if ray.get(workflow_manager.is_workflow_non_terminating.remote(workflow_id)):
raise RuntimeError(f"Workflow '{workflow_id}' is already running or pending.")
state = workflow_state_from_dag(dag, input_data, workflow_id)
logger.info(f'Workflow job created. [id="{workflow_id}"].')
context = workflow_context.WorkflowTaskContext(workflow_id=workflow_id)
with workflow_context.workflow_task_context(context):
# checkpoint the workflow
@client_mode_wrap
def _try_checkpoint_workflow(workflow_state) -> bool:
ws = WorkflowStorage(workflow_id)
ws.save_workflow_user_metadata(metadata)
try:
ws.get_entrypoint_task_id()
return True
except Exception:
# The workflow does not exist. We must checkpoint entry workflow.
ws.save_workflow_execution_state("", workflow_state)
return False
wf_exists = _try_checkpoint_workflow(state)
if wf_exists:
return resume_async(workflow_id)
ray.get(
workflow_manager.submit_workflow.remote(
workflow_id, state, ignore_existing=False
)
)
job_id = ray.get_runtime_context().get_job_id()
return workflow_manager.execute_workflow.remote(job_id, context)
[docs]@PublicAPI(stability="alpha")
def resume(workflow_id: str) -> Any:
"""Resume a workflow.
Resume a workflow and retrieve its output. If the workflow was incomplete,
it will be re-executed from its checkpointed outputs. If the workflow was
complete, returns the result immediately.
Examples:
.. testcode::
from ray import workflow
@ray.remote
def start_trip():
return 3
trip = start_trip.bind()
res1 = workflow.run_async(trip, workflow_id="trip1")
res2 = workflow.resume("trip1")
assert ray.get(res1) == res2
Args:
workflow_id: The id of the workflow to resume.
Returns:
The output of the workflow.
"""
return ray.get(resume_async(workflow_id))
[docs]@PublicAPI(stability="alpha")
def resume_async(workflow_id: str) -> ray.ObjectRef:
"""Resume a workflow asynchronously.
Resume a workflow and retrieve its output. If the workflow was incomplete,
it will be re-executed from its checkpointed outputs. If the workflow was
complete, returns the result immediately.
Examples:
.. testcode::
from ray import workflow
@ray.remote
def start_trip():
return 3
trip = start_trip.bind()
res1 = workflow.run_async(trip, workflow_id="trip1")
res2 = workflow.resume_async("trip1")
assert ray.get(res1) == ray.get(res2)
Args:
workflow_id: The id of the workflow to resume.
Returns:
An object reference that can be used to retrieve the workflow result.
"""
_ensure_workflow_initialized()
logger.info(f'Resuming workflow [id="{workflow_id}"].')
workflow_manager = workflow_access.get_management_actor()
if ray.get(workflow_manager.is_workflow_non_terminating.remote(workflow_id)):
raise RuntimeError(f"Workflow '{workflow_id}' is already running or pending.")
# NOTE: It is important to 'ray.get' the returned output. This
# ensures caller of 'run()' holds the reference to the workflow
# result. Otherwise if the actor removes the reference of the
# workflow output, the caller may fail to resolve the result.
job_id = ray.get_runtime_context().get_job_id()
context = workflow_context.WorkflowTaskContext(workflow_id=workflow_id)
ray.get(workflow_manager.reconstruct_workflow.remote(job_id, context))
result = workflow_manager.execute_workflow.remote(job_id, context)
logger.info(f"Workflow job {workflow_id} resumed.")
return result
[docs]@PublicAPI(stability="alpha")
def get_output(workflow_id: str, *, task_id: Optional[str] = None) -> Any:
"""Get the output of a running workflow.
Args:
workflow_id: The workflow to get the output of.
task_id: If set, fetch the specific task instead of the output of the
workflow.
Examples:
.. testcode::
from ray import workflow
@ray.remote
def start_trip():
return 1
trip = start_trip.options(**workflow.options(task_id="trip")).bind()
res1 = workflow.run_async(trip, workflow_id="trip1")
# you could "get_output()" in another machine
res2 = workflow.get_output("trip1")
assert ray.get(res1) == res2
task_output = workflow.get_output_async("trip1", task_id="trip")
assert ray.get(task_output) == ray.get(res1)
Returns:
The output of the workflow task.
"""
return ray.get(get_output_async(workflow_id, task_id=task_id))
[docs]@PublicAPI(stability="alpha")
@client_mode_wrap
def get_output_async(
workflow_id: str, *, task_id: Optional[str] = None
) -> ray.ObjectRef:
"""Get the output of a running workflow asynchronously.
Args:
workflow_id: The workflow to get the output of.
task_id: If set, fetch the specific task output instead of the output
of the workflow.
Returns:
An object reference that can be used to retrieve the workflow task result.
"""
_ensure_workflow_initialized()
try:
workflow_manager = workflow_access.get_management_actor()
except ValueError as e:
raise ValueError(
"Failed to connect to the workflow management "
"actor. The workflow could have already failed. You can use "
"workflow.resume() or workflow.resume_async() to resume the "
"workflow."
) from e
return workflow_manager.get_output.remote(workflow_id, task_id)
[docs]@PublicAPI(stability="alpha")
@client_mode_wrap
def list_all(
status_filter: Optional[
Union[Union[WorkflowStatus, str], Set[Union[WorkflowStatus, str]]]
] = None
) -> List[Tuple[str, WorkflowStatus]]:
"""List all workflows matching a given status filter. When returning "RESUMEABLE"
workflows, the workflows that was running ranks before the workflow that was pending
in the result list.
Args:
status_filter: If given, only returns workflow with that status. This can
be a single status or set of statuses. The string form of the
status is also acceptable, i.e.,
"RUNNING"/"FAILED"/"SUCCESSFUL"/"CANCELED"/"RESUMABLE"/"PENDING".
Examples:
.. testcode::
from ray import workflow
@ray.remote
def long_running_job():
import time
time.sleep(2)
workflow_task = long_running_job.bind()
wf = workflow.run_async(workflow_task,
workflow_id="long_running_job")
jobs = workflow.list_all(workflow.RUNNING)
assert jobs == [ ("long_running_job", workflow.RUNNING) ]
ray.get(wf)
jobs = workflow.list_all({workflow.RUNNING})
assert jobs == []
Returns:
A list of tuple with workflow id and workflow status
"""
_ensure_workflow_initialized()
if isinstance(status_filter, str):
status_filter = set({WorkflowStatus(status_filter)})
elif isinstance(status_filter, WorkflowStatus):
status_filter = set({status_filter})
elif isinstance(status_filter, set):
if all(isinstance(s, str) for s in status_filter):
status_filter = {WorkflowStatus(s) for s in status_filter}
elif not all(isinstance(s, WorkflowStatus) for s in status_filter):
raise TypeError(
"status_filter contains element which is not"
" a type of `WorkflowStatus or str`."
f" {status_filter}"
)
elif status_filter is None:
status_filter = set(WorkflowStatus)
status_filter.discard(WorkflowStatus.NONE)
else:
raise TypeError(
"status_filter must be WorkflowStatus or a set of WorkflowStatus."
)
try:
workflow_manager = workflow_access.get_management_actor()
except ValueError:
workflow_manager = None
if workflow_manager is None:
non_terminating_workflows = {}
else:
non_terminating_workflows = ray.get(
workflow_manager.list_non_terminating_workflows.remote()
)
ret = []
if set(non_terminating_workflows.keys()).issuperset(status_filter):
for status, workflows in non_terminating_workflows.items():
if status in status_filter:
for w in workflows:
ret.append((w, status))
return ret
ret = []
# Here we don't have workflow id, so use empty one instead
store = WorkflowStorage("")
modified_status_filter = status_filter.copy()
# Here we have to add non-terminating status to the status filter, because some
# "RESUMABLE" workflows are converted from non-terminating workflows below.
# This is the tricky part: the status "RESUMABLE" neither come from
# the workflow management actor nor the storage. It is the status where
# the storage says it is non-terminating but the workflow management actor
# is not running it. This usually happened when there was a sudden crash
# of the whole Ray runtime or the workflow management actor
# (due to cluster etc.). So we includes non terminating status in the storage
# filter to get "RESUMABLE" candidates.
modified_status_filter.update(WorkflowStatus.non_terminating_status())
status_from_storage = store.list_workflow(modified_status_filter)
non_terminating_workflows = {
k: set(v) for k, v in non_terminating_workflows.items()
}
resume_running = []
resume_pending = []
for (k, s) in status_from_storage:
if s in non_terminating_workflows and k not in non_terminating_workflows[s]:
if s == WorkflowStatus.RUNNING:
resume_running.append(k)
elif s == WorkflowStatus.PENDING:
resume_pending.append(k)
else:
assert False, "This line of code should not be reachable."
continue
if s in status_filter:
ret.append((k, s))
if WorkflowStatus.RESUMABLE in status_filter:
# The running workflows ranks before the pending workflows.
for w in resume_running:
ret.append((w, WorkflowStatus.RESUMABLE))
for w in resume_pending:
ret.append((w, WorkflowStatus.RESUMABLE))
return ret
[docs]@PublicAPI(stability="alpha")
@client_mode_wrap
def resume_all(include_failed: bool = False) -> List[Tuple[str, ray.ObjectRef]]:
"""Resume all resumable workflow jobs.
This can be used after cluster restart to resume all tasks.
Args:
include_failed: Whether to resume FAILED workflows.
Examples:
.. testcode::
from ray import workflow
@ray.remote
def failed_job():
raise ValueError()
workflow_task = failed_job.bind()
output = workflow.run_async(
workflow_task, workflow_id="failed_job")
try:
ray.get(output)
except Exception:
print("JobFailed")
assert workflow.get_status("failed_job") == workflow.FAILED
print(workflow.resume_all(include_failed=True))
.. testoutput::
JobFailed
[('failed_job', ObjectRef(...))]
Returns:
A list of (workflow_id, returned_obj_ref) resumed.
"""
_ensure_workflow_initialized()
filter_set = {WorkflowStatus.RESUMABLE}
if include_failed:
filter_set.add(WorkflowStatus.FAILED)
all_failed = list_all(filter_set)
try:
workflow_manager = workflow_access.get_management_actor()
except Exception as e:
raise RuntimeError("Failed to get management actor") from e
job_id = ray.get_runtime_context().get_job_id()
reconstructed_workflows = []
for wid, _ in all_failed:
context = workflow_context.WorkflowTaskContext(workflow_id=wid)
# TODO(suquark): This is not very efficient, but it makes sure
# running workflows has higher priority when getting reconstructed.
try:
ray.get(workflow_manager.reconstruct_workflow.remote(job_id, context))
except Exception as e:
# TODO(suquark): Here some workflows got resumed successfully but some
# failed and the user has no idea about this, which is very wired.
# Maybe we should raise an exception here instead?
logger.error(f"Failed to resume workflow {context.workflow_id}", exc_info=e)
raise
reconstructed_workflows.append(context)
results = []
for context in reconstructed_workflows:
results.append(
(
context.workflow_id,
workflow_manager.execute_workflow.remote(job_id, context),
)
)
return results
[docs]@PublicAPI(stability="alpha")
def get_status(workflow_id: str) -> WorkflowStatus:
"""Get the status for a given workflow.
Args:
workflow_id: The workflow to query.
Examples:
.. testcode::
from ray import workflow
@ray.remote
def trip():
pass
workflow_task = trip.bind()
output = workflow.run(workflow_task, workflow_id="local_trip")
assert workflow.SUCCESSFUL == workflow.get_status("local_trip")
Returns:
The status of that workflow
"""
_ensure_workflow_initialized()
if not isinstance(workflow_id, str):
raise TypeError("workflow_id has to be a string type.")
workflow_manager = workflow_access.get_management_actor()
return ray.get(workflow_manager.get_workflow_status.remote(workflow_id))
@PublicAPI(stability="alpha")
def wait_for_event(
event_listener_type: EventListenerType, *args, **kwargs
) -> "DAGNode[Event]":
if not issubclass(event_listener_type, EventListener):
raise TypeError(
f"Event listener type is {event_listener_type.__name__}"
", which is not a subclass of workflow.EventListener"
)
@ray.remote
def get_message(event_listener_type: EventListenerType, *args, **kwargs) -> Event:
event_listener = event_listener_type()
return asyncio_run(event_listener.poll_for_event(*args, **kwargs))
@ray.remote
def message_committed(
event_listener_type: EventListenerType, event: Event
) -> Event:
event_listener = event_listener_type()
asyncio_run(event_listener.event_checkpointed(event))
return event
return message_committed.bind(
event_listener_type, get_message.bind(event_listener_type, *args, **kwargs)
)
@PublicAPI(stability="alpha")
def sleep(duration: float) -> "DAGNode[Event]":
"""
A workfow that resolves after sleeping for a given duration.
"""
@ray.remote
def end_time():
return time.time() + duration
return wait_for_event(TimerListener, end_time.bind())
[docs]@PublicAPI(stability="alpha")
def cancel(workflow_id: str) -> None:
"""Cancel a workflow. Workflow checkpoints will still be saved in storage. To
clean up saved checkpoints, see `workflow.delete()`.
Args:
workflow_id: The workflow to cancel.
Examples:
.. testcode::
from ray import workflow
@ray.remote
def some_job():
return 1
workflow_task = some_job.bind()
workflow.run(workflow_task, workflow_id="some_job")
workflow.cancel(workflow_id="some_job")
assert workflow.get_status("some_job") == workflow.CANCELED
Returns:
None
"""
_ensure_workflow_initialized()
if not isinstance(workflow_id, str):
raise TypeError("workflow_id has to be a string type.")
workflow_manager = workflow_access.get_management_actor()
ray.get(workflow_manager.cancel_workflow.remote(workflow_id))
@PublicAPI(stability="alpha")
def delete(workflow_id: str) -> None:
"""Delete a workflow, its checkpoints, and other information it may have
persisted to storage. To stop a running workflow, see
`workflow.cancel()`.
Args:
workflow_id: The workflow to delete.
Raises:
WorkflowStillActiveError: The workflow is still active.
WorkflowNotFoundError: The workflow does not exist.
Examples:
.. testcode::
from ray import workflow
@ray.remote
def some_job():
pass
workflow_task = some_job.bind()
workflow.run(workflow_task, workflow_id="some_job")
workflow.delete(workflow_id="some_job")
"""
_ensure_workflow_initialized()
workflow_manager = workflow_access.get_management_actor()
ray.get(workflow_manager.delete_workflow.remote(workflow_id))
@PublicAPI(stability="alpha")
def continuation(dag_node: "DAGNode") -> Union["DAGNode", Any]:
"""Converts a DAG into a continuation.
The result depends on the context. If it is inside a workflow, it
returns a workflow; otherwise it executes and get the result of
the DAG.
Args:
dag_node: The DAG to be converted.
"""
from ray.workflow.workflow_context import in_workflow_execution
if not isinstance(dag_node, DAGNode):
raise TypeError("Input should be a DAG.")
if in_workflow_execution():
return dag_node
return ray.get(dag_node.execute())
@PublicAPI(stability="alpha")
class options:
"""This class serves both as a decorator and options for workflow.
Examples:
.. testcode::
import ray
from ray import workflow
# specify workflow options with a decorator
@workflow.options(catch_exceptions=True)
@ray.remote
def foo():
return 1
# specify workflow options in ".options"
foo_new = foo.options(**workflow.options(catch_exceptions=False))
"""
def __init__(self, **workflow_options: Dict[str, Any]):
# TODO(suquark): More rigid arguments check like @ray.remote arguments. This is
# fairly complex, but we should enable it later.
valid_options = {
"task_id",
"metadata",
"catch_exceptions",
"checkpoint",
}
invalid_keywords = set(workflow_options.keys()) - valid_options
if invalid_keywords:
raise ValueError(
f"Invalid option keywords {invalid_keywords} for workflow tasks. "
f"Valid ones are {valid_options}."
)
from ray.workflow.common import WORKFLOW_OPTIONS
validate_user_metadata(workflow_options.get("metadata"))
self.options = {"_metadata": {WORKFLOW_OPTIONS: workflow_options}}
def keys(self):
return ("_metadata",)
def __getitem__(self, key):
return self.options[key]
def __call__(self, f: RemoteFunction) -> RemoteFunction:
if not isinstance(f, RemoteFunction):
raise ValueError("Only apply 'workflow.options' to Ray remote functions.")
f._default_options.update(self.options)
return f
__all__ = (
"init",
"run",
"run_async",
"resume",
"resume_async",
"resume_all",
"cancel",
"list_all",
"delete",
"get_output",
"get_output_async",
"get_status",
"get_metadata",
"sleep",
"wait_for_event",
"options",
"continuation",
)