Serving an inference model on AWS NeuronCores using FastAPI (Experimental)#

This example compiles a BERT-based model and deploys the traced model on an AWS Inferentia (Inf2) or Tranium (Trn1) instance using Ray Serve and FastAPI.

Note

The setup assumes that the user has followed the PyTorch Neuron setup guide and installed AWS NeuronCore drivers/tools and torch-neuronx based on the instance-type.

python -m pip install "ray[serve]" requests transformers

This example uses the j-hartmann/emotion-english-distilroberta-base model and FastAPI.

Use the following code to compile the model:

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch, torch_neuronx

hf_model = "j-hartmann/emotion-english-distilroberta-base"
neuron_model = "./sentiment_neuron.pt"

model = AutoModelForSequenceClassification.from_pretrained(hf_model)
tokenizer = AutoTokenizer.from_pretrained(hf_model)
sequence_0 = "The company HuggingFace is based in New York City"
sequence_1 = "HuggingFace's headquarters are situated in Manhattan"
example_inputs = tokenizer.encode_plus(
    sequence_0,
    sequence_1,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=128,
)
neuron_inputs = example_inputs["input_ids"], example_inputs["attention_mask"]
n_model = torch_neuronx.trace(model, neuron_inputs)
n_model.save(neuron_model)
print(f"Saved Neuron-compiled model {neuron_model}")

For compiling the model, you should see the following logs:

Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.00k/1.00k [00:00<00:00, 242kB/s]
Downloading pytorch_model.bin: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 329M/329M [00:01<00:00, 217MB/s]
Downloading (…)okenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 294/294 [00:00<00:00, 305kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 798k/798k [00:00<00:00, 22.0MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 57.0MB/s]
Downloading (…)/main/tokenizer.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 6.16MB/s]
Downloading (…)cial_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 239/239 [00:00<00:00, 448kB/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Saved Neuron-compiled model ./sentiment_neuron.pt

The traced model should be ready for deployment. Save the following code to a file named aws_neuron_core_inference_serve.py.

Use serve run aws_neuron_core_inference_serve:entrypoint to start the serve application.

from fastapi import FastAPI
import torch

from ray import serve
from ray.serve.handle import DeploymentHandle

app = FastAPI()

hf_model = "j-hartmann/emotion-english-distilroberta-base"
neuron_model = "./sentiment_neuron.pt"


@serve.deployment(num_replicas=1)
@serve.ingress(app)
class APIIngress:
    def __init__(self, bert_base_model_handle: DeploymentHandle) -> None:
        self.handle = bert_base_model_handle

    @app.get("/infer")
    async def infer(self, sentence: str):
        return await self.handle.infer.remote(sentence)


@serve.deployment(
    ray_actor_options={"resources": {"neuron_cores": 1}},
    autoscaling_config={"min_replicas": 1, "max_replicas": 2},
)
class BertBaseModel:
    def __init__(self):
        import torch, torch_neuronx  # noqa
        from transformers import AutoTokenizer

        self.model = torch.jit.load(neuron_model)
        self.tokenizer = AutoTokenizer.from_pretrained(hf_model)
        self.classmap = {
            0: "anger",
            1: "disgust",
            2: "fear",
            3: "joy",
            4: "neutral",
            5: "sadness",
            6: "surprise",
        }

    def infer(self, sentence: str):
        inputs = self.tokenizer.encode_plus(
            sentence,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=128,
        )
        output = self.model(*(inputs["input_ids"], inputs["attention_mask"]))
        class_id = torch.argmax(output["logits"], dim=1).item()
        return self.classmap[class_id]


entrypoint = APIIngress.bind(BertBaseModel.bind())


You should see the following logs for a successful deployment:

(ServeController pid=43105) INFO 2023-08-23 20:29:32,694 controller 43105 deployment_state.py:1372 - Deploying new version of deployment default_BertBaseModel.
(ServeController pid=43105) INFO 2023-08-23 20:29:32,695 controller 43105 deployment_state.py:1372 - Deploying new version of deployment default_APIIngress.
(ProxyActor pid=43147) INFO 2023-08-23 20:29:32,620 http_proxy 10.0.1.234 http_proxy.py:1328 - Proxy actor 8be14f6b6b10c0190cd0c39101000000 starting on node 46a7f740898fef723c3360ef598c1309701b07d11fb9dc45e236620a.
(ProxyActor pid=43147) INFO:     Started server process [43147]
(ServeController pid=43105) INFO 2023-08-23 20:29:32,799 controller 43105 deployment_state.py:1654 - Adding 1 replica to deployment default_BertBaseModel.
(ServeController pid=43105) INFO 2023-08-23 20:29:32,801 controller 43105 deployment_state.py:1654 - Adding 1 replica to deployment default_APIIngress.
2023-08-23 20:29:44,690 SUCC scripts.py:462 -- Deployed Serve app successfully.

Use the following code to send requests:

import requests

response = requests.get(f"http://127.0.0.1:8000/infer?sentence=Ray is super cool")
print(response.status_code, response.json())

The response includes status code and the classifier output

200 joy