PyTorch Tutorial

In this guide, we will load and serve a PyTorch Resnet Model. In particular, we show:

  • How to load the model from PyTorch’s pre-trained modelzoo.

  • How to parse the JSON request, transform the payload and evaluated in the model.

Please see the Key Concepts to learn more general information about Ray Serve.

This tutorial requires Pytorch and Torchvision installed in your system. Ray Serve is framework agnostic and work with any version of PyTorch.

pip install torch torchvision

Let’s import Ray Serve and some other helpers.

from ray import serve

from io import BytesIO
from PIL import Image
import requests

import torch
from torchvision import transforms
from torchvision.models import resnet18

Services are just defined as normal classes with __init__ and __call__ methods. The __call__ method will be invoked per request.

class ImageModel:
    def __init__(self):
        self.model = resnet18(pretrained=True).eval()
        self.preprocessor = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: t[:3, ...]),  # remove alpha channel
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __call__(self, flask_request):
        image_payload_bytes = flask_request.data
        pil_image = Image.open(BytesIO(image_payload_bytes))
        print("[1/3] Parsed image data: {}".format(pil_image))

        pil_images = [pil_image]  # Our current batch size is one
        input_tensor = torch.cat(
            [self.preprocessor(i).unsqueeze(0) for i in pil_images])
        print("[2/3] Images transformed, tensor shape {}".format(
            input_tensor.shape))

        with torch.no_grad():
            output_tensor = self.model(input_tensor)
        print("[3/3] Inference done!")
        return {"class_index": int(torch.argmax(output_tensor[0]))}


Now that we’ve defined our services, let’s deploy the model to Ray Serve. We will define an endpoint for the route representing the digit classifier task, a backend correspond the physical implementation, and connect them together.

client = serve.start()
client.create_backend("resnet18:v0", ImageModel)
client.create_endpoint(
    "predictor",
    backend="resnet18:v0",
    route="/image_predict",
    methods=["POST"])

Let’s query it!

ray_logo_bytes = requests.get(
    "https://github.com/ray-project/ray/raw/"
    "master/doc/source/images/ray_header_logo.png").content

resp = requests.post(
    "http://localhost:8000/image_predict", data=ray_logo_bytes)
print(resp.json())
# Output
# {'class_index': 463}