ray.data.preprocessors.TorchVisionPreprocessor#

class ray.data.preprocessors.TorchVisionPreprocessor(columns: List[str], transform: Callable[[np.ndarray], torch.Tensor], batched: bool = False)[source]#

Bases: ray.data.preprocessor.Preprocessor

Apply a TorchVision transform to image columns.

Examples

>>> import ray
>>> dataset = ray.data.read_images("s3://[email protected]/imagenet-sample-images")
>>> dataset  
Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(..., 3), dtype=float)})

TorchVisionPreprocessor passes ndarrays to your transform. To convert ndarrays to Torch tensors, add ToTensor to your pipeline.

>>> from torchvision import transforms
>>> from ray.data.preprocessors import TorchVisionPreprocessor
>>> transform = transforms.Compose([
...     transforms.ToTensor(),
...     transforms.Resize((224, 224)),
... ])
>>> preprocessor = TorchVisionPreprocessor(["image"], transform=transform)
>>> preprocessor.transform(dataset)  
Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})

For better performance, set batched to True and replace ToTensor with a batch-supporting Lambda.

>>> def to_tensor(batch: np.ndarray) -> torch.Tensor:
...     tensor = torch.as_tensor(batch, dtype=torch.float)
...     # (B, H, W, C) -> (B, C, H, W)
...     tensor = tensor.permute(0, 3, 1, 2).contiguous()
...     # [0., 255.] -> [0., 1.]
...     tensor = tensor.div(255)
...     return tensor
>>> transform = transforms.Compose([
...     transforms.Lambda(to_tensor),
...     transforms.Resize((224, 224))
... ])
>>> preprocessor = TorchVisionPreprocessor(
...     ["image"], transform=transform, batched=True
... )
>>> preprocessor.transform(dataset)  
Dataset(num_blocks=..., num_rows=..., schema={image: ArrowTensorType(shape=(3, 224, 224), dtype=float)})
Parameters
  • columns – The columns to apply the TorchVision transform to.

  • transform – The TorchVision transform you want to apply. This transform should accept an np.ndarray as input and return a torch.Tensor as output.

  • batched – If True, apply transform to batches of shape \((B, H, W, C)\). Otherwise, apply transform to individual images.

PublicAPI (alpha): This API is in alpha and may change before becoming stable.