import logging
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from pydantic import Field, field_validator
import ray
from ray.data import Dataset
from ray.data.block import UserDefinedFunction
from ray.llm._internal.batch.stages import (
StatefulStage,
wrap_postprocess,
wrap_preprocess,
)
from ray.llm._internal.common.base_pydantic import BaseModelExtended
from ray.util.annotations import DeveloperAPI, PublicAPI
logger = logging.getLogger(__name__)
# Higher values here are better for prefetching and locality. It's ok for this to be
# fairly high since streaming backpressure prevents us from overloading actors.
DEFAULT_MAX_TASKS_IN_FLIGHT = 4
class ProcessorConfig(BaseModelExtended):
"""The processor configuration."""
batch_size: int = Field(
default=32,
description="Large batch sizes are likely to saturate the compute resources "
"and could achieve higher throughput. On the other hand, small batch sizes "
"are more fault-tolerant and could reduce bubbles in the data pipeline. "
"You can tune the batch size to balance the throughput and fault-tolerance "
"based on your use case. Defaults to 32.",
)
resources_per_bundle: Optional[Dict[str, float]] = Field(
default=None,
description="This will override the default resource bundles for placement groups. "
"You can specify a custom device label e.g. {'NPU': 1}. "
"The default resource bundle for LLM Stage is always a GPU resource i.e. {'GPU': 1}.",
)
accelerator_type: Optional[str] = Field(
default=None,
description="The accelerator type used by the LLM stage in a processor. "
"Default to None, meaning that only the CPU will be used.",
)
concurrency: Union[int, Tuple[int, int]] = Field(
default=1,
description="The number of workers for data parallelism. Default to 1. "
"If ``concurrency`` is a ``tuple`` ``(m, n)``, Ray creates an autoscaling "
"actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``). "
"If ``concurrency`` is an ``int`` ``n``, Ray uses either a fixed pool of ``n`` "
"workers or an autoscaling pool from ``1`` to ``n`` workers, depending on "
"the processor and stage.",
)
experimental: Dict[str, Any] = Field(
default_factory=dict,
description="[Experimental] Experimental configurations."
"Supported keys:\n"
"`max_tasks_in_flight_per_actor`: The maximum number of tasks in flight per actor. Default to 4.",
)
@field_validator("concurrency")
def validate_concurrency(
cls, concurrency: Union[int, Tuple[int, int]]
) -> Union[int, Tuple[int, int]]:
"""Validate that `concurrency` is either:
- a positive int, or
- a 2-tuple `(min, max)` of positive ints with `min <= max`.
"""
def require(condition: bool, message: str) -> None:
if not condition:
raise ValueError(message)
if isinstance(concurrency, int):
require(
concurrency > 0,
f"A positive integer for `concurrency` is expected! Got: `{concurrency}`.",
)
elif isinstance(concurrency, tuple):
require(
all(c > 0 for c in concurrency),
f"`concurrency` tuple items must be positive integers! Got: `{concurrency}`.",
)
min_concurrency, max_concurrency = concurrency
require(
min_concurrency <= max_concurrency,
f"min > max in the concurrency tuple `{concurrency}`!",
)
return concurrency
def get_concurrency(self, autoscaling_enabled: bool = True) -> Tuple[int, int]:
"""Return a normalized `(min, max)` worker range from `self.concurrency`.
Behavior:
- If `concurrency` is an int `n`:
- `autoscaling_enabled` is True -> return `(1, n)` (autoscaling).
- `autoscaling_enabled` is False -> return `(n, n)` (fixed-size pool).
- If `concurrency` is a 2-tuple `(m, n)`, return it unchanged
(the `autoscaling_enabled` flag is ignored).
Args:
autoscaling_enabled: When False, treat an integer `concurrency` as fixed `(n, n)`;
otherwise treat it as a range `(1, n)`. Defaults to True.
Returns:
tuple[int, int]: The allowed worker range `(min, max)`.
Examples:
>>> self.concurrency = (2, 4)
>>> self.get_concurrency()
(2, 4)
>>> self.concurrency = 4
>>> self.get_concurrency()
(1, 4)
>>> self.get_concurrency(autoscaling_enabled=False)
(4, 4)
"""
if isinstance(self.concurrency, int):
if autoscaling_enabled:
return 1, self.concurrency
else:
return self.concurrency, self.concurrency
return self.concurrency
class Config:
validate_assignment = True
arbitrary_types_allowed = True
class OfflineProcessorConfig(ProcessorConfig):
"""The processor configuration for offline processing."""
model_source: str = Field(
description="The model source to use for the offline processing.",
)
runtime_env: Optional[Dict[str, Any]] = Field(
default=None,
description="The runtime environment to use for the offline processing.",
)
max_pending_requests: Optional[int] = Field(
default=None,
description="The maximum number of pending requests. If not specified, "
"will use the default value from the backend engine.",
)
max_concurrent_batches: int = Field(
default=8,
description="The maximum number of concurrent batches in the engine. "
"This is to overlap the batch processing to avoid the tail latency of "
"each batch. The default value may not be optimal when the batch size "
"or the batch processing latency is too small, but it should be good "
"enough for batch size >= 32.",
)
# Processor stage configurations.
apply_chat_template: bool = Field(
default=True, description="Whether to apply chat template."
)
chat_template: Optional[str] = Field(
default=None,
description="The chat template to use. This is usually not needed if the "
"model checkpoint already contains the chat template.",
)
tokenize: bool = Field(
default=True,
description="Whether to tokenize the input before passing it to the "
"backend engine. If not, the backend engine will tokenize the prompt.",
)
detokenize: bool = Field(
default=True,
description="Whether to detokenize the output.",
)
has_image: bool = Field(
default=False,
description="Whether the input messages have images.",
)
[docs]
@PublicAPI(stability="alpha")
class Processor:
"""A processor is composed of a preprocess stage, followed by one or more
processing stages, and finally a postprocess stage. We use processor as a
paradigm for processing data using LLMs.
Args:
config: The processor config.
preprocess: An optional lambda function that takes a row (dict) as input
and returns a preprocessed row (dict). The output row must contain the
required fields for the following processing stages.
postprocess: An optional lambda function that takes a row (dict) as input
and returns a postprocessed row (dict).
"""
# The internal used data column name ("__data"). Your input
# dataset should not contain this column. If you want to use this column
# in your input dataset, you have to derive and customize Processor.
DATA_COLUMN: str = "__data"
def __init__(
self,
config: ProcessorConfig,
stages: List[StatefulStage],
preprocess: Optional[UserDefinedFunction] = None,
postprocess: Optional[UserDefinedFunction] = None,
):
self.config = config
self.preprocess = None
self.postprocess = None
self.stages: OrderedDict[str, StatefulStage] = OrderedDict()
# FIXES: https://github.com/ray-project/ray/issues/53124
# TODO (Kourosh): Remove this once the issue is fixed
data_context = ray.data.DataContext.get_current()
data_context.wait_for_min_actors_s = 600
# TODO: Remove this when https://github.com/ray-project/ray/issues/53169
# is fixed.
data_context._enable_actor_pool_on_exit_hook = True
# NOTE (Kourosh): If pre/postprocess is not provided, use the identity function.
# Wrapping is required even if they are identity functions, b/c data_column
# gets inserted/removed via wrap_preprocess/wrap_postprocess.
preprocess = preprocess or (lambda row: row)
postprocess = postprocess or (lambda row: row)
self.preprocess = wrap_preprocess(
preprocess,
self.DATA_COLUMN,
)
self.postprocess = wrap_postprocess(
postprocess,
self.DATA_COLUMN,
)
for stage in stages:
self._append_stage(stage)
def __call__(self, dataset: Dataset) -> Dataset:
"""Execute the processor:
preprocess -> stages -> postprocess.
Note that the dataset won't be materialized during the execution.
Args:
dataset: The input dataset.
Returns:
The output dataset.
"""
if self.preprocess is not None:
dataset = dataset.map(self.preprocess)
# Apply stages.
for stage in self.stages.values():
kwargs = stage.get_dataset_map_batches_kwargs(
batch_size=self.config.batch_size,
data_column=self.DATA_COLUMN,
)
dataset = dataset.map_batches(stage.fn, **kwargs)
if self.postprocess is not None:
dataset = dataset.map(self.postprocess)
return dataset
def _append_stage(self, stage: StatefulStage) -> None:
"""Append a stage before postprocess. The stage class name will be used as
the stage name. If there are multiple stages with the same type, a suffix
will be added to the stage name to avoid conflicts.
Args:
stage: The stage to append.
"""
stage_name = type(stage).__name__
# When a processor has multiple stages with the same type,
# append a index suffix to the stage name to avoid conflicts.
if stage_name in self.stages:
num_same_type_stage = len([s for s in self.stages.values() if s is stage])
stage_name = f"{stage_name}_{num_same_type_stage + 1}"
self.stages[stage_name] = stage
[docs]
def list_stage_names(self) -> List[str]:
"""List the stage names of this processor in order. Preprocess and postprocess
are not included.
Returns:
A list of stage names.
"""
return list(self.stages.keys())
[docs]
def get_stage_by_name(self, name: str) -> StatefulStage:
"""Get a particular stage by its name. If the stage is not found,
a ValueError will be raised.
Args:
name: The stage name.
Returns:
The pipeline stage.
"""
if name in self.stages:
return self.stages[name]
raise ValueError(f"Stage {name} not found")
@DeveloperAPI
class ProcessorBuilder:
"""Build a processor based on the configuration."""
_registry: Dict[str, Callable] = {}
@classmethod
def register(cls, config_type: Type[ProcessorConfig], builder: Callable) -> None:
"""A decorator to associate a particular pipeline config
with its build function.
"""
type_name = config_type.__name__
if type_name in cls._registry:
raise ValueError(f"Processor config type {type_name} already registered.")
cls._registry[type_name] = builder
@classmethod
def build(
cls,
config: ProcessorConfig,
override_stage_config_fn: Optional[Callable] = None,
**kwargs,
) -> Processor:
"""Build a processor.
Args:
config: The processor config.
override_stage_config_fn: Custom stages configurations.
Returns:
The built processor.
"""
type_name = type(config).__name__
if type_name not in cls._registry:
raise ValueError(
f"Processor config type {type_name} not registered. "
f"Available types: {cls._registry.keys()}"
)
processor = cls._registry[type_name](config, **kwargs)
if override_stage_config_fn is not None:
for name, stage in processor.stages.items():
override_stage_config_fn(name, stage)
return processor