Diffusion policy pattern#
This notebook builds a mini diffusion-policy pipeline on a real Pendulum-v1 offline dataset and runs it end-to-end on an Anyscale cluster with Ray Train.
Learning objectives#
How to use Ray Data to stream and preprocess Gymnasium rollouts in parallel across CPU workers.
How to scale training across multiple A10G GPUs using
TorchTrainerwith a minimalLightningModule.How to checkpoint every epoch for robust fault tolerance and auto-resume.
How to log and visualize metrics using Ray’s built-in results and observability tooling.
How to generate actions from a trained policy directly in-notebook, with no need to repackage or redeploy.
How to run the full pipeline on Anyscale Workspaces with no infrastructure setup or cluster config required.
What problem are you solving? (Inverted Pendulum, Diffusion-Style)#
You’re training a policy to swing up and balance an inverted pendulum — a classic control problem.
In the Gym Pendulum-v1 environment, the agent sees the current state of the pendulum and must decide what torque to apply at the pivot.
What’s a policy?#
A policy is a function that maps the current state to an action:
Here:
The state \(s_k\) describes where the pendulum is and how fast it’s moving
The action \(u_k\) is the torque you apply to influence future motion
The goal is to learn a policy that keeps the pendulum upright by generating the right torque at every step
Environment state and action#
At each timestep:
Symbol |
Dim |
Meaning |
|---|---|---|
\(\theta_{k}\) |
scalar |
Angle of the pendulum |
\(\dot\theta_{k}\) |
scalar |
Angular velocity |
\(u_{k}\) |
scalar |
Torque applied to the base |
The pendulum starts hanging down and must swing up and maintain balance.
Encode the state as:
This avoids angle discontinuities (no \(\pm\pi\) jumps) and keeps values in \([-1, 1]\).
1. Dataset tuples#
Train on a log of actions from a random policy, then inject artificial noise to simulate the diffusion process:
and construct a noisy action:
2. Training objective#
Train a model \(f_\theta\) to predict the injected noise, given the state, the noisy action, and the timestep:
Minimizing this loss teaches the model to de-noise \(\tilde{u}_{k}\) back toward the expert action \(u_k\).
3. Reverse diffusion (sampling)#
At inference time, start from noise \(x_T \sim \mathcal{N}(0, 1)\) and de-noise step by step:
After \(T\) steps:
is a valid torque for the current state — a sample from your learned diffusion policy.
How to scale this policy learning workload using Ray on Anyscale#
This tutorial shows how to take a local PyTorch + Gymnasium workflow and migrate it to a fully distributed, fault-tolerant Ray pipeline running on Anyscale with minimal code changes.
Here’s how the transition works:
Gym rollouts → Ray Dataset
Generate simulation rollouts fromPendulum-v1and stream them directly into a Ray Dataset, enabling distributed preprocessing (For example, normalization) and automatic partitioning across workers.Local Training → Cluster-scale Distributed Training
Wrap a minimalLightningModulein a Ray Traintrain_loop, then launch training with TorchTrainer across eight A10G GPUs. Ray handles data sharding, worker setup, and device placement without boilerplate.Manual State Saving → Lightning-Integrated Checkpointing & Auto-Resume
Checkpointing and metric tracking are handled automatically by PyTorch Lightning and Ray Train V2.
TheRayTrainReportCallback()forwards Lightning’scheckpoint.ckptfiles and logged metrics to Ray,
enabling structured, fault-tolerant training with seamless resume support. No manual save or report logic required.Ad-hoc Coordination → Declarative Orchestration
Replace manual logging, retry logic, and resource management with Ray-native configs (ScalingConfig,CheckpointConfig,FailureConfig), letting Ray + Anyscale own the orchestration.Notebook-only Inference → Cluster-aware Evaluation
After training, perform reverse diffusion sampling in-notebook using the latest checkpoint—but this can easily scale using Ray Data.
This flow upgrades a local notebook into a multi-node, resilient training + inference pipeline, using Ray’s native abstractions and running seamlessly inside an Anyscale Workspace, without sacrificing dev agility.
1. Imports and setup#
Standard scientific-Python stack, plus Ray for distributed data/training and Lightning for ergonomic model training.
# 00. Runtime setup
import os
import sys
import subprocess
# Non-secret env var
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
# Install Python dependencies
subprocess.check_call([
sys.executable, "-m", "pip", "install", "--no-cache-dir",
"torch==2.8.0",
"matplotlib==3.10.6",
"lightning==2.5.5",
"pyarrow==14.0.2",
])
# 01. Imports
# Standard Python packages for math, plotting, and data handling
import os
import shutil
import glob
import json
import uuid
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import gymnasium as gym
# Ray libraries for distributed data and training
import ray
import ray.data
from ray.train.lightning import RayLightningEnvironment
from ray.train import ScalingConfig, RunConfig, FailureConfig, CheckpointConfig, get_context, get_checkpoint, report, Checkpoint
from ray.train.torch import TorchTrainer
# PyTorch Lightning and base PyTorch for model definition and training
import lightning.pytorch as pl
import torch
from torch.utils.data import DataLoader
from torch import nn
2. Generate a real pendulum dataset#
Roll out a random policy for 10,000 steps, logging:
field |
shape |
description |
|---|---|---|
|
|
|
|
|
ground-truth action + Gaussian noise |
|
|
the injected noise (supervision target) |
|
|
random diffusion step ∈ [0, 999] |
You wrap the list of dicts in a Ray Dataset for automatic sharding.
# 02. Generate Pendulum offline dataset
def make_pendulum_dataset(n_steps: int = 10_000):
"""
Roll out a random policy in Pendulum-v1 and log (obs, noisy_action, noise, timestep).
Returns a Ray Dataset ready for sharding.
"""
env = gym.make("Pendulum-v1")
obs, _ = env.reset(seed=0)
data = []
for _ in range(n_steps):
action = env.action_space.sample().astype(np.float32) # shape (1,)
noise = np.random.randn(*action.shape).astype(np.float32)
noisy_action = action + noise # add Gaussian noise
timestep = np.random.randint(0, 1000, dtype=np.int64)
data.append(
{
"obs": obs.astype(np.float32), # shape (3,)
"noisy_action": noisy_action, # shape (1,)
"noise": noise, # shape (1,)
"timestep": timestep,
}
)
# Step environment
obs, _, terminated, truncated, _ = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
return ray.data.from_items(data)
ds = make_pendulum_dataset()
3. Normalize and split#
Pendulum states lie roughly in [–π, π].
Scale to [–1, 1], then shuffle and split 80 / 20 into train and val shards.
All transformations execute in parallel across the Ray cluster.
# 03. Normalize and split (vector obs ∈ [-π, π])
# Normalize pixel values from [0, 1] to [-1, 1] for training
def normalize(batch):
# Pendulum observations are roughly in [-π, π] → scale to [-1, 1]
batch["obs"] = batch["obs"] / np.pi
return batch
# Apply normalization in parallel using Ray Data
ds = ds.map_batches(normalize, batch_format="numpy")
# Count total number of items (triggers actual execution)
total = ds.count()
print("Total dataset size:", total)
# Shuffle and split dataset into 80% training and 20% validation
split_idx = int(total * 0.8)
ds = ds.random_shuffle()
train_ds, val_ds = ds.split_at_indices([split_idx])
print("Train size:", train_ds.count())
print("Val size:", val_ds.count())
4. DiffusionPolicy LightningModule#
A tiny MLP that predicts the injected noise ϵ given:
3D normalized state
1D noisy action
scalar timestep (normalized by
max_t)
It logs per-epoch losses so you can plot later.
# 04. DiffusionPolicy for low-dim observation (3D) and action (1D)
class DiffusionPolicy(pl.LightningModule):
"""Tiny MLP that predicts injected noise ϵ given (obs, noisy_action, timestep)."""
def __init__(self, obs_dim: int = 3, act_dim: int = 1, max_t: int = 1000):
super().__init__()
self.max_t = max_t
# 3D obs + 1D action + 1 timestep → 1D noise
self.net = nn.Sequential(
nn.Linear(obs_dim + act_dim + 1, 128),
nn.ReLU(),
nn.Linear(128, act_dim),
)
self.loss_fn = nn.MSELoss()
# ---------- forward ----------
def forward(self, obs, noisy_action, timestep):
t = timestep.view(-1, 1).float() / self.max_t
x = torch.cat([obs, noisy_action, t], dim=1)
return self.net(x)
# ---------- shared loss ----------
def _shared_step(self, batch):
pred = self.forward(
batch["obs"].float(),
batch["noisy_action"],
batch["timestep"],
)
return self.loss_fn(pred, batch["noise"])
# ---------- training / validation ----------
def training_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log("train_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
loss = self._shared_step(batch)
self.log("val_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
5. Distributed Train loop with checkpointing#
This per-worker function demonstrates:
Ray Data to PyTorch:
iter_torch_batches()each epochLightning-on-Ray: single-GPU trainer per worker
Fault tolerance: resume from the latest Ray Train checkpoint
checkpoint: saves the metadata and metrics every epoch
# 05. Ray Train Lightning-native training loop
def train_loop(config):
import os, tempfile, torch, warnings
import lightning.pytorch as pl
from ray.train import get_checkpoint, get_context
from ray.train.lightning import (
RayLightningEnvironment,
RayDDPStrategy,
RayTrainReportCallback,
prepare_trainer,
)
warnings.filterwarnings(
"ignore", message="barrier.*using the device under current context"
)
# ---- Ray Dataset shards → iterable torch batches ----
train_ds = ray.train.get_dataset_shard("train")
val_ds = ray.train.get_dataset_shard("val")
train_loader = train_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))
val_loader = val_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))
# ---- Model ----
model = DiffusionPolicy()
# ---- Local scratch for PL checkpoints (Ray will persist to storage_path) ----
CKPT_ROOT = os.path.join(tempfile.gettempdir(), "ray_pl_ckpts")
os.makedirs(CKPT_ROOT, exist_ok=True)
# ---- Lightning Trainer configured for Ray ----
trainer = pl.Trainer(
max_epochs=config.get("epochs", 10),
devices="auto",
accelerator="auto",
strategy=RayDDPStrategy(),
plugins=[RayLightningEnvironment()],
callbacks=[
RayTrainReportCallback(), # forwards metrics + ckpt to Ray
pl.callbacks.ModelCheckpoint( # local PL checkpoints each epoch
dirpath=CKPT_ROOT,
filename="epoch-{epoch:03d}",
every_n_epochs=1,
save_top_k=-1,
save_last=True,
),
],
default_root_dir=CKPT_ROOT,
enable_progress_bar=False,
check_val_every_n_epoch=1,
)
# ---- Prepare trainer for Ray environment ----
trainer = prepare_trainer(trainer)
# ---- Resume from Ray checkpoint if available ----
ckpt_path = None
ckpt = get_checkpoint()
if ckpt:
with ckpt.as_directory() as d:
candidate = os.path.join(d, "checkpoint.ckpt")
if os.path.exists(candidate):
ckpt_path = candidate
if get_context().get_world_rank() == 0:
print(f"✅ Resuming from Lightning checkpoint: {ckpt_path}")
# ---- Run training (Lightning owns the loop) ----
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path)
6. Launch Ray TorchTrainer#
Eight A10G workers train in parallel.
RunConfig keeps the five most recent checkpoints and automatically restarts
up to three times on failure.
# 06. Launch distributed training with Ray TorchTrainer
trainer = TorchTrainer(
train_loop,
scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
datasets={"train": train_ds, "val": val_ds},
run_config=RunConfig(
name="pendulum_diffusion_ft",
storage_path="/mnt/cluster_storage/pendulum_diffusion/pendulum_diffusion_results",
checkpoint_config=CheckpointConfig(
num_to_keep=5,
checkpoint_score_attribute="epoch",
checkpoint_score_order="max",
),
failure_config=FailureConfig(max_failures=3),
),
)
result = trainer.fit()
print("Training complete →", result.metrics)
best_ckpt = result.checkpoint # latest Ray-managed Lightning checkpoint
7. Plot train / val loss#
Visualize convergence using Ray’s built-in metrics.
result.metrics_dataframe automatically collects the losses logged by Lightning each epoch.
# 07. Plot training and validation loss (Ray + Lightning integration)
df = result.metrics_dataframe
print(df.head()) # optional sanity check
if "train_loss" not in df.columns or "val_loss" not in df.columns:
raise ValueError("train_loss / val_loss missing. Did you log them via self.log()?")
plt.figure(figsize=(7, 4))
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Pendulum Diffusion - Loss per Epoch (Ray Train + Lightning)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
8. Reverse diffusion helper#
Iteratively de-noise a random action vector 50 steps back to a feasible Pendulum command.
# 08. Reverse diffusion sampling for 1-D action
# Function to simulate reverse diffusion process
def sample_action(model, obs, n_steps=50, device="cpu"):
"""
Runs reverse diffusion starting from noise to generate a Pendulum action.
obs: torch.Tensor of shape (3,)
returns: torch.Tensor of shape (1,)
"""
model.eval()
with torch.no_grad():
obs = obs.unsqueeze(0).to(device) # [1, 3]
obs = obs / np.pi # Same normalization used in training
x = torch.randn(1, 1).to(device) # Start from noise in action space
for step in reversed(range(n_steps)):
t = torch.tensor([step], device=device)
pred_noise = model(obs, x, t)
x = x - pred_noise * 0.1
return x.squeeze(0)
9. Sample an action from the trained policy#
Finally, load the latest epoch checkpoint, supply a sample state
[cos θ = 1, sin θ = 0, θ̇ = 0], and generate a 1-D torque command.
# 09. In-notebook sampling from trained model (Ray Lightning checkpoint)
# A plausible pendulum state: [cos(theta), sin(theta), theta_dot]
obs_sample = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32) # shape (3,)
assert best_ckpt is not None, "No checkpoint found — did training complete successfully?"
# Load the trained model from Ray's latest Lightning checkpoint
model = DiffusionPolicy(obs_dim=3, act_dim=1)
with best_ckpt.as_directory() as ckpt_dir:
# RayTrainReportCallback saves a file named "checkpoint.ckpt"
ckpt_file = os.path.join(ckpt_dir, "checkpoint.ckpt")
if not os.path.exists(ckpt_file):
# Fallback: search any .ckpt file if name differs
candidates = glob.glob(os.path.join(ckpt_dir, "*.ckpt"))
ckpt_file = candidates[0] if candidates else None
assert ckpt_file is not None, f"No Lightning checkpoint found in {ckpt_dir}"
state = torch.load(ckpt_file, map_location="cpu")
model.load_state_dict(state.get("state_dict", state), strict=False)
# Move to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Run reverse diffusion sampling
action = sample_action(model, obs_sample, n_steps=50, device=device)
print("Sampled action:", action)
10. Clean up#
When you’re finished, release Ray resources and clear any temporary files.
This ensures the cluster is ready for other jobs and avoids unnecessary storage costs.
# 10. Cleanup -- delete checkpoints and metrics from model training
TARGET_PATH = "/mnt/cluster_storage/pendulum_diffusion"
if os.path.exists(TARGET_PATH):
shutil.rmtree(TARGET_PATH)
print(f"✅ Deleted everything under {TARGET_PATH}")
else:
print(f"⚠️ Path does not exist: {TARGET_PATH}")
Wrap up and next steps#
You transformed a synthetic control demo into a Ray-native, real-data pipeline, training a diffusion policy across multiple GPUs, surviving worker restarts, and sampling feasible actions, all within a distributed Ray environment.
This tutorial demonstrates:
Logging continuous-control trajectories directly into a Ray Dataset for scalable preprocessing.
Streaming data into a Ray Train workload using Ray Data and Lightning with minimal integration overhead.
Saving structured checkpoints automatically through Lightning + Ray Train callbacks, ensuring seamless fault-tolerant recovery.
Running reverse diffusion sampling directly in-notebook.
Next steps#
The following are a few directions you can explore to extend or adapt this workload:
Evaluate in the environment
Load the best checkpoint, deploy the policy in Gym’s
Pendulum-v1, and log episode returns.Compare against baseline behavior cloning or TD3/TD3+BC.
Larger and richer datasets
Generate 100 k+ steps with a scripted controller or collect data from a learned agent.
Swap in other classic-control tasks like
CartPoleorMountainCar.
Model and loss upgrades
Add timestep embeddings or a small transformer for better temporal reasoning.
Experiment with different noise schedules or auxiliary consistency losses.
Hyperparameter sweeps
Wrap the training loop in Ray Tune and grid-search learning rate, hidden size, or diffusion steps.
Use Tune’s automatic checkpoint pruning to keep only the top-N runs.
Mixed precision and performance
Enable
torch.set_float32_matmul_precision('high')to leverage A10G Tensor Cores.Profile GPU utilization across workers and tune batch size accordingly.
Real robotics logs
Replace Pendulum with logs from a real robotic apparatus stored in Parquet; Ray Data shards them the same way.
Serving the policy
Export the trained MLP to TorchScript and deploy with Ray Serve for low-latency inference.
Hook it to a real-time simulator or a web dashboard.
End-to-end MLOps
Track checkpoints and metrics with MLflow or Weights & Biases.
Schedule nightly Ray jobs on Anyscale to retrain as new data arrives.