Model Registry Integration#
Ray Serve is Python-native, which means it integrates seamlessly with the broader MLOps ecosystem. You can easily connect Ray Serve deployments to Model Registry, enabling production-ready ML workflows without complex configuration or glue code. This guide shows you how to integrate Ray Serve with Model Registry to build end-to-end ML serving pipelines.
Why Python-native integration matters#
Unlike framework-specific serving solutions that require custom adapters or complex configuration, Ray Serve runs arbitrary Python code. This means you can:
Load models directly from any model registry using standard Python clients
Combine model loading and inference in a single deployment
Iterate quickly without wrestling with YAML configurations or custom serialization formats
Integrate with MLflow#
MLflow is a popular open-source platform for managing the ML lifecycle. Ray Serve makes it easy to load models from MLflow Model Registry and serve them in production.
Best practices for serving MLflow models#
Use model signatures and input schema validation: Always log a model signature using
mlflow.models.infer_signatureso MLflow can validate inputs. This prevents silent failures when upstream code changes and enables automatic schema enforcement during serving.Package dependencies explicitly: Use
pip_requirementswhen logging models and pin versions of core libraries. This ensures your model behaves identically across training, evaluation, and serving environments.Persist preprocessing pipelines: If you use scikit-learn, log complete
Pipelineobjects that include preprocessing steps. This ensures training and serving transformations stay aligned.For LLMs and diffusion models, use Hugging Face Hub or Weights & Biases: MLflow’s built-in REST server isn’t optimized for high-concurrency GPU workloads. For large language models, diffusion models, and other heavy transformer-based architectures, use Hugging Face Hub or Weights & Biases as your model registry. These platforms provide better tooling for large model artifacts, and Ray Serve handles GPU batching, autoscaling, and scheduling efficiently.
Train and register a model#
The following example shows how to train a scikit-learn model with best practices and register it with MLflow:
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import mlflow
import mlflow.sklearn
import mlflow.pyfunc
from mlflow.entities import LoggedModelStatus
from mlflow.models import infer_signature
import numpy as np
def train_and_register_model():
# Initialize model in PENDING state
logged_model = mlflow.initialize_logged_model(
name="sk-learn-random-forest-reg-model",
model_type="sklearn",
tags={"model_type": "random_forest"},
)
try:
with mlflow.start_run() as run:
X, y = make_regression(n_features=4, n_informative=2, random_state=0, shuffle=False)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
params = {"max_depth": 2, "random_state": 42}
# Best Practice: Use sklearn Pipeline to persist preprocessing
# This ensures training and serving transformations stay aligned
pipeline = Pipeline([
("scaler", StandardScaler()),
("regressor", RandomForestRegressor(**params))
])
pipeline.fit(X_train, y_train)
# Log parameters and metrics
mlflow.log_params(params)
y_pred = pipeline.predict(X_test)
mlflow.log_metrics({"mse": mean_squared_error(y_test, y_pred)})
# Best Practice: Infer model signature for input validation
# Prevents silent failures from mismatched feature order or missing columns
signature = infer_signature(X_train, y_pred)
# Best Practice: Pin dependency versions explicitly
# Ensures identical behavior across training, evaluation, and serving
pip_requirements = [
f"scikit-learn=={__import__('sklearn').__version__}",
f"numpy=={np.__version__}",
]
# Log the sklearn pipeline with signature and dependencies
mlflow.sklearn.log_model(
sk_model=pipeline,
name="sklearn-model",
input_example=X_train[:1],
signature=signature,
pip_requirements=pip_requirements,
registered_model_name="sk-learn-random-forest-reg-model",
model_id=logged_model.model_id,
)
# Finalize model as READY
mlflow.finalize_logged_model(logged_model.model_id, LoggedModelStatus.READY)
mlflow.set_logged_model_tags(
logged_model.model_id,
tags={"production": "true"},
)
except Exception as e:
# Mark model as FAILED if issues occur
mlflow.finalize_logged_model(logged_model.model_id, LoggedModelStatus.FAILED)
raise
# Retrieve and work with the logged model
final_model = mlflow.get_logged_model(logged_model.model_id)
print(f"Model {final_model.name} is {final_model.status}")
This function trains a RandomForestRegressor wrapped in a Pipeline with preprocessing, logs the model with a signature and pinned dependencies, and registers it in MLflow Model Registry with the name sk-learn-random-forest-reg-model.
Load and serve the model#
Once you’ve registered a model in MLflow, you can load and serve it with Ray Serve. The following example shows how to create a deployment that loads a model from MLflow Model Registry with warm-start initialization:
from ray import serve
import mlflow.pyfunc
import numpy as np
@serve.deployment
class MLflowModelDeployment:
def __init__(self):
# Search for models with production tag
models = mlflow.search_logged_models(
filter_string="tags.production='true' AND name='sk-learn-random-forest-reg-model'",
order_by=[{"field_name": "creation_time", "ascending": False}],
)
if models.empty:
raise ValueError("No model with production tag found")
# Get the most recent production model
model_row = models.iloc[0]
artifact_location = model_row["artifact_location"]
# Best Practice: Load model once during initialization (warm-start)
# This eliminates first-request latency spikes
self.model = mlflow.pyfunc.load_model(artifact_location)
# Pre-warm the model with a dummy prediction
dummy_input = np.zeros((1, 4))
_ = self.model.predict(dummy_input)
async def __call__(self, request):
data = await request.json()
features = np.array(data["features"])
# MLflow validates input against the logged signature automatically
prediction = self.model.predict(features)
return {"prediction": prediction.tolist()}
app = MLflowModelDeployment.bind()