Source code for ray.dag.output_node
import ray
from typing import Any, Dict, List, Union, Tuple
from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.util.annotations import DeveloperAPI
[docs]
@DeveloperAPI
class MultiOutputNode(DAGNode):
"""Ray dag node used in DAG building API to mark the endpoint of DAG"""
def __init__(
self,
args: Union[List[DAGNode], Tuple[DAGNode]],
other_args_to_resolve: Dict[str, Any] = None,
):
if isinstance(args, tuple):
args = list(args)
if not isinstance(args, list):
raise ValueError(f"Invalid input type for `args`, {type(args)}.")
super().__init__(
args,
{},
{},
other_args_to_resolve=other_args_to_resolve or {},
)
def _execute_impl(
self, *args, **kwargs
) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
return self._bound_args
def _copy_impl(
self,
new_args: List[Any],
new_kwargs: Dict[str, Any],
new_options: Dict[str, Any],
new_other_args_to_resolve: Dict[str, Any],
) -> "DAGNode":
"""Return a copy of this node with the given new args."""
return MultiOutputNode(new_args, new_other_args_to_resolve)
def __str__(self) -> str:
return get_dag_node_str(self, "__MultiOutputNode__")