Recommendation system pattern#
This notebook builds a scalable matrix factorization recommendation system using the MovieLens 100K dataset, fully distributed on an Anyscale cluster with Ray Train and Ray Data. For larger scale recommendation use-cases we additionally have an integration with TorchRec. An example can be found here.
Learning objectives#
How to use Ray Data to load, encode, and shard tabular datasets across many workers.
How to stream training data directly into PyTorch using
iter_torch_batches().How to build a custom training loop with validation and checkpointing using
ray.train.report().How to use Ray Train’s fault-tolerant trainer to resume training from the latest checkpoint with no extra logic.
How to separate training, evaluation, and inference while keeping all code modular and distributed-ready.
How to run real-world recommendation workloads with no changes to your model code, using Ray’s orchestration.
What problem are you solving? (matrix factorization for recommendations)#
Build a collaborative filtering recommendation system that predicts how much a user likes an item
based on historical interaction data—in this case, user ratings from the MovieLens 100K dataset.
Use matrix factorization, a classic yet scalable approach where you embed each user and item in a latent space.
The model learns to represent users and items as vectors and predicts ratings by computing their dot product.
Input: user–item–rating triples#
Each row in the dataset represents a user’s explicit rating of a movie:
Encode these using contiguous integer indices (user_idx, item_idx)
and normalize them for efficient embedding lookup and training.
Model: embedding-based matrix factorization#
Learn an embedding vector for each user and each item:
The predicted rating is the dot product of these vectors:
The embedding dimension \(d\) controls model capacity.
Training objective#
Minimize Mean Squared Error (MSE) between predicted and actual ratings:
This encourages the model to assign higher scores to user–item pairs that historically received high ratings.
Inference: ranking items per user#
Once the model trains, you can recommend items by computing predicted scores for a target user
against all items in the catalog (approximate methods can later be applied at scale):
Sort these scores and return the top-N items as personalized recommendations.
How to migrate this recommendation system workload to a distributed setup using Ray on Anyscale#
This tutorial migrates a local matrix factorization pipeline for recommendation into a distributed, fault-tolerant training loop using Ray Train and Ray Data on Anyscale.
Approach the transition with the following steps:
Convert parquet files to sharded Ray Dataset
Load MovieLens 100K to parquet, encode the IDs to create a multi-block Ray Dataset. Each block is a training shard that Ray can distribute across workers.Stream Torch data loaders
Instead of manually writing PyTorchDatasetlogic, useiter_torch_batches()from Ray Data to stream batches directly into each worker. Ray handles all the parallelism and sharding behind the scenes.Convert a single-node PyTorch process to a multi-GPU distributed training
Write a minimaltrain_loop_per_workerthat runs on each Ray worker. UsingTorchTrainerandprepare_model(), scale this loop across eight GPU workers automatically, where each works on its own data shard.Configure structured epoch logging and checkpoints
Each epoch logstrain_lossandval_lossand report checkpoints withray.train.report(checkpoint=...). This enables automatic recovery and metric tracking without any additional code.Declaratively configure tolerance, checkpointing and scaling
Configure fault tolerance, checkpointing, and scaling usingScalingConfig,CheckpointConfig, andFailureConfig. This lets Ray and Anyscale handle retries, recovery, and GPU orchestration.Write lightweight Python functions for post-training inference
After training, load the latest checkpoint and generate top-N recommendations for any user with a simple forward pass. No retraining, no re-initialization, just pure PyTorch inference.
With just a few changes to your core code, scale a traditional recommendation pipeline across a Ray cluster with distributed data loading, checkpointing, fault tolerance, and parallel training, all fully managed by Anyscale.
1. Imports#
Start by importing all the libraries you need for the rest of the notebook. These include standard utilities like os, json, and pandas, as well as deep learning libraries like PyTorch and visualization tools like matplotlib.
Also, import everything needed for distributed training and data processing with Ray:
rayandray.dataprovide the high-level distributed data API.ray.traingives youTorchTrainer,ScalingConfig, checkpointing, and metrics reporting.prepare_modelwraps your PyTorch model for multi-worker training with Distributed Data Parallel (DDP).
A few extra helpers like tqdm and train_test_split round out the list for progress bars and quick offline preprocessing.
# 00. Runtime setup
import os, sys, subprocess
# Non-secret env var
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
# Install Python dependencies
subprocess.check_call([
sys.executable, "-m", "pip", "install", "--no-cache-dir",
"torch==2.8.0",
"matplotlib==3.10.6",
"pyarrow==14.0.2",
])
# 01. Imports
# Standard libraries
import os
import uuid
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import zipfile
import shutil
import tempfile
# PyTorch
import torch
from torch import nn
import torch.nn.functional as F
# Ray
import ray
import ray.data
from ray.train import ScalingConfig, RunConfig, CheckpointConfig, FailureConfig, Checkpoint, get_checkpoint, get_context, get_dataset_shard, report
from ray.train.torch import TorchTrainer, prepare_model
# Other
from tqdm import tqdm
2. Load MovieLens 100K dataset#
Download and extract the MovieLens 100K dataset, then persist a cleaned copy under /mnt/cluster_storage/rec_sys_tutorial/raw/ in two formats:
CSV:
ratings.csv(kept for later inference cells).Parquet dataset:
ratings_parquet/as multiple shards (production-style blob store layout) so Ray Data can stream reads in parallel without materializing the full dataset.
The output has four columns: user_id, item_id, rating, and timestamp.
The MovieLens 100K dataset contains 100,000 ratings across 943 users and 1,682 movies — small enough for quick iteration, yet realistic for demonstrating distributed streaming and training with Ray Data + Ray Train.
# 02. Load MovieLens 100K Dataset and store in /mnt/cluster_storage/ as CSV + Parquet
# Define clean working paths
DATA_URL = "http://files.grouplens.org/datasets/movielens/ml-100k.zip"
LOCAL_ZIP = "/mnt/cluster_storage/rec_sys_tutorial/ml-100k.zip"
EXTRACT_DIR = "/mnt/cluster_storage/rec_sys_tutorial/ml-100k"
OUTPUT_CSV = "/mnt/cluster_storage/rec_sys_tutorial/raw/ratings.csv"
PARQUET_DIR = "/mnt/cluster_storage/rec_sys_tutorial/raw/ratings_parquet"
# Ensure target directories exist
os.makedirs("/mnt/cluster_storage/rec_sys_tutorial/raw", exist_ok=True)
# Download only if not already done
if not os.path.exists(LOCAL_ZIP):
!wget -q $DATA_URL -O $LOCAL_ZIP
# Extract cleanly
if not os.path.exists(EXTRACT_DIR):
import zipfile
with zipfile.ZipFile(LOCAL_ZIP, 'r') as zip_ref:
zip_ref.extractall("/mnt/cluster_storage/rec_sys_tutorial")
# Load raw file
raw_path = os.path.join(EXTRACT_DIR, "u.data")
df = pd.read_csv(raw_path, sep="\t", names=["user_id", "item_id", "rating", "timestamp"])
# Persist CSV (kept for later inference cell that expects CSV)
df.to_csv(OUTPUT_CSV, index=False)
# Persist a Parquet *dataset* (multiple files) to simulate blob storage layout
if os.path.exists(PARQUET_DIR):
shutil.rmtree(PARQUET_DIR)
os.makedirs(PARQUET_DIR, exist_ok=True)
NUM_PARQUET_SHARDS = 8
for i, shard in enumerate(np.array_split(df, NUM_PARQUET_SHARDS)):
shard.to_parquet(os.path.join(PARQUET_DIR, f"part-{i:02d}.parquet"), index=False)
print(f"✅ Loaded {len(df):,} ratings → CSV: {OUTPUT_CSV}")
print(f"✅ Wrote Parquet dataset with {NUM_PARQUET_SHARDS} shards → {PARQUET_DIR}")
3. Point to Parquet dataset URI#
Instead of creating a Ray Dataset from in-memory pandas objects, this tutorial now reads data directly from a Parquet dataset stored in persistent cluster storage.
This URI is used by Ray Data to stream Parquet shards efficiently across workers without loading the full dataset into memory.
# 03. Point to Parquet dataset URI
DATASET_URI = os.environ.get(
"RATINGS_PARQUET_URI",
"/mnt/cluster_storage/rec_sys_tutorial/raw/ratings_parquet",
)
print("Parquet dataset URI:", DATASET_URI)
4. Visualize dataset: ratings, users, and items#
Before training, visualize the distribution of ratings, user activity, and item popularity.
These plots serve as a quick sanity check to confirm the dataset loaded correctly and to highlight patterns in user–item interactions:
Rating distribution: shows how often each rating (1–5 stars) occurs, typically skewed toward higher scores.
Ratings per user: reveals the long-tail behavior where a few users rate many items, while most rate only a few.
Ratings per item: similarly shows that a handful of popular items receive most of the ratings.
This visualization works with either raw IDs (user_id, item_id) or encoded indices (user_idx, item_idx), depending on what’s available in the current DataFrame.
# 04. Visualize dataset: ratings, user and item activity
# Use encoded indices if present; otherwise fall back to raw IDs
user_col = "user_idx" if "user_idx" in df.columns else "user_id"
item_col = "item_idx" if "item_idx" in df.columns else "item_id"
plt.figure(figsize=(12, 4))
# Rating distribution
plt.subplot(1, 3, 1)
df["rating"].hist(bins=[0.5,1.5,2.5,3.5,4.5,5.5], edgecolor='black')
plt.title("Rating Distribution")
plt.xlabel("Rating"); plt.ylabel("Frequency")
# Number of ratings per user
plt.subplot(1, 3, 2)
df[user_col].value_counts().hist(bins=30, edgecolor='black')
plt.title("Ratings per User")
plt.xlabel("# Ratings"); plt.ylabel("Users")
# Number of ratings per item
plt.subplot(1, 3, 3)
df[item_col].value_counts().hist(bins=30, edgecolor='black')
plt.title("Ratings per Item")
plt.xlabel("# Ratings"); plt.ylabel("Items")
plt.tight_layout()
plt.show()
5. Create Ray Dataset from Parquet and encode IDs#
Read the MovieLens ratings directly from the Parquet dataset using ray.data.read_parquet(). This keeps data in a streaming, non-materialized form suitable for large-scale distributed processing.
Next, build lightweight global ID mappings for users and items on the driver to convert raw user_id and item_id values into contiguous integer indices (user_idx, item_idx) required for embedding layers.
This mapping step materializes only the distinct IDs (a small subset of the data) while keeping the main dataset lazy and scalable.
Finally, apply a map_batches() transformation to encode each batch of rows in parallel across the cluster.
The resulting Ray Dataset remains distributed and ready for streaming batches directly into the Ray Train workers.
# 05. Create Ray Dataset by reading Parquet, then encode IDs via Ray
# Read Parquet dataset directly
ratings_ds = ray.data.read_parquet(DATASET_URI)
print("✅ Parquet dataset loaded (streaming, non-materialized)")
ratings_ds.show(3)
# ---- Build global ID mappings on the driver ----
user_ids = sorted([r["user_id"] for r in ratings_ds.groupby("user_id").count().take_all()])
item_ids = sorted([r["item_id"] for r in ratings_ds.groupby("item_id").count().take_all()])
user2idx = {uid: j for j, uid in enumerate(user_ids)}
item2idx = {iid: j for j, iid in enumerate(item_ids)}
NUM_USERS = len(user2idx)
NUM_ITEMS = len(item2idx)
print(f"Users: {NUM_USERS:,} | Items: {NUM_ITEMS:,}")
# ---- Encode to contiguous indices within Ray (keeps everything distributed) ----
def encode_batch(pdf: pd.DataFrame) -> pd.DataFrame:
pdf["user_idx"] = pdf["user_id"].map(user2idx).astype("int64")
pdf["item_idx"] = pdf["item_id"].map(item2idx).astype("int64")
return pdf[["user_idx", "item_idx", "rating", "timestamp"]]
ratings_ds = ratings_ds.map_batches(encode_batch, batch_format="pandas")
print("✅ Encoded Ray Dataset schema:", ratings_ds.schema())
ratings_ds.show(3)
6. Train/validation split using Ray Data#
Next, split the dataset into training and validation sets. First, shuffle the entire Ray Dataset to ensure randomization, then split by row index, using 80% for training and 20% for validation.
This approach is simple and scalable: Ray handles the shuffling and slicing in parallel across blocks. Also, set a fixed seed to ensure the split is reproducible. After you split it, each dataset remains a fully distributed Ray Dataset, ready to stream into workers.
# 06. Train/val split using Ray Data (lazy, avoids materialization)
TRAIN_FRAC = 0.8
SEED = 42 # for reproducibility
# Block-level shuffle + proportional split (approximate by block, lazy)
train_ds, val_ds = (
ratings_ds
.randomize_block_order(seed=SEED) # lightweight; no row-level materialization
.split_proportionately([TRAIN_FRAC]) # returns [train, remainder]
)
print("✅ Train/Val Split:")
print(f" Train → {train_ds.count():,} rows")
print(f" Val → {val_ds.count():,} rows")
7. Define matrix factorization model#
Define a simple but effective matrix factorization model using PyTorch. A learned embedding vector represents each user and item. The model predicts a rating by taking the dot product of the corresponding user and item embeddings.
This architecture is commonly used in collaborative filtering and serves as a strong baseline for recommendation tasks. It’s also well-suited for scaling with Ray Train and DistributedDataParallel (DDP).
# 07. Define matrix factorization model
class MatrixFactorizationModel(nn.Module):
def __init__(self, num_users: int, num_items: int, embedding_dim: int = 64):
super().__init__()
self.user_embedding = nn.Embedding(num_users, embedding_dim)
self.item_embedding = nn.Embedding(num_items, embedding_dim)
def forward(self, user_idx, item_idx):
user_vecs = self.user_embedding(user_idx)
item_vecs = self.item_embedding(item_idx)
dot_product = (user_vecs * item_vecs).sum(dim=1)
return dot_product
8. Define Ray Train loop (with validation, checkpointing, and Ray-managed metrics)#
Define the train_loop_per_worker, the core function executed by each Ray Train worker.
This loop handles distributed training, validation, and checkpointing with Ray-managed metrics.
Each worker receives its own shard of the training and validation datasets using get_dataset_shard().
Batches are streamed directly into PyTorch via iter_torch_batches(), ensuring efficient, fully distributed data loading.
During each epoch:
Compute average training and validation MSE losses.
On rank 0 only, save a temporary checkpoint (model weights + epoch metadata) using
tempfile.TemporaryDirectory().Call
ray.train.report()to report metrics and attach the checkpoint; other workers report metrics only.
All metrics are automatically captured by Ray and made available in result.metrics_dataframe, enabling progress tracking and fault-tolerant recovery without extra logging logic.
# 08. Define Ray Train loop (with val loss, checkpointing, and Ray-managed metrics)
def train_loop_per_worker(config):
import tempfile
# ---------------- Dataset shards -> PyTorch-style iterators ---------------- #
train_ds = get_dataset_shard("train")
val_ds = get_dataset_shard("val")
train_loader = train_ds.iter_torch_batches(batch_size=512, dtypes=torch.float32)
val_loader = val_ds.iter_torch_batches(batch_size=512, dtypes=torch.float32)
# ---------------- Model / Optimizer ---------------- #
model = MatrixFactorizationModel(
num_users=config["num_users"],
num_items=config["num_items"],
embedding_dim=config.get("embedding_dim", 64),
)
model = prepare_model(model)
optimizer = torch.optim.Adam(model.parameters(), lr=config.get("lr", 1e-3))
# ---------------- Checkpointing setup ---------------- #
rank = get_context().get_world_rank()
start_epoch = 0
# If a checkpoint exists (auto-resume), load it
ckpt = get_checkpoint()
if ckpt:
with ckpt.as_directory() as ckpt_dir:
model.load_state_dict(
torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu")
)
start_epoch = torch.load(os.path.join(ckpt_dir, "meta.pt")).get("epoch", 0) + 1
if rank == 0:
print(f"[Rank {rank}] ✅ Resumed from checkpoint at epoch {start_epoch}")
# ---------------- Training loop ---------------- #
for epoch in range(start_epoch, config.get("epochs", 5)):
# ---- Train ----
model.train()
train_losses = []
for batch in train_loader:
user = batch["user_idx"].long()
item = batch["item_idx"].long()
rating = batch["rating"].float()
pred = model(user, item)
loss = F.mse_loss(pred, rating)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses.append(loss.item())
avg_train_loss = sum(train_losses) / max(1, len(train_losses))
# ---- Validate ----
model.eval()
val_losses = []
with torch.no_grad():
for batch in val_loader:
user = batch["user_idx"].long()
item = batch["item_idx"].long()
rating = batch["rating"].float()
pred = model(user, item)
loss = F.mse_loss(pred, rating)
val_losses.append(loss.item())
avg_val_loss = sum(val_losses) / max(1, len(val_losses))
# Console log (optional)
if rank == 0:
print(f"[Epoch {epoch}] Train MSE: {avg_train_loss:.4f} | Val MSE: {avg_val_loss:.4f}")
metrics = {
"epoch": epoch,
"train_loss": avg_train_loss,
"val_loss": avg_val_loss,
}
# ---- Save checkpoint & report (rank 0 attaches checkpoint; others report metrics only) ----
if rank == 0:
with tempfile.TemporaryDirectory() as tmpdir:
torch.save(model.state_dict(), os.path.join(tmpdir, "model.pt"))
torch.save({"epoch": epoch}, os.path.join(tmpdir, "meta.pt"))
ckpt_out = Checkpoint.from_directory(tmpdir)
report(metrics, checkpoint=ckpt_out)
else:
report(metrics, checkpoint=None)
9. Launch distributed training with Ray Train#
Now, launch distributed training using TorchTrainer, Ray Train’s high-level orchestration interface. Provide it with:
Your custom
train_loop_per_workerfunctionA
train_configdictionary that specifies model dimensions, learning rate, and number of epochsThe sharded
trainandvalRay DatasetsA
ScalingConfigthat sets the number of workers and GPU usage
Also, configure checkpointing and fault tolerance:
Ray keeps all checkpoints checkpoints for later plotting
Failed workers retry up to two times
Calling trainer.fit() kicks off training across the cluster. If any workers fail or disconnect, Ray restarts them and resume from the latest checkpoint.
# 09. Launch distributed training with Ray TorchTrainer
# Define config params (use Ray-derived counts)
train_config = {
"num_users": NUM_USERS,
"num_items": NUM_ITEMS,
"embedding_dim": 64,
"lr": 1e-3,
"epochs": 20,
}
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config=train_config,
scaling_config=ScalingConfig(
num_workers=8, # Increase as needed
use_gpu=True # Set to True if training on GPUs
),
datasets={"train": train_ds, "val": val_ds},
run_config=RunConfig(
name="mf_ray_train",
storage_path="/mnt/cluster_storage/rec_sys_tutorial/results",
checkpoint_config=CheckpointConfig(num_to_keep=20),
failure_config=FailureConfig(max_failures=2)
)
)
# Run distributed training
result = trainer.fit()
10. Plot train and validation loss curves#
After training, retrieve the full metrics history directly from Ray Train’s internal tracking via result.metrics_dataframe.
This DataFrame automatically includes all reported metrics across epochs (e.g., train_loss, val_loss) for every call to ray.train.report().
You use it to visualize model convergence and ensure the training loop, checkpointing, and reporting worked correctly.
The plotted curves show how the training and validation MSE losses evolve over time—confirming whether the model is learning effectively and when it begins to stabilize.
# 10. Plot train/val loss curves (from Ray Train results)
# Pull the full metrics history Ray stored for this run
df = result.metrics_dataframe.copy()
# Keep only the columns we need (guard against extra columns)
cols = [c for c in ["epoch", "train_loss", "val_loss"] if c in df.columns]
df = df[cols].dropna()
# If multiple rows per epoch exist, keep the last report per epoch
if "epoch" in df.columns:
df = df.sort_index().groupby("epoch", as_index=False).last()
# Plot
plt.figure(figsize=(7, 4))
if "train_loss" in df.columns:
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
if "val_loss" in df.columns:
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Matrix Factorization - Loss per Epoch")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
11. Resume training from checkpoint#
Run trainer.fit() again to resume training from the most recent checkpoint. Since TorchTrainer was originally configured with resume_from_checkpoint=True and a persistent storage_path, Ray automatically restores the latest saved model state and continues training from the correct epoch.
This demonstrates Ray Train’s built-in support for fault tolerance and iterative experimentation, allowing training to pick up exactly where it left off without manual intervention.
# 11. Run trainer.fit() again to resume from last checkpoint
result = trainer.fit()
12. Inference: recommend top-N items for a user#
To demonstrate inference, generate top-10 item recommendations for a randomly selected user. Please note that the following method is meant for this small example, and Ray Data should be used for inference at scale.
First, reload the original ratings.csv and rebuild the user and item ID mappings used during training. Then, load the latest model checkpoint and restore the trained embedding weights. If you trained the model with DDP, strip the 'module.' prefix from checkpoint keys.
Next, select a user, compute their embedding, and take the dot product against all item embeddings to produce predicted scores. Finally, extract the top-N items with the highest scores and print their IDs and associated scores.
# 12. Inference: recommend top-N items for a user
# ---------------------------------------------
# Step 1: Reload original ratings CSV and mappings
# ---------------------------------------------
df = pd.read_csv("/mnt/cluster_storage/rec_sys_tutorial/raw/ratings.csv")
# Recompute ID mappings (same as during preprocessing)
unique_users = sorted(df["user_id"].unique())
unique_items = sorted(df["item_id"].unique())
user2idx = {uid: j for j, uid in enumerate(unique_users)}
item2idx = {iid: j for j, iid in enumerate(unique_items)}
idx2item = {v: k for k, v in item2idx.items()}
# ---------------------------------------------
# Step 2: Load model from checkpoint
# ---------------------------------------------
model = MatrixFactorizationModel(
num_users=len(user2idx),
num_items=len(item2idx),
embedding_dim=train_config["embedding_dim"]
)
with result.checkpoint.as_directory() as ckpt_dir:
state_dict = torch.load(os.path.join(ckpt_dir, "model.pt"), map_location="cpu")
# Remove 'module.' prefix if using DDP-trained model
if any(k.startswith("module.") for k in state_dict):
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
model.eval()
# ---------------------------------------------
# Step 3: Select a user and generate recommendations
# ---------------------------------------------
# Choose a random user from the original dataset
original_user_id = df["user_id"].sample(1).iloc[0]
user_idx = user2idx[original_user_id]
print(f"Generating recommendations for user_id={original_user_id} (internal idx={user_idx})")
# Compute scores for all items for this user
with torch.no_grad():
user_vector = model.user_embedding(torch.tensor([user_idx])) # [1, D]
item_vectors = model.item_embedding.weight # [num_items, D]
scores = torch.matmul(user_vector, item_vectors.T).squeeze(0) # [num_items]
topk = torch.topk(scores, k=10)
top_item_ids = [idx2item[j.item()] for j in topk.indices]
top_scores = topk.values.tolist()
# ---------------------------------------------
# Step 4: Print top-N recommendations
# ---------------------------------------------
print("\nTop 10 Recommended Item IDs:")
for i, (item_id, score) in enumerate(zip(top_item_ids, top_scores), 1):
print(f"{i:2d}. Item ID: {item_id} | Score: {score:.2f}")
13. Join top-N item IDs with movie titles#
To make your recommendations more interpretable, join the top-10 recommended item_ids with movie titles from the original u.item metadata file.
Load only the relevant columns—item_id and title—from u.item, then merge them with the top-N predictions you computed in the previous step. The result is a user-friendly list of movie titles with associated predicted scores, rather than raw item IDs.
This small addition makes the model outputs easier to understand and more useful for downstream applications.
# 13. Join top-N item IDs with movie titles from u.item
item_metadata = pd.read_csv(
"/mnt/cluster_storage/rec_sys_tutorial/ml-100k/u.item",
sep="|",
encoding="latin-1",
header=None,
usecols=[0, 1], # Only item_id and title
names=["item_id", "title"]
)
# Join with top-N items
top_items_df = pd.DataFrame({
"item_id": top_item_ids,
"score": top_scores
})
merged = top_items_df.merge(item_metadata, on="item_id", how="left")
print("\nTop 10 Recommended Movies:")
for j, row in merged.iterrows():
print(f"{j+1:2d}. {row['title']} | Score: {row['score']:.2f}")
Wrap up and next steps#
In this tutorial, you used Ray Train and Ray Data on Anyscale to scale a full matrix factorization recommendation system, end-to-end, from a raw CSV to multi-GPU distributed training and personalized top-N item recommendations.
This tutorial demonstrates:
Using Ray Data to preprocess, encode, and shard large tabular datasets.
Streaming data into PyTorch with
iter_torch_batches()for efficient training.Scaling matrix factorization across multiple GPUs with Ray Train’s
TorchTrainer.Saving and resuming training with Ray Checkpoints.
Running multi-node, fault-tolerant jobs without touching orchestration code.
Performing post-training inference using Ray-restored model checkpoints and learned user and item embeddings.
Next steps#
The following are a few directions you can explore to extend or adapt this workload:
Ranking metrics and evaluation
Add metrics like Root Mean Squared Error (RMSE), Normalized Discounted Cumulative Gain (NDCG), or Hit@K to evaluate recommendation quality.
Filter out already-rated items during inference to measure novelty.
Two-tower and deep models
Replace dot product with a two-tower neural model or a deep MLP.
Add side features (for example, timestamp, genre) into each tower for better personalization.
Recommendation personalization
Store and cache user embeddings after training.
Run lightweight inference tasks to generate recommendations in real-time.
Content-based or hybrid models
Join movie metadata (genres, tags) and build a hybrid collaborative–content model.
Embed titles or genres using pre-trained language models.
Hyperparameter optimization
Use Ray Tune to sweep embedding sizes, learning rates, or regularization.
Track performance over epochs and checkpoint the best models automatically.
Data scaling
Switch from MovieLens 100K to 1M or 10M as Ray Data handles it seamlessly.
Save and load from cloud object storage (S3, GCS) for real-world deployments.
Production inference
Wrap the recommendation system into a Ray Serve endpoint for serving top-N results using Ray Data-based inference.
Build a simple demo that recommends movies to live users.
End-to-end MLOps
Register the best model with MLflow or Weights & Biases.
Package the training job as a Ray job and schedule it with Anyscale.
Multi-tenant recommendation systems
Extend this to support multiple audiences or contexts (for example, multi-country, A/B groups).
Train and serve context-aware models in parallel using Ray.
This pattern gives you a solid foundation for scaling recommendation workloads across real datasets and real infrastructure—without rewriting your model or managing your cluster.