Online serving#

 

https://img.shields.io/static/v1?label=&message=View%20On%20GitHub&color=586069&logo=github&labelColor=2f363d 

This tutorial launches an online service that:

  • deploys model artifacts to generate predictions

  • deploys custom logic, for example, fetch similar images and filter by top k predictions

  • autoscale based on incoming traffic

  • cover observability and debugging around a service

Ray Serve is a highly scalable and flexible model serving library for building online inference APIs that allows you to:

  • Wrap models and business logic as separate serve deployments and connect them together (pipeline, ensemble, etc.)

  • Avoid one large service that’s network and compute bounded and an inefficient use of resources.

  • Utilize fractional heterogeneous resources, which isn’t possible with SageMaker, Vertex, KServe, etc., and horizontally scale withnum_replicas.

  • autoscale up and down based on traffic.

  • Integrate with FastAPI and HTTP.

  • Set up a gRPC service to build distributed systems and microservices.

  • Enable dynamic batching based on batch size, time, etc.

  • Access a suite of utilities for serving LLMs that are inference-engine agnostic and have batteries-included support for LLM-specific features such as multi-LoRA support

https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/ray_serve.png
%%bash
pip install -q "matplotlib==3.10.0" "torch==2.7.0" "transformers==4.52.3" "scikit-learn==1.6.0" "mlflow==2.19.0" "ipywidgets==8.1.3"
%load_ext autoreload
%autoreload all
import os
import ray
import sys
sys.path.append(os.path.abspath(".."))
ray.init(runtime_env={"working_dir": "../"})
2025-05-29 17:39:48,311	INFO worker.py:1660 -- Connecting to existing Ray cluster at address: 10.0.56.137:6379...
2025-05-29 17:39:48,322	INFO worker.py:1843 -- Connected to Ray cluster. View the dashboard at https://session-gcwehd9xxjzkv5lxv8lgcdgx2n.i.anyscaleuserdata.com 
2025-05-29 17:39:48,414	INFO packaging.py:575 -- Creating a file package for local module '../'.
2025-05-29 17:39:48,476	WARNING packaging.py:417 -- File /home/ray/default/foundational-ray-app/notebooks/../.git/objects/pack/pack-b8b7f3cf34764341ace726e9197e18f11b5aaedc.pack is very large (15.84MiB). Consider adding this file to the 'excludes' list to skip uploading it: `ray.init(..., runtime_env={'excludes': ['/home/ray/default/foundational-ray-app/notebooks/../.git/objects/pack/pack-b8b7f3cf34764341ace726e9197e18f11b5aaedc.pack']})`
2025-05-29 17:39:48,540	INFO packaging.py:367 -- Pushing file package 'gcs://_ray_pkg_f145abd4c54ca9ee.zip' (29.53MiB) to Ray cluster...
2025-05-29 17:39:48,670	INFO packaging.py:380 -- Successfully pushed file package 'gcs://_ray_pkg_f145abd4c54ca9ee.zip'.
import os
from fastapi import FastAPI
import mlflow
import requests
from starlette.requests import Request
from urllib.parse import urlparse
from ray import serve
from doggos.embed import EmbeddingGenerator, get_top_matches
from doggos.infer import TorchPredictor
from doggos.utils import url_to_array

Deployments#

First create a deployment for the trained model that generates a probability distribution for a given image URL. You can specify the compute you want to use with ray_actor_options, and how you want to horizontally scale, with num_replicas, this specific deployment.

@serve.deployment(
    num_replicas="1", 
    ray_actor_options={
        "num_cpus": 2, 
        "num_gpus": 1, 
        "accelerator_type": "L4",
    },
)
class ClassPredictor:
    def __init__(self, artifacts_dir):
        """Initialize the model."""
        self.predictor = TorchPredictor.from_artifacts_dir(artifacts_dir=artifacts_dir)
        self.preprocessor = self.predictor.preprocessor

    def get_probabilities(self, url):
        image = url_to_array(url=url)
        ds = ray.data.from_items([{"image": image, "url": url}])
        ds = self.preprocessor.transform(
            ds=ds,
            concurrency=1,
            batch_size=1,
            num_gpus=1,
        )
        ds = ds.map_batches(
            self.predictor.predict_probabilities,
            fn_kwargs={"device": "cuda"},
            concurrency=1,
            batch_size=1,
            num_gpus=1,
        )
        probabilities = ds.take_all()[0]["probabilities"]
        return probabilities

