YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

SAD: Soft Ancestor Diffusion: Optional Intermediate States for Masked Discrete Diffusion Language Models

SAD is a block-wise hierarchical diffusion model for unconditional text generation. It supports three variants:

  • SAD (sad) β€” block-wise diffusion with learnable ancestor hierarchy
  • Block Diffusion (block_diffusion) β€” block-wise binary mask diffusion (no ancestors)
  • AR (ar) β€” standard autoregressive GPT-2 baseline

Directory Structure

sad/
β”œβ”€β”€ configs/              # YAML config files
β”œβ”€β”€ scripts/              # Entry-point scripts (training / inference / eval)
β”œβ”€β”€ src/                  # Core library
β”‚   β”œβ”€β”€ models/           # Model definitions
β”‚   β”œβ”€β”€ diffusion/        # Ancestor table & noisy state builder
β”‚   β”œβ”€β”€ losses/           # Loss functions
β”‚   β”œβ”€β”€ data/             # Data loaders
β”‚   β”œβ”€β”€ eval/             # Sampling & metric utilities
β”‚   └── utils/            # Helpers
β”œβ”€β”€ data/                 # Preprocessed data & caches (see below)
β”œβ”€β”€ models/               # Downloaded pretrained models (see below)
β”œβ”€β”€ tokenizers/           # Local GPT-2 tokenizer
β”œβ”€β”€ outputs/              # Training checkpoints & logs
└── eval/                 # Evaluation results

What to Put in data/ and models/

data/

File / Subdir What it is How to get it
data/owt_cache/ HuggingFace openwebtext cache Auto-created on first training run
data/hierarchy_prototypes_*.pt Hierarchy prototype tensors (k-means centroids) Run scripts/build_hierarchy.py learn
data/ancestor_lut_*.pt Sparse soft top-k ancestor lookup tables Run scripts/build_hierarchy.py build_lut

You do NOT need to commit data/ to git. Add it to .gitignore.

models/

File / Subdir What it is How to get it
models/gpt2/ GPT-2 base weights (for leaf embeddings & eval LM) huggingface-cli download gpt2 --local-dir models/gpt2
models/gpt2-large/ GPT-2 Large (for MAUVE featurization & gen-PPL eval) huggingface-cli download gpt2-large --local-dir models/gpt2-large

You do NOT need to commit models/ to git. Add it to .gitignore.

tokenizers/

File / Subdir What it is How to get it
tokenizers/gpt2/ GPT-2 tokenizer files huggingface-cli download gpt2 --local-dir tokenizers/gpt2

Scripts

Training

Script Purpose Quick Start
scripts/train_sad.py Train SAD (with ancestors) torchrun --nproc_per_node=8 scripts/train_sad.py --config configs/sad_owt_b32_top3.yaml
scripts/train_block_diffusion.py Train block-mask diffusion (no ancestors) torchrun --nproc_per_node=8 scripts/train_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml
scripts/train_ar.py Train AR baseline torchrun --nproc_per_node=8 scripts/train_ar.py --config configs/ar_owt.yaml

All training scripts support:

  • --config <path> β€” YAML config (required)
  • --resume <path> β€” resume from checkpoint
  • --num_steps <N> β€” override training steps
  • --batch_size <N> β€” override per-GPU batch size

Inference / Sampling

Script Purpose Quick Start
scripts/inference_sad.py Sample from SAD checkpoint python scripts/inference_sad.py --config configs/sad_owt_b32_top3.yaml --checkpoint outputs/sad/latest.pt --num_samples 4
scripts/inference_block_diffusion.py Sample from block-diffusion checkpoint python scripts/inference_block_diffusion.py --config configs/block_diffusion_owt_b32.yaml --checkpoint outputs/block_diffusion/latest.pt --num_samples 4
scripts/inference_ar.py Sample from AR checkpoint python scripts/inference_ar.py --config configs/ar_owt.yaml --checkpoint outputs/ar_baseline/latest.pt --num_samples 4

