ray.train.predictor.Predictor#

class ray.train.predictor.Predictor(preprocessor: Optional[ray.data.preprocessor.Preprocessor] = None)[source]#

Bases: abc.ABC

Predictors load models from checkpoints to perform inference.

Note

The base Predictor class cannot be instantiated directly. Only one of its subclasses can be used.

How does a Predictor work?

Predictors expose a predict method that accepts an input batch of type DataBatchType and outputs predictions of the same type as the input batch.

When the predict method is called the following occurs:

  • The input batch is converted into a pandas DataFrame. Tensor input (like a np.ndarray) will be converted into a single column Pandas Dataframe.

  • If there is a Preprocessor saved in the provided Checkpoint, the preprocessor will be used to transform the DataFrame.

  • The transformed DataFrame will be passed to the model for inference (via the predictor._predict_pandas method).

  • The predictions will be outputted by predict in the same type as the original input.

How do I create a new Predictor?

To implement a new Predictor for your particular framework, you should subclass the base Predictor and implement the following two methods:

  1. _predict_pandas: Given a pandas.DataFrame input, return a pandas.DataFrame containing predictions.

  2. from_checkpoint: Logic for creating a Predictor from an AIR Checkpoint.

  3. Optionally _predict_numpy for better performance when working with tensor data to avoid extra copies from Pandas conversions.

PublicAPI (beta): This API is in beta and may change before becoming stable.

Methods

__init__([preprocessor])

Subclasseses must call Predictor.__init__() to set a preprocessor.

from_checkpoint(checkpoint, **kwargs)

Create a specific predictor from a checkpoint.

from_pandas_udf(pandas_udf)

Create a Predictor from a Pandas UDF.

get_preprocessor()

Get the preprocessor to use prior to executing predictions.

predict(data, **kwargs)

Perform inference on a batch of data.

preferred_batch_format()

Batch format hint for upstream producers to try yielding best block format.

set_preprocessor(preprocessor)

Set the preprocessor to use prior to executing predictions.