Pattern: Combine two nodes with passing same input in parallel

The example shows how to pass same input to two nodes in parallel and combine the outputs

pic

Code

import ray
from ray import serve
from ray.serve.deployment_graph import InputNode


ray.init()
serve.start()


@serve.deployment
class Model:
    def __init__(self, weight):
        self.weight = weight

    def forward(self, input):
        return input + self.weight


@serve.deployment
def combine(value_refs):
    return sum(ray.get(value_refs))


with InputNode() as user_input:
    model1 = Model.bind(0)
    model2 = Model.bind(1)
    output1 = model1.forward.bind(user_input)
    output2 = model2.forward.bind(user_input)
    dag = combine.bind([output1, output2])

print(ray.get(dag.execute(1)))

Outputs

The graph will pass input into two nodes and sum the outputs of the two model.
Model output1: 1(input) + 0(weight) = 1
Model output2: 1(input) + 1(weight) = 2
Combine sum: 1 (output1) + 2 (output2) = 3

3