Serve a Chatbot with Request and Response Streaming#

This example deploys a chatbot that streams output back to the user. It shows:

  • How to stream outputs from a Serve application

  • How to use WebSockets in a Serve application

  • How to combine batching requests with streaming outputs

This tutorial should help you with following use cases:

  • You want to serve a large language model and stream results back token-by-token.

  • You want to serve a chatbot that accepts a stream of inputs from the user.

This tutorial serves the DialoGPT language model. Install the Hugging Face library to access it:

pip install transformers

Create a streaming deployment#

Open a new Python file called textbot.py. First, add the imports and the Serve logger.

import asyncio
import logging
from queue import Empty

from fastapi import FastAPI
from starlette.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from ray import serve

logger = logging.getLogger("ray.serve")

Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:

fastapi_app = FastAPI()


@serve.deployment
@serve.ingress(fastapi_app)
class Textbot:
    def __init__(self, model_id: str):
        self.loop = asyncio.get_running_loop()

        self.model_id = model_id
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

Note that the constructor also caches an asyncio loop. This behavior is useful when you need to run a model and concurrently stream its tokens back to the user.

Add the following logic to handle requests sent to the Textbot:

    @fastapi_app.post("/")
    def handle_request(self, prompt: str) -> StreamingResponse:
        logger.info(f'Got prompt: "{prompt}"')
        streamer = TextIteratorStreamer(
            self.tokenizer, timeout=0, skip_prompt=True, skip_special_tokens=True
        )
        self.loop.run_in_executor(None, self.generate_text, prompt, streamer)
        return StreamingResponse(
            self.consume_streamer(streamer), media_type="text/plain"
        )

    def generate_text(self, prompt: str, streamer: TextIteratorStreamer):
        input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
        self.model.generate(input_ids, streamer=streamer, max_length=10000)

    async def consume_streamer(self, streamer: TextIteratorStreamer):
        while True:
            try:
                for token in streamer:
                    logger.info(f'Yielding token: "{token}"')
                    yield token
                break
            except Empty:
                # The streamer raises an Empty exception if the next token
                # hasn't been generated yet. `await` here to yield control
                # back to the event loop so other coroutines can run.
                await asyncio.sleep(0.001)

Textbot uses three methods to handle requests:

  • handle_request: the entrypoint for HTTP requests. FastAPI automatically unpacks the prompt query parameter and passes it into handle_request. This method then creates a TextIteratorStreamer. Hugging Face provides this streamer as a convenient interface to access tokens generated by a language model. handle_request then kicks off the model in a background thread using self.loop.run_in_executor. This behavior lets the model generate tokens while handle_request concurrently calls self.consume_streamer to stream the tokens back to the user. self.consume_streamer is a generator that yields tokens one by one from the streamer. Lastly, handle_request passes the self.consume_streamer generator into a Starlette StreamingResponse and returns the response. Serve unpacks the Starlette StreamingResponse and yields the contents of the generator back to the user one by one.

  • generate_text: the method that runs the model. This method runs in a background thread kicked off by handle_request. It pushes generated tokens into the streamer constructed by handle_request.

  • consume_streamer: a generator method that consumes the streamer constructed by handle_request. This method keeps yielding tokens from the streamer until the model in generate_text closes the streamer. This method avoids blocking the event loop by calling asyncio.sleep with a brief timeout whenever the streamer is empty and waiting for a new token.

Bind the Textbot to a language model. For this tutorial, use the "microsoft/DialoGPT-small" model:

app = Textbot.bind("microsoft/DialoGPT-small")

Run the model with serve run textbot:app, and query it from another terminal window with this script:

import requests

prompt = "Tell me a story about dogs."

response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
    print(chunk, end="")

    # Dogs are the best.

You should see the output printed token by token.

Stream inputs and outputs using WebSockets#

WebSockets let you stream input into the application and stream output back to the client. Use WebSockets to create a chatbot that stores a conversation with a user.

Create a Python file called chatbot.py. First add the imports:

import asyncio
import logging
from queue import Empty

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from ray import serve

logger = logging.getLogger("ray.serve")

Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:

fastapi_app = FastAPI()


