ray.data.preprocessors.TorchVisionPreprocessor#
- class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[np.ndarray | torch.Tensor], torch.Tensor], output_columns: List[str] | None = None, batched: bool = False)[source]#
- Bases: - Preprocessor- Apply a TorchVision transform to image columns. - Examples - Torch models expect inputs of shape \((B, C, H, W)\) in the range \([0.0, 1.0]\). To convert images to this format, add - ToTensorto your preprocessing pipeline.- from torchvision import transforms import ray from ray.data.preprocessors import TorchVisionPreprocessor transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset) - For better performance, set - batchedto- Trueand replace- ToTensorwith a batch-supporting- Lambda.- import numpy as np import torch def to_tensor(batch: np.ndarray) -> torch.Tensor: tensor = torch.as_tensor(batch, dtype=torch.float) # (B, H, W, C) -> (B, C, H, W) tensor = tensor.permute(0, 3, 1, 2).contiguous() # [0., 255.] -> [0., 1.] tensor = tensor.div(255) return tensor transform = transforms.Compose([ transforms.Lambda(to_tensor), transforms.Resize((224, 224)) ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset) - Parameters:
- columns – The columns to apply the TorchVision transform to. 
- transform – The TorchVision transform you want to apply. This transform should accept a - np.ndarrayor- torch.Tensoras input and return a- torch.Tensoras output.
- output_columns – The output name for each input column. If not specified, this defaults to the same set of columns as the columns. 
- batched – If - True, apply- transformto batches of shape \((B, H, W, C)\). Otherwise, apply- transformto individual images.
 
 - PublicAPI (alpha): This API is in alpha and may change before becoming stable. - Methods - Load the original preprocessor serialized via - self.serialize().- Fit this Preprocessor to the Dataset. - Fit this Preprocessor to the Dataset and then transform the Dataset. - Return this preprocessor serialized as a string. - Transform the given dataset. - Transform a single batch of data.