Batch inference#
This tutorial executes a batch inference workload that connects the following heterogeneous workloads:
distributed read from cloud storage (CPU)
apply distributed preprocessing (CPU)
batch inference (GPU)
distributed write to cloud storage (CPU)

%%bash
pip install -q "matplotlib==3.10.0" "torch==2.7.1" "transformers==4.52.3" "scikit-learn==1.6.0" "mlflow==2.19.0" "ipywidgets==8.1.3"
Successfully registered `matplotlib, torch` and 4 other packages to be installed on all cluster nodes.
View and update dependencies here: https://console.anyscale.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_cz951f43jjdybtzkx1s5sjgz99/workspaces/expwrk_eys8cskj5aivghbf773dp2vmcd?workspace-tab=dependencies
%load_ext autoreload
%autoreload all
import os
import ray
import sys
sys.path.append(os.path.abspath(".."))
Data ingestion#
Start by reading the data from a public cloud storage bucket.
# Load data.
ds = ray.data.read_images(
"s3://doggos-dataset/train",
include_paths=True,
shuffle="files",
)
ds.take(1)
2025-06-23 14:04:39,249 INFO worker.py:1723 -- Connecting to existing Ray cluster at address: 10.0.52.172:6379...
2025-06-23 14:04:39,260 INFO worker.py:1908 -- Connected to Ray cluster. View the dashboard at https://session-gcwehd9xxjzkv5lxv8lgcdgx2n.i.anyscaleuserdata.com
2025-06-23 14:04:39,266 INFO packaging.py:380 -- Pushing file package 'gcs://_ray_pkg_a644723e367c78760222a7f2fcce949b2fe72f7b.zip' (1.92MiB) to Ray cluster...
2025-06-23 14:04:39,275 INFO packaging.py:393 -- Successfully pushed file package 'gcs://_ray_pkg_a644723e367c78760222a7f2fcce949b2fe72f7b.zip'.
2025-06-23 14:04:39,581 INFO dataset.py:3048 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-06-23 14:04:39,583 INFO logging.py:295 -- Registered dataset logger for dataset dataset_11_0
2025-06-23 14:04:39,594 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_11_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:04:39,595 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_11_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> LimitOperator[limit=1]
2025-06-23 14:05:56,467 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_11_0 execution finished in 76.87 seconds
[{'image': array([[[ 32, 52, 77],
[ 27, 47, 72],
[ 28, 43, 72],
...,
[235, 235, 233],
[236, 236, 234],
[236, 236, 234]],
[[ 34, 51, 77],
[ 30, 47, 73],
[ 30, 45, 74],
...,
[233, 233, 231],
[233, 233, 231],
[233, 233, 231]],
[[ 35, 50, 79],
[ 32, 47, 76],
[ 33, 48, 79],
...,
[237, 237, 237],
[237, 237, 237],
[237, 237, 237]],
...,
[[ 55, 80, 76],
[ 65, 90, 86],
[ 56, 78, 75],
...,
[142, 168, 133],
[157, 184, 149],
[140, 170, 132]],
[[ 52, 72, 70],
[ 77, 97, 95],
[ 78, 97, 95],
...,
[125, 151, 112],
[141, 169, 128],
[180, 211, 167]],
[[ 92, 108, 107],
[123, 139, 138],
[135, 149, 149],
...,
[125, 152, 109],
[ 87, 116, 68],
[127, 159, 109]]], dtype=uint8),
'path': 'doggos-dataset/train/saint_bernard/saint_bernard_7024.jpg'}]
Ray Data supports a wide range of data sources for both loading and saving from generic binary files in cloud storage to structured data formats used by modern data platforms. This example reads data from a public S3 bucket prepared with the dataset. This read
operation, much like the write
operation in a later step, runs in a distributed fashion. As a result, Ray Data processes the data in parallel across the cluster and doesn’t need to load the data entirely into memory at once, making data loading scalable and memory-efficient.
trigger lazy execution: use
take
to trigger the execution because Ray has lazy execution mode, which decreases execution time and memory utilization. But, this approach means that you need an operation like take, count, write, etc., to actually execute the workflow DAG.shuffling strategies: to shuffle the dataset because it’s all ordered by class, randomly shuffle the ordering of input files before reading. Ray Data also provides an extensive list of shuffling strategies such as local shuffles, per-epoch shuffles, etc.
materialize
during development: usematerialize
to execute and materialize the dataset into Ray’s shared memory object store memory. This way, you save a checkpoint at this point and future operations on the dataset can start from this point. You won’t rerun all operations on the dataset again from scratch. This feature is convenient during development, especially in a stateful environment like Jupyter notebooks, because you can run from saved checkpoints.ds = ds.map(...) ds = ds.materialize()
Note: only use this during development and use it with small datasets, as it will load it all into memory.
You also want to add the class for each data point. When reading the data with include_paths
Ray Data saves the filename with each data point. The filename has the class label in it so add that to each data point’s row. Use Ray Data’s map function to apply the function to each row.
def add_class(row):
row["class"] = row["path"].rsplit("/", 3)[-2]
return row
# Add class.
ds = ds.map(add_class,
num_cpus=1,
num_gpus=0,
concurrency=4)
❌ Traditional batch execution, for example, non-streaming like Spark without pipelining, SageMaker Batch Transform:
Reads the entire dataset into memory or a persistent intermediate format.
Only then starts applying transformations like .map, .filter, etc.
Higher memory pressure and startup latency.
✅ Streaming execution with Ray Data:
Starts processing chunks (“blocks”) as they’re loaded. No need to wait for entire dataset to load.
Reduces memory footprint (no OOMs) and speeds up time to first output.
Increase resource utilization by reducing idle time.
Online-style inference pipelines with minimal latency.

