import abc
import base64
import collections
import pickle
import warnings
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, Union
from ray.air.util.data_batch_conversion import BatchFormat
from ray.util.annotations import DeveloperAPI, PublicAPI
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from ray.air.data_batch_type import DataBatchType
from ray.data import Dataset
[docs]
@PublicAPI(stability="beta")
class PreprocessorNotFittedException(RuntimeError):
"""Error raised when the preprocessor needs to be fitted first."""
pass
[docs]
@PublicAPI(stability="beta")
class Preprocessor(abc.ABC):
"""Implements an ML preprocessing operation.
Preprocessors are stateful objects that can be fitted against a Dataset and used
to transform both local data batches and distributed data. For example, a
Normalization preprocessor may calculate the mean and stdev of a field during
fitting, and uses these attributes to implement its normalization transform.
Preprocessors can also be stateless and transform data without needed to be fitted.
For example, a preprocessor may simply remove a column, which does not require
any state to be fitted.
If you are implementing your own Preprocessor sub-class, you should override the
following:
* ``_fit`` if your preprocessor is stateful. Otherwise, set
``_is_fittable=False``.
* ``_transform_pandas`` and/or ``_transform_numpy`` for best performance,
implement both. Otherwise, the data will be converted to the match the
implemented method.
"""
class FitStatus(str, Enum):
"""The fit status of preprocessor."""
NOT_FITTABLE = "NOT_FITTABLE"
NOT_FITTED = "NOT_FITTED"
# Only meaningful for Chain preprocessors.
# At least one contained preprocessor in the chain preprocessor
# is fitted and at least one that can be fitted is not fitted yet.
# This is a state that show up if caller only interacts
# with the chain preprocessor through intended Preprocessor APIs.
PARTIALLY_FITTED = "PARTIALLY_FITTED"
FITTED = "FITTED"
# Preprocessors that do not need to be fitted must override this.
_is_fittable = True
def _check_has_fitted_state(self):
"""Checks if the Preprocessor has fitted state.
This is also used as an indiciation if the Preprocessor has been fit, following
convention from Ray versions prior to 2.6.
This allows preprocessors that have been fit in older versions of Ray to be
used to transform data in newer versions.
"""
fitted_vars = [v for v in vars(self) if v.endswith("_")]
return bool(fitted_vars)
def fit_status(self) -> "Preprocessor.FitStatus":
if not self._is_fittable:
return Preprocessor.FitStatus.NOT_FITTABLE
elif (
hasattr(self, "_fitted") and self._fitted
) or self._check_has_fitted_state():
return Preprocessor.FitStatus.FITTED
else:
return Preprocessor.FitStatus.NOT_FITTED
[docs]
def fit(self, ds: "Dataset") -> "Preprocessor":
"""Fit this Preprocessor to the Dataset.
Fitted state attributes will be directly set in the Preprocessor.
Calling it more than once will overwrite all previously fitted state:
``preprocessor.fit(A).fit(B)`` is equivalent to ``preprocessor.fit(B)``.
Args:
ds: Input dataset.
Returns:
Preprocessor: The fitted Preprocessor with state attributes.
"""
fit_status = self.fit_status()
if fit_status == Preprocessor.FitStatus.NOT_FITTABLE:
# No-op as there is no state to be fitted.
return self
if fit_status in (
Preprocessor.FitStatus.FITTED,
Preprocessor.FitStatus.PARTIALLY_FITTED,
):
warnings.warn(
"`fit` has already been called on the preprocessor (or at least one "
"contained preprocessors if this is a chain). "
"All previously fitted state will be overwritten!"
)
fitted_ds = self._fit(ds)
self._fitted = True
return fitted_ds
@DeveloperAPI
def _fit(self, ds: "Dataset") -> "Preprocessor":
"""Sub-classes should override this instead of fit()."""
raise NotImplementedError()
def _determine_transform_to_use(self) -> BatchFormat:
"""Determine which batch format to use based on Preprocessor implementation.
* If only `_transform_pandas` is implemented, then use ``pandas`` batch format.
* If only `_transform_numpy` is implemented, then use ``numpy`` batch format.
* If both are implemented, then use the Preprocessor defined preferred batch
format.
"""
has_transform_pandas = (
self.__class__._transform_pandas != Preprocessor._transform_pandas
)
has_transform_numpy = (
self.__class__._transform_numpy != Preprocessor._transform_numpy
)
if has_transform_numpy and has_transform_pandas:
return self.preferred_batch_format()
elif has_transform_numpy:
return BatchFormat.NUMPY
elif has_transform_pandas:
return BatchFormat.PANDAS
else:
raise NotImplementedError(
"None of `_transform_numpy` or `_transform_pandas` are implemented. "
"At least one of these transform functions must be implemented "
"for Preprocessor transforms."
)
def _transform(self, ds: "Dataset") -> "Dataset":
# TODO(matt): Expose `batch_size` or similar configurability.
# The default may be too small for some datasets and too large for others.
transform_type = self._determine_transform_to_use()
# Our user-facing batch format should only be pandas or NumPy, other
# formats {arrow, simple} are internal.
kwargs = self._get_transform_config()
if transform_type == BatchFormat.PANDAS:
return ds.map_batches(
self._transform_pandas, batch_format=BatchFormat.PANDAS, **kwargs
)
elif transform_type == BatchFormat.NUMPY:
return ds.map_batches(
self._transform_numpy, batch_format=BatchFormat.NUMPY, **kwargs
)
else:
raise ValueError(
"Invalid transform type returned from _determine_transform_to_use; "
f'"pandas" and "numpy" allowed, but got: {transform_type}'
)
def _get_transform_config(self) -> Dict[str, Any]:
"""Returns kwargs to be passed to :meth:`ray.data.Dataset.map_batches`.
This can be implemented by subclassing preprocessors.
"""
return {}
def _transform_batch(self, data: "DataBatchType") -> "DataBatchType":
# For minimal install to locally import air modules
import numpy as np
import pandas as pd
from ray.air.util.data_batch_conversion import (
_convert_batch_type_to_numpy,
_convert_batch_type_to_pandas,
)
try:
import pyarrow
except ImportError:
pyarrow = None
if not isinstance(
data, (pd.DataFrame, pyarrow.Table, collections.abc.Mapping, np.ndarray)
):
raise ValueError(
"`transform_batch` is currently only implemented for Pandas "
"DataFrames, pyarrow Tables, NumPy ndarray and dictionary of "
f"ndarray. Got {type(data)}."
)
transform_type = self._determine_transform_to_use()
if transform_type == BatchFormat.PANDAS:
return self._transform_pandas(_convert_batch_type_to_pandas(data))
elif transform_type == BatchFormat.NUMPY:
return self._transform_numpy(_convert_batch_type_to_numpy(data))
@DeveloperAPI
def _transform_pandas(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Run the transformation on a data batch in a Pandas DataFrame format."""
raise NotImplementedError()
@DeveloperAPI
def _transform_numpy(
self, np_data: Union["np.ndarray", Dict[str, "np.ndarray"]]
) -> Union["np.ndarray", Dict[str, "np.ndarray"]]:
"""Run the transformation on a data batch in a NumPy ndarray format."""
raise NotImplementedError()
[docs]
@DeveloperAPI
def serialize(self) -> str:
"""Return this preprocessor serialized as a string.
Note: this is not a stable serialization format as it uses `pickle`.
"""
# Convert it to a plain string so that it can be included as JSON metadata
# in Trainer checkpoints.
return base64.b64encode(pickle.dumps(self)).decode("ascii")
[docs]
@staticmethod
@DeveloperAPI
def deserialize(serialized: str) -> "Preprocessor":
"""Load the original preprocessor serialized via `self.serialize()`."""
return pickle.loads(base64.b64decode(serialized))