Working with LLMs#
The ray.data.llm module enables scalable batch inference on Ray Data datasets. It supports two modes: running LLM inference engines directly (vLLM, SGLang) or querying hosted endpoints through ServeDeploymentProcessorConfig.
Getting started:
Quickstart - Run your first batch inference job
Architecture - Understand the processor pipeline
Scaling - Scale your LLM stage to multiple replicas
Common use cases:
Text generation - Chat completions with LLMs
Embeddings - Generate text embeddings
Classification - Content classifiers and sentiment analyzers
Multimodality - Batch inference with VLM / omni models on multimodal data
OpenAI-compatible endpoints - Query deployed models
Serve deployments - Share vLLM engines across processors
Operations:
Troubleshooting - GPU memory, model loading issues
Advanced configuration - Parallelism, per-stage tuning, LoRA
Quickstart: vLLM batch inference#
Get started with vLLM batch inference in just a few steps. This example shows the minimal setup needed to run batch inference on a dataset.
Note
This quickstart requires a GPU as vLLM is GPU-accelerated.
First, install Ray Data with LLM support:
pip install -U "ray[data, llm]>=2.49.1"
Here’s a complete minimal example that runs batch inference:
import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_processor
# Initialize Ray
ray.init()
# simple dataset
ds = ray.data.from_items([
{"prompt": "What is machine learning?"},
{"prompt": "Explain neural networks in one sentence."},
])
# Minimal vLLM configuration
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
concurrency=1, # 1 vLLM engine replica
batch_size=32, # 32 samples per batch
engine_kwargs={
"max_model_len": 4096, # Fit into test GPU memory
}
)
# Build processor
# preprocess: converts input row to format expected by vLLM (OpenAI chat format)
# postprocess: extracts generated text from vLLM output
processor = build_processor(
config,
preprocess=lambda row: {
"messages": [{"role": "user", "content": row["prompt"]}],
"sampling_params": {"temperature": 0.7, "max_tokens": 100},
},
postprocess=lambda row: {
"prompt": row["prompt"],
"response": row["generated_text"],
},
)
# inference
ds = processor(ds)
# iterate through the results
for result in ds.iter_rows():
print(f"Q: {result['prompt']}")
print(f"A: {result['response']}\n")
# Alternative ways to get results:
# results = ds.take(10) # Get first 10 results
# ds.show(limit=5) # Print first 5 results
# ds.write_parquet("output.parquet") # Save to file
This example:
Creates a simple dataset with prompts
Configures a vLLM processor with minimal settings
Builds a processor that handles preprocessing (converting prompts to OpenAI chat format) and postprocessing (extracting generated text)
Runs inference on the dataset
Iterates through results
The processor expects input rows with a prompt field and outputs rows with both prompt and response fields. You can consume results using iter_rows(), take(), show(), or save to files with write_parquet().
For more configuration options and advanced features, see the sections below.
Processor architecture#
Ray Data LLM uses a multi-stage processor pipeline to transform your data through LLM inference. Understanding this architecture helps you optimize performance and debug issues.
Input Dataset
|
v
- Preprocess (Custom Function)
- PrepareMultimodal (Optional, for VLM / Omni models)
- ChatTemplate (Applies chat template to messages)
- Tokenize (Converts text to token IDs)
- LLM Engine (vLLM/SGLang inference on GPU)
- Detokenize (Converts token IDs back to text)
- Postprocess (Custom Function)
|
v
Output Dataset
Stage descriptions:
Preprocess: Your custom function that transforms input rows into the format expected by downstream stages (typically OpenAI chat format with
messages).PrepareMultimodal: Extracts and prepares multimodal inputs. Enable with
prepare_multimodal_stage={"enabled": True}.ChatTemplate: Applies the model’s chat template to convert messages into a prompt string.
Tokenize: Converts the prompt string into token IDs for the model.
LLM Engine: The accelerated (GPU/TPU) inference stage running vLLM or SGLang.
Detokenize: Converts output token IDs back to readable text.
Postprocess: Your custom function that extracts and formats the final output.
Each stage runs as a separate Ray actor pool, enabling independent scaling and resource allocation. CPU stages (ChatTemplate, Tokenize, Detokenize, and HttpRequestStage) use autoscaling actor pools (except for ServeDeployment stage), while the GPU stage uses a fixed pool.
Scaling to multiple GPUs#
Horizontally scale the LLM stage to multiple GPU replicas using the concurrency parameter:
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
},
concurrency=10,
batch_size=64,
)
Each replica runs an independent inference engine. Set concurrency to match the number of available GPUs or GPU nodes.
Text generation#
Use vLLMEngineProcessorConfig or SGLangEngineProcessorConfig for chat completions and text generation tasks.
Key configuration options:
model_source: HuggingFace model ID or path to model weightsconcurrency: Number of vLLM engine replicas (typically 1 per GPU node)batch_size: Rows per batch (reduce if hitting memory limits)
# Basic vLLM configuration
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096, # Reduce if CUDA OOM occurs
"max_model_len": 4096, # Constrain to fit test GPU memory
},
concurrency=1,
batch_size=64,
)
For gated models requiring authentication, pass your HuggingFace token through runtime_env:
# Configuration with Hugging Face token
config_with_token = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
runtime_env={"env_vars": {"HF_TOKEN": "your_huggingface_token"}},
concurrency=1,
batch_size=64,
)
Multimodality#
Ray Data LLM also supports running batch inference with vision language and omni-modal models on multimodal data. To enable multimodal batch inference, apply the following 2 adjustments on top of the previous example:
Set
prepare_multimodal_stage={"enabled": True}in thevLLMEngineProcessorConfigPrepare multimodal data inside the preprocessor.
Image batch inference with vision language model (VLM)#
First, load a vision dataset:
"""
Load vision dataset from Hugging Face.
This function loads the LMMs-Eval-Lite dataset which contains:
- Images with associated questions
- Multiple choice answers
- Various visual reasoning tasks
"""
try:
from huggingface_hub import HfFileSystem
# Load "LMMs-Eval-Lite" dataset from Hugging Face using HfFileSystem
path = "hf://datasets/lmms-lab/LMMs-Eval-Lite/coco2017_cap_val/"
fs = HfFileSystem()
vision_dataset = ray.data.read_parquet(path, filesystem=fs)
return vision_dataset
except ImportError:
print(
"huggingface_hub package not available. Install with: pip install huggingface_hub"
)
return None
except Exception as e:
print(f"Error loading dataset: {e}")
return None
Next, configure the VLM processor with the essential settings:
vision_processor_config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-VL-3B-Instruct",
engine_kwargs=dict(
tensor_parallel_size=1,
pipeline_parallel_size=1,
max_model_len=4096,
trust_remote_code=True,
limit_mm_per_prompt={"image": 1},
),
batch_size=16,
concurrency=1,
prepare_multimodal_stage={"enabled": True},
)
Define preprocessing and postprocessing functions to convert dataset rows into
the format expected by the VLM and extract model responses. Within the preprocessor,
structure image data as part of an OpenAI-compatible message. Both image URL and
PIL.Image.Image object are supported.
"""Supported image input formats: image URL, PIL Image object"""
{
"messages": [
{
"role": "system",
"content": "Provide a detailed description of the image."
},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe what happens in this image."},
# Option 1: Provide image URL
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
# Option 2: Provide PIL Image object
{"type": "image_pil", "image_pil": PIL.Image.open("path/to/image.jpg")}
]
},
]
}
def vision_preprocess(row: dict) -> dict:
"""
Preprocessing function for vision-language model inputs.
Converts dataset rows into the format expected by the VLM:
- System prompt for analysis instructions
- User message with text and image content
- Multiple choice formatting
- Sampling parameters
"""
choice_indices = ["A", "B", "C", "D", "E", "F", "G", "H"]
return {
"messages": [
{
"role": "system",
"content": (
"Analyze the image and question carefully, using step-by-step reasoning. "
"First, describe any image provided in detail. Then, present your reasoning. "
"And finally your final answer in this format: Final Answer: <answer> "
"where <answer> is: The single correct letter choice A, B, C, D, E, F, etc. when options are provided. "
"Only include the letter. Your direct answer if no options are given, as a single phrase or number. "
"IMPORTANT: Remember, to end your answer with Final Answer: <answer>."
),
},
{
"role": "user",
"content": [
{"type": "text", "text": row["question"] + "\n\n"},
{
"type": "image_pil",
"image_pil": Image.open(BytesIO(row["image"]["bytes"])),
},
{
"type": "text",
"text": "\n\nChoices:\n"
+ "\n".join(
[
f"{choice_indices[i]}. {choice}"
for i, choice in enumerate(row["answer"])
]
),
},
],
},
],
"sampling_params": {
"temperature": 0.3,
"max_tokens": 150,
"detokenize": False,
},
# Include original data for reference
"original_data": {
"question": row["question"],
"answer_choices": row["answer"],
"image_size": row["image"].get("width", 0) if row["image"] else 0,
},
}
def vision_postprocess(row: dict) -> dict:
return {
"resp": row["generated_text"],
}
Finally, run the VLM inference:
"""Run the complete VLM example workflow."""
config = create_vlm_config()
vision_dataset = load_vision_dataset()
if vision_dataset:
# Build processor with preprocessing and postprocessing
processor = build_processor(
config, preprocess=vision_preprocess, postprocess=vision_postprocess
)
print("VLM processor configured successfully")
print(f"Model: {config.model_source}")
print(f"Has multimodal support: {config.prepare_multimodal_stage.get('enabled', False)}")
result = processor(vision_dataset).take_all()
return config, processor, result
Video batch inference with vision language model (VLM)#
First, load a video dataset:
"""
Load video dataset from ShareGPTVideo Hugging Face dataset.
"""
try:
from huggingface_hub import hf_hub_download
import tarfile
from pathlib import Path
dataset_name = "ShareGPTVideo/train_raw_video"
tar_path = hf_hub_download(
repo_id=dataset_name,
filename="activitynet/chunk_0.tar.gz",
repo_type="dataset",
)
extract_dir = "/tmp/sharegpt_videos"
os.makedirs(extract_dir, exist_ok=True)
if not any(Path(extract_dir).glob("*.mp4")):
with tarfile.open(tar_path, "r:gz") as tar:
tar.extractall(extract_dir)
video_files = list(Path(extract_dir).rglob("*.mp4"))
# Limit to first 10 videos for the example
video_files = video_files[:10]
video_dataset = ray.data.from_items(
[
{
"video_path": str(video_file),
"video_url": f"file://{video_file}",
"text": "Describe what happens in this video.",
}
for video_file in video_files
]
)
return video_dataset
except Exception as e:
print(f"Error loading dataset: {e}")
return None
Next, configure the VLM processor with the essential settings:
video_processor_config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen3-VL-4B-Instruct",
engine_kwargs=dict(
tensor_parallel_size=4,
pipeline_parallel_size=1,
trust_remote_code=True,
limit_mm_per_prompt={"video": 1},
),
batch_size=1,
accelerator_type="L4",
concurrency=1,
prepare_multimodal_stage={
"enabled": True,
"model_config_kwargs": dict(
# See available model config kwargs at https://docs.vllm.ai/en/latest/api/vllm/config/#vllm.config.ModelConfig
allowed_local_media_path="/tmp",
),
},
chat_template_stage=True,
tokenize_stage=True,
detokenize_stage=True,
)
Define preprocessing and postprocessing functions to convert dataset rows into the format expected by the VLM and extract model responses. Within the preprocessor, structure video data as part of an OpenAI-compatible message.
def video_preprocess(row: dict) -> dict:
"""
Preprocessing function for video-language model inputs.
Converts dataset rows into the format expected by the VLM:
- System prompt for analysis instructions
- User message with text and video content
- Sampling parameters
- Multimodal processor kwargs for video processing
"""
return {
"messages": [
{
"role": "system",
"content": (
"You are a helpful assistant that analyzes videos. "
"Watch the video carefully and provide detailed descriptions."
),
},
{
"role": "user",
"content": [
{
"type": "text",
"text": row["text"],
},
{
"type": "video_url",
"video_url": {"url": row["video_url"]},
},
],
},
],
"sampling_params": {
"temperature": 0.3,
"max_tokens": 150,
"detokenize": False,
},
# Optional: Multimodal processor kwargs for video processing
"mm_processor_kwargs": dict(
min_pixels=28 * 28,
max_pixels=1280 * 28 * 28,
fps=1,
),
}
def video_postprocess(row: dict) -> dict:
return {
"resp": row["generated_text"],
}
Finally, run the VLM inference:
"""Run the complete VLM video example workflow."""
config = create_vlm_video_config()
video_dataset = load_video_dataset()
if video_dataset:
# Build processor with preprocessing and postprocessing
processor = build_processor(
config, preprocess=video_preprocess, postprocess=video_postprocess
)
print("VLM video processor configured successfully")
print(f"Model: {config.model_source}")
print(f"Has multimodal support: {config.prepare_multimodal_stage.get('enabled', False)}")
result = processor(video_dataset).take_all()
return config, processor, result
Audio batch inference with omni-modal model#
First, load an audio dataset:
"""
Load audio dataset from MRSAudio Hugging Face dataset.
"""
try:
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import base64
dataset_name = "MRSAudio/MRSAudio"
dataset = load_dataset(dataset_name, split="train")
audio_items = []
# Limit to first 10 samples for the example
num_samples = min(10, len(dataset))
for i in range(num_samples):
item = dataset[i]
audio_path = hf_hub_download(
repo_id=dataset_name, filename=item["path"], repo_type="dataset"
)
with open(audio_path, "rb") as f:
audio_bytes = f.read()
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
audio_items.append(
{
"audio_data": audio_base64,
"text": item.get("text", "Describe this audio."),
}
)
audio_dataset = ray.data.from_items(audio_items)
return audio_dataset
except Exception as e:
print(f"Error loading dataset: {e}")
return None
Next, configure the omni-modal processor with the essential settings:
audio_processor_config = vLLMEngineProcessorConfig(
model_source="Qwen/Qwen2.5-Omni-3B",
task_type="generate",
engine_kwargs=dict(
limit_mm_per_prompt={"audio": 1},
),
batch_size=16,
accelerator_type="L4",
concurrency=1,
prepare_multimodal_stage={
"enabled": True,
"chat_template_content_format": "openai",
},
chat_template_stage=True,
tokenize_stage=True,
detokenize_stage=True,
)
Define preprocessing and postprocessing functions to convert dataset rows into the format expected by the omni-modal model and extract model responses. Within the preprocessor, structure audio data as part of an OpenAI-compatible message. Both audio URL and audio binary data are supported.
"""Supported audio input formats: audio URL, audio binary data"""
{
"messages": [
{
"role": "system",
"content": "Provide a detailed description of the audio."
},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe what happens in this audio."},
# Option 1: Provide audio URL
{"type": "audio_url", "audio_url": {"url": "https://example.com/audio.wav"}},
# Option 2: Provide audio binary data
{"type": "input_audio", "input_audio": {"data": audio_base64, "format": "wav"}},
]
},
]
}
def audio_preprocess(row: dict) -> dict:
"""
Preprocessing function for audio-language model inputs.
Converts dataset rows into the format expected by the Omni model:
- System prompt for analysis instructions
- User message with text and audio content
- Sampling parameters
"""
return {
"messages": [
{
"role": "system",
"content": "You are a helpful assistant that analyzes audio. "
"Listen to the audio carefully and provide detailed descriptions.",
},
{
"role": "user",
"content": [
{
"type": "text",
"text": row["text"],
},
{
"type": "input_audio",
"input_audio": {
"data": row["audio_data"],
"format": "wav",
},
},
],
},
],
"sampling_params": {
"temperature": 0.3,
"max_tokens": 150,
"detokenize": False,
},
}
def audio_postprocess(row: dict) -> dict:
return {
"resp": row["generated_text"],
}
Finally, run the omni-modal inference:
"""Run the complete Omni audio example workflow."""
config = create_omni_audio_config()
audio_dataset = load_audio_dataset()
if audio_dataset:
# Build processor with preprocessing and postprocessing
processor = build_processor(
config, preprocess=audio_preprocess, postprocess=audio_postprocess
)
print("Omni audio processor configured successfully")
print(f"Model: {config.model_source}")
print(f"Has multimodal support: {config.prepare_multimodal_stage.get('enabled', False)}")
result = processor(audio_dataset).take_all()
return config, processor, result
Embeddings#
For embedding models, set task_type="embed" and disable chat templating:
import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_processor
embedding_config = vLLMEngineProcessorConfig(
model_source="sentence-transformers/all-MiniLM-L6-v2",
task_type="embed",
engine_kwargs=dict(
enable_prefix_caching=False,
enable_chunked_prefill=False,
max_model_len=256,
enforce_eager=True,
),
batch_size=32,
concurrency=1,
chat_template_stage=False, # Skip chat templating for embeddings
detokenize_stage=False, # Skip detokenization for embeddings
)
embedding_processor = build_processor(
embedding_config,
preprocess=lambda row: dict(prompt=row["text"]),
postprocess=lambda row: {
"text": row["prompt"],
"embedding": row["embeddings"],
},
)
texts = [
"Hello world",
"This is a test sentence",
"Embedding models convert text to vectors",
]
ds = ray.data.from_items([{"text": text} for text in texts])
embedded_ds = embedding_processor(ds)
embedded_ds.show(limit=1)
Key differences from text generation:
Use
promptinput instead ofmessagesAccess results through
row["embeddings"]
Classification#
Ray Data LLM supports batch inference with sequence classification models, such as content classifiers and sentiment analyzers:
import ray
from ray.data.llm import vLLMEngineProcessorConfig, build_processor
# Configure vLLM for a sequence classification model
classification_config = vLLMEngineProcessorConfig(
model_source="nvidia/nemocurator-fineweb-nemotron-4-edu-classifier",
task_type="classify", # Use 'classify' for sequence classification models
engine_kwargs=dict(
max_model_len=512,
enforce_eager=True,
),
batch_size=8,
concurrency=1,
chat_template_stage=False,
detokenize_stage=False,
)
classification_processor = build_processor(
classification_config,
preprocess=lambda row: dict(prompt=row["text"]),
postprocess=lambda row: {
"text": row["prompt"],
# Classification models return logits in the 'embeddings' field
"edu_score": float(row["embeddings"][0])
if row.get("embeddings") is not None and len(row["embeddings"]) > 0
else None,
},
)
# Sample texts with varying educational quality
texts = [
"lol that was so funny haha",
"Photosynthesis converts light energy into chemical energy.",
"Newton's laws describe the relationship between forces and motion.",
]
ds = ray.data.from_items([{"text": text} for text in texts])
if __name__ == "__main__":
try:
import torch
if torch.cuda.is_available():
classified_ds = classification_processor(ds)
classified_ds.show(limit=3)
else:
print("Skipping classification run (no GPU available)")
except Exception as e:
print(f"Skipping classification run due to environment error: {e}")
Key differences for classification models:
Set
task_type="classify"(ortask_type="score"for scoring models)Set
chat_template_stage=Falseanddetokenize_stage=FalseUse direct
promptinput instead ofmessagesAccess classification logits through
row["embeddings"]
OpenAI-compatible endpoints#
Query deployed models with an OpenAI-compatible API:
import ray
OPENAI_KEY = os.environ["OPENAI_API_KEY"]
ds = ray.data.from_items(["Hand me a haiku."])
config = HttpRequestProcessorConfig(
url="https://api.openai.com/v1/chat/completions",
headers={"Authorization": f"Bearer {OPENAI_KEY}"},
qps=1,
)
processor = build_processor(
config,
preprocess=lambda row: dict(
payload=dict(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "You are a bot that responds with haikus.",
},
{"role": "user", "content": row["item"]},
],
temperature=0.0,
max_tokens=150,
),
),
postprocess=lambda row: dict(
response=row["http_response"]["choices"][0]["message"]["content"]
),
)
ds = processor(ds)
print(ds.take_all())
Troubleshooting#
GPU memory and CUDA OOM#
If you encounter CUDA out of memory errors, try these strategies:
Reduce batch size: Start with 8-16 and increase gradually
Lower ``max_num_batched_tokens``: Reduce from 4096 to 2048 or 1024
Decrease ``max_model_len``: Use shorter context lengths
Set ``gpu_memory_utilization``: Use 0.75-0.85 instead of default 0.90
# GPU memory management configuration
# If you encounter CUDA out of memory errors, try these optimizations:
config_memory_optimized = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 8192,
"max_num_batched_tokens": 2048,
"enable_chunked_prefill": True,
"gpu_memory_utilization": 0.85,
"block_size": 16,
},
concurrency=1,
batch_size=16,
)
# For very large models or limited GPU memory:
config_minimal_memory = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 4096,
"max_num_batched_tokens": 1024,
"enable_chunked_prefill": True,
"gpu_memory_utilization": 0.75,
},
concurrency=1,
batch_size=8,
)
Model loading at scale#
For large clusters, HuggingFace downloads may be rate-limited. Cache models to S3 or GCS:
python -m ray.llm.utils.upload_model \
--model-source facebook/opt-350m \
--bucket-uri gs://my-bucket/path/to/model
Then reference the remote path in your config:
# S3 hosted model configuration
s3_config = vLLMEngineProcessorConfig(
model_source="s3://your-bucket/your-model-path/",
engine_kwargs={
"load_format": "runai_streamer",
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)
Resiliency#
Row-level fault tolerance#
In Ray Data LLM, row-level fault tolerance is achieved by setting the should_continue_on_error parameter to True in the processor config.
This means that if a single row fails due to a request level error from the engine, the job continues processing the remaining rows.
This is useful for long-running jobs where you want to minimize the impact of request failures.
# Row-level fault tolerance configuration
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
concurrency=1,
batch_size=64,
should_continue_on_error=True,
)
Actor-level fault tolerance#
When an actor dies in the middle of a pipeline execution, it’s restarted and rejoins the actor pool to process remaining rows. This feature is enabled by default, and there are no additional configuration needed.
Checkpoint recovery#
Ray Data supports checkpoint recovery, which lets you resume pipeline execution from a checkpoint stored in local or cloud storage. Checkpointing works only for pipelines that start with a read operation and end with a write operation. For checkpointing to take effect, successful blocks must reach the write sink before a failure occurs. After a failure, you can resume processing from the checkpoint in a subsequent run.
First, set up the checkpoint configuration and specify the ID column for checkpointing.
from ray.data.checkpoint import CheckpointConfig
ctx = ray.data.DataContext.get_current()
ctx.checkpoint_config = CheckpointConfig(
id_column="id",
checkpoint_path=checkpoint_path,
delete_checkpoint_on_success=False,
)
Then, include a read and write operation in the pipeline to enable checkpoint recovery. It’s important to preserve the ID column during postprocess to ensure that the ID column is stored in the checkpoint.
processor_config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
concurrency=1,
batch_size=16,
)
processor = build_processor(
processor_config,
preprocess=lambda row: dict(
id=row["id"], # Preserve the ID column for checkpointing
prompt=row["prompt"],
sampling_params=dict(
temperature=0.3,
max_tokens=10,
),
),
postprocess=lambda row: {
"id": row["id"], # Preserve the ID column for checkpointing
"answer": row.get("generated_text"),
},
)
ds = ray.data.read_parquet(input_path)
ds = processor(ds)
ds.write_parquet(output_path)
To resume from a checkpoint, run the same code again. Ray Data discovers the checkpoint and resumes from the last successful block.
Advanced configuration#
Model parallelism#
For large models that don’t fit on a single GPU, use tensor and pipeline parallelism:
# Model parallelism configuration for larger models
# tensor_parallel_size=2: Split model across 2 GPUs for tensor parallelism
# pipeline_parallel_size=2: Use 2 pipeline stages (total 4 GPUs needed)
# Total GPUs required = tensor_parallel_size * pipeline_parallel_size = 4
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"max_model_len": 16384,
"tensor_parallel_size": 2,
"pipeline_parallel_size": 2,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 2048,
},
concurrency=1,
batch_size=32,
accelerator_type="L4",
)
Cross-node parallelism#
Ray Data LLM supports cross-node parallelism, including tensor parallelism and pipeline parallelism. Configure the parallelism level through engine_kwargs. The distributed_executor_backend defaults to "ray" for cross-node support.
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
"pipeline_parallel_size": 4,
"tensor_parallel_size": 4,
"distributed_executor_backend": "ray",
},
batch_size=32,
concurrency=1,
)
You can customize the placement group strategy to control how Ray places vLLM engine workers across nodes. While you can specify the degree of tensor and pipeline parallelism, the specific assignment of model ranks to GPUs is managed by the vLLM engine.
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4096,
"max_model_len": 16384,
"pipeline_parallel_size": 2,
"tensor_parallel_size": 2,
"distributed_executor_backend": "ray",
},
batch_size=32,
concurrency=1,
placement_group_config={
"bundles": [{"GPU": 1}] * 4,
"strategy": "STRICT_PACK",
},
)
Per-stage configuration#
Configure individual pipeline stages for fine-grained resource control:
config = vLLMEngineProcessorConfig(
model_source="meta-llama/Llama-3.1-8B-Instruct",
chat_template_stage={
"enabled": True,
"batch_size": 256,
"concurrency": 4,
},
tokenize_stage={
"enabled": True,
"batch_size": 512,
"num_cpus": 0.5,
},
detokenize_stage={
"enabled": True,
"concurrency": (2, 8), # Autoscaling pool
},
)
See stage config classes for all available fields.
LoRA adapters#
For multi-LoRA batch inference:
# Multi-LoRA configuration
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"enable_lora": True,
"max_lora_rank": 32,
"max_loras": 1,
"max_model_len": 16384,
},
concurrency=1,
batch_size=32,
)
See the vLLM with LoRA example for details.
Accelerated model loading with RunAI streamer#
Use RunAI Model Streamer for faster model loading from cloud storage:
Note
Install vLLM with runai dependencies: pip install -U "vllm[runai]>=0.10.1"
# RunAI streamer configuration for optimized model loading
# Note: Install vLLM with runai dependencies: pip install -U "vllm[runai]>=0.10.1"
config = vLLMEngineProcessorConfig(
model_source="unsloth/Llama-3.1-8B-Instruct",
engine_kwargs={
"load_format": "runai_streamer",
"max_model_len": 16384,
},
concurrency=1,
batch_size=64,
)
Serve deployments#
For multi-turn conversations or complex agentic workflows, share a vLLM engine across multiple processors using Ray Serve:
import ray
from ray import serve
from ray.data.llm import ServeDeploymentProcessorConfig, build_processor
from ray.serve.llm import (
LLMConfig,
ModelLoadingConfig,
build_llm_deployment,
)
from ray.serve.llm.openai_api_models import CompletionRequest
llm_config = LLMConfig(
model_loading_config=ModelLoadingConfig(
model_id="facebook/opt-1.3b",
model_source="facebook/opt-1.3b",
),
deployment_config=dict(
name="demo_deployment_config",
autoscaling_config=dict(
min_replicas=1,
max_replicas=1,
),
),
engine_kwargs=dict(
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_num_batched_tokens=4096,
),
)
APP_NAME = "demo_app"
DEPLOYMENT_NAME = "demo_deployment"
override_serve_options = dict(name=DEPLOYMENT_NAME)
llm_app = build_llm_deployment(
llm_config, override_serve_options=override_serve_options
)
app = serve.run(llm_app, name=APP_NAME)
config = ServeDeploymentProcessorConfig(
deployment_name=DEPLOYMENT_NAME,
app_name=APP_NAME,
dtype_mapping={
"CompletionRequest": CompletionRequest,
},
concurrency=1,
batch_size=64,
)
processor1 = build_processor(
config,
preprocess=lambda row: dict(
method="completions",
dtype="CompletionRequest",
request_kwargs=dict(
model="facebook/opt-1.3b",
prompt=f"This is a prompt for {row['id']}",
stream=False,
),
),
postprocess=lambda row: dict(
prompt=row["choices"][0]["text"],
),
)
processor2 = build_processor(
config,
preprocess=lambda row: dict(
method="completions",
dtype="CompletionRequest",
request_kwargs=dict(
model="facebook/opt-1.3b",
prompt=row["prompt"],
stream=False,
),
),
postprocess=lambda row: row,
)
ds = ray.data.range(10)
ds = processor2(processor1(ds))
print(ds.take_all())
Usage data collection: Ray collects anonymous usage data to improve Ray Data LLM. To opt out, see Ray usage stats.