ray.data.preprocessors.TorchVisionPreprocessor
ray.data.preprocessors.TorchVisionPreprocessor#
- class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[np.ndarray], torch.Tensor], batched: bool = False)[source]#
Bases:
ray.data.preprocessor.Preprocessor
Apply a TorchVision transform to image columns.
Examples
>>> import ray >>> dataset = ray.data.read_images("s3://[email protected]/imagenet-sample-images") >>> dataset Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(..., 3), dtype=float)})
TorchVisionPreprocessor
passes ndarrays to your transform. To convert ndarrays to Torch tensors, addToTensor
to your pipeline.>>> from torchvision import transforms >>> from ray.data.preprocessors import TorchVisionPreprocessor >>> transform = transforms.Compose([ ... transforms.ToTensor(), ... transforms.Resize((224, 224)), ... ]) >>> preprocessor = TorchVisionPreprocessor(["image"], transform=transform) >>> preprocessor.transform(dataset) Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})
For better performance, set
batched
toTrue
and replaceToTensor
with a batch-supportingLambda
.>>> def to_tensor(batch: np.ndarray) -> torch.Tensor: ... tensor = torch.as_tensor(batch, dtype=torch.float) ... # (B, H, W, C) -> (B, C, H, W) ... tensor = tensor.permute(0, 3, 1, 2).contiguous() ... # [0., 255.] -> [0., 1.] ... tensor = tensor.div(255) ... return tensor >>> transform = transforms.Compose([ ... transforms.Lambda(to_tensor), ... transforms.Resize((224, 224)) ... ]) >>> preprocessor = TorchVisionPreprocessor( ... ["image"], transform=transform, batched=True ... ) >>> preprocessor.transform(dataset) Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})
- Parameters
columns – The columns to apply the TorchVision transform to.
transform – The TorchVision transform you want to apply. This transform should accept an
np.ndarray
as input and return atorch.Tensor
as output.batched – If
True
, applytransform
to batches of shape \((B, H, W, C)\). Otherwise, applytransform
to individual images.
PublicAPI (alpha): This API is in alpha and may change before becoming stable.