Diffusion policy pattern#

   

This notebook builds a mini diffusion-policy pipeline on a real Pendulum-v1 offline dataset and runs it end-to-end on an Anyscale cluster with Ray Train.

Learning objectives#

  • How to use Ray Data to stream and preprocess Gymnasium rollouts in parallel across CPU workers.

  • How to scale training across multiple A10G GPUs using TorchTrainer with a minimal LightningModule.

  • How to checkpoint every epoch for robust fault tolerance and auto-resume.

  • How to log and visualize metrics using Ray’s built-in results and observability tooling.

  • How to generate actions from a trained policy directly in-notebook, with no need to repackage or redeploy.

  • How to run the full pipeline on Anyscale Workspaces with no infrastructure setup or cluster config required.

What problem are you solving? (Inverted Pendulum, Diffusion-Style)#

You’re training a policy to swing up and balance an inverted pendulum — a classic control problem.
In the Gym Pendulum-v1 environment, the agent sees the current state of the pendulum and must decide what torque to apply at the pivot.


What’s a policy?#

A policy is a function that maps the current state to an action:

\[ \pi_\theta(s_{k}) \;\longrightarrow\; u_{k} \]

Here:

  • The state \(s_k\) describes where the pendulum is and how fast it’s moving

  • The action \(u_k\) is the torque you apply to influence future motion

  • The goal is to learn a policy that keeps the pendulum upright by generating the right torque at every step


Environment state and action#

At each timestep:

Symbol

Dim

Meaning

\(\theta_{k}\)

scalar

Angle of the pendulum

\(\dot\theta_{k}\)

scalar

Angular velocity

\(u_{k}\)

scalar

Torque applied to the base

The pendulum starts hanging down and must swing up and maintain balance.

Encode the state as:

\[ s_{k} = [\cos\theta_{k},\ \sin\theta_{k},\ \dot\theta_{k}] \in \mathbb{R}^3 \]

This avoids angle discontinuities (no \(\pm\pi\) jumps) and keeps values in \([-1, 1]\).


1. Dataset tuples#

Train on a log of actions from a random policy, then inject artificial noise to simulate the diffusion process:

\[ \varepsilon_{k} \sim \mathcal{N}(0, 1), \quad t_{k} \sim \text{Uniform}\{0,\dots,T{-}1\} \]

and construct a noisy action:

\[ \tilde{u}_k = u_{k} + \varepsilon_{k} \]

2. Training objective#

Train a model \(f_\theta\) to predict the injected noise, given the state, the noisy action, and the timestep:

\[ \mathcal{L} = \mathbb{E}_{s_{k},\varepsilon_k,t_{k}}\ \big\|f_\theta(s_k, \tilde{u}_k, t_{k}) - \varepsilon_k\big\|_2^2 \]

Minimizing this loss teaches the model to de-noise \(\tilde{u}_{k}\) back toward the expert action \(u_k\).


3. Reverse diffusion (sampling)#

At inference time, start from noise \(x_T \sim \mathcal{N}(0, 1)\) and de-noise step by step:

\[ x_{t} \;\leftarrow\; x_{t} - \eta \cdot f_\theta(s, x_{t}, t), \quad t = T{-}1, \dots, 0 \]

After \(T\) steps:

\[ x_0 \approx u^\star \]

is a valid torque for the current state — a sample from your learned diffusion policy.


How to scale this policy learning workload using Ray on Anyscale#

This tutorial shows how to take a local PyTorch + Gymnasium workflow and migrate it to a fully distributed, fault-tolerant Ray pipeline running on Anyscale with minimal code changes.

