diffusionlm-from-scratch β masked diffusion LM (DiT, 142M)
A masked (absorbing-state) diffusion language model, built and trained from
scratch on TinyStories. Instead of generating left-to-right one token at a time,
it starts from a sequence of pure [MASK] tokens and denoises the whole
sequence in parallel β committing the tokens it is most confident about first,
in whatever order the meaning falls into place.
- Code, training & sampling: https://github.com/tchauffi/diffusionlm-from-scratch
- Course / write-up:
RESEARCH.mdβ a from-scratch course on discrete/text diffusion (D3PM β absorbing-state β sampling). - Demo site: animated real generations live in
docs/.
Model
| Architecture | DiT (transformer denoiser), bidirectional attention, adaLN-Zero |
| Parameters | ~142M |
| Hidden size / depth / heads | 768 / 12 / 12 |
| MLP ratio | 4.0 |
| Vocab | 8,192 (byte-level BPE, trained on TinyStories) |
| Max sequence length | 256 |
| Diffusion | absorbing-state (masked) discrete diffusion |
| Training data | TinyStories |
| Eval cross-entropy | 2.18 |
Key finding: uniform loss weighting (w(t) = 1), not the textbook ELBO
weight 1/Ο(t), is what turned word-salad into coherent stories.
Files
final.ptβ checkpoint with two state dicts,model(EMA, preferred) andraw, plus theconfigused to build the model.tokenizer.json,tokenizer_config.jsonβ the byte-level BPE tokenizer (PreTrainedTokenizerFast; special tokens[PAD][UNK][MASK]<|endoftext|>).
Usage
Install the model code from the
GitHub repo, then
generate stories in two lines β DiffusionLM bundles the model, tokenizer, and
absorbing-state scheduler:
from diffusionlm_from_scratch import DiffusionLM
lm = DiffusionLM.from_pretrained("tchauffi/diffusionlm-from-scratch")
for story in lm.generate(n=4, seq_len=80, temperature=0.9):
print(story)
generate exposes the sampler knobs (order, steps, corrector_frac,
confidence_threshold, β¦). For lower-level access, load just the model:
from diffusionlm_from_scratch.model import DiT
model = DiT.from_pretrained("tchauffi/diffusionlm-from-scratch") # downloads final.pt
# the raw checkpoint carries ck["config"], ck["model"] (EMA), and ck["raw"].
See scripts/capture_trajectories.py
in the repo for the full parallel-denoising sampling loop.