ray.data.preprocessors.TorchVisionPreprocessor#
- class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[np.ndarray | torch.Tensor], torch.Tensor], output_columns: List[str] | None = None, batched: bool = False)[source]#
Bases:
Preprocessor
Apply a TorchVision transform to image columns.
Examples
Torch models expect inputs of shape \((B, C, H, W)\) in the range \([0.0, 1.0]\). To convert images to this format, add
ToTensor
to your preprocessing pipeline.from torchvision import transforms import ray from ray.data.preprocessors import TorchVisionPreprocessor transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset)
For better performance, set
batched
toTrue
and replaceToTensor
with a batch-supportingLambda
.import numpy as np import torch def to_tensor(batch: np.ndarray) -> torch.Tensor: tensor = torch.as_tensor(batch, dtype=torch.float) # (B, H, W, C) -> (B, C, H, W) tensor = tensor.permute(0, 3, 1, 2).contiguous() # [0., 255.] -> [0., 1.] tensor = tensor.div(255) return tensor transform = transforms.Compose([ transforms.Lambda(to_tensor), transforms.Resize((224, 224)) ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset)
- Parameters:
columns – The columns to apply the TorchVision transform to.
transform – The TorchVision transform you want to apply. This transform should accept a
np.ndarray
ortorch.Tensor
as input and return atorch.Tensor
as output.output_columns – The output name for each input column. If not specified, this defaults to the same set of columns as the columns.
batched – If
True
, applytransform
to batches of shape \((B, H, W, C)\). Otherwise, applytransform
to individual images.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
Load the original preprocessor serialized via
self.serialize()
.Fit this Preprocessor to the Dataset.
Fit this Preprocessor to the Dataset and then transform the Dataset.
Return this preprocessor serialized as a string.
Transform the given dataset.
Transform a single batch of data.