Here’s how the transition works:

  1. Gym rollouts → Ray Dataset
    Generate simulation rollouts from Pendulum-v1 and stream them directly into a Ray Dataset, enabling distributed preprocessing (For example, normalization) and automatic partitioning across workers.

  2. Local Training → Cluster-scale Distributed Training
    Wrap a minimal LightningModule in a Ray Train train_loop, then launch training with TorchTrainer across eight A10G GPUs. Ray handles data sharding, worker setup, and device placement without boilerplate.

  3. Manual State Saving → Lightning-Integrated Checkpointing & Auto-Resume
    Checkpointing and metric tracking are handled automatically by PyTorch Lightning and Ray Train V2.
    The RayTrainReportCallback() forwards Lightning’s checkpoint.ckpt files and logged metrics to Ray,
    enabling structured, fault-tolerant training with seamless resume support. No manual save or report logic required.

  4. Ad-hoc Coordination → Declarative Orchestration
    Replace manual logging, retry logic, and resource management with Ray-native configs (ScalingConfig, CheckpointConfig, FailureConfig), letting Ray + Anyscale own the orchestration.

  5. Notebook-only Inference → Cluster-aware Evaluation
    After training, perform reverse diffusion sampling in-notebook using the latest checkpoint—but this can easily scale using Ray Data.

This flow upgrades a local notebook into a multi-node, resilient training + inference pipeline, using Ray’s native abstractions and running seamlessly inside an Anyscale Workspace, without sacrificing dev agility.

1. Imports and setup#

Standard scientific-Python stack, plus Ray for distributed data/training and Lightning for ergonomic model training.

# 00. Runtime setup 
import os
import sys
import subprocess

# Non-secret env var 
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"

# Install Python dependencies 
subprocess.check_call([
    sys.executable, "-m", "pip", "install", "--no-cache-dir",
    "torch==2.8.0",
    "matplotlib==3.10.6",
    "lightning==2.5.5",
    "pyarrow==14.0.2",
])
# 01. Imports

# Standard Python packages for math, plotting, and data handling
import os
import shutil
import glob
import json
import uuid
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import gymnasium as gym

# Ray libraries for distributed data and training
import ray
import ray.data
from ray.train.lightning import RayLightningEnvironment  
from ray.train import ScalingConfig, RunConfig, FailureConfig, CheckpointConfig, get_context, get_checkpoint, report, Checkpoint
from ray.train.torch import TorchTrainer

# PyTorch Lightning and base PyTorch for model definition and training
import lightning.pytorch as pl
import torch
from torch.utils.data import DataLoader
from torch import nn

2. Generate a real pendulum dataset#

Roll out a random policy for 10,000 steps, logging:

field

shape

description

obs

(3,)

[cos θ, sin θ, θ̇]

noisy_action

(1,)

ground-truth action + Gaussian noise

noise

(1,)

the injected noise (supervision target)

timestep

()

random diffusion step ∈ [0, 999]

You wrap the list of dicts in a Ray Dataset for automatic sharding.

# 02. Generate Pendulum offline dataset 

def make_pendulum_dataset(n_steps: int = 10_000):
    """
    Roll out a random policy in Pendulum-v1 and log (obs, noisy_action, noise, timestep).
    Returns a Ray Dataset ready for sharding.
    """
    env = gym.make("Pendulum-v1")
    obs, _ = env.reset(seed=0)
    data = []

    for _ in range(n_steps):
        action = env.action_space.sample().astype(np.float32)      # shape (1,)
        noise   = np.random.randn(*action.shape).astype(np.float32)
        noisy_action = action + noise                              # add Gaussian noise
        timestep = np.random.randint(0, 1000, dtype=np.int64)

        data.append(
            {
                "obs":        obs.astype(np.float32),              # shape (3,)
                "noisy_action": noisy_action,                      # shape (1,)
                "noise":        noise,                             # shape (1,)
                "timestep":     timestep,
            }
        )

        # Step environment
        obs, _, terminated, truncated, _ = env.step(action)
        if terminated or truncated:
            obs, _ = env.reset()

    return ray.data.from_items(data)

ds = make_pendulum_dataset()

3. Normalize and split#

