Source code for ray.serve.handle

import asyncio
import concurrent.futures
from dataclasses import dataclass
from functools import wraps
import inspect
import os
from typing import Coroutine, Dict, Optional, Union
import threading

import ray
from ray._private.utils import get_or_create_event_loop
from ray.actor import ActorHandle

from ray import serve
from ray.serve._private.common import EndpointTag
from ray.serve._private.constants import (
    SERVE_HANDLE_JSON_KEY,
    SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY,
    ServeHandleType,
)
from ray.serve._private.utils import (
    get_random_letters,
    DEFAULT,
)
from ray.serve._private.router import Router, RequestMetadata
from ray.util import metrics
from ray.util.annotations import DeveloperAPI, PublicAPI

_global_async_loop = None


# Feature flag to revert to legacy behavior of synchronous deployment
# handle in dynamic dispatch. This is here as an escape hatch and last resort.
FLAG_SERVE_DEPLOYMENT_HANDLE_IS_SYNC = (
    os.environ.get(SYNC_HANDLE_IN_DAG_FEATURE_FLAG_ENV_KEY, "0") == "1"
)


def _wrap_into_async_task(async_func):
    """Wrap an async function so it returns async task instead of coroutine

    This makes the returned value awaitable more than once.
    """
    assert inspect.iscoroutinefunction(async_func)

    @wraps(async_func)
    def wrapper(*args, **kwargs):
        return asyncio.ensure_future(async_func(*args, **kwargs))

    return wrapper


def _create_or_get_async_loop_in_thread():
    global _global_async_loop
    if _global_async_loop is None:
        _global_async_loop = asyncio.new_event_loop()
        thread = threading.Thread(
            daemon=True,
            target=_global_async_loop.run_forever,
        )
        thread.start()
    return _global_async_loop


@PublicAPI(stability="beta")
@dataclass(frozen=True)
class HandleOptions:
    """Options for each ServeHandle instances. These fields are immutable."""

    method_name: str = "__call__"


