Object Detection Batch Inference with PyTorch#

This example demonstrates how to do object detection batch inference at scale with a pre-trained PyTorch model and Ray Data.

Here is what you’ll do:

  1. Perform object detection on a single image with a pre-trained PyTorch model.

  2. Scale the PyTorch model with Ray Data, and perform object detection batch inference on a large set of images.

  3. Verify the inference results and save them to an external storage.

  4. Learn how to use Ray Data with GPUs.

Before You Begin#

Install the following dependencies if you haven’t already.

!pip install "ray[data]" torchvision

Object Detection on a single Image with PyTorch#

Before diving into Ray Data, let’s take a look at this object detection example from PyTorch’s official documentation. The example used a pre-trained model (FasterRCNN_ResNet50) to do object detection inference on a single image.

First, download an image from the Internet.

import requests
from PIL import Image

url = "https://s3-us-west-2.amazonaws.com/air-example-data/AnimalDetection/JPEGImages/2007_000063.jpg"
img = Image.open(requests.get(url, stream=True).raw)
display(img)
../../_images/5f1ba187d2025085cccbda577067ba972fd4bc2cd346ee06da5a880e385dcbf9.png

Second, load and intialize a pre-trained PyTorch model.

from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights

weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9)
model.eval()
FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (fpn): FeaturePyramidNetwork(
      (inner_blocks): ModuleList(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (3): Conv2dNormActivation(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer_blocks): ModuleList(
        (0-3): 4 x Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (extra_blocks): LastLevelMaxPool()
    )
  )
  (rpn): RegionProposalNetwork(
    (anchor_generator): AnchorGenerator()
    (head): RPNHead(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): ReLU(inplace=True)
        )
      )
      (cls_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
      (bbox_pred): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (roi_heads): RoIHeads(
    (box_roi_pool): MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'], output_size=(7, 7), sampling_ratio=2)
    (box_head): FastRCNNConvFCHead(
      (0): Conv2dNormActivation(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (1): Conv2dNormActivation(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (2): Conv2dNormActivation(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (3): Conv2dNormActivation(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
      (4): Flatten(start_dim=1, end_dim=-1)
      (5): Linear(in_features=12544, out_features=1024, bias=True)
      (6): ReLU(inplace=True)
    )
    (box_predictor): FastRCNNPredictor(
      (cls_score): Linear(in_features=1024, out_features=91, bias=True)
      (bbox_pred): Linear(in_features=1024, out_features=364, bias=True)
    )
  )
)

Then apply the preprocessing transforms.

img = transforms.Compose([transforms.PILToTensor()])(img)
preprocess = weights.transforms()
batch = [preprocess(img)]

Then use the model for inference.

prediction = model(batch)[0]

Lastly, visualize the result.

from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

labels = [weights.meta["categories"][i] for i in prediction["labels"]]
box = draw_bounding_boxes(img, 
                          boxes=prediction["boxes"],
                          labels=labels,
                          colors="red",
                          width=4)
im = to_pil_image(box.detach())
display(im)
../../_images/6bf6a8685ef6d620b2308cbcb538bc33d61a3c23081997531d643b39df4be5f0.png

Scaling with Ray Data#

Then let’s see how to scale the previous example to a large set of images. We will use Ray Data to do batch inference in a distributed fashion, leveraging all the CPU and GPU resources in our cluster.

Loading the Image Dataset#

The dataset that we will be using is a subset of Pascal VOC that contains cats and dogs (the full dataset has 20 classes). There are 2434 images in the this dataset.

First, we use the ray.data.read_images API to load a prepared image dataset from S3. We can use the schema API to check the schema of the dataset. As we can see, it has one column named “image”, and the value is the image data represented in np.ndarray format.

import ray

ds = ray.data.read_images("s3://anonymous@air-example-data/AnimalDetection/JPEGImages")
display(ds.schema())
[2023-05-19 18:10:29]  INFO ray._private.worker::Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
[2023-05-19 18:10:35] [Ray Data] WARNING ray.data.dataset::Important: Ray Data requires schemas for all datasets in Ray 2.5. This means that standalone Python objects are no longer supported. In addition, the default batch format is fixed to NumPy. To revert to legacy behavior temporarily, set the environment variable RAY_DATA_STRICT_MODE=0 on all cluster processes.

Learn more here: https://docs.ray.io/en/master/data/faq.html#migrating-to-strict-mode
Column  Type
------  ----
image   numpy.ndarray(ndim=3, dtype=uint8)

Batch inference with Ray Data#

As we can see from the PyTorch example, model inference consists of 2 steps: preprocessing the image and model inference.

Preprocessing#

First let’s convert the preprocessing code to Ray Data. We’ll package the preprocessing code within a preprocess_image function. This function should take only one argument, which is a dict that contains a single image in the dataset, represented as a numpy array.

import numpy as np
import torch
from torchvision import transforms
from torchvision.models.detection import (FasterRCNN_ResNet50_FPN_V2_Weights,
                                          fasterrcnn_resnet50_fpn_v2)
from typing import Dict


def preprocess_image(data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
    weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
    preprocessor = transforms.Compose(
        [transforms.ToTensor(), weights.transforms()]
    )
    return {
        "image": data["image"],
        "transformed": preprocessor(data["image"]),
    }

Then we use the map API to apply the function to the whole dataset. By using Ray Data’s map, we can scale out the preprocessing to all the resources in our Ray cluster Note, the map method is lazy, it won’t perform execution until we start to consume the results.

ds = ds.map(preprocess_image)
[2023-05-19 18:10:37] [Ray Data] WARNING ray.data.dataset::The `map`, `flat_map`, and `filter` operations are unvectorized and can be very slow. If you're using a vectorized transformation, consider using `.map_batches()` instead.

Model inference#

Next, let’s convert the model inference part. Compared with preprocessing, model inference has 2 differences:

  1. Model loading and initialization is usually expensive.

  2. Model inference can be optimized with hardware acceleration if we process data in batches. Using larger batches improves GPU utilization and the overall runtime of the inference job.

Thus, we convert the model inference code to the following ObjectDetectionModel class. In this class, we put the expensive model loading and initialization code in the __init__ constructor, which will run only once. And we put the model inference code in the __call__ method, which will be called for each batch.

The __call__ method takes a batch of data items, instead of a single one. In this case, the batch is also a dict that has one key named “image”, and the value is an array of images represented in np.ndarray format. We can also use the take_batch API to fetch a single batch, and inspect its internal data structure.

single_batch = ds.take_batch(batch_size=3)
display(single_batch)
[2023-05-19 18:10:38] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadImage->Map]
[2023-05-19 18:10:38] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
[2023-05-19 18:10:38] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
[2023-05-19 18:10:40] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Shutting down <StreamingExecutor(Thread-11, started daemon 16076255232)>.
{'image': array([array([[[173, 153, 142],
                [255, 246, 242],
                [255, 245, 245],
                ...,
                [255, 255, 244],
                [237, 235, 223],
                [214, 212, 200]],
 
               [[124, 105,  90],
                [255, 249, 238],
                [251, 244, 236],
                ...,
                [255, 252, 245],
                [255, 254, 247],
                [247, 244, 237]],
 
               [[ 56,  37,  20],
                [255, 253, 239],
                [248, 248, 236],
                ...,
                [248, 247, 243],
                [248, 247, 243],
                [254, 253, 249]],
 
               ...,
 
               [[ 64,  78,  87],
                [ 63,  74,  80],
                [105, 113, 115],
                ...,
                [ 94, 105, 109],
                [ 90,  99, 104],
                [ 84,  91,  97]],
 
               [[ 68,  86,  96],
                [ 69,  82,  88],
                [ 55,  63,  66],
                ...,
                [ 82,  98,  98],
                [ 54,  70,  70],
                [ 82,  96,  97]],
 
               [[ 67,  87,  96],
                [ 43,  60,  67],
                [ 80,  96,  96],
                ...,
                [ 63,  75,  75],
                [ 89, 101, 101],
                [ 54,  65,  67]]], dtype=uint8),
        array([[[31, 32, 26],
                [31, 32, 26],
                [30, 31, 25],
                ...,
                [82, 83, 78],
                [82, 83, 78],
                [82, 83, 78]],
 
               [[32, 33, 27],
                [29, 30, 24],
                [26, 27, 21],
                ...,
                [82, 83, 78],
                [82, 83, 78],
                [82, 83, 78]],
 
               [[27, 28, 22],
                [23, 24, 18],
                [21, 22, 16],
                ...,
                [84, 85, 80],
                [84, 85, 80],
                [84, 85, 80]],
 
               ...,
 
               [[43, 18, 21],
                [36, 14, 16],
                [39, 19, 20],
                ...,
                [19, 24, 18],
                [19, 24, 18],
                [13, 18, 12]],
 
               [[47, 21, 24],
                [39, 14, 17],
                [36, 16, 17],
                ...,
                [21, 26, 20],
                [24, 29, 23],
                [22, 27, 21]],
 
               [[47, 16, 22],
                [40, 13, 18],
                [36, 16, 18],
                ...,
                [ 9, 14,  8],
                [ 7, 12,  6],
                [ 1,  6,  0]]], dtype=uint8),
        array([[[ 17,   3,   2],
                [ 17,   3,   2],
                [ 19,   3,   3],
                ...,
                [ 55,  68,  84],
                [ 56,  69,  85],
                [ 56,  69,  85]],
 
               [[ 18,   4,   3],
                [ 18,   4,   3],
                [ 19,   3,   3],
                ...,
                [ 56,  69,  85],
                [ 56,  69,  85],
                [ 57,  70,  86]],
 
               [[ 18,   4,   3],
                [ 18,   4,   3],
                [ 19,   3,   3],
                ...,
                [ 56,  69,  85],
                [ 56,  69,  85],
                [ 57,  70,  86]],
 
               ...,
 
               [[  9,   0,   1],
                [  9,   0,   1],
                [  9,   0,   1],
                ...,
                [123, 124, 116],
                [121, 122, 114],
                [116, 117, 109]],
 
               [[  9,   0,   1],
                [  9,   0,   1],
                [  9,   0,   1],
                ...,
                [121, 122, 114],
                [119, 120, 112],
                [115, 116, 108]],
 
               [[  9,   0,   1],
                [  9,   0,   1],
                [  9,   0,   1],
                ...,
                [121, 122, 114],
                [119, 120, 112],
                [116, 117, 109]]], dtype=uint8)], dtype=object),
 'transformed': array([array([[[0.6784314 , 1.        , 1.        , ..., 1.        ,
                 0.92941177, 0.8392157 ],
                [0.4862745 , 1.        , 0.9843137 , ..., 1.        ,
                 1.        , 0.96862745],
                [0.21960784, 1.        , 0.972549  , ..., 0.972549  ,
                 0.972549  , 0.99607843],
                ...,
                [0.2509804 , 0.24705882, 0.4117647 , ..., 0.36862746,
                 0.3529412 , 0.32941177],
                [0.26666668, 0.27058825, 0.21568628, ..., 0.32156864,
                 0.21176471, 0.32156864],
                [0.2627451 , 0.16862746, 0.3137255 , ..., 0.24705882,
                 0.34901962, 0.21176471]],
 
               [[0.6       , 0.9647059 , 0.9607843 , ..., 1.        ,
                 0.92156863, 0.83137256],
                [0.4117647 , 0.9764706 , 0.95686275, ..., 0.9882353 ,
                 0.99607843, 0.95686275],
                [0.14509805, 0.99215686, 0.972549  , ..., 0.96862745,
                 0.96862745, 0.99215686],
                ...,
                [0.30588236, 0.2901961 , 0.44313726, ..., 0.4117647 ,
                 0.3882353 , 0.35686275],
                [0.3372549 , 0.32156864, 0.24705882, ..., 0.38431373,
                 0.27450982, 0.3764706 ],
                [0.34117648, 0.23529412, 0.3764706 , ..., 0.29411766,
                 0.39607844, 0.25490198]],
 
               [[0.5568628 , 0.9490196 , 0.9607843 , ..., 0.95686275,
                 0.8745098 , 0.78431374],
                [0.3529412 , 0.93333334, 0.9254902 , ..., 0.9607843 ,
                 0.96862745, 0.92941177],
                [0.07843138, 0.9372549 , 0.9254902 , ..., 0.9529412 ,
                 0.9529412 , 0.9764706 ],
                ...,
                [0.34117648, 0.3137255 , 0.4509804 , ..., 0.42745098,
                 0.40784314, 0.38039216],
                [0.3764706 , 0.34509805, 0.25882354, ..., 0.38431373,
                 0.27450982, 0.38039216],
                [0.3764706 , 0.2627451 , 0.3764706 , ..., 0.29411766,
                 0.39607844, 0.2627451 ]]], dtype=float32)           ,
        array([[[0.12156863, 0.12156863, 0.11764706, ..., 0.32156864,
                 0.32156864, 0.32156864],
                [0.1254902 , 0.11372549, 0.10196079, ..., 0.32156864,
                 0.32156864, 0.32156864],
                [0.10588235, 0.09019608, 0.08235294, ..., 0.32941177,
                 0.32941177, 0.32941177],
                ...,
                [0.16862746, 0.14117648, 0.15294118, ..., 0.07450981,
                 0.07450981, 0.05098039],
                [0.18431373, 0.15294118, 0.14117648, ..., 0.08235294,
                 0.09411765, 0.08627451],
                [0.18431373, 0.15686275, 0.14117648, ..., 0.03529412,
                 0.02745098, 0.00392157]],
 
               [[0.1254902 , 0.1254902 , 0.12156863, ..., 0.3254902 ,
                 0.3254902 , 0.3254902 ],
                [0.12941177, 0.11764706, 0.10588235, ..., 0.3254902 ,
                 0.3254902 , 0.3254902 ],
                [0.10980392, 0.09411765, 0.08627451, ..., 0.33333334,
                 0.33333334, 0.33333334],
                ...,
                [0.07058824, 0.05490196, 0.07450981, ..., 0.09411765,
                 0.09411765, 0.07058824],
                [0.08235294, 0.05490196, 0.0627451 , ..., 0.10196079,
                 0.11372549, 0.10588235],
                [0.0627451 , 0.05098039, 0.0627451 , ..., 0.05490196,
                 0.04705882, 0.02352941]],
 
               [[0.10196079, 0.10196079, 0.09803922, ..., 0.30588236,
                 0.30588236, 0.30588236],
                [0.10588235, 0.09411765, 0.08235294, ..., 0.30588236,
                 0.30588236, 0.30588236],
                [0.08627451, 0.07058824, 0.0627451 , ..., 0.3137255 ,
                 0.3137255 , 0.3137255 ],
                ...,
                [0.08235294, 0.0627451 , 0.07843138, ..., 0.07058824,
                 0.07058824, 0.04705882],
                [0.09411765, 0.06666667, 0.06666667, ..., 0.07843138,
                 0.09019608, 0.08235294],
                [0.08627451, 0.07058824, 0.07058824, ..., 0.03137255,
                 0.02352941, 0.        ]]], dtype=float32)           ,
        array([[[0.06666667, 0.06666667, 0.07450981, ..., 0.21568628,
                 0.21960784, 0.21960784],
                [0.07058824, 0.07058824, 0.07450981, ..., 0.21960784,
                 0.21960784, 0.22352941],
                [0.07058824, 0.07058824, 0.07450981, ..., 0.21960784,
                 0.21960784, 0.22352941],
                ...,
                [0.03529412, 0.03529412, 0.03529412, ..., 0.48235294,
                 0.4745098 , 0.45490196],
                [0.03529412, 0.03529412, 0.03529412, ..., 0.4745098 ,
                 0.46666667, 0.4509804 ],
                [0.03529412, 0.03529412, 0.03529412, ..., 0.4745098 ,
                 0.46666667, 0.45490196]],
 
               [[0.01176471, 0.01176471, 0.01176471, ..., 0.26666668,
                 0.27058825, 0.27058825],
                [0.01568628, 0.01568628, 0.01176471, ..., 0.27058825,
                 0.27058825, 0.27450982],
                [0.01568628, 0.01568628, 0.01176471, ..., 0.27058825,
                 0.27058825, 0.27450982],
                ...,
                [0.        , 0.        , 0.        , ..., 0.4862745 ,
                 0.47843137, 0.45882353],
                [0.        , 0.        , 0.        , ..., 0.47843137,
                 0.47058824, 0.45490196],
                [0.        , 0.        , 0.        , ..., 0.47843137,
                 0.47058824, 0.45882353]],
 
               [[0.00784314, 0.00784314, 0.01176471, ..., 0.32941177,
                 0.33333334, 0.33333334],
                [0.01176471, 0.01176471, 0.01176471, ..., 0.33333334,
                 0.33333334, 0.3372549 ],
                [0.01176471, 0.01176471, 0.01176471, ..., 0.33333334,
                 0.33333334, 0.3372549 ],
                ...,
                [0.00392157, 0.00392157, 0.00392157, ..., 0.45490196,
                 0.44705883, 0.42745098],
                [0.00392157, 0.00392157, 0.00392157, ..., 0.44705883,
                 0.4392157 , 0.42352942],
                [0.00392157, 0.00392157, 0.00392157, ..., 0.44705883,
                 0.4392157 , 0.42745098]]], dtype=float32)           ],
       dtype=object)}
class ObjectDetectionModel:
    def __init__(self):
        # Define the model loading and initialization code in `__init__`.
        self.weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
        self.model = fasterrcnn_resnet50_fpn_v2(
            weights=self.weights,
            box_score_thresh=0.9,
        )
        if torch.cuda.is_available():
            # Move the model to GPU if it's available.
            self.model = self.model.cuda()
        self.model.eval()

    def __call__(self, input_batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        # Define the per-batch inference code in `__call__`.
        batch = [torch.from_numpy(image) for image in input_batch["transformed"]]
        if torch.cuda.is_available():
            # Move the data to GPU if it's available.
            batch = [image.cuda() for image in batch]
        predictions = self.model(batch)
        return {
            "image": input_batch["image"],
            "labels": [pred["labels"].detach().cpu().numpy() for pred in predictions],
            "boxes": [pred["boxes"].detach().cpu().numpy() for pred in predictions],
        }

Then we use the map_batches API to apply the model to the whole dataset.

The first parameter of map and map_batches is the user-defined function (UDF), which can either be a function or a class. Function-based UDFs run as short-running Ray tasks, and class-based UDFs run as long-running Ray actors. For class-based UDFs, use the concurrency argument to specify the number of parallel actors. The batch_size argument indicates the number of images in each batch.

The num_gpus argument specifies the number of GPUs needed for each ObjectDetectionModel instance. The Ray scheduler can handle heterogeous resource requirements in order to maximize the resource utilization. In this case, the ObjectDetectionModel instances will run on GPU and preprocess_image instances will run on CPU.

ds = ds.map_batches(
    ObjectDetectionModel,
    concurrency=4, # Use 4 GPUs. Change this number based on the number of GPUs in your cluster.
    batch_size=4, # Use the largest batch size that can fit in GPU memory.
    num_gpus=1,  # Specify 1 GPU per model replica. Remove this if you are doing CPU inference.
)

Verify and Save Results#

Then let’s take a small batch and verify the inference results with visualization.

from torchvision.transforms.functional import convert_image_dtype, to_tensor

batch = ds.take_batch(batch_size=2)
for image, labels, boxes in zip(batch["image"], batch["labels"], batch["boxes"]):
    image = convert_image_dtype(to_tensor(image), torch.uint8)
    labels = [weights.meta["categories"][i] for i in labels]
    boxes = torch.from_numpy(boxes)
    img = to_pil_image(draw_bounding_boxes(
        image,
        boxes,
        labels=labels,
        colors="red",
        width=4,
    ))
    display(img)
[2023-05-19 18:10:40] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Executing DAG InputDataBuffer[Input] -> ActorPoolMapOperator[ReadImage->Map->MapBatches(ObjectDetectionModel)]
[2023-05-19 18:10:40] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), locality_with_output=False, preserve_order=False, actor_locality_enabled=True, verbose_progress=False)
[2023-05-19 18:10:40] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Tip: For detailed progress reporting, run `ray.data.DataContext.get_current().execution_options.verbose_progress = True`
[2023-05-19 18:10:40] [Ray Data] INFO ray.data._internal.execution.operators.actor_pool_map_operator.logfile::ReadImage->Map->MapBatches(ObjectDetectionModel): Waiting for 4 pool actors to start...
[2023-05-19 18:11:50] [Ray Data] INFO ray.data._internal.execution.streaming_executor.logfile::Shutting down <StreamingExecutor(Thread-26, started daemon 16076255232)>.
[2023-05-19 18:11:50] [Ray Data] WARNING ray.data._internal.execution.operators.actor_pool_map_operator.logfile::To ensure full parallelization across an actor pool of size 4, the specified batch size should be at most 3. Your configured batch size for this operator was 4.
../../_images/52b0e26b926834509d4a1ea8ee867a7940f9531464338a3bee3330617d456cf2.png ../../_images/b73f2406c0823290221da22680043b5282b9e57c88def747a42db4bdfb392902.png

If the samples look good, we can proceed with saving the results to an external storage, e.g., S3 or local disks. See Ray Data Input/Output for all supported stoarges and file formats.

ds.write_parquet("local://tmp/inference_results")