ray.data.preprocessors.TorchVisionPreprocessor#

class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[np.ndarray | torch.Tensor], torch.Tensor], output_columns: List[str] | None = None, batched: bool = False)[source]#

Bases: Preprocessor

Apply a TorchVision transform to image columns.

Examples

Torch models expect inputs of shape \((B, C, H, W)\) in the range \([0.0, 1.0]\). To convert images to this format, add ToTensor to your preprocessing pipeline.

from torchvision import transforms

import ray
from ray.data.preprocessors import TorchVisionPreprocessor

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
])
preprocessor = TorchVisionPreprocessor(["image"], transform=transform)

dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
dataset = preprocessor.transform(dataset)

For better performance, set batched to True and replace ToTensor with a batch-supporting Lambda.

import numpy as np
import torch

def to_tensor(batch: np.ndarray) -> torch.Tensor:
    tensor = torch.as_tensor(batch, dtype=torch.float)
    # (B, H, W, C) -> (B, C, H, W)
    tensor = tensor.permute(0, 3, 1, 2).contiguous()
    # [0., 255.] -> [0., 1.]
    tensor = tensor.div(255)
    return tensor

transform = transforms.Compose([
    transforms.Lambda(to_tensor),
    transforms.Resize((224, 224))
])
preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True)

dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
dataset = preprocessor.transform(dataset)
Parameters:
  • columns – The columns to apply the TorchVision transform to.

  • transform – The TorchVision transform you want to apply. This transform should accept a np.ndarray or torch.Tensor as input and return a torch.Tensor as output.

  • output_columns – The output name for each input column. If not specified, this defaults to the same set of columns as the columns.

  • batched – If True, apply transform to batches of shape \((B, H, W, C)\). Otherwise, apply transform to individual images.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.

Methods

deserialize

Load the original preprocessor serialized via self.serialize().

fit

Fit this Preprocessor to the Dataset.

fit_transform

Fit this Preprocessor to the Dataset and then transform the Dataset.

serialize

Return this preprocessor serialized as a string.

transform

Transform the given dataset.

transform_batch

Transform a single batch of data.