Source code for ray.llm._internal.serve.configs.server_models

import pydantic
import os
import ray

from enum import Enum
from ray.llm._internal.serve.configs.error_handling import TooManyStoppingSequences

from typing import (
    Any,
    Dict,
    List,
    Optional,
    Type,
    TypeVar,
    Union,
    Tuple,
    Sequence,
    Set,
)
import time
from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    PositiveInt,
    PrivateAttr,
    field_validator,
    model_validator,
)

from ray.llm._internal.utils import try_import


from ray.llm._internal.serve.observability.logging import get_logger
import ray.util.accelerators.accelerators as accelerators

from ray.llm._internal.serve.configs.constants import (
    DEFAULT_MULTIPLEX_DOWNLOAD_TIMEOUT_S,
    DEFAULT_MULTIPLEX_DOWNLOAD_TRIES,
    DEFAULT_TARGET_ONGOING_REQUESTS,
    MAX_NUM_STOPPING_SEQUENCES,
    ENABLE_WORKER_PROCESS_SETUP_HOOK,
)
from ray.llm._internal.serve.configs.prompt_formats import (
    Prompt,
    HuggingFacePromptFormat,
)
from ray.llm._internal.serve.configs.openai_api_models_patch import (
    ErrorResponse,
    ResponseFormatType,
)
from ray.llm._internal.serve.configs.base import BaseModelExtended

transformers = try_import("transformers")


GPUType = Enum("GPUType", vars(accelerators))
ModelT = TypeVar("ModelT", bound=BaseModel)


logger = get_logger(__name__)


class ExtraFiles(BaseModelExtended):
    bucket_uri: str
    destination_path: str


class MirrorConfig(BaseModelExtended):
    bucket_uri: Optional[str] = None
    extra_files: List[ExtraFiles] = Field(default_factory=list)


class S3AWSCredentials(BaseModelExtended):
    create_aws_credentials_url: str
    auth_token_env_variable: Optional[str] = None


