Serve a Chatbot with Request and Response Streaming#
This example deploys a chatbot that streams output back to the user. It shows:
How to stream outputs from a Serve application
How to use WebSockets in a Serve application
How to combine batching requests with streaming outputs
This tutorial should help you with following use cases:
You want to serve a large language model and stream results back token-by-token.
You want to serve a chatbot that accepts a stream of inputs from the user.
This tutorial serves the DialoGPT language model. Install the Hugging Face library to access it:
pip install transformers
Create a streaming deployment#
Open a new Python file called textbot.py
. First, add the imports and the Serve logger.
import asyncio
import logging
from queue import Empty
from fastapi import FastAPI
from starlette.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from ray import serve
logger = logging.getLogger("ray.serve")
Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:
fastapi_app = FastAPI()
@serve.deployment
@serve.ingress(fastapi_app)
class Textbot:
def __init__(self, model_id: str):
self.loop = asyncio.get_running_loop()
self.model_id = model_id
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
Note that the constructor also caches an asyncio
loop. This behavior is useful when you need to run a model and concurrently stream its tokens back to the user.
Add the following logic to handle requests sent to the Textbot
:
@fastapi_app.post("/")
def handle_request(self, prompt: str) -> StreamingResponse:
logger.info(f'Got prompt: "{prompt}"')
streamer = TextIteratorStreamer(
self.tokenizer, timeout=0, skip_prompt=True, skip_special_tokens=True
)
self.loop.run_in_executor(None, self.generate_text, prompt, streamer)
return StreamingResponse(
self.consume_streamer(streamer), media_type="text/plain"
)
def generate_text(self, prompt: str, streamer: TextIteratorStreamer):
input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
self.model.generate(input_ids, streamer=streamer, max_length=10000)
async def consume_streamer(self, streamer: TextIteratorStreamer):
while True:
try:
for token in streamer:
logger.info(f'Yielding token: "{token}"')
yield token
break
except Empty:
# The streamer raises an Empty exception if the next token
# hasn't been generated yet. `await` here to yield control
# back to the event loop so other coroutines can run.
await asyncio.sleep(0.001)
Textbot
uses three methods to handle requests:
handle_request
: the entrypoint for HTTP requests. FastAPI automatically unpacks theprompt
query parameter and passes it intohandle_request
. This method then creates aTextIteratorStreamer
. Hugging Face provides this streamer as a convenient interface to access tokens generated by a language model.handle_request
then kicks off the model in a background thread usingself.loop.run_in_executor
. This behavior lets the model generate tokens whilehandle_request
concurrently callsself.consume_streamer
to stream the tokens back to the user.self.consume_streamer
is a generator that yields tokens one by one from the streamer. Lastly,handle_request
passes theself.consume_streamer
generator into a StarletteStreamingResponse
and returns the response. Serve unpacks the StarletteStreamingResponse
and yields the contents of the generator back to the user one by one.generate_text
: the method that runs the model. This method runs in a background thread kicked off byhandle_request
. It pushes generated tokens into the streamer constructed byhandle_request
.consume_streamer
: a generator method that consumes the streamer constructed byhandle_request
. This method keeps yielding tokens from the streamer until the model ingenerate_text
closes the streamer. This method avoids blocking the event loop by callingasyncio.sleep
with a brief timeout whenever the streamer is empty and waiting for a new token.
Bind the Textbot
to a language model. For this tutorial, use the "microsoft/DialoGPT-small"
model:
app = Textbot.bind("microsoft/DialoGPT-small")
Run the model with serve run textbot:app
, and query it from another terminal window with this script:
import requests
prompt = "Tell me a story about dogs."
response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
print(chunk, end="")
# Dogs are the best.
You should see the output printed token by token.
Stream inputs and outputs using WebSockets#
WebSockets let you stream input into the application and stream output back to the client. Use WebSockets to create a chatbot that stores a conversation with a user.
Create a Python file called chatbot.py
. First add the imports:
import asyncio
import logging
from queue import Empty
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from ray import serve
logger = logging.getLogger("ray.serve")
Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:
fastapi_app = FastAPI()
@serve.deployment
@serve.ingress(fastapi_app)
class Chatbot:
def __init__(self, model_id: str):
self.loop = asyncio.get_running_loop()
self.model_id = model_id
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
Add the following logic to handle requests sent to the Chatbot
:
@fastapi_app.websocket("/")
async def handle_request(self, ws: WebSocket) -> None:
await ws.accept()
conversation = ""
try:
while True:
prompt = await ws.receive_text()
logger.info(f'Got prompt: "{prompt}"')
conversation += prompt
streamer = TextIteratorStreamer(
self.tokenizer,
timeout=0,
skip_prompt=True,
skip_special_tokens=True,
)
self.loop.run_in_executor(
None, self.generate_text, conversation, streamer
)
response = ""
async for text in self.consume_streamer(streamer):
await ws.send_text(text)
response += text
await ws.send_text("<<Response Finished>>")
conversation += response
except WebSocketDisconnect:
print("Client disconnected.")
def generate_text(self, prompt: str, streamer: TextIteratorStreamer):
input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
self.model.generate(input_ids, streamer=streamer, max_length=10000)
async def consume_streamer(self, streamer: TextIteratorStreamer):
while True:
try:
for token in streamer:
logger.info(f'Yielding token: "{token}"')
yield token
break
except Empty:
await asyncio.sleep(0.001)
The generate_text
and consume_streamer
methods are the same as they were for the Textbot
. The handle_request
method has been updated to handle WebSocket requests.
The handle_request
method is decorated with a fastapi_app.websocket
decorator, which lets it accept WebSocket requests. First it awaits
to accept the client’s WebSocket request. Then, until the client disconnects, it does the following:
gets the prompt from the client with
ws.receive_text
starts a new
TextIteratorStreamer
to access generated tokensruns the model in a background thread on the conversation so far
streams the model’s output back using
ws.send_text
stores the prompt and the response in the
conversation
string
Each time handle_request
gets a new prompt from a client, it runs the whole conversation–with the new prompt appended–through the model. When the model finishes generating tokens, handle_request
sends the "<<Response Finished>>"
string to inform the client that the model has generated all tokens. handle_request
continues to run until the client explicitly disconnects. This disconnect raises a WebSocketDisconnect
exception, which ends the call.
Read more about WebSockets in the FastAPI documentation.
Bind the Chatbot
to a language model. For this tutorial, use the "microsoft/DialoGPT-small"
model:
app = Chatbot.bind("microsoft/DialoGPT-small")
Run the model with serve run chatbot:app
. Query it using the websockets
package, using pip install websockets
:
from websockets.sync.client import connect
with connect("ws://localhost:8000") as websocket:
websocket.send("Space the final")
while True:
received = websocket.recv()
if received == "<<Response Finished>>":
break
print(received, end="")
print("\n")
websocket.send(" These are the voyages")
while True:
received = websocket.recv()
if received == "<<Response Finished>>":
break
print(received, end="")
print("\n")
You should see the outputs printed token by token.
Batch requests and stream the output for each#
Improve model utilization and request latency by batching requests together when running the model.
Create a Python file called batchbot.py
. First add the imports:
import asyncio
import logging
from queue import Empty, Queue
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
from ray import serve
logger = logging.getLogger("ray.serve")
Warning
Hugging Face’s support for Streamers
is still under development and may change in the future. RawQueue
is compatible with the Streamers
interface in Hugging Face 4.30.2. However, the Streamers
interface may change, making the RawQueue
incompatible with Hugging Face models in the future.
Similar to Textbot
and Chatbot
, the Batchbot
needs a streamer to stream outputs from batched requests, but Hugging Face Streamers
don’t support batched requests. Add this custom RawStreamer
to process batches of tokens:
class RawStreamer:
def __init__(self, timeout: float = None):
self.q = Queue()
self.stop_signal = None
self.timeout = timeout
def put(self, values):
self.q.put(values)
def end(self):
self.q.put(self.stop_signal)
def __iter__(self):
return self
def __next__(self):
result = self.q.get(timeout=self.timeout)
if result == self.stop_signal:
raise StopIteration()
else:
return result
Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:
fastapi_app = FastAPI()
@serve.deployment
@serve.ingress(fastapi_app)
class Batchbot:
def __init__(self, model_id: str):
self.loop = asyncio.get_running_loop()
self.model_id = model_id
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
Unlike Textbot
and Chatbot
, the Batchbot
constructor also sets a pad_token
. You need to set this token to batch prompts with different lengths.
Add the following logic to handle requests sent to the Batchbot
:
@fastapi_app.post("/")
async def handle_request(self, prompt: str) -> StreamingResponse:
logger.info(f'Got prompt: "{prompt}"')
return StreamingResponse(self.run_model(prompt), media_type="text/plain")
@serve.batch(max_batch_size=2, batch_wait_timeout_s=15)
async def run_model(self, prompts: List[str]):
streamer = RawStreamer()
self.loop.run_in_executor(None, self.generate_text, prompts, streamer)
on_prompt_tokens = True
async for decoded_token_batch in self.consume_streamer(streamer):
# The first batch of tokens contains the prompts, so we skip it.
if not on_prompt_tokens:
logger.info(f"Yielding decoded_token_batch: {decoded_token_batch}")
yield decoded_token_batch
else:
logger.info(f"Skipped prompts: {decoded_token_batch}")
on_prompt_tokens = False
def generate_text(self, prompts: str, streamer: RawStreamer):
input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True).input_ids
self.model.generate(input_ids, streamer=streamer, max_length=10000)
async def consume_streamer(self, streamer: RawStreamer):
while True:
try:
for token_batch in streamer:
decoded_tokens = []
for token in token_batch:
decoded_tokens.append(
self.tokenizer.decode(token, skip_special_tokens=True)
)
logger.info(f"Yielding decoded tokens: {decoded_tokens}")
yield decoded_tokens
break
except Empty:
await asyncio.sleep(0.001)
Batchbot
uses four methods to handle requests:
handle_request
: the entrypoint method. This method simply takes in the request’s prompt and calls therun_model
method on it.run_model
is a generator method that also handles batching the requests.handle_request
passesrun_model
into a StarletteStreamingResponse
and returns the response, so the bot can stream generated tokens back to the client.run_model
: a generator method that performs batching. Sincerun_model
is decorated with@serve.batch
, it automatically takes in a batch of prompts. See the batching guide for more info.run_model
creates aRawStreamer
to access the generated tokens. It callsgenerate_text
in a background thread, and passes in theprompts
and thestreamer
, similar to theTextbot
. Then it iterates through theconsume_streamer
generator, repeatedly yielding a batch of tokens generated by the model.generate_text
: the method that runs the model. It’s mostly the same asgenerate_text
inTextbot
, with two differences. First, it takes in and processes a batch of prompts instead of a single prompt. Second, it setspadding=True
, so prompts with different lengths can be batched together.consume_streamer
: a generator method that consumes the streamer constructed byhandle_request
. It’s mostly the same asconsume_streamer
inTextbot
, with one difference. It uses thetokenizer
to decode the generated tokens. Usually, the Hugging Face streamer handles the decoding. Because this implementation uses the customRawStreamer
,consume_streamer
must handle the decoding.
Tip
Some inputs within a batch may generate fewer outputs than others. When a particular input has nothing left to yield, pass a StopIteration
object into the output iterable to terminate that input’s request. See Streaming batched requests for more details.
Bind the Batchbot
to a language model. For this tutorial, use the "microsoft/DialoGPT-small"
model:
app = Batchbot.bind("microsoft/DialoGPT-small")
Run the model with serve run batchbot:app
. Query it from two other terminal windows with this script:
import requests
prompt = "Tell me a story about dogs."
response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
print(chunk, end="")
# Dogs are the best.
You should see the output printed token by token in both windows.