Pendulum states lie roughly in [–π, π].
Scale to [–1, 1], then shuffle and split 80 / 20 into train and val shards. All transformations execute in parallel across the Ray cluster.

# 03. Normalize and split (vector obs ∈ [-π, π])

# Normalize pixel values from [0, 1] to [-1, 1] for training
def normalize(batch):
    # Pendulum observations are roughly in [-π, π] → scale to [-1, 1]
    batch["obs"] = batch["obs"] / np.pi
    return batch

# Apply normalization in parallel using Ray Data
ds = ds.map_batches(normalize, batch_format="numpy")

# Count total number of items (triggers actual execution)
total = ds.count()
print("Total dataset size:", total)

# Shuffle and split dataset into 80% training and 20% validation
split_idx = int(total * 0.8)
ds = ds.random_shuffle()
train_ds, val_ds = ds.split_at_indices([split_idx])

print("Train size:", train_ds.count())
print("Val size:", val_ds.count())

4. DiffusionPolicy LightningModule#

A tiny MLP that predicts the injected noise ϵ given:

  • 3D normalized state

  • 1D noisy action

  • scalar timestep (normalized by max_t)

It logs per-epoch losses so you can plot later.

# 04. DiffusionPolicy for low-dim observation (3D) and action (1D)

class DiffusionPolicy(pl.LightningModule):
    """Tiny MLP that predicts injected noise ϵ given (obs, noisy_action, timestep)."""

    def __init__(self, obs_dim: int = 3, act_dim: int = 1, max_t: int = 1000):
        super().__init__()
        self.max_t = max_t

        # 3D obs + 1D action + 1 timestep → 1D noise
        self.net = nn.Sequential(
            nn.Linear(obs_dim + act_dim + 1, 128),
            nn.ReLU(),
            nn.Linear(128, act_dim),
        )
        self.loss_fn = nn.MSELoss()

    # ---------- forward ----------
    def forward(self, obs, noisy_action, timestep):
        t = timestep.view(-1, 1).float() / self.max_t
        x = torch.cat([obs, noisy_action, t], dim=1)
        return self.net(x)

    # ---------- shared loss ----------
    def _shared_step(self, batch):
        pred = self.forward(
            batch["obs"].float(),
            batch["noisy_action"],
            batch["timestep"],
        )
        return self.loss_fn(pred, batch["noise"])

    # ---------- training / validation ----------
    def training_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        self.log("train_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._shared_step(batch)
        self.log("val_loss", loss, on_epoch=True, prog_bar=False, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

5. Distributed Train loop with checkpointing#

This per-worker function demonstrates:

  • Ray Data to PyTorch: iter_torch_batches() each epoch

  • Lightning-on-Ray: single-GPU trainer per worker

  • Fault tolerance: resume from the latest Ray Train checkpoint

  • checkpoint: saves the metadata and metrics every epoch

# 05. Ray Train Lightning-native training loop

def train_loop(config):
    import os, tempfile, torch, warnings
    import lightning.pytorch as pl
    from ray.train import get_checkpoint, get_context
    from ray.train.lightning import (
        RayLightningEnvironment,
        RayDDPStrategy,
        RayTrainReportCallback,
        prepare_trainer,
    )

    warnings.filterwarnings(
        "ignore", message="barrier.*using the device under current context"
    )

    # ---- Ray Dataset shards → iterable torch batches ----
    train_ds = ray.train.get_dataset_shard("train")
    val_ds   = ray.train.get_dataset_shard("val")
    train_loader = train_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))
    val_loader   = val_ds.iter_torch_batches(batch_size=config.get("batch_size", 32))

    # ---- Model ----
    model = DiffusionPolicy()

    # ---- Local scratch for PL checkpoints (Ray will persist to storage_path) ----
    CKPT_ROOT = os.path.join(tempfile.gettempdir(), "ray_pl_ckpts")
    os.makedirs(CKPT_ROOT, exist_ok=True)

    # ---- Lightning Trainer configured for Ray ----
    trainer = pl.Trainer(
        max_epochs=config.get("epochs", 10),
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[
            RayTrainReportCallback(),       # forwards metrics + ckpt to Ray
            pl.callbacks.ModelCheckpoint(   # local PL checkpoints each epoch
                dirpath=CKPT_ROOT,
                filename="epoch-{epoch:03d}",
                every_n_epochs=1,
                save_top_k=-1,
                save_last=True,
            ),
        ],
        default_root_dir=CKPT_ROOT,
        enable_progress_bar=False,
        check_val_every_n_epoch=1,
    )

    # ---- Prepare trainer for Ray environment ----
    trainer = prepare_trainer(trainer)

    # ---- Resume from Ray checkpoint if available ----
    ckpt_path = None
    ckpt = get_checkpoint()
    if ckpt:
        with ckpt.as_directory() as d:
            candidate = os.path.join(d, "checkpoint.ckpt")
            if os.path.exists(candidate):
                ckpt_path = candidate
                if get_context().get_world_rank() == 0:
                    print(f"✅ Resuming from Lightning checkpoint: {ckpt_path}")

    # ---- Run training (Lightning owns the loop) ----
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=ckpt_path)

