import asyncio
import logging
import pickle
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Set, Tuple
import grpc
import ray
from ray.actor import ActorHandle
from ray.serve._private.common import (
DeploymentID,
ReplicaID,
ReplicaQueueLengthInfo,
RequestMetadata,
RunningReplicaInfo,
)
from ray.serve._private.constants import (
RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
SERVE_LOGGER_NAME,
)
from ray.serve._private.replica_result import (
ActorReplicaResult,
ReplicaResult,
gRPCReplicaResult,
)
from ray.serve._private.request_router.common import PendingRequest
from ray.serve._private.serialization import RPCSerializer
from ray.serve._private.utils import JavaActorHandleProxy
from ray.serve.generated.serve_pb2 import (
ASGIRequest,
RequestMetadata as RequestMetadataProto,
)
from ray.serve.generated.serve_pb2_grpc import ASGIServiceStub
from ray.util.annotations import PublicAPI
from ray.util.tracing.tracing_helper import (
_DictPropagator,
_is_tracing_enabled,
)
logger = logging.getLogger(SERVE_LOGGER_NAME)
class ReplicaWrapper(ABC):
"""This is used to abstract away details of the transport layer
when communicating with the replica.
"""
@abstractmethod
def send_request_java(self, pr: PendingRequest) -> ReplicaResult:
"""Send request to Java replica."""
pass
@abstractmethod
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> ReplicaResult:
"""Send request to Python replica.
If sending request with rejection, the replica will yield a
system message (ReplicaQueueLengthInfo) before executing the
actual request. This can cause it to reject the request. The
result will *always* be a generator, so for non-streaming
requests it's up to the caller to resolve it to its first (and
only) ObjectRef.
"""
pass
class ActorReplicaWrapper(ReplicaWrapper):
def __init__(self, actor_handle):
self._actor_handle = actor_handle
def send_request_java(self, pr: PendingRequest) -> ActorReplicaResult:
"""Send the request to a Java replica.
Does not currently support streaming.
"""
if pr.metadata.is_streaming:
raise RuntimeError("Streaming not supported for Java.")
if len(pr.args) != 1:
raise ValueError("Java handle calls only support a single argument.")
return ActorReplicaResult(
self._actor_handle.handle_request.remote(
RequestMetadataProto(
request_id=pr.metadata.request_id,
# Default call method in java is "call," not "__call__" like Python.
call_method="call"
if pr.metadata.call_method == "__call__"
else pr.metadata.call_method,
).SerializeToString(),
pr.args,
),
pr.metadata,
)
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> ActorReplicaResult:
"""Send the request to a Python replica."""
if with_rejection:
# Call a separate handler that may reject the request.
# This handler is *always* a streaming call and the first message will
# be a system message that accepts or rejects.
method = self._actor_handle.handle_request_with_rejection.options(
num_returns="streaming"
)
elif pr.metadata.is_streaming:
method = self._actor_handle.handle_request_streaming.options(
num_returns="streaming"
)
else:
method = self._actor_handle.handle_request
obj_ref_gen = method.remote(pickle.dumps(pr.metadata), *pr.args, **pr.kwargs)
return ActorReplicaResult(
obj_ref_gen, pr.metadata, with_rejection=with_rejection
)
class gRPCReplicaWrapper(ReplicaWrapper):
def __init__(self, stub, actor_id):
self._stub = stub
self._actor_id = actor_id
self._loop = asyncio.get_running_loop()
def send_request_java(self, pr: PendingRequest):
raise RuntimeError("gRPC requests not supported for Java.")
def send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> gRPCReplicaResult:
"""Send the request to a Python replica."""
# Get serialization options from request metadata
request_serialization = pr.metadata.request_serialization
response_serialization = pr.metadata.response_serialization
# Get cached serializer for this request to avoid per-request instantiation overhead
serializer = RPCSerializer.get_cached_serializer(
request_serialization, response_serialization
)
# When using gRPC transport, requests go over the network rather than through
# Ray's actor RPC. Ray's tracing decorators inject _ray_trace_ctx for actor
# calls, but that doesn't apply here. We must manually inject the current
# trace context so it propagates to the replica (matching the actor path).
if _is_tracing_enabled():
pr.kwargs["_ray_trace_ctx"] = _DictPropagator.inject_current_context()
asgi_request = ASGIRequest(
pickled_request_metadata=pickle.dumps(pr.metadata),
request_args=serializer.dumps_request(pr.args),
request_kwargs=serializer.dumps_request(pr.kwargs),
)
if with_rejection and pr.metadata.is_streaming:
# Call a separate handler that may reject the request.
# This handler is *always* a streaming call and the first message will
# be a system message that accepts or rejects.
call = self._stub.HandleRequestWithRejectionStreaming(asgi_request)
elif with_rejection and not pr.metadata.is_streaming:
# Call a separate handler that may reject the request.
# This handler is *always* a unary call and the first message will
# be a system message that accepts or rejects.
call = self._stub.HandleRequestWithRejection(asgi_request)
elif pr.metadata.is_streaming:
call = self._stub.HandleRequestStreaming(asgi_request)
else:
call = self._stub.HandleRequest(asgi_request)
return gRPCReplicaResult(
call,
pr.metadata,
self._actor_id,
loop=self._loop,
with_rejection=with_rejection,
)
[docs]
@PublicAPI(stability="alpha")
class RunningReplica:
"""Contains info on a running replica.
Also defines the interface for a request router to talk to a replica.
"""
def __init__(self, replica_info: RunningReplicaInfo):
self._replica_info = replica_info
self._multiplexed_model_ids = set(replica_info.multiplexed_model_ids)
# Fetch and cache the actor handle once per RunningReplica instance.
# This avoids the borrower-of-borrower pattern while minimizing GCS lookups.
actor_handle = replica_info.get_actor_handle()
if replica_info.is_cross_language:
self._actor_handle = JavaActorHandleProxy(actor_handle)
else:
self._actor_handle = actor_handle
# Lazily created
self._channel = None
self._stub = None
# Replica wrappers
self._actor_replica_wrapper = ActorReplicaWrapper(self._actor_handle)
self._grpc_replica_wrapper = None
[docs]
def update_replica_info(self, replica_info: RunningReplicaInfo) -> None:
"""Update mutable fields from a new RunningReplicaInfo.
Called when reusing an existing wrapper in _update_running_replicas.
Replicas dynamically load/unload models via record_multiplexed_model_ids,
which triggers a broadcast with updated RunningReplicaInfo. Without this
update, the router would use stale multiplexed_model_ids and break
multiplexed model routing.
Because we reassign _replica_info, any property that reads from it
(including max_ongoing_requests, node_id, availability_zone, etc.)
will reflect the new values. Fields that are cached separately
(e.g., _actor_handle) are NOT refreshed here because they are tied
to the replica's identity and should never change for a live replica.
"""
self._replica_info = replica_info
self._multiplexed_model_ids = set(replica_info.multiplexed_model_ids)
@property
def replica_id(self) -> ReplicaID:
"""ID of this replica."""
return self._replica_info.replica_id
@property
def actor_id(self) -> ray.ActorID:
"""Actor ID of this replica."""
return self._actor_handle._actor_id
@property
def node_id(self) -> str:
"""Node ID of the node this replica is running on."""
return self._replica_info.node_id
@property
def availability_zone(self) -> Optional[str]:
"""Availability zone of the node this replica is running on."""
return self._replica_info.availability_zone
@property
def multiplexed_model_ids(self) -> Set[str]:
"""Set of model IDs on this replica."""
return self._multiplexed_model_ids
@property
def routing_stats(self) -> Dict[str, Any]:
"""Dictionary of routing stats."""
return self._replica_info.routing_stats
@property
def max_ongoing_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
return self._replica_info.max_ongoing_requests
@property
def is_cross_language(self) -> bool:
"""Whether this replica is cross-language (Java)."""
return self._replica_info.is_cross_language
@property
def backend_http_endpoint(self) -> Optional[Tuple[str, int]]:
"""Return (host, port) of the replica's backend HTTP server."""
port = self._replica_info.backend_http_port
host = self._replica_info.node_ip
if host is not None and port is not None:
return (host, port)
return None
@property
def stub(self):
if self._stub is None:
self._channel = grpc.aio.insecure_channel(
f"{self._replica_info.node_ip}:{self._replica_info.port}",
options=[
(
"grpc.max_receive_message_length",
RAY_SERVE_REPLICA_GRPC_MAX_MESSAGE_LENGTH,
)
],
)
self._stub = ASGIServiceStub(self._channel)
return self._stub
def _get_replica_wrapper(self, pr: PendingRequest) -> ReplicaWrapper:
if self._grpc_replica_wrapper is None:
self._grpc_replica_wrapper = gRPCReplicaWrapper(
self.stub, self._actor_handle._actor_id
)
return (
self._actor_replica_wrapper
if pr.metadata._by_reference
else self._grpc_replica_wrapper
)
[docs]
def push_proxy_handle(self, handle: ActorHandle):
"""When on proxy, push proxy's self handle to replica"""
self._actor_handle.push_proxy_handle.remote(handle)
[docs]
async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.
`deadline_s` is passed to verify backoff for testing.
"""
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
# change both (or introduce a branch here).
obj_ref = self._actor_handle.get_num_ongoing_requests.remote()
try:
return await obj_ref
except asyncio.CancelledError:
ray.cancel(obj_ref)
raise
[docs]
def try_send_request(
self, pr: PendingRequest, with_rejection: bool
) -> ReplicaResult:
"""Try to send the request to this replica. It may be rejected."""
wrapper = self._get_replica_wrapper(pr)
if self._replica_info.is_cross_language:
assert not with_rejection, "Request rejection not supported for Java."
return wrapper.send_request_java(pr)
return wrapper.send_request_python(pr, with_rejection=with_rejection)
[docs]
async def reserve_slot(
self, request_metadata: RequestMetadata
) -> Tuple[str, ReplicaQueueLengthInfo]:
"""Reserve a slot on this replica for an upcoming request.
Returns a unique token that can be used to release the slot later.
This is used in the choose_replica/dispatch pattern to track
reservations that haven't been dispatched yet.
"""
if self._replica_info.is_cross_language:
raise RuntimeError("Slot reservation not supported for Java.")
slot_token = str(uuid.uuid4())
obj_ref = self._actor_handle.reserve_slot.remote(request_metadata, slot_token)
try:
accepted, num_ongoing_requests = await obj_ref
except asyncio.CancelledError:
ray.cancel(obj_ref)
self._actor_handle.release_slot.remote(slot_token)
raise
except Exception:
# The actor may have reserved the slot before the reply was lost
# (e.g. ActorUnavailableError). `release_slot` is idempotent for unknown
# tokens, so this is safe even when the reservation never actually happened.
self._actor_handle.release_slot.remote(slot_token)
raise
return slot_token, ReplicaQueueLengthInfo(
accepted=accepted,
num_ongoing_requests=num_ongoing_requests,
)
[docs]
async def release_slot(self, slot_token: str) -> int:
"""Release a previously reserved slot.
This should be called if a request is not dispatched after
reserving a slot (e.g., due to an error or cancellation).
Returns the replica's reported num_ongoing_requests after the release.
"""
if self._replica_info.is_cross_language:
raise RuntimeError("Slot reservation not supported for Java.")
_, num_ongoing_requests = await self._actor_handle.release_slot.remote(
slot_token
)
return num_ongoing_requests
@dataclass
class ReplicaSelection:
"""Represents a selected replica, holding information for dispatch or coordination.
This class is returned by the choose_replica() context manager.
The slot reservation lifecycle is managed by the context manager.
"""
# Public, user-accessible fields
replica_id: str
"""Unique identifier for the selected replica."""
node_ip: str
"""IP address of the node running this replica."""
port: Optional[int]
"""Port number for direct communication (if configured)."""
node_id: str
"""Ray node ID where the replica is running."""
availability_zone: Optional[str]
"""Cloud availability zone of the replica's node."""
# Internal fields (not part of public API)
_replica: RunningReplica
_deployment_id: Optional[DeploymentID]
_request_metadata: RequestMetadata
_method_name: str
_slot_token: str # Token for reserved slot
_dispatched: bool = field(
default=False, init=False
) # Tracks if dispatch was called
# Set by dispatch once the result's done-callback is wired up. Read by
# choose_replica's finally to decide whether to fire on_request_completed
# manually (only one of the two paths should fire it).
_completion_callback_registered: bool = field(default=False, init=False)
@property
def address(self) -> str:
"""Returns the replica address in host:port format."""
if self.port:
return f"{self.node_ip}:{self.port}"
return self.node_ip
def to_dict(self) -> Dict[str, Any]:
"""Serialize public fields to a dictionary."""
return {
"replica_id": self.replica_id,
"node_ip": self.node_ip,
"port": self.port,
"node_id": self.node_id,
"availability_zone": self.availability_zone,
}
def _mark_dispatched(self) -> None:
"""Internal: Mark this selection as dispatched (slot consumed).
Raises:
RuntimeError: If the selection has already been dispatched.
"""
if self._dispatched:
raise RuntimeError(
f"ReplicaSelection for {self.replica_id} has already been dispatched. "
"Each selection can only be dispatched once."
)
self._dispatched = True
async def _release_slot(self, *, force: bool = False) -> Optional[int]:
"""Internal: Release the reserved slot.
Returns the replica's reported num_ongoing_requests after the release,
or None if dispatch already consumed the slot (and ``force`` is False).
"""
if self._dispatched and not force:
return None
return await self._replica.release_slot(self._slot_token)