Tabular workload pattern#
In this tutorial you take the classic Cover type forest-cover dataset (580 k rows, 54 tabular features) and scale an XGBoost model across an Anyscale cluster using Ray Train.
Learning objectives#
Ingest tabular data at scale using Ray Data and persist it to Parquet for reproducibility
Launch a fault-tolerant, checkpoint enabled XGBoost training loop on multiple CPUs using Ray Train
Resume training from checkpoints for protection against job restarts and hardware failures
Evaluate model accuracy, visualize feature importance, and scale batch inference using Ray Data
Understand how to port classic gradient boosting workflows into a fully distributed, multi-node training setup on Anyscale
What problem are you solving? (Forest cover classification with XGBoost)#
You’re predicting which type of forest vegetation (for example, Lodge-pole Pine, Spruce/Fir, Aspen) is present at a given land location, using only numeric and binary cartographic features such as elevation, slope, soil type, and proximity to roads or hydrology.
What’s XGBoost?#
XGBoost (Extreme Gradient Boosting) is a fast, scalable machine learning algorithm based on gradient-boosted decision trees. It builds a sequence of shallow decision trees, where each new tree tries to correct the errors of the previous ensemble by minimizing a differentiable loss (like log-loss).
In your case, minimize the multi-class Softmax log-loss, learning a function:
that maps a 54-dimensional tabular input (raw geo-spatial features) to a forest cover type. Each boosting round fits a new tree on the gradient of the loss, gradually improving accuracy over hundreds of rounds.
How to migrate this tabular workload to a distributed setup using Ray on Anyscale#
This tutorial walks through the end-to-end process of migrating a local XGBoost training pipeline to a distributed Ray cluster running on Anyscale.
The following steps make that transition:
Store local data as remote data
Store the raw data as Parquet in a shared cloud directory and load it using Ray Data, which streams and shards the dataset across workers automatically.Convert a single-process to multi-worker training
Define a customtrain_func, then let Ray Train spin up 16 distributed training workers (1 per CPU) and runxgb.trainin parallel, each with its own data shard.Configure Ray for automated fault tolerance
WithRayTrainReportCallbackandCheckpointConfig, Ray saves checkpoints every 10 boosting rounds and can resume mid-training if any worker crashes or a job is re-launched.Use Ray’s cluster-scale abstractions
Skip the boilerplate of manually slicing datasets, coordinating workers, or building launch scripts. Instead, declare intent (withScalingConfig,RunConfig, andFailureConfig) and let Ray and Anyscale manage the execution.Use offline inference
Batch inference is done with Ray Data on CPU workers. This is useful for seamlessly evolving the pipeline to large-scale production environments.
This pattern turns a traditional single-node workflow into a scalable, resilient training pipeline with minimal code changes, and it works seamlessly on any cluster you provision through Anyscale.
1. Imports#
Before you touch any data, import every tool you need.
Alongside the standard scientific-Python stack, bring in XGBoost for gradient-boosted decision trees and Ray for distributed data loading and training. Ray Train’s helper classes (RunConfig, ScalingConfig, CheckpointConfig, FailureConfig) give you fault-tolerant, CPU training with almost no extra code.
# 00. Runtime setup
import os
import sys
import subprocess
# Non-secret env var
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
# Install Python dependencies
subprocess.check_call([
sys.executable, "-m", "pip", "install", "--no-cache-dir",
"matplotlib==3.10.6",
"scikit-learn==1.7.2",
"pyarrow==14.0.2",
"xgboost==3.0.5",
"seaborn==0.13.2",
])
# 01. Imports
import os
import shutil
import json
import uuid
import tempfile
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import fetch_covtype
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split
import xgboost as xgb
import pyarrow as pa
import ray
import ray.data as rd
from ray.data import ActorPoolStrategy
from ray.train import RunConfig, ScalingConfig, CheckpointConfig, FailureConfig, get_dataset_shard, get_checkpoint, get_context
from ray.train.xgboost import XGBoostTrainer, RayTrainReportCallback
2. Load the University of California, Irvine (UCI) Cover type dataset#
The Cover type dataset contains ~580,000 forest-cover observations with 54 tabular features and a 7-class label. Fetch it from sklearn.datasets, rename the target column to label (Ray’s default), and shift the classes from 1-7 to 0-6 so they’re zero-indexed as XGBoost expects. A quick value_counts sanity-check confirms the mapping worked.
# 02. Load the UCI Cover type dataset (~580k rows, 54 features)
data = fetch_covtype(as_frame=True)
df = data.frame
df.rename(columns={"Cover_Type": "label"}, inplace=True) # Ray expects "label"
df["label"] = df["label"] - 1 # 1-7 → 0-6
assert df["label"].between(0, 6).all()
print(df.shape, df.label.value_counts(normalize=True).head())
3. Visualize class balance#
Highly imbalanced targets can bias tree-based models, so plot the raw label counts. The cover type distribution shows skew, but not much—the bar chart lets you judge whether extra re-scaling or class-weighting is necessary. Rely on XGBoost’s built-in handling for this step.
# 03. Visualize class distribution
df.label.value_counts().plot(kind="bar", figsize=(6,3), title="Cover Type distribution")
plt.ylabel("Frequency"); plt.show()
4. Write train / validation Parquet files#
Rather than splitting a large dataset in memory later, you persist train and validation splits up front.
Each split is written to the cluster’s shared volume (/mnt/cluster_storage) so that all Ray workers can access it directly.
This approach keeps the workflow reproducible and avoids rematerializing the dataset during distributed training.
You perform a stratified 80 / 20 split to preserve class balance across splits, then write each subset to its own Parquet file.
Parquet is columnar and compressed, making it ideal for Ray Data ingestion and parallel reads.
# 04. Write separate train/val Parquets to /mnt/cluster_storage/covtype/
PARQUET_DIR = "/mnt/cluster_storage/covtype/parquet"
os.makedirs(PARQUET_DIR, exist_ok=True)
TRAIN_PARQUET = os.path.join(PARQUET_DIR, "train.parquet")
VAL_PARQUET = os.path.join(PARQUET_DIR, "val.parquet")
# Stratified 80/20 split for reproducibility
train_df, val_df = train_test_split(
df, test_size=0.2, random_state=42, stratify=df["label"]
)
train_df.to_parquet(TRAIN_PARQUET, index=False)
val_df.to_parquet(VAL_PARQUET, index=False)
print(f"Wrote Train → {TRAIN_PARQUET} ({len(train_df):,} rows)")
print(f"Wrote Val → {VAL_PARQUET} ({len(val_df):,} rows)")
5. Load the train and validation splits as Ray Datasets#
Now that the data is stored in Parquet, you load each split directly with ray.data.read_parquet.
Each call returns a lazy, columnar Ray Dataset that supports distributed reads and transformations across the cluster.
Calling .random_shuffle() on the training split ensures balanced sampling during training,
while leaving the validation split unshuffled preserves its deterministic order for evaluation.
From this point forward, all data access is parallel and streaming, eliminating single-node I/O bottlenecks.
# 05. Load the two splits as Ray Datasets (lazy, columnar)
train_ds = rd.read_parquet(TRAIN_PARQUET).random_shuffle()
val_ds = rd.read_parquet(VAL_PARQUET)
print(train_ds)
print(val_ds)
6. Inspect dataset sizes (optional)#
After loading the Parquet files, quickly confirm that both splits were read correctly by counting their rows.
This step triggers a lightweight distributed count across the cluster and verifies that the
train / validation partitioning matches the expected 80 / 20 ratio before moving on to distributed training.
print(f"Train rows: {train_ds.count():,}, Val rows: {val_ds.count():,}") # Note that this will materialize the dataset (skip at scale)
7. Inspect a mini-batch#
Taking a tiny pandas batch helps verify that feature columns and labels have the expected shapes and types. You also build feature_columns, a list you reuse when building XGBoost’s DMatrix.
# 07. Look into one batch to confirm feature dimensionality
batch = train_ds.take_batch(batch_size=5, batch_format="pandas")
print(batch.head())
feature_columns = [c for c in batch.columns if c != "label"]
8. Define the Ray Train worker loop (Arrow-based, memory-efficient)#
Each Ray Train worker runs its own copy of train_func.
Inside the loop, the worker pulls its shard of the train and validation datasets directly as Arrow tables.
You then:
Materialize each shard into a
pyarrow.Tableand drop any accidental index columns (like__index_level_0__)
that might have been added during Parquet serialization.Convert Arrow → NumPy → XGBoost DMatrix with explicit
feature_names, ensuring consistent column order
across all workers and splits.Optionally resume from a prior checkpoint using
get_checkpoint().Train the booster with
xgb.train, using the built-inRayTrainReportCallback()to automatically stream
per-round metrics and checkpoints back to Ray Train.
This design keeps the data path fully distributed and avoids unnecessary copies or manual metric handling.
INDEX_COLS = {"__index_level_0__"} # extend if needed
def _arrow_table_from_shard(name: str) -> pa.Table:
"""Collect this worker's Ray Dataset shard into one pyarrow. Table and
drop accidental index columns (e.g., from pandas Parquet)."""
ds_iter = get_dataset_shard(name)
arrow_refs = ds_iter.materialize().to_arrow_refs()
tables = [ray.get(r) for r in arrow_refs]
tbl = pa.concat_tables(tables, promote_options="none") if tables else pa.table({})
# Drop index columns if present
keep = [c for c in tbl.column_names if c not in INDEX_COLS]
if len(keep) != len(tbl.column_names):
tbl = tbl.select(keep)
return tbl
def _dmat_from_arrow(table: pa.Table, feature_cols, label_col: str):
"""Build XGBoost DMatrix from pyarrow.Table with explicit feature_names."""
X = np.column_stack([table[c].to_numpy(zero_copy_only=False) for c in feature_cols])
y = table[label_col].to_numpy(zero_copy_only=False)
return xgb.DMatrix(X, label=y, feature_names=feature_cols)
def train_func(config):
label_col = config["label_column"]
# Arrow tables
train_arrow = _arrow_table_from_shard("train")
eval_arrow = _arrow_table_from_shard("evaluation")
# Use the SAME ordered feature list for both splits
feature_cols = [c for c in train_arrow.column_names if c != label_col]
dtrain = _dmat_from_arrow(train_arrow, feature_cols, label_col)
deval = _dmat_from_arrow(eval_arrow, feature_cols, label_col)
# -------- 2) Optional resume from checkpoint ------------------------------
ckpt = get_checkpoint()
if ckpt:
with ckpt.as_directory() as d:
model_path = os.path.join(d, RayTrainReportCallback.CHECKPOINT_NAME)
booster = xgb.Booster()
booster.load_model(model_path)
print(f"[Rank {get_context().get_world_rank()}] Resumed from checkpoint")
else:
booster = None
# -------- 3) Train with per-round reporting & checkpointing ---------------
evals_result = {}
xgb.train(
params = config["params"],
dtrain = dtrain,
evals = [(dtrain, "train"), (deval, "validation")],
num_boost_round = config["num_boost_round"],
xgb_model = booster,
evals_result = evals_result,
callbacks = [RayTrainReportCallback()],
)
9. Configure XGBoost and build the Trainer#
Next, define the XGBoost hyperparameters and wrap the train_func in an XGBoostTrainer for distributed execution.
Each worker is assigned an entire CPU node (resources_per_worker={"CPU": CPUS_PER_WORKER}),
allowing XGBoost to use all local cores efficiently through the nthread parameter.
Key settings:
ScalingConfig— controls how many workers to launch and their CPU/GPU allocation.CheckpointConfig— saves a checkpoint every 10 boosting rounds and scores each checkpoint by
validation log-loss (validation-mlogloss), retaining only the best model.FailureConfig— automatically retries failed workers once for fault tolerance.
By passing the Ray Datasets directly into the trainer, Ray handles dataset sharding and distributed streaming automatically,
so each worker trains on its own slice of the data without manual coordination.
# 09. XGBoost config and Trainer (full-node CPU workers)
# Adjust this to your node size if different (e.g., 16, 32, etc.)
CPUS_PER_WORKER = 4
xgb_params = {
"objective": "multi:softprob",
"num_class": 7,
"eval_metric": "mlogloss",
"tree_method": "hist",
"eta": 0.3,
"max_depth": 8,
"nthread": CPUS_PER_WORKER,
}
trainer = XGBoostTrainer(
train_func,
scaling_config=ScalingConfig(
num_workers=2,
use_gpu=False,
resources_per_worker={"CPU": CPUS_PER_WORKER},
),
datasets={"train": train_ds, "evaluation": val_ds},
train_loop_config={
"label_column": "label",
"params": xgb_params,
"num_boost_round": 50,
},
run_config=RunConfig(
name="covtype_xgb_cpu",
storage_path="/mnt/cluster_storage/covtype/results",
checkpoint_config=CheckpointConfig(
num_to_keep=1,
checkpoint_score_attribute="validation-mlogloss", # score by val loss
checkpoint_score_order="min",
),
failure_config=FailureConfig(max_failures=1),
),
)
10. Start distributed training#
trainer.fit() blocks until all boosting rounds finish, or until Ray exhausts retries. The result object contains the last reported metrics and the best checkpoint found so far. Print the final validation log-loss and keep a handle to the checkpoint for inference.
# 10. Fit the trainer (reports eval metrics every boosting round)
result = trainer.fit()
best_ckpt = result.checkpoint # saved automatically by Trainer
11. Evaluate the trained model#
Pull the XGBoost Booster back from the checkpoint, run predictions on the entire validation set, and compute overall accuracy. Converting the Ray Dataset to pandas keeps the example short. In production you stream batches instead of materializing the whole frame.
# 11. Retrieve Booster object from Ray checkpoint
booster = RayTrainReportCallback.get_model(best_ckpt)
# Convert Ray Dataset to pandas for quick local scoring
val_pd = val_ds.to_pandas()
dmatrix = xgb.DMatrix(val_pd[feature_columns])
pred_prob = booster.predict(dmatrix)
pred_labels = np.argmax(pred_prob, axis=1)
acc = accuracy_score(val_pd.label, pred_labels)
print(f"Validation accuracy: {acc:.3f}")
12. Confusion matrix visualization#
Raw counts and row-normalized ratios highlight which cover types the model confuses most often. Diagonal dominance indicates good performance; off-diagonal hot spots may suggest a need for more data or feature engineering for those specific classes.
# 12. Confusion matrix
cm = confusion_matrix(val_pd.label, pred_labels) # or sample_batch.label if used
sns.heatmap(cm, annot=True, fmt="d", cmap="viridis")
plt.title("Confusion Matrix with Counts")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()
cm_norm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="viridis")
plt.title("Normalized Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()
13. CPU batch inference with Ray Data#
Use Ray Data for scalable, parallel inference.
Each actor loads the trained model once and processes data batches in parallel,
providing better throughput than ad-hoc remote tasks and avoiding repeated model loads.
# 13. CPU batch inference with Ray Data
# Assumes: val_ds, feature_columns, best_ckpt already defined.
class XGBPredictor:
"""Stateful actor: load Booster once, reuse across batches."""
def __init__(self, ckpt, feature_cols):
self.model = RayTrainReportCallback.get_model(ckpt)
self.feature_cols = feature_cols
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
dmatrix = xgb.DMatrix(batch[self.feature_cols])
probs = self.model.predict(dmatrix)
preds = np.argmax(probs, axis=1)
return pd.DataFrame(
{"pred": preds.astype(np.int32), "label": batch["label"].astype(np.int32)}
)
# Use an ActorPoolStrategy instead of compute="actors"
pred_ds = val_ds.map_batches(
XGBPredictor,
fn_constructor_args=(best_ckpt, feature_columns),
batch_format="pandas",
compute=ActorPoolStrategy(),
num_cpus=1, # per-actor CPU; tune as needed
)
# Aggregate accuracy without collecting to driver
stats_ds = pred_ds.map_batches(
lambda df: pd.DataFrame({
"correct": [int((df["pred"].to_numpy() == df["label"].to_numpy()).sum())],
"n": [int(len(df))]
}),
batch_format="pandas",
)
correct = int(stats_ds.sum("correct"))
n = int(stats_ds.sum("n"))
print(f"Validation accuracy (Ray Data inference): {correct / n:.3f}")
14. Feature-importance diagnostics#
XGBoost’s built-in get_score(importance_type="gain") ranks each feature by its average gain across all splits. Visualizing the top-15 helps connect model behaviour back to domain knowledge. For example, elevation and soil type often dominate forest-cover prediction.
# 14. Gain‑based feature importance
importances = booster.get_score(importance_type="gain")
keys, gains = zip(*sorted(importances.items(), key=lambda kv: kv[1], reverse=True)[:15])
plt.barh(range(len(gains)), gains)
plt.yticks(range(len(gains)), keys)
plt.gca().invert_yaxis()
plt.title("Top-15 Feature Importances (gain)"); plt.xlabel("Average gain"); plt.show()
15. Continue training from the latest checkpoint#
Because train_func always checks for get_checkpoint(), re-invoking trainer.fit() automatically resumes boosting from where you left off. Call fit() a second time and print the new best validation log-loss.
# 15. Run 50 more training iterations from the last saved checkpoint
result = trainer.fit()
best_ckpt = result.checkpoint # Saved automatically by Trainer
16. Verify post-training inference#
Rerun the Ray Data inference pipeline with the latest checkpoint to confirm that
additional boosting rounds improved validation accuracy.
This reuses the same distributed actors, ensuring consistent and scalable evaluation.
# 16. Rerun Ray Data inference to verify improved accuracy after continued training
# Reuse the existing Ray Data inference setup with the latest checkpoint
pred_ds = val_ds.map_batches(
XGBPredictor,
fn_constructor_args=(best_ckpt, feature_columns),
batch_format="pandas",
compute=ActorPoolStrategy(),
num_cpus=1,
)
# Aggregate accuracy across all batches
stats_ds = pred_ds.map_batches(
lambda df: pd.DataFrame({
"correct": [int((df["pred"] == df["label"]).sum())],
"n": [int(len(df))]
}),
batch_format="pandas",
)
correct = int(stats_ds.sum("correct"))
n = int(stats_ds.sum("n"))
print(f"Validation accuracy after continued training: {correct / n:.3f}")
17. Clean up#
Finally, tidy up by deleting temporary checkpoint folders, the metrics CSV, and any intermediate result directories. Clearing out old artifacts frees disk space and leaves your workspace clean for whatever comes next.
# 17. Optional cleanup to free space
ARTIFACT_DIR = "/mnt/cluster_storage/covtype"
if os.path.exists(ARTIFACT_DIR):
shutil.rmtree(ARTIFACT_DIR)
print(f"Deleted {ARTIFACT_DIR}")
Wrap up and next steps#
You built a fast and fault-tolerant XGBoost training loop that runs on real data, scales across CPUs, recovers from worker failures, and supports batch inference, all inside a single notebook.
This tutorial demonstrates:
Using Ray Data to ingest, shuffle, and shard large tabular datasets across a cluster.
Defining custom
train_funcs that run on Ray Train workers and resume seamlessly from checkpoints.Tracking per-round metrics and saving checkpoints with RayTrainReportCallback.
Leveraging Ray’s distributed execution model to evaluate and monitor models without manual orchestration.
Launching remote CPU-powered inference tasks using Ray Data for scalable batch scoring.
Next steps#
Below are a few directions you might explore to adapt or extend the pattern:
Early stopping and best iteration tracking
Add
early_stopping_rounds=10toxgb.trainand log the best round.Track performance delta across resumed runs.
Hyperparameter sweeps
Wrap the trainer with Ray Tune and search over
eta,max_depth, orsubsample.Use Tune’s built-in checkpoint pruning and log callbacks.
Feature engineering at scale
Create new features using
Ray Dataset.map_batches, such as terrain interactions or log-scaled distances.Materialize multiple Parquet shards and benchmark load time.
Model interpretability
Use XGBoost’s built-in
Booster.get_scorefor feature attributions.Rank features by importance and validate with domain knowledge.
Serving the model
Package the Booster as a Ray task or Ray Serve endpoint.
Deploy an API that takes a feature vector and returns the predicted cover type.
Real-time logging
Integrate with MLflow or Weights & Biases to store logs, plots, and checkpoints.
Use tags and metadata to track experiments over time.
Alternative objectives
Try a binary objective (for example, presence versus absence of a species) or regression target (for example, canopy height).
Fine-tune loss functions for specific ecological tasks.
End-to-end MLOps
Schedule retraining with Ray Jobs or Anyscale Jobs.
Upload new data snapshots and trigger daily training runs with automatic checkpoint cleanup.