Source code for ray.dag.context
from dataclasses import dataclass
import os
import threading
from typing import Optional
from ray.util.annotations import DeveloperAPI
# The context singleton on this process.
_default_context: "Optional[DAGContext]" = None
_context_lock = threading.Lock()
DEFAULT_SUBMIT_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_submit_timeout", 10))
DEFAULT_GET_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_get_timeout", 10))
DEFAULT_TEARDOWN_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_teardown_timeout", 30))
DEFAULT_READ_ITERATION_TIMEOUT_S = float(
os.environ.get("RAY_CGRAPH_read_iteration_timeout_s", 0.1)
)
# Default buffer size is 1MB.
DEFAULT_BUFFER_SIZE_BYTES = int(os.environ.get("RAY_CGRAPH_buffer_size_bytes", 1e6))
# The default number of in-flight executions that can be submitted before consuming the
# output.
DEFAULT_MAX_INFLIGHT_EXECUTIONS = int(
os.environ.get("RAY_CGRAPH_max_inflight_executions", 10)
)
# The default number of results that can be buffered at the driver.
DEFAULT_MAX_BUFFERED_RESULTS = int(
os.environ.get("RAY_CGRAPH_max_buffered_results", 1000)
)
DEFAULT_OVERLAP_GPU_COMMUNICATION = bool(
os.environ.get("RAY_CGRAPH_overlap_gpu_communication", 0)
)
[docs]
@DeveloperAPI
@dataclass
class DAGContext:
"""Global settings for Ray DAG.
You can configure parameters in the DAGContext by setting the environment
variables, `RAY_CGRAPH_<param>` (e.g., `RAY_CGRAPH_buffer_size_bytes`) or Python.
Examples:
>>> from ray.dag import DAGContext
>>> DAGContext.get_current().buffer_size_bytes
1000000
>>> DAGContext.get_current().buffer_size_bytes = 500
>>> DAGContext.get_current().buffer_size_bytes
500
Args:
submit_timeout: The maximum time in seconds to wait for execute()
calls.
get_timeout: The maximum time in seconds to wait when retrieving
a result from the DAG during `ray.get`. This should be set to a
value higher than the expected time to execute the entire DAG.
teardown_timeout: The maximum time in seconds to wait for the DAG to
cleanly shut down.
read_iteration_timeout: The timeout in seconds for each read iteration
that reads one of the input channels. If the timeout is reached, the
read operation will be interrupted and will try to read the next
input channel. It must be less than or equal to `get_timeout`.
buffer_size_bytes: The initial buffer size in bytes for messages
that can be passed between tasks in the DAG. The buffers will
be automatically resized if larger messages are written to the
channel.
max_inflight_executions: The maximum number of in-flight executions that
can be submitted via `execute` or `execute_async` before consuming
the output using `ray.get()`. If the caller submits more executions,
`RayCgraphCapacityExceeded` is raised.
overlap_gpu_communication: (experimental) Whether to overlap GPU
communication with computation during DAG execution. If True, the
communication and computation can be overlapped, which can improve
the performance of the DAG execution.
"""
submit_timeout: int = DEFAULT_SUBMIT_TIMEOUT_S
get_timeout: int = DEFAULT_GET_TIMEOUT_S
teardown_timeout: int = DEFAULT_TEARDOWN_TIMEOUT_S
read_iteration_timeout: float = DEFAULT_READ_ITERATION_TIMEOUT_S
buffer_size_bytes: int = DEFAULT_BUFFER_SIZE_BYTES
max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS
max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS
overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION
def __post_init__(self):
if self.read_iteration_timeout > self.get_timeout:
raise ValueError(
"RAY_CGRAPH_read_iteration_timeout_s "
f"({self.read_iteration_timeout}) must be less than or equal to "
f"RAY_CGRAPH_get_timeout ({self.get_timeout})"
)
[docs]
@staticmethod
def get_current() -> "DAGContext":
"""Get or create a singleton context.
If the context has not yet been created in this process, it will be
initialized with default settings.
"""
global _default_context
with _context_lock:
if _default_context is None:
_default_context = DAGContext()
return _default_context