Source code for ray.serve._private.request_router.common
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set
from ray.serve._private.common import ReplicaID, RequestMetadata
from ray.serve._private.constants import (
RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S,
SERVE_LOGGER_NAME,
)
from ray.util.annotations import PublicAPI
logger = logging.getLogger(SERVE_LOGGER_NAME)
@dataclass()
class RequestRoutingContext:
multiplexed_start_matching_time: Optional[float] = None
tried_fewest_multiplexed_models: bool = False
tried_first_multiplexed_models: bool = False
tried_same_node: bool = False
tried_same_az: bool = False
should_backoff: bool = False
[docs]
@PublicAPI(stability="alpha")
@dataclass
class PendingRequest:
"""A request that is pending execution by a replica."""
args: List[Any]
"""Positional arguments for the request."""
kwargs: Dict[Any, Any]
"""Keyword arguments for the request."""
metadata: RequestMetadata
"""Metadata for the request, including request ID and whether it's streaming."""
created_at: float = field(default_factory=time.time)
"""Timestamp when the request was created."""
future: asyncio.Future = field(default_factory=lambda: asyncio.Future())
"""An asyncio Future that will be set when the request is routed."""
routing_context: RequestRoutingContext = field(
default_factory=RequestRoutingContext
)
"""Context for request routing, used to track routing attempts and backoff."""
[docs]
def reset_future(self):
"""Reset the `asyncio.Future`, must be called if this request is re-used."""
self.future = asyncio.Future()
@dataclass(frozen=True)
class ReplicaQueueLengthCacheEntry:
queue_len: int
timestamp: float
class ReplicaQueueLengthCache:
def __init__(
self,
*,
staleness_timeout_s: float = RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S,
get_curr_time_s: Optional[Callable[[], float]] = None,
):
self._cache: Dict[ReplicaID, ReplicaQueueLengthCacheEntry] = {}
self._staleness_timeout_s = staleness_timeout_s
self._get_curr_time_s = (
get_curr_time_s if get_curr_time_s is not None else time.time
)
def _is_timed_out(self, timestamp_s: int) -> bool:
return self._get_curr_time_s() - timestamp_s > self._staleness_timeout_s
def get(self, replica_id: ReplicaID) -> Optional[int]:
"""Get the queue length for a replica.
Returns `None` if the replica ID is not present or the entry is timed out.
"""
entry = self._cache.get(replica_id)
if entry is None or self._is_timed_out(entry.timestamp):
return None
return entry.queue_len
def update(self, replica_id: ReplicaID, queue_len: int):
"""Set (or update) the queue length for a replica ID."""
self._cache[replica_id] = ReplicaQueueLengthCacheEntry(
queue_len, self._get_curr_time_s()
)
def invalidate_key(self, replica_id: ReplicaID):
self._cache.pop(replica_id, None)
def remove_inactive_replicas(self, *, active_replica_ids: Set[ReplicaID]):
"""Removes entries for all replica IDs not in the provided active set."""
# NOTE: the size of the cache dictionary changes during this loop.
for replica_id in list(self._cache.keys()):
if replica_id not in active_replica_ids:
self._cache.pop(replica_id)