Pattern: Conditional
Contents
Pattern: Conditional#
This deployment graph pattern allows you to control your graph’s flow using conditionals. You can use this pattern to introduce a dynamic path for your requests to flow through.
Code#
# File name: conditional.py
import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.serve.http_adapters import json_request
from ray.serve.deployment_graph import InputNode
@serve.deployment
class Model:
def __init__(self, weight: int):
self.weight = weight
def forward(self, input: int) -> int:
return input + self.weight
@serve.deployment
def combine(value1: int, value2: int, operation: str) -> int:
if operation == "sum":
return sum([value1, value2])
else:
return max([value1, value2])
model1 = Model.bind(0)
model2 = Model.bind(1)
with InputNode() as user_input:
input_number, input_operation = user_input[0], user_input[1]
output1 = model1.forward.bind(input_number)
output2 = model2.forward.bind(input_number)
combine_output = combine.bind(output1, output2, input_operation)
graph = DAGDriver.bind(combine_output, http_adapter=json_request)
handle = serve.run(graph)
max_output = ray.get(handle.predict.remote(1, "max"))
print(max_output)
sum_output = ray.get(handle.predict.remote(1, "sum"))
print(sum_output)
Note
combine
takes in intermediate values from the call graph as the individual arguments, value1
and value2
. You can also aggregate and pass these intermediate values as a list argument. However, this list contains references to the values, rather than the values themselves. You must explicitly use await
to get the actual values before using them. Use await
instead of ray.get
to avoid blocking the deployment.
For example:
dag = combine.bind([output1, output2], user_input[1])
...
@serve.deployment
async def combine(value_refs, combine_type):
values = await value_refs
value1, value2 = values
...
Execution#
The graph creates two Model
nodes, with weights
of 0 and 1. It then takes the user_input
and unpacks it into two parts: a number and an operation.
Note
handle.predict.remote()
can take an arbitrary number of arguments. These arguments can be unpacked by indexing into the InputNode
. For example,
with InputNode() as user_input:
input_number, input_operation = user_input[0], user_input[1]
It passes the number into the two Model
nodes, similar to the branching input pattern. Then it passes the requested operation, as well as the intermediate outputs, to the combine
deployment to get a final result.
The example script makes two requests to the graph, both with a number input of 1. The resulting calculations are
max
:
input = 1
output1 = input + weight_1 = 0 + 1 = 1
output2 = input + weight_2 = 1 + 1 = 2
combine_output = max(output1, output2) = max(1, 2) = 2
sum
:
input = 1
output1 = input + weight_1 = 0 + 1 = 1
output2 = input + weight_2 = 1 + 1 = 2
combine_output = output1 + output2 = 1 + 2 = 3
The final outputs are 2 and 3:
$ python conditional.py
2
3