import asyncio
import copy
import os
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
)
import ray
from ray import serve
from ray._common.utils import import_attr
from ray.llm._internal.serve.constants import (
ENABLE_WORKER_PROCESS_SETUP_HOOK,
ENGINE_START_TIMEOUT_S,
MODEL_RESPONSE_BATCH_TIMEOUT_MS,
RAYLLM_VLLM_ENGINE_CLS_ENV,
)
from ray.llm._internal.serve.core.configs.llm_config import (
DiskMultiplexConfig,
LLMConfig,
)
from ray.llm._internal.serve.core.engine.protocol import LLMEngine
from ray.llm._internal.serve.core.protocol import LLMServerProtocol, RawRequestInfo
from ray.llm._internal.serve.engines.vllm.vllm_engine import VLLMEngine
from ray.llm._internal.serve.observability.logging import get_logger
from ray.llm._internal.serve.observability.usage_telemetry.usage import (
push_telemetry_report_for_all_models,
)
from ray.llm._internal.serve.utils.batcher import Batcher
from ray.llm._internal.serve.utils.lora_serve_utils import (
LoraModelLoader,
)
from ray.llm._internal.serve.utils.server_utils import (
get_serve_request_id,
)
if TYPE_CHECKING:
from ray.llm._internal.serve.core.configs.openai_api_models import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
ScoreRequest,
ScoreResponse,
TranscriptionRequest,
TranscriptionResponse,
)
logger = get_logger(__name__)
T = TypeVar("T")
def _merge_replica_actor_and_child_actor_bundles(
child_actor_bundles: List[Dict[str, float]],
replica_actor_bundle: Dict[str, float],
) -> List[Dict[str, float]]:
"""Sum up the bundles from replica actor bundles with the first bundle from child actor bundles.
This is because the replica actor will use the first bundle in the list, and we want to collocate the replica actor with the child actor.
So we need to group them together.
So for example:
child_actor_bundles = [{"GPU": 1, "CPU": 1}, {"GPU": 1, "CPU": 1}]
replica_actor_bundle = {"GPU": 0, "CPU": 1, "memory": 100}
return [{"GPU": 1, "CPU": 2, "memory": 100}, {"GPU": 1, "CPU": 1}]
"""
if not child_actor_bundles:
return [copy.copy(replica_actor_bundle)]
if not replica_actor_bundle:
return [copy.copy(bundle) for bundle in child_actor_bundles]
original_first_bundle = child_actor_bundles[0]
bundle_key_set = set(original_first_bundle.keys()) | set(
replica_actor_bundle.keys()
)
merged_first_bundle = {
key: original_first_bundle.get(key, 0) + replica_actor_bundle.get(key, 0)
for key in bundle_key_set
}
return [merged_first_bundle] + [
copy.copy(bundle) for bundle in child_actor_bundles[1:]
]
class LLMServer(LLMServerProtocol):
"""This is a shim layer to decouple the LLM engine from the ingress
deployment.
It has a very similar API as the engine. Almost all of the abstractions are
implemented by the engine. This class just a little bit more logic on top:
1. Logic for serve multiplexing (e.g. LoRA loading).
2. Request id handing from serve context.
3. Batching in case of streaming (only for chat and completions).
4. Telemetry reporting.
Usage Patterns:
1. Basic pattern (for testing):
server = LLMServer.sync_init(llm_config) # Sync constructor, unstarted
await server.start() # Must explicitly start
2. Async context (default, used by Ray Serve):
server = await LLMServer(llm_config) # Async constructor, fully started
3. Ray Serve deployment:
# Ray Serve calls the async constructor directly
deployment = serve.deployment(LLMServer).bind(llm_config)
"""
_default_engine_cls = VLLMEngine
async def __init__(
self,
llm_config: LLMConfig,
*,
engine_cls: Optional[Type[LLMEngine]] = None,
model_downloader: Optional[Type[LoraModelLoader]] = None,
):
"""Asynchronous constructor that returns a fully started instance.
This is the default constructor used by Ray Serve deployments.
Args:
llm_config: LLMConfig for the model.
engine_cls: Dependency injection for the vllm engine class.
Defaults to `VLLMEngine`.
model_downloader: Dependency injection for the model downloader.
Defaults to `LoraModelLoader`.
"""
super().__init__()
self._init_shared(llm_config, engine_cls, model_downloader)
await self.start()
def _init_shared(
self,
llm_config: LLMConfig,
engine_cls: Optional[Type[LLMEngine]] = None,
model_downloader: Optional[Type[LoraModelLoader]] = None,
):
"""Shared initialization logic between constructors."""
self._llm_config = llm_config
self._engine_cls = engine_cls or self._get_default_engine_class()
self.engine: Optional[LLMEngine] = None
self._init_multiplex_loader(model_downloader)
[docs]
@classmethod
def sync_init(
cls,
llm_config: LLMConfig,
*,
engine_cls: Optional[Type[LLMEngine]] = None,
model_downloader: Optional[Type[LoraModelLoader]] = None,
) -> "LLMServer":
"""Synchronous constructor that returns an unstarted instance.
This is used for testing the new pattern where initialization
and starting are explicitly separated.
Args:
llm_config: LLMConfig for the model.
engine_cls: Dependency injection for the vllm engine class.
Defaults to `VLLMEngine`.
model_downloader: Dependency injection for the model downloader.
Defaults to `LoraModelLoader`.
Returns:
An unstarted LLMServer instance. Caller must call await start().
"""
instance = cls.__new__(cls)
LLMServerProtocol.__init__(instance)
instance._init_shared(llm_config, engine_cls, model_downloader)
return instance
[docs]
async def start(self):
"""Start the underlying engine. This handles async initialization."""
if self._engine_cls is not None:
self.engine = self._engine_cls(self._llm_config)
await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S)
def _init_multiplex_loader(
self, model_downloader_cls: Optional[Type[LoraModelLoader]] = None
):
"""Initialize the multiplex loader."""
model_downloader_cls = model_downloader_cls or LoraModelLoader
mx_config = self._llm_config.multiplex_config()
if mx_config is not None:
model_downloader = model_downloader_cls(
download_timeout_s=mx_config.download_timeout_s,
max_tries=mx_config.max_download_tries,
)
async def _load_model(lora_model_id: str) -> DiskMultiplexConfig:
return await model_downloader.load_model_from_config(
lora_model_id=lora_model_id,
llm_config=self._llm_config,
)
self._load_model = serve.multiplexed(
max_num_models_per_replica=mx_config.max_num_models_per_replica
)(_load_model)
else:
async def _load_model(lora_model_id: str) -> DiskMultiplexConfig:
raise ValueError("LoRA config is not set in the LLMConfig")
self._load_model = _load_model
def _get_default_engine_class(self) -> Type[LLMEngine]:
"""Helper to load the engine class from the environment variable.
This is used for testing or escape-hatch for patching purposes.
If env variable is not set, it will fallback to the default engine class.
"""
engine_cls_path = os.environ.get(RAYLLM_VLLM_ENGINE_CLS_ENV)
if engine_cls_path:
return import_attr(engine_cls_path)
return self._default_engine_cls
async def _start_engine(self):
if self.engine is None:
raise ValueError("Engine is not set")
await self.engine.start()
# Push telemetry reports for the model in the current deployment.
push_telemetry_report_for_all_models(all_models=[self._llm_config])
def _get_batch_interval_ms(self, stream: bool = True) -> int:
"""Calculate the batching interval for responses."""
stream_batching_interval_ms = self._llm_config.experimental_configs.get(
"stream_batching_interval_ms"
)
if stream_batching_interval_ms is None:
stream_batching_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS
return stream_batching_interval_ms if stream else None
async def _maybe_add_request_id_to_request(
self,
request: Union[
"ChatCompletionRequest",
"CompletionRequest",
"EmbeddingRequest",
"TranscriptionRequest",
],
):
"""Add the request id to the request."""
request_id = get_serve_request_id()
if request_id:
request.request_id = request_id
async def _maybe_resolve_lora_from_multiplex(self) -> None:
"""Handle the lora model for the request."""
multiplexed_model_id = serve.get_multiplexed_model_id()
if multiplexed_model_id:
if self._llm_config.lora_config is None:
raise ValueError("Must setup lora config for multiplexed requests.")
disk_lora_model = await self._load_model(multiplexed_model_id)
await self.engine.resolve_lora(disk_lora_model)
def _batch_output_stream(
self, generator: AsyncGenerator[T, None]
) -> AsyncGenerator[List[T], None]:
return Batcher(
generator,
interval_ms=self._get_batch_interval_ms(),
).stream()
async def _run_request(
self,
request: Union[
"ChatCompletionRequest",
"CompletionRequest",
"EmbeddingRequest",
"TranscriptionRequest",
"ScoreRequest",
],
*,
engine_method: str,
batch_output_stream: bool = False,
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Any, None]:
"""Run the engine method on the request + perform batching when stream=True.
Args:
request: The request to run.
engine_method: The method to call on the engine.
batch_output_stream: Whether to batch the output stream.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the non-streaming response from engine directly.
"""
await self._maybe_add_request_id_to_request(request)
await self._maybe_resolve_lora_from_multiplex()
is_stream = hasattr(request, "stream") and request.stream
engine_stream = getattr(self.engine, engine_method)(request, raw_request_info)
if is_stream and batch_output_stream:
stream = self._batch_output_stream(engine_stream)
else:
stream = engine_stream
return stream
[docs]
async def chat(
self,
request: "ChatCompletionRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "ChatCompletionResponse"], None
]:
"""Runs a chat request to the LLM engine and returns the response.
Args:
request: A ChatCompletionRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator of the response. If stream is True and batching
is enabled, then the generator will yield a list of chat streaming
responses (strings of the format data: {response_json}\\n\\n).
Otherwise, it will yield the ChatCompletionResponse object directly.
"""
return await self._run_request(
request,
engine_method="chat",
batch_output_stream=True,
raw_request_info=raw_request_info,
)
[docs]
async def completions(
self,
request: "CompletionRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "CompletionResponse"], None
]:
"""Runs a completion request to the LLM engine and returns the response.
Args:
request: A CompletionRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator of the response. If stream is True and batching
is enabled, then the generator will yield a list of completion
streaming responses (strings of the format data: {response_json}\\n\\n).
Otherwise, it will yield the CompletionResponse object directly.
"""
return await self._run_request(
request,
engine_method="completions",
batch_output_stream=True,
raw_request_info=raw_request_info,
)
[docs]
async def embeddings(
self,
request: "EmbeddingRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Union[List["ErrorResponse"], "EmbeddingResponse"], None]:
"""Runs an embeddings request to the engine and returns the response.
Returns an AsyncGenerator over the EmbeddingResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings and transcriptions.
Args:
request: An EmbeddingRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the EmbeddingResponse object.
"""
# NOTE: Embeddings does not need batching.
return await self._run_request(
request,
engine_method="embeddings",
batch_output_stream=False,
raw_request_info=raw_request_info,
)
[docs]
async def transcriptions(
self,
request: "TranscriptionRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "TranscriptionResponse"], None
]:
"""Runs an transcriptions request to the engine and returns the response.
Returns an AsyncGenerator over the TranscriptionResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings and transcriptions.
Args:
request: A TranscriptionRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the TranscriptionResponse object.
"""
return await self._run_request(
request,
engine_method="transcriptions",
batch_output_stream=True,
raw_request_info=raw_request_info,
)
[docs]
async def score(
self,
request: "ScoreRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Union["ScoreResponse", "ErrorResponse"], None]:
"""Runs a score request to the engine and returns the response.
Returns an AsyncGenerator over the ScoreResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings, and score.
Args:
request: A ScoreRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the ScoreResponse object.
"""
# NOTE: Score does not need batching, similar to embeddings.
return await self._run_request(
request,
engine_method="score",
batch_output_stream=False,
raw_request_info=raw_request_info,
)
[docs]
async def check_health(self) -> None:
"""
Check the health of the replica. Does not return anything. Raise error when
the engine is dead and needs to be restarted.
"""
if self.engine is None:
return
try:
return await self.engine.check_health()
except Exception as e:
logger.error("Engine health check failed in LLMServer.check_health: %s", e)
raise e
[docs]
async def reset_prefix_cache(self) -> None:
"""Reset the prefix cache of the underlying engine"""
if self.engine is None:
return
try:
await self.engine.reset_prefix_cache()
except Exception as e:
logger.error(
"Engine reset prefix cache failed in LLMServer.reset_prefix_cache: %s",
e,
)
raise e
[docs]
async def start_profile(self) -> None:
"""Start profiling"""
if self.engine is None:
return
try:
await self.engine.start_profile()
except Exception as e:
logger.error(
"Engine start profile failed in LLMServer.start_profile: %s", e
)
raise e
[docs]
async def stop_profile(self) -> None:
"""Stop profiling"""
if self.engine is None:
return
try:
await self.engine.stop_profile()
except Exception as e:
logger.error("Engine stop profile failed in LLMServer.stop_profile: %s", e)
raise e
async def llm_config(self) -> Optional[LLMConfig]:
return self._llm_config
@classmethod
def get_deployment_options(cls, llm_config: "LLMConfig"):
engine_config = llm_config.get_engine_config()
deployment_options = copy.deepcopy(llm_config.deployment_config)
# Handle the ray_actor_options that could be passed in to
# deployment_options
ray_actor_options = deployment_options.get("ray_actor_options", {})
replica_actor_resources = {
"CPU": ray_actor_options.get("num_cpus", 1),
"GPU": ray_actor_options.get("num_gpus", 0),
**ray_actor_options.get("resources", {}),
}
if "memory" in ray_actor_options:
replica_actor_resources["memory"] = ray_actor_options["memory"]
if (
"placement_group_bundles" in llm_config.deployment_config
or "placement_group_strategy" in llm_config.deployment_config
):
raise ValueError(
"placement_group_bundles and placement_group_strategy must not be specified in deployment_config. You can override the default values by setting the `placement_group_config` in the LLMConfig."
)
# TODO: Move this _merge_replica_actor_and_child_actor_bundles to a
# more generic place.
pg_bundles = _merge_replica_actor_and_child_actor_bundles(
engine_config.placement_bundles, replica_actor_resources
)
deployment_options.update(
{
"placement_group_bundles": pg_bundles,
"placement_group_strategy": engine_config.placement_strategy,
}
)
# Handle env vars from runtime_env
default_runtime_env = ray.get_runtime_context().runtime_env
if ENABLE_WORKER_PROCESS_SETUP_HOOK:
default_runtime_env[
"worker_process_setup_hook"
] = "ray.llm._internal.serve._worker_process_setup_hook"
ray_actor_options = deployment_options.get("ray_actor_options", {})
ray_actor_options["runtime_env"] = {
**default_runtime_env,
# Existing runtime_env should take precedence over the default.
**ray_actor_options.get("runtime_env", {}),
**(llm_config.runtime_env if llm_config.runtime_env else {}),
}
deployment_options["ray_actor_options"] = ray_actor_options
return deployment_options