Pattern: Control flow based on the user inputs

The example shows how to use inputs to control the graph flow



import ray
from ray import serve
from ray.dag.input_node import InputNode


class Model:
    def __init__(self, weight):
        self.weight = weight

    def forward(self, input):
        return input + self.weight

def combine(value1, value2, combine_type):
    if combine_type == "sum":
        return sum([value1, value2])
        return max([value1, value2])

with InputNode() as user_input:
    model1 = Model.bind(0)
    model2 = Model.bind(1)
    output1 = model1.forward.bind(user_input[0])
    output2 = model2.forward.bind(user_input[0])
    dag = combine.bind(output1, output2, user_input[1])

print(ray.get(dag.execute(1, "max")))
print(ray.get(dag.execute(1, "sum")))


  1. The dag.execute() take arbitrary number of arguments, and internally we implemented data objects to facilitate accessing by index or key.

    code example:

    dag = combine.bind(output1, output2, user_input[1])
  2. value1 and value2 are ObjectRef passed into the combine, the value of ObjectRef will be resolved at the runtime.

  3. we can pass value1 and value2 as a list. In this case, we are passing the ObjectRef as reference, the value of ObjectRef will not be addressed automatically. We need to explicitly use ray.get() to address value before we do sum() or max() function. (passing objects by reference)

    code example:

    dag = combine.bind([output1, output2], user_input[1])
    def combine(value_refs, combine_type):
        values = ray.get(value_refs)


The code uses ‘max’ to do combine from the output of the two models.

Model output1: 1(input) + 0(weight) = 1
Model output2: 1(input) + 1(weight) = 2
So the combine max is 2, the combine sum is 3