@serve.deployment
@serve.ingress(fastapi_app)
class Chatbot:
    def __init__(self, model_id: str):
        self.loop = asyncio.get_running_loop()

        self.model_id = model_id
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

Add the following logic to handle requests sent to the Chatbot:

    @fastapi_app.websocket("/")
    async def handle_request(self, ws: WebSocket) -> None:
        await ws.accept()

        conversation = ""
        try:
            while True:
                prompt = await ws.receive_text()
                logger.info(f'Got prompt: "{prompt}"')
                conversation += prompt
                streamer = TextIteratorStreamer(
                    self.tokenizer,
                    timeout=0,
                    skip_prompt=True,
                    skip_special_tokens=True,
                )
                self.loop.run_in_executor(
                    None, self.generate_text, conversation, streamer
                )
                response = ""
                async for text in self.consume_streamer(streamer):
                    await ws.send_text(text)
                    response += text
                await ws.send_text("<<Response Finished>>")
                conversation += response
        except WebSocketDisconnect:
            print("Client disconnected.")

    def generate_text(self, prompt: str, streamer: TextIteratorStreamer):
        input_ids = self.tokenizer([prompt], return_tensors="pt").input_ids
        self.model.generate(input_ids, streamer=streamer, max_length=10000)

    async def consume_streamer(self, streamer: TextIteratorStreamer):
        while True:
            try:
                for token in streamer:
                    logger.info(f'Yielding token: "{token}"')
                    yield token
                break
            except Empty:
                await asyncio.sleep(0.001)


The generate_text and consume_streamer methods are the same as they were for the Textbot. The handle_request method has been updated to handle WebSocket requests.

The handle_request method is decorated with a fastapi_app.websocket decorator, which lets it accept WebSocket requests. First it awaits to accept the client’s WebSocket request. Then, until the client disconnects, it does the following:

  • gets the prompt from the client with ws.receive_text

  • starts a new TextIteratorStreamer to access generated tokens

  • runs the model in a background thread on the conversation so far

  • streams the model’s output back using ws.send_text

  • stores the prompt and the response in the conversation string

Each time handle_request gets a new prompt from a client, it runs the whole conversation–with the new prompt appended–through the model. When the model finishes generating tokens, handle_request sends the "<<Response Finished>>" string to inform the client that the model has generated all tokens. handle_request continues to run until the client explicitly disconnects. This disconnect raises a WebSocketDisconnect exception, which ends the call.

Read more about WebSockets in the FastAPI documentation.

Bind the Chatbot to a language model. For this tutorial, use the "microsoft/DialoGPT-small" model:

app = Chatbot.bind("microsoft/DialoGPT-small")

Run the model with serve run chatbot:app. Query it using the websockets package, using pip install websockets:

from websockets.sync.client import connect

with connect("ws://localhost:8000") as websocket:
    websocket.send("Space the final")
    while True:
        received = websocket.recv()
        if received == "<<Response Finished>>":
            break
        print(received, end="")
    print("\n")

    websocket.send(" These are the voyages")
    while True:
        received = websocket.recv()
        if received == "<<Response Finished>>":
            break
        print(received, end="")
    print("\n")

You should see the outputs printed token by token.

Batch requests and stream the output for each#

Improve model utilization and request latency by batching requests together when running the model.

Create a Python file called batchbot.py. First add the imports:

import asyncio
import logging
from queue import Empty, Queue

from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer

from ray import serve

logger = logging.getLogger("ray.serve")

Warning

Hugging Face’s support for Streamers is still under development and may change in the future. RawQueue is compatible with the Streamers interface in Hugging Face 4.30.2. However, the Streamers interface may change, making the RawQueue incompatible with Hugging Face models in the future.

Similar to Textbot and Chatbot, the Batchbot needs a streamer to stream outputs from batched requests, but Hugging Face Streamers don’t support batched requests. Add this custom RawStreamer to process batches of tokens:

class RawStreamer:
    def __init__(self, timeout: float = None):
        self.q = Queue()
        self.stop_signal = None
        self.timeout = timeout

    def put(self, values):
        self.q.put(values)

    def end(self):
        self.q.put(self.stop_signal)

    def __iter__(self):
        return self

    def __next__(self):
        result = self.q.get(timeout=self.timeout)
        if result == self.stop_signal:
            raise StopIteration()
        else:
            return result