class GCSMirrorConfig(MirrorConfig):
[docs] @field_validator("bucket_uri") @classmethod def check_uri_format(cls, value): if not value.startswith("gs://"): raise ValueError( f'Got invalid value "{value}" for bucket_uri. ' 'Expected a URI that starts with "gs://".' ) return value
class S3MirrorConfig(MirrorConfig): s3_sync_args: Optional[List[str]] = None s3_aws_credentials: Optional[S3AWSCredentials] = None
[docs] @field_validator("bucket_uri") @classmethod def check_uri_format(cls, value): if value and not value.startswith("s3://"): raise ValueError( f'Got invalid value "{value}" for bucket_uri. ' 'Expected a URI that starts with "s3://".' ) return value
class AutoscalingConfig(BaseModel, extra="allow"): """ The model here provides reasonable defaults for llm model serving. Please note that field descriptions may be exposed to the end users. """ min_replicas: int = Field( 1, description="min_replicas is the minimum number of replicas for the deployment.", ) initial_replicas: int = Field( 1, description="The number of replicas that are started initially for the deployment.", ) max_replicas: int = Field( 100, description="max_replicas is the maximum number of replicas for the deployment.", ) target_ongoing_requests: Optional[int] = Field( None, description="target_ongoing_requests is the maximum number of queries that are sent to a replica of this deployment without receiving a response.", ) target_num_ongoing_requests_per_replica: Optional[int] = Field( None, description="target_num_ongoing_requests_per_replica is the deprecated field." "If it is set, the model will set target_ongoing_requests to that value too." "If neither field is set, DEFAULT_TARGET_ONGOING_REQUESTS will be used.", exclude=True, ) metrics_interval_s: float = Field( 10.0, description="How often to scrape for metrics in seconds." ) look_back_period_s: float = Field( 30.0, description="Time window to average over for metrics, in seconds." ) downscale_delay_s: float = Field( 300.0, description="How long to wait before scaling down replicas, in seconds." ) upscale_delay_s: float = Field( 10.0, description="How long to wait before scaling up replicas, in seconds." ) @model_validator(mode="before") def sync_target_ongoing_requests(cls, values): """This is a temporary validator to sync the target_ongoing_requests and target_num_ongoing_requests_per_replica fields. """ target_ongoing_requests = values.get("target_ongoing_requests", None) target_num_ongoing_requests_per_replica = values.get( "target_num_ongoing_requests_per_replica", None ) final_val = ( target_ongoing_requests or target_num_ongoing_requests_per_replica or DEFAULT_TARGET_ONGOING_REQUESTS ) values["target_ongoing_requests"] = final_val values["target_num_ongoing_requests_per_replica"] = final_val return values class ServeMultiplexConfig(BaseModelExtended): max_num_models_per_replica: PositiveInt = Field( ..., description="The maximum number of models to be loaded on each replica." ) download_timeout_s: Optional[float] = Field( DEFAULT_MULTIPLEX_DOWNLOAD_TIMEOUT_S, description="How much time the download subprocess has to download a single LoRA before a timeout. None means no timeout.", ) max_download_tries: int = Field( DEFAULT_MULTIPLEX_DOWNLOAD_TRIES, description="The maximum number of download retries.", ) # See: https://docs.ray.io/en/latest/serve/configure-serve-deployment.html class DeploymentConfig(BaseModelExtended): class Config: extra = "forbid" autoscaling_config: Optional[AutoscalingConfig] = Field( default=None, description="Configuration for autoscaling the number of workers", ) max_ongoing_requests: Optional[int] = Field( None, description="Sets the maximum number of queries in flight that are sent to a single replica.", ) ray_actor_options: Optional[Dict[str, Any]] = Field( None, description="the Ray actor options to pass into the replica's actor." ) graceful_shutdown_timeout_s: int = Field( 300, description="Controller waits for this duration to forcefully kill the replica for shutdown, in seconds.", ) class InputModality(str, Enum): text = "text" image = "image" class LLMEngine(str, Enum): """Enum that represents an LLMEngine.""" VLLM = "VLLM" class JSONModeOptions(BaseModelExtended): num_processes: int = Field( default=8, description="The number of background processes for each replica.", ) recreate_failed_actors: bool = Field( default=True, description="Whether to restart failed JSON mode actors." ) class LoraConfig(BaseModelExtended): dynamic_lora_loading_path: Optional[str] = Field( default=None, description="Cloud storage path where LoRA adapter weights are stored.", ) max_num_adapters_per_replica: PositiveInt = Field( default=16, description="The maximum number of adapters load on each replica.", ) download_timeout_s: Optional[float] = Field( DEFAULT_MULTIPLEX_DOWNLOAD_TIMEOUT_S, description=( "How much time the download subprocess has to download a single " "LoRA before a timeout. None means no timeout." ), ) max_download_tries: int = Field( DEFAULT_MULTIPLEX_DOWNLOAD_TRIES, description="The maximum number of download retries.", )
[docs] @field_validator("dynamic_lora_loading_path") def validate_dynamic_lora_loading_path(cls, value: Optional[str]): if value is None: return value assert value.startswith("s3://") or value.startswith("gs://"), ( "Only AWS S3 and Google Cloud Storage are supported. The " 'dynamic_lora_loading_path must start with "s3://" or "gs://". ' f'Got "{value}" instead.' ) return value.rstrip("/")
class ModelLoadingConfig(BaseModelExtended): model_id: str = Field( description="The ID that should be used by end users to access this model.", ) model_source: Optional[Union[str, S3MirrorConfig, GCSMirrorConfig]] = Field( default=None, description=( "Where to obtain the model weights from. " "Should be a HuggingFace model ID, S3 mirror config, or GCS " "mirror config. When omitted, defaults to the model_id as a " "HuggingFace model ID." ), ) tokenizer_source: Optional[str] = Field( default=None, description=( "Where to obtain the tokenizer from. If None, tokenizer is " "obtained from the model source. Only HuggingFace IDs are " "supported for now." ), ) class LLMConfig(BaseModelExtended): # model_config is a Pydantic setting. This setting merges with # model_configs in parent classes. model_config = ConfigDict( extra="forbid", ) runtime_env: Optional[Dict[str, Any]] = Field( None, description=( "The runtime_env to use for the model deployment replica " "and the engine workers." ), ) model_loading_config: ModelLoadingConfig = Field( description="The settings for how to download and expose the model." ) llm_engine: str = Field( default=LLMEngine.VLLM.value, description=f"The LLMEngine that should be used to run the model. Only the following values are supported: {str([t.value for t in LLMEngine])}", ) engine_kwargs: Dict[str, Any] = Field( default={}, description=( "Additional keyword arguments for the engine. In case of vLLM, " "this will include all the configuration knobs they provide out " "of the box, except for tensor-parallelism which is set " "automatically from Ray Serve configs." ), ) accelerator_type: str = Field( description=f"The type of accelerator runs the model on. Only the following values are supported: {str([t.value for t in GPUType])}", ) lora_config: Optional[LoraConfig] = Field( default=None, description="Settings for LoRA adapter." ) deployment_config: Dict[str, Any] = Field( default_factory=dict, description="The Ray @server.deployment options. See @server.deployment for more details.", ) _supports_vision: bool = PrivateAttr(False) _prompt_format: HuggingFacePromptFormat = PrivateAttr( default_factory=HuggingFacePromptFormat ) def _infer_supports_vision(self, model_id_or_path: str) -> None: """Called in llm node initializer together with other transformers calls. It loads the model config from huggingface and sets the supports_vision attribute based on whether the config has `vision_config`. All LVM models has `vision_config` setup. """ hf_config = transformers.PretrainedConfig.from_pretrained(model_id_or_path) self._supports_vision = hasattr(hf_config, "vision_config")
[docs] def apply_checkpoint_info( self, model_id_or_path: str, trust_remote_code: bool = False ) -> None: """Apply the checkpoint info to the model config.""" self._infer_supports_vision(model_id_or_path) self._prompt_format.set_processor( model_id_or_path, trust_remote_code=trust_remote_code, )
@property def supports_vision(self) -> bool: return self._supports_vision @property def prompt_format(self) -> HuggingFacePromptFormat: return self._prompt_format @property def input_modality(self) -> str: """Returns the input modality of the model. There could be more types in the future. Right now assumes if the model doesn't support version, it'll be text. """ if self.supports_vision: return InputModality.image.value return InputModality.text.value @property def model_id(self) -> str: return self.model_loading_config.model_id @property def max_request_context_length(self) -> Optional[int]: return self.engine_kwargs.get("max_model_len")
[docs] @field_validator("accelerator_type") def validate_accelerator_type(cls, value: str): # Ensure A10 is converted to A10G. if value == "A10": value = "A10G" if value not in [t.value for t in GPUType]: raise ValueError(f"Unsupported accelerator type: {value}") return value
[docs] @field_validator("llm_engine") def validate_llm_engine(cls, value: str) -> str: """Validates the llm_engine string value.""" try: # Validate the engine LLMEngine(value) except ValueError as e: raise ValueError(f"Unsupported engine: {value}") from e return value
[docs] @field_validator("deployment_config") def validate_deployment_config(cls, value: Dict[str, Any]) -> Dict[str, Any]: """Validates the deployment config dictionary.""" try: # Only validate the deployment config DeploymentConfig(**value) except Exception as e: raise ValueError(f"Invalid deployment config: {value}") from e return value
[docs] def ray_accelerator_type(self) -> str: """Converts the accelerator type to the Ray Core format.""" # Ray uses a hyphen instead of an underscore for # accelerator_type. return f"accelerator_type:{self.accelerator_type.replace('_', '-')}"
[docs] def multiplex_config(self) -> ServeMultiplexConfig: multiplex_config = None if self.lora_config: multiplex_config = ServeMultiplexConfig( max_num_models_per_replica=self.lora_config.max_num_adapters_per_replica, download_timeout_s=self.lora_config.download_timeout_s, max_download_tries=self.lora_config.max_download_tries, ) return multiplex_config
[docs] def get_engine_config(self): """Returns the engine config for the given LLM config. LLMConfig not only has engine config but also deployment config, etc. """ if self.llm_engine == LLMEngine.VLLM: from ray.llm._internal.serve.deployments.llm.vllm.vllm_models import ( VLLMEngineConfig, ) return VLLMEngineConfig.from_llm_config(self) else: # Note (genesu): This should never happen because we validate the engine # in the config. raise ValueError(f"Unsupported engine: {self.llm_engine}")
def _set_deployment_placement_options(self) -> Dict[str, Any]: deployment_config = self.deployment_config engine_config = self.get_engine_config() ray_actor_options = deployment_config.get("ray_actor_options", {}) deployment_config["ray_actor_options"] = ray_actor_options replica_actor_resources = { "CPU": ray_actor_options.get("num_cpus", 1), "GPU": ray_actor_options.get("num_gpus", 0), **ray_actor_options.get("resources", {}), } if "memory" in ray_actor_options: replica_actor_resources["memory"] = ray_actor_options["memory"] if ( "placement_group_bundles" in deployment_config or "placement_group_strategy" in deployment_config ): raise ValueError( "placement_group_bundles and placement_group_strategy must not be specified in deployment_config. " "Use scaling_config to configure replica placement group." ) # TODO (Kourosh): There is some test code leakage happening here that should be removed. try: # resources.mock_resource is a special key we used in tests to skip placement # group on the gpu nodes. if "mock_resource" in ray_actor_options.get("resources", {}): bundles = [] else: bundles = engine_config.placement_bundles except ValueError: # May happen if all bundles are empty. bundles = [] bundles = [replica_actor_resources] + bundles deployment_config.update( { "placement_group_bundles": bundles, "placement_group_strategy": engine_config.placement_strategy, } ) return deployment_config def _get_deployment_name(self, name_prefix: str) -> str: unsanitized_deployment_name = name_prefix + self.model_id return unsanitized_deployment_name.replace("/", "--").replace(".", "_")
[docs] def get_serve_options( self, *, name_prefix: str, ) -> Dict[str, Any]: """Get the Serve options for the given LLM config. This method is used to generate the Serve options for the given LLM config. Examples: .. testcode:: :skipif: True from ray import serve from ray.serve.llm.configs import LLMConfig, ModelLoadingConfig from ray.serve.llm.deployments import VLLMDeployment llm_config = LLMConfig( model_loading_config=ModelLoadingConfig(model_id="test_model"), accelerator_type="L4", runtime_env={"env_vars": {"FOO": "bar"}}, ) serve_options = llm_config.get_serve_options(name_prefix="Test:") vllm_app = VLLMDeployment.options(**serve_options).bind(llm_config) serve.run(vllm_app) Keyword Args: name_prefix: The prefix to use for the deployment name. Returns: The dictionary to use in .options() when creating the deployment. """ deployment_config = self._set_deployment_placement_options() default_runtime_env = ray.get_runtime_context().runtime_env if ENABLE_WORKER_PROCESS_SETUP_HOOK: default_runtime_env[ "worker_process_setup_hook" ] = "ray.llm._internal.serve._worker_process_setup_hook" ray_actor_options = deployment_config.get("ray_actor_options", {}) ray_actor_options["runtime_env"] = { **default_runtime_env, # Existing runtime_env should take precedence over the default. **ray_actor_options.get("runtime_env", {}), **(self.runtime_env if self.runtime_env else {}), } deployment_config["ray_actor_options"] = ray_actor_options # Set the name of the deployment config to map to the model ID. deployment_config["name"] = self._get_deployment_name(name_prefix) return deployment_config
def _is_yaml_file(filename: str) -> bool: yaml_extensions = [".yml", ".yaml", ".json"] for s in yaml_extensions: if filename.endswith(s): return True return False def _parse_path_args(path: str) -> List[LLMConfig]: assert os.path.exists( path ), f"Could not load model from {path}, as it does not exist." if os.path.isfile(path): with open(path, "r") as f: llm_config = LLMConfig.parse_yaml(f) return [llm_config] elif os.path.isdir(path): apps = [] for root, _dirs, files in os.walk(path): for p in files: if _is_yaml_file(p): with open(os.path.join(root, p), "r") as f: llm_config = LLMConfig.parse_yaml(f) apps.append(llm_config) return apps else: raise ValueError( f"Could not load model from {path}, as it is not a file or directory." ) def parse_args( args: Union[str, LLMConfig, Any, Sequence[Union[LLMConfig, str, Any]]], ) -> List[LLMConfig]: """Parse the input args and return a standardized list of LLMConfig objects Supported args format: 1. The path to a yaml file defining your LLMConfig 2. The path to a folder containing yaml files, which define your LLMConfigs 3. A list of yaml files defining multiple LLMConfigs 4. A dict or LLMConfig object 5. A list of dicts or LLMConfig objects """ raw_models = [args] if isinstance(args, list): raw_models = args # For each models: List[LLMConfig] = [] for raw_model in raw_models: if isinstance(raw_model, str): if os.path.exists(raw_model): parsed_models = _parse_path_args(raw_model) else: try: llm_config = LLMConfig.parse_yaml(raw_model) parsed_models = [llm_config] except pydantic.ValidationError as e: raise ValueError( f"Could not parse string as yaml. If you are " "specifying a path, make sure it exists and can be " f"reached. raw_model: {raw_model}" ) from e else: try: llm_config = LLMConfig.model_validate(raw_model) parsed_models = [llm_config] except pydantic.ValidationError: parsed_models = [LLMConfig.model_validate(raw_model)] models += parsed_models return models class LLMServingArgs(BaseModel): llm_configs: List[Union[str, LLMConfig]] = Field( description="A list of LLMConfigs, or paths to LLMConfigs, to run.", )
[docs] def parse_args(self) -> "LLMServingArgs": """Converts this LLMServingArgs object into an DeployArgs object.""" llm_configs = [] for config in self.llm_configs: parsed_config = parse_args(config)[0] if not isinstance(parsed_config, LLMConfig): raise ValueError( "When using the new Serve config format, all model " "configs must also use the new model config format. Got " "a model config that doesn't match new format. Type: " f"{type(parsed_config)}. Contents: {parsed_config}." ) llm_configs.append(parsed_config) return LLMServingArgs(llm_configs=llm_configs)
TModel = TypeVar("TModel", bound="Model") class ModelData(BaseModel): model_config = ConfigDict(protected_namespaces=tuple()) id: str object: str owned_by: str permission: List[str] rayllm_metadata: Dict[str, Any] @property def model_type(self) -> str: return self.rayllm_metadata["engine_config"]["model_type"] class Model(BaseModel): data: List[ModelData] object: str = "list" @classmethod def list(cls) -> TModel: pass class FinishReason(str, Enum): LENGTH = "length" STOP = "stop" ERROR = "error" CANCELLED = "cancelled" def __str__(self) -> str: return self.value @classmethod def from_vllm_finish_reason( cls, finish_reason: Optional[str] ) -> Optional["FinishReason"]: if finish_reason is None: return None if finish_reason == "stop": return cls.STOP if finish_reason == "length": return cls.LENGTH if finish_reason == "abort": return cls.CANCELLED return cls.STOP class LoraMirrorConfig(BaseModelExtended): lora_model_id: str bucket_uri: str max_total_tokens: Optional[int] sync_args: Optional[List[str]] = None @field_validator("bucket_uri") @classmethod def validate_bucket_uri(cls, value: str): # TODO(tchordia): remove this. this is a short term fix. # We should fix this on the LLM-forge side if not value.startswith("s3://") and not value.startswith("gs://"): value = "s3://" + value return value @property def _bucket_name_and_path(self) -> str: for prefix in ["s3://", "gs://"]: if self.bucket_uri.startswith(prefix): return self.bucket_uri[len(prefix) :] return self.bucket_uri @property def bucket_name(self) -> str: return self._bucket_name_and_path.split("/")[0] @property def bucket_path(self) -> str: return "/".join(self._bucket_name_and_path.split("/")[1:]) class DiskMultiplexConfig(BaseModelExtended): model_id: str max_total_tokens: Optional[int] local_path: str # this is a per process id assigned to the model lora_assigned_int_id: int class ComputedPropertyMixin: """ Include properties in the dict and json representations of the model. """ # Replace with pydantic.computed_field once it's available @classmethod def get_properties(cls): return [prop for prop in dir(cls) if isinstance(getattr(cls, prop), property)] def model_dump(self, *args, **kwargs): self.__dict__.update( {prop: getattr(self, prop) for prop in self.get_properties()} ) return super().model_dump(*args, **kwargs) # type: ignore def model_dump_json( self, *args, **kwargs, ) -> str: self.__dict__.update( {prop: getattr(self, prop) for prop in self.get_properties()} ) return super().model_dump_json(*args, **kwargs) # type: ignore class LogProb(BaseModel): logprob: float token: str bytes: List[int] class LogProbs(BaseModel): token: str logprob: float bytes: List[int] top_logprobs: List[LogProb] @classmethod def create(cls, logprobs: List[LogProb], top_logprobs: Optional[int] = None): assert len(logprobs) > 0, "logprobs must be a non-empty list" token = logprobs[0].token logprob = logprobs[0].logprob bytes = logprobs[0].bytes all_logprobs = logprobs if top_logprobs else [] ret = cls(token=token, logprob=logprob, bytes=bytes, top_logprobs=all_logprobs) return ret class LLMRawResponse(ComputedPropertyMixin, BaseModelExtended): """The response from a query to a RayLLM Model. Args: generated_text: The generated text. logprobs: Log probabilities of each token and possibly some of the unchosen tokens. num_input_tokens: The number of input tokens. num_generated_tokens: The number of generated tokens. num_input_tokens_batch: The number of input tokens in the batch. num_generated_tokens_batch: The number of generated tokens in the batch. preprocessing_time: The time spent preprocessing the request. generation_time: The time spent generating the response. timestamp: The timestamp of the response. finish_reason: The reason the generation finished. error: The error, if any. """ generated_text: Optional[str] = None logprobs: Optional[List[LogProbs]] = None num_input_tokens: Optional[int] = None num_input_tokens_batch: Optional[int] = None num_generated_tokens: Optional[int] = None num_generated_tokens_batch: Optional[int] = None preprocessing_time: Optional[float] = None generation_time: Optional[float] = None timestamp: Optional[float] = Field(default_factory=time.time) finish_reason: Optional[str] = None error: Optional[ErrorResponse] = None @model_validator(mode="before") @classmethod def text_or_error_or_finish_reason(cls, values): if ( values.get("generated_text") is None and values.get("error") is None and values.get("finish_reason") is None ): raise ValueError( "'generated_text', 'error', or 'finish_reason' must be set." ) return values @classmethod def merge_stream(cls, *responses: "LLMRawResponse") -> "LLMRawResponse": """ Merge a stream of responses into a single response. The generated text is concatenated. Fields are maxed, except for num_generated_tokens and generation_time, which are summed. """ if len(responses) == 1: return responses[0] generated_text = ( None if responses[0].generated_text is None else "".join([response.generated_text or "" for response in responses]) ) num_input_tokens = [ response.num_input_tokens for response in responses if response.num_input_tokens is not None ] max_num_input_tokens = max(num_input_tokens) if num_input_tokens else None num_input_tokens_batch = [ response.num_input_tokens_batch for response in responses if response.num_input_tokens_batch is not None ] max_num_input_tokens_batch = ( max(num_input_tokens_batch) if num_input_tokens_batch else None ) num_generated_tokens = [ response.num_generated_tokens for response in responses if response.num_generated_tokens is not None ] total_generated_tokens = ( sum(num_generated_tokens) if num_generated_tokens else None ) num_generated_tokens_batch = [ response.num_generated_tokens_batch for response in responses if response.num_generated_tokens_batch is not None ] total_generated_tokens_batch = ( sum(num_generated_tokens_batch) if num_generated_tokens_batch else None ) preprocessing_time = [ response.preprocessing_time for response in responses if response.preprocessing_time is not None ] max_preprocessing_time = max(preprocessing_time) if preprocessing_time else None generation_time = [ response.generation_time for response in responses if response.generation_time is not None ] total_generation_time = sum(generation_time) if generation_time else None error = next( (response.error for response in reversed(responses) if response.error), None ) logprobs = [] for response in responses: if response.logprobs: logprobs.extend(response.logprobs) return cls( generated_text=generated_text, logprobs=logprobs, num_input_tokens=max_num_input_tokens, num_input_tokens_batch=max_num_input_tokens_batch, num_generated_tokens=total_generated_tokens, num_generated_tokens_batch=total_generated_tokens_batch, preprocessing_time=max_preprocessing_time, generation_time=total_generation_time, timestamp=responses[-1].timestamp, finish_reason=responses[-1].finish_reason, error=error, ) @property def total_time(self) -> Optional[float]: if self.generation_time is None and self.preprocessing_time is None: return None return (self.preprocessing_time or 0) + (self.generation_time or 0) @property def num_total_tokens(self) -> Optional[float]: try: return (self.num_input_tokens or 0) + (self.num_generated_tokens or 0) except Exception: return None @property def num_total_tokens_batch(self) -> Optional[float]: try: return (self.num_input_tokens_batch or 0) + ( self.num_generated_tokens_batch or 0 ) except Exception: return None def unpack(self) -> Tuple["LLMRawResponse", ...]: return (self,) class BatchedLLMRawResponse(LLMRawResponse): # Same as LLMRawResponse, but persists the individual responses # that were batched together to produce this response. _individual_responses: Optional[List[LLMRawResponse]] = PrivateAttr(None) @classmethod def merge_stream(cls, *responses: LLMRawResponse) -> LLMRawResponse: if len(responses) == 1: return responses[0] obj = super().merge_stream(*responses) obj._individual_responses = list(responses) # type: ignore return obj def unpack(self) -> Tuple[LLMRawResponse]: return tuple(self._individual_responses or []) def merge_dicts(base: Dict, overwrite: Dict) -> Dict: """ Merge overwrite into base. Modify base inplace. """ for key in overwrite: if ( key in base and isinstance(base[key], dict) and isinstance(overwrite[key], dict) ): merge_dicts(base[key], overwrite[key]) else: base[key] = overwrite[key] return base class SamplingParams(BaseModelExtended): """ Args: max_tokens: The maximum number of tokens to generate. Defaults to inf. temperature: What sampling temperature to use. top_p: An alternative to sampling with temperature, called nucleus sampling. n: How many completions to generate for each prompt. logprobs: Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. top_logprobs: The number of logprobs to return. Defaults to 1. `logprobs` must be set to `True` in order to use top_logprobs. stop: Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. stop_tokens: Tokens to stop on (applied before detokenization). presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. best_of: Generates `best_of` completions server-side and returns the "best". logit_bias: Modify the likelihood of specified tokens appearing in the completion. response_format: Format to return the final response in. Can be for ex: response_format={"type": "json", "schema": "{...}"} """ _ignored_fields: Set[str] = set() max_tokens: Optional[int] = None temperature: Optional[float] = None top_p: Optional[float] = None n: int = 1 logprobs: Optional[bool] = None top_logprobs: Optional[int] = None logit_bias: Optional[Dict[str, float]] = None stop: Optional[List[str]] = None stop_tokens: Optional[List[int]] = None ignore_eos: Optional[bool] = None presence_penalty: Optional[float] = None frequency_penalty: Optional[float] = None best_of: int = 1 response_format: Optional[ResponseFormatType] = None def model_dump(self, **kwargs): if kwargs.get("exclude", None) is None: kwargs["exclude"] = self._ignored_fields return super().model_dump(**kwargs) @field_validator("stop", mode="before") @classmethod def validate_stopping_sequences(cls, values): if not values: return values unique_val = sorted(set(values)) if len(unique_val) > MAX_NUM_STOPPING_SEQUENCES: TooManyStoppingSequences( len(unique_val), MAX_NUM_STOPPING_SEQUENCES ).raise_exception() return unique_val @classmethod def from_prompt(cls: Type[ModelT], prompt: Prompt) -> ModelT: # Extract parameters object from prompt generate_kwargs = prompt.parameters or {} if not isinstance(generate_kwargs, dict): generate_kwargs = generate_kwargs.model_dump(exclude_unset=True) generate_kwargs["stop"] = set(generate_kwargs.get("stop", [])) generate_kwargs["stop_tokens"] = set(generate_kwargs.get("stop_tokens", [])) return cls.model_validate(generate_kwargs) class GenerationRequest(BaseModelExtended): prompt: Union[str, List[int], List[str]] request_id: Union[str, List[str]] sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None