Evaluation

Script Purpose Quick Start
scripts/eval_gen_ppl.py Generative perplexity of SAD / block-diffusion samples under GPT-2 Large python scripts/eval_gen_ppl.py --checkpoint outputs/sad/latest.pt --model_type sad --num_samples 256
scripts/eval_ar_gen_ppl.py Generative perplexity of AR samples under GPT-2 Large python scripts/eval_ar_gen_ppl.py --checkpoint outputs/ar_baseline/latest.pt --num_samples 256
scripts/compute_mauve.py Compute MAUVE score against OpenWebText python scripts/compute_mauve.py --checkpoint outputs/sad/latest.pt --model_type sad --num_samples 5000
scripts/compute_diversity.py Compute distinct-n / repetition-n diversity metrics python scripts/compute_diversity.py --input samples.json --output diversity.json

Hierarchy Construction (SAD only)

Script Purpose Quick Start
scripts/build_hierarchy.py learn Learn hierarchy prototypes via cosine k-means++ python scripts/build_hierarchy.py learn --config configs/sad_owt_b32_top3.yaml
scripts/build_hierarchy.py build_lut Build sparse soft top-k ancestor LUT python scripts/build_hierarchy.py build_lut --config configs/sad_owt_b32_top3.yaml
scripts/build_hierarchy.py extend Extend existing prototypes with deeper levels python scripts/build_hierarchy.py extend --input data/hierarchy_50257-128.pt --levels 128,32,8 --output data/hierarchy_50257-128-32-8.pt
scripts/build_hierarchy.py merge Merge per-level LUT files into one python scripts/build_hierarchy.py merge --source 1:data/lut_top3.pt --source 2:data/lut_top2.pt --output data/lut_mixed.pt

Configs

Config Model Description
configs/sad_owt_b32_top3.yaml SAD Base SAD on OpenWebText (512 tokens, 2 levels)
configs/sad_owt_b32_h2_mixed.yaml SAD batch_size=32, 2 ancestor levels, mixed top-k
configs/sad_owt_b32_h3_mixed.yaml SAD batch_size=32, 3 ancestor levels, mixed top-k
configs/sad_owt_b32_top1.yaml SAD top-1 ancestors
configs/sad_owt_b32_top2.yaml SAD top-2 ancestors
configs/sad_owt_b32_top3.yaml SAD top-3 ancestors
configs/block_diffusion_owt_b32.yaml Block Diffusion Mask-only block diffusion (no ancestors)
configs/ar_owt.yaml AR Standard autoregressive GPT-2 baseline

Core Source Modules

File Role
src/models/sad_model.py SADModel β€” backbone with block-diff attention mask & FlexAttention
src/models/ar_model.py ARModel β€” standard causal decoder baseline
src/models/dit_components.py DiT blocks, Rotary embeddings, EmbeddingLayer
src/diffusion/ancestor_table.py AncestorTable β€” fixed LUT + learnable ancestor embeddings
src/diffusion/noisy_state.py NoisyStateBuilder β€” sample level per position & build noisy embeddings
src/losses/sad_loss.py Leaf CE (ancestor CE loss was not experimented)
src/data/__init__.py build_owt_dataloader, build_debug_dataloader
src/eval/metrics.py Evaluation metric helpers

1. SAD (scripts/train_sad.py)

Quick Start

# Single GPU
python scripts/train_sad.py --config configs/sad_owt_b32_top3.yaml

# Multi-GPU DDP (e.g. 8 GPUs)
torchrun --nproc_per_node=8 scripts/train_sad.py \
    --config configs/sad_owt_b32_top3.yaml

# Resume
torchrun --nproc_per_node=8 scripts/train_sad.py \
    --config configs/sad_owt_b32_top3.yaml \
    --resume outputs/sad/latest.pt