Do the same for another deployment to use the input image, embed it, and then retrieve the top similar matches. However, it uses the top k predictions from the model as a filter, class_filters, for retrieving the similar images.

@serve.deployment(
    num_replicas="1", 
    ray_actor_options={
        "num_cpus": 2, 
        "num_gpus": 1, 
        "accelerator_type": "L4",
    },
)
class EmbeddingSimilarity:
    def __init__(self, embeddings_path):
        self.embedding_generator = EmbeddingGenerator(model_id="openai/clip-vit-base-patch32")
        self.embeddings_ds = ray.data.read_parquet(embeddings_path)  # use vector DB

    def get_top_matches(self, url, probabilities, k):
        # Top k class predictions
        sorted_probabilities = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
        top_k = [item[0] for item in sorted_probabilities[0:k]]

        # Generate embedding.
        image = url_to_array(url=url)
        embedding = self.embedding_generator({"image": [image]})["embedding"][0]

        # Filter for top matches.
        top_matches = get_top_matches(
            query_embedding=embedding,
            embeddings_ds=self.embeddings_ds,
            class_filters=top_k,
            n=5,
        )
        return top_matches

You can combine all the service deployments into one application. Here the set up is a simple pipeline, where you pass the model’s predictions to the embedding similarity service deployment.

🧱 Model composition

Ray Serve makes it easy to do model composition where you can compose multiple deployments containing ML models or business logic into a single application. You can independently scale even fractional resources, and configure each of your deployments.

https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/serve_composition.png

Application#

# Define app.
api = FastAPI(
    title="doggos", 
    description="classify your dog", 
    version="0.1",
)
@serve.deployment(num_replicas="1", ray_actor_options={"num_cpus": 2})
@serve.ingress(api)
class Doggos:
    def __init__(self, classifier, embedder):
        self.classifier = classifier
        self.embedder = embedder
        
    @api.post("/predict/")
    async def predict(self, request: Request):
        data = await request.json()
        probabilities = await self.classifier.get_probabilities.remote(url=data["url"])
        top_matches = await self.embedder.get_top_matches.remote(
            url=data["url"],
            probabilities=probabilities,
            k=data["k"],
        )
        return {
            "probabilities": probabilities, 
            "top_matches": top_matches,
        }
# Model registry.
model_registry = "/mnt/cluster_storage/mlflow/doggos"
experiment_name = "doggos"
mlflow.set_tracking_uri(f"file:{model_registry}")
embeddings_path = os.path.join( "/mnt/cluster_storage", "doggos/embeddings")
# Get best_run's artifact_dir.
mlflow.set_tracking_uri(f"file:{model_registry}")
sorted_runs = mlflow.search_runs(
    experiment_names=[experiment_name], 
    order_by=["metrics.val_loss ASC"])