Note: Ray Data isn’t a real-time stream processing engine like Flink or Kafka Streams. Instead, it’s batch processing with streaming execution, which is especially useful for iterative ML workloads, ETL pipelines, and preprocessing before training or inference. Ray typically has a 2-17x throughput improvement over solutions like Spark and SageMaker Batch Transform, etc.
Batch embeddings#
The previous section applied a mapping operation using a function to each row in the dataset. Now you’re ready to generate embeddings from the data and using Ray Data’s map_batches
to apply an operation across batches of the data. The operation is in the form of a callable, which is a function or a class with a __call__
method.
import numpy as np
from PIL import Image
import torch
from transformers import CLIPModel, CLIPProcessor
class EmbedImages(object):
def __init__(self, model_id, device):
# Load CLIP model and processor
self.processor = CLIPProcessor.from_pretrained(model_id)
self.model = CLIPModel.from_pretrained(model_id)
self.model.to(device)
self.device = device
def __call__(self, batch):
# Load and preprocess images
images = [Image.fromarray(np.uint8(img)).convert("RGB") for img in batch["image"]]
inputs = self.processor(images=images, return_tensors="pt", padding=True).to(self.device)
# Generate embeddings
with torch.inference_mode():
batch["embedding"] = self.model.get_image_features(**inputs).cpu().numpy()
return batch
Instead of initializing the same model for each instance of the class above, we can instead use references to Ray’s shared memory object store. We can load the model once, store it inside the default object store and then have each instance of our class refer to it.
model = load_model(...)
model_ref = ray.put(model)
class Foo:
def __init__(self, model_ref):
self.model = ray.get(model_ref)
...
# Generate batch embeddings
embeddings_ds = ds.map_batches(
EmbedImages,
fn_constructor_kwargs={
"model_id": "openai/clip-vit-base-patch32",
"device": "cuda",
}, # class kwargs
fn_kwargs={}, # __call__ kwargs
concurrency=4,
batch_size=64,
num_gpus=1,
accelerator_type="L4",
)
embeddings_ds = embeddings_ds.drop_columns(["image"]) # remove image column
Ray Data#
Ray Data not only makes it extremely easy to distribute workloads but also ensures that they with:
efficiency: minimize CPU/GPU idle time with heterogeneous resource scheduling.
scalability: streaming execution to petabyte-scale datasets, especially when working with LLMs
reliability by checkpointing processes, especially when running workloads on spot instances with on-demand fallback.
flexibility: connect to data from any source, apply transformations, and save to any format or location for your next workload.