# Override steps / batch size
torchrun --nproc_per_node=8 scripts/train_sad.py \
    --config configs/sad_owt_b32_top3.yaml \
    --num_steps 500000 --batch_size 64

Training Paradigm

  • Sequence is divided into block_size token blocks; intra-block bidirectional, inter-block causal.
  • Vectorized training: x_full = [noisy | clean], single forward with block-diff mask:
q→k Allowed Condition
noisy β†’ noisy Same block only
noisy β†’ clean Strictly earlier clean block
clean β†’ clean Same or earlier clean block
clean β†’ noisy Never
  • No timestep conditioning; only a single learnable cond_bias.
  • Input is continuous embeddings (leaf / ancestor / mask), not discrete token ids.
  • Loss computed only on the noisy half: leaf cross-entropy. (Note: lambda_ancestor > 0 was not experimented.)

Positional Encoding

Three-way additive + RoPE:

  • block_idx_embed β€” cross-block position (AR)
  • intra_pos_embed β€” intra-block position
  • segment_embed β€” noisy(0) / clean(1)
  • RoPE applied to both halves with 0..L-1; noisy[i] and clean[i] have relative offset 0

Config Skeleton (configs/sad_owt_b32_top3.yaml)

model:
  vocab_size: 50257
  hidden_size: 768
  n_blocks: 12
  n_heads: 12
  max_seq_len: 512
  block_size: 8
  num_levels: 2
  level_sizes: [50257, 128]

ancestor:
  lut_path: data/ancestor_lut_50257-128_top3_t1.0.pt    # fallback: random debug LUT
  proto_path: data/hierarchy_prototypes_50257-128.pt

loss:
  lambda_ancestor: 0.0      # ancestor CE loss was not experimented
  mask_only: false

training:
  seed: 0
  batch_size: 64          # per-GPU
  num_steps: 1_000_000
  lr: 3.0e-4
  weight_decay: 0.01
  grad_clip: 1.0
  dtype: bf16
  log_interval: 100
  eval_interval: 5000
  save_interval: 10000
  compile: default        # "off" to disable torch.compile

data:
  dataset: openwebtext    # "debug" uses TinyDebugDataset
  seq_len: 512
  cache_dir: data/owt_cache
  mode: subsample         # subsample | pack

logging:
  use_wandb: true
  project: sad
  save_dir: outputs/sad

DDP Invariants (do not change without care)

  1. AncestorTable is outside DDP, but its parameters are in the optimizer. After loss.backward() and before clip, manually dist.all_reduce(grad, op=AVG) to sync its gradients.
  2. AncestorTable initial value broadcast: after build_ancestor_table, broadcast parameters and buffers from rank 0 via dist.broadcast to all ranks, to counteract the drift from set_seed(seed + local_rank).
  3. Streaming data sharding: build_owt_dataloader(shard_across_ranks=True) for train, False for val (val only runs on rank 0; sharding would bias metrics to a single shard).
  4. FlexAttention ignores padding mask, protected by an assert that pad_token_id == eos_token_id. Future multi-doc packing would need doc ids in mask_mod.
  5. Debug LUT cross-rank consistency: build_ancestor_table uses an independent torch.Generator for random LUT, not the global RNG.

Output

  • outputs/sad/ckpt_{step}.pt β€” periodic checkpoint
  • outputs/sad/latest.pt β€” latest checkpoint (overwritten)
  • outputs/sad/config.yaml β€” runtime config snapshot
  • wandb: project=sad (when logging.use_wandb: true)

Checkpoint contains step / model / ancestor_table / optimizer / config / metrics; resume restores all of them (data iterator position is not restored; streaming data restarts from the beginning).

Evaluate

Runs on rank 0 only, every eval_interval steps, over num_batches=50 batches (see evaluate()). Non-main ranks wait for rank 0 at the next backward all-reduce, so long eval will drop throughput but will not deadlock.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support