Source code for ray.data.preprocessors.torch

from typing import TYPE_CHECKING, Callable, Dict, List, Mapping, Optional, Union

import numpy as np

from ray.air.util.data_batch_conversion import BatchFormat
from ray.air.util.tensor_extensions.utils import _create_possibly_ragged_ndarray
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
    import torch


[docs] @PublicAPI(stability="alpha") class TorchVisionPreprocessor(Preprocessor): """Apply a `TorchVision transform <https://pytorch.org/vision/stable/transforms.html>`_ to image columns. Examples: Torch models expect inputs of shape :math:`(B, C, H, W)` in the range :math:`[0.0, 1.0]`. To convert images to this format, add ``ToTensor`` to your preprocessing pipeline. .. testcode:: from torchvision import transforms import ray from ray.data.preprocessors import TorchVisionPreprocessor transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize((224, 224)), ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset) For better performance, set ``batched`` to ``True`` and replace ``ToTensor`` with a batch-supporting ``Lambda``. .. testcode:: import numpy as np import torch def to_tensor(batch: np.ndarray) -> torch.Tensor: tensor = torch.as_tensor(batch, dtype=torch.float) # (B, H, W, C) -> (B, C, H, W) tensor = tensor.permute(0, 3, 1, 2).contiguous() # [0., 255.] -> [0., 1.] tensor = tensor.div(255) return tensor transform = transforms.Compose([ transforms.Lambda(to_tensor), transforms.Resize((224, 224)) ]) preprocessor = TorchVisionPreprocessor(["image"], transform=transform, batched=True) dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images") dataset = preprocessor.transform(dataset) Args: columns: The columns to apply the TorchVision transform to. transform: The TorchVision transform you want to apply. This transform should accept a ``np.ndarray`` or ``torch.Tensor`` as input and return a ``torch.Tensor`` as output. output_columns: The output name for each input column. If not specified, this defaults to the same set of columns as the columns. batched: If ``True``, apply ``transform`` to batches of shape :math:`(B, H, W, C)`. Otherwise, apply ``transform`` to individual images. """ # noqa: E501 _is_fittable = False def __init__( self, columns: List[str], transform: Callable[[Union["np.ndarray", "torch.Tensor"]], "torch.Tensor"], output_columns: Optional[List[str]] = None, batched: bool = False, ): if not output_columns: output_columns = columns if len(columns) != len(output_columns): raise ValueError( "The length of columns should match the " f"length of output_columns: {columns} vs {output_columns}." ) self._columns = columns self._output_columns = output_columns self._torchvision_transform = transform self._batched = batched def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"columns={self._columns}, " f"output_columns={self._output_columns}, " f"transform={self._torchvision_transform!r})" ) def _transform_numpy( self, data_batch: Dict[str, "np.ndarray"] ) -> Dict[str, "np.ndarray"]: import torch from ray.air._internal.torch_utils import convert_ndarray_to_torch_tensor def apply_torchvision_transform(array: np.ndarray) -> np.ndarray: try: tensor = convert_ndarray_to_torch_tensor(array) output = self._torchvision_transform(tensor) except TypeError: # Transforms like `ToTensor` expect a `np.ndarray` as input. output = self._torchvision_transform(array) if isinstance(output, torch.Tensor): output = output.numpy() if not isinstance(output, np.ndarray): raise ValueError( "`TorchVisionPreprocessor` expected your transform to return a " "`torch.Tensor` or `np.ndarray`, but your transform returned a " f"`{type(output).__name__}` instead." ) return output def transform_batch(batch: np.ndarray) -> np.ndarray: if self._batched: return apply_torchvision_transform(batch) return _create_possibly_ragged_ndarray( [apply_torchvision_transform(array) for array in batch] ) if isinstance(data_batch, Mapping): for input_col, output_col in zip(self._columns, self._output_columns): data_batch[output_col] = transform_batch(data_batch[input_col]) else: # TODO(ekl) deprecate this code path. Unfortunately, predictors are still # sending schemaless arrays to preprocessors. data_batch = transform_batch(data_batch) return data_batch def preferred_batch_format(cls) -> BatchFormat: return BatchFormat.NUMPY