ray.data.preprocessors.TorchVisionPreprocessor#
- class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[np.ndarray | torch.Tensor], torch.Tensor], output_columns: List[str] | None = None, batched: bool = False)[source]#
Bases:
SerializablePreprocessorBaseApply a TorchVision transform to image columns.
Examples
Torch models expect inputs of shape \((B, C, H, W)\) in the range \([0.0, 1.0]\). To convert images to this format, add
ToTensorto your preprocessing pipeline.from torchvision import transforms import ray from ray.data.preprocessors import TorchVisionPreprocessor transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset)
For better performance, set
batchedtoTrueand replaceToTensorwith a batch-supportingLambda.import numpy as np import torch def to_tensor(batch: np.ndarray) -> torch.Tensor: tensor = torch.as_tensor(batch, dtype=torch.float) # (B, H, W, C) -> (B, C, H, W) tensor = tensor.permute(0, 3, 1, 2).contiguous() # [0., 255.] -> [0., 1.] tensor = tensor.div(255) return tensor transform = transforms.Compose([ transforms.Lambda(to_tensor), transforms.Resize((224, 224)) ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset)
- Parameters:
columns – The columns to apply the TorchVision transform to.
transform – The TorchVision transform you want to apply. This transform should accept a
np.ndarrayortorch.Tensoras input and return atorch.Tensoras output.output_columns – The output name for each input column. If not specified, this defaults to the same set of columns as the columns.
batched – If
True, applytransformto batches of shape \((B, H, W, C)\). Otherwise, applytransformto individual images.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.
Methods
Deserialize a preprocessor from serialized data.
Fit this Preprocessor to the Dataset.
Fit this Preprocessor to the Dataset and then transform the Dataset.
Get the preprocessor class identifier for this preprocessor class.
Get the version number for this preprocessor class.
Serialize this preprocessor to a string or bytes.
Set the preprocessor class identifier for this preprocessor class.
Set the version number for this preprocessor class.
Transform the given dataset.
Transform a single batch of data.
Attributes