Set Up a gRPC Service#
This section helps you understand how to:
Build a user defined gRPC service and protobuf
Start Serve with gRPC enabled
Deploy gRPC applications
Send gRPC requests to Serve deployments
Check proxy health
Work with gRPC metadata
Use streaming and model composition
Handle errors
Use gRPC context
Define a gRPC service#
Running a gRPC server starts with defining gRPC services, RPC methods, and protobufs similar to the one below.
// user_defined_protos.proto
syntax = "proto3";
option java_multiple_files = true;
option java_package = "io.ray.examples.user_defined_protos";
option java_outer_classname = "UserDefinedProtos";
package userdefinedprotos;
message UserDefinedMessage {
string name = 1;
string origin = 2;
int64 num = 3;
}
message UserDefinedResponse {
string greeting = 1;
int64 num = 2;
}
message UserDefinedMessage2 {}
message UserDefinedResponse2 {
string greeting = 1;
}
message ImageData {
string url = 1;
string filename = 2;
}
message ImageClass {
repeated string classes = 1;
repeated float probabilities = 2;
}
service UserDefinedService {
rpc __call__(UserDefinedMessage) returns (UserDefinedResponse);
rpc Multiplexing(UserDefinedMessage2) returns (UserDefinedResponse2);
rpc Streaming(UserDefinedMessage) returns (stream UserDefinedResponse);
}
service ImageClassificationService {
rpc Predict(ImageData) returns (ImageClass);
}
This example creates a file named user_defined_protos.proto
with two
gRPC services: UserDefinedService
and ImageClassificationService
.
UserDefinedService
has three RPC methods: __call__
, Multiplexing
, and Streaming
.
ImageClassificationService
has one RPC method: Predict
. Their corresponding input
and output types are also defined specifically for each RPC method.
Once you define the .proto
services, use grpcio-tools
to compile python
code for those services. Example command looks like the following:
python -m grpc_tools.protoc -I=. --python_out=. --grpc_python_out=. ./user_defined_protos.proto
It generates two files: user_defined_protos_pb2.py
and
user_defined_protos_pb2_grpc.py
.
For more details on grpcio-tools
see https://grpc.io/docs/languages/python/basics/#generating-client-and-server-code.
Note
Ensure that the generated files are in the same directory as where the Ray cluster is running so that Serve can import them when starting the proxies.
Start Serve with gRPC enabled#
The Serve start CLI,
ray.serve.start
API,
and Serve config files
all support starting Serve with a gRPC proxy. Two options are related to Serve’s
gRPC proxy: grpc_port
and grpc_servicer_functions
. grpc_port
is the port for gRPC
proxies to listen to. It defaults to 9000. grpc_servicer_functions
is a list of import
paths for gRPC add_servicer_to_server
functions to add to a gRPC proxy. It also
serves as the flag to determine whether to start gRPC server. The default is an empty
list, meaning no gRPC server is started.
ray start --head
serve start \
--grpc-port 9000 \
--grpc-servicer-functions user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server \
--grpc-servicer-functions user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server
from ray import serve
from ray.serve.config import gRPCOptions
grpc_port = 9000
grpc_servicer_functions = [
"user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server",
"user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server",
]
serve.start(
grpc_options=gRPCOptions(
port=grpc_port,
grpc_servicer_functions=grpc_servicer_functions,
),
)
# config.yaml
grpc_options:
port: 9000
grpc_servicer_functions:
- user_defined_protos_pb2_grpc.add_UserDefinedServiceServicer_to_server
- user_defined_protos_pb2_grpc.add_ImageClassificationServiceServicer_to_server
applications:
- name: app1
route_prefix: /app1
import_path: test_deployment_v2:g
runtime_env: {}
- name: app2
route_prefix: /app2
import_path: test_deployment_v2:g2
runtime_env: {}
# Start Serve with above config file.
serve run config.yaml
Deploy gRPC applications#
gRPC applications in Serve works similarly to HTTP applications. The only difference is
that the input and output of the methods need to match with what’s defined in the .proto
file and that the method of the application needs to be an exact match (case sensitive)
with the predefined RPC methods. For example, if we want to deploy UserDefinedService
with __call__
method, the method name needs to be __call__
, the input type needs to
be UserDefinedMessage
, and the output type needs to be UserDefinedResponse
. Serve
passes the protobuf object into the method and expects the protobuf object back
from the method.
Example deployment:
import time
from typing import Generator
from user_defined_protos_pb2 import (
UserDefinedMessage,
UserDefinedMessage2,
UserDefinedResponse,
UserDefinedResponse2,
)
import ray
from ray import serve
@serve.deployment
class GrpcDeployment:
def __call__(self, user_message: UserDefinedMessage) -> UserDefinedResponse:
greeting = f"Hello {user_message.name} from {user_message.origin}"
num = user_message.num * 2
user_response = UserDefinedResponse(
greeting=greeting,
num=num,
)
return user_response
@serve.multiplexed(max_num_models_per_replica=1)
async def get_model(self, model_id: str) -> str:
return f"loading model: {model_id}"
async def Multiplexing(
self, user_message: UserDefinedMessage2
) -> UserDefinedResponse2:
model_id = serve.get_multiplexed_model_id()
model = await self.get_model(model_id)
user_response = UserDefinedResponse2(
greeting=f"Method2 called model, {model}",
)
return user_response
def Streaming(
self, user_message: UserDefinedMessage
) -> Generator[UserDefinedResponse, None, None]:
for i in range(10):
greeting = f"{i}: Hello {user_message.name} from {user_message.origin}"
num = user_message.num * 2 + i
user_response = UserDefinedResponse(
greeting=greeting,
num=num,
)
yield user_response
time.sleep(0.1)
g = GrpcDeployment.bind()
Deploy the application:
app1 = "app1"
serve.run(target=g, name=app1, route_prefix=f"/{app1}")
Note
route_prefix
is still a required field as of Ray 2.7.0 due to a shared code path with
HTTP. Future releases will make it optional for gRPC.
Send gRPC requests to serve deployments#
Sending a gRPC request to a Serve deployment is similar to sending a gRPC request to any other gRPC server. Create a gRPC channel and stub, then call the RPC method on the stub with the appropriate input. The output is the protobuf object that your Serve application returns.
Sending a gRPC request:
import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage
channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
response, call = stub.__call__.with_call(request=request)
print(f"status code: {call.code()}") # grpc.StatusCode.OK
print(f"greeting: {response.greeting}") # "Hello foo from bar"
print(f"num: {response.num}") # 60
Read more about gRPC clients in Python: https://grpc.io/docs/languages/python/basics/#client
Check proxy health#
Similar to HTTP /-/routes
and /-/healthz
endpoints, Serve also provides gRPC
service method to be used in health check.
/ray.serve.RayServeAPIService/ListApplications
is used to list all applications deployed in Serve./ray.serve.RayServeAPIService/Healthz
is used to check the health of the proxy. It returnsOK
status and “success” message if the proxy is healthy.
The service method and protobuf are defined as below:
message ListApplicationsRequest {}
message ListApplicationsResponse {
repeated string application_names = 1;
}
message HealthzRequest {}
message HealthzResponse {
string message = 1;
}
service RayServeAPIService {
rpc ListApplications(ListApplicationsRequest) returns (ListApplicationsResponse);
rpc Healthz(HealthzRequest) returns (HealthzResponse);
}
You can call the service method with the following code:
import grpc
from ray.serve.generated.serve_pb2_grpc import RayServeAPIServiceStub
from ray.serve.generated.serve_pb2 import HealthzRequest, ListApplicationsRequest
channel = grpc.insecure_channel("localhost:9000")
stub = RayServeAPIServiceStub(channel)
request = ListApplicationsRequest()
response = stub.ListApplications(request=request)
print(f"Applications: {response.application_names}") # ["app1"]
request = HealthzRequest()
response = stub.Healthz(request=request)
print(f"Health: {response.message}") # "success"
Note
Serve provides the RayServeAPIServiceStub
stub, and HealthzRequest
and
ListApplicationsRequest
protobufs for you to use. You don’t need to generate them
from the proto file. They are available for your reference.
Work with gRPC metadata#
Just like HTTP headers, gRPC also supports metadata to pass request related information. You can pass metadata to Serve’s gRPC proxy and Serve knows how to parse and use them. Serve also passes trailing metadata back to the client.
List of Serve accepted metadata keys:
application
: The name of the Serve application to route to. If not passed and only one application is deployed, serve routes to the only deployed app automatically.request_id
: The request ID to track the request.multiplexed_model_id
: The model ID to do model multiplexing.
List of Serve returned trailing metadata keys:
request_id
: The request ID to track the request.
Example of using metadata:
import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage2
channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage2()
app_name = "app1"
request_id = "123"
multiplexed_model_id = "999"
metadata = (
("application", app_name),
("request_id", request_id),
("multiplexed_model_id", multiplexed_model_id),
)
response, call = stub.Multiplexing.with_call(request=request, metadata=metadata)
print(f"greeting: {response.greeting}") # "Method2 called model, loading model: 999"
for key, value in call.trailing_metadata():
print(f"trailing metadata key: {key}, value {value}") # "request_id: 123"
Use streaming and model composition#
gRPC proxy remains at feature parity with HTTP proxy. Here are more examples of using gRPC proxy for getting streaming response as well as doing model composition.
Streaming#
The Steaming
method is deployed with the app named “app1” above. The following code
gets a streaming response.
import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage
channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
metadata = (("application", "app1"),)
responses = stub.Streaming(request=request, metadata=metadata)
for response in responses:
print(f"greeting: {response.greeting}") # greeting: n: Hello foo from bar
print(f"num: {response.num}") # num: 60 + n
Model composition#
Assuming we have the below deployments. ImageDownloader
and DataPreprocessor
are two
separate steps to download and process the image before PyTorch can run inference.
The ImageClassifier
deployment initializes the model, calls both
ImageDownloader
and DataPreprocessor
, and feed into the resnet model to get the
classes and probabilities of the given image.
import requests
import torch
from typing import List
from PIL import Image
from io import BytesIO
from torchvision import transforms
from user_defined_protos_pb2 import (
ImageClass,
ImageData,
)
from ray import serve
from ray.serve.handle import DeploymentHandle
@serve.deployment
class ImageClassifier:
def __init__(
self,
_image_downloader: DeploymentHandle,
_data_preprocessor: DeploymentHandle,
):
self._image_downloader = _image_downloader
self._data_preprocessor = _data_preprocessor
self.model = torch.hub.load(
"pytorch/vision:v0.10.0", "resnet18", pretrained=True
)
self.model.eval()
self.categories = self._image_labels()
def _image_labels(self) -> List[str]:
categories = []
url = (
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
)
labels = requests.get(url).text
for label in labels.split("\n"):
categories.append(label.strip())
return categories
async def Predict(self, image_data: ImageData) -> ImageClass:
# Download image
image = await self._image_downloader.remote(image_data.url)
# Preprocess image
input_batch = await self._data_preprocessor.remote(image)
# Predict image
with torch.no_grad():
output = self.model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
return self.process_model_outputs(probabilities)
def process_model_outputs(self, probabilities: torch.Tensor) -> ImageClass:
image_classes = []
image_probabilities = []
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
image_classes.append(self.categories[top5_catid[i]])
image_probabilities.append(top5_prob[i].item())
return ImageClass(
classes=image_classes,
probabilities=image_probabilities,
)
@serve.deployment
class ImageDownloader:
def __call__(self, image_url: str):
image_bytes = requests.get(image_url).content
return Image.open(BytesIO(image_bytes)).convert("RGB")
@serve.deployment
class DataPreprocessor:
def __init__(self):
self.preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def __call__(self, image: Image):
input_tensor = self.preprocess(image)
return input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
image_downloader = ImageDownloader.bind()
data_preprocessor = DataPreprocessor.bind()
g2 = ImageClassifier.options(name="grpc-image-classifier").bind(
image_downloader, data_preprocessor
)
We can deploy the application with the following code:
app2 = "app2"
serve.run(target=g2, name=app2, route_prefix=f"/{app2}")
The client code to call the application looks like the following:
import grpc
from user_defined_protos_pb2_grpc import ImageClassificationServiceStub
from user_defined_protos_pb2 import ImageData
channel = grpc.insecure_channel("localhost:9000")
stub = ImageClassificationServiceStub(channel)
request = ImageData(url="https://github.com/pytorch/hub/raw/master/images/dog.jpg")
metadata = (("application", "app2"),) # Make sure application metadata is passed.
response, call = stub.Predict.with_call(request=request, metadata=metadata)
print(f"status code: {call.code()}") # grpc.StatusCode.OK
print(f"Classes: {response.classes}") # ['Samoyed', ...]
print(f"Probabilities: {response.probabilities}") # [0.8846230506896973, ...]
Note
At this point, two applications are running on Serve, “app1” and “app2”. If more
than one application is running, you need to pass application
to the
metadata so Serve knows which application to route to.
Handle errors#
Similar to any other gRPC server, request throws a grpc.RpcError
when the response
code is not “OK”. Put your request code in a try-except block and handle
the error accordingly.
import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage
channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
try:
response = stub.__call__(request=request)
except grpc.RpcError as rpc_error:
print(f"status code: {rpc_error.code()}") # StatusCode.NOT_FOUND
print(f"details: {rpc_error.details()}") # Application metadata not set...
Serve uses the following gRPC error codes:
NOT_FOUND
: When multiple applications are deployed to Serve and the application is not passed in metadata or passed but no matching application.UNAVAILABLE
: Only on the health check methods when the proxy is in draining state. When the health check is throwingUNAVAILABLE
, it means the health check failed on this node and you should no longer route to this node.DEADLINE_EXCEEDED
: The request took longer than the timeout setting and got cancelled.INTERNAL
: Other unhandled errors during the request.
Use gRPC context#
Serve provides a gRPC context object
to the deployment replica to get information
about the request as well as setting response metadata such as code and details.
If the handler function is defined with a grpc_context
argument, Serve will pass a
RayServegRPCContext object
in for each request. Below is an example of how to set a custom status code,
details, and trailing metadata.
from user_defined_protos_pb2 import UserDefinedMessage, UserDefinedResponse
from ray import serve
from ray.serve.grpc_util import RayServegRPCContext
import grpc
from typing import Tuple
@serve.deployment
class GrpcDeployment:
def __init__(self):
self.nums = {}
def num_lookup(self, name: str) -> Tuple[int, grpc.StatusCode, str]:
if name not in self.nums:
self.nums[name] = len(self.nums)
code = grpc.StatusCode.INVALID_ARGUMENT
message = f"{name} not found, adding to nums."
else:
code = grpc.StatusCode.OK
message = f"{name} found."
return self.nums[name], code, message
def __call__(
self,
user_message: UserDefinedMessage,
grpc_context: RayServegRPCContext, # to use grpc context, add this kwarg
) -> UserDefinedResponse:
greeting = f"Hello {user_message.name} from {user_message.origin}"
num, code, message = self.num_lookup(user_message.name)
# Set custom code, details, and trailing metadata.
grpc_context.set_code(code)
grpc_context.set_details(message)
grpc_context.set_trailing_metadata([("num", str(num))])
user_response = UserDefinedResponse(
greeting=greeting,
num=num,
)
return user_response
g = GrpcDeployment.bind()
app1 = "app1"
serve.run(target=g, name=app1, route_prefix=f"/{app1}")
The client code is defined like the following to get those attributes.
import grpc
from user_defined_protos_pb2_grpc import UserDefinedServiceStub
from user_defined_protos_pb2 import UserDefinedMessage
channel = grpc.insecure_channel("localhost:9000")
stub = UserDefinedServiceStub(channel)
request = UserDefinedMessage(name="foo", num=30, origin="bar")
metadata = (("application", "app1"),)
# First call is going to page miss and return INVALID_ARGUMENT status code.
try:
response, call = stub.__call__.with_call(request=request, metadata=metadata)
except grpc.RpcError as rpc_error:
assert rpc_error.code() == grpc.StatusCode.INVALID_ARGUMENT
assert rpc_error.details() == "foo not found, adding to nums."
assert any(
[key == "num" and value == "0" for key, value in rpc_error.trailing_metadata()]
)
assert any([key == "request_id" for key, _ in rpc_error.trailing_metadata()])
# Second call is going to page hit and return OK status code.
response, call = stub.__call__.with_call(request=request, metadata=metadata)
assert call.code() == grpc.StatusCode.OK
assert call.details() == "foo found."
assert any([key == "num" and value == "0" for key, value in call.trailing_metadata()])
assert any([key == "request_id" for key, _ in call.trailing_metadata()])
Note
If the handler raises an unhandled exception, Serve will return an INTERNAL
error code
with the stacktrace in the details, regardless of what code and details
are set in the RayServegRPCContext
object.