6. Launch Ray TorchTrainer#

Eight A10G workers train in parallel.
RunConfig keeps the five most recent checkpoints and automatically restarts up to three times on failure.

# 06. Launch distributed training with Ray TorchTrainer

trainer = TorchTrainer(
    train_loop,
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True),
    datasets={"train": train_ds, "val": val_ds},
    run_config=RunConfig(
        name="pendulum_diffusion_ft",
        storage_path="/mnt/cluster_storage/pendulum_diffusion/pendulum_diffusion_results",
        checkpoint_config=CheckpointConfig(
            num_to_keep=5,
            checkpoint_score_attribute="epoch",
            checkpoint_score_order="max",
        ),
        failure_config=FailureConfig(max_failures=3),
    ),
)

result = trainer.fit()
print("Training complete →", result.metrics)
best_ckpt = result.checkpoint  # latest Ray-managed Lightning checkpoint

7. Plot train / val loss#

Visualize convergence using Ray’s built-in metrics.
result.metrics_dataframe automatically collects the losses logged by Lightning each epoch.

# 07. Plot training and validation loss (Ray + Lightning integration)

df = result.metrics_dataframe
print(df.head())  # optional sanity check

if "train_loss" not in df.columns or "val_loss" not in df.columns:
    raise ValueError("train_loss / val_loss missing. Did you log them via self.log()?")

plt.figure(figsize=(7, 4))
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train")
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Val")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.title("Pendulum Diffusion - Loss per Epoch (Ray Train + Lightning)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

8. Reverse diffusion helper#

Iteratively de-noise a random action vector 50 steps back to a feasible Pendulum command.

# 08. Reverse diffusion sampling for 1-D action

# Function to simulate reverse diffusion process
def sample_action(model, obs, n_steps=50, device="cpu"):
    """
    Runs reverse diffusion starting from noise to generate a Pendulum action.
    obs: torch.Tensor of shape (3,)
    returns: torch.Tensor of shape (1,)
    """
    model.eval()
    with torch.no_grad():
        obs = obs.unsqueeze(0).to(device)      # [1, 3]
        obs = obs / np.pi                      # Same normalization used in training

        x = torch.randn(1, 1).to(device)       # Start from noise in action space

        for step in reversed(range(n_steps)):
            t = torch.tensor([step], device=device)
            pred_noise = model(obs, x, t)
            x = x - pred_noise * 0.1

        return x.squeeze(0)

9. Sample an action from the trained policy#

Finally, load the latest epoch checkpoint, supply a sample state
[cos θ = 1, sin θ = 0, θ̇ = 0], and generate a 1-D torque command.

# 09. In-notebook sampling from trained model (Ray Lightning checkpoint)

