Key Concepts#

Here, we cover the main concepts in AIR.


Ray Data is the standard way to load and exchange data in Ray AIR. It provides a Dataset concept which is used extensively for data loading, preprocessing, and batch inference.


Preprocessors are primitives that can be used to transform input data into features. Preprocessors operate on Datasets, which makes them scalable and compatible with a variety of datasources and dataframe libraries.

A Preprocessor is fitted during Training, and applied at runtime in both Training and Serving on data batches in the same way. AIR comes with a collection of built-in preprocessors, and you can also define your own with simple templates.

See the documentation on Preprocessors.

import ray
import pandas as pd
from sklearn.datasets import load_breast_cancer

from import *

# Split data into train and validation.
dataset ="s3://anonymous@air-example-data/breast_cancer.csv")
train_dataset, valid_dataset = dataset.train_test_split(test_size=0.3)
test_dataset = valid_dataset.drop_columns(["target"])

columns_to_scale = ["mean radius", "mean texture"]
preprocessor = StandardScaler(columns=columns_to_scale)


Trainers are wrapper classes around third-party training frameworks such as XGBoost and Pytorch. They are built to help integrate with core Ray actors (for distribution), Ray Tune, and Ray Data.

See the documentation on Trainers.

from ray.train.xgboost import XGBoostTrainer
from ray.air.config import ScalingConfig

num_workers = 2
use_gpu = False
# XGBoost specific params
params = {
    "tree_method": "approx",
    "objective": "binary:logistic",
    "eval_metric": ["logloss", "error"],
    "max_depth": 2,

trainer = XGBoostTrainer(
        # Make sure to leave some CPUs free for Ray Data operations.
    datasets={"train": train_dataset, "valid": valid_dataset},

result =

Trainer objects produce a Result object after calling .fit(). These objects contain training metrics as well as checkpoints to retrieve the best model.



Tuners offer scalable hyperparameter tuning as part of Ray Tune.

Tuners can work seamlessly with any Trainer but also can support arbitrary training functions.

from ray import tune
from ray.tune.tuner import Tuner, TuneConfig

tuner = Tuner(
    param_space={"params": {"max_depth": tune.randint(1, 9)}},
    tune_config=TuneConfig(num_samples=5, metric="train-logloss", mode="min"),
result_grid =
best_result = result_grid.get_best_result()


The AIR trainers, tuners, and custom pretrained model generate a framework-specific Checkpoint object. Checkpoints are a common interface for models that are used across different AIR components and libraries.

There are two main ways to generate a checkpoint.

Checkpoint objects can be retrieved from the Result object returned by a Trainer or Tuner .fit() call.

checkpoint = result.checkpoint
# Checkpoint(local_path=..../checkpoint_000005)

tuned_checkpoint = result_grid.get_best_result().checkpoint
# Checkpoint(local_path=..../checkpoint_000005)

You can also generate a checkpoint from a pretrained model. Each AIR supported machine learning (ML) framework has a Checkpoint object that can be used to generate an AIR checkpoint:

from ray.train.tensorflow import TensorflowCheckpoint
import tensorflow as tf

# This can be a trained model.
def build_model() -> tf.keras.Model:
    model = tf.keras.Sequential(
    return model

model = build_model()

checkpoint = TensorflowCheckpoint.from_model(model)

Checkpoints can be used to instantiate a Predictor, BatchPredictor, or PredictorDeployment classes, as seen below.

Batch Predictor#

You can take a checkpoint and do batch inference using the BatchPredictor object.

from ray.train.batch_predictor import BatchPredictor
from ray.train.xgboost import XGBoostPredictor

batch_predictor = BatchPredictor.from_checkpoint(result.checkpoint, XGBoostPredictor)

# Bulk batch prediction.
predicted_probabilities = batch_predictor.predict(test_dataset)

# Pipelined batch prediction: instead of processing the data in bulk, process it
# incrementally in windows of the given size.
pipeline = batch_predictor.predict_pipelined(test_dataset, bytes_per_window=1048576)


Deploy the model as an inference service by using Ray Serve and the PredictorDeployment class.

from ray import serve
from fastapi import Request
from ray.serve import PredictorDeployment
from ray.serve.http_adapters import json_request

async def adapter(request: Request):
    content = await request.json()
    return pd.DataFrame.from_dict(content)
        XGBoostPredictor, result.checkpoint, batching_params=False, http_adapter=adapter

After deploying the service, you can send requests to it.

import requests

sample_input = test_dataset.take(1)
sample_input = dict(sample_input[0])

output ="http://localhost:8000/", json=[sample_input]).json()