Keras and Tensorflow Tutorial

In this guide, we will train and deploy a simple Tensorflow neural net. In particular, we show:

  • How to load the model from file system in your Ray Serve definition

  • How to parse the JSON request and evaluated in Tensorflow

Please see the Key Concepts to learn more general information about Ray Serve.

Ray Serve is framework agnostic – you can use any version of Tensorflow. However, for this tutorial, we use Tensorflow 2 and Keras. Please make sure you have Tensorflow 2 installed.

pip install "tensorflow>=2.0"

Let’s import Ray Serve and some other helpers.

from ray import serve

import os
import tempfile
import numpy as np
import requests

We will train a simple MNIST model using Keras.

TRAINED_MODEL_PATH = os.path.join(tempfile.gettempdir(), "mnist_model.h5")

def train_and_save_model():
    import tensorflow as tf

    # Load mnist dataset
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    # Train a simple neural net model
    model = tf.keras.models.Sequential(
            tf.keras.layers.Flatten(input_shape=(28, 28)),
            tf.keras.layers.Dense(128, activation="relu"),
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"]), y_train, epochs=1)

    model.evaluate(x_test, y_test, verbose=2)

    # Save the model in h5 format in local file system

if not os.path.exists(TRAINED_MODEL_PATH):

Services are just defined as normal classes with __init__ and __call__ methods. The __call__ method will be invoked per request.

class TFMnistModel:
    def __init__(self, model_path):
        import tensorflow as tf

        self.model_path = model_path
        self.model = tf.keras.models.load_model(model_path)

    async def __call__(self, starlette_request):
        # Step 1: transform HTTP request -> tensorflow input
        # Here we define the request schema to be a json array.
        input_array = np.array((await starlette_request.json())["array"])
        reshaped_array = input_array.reshape((1, 28, 28))

        # Step 2: tensorflow input -> tensorflow output
        prediction = self.model(reshaped_array)

        # Step 3: tensorflow output -> web output
        return {"prediction": prediction.numpy().tolist(), "file": self.model_path}

Now that we’ve defined our services, let’s deploy the model to Ray Serve. We will define a Serve deployment that will be exposed over an HTTP route.


Let’s query it!

resp = requests.get(
    "http://localhost:8000/mnist", json={"array": np.random.randn(28 * 28).tolist()}
# {
#  "prediction": [[-1.504277229309082, ..., -6.793371200561523]],
#  "file": "/tmp/mnist_model.h5"
# }