🔥 RayTurbo Data has more functionality on top of Ray Data:
accelerated metadata fetching to improve reading from large datasets (start processes earlier).
optimized autoscaling where actor pools are scaled up faster, start jobs before entire cluster is ready, etc.
high reliability where entire fails jobs (even on spot instances), like head node, cluster, uncaptured exceptions, etc., can resume from checkpoints. OSS Ray can only recover from worker node failures.
Data storage#
import shutil
# Save to artifact storage.
embeddings_path = os.path.join("/mnt/cluster_storage", "doggos/embeddings")
if os.path.exists(embeddings_path):
shutil.rmtree(embeddings_path) # clean up
embeddings_ds.write_parquet(embeddings_path)
2025-06-23 14:06:01,973 INFO logging.py:295 -- Registered dataset logger for dataset dataset_16_0
2025-06-23 14:06:02,000 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_16_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:06:02,002 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_16_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles] -> TaskPoolMapOperator[Map(add_class)] -> ActorPoolMapOperator[MapBatches(EmbedImages)] -> TaskPoolMapOperator[MapBatches(drop_columns)->Write]
2025-06-23 14:06:02,029 INFO actor_pool_map_operator.py:633 -- Scaling up actor pool by 4 (reason=scaling to min size, running=0, restarting=0, pending=0)
(_MapWorker pid=2910, ip=10.0.69.70) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
2025-06-23 14:06:22,379 INFO actor_pool_map_operator.py:661 -- Scaled down actor pool by 1 (reason=None; running=3, restarting=0, pending=0)
2025-06-23 14:06:22,744 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_16_0 execution finished in 20.74 seconds
2025-06-23 14:06:22,842 INFO dataset.py:4603 -- Data sink Parquet finished. 2880 rows and 5.8MB data written.
(autoscaler +5m51s) Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.
(autoscaler +5m51s) [autoscaler] Downscaling node i-018706717a4455b75 (node IP: 10.0.65.200) due to node idle termination.
(autoscaler +5m51s) [autoscaler] Downscaling node i-0e3238b7f703616e7 (node IP: 10.0.127.236) due to node idle termination.
(autoscaler +5m51s) [autoscaler] Downscaling node i-0fcefb76d19edf42b (node IP: 10.0.49.153) due to node idle termination.
(autoscaler +5m56s) [autoscaler] Cluster resized to {8 CPU, 2 GPU}.
You can always store to the data inside any storage buckets but Anyscale offers a default storage bucket to make things easier. You also have plenty of other storage options as well, for example, shared at the cluster, user and cloud levels.
Note: ideally you would store these embeddings in a vector database like efficient search, filter, index, etc., but for this tutorial, just store to a shared file system.
Monitoring and Debugging#
While you’re developing out workloads, Ray offers logs and an observability dashboard that you can use to monitor and debug. The dashboard includes a lot of different components such as:
memory, utilization, etc., of the tasks running in the cluster

views to see all running tasks, utilization across instance types, autoscaling, etc.

🔥 While OSS Ray comes with an extensive observability suite, Anyscale takes it many steps further to make it easier and faster to monitor and debug workloads.
Ray workload specific dashboard, like Data, Train, etc., that can breakdown the tasks

unified log viewer to see logs from all driver and worker processes

Production jobs#
Anyscale Jobs (API ref) allows you to execute discrete workloads in production such as batch inference, embeddings generation, or model fine-tuning.
define and manage Jobs in many different ways, including with a CLI or Python SDK.
set up all the observability, alerting, etc. around your Jobs.

