vedkdev commited on
Commit
6bb4efd
·
verified ·
1 Parent(s): 6977714

Add README

Browse files
Files changed (1) hide show
  1. README.md +166 -0
README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-diffusion
4
+ - machine-translation
5
+ - en-de
6
+ - masked-diffusion
7
+ - from-scratch
8
+ language:
9
+ - en
10
+ - de
11
+ datasets:
12
+ - wmt/wmt14
13
+ license: apache-2.0
14
+ ---
15
+
16
+ # Text Diffusion Model for EN→DE Translation
17
+
18
+ A **masked discrete diffusion** model for English-to-German machine translation, trained from scratch on WMT14 EN-DE.
19
+
20
+ ## Architecture
21
+
22
+ | Component | Detail |
23
+ |---|---|
24
+ | **Type** | Masked Discrete Diffusion |
25
+ | **Backbone** | DiT (Diffusion Transformer) with adaLN |
26
+ | **Parameters** | ~72M |
27
+ | **Blocks** | 12 DiT blocks |
28
+ | **Hidden dim** | 512, 8 attention heads |
29
+ | **Attention** | Bidirectional (no causal mask) with RoPE |
30
+ | **Conditioning** | Timestep via sinusoidal embeddings + adaLN; Segment embeddings for src/tgt |
31
+ | **Weight tying** | Input embeddings tied to output projection |
32
+ | **Tokenizer** | [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) (~58K vocab) |
33
+ | **Max sequence** | 128 src + 128 tgt tokens |
34
+
35
+ ### Inspired by
36
+ - **[MDLM](https://arxiv.org/abs/2406.07524)** — DiT backbone architecture, masked diffusion objective
37
+ - **[LLaDA](https://arxiv.org/abs/2502.09992)** — Conditional generation via SFT (keep prompt unmasked, mask only target), 1/t ELBO weighting
38
+ - **[DiNoiSer](https://arxiv.org/abs/2302.10025)** — Noise manipulation for conditional seq2seq diffusion
39
+
40
+ ## How It Works
41
+
42
+ ### Training (Forward Diffusion)
43
+ 1. Source (EN) and target (DE) tokens are concatenated: `[source | target]`
44
+ 2. A random masking rate `t ~ Uniform(0, 1)` is sampled per example
45
+ 3. Each target token is independently masked with probability `t`
46
+ 4. The bidirectional DiT predicts all masked tokens simultaneously
47
+ 5. Loss = cross-entropy on masked positions only, weighted by `1/t` (continuous-time ELBO)
48
+
49
+ ### Inference (Reverse Diffusion)
50
+ 1. Start with source tokens + fully masked target: `[source | MASK MASK ... MASK]`
51
+ 2. Over 50 denoising steps, iteratively predict and unmask tokens
52
+ 3. At each step `t → s`: predict all masked tokens, randomly re-mask a fraction `s/t`
53
+ 4. Final step: all remaining masks are filled with predictions
54
+
55
+ ## Training Details
56
+
57
+ | Setting | Value |
58
+ |---|---|
59
+ | **Dataset** | WMT14 EN-DE (~4.5M parallel sentence pairs) |
60
+ | **Optimizer** | AdamW (lr=3e-4, β₁=0.9, β₂=0.98, wd=0.01) |
61
+ | **Schedule** | Cosine with 4K linear warmup |
62
+ | **Effective batch size** | 256 (64 × 4 gradient accumulation) |
63
+ | **Max steps** | 200,000 |
64
+ | **Mixed precision** | FP16 |
65
+ | **Gradient clipping** | max_norm=1.0 |
66
+ | **Evaluation** | SacreBLEU on WMT14 test set every 20K steps |
67
+
68
+ ## Quick Start
69
+
70
+ ### Install dependencies
71
+
72
+ ```bash
73
+ pip install torch transformers datasets trackio sacrebleu sacremoses sentencepiece protobuf
74
+ ```
75
+
76
+ ### Train
77
+
78
+ ```bash
79
+ git clone https://huggingface.co/vedkdev/text-diffusion-en-de
80
+ cd text-diffusion-en-de
81
+ python train.py
82
+ ```
83
+
84
+ The script will:
85
+ - Download WMT14 EN-DE automatically
86
+ - Train for 200K steps with logging via [Trackio](https://huggingface.co/docs/trackio)
87
+ - Evaluate SacreBLEU periodically
88
+ - Push checkpoints to this repo
89
+
90
+ ### Adjusting for your hardware
91
+
92
+ Edit the `TRAIN_CONFIG` dict in `train.py`:
93
+
94
+ | GPU VRAM | Recommended `batch_size` | `gradient_accumulation_steps` |
95
+ |---|---|---|
96
+ | 24GB (A10G/3090/4090) | 64 | 4 |
97
+ | 16GB (T4/V100) | 32 | 8 |
98
+ | 12GB (3060) | 16 | 16 |
99
+ | 8GB (3070) | 8 | 32 |
100
+
101
+ ### Inference (after training)
102
+
103
+ ```python
104
+ import torch, json
105
+ from train import DiffusionTranslator, DiffusionTranslatorConfig, generate
106
+ from transformers import AutoTokenizer
107
+
108
+ # Load checkpoint
109
+ config = DiffusionTranslatorConfig(**json.load(open("checkpoints/best/config.json")))
110
+ model = DiffusionTranslator(config)
111
+ model.load_state_dict(torch.load("checkpoints/best/model.pt", map_location="cpu"))
112
+ model.eval()
113
+
114
+ tokenizer = AutoTokenizer.from_pretrained("checkpoints/best/")
115
+
116
+ # Translate
117
+ text = "The weather is nice today."
118
+ src = tokenizer(f"translate English to German: {text}",
119
+ max_length=128, truncation=True, padding="max_length",
120
+ return_tensors="pt")
121
+
122
+ gen_ids = generate(model, src["input_ids"], torch.zeros_like(src["input_ids"]),
123
+ config, num_steps=50, device="cpu")
124
+ print(tokenizer.decode(gen_ids[0], skip_special_tokens=True))
125
+ ```
126
+
127
+ ## Expected Results
128
+
129
+ Based on published literature for similar architectures on WMT14 EN→DE:
130
+
131
+ | Model | BLEU | Reference |
132
+ |---|---|---|
133
+ | Autoregressive Transformer | ~27 | Vaswani et al. |
134
+ | DiNoiSer (continuous diffusion) | 24.6 | Ye et al. 2023 |
135
+ | SeqDiffuSeq | 19.8 | Yuan et al. 2022 |
136
+ | E2D2 (discrete diffusion) | 24.8 | Kuleshov et al. 2024 |
137
+ | **This model (target)** | **15-20** | ~72M params, no KD |
138
+
139
+ > Note: Text diffusion models typically score 2-5 BLEU below autoregressive transformers of similar size. Knowledge distillation (KD) from an AR teacher can close the gap by ~1-2 BLEU.
140
+
141
+ ## Citation
142
+
143
+ If you use this model, please cite the foundational papers:
144
+
145
+ ```bibtex
146
+ @article{sahoo2024mdlm,
147
+ title={Simple and Effective Masked Diffusion Language Models},
148
+ author={Sahoo, Subham Sekhar and Arriola, Marianne and Schiff, Yair and Gokaslan, Aaron and Marroquin, Edgar and Kuleshov, Volodymyr},
149
+ journal={NeurIPS},
150
+ year={2024}
151
+ }
152
+
153
+ @article{nie2025llada,
154
+ title={Large Language Diffusion Models},
155
+ author={Nie, Shen and Zhu, Fengqi and You, Chao and Zhang, Xiaojun and Ou, Zhenguo and Zhu, Jun},
156
+ journal={arXiv preprint arXiv:2502.09992},
157
+ year={2025}
158
+ }
159
+
160
+ @article{ye2023dinoiser,
161
+ title={DiNoiSer: Diffused Conditional Sequence Learning by Manipulating Noises},
162
+ author={Ye, Jiasheng and Zheng, Zaixiang and Bao, Yu and Qian, Lihua and Gu, Quanquan},
163
+ journal={ACL},
164
+ year={2023}
165
+ }
166
+ ```