best_run = sorted_runs.iloc[0]
artifacts_dir = urlparse(best_run.artifact_uri).path
# Define app.
app = Doggos.bind(
    classifier=ClassPredictor.bind(artifacts_dir=artifacts_dir),
    embedder=EmbeddingSimilarity.bind(embeddings_path=embeddings_path),
)
# Run service locally.
serve.run(app, route_prefix="/")
(ProxyActor pid=113956) INFO 2025-05-29 17:40:02,366 proxy 10.0.56.137 -- Proxy starting on node 433b62a87a85338338747478dd7513ad3aeaf6bc14e890709b747eec (HTTP port: 8000).
INFO 2025-05-29 17:40:02,478 serve 113749 -- Started Serve in namespace "serve".
(ServeController pid=113895) INFO 2025-05-29 17:40:02,504 controller 113895 -- Deploying new version of Deployment(name='ClassPredictor', app='default') (initial target replicas: 1).
(ServeController pid=113895) INFO 2025-05-29 17:40:02,505 controller 113895 -- Deploying new version of Deployment(name='EmbeddingSimilarity', app='default') (initial target replicas: 1).
(ServeController pid=113895) INFO 2025-05-29 17:40:02,506 controller 113895 -- Deploying new version of Deployment(name='Doggos', app='default') (initial target replicas: 1).
(ProxyActor pid=113956) INFO 2025-05-29 17:40:02,441 proxy 10.0.56.137 -- Got updated endpoints: {}.
(ProxyActor pid=113956) INFO 2025-05-29 17:40:02,509 proxy 10.0.56.137 -- Got updated endpoints: {Deployment(name='Doggos', app='default'): EndpointInfo(route='/', app_is_cross_language=False)}.
(ServeController pid=113895) INFO 2025-05-29 17:40:02,611 controller 113895 -- Adding 1 replica to Deployment(name='ClassPredictor', app='default').
(ServeController pid=113895) INFO 2025-05-29 17:40:02,613 controller 113895 -- Adding 1 replica to Deployment(name='EmbeddingSimilarity', app='default').
(ServeController pid=113895) INFO 2025-05-29 17:40:02,615 controller 113895 -- Adding 1 replica to Deployment(name='Doggos', app='default').
(ProxyActor pid=113956) INFO 2025-05-29 17:40:02,533 proxy 10.0.56.137 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x78c90ee61e50>.
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
(ProxyActor pid=27365, ip=10.0.188.182) INFO 2025-05-29 17:40:06,609 proxy 10.0.188.182 -- Proxy starting on node 83642d8a20226a0114da9be1953fe119a133bf448b1a4c2abfed80c3 (HTTP port: 8000).
(ProxyActor pid=27365, ip=10.0.188.182) INFO 2025-05-29 17:40:06,670 proxy 10.0.188.182 -- Got updated endpoints: {Deployment(name='Doggos', app='default'): EndpointInfo(route='/', app_is_cross_language=False)}.
(ProxyActor pid=27365, ip=10.0.188.182) INFO 2025-05-29 17:40:06,692 proxy 10.0.188.182 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x75f4c8495d60>.
(ProxyActor pid=33803, ip=10.0.153.142) INFO 2025-05-29 17:40:09,665 proxy 10.0.153.142 -- Proxy starting on node 8bcf6da3780835acf76bf498a98be8b9617ee718765d73f10de0f4a3 (HTTP port: 8000).
(ProxyActor pid=33803, ip=10.0.153.142) INFO 2025-05-29 17:40:09,719 proxy 10.0.153.142 -- Got updated endpoints: {Deployment(name='Doggos', app='default'): EndpointInfo(route='/', app_is_cross_language=False)}.
(ProxyActor pid=33803, ip=10.0.153.142) INFO 2025-05-29 17:40:09,741 proxy 10.0.153.142 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x7eeedcf377a0>.
INFO 2025-05-29 17:40:10,613 serve 113749 -- Application 'default' is ready at http://127.0.0.1:8000/.
INFO 2025-05-29 17:40:10,622 serve 113749 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x7912fe07b950>.
DeploymentHandle(deployment='Doggos')
# Send a request.
url = "https://doggos-dataset.s3.us-west-2.amazonaws.com/samara.png"
data = {"url": url, "k": 4}
response = requests.post("http://127.0.0.1:8000/predict/", json=data)
response.json()
(ServeReplica:default:Doggos pid=27296, ip=10.0.188.182) INFO 2025-05-29 17:40:10,757 default_Doggos 0964vqxx 5759bf7a-7345-415c-8005-f8ef6e5a6041 -- Started <ray.serve._private.router.SharedRouterLongPollClient object at 0x7c5a0afc5ca0>.
(ServeReplica:default:ClassPredictor pid=33585, ip=10.0.153.142) /home/ray/anaconda3/lib/python3.12/site-packages/ray/serve/_private/replica.py:1320: UserWarning: Calling sync method 'get_probabilities' directly on the asyncio loop. In a future version, sync methods will be run in a threadpool by default. Ensure your sync methods are thread safe or keep the existing behavior by making them `async def`. Opt into the new behavior by setting RAY_SERVE_RUN_SYNC_IN_THREADPOOL=1.
(ServeReplica:default:ClassPredictor pid=33585, ip=10.0.153.142)   warnings.warn(
(ServeReplica:default:ClassPredictor pid=33585, ip=10.0.153.142) Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-05-29_17-07-53_816345_69024/logs/ray-data
(ServeReplica:default:ClassPredictor pid=33585, ip=10.0.153.142) Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[Map(Preprocessor.convert_to_label)] -> ActorPoolMapOperator[MapBatches(EmbeddingGenerator)] -> TaskPoolMapOperator[MapBatches(drop_columns)] -> TaskPoolMapOperator[MapBatches(TorchPredictor.predict_probabilities)]
(_MapWorker pid=33905, ip=10.0.153.142) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
(ServeReplica:default:ClassPredictor pid=33585, ip=10.0.153.142) INFO 2025-05-29 17:40:31,871 default_ClassPredictor orvh66v5 5759bf7a-7345-415c-8005-f8ef6e5a6041 -- CALL /predict/ OK 21100.2ms
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142) /home/ray/anaconda3/lib/python3.12/site-packages/ray/serve/_private/replica.py:1320: UserWarning: Calling sync method 'get_top_matches' directly on the asyncio loop. In a future version, sync methods will be run in a threadpool by default. Ensure your sync methods are thread safe or keep the existing behavior by making them `async def`. Opt into the new behavior by setting RAY_SERVE_RUN_SYNC_IN_THREADPOOL=1.
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142)   warnings.warn(
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142) /home/ray/anaconda3/lib/python3.12/site-packages/ray/data/dataset.py:1331: UserWarning: Use 'expr' instead of 'fn' when possible for performant filters.
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142)   warnings.warn(
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142) Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142) Starting execution of Dataset. Full logs are in /tmp/ray/session_2025-05-29_17-07-53_816345_69024/logs/ray-data
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142) Execution plan of Dataset: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Filter(<lambda>)] -> TaskPoolMapOperator[MapBatches(compute_similarities)] -> AllToAllOperator[Sort] -> LimitOperator[limit=5]
{'probabilities': {'basset': 0.002643201034516096,
  'beagle': 0.018345436081290245,
  'bernese_mountain_dog': 0.15381267666816711,
  'bloodhound': 0.00039815512718632817,
  'border_collie': 0.09815514087677002,
  'boxer': 0.0015894499374553561,
  'bull_mastiff': 0.0002529473276808858,
  'chihuahua': 0.01949913240969181,
  'chow': 0.0006135992589406669,
  'cocker_spaniel': 0.004434243775904179,
  'collie': 0.2384808510541916,
  'dingo': 0.030564747750759125,
  'doberman': 0.001398741384036839,
  'eskimo_dog': 0.05846692994236946,
  'french_bulldog': 0.012217871844768524,
  'german_shepherd': 0.14378124475479126,
  'golden_retriever': 0.006297916639596224,
  'great_dane': 0.00164966587908566,
  'great_pyrenees': 0.00880351010710001,
  'italian_greyhound': 0.00497599458321929,
  'labrador_retriever': 0.001772725023329258,
  'malamute': 0.08132066577672958,
  'malinois': 0.0032257658895105124,
  'maltese_dog': 0.007002878934144974,
  'miniature_poodle': 0.002071917988359928,
  'miniature_schnauzer': 0.0036232389975339174,
  'pomeranian': 0.003027325263246894,
  'pug': 0.0008073864737525582,
  'rottweiler': 0.0028987338300794363,
  'saint_bernard': 0.006798494607210159,
  'samoyed': 0.002514089923352003,
  'shih_tzu': 0.01896044984459877,
  'siberian_husky': 0.016913356259465218,
  'tibetan_mastiff': 0.003269634675234556,
  'toy_poodle': 0.0006831369246356189,
  'yorkshire_terrier': 0.0387287363409996},
 'top_matches': [{'class': 'border_collie',
   'path': 'doggos-dataset/train/border_collie/border_collie_4107.jpg',
   'similarity': 0.817649365682867},
  {'class': 'border_collie',
   'path': 'doggos-dataset/train/border_collie/border_collie_7074.jpg',
   'similarity': 0.8066087255594375},
  {'class': 'collie',
   'path': 'doggos-dataset/train/collie/collie_7572.jpg',
   'similarity': 0.8050143036629519},
  {'class': 'border_collie',
   'path': 'doggos-dataset/train/border_collie/border_collie_90.jpg',
   'similarity': 0.8028407534978105},
  {'class': 'collie',
   'path': 'doggos-dataset/train/collie/collie_18685.jpg',
   'similarity': 0.8025033848054364}]}
