import asyncio
import copy
import os
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
)
import ray
from ray import serve
from ray._common.utils import import_attr
from ray.llm._internal.serve.constants import (
ENABLE_WORKER_PROCESS_SETUP_HOOK,
ENGINE_START_TIMEOUT_S,
MODEL_RESPONSE_BATCH_TIMEOUT_MS,
RAYLLM_VLLM_ENGINE_CLS_ENV,
)
from ray.llm._internal.serve.core.configs.llm_config import (
DiskMultiplexConfig,
LLMConfig,
)
from ray.llm._internal.serve.core.engine.protocol import LLMEngine
from ray.llm._internal.serve.core.protocol import LLMServerProtocol, RawRequestInfo
from ray.llm._internal.serve.observability.logging import get_logger
from ray.llm._internal.serve.observability.usage_telemetry.usage import (
push_telemetry_report_for_all_models,
)
from ray.llm._internal.serve.utils.batcher import Batcher
from ray.llm._internal.serve.utils.lora_serve_utils import (
LoraModelLoader,
)
from ray.llm._internal.serve.utils.server_utils import (
get_serve_request_id,
)
if TYPE_CHECKING:
from ray.llm._internal.serve.core.configs.openai_api_models import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse,
ErrorResponse,
ScoreRequest,
ScoreResponse,
TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
)
logger = get_logger(__name__)
T = TypeVar("T")
def _merge_replica_actor_and_child_actor_bundles(
child_actor_bundles: List[Dict[str, float]],
replica_actor_bundle: Dict[str, float],
) -> List[Dict[str, float]]:
"""Sum up the bundles from replica actor bundles with the first bundle from child actor bundles.
This is because the replica actor will use the first bundle in the list, and we want to collocate the replica actor with the child actor.
So we need to group them together.
So for example:
child_actor_bundles = [{"GPU": 1, "CPU": 1}, {"GPU": 1, "CPU": 1}]
replica_actor_bundle = {"GPU": 0, "CPU": 1, "memory": 100}
return [{"GPU": 1, "CPU": 2, "memory": 100}, {"GPU": 1, "CPU": 1}]
"""
if not child_actor_bundles:
return [copy.copy(replica_actor_bundle)]
if not replica_actor_bundle:
return [copy.copy(bundle) for bundle in child_actor_bundles]
original_first_bundle = child_actor_bundles[0]
bundle_key_set = set(original_first_bundle.keys()) | set(
replica_actor_bundle.keys()
)
merged_first_bundle = {
key: original_first_bundle.get(key, 0) + replica_actor_bundle.get(key, 0)
for key in bundle_key_set
}
return [merged_first_bundle] + [
copy.copy(bundle) for bundle in child_actor_bundles[1:]
]
class LLMServer(LLMServerProtocol):
"""This is a shim layer to decouple the LLM engine from the ingress
deployment.
It has a very similar API as the engine. Almost all of the abstractions are
implemented by the engine. This class just a little bit more logic on top:
1. Logic for serve multiplexing (e.g. LoRA loading).
2. Request id handing from serve context.
3. Batching in case of streaming (only for chat and completions).
4. Telemetry reporting.
Usage Patterns:
1. Basic pattern (for testing):
server = LLMServer.sync_init(llm_config) # Sync constructor, unstarted
await server.start() # Must explicitly start
2. Async context (default, used by Ray Serve):
server = await LLMServer(llm_config) # Async constructor, fully started
3. Ray Serve deployment:
# Ray Serve calls the async constructor directly
deployment = serve.deployment(LLMServer).bind(llm_config)
"""
_default_engine_cls = None
async def __init__(
self,
llm_config: LLMConfig,
*,
engine_cls: Optional[Type[LLMEngine]] = None,
model_downloader: Optional[Type[LoraModelLoader]] = None,
):
"""Asynchronous constructor that returns a fully started instance.
This is the default constructor used by Ray Serve deployments.
Args:
llm_config: LLMConfig for the model.
engine_cls: Dependency injection for the vllm engine class.
Defaults to `VLLMEngine`.
model_downloader: Dependency injection for the model downloader.
Defaults to `LoraModelLoader`.
"""
super().__init__()
self._init_shared(llm_config, engine_cls, model_downloader)
await self.start()
def _init_shared(
self,
llm_config: LLMConfig,
engine_cls: Optional[Type[LLMEngine]] = None,
model_downloader: Optional[Type[LoraModelLoader]] = None,
):
"""Shared initialization logic between constructors."""
self._llm_config = llm_config
self._engine_cls = engine_cls or self._get_default_engine_class()
self.engine: Optional[LLMEngine] = None
self._init_multiplex_loader(model_downloader)
[docs]
@classmethod
def sync_init(
cls,
llm_config: LLMConfig,
*,
engine_cls: Optional[Type[LLMEngine]] = None,
model_downloader: Optional[Type[LoraModelLoader]] = None,
) -> "LLMServer":
"""Synchronous constructor that returns an unstarted instance.
This is used for testing the new pattern where initialization
and starting are explicitly separated.
Args:
llm_config: LLMConfig for the model.
engine_cls: Dependency injection for the vllm engine class.
Defaults to `VLLMEngine`.
model_downloader: Dependency injection for the model downloader.
Defaults to `LoraModelLoader`.
Returns:
An unstarted LLMServer instance. Caller must call await start().
"""
instance = cls.__new__(cls)
LLMServerProtocol.__init__(instance)
instance._init_shared(llm_config, engine_cls, model_downloader)
return instance
[docs]
async def start(self):
"""Start the underlying engine. This handles async initialization."""
if self._engine_cls is not None:
self.engine = self._engine_cls(self._llm_config)
await asyncio.wait_for(self._start_engine(), timeout=ENGINE_START_TIMEOUT_S)
async def __serve_build_asgi_app__(self):
return await self.engine.build_asgi_app()
def _init_multiplex_loader(
self, model_downloader_cls: Optional[Type[LoraModelLoader]] = None
):
"""Initialize the multiplex loader."""
model_downloader_cls = model_downloader_cls or LoraModelLoader
mx_config = self._llm_config.multiplex_config()
if mx_config is not None:
model_downloader = model_downloader_cls(
download_timeout_s=mx_config.download_timeout_s,
max_tries=mx_config.max_download_tries,
)
async def _load_model(lora_model_id: str) -> DiskMultiplexConfig:
return await model_downloader.load_model_from_config(
lora_model_id=lora_model_id,
llm_config=self._llm_config,
)
self._load_model = serve.multiplexed(
max_num_models_per_replica=mx_config.max_num_models_per_replica
)(_load_model)
else:
async def _load_model(lora_model_id: str) -> DiskMultiplexConfig:
raise ValueError("LoRA config is not set in the LLMConfig")
self._load_model = _load_model
def _get_default_engine_class(self) -> Type[LLMEngine]:
"""Helper to load the engine class from the environment variable.
This is used for testing or escape-hatch for patching purposes.
If env variable is not set, it will fallback to the default engine class
(VLLMEngine, imported lazily to avoid a hard module-level dependency).
"""
engine_cls_path = os.environ.get(RAYLLM_VLLM_ENGINE_CLS_ENV)
if engine_cls_path:
return import_attr(engine_cls_path)
if self._default_engine_cls is not None:
return self._default_engine_cls
from ray.llm._internal.serve.engines.vllm.vllm_engine import VLLMEngine
return VLLMEngine
async def _start_engine(self):
if self.engine is None:
raise ValueError("Engine is not set")
await self.engine.start()
# Push telemetry reports for the model in the current deployment.
push_telemetry_report_for_all_models(all_models=[self._llm_config])
def _get_batch_interval_ms(self, stream: bool = True) -> int:
"""Calculate the batching interval for responses."""
stream_batching_interval_ms = self._llm_config.experimental_configs.get(
"stream_batching_interval_ms"
)
if stream_batching_interval_ms is None:
stream_batching_interval_ms = MODEL_RESPONSE_BATCH_TIMEOUT_MS
return stream_batching_interval_ms if stream else None
async def _maybe_add_request_id_to_request(
self,
request: Union[
"ChatCompletionRequest",
"CompletionRequest",
"EmbeddingRequest",
"TranscriptionRequest",
],
):
"""Add the request id to the request."""
request_id = get_serve_request_id()
if request_id:
request.request_id = request_id
async def _maybe_resolve_lora_from_multiplex(self) -> None:
"""Handle the lora model for the request."""
multiplexed_model_id = serve.get_multiplexed_model_id()
if multiplexed_model_id:
if self._llm_config.lora_config is None:
raise ValueError("Must setup lora config for multiplexed requests.")
disk_lora_model = await self._load_model(multiplexed_model_id)
await self.engine.resolve_lora(disk_lora_model)
def _batch_output_stream(
self, generator: AsyncGenerator[T, None]
) -> AsyncGenerator[List[T], None]:
return Batcher(
generator,
interval_ms=self._get_batch_interval_ms(),
).stream()
async def _run_request(
self,
request: Union[
"ChatCompletionRequest",
"CompletionRequest",
"EmbeddingRequest",
"TranscriptionRequest",
"ScoreRequest",
],
*,
engine_method: str,
batch_output_stream: bool = False,
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Any, None]:
"""Run the engine method on the request + perform batching when stream=True.
Args:
request: The request to run.
engine_method: The method to call on the engine.
batch_output_stream: Whether to batch the output stream.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator of the response. If stream is True and batching is enabled, then the generator will yield a list of streaming responses (strings of the format data: {response_json}\n\n). Otherwise, it will yield the non-streaming response from engine directly.
"""
await self._maybe_add_request_id_to_request(request)
await self._maybe_resolve_lora_from_multiplex()
is_stream = hasattr(request, "stream") and request.stream
engine_stream = getattr(self.engine, engine_method)(request, raw_request_info)
if is_stream and batch_output_stream:
stream = self._batch_output_stream(engine_stream)
else:
stream = engine_stream
return stream
[docs]
async def chat(
self,
request: "ChatCompletionRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "ChatCompletionResponse"], None
]:
"""Runs a chat request to the LLM engine and returns the response.
Args:
request: A ChatCompletionRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator of the response. If stream is True and batching
is enabled, then the generator will yield a list of chat streaming
responses (strings of the format data: {response_json}\\n\\n).
Otherwise, it will yield the ChatCompletionResponse object directly.
"""
return await self._run_request(
request,
engine_method="chat",
batch_output_stream=True,
raw_request_info=raw_request_info,
)
[docs]
async def completions(
self,
request: "CompletionRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "CompletionResponse"], None
]:
"""Runs a completion request to the LLM engine and returns the response.
Args:
request: A CompletionRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator of the response. If stream is True and batching
is enabled, then the generator will yield a list of completion
streaming responses (strings of the format data: {response_json}\\n\\n).
Otherwise, it will yield the CompletionResponse object directly.
"""
return await self._run_request(
request,
engine_method="completions",
batch_output_stream=True,
raw_request_info=raw_request_info,
)
[docs]
async def embeddings(
self,
request: "EmbeddingRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Union[List["ErrorResponse"], "EmbeddingResponse"], None]:
"""Runs an embeddings request to the engine and returns the response.
Returns an AsyncGenerator over the EmbeddingResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings and transcriptions.
Args:
request: An EmbeddingRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the EmbeddingResponse object.
"""
# NOTE: Embeddings does not need batching.
return await self._run_request(
request,
engine_method="embeddings",
batch_output_stream=False,
raw_request_info=raw_request_info,
)
[docs]
async def transcriptions(
self,
request: "TranscriptionRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[
Union[List[Union[str, "ErrorResponse"]], "TranscriptionResponse"], None
]:
"""Runs an transcriptions request to the engine and returns the response.
Returns an AsyncGenerator over the TranscriptionResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings and transcriptions.
Args:
request: A TranscriptionRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the TranscriptionResponse object.
"""
return await self._run_request(
request,
engine_method="transcriptions",
batch_output_stream=True,
raw_request_info=raw_request_info,
)
[docs]
async def score(
self,
request: "ScoreRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Union["ScoreResponse", "ErrorResponse"], None]:
"""Runs a score request to the engine and returns the response.
Returns an AsyncGenerator over the ScoreResponse object. This is so that the caller can have a consistent interface across all the methods of chat, completions, embeddings, and score.
Args:
request: A ScoreRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the ScoreResponse object.
"""
# NOTE: Score does not need batching, similar to embeddings.
return await self._run_request(
request,
engine_method="score",
batch_output_stream=False,
raw_request_info=raw_request_info,
)
[docs]
async def tokenize(
self,
request: "TokenizeRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Union["TokenizeResponse", "ErrorResponse"], None]:
"""Tokenize the input text.
Args:
request: A TokenizeRequest object (TokenizeCompletionRequest or TokenizeChatRequest).
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the TokenizeResponse object.
"""
# NOTE: Tokenize does not need batching.
return await self._run_request(
request,
engine_method="tokenize",
batch_output_stream=False,
raw_request_info=raw_request_info,
)
[docs]
async def detokenize(
self,
request: "DetokenizeRequest",
raw_request_info: Optional[RawRequestInfo] = None,
) -> AsyncGenerator[Union["DetokenizeResponse", "ErrorResponse"], None]:
"""Detokenize the input token IDs.
Args:
request: A DetokenizeRequest object.
raw_request_info: Optional RawRequestInfo containing data from the original
HTTP request.
Returns:
An AsyncGenerator over the DetokenizeResponse object.
"""
# NOTE: Detokenize does not need batching.
return await self._run_request(
request,
engine_method="detokenize",
batch_output_stream=False,
raw_request_info=raw_request_info,
)
[docs]
async def check_health(self) -> None:
"""
Check the health of the replica. Does not return anything. Raise error when
the engine is dead and needs to be restarted.
"""
if self.engine is None:
return
try:
return await self.engine.check_health()
except Exception as e:
logger.error("Engine health check failed in LLMServer.check_health: %s", e)
raise e
[docs]
async def sleep(self, **kwargs: Any) -> None:
"""Put the engine to sleep.
Args:
**kwargs: Engine-specific sleep options. Passed through to the engine.
"""
if self.engine is None:
return
try:
await self.engine.sleep(**kwargs)
except Exception as e:
logger.error("Engine sleep failed in LLMServer.sleep: %s", e)
raise e
[docs]
async def wakeup(self, **kwargs: Any) -> None:
"""Wake up the engine from sleep mode.
Args:
**kwargs: Engine-specific wakeup options. Passed through to the engine.
"""
if self.engine is None:
return
try:
await self.engine.wakeup(**kwargs)
except Exception as e:
logger.error("Engine wakeup failed in LLMServer.wakeup: %s", e)
raise e
[docs]
async def is_sleeping(self) -> bool:
"""Check whether the engine is currently sleeping.
Returns:
True if the engine is sleeping, False otherwise.
"""
if self.engine is None:
return False
try:
return await self.engine.is_sleeping()
except Exception as e:
logger.error("Engine is_sleeping failed in LLMServer.is_sleeping: %s", e)
raise e
[docs]
async def reset_prefix_cache(self) -> None:
"""Reset the KV prefix cache on the engine.
Clears cached key-value pairs from previous requests.
"""
if self.engine is None:
return
try:
await self.engine.reset_prefix_cache()
except Exception as e:
logger.error(
"Engine reset_prefix_cache failed in LLMServer.reset_prefix_cache: %s",
e,
)
raise e
[docs]
async def pause(self, **kwargs: Any) -> None:
"""Pause generation on the engine.
This halts generation requests while keeping model weights
in GPU memory. New requests are blocked until resume is called.
Args:
**kwargs: Engine-specific pause options. Passed through to the engine.
"""
if self.engine is None:
return
try:
await self.engine.pause(**kwargs)
except Exception as e:
logger.error("Engine pause failed in LLMServer.pause: %s", e)
raise e
[docs]
async def resume(self, **kwargs: Any) -> None:
"""Resume generation on the engine after pause.
Args:
**kwargs: Engine-specific resume options. Passed through to the engine.
"""
if self.engine is None:
return
try:
await self.engine.resume(**kwargs)
except Exception as e:
logger.error("Engine resume failed in LLMServer.resume: %s", e)
raise e
[docs]
async def is_paused(self) -> bool:
"""Check whether the engine is currently paused.
Returns:
True if the engine is paused, False otherwise.
"""
if self.engine is None:
return False
try:
return await self.engine.is_paused()
except Exception as e:
logger.error("Engine is_paused failed in LLMServer.is_paused: %s", e)
raise e
[docs]
async def start_profile(self) -> None:
"""Start profiling"""
if self.engine is None:
return
try:
await self.engine.start_profile()
except Exception as e:
logger.error(
"Engine start profile failed in LLMServer.start_profile: %s", e
)
raise e
[docs]
async def stop_profile(self) -> None:
"""Stop profiling"""
if self.engine is None:
return
try:
await self.engine.stop_profile()
except Exception as e:
logger.error("Engine stop profile failed in LLMServer.stop_profile: %s", e)
raise e
[docs]
async def collective_rpc(
self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
) -> list:
"""Execute a collective RPC call on all workers.
This is used for RLHF workflows where a trainer needs to execute
methods on all TP/PP workers (e.g., for weight synchronization).
Args:
method: Name of the worker method to execute.
timeout: Maximum time in seconds to wait for execution.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
"""
if self.engine is None:
return []
try:
return await self.engine.collective_rpc(
method=method,
timeout=timeout,
args=args,
kwargs=kwargs,
)
except Exception as e:
logger.error(
"Engine collective_rpc failed in LLMServer.collective_rpc: %s", e
)
raise e
async def llm_config(self) -> Optional[LLMConfig]:
return self._llm_config
@classmethod
def get_deployment_options(cls, llm_config: "LLMConfig"):
engine_config = llm_config.get_engine_config()
deployment_options = copy.deepcopy(llm_config.deployment_config)
if (
"placement_group_bundles" in llm_config.deployment_config
or "placement_group_strategy" in llm_config.deployment_config
):
raise ValueError(
"placement_group_bundles and placement_group_strategy must not be specified in deployment_config. You can override the default values by setting the `placement_group_config` in the LLMConfig."
)
# Handle the ray_actor_options that could be passed in to
# deployment_options
ray_actor_options = deployment_options.get("ray_actor_options", {})
if not engine_config.accelerator.requires_deferred_placement_group:
replica_actor_resources = {
"CPU": ray_actor_options.get("num_cpus", 1),
"GPU": ray_actor_options.get("num_gpus", 0),
**ray_actor_options.get("resources", {}),
}
if "memory" in ray_actor_options:
replica_actor_resources["memory"] = ray_actor_options["memory"]
# TODO: Move this _merge_replica_actor_and_child_actor_bundles to a
# more generic place.
pg_bundles = _merge_replica_actor_and_child_actor_bundles(
engine_config.placement_bundles, replica_actor_resources
)
deployment_options.update(
{
"placement_group_bundles": pg_bundles,
"placement_group_strategy": engine_config.placement_strategy,
}
)
# Handle env vars from runtime_env
default_runtime_env = ray.get_runtime_context().runtime_env
if ENABLE_WORKER_PROCESS_SETUP_HOOK:
default_runtime_env[
"worker_process_setup_hook"
] = "ray.llm._internal.serve._worker_process_setup_hook"
ray_actor_options = deployment_options.get("ray_actor_options", {})
ray_actor_options["runtime_env"] = {
**default_runtime_env,
# Existing runtime_env should take precedence over the default.
**ray_actor_options.get("runtime_env", {}),
**(llm_config.runtime_env if llm_config.runtime_env else {}),
}
deployment_options["ray_actor_options"] = ray_actor_options
return deployment_options