[docs]@PublicAPI(stability="beta") class RayServeHandle: """A handle used to make requests from one deployment to another. This is used to compose multiple deployments into a single application. After building the application, this handle is substituted at runtime for deployments passed as arguments via `.bind()`. Example: .. code-block:: python import ray from ray import serve from ray.serve.handle import RayServeHandle, RayServeSyncHandle @serve.deployment class Downstream: def __init__(self, message: str): self._message = message def __call__(self, name: str) -> str: return self._message + name @serve.deployment class Ingress: def __init__(self, handle: RayServeHandle): self._handle = handle async def __call__(self, name: str) -> str: obj_ref: ray.ObjectRef = await self._handle.remote(name) return await obj_ref app = Ingress.bind(Downstream.bind("Hello ")) handle: RayServeSyncHandle = serve.run(app) # Prints "Hello Mr. Magoo" print(ray.get(handle.remote("Mr. Magoo"))) """ def __init__( self, controller_handle: ActorHandle, deployment_name: EndpointTag, handle_options: Optional[HandleOptions] = None, *, _router: Optional[Router] = None, _internal_pickled_http_request: bool = False, _stream: bool = False, ): self.controller_handle = controller_handle self.deployment_name = deployment_name self.handle_options = handle_options or HandleOptions() self.handle_tag = f"{self.deployment_name}#{get_random_letters()}" self._pickled_http_request = _internal_pickled_http_request self._stream = _stream self.request_counter = metrics.Counter( "serve_handle_request_counter", description=( "The number of handle.remote() calls that have been " "made on this handle." ), tag_keys=("handle", "deployment", "route", "application"), ) self.request_counter.set_default_tags( {"handle": self.handle_tag, "deployment": self.deployment_name} ) self.router: Router = _router or self._make_router() def _make_router(self) -> Router: return Router( self.controller_handle, self.deployment_name, event_loop=get_or_create_event_loop(), _stream=self._stream, ) @property def _is_polling(self) -> bool: """Whether this handle is actively polling for replica updates.""" return self.router.long_poll_client.is_running @property def _is_same_loop(self) -> bool: """Whether the caller's asyncio loop is the same loop for handle. This is only useful for async handles. """ return get_or_create_event_loop() == self.router._event_loop def _options( self, *, method_name: Union[str, DEFAULT] = DEFAULT.VALUE, multiplexed_model_id: Union[str, DEFAULT] = DEFAULT.VALUE, ): new_options_dict = self.handle_options.__dict__.copy() user_modified_options_dict = { key: value for key, value in zip(["method_name"], [method_name]) if value != DEFAULT.VALUE } new_options_dict.update(user_modified_options_dict) new_options = HandleOptions(**new_options_dict) if multiplexed_model_id != DEFAULT.VALUE: # If the user specifies model id, we need to update the RequestContext # to include the model_id. ray.serve.context._set_request_context( multiplexed_model_id=multiplexed_model_id ) return self.__class__( self.controller_handle, self.deployment_name, new_options, _router=self.router, _internal_pickled_http_request=self._pickled_http_request, )
[docs] def options( self, *, method_name: Union[str, DEFAULT] = DEFAULT.VALUE, multiplexed_model_id: Union[str, DEFAULT] = DEFAULT.VALUE, ) -> "RayServeHandle": """Set options for this handle and return an updated copy of it. Example: .. code-block:: python # The following two lines are equivalent: obj_ref = await handle.other_method.remote(*args) obj_ref = await handle.options(method_name="other_method").remote(*args) obj_ref = await handle.options( multiplexed_model_id="model:v1").remote(*args) """ return self._options( method_name=method_name, multiplexed_model_id=multiplexed_model_id )
def _remote(self, deployment_name, handle_options, args, kwargs) -> Coroutine: _request_context = ray.serve.context._serve_request_context.get() request_metadata = RequestMetadata( _request_context.request_id, deployment_name, call_method=handle_options.method_name, http_arg_is_pickled=self._pickled_http_request, route=_request_context.route, app_name=_request_context.app_name, multiplexed_model_id=_request_context.multiplexed_model_id, ) self.request_counter.inc( tags={ "route": _request_context.route, "application": _request_context.app_name, } ) coro = self.router.assign_request(request_metadata, *args, **kwargs) return coro
[docs] @_wrap_into_async_task async def remote(self, *args, **kwargs) -> asyncio.Task: """Issue an asynchronous request to the __call__ method of the deployment. Returns an `asyncio.Task` whose underlying result is a Ray ObjectRef that points to the final result of the request. The final result can be retrieved by awaiting the ObjectRef. Example: .. code-block:: python obj_ref = await handle.remote(*args) result = await obj_ref """ return await self._remote( self.deployment_name, self.handle_options, args, kwargs )
def __repr__(self): return f"{self.__class__.__name__}" f"(deployment='{self.deployment_name}')" @classmethod def _deserialize(cls, kwargs): """Required for this class's __reduce__ method to be picklable.""" return cls(**kwargs) def __reduce__(self): serialized_data = { "controller_handle": self.controller_handle, "deployment_name": self.deployment_name, "handle_options": self.handle_options, "_internal_pickled_http_request": self._pickled_http_request, } return RayServeHandle._deserialize, (serialized_data,) def __getattr__(self, name): return self.options(method_name=name)
[docs]@PublicAPI(stability="beta") class RayServeSyncHandle(RayServeHandle): """A handle used to make requests to the ingress deployment of an application. This is returned by `serve.run` and can be used to invoke the application from Python rather than over HTTP. For example: .. code-block:: python import ray from ray import serve from ray.serve.handle import RayServeSyncHandle @serve.deployment class Ingress: def __call__(self, name: str) -> str: return f"Hello {name}" app = Ingress.bind() handle: RayServeSyncHandle = serve.run(app) # Prints "Hello Mr. Magoo" print(ray.get(handle.remote("Mr. Magoo"))) """ @property def _is_same_loop(self) -> bool: # NOTE(simon): For sync handle, the caller doesn't have to be in the # same loop as the handle's loop, so we always return True here. return True def _make_router(self) -> Router: # Delayed import because ray.serve.api depends on handles. return Router( self.controller_handle, self.deployment_name, event_loop=_create_or_get_async_loop_in_thread(), _stream=self._stream, )
[docs] def options( self, *, method_name: Union[str, DEFAULT] = DEFAULT.VALUE, multiplexed_model_id: Union[str, DEFAULT] = DEFAULT.VALUE, ) -> "RayServeSyncHandle": """Set options for this handle and return an updated copy of it. Example: .. code-block:: python # The following two lines are equivalent: obj_ref = handle.other_method.remote(*args) obj_ref = handle.options(method_name="other_method").remote(*args) obj_ref = handle.options(multiplexed_model_id="model1").remote(*args) """ return self._options( method_name=method_name, multiplexed_model_id=multiplexed_model_id )
[docs] def remote(self, *args, **kwargs) -> ray.ObjectRef: """Issue an asynchronous request to the __call__ method of the deployment. Returns a Ray ObjectRef whose results can be waited for or retrieved using ray.wait or ray.get, respectively. .. code-block:: python obj_ref = handle.remote(*args) result = ray.get(obj_ref) """ coro = self._remote(self.deployment_name, self.handle_options, args, kwargs) future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( coro, self.router._event_loop ) return future.result()
def __reduce__(self): serialized_data = { "controller_handle": self.controller_handle, "deployment_name": self.deployment_name, "handle_options": self.handle_options, "_internal_pickled_http_request": self._pickled_http_request, } return RayServeSyncHandle._deserialize, (serialized_data,)
@DeveloperAPI class RayServeDeploymentHandle: """Send requests to a deployment. This class should not be manually created.""" # """Lazily initialized handle that only gets fulfilled upon first execution.""" def __init__( self, deployment_name: str, handle_options: Optional[HandleOptions] = None, ): self.deployment_name = deployment_name self.handle_options = handle_options or HandleOptions() # For Serve DAG we need placeholder in DAG binding and building without # requirement of serve.start; Thus handle is fulfilled at runtime. self.handle: RayServeHandle = None def options(self, *, method_name: str) -> "RayServeDeploymentHandle": return self.__class__( self.deployment_name, HandleOptions(method_name=method_name) ) def remote(self, *args, _ray_cache_refs: bool = False, **kwargs) -> asyncio.Task: if not self.handle: handle = serve._private.api.get_deployment( self.deployment_name )._get_handle(sync=FLAG_SERVE_DEPLOYMENT_HANDLE_IS_SYNC) self.handle = handle.options(method_name=self.handle_options.method_name) return self.handle.remote(*args, **kwargs) @classmethod def _deserialize(cls, kwargs): """Required for this class's __reduce__ method to be picklable.""" return cls(**kwargs) def __reduce__(self): serialized_data = { "deployment_name": self.deployment_name, "handle_options": self.handle_options, } return RayServeDeploymentHandle._deserialize, (serialized_data,) def __getattr__(self, name): return self.options(method_name=name) def __repr__(self): return f"{self.__class__.__name__}" f"(deployment='{self.deployment_name}')" def _serve_handle_to_json_dict(handle: RayServeHandle) -> Dict[str, str]: """Converts a Serve handle to a JSON-serializable dictionary. The dictionary can be converted back to a ServeHandle using _serve_handle_from_json_dict. """ if isinstance(handle, RayServeSyncHandle): handle_type = ServeHandleType.SYNC else: handle_type = ServeHandleType.ASYNC return { SERVE_HANDLE_JSON_KEY: handle_type, "deployment_name": handle.deployment_name, } def _serve_handle_from_json_dict(d: Dict[str, str]) -> RayServeHandle: """Converts a JSON-serializable dictionary back to a ServeHandle. The dictionary should be constructed using _serve_handle_to_json_dict. """ if SERVE_HANDLE_JSON_KEY not in d: raise ValueError(f"dict must contain {SERVE_HANDLE_JSON_KEY} key.") return serve.context.get_global_client().get_handle( d["deployment_name"], sync=d[SERVE_HANDLE_JSON_KEY] == ServeHandleType.SYNC, missing_ok=True, )