Wrap the batch embedding generation workload as an Anyscale Job by providing the main command to run, python doggos/embed.py
, and the appropriate compute and dependencies required for it. Also set the working directory to the default
directory so that the Job has access to all the files for the workload.
Note:
this step uses a
containerfile
to define dependencies, but you could easily use a pre-built image as well.you can specify the compute as a compute config or inline in a job config file.
when you don’t specify compute while launching from a workspace, the configuration defaults to the compute configuration of the workspace.
and of course we can launch Jobs from anywhere (not just from within Workspaces) where we can specify the compute config and dependencies for the Job to use. Learn more on how to create and manage Jobs.
# Production batch job.
anyscale job submit --name=generate-doggos-embeddings \
--containerfile="/home/ray/default/containerfile" \
--compute-config="/home/ray/default/configs/aws.yaml" \
--working-dir="/home/ray/default" \
--exclude="" \
--max-retries=0 \
-- python doggos/embed.py

Similar images#
Process a new image, embed it, and then retrieve the top similar images, based on embedding similarity (cosine), from the larger dataset of images you just computed batch embeddings for.
from io import BytesIO
from PIL import Image
import numpy as np
import requests
from doggos.embed import get_top_matches, display_top_matches
def url_to_array(url):
return np.array(Image.open(
BytesIO(requests.get(url).content)).convert("RGB"))
# Embed input image.
url = "https://doggos-dataset.s3.us-west-2.amazonaws.com/samara.png"
image = url_to_array(url=url)
embedding_generator = EmbedImages(model_id="openai/clip-vit-base-patch32", device="cpu")
embedding = embedding_generator({"image": [image]})["embedding"][0]
np.shape(embedding)
(512,)
# Top matches by embedding similarity.
embeddings_ds = ray.data.read_parquet(embeddings_path)
top_matches = get_top_matches(embedding, embeddings_ds, n=5)
display_top_matches(url, top_matches)
2025-06-23 14:13:37,591 INFO logging.py:295 -- Registered dataset logger for dataset dataset_20_0
2025-06-23 14:13:37,597 INFO streaming_executor.py:117 -- Starting execution of Dataset dataset_20_0. Full logs are in /tmp/ray/session_2025-06-23_13-49-50_102769_2149/logs/ray-data
2025-06-23 14:13:37,598 INFO streaming_executor.py:118 -- Execution plan of Dataset dataset_20_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ListFiles] -> TaskPoolMapOperator[ReadFiles]
2025-06-23 14:13:38,825 INFO streaming_executor.py:227 -- ✔️ Dataset dataset_20_0 execution finished in 1.23 seconds

(_MapWorker pid=3343, ip=10.0.102.235) Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
(autoscaler +16m26s) [autoscaler] [4xL4:48CPU-192GB] Attempting to add 2 node(s) to the cluster (increasing from 1 to 3).
(autoscaler +16m26s) [autoscaler] [4xL4:48CPU-192GB] Launched 2 instances.
(autoscaler +17m11s) [autoscaler] Cluster upscaled to {152 CPU, 14 GPU}.
(raylet) WARNING: 4 PYTHON worker processes have been started on node: 97b39558bc8a3057162823cead1b8e035f1be130c49bb311e538ed2d with address: 10.0.52.172. This could be a result of using a large number of actors, or due to tasks blocked in ray.get() calls (see https://github.com/ray-project/ray/issues/3644 for some discussion of workarounds).
(autoscaler +1h19m21s) [autoscaler] Downscaling node i-03a133888407b8cf8 (node IP: 10.0.103.152) due to node idle termination.
(autoscaler +1h19m21s) [autoscaler] Downscaling node i-06023e83fb012b7ae (node IP: 10.0.90.122) due to node idle termination.
(autoscaler +1h19m26s) [autoscaler] Cluster resized to {56 CPU, 6 GPU}.
🚨 Note: Reset this notebook using the “🔄 Restart” button location at the notebook’s menu bar. This way we can free up all the variables, utils, etc. used in this notebook.