Pattern: Branching Input
Contents
Pattern: Branching Input#
This deployment graph pattern lets you pass the same input to multiple deployments in parallel. You can then aggregate these deployments’ intermediate outputs in another deployment.
Code#
# File name: branching_input.py
import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.serve.http_adapters import json_request
from ray.serve.deployment_graph import InputNode
@serve.deployment
class Model:
def __init__(self, weight: int):
self.weight = weight
def forward(self, input: int) -> int:
return input + self.weight
@serve.deployment
def combine(value_refs):
return sum(ray.get(value_refs))
model1 = Model.bind(0)
model2 = Model.bind(1)
with InputNode() as user_input:
output1 = model1.forward.bind(user_input)
output2 = model2.forward.bind(user_input)
combine_output = combine.bind([output1, output2])
graph = DAGDriver.bind(combine_output, http_adapter=json_request)
handle = serve.run(graph)
sum = ray.get(handle.predict.remote(1))
print(sum)
Execution#
This graph includes two Model
nodes, with weights
of 0 and 1. It passes the input into the two Models
, and they add their own weights to it. Then, it uses the combine
deployment to add the two Model
deployments’ outputs together.
The resulting calculation is:
input = 1
output1 = input + weight_1 = 0 + 1 = 1
output2 = input + weight_2 = 1 + 1 = 2
combine_output = output1 + output2 = 1 + 2 = 3
The final output is 3:
$ python branching_input.py
3