Serve a Chatbot with Request and Response Streaming#
This example deploys a chatbot that streams output back to the user. It shows:
How to stream outputs from a Serve application
How to use WebSockets in a Serve application
How to combine batching requests with streaming outputs
This tutorial should help you with following use cases:
You want to serve a large language model and stream results back token-by-token.
You want to serve a chatbot that accepts a stream of inputs from the user.
This tutorial serves the DialoGPT language model. Install the Hugging Face library to access it:
pip install "ray[serve]" transformers torch
Create a streaming deployment#
Open a new Python file called textbot.py. First, add the imports and the Serve logger.
import asyncio
import logging
from queue import Empty
from fastapi import FastAPI
from starlette.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from ray import serve
logger = logging.getLogger("ray.serve")
Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:
fastapi_app = FastAPI()
@serve.deployment
@serve.ingress(fastapi_app)
class Textbot:
def __init__(self, model_id: str):
self.loop = asyncio.get_running_loop()
self.model_id = model_id
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
Note that the constructor also caches an asyncio loop. This behavior is useful when you need to run a model and concurrently stream its tokens back to the user.
Add the following logic to handle requests sent to the Textbot:
@fastapi_app.post("/")
def handle_request(self, prompt: str) -> StreamingResponse:
logger.info(f'Got prompt: "{prompt}"')
streamer = TextIteratorStreamer(
self.tokenizer, timeout=0, skip_prompt=True, skip_special_tokens=True
)
self.loop.run_in_executor(None, self.generate_text, prompt, streamer)
return StreamingResponse(
self.consume_streamer(streamer), media_type="text/plain"
)
def generate_text(self, prompt: str, streamer: TextIteratorStreamer):
input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
self.model.generate(input_ids, streamer=streamer, max_length=10000)
async def consume_streamer(self, streamer: TextIteratorStreamer):
while True:
try:
for token in streamer:
logger.info(f'Yielding token: "{token}"')
yield token
break
except Empty:
# The streamer raises an Empty exception if the next token
# hasn't been generated yet. `await` here to yield control
# back to the event loop so other coroutines can run.
await asyncio.sleep(0.001)
Textbot uses three methods to handle requests:
handle_request: the entrypoint for HTTP requests. FastAPI automatically unpacks thepromptquery parameter and passes it intohandle_request. This method then creates aTextIteratorStreamer. Hugging Face provides this streamer as a convenient interface to access tokens generated by a language model.handle_requestthen kicks off the model in a background thread usingself.loop.run_in_executor. This behavior lets the model generate tokens whilehandle_requestconcurrently callsself.consume_streamerto stream the tokens back to the user.self.consume_streameris a generator that yields tokens one by one from the streamer. Lastly,handle_requestpasses theself.consume_streamergenerator into a StarletteStreamingResponseand returns the response. Serve unpacks the StarletteStreamingResponseand yields the contents of the generator back to the user one by one.generate_text: the method that runs the model. This method runs in a background thread kicked off byhandle_request. It pushes generated tokens into the streamer constructed byhandle_request.consume_streamer: a generator method that consumes the streamer constructed byhandle_request. This method keeps yielding tokens from the streamer until the model ingenerate_textcloses the streamer. This method avoids blocking the event loop by callingasyncio.sleepwith a brief timeout whenever the streamer is empty and waiting for a new token.
Bind the Textbot to a language model. For this tutorial, use the "microsoft/DialoGPT-small" model:
app = Textbot.bind("microsoft/DialoGPT-small")
Run the model with serve run textbot:app, and query it from another terminal window with this script:
import requests
prompt = "Tell me a story about dogs."
response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
print(chunk, end="")
# Dogs are the best.
You should see the output printed token by token.
Stream inputs and outputs using WebSockets#
WebSockets let you stream input into the application and stream output back to the client. Use WebSockets to create a chatbot that stores a conversation with a user.
Create a Python file called chatbot.py. First add the imports:
import asyncio
import logging
from queue import Empty
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from ray import serve
logger = logging.getLogger("ray.serve")
Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:
fastapi_app = FastAPI()
@serve.deployment
@serve.ingress(fastapi_app)
class Chatbot:
def __init__(self, model_id: str):
self.loop = asyncio.get_running_loop()
self.model_id = model_id
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
Add the following logic to handle requests sent to the Chatbot:
@fastapi_app.websocket("/")
async def handle_request(self, ws: WebSocket) -> None:
await ws.accept()
conversation = ""
try:
while True:
prompt = await ws.receive_text()
logger.info(f'Got prompt: "{prompt}"')
conversation += prompt
streamer = TextIteratorStreamer(
self.tokenizer,
timeout=0,
skip_prompt=True,
skip_special_tokens=True,
)
self.loop.run_in_executor(
None, self.generate_text, conversation, streamer
)
response = ""
async for text in self.consume_streamer(streamer):
await ws.send_text(text)
response += text
await ws.send_text("<<Response Finished>>")
conversation += response
except WebSocketDisconnect:
print("Client disconnected.")
def generate_text(self, prompt: str, streamer: TextIteratorStreamer):
input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
self.model.generate(input_ids, streamer=streamer, max_length=10000)
async def consume_streamer(self, streamer: TextIteratorStreamer):
while True:
try:
for token in streamer:
logger.info(f'Yielding token: "{token}"')
yield token
break
except Empty:
await asyncio.sleep(0.001)
The generate_text and consume_streamer methods are the same as they were for the Textbot. The handle_request method has been updated to handle WebSocket requests.
The handle_request method is decorated with a fastapi_app.websocket decorator, which lets it accept WebSocket requests. First it awaits to accept the client’s WebSocket request. Then, until the client disconnects, it does the following:
gets the prompt from the client with
ws.receive_textstarts a new
TextIteratorStreamerto access generated tokensruns the model in a background thread on the conversation so far
streams the model’s output back using
ws.send_textstores the prompt and the response in the
conversationstring
Each time handle_request gets a new prompt from a client, it runs the whole conversation–with the new prompt appended–through the model. When the model finishes generating tokens, handle_request sends the "<<Response Finished>>" string to inform the client that the model has generated all tokens. handle_request continues to run until the client explicitly disconnects. This disconnect raises a WebSocketDisconnect exception, which ends the call.
Read more about WebSockets in the FastAPI documentation.
Bind the Chatbot to a language model. For this tutorial, use the "microsoft/DialoGPT-small" model:
app = Chatbot.bind("microsoft/DialoGPT-small")
Run the model with serve run chatbot:app. Query it using the websockets package, using pip install websockets:
from websockets.sync.client import connect
with connect("ws://localhost:8000") as websocket:
websocket.send("Space the final")
while True:
received = websocket.recv()
if received == "<<Response Finished>>":
break
print(received, end="")
print("\n")
websocket.send(" These are the voyages")
while True:
received = websocket.recv()
if received == "<<Response Finished>>":
break
print(received, end="")
print("\n")
You should see the outputs printed token by token.
Batch requests and stream the output for each#
Improve model utilization and request latency by batching requests together when running the model.
Create a Python file called batchbot.py. First add the imports:
import asyncio
import logging
from queue import Empty, Queue
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
from ray import serve
logger = logging.getLogger("ray.serve")
Warning
Hugging Face’s support for Streamers is still under development and may change in the future. RawQueue is compatible with the Streamers interface in Hugging Face 4.30.2. However, the Streamers interface may change, making the RawQueue incompatible with Hugging Face models in the future.
Similar to Textbot and Chatbot, the Batchbot needs a streamer to stream outputs from batched requests, but Hugging Face Streamers don’t support batched requests. Add this custom RawStreamer to process batches of tokens:
class RawStreamer:
def __init__(self, timeout: float = None):
self.q = Queue()
self.stop_signal = None
self.timeout = timeout
def put(self, values):
self.q.put(values)
def end(self):
self.q.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
result = self.q.get(timeout=self.timeout)
if result == self.stop_signal:
raise StopIteration()
else:
return result
Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:
fastapi_app = FastAPI()
@serve.deployment
@serve.ingress(fastapi_app)
class Batchbot:
def __init__(self, model_id: str):
self.loop = asyncio.get_running_loop()
self.model_id = model_id
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
Unlike Textbot and Chatbot, the Batchbot constructor also sets a pad_token. You need to set this token to batch prompts with different lengths.
Add the following logic to handle requests sent to the Batchbot:
@fastapi_app.post("/")
async def handle_request(self, prompt: str) -> StreamingResponse:
logger.info(f'Got prompt: "{prompt}"')
return StreamingResponse(self.run_model(prompt), media_type="text/plain")
@serve.batch(max_batch_size=2, batch_wait_timeout_s=15)
async def run_model(self, prompts: List[str]):
streamer = RawStreamer()
self.loop.run_in_executor(None, self.generate_text, prompts, streamer)
on_prompt_tokens = True
async for decoded_token_batch in self.consume_streamer(streamer):
# The first batch of tokens contains the prompts, so we skip it.
if not on_prompt_tokens:
logger.info(f"Yielding decoded_token_batch: {decoded_token_batch}")
yield decoded_token_batch
else:
logger.info(f"Skipped prompts: {decoded_token_batch}")
on_prompt_tokens = False
def generate_text(self, prompts: str, streamer: RawStreamer):
input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True).input_ids
self.model.generate(input_ids, streamer=streamer, max_length=10000)
async def consume_streamer(self, streamer: RawStreamer):
while True:
try:
for token_batch in streamer:
decoded_tokens = []
for token in token_batch:
decoded_tokens.append(
self.tokenizer.decode(token, skip_special_tokens=True)
)
logger.info(f"Yielding decoded tokens: {decoded_tokens}")
yield decoded_tokens
break
except Empty:
await asyncio.sleep(0.001)
Batchbot uses four methods to handle requests:
handle_request: the entrypoint method. This method simply takes in the request’s prompt and calls therun_modelmethod on it.run_modelis a generator method that also handles batching the requests.handle_requestpassesrun_modelinto a StarletteStreamingResponseand returns the response, so the bot can stream generated tokens back to the client.run_model: a generator method that performs batching. Sincerun_modelis decorated with@serve.batch, it automatically takes in a batch of prompts. See the batching guide for more info.run_modelcreates aRawStreamerto access the generated tokens. It callsgenerate_textin a background thread, and passes in thepromptsand thestreamer, similar to theTextbot. Then it iterates through theconsume_streamergenerator, repeatedly yielding a batch of tokens generated by the model.generate_text: the method that runs the model. It’s mostly the same asgenerate_textinTextbot, with two differences. First, it takes in and processes a batch of prompts instead of a single prompt. Second, it setspadding=True, so prompts with different lengths can be batched together.consume_streamer: a generator method that consumes the streamer constructed byhandle_request. It’s mostly the same asconsume_streamerinTextbot, with one difference. It uses thetokenizerto decode the generated tokens. Usually, the Hugging Face streamer handles the decoding. Because this implementation uses the customRawStreamer,consume_streamermust handle the decoding.
Tip
Some inputs within a batch may generate fewer outputs than others. When a particular input has nothing left to yield, pass a StopIteration object into the output iterable to terminate that input’s request. See Streaming batched requests for more details.
Bind the Batchbot to a language model. For this tutorial, use the "microsoft/DialoGPT-small" model:
app = Batchbot.bind("microsoft/DialoGPT-small")
Run the model with serve run batchbot:app. Query it from two other terminal windows with this script:
import requests
prompt = "Tell me a story about dogs."
response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
print(chunk, end="")
# Dogs are the best.
You should see the output printed token by token in both windows.