# Top class predictions.
probabilities = response.json()["probabilities"]
sorted_probabilities = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
sorted_probabilities[0:3]
(ServeReplica:default:Doggos pid=27296, ip=10.0.188.182) INFO 2025-05-29 17:40:39,981 default_Doggos 0964vqxx 5759bf7a-7345-415c-8005-f8ef6e5a6041 -- POST /predict/ 200 29246.8ms
(ServeReplica:default:EmbeddingSimilarity pid=33586, ip=10.0.153.142) INFO 2025-05-29 17:40:39,979 default_EmbeddingSimilarity vmdcfuzx 5759bf7a-7345-415c-8005-f8ef6e5a6041 -- CALL /predict/ OK 8086.8ms
[('collie', 0.2384808510541916),
 ('bernese_mountain_dog', 0.15381267666816711),
 ('german_shepherd', 0.14378124475479126)]
# Top matches by embedding similarity
from doggos.embed import display_top_matches
display_top_matches(url=url, matches=response.json()["top_matches"])
../../../../_images/839cbbfa2a74cd45abec050d71878de32f0ce817217a91e8ab6edeb954e96e73.png
🔎 Observability for services

The Ray dashboard and specifically the Serve view automatically captures observability for Ray Serve applications. You can view the service deployments and their replicas and time-series metrics to see the service’s health.

