Source code for ray.serve._private.request_router.replica_wrapper
import asyncio
import pickle
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Set
import ray
from ray.actor import ActorHandle
from ray.serve._private.common import (
ReplicaID,
RunningReplicaInfo,
)
from ray.serve._private.replica_result import ActorReplicaResult, ReplicaResult
from ray.serve._private.request_router.common import PendingRequest
from ray.serve._private.utils import JavaActorHandleProxy
from ray.serve.generated.serve_pb2 import RequestMetadata as RequestMetadataProto
from ray.util.annotations import PublicAPI
class ReplicaWrapper(ABC):
"""This is used to abstract away details of the transport layer
when communicating with the replica.
"""
@abstractmethod
def send_request_java(self, pr: PendingRequest) -> ReplicaResult:
"""Send request to Java replica."""
pass
@abstractmethod
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> ReplicaResult:
"""Send request to Python replica.
If sending request with rejection, the replica will yield a
system message (ReplicaQueueLengthInfo) before executing the
actual request. This can cause it to reject the request. The
result will *always* be a generator, so for non-streaming
requests it's up to the caller to resolve it to its first (and
only) ObjectRef.
"""
pass
class ActorReplicaWrapper(ReplicaWrapper):
def __init__(self, actor_handle):
self._actor_handle = actor_handle
def send_request_java(self, pr: PendingRequest) -> ActorReplicaResult:
"""Send the request to a Java replica.
Does not currently support streaming.
"""
if pr.metadata.is_streaming:
raise RuntimeError("Streaming not supported for Java.")
if len(pr.args) != 1:
raise ValueError("Java handle calls only support a single argument.")
return ActorReplicaResult(
self._actor_handle.handle_request.remote(
RequestMetadataProto(
request_id=pr.metadata.request_id,
# Default call method in java is "call," not "__call__" like Python.
call_method="call"
if pr.metadata.call_method == "__call__"
else pr.metadata.call_method,
).SerializeToString(),
pr.args,
),
pr.metadata,
)
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> ActorReplicaResult:
"""Send the request to a Python replica."""
if with_rejection:
# Call a separate handler that may reject the request.
# This handler is *always* a streaming call and the first message will
# be a system message that accepts or rejects.
method = self._actor_handle.handle_request_with_rejection.options(
num_returns="streaming"
)
elif pr.metadata.is_streaming:
method = self._actor_handle.handle_request_streaming.options(
num_returns="streaming"
)
else:
method = self._actor_handle.handle_request
obj_ref_gen = method.remote(pickle.dumps(pr.metadata), *pr.args, **pr.kwargs)
return ActorReplicaResult(
obj_ref_gen, pr.metadata, with_rejection=with_rejection
)
[docs]
@PublicAPI(stability="alpha")
class RunningReplica:
"""Contains info on a running replica.
Also defines the interface for a request router to talk to a replica.
"""
def __init__(self, replica_info: RunningReplicaInfo):
self._replica_info = replica_info
self._multiplexed_model_ids = set(replica_info.multiplexed_model_ids)
if replica_info.is_cross_language:
self._actor_handle = JavaActorHandleProxy(replica_info.actor_handle)
else:
self._actor_handle = replica_info.actor_handle
@property
def replica_id(self) -> ReplicaID:
"""ID of this replica."""
return self._replica_info.replica_id
@property
def actor_id(self) -> ray.ActorID:
"""Actor ID of this replica."""
return self._actor_handle._actor_id
@property
def node_id(self) -> str:
"""Node ID of the node this replica is running on."""
return self._replica_info.node_id
@property
def availability_zone(self) -> Optional[str]:
"""Availability zone of the node this replica is running on."""
return self._replica_info.availability_zone
@property
def multiplexed_model_ids(self) -> Set[str]:
"""Set of model IDs on this replica."""
return self._multiplexed_model_ids
@property
def routing_stats(self) -> Dict[str, Any]:
"""Dictionary of routing stats."""
return self._replica_info.routing_stats
@property
def max_ongoing_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
return self._replica_info.max_ongoing_requests
@property
def is_cross_language(self) -> bool:
"""Whether this replica is cross-language (Java)."""
return self._replica_info.is_cross_language
def _get_replica_wrapper(self, pr: PendingRequest) -> ReplicaWrapper:
return ActorReplicaWrapper(self._actor_handle)
[docs]
def push_proxy_handle(self, handle: ActorHandle):
"""When on proxy, push proxy's self handle to replica"""
self._actor_handle.push_proxy_handle.remote(handle)
[docs]
async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.
`deadline_s` is passed to verify backoff for testing.
"""
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
# change both (or introduce a branch here).
obj_ref = self._actor_handle.get_num_ongoing_requests.remote()
try:
return await obj_ref
except asyncio.CancelledError:
ray.cancel(obj_ref)
raise
[docs]
def try_send_request(
self, pr: PendingRequest, with_rejection: bool
) -> ReplicaResult:
"""Try to send the request to this replica. It may be rejected."""
wrapper = self._get_replica_wrapper(pr)
if self._replica_info.is_cross_language:
assert not with_rejection, "Request rejection not supported for Java."
return wrapper.send_request_java(pr)
return wrapper.send_request_python(pr, with_rejection=with_rejection)