Pattern: Branching Input
Pattern: Branching Input#
This deployment graph pattern lets you pass the same input to multiple deployments in parallel. You can then aggregate these deployments’ intermediate outputs in another deployment.
# File name: branching_input.py import ray from ray import serve from ray.serve.drivers import DAGDriver from ray.serve.http_adapters import json_request from ray.serve.deployment_graph import InputNode @serve.deployment class Model: def __init__(self, weight: int): self.weight = weight def forward(self, input: int) -> int: return input + self.weight @serve.deployment def combine(value_refs): return sum(ray.get(value_refs)) model1 = Model.bind(0) model2 = Model.bind(1) with InputNode() as user_input: output1 = model1.forward.bind(user_input) output2 = model2.forward.bind(user_input) combine_output = combine.bind([output1, output2]) graph = DAGDriver.bind(combine_output, http_adapter=json_request) handle = serve.run(graph) sum = ray.get(handle.predict.remote(1)) print(sum)
This graph includes two
Model nodes, with
weights of 0 and 1. It passes the input into the two
Models, and they add their own weights to it. Then, it uses the
combine deployment to add the two
Model deployments’ outputs together.
The resulting calculation is:
input = 1 output1 = input + weight_1 = 0 + 1 = 1 output2 = input + weight_2 = 1 + 1 = 2 combine_output = output1 + output2 = 1 + 2 = 3
The final output is 3:
$ python branching_input.py 3