Pattern: Branching Input¶

This deployment graph pattern lets you pass the same input to multiple deployments in parallel. You can then aggregate these deployments’ intermediate outputs in another deployment.

pic

Code¶

# File name: branching_input.py

import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.serve.http_adapters import json_request
from ray.serve.deployment_graph import InputNode


@serve.deployment
class Model:
    def __init__(self, weight: int):
        self.weight = weight

    def forward(self, input: int) -> int:
        return input + self.weight


@serve.deployment
def combine(value_refs):
    return sum(ray.get(value_refs))


model1 = Model.bind(0)
model2 = Model.bind(1)

with InputNode() as user_input:
    output1 = model1.forward.bind(user_input)
    output2 = model2.forward.bind(user_input)
    combine_output = combine.bind([output1, output2])

graph = DAGDriver.bind(combine_output, http_adapter=json_request)

handle = serve.run(graph)

sum = ray.get(handle.predict.remote(1))
print(sum)

Execution¶

This graph includes two Model nodes, with weights of 0 and 1. It passes the input into the two Models, and they add their own weights to it. Then, it uses the combine deployment to add the two Model deployments’ outputs together.

The resulting calculation is:

input = 1
output1 = input + weight_1 = 0 + 1 = 1
output2 = input + weight_2 = 1 + 1 = 2
combine_output = output1 + output2 = 1 + 2 = 3

The final output is 3:

$ python branching_input.py

3