from typing import TYPE_CHECKING, Callable, Dict, List, Mapping, Optional, Union
import numpy as np
from ray.air.util.data_batch_conversion import BatchFormat
from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI
if TYPE_CHECKING:
import torch
[docs]
@PublicAPI(stability="alpha")
class TorchVisionPreprocessor(Preprocessor):
"""Apply a `TorchVision transform <https://pytorch.org/vision/stable/transforms.html>`_
to image columns.
Examples:
Torch models expect inputs of shape :math:`(B, C, H, W)` in the range
:math:`[0.0, 1.0]`. To convert images to this format, add ``ToTensor`` to your
preprocessing pipeline.
.. testcode::
from torchvision import transforms
import ray
from ray.data.preprocessors import TorchVisionPreprocessor
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
])
preprocessor = TorchVisionPreprocessor(["image"], transform=transform)
dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
dataset = preprocessor.transform(dataset)
For better performance, set ``batched`` to ``True`` and replace ``ToTensor``
with a batch-supporting ``Lambda``.
.. testcode::
import numpy as np
import torch
def to_tensor(batch: np.ndarray) -> torch.Tensor:
tensor = torch.as_tensor(batch, dtype=torch.float)
# (B, H, W, C) -> (B, C, H, W)
tensor = tensor.permute(0, 3, 1, 2).contiguous()
# [0., 255.] -> [0., 1.]
tensor = tensor.div(255)
return tensor
transform = transforms.Compose([
transforms.Lambda(to_tensor),
transforms.Resize((224, 224))
])
preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True)
dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
dataset = preprocessor.transform(dataset)
Args:
columns: The columns to apply the TorchVision transform to.
transform: The TorchVision transform you want to apply. This transform should
accept a ``np.ndarray`` or ``torch.Tensor`` as input and return a
``torch.Tensor`` as output.
output_columns: The output name for each input column. If not specified, this
defaults to the same set of columns as the columns.
batched: If ``True``, apply ``transform`` to batches of shape
:math:`(B, H, W, C)`. Otherwise, apply ``transform`` to individual images.
""" # noqa: E501
_is_fittable = False
def __init__(
self,
columns: List[str],
transform: Callable[[Union["np.ndarray", "torch.Tensor"]], "torch.Tensor"],
output_columns: Optional[List[str]] = None,
batched: bool = False,
):
if not output_columns:
output_columns = columns
if len(columns) != len(output_columns):
raise ValueError(
"The length of columns should match the "
f"length of output_columns: {columns} vs {output_columns}."
)
self._columns = columns
self._output_columns = output_columns
self._torchvision_transform = transform
self._batched = batched
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"columns={self._columns}, "
f"output_columns={self._output_columns}, "
f"transform={self._torchvision_transform!r})"
)
def _transform_numpy(
self, data_batch: Dict[str, "np.ndarray"]
) -> Dict[str, "np.ndarray"]:
import torch
from ray.air._internal.torch_utils import convert_ndarray_to_torch_tensor
def apply_torchvision_transform(array: np.ndarray) -> np.ndarray:
try:
tensor = convert_ndarray_to_torch_tensor(array)
output = self._torchvision_transform(tensor)
except TypeError:
# Transforms like `ToTensor` expect a `np.ndarray` as input.
output = self._torchvision_transform(array)
if isinstance(output, torch.Tensor):
output = output.numpy()
if not isinstance(output, np.ndarray):
raise ValueError(
"`TorchVisionPreprocessor` expected your transform to return a "
"`torch.Tensor` or `np.ndarray`, but your transform returned a "
f"`{type(output).__name__}` instead."
)
return output
def transform_batch(batch: np.ndarray) -> np.ndarray:
if self._batched:
return apply_torchvision_transform(batch)
return _create_possibly_ragged_ndarray(
[apply_torchvision_transform(array) for array in batch]
)
if isinstance(data_batch, Mapping):
for input_col, output_col in zip(self._columns, self._output_columns):
data_batch[output_col] = transform_batch(data_batch[input_col])
else:
# TODO(ekl) deprecate this code path. Unfortunately, predictors are still
# sending schemaless arrays to preprocessors.
data_batch = transform_batch(data_batch)
return data_batch
def preferred_batch_format(cls) -> BatchFormat:
return BatchFormat.NUMPY