Pattern: Conditional

This deployment graph pattern allows you to control your graph’s flow using conditionals. You can use this pattern to introduce a dynamic path for your requests to flow through.

pic

Code

# File name: conditional.py

import ray
from ray import serve
from ray.serve.drivers import DAGDriver
from ray.serve.http_adapters import json_request
from ray.serve.deployment_graph import InputNode


@serve.deployment
class Model:
    def __init__(self, weight: int):
        self.weight = weight

    def forward(self, input: int) -> int:
        return input + self.weight


@serve.deployment
def combine(value1: int, value2: int, operation: str) -> int:
    if operation == "sum":
        return sum([value1, value2])
    else:
        return max([value1, value2])


model1 = Model.bind(0)
model2 = Model.bind(1)

with InputNode() as user_input:
    input_number, input_operation = user_input[0], user_input[1]
    output1 = model1.forward.bind(input_number)
    output2 = model2.forward.bind(input_number)
    combine_output = combine.bind(output1, output2, input_operation)

graph = DAGDriver.bind(combine_output, http_adapter=json_request)

handle = serve.run(graph)

max_output = ray.get(handle.predict.remote(1, "max"))
print(max_output)

sum_output = ray.get(handle.predict.remote(1, "sum"))
print(sum_output)

Note

combine takes in intermediate values from the call graph as the individual arguments, value1 and value2. You can also aggregate and pass these intermediate values as a list argument. However, this list contains references to the values, rather than the values themselves. You must explicitly use await to get the actual values before using them. Use await instead of ray.get to avoid blocking the deployment.

For example:

dag = combine.bind([output1, output2], user_input[1])
...
@serve.deployment
async def combine(value_refs, combine_type):
   values = await value_refs
   value1, value2 = values
...

Execution

The graph creates two Model nodes, with weights of 0 and 1. It then takes the user_input and unpacks it into two parts: a number and an operation.

Note

handle.predict.remote() can take an arbitrary number of arguments. These arguments can be unpacked by indexing into the InputNode. For example,

with InputNode() as user_input:
   input_number, input_operation = user_input[0], user_input[1]

It passes the number into the two Model nodes, similar to the branching input pattern. Then it passes the requested operation, as well as the intermediate outputs, to the combine deployment to get a final result.

The example script makes two requests to the graph, both with a number input of 1. The resulting calculations are

max:

input = 1
output1 = input + weight_1 = 0 + 1 = 1
output2 = input + weight_2 = 1 + 1 = 2
combine_output = max(output1, output2) = max(1, 2) = 2

sum:

input = 1
output1 = input + weight_1 = 0 + 1 = 1
output2 = input + weight_2 = 1 + 1 = 2
combine_output = output1 + output2 = 1 + 2 = 3

The final outputs are 2 and 3:

$ python conditional.py

2
3