import asyncio
import logging
import pickle
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Set
import grpc
import ray
from ray.actor import ActorHandle
from ray.serve._private.common import (
ReplicaID,
RunningReplicaInfo,
)
from ray.serve._private.constants import (
RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
SERVE_LOGGER_NAME,
)
from ray.serve._private.replica_result import (
ActorReplicaResult,
ReplicaResult,
gRPCReplicaResult,
)
from ray.serve._private.request_router.common import PendingRequest
from ray.serve._private.serialization import RPCSerializer
from ray.serve._private.utils import JavaActorHandleProxy
from ray.serve.generated.serve_pb2 import (
ASGIRequest,
RequestMetadata as RequestMetadataProto,
)
from ray.serve.generated.serve_pb2_grpc import ASGIServiceStub
from ray.util.annotations import PublicAPI
logger = logging.getLogger(SERVE_LOGGER_NAME)
class ReplicaWrapper(ABC):
"""This is used to abstract away details of the transport layer
when communicating with the replica.
"""
@abstractmethod
def send_request_java(self, pr: PendingRequest) -> ReplicaResult:
"""Send request to Java replica."""
pass
@abstractmethod
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> ReplicaResult:
"""Send request to Python replica.
If sending request with rejection, the replica will yield a
system message (ReplicaQueueLengthInfo) before executing the
actual request. This can cause it to reject the request. The
result will *always* be a generator, so for non-streaming
requests it's up to the caller to resolve it to its first (and
only) ObjectRef.
"""
pass
class ActorReplicaWrapper(ReplicaWrapper):
def __init__(self, actor_handle):
self._actor_handle = actor_handle
def send_request_java(self, pr: PendingRequest) -> ActorReplicaResult:
"""Send the request to a Java replica.
Does not currently support streaming.
"""
if pr.metadata.is_streaming:
raise RuntimeError("Streaming not supported for Java.")
if len(pr.args) != 1:
raise ValueError("Java handle calls only support a single argument.")
return ActorReplicaResult(
self._actor_handle.handle_request.remote(
RequestMetadataProto(
request_id=pr.metadata.request_id,
# Default call method in java is "call," not "__call__" like Python.
call_method="call"
if pr.metadata.call_method == "__call__"
else pr.metadata.call_method,
).SerializeToString(),
pr.args,
),
pr.metadata,
)
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> ActorReplicaResult:
"""Send the request to a Python replica."""
if with_rejection:
# Call a separate handler that may reject the request.
# This handler is *always* a streaming call and the first message will
# be a system message that accepts or rejects.
method = self._actor_handle.handle_request_with_rejection.options(
num_returns="streaming"
)
elif pr.metadata.is_streaming:
method = self._actor_handle.handle_request_streaming.options(
num_returns="streaming"
)
else:
method = self._actor_handle.handle_request
obj_ref_gen = method.remote(pickle.dumps(pr.metadata), *pr.args, **pr.kwargs)
return ActorReplicaResult(
obj_ref_gen, pr.metadata, with_rejection=with_rejection
)
class gRPCReplicaWrapper(ReplicaWrapper):
def __init__(self, stub, actor_id):
self._stub = stub
self._actor_id = actor_id
self._loop = asyncio.get_running_loop()
def send_request_java(self, pr: PendingRequest):
raise RuntimeError("gRPC requests not supported for Java.")
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> gRPCReplicaResult:
"""Send the request to a Python replica."""
# Get serialization options from request metadata
request_serialization = pr.metadata.request_serialization
response_serialization = pr.metadata.response_serialization
# Get cached serializer for this request to avoid per-request instantiation overhead
serializer = RPCSerializer.get_cached_serializer(
request_serialization, response_serialization
)
asgi_request = ASGIRequest(
pickled_request_metadata=pickle.dumps(pr.metadata),
request_args=serializer.dumps_request(pr.args),
request_kwargs=serializer.dumps_request(pr.kwargs),
)
if with_rejection and pr.metadata.is_streaming:
# Call a separate handler that may reject the request.
# This handler is *always* a streaming call and the first message will
# be a system message that accepts or rejects.
call = self._stub.HandleRequestWithRejectionStreaming(asgi_request)
elif with_rejection and not pr.metadata.is_streaming:
# Call a separate handler that may reject the request.
# This handler is *always* a unary call and the first message will
# be a system message that accepts or rejects.
call = self._stub.HandleRequestWithRejection(asgi_request)
elif pr.metadata.is_streaming:
call = self._stub.HandleRequestStreaming(asgi_request)
else:
call = self._stub.HandleRequest(asgi_request)
return gRPCReplicaResult(
call,
pr.metadata,
self._actor_id,
loop=self._loop,
with_rejection=with_rejection,
)
[docs]
@PublicAPI(stability="alpha")
class RunningReplica:
"""Contains info on a running replica.
Also defines the interface for a request router to talk to a replica.
"""
def __init__(self, replica_info: RunningReplicaInfo):
self._replica_info = replica_info
self._multiplexed_model_ids = set(replica_info.multiplexed_model_ids)
# Fetch and cache the actor handle once per RunningReplica instance.
# This avoids the borrower-of-borrower pattern while minimizing GCS lookups.
actor_handle = replica_info.get_actor_handle()
if replica_info.is_cross_language:
self._actor_handle = JavaActorHandleProxy(actor_handle)
else:
self._actor_handle = actor_handle
# Lazily created
self._channel = None
self._stub = None
# Replica wrappers
self._actor_replica_wrapper = ActorReplicaWrapper(self._actor_handle)
self._grpc_replica_wrapper = None
@property
def replica_id(self) -> ReplicaID:
"""ID of this replica."""
return self._replica_info.replica_id
@property
def actor_id(self) -> ray.ActorID:
"""Actor ID of this replica."""
return self._actor_handle._actor_id
@property
def node_id(self) -> str:
"""Node ID of the node this replica is running on."""
return self._replica_info.node_id
@property
def availability_zone(self) -> Optional[str]:
"""Availability zone of the node this replica is running on."""
return self._replica_info.availability_zone
@property
def multiplexed_model_ids(self) -> Set[str]:
"""Set of model IDs on this replica."""
return self._multiplexed_model_ids
@property
def routing_stats(self) -> Dict[str, Any]:
"""Dictionary of routing stats."""
return self._replica_info.routing_stats
@property
def max_ongoing_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
return self._replica_info.max_ongoing_requests
@property
def is_cross_language(self) -> bool:
"""Whether this replica is cross-language (Java)."""
return self._replica_info.is_cross_language
@property
def stub(self):
if self._stub is None:
self._channel = grpc.aio.insecure_channel(
f"{self._replica_info.node_ip}:{self._replica_info.port}",
options=[
(
"grpc.max_receive_message_length",
RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
)
],
)
self._stub = ASGIServiceStub(self._channel)
return self._stub
def _get_replica_wrapper(self, pr: PendingRequest) -> ReplicaWrapper:
if self._grpc_replica_wrapper is None:
self._grpc_replica_wrapper = gRPCReplicaWrapper(
self.stub, self._actor_handle._actor_id
)
return (
self._actor_replica_wrapper
if pr.metadata._by_reference
else self._grpc_replica_wrapper
)
[docs]
def push_proxy_handle(self, handle: ActorHandle):
"""When on proxy, push proxy's self handle to replica"""
self._actor_handle.push_proxy_handle.remote(handle)
[docs]
async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.
`deadline_s` is passed to verify backoff for testing.
"""
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
# change both (or introduce a branch here).
obj_ref = self._actor_handle.get_num_ongoing_requests.remote()
try:
return await obj_ref
except asyncio.CancelledError:
ray.cancel(obj_ref)
raise
[docs]
def try_send_request(
self, pr: PendingRequest, with_rejection: bool
) -> ReplicaResult:
"""Try to send the request to this replica. It may be rejected."""
wrapper = self._get_replica_wrapper(pr)
if self._replica_info.is_cross_language:
assert not with_rejection, "Request rejection not supported for Java."
return wrapper.send_request_java(pr)
return wrapper.send_request_python(pr, with_rejection=with_rejection)