ray.data.preprocessors.TorchVisionPreprocessor#
- class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[np.ndarray | torch.Tensor], torch.Tensor], output_columns: List[str] | None = None, batched: bool = False)[source]#
Bases:
PreprocessorApply a TorchVision transform to image columns.
Examples
Torch models expect inputs of shape \((B, C, H, W)\) in the range \([0.0, 1.0]\). To convert images to this format, add
ToTensorto your preprocessing pipeline.from torchvision import transforms import ray from ray.data.preprocessors import TorchVisionPreprocessor transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset)
For better performance, set
batchedtoTrueand replaceToTensorwith a batch-supportingLambda.import numpy as np import torch def to_tensor(batch: np.ndarray) -> torch.Tensor: tensor = torch.as_tensor(batch, dtype=torch.float) # (B, H, W, C) -> (B, C, H, W) tensor = tensor.permute(0, 3, 1, 2).contiguous() # [0., 255.] -> [0., 1.] tensor = tensor.div(255) return tensor transform = transforms.Compose([ transforms.Lambda(to_tensor), transforms.Resize((224, 224)) ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset)
- Parameters:
columns – The columns to apply the TorchVision transform to.
transform – The TorchVision transform you want to apply. This transform should accept a
np.ndarrayortorch.Tensoras input and return atorch.Tensoras output.output_columns – The output name for each input column. If not specified, this defaults to the same set of columns as the columns.
batched – If
True, applytransformto batches of shape \((B, H, W, C)\). Otherwise, applytransformto individual images.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
Load the original preprocessor serialized via
self.serialize().Fit this Preprocessor to the Dataset.
Fit this Preprocessor to the Dataset and then transform the Dataset.
Return this preprocessor serialized as a string.
Transform the given dataset.
Transform a single batch of data.