Building a Gradio demo with Ray Serve

In this example, we will show you how to wrap a machine learning model served by Ray Serve in a Gradio demo.

Specifically, we’re going to download a GPT-2 model from the transformer library, define a Ray Serve deployment with it, and then define and launch a Gradio Interface. Let’s take a look.

Deploying a model with Ray Serve

To start off, we import Ray Serve, Gradio, the transformers and requests libraries, and then simply start Ray Serve:

import gradio as gr
from ray import serve
from transformers import pipeline
import requests


serve.start()

Next, we define a Ray Serve deployment with a GPT-2 model, by using the @serve.deployment decorator on a model function that takes a request argument. In this function we define a GPT-2 model with a call to pipeline and return the result of querying the model.

@serve.deployment
def model(request):
    language_model = pipeline("text-generation", model="gpt2")
    query = request.query_params["query"]
    return language_model(query, max_length=100)

This model can now easily be deployed using a model.deploy() call. To test this deployment we use a simple example query to get a response from the model running on localhost:8000/model. The first time you use this endpoint, the model will be downloaded first, which can take a while to complete. Subsequent calls will be faster.

model.deploy()
example = "What's the meaning of life?"
response = requests.get(f"http://localhost:8000/model?query={example}")
print(response.text)

Defining and launching a Gradio interface

Defining a Gradio interface is now straightforward. All we need is a function that Gradio can call to get the response from the model. That’s just a thin wrapper around our previous requests call:

def gpt2(query):
    response = requests.get(f"http://localhost:8000/model?query={query}")
    return response.json()[0]["generated_text"]

Apart from our gpt2 function, the only other thing that we need to define a Gradio interface is a description of the model inputs and outputs that Gradio understands. Since our model takes text as input and output, this turns out to be pretty simple:

iface = gr.Interface(
    fn=gpt2,
    inputs=[gr.inputs.Textbox(
        default=example, label="Input prompt"
    )],
    outputs=[gr.outputs.Textbox(label="Model output")]
)

For more complex models served with Ray, you might need multiple gr.inputs and gr.outputs of different types.

Finally, we can launch the interface using iface.launch():

iface.launch()

This should launch an interface that you can interact with that looks like this:

https://raw.githubusercontent.com/ray-project/images/master/docs/serve/gradio_serve_gpt.png

You can run this examples directly in the browser, for instance by launching this notebook directly into Google Colab or Binder, by clicking on the rocket icon at the top right of this page. If you run this code locally in Python, this Gradio app will be served on http://127.0.0.1:7861/.

Building a Gradio app from a Scikit-Learn model

Let’s take a look at another example, so that you can see the slight differences to the first example in direct comparison.

This time we’re going to use a Scikit-Learn model that we quickly train ourselves on the famous Iris dataset. To do this, we’ll download the Iris dataset using the built-in load_iris function from the sklearn library, and we used the GradientBoostingClassifier from the sklearn.ensemble module for training.

This time we’ll use the @serve.deployment decorator on a class called BoostingModel, which has an asynchronous __call__ method that Ray Serve needs to define your deployment. All else remains the same as in the first example.

import gradio as gr
import requests
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier

from ray import serve

# Train your model.
iris_dataset = load_iris()
model = GradientBoostingClassifier()
model.fit(iris_dataset["data"], iris_dataset["target"])

# Start Ray Serve.
serve.start()

# Define your deployment.
@serve.deployment(route_prefix="/iris")
class BoostingModel:
    def __init__(self, model):
        self.model = model
        self.label_list = iris_dataset["target_names"].tolist()

    async def __call__(self, request):
        payload = (await request.json())["vector"]
        print(f"Received http request with data {payload}")

        prediction = self.model.predict([payload])[0]
        human_name = self.label_list[prediction]
        return {"result": human_name}


# Deploy your model.
BoostingModel.deploy(model)

Equipped with our BoostingModel class, we can now define and launch a Gradio interface as follows. The Iris dataset has a total of four features, namely the four numeric values sepal length, sepal width, petal length, and petal width. We use this fact to define an iris function that takes these four features and returns the predicted class, using our deployed model. This time, the Gradio interface takes four input Numbers, and returns the predicted class as text. Go ahead and try it out in the browser yourself.

# Define gradio function
def iris(sl, sw, pl, pw):
    request_input = {"vector": [sl, sw, pl, pw]}
    response = requests.get(
        "http://localhost:8000/iris", json=request_input)
    return response.json()[0]["result"]


# Define gradio interface
iface = gr.Interface(
    fn=iris,
    inputs=[
        gr.inputs.Number(default=1.0, label="sepal length (cm)"),
        gr.inputs.Number(default=1.0, label="sepal width (cm)"),
        gr.inputs.Number(default=1.0, label="petal length (cm)"),
        gr.inputs.Number(default=1.0, label="petal width (cm)"),
        ],
    outputs="text")

# Launch the gradio interface
iface.launch()

Launching this interface, you should see an interactive interface that looks like this:

https://raw.githubusercontent.com/ray-project/images/master/docs/serve/gradio_serve_iris.png

Conclusion

To summarize, it’s easy to build Gradio apps from Ray Serve deployments. You only need to properly encode your model’s inputs and outputs in a Gradio interface, and you’re good to go!