Use a pretrained model for batch or online inference

Ray AIR moves end to end machine learning workloads seamlessly through the construct of Checkpoint. Checkpoint is the output of training and tuning as well as the input to downstream inference tasks.

Having said that, it is entirely possible and supported to use Ray AIR in a piecemeal fashion.

Say you already have a model trained elsewhere, you can use Ray AIR for downstream tasks such as batch and online inference. To do that, you would need to convert the pretrained model together with any preprocessing steps into Checkpoint.

To facilitate this, we have prepared framework specific to_air_checkpoint helper function.

Examples:

import ray
import tensorflow as tf
from ray.train.batch_predictor import BatchPredictor
from ray.train.tensorflow import (
    to_air_checkpoint,
    TensorflowPredictor,
)


# to simulate having a pretrained model.
def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(1,)),
            tf.keras.layers.Dense(1),
        ]
    )
    return model


model = build_model()
checkpoint = to_air_checkpoint(model)
batch_predictor = BatchPredictor(
    checkpoint, TensorflowPredictor, model_definition=build_model
)
predict_dataset = ray.data.range(3)
predictions = batch_predictor.predict(predict_dataset)