Source code for ray.llm._internal.serve.deployments.routers.router

import asyncio
import json
import os
from typing import (
    Any,
    AsyncGenerator,
    Awaitable,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
    Union,
)

# TODO (genesu): remove dependency on async_timeout.
import async_timeout
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from ray import serve
from ray._private.utils import get_or_create_event_loop
from ray.serve.handle import DeploymentHandle
from starlette.responses import JSONResponse, Response, StreamingResponse

from ray.llm._internal.serve.configs.constants import RAYLLM_ROUTER_HTTP_TIMEOUT
from ray.llm._internal.serve.observability.logging import get_logger
from ray.llm._internal.serve.observability.metrics.fast_api_metrics import (
    add_http_metrics_middleware,
    metrics_lifespan,
)
from ray.llm._internal.serve.deployments.llm.multiplex.utils import (
    get_base_model_id,
    get_lora_model_ids,
    get_lora_model_metadata,
)
from ray.llm._internal.serve.configs.openai_api_models import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatCompletionStreamResponse,
    CompletionRequest,
    CompletionResponse,
    CompletionStreamResponse,
    LLMChatResponse,
    LLMCompletionsResponse,
    OpenAIHTTPException,
    to_model_metadata,
)
from ray.llm._internal.serve.configs.openai_api_models_patch import (
    ErrorResponse,
)
from ray.llm._internal.serve.configs.server_models import (
    LLMConfig,
    ModelData,
    Model,
)
from ray.llm._internal.serve.deployments.routers.middleware import (
    SetRequestIdMiddleware,
    add_exception_handling_middleware,
)
from ray.llm._internal.serve.deployments.utils.server_utils import replace_prefix

logger = get_logger(__name__)


def init() -> FastAPI:
    _fastapi_router_app = FastAPI(lifespan=metrics_lifespan)

    # NOTE: PLEASE READ CAREFULLY BEFORE MODIFYING
    #
    # FastAPI middleware is executed in LIFO (last-in, first-out) order,
    # hence maintaining current ordering is crucial as some of the middleware
    # might have data dependency on the other: for ex, telemetry middleware
    # depends on middleware generating request-id
    #
    # Add exception handling middleware
    # NOTE: This middleware should be added first such that it's intercepting
    #       exceptions from the handlers, avoiding them propagating to other
    #       middleware (for ex, telemetry)
    add_exception_handling_middleware(_fastapi_router_app)
    # Configure CORS middleware
    _fastapi_router_app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    # Add HTTP metrics middleware
    add_http_metrics_middleware(_fastapi_router_app)

    # Inject unique per-request ID
    #
    # NOTE: This middleware should be executed among the last (since
    # middleware is executed in LIFO).
    _fastapi_router_app.add_middleware(SetRequestIdMiddleware)

    return _fastapi_router_app


fastapi_router_app = init()


def _apply_openai_json_format(
    response: Union[ChatCompletionStreamResponse, CompletionStreamResponse]
) -> str:
    """Converts a CompletionStreamResponse to OpenAI format.

    Each model response is converted to the string:
        data: <response-json1>\n\n

    The converted strings are concatenated and returned:
        data: <response-json1>\n\ndata: <response-json2>\n\n...
    """

    return "".join(f"data: {response.model_dump_json()}\n\n")


async def _openai_json_wrapper(
    generator: AsyncGenerator[
        Union[ChatCompletionStreamResponse, CompletionStreamResponse], None
    ],
    first_response: Union[ChatCompletionStreamResponse, CompletionStreamResponse],
) -> AsyncGenerator[str, None]:
    """Wrapper that converts CompletionStreamResponse into OpenAI JSON strings.

    Args:
        generator: an async generator that yields CompletionStreamResponse.
            Each response is converted into OpenAI JSON
            format. The jsonified responses from a list are concatenated
            together and yielded as a single string.
        first_response: the first CompletionStreamResponse to yield.

    Yields:
        Concatenated JSON strings that represent CompletionStreamResponse.
    """
    yield _apply_openai_json_format(first_response)

    async for response in generator:
        yield _apply_openai_json_format(response)

    yield "data: [DONE]\n\n"


async def _peek_at_openai_json_generator(
    generator: Union[LLMChatResponse, LLMCompletionsResponse],
) -> Tuple[
    Union[ChatCompletionStreamResponse, CompletionStreamResponse, ErrorResponse],
    AsyncGenerator[str, None],
]:
    """Runs one iteration of the underlying generator
    and returns the result, alongside the generator itself (with the
    first iteration still there).
    """
    first_response = await generator.__anext__()

    return first_response, _openai_json_wrapper(generator, first_response)