# A plausible pendulum state: [cos(theta), sin(theta), theta_dot]
obs_sample = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float32)   # shape (3,)

assert best_ckpt is not None, "No checkpoint found — did training complete successfully?"

# Load the trained model from Ray's latest Lightning checkpoint
model = DiffusionPolicy(obs_dim=3, act_dim=1)

with best_ckpt.as_directory() as ckpt_dir:
    # RayTrainReportCallback saves a file named "checkpoint.ckpt"
    ckpt_file = os.path.join(ckpt_dir, "checkpoint.ckpt")
    if not os.path.exists(ckpt_file):
        # Fallback: search any .ckpt file if name differs
        candidates = glob.glob(os.path.join(ckpt_dir, "*.ckpt"))
        ckpt_file = candidates[0] if candidates else None

    assert ckpt_file is not None, f"No Lightning checkpoint found in {ckpt_dir}"
    state = torch.load(ckpt_file, map_location="cpu")
    model.load_state_dict(state.get("state_dict", state), strict=False)

# Move to device
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Run reverse diffusion sampling
action = sample_action(model, obs_sample, n_steps=50, device=device)
print("Sampled action:", action)

10. Clean up#

When you’re finished, release Ray resources and clear any temporary files.
This ensures the cluster is ready for other jobs and avoids unnecessary storage costs.

# 10. Cleanup -- delete checkpoints and metrics from model training

TARGET_PATH = "/mnt/cluster_storage/pendulum_diffusion"

if os.path.exists(TARGET_PATH):
    shutil.rmtree(TARGET_PATH)
    print(f"✅ Deleted everything under {TARGET_PATH}")
else:
    print(f"⚠️ Path does not exist: {TARGET_PATH}")

Wrap up and next steps#

You transformed a synthetic control demo into a Ray-native, real-data pipeline, training a diffusion policy across multiple GPUs, surviving worker restarts, and sampling feasible actions, all within a distributed Ray environment.

This tutorial demonstrates:

  • Logging continuous-control trajectories directly into a Ray Dataset for scalable preprocessing.

  • Streaming data into a Ray Train workload using Ray Data and Lightning with minimal integration overhead.

  • Saving structured checkpoints automatically through Lightning + Ray Train callbacks, ensuring seamless fault-tolerant recovery.

  • Running reverse diffusion sampling directly in-notebook.


Next steps#

The following are a few directions you can explore to extend or adapt this workload:

  1. Evaluate in the environment

    • Load the best checkpoint, deploy the policy in Gym’s Pendulum-v1, and log episode returns.

    • Compare against baseline behavior cloning or TD3/TD3+BC.

  2. Larger and richer datasets

    • Generate 100 k+ steps with a scripted controller or collect data from a learned agent.

    • Swap in other classic-control tasks like CartPole or MountainCar.

  3. Model and loss upgrades

    • Add timestep embeddings or a small transformer for better temporal reasoning.

    • Experiment with different noise schedules or auxiliary consistency losses.

  4. Hyperparameter sweeps

    • Wrap the training loop in Ray Tune and grid-search learning rate, hidden size, or diffusion steps.

    • Use Tune’s automatic checkpoint pruning to keep only the top-N runs.

  5. Mixed precision and performance

    • Enable torch.set_float32_matmul_precision('high') to leverage A10G Tensor Cores.

    • Profile GPU utilization across workers and tune batch size accordingly.

  6. Real robotics logs

    • Replace Pendulum with logs from a real robotic apparatus stored in Parquet; Ray Data shards them the same way.

  7. Serving the policy

    • Export the trained MLP to TorchScript and deploy with Ray Serve for low-latency inference.

    • Hook it to a real-time simulator or a web dashboard.

  8. End-to-end MLOps

    • Track checkpoints and metrics with MLflow or Weights & Biases.

    • Schedule nightly Ray jobs on Anyscale to retrain as new data arrives.