Create a FastAPI deployment, and initialize the model and the tokenizer in the constructor:

fastapi_app = FastAPI()


@serve.deployment
@serve.ingress(fastapi_app)
class Batchbot:
    def __init__(self, model_id: str):
        self.loop = asyncio.get_running_loop()

        self.model_id = model_id
        self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
        self.tokenizer.pad_token = self.tokenizer.eos_token

Unlike Textbot and Chatbot, the Batchbot constructor also sets a pad_token. You need to set this token to batch prompts with different lengths.

Add the following logic to handle requests sent to the Batchbot:

    @fastapi_app.post("/")
    async def handle_request(self, prompt: str) -> StreamingResponse:
        logger.info(f'Got prompt: "{prompt}"')
        return StreamingResponse(self.run_model(prompt), media_type="text/plain")

    @serve.batch(max_batch_size=2, batch_wait_timeout_s=15)
    async def run_model(self, prompts: List[str]):
        streamer = RawStreamer()
        self.loop.run_in_executor(None, self.generate_text, prompts, streamer)
        on_prompt_tokens = True
        async for decoded_token_batch in self.consume_streamer(streamer):
            # The first batch of tokens contains the prompts, so we skip it.
            if not on_prompt_tokens:
                logger.info(f"Yielding decoded_token_batch: {decoded_token_batch}")
                yield decoded_token_batch
            else:
                logger.info(f"Skipped prompts: {decoded_token_batch}")
                on_prompt_tokens = False

    def generate_text(self, prompts: str, streamer: RawStreamer):
        input_ids = self.tokenizer(prompts, return_tensors="pt", padding=True).input_ids
        self.model.generate(input_ids, streamer=streamer, max_length=10000)

    async def consume_streamer(self, streamer: RawStreamer):
        while True:
            try:
                for token_batch in streamer:
                    decoded_tokens = []
                    for token in token_batch:
                        decoded_tokens.append(
                            self.tokenizer.decode(token, skip_special_tokens=True)
                        )
                    logger.info(f"Yielding decoded tokens: {decoded_tokens}")
                    yield decoded_tokens
                break
            except Empty:
                await asyncio.sleep(0.001)


Batchbot uses four methods to handle requests:

  • handle_request: the entrypoint method. This method simply takes in the request’s prompt and calls the run_model method on it. run_model is a generator method that also handles batching the requests. handle_request passes run_model into a Starlette StreamingResponse and returns the response, so the bot can stream generated tokens back to the client.

  • run_model: a generator method that performs batching. Since run_model is decorated with @serve.batch, it automatically takes in a batch of prompts. See the batching guide for more info. run_model creates a RawStreamer to access the generated tokens. It calls generate_text in a background thread, and passes in the prompts and the streamer, similar to the Textbot. Then it iterates through the consume_streamer generator, repeatedly yielding a batch of tokens generated by the model.

  • generate_text: the method that runs the model. It’s mostly the same as generate_text in Textbot, with two differences. First, it takes in and processes a batch of prompts instead of a single prompt. Second, it sets padding=True, so prompts with different lengths can be batched together.

  • consume_streamer: a generator method that consumes the streamer constructed by handle_request. It’s mostly the same as consume_streamer in Textbot, with one difference. It uses the tokenizer to decode the generated tokens. Usually, the Hugging Face streamer handles the decoding. Because this implementation uses the custom RawStreamer, consume_streamer must handle the decoding.

Tip

Some inputs within a batch may generate fewer outputs than others. When a particular input has nothing left to yield, pass a StopIteration object into the output iterable to terminate that input’s request. See Streaming batched requests for more details.

Bind the Batchbot to a language model. For this tutorial, use the "microsoft/DialoGPT-small" model:

app = Batchbot.bind("microsoft/DialoGPT-small")

Run the model with serve run batchbot:app. Query it from two other terminal windows with this script:

import requests

prompt = "Tell me a story about dogs."

response = requests.post(f"http://localhost:8000/?prompt={prompt}", stream=True)
response.raise_for_status()
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
    print(chunk, end="")

    # Dogs are the best.

You should see the output printed token by token in both windows.