Source code for ray.dag.context

from dataclasses import dataclass
import os
import threading
from typing import Optional
from ray.util.annotations import DeveloperAPI

# The context singleton on this process.
_default_context: "Optional[DAGContext]" = None
_context_lock = threading.Lock()

DEFAULT_SUBMIT_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_submit_timeout", 10))
DEFAULT_GET_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_get_timeout", 10))
DEFAULT_TEARDOWN_TIMEOUT_S = int(os.environ.get("RAY_CGRAPH_teardown_timeout", 30))
DEFAULT_READ_ITERATION_TIMEOUT_S = float(
    os.environ.get("RAY_CGRAPH_read_iteration_timeout_s", 0.1)
)
# Default buffer size is 1MB.
DEFAULT_BUFFER_SIZE_BYTES = int(os.environ.get("RAY_CGRAPH_buffer_size_bytes", 1e6))
# The default number of in-flight executions that can be submitted before consuming the
# output.
DEFAULT_MAX_INFLIGHT_EXECUTIONS = int(
    os.environ.get("RAY_CGRAPH_max_inflight_executions", 10)
)

# The default number of results that can be buffered at the driver.
DEFAULT_MAX_BUFFERED_RESULTS = int(
    os.environ.get("RAY_CGRAPH_max_buffered_results", 1000)
)

DEFAULT_OVERLAP_GPU_COMMUNICATION = bool(
    os.environ.get("RAY_CGRAPH_overlap_gpu_communication", 0)
)


[docs] @DeveloperAPI @dataclass class DAGContext: """Global settings for Ray DAG. You can configure parameters in the DAGContext by setting the environment variables, `RAY_CGRAPH_<param>` (e.g., `RAY_CGRAPH_buffer_size_bytes`) or Python. Examples: >>> from ray.dag import DAGContext >>> DAGContext.get_current().buffer_size_bytes 1000000 >>> DAGContext.get_current().buffer_size_bytes = 500 >>> DAGContext.get_current().buffer_size_bytes 500 Args: submit_timeout: The maximum time in seconds to wait for execute() calls. get_timeout: The maximum time in seconds to wait when retrieving a result from the DAG during `ray.get`. This should be set to a value higher than the expected time to execute the entire DAG. teardown_timeout: The maximum time in seconds to wait for the DAG to cleanly shut down. read_iteration_timeout: The timeout in seconds for each read iteration that reads one of the input channels. If the timeout is reached, the read operation will be interrupted and will try to read the next input channel. It must be less than or equal to `get_timeout`. buffer_size_bytes: The initial buffer size in bytes for messages that can be passed between tasks in the DAG. The buffers will be automatically resized if larger messages are written to the channel. max_inflight_executions: The maximum number of in-flight executions that can be submitted via `execute` or `execute_async` before consuming the output using `ray.get()`. If the caller submits more executions, `RayCgraphCapacityExceeded` is raised. overlap_gpu_communication: (experimental) Whether to overlap GPU communication with computation during DAG execution. If True, the communication and computation can be overlapped, which can improve the performance of the DAG execution. """ submit_timeout: int = DEFAULT_SUBMIT_TIMEOUT_S get_timeout: int = DEFAULT_GET_TIMEOUT_S teardown_timeout: int = DEFAULT_TEARDOWN_TIMEOUT_S read_iteration_timeout: float = DEFAULT_READ_ITERATION_TIMEOUT_S buffer_size_bytes: int = DEFAULT_BUFFER_SIZE_BYTES max_inflight_executions: int = DEFAULT_MAX_INFLIGHT_EXECUTIONS max_buffered_results: int = DEFAULT_MAX_BUFFERED_RESULTS overlap_gpu_communication: bool = DEFAULT_OVERLAP_GPU_COMMUNICATION def __post_init__(self): if self.read_iteration_timeout > self.get_timeout: raise ValueError( "RAY_CGRAPH_read_iteration_timeout_s " f"({self.read_iteration_timeout}) must be less than or equal to " f"RAY_CGRAPH_get_timeout ({self.get_timeout})" )
[docs] @staticmethod def get_current() -> "DAGContext": """Get or create a singleton context. If the context has not yet been created in this process, it will be initialized with default settings. """ global _default_context with _context_lock: if _default_context is None: _default_context = DAGContext() return _default_context