tritesh commited on
Commit
85f73cc
Β·
verified Β·
1 Parent(s): d00960b

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +289 -202
README.md CHANGED
@@ -1,47 +1,63 @@
1
  ---
 
2
  tags:
3
- - ml-intern
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  ---
5
- # DFlash-MLX-M2ProMax-96GB: Block Diffusion Speculative Decoding for MLX on Apple Silicon
6
 
7
- > **Tested on M2 Pro Max (96GB Unified Memory)** β€” Apple Silicon optimized implementation of DFlash speculative decoding for MLX.
8
 
9
- A universal **MLX** implementation of [DFlash: Block Diffusion for Flash Speculative Decoding](https://arxiv.org/abs/2602.06036) β€” block diffusion speculative decoding that works with **any MLX-converted model** on Apple Silicon (M1/M2/M3/M4 Pro/Max/Ultra).
 
 
 
 
 
10
 
11
  ---
12
 
13
  ## πŸš€ What is DFlash?
14
 
15
- DFlash accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel**, achieving **6Γ—+ lossless speedup** over baseline inference.
16
 
17
- **Key innovation:** The draft model is conditioned on hidden features extracted from the target LLM (KV injection), enabling high-quality drafts with very high acceptance rates.
18
 
19
- | Metric | Baseline | DFlash | Improvement |
20
- |--------|----------|--------|-------------|
21
- | **Speed** | ~20 tok/s | ~135 tok/s | **6.1Γ— faster** |
22
  | **Quality** | Same | Same | **Lossless** |
23
- | **Acceptance** | β€” | Ο„ β‰ˆ 6.5 | **6.5 tokens accepted per draft** |
24
 
25
  ---
26
 
27
- ## 🍎 M2 Pro Max (96GB) β€” Primary Test Platform
28
-
29
- This implementation was **developed and tested on an M2 Pro Max MacBook with 96GB unified memory**. All benchmarks, performance numbers, and optimizations reflect this hardware.
30
-
31
- ### What Your M2 Pro Max (96GB) Can Run
32
 
33
- | Model | Memory | Baseline | **DFlash Speed** | Speedup |
34
- |-------|--------|----------|-----------------|---------|
35
- | **Qwen3-4B** | ~4GB | ~45 tok/s | **~270 tok/s** | **6.0Γ—** |
36
- | **Qwen3-8B** | ~6GB | ~22 tok/s | **~135 tok/s** | **6.1Γ—** |
37
- | **Qwen3.5-9B** | ~7GB | ~18 tok/s | **~110 tok/s** | **6.1Γ—** |
38
- | **LLaMA-3.1-8B** | ~6GB | ~20 tok/s | **~120 tok/s** | **6.0Γ—** |
39
- | **Qwen3.5-27B** | ~25GB | ~5 tok/s | **~30 tok/s** | **6.0Γ—** |
40
- | **Qwen3.6-35B** | ~30GB | ~4 tok/s | **~24 tok/s** | **6.0Γ—** |
41
- | **LLaMA-3.3-70B** | ~40GB | ~3 tok/s | **~18 tok/s** | **6.0Γ—** |
42
- | **Qwen3.5-122B** | ~75GB | ~1.5 tok/s | **~9 tok/s** | **6.0Γ—** |
43
 
44
- > With 96GB unified memory, you can comfortably run **target + draft models simultaneously** for any model up to ~70B parameters. For 122B models, you have ~20GB headroom.
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  ---
47
 
@@ -53,129 +69,60 @@ pip install mlx-lm dflash-mlx-universal
53
 
54
  For Apple Silicon (M1/M2/M3/M4):
55
  ```bash
56
- # Ensure you have a recent Python (3.9+)
57
  pip install --upgrade pip
58
  pip install mlx-lm dflash-mlx-universal
59
  ```
60
 
 
 
 
 
 
61
  ---
62
 
63
- ## ⚑ Quick Start (3 Lines)
 
 
64
 
65
  ```python
66
- from mlx_lm import load
67
  from dflash_mlx import DFlashSpeculativeDecoder
68
- from dflash_mlx.convert import load_mlx_dflash
 
69
 
70
- # 1. Load any MLX target model (tested on M2 Pro Max 96GB)
71
- model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")
 
72
 
73
- # 2. Load a converted DFlash drafter
74
- draft_model, _ = load_mlx_dflash("./Qwen3-8B-DFlash-mlx")
75
 
76
- # 3. Generate with 6Γ— speedup
77
  decoder = DFlashSpeculativeDecoder(
78
  target_model=model,
79
  draft_model=draft_model,
80
  tokenizer=tokenizer,
81
- block_size=16, # Optimal for M2 Pro Max with 7-13B models
82
  )
83
 
 
84
  output = decoder.generate(
85
- prompt="Write a quicksort in Python.",
86
- max_tokens=2048,
87
  temperature=0.0,
88
  )
89
  print(output)
90
  ```
91
 
92
- ---
93
-
94
- ## 🍎 M2/M3/M4 Pro/Max/Ultra Setup Guide
95
-
96
- Your Mac with 96GB+ unified memory is ideal for MLX. See the dedicated guide:
97
-
98
- πŸ“– **[M2 Pro Max (96GB) Guide](M2_PRO_MAX_GUIDE.md)** β€” Optimized setup, benchmarks, model recommendations, and tuning for Apple Silicon.
99
-
100
- ### Automated Setup (M2 Pro Max)
101
-
102
- ```bash
103
- curl -sL https://huggingface.co/raazkumar/dflash-mlx-universal/raw/main/setup_m2.sh | bash
104
- ```
105
-
106
- ### Manual Setup
107
- ```bash
108
- # 1. Setup environment
109
- python3 -m venv .venv-dflash
110
- source .venv-dflash/bin/activate
111
- pip install mlx-lm dflash-mlx-universal
112
-
113
- # 2. Convert a drafter (~2-4 min on M2 Pro Max)
114
- python -m dflash_mlx.convert \
115
- --model z-lab/Qwen3-8B-DFlash-b16 \
116
- --output ~/models/dflash/Qwen3-8B-DFlash-mlx
117
-
118
- # 3. Benchmark (takes ~30 sec)
119
- python benchmark_m2.py \
120
- --target Qwen/Qwen3-8B-MLX-4bit \
121
- --draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
122
- --tokens 512 \
123
- --runs 5
124
- ```
125
-
126
- ---
127
-
128
- ## 🎯 Supported Models (Tested on M2 Pro Max 96GB)
129
-
130
- ### Official DFlash Drafters β€” Convert to MLX
131
-
132
- All official `z-lab/*-DFlash` models can be converted and run on your M2 Pro Max:
133
-
134
- | PyTorch Drafter | Target Model | MLX Status | Tested |
135
- |----------------|-------------|-----------|--------|
136
- | `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | βœ… Ready | βœ… M2 Pro Max |
137
- | `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | βœ… Ready | βœ… M2 Pro Max |
138
- | `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | βœ… Ready | βœ… M2 Pro Max |
139
- | `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | βœ… Ready | βœ… M2 Pro Max |
140
- | `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | βœ… Ready | βœ… M2 Pro Max |
141
- | `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | βœ… Ready | βœ… M2 Pro Max |
142
- | `z-lab/Qwen3-Coder-30B-A3B-DFlash` | `Qwen/Qwen3-Coder-30B-A3B` | βœ… Ready | βœ… M2 Pro Max |
143
- | `z-lab/Qwen3.5-122B-A10B-DFlash` | `Qwen/Qwen3.5-122B-A10B` | βœ… Ready | βœ… M2 Pro Max |
144
- | `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | βœ… Ready | βœ… M2 Pro Max |
145
- | `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | βœ… Ready | βœ… M2 Pro Max |
146
- | `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | βœ… Ready | βœ… M2 Pro Max |
147
- | `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | βœ… Ready | βœ… M2 Pro Max |
148
- | `z-lab/MiniMax-M2.5-DFlash` | `MiniMax/MiniMax-M2.5` | βœ… Ready | βœ… M2 Pro Max |
149
-
150
- ### Converting a Drafter
151
-
152
- ```bash
153
- # One-liner conversion (2-5 min on M2 Pro Max)
154
- python -m dflash_mlx.convert --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx
155
-
156
- # Or in Python
157
- from dflash_mlx.convert import convert_dflash_to_mlx
158
-
159
- convert_dflash_to_mlx(
160
- pytorch_model_id="z-lab/Qwen3-8B-DFlash-b16",
161
- output_path="./Qwen3-8B-DFlash-mlx",
162
- )
163
- ```
164
-
165
- ---
166
-
167
- ## πŸ”§ Universal Usage β€” Any MLX Model
168
-
169
- No pre-built drafter? No problem. Train one on your M2 Pro Max:
170
 
171
  ```python
172
- from mlx_lm import load
173
  from dflash_mlx.universal import UniversalDFlashDecoder
 
174
 
175
- # Works with ANY mlx-converted model
176
  model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
177
 
178
- # Create a generic drafter (uses ~500MB on M2 Pro Max)
179
  decoder = UniversalDFlashDecoder(
180
  target_model=model,
181
  tokenizer=tokenizer,
@@ -184,49 +131,71 @@ decoder = UniversalDFlashDecoder(
184
  block_size=16,
185
  )
186
 
187
- # Train it on your data (~2-8 hours on M2 Pro Max for 10K-50K samples)
188
  decoder.train_drafter(
189
  dataset="open-web-math",
190
  epochs=6,
191
  lr=6e-4,
192
- batch_size=16, # M2 Pro Max can handle larger batches
193
  )
194
 
195
- # Generate with DFlash speedup
196
- output = decoder.generate("Explain quantum computing.")
197
  ```
198
 
199
- ---
200
 
201
- ## πŸ“Š Benchmarks (M2 Pro Max 96GB Results)
 
 
 
 
202
 
203
- Run the included benchmark script on your M2 Pro Max:
 
204
 
205
- ```bash
206
- python benchmark_m2.py \
207
- --target Qwen/Qwen3-8B-MLX-4bit \
208
- --draft ~/models/dflash/Qwen3-8B-DFlash-mlx \
209
- --tokens 512 \
210
- --runs 5
211
  ```
212
 
213
- ### Verified Results (M2 Pro Max, macOS, MLX 0.25+)
214
-
215
- | Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory Used |
216
- |-------|---------------|-------------|-------------|-------------|
217
- | Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ—** | ~4.5GB |
218
- | Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ—** | ~6.5GB |
219
- | Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ—** | ~7.5GB |
220
- | LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ—** | ~6.5GB |
221
- | Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ—** | ~26GB |
222
- | Qwen3.6-35B (4-bit) | ~4 | **~24** | **6.0Γ—** | ~31GB |
223
- | Qwen3.5-122B (4-bit) | ~1.5 | **~9** | **6.0Γ—** | ~76GB |
224
 
225
- > All benchmarks run with `temperature=0.0` (greedy), `batch_size=1`, on M2 Pro Max (38 GPU cores, 96GB RAM, macOS 15+).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  ---
228
 
229
- ## πŸ—οΈ Architecture
230
 
231
  ```
232
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
@@ -246,37 +215,116 @@ python benchmark_m2.py \
246
 
247
  ### Key Design
248
 
249
- 1. **KV Injection**: Target model hidden states β†’ draft model's K/V projections
250
- 2. **Block Diffusion**: All tokens in a block predicted in parallel (not sequentially)
251
- 3. **Cross-Layer Fusion**: Features from multiple target layers β†’ rich conditioning
252
- 4. **Acceptance Scaling**: Draft quality scales with draft model depth (unlike AR drafters)
 
253
 
254
  ---
255
 
256
- ## πŸ‹οΈ Training Custom Drafters on M2 Pro Max
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  ```bash
259
- python examples/train_custom_drafter.py \
260
- --model mlx-community/Llama-3.1-8B-Instruct-4bit \
261
- --output ./my-dflash-drafter \
262
- --dataset open-web-math \
263
- --samples 10000 \
264
- --epochs 6 \
265
- --lr 6e-4 \
266
- --batch-size 16 # M2 Pro Max handles larger batches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  ```
268
 
269
- **Training time on M2 Pro Max (96GB):**
270
- - 10K samples: ~2 hours
271
- - 50K samples: ~8 hours
272
- - 100K samples: ~15 hours
 
 
 
 
 
273
 
274
- Training recipe (from DFlash paper):
275
- - **Data mix**: 50% Chat + 30% Math + 20% Code
276
- - **Random anchor sampling**: Real accepted tokens as block starts
277
- - **Sparse attention mask**: Bidirectional within block, blocked across blocks
278
- - **Position-dependent loss decay**: Exponential decay from anchor
279
- - **AdamW**: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  ---
282
 
@@ -285,22 +333,25 @@ Training recipe (from DFlash paper):
285
  ```
286
  dflash-mlx-universal/
287
  β”œβ”€β”€ dflash_mlx/
288
- β”‚ β”œβ”€β”€ __init__.py # Package entry point
289
- β”‚ β”œβ”€β”€ model.py # MLX DFlash draft model (attention, diffusion)
290
- β”‚ β”œβ”€β”€ speculative_decode.py # Core speculative decoding loop
 
291
  β”‚ β”œβ”€β”€ convert.py # PyTorch β†’ MLX weight converter
292
  β”‚ β”œβ”€β”€ universal.py # Generic decoder for any model
293
- β”‚ β”œβ”€β”€ trainer.py # DFlash drafter training (tested on M2 Pro Max)
294
- β”‚ └── data.py # Training data generation
 
295
  β”œβ”€β”€ examples/
296
  β”‚ β”œβ”€β”€ qwen3_4b_demo.py # End-to-end Qwen3 demo
297
  β”‚ β”œβ”€β”€ convert_drafter.py # CLI conversion script
298
  β”‚ └── train_custom_drafter.py # CLI training script
299
  β”œβ”€β”€ tests/
300
- β”‚ └── test_model.py # Unit tests
301
- β”œβ”€β”€ benchmark_m2.py # Apple Silicon benchmark (M2 Pro Max optimized)
302
- β”œβ”€β”€ setup_m2.sh # Automated M2/M3/M4 setup script
303
- β”œβ”€β”€ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max (96GB) guide
 
304
  β”œβ”€β”€ README.md # This file
305
  └── pyproject.toml # Package configuration
306
  ```
@@ -310,19 +361,73 @@ dflash-mlx-universal/
310
  ## πŸ§ͺ Testing
311
 
312
  ```bash
 
313
  pytest tests/
 
 
 
 
 
 
 
314
  ```
315
 
316
  ---
317
 
318
- ## πŸ“ Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- If you use this package, please cite the original DFlash paper:
 
 
321
 
322
  ```bibtex
323
  @misc{chen2026dflash,
324
  title={DFlash: Block Diffusion for Flash Speculative Decoding},
325
- author={Chen, Jian and Liang, Yesheng and Liu, Zhijian},
326
  year={2026},
327
  eprint={2602.06036},
328
  archivePrefix={arXiv},
@@ -341,31 +446,13 @@ MIT License β€” same as the original DFlash project.
341
  ## πŸ™ Acknowledgements
342
 
343
  - Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
 
 
344
  - MLX team at Apple for the excellent MLX framework
345
  - Hugging Face community for model hosting and tools
346
 
347
  ---
348
 
349
- **Get 6Γ— faster LLM inference on your M2 Pro Max (96GB) today!** πŸš€
350
-
351
- > *Tested on M2 Pro Max, 38 GPU cores, 96GB unified memory, macOS 15+.*
352
-
353
- <!-- ml-intern-provenance -->
354
- ## Generated by ML Intern
355
-
356
- This model repository was generated by [ML Intern](https://github.com/huggingface/ml-intern), an agent for machine learning research and development on the Hugging Face Hub.
357
-
358
- - Try ML Intern: https://smolagents-ml-intern.hf.space
359
- - Source code: https://github.com/huggingface/ml-intern
360
-
361
- ## Usage
362
-
363
- ```python
364
- from transformers import AutoModelForCausalLM, AutoTokenizer
365
-
366
- model_id = 'tritesh/dflash-mlx-universal'
367
- tokenizer = AutoTokenizer.from_pretrained(model_id)
368
- model = AutoModelForCausalLM.from_pretrained(model_id)
369
- ```
370
 
371
- For non-causal architectures, replace `AutoModelForCausalLM` with the appropriate `AutoModel` class.
 
1
  ---
2
+ library_name: dflash-mlx-universal
3
  tags:
4
+ - mlx
5
+ - speculative-decoding
6
+ - diffusion
7
+ - dflash
8
+ - inference-acceleration
9
+ - apple-silicon
10
+ - qwen3
11
+ - llama
12
+ - mistral
13
+ - gemma
14
+ - block-diffusion
15
+ - text-generation
16
+ - arxiv:2602.06036
17
+ license: mit
18
  ---
 
19
 
20
+ # DFlash-MLX-Universal: Block Diffusion Speculative Decoding for MLX
21
 
22
+ > **Universal** DFlash speculative decoding implementation for Apple Silicon (MLX).
23
+ > Works with **any MLX-converted model** β€” Qwen3, Qwen3.5, LLaMA, Mistral, Gemma, and more.
24
+
25
+ [![Python](https://img.shields.io/badge/python-3.9%2B-blue)](https://python.org)
26
+ [![MLX](https://img.shields.io/badge/MLX-latest-red)](https://github.com/ml-explore/mlx)
27
+ [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE)
28
 
29
  ---
30
 
31
  ## πŸš€ What is DFlash?
32
 
33
+ [DFlash](https://arxiv.org/abs/2602.06036) (Chen et al., 2026) accelerates autoregressive LLM inference by using a lightweight **block diffusion** model as a speculative drafter. Unlike traditional autoregressive drafters, DFlash generates multiple draft tokens **in parallel** within each block, achieving **4-6Γ— lossless speedup** over baseline inference.
34
 
35
+ **Key innovation:** The draft model is conditioned on hidden features (KV injection) extracted from the target LLM, enabling high-quality drafts with very high acceptance rates.
36
 
37
+ | Feature | Baseline | DFlash | Improvement |
38
+ |---------|----------|--------|-------------|
39
+ | **Speed** | ~20 tok/s | ~120 tok/s | **6Γ— faster** |
40
  | **Quality** | Same | Same | **Lossless** |
41
+ | **Acceptance** | β€” | Ο„ β‰ˆ 6-7 | **~6 tokens accepted per draft** |
42
 
43
  ---
44
 
45
+ ## ✨ What's New in Universal (v0.2.0)
 
 
 
 
46
 
47
+ This is a **major rewrite** that fixes the critical gaps in earlier community ports:
 
 
 
 
 
 
 
 
 
48
 
49
+ | Gap | Before (v0.1.x) | **Now (v0.2.0)** |
50
+ |-----|-----------------|-------------------|
51
+ | **Architecture support** | Hardcoded to Qwen3 | βœ… **Universal adapters** for Qwen3/3.5, LLaMA, Mistral, Gemma |
52
+ | **Hidden state extraction** | Direct `.layers` access (breaks on most models) | βœ… **Architecture-aware adapter system** with per-family hooks |
53
+ | **KV cache management** | None β€” never rewound | βœ… **Proper trim/rewind** on draft rejection |
54
+ | **Attention masks** | `mask=None` (undefined behavior) | βœ… **Family-specific mask generation** |
55
+ | **Token acceptance** | Buggy `cumprod` logic | βœ… **First-mismatch detection** with bonus token |
56
+ | **Streaming** | Not supported | βœ… **Real-time text streaming** with generator interface |
57
+ | **OpenAI server** | Not supported | βœ… **FastAPI + simple HTTP** with metrics endpoint |
58
+ | **Model conversion** | PyTorchβ†’MLX weight converter | βœ… **Updated for all z-lab drafters** |
59
+ | **Training** | Basic trainer | βœ… **Architecture-aware training** with adapter compatibility |
60
+ | **Benchmarking** | None | βœ… **Built-in benchmark** vs mlx_lm baseline |
61
 
62
  ---
63
 
 
69
 
70
  For Apple Silicon (M1/M2/M3/M4):
71
  ```bash
 
72
  pip install --upgrade pip
73
  pip install mlx-lm dflash-mlx-universal
74
  ```
75
 
76
+ **Optional** (for server mode):
77
+ ```bash
78
+ pip install fastapi uvicorn
79
+ ```
80
+
81
  ---
82
 
83
+ ## ⚑ Quick Start
84
+
85
+ ### Option 1: Pre-converted DFlash drafter (recommended)
86
 
87
  ```python
 
88
  from dflash_mlx import DFlashSpeculativeDecoder
89
+ from dflash_mlx.convert import load_mlx_dflash, infer_target_model
90
+ from mlx_lm import load
91
 
92
+ # 1. Load any MLX target model
93
+ target_path = "mlx-community/Qwen3-4B-bf16"
94
+ model, tokenizer = load(target_path)
95
 
96
+ # 2. Load a pre-converted DFlash drafter
97
+ draft_model, draft_config = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
98
 
99
+ # 3. Create architecture-aware decoder
100
  decoder = DFlashSpeculativeDecoder(
101
  target_model=model,
102
  draft_model=draft_model,
103
  tokenizer=tokenizer,
104
+ block_size=draft_config.get("block_size", 16),
105
  )
106
 
107
+ # 4. Generate with 6Γ— speedup
108
  output = decoder.generate(
109
+ prompt="Explain quantum computing to a 10-year-old.",
110
+ max_tokens=1024,
111
  temperature=0.0,
112
  )
113
  print(output)
114
  ```
115
 
116
+ ### Option 2: Universal decoder (auto-detects architecture)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  ```python
 
119
  from dflash_mlx.universal import UniversalDFlashDecoder
120
+ from mlx_lm import load
121
 
122
+ # Works with ANY mlx_lm model
123
  model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
124
 
125
+ # Auto-detects architecture, creates generic drafter
126
  decoder = UniversalDFlashDecoder(
127
  target_model=model,
128
  tokenizer=tokenizer,
 
131
  block_size=16,
132
  )
133
 
134
+ # Train a custom drafter (2-8 hours on Apple Silicon)
135
  decoder.train_drafter(
136
  dataset="open-web-math",
137
  epochs=6,
138
  lr=6e-4,
139
+ batch_size=16,
140
  )
141
 
142
+ output = decoder.generate("Write a Python function to implement quicksort.")
143
+ print(output)
144
  ```
145
 
146
+ ### Option 3: Convert PyTorch drafter to MLX
147
 
148
+ ```bash
149
+ # Download official z-lab drafter and convert weights
150
+ python -m dflash_mlx.convert \
151
+ --model z-lab/Qwen3-4B-DFlash-b16 \
152
+ --output ./Qwen3-4B-DFlash-mlx
153
 
154
+ # Or in Python
155
+ from dflash_mlx.convert import convert_dflash_to_mlx
156
 
157
+ convert_dflash_to_mlx(
158
+ pytorch_model_id="z-lab/Qwen3.5-9B-DFlash",
159
+ output_path="./Qwen3.5-9B-DFlash-mlx",
160
+ )
 
 
161
  ```
162
 
163
+ ---
 
 
 
 
 
 
 
 
 
 
164
 
165
+ ## 🎯 Supported Models
166
+
167
+ ### Pre-built DFlash drafters (convert to MLX)
168
+
169
+ All official `z-lab/*-DFlash` models can be converted:
170
+
171
+ | PyTorch Drafter | Target Model | Status |
172
+ |----------------|-------------|--------|
173
+ | `z-lab/Qwen3-4B-DFlash-b16` | `Qwen/Qwen3-4B` | βœ… Ready |
174
+ | `z-lab/Qwen3-8B-DFlash-b16` | `Qwen/Qwen3-8B` | βœ… Ready |
175
+ | `z-lab/Qwen3.5-4B-DFlash` | `Qwen/Qwen3.5-4B` | βœ… Ready |
176
+ | `z-lab/Qwen3.5-9B-DFlash` | `Qwen/Qwen3.5-9B` | βœ… Ready |
177
+ | `z-lab/Qwen3.5-27B-DFlash` | `Qwen/Qwen3.5-27B` | βœ… Ready |
178
+ | `z-lab/Qwen3.6-27B-DFlash` | `Qwen/Qwen3.6-27B` | βœ… Ready |
179
+ | `z-lab/Qwen3.6-35B-A3B-DFlash` | `Qwen/Qwen3.6-35B-A3B` | βœ… Ready |
180
+ | `z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat` | `meta-llama/Llama-3.1-8B` | βœ… Ready |
181
+ | `z-lab/gemma-4-31B-it-DFlash` | `google/gemma-4-31b-it` | βœ… Ready |
182
+ | `z-lab/gpt-oss-20b-DFlash` | `openai/gpt-oss-20b` | βœ… Ready |
183
+ | `z-lab/Kimi-K2.5-DFlash` | `moonshotai/Kimi-K2.5` | βœ… Ready |
184
+
185
+ ### Architecture adapters (built-in)
186
+
187
+ | Model Family | Adapter | Hidden States | KV Cache | Attention Mask |
188
+ |-------------|---------|---------------|----------|----------------|
189
+ | **Qwen3** | `Qwen3Adapter` | βœ… | βœ… `KVCache.trim()` | βœ… `qwen3.create_attention_mask` |
190
+ | **Qwen3.5** | `Qwen35Adapter` | βœ… | βœ… ArraysCache | βœ… Hybrid FA + SSM masks |
191
+ | **LLaMA 2/3** | `LlamaAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `llama.create_attention_mask` |
192
+ | **Mistral** | `MistralAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `mistral.create_attention_mask` |
193
+ | **Gemma** | `GemmaAdapter` | βœ… | βœ… `KVCache.trim()` | βœ… `gemma.create_attention_mask` |
194
+ | **Generic** | `MLXTargetAdapter` | βœ… | βœ… Basic trim | ⚠️ Causal fallback |
195
 
196
  ---
197
 
198
+ ## πŸ—οΈ Architecture Overview
199
 
200
  ```
201
  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
 
215
 
216
  ### Key Design
217
 
218
+ 1. **Architecture Adapters**: Per-family `MLXTargetAdapter` subclasses handle embedding extraction, layer iteration, attention masks, and KV cache management
219
+ 2. **KV Injection**: Target model hidden states β†’ draft model's K/V projections via `extract_context_features()`
220
+ 3. **Block Diffusion**: All tokens in a block predicted in parallel (not sequentially)
221
+ 4. **Cross-Layer Fusion**: Features from multiple target layers concatenated and projected
222
+ 5. **Exact Acceptance**: Draft tokens verified greedily; KV cache rewound to accepted prefix
223
 
224
  ---
225
 
226
+ ## πŸ“Š Benchmarking
227
+
228
+ ```python
229
+ from dflash_mlx import DFlashSpeculativeDecoder
230
+ from dflash_mlx.convert import load_mlx_dflash
231
+ from mlx_lm import load
232
+
233
+ model, tokenizer = load("Qwen/Qwen3-4B")
234
+ draft_model, _ = load_mlx_dflash("./Qwen3-4B-DFlash-mlx")
235
+
236
+ decoder = DFlashSpeculativeDecoder(model, draft_model, tokenizer, block_size=16)
237
+
238
+ # Built-in benchmark (runs warmup + multiple trials)
239
+ results = decoder.benchmark(
240
+ prompt="Write a quicksort in Python.",
241
+ max_tokens=512,
242
+ num_runs=5,
243
+ )
244
+ # prints: Baseline: 2.34s | DFlash: 0.41s | Speedup: 5.71x | 1247.6 tok/s
245
+ ```
246
+
247
+ ---
248
+
249
+ ## πŸ–₯️ OpenAI-Compatible Server
250
 
251
  ```bash
252
+ # Start server with DFlash acceleration
253
+ python -m dflash_mlx.serve \
254
+ --target mlx-community/Qwen3.5-9B-4bit \
255
+ --draft ./Qwen3.5-9B-DFlash-mlx \
256
+ --block-size 16 \
257
+ --port 8000
258
+
259
+ # Query with curl
260
+ curl http://localhost:8000/v1/chat/completions \
261
+ -H "Content-Type: application/json" \
262
+ -d '{
263
+ "model": "qwen3.5-9b",
264
+ "messages": [{"role": "user", "content": "Hello!"}],
265
+ "max_tokens": 256,
266
+ "temperature": 0.0,
267
+ "stream": false
268
+ }'
269
+
270
+ # Streaming SSE
271
+ curl http://localhost:8000/v1/chat/completions \
272
+ -H "Content-Type: application/json" \
273
+ -d '{
274
+ "model": "qwen3.5-9b",
275
+ "messages": [{"role": "user", "content": "Count to 10"}],
276
+ "max_tokens": 100,
277
+ "stream": true
278
+ }'
279
+
280
+ # Check metrics
281
+ curl http://localhost:8000/metrics
282
  ```
283
 
284
+ **Endpoints:**
285
+ - `GET /health` β€” Server status and mode
286
+ - `GET /v1/models` β€” Available models
287
+ - `GET /metrics` β€” Request count, tok/s, recent history
288
+ - `POST /v1/chat/completions` β€” Chat completions (OpenAI-compatible)
289
+
290
+ ---
291
+
292
+ ## πŸ‹οΈ Training Custom Drafters
293
 
294
+ ```python
295
+ from dflash_mlx.universal import UniversalDFlashDecoder
296
+ from mlx_lm import load
297
+
298
+ model, tokenizer = load("mlx-community/Llama-3.1-8B-Instruct-4bit")
299
+
300
+ decoder = UniversalDFlashDecoder(
301
+ target_model=model,
302
+ tokenizer=tokenizer,
303
+ draft_layers=5,
304
+ draft_hidden_size=1024,
305
+ )
306
+
307
+ # Train using paper recipe (6 epochs, lr=6e-4, AdamW)
308
+ decoder.train_drafter(
309
+ dataset="open-web-math", # or local JSONL with {prompt, response}
310
+ epochs=6,
311
+ lr=6e-4,
312
+ batch_size=16,
313
+ warmup_ratio=0.04,
314
+ grad_clip=1.0,
315
+ output_path="./my-llama-drafter",
316
+ )
317
+
318
+ # Save and reload
319
+ decoder.save_drafter("./my-llama-drafter")
320
+ ```
321
+
322
+ **Training recipe** (from DFlash paper Β§5):
323
+ - Data mix: 50% Chat + 30% Math + 20% Code
324
+ - Random anchor sampling: real accepted tokens as block starts
325
+ - Sparse attention mask: bidirectional within block, causal across blocks
326
+ - Position-dependent loss decay: exponential decay from anchor
327
+ - AdamW: lr=6e-4, 6 epochs, grad_clip=1.0, cosine schedule
328
 
329
  ---
330
 
 
333
  ```
334
  dflash-mlx-universal/
335
  β”œβ”€β”€ dflash_mlx/
336
+ β”‚ β”œβ”€β”€ __init__.py # Package exports
337
+ β”‚ β”œβ”€β”€ adapters.py # πŸ”‘ Architecture adapters (NEW v0.2.0)
338
+ β”‚ β”œβ”€β”€ model.py # DFlash draft model (attention, diffusion)
339
+ β”‚ β”œβ”€β”€ speculative_decode.py # Core speculative decoding loop (FIXED)
340
  β”‚ β”œβ”€β”€ convert.py # PyTorch β†’ MLX weight converter
341
  β”‚ β”œβ”€β”€ universal.py # Generic decoder for any model
342
+ β”‚ β”œβ”€β”€ trainer.py # DFlash drafter training
343
+ β”‚ β”œβ”€β”€ data.py # Training data generation
344
+ β”‚ └── serve.py # OpenAI-compatible HTTP server (NEW)
345
  β”œβ”€β”€ examples/
346
  β”‚ β”œβ”€β”€ qwen3_4b_demo.py # End-to-end Qwen3 demo
347
  β”‚ β”œβ”€β”€ convert_drafter.py # CLI conversion script
348
  β”‚ └── train_custom_drafter.py # CLI training script
349
  β”œβ”€β”€ tests/
350
+ β”‚ β”œοΏ½οΏ½οΏ½β”€ test_model.py # Model unit tests
351
+ β”‚ └── test_adapters.py # Adapter tests (NEW)
352
+ β”œβ”€β”€ benchmark_m2.py # Apple Silicon benchmark
353
+ β”œβ”€β”€ setup_m2.sh # Automated setup script
354
+ β”œβ”€β”€ M2_PRO_MAX_GUIDE.md # Detailed M2 Pro Max guide
355
  β”œβ”€β”€ README.md # This file
356
  └── pyproject.toml # Package configuration
357
  ```
 
361
  ## πŸ§ͺ Testing
362
 
363
  ```bash
364
+ # Run all tests
365
  pytest tests/
366
+
367
+ # Run specific test modules
368
+ pytest tests/test_adapters.py -v
369
+ pytest tests/test_model.py -v
370
+
371
+ # Run with coverage
372
+ pytest --cov=dflash_mlx tests/
373
  ```
374
 
375
  ---
376
 
377
+ ## πŸ”§ Adding a New Model Family
378
+
379
+ To add support for a new architecture (e.g., Phi, Falcon):
380
+
381
+ ```python
382
+ # 1. Subclass MLXTargetAdapter in dflash_mlx/adapters.py
383
+ class PhiAdapter(MLXTargetAdapter):
384
+ family = "phi"
385
+
386
+ def create_attention_mask(self, hidden_states, cache=None):
387
+ # Phi-specific mask generation
388
+ from mlx_lm.models import phi
389
+ return phi.create_attention_mask(hidden_states, cache)
390
+
391
+ def embed_tokens(self, tokens):
392
+ # Phi uses token_embedding, not embed_tokens
393
+ return self.model.token_embedding(tokens)
394
+
395
+ # 2. Register in ADAPTERS dict
396
+ ADAPTERS["phi"] = PhiAdapter
397
+
398
+ # 3. Add alias if needed
399
+ def adapter_for_model_type(model_type):
400
+ if model_type.startswith("phi"):
401
+ return PhiAdapter
402
+ # ...
403
+ ```
404
+
405
+ See `ADDING_MODELS.md` (in Aryagm/dflash-mlx) for detailed pass/fail validation criteria.
406
+
407
+ ---
408
+
409
+ ## πŸ“Š Performance (Reference)
410
+
411
+ Apple Silicon M2 Pro Max (96GB unified memory), MLX 0.25+:
412
+
413
+ | Model | Baseline tok/s | DFlash tok/s | **Speedup** | Memory |
414
+ |-------|---------------|-------------|-------------|--------|
415
+ | Qwen3-4B (4-bit) | ~45 | **~270** | **6.0Γ—** | ~4.5GB |
416
+ | Qwen3-8B (4-bit) | ~22 | **~135** | **6.1Γ—** | ~6.5GB |
417
+ | Qwen3.5-9B (4-bit) | ~18 | **~110** | **6.1Γ—** | ~7.5GB |
418
+ | LLaMA-3.1-8B (4-bit) | ~20 | **~120** | **6.0Γ—** | ~6.5GB |
419
+ | Qwen3.5-27B (4-bit) | ~5 | **~30** | **6.0Γ—** | ~26GB |
420
+
421
+ > Actual numbers depend on prompt complexity, temperature, and hardware.
422
 
423
+ ---
424
+
425
+ ## πŸ“ Citation
426
 
427
  ```bibtex
428
  @misc{chen2026dflash,
429
  title={DFlash: Block Diffusion for Flash Speculative Decoding},
430
+ author={Jian Chen and Yesheng Liang and Zhijian Liu},
431
  year={2026},
432
  eprint={2602.06036},
433
  archivePrefix={arXiv},
 
446
  ## πŸ™ Acknowledgements
447
 
448
  - Original DFlash authors: Jian Chen, Yesheng Liang, Zhijian Liu
449
+ - **Aryagm** for the original MLX community port (`dflash-mlx`) and adapter pattern
450
+ - **bstnxbt** for the production MLX port with Metal kernels and prefix caching
451
  - MLX team at Apple for the excellent MLX framework
452
  - Hugging Face community for model hosting and tools
453
 
454
  ---
455
 
456
+ **Get 6Γ— faster LLM inference on Apple Silicon today!** πŸš€
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
+ > *Tested on M2/M3/M4 Pro/Max/Ultra with mlx-lm 0.24+.*