SAD: Soft Ancestor Diffusion: Optional Intermediate States for Masked Discrete Diffusion Language Models
SAD is a block-wise hierarchical diffusion model for unconditional text generation. It supports three variants:
- SAD (
sad) β block-wise diffusion with learnable ancestor hierarchy
- Block Diffusion (
block_diffusion) β block-wise binary mask diffusion (no ancestors)
- AR (
ar) β standard autoregressive GPT-2 baseline
Directory Structure
sad/
βββ configs/ # YAML config files
βββ scripts/ # Entry-point scripts (training / inference / eval)
βββ src/ # Core library
β βββ models/ # Model definitions
β βββ diffusion/ # Ancestor table & noisy state builder
β βββ losses/ # Loss functions
β βββ data/ # Data loaders
β βββ eval/ # Sampling & metric utilities
β βββ utils/ # Helpers
βββ data/ # Preprocessed data & caches (see below)
βββ models/ # Downloaded pretrained models (see below)
βββ tokenizers/ # Local GPT-2 tokenizer
βββ outputs/ # Training checkpoints & logs
βββ eval/ # Evaluation results
What to Put in data/ and models/
data/
| File / Subdir |
What it is |
How to get it |
data/owt_cache/ |
HuggingFace openwebtext cache |
Auto-created on first training run |
data/hierarchy_prototypes_*.pt |
Hierarchy prototype tensors (k-means centroids) |
Run scripts/build_hierarchy.py learn |
data/ancestor_lut_*.pt |
Sparse soft top-k ancestor lookup tables |
Run scripts/build_hierarchy.py build_lut |
You do NOT need to commit data/ to git. Add it to .gitignore.
models/
| File / Subdir |
What it is |
How to get it |
models/gpt2/ |
GPT-2 base weights (for leaf embeddings & eval LM) |
huggingface-cli download gpt2 --local-dir models/gpt2 |
models/gpt2-large/ |
GPT-2 Large (for MAUVE featurization & gen-PPL eval) |
huggingface-cli download gpt2-large --local-dir models/gpt2-large |
You do NOT need to commit models/ to git. Add it to .gitignore.
tokenizers/
| File / Subdir |
What it is |
How to get it |
tokenizers/gpt2/ |
GPT-2 tokenizer files |
huggingface-cli download gpt2 --local-dir tokenizers/gpt2 |
Scripts
Training
| Script |
Purpose |
Quick Start |
scripts/train_sad.py |
Train SAD (with ancestors) |
torchrun --nproc_per_node=8 scripts/train_sad.py --config configs/sad_owt_b32_top3.yaml |
scripts/train_block_diffusion.py |
Train block-mask diffusion (no ancestors) |
torchrun --nproc_per_node=8 scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml |
scripts/train_ar.py |
Train AR baseline |
torchrun --nproc_per_node=8 scripts/train_ar.py --config configs/ar_owt.yaml |
All training scripts support:
--config <path> β YAML config (required)
--resume <path> β resume from checkpoint
--num_steps <N> β override training steps
--batch_size <N> β override per-GPU batch size
Inference / Sampling
| Script |
Purpose |
Quick Start |
scripts/inference_sad.py |
Sample from SAD checkpoint |
python scripts/inference_sad.py --config configs/sad_owt_b32_top3.yaml --checkpoint outputs/sad/latest.pt --num_samples 4 |
scripts/inference_block_diffusion.py |
Sample from block-diffusion checkpoint |
python scripts/inference_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml --checkpoint outputs/block_diffusion/latest.pt --num_samples 4 |
scripts/inference_ar.py |
Sample from AR checkpoint |
python scripts/inference_ar.py --config configs/ar_owt.yaml --checkpoint outputs/ar_baseline/latest.pt --num_samples 4 |
Evaluation
| Script |
Purpose |
Quick Start |
scripts/eval_gen_ppl.py |
Generative perplexity of SAD / block-diffusion samples under GPT-2 Large |
python scripts/eval_gen_ppl.py --checkpoint outputs/sad/latest.pt --model_type sad --num_samples 256 |
scripts/eval_ar_gen_ppl.py |
Generative perplexity of AR samples under GPT-2 Large |
python scripts/eval_ar_gen_ppl.py --checkpoint outputs/ar_baseline/latest.pt --num_samples 256 |
scripts/compute_mauve.py |
Compute MAUVE score against OpenWebText |
python scripts/compute_mauve.py --checkpoint outputs/sad/latest.pt --model_type sad --num_samples 5000 |
scripts/compute_diversity.py |
Compute distinct-n / repetition-n diversity metrics |
python scripts/compute_diversity.py --input samples.json --output diversity.json |
Hierarchy Construction (SAD only)
| Script |
Purpose |
Quick Start |
scripts/build_hierarchy.py learn |
Learn hierarchy prototypes via cosine k-means++ |
python scripts/build_hierarchy.py learn --config configs/sad_owt_b32_top3.yaml |
scripts/build_hierarchy.py build_lut |
Build sparse soft top-k ancestor LUT |
python scripts/build_hierarchy.py build_lut --config configs/sad_owt_b32_top3.yaml |
scripts/build_hierarchy.py extend |
Extend existing prototypes with deeper levels |
python scripts/build_hierarchy.py extend --input data/hierarchy_50257-128.pt --levels 128,32,8 --output data/hierarchy_50257-128-32-8.pt |
scripts/build_hierarchy.py merge |
Merge per-level LUT files into one |
python scripts/build_hierarchy.py merge --source 1:data/lut_top3.pt --source 2:data/lut_top2.pt --output data/lut_mixed.pt |
Configs
| Config |
Model |
Description |
configs/sad_owt_b32_top3.yaml |
SAD |
Base SAD on OpenWebText (512 tokens, 2 levels) |
configs/sad_owt_b32_h2_mixed.yaml |
SAD |
batch_size=32, 2 ancestor levels, mixed top-k |
configs/sad_owt_b32_h3_mixed.yaml |
SAD |
batch_size=32, 3 ancestor levels, mixed top-k |
configs/sad_owt_b32_top1.yaml |
SAD |
top-1 ancestors |
configs/sad_owt_b32_top2.yaml |
SAD |
top-2 ancestors |
configs/sad_owt_b32_top3.yaml |
SAD |
top-3 ancestors |
configs/block_diffusion_owt_b32.yaml |
Block Diffusion |
Mask-only block diffusion (no ancestors) |
configs/ar_owt.yaml |
AR |
Standard autoregressive GPT-2 baseline |
Core Source Modules
| File |
Role |
src/models/sad_model.py |
SADModel β backbone with block-diff attention mask & FlexAttention |
src/models/ar_model.py |
ARModel β standard causal decoder baseline |
src/models/dit_components.py |
DiT blocks, Rotary embeddings, EmbeddingLayer |
src/diffusion/ancestor_table.py |
AncestorTable β fixed LUT + learnable ancestor embeddings |
src/diffusion/noisy_state.py |
NoisyStateBuilder β sample level per position & build noisy embeddings |
src/losses/sad_loss.py |
Leaf CE (ancestor CE loss was not experimented) |
src/data/__init__.py |
build_owt_dataloader, build_debug_dataloader |
src/eval/metrics.py |
Evaluation metric helpers |
1. SAD (scripts/train_sad.py)
Quick Start
python scripts/train_sad.py --config configs/sad_owt_b32_top3.yaml
torchrun --nproc_per_node=8 scripts/train_sad.py \
--config configs/sad_owt_b32_top3.yaml
torchrun --nproc_per_node=8 scripts/train_sad.py \
--config configs/sad_owt_b32_top3.yaml \
--resume outputs/sad/latest.pt
torchrun --nproc_per_node=8 scripts/train_sad.py \
--config configs/sad_owt_b32_top3.yaml \
--num_steps 500000 --batch_size 64
Training Paradigm
- Sequence is divided into
block_size token blocks; intra-block bidirectional, inter-block causal.
- Vectorized training:
x_full = [noisy | clean], single forward with block-diff mask:
| qβk |
Allowed Condition |
| noisy β noisy |
Same block only |
| noisy β clean |
Strictly earlier clean block |
| clean β clean |
Same or earlier clean block |
| clean β noisy |
Never |
- No timestep conditioning; only a single learnable
cond_bias.
- Input is continuous embeddings (leaf / ancestor / mask), not discrete token ids.
- Loss computed only on the noisy half: leaf cross-entropy. (Note:
lambda_ancestor > 0 was not experimented.)
Positional Encoding
Three-way additive + RoPE:
block_idx_embed β cross-block position (AR)
intra_pos_embed β intra-block position
segment_embed β noisy(0) / clean(1)
- RoPE applied to both halves with
0..L-1; noisy[i] and clean[i] have relative offset 0
Config Skeleton (configs/sad_owt_b32_top3.yaml)
model:
vocab_size: 50257
hidden_size: 768
n_blocks: 12
n_heads: 12
max_seq_len: 512
block_size: 8
num_levels: 2
level_sizes: [50257, 128]
ancestor:
lut_path: data/ancestor_lut_50257-128_top3_t1.0.pt
proto_path: data/hierarchy_prototypes_50257-128.pt
loss:
lambda_ancestor: 0.0
mask_only: false
training:
seed: 0
batch_size: 64
num_steps: 1_000_000
lr: 3.0e-4
weight_decay: 0.01
grad_clip: 1.0
dtype: bf16
log_interval: 100
eval_interval: 5000
save_interval: 10000
compile: default
data:
dataset: openwebtext
seq_len: 512
cache_dir: data/owt_cache
mode: subsample
logging:
use_wandb: true
project: sad
save_dir: outputs/sad
DDP Invariants (do not change without care)
- AncestorTable is outside DDP, but its parameters are in the optimizer. After
loss.backward() and before clip, manually dist.all_reduce(grad, op=AVG) to sync its gradients.
- AncestorTable initial value broadcast: after
build_ancestor_table, broadcast parameters and buffers from rank 0 via dist.broadcast to all ranks, to counteract the drift from set_seed(seed + local_rank).
- Streaming data sharding:
build_owt_dataloader(shard_across_ranks=True) for train, False for val (val only runs on rank 0; sharding would bias metrics to a single shard).
- FlexAttention ignores padding mask, protected by an assert that
pad_token_id == eos_token_id. Future multi-doc packing would need doc ids in mask_mod.
- Debug LUT cross-rank consistency:
build_ancestor_table uses an independent torch.Generator for random LUT, not the global RNG.
Output
outputs/sad/ckpt_{step}.pt β periodic checkpoint
outputs/sad/latest.pt β latest checkpoint (overwritten)
outputs/sad/config.yaml β runtime config snapshot
- wandb:
project=sad (when logging.use_wandb: true)
Checkpoint contains step / model / ancestor_table / optimizer / config / metrics; resume restores all of them (data iterator position is not restored; streaming data restarts from the beginning).
Evaluate
Runs on rank 0 only, every eval_interval steps, over num_batches=50 batches (see evaluate()). Non-main ranks wait for rank 0 at the next backward all-reduce, so long eval will drop throughput but will not deadlock.