Source code for ray.dag.output_node

import ray
from typing import Any, Dict, List, Union, Tuple

from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.util.annotations import DeveloperAPI


[docs] @DeveloperAPI class MultiOutputNode(DAGNode): """Ray dag node used in DAG building API to mark the endpoint of DAG""" def __init__( self, args: Union[List[DAGNode], Tuple[DAGNode]], other_args_to_resolve: Dict[str, Any] = None, ): if isinstance(args, tuple): args = list(args) if not isinstance(args, list): raise ValueError(f"Invalid input type for `args`, {type(args)}.") super().__init__( args, {}, {}, other_args_to_resolve=other_args_to_resolve or {}, ) def _execute_impl( self, *args, **kwargs ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]: return self._bound_args def _copy_impl( self, new_args: List[Any], new_kwargs: Dict[str, Any], new_options: Dict[str, Any], new_other_args_to_resolve: Dict[str, Any], ) -> "DAGNode": """Return a copy of this node with the given new args.""" return MultiOutputNode(new_args, new_other_args_to_resolve) def __str__(self) -> str: return get_dag_node_str(self, "__MultiOutputNode__")