Source code for ray.dag.input_node
from typing import Any, Dict, List, Union, Optional
from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.experimental.gradio_utils import type_to_string
from ray.util.annotations import DeveloperAPI
IN_CONTEXT_MANAGER = "__in_context_manager__"
[docs]
@DeveloperAPI
class InputNode(DAGNode):
r"""Ray dag node used in DAG building API to mark entrypoints of a DAG.
Should only be function or class method. A DAG can have multiple
entrypoints, but only one instance of InputNode exists per DAG, shared
among all DAGNodes.
Example:
.. code-block::
m1.forward
/ \
dag_input ensemble -> dag_output
\ /
m2.forward
In this pipeline, each user input is broadcasted to both m1.forward and
m2.forward as first stop of the DAG, and authored like
.. code-block:: python
import ray
@ray.remote
class Model:
def __init__(self, val):
self.val = val
def forward(self, input):
return self.val * input
@ray.remote
def combine(a, b):
return a + b
with InputNode() as dag_input:
m1 = Model.bind(1)
m2 = Model.bind(2)
m1_output = m1.forward.bind(dag_input[0])
m2_output = m2.forward.bind(dag_input.x)
ray_dag = combine.bind(m1_output, m2_output)
# Pass mix of args and kwargs as input.
ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2
# Alternatively user can also pass single data object, list or dict
# and access them via list index, object attribute or dict key str.
ray_dag.execute(UserDataObject(m1=1, m2=2))
# dag_input.m1, dag_input.m2
ray_dag.execute([1, 2])
# dag_input[0], dag_input[1]
ray_dag.execute({"m1": 1, "m2": 2})
# dag_input["m1"], dag_input["m2"]
"""
[docs]
def __init__(
self,
*args,
input_type: Optional[Union[type, Dict[Union[int, str], type]]] = None,
_other_args_to_resolve=None,
**kwargs,
):
"""InputNode should only take attributes of validating and converting
input data rather than the input data itself. User input should be
provided via `ray_dag.execute(user_input)`.
Args:
input_type: Describes the data type of inputs user will be giving.
- if given through singular InputNode: type of InputNode
- if given through InputAttributeNodes: map of key -> type
Used when deciding what Gradio block to represent the input nodes with.
_other_args_to_resolve: Internal only to keep InputNode's execution
context throughput pickling, replacement and serialization.
User should not use or pass this field.
"""
if len(args) != 0 or len(kwargs) != 0:
raise ValueError("InputNode should not take any args or kwargs.")
self.input_attribute_nodes = {}
self.input_type = input_type
if input_type is not None and isinstance(input_type, type):
if _other_args_to_resolve is None:
_other_args_to_resolve = {}
_other_args_to_resolve["result_type_string"] = type_to_string(input_type)
super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve)
def _copy_impl(
self,
new_args: List[Any],
new_kwargs: Dict[str, Any],
new_options: Dict[str, Any],
new_other_args_to_resolve: Dict[str, Any],
):
return InputNode(_other_args_to_resolve=new_other_args_to_resolve)
def _execute_impl(self, *args, **kwargs):
"""Executor of InputNode."""
# Catch and assert singleton context at dag execution time.
assert self._in_context_manager(), (
"InputNode is a singleton instance that should be only used in "
"context manager for dag building and execution. See the docstring "
"of class InputNode for examples."
)
# If user only passed in one value, for simplicity we just return it.
if len(args) == 1 and len(kwargs) == 0:
return args[0]
return DAGInputData(*args, **kwargs)
def _in_context_manager(self) -> bool:
"""Return if InputNode is created in context manager."""
if (
not self._bound_other_args_to_resolve
or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve
):
return False
else:
return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER]
[docs]
def set_context(self, key: str, val: Any):
"""Set field in parent DAGNode attribute that can be resolved in both
pickle and JSON serialization
"""
self._bound_other_args_to_resolve[key] = val
def __str__(self) -> str:
return get_dag_node_str(self, "__InputNode__")
def __getattr__(self, key: str):
assert isinstance(
key, str
), "Please only access dag input attributes with str key."
if key not in self.input_attribute_nodes:
self.input_attribute_nodes[key] = InputAttributeNode(
self, key, "__getattr__"
)
return self.input_attribute_nodes[key]
def __getitem__(self, key: Union[int, str]) -> Any:
assert isinstance(key, (str, int)), (
"Please only use int index or str as first-level key to "
"access fields of dag input."
)
input_type = None
if self.input_type is not None and key in self.input_type:
input_type = type_to_string(self.input_type[key])
if key not in self.input_attribute_nodes:
self.input_attribute_nodes[key] = InputAttributeNode(
self, key, "__getitem__", input_type
)
return self.input_attribute_nodes[key]
def __enter__(self):
self.set_context(IN_CONTEXT_MANAGER, True)
return self
def __exit__(self, *args):
pass
[docs]
def get_result_type(self) -> str:
"""Get type of the output of this DAGNode.
Generated by ray.experimental.gradio_utils.type_to_string().
"""
if "result_type_string" in self._bound_other_args_to_resolve:
return self._bound_other_args_to_resolve["result_type_string"]
@DeveloperAPI
class InputAttributeNode(DAGNode):
"""Represents partial access of user input based on an index (int),
object attribute or dict key (str).
Examples:
.. code-block:: python
with InputNode() as dag_input:
a = dag_input[0]
b = dag_input.x
ray_dag = add.bind(a, b)
# This makes a = 1 and b = 2
ray_dag.execute(1, x=2)
with InputNode() as dag_input:
a = dag_input[0]
b = dag_input[1]
ray_dag = add.bind(a, b)
# This makes a = 2 and b = 3
ray_dag.execute(2, 3)
# Alternatively, you can input a single object
# and the inputs are automatically indexed from the object:
# This makes a = 2 and b = 3
ray_dag.execute([2, 3])
"""
def __init__(
self,
dag_input_node: InputNode,
key: Union[int, str],
accessor_method: str,
input_type: str = None,
):
self._dag_input_node = dag_input_node
self._key = key
self._accessor_method = accessor_method
super().__init__(
[],
{},
{},
{
"dag_input_node": dag_input_node,
"key": key,
"accessor_method": accessor_method,
# Type of the input tied to this node. Used by
# gradio_visualize_graph.GraphVisualizer to determine which Gradio
# component should be used for this node.
"result_type_string": input_type,
},
)
def _copy_impl(
self,
new_args: List[Any],
new_kwargs: Dict[str, Any],
new_options: Dict[str, Any],
new_other_args_to_resolve: Dict[str, Any],
):
return InputAttributeNode(
new_other_args_to_resolve["dag_input_node"],
new_other_args_to_resolve["key"],
new_other_args_to_resolve["accessor_method"],
new_other_args_to_resolve["result_type_string"],
)
def _execute_impl(self, *args, **kwargs):
"""Executor of InputAttributeNode.
Args and kwargs are to match base class signature, but not in the
implementation. All args and kwargs should be resolved and replaced
with value in bound_args and bound_kwargs via bottom-up recursion when
current node is executed.
"""
if isinstance(self._dag_input_node, DAGInputData):
return self._dag_input_node[self._key]
else:
# dag.execute() is called with only one arg, thus when an
# InputAttributeNode is executed, its dependent InputNode is
# resolved with original user input python object.
user_input_python_object = self._dag_input_node
if isinstance(self._key, str):
if self._accessor_method == "__getitem__":
return user_input_python_object[self._key]
elif self._accessor_method == "__getattr__":
return getattr(user_input_python_object, self._key)
elif isinstance(self._key, int):
return user_input_python_object[self._key]
else:
raise ValueError(
"Please only use int index or str as first-level key to "
"access fields of dag input."
)
def __str__(self) -> str:
return get_dag_node_str(self, f'["{self._key}"]')
def get_result_type(self) -> str:
"""Get type of the output of this DAGNode.
Generated by ray.experimental.gradio_utils.type_to_string().
"""
if "result_type_string" in self._bound_other_args_to_resolve:
return self._bound_other_args_to_resolve["result_type_string"]
@property
def key(self) -> Union[int, str]:
return self._key
@DeveloperAPI
class DAGInputData:
"""If user passed multiple args and kwargs directly to dag.execute(), we
generate this wrapper for all user inputs as one object, accessible via
list index or object attribute key.
"""
def __init__(self, *args, **kwargs):
self._args = list(args)
self._kwargs = kwargs
def __getitem__(self, key: Union[int, str]) -> Any:
if isinstance(key, int):
# Access list args by index.
return self._args[key]
elif isinstance(key, str):
# Access kwarg by key.
return self._kwargs[key]
else:
raise ValueError(
"Please only use int index or str as first-level key to "
"access fields of dag input."
)