Using Predictors for Inference
Contents
Using Predictors for Inference#
Tip
Refer to the blog on Model Batch Inference in Ray for an overview of batch inference strategies in Ray and additional examples.

After you train a model, you will often want to use the model to do inference and prediction. To do so, you can use a Ray AIR Predictor. In this guide, we’ll cover how to use the Predictor on different types of data.
What are predictors?#
Ray AIR Predictors are a class that loads models from Checkpoint
to perform inference.
Predictors are used by BatchPredictor
and PredictorDeployment
to do large-scale scoring or online inference.
Let’s walk through a basic usage of the Predictor. In the below example, we create Checkpoint
object from a model definition.
Checkpoints can be generated from a variety of different ways – see the Checkpoints user guide for more details.
The checkpoint then is used to create a framework specific Predictor (in our example, a TensorflowPredictor
), which then can be used for inference:
import numpy as np
import tensorflow as tf
import ray
from ray.train.batch_predictor import BatchPredictor
from ray.train.tensorflow import (
TensorflowCheckpoint,
TensorflowPredictor,
)
def build_model() -> tf.keras.Model:
model = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=()),
# Add feature dimension, expanding (batch_size,) to (batch_size, 1).
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1),
]
)
return model
model = build_model()
checkpoint = TensorflowCheckpoint.from_model(model)
predictor = TensorflowPredictor.from_checkpoint(
checkpoint, model_definition=build_model
)
data = np.array([1, 2, 3, 4])
predictions = predictor.predict(data)
print(predictions)
# [[-1.6930283]
# [-3.3860567]
# [-5.079085 ]
# [-6.7721133]]
Predictors expose a predict
method that accepts an input batch of type DataBatchType
(which is a typing union of different standard Python ecosystem data types, such as Pandas Dataframe or Numpy Array) and outputs predictions of the same type as the input batch.
Life of a prediction: Underneath the hood, when the Predictor.predict
method is called the following occurs:
The input batch is converted into a Pandas DataFrame. Tensor input (like a
np.ndarray
) will be converted into a single-column Pandas Dataframe.If there is a Preprocessor saved in the provided Checkpoint, the preprocessor will be used to transform the DataFrame.
The transformed DataFrame will be passed to the model for inference.
The predictions will be outputted by
predict
in the same type as the original input.
Batch Prediction#
Ray AIR provides a BatchPredictor
utility for large-scale batch inference.
The BatchPredictor takes in a checkpoint and a predictor class and executes
large-scale batch prediction on a given dataset in a parallel/distributed fashion when calling predict()
.
Note
predict()
will load the entire given dataset into memory, which may be a problem if your dataset
size is larger than your available cluster memory. See the Lazy/Pipelined Prediction (experimental) section for a workaround.
import pandas as pd
from ray.train.batch_predictor import BatchPredictor
batch_predictor = BatchPredictor(
checkpoint, TensorflowPredictor, model_definition=build_model
)
# Create a dummy dataset.
ds = ray.data.from_pandas(pd.DataFrame({"feature_1": [1, 2, 3], "label": [1, 2, 3]}))
# Use `feature_columns` to specify the input columns to your model.
predictions = batch_predictor.predict(ds, feature_columns=["feature_1"])
print(predictions.show())
# {'predictions': array([-1.2789773], dtype=float32)}
# {'predictions': array([-2.5579545], dtype=float32)}
# {'predictions': array([-3.8369317], dtype=float32)}
Additionally, you can compute metrics from the predictions. Do this by:
specifying a function for computing metrics
using
keep_columns
to keep the label column in the returned datasetusing
map_batches
to compute metrics on a batch-by-batch basisAggregate batch metrics via
mean()
def calculate_accuracy(df):
return pd.DataFrame({"correct": int(df["predictions"][0]) == df["label"]})
predictions = batch_predictor.predict(
ds, feature_columns=["feature_1"], keep_columns=["label"]
)
print(predictions.show())
# {'predictions': array([-1.2789773], dtype=float32), 'label': 0}
# {'predictions': array([-2.5579545], dtype=float32), 'label': 1}
# {'predictions': array([-3.8369317], dtype=float32), 'label': 0}
correct = predictions.map_batches(calculate_accuracy)
print("Final accuracy: ", correct.mean(on="correct"))
# Final accuracy: 0.5
Configuring Batch Prediction#
To configure the computation resources for your BatchPredictor
, you have to set the following parameters in predict()
:
min_scoring_workers
andmax_scoring_workers
The BatchPredictor will internally create an actor pool to autoscale the number of workers from [min, max] to execute your transforms.
If not set, the auto-scaling range will be set to [1, inf) by default.
num_gpus_per_worker
:If you want to use GPU for batch prediction, please set this parameter explicitly.
If not specified, the BatchPredictor will perform inference on CPUs by default.
num_cpus_per_worker
:Set the number of CPUs for a worker.
separate_gpu_stage
:If using GPUs, whether to use separate stages for GPU inference and data preprocessing.
Enabled by default to avoid excessive preprocessing workload on GPU workers. You may disable it if your preprocessor is very lightweight.
Here are some examples:
1. Use multiple CPUs for Batch Prediction:
If
num_gpus_per_worker
not specified, use CPUs for batch prediction by default.Two workers with 3 CPUs each.
predictions = batch_predictor.predict(
ds,
feature_columns=["feature_1"],
min_scoring_workers=2,
max_scoring_workers=2,
num_cpus_per_worker=3,
)
2. Use multiple GPUs for Batch prediction:
Two workers, each with 1 GPU and 1 CPU (by default).
predictions = batch_predictor.predict(
ds,
feature_columns=["feature_1"],
min_scoring_workers=2,
max_scoring_workers=2,
num_gpus_per_worker=1,
)
3. Configure Auto-scaling:
Scale from 1 to 4 workers, depending on your dataset size and cluster resources.
If no min/max values are provided,
BatchPredictor
will scale from 1 to inf workers by default.
predictions = batch_predictor.predict(
ds,
feature_columns=["feature_1"],
min_scoring_workers=1,
max_scoring_workers=4,
num_cpus_per_worker=3,
)
Batch Inference Examples#
Below, we provide examples of using common frameworks to do batch inference for different data types:
Tabular#
import ray
from ray.data.preprocessors import StandardScaler
from ray.train.batch_predictor import BatchPredictor
from ray.train.xgboost import XGBoostTrainer, XGBoostPredictor
from ray.air.config import ScalingConfig
# Split data into train and validation.
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)
test_dataset = valid_dataset.drop_columns(["target"])
columns_to_scale = ["mean radius", "mean texture"]
preprocessor = StandardScaler(columns=columns_to_scale)
trainer = XGBoostTrainer(
label_column="target",
num_boost_round=20,
scaling_config=ScalingConfig(num_workers=2),
params={
"objective": "binary:logistic",
"eval_metric": ["logloss", "error"],
},
datasets={"train": train_dataset},
preprocessor=preprocessor,
)
result = trainer.fit()
# You can also create a checkpoint from a trained model using
# `XGBoostCheckpoint.from_model`.
# import xgboost as xgb
# from ray.train.xgboost import XGBoostCheckpoint
# model = xgb.Booster()
# model.load_model(...)
# checkpoint = XGBoostCheckpoint.from_model(model, path=".")
checkpoint = result.checkpoint
batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor)
predicted_probabilities = batch_predictor.predict(test_dataset)
# Call show on the output probabilities to trigger execution.
predicted_probabilities.show()
import numpy as np
import torch.nn as nn
import ray
from ray.data.preprocessors import Concatenator
from ray.train.torch import TorchCheckpoint, TorchPredictor
from ray.train.batch_predictor import BatchPredictor
def create_model(input_features: int):
return nn.Sequential(
nn.Linear(in_features=input_features, out_features=16),
nn.ReLU(),
nn.Linear(16, 16),
nn.ReLU(),
nn.Linear(16, 1),
nn.Sigmoid(),
)
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")
# All columns are features except the target column.
num_features = len(dataset.schema().names) - 1
# Specify a preprocessor to concatenate all feature columns.
prep = Concatenator(
output_column_name="concat_features", exclude=["target"], dtype=np.float32
)
checkpoint = TorchCheckpoint.from_model(
model=create_model(num_features), preprocessor=prep
)
# You can also fetch a checkpoint from a Trainer
# checkpoint = best_result.checkpoint
batch_predictor = BatchPredictor.from_checkpoint(checkpoint, TorchPredictor)
# Predict on the features.
predicted_probabilities = batch_predictor.predict(
dataset, feature_columns=["concat_features"]
)
# Call show on the output probabilities to trigger execution.
predicted_probabilities.show()
# {'predictions': array([1.], dtype=float32)}
# {'predictions': array([0.], dtype=float32)}
import numpy as np
import ray
from ray.data.preprocessors import Concatenator
from ray.train.tensorflow import TensorflowCheckpoint, TensorflowPredictor
from ray.train.batch_predictor import BatchPredictor
def create_model(input_features):
from tensorflow import keras # this is needed for tf<2.9
from tensorflow.keras import layers
return keras.Sequential(
[
keras.Input(shape=(input_features,)),
layers.Dense(16, activation="relu"),
layers.Dense(16, activation="relu"),
layers.Dense(1, activation="sigmoid"),
]
)
dataset = ray.data.read_csv("s3://anonymous@air-example-data/breast_cancer.csv")
# All columns are features except the target column.
num_features = len(dataset.schema().names) - 1
# Specify a preprocessor to concatenate all feature columns.
prep = Concatenator(
output_column_name="concat_features", exclude=["target"], dtype=np.float32
)
checkpoint = TensorflowCheckpoint.from_model(
model=create_model(num_features), preprocessor=prep
)
# You can also fetch a checkpoint from a Trainer
# checkpoint = trainer.fit().checkpoint
batch_predictor = BatchPredictor.from_checkpoint(
checkpoint, TensorflowPredictor, model_definition=lambda: create_model(num_features)
)
predicted_probabilities = batch_predictor.predict(
dataset, feature_columns=["concat_features"]
)
# Call show on the output probabilities to trigger execution.
predicted_probabilities.show()
# {'predictions': array([1.], dtype=float32)}
# {'predictions': array([0.], dtype=float32)}
Image#
from torchvision import transforms
from torchvision.models import resnet18
import ray
from ray.train.torch import TorchCheckpoint, TorchPredictor
from ray.train.batch_predictor import BatchPredictor
from ray.data.preprocessors import TorchVisionPreprocessor
data_url = "s3://anonymous@air-example-data-2/1G-image-data-synthetic-raw"
print(f"Running GPU batch prediction with 1GB data from {data_url}")
dataset = ray.data.read_images(data_url, size=(256, 256)).limit(10)
model = resnet18(pretrained=True)
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
preprocessor = TorchVisionPreprocessor(columns=["image"], transform=transform)
ckpt = TorchCheckpoint.from_model(model=model, preprocessor=preprocessor)
predictor = BatchPredictor.from_checkpoint(ckpt, TorchPredictor)
predictions = predictor.predict(dataset, batch_size=80, num_gpus_per_worker=1)
# Call show on the output probabilities to trigger execution
predictions.show()
Coming soon!
Text#
Coming soon!
Developer Guide: Implementing your own Predictor#
If you’re using an unsupported framework, or if built-in predictors are too inflexible, you may need to implement a custom predictor.
To implement a custom Predictor
,
subclass Predictor
and implement:
_predict_numpy()
or_predict_pandas()
Tip
You don’t need to implement both
_predict_numpy()
and
_predict_pandas()
. Pick the method that’s
easiest to implement. If both are implemented, override
preferred_batch_format()
to specify which format
is more performant. This allows upstream producers to choose the best format.
Examples#
We’ll walk through how to implement a predictor for two frameworks:
MXNet – a deep learning framework like Torch.
statsmodel – a Python library that provides regression and linear models.
For more examples, read the source code of built-in predictors like
TorchPredictor
,
XGBoostPredictor
, and
SklearnPredictor
.
Before you begin#
First, install MXNet and Ray AIR.
pip install mxnet 'ray[air]'
Then, import the objects required for this example.
import os
from typing import Dict, Optional, Union
import mxnet as mx
import numpy as np
from mxnet import gluon
import ray
from ray.air import Checkpoint
from ray.data.preprocessor import Preprocessor
from ray.data.preprocessors import BatchMapper
from ray.train.batch_predictor import BatchPredictor
from ray.train.predictor import Predictor
Finally, create a stub for the MXNetPredictor
class.
class MXNetPredictor(Predictor):
...
First, install statsmodel and Ray AIR.
pip install statsmodel 'ray[air]'
Then, import the objects required for this example.
import os
from typing import Optional
import numpy as np # noqa: F401
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.base.model import Results
from statsmodels.regression.linear_model import OLSResults
import ray
from ray.air import Checkpoint
from ray.data.preprocessor import Preprocessor
from ray.train.batch_predictor import BatchPredictor
from ray.train.predictor import Predictor
Finally, create a stub the StatsmodelPredictor
class.
class StatsmodelPredictor(Predictor):
...
Create a model#
You’ll need to pass a model to the MXNetPredictor
constructor.
To create the model, load a pre-trained computer vision model from the MXNet model zoo.
net = gluon.model_zoo.vision.resnet50_v1(pretrained=True)
You’ll need to pass a model to the StatsmodelPredictor
constructor.
To create the model, fit a linear model on the Guerry dataset.
data: pd.DataFrame = sm.datasets.get_rdataset("Guerry", "HistData").data
results = smf.ols("Lottery ~ Literacy + np.log(Pop1831)", data=data).fit()
Implement __init__
#
Use the constructor to set instance attributes required for prediction. In
the code snippet below, we assign the model to an attribute named net
.
def __init__(
self,
net: gluon.Block,
preprocessor: Optional[Preprocessor] = None,
):
self.net = net
super().__init__(preprocessor)
Warning
You must call the base class’ constructor; otherwise,
Predictor.predict
raises a
NotImplementedError
.
Use the constructor to set instance attributes required for prediction. In
the code snippet below, we assign the fitted model to an attribute named
results
.
def __init__(self, results: Results, preprocessor: Optional[Preprocessor] = None):
self.results = results
super().__init__(preprocessor)
Warning
You must call the base class’ constructor; otherwise,
Predictor.predict
raises a
NotImplementedError
.
Implement from_checkpoint
#
from_checkpoint()
creates a
Predictor
from a
Checkpoint
.
Before implementing from_checkpoint()
,
save the model parameters to a directory, and create a
Checkpoint
from that directory.
os.makedirs("checkpoint", exist_ok=True)
net.save_parameters("checkpoint/net.params")
checkpoint = Checkpoint.from_directory("checkpoint")
Then, implement from_checkpoint()
.
@classmethod
def from_checkpoint(
cls,
checkpoint: Checkpoint,
net: gluon.Block,
) -> Predictor:
with checkpoint.as_directory() as directory:
path = os.path.join(directory, "net.params")
net.load_parameters(path)
return cls(net, preprocessor=checkpoint.get_preprocessor())
from_checkpoint()
creates a
Predictor
from a
Checkpoint
.
Before implementing from_checkpoint()
,
save the fitten model to a directory, and create a
Checkpoint
from that directory.
os.makedirs("checkpoint", exist_ok=True)
results.save("checkpoint/guerry.pickle")
checkpoint = Checkpoint.from_directory("checkpoint")
Then, implement from_checkpoint()
.
@classmethod
def from_checkpoint(
cls,
checkpoint: Checkpoint,
filename: str,
) -> Predictor:
with checkpoint.as_directory() as directory:
path = os.path.join(directory, filename)
results = OLSResults.load(path)
return cls(results, checkpoint.get_preprocessor())
Implement _predict_numpy
or _predict_pandas
#
Because MXNet models accept tensors as input, you should implement
_predict_numpy()
.
_predict_numpy()
performs inference on a
batch of NumPy data. It accepts a np.ndarray
or dict[str, np.ndarray]
as
input and returns a np.ndarray
or dict[str, np.ndarray]
as output.
The input type is determined by the type of Dataset
passed to
BatchPredictor.predict
.
If your dataset has columns, the input is a dict
; otherwise, the input is a
np.ndarray
.
def _predict_numpy(
self,
data: Union[np.ndarray, Dict[str, np.ndarray]],
dtype: Optional[np.dtype] = None,
) -> Dict[str, np.ndarray]:
# If `data` looks like `{"features": array([...])}`, unwrap the `dict` and pass
# the array directly to the model.
if isinstance(data, dict) and len(data) == 1:
data = next(iter(data.values()))
inputs = mx.nd.array(data, dtype=dtype)
outputs = self.net(inputs).asnumpy()
return {"predictions": outputs}
Because your OLS model accepts dataframes as input, you should implement
_predict_pandas()
.
_predict_pandas()
performs inference on a
batch of pandas data. It accepts a pandas.DataFrame
as input and return a
pandas.DataFrame
as output.
def _predict_pandas(self, data: pd.DataFrame) -> pd.DataFrame:
predictions: pd.Series = self.results.predict(data)
return predictions.to_frame(name="predictions")
Perform inference#
To perform inference with the completed MXNetPredictor
:
Create a
Preprocessor
and set it in theCheckpoint
. You can also use any of the out-of-the-box preprocessors instead of implementing your own: Preprocessor.Create a
BatchPredictor
from your checkpoint.Read sample images into a
Dataset
.Call
predict
to classify the images in the dataset.
# These images aren't normalized. In practice, normalize images before inference.
dataset = ray.data.read_images(
"s3://anonymous@air-example-data-2/imagenet-sample-images", size=(224, 224)
)
def preprocess(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
# (B, H, W, C) -> (B, C, H, W)
batch["image"] = batch["image"].transpose(0, 3, 1, 2)
return batch
# Create the preprocessor and set it in the checkpoint.
# This preprocessor will be used to transform the data prior to prediction.
preprocessor = BatchMapper(preprocess, batch_format="numpy")
checkpoint.set_preprocessor(preprocessor=preprocessor)
predictor = BatchPredictor.from_checkpoint(
checkpoint, MXNetPredictor, net=net
)
predictor.predict(dataset)
To perform inference with the completed StatsmodelPredictor
:
Create a
BatchPredictor
from your checkpoint.Read the Guerry dataset into a
Dataset
.Call
predict
to perform regression on the samples in the dataset.
predictor = BatchPredictor.from_checkpoint(
checkpoint, StatsmodelPredictor, filename="guerry.pickle"
)
# This is the same data we trained our model on. Don't do this in practice.
dataset = ray.data.from_pandas(data)
predictions = predictor.predict(dataset)
predictions.show()
Lazy/Pipelined Prediction (experimental)#
If you have a large dataset but not a lot of available memory, you can use the
predict_pipelined
method.
Unlike predict()
which will load the entire data into memory, predict_pipelined
will create a
DatasetPipeline
object, which will lazily load the data and perform inference on a smaller batch of data at a time.
The lazy loading of the data will allow you to operate on datasets much greater than your available memory. Execution can be triggered by pulling from the pipeline, as shown in the example below.
import pandas as pd
import ray
from ray.air import Checkpoint
from ray.train.predictor import Predictor
from ray.train.batch_predictor import BatchPredictor
# Create a BatchPredictor that always returns `42` for each input.
batch_pred = BatchPredictor.from_pandas_udf(
lambda data: pd.DataFrame({"a": [42] * len(data)})
)
# Create a dummy dataset.
ds = ray.data.range_tensor(200, parallelism=4)
# Setup a prediction pipeline.
pipeline = batch_pred.predict_pipelined(ds, blocks_per_window=1)
for batch in pipeline.iter_batches():
print("Pipeline result", batch)
# 0 42
# 1 42
# ...
Online Inference#
Check out the Deploying Predictors with Serve for details on how to perform online inference with AIR.