Pattern: Combine two nodes with passing same input in parallel¶
The example shows how to pass same input to two nodes in parallel and combine the outputs
Code¶
import ray
from ray import serve
from ray.serve.deployment_graph import InputNode
ray.init()
serve.start()
@serve.deployment
class Model:
def __init__(self, weight):
self.weight = weight
def forward(self, input):
return input + self.weight
@serve.deployment
def combine(value_refs):
return sum(ray.get(value_refs))
with InputNode() as user_input:
model1 = Model.bind(0)
model2 = Model.bind(1)
output1 = model1.forward.bind(user_input)
output2 = model2.forward.bind(user_input)
dag = combine.bind([output1, output2])
print(ray.get(dag.execute(1)))
Outputs¶
The graph will pass input into two nodes and sum the outputs of the two model.
Model output1: 1(input) + 0(weight) = 1
Model output2: 1(input) + 1(weight) = 2
Combine sum: 1 (output1) + 2 (output2) = 3
3