https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/serve_dashboard.png

Production services#

Anyscale Services

Anyscale Services (API ref) offers a fault tolerant, scalable, and optimized way to serve Ray Serve applications. You can:

  • rollout and update services with canary deployment and zero-downtime upgrades.

  • monitor services through a dedicated service page, unified log viewer, tracing, set up alerts, etc.

  • scale a service, with num_replicas=auto, and utilize replica compaction to consolidate nodes that are fractionally utilized.

  • get head node fault tolerance. OSS Ray recovers from failed workers and replicas but not head node crashes.

  • serve multiple applications in a single service.

https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/canary.png

RayTurbo Serve on Anyscale has more functionality on top of Ray Serve:

  • fast autoscaling and model loading to get services up and running even faster with 5x improvements even for LLMs.

  • 54% higher QPS and up-to 3x streaming tokens per second for high traffic serving use-cases with no proxy bottlenecks.

  • replica compaction into fewer nodes where possible to reduce resource fragmentation and improve hardware utilization.

  • zero-downtime incremental rollouts so your service is never interrupted.

  • different environments for each service in a multi-serve application.

  • multi availability-zone aware scheduling of Ray Serve replicas to provide higher redundancy to availability zone failures.

Note:

  • This tutorial uses a containerfile to define dependencies, but you could easily use a pre-built image as well.

  • You can specify the compute as a compute config or inline in a Service config file.

  • When you don’t specify compute while launching from a workspace, this configuration defaults to the compute configuration of the workspace.

# Production online service.
anyscale service deploy doggos.serve:app --name=doggos-app \
    --containerfile="/home/ray/default/containerfile" \
    --working-dir="/home/ray/default" \
    --exclude=""
(anyscale +1.9s) Restarting existing service 'doggos-app'.
(anyscale +3.2s) Uploading local dir '/home/ray/default' to cloud storage.
(anyscale +5.2s) Including workspace-managed pip dependencies.
(anyscale +5.8s) Service 'doggos-app' deployed (version ID: akz9ul28).
(anyscale +5.8s) View the service in the UI: 'https://console.anyscale.com/services/service2_6hxismeqf1fkd2h7pfmljmncvm'
(anyscale +5.8s) Query the service once it's running using the following curl command (add the path you want to query):
(anyscale +5.8s) curl -H "Authorization: Bearer 4syAhpiTQPW_iskqwprfFNzfO_ZSeJY_SVNRrmDdm98" https://doggos-app-bxauk.cld-kvedzwag2qa8i5bj.s.anyscaleuserdata.com/
curl -X POST "https://doggos-app-bxauk.cld-kvedzwag2qa8i5bj.s.anyscaleuserdata.com/predict/" \
     -H "Authorization: Bearer <BEARER_TOKEN>" \
     -H "Content-Type: application/json" \
     -d '{"url": "https://doggos-dataset.s3.us-west-2.amazonaws.com/samara.png", "k": 4}'
# Terminate service.
anyscale service terminate --name doggos-app
(anyscale +1.5s) Service service2_6hxismeqf1fkd2h7pfmljmncvm terminate initiated.
(anyscale +1.5s) View the service in the UI at https://console.anyscale.com/services/service2_6hxismeqf1fkd2h7pfmljmncvm

CI/CD#

CI/CD

While Anyscale Jobs and Services are useful atomic concepts that help you productionize workloads, they’re also useful for nodes in a larger ML DAG or CI/CD workflow. You can chain Jobs together, store results and then serve your application with those artifacts. From there, you can trigger updates to your service and retrigger the Jobs based on events, time, etc. While you can simply use the Anyscale CLI to integrate with any orchestration platform, Anyscale does support some purpose-built integrations like Airflow and Prefect.

https://raw.githubusercontent.com/anyscale/foundational-ray-app/refs/heads/main/images/cicd.png