class LLMRouter:
    def __init__(
        self,
        llm_deployments: List[DeploymentHandle],
        *,
        _get_lora_model_metadata_func: Optional[
            Callable[[str, LLMConfig], Awaitable[Dict[str, Any]]]
        ] = None,
    ):
        self._default_serve_handles: Dict[str, DeploymentHandle] = {}
        self._llm_configs: Dict[str, LLMConfig] = {}

        # Configuring a ServeHandle with .options() creates a new ServeHandle
        # object, which contains a new metrics pusher and long-polling call.
        # Creating too many ServeHandles can impact event-loop and Serve Controller
        # performance, so we save configured ServeHandles here and reuse them.
        self._configured_serve_handles: Dict[str, DeploymentHandle] = {}
        self._get_lora_model_metadata_func = (
            _get_lora_model_metadata_func or self._default_get_lora_model_metadata_func
        )

        # Setup _default_serve_handles and _llm_configs asynchronously.
        self._init_completed = asyncio.Event()
        self.running_setup_task = get_or_create_event_loop().create_task(
            self._setup_handle_and_config_maps(llm_deployments=llm_deployments)
        )

    async def _default_get_lora_model_metadata_func(
        self, model_id: str, llm_config: LLMConfig
    ) -> Dict[str, Any]:
        return await get_lora_model_metadata(model_id, llm_config)

    async def _setup_handle_and_config_maps(
        self, llm_deployments: List[DeploymentHandle]
    ):
        for handle in llm_deployments:
            llm_config = await handle.llm_config.remote()
            self._default_serve_handles[llm_config.model_id] = handle
            self._llm_configs[llm_config.model_id] = llm_config

        # Note (genesu): Even though we have already checked model id uniqueness in
        # `router_application()` under run.py. When we OSS this router component, users
        # would be able to directly use the lower level api and bypass that check. We
        # check it again here to ensure all the model ids are unique.
        if len(llm_deployments) != len(self._llm_configs):
            raise ValueError("Duplicate models found. Make sure model ids are unique.")

        self._init_completed.set()

    async def check_health(self):
        await self._init_completed.wait()

    def _get_configured_serve_handle(self, model_id: str):
        """Gets a ServeHandle to a model deployment.

        Configures the handle's options, and stores it in a cache.

        If the model_id includes LoRA suffix, we set the model ID as
        the multiplexed_model_id, so the request uses Serve's multiplexed
        routing logic.

        If the model_id is a base model- even if the model has LoRA
        adapters- we don't set multiplexed_model_id. Setting
        multiplexed_model_id would cause base model requests to be
        sent to a single model replica, instead of being load
        balanced across all replicas. This is undesirable for base
        model requests (unlike LoRA requests) because all the replicas
        have a copy of the base model.
        """

        if model_id not in self._configured_serve_handles:
            base_model_id = get_base_model_id(model_id)
            if base_model_id in self._default_serve_handles:
                if model_id == base_model_id:
                    default_handle = self._default_serve_handles[model_id]
                    configured_handle = default_handle.options(stream=True)
                    self._configured_serve_handles[model_id] = configured_handle
                else:
                    default_handle = self._default_serve_handles[base_model_id]
                    configured_handle = default_handle.options(
                        stream=True,
                        multiplexed_model_id=model_id,
                    )
                    self._configured_serve_handles[model_id] = configured_handle
            else:
                raise HTTPException(
                    status.HTTP_404_NOT_FOUND,
                    f'Could not find model with id "{model_id}".',
                )

        return self._configured_serve_handles[model_id]

    async def _get_response(
        self,
        *,
        body: Union[CompletionRequest, ChatCompletionRequest],
        call_method: str,
    ) -> AsyncGenerator[Union[LLMChatResponse, LLMCompletionsResponse], None]:
        """Calls the model deployment and returns the stream."""
        model: str = body.model
        base_model_id = get_base_model_id(model)
        if base_model_id not in self._llm_configs:
            raise HTTPException(
                status.HTTP_404_NOT_FOUND,
                f'Got request for model "{model}". '
                f'Could not find base model with ID "{base_model_id}".',
            )

        model_handle = self._get_configured_serve_handle(model)

        async for response in getattr(model_handle, call_method).remote(body):
            yield response

    async def model(self, model_id: str) -> Optional[ModelData]:
        if model_id in self._llm_configs:
            return to_model_metadata(model_id, self._llm_configs[model_id])

        base_model_id = get_base_model_id(model_id)
        if (
            base_model_id in self._llm_configs
            and self._llm_configs[base_model_id].lora_config
        ):
            try:
                overrides = await self._get_lora_model_metadata_func(
                    model_id, self._llm_configs[base_model_id]
                )

                return to_model_metadata(
                    model_id=model_id,
                    model_config=self._llm_configs[base_model_id],
                    overrides=overrides,
                )
            except HTTPException:
                logger.exception(
                    "Unable to retrieve LoRA adapter config file for "
                    f'"{model_id}". Omitting it from list of available models. '
                    "Check that adapter config file exists in cloud bucket."
                )

    @fastapi_router_app.get("/v1/models", response_model=Model)
    async def models(self) -> Model:
        """OpenAI API-compliant endpoint to get all rayllm models."""
        all_models = dict()
        for base_model_id, llm_config in self._llm_configs.items():
            # Add the base model.
            all_models[base_model_id] = await self.model(base_model_id)

            if llm_config.lora_config is not None:
                # Add all the fine-tuned models.
                lora_model_ids = get_lora_model_ids(
                    dynamic_lora_loading_path=llm_config.lora_config.dynamic_lora_loading_path,
                    base_model_id=base_model_id,
                )
                for lora_id in lora_model_ids:
                    model_data = await self.model(lora_id)
                    if model_data is not None:
                        all_models[lora_id] = model_data

        return Model(data=list(all_models.values()))

    # :path allows us to have slashes in the model name
    @fastapi_router_app.get("/v1/models/{model:path}", response_model=ModelData)
    async def model_data(self, model: str) -> ModelData:
        """OpenAI API-compliant endpoint to get one rayllm model.

        :param model: The model ID (e.g. "amazon/LightGPT")
        """
        model = replace_prefix(model)
        model_data = await self.model(model)
        if model_data is None:
            raise OpenAIHTTPException(
                message=f"Unable to find {model}. Please ensure that the model exists and you have permission.",
                status_code=status.HTTP_404_NOT_FOUND,
                type="InvalidModel",
            )
        return model_data

    @fastapi_router_app.post("/v1/completions")
    async def completions(self, body: CompletionRequest) -> Response:
        """Given a prompt, the model will return one or more predicted completions,
        and can also return the probabilities of alternative tokens at each position.

        Returns:
            A response object with completions.
        """
        async with async_timeout.timeout(RAYLLM_ROUTER_HTTP_TIMEOUT):
            results = self._get_response(body=body, call_method="completions")
            if body.stream:
                first_response, wrapper = await _peek_at_openai_json_generator(results)
                if isinstance(first_response, ErrorResponse):
                    raise OpenAIHTTPException(
                        message=first_response.message,
                        status_code=first_response.code,
                        type=first_response.type,
                    )
                return StreamingResponse(wrapper, media_type="text/event-stream")

            result = await results.__anext__()
            if isinstance(result, ErrorResponse):
                raise OpenAIHTTPException(
                    message=result.message,
                    status_code=result.code,
                    type=result.type,
                )

            if isinstance(result, CompletionResponse):
                return JSONResponse(content=result.model_dump())

    @fastapi_router_app.post("/v1/chat/completions")
    async def chat(self, body: ChatCompletionRequest) -> Response:
        """Given a prompt, the model will return one or more predicted completions,
        and can also return the probabilities of alternative tokens at each position.

        Returns:
            A response object with completions.
        """

        async with async_timeout.timeout(RAYLLM_ROUTER_HTTP_TIMEOUT):
            results = self._get_response(body=body, call_method="chat")
            if body.stream:
                first_response, wrapper = await _peek_at_openai_json_generator(results)
                if isinstance(first_response, ErrorResponse):
                    raise OpenAIHTTPException(
                        message=first_response.message,
                        status_code=first_response.code,
                        type=first_response.type,
                    )
                return StreamingResponse(wrapper, media_type="text/event-stream")

            result = await results.__anext__()
            if isinstance(result, ErrorResponse):
                raise OpenAIHTTPException(
                    message=result.message,
                    status_code=result.code,
                    type=result.type,
                )

            if isinstance(result, ChatCompletionResponse):
                return JSONResponse(content=result.model_dump())

[docs] @classmethod def as_deployment(cls) -> serve.Deployment: """Converts this class to a Ray Serve deployment with ingress. Returns: A Ray Serve deployment. """ ingress_cls = serve.ingress(fastapi_router_app)(cls) deployment_decorator = serve.deployment( # TODO (Kourosh): make this configurable autoscaling_config={ "min_replicas": int(os.environ.get("RAYLLM_ROUTER_MIN_REPLICAS", 0)), "initial_replicas": int( os.environ.get("RAYLLM_ROUTER_INITIAL_REPLICAS", 2) ), "max_replicas": int(os.environ.get("RAYLLM_ROUTER_MAX_REPLICAS", 16)), "target_ongoing_requests": int( os.environ.get( "RAYLLM_ROUTER_TARGET_ONGOING_REQUESTS", os.environ.get( "RAYLLM_ROUTER_TARGET_NUM_ONGOING_REQUESTS_PER_REPLICA", 200 ), ) ), }, ray_actor_options=json.loads( os.environ.get("RAYLLM_ROUTER_RAY_ACTOR_OPTIONS", "{}") ), max_ongoing_requests=1000, # Maximum backlog for a single replica ) deployment_cls = deployment_decorator(ingress_cls) return deployment_cls