Lekr0 commited on
Commit
40d87dd
·
verified ·
1 Parent(s): a4f7c99

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. IDEA_REPORT.md +187 -0
  2. datasets/_workspace_hanrui_datasets_HuggingFaceH4___aime_2024_default_0.0.0_2fe88a2f1091d5048c0f36abc874fb997b3dd99a.lock +0 -0
  3. datasets/_workspace_hanrui_datasets_MathArena___aime_2025_default_0.0.0_beca2d7875cf92cdac07acefbccad3c4d16e2916.lock +0 -0
  4. datasets/_workspace_hanrui_datasets_google-research-datasets___mbpp_sanitized_0.0.0_4bb6404fdc6cacfda99d4ac4205087b89d32030c.lock +0 -0
  5. datasets/_workspace_hanrui_datasets_json_default-3ab01998402731b9_0.0.0_c181ad2be84b86e0b75142bbe88bda3f4906d051ee75b5ff536a5dba0ffbe8f2.lock +0 -0
  6. datasets/_workspace_hanrui_datasets_princeton-nlp___swe-bench_lite_default_0.0.0_6ec7bb89b9342f664a54a6e0a6ea6501d3437cc2.lock +0 -0
  7. datasets/_workspace_hanrui_datasets_tatsu-lab___alpaca_default_0.0.0_dce01c9b08f87459cf36a430d809084718273017.lock +0 -0
  8. datasets/download_nemotron_codealpha.sh +10 -0
  9. manage_subgits.sh +87 -0
  10. nohup.out +48 -0
  11. progress/dflash_lora_changelog.md +232 -0
  12. progress/list.md +12 -0
  13. progress/oom_fix_progress.md +42 -0
  14. progress/requirements.txt +20 -0
  15. progress/step1.md +139 -0
  16. sglang/.codespellrc +3 -0
  17. sglang/.editorconfig +25 -0
  18. sglang/.isort.cfg +3 -0
  19. sglang/.pre-commit-config.yaml +83 -0
  20. sglang/CODE_OF_CONDUCT.md +128 -0
  21. sglang/LICENSE +201 -0
  22. sglang/README.md +90 -0
  23. syxin_old/DFLASH_LORA_INJECT_FIXES.md +142 -0
  24. syxin_old/backup.log +0 -0
  25. syxin_old/dflash_8gpu_03-31-13:40.log +552 -0
  26. syxin_old/diagnostic_compare.py +301 -0
  27. syxin_old/eval_alignment_diff.md +132 -0
  28. syxin_old/eval_dflash_b16_baseline.py +354 -0
  29. syxin_old/eval_dflash_b16_baseline_changelog.md +143 -0
  30. syxin_old/eval_dflash_lora_inject.py +660 -0
  31. syxin_old/eval_gsm8k_humaneval_mtbench.log +81 -0
  32. syxin_old/eval_run.log +0 -0
  33. syxin_old/launch_train.sh +37 -0
  34. syxin_old/launch_train_dflash_wrapper.py +17 -0
  35. syxin_old/launch_train_random_anchor.py +15 -0
  36. syxin_old/launch_train_wrapper.py +21 -0
  37. syxin_old/list.md +12 -0
  38. syxin_old/merge_lora.py +66 -0
  39. syxin_old/oom_fix_progress.md +42 -0
  40. syxin_old/random_anchor_plan.md +82 -0
  41. syxin_old/requirements.txt +0 -0
  42. syxin_old/run_bench_dflash.sh +71 -0
  43. syxin_old/run_bench_dflash_b16_baseline.sh +60 -0
  44. syxin_old/run_bench_dflash_lora_inject.sh +60 -0
  45. syxin_old/run_qwen3_8b_sft_64gpu.sh +31 -0
  46. syxin_old/run_train_dflash_lora_inject.sh +73 -0
  47. syxin_old/run_train_multinode.sh +67 -0
  48. syxin_old/run_train_multinode_random_anchor.sh +72 -0
  49. syxin_old/start_server.sh +42 -0
  50. syxin_old/start_server_dflash.sh +54 -0
IDEA_REPORT.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash Improvement Ideas: Higher Acceptance Length Without Training
2
+
3
+ **Goal:** Improve DFlash's acceptance length (tau) and acceleration ratio using only inference-time modifications — no additional training.
4
+
5
+ **Baseline:** Qwen3-4B + z-lab/Qwen3-4B-DFlash-b16, block_size=16, math500 (10 samples, 512 tokens)
6
+ - **Baseline avg tau = 8.63**, median = 8.0
7
+
8
+ ---
9
+
10
+ ## Idea 1: Iterative Block Refinement (Multi-Step Denoising)⭐⭐⭐⭐⭐
11
+
12
+ **Core Idea:** Run the DFlash draft model multiple times on the same block. After each pass, use the sampled tokens as updated noise embeddings for the next pass, mimicking multi-step diffusion denoising.
13
+
14
+ **Why it might work:** DFlash currently uses a single forward pass to predict all block tokens from mask tokens. The initial mask embeddings carry no information about what the draft should generate. By iterating, each pass conditions on an increasingly informed noise context — the first pass gives a rough draft, the second pass refines it with better token embeddings as context.
15
+
16
+ **Implementation complexity:** Low. Just loop the draft forward pass 2-3 times, feeding output back as input. No KV cache across steps.
17
+
18
+ **Expected improvement:** +0.5 to +2.0 tau (denoising is the core mechanism of diffusion models — more steps should help).
19
+
20
+ **Risk:** Extra draft compute may negate speedup gains. Must keep step count low (2-3) to maintain wall-clock advantage.
21
+
22
+ **Pilot result:** `[PENDING]`
23
+
24
+ ## Idea 1 plus: Confidence-Gated Selective Redrafting
25
+
26
+ **Core Idea:** After the first draft pass, compute per-position entropy of the draft logits. If any position (especially early ones) has high entropy (>threshold), run a second draft pass with the partially-filled block as context. Only replace the high-entropy positions with the second pass's predictions.
27
+
28
+ **Why it might work:** High entropy at a position signals that the draft model is uncertain — these are the positions most likely to cause rejection. A second pass, now conditioned on a partially-correct draft, can refine exactly these problematic positions.
29
+
30
+ **Implementation complexity:** Medium. Two draft passes + entropy computation + selective replacement.
31
+
32
+ **Expected improvement:** +0.5 to +2.0 tau (targeted improvement where it matters most).
33
+
34
+ **Risk:** Extra compute for the second pass. Entropy threshold needs tuning per dataset/model.
35
+
36
+ **Pilot result:** `[PENDING]
37
+
38
+
39
+
40
+ ---
41
+
42
+ ## Idea 2: N-Best Draft Proposals (Multi-Candidate Selection)⭐
43
+
44
+ **Core Idea:** Generate K candidate draft blocks (K=2-4) using different sampling strategies (greedy + temperature-based), then select the candidate with the highest aggregate log-probability under the draft model's own distribution.
45
+
46
+ **Why it might work:** Exact-match acceptance is binary — a single wrong token kills the entire suffix. By generating multiple candidates and picking the most confident one, we increase the probability that at least one candidate matches the target's greedy output. The confidence score acts as a proxy for "likely to match target."
47
+
48
+ **Implementation complexity:** Low-Medium. K forward passes per block, simple confidence scoring.
49
+
50
+ **Expected improvement:** +0.5 to +2.5 tau (especially for "unlucky" blocks where the default greedy choice is wrong).
51
+
52
+ **Risk:** K times the draft compute cost. Must keep K small. Confidence score may not perfectly correlate with acceptance.
53
+
54
+ **Pilot result:** `[PENDING]`
55
+
56
+ ---
57
+
58
+ ## Idea 6: Token Recycling / Warm-Start Drafting⭐⭐⭐
59
+
60
+ **Core Idea:** When rejection occurs at position j in a block of B tokens, the rejected tokens at positions j+1..B are discarded. Instead, save these tokens and use them to warm-start the noise embeddings of the next draft block. This gives the draft model a better starting point than random mask tokens.
61
+
62
+ **Why it might work:** Even though the prefix was wrong, later tokens in the rejected draft may still carry useful distributional information about the continuation. Using them as initial noise (instead of mask tokens) gives the draft model more context for its single-pass prediction.
63
+
64
+ **Implementation complexity:** Low. Save rejected suffix, inject into next block's initial embeddings.
65
+
66
+ **Expected improvement:** +0.3 to +1.0 tau (modest, since the recycled tokens are conditioned on a wrong prefix).
67
+
68
+ **Risk:** Recycled tokens may actually mislead the draft model if they were generated from a very different prefix. Net effect could be negative.
69
+
70
+ **Pilot result:** `[PENDING]`
71
+
72
+ ---
73
+
74
+ ## Idea 9: Dynamic Target Layer Selection
75
+
76
+ **Core Idea:** Instead of always extracting features from the same 5 fixed target layers, try alternative layer selections (e.g., shifted by +2 or -2) and pick the one that produces the highest-confidence draft. Different parts of the sequence may benefit from different layers.
77
+
78
+ **Why it might work:** The paper's ablation (Table 5) shows that layer selection affects acceptance length. The optimal layers may vary by position in the sequence or by the type of content being generated. Late layers have more "final answer" information; early layers have more syntactic/structural information.
79
+
80
+ **Implementation complexity:** Medium. Multiple draft passes with different layer configs + scoring.
81
+
82
+ **Expected improvement:** +0.3 to +1.5 tau (if the fixed layers are suboptimal for certain content types).
83
+
84
+ **Risk:** The draft model's fc projection was trained on specific layer combinations. Using different layers degrades the learned alignment. Needs the fc layer to generalize.
85
+
86
+ **Pilot result:** `[PENDING]`
87
+
88
+ ---
89
+
90
+ ## Idea 11: Top-K Constrained Draft SamplingIdea 7: Confidence-Gated Selective Redrafting⭐
91
+
92
+ **Core Idea:** After the first draft pass, compute per-position entropy of the draft logits. If any position (especially early ones) has high entropy (>threshold), run a second draft pass with the partially-filled block as context. Only replace the high-entropy positions with the second pass's predictions.
93
+
94
+ **Why it might work:** High entropy at a position signals that the draft model is uncertain — these are the positions most likely to cause rejection. A second pass, now conditioned on a partially-correct draft, can refine exactly these problematic positions.
95
+
96
+ **Implementation complexity:** Medium. Two draft passes + entropy computation + selective replacement.
97
+
98
+ **Expected improvement:** +0.5 to +2.0 tau (targeted improvement where it matters most).
99
+
100
+ **Risk:** Extra compute for the second pass. Entropy threshold needs tuning per dataset/model.
101
+
102
+ **Pilot result:** `[PENDING]
103
+
104
+ **Core Idea:** Apply top-k filtering to draft logits before sampling, zeroing out all but the top-k tokens at each position. This forces the draft to choose among only the most probable tokens.
105
+
106
+ **Why it might work:** For exact-match acceptance under greedy target decoding, only the target's argmax token matters. By restricting the draft's vocabulary to its own top-k, we reduce the chance of sampling a low-probability token that definitely won't match the target.
107
+
108
+ **Implementation complexity:** Very low. Single top-k operation on logits.
109
+
110
+ **Expected improvement:** +0.1 to +0.5 tau (minor, since greedy draft already picks argmax; mainly helps with stochastic target).
111
+
112
+ **Risk:** Under greedy draft + greedy target, this is a no-op. Only helps when draft uses non-zero temperature.
113
+
114
+ **Pilot result:** `[PENDING]`
115
+
116
+ ---
117
+
118
+ ## Idea 12: Position-Weighted Logit Scaling⭐⭐
119
+
120
+ **Core Idea:** Scale draft logits by a position-dependent factor: early positions get more aggressive scaling (sharper distribution = higher confidence), later positions get gentler scaling. Rationale: early positions matter most for prefix-based acceptance.
121
+
122
+ **Why it might work:** By sharpening early positions, we increase the probability that positions 1-3 are correct (the most critical for tau). Later positions can afford to be less sharp since they only matter if all earlier positions are accepted.
123
+
124
+ **Implementation complexity:** Very low. Multiply logits by a position-dependent vector.
125
+
126
+ **Expected improvement:** +0.2 to +1.0 tau.
127
+
128
+ **Risk:** Over-sharpening may concentrate probability on a wrong token. Needs careful calibration of the scaling schedule.
129
+
130
+ **Pilot result:** `[PENDING]`
131
+
132
+ ---
133
+
134
+ ## Bonus Ideas (Not Yet Implemented)
135
+
136
+ ### Idea 13: Tree-Structured Verification
137
+ Verify multiple candidate continuations in a single batched target forward pass using packed attention with tree causal masks. This doesn't improve tau per-candidate but amortizes the verification cost across candidates, enabling higher effective throughput. Very promising for combining with N-best or beam approaches.
138
+
139
+ ### Idea 16: Draft-Target KL Alignment via Inference-Time Calibration⭐⭐⭐
140
+ Compute a lightweight calibration mapping (affine transform on draft logits) by running a small calibration set and measuring draft vs target token agreement. Apply this calibration at inference time without retraining.
141
+
142
+ ### Idea 17: Multi-Block Pipelining
143
+ Overlap the draft and verification phases across blocks. While the target model verifies block k, the draft model starts working on block k+1 using a speculative target_hidden extrapolation. If the speculation was right, the pipeline stays full.
144
+
145
+ ---
146
+
147
+ ## Experiment Configuration
148
+
149
+ | Parameter | Value |
150
+ |-----------|-------|
151
+ | Target model | Qwen/Qwen3-4B |
152
+ | Draft model | z-lab/Qwen3-4B-DFlash-b16 |
153
+ | Block size | 16 |
154
+ | Dataset | math500 |
155
+ | Max samples | 10 |
156
+ | Max new tokens | 512 |
157
+ | Temperature | 0.0 (greedy) |
158
+ | GPU | NVIDIA H200 (single GPU) |
159
+ | Attention | SDPA |
160
+
161
+ ## Results Summary
162
+
163
+ | # | Method | Avg tau | Delta | Pilot Signal |
164
+ |---|--------|---------|-------|--------------|
165
+ | 0 | **Baseline** | **8.63** | - | - |
166
+ | 1 | Iterative Refinement (2 steps) | `[PENDING]` | | |
167
+ | 2 | Iterative Refinement (3 steps) | `[PENDING]` | | |
168
+ | 3 | N-Best Draft (K=2) | `[PENDING]` | | |
169
+ | 4 | N-Best Draft (K=3) | `[PENDING]` | | |
170
+ | 5 | Adaptive Block Size (4-16) | `[PENDING]` | | |
171
+ | 6 | Early-Position Beam (width=3) | `[PENDING]` | | |
172
+ | 7 | Draft Temp t=0.3 | `[PENDING]` | | |
173
+ | 8 | Draft Temp t=0.1 | `[PENDING]` | | |
174
+ | 9 | Token Recycling | `[PENDING]` | | |
175
+ | 10 | Selective Redraft (ent>1.5) | `[PENDING]` | | |
176
+ | 11 | Selective Redraft (ent>1.0) | `[PENDING]` | | |
177
+ | 12 | Majority Vote (K=3) | `[PENDING]` | | |
178
+ | 13 | Majority Vote (K=5) | `[PENDING]` | | |
179
+ | 14 | Shifted Target Layers (+2) | `[PENDING]` | | |
180
+ | 15 | Logit Averaging (2 pass) | `[PENDING]` | | |
181
+ | 16 | Logit Averaging (3 pass) | `[PENDING]` | | |
182
+ | 17 | Top-K Constrained (k=10) | `[PENDING]` | | |
183
+ | 18 | Position-Weighted Temp | `[PENDING]` | | |
184
+
185
+ ---
186
+
187
+ *Generated 2026-04-01. Experiments running on NVIDIA H200, dflash conda env.*
datasets/_workspace_hanrui_datasets_HuggingFaceH4___aime_2024_default_0.0.0_2fe88a2f1091d5048c0f36abc874fb997b3dd99a.lock ADDED
File without changes
datasets/_workspace_hanrui_datasets_MathArena___aime_2025_default_0.0.0_beca2d7875cf92cdac07acefbccad3c4d16e2916.lock ADDED
File without changes
datasets/_workspace_hanrui_datasets_google-research-datasets___mbpp_sanitized_0.0.0_4bb6404fdc6cacfda99d4ac4205087b89d32030c.lock ADDED
File without changes
datasets/_workspace_hanrui_datasets_json_default-3ab01998402731b9_0.0.0_c181ad2be84b86e0b75142bbe88bda3f4906d051ee75b5ff536a5dba0ffbe8f2.lock ADDED
File without changes
datasets/_workspace_hanrui_datasets_princeton-nlp___swe-bench_lite_default_0.0.0_6ec7bb89b9342f664a54a6e0a6ea6501d3437cc2.lock ADDED
File without changes
datasets/_workspace_hanrui_datasets_tatsu-lab___alpaca_default_0.0.0_dce01c9b08f87459cf36a430d809084718273017.lock ADDED
File without changes
datasets/download_nemotron_codealpha.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export HF_TOKEN="YOUR_HF_TOKEN_HERE"
4
+ export HF_HUB_ENABLE_HF_TRANSFER=1
5
+ export HF_HUB_VERBOSITY=debug
6
+
7
+ hf download \
8
+ --repo-type dataset \
9
+ --local-dir /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
10
+ eigen-ai-labs/Nemotron-CodeAlpaca-qwen3-8b-800K
manage_subgits.sh ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # 管理子目录中的 .git 文件夹:备份、删除、恢复
3
+ # 用法:
4
+ # ./manage_subgits.sh backup - 备份并删除子目录中的 .git
5
+ # ./manage_subgits.sh restore - 从备份中恢复 .git
6
+
7
+ set -euo pipefail
8
+
9
+ cd "$(dirname "$0")"
10
+
11
+ BACKUP_DIR=".git_backups"
12
+ MANIFEST="$BACKUP_DIR/manifest.txt"
13
+
14
+ backup() {
15
+ if [ -d "$BACKUP_DIR" ]; then
16
+ echo "❌ 备份目录 $BACKUP_DIR 已存在,请先 restore 或手动删除"
17
+ exit 1
18
+ fi
19
+
20
+ mkdir -p "$BACKUP_DIR"
21
+ > "$MANIFEST"
22
+
23
+ count=0
24
+ while IFS= read -r gitdir; do
25
+ count=$((count + 1))
26
+ echo "$count|$gitdir" >> "$MANIFEST"
27
+
28
+ echo "📦 备份: $gitdir"
29
+ cp -a "$gitdir" "$BACKUP_DIR/$count"
30
+
31
+ echo "🗑️ 删除: $gitdir"
32
+ rm -rf "$gitdir"
33
+ done < <(find . -mindepth 2 -name ".git" -not -path "./$BACKUP_DIR/*" | sort)
34
+
35
+ if [ "$count" -eq 0 ]; then
36
+ rm -rf "$BACKUP_DIR"
37
+ echo "ℹ️ 没有找到子目录中的 .git,无需操作"
38
+ else
39
+ echo ""
40
+ echo "✅ 完成!共备份并删除了 $count 个 .git"
41
+ echo "📁 备份存放在: $BACKUP_DIR/"
42
+ echo "👉 上传完成后运行: $0 restore"
43
+ fi
44
+ }
45
+
46
+ restore() {
47
+ if [ ! -f "$MANIFEST" ]; then
48
+ echo "❌ 找不到备份清单 $MANIFEST,没有可恢复的内容"
49
+ exit 1
50
+ fi
51
+
52
+ count=0
53
+ while IFS='|' read -r id gitdir; do
54
+ if [ ! -d "$BACKUP_DIR/$id" ]; then
55
+ echo "⚠️ 跳过: 备份 #$id 不存在 ($gitdir)"
56
+ continue
57
+ fi
58
+
59
+ mkdir -p "$(dirname "$gitdir")"
60
+
61
+ echo "♻️ 恢复: $gitdir"
62
+ cp -a "$BACKUP_DIR/$id" "$gitdir"
63
+ count=$((count + 1))
64
+ done < "$MANIFEST"
65
+
66
+ rm -rf "$BACKUP_DIR"
67
+
68
+ echo ""
69
+ echo "✅ 完成!共恢复了 $count 个 .git"
70
+ echo "🧹 备份目录已清理"
71
+ }
72
+
73
+ case "${1:-}" in
74
+ backup)
75
+ backup
76
+ ;;
77
+ restore)
78
+ restore
79
+ ;;
80
+ *)
81
+ echo "用法: $0 {backup|restore}"
82
+ echo ""
83
+ echo " backup - 备份子目录中所有 .git 并删除它们"
84
+ echo " restore - 从备份中恢复所有 .git"
85
+ exit 1
86
+ ;;
87
+ esac
nohup.out ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /workspace/miniconda3/envs/dflash/bin/python3: can't open file '/workspace/hanrui/ ': [Errno 2] No such file or directory
2
+ E0317 16:57:14.100000 140364991186752 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 2) local_rank: 0 (pid: 14058) of binary: /workspace/miniconda3/envs/dflash/bin/python3
3
+ Traceback (most recent call last):
4
+ File "<frozen runpy>", line 198, in _run_module_as_main
5
+ File "<frozen runpy>", line 88, in _run_code
6
+ File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/run.py", line 905, in <module>
7
+ main()
8
+ File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
9
+ return f(*args, **kwargs)
10
+ ^^^^^^^^^^^^^^^^^^
11
+ File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/run.py", line 901, in main
12
+ run(args)
13
+ File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/run.py", line 892, in run
14
+ elastic_launch(
15
+ File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
16
+ return launch_agent(self._config, self._entrypoint, list(args))
17
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
18
+ File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
19
+ raise ChildFailedError(
20
+ torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
21
+ ============================================================
22
+ FAILED
23
+ ------------------------------------------------------------
24
+ Failures:
25
+ <NO_OTHER_FAILURES>
26
+ ------------------------------------------------------------
27
+ Root Cause (first observed failure):
28
+ [0]:
29
+ time : 2026-03-17_16:57:14
30
+ host : job-006ce80a7c47-20260302193512-5dcd4c9bbd-gfjsn
31
+ rank : 0 (local_rank: 0)
32
+ exitcode : 2 (pid: 14058)
33
+ error_file: <N/A>
34
+ traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
35
+ ============================================================
36
+ usage: run.py [-h] [--nnodes NNODES] [--nproc-per-node NPROC_PER_NODE]
37
+ [--rdzv-backend RDZV_BACKEND] [--rdzv-endpoint RDZV_ENDPOINT]
38
+ [--rdzv-id RDZV_ID] [--rdzv-conf RDZV_CONF] [--standalone]
39
+ [--max-restarts MAX_RESTARTS]
40
+ [--monitor-interval MONITOR_INTERVAL]
41
+ [--start-method {spawn,fork,forkserver}] [--role ROLE] [-m]
42
+ [--no-python] [--run-path] [--log-dir LOG_DIR] [-r REDIRECTS]
43
+ [-t TEE] [--local-ranks-filter LOCAL_RANKS_FILTER]
44
+ [--node-rank NODE_RANK] [--master-addr MASTER_ADDR]
45
+ [--master-port MASTER_PORT] [--local-addr LOCAL_ADDR]
46
+ [--logs-specs LOGS_SPECS]
47
+ training_script ...
48
+ run.py: error: the following arguments are required: training_script, training_script_args
progress/dflash_lora_changelog.md ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash LoRA 全部改动记录
2
+
3
+ ## 概述
4
+
5
+ 为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。
6
+
7
+ ---
8
+
9
+ ## 新增文件清单
10
+
11
+ | 文件 | 行数 | 用途 |
12
+ |------|------|------|
13
+ | `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) |
14
+ | `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) |
15
+ | `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 |
16
+ | `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 |
17
+ | `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 |
18
+
19
+ ---
20
+
21
+ ## Step 1 完成过程
22
+
23
+ ### 1.1 分析现有代码
24
+
25
+ 首先分析了非 LoRA 版 `train_dflash.py` 的完整流程:
26
+
27
+ ```
28
+ input_ids → target_model.generate_dflash_data() → hidden_states
29
+ → OnlineDFlashModel.forward():
30
+ 1. 截断到 block 边界
31
+ 2. prepare_noise_input(): anchor 保留,其余 → MASK
32
+ 3. embed_tokens(noise_input_ids) → noise_embedding
33
+ 4. 构建 DFlash attention mask
34
+ 5. draft_model(noise_embedding, target_hidden, mask)
35
+ 6. lm_head(hidden) → logits → CE loss
36
+ ```
37
+
38
+ 非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。
39
+
40
+ ### 1.2 确定 LoRA 版设计差异
41
+
42
+ | 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) |
43
+ |------|------|------|
44
+ | Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
45
+ | Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
46
+ | Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
47
+ | KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
48
+ | 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |
49
+
50
+ ### 1.3 新建 LoRA 版三个核心文件
51
+
52
+ #### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel
53
+
54
+ - `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数
55
+ - `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要)
56
+ - `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用
57
+ - `gradient_checkpointing_enable()`: 代理到底层模型
58
+ - `save_pretrained()`: 仅保存 LoRA adapter 权重
59
+
60
+ #### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel
61
+
62
+ - `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
63
+ - `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]`
64
+ - `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
65
+ - `_full_lm_loss()`: 标准 CE loss 路径
66
+ - `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
67
+ - `forward()`: 完整训练 forward pass
68
+
69
+ LoRA 版 mask 规则:
70
+ - context token i → 因果注意力 (j ≤ i)
71
+ - block token i (属于 block b) → 所有 context + 同 block 内双向注意力
72
+
73
+ #### `scripts/train_dflash_lora.py` — 训练脚本
74
+
75
+ - 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
76
+ - `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
77
+ - `build_dataloader()`: 复用 `build_eagle3_dataset` 和 `prepare_dp_dataloaders`
78
+ - FSDP 包装 + BF16Optimizer
79
+ - 训练循环:forward → backward → accumulation → optimizer step
80
+ - checkpoint 保存/恢复
81
+
82
+ ---
83
+
84
+ ## OOM 修复改动(4 项)
85
+
86
+ ### 改动 1: FSDP FULL_SHARD (ZeRO-3)
87
+
88
+ **问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)
89
+
90
+ **修复**: `train_dflash_lora.py:362`
91
+ ```python
92
+ # 之前
93
+ sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
94
+ # 之后
95
+ sharding_strategy=ShardingStrategy.FULL_SHARD
96
+ ```
97
+
98
+ **效果**: 参数跨卡分片,每卡省 ~8-12GB
99
+
100
+ ### 改动 2: batch_size=1 + accumulation_steps=8
101
+
102
+ **问题**: `batch_size=2` 时峰值显存过高
103
+
104
+ **修复**: `run_train_dflash_lora.sh`
105
+ ```bash
106
+ --batch-size 1 \
107
+ --accumulation-steps 8 \
108
+ ```
109
+
110
+ **效果**: 等效 global batch size 不变,峰值显存减半
111
+
112
+ ### 改动 3: flex_attention + BlockMask 替换 4D additive mask
113
+
114
+ **问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores
115
+
116
+ **修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景
117
+
118
+ 涉及文件:
119
+
120
+ 1. **`specforge/core/dflash_lora.py`**
121
+ - `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段
122
+ - 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask
123
+ - `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask
124
+
125
+ 2. **`specforge/modeling/draft/dflash_lora.py`**
126
+ - `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace
127
+
128
+ 3. **`scripts/train_dflash_lora.py`**
129
+ - `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`)
130
+ - `build_model()`: 根据 backend 选择 `attn_implementation`
131
+
132
+ BlockMask mask function(LoRA 版):
133
+ ```python
134
+ def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
135
+ # Context query: 标准因果
136
+ is_q_ctx = q_idx < context_len
137
+ ctx_visible = is_q_ctx & (kv_idx <= q_idx)
138
+
139
+ # Block query: 全部 context + 同 block 双向
140
+ is_q_block = q_idx >= context_len
141
+ is_k_ctx = kv_idx < context_len
142
+ q_block_id = (q_idx - context_len) // block_size
143
+ k_block_id = (kv_idx - context_len) // block_size
144
+ block_attend_ctx = is_q_block & is_k_ctx
145
+ block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)
146
+
147
+ return ctx_visible | (block_attend_ctx | block_attend_same)
148
+ ```
149
+
150
+ **验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。
151
+
152
+ **效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存
153
+
154
+ ### 改动 4: chunked cross-entropy loss
155
+
156
+ **问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+
157
+
158
+ **修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss
159
+
160
+ 涉及文件:
161
+
162
+ 1. **`specforge/core/dflash_lora.py`**
163
+ - `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用)
164
+ - 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing
165
+ - 提取 `_full_lm_loss()`: 原始非 chunked 路径
166
+ - `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径
167
+
168
+ 2. **`specforge/modeling/draft/dflash_lora.py`**
169
+ - `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits
170
+ - `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用
171
+
172
+ 3. **`scripts/train_dflash_lora.py`**
173
+ - `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256)
174
+ - `build_model()`: 传递到 OnlineDFlashLoRAModel
175
+
176
+ Chunked loss 核心逻辑:
177
+ ```python
178
+ # 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
179
+ for start in range(0, effective_len, chunk_size):
180
+ end = min(start + chunk_size, effective_len)
181
+ chunk_loss, chunk_weight = grad_checkpoint(
182
+ _chunk_ce, # lm_head + CE
183
+ hidden[:, start:end, :], # 只取当前 chunk
184
+ input_ids[:, start:end],
185
+ combined_mask[:, start:end],
186
+ use_reentrant=False,
187
+ )
188
+ total_loss += chunk_loss
189
+ total_weight += chunk_weight
190
+ loss = total_loss / total_weight
191
+ ```
192
+
193
+ **效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB
194
+
195
+ ---
196
+
197
+ ## 当前训练命令
198
+
199
+ ```bash
200
+ bash run_train_dflash_lora.sh 2 # 2 = GPU 数量
201
+ ```
202
+
203
+ 对应完整参数:
204
+ ```bash
205
+ torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
206
+ --model-path /workspace/Qwen3-8B \
207
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
208
+ --output-dir outputs/qwen3-8b-dflash-lora \
209
+ --lora-config configs/qwen3-8b-dflash-lora.json \
210
+ --block-size 16 \
211
+ --max-length 2048 \
212
+ --batch-size 1 \
213
+ --num-epochs 3 \
214
+ --learning-rate 2e-4 \
215
+ --accumulation-steps 8 \
216
+ --loss-decay-gamma 7 \
217
+ --attention-backend flex_attention \
218
+ --lm-head-chunk-size 256 \
219
+ --gradient-checkpointing \
220
+ --chat-template qwen \
221
+ --log-interval 50 \
222
+ --save-interval 500
223
+ ```
224
+
225
+ ---
226
+
227
+ ## 待验证
228
+
229
+ - [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM
230
+ - [ ] 确认无 SDPA math fallback warning
231
+ - [ ] 观察 GPU 显存峰值
232
+ - [ ] 确认 loss 下降和 accuracy 上升趋势正常
progress/list.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. `train_dflash_lora.py`
2
+ * 加了lora,原来是调用小模型,现在是hidden states+lora预测。
3
+ * `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
4
+
5
+ ### 2. OOM优化
6
+ * 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。
7
+ * `batch-size=1`,`accumulation-steps=8`。
8
+ * 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
9
+ * `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。
10
+
11
+ ### 运行
12
+ * bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2
progress/oom_fix_progress.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash LoRA OOM 修复记录
2
+
3
+ ## OOM 根因分析
4
+
5
+ 1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
6
+ 2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`)
7
+ 3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
8
+ 4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4`
9
+
10
+ ## 已完成的改动
11
+
12
+ ### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
13
+ - 文件: `SpecForge/scripts/train_dflash_lora.py:347`
14
+ - `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD`
15
+ - 效果: 参数跨卡分片,每卡省 ~8-12GB
16
+
17
+ ### 2. 降 batch-size,提高 accumulation-steps
18
+ - 文件: `SpecForge/scripts/run_train_dflash_lora.sh`
19
+ - `--batch-size 2` → `1`,`--accumulation-steps 4` → `8`
20
+ - 效果: 等效 global batch size 不变,峰值显存减半
21
+
22
+ ## 待验证 / 后续优化
23
+
24
+ - [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡
25
+ - [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
26
+ - [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
27
+
28
+ ### 3. flex_attention + BlockMask 替换 4D additive mask
29
+ - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
30
+ - 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
31
+ - LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
32
+ - 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask
33
+ - HuggingFace model 用 `attn_implementation="flex_attention"` 加载
34
+ - 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存
35
+
36
+ ### 4. chunked cross-entropy loss
37
+ - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
38
+ - 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法
39
+ - 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits
40
+ - 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用)
41
+ - `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states
42
+ - 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)
progress/requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchaudio==2.4.0
3
+ torchvision==0.19.0
4
+ transformers==4.57.1
5
+ accelerate
6
+ datasets
7
+ peft
8
+ numpy
9
+ tqdm
10
+ wandb
11
+ psutil
12
+ pydantic
13
+ packaging
14
+ ninja
15
+ setuptools
16
+ yunchang
17
+ sglang<=0.4.10
18
+ openai-harmony
19
+ qwen-vl-utils==0.0.11
20
+ pre-commit
progress/step1.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Now I have a complete picture. Here's the full analysis and proposed plan.                                                               
2
+                                                                                                                                            
3
+   ---                                                                                                                                      
4
+   Current Logic of train_dflash.py                                                                                                         
5
+                                                                                                                                            
6
+   This script trains a small, standalone DFlash draft model (1–10 transformer layers) that learns to predict token blocks in parallel,     
7
+   guided by hidden states extracted from a frozen target model.
8
+
9
+   Architecture (two separate models)
10
+
11
+   1. Target model (frozen, e.g. Qwen3-8B/32B) — runs a full forward pass to produce hidden states at selected layers.
12
+   2. Draft model (DFlashDraftModel, ~1–10 layers) — a lightweight Qwen3-based decoder that takes noise embeddings + target hidden states
13
+   and predicts the block tokens.
14
+   3. Target embed_tokens + lm_head — loaded separately via TargetEmbeddingsAndHead to avoid duplicating the full target model in memory.
15
+
16
+   Key locations
17
+
18
+   ┌──────────────────────────┬────────────────────────────────────┬───────────────────────────────────────────────────────┐
19
+   │        Component         │                File                │                         Lines                         │
20
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
21
+   │ Model init               │ scripts/train_dflash.py            │ build_models() L254–311                               │
22
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
23
+   │ Target hidden extraction │ scripts/train_dflash.py            │ L644–647 (target_model.generate_dflash_data)          │
24
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
25
+   │ Forward pass             │ specforge/core/dflash.py           │ OnlineDFlashModel.forward() L243–332                  │
26
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
27
+   │ Loss calculation         │ specforge/core/dflash.py           │ _full_lm_loss() L382–417, _chunked_lm_loss() L419–478 │
28
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
29
+   │ Loss mask                │ specforge/core/dflash.py           │ create_dflash_loss_mask() L481–509                    │
30
+   ├──────────────────────────┼────────────────────────────────────┼────���──────────────────────────────────────────────────┤
31
+   │ Draft model architecture │ specforge/modeling/draft/dflash.py │ DFlashDraftModel L212–266                             │
32
+   ├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
33
+   │ DFlash attention         │ specforge/modeling/draft/dflash.py │ Qwen3DFlashAttention L42–134                          │
34
+   └──────────────────────────┴────────────────────────────────────┴───────────────────────────────────────────────────────┘
35
+
36
+   Forward pass flow (per training step)
37
+
38
+   input_ids, attention_mask, loss_mask  →  target_model.generate_dflash_data()
39
+                                                 ↓
40
+                                        hidden_states (from target layers [1,9,17,25,33])
41
+                                                 ↓
42
+                                 OnlineDFlashModel.forward():
43
+                                   1. Truncate to block boundary
44
+                                   2. prepare_noise_input(): anchor tokens kept, rest → MASK
45
+                                   3. embed_tokens(noise_input_ids) → noise_embedding
46
+                                   4. Build DFlash attention mask (flex_attention or additive)
47
+                                   5. draft_model(noise_embedding, target_hidden, mask)
48
+                                   6. lm_head(hidden) → logits
49
+                                   7. CE loss on non-anchor positions (weighted by loss_mask × decay)
50
+
51
+   The draft model's custom Qwen3DFlashAttention concatenates [context_hidden, noise_hidden] as KV, with queries only from noise tokens. The
52
+    attention mask enforces: block tokens see all preceding blocks' context + bidirectional within their own block.
53
+
54
+   ---
55
+   What already exists: train_dflash_lora.py
56
+
57
+   Interestingly, the repo already has a LoRA variant at scripts/train_dflash_lora.py with its own model (DFlashLoRADraftModel) and wrapper
58
+   (OnlineDFlashLoRAModel). This is exactly the approach you described — Qwen3-8B + LoRA, no separate target model, 1-step diffusion
59
+   training. The key differences from train_dflash.py:
60
+
61
+   ┌─────────────────┬─────────────────────────────────────────────────────────────┬────────────────────────────────────────────────────┐
62
+   │     Aspect      │                       train_dflash.py                       │                train_dflash_lora.py                │
63
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
64
+   │ Draft model     │ Small custom DFlashDraftModel (1–10 layers)                 │ Full Qwen3-8B + LoRA adapters                      │
65
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
66
+   │ Target model    │ Separate frozen model for hidden state extraction           │ None — model uses its own representations          │
67
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼───────────────────────────���────────────────────────┤
68
+   │ Attention       │ Custom Qwen3DFlashAttention (Q from noise, KV from [ctx,    │ Standard HF attention with 4D additive DFlash mask │
69
+   │                 │ noise])                                                     │                                                    │
70
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
71
+   │ Forward         │ draft_model(noise_emb, target_hidden, mask)                 │ model(noise_input_ids, 4d_mask, position_ids) →    │
72
+   │                 │                                                             │ logits                                             │
73
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
74
+   │ Trainable       │ All draft model params                                      │ Only LoRA (q/k/v/o_proj), base frozen              │
75
+   │ params          │                                                             │                                                    │
76
+   ├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
77
+   │ FSDP strategy   │ SHARD_GRAD_OP                                               │ FULL_SHARD                                         │
78
+   └─────────────────┴─────────────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘
79
+
80
+   ---
81
+   Proposed Modification Plan
82
+
83
+   Since train_dflash_lora.py already implements the core idea, the plan focuses on what's missing or needs improvement to make it a proper
84
+   "1-step dLLM draft model" for your research:
85
+
86
+   Phase 1: Validate and extend the existing LoRA pipeline
87
+
88
+   1. Add MLP to LoRA targets — The current config only targets q_proj, k_proj, v_proj, o_proj. For stronger 1-step diffusion capability,
89
+   add gate_proj, up_proj, down_proj to lora_target_modules. This gives the model more capacity to learn the non-autoregressive distribution
90
+    shift.
91
+   2. Add multi-step noise schedule support — Currently the training is strictly 1-step (all non-anchors → MASK). For a proper diffusion/AR
92
+   fusion, add an option for a noise schedule where a fraction of block tokens are revealed (not just the anchor), controlled by a
93
+   noise_ratio parameter. This would modify prepare_noise_input() in OnlineDFlashLoRAModel:
94
+   # Instead of: all non-anchor → MASK
95
+   # Allow: randomly keep some non-anchor tokens with probability (1 - noise_ratio)
96
+   3. Add configurable context_len strategy — Currently context_len=0 treats the whole sequence as blocks. Add a --context-ratio arg that
97
+   dynamically sets context_len as a fraction of the sequence, so the model learns to condition on varying amounts of AR-decoded prefix.
98
+
99
+   Phase 2: Training logic improvements
100
+
101
+   4. Add KL divergence loss — In addition to CE loss against ground truth, add an optional KL loss against the base model's AR distribution
102
+    (teacher forcing). This regularizes the LoRA model to stay close to the original Qwen3-8B distribution. Modify
103
+   OnlineDFlashLoRAModel.forward():
104
+   # Compute base model logits (no_grad, no LoRA) as teacher
105
+   # KL(draft_logits || teacher_logits) on block positions
106
+   # total_loss = ce_loss + kl_weight * kl_loss
107
+   5. Add evaluation with speculative decoding metrics — The current accuracy metric is block-wise acceptance rate. Add an eval loop that
108
+   actually runs speculative decoding (draft → verify) to measure real speedup, using the LoRA model as the drafter and the base model (with
109
+    LoRA disabled) as the verifier.
110
+
111
+   Phase 3: Integration with train_dflash.py style features
112
+
113
+   6. Port random anchor sampling — train_dflash.py has --random-anchor (L147–156) which samples diverse anchor positions per sequence. This
114
+    is missing from the LoRA variant and would improve training diversity. Port _sample_anchor_positions and _build_blocks_from_anchors from
115
+    OnlineDFlashModel to OnlineDFlashLoRAModel.
116
+   7. Port chunked lm_head — The LoRA variant materializes full [bsz, seq_len, vocab_size] logits. For Qwen3-8B's 152K vocab, this is ~1.2GB
117
+    per sample at seq_len=2048. Port the _chunked_lm_loss from OnlineDFlashModel for memory efficiency.
118
+   8. Add tensor parallelism support — The LoRA script currently forces tp_size=1. For Qwen3-8B on multi-GPU setups, add TP support to shard
119
+    the base model across GPUs while keeping LoRA params on each rank.
120
+
121
+   Minimal starting point
122
+
123
+   If you want to just get training running immediately, the existing train_dflash_lora.py + configs/qwen3-8b-dflash-lora.json should work
124
+   out of the box:
125
+
126
+   torchrun --nproc_per_node=4 scripts/train_dflash_lora.py \
127
+     --model-path Qwen/Qwen3-8B \
128
+     --lora-config configs/qwen3-8b-dflash-lora.json \
129
+     --train-data-path /path/to/data.jsonl \
130
+     --output-dir ./outputs/qwen3-8b-dflash-lora \
131
+     --block-size 16 \
132
+     --gradient-checkpointing \
133
+     --learning-rate 2e-4 \
134
+     --batch-size 1 \
135
+     --max-length 2048 \
136
+     --num-epochs 3
137
+
138
+   Want me to start implementing any of these phases? I'd suggest starting with Phase 1 (noise schedule + MLP targets) since that's the most
139
+    impactful for the diffusion/AR fusion research angle.
sglang/.codespellrc ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [codespell]
2
+ ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS
3
+ skip = *.json,*.jsonl,*.patch,*.txt
sglang/.editorconfig ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://editorconfig.org/
2
+
3
+ root = true
4
+
5
+ [*]
6
+ charset = utf-8
7
+ end_of_line = lf
8
+ indent_style = space
9
+ indent_size = 4
10
+ trim_trailing_whitespace = true
11
+ insert_final_newline = true
12
+
13
+ [*.{json,yaml,yml}]
14
+ indent_size = 2
15
+
16
+ [*.md]
17
+ indent_size = 2
18
+ x-soft-wrap-text = true
19
+
20
+ [*.rst]
21
+ indent_size = 4
22
+ x-soft-wrap-text = true
23
+
24
+ [Makefile]
25
+ indent_style = tab
sglang/.isort.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [settings]
2
+ profile=black
3
+ known_first_party=sglang
sglang/.pre-commit-config.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_stages: [pre-commit, pre-push, manual]
2
+ exclude: ^(python/sglang/multimodal_gen/csrc|python/sglang/jit_kernel/flash_attention/cute)
3
+
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v6.0.0
7
+ hooks:
8
+ - id: check-symlinks
9
+ - id: destroyed-symlinks
10
+ - id: trailing-whitespace
11
+ - id: end-of-file-fixer
12
+ - id: check-yaml
13
+ args: [--allow-multiple-documents]
14
+ - id: check-toml
15
+ - id: check-ast
16
+ - id: check-added-large-files
17
+ - id: check-merge-conflict
18
+ - id: check-shebang-scripts-are-executable
19
+ - id: detect-private-key
20
+ exclude: ^sgl-model-gateway/tests/.*_test\.rs$
21
+ - id: debug-statements
22
+ - id: no-commit-to-branch
23
+ - repo: https://github.com/PyCQA/isort
24
+ rev: 7.0.0
25
+ hooks:
26
+ - id: isort
27
+ exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
28
+ - repo: https://github.com/astral-sh/ruff-pre-commit
29
+ rev: v0.15.1
30
+ hooks:
31
+ - id: ruff
32
+ args:
33
+ - --select=F401,F821
34
+ - --fix
35
+ files: ^(benchmark/|docs/|examples/|python/sglang/|sgl-model-gateway/py_*|test/)
36
+ exclude: |
37
+ (?x)^(
38
+ .*/__init__\.py$|
39
+ .*\.ipynb$|
40
+ python/sglang/srt/grpc/.*_pb2\.py$|
41
+ python/sglang/srt/grpc/.*_pb2_grpc\.py$|
42
+ python/sglang/srt/grpc/.*_pb2\.pyi$|
43
+ python/sglang/srt/grpc/.*_pb2_grpc\.pyi$|
44
+ )$
45
+ - repo: https://github.com/psf/black
46
+ rev: 26.1.0
47
+ hooks:
48
+ - id: black-jupyter
49
+ exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
50
+ - repo: https://github.com/codespell-project/codespell
51
+ rev: v2.4.1
52
+ hooks:
53
+ - id: codespell
54
+ args: ['--config', '.codespellrc']
55
+ - repo: https://github.com/pre-commit/mirrors-clang-format
56
+ rev: v20.1.7
57
+ hooks:
58
+ - id: clang-format
59
+ types_or: [c++, cuda]
60
+ args: [--style=file, --verbose]
61
+ - repo: https://github.com/kynan/nbstripout
62
+ rev: 0.9.0
63
+ hooks:
64
+ - id: nbstripout
65
+ args:
66
+ - '--keep-output'
67
+ - '--extra-keys=metadata.kernelspec metadata.language_info.version'
68
+ - repo: local
69
+ hooks:
70
+ - id: check-chinese-characters
71
+ name: check chinese characters in multimodal_gen
72
+ entry: >-
73
+ python3 -c 'import sys, re; p=re.compile(r"[\u4e00-\u9fff]"); ec=0; [ ([(print(f"{f}:{i+1}: {l.strip()}") or (ec:=1)) for i,l in enumerate(open(f, "r", encoding="utf-8", errors="ignore")) if p.search(l)]) for f in sys.argv[1:] ]; sys.exit(ec)'
74
+ language: system
75
+ files: ^python/sglang/multimodal_gen/.*
76
+ exclude: ^(python/sglang/multimodal_gen/configs/sample|python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows|python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages)(/|$)
77
+ types_or: [python, markdown, json, text]
78
+ - id: sort-ci-permissions
79
+ name: sort CI_PERMISSIONS.json
80
+ entry: python3 .github/update_ci_permission.py --sort-only
81
+ language: system
82
+ files: ^\.github/CI_PERMISSIONS\.json$
83
+ pass_filenames: false
sglang/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, religion, or sexual identity
10
+ and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the
26
+ overall community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or
31
+ advances of any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email
35
+ address, without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ .
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series
86
+ of actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or
93
+ permanent ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within
113
+ the community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.0, available at
119
+ https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120
+
121
+ Community Impact Guidelines were inspired by [Mozilla's code of conduct
122
+ enforcement ladder](https://github.com/mozilla/diversity).
123
+
124
+ [homepage]: https://www.contributor-covenant.org
125
+
126
+ For answers to common questions about this code of conduct, see the FAQ at
127
+ https://www.contributor-covenant.org/faq. Translations are available at
128
+ https://www.contributor-covenant.org/translations.
sglang/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2023-2024 SGLang Team
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
sglang/README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ <div align="center" id="sglangtop">
4
+ <img src="https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" alt="logo" width="400" margin="10px"></img>
5
+
6
+ [![PyPI](https://img.shields.io/pypi/v/sglang)](https://pypi.org/project/sglang)
7
+ ![PyPI - Downloads](https://static.pepy.tech/badge/sglang?period=month)
8
+ [![license](https://img.shields.io/github/license/sgl-project/sglang.svg)](https://github.com/sgl-project/sglang/tree/main/LICENSE)
9
+ [![issue resolution](https://img.shields.io/github/issues-closed-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues)
10
+ [![open issues](https://img.shields.io/github/issues-raw/sgl-project/sglang)](https://github.com/sgl-project/sglang/issues)
11
+ [![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/sgl-project/sglang)
12
+
13
+ </div>
14
+
15
+ --------------------------------------------------------------------------------
16
+
17
+ <p align="center">
18
+ <a href="https://lmsys.org/blog/"><b>Blog</b></a> |
19
+ <a href="https://docs.sglang.io/"><b>Documentation</b></a> |
20
+ <a href="https://roadmap.sglang.io/"><b>Roadmap</b></a> |
21
+ <a href="https://slack.sglang.io/"><b>Join Slack</b></a> |
22
+ <a href="https://meet.sglang.io/"><b>Weekly Dev Meeting</b></a> |
23
+ <a href="https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides"><b>Slides</b></a>
24
+ </p>
25
+
26
+ ## News
27
+ - [2026/01] 🔥 SGLang Diffusion accelerates video and image generation ([blog](https://lmsys.org/blog/2026-01-16-sglang-diffusion/)).
28
+ - [2025/12] SGLang provides day-0 support for latest open models ([MiMo-V2-Flash](https://lmsys.org/blog/2025-12-16-mimo-v2-flash/), [Nemotron 3 Nano](https://lmsys.org/blog/2025-12-15-run-nvidia-nemotron-3-nano/), [Mistral Large 3](https://github.com/sgl-project/sglang/pull/14213), [LLaDA 2.0 Diffusion LLM](https://lmsys.org/blog/2025-12-19-diffusion-llm/), [MiniMax M2](https://lmsys.org/blog/2025-11-04-miminmax-m2/)).
29
+ - [2025/10] 🔥 SGLang now runs natively on TPU with the SGLang-Jax backend ([blog](https://lmsys.org/blog/2025-10-29-sglang-jax/)).
30
+ - [2025/09] Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part II): 3.8x Prefill, 4.8x Decode Throughput ([blog](https://lmsys.org/blog/2025-09-25-gb200-part-2/)).
31
+ - [2025/09] SGLang Day 0 Support for DeepSeek-V3.2 with Sparse Attention ([blog](https://lmsys.org/blog/2025-09-29-deepseek-V32/)).
32
+ - [2025/08] SGLang x AMD SF Meetup on 8/22: Hands-on GPU workshop, tech talks by AMD/xAI/SGLang, and networking ([Roadmap](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_roadmap.pdf), [Large-scale EP](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_ep.pdf), [Highlights](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_highlights.pdf), [AITER/MoRI](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_aiter_mori.pdf), [Wave](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_wave.pdf)).
33
+
34
+ <details>
35
+ <summary>More</summary>
36
+
37
+ - [2025/11] SGLang Diffusion accelerates video and image generation ([blog](https://lmsys.org/blog/2025-11-07-sglang-diffusion/)).
38
+ - [2025/10] PyTorch Conference 2025 SGLang Talk ([slide](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/sglang_pytorch_2025.pdf)).
39
+ - [2025/10] SGLang x Nvidia SF Meetup on 10/2 ([recap](https://x.com/lmsysorg/status/1975339501934510231)).
40
+ - [2025/08] SGLang provides day-0 support for OpenAI gpt-oss model ([instructions](https://github.com/sgl-project/sglang/issues/8833))
41
+ - [2025/06] SGLang, the high-performance serving infrastructure powering trillions of tokens daily, has been awarded the third batch of the Open Source AI Grant by a16z ([a16z blog](https://a16z.com/advancing-open-source-ai-through-benchmarks-and-bold-experimentation/)).
42
+ - [2025/05] Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)).
43
+ - [2025/06] Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part I): 2.7x Higher Decoding Throughput ([blog](https://lmsys.org/blog/2025-06-16-gb200-part-1/)).
44
+ - [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html))
45
+ - [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))
46
+ - [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html))
47
+ - [2025/01] SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))
48
+ - [2024/12] v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
49
+ - [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
50
+ - [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
51
+ - [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
52
+ - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
53
+ - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
54
+ - [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
55
+
56
+ </details>
57
+
58
+ ## About
59
+ SGLang is a high-performance serving framework for large language models and multimodal models.
60
+ It is designed to deliver low-latency and high-throughput inference across a wide range of setups, from a single GPU to large distributed clusters.
61
+ Its core features include:
62
+
63
+ - **Fast Runtime**: Provides efficient serving with RadixAttention for prefix caching, a zero-overhead CPU scheduler, prefill-decode disaggregation, speculative decoding, continuous batching, paged attention, tensor/pipeline/expert/data parallelism, structured outputs, chunked prefill, quantization (FP4/FP8/INT4/AWQ/GPTQ), and multi-LoRA batching.
64
+ - **Broad Model Support**: Supports a wide range of language models (Llama, Qwen, DeepSeek, Kimi, GLM, GPT, Gemma, Mistral, etc.), embedding models (e5-mistral, gte, mcdse), reward models (Skywork), and diffusion models (WAN, Qwen-Image), with easy extensibility for adding new models. Compatible with most Hugging Face models and OpenAI APIs.
65
+ - **Extensive Hardware Support**: Runs on NVIDIA GPUs (GB200/B300/H100/A100/Spark), AMD GPUs (MI355/MI300), Intel Xeon CPUs, Google TPUs, Ascend NPUs, and more.
66
+ - **Active Community**: SGLang is open-source and supported by a vibrant community with widespread industry adoption, powering over 400,000 GPUs worldwide.
67
+ - **RL & Post-Training Backbone**: SGLang is a proven rollout backend across the world, with native RL integrations and adoption by well-known post-training frameworks such as [**AReaL**](https://github.com/inclusionAI/AReaL), [**Miles**](https://github.com/radixark/miles), [**slime**](https://github.com/THUDM/slime), [**Tunix**](https://github.com/google/tunix), [**verl**](https://github.com/volcengine/verl) and more.
68
+
69
+ ## Getting Started
70
+ - [Install SGLang](https://docs.sglang.io/get_started/install.html)
71
+ - [Quick Start](https://docs.sglang.io/basic_usage/send_request.html)
72
+ - [Backend Tutorial](https://docs.sglang.io/basic_usage/openai_api_completions.html)
73
+ - [Frontend Tutorial](https://docs.sglang.io/references/frontend/frontend_tutorial.html)
74
+ - [Contribution Guide](https://docs.sglang.io/developer_guide/contribution_guide.html)
75
+
76
+ ## Benchmark and Performance
77
+ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/), [GB200 rack-scale parallelism](https://lmsys.org/blog/2025-09-25-gb200-part-2/).
78
+
79
+ ## Adoption and Sponsorship
80
+ SGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia.
81
+ As an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 400,000 GPUs worldwide.
82
+ SGLang is currently hosted under the non-profit open-source organization [LMSYS](https://lmsys.org/about/).
83
+
84
+ <img src="https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/refs/heads/main/slides/adoption.png" alt="logo" width="800" margin="10px"></img>
85
+
86
+ ## Contact Us
87
+ For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at sglang@lmsys.org
88
+
89
+ ## Acknowledgment
90
+ We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
syxin_old/DFLASH_LORA_INJECT_FIXES.md ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash LoRA Inject 修复方案
2
+
3
+ ## 问题背景
4
+
5
+ DFlash LoRA inject 训练后评估结果极差:
6
+ - GSM8K: 5.05× → **1.04×**(baseline → LoRA inject)
7
+ - HumanEval: 5.06× → **0.98×**
8
+ - MT-Bench: 2.70× → **0.85×**
9
+
10
+ 根本原因:LoRA 权重保存格式不正确 + draft model forward 效率问题。
11
+
12
+ ---
13
+
14
+ ## Fix 1: LoRA 权重保存修复(最关键)
15
+
16
+ ### 文件
17
+ `Specforge/scripts/train_dflash_lora_inject.py`
18
+
19
+ ### 问题
20
+ `save_checkpoint()` 函数(L292-306)手动提取 state_dict 中含 `"lora_"` 的 key 保存为 `adapter_model.safetensors`,然后用 `peft_config["default"].save_pretrained()` 只保存了 LoraConfig 的 JSON。
21
+
22
+ 但 PEFT 的 `PeftModel.from_pretrained()` 期望:
23
+ 1. **标准的 `adapter_config.json`**(不是 LoraConfig 直接序列化的格式)
24
+ 2. **key 命名规范不同**(PEFT 内部会自动处理 `base_model.model.` 前缀)
25
+
26
+ 导致评估加载时出现大量 **"Found missing adapter keys"** 警告,LoRA 权重实际未加载。
27
+
28
+ ### 修改
29
+
30
+ **删除 L295-306:**
31
+ ```python
32
+ lora_state_dict = {
33
+ k: v for k, v in module.draft_model.model.state_dict().items()
34
+ if "lora_" in k
35
+ }
36
+
37
+ try:
38
+ from safetensors.torch import save_file as safetensors_save
39
+ safetensors_save(lora_state_dict, os.path.join(save_dir, "adapter_model.safetensors"))
40
+ except (ImportError, Exception):
41
+ torch.save(lora_state_dict, os.path.join(save_dir, "adapter_model.bin"))
42
+
43
+ draft_model.model.peft_config["default"].save_pretrained(save_dir)
44
+ ```
45
+
46
+ **替换为:**
47
+ ```python
48
+ # Use PEFT native save which handles key naming and adapter_config.json correctly
49
+ module.draft_model.model.save_pretrained(save_dir)
50
+ ```
51
+
52
+ `PeftModel.save_pretrained()` 会正确:
53
+ - 自动处理 key 前缀映射
54
+ - 生成标准 `adapter_config.json`
55
+ - 保存 `adapter_model.safetensors`
56
+
57
+ ---
58
+
59
+ ## Fix 2: Draft Model Forward 优化(性能 + 代码清晰度)
60
+
61
+ ### 文件
62
+ `Specforge/specforge/modeling/draft/dflash_lora_inject.py`
63
+
64
+ ### 问题
65
+ `_forward_with_injection()` 当前在 **每一层的 for loop 内** 重复计算:
66
+ - `rotary_emb(layer_input, extended_pos)` — 每层一次,36层重复计算
67
+ - 扩展 attention mask — 每层重建 O(seq_len²) 的 mask
68
+ - `extended_pos = torch.cat([target_pos, position_ids])` — 每层重复拼接
69
+
70
+ ### 修改
71
+ 将 position_embeddings、extended_mask、extended_pos 的计算 **提取到 for loop 之前**(一次性预计算):
72
+
73
+ ```python
74
+ def _forward_with_injection(self, input_ids, attention_mask, target_hidden_states,
75
+ position_ids=None, output_hidden_states=False, context_len=0):
76
+ # ... (get base_model, embed_tokens, layers, norm, lm_head)
77
+
78
+ hidden_states = embed_tokens(input_ids)
79
+ bsz, seq_len, hidden_dim = hidden_states.shape
80
+ ctx_len = target_hidden_states[0].shape[1] if target_hidden_states else 0
81
+ full_seq_len = ctx_len + seq_len
82
+
83
+ # ── Pre-compute position embeddings ONCE ──
84
+ target_pos = torch.arange(ctx_len, device=hidden_states.device)
85
+ draft_pos_ids = position_ids if position_ids is not None else torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
86
+ extended_pos = torch.cat([
87
+ target_pos.unsqueeze(0).expand(bsz, -1),
88
+ draft_pos_ids
89
+ ], dim=1)
90
+
91
+ position_embeddings = None
92
+ if hasattr(base_model.model, 'rotary_emb'):
93
+ dummy = torch.empty(1, full_seq_len, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
94
+ position_embeddings = base_model.model.rotary_emb(dummy, extended_pos)
95
+
96
+ # ── Pre-compute extended attention mask ONCE ──
97
+ extended_mask = attention_mask # fallback
98
+ if attention_mask is not None and attention_mask.dim() == 4:
99
+ # ... (build ctx_mask_full + draft_mask_full, same logic as before)
100
+ # Key: use block_start (NOT block_start + 1) to prevent leakage
101
+ extended_mask = torch.cat([ctx_mask_full, draft_mask_full], dim=2)
102
+
103
+ # ── Layer-by-layer forward ──
104
+ for layer_idx, layer in enumerate(layers):
105
+ if target_hidden_states and layer_idx < len(target_hidden_states):
106
+ target_ctx = target_hidden_states[layer_idx]
107
+ layer_input = torch.cat([target_ctx, hidden_states], dim=1)
108
+
109
+ layer_output = layer(
110
+ layer_input,
111
+ attention_mask=extended_mask,
112
+ position_ids=extended_pos,
113
+ position_embeddings=position_embeddings,
114
+ )
115
+
116
+ hidden_states = layer_output[0][:, ctx_len:, :] if isinstance(layer_output, tuple) else layer_output[:, ctx_len:, :]
117
+ else:
118
+ layer_output = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
119
+ hidden_states = layer_output[0] if isinstance(layer_output, tuple) else layer_output
120
+
121
+ hidden_states = norm(hidden_states)
122
+ if output_hidden_states:
123
+ return hidden_states
124
+ return lm_head(hidden_states)
125
+ ```
126
+
127
+ ---
128
+
129
+ ## 不需要改动的文件
130
+
131
+ | 文件 | 原因 |
132
+ |------|------|
133
+ | `eval_dflash_lora_inject.py` | 推理逻辑正确,position 已对齐 |
134
+ | `specforge/core/dflash_lora_inject.py` | 训练 wrapper 的 mask (`block_start` 不含 +1) 已正确 |
135
+
136
+ ---
137
+
138
+ ## 验证方案
139
+
140
+ 1. **LoRA roundtrip**: 保存后 `PeftModel.from_pretrained()` 加载无 warning
141
+ 2. **Forward 一致性**: 预计算 vs 每层重算输出相同
142
+ 3. **端到端评估**: 重新训练 + `eval_dflash_lora_inject.py` 验证 acceptance length 提升
syxin_old/backup.log ADDED
The diff for this file is too large to render. See raw diff
 
syxin_old/dflash_8gpu_03-31-13:40.log ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+
3
+ *****************************************
4
+ Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
5
+ *****************************************
6
+ Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0
7
+
8
+
9
+
10
+
11
+ Set TORCH_CUDA_ARCH_LIST to 9.0
12
+ Set TORCH_CUDA_ARCH_LIST to 9.0
13
+ Set TORCH_CUDA_ARCH_LIST to 9.0
14
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
15
+ warnings.warn(
16
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
17
+ warnings.warn(
18
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
19
+ warnings.warn(
20
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
21
+ warnings.warn(
22
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
23
+ warnings.warn(
24
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
25
+ warnings.warn(
26
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
27
+ warnings.warn(
28
+ /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
29
+ warnings.warn(
30
+ INFO:specforge.utils:rank 0: bind to device 0
31
+ INFO:specforge.utils:rank 7: bind to device 7
32
+ INFO:specforge.utils:rank 0: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
33
+ INFO:specforge.utils:rank 2: bind to device 2
34
+ INFO:specforge.utils:rank 7: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
35
+ INFO:specforge.utils:rank 0: Initialized distributed
36
+ INFO:specforge.utils:Loading target model from /workspace/models/Qwen3-8B using hf backend
37
+ INFO:specforge.utils:rank 7: Initialized distributed
38
+ INFO:specforge.utils:rank 1: bind to device 1
39
+ INFO:specforge.utils:rank 2: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
40
+ INFO:specforge.utils:rank 2: Initialized distributed
41
+ `torch_dtype` is deprecated! Use `dtype` instead!
42
+ `torch_dtype` is deprecated! Use `dtype` instead!
43
+ `torch_dtype` is deprecated! Use `dtype` instead!
44
+ INFO:specforge.utils:rank 6: bind to device 6
45
+ INFO:specforge.utils:rank 5: bind to device 5
46
+ INFO:specforge.utils:rank 1: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
47
+ INFO:specforge.utils:rank 1: Initialized distributed
48
+ `torch_dtype` is deprecated! Use `dtype` instead!
49
+ INFO:specforge.utils:rank 6: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
50
+ INFO:specforge.utils:rank 4: bind to device 4
51
+ INFO:specforge.utils:rank 5: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
52
+ INFO:specforge.utils:rank 6: Initialized distributed
53
+ INFO:specforge.utils:rank 5: Initialized distributed
54
+ INFO:specforge.utils:rank 4: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
55
+ `torch_dtype` is deprecated! Use `dtype` instead!
56
+ `torch_dtype` is deprecated! Use `dtype` instead!
57
+ INFO:specforge.utils:rank 4: Initialized distributed
58
+ `torch_dtype` is deprecated! Use `dtype` instead!
59
+ INFO:specforge.utils:rank 3: bind to device 3
60
+ [rank2]: Traceback (most recent call last):
61
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
62
+ [rank2]: return next(cls.discover(name=name))
63
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
64
+ [rank2]: StopIteration
65
+
66
+ [rank2]: During handling of the above exception, another exception occurred:
67
+
68
+ [rank2]: Traceback (most recent call last):
69
+ [rank2]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
70
+ [rank2]: main()
71
+ [rank2]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
72
+ [rank2]: target_model, draft_model = build_models(args)
73
+ [rank2]: ^^^^^^^^^^^^^^^^^^
74
+ [rank2]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
75
+ [rank2]: target_model = get_dflash_target_model(
76
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^
77
+ [rank2]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
78
+ [rank2]: return HFDFlashTargetModel.from_pretrained(
79
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
80
+ [rank2]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
81
+ [rank2]: target_model = AutoModelForCausalLM.from_pretrained(
82
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
83
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
84
+ [rank2]: return model_class.from_pretrained(
85
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
86
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
87
+ [rank2]: return func(*args, **kwargs)
88
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^
89
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
90
+ [rank2]: model = cls(config, *model_args, **model_kwargs)
91
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
92
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
93
+ [rank2]: super().__init__(config)
94
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
95
+ [rank2]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
96
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
97
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
98
+ [rank2]: applicable_attn_implementation = self.get_correct_attn_implementation(
99
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
100
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
101
+ [rank2]: self._flash_attn_2_can_dispatch(is_init_check)
102
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
103
+ [rank2]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
104
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
105
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
106
+ [rank2]: return distribution(distribution_name).version
107
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
108
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
109
+ [rank2]: return Distribution.from_name(distribution_name)
110
+ [rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
111
+ [rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
112
+ [rank2]: raise PackageNotFoundError(name)
113
+ [rank2]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
114
+ [rank0]: Traceback (most recent call last):
115
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
116
+ [rank0]: return next(cls.discover(name=name))
117
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
118
+ [rank0]: StopIteration
119
+
120
+ [rank0]: During handling of the above exception, another exception occurred:
121
+
122
+ [rank0]: Traceback (most recent call last):
123
+ [rank0]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
124
+ [rank0]: main()
125
+ [rank0]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
126
+ [rank0]: target_model, draft_model = build_models(args)
127
+ [rank0]: ^^^^^^^^^^^^^^^^^^
128
+ [rank0]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
129
+ [rank0]: target_model = get_dflash_target_model(
130
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
131
+ [rank0]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
132
+ [rank0]: return HFDFlashTargetModel.from_pretrained(
133
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
134
+ [rank0]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
135
+ [rank0]: target_model = AutoModelForCausalLM.from_pretrained(
136
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
137
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
138
+ [rank0]: return model_class.from_pretrained(
139
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
140
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
141
+ [rank0]: return func(*args, **kwargs)
142
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^
143
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
144
+ [rank0]: model = cls(config, *model_args, **model_kwargs)
145
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
146
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
147
+ [rank0]: super().__init__(config)
148
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
149
+ [rank0]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
150
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
151
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
152
+ [rank0]: applicable_attn_implementation = self.get_correct_attn_implementation(
153
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
154
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
155
+ [rank0]: self._flash_attn_2_can_dispatch(is_init_check)
156
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
157
+ [rank0]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
158
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
159
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
160
+ [rank0]: return distribution(distribution_name).version
161
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
162
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
163
+ [rank0]: return Distribution.from_name(distribution_name)
164
+ [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
165
+ [rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
166
+ [rank0]: raise PackageNotFoundError(name)
167
+ [rank0]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
168
+ [rank7]: Traceback (most recent call last):
169
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
170
+ [rank7]: return next(cls.discover(name=name))
171
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
172
+ [rank7]: StopIteration
173
+
174
+ [rank7]: During handling of the above exception, another exception occurred:
175
+
176
+ [rank7]: Traceback (most recent call last):
177
+ [rank7]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
178
+ [rank7]: main()
179
+ [rank7]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
180
+ [rank7]: target_model, draft_model = build_models(args)
181
+ [rank7]: ^^^^^^^^^^^^^^^^^^
182
+ [rank7]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
183
+ [rank7]: target_model = get_dflash_target_model(
184
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^
185
+ [rank7]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
186
+ [rank7]: return HFDFlashTargetModel.from_pretrained(
187
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
188
+ [rank7]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
189
+ [rank7]: target_model = AutoModelForCausalLM.from_pretrained(
190
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
191
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
192
+ [rank7]: return model_class.from_pretrained(
193
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
194
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
195
+ [rank7]: return func(*args, **kwargs)
196
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^
197
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
198
+ [rank7]: model = cls(config, *model_args, **model_kwargs)
199
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
200
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
201
+ [rank7]: super().__init__(config)
202
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
203
+ [rank7]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
204
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
205
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
206
+ [rank7]: applicable_attn_implementation = self.get_correct_attn_implementation(
207
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
208
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
209
+ [rank7]: self._flash_attn_2_can_dispatch(is_init_check)
210
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
211
+ [rank7]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
212
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
213
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
214
+ [rank7]: return distribution(distribution_name).version
215
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
217
+ [rank7]: return Distribution.from_name(distribution_name)
218
+ [rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
219
+ [rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
220
+ [rank7]: raise PackageNotFoundError(name)
221
+ [rank7]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
222
+ [rank1]: Traceback (most recent call last):
223
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
224
+ [rank1]: return next(cls.discover(name=name))
225
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
226
+ [rank1]: StopIteration
227
+
228
+ [rank1]: During handling of the above exception, another exception occurred:
229
+
230
+ [rank1]: Traceback (most recent call last):
231
+ [rank1]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
232
+ [rank1]: main()
233
+ [rank1]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
234
+ [rank1]: target_model, draft_model = build_models(args)
235
+ [rank1]: ^^^^^^^^^^^^^^^^^^
236
+ [rank1]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
237
+ [rank1]: target_model = get_dflash_target_model(
238
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^
239
+ [rank1]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
240
+ [rank1]: return HFDFlashTargetModel.from_pretrained(
241
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
242
+ [rank1]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
243
+ [rank1]: target_model = AutoModelForCausalLM.from_pretrained(
244
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
245
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
246
+ [rank1]: return model_class.from_pretrained(
247
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
248
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
249
+ [rank1]: return func(*args, **kwargs)
250
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^
251
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
252
+ [rank1]: model = cls(config, *model_args, **model_kwargs)
253
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
254
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
255
+ [rank1]: super().__init__(config)
256
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
257
+ [rank1]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
258
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
259
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
260
+ [rank1]: applicable_attn_implementation = self.get_correct_attn_implementation(
261
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
262
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
263
+ [rank1]: self._flash_attn_2_can_dispatch(is_init_check)
264
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
265
+ [rank1]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
266
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
267
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
268
+ [rank1]: return distribution(distribution_name).version
269
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
270
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
271
+ [rank1]: return Distribution.from_name(distribution_name)
272
+ [rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
273
+ [rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
274
+ [rank1]: raise PackageNotFoundError(name)
275
+ [rank1]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
276
+ INFO:specforge.utils:rank 3: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
277
+ INFO:specforge.utils:rank 3: Initialized distributed
278
+ `torch_dtype` is deprecated! Use `dtype` instead!
279
+ [rank5]: Traceback (most recent call last):
280
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
281
+ [rank5]: return next(cls.discover(name=name))
282
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
283
+ [rank5]: StopIteration
284
+
285
+ [rank5]: During handling of the above exception, another exception occurred:
286
+
287
+ [rank5]: Traceback (most recent call last):
288
+ [rank5]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
289
+ [rank5]: main()
290
+ [rank5]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
291
+ [rank5]: target_model, draft_model = build_models(args)
292
+ [rank5]: ^^^^^^^^^^^^^^^^^^
293
+ [rank5]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
294
+ [rank5]: target_model = get_dflash_target_model(
295
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^
296
+ [rank5]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
297
+ [rank5]: return HFDFlashTargetModel.from_pretrained(
298
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
299
+ [rank5]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
300
+ [rank5]: target_model = AutoModelForCausalLM.from_pretrained(
301
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
302
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
303
+ [rank5]: return model_class.from_pretrained(
304
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
305
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
306
+ [rank5]: return func(*args, **kwargs)
307
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^
308
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
309
+ [rank5]: model = cls(config, *model_args, **model_kwargs)
310
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
311
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
312
+ [rank5]: super().__init__(config)
313
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
314
+ [rank5]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
315
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
316
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
317
+ [rank5]: applicable_attn_implementation = self.get_correct_attn_implementation(
318
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
319
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
320
+ [rank5]: self._flash_attn_2_can_dispatch(is_init_check)
321
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
322
+ [rank5]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
323
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
324
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
325
+ [rank5]: return distribution(distribution_name).version
326
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
327
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
328
+ [rank5]: return Distribution.from_name(distribution_name)
329
+ [rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
330
+ [rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
331
+ [rank5]: raise PackageNotFoundError(name)
332
+ [rank5]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
333
+ [rank4]: Traceback (most recent call last):
334
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
335
+ [rank4]: return next(cls.discover(name=name))
336
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
337
+ [rank4]: StopIteration
338
+
339
+ [rank4]: During handling of the above exception, another exception occurred:
340
+
341
+ [rank4]: Traceback (most recent call last):
342
+ [rank4]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
343
+ [rank4]: main()
344
+ [rank4]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
345
+ [rank4]: target_model, draft_model = build_models(args)
346
+ [rank4]: ^^^^^^^^^^^^^^^^^^
347
+ [rank4]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
348
+ [rank4]: target_model = get_dflash_target_model(
349
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^
350
+ [rank4]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
351
+ [rank4]: return HFDFlashTargetModel.from_pretrained(
352
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
353
+ [rank4]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
354
+ [rank4]: target_model = AutoModelForCausalLM.from_pretrained(
355
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
356
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
357
+ [rank4]: return model_class.from_pretrained(
358
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
359
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
360
+ [rank4]: return func(*args, **kwargs)
361
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^
362
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
363
+ [rank4]: model = cls(config, *model_args, **model_kwargs)
364
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
365
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
366
+ [rank4]: super().__init__(config)
367
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
368
+ [rank4]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
369
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
370
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
371
+ [rank4]: applicable_attn_implementation = self.get_correct_attn_implementation(
372
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
373
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
374
+ [rank4]: self._flash_attn_2_can_dispatch(is_init_check)
375
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
376
+ [rank4]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
377
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
378
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
379
+ [rank4]: return distribution(distribution_name).version
380
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
381
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
382
+ [rank4]: return Distribution.from_name(distribution_name)
383
+ [rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
384
+ [rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
385
+ [rank4]: raise PackageNotFoundError(name)
386
+ [rank4]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
387
+ [rank6]: Traceback (most recent call last):
388
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
389
+ [rank6]: return next(cls.discover(name=name))
390
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
391
+ [rank6]: StopIteration
392
+
393
+ [rank6]: During handling of the above exception, another exception occurred:
394
+
395
+ [rank6]: Traceback (most recent call last):
396
+ [rank6]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
397
+ [rank6]: main()
398
+ [rank6]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
399
+ [rank6]: target_model, draft_model = build_models(args)
400
+ [rank6]: ^^^^^^^^^^^^^^^^^^
401
+ [rank6]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
402
+ [rank6]: target_model = get_dflash_target_model(
403
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^
404
+ [rank6]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
405
+ [rank6]: return HFDFlashTargetModel.from_pretrained(
406
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
407
+ [rank6]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
408
+ [rank6]: target_model = AutoModelForCausalLM.from_pretrained(
409
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
410
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
411
+ [rank6]: return model_class.from_pretrained(
412
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
413
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
414
+ [rank6]: return func(*args, **kwargs)
415
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^
416
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
417
+ [rank6]: model = cls(config, *model_args, **model_kwargs)
418
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
419
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
420
+ [rank6]: super().__init__(config)
421
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
422
+ [rank6]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
423
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
424
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
425
+ [rank6]: applicable_attn_implementation = self.get_correct_attn_implementation(
426
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
427
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
428
+ [rank6]: self._flash_attn_2_can_dispatch(is_init_check)
429
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
430
+ [rank6]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
431
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
432
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
433
+ [rank6]: return distribution(distribution_name).version
434
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
435
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
436
+ [rank6]: return Distribution.from_name(distribution_name)
437
+ [rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
438
+ [rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
439
+ [rank6]: raise PackageNotFoundError(name)
440
+ [rank6]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
441
+ [rank3]: Traceback (most recent call last):
442
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
443
+ [rank3]: return next(cls.discover(name=name))
444
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
445
+ [rank3]: StopIteration
446
+
447
+ [rank3]: During handling of the above exception, another exception occurred:
448
+
449
+ [rank3]: Traceback (most recent call last):
450
+ [rank3]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
451
+ [rank3]: main()
452
+ [rank3]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
453
+ [rank3]: target_model, draft_model = build_models(args)
454
+ [rank3]: ^^^^^^^^^^^^^^^^^^
455
+ [rank3]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
456
+ [rank3]: target_model = get_dflash_target_model(
457
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^
458
+ [rank3]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
459
+ [rank3]: return HFDFlashTargetModel.from_pretrained(
460
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
461
+ [rank3]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
462
+ [rank3]: target_model = AutoModelForCausalLM.from_pretrained(
463
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
464
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
465
+ [rank3]: return model_class.from_pretrained(
466
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
467
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
468
+ [rank3]: return func(*args, **kwargs)
469
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^
470
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
471
+ [rank3]: model = cls(config, *model_args, **model_kwargs)
472
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
473
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
474
+ [rank3]: super().__init__(config)
475
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
476
+ [rank3]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
477
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
478
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
479
+ [rank3]: applicable_attn_implementation = self.get_correct_attn_implementation(
480
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
481
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
482
+ [rank3]: self._flash_attn_2_can_dispatch(is_init_check)
483
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
484
+ [rank3]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
485
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
486
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
487
+ [rank3]: return distribution(distribution_name).version
488
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
489
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
490
+ [rank3]: return Distribution.from_name(distribution_name)
491
+ [rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
492
+ [rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
493
+ [rank3]: raise PackageNotFoundError(name)
494
+ [rank3]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
495
+ [rank0]:[W331 13:41:10.473818504 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
496
+ [rank7]:[W331 13:41:10.548010235 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
497
+ [rank7]:[W331 13:41:10.659783753 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
498
+ [rank4]:[W331 13:41:11.950068591 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
499
+ [rank2]:[W331 13:41:11.951701730 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
500
+ [rank6]:[W331 13:41:11.974675832 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
501
+ [rank5]:[W331 13:41:11.997313679 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
502
+ [rank3]:[W331 13:41:11.024650758 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
503
+ [rank1]:[W331 13:41:11.024685351 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
504
+ [rank2]:[W331 13:41:11.101274402 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
505
+ [rank4]:[W331 13:41:11.102711684 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
506
+ [rank6]:[W331 13:41:11.121351120 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
507
+ [rank0]:[W331 13:41:11.122852367 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
508
+ [rank1]:[W331 13:41:11.167109415 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
509
+ [rank3]:[W331 13:41:11.170910568 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
510
+ [rank5]:[W331 13:41:11.173578451 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
511
+ W0331 13:41:11.393000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 641 closing signal SIGTERM
512
+ W0331 13:41:11.393000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 642 closing signal SIGTERM
513
+ W0331 13:41:11.394000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 643 closing signal SIGTERM
514
+ W0331 13:41:11.394000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 644 closing signal SIGTERM
515
+ W0331 13:41:11.394000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 645 closing signal SIGTERM
516
+ W0331 13:41:11.395000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 646 closing signal SIGTERM
517
+ W0331 13:41:11.395000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 647 closing signal SIGTERM
518
+ E0331 13:41:12.401000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 7 (pid: 648) of binary: /workspace/miniconda3/envs/spec/bin/python3
519
+ Traceback (most recent call last):
520
+ File "<frozen runpy>", line 198, in _run_module_as_main
521
+ File "<frozen runpy>", line 88, in _run_code
522
+ File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/run.py", line 940, in <module>
523
+ main()
524
+ File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
525
+ return f(*args, **kwargs)
526
+ ^^^^^^^^^^^^^^^^^^
527
+ File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/run.py", line 936, in main
528
+ run(args)
529
+ File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/run.py", line 927, in run
530
+ elastic_launch(
531
+ File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
532
+ return launch_agent(self._config, self._entrypoint, list(args))
533
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
534
+ File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
535
+ raise ChildFailedError(
536
+ torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
537
+ ============================================================
538
+ /workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py FAILED
539
+ ------------------------------------------------------------
540
+ Failures:
541
+ <NO_OTHER_FAILURES>
542
+ ------------------------------------------------------------
543
+ Root Cause (first observed failure):
544
+ [0]:
545
+ time : 2026-03-31_13:41:11
546
+ host : job-006ce80a7c47-20260302193512-5cd88f7cfc-mlbh9
547
+ rank : 7 (local_rank: 7)
548
+ exitcode : 1 (pid: 648)
549
+ error_file: <N/A>
550
+ traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
551
+ ============================================================
552
+ [W331 13:41:12.379198750 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
syxin_old/diagnostic_compare.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Diagnostic: compare training forward vs eval forward on the same block.
3
+
4
+ Goal: find where the train-eval mismatch is.
5
+ Runs on a single GPU (no distributed).
6
+
7
+ Usage:
8
+ /workspace/miniconda3/envs/dflash/bin/python3 /workspace/hanrui/syxin_old/diagnostic_compare.py
9
+ """
10
+ import sys, os, warnings
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+
15
+ sys.path.insert(0, "/workspace/hanrui/syxin_old")
16
+ sys.path.insert(0, "/workspace/hanrui/syxin_old/Specforge")
17
+
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+ from peft import PeftModel
20
+
21
+ BASE_MODEL = "/workspace/models/Qwen3-8B"
22
+ ADAPTER_PATH = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_4644"
23
+ BLOCK_SIZE = 16
24
+ MASK_TOKEN_ID = 151666
25
+
26
+ device = torch.device("cuda:0")
27
+
28
+
29
+ def main():
30
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
31
+
32
+ print("Loading target model...")
33
+ target_model = AutoModelForCausalLM.from_pretrained(
34
+ BASE_MODEL, torch_dtype=torch.bfloat16,
35
+ attn_implementation="sdpa", device_map=device, trust_remote_code=True,
36
+ )
37
+ target_model.eval()
38
+
39
+ print("Loading draft model (base + LoRA adapter)...")
40
+ draft_model = AutoModelForCausalLM.from_pretrained(
41
+ BASE_MODEL, torch_dtype=torch.bfloat16,
42
+ attn_implementation="sdpa", device_map=device, trust_remote_code=True,
43
+ )
44
+ draft_model = PeftModel.from_pretrained(draft_model, ADAPTER_PATH)
45
+ draft_model = draft_model.merge_and_unload()
46
+ draft_model.eval()
47
+
48
+ num_layers = len(draft_model.model.layers)
49
+ draft_layers = draft_model.model.layers
50
+ draft_norm = draft_model.model.norm
51
+ draft_lm_head = draft_model.lm_head
52
+ rotary_emb = draft_model.model.rotary_emb
53
+
54
+ # Create a test sequence
55
+ text = "The quick brown fox jumps over the lazy dog. " * 10
56
+ messages = [{"role": "user", "content": text}]
57
+ input_text = tokenizer.apply_chat_template(
58
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False,
59
+ )
60
+ full_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
61
+
62
+ # Extend to get a sequence long enough for multiple blocks
63
+ # Use the target model to generate some tokens
64
+ print(f"Input length: {full_ids.shape[1]}")
65
+
66
+ # We'll create a fixed sequence by repeating the prompt
67
+ # Actually, let's use the full_ids as is (it should be ~200 tokens)
68
+ seq_len = full_ids.shape[1]
69
+
70
+ # Make it align to block boundaries
71
+ n_blocks = seq_len // BLOCK_SIZE
72
+ effective_len = n_blocks * BLOCK_SIZE
73
+ input_ids = full_ids[:, :effective_len]
74
+ seq_len = effective_len
75
+
76
+ print(f"Using {n_blocks} blocks, seq_len = {seq_len}")
77
+
78
+ # ═══════════════════════════════════════════════════════
79
+ # TRAINING-STYLE FORWARD
80
+ # ═══════════════════════════════════════════════════════
81
+ print("\n" + "="*60)
82
+ print("TRAINING-STYLE FORWARD")
83
+ print("="*60)
84
+
85
+ with torch.no_grad():
86
+ # Step 1: Get target hidden states (full sequence)
87
+ target_output = target_model(
88
+ input_ids,
89
+ output_hidden_states=True,
90
+ )
91
+ # target hidden states: [hidden_states[0], ..., hidden_states[L-1]]
92
+ # hidden_states[k] = input to layer k
93
+ target_hidden_states = [target_output.hidden_states[i] for i in range(num_layers)]
94
+
95
+ # Step 2: Prepare noise input (mask non-anchors)
96
+ noise_input = input_ids.clone()
97
+ positions = torch.arange(seq_len, device=device)
98
+ is_anchor = (positions % BLOCK_SIZE) == 0
99
+ noise_input[:, ~is_anchor] = MASK_TOKEN_ID
100
+
101
+ # Step 3: Build DFlash mask (draft-to-draft)
102
+ NEG_INF = torch.finfo(torch.bfloat16).min
103
+ block_ids_mask = positions // BLOCK_SIZE
104
+ q_ids = block_ids_mask.unsqueeze(1)
105
+ k_ids = block_ids_mask.unsqueeze(0)
106
+ same_block = (q_ids == k_ids)
107
+ dflash_mask = torch.full((seq_len, seq_len), NEG_INF, device=device, dtype=torch.bfloat16)
108
+ dflash_mask[same_block] = 0.0
109
+ dflash_mask = dflash_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
110
+
111
+ # Step 4: Forward through draft model with injection
112
+ # (Mimicking _forward_with_injection with context_len=0)
113
+ ctx_len = seq_len # target hidden states span full sequence
114
+ full_seq_len = ctx_len + seq_len
115
+
116
+ # Position IDs: [0..N-1, 0..N-1]
117
+ orig_pos = torch.arange(seq_len, device=device)
118
+ extended_pos = torch.cat([orig_pos, orig_pos], dim=0).unsqueeze(0)
119
+
120
+ # Extended mask
121
+ # Context-to-context: causal
122
+ ctx_ctx_mask = torch.full((ctx_len, ctx_len), NEG_INF, device=device, dtype=torch.bfloat16)
123
+ ctx_ctx_mask = torch.triu(ctx_ctx_mask, diagonal=1)
124
+ ctx_draft_mask = torch.full((ctx_len, seq_len), NEG_INF, device=device, dtype=torch.bfloat16)
125
+ ctx_mask_full = torch.cat([ctx_ctx_mask, ctx_draft_mask], dim=-1)
126
+ ctx_mask_full = ctx_mask_full.unsqueeze(0).unsqueeze(0)
127
+
128
+ draft_mask_full = torch.full((1, 1, seq_len, full_seq_len), NEG_INF, device=device, dtype=torch.bfloat16)
129
+
130
+ # Draft-to-target visibility
131
+ draft_pos = torch.arange(seq_len, device=device)
132
+ target_pos = torch.arange(ctx_len, device=device)
133
+ context_len = 0 # training uses context_len=0
134
+
135
+ is_ctx = draft_pos < context_len # all False
136
+ block_id = (draft_pos - context_len).clamp(min=0) // BLOCK_SIZE
137
+ block_start = context_len + block_id * BLOCK_SIZE
138
+ max_visible = torch.where(is_ctx, draft_pos + 1, block_start + 1)
139
+ visible = target_pos.unsqueeze(0) < max_visible.unsqueeze(1)
140
+ draft_mask_full[:, :, :, :ctx_len].masked_fill_(
141
+ visible.unsqueeze(0).unsqueeze(0), 0
142
+ )
143
+
144
+ # Draft-to-draft
145
+ draft_mask_full[:, :, :, ctx_len:] = dflash_mask
146
+
147
+ extended_mask = torch.cat([ctx_mask_full, draft_mask_full], dim=2)
148
+
149
+ # Position embeddings
150
+ dummy = torch.empty(1, full_seq_len, target_hidden_states[0].shape[-1],
151
+ device=device, dtype=torch.bfloat16)
152
+ position_embeddings = rotary_emb(dummy, extended_pos)
153
+
154
+ # Layer-by-layer forward
155
+ hidden_states = draft_model.model.embed_tokens(noise_input)
156
+
157
+ for layer_idx in range(num_layers):
158
+ target_ctx = target_hidden_states[layer_idx]
159
+ layer_input = torch.cat([target_ctx, hidden_states], dim=1)
160
+
161
+ layer_output = draft_layers[layer_idx](
162
+ layer_input,
163
+ attention_mask=extended_mask,
164
+ position_ids=extended_pos,
165
+ position_embeddings=position_embeddings,
166
+ )
167
+ if isinstance(layer_output, tuple):
168
+ layer_output = layer_output[0]
169
+ hidden_states = layer_output[:, ctx_len:, :]
170
+
171
+ hidden_states = draft_norm(hidden_states)
172
+ train_logits = draft_lm_head(hidden_states) # [1, seq_len, vocab_size]
173
+ train_preds = train_logits.argmax(dim=-1) # [1, seq_len]
174
+
175
+ # Compute training accuracy per block
176
+ print("\nTraining-style per-block accuracy (consecutive correct from pos 1):")
177
+ for b in range(n_blocks):
178
+ start = b * BLOCK_SIZE
179
+ block_preds = train_preds[0, start:start + BLOCK_SIZE]
180
+ block_labels = input_ids[0, start:start + BLOCK_SIZE]
181
+ correct = (block_preds[1:] == block_labels[1:]) # skip anchor
182
+ cumprod = correct.cumprod(dim=0)
183
+ accept_len = cumprod.sum().item()
184
+ print(f" Block {b} (pos {start}-{start+15}): accept_len={accept_len}, "
185
+ f"token_acc={correct.float().mean():.3f}")
186
+
187
+ # ═══════════════════════════════════════════════════════
188
+ # EVAL-STYLE FORWARD (block by block)
189
+ # ═══════════════════════════════════════════════════════
190
+ print("\n" + "="*60)
191
+ print("EVAL-STYLE FORWARD (block by block)")
192
+ print("="*60)
193
+
194
+ with torch.no_grad():
195
+ # Get target hidden states for the full sequence (to use as context)
196
+ # In real eval, these would come from incremental target forwards
197
+ # Here we use the same full-sequence target hidden states
198
+
199
+ for b in range(n_blocks):
200
+ start = b * BLOCK_SIZE
201
+ end = start + BLOCK_SIZE
202
+
203
+ # Block input: anchor + MASK
204
+ block_ids = input_ids[:, start:end].clone()
205
+ block_ids[:, 1:] = MASK_TOKEN_ID # mask non-anchors
206
+
207
+ # Context: target hidden states for positions 0..start (inclusive)
208
+ # This matches the training visibility: block k sees target 0..k*16
209
+ ctx_end = start + 1 # include anchor position
210
+
211
+ if ctx_end == 0:
212
+ # Block 0 with no context — skip (can't have empty context)
213
+ # Actually block 0 has anchor at position 0, so ctx_end = 1
214
+ ctx_end = 1
215
+
216
+ ctx_len_eval = ctx_end
217
+ actual_bs = BLOCK_SIZE
218
+
219
+ # Build eval mask
220
+ full_len_eval = ctx_len_eval + actual_bs
221
+ eval_mask = torch.full((1, 1, full_len_eval, full_len_eval), NEG_INF,
222
+ device=device, dtype=torch.bfloat16)
223
+
224
+ # Context-to-context: causal
225
+ if ctx_len_eval > 0:
226
+ ctx_rows = torch.arange(ctx_len_eval, device=device)
227
+ ctx_cols = torch.arange(ctx_len_eval, device=device)
228
+ causal = ctx_cols.unsqueeze(0) <= ctx_rows.unsqueeze(1)
229
+ eval_mask[0, 0, :ctx_len_eval, :ctx_len_eval].masked_fill_(causal, 0)
230
+
231
+ # Block-to-context: all visible
232
+ eval_mask[0, 0, ctx_len_eval:, :ctx_len_eval] = 0
233
+ # Block-to-block: bidirectional
234
+ eval_mask[0, 0, ctx_len_eval:, ctx_len_eval:] = 0
235
+
236
+ # Position IDs
237
+ ctx_positions = torch.arange(ctx_len_eval, device=device)
238
+ block_positions = torch.arange(start, start + actual_bs, device=device)
239
+ combined_pos = torch.cat([ctx_positions, block_positions], dim=0).unsqueeze(0)
240
+
241
+ # Position embeddings
242
+ hidden_dim = target_hidden_states[0].shape[-1]
243
+ dummy_eval = torch.empty(1, full_len_eval, hidden_dim, device=device, dtype=torch.bfloat16)
244
+ pos_emb_eval = rotary_emb(dummy_eval, combined_pos)
245
+
246
+ # Draft forward
247
+ draft_hidden = draft_model.model.embed_tokens(block_ids)
248
+
249
+ for layer_idx in range(num_layers):
250
+ target_ctx = target_hidden_states[layer_idx][:, :ctx_end, :]
251
+ combined = torch.cat([target_ctx, draft_hidden], dim=1)
252
+
253
+ layer_output = draft_layers[layer_idx](
254
+ combined,
255
+ attention_mask=eval_mask,
256
+ position_ids=combined_pos,
257
+ position_embeddings=pos_emb_eval,
258
+ )
259
+ if isinstance(layer_output, tuple):
260
+ layer_output = layer_output[0]
261
+ draft_hidden = layer_output[:, ctx_len_eval:, :]
262
+
263
+ draft_hidden = draft_norm(draft_hidden)
264
+ eval_logits = draft_lm_head(draft_hidden) # [1, 16, vocab_size]
265
+ eval_preds = eval_logits.argmax(dim=-1) # [1, 16]
266
+
267
+ # Compare with training
268
+ train_block_preds = train_preds[0, start:end]
269
+ eval_block_preds = eval_preds[0]
270
+ block_labels = input_ids[0, start:end]
271
+
272
+ train_correct = (train_block_preds[1:] == block_labels[1:])
273
+ eval_correct = (eval_block_preds[1:] == block_labels[1:])
274
+ preds_match = (train_block_preds == eval_block_preds)
275
+
276
+ train_accept = train_correct.cumprod(dim=0).sum().item()
277
+ eval_accept = eval_correct.cumprod(dim=0).sum().item()
278
+
279
+ # Check if logits are close
280
+ train_block_logits = train_logits[0, start:end, :]
281
+ eval_block_logits = eval_logits[0, :, :]
282
+ logit_diff = (train_block_logits - eval_block_logits).abs().max().item()
283
+ logit_rmse = ((train_block_logits - eval_block_logits)**2).mean().sqrt().item()
284
+
285
+ print(f" Block {b} (pos {start}-{start+15}):")
286
+ print(f" Train accept_len={train_accept}, Eval accept_len={eval_accept}")
287
+ print(f" Predictions match: {preds_match.sum().item()}/{BLOCK_SIZE}")
288
+ print(f" Logit max_diff={logit_diff:.4f}, rmse={logit_rmse:.6f}")
289
+
290
+ if not preds_match.all():
291
+ mismatch_pos = (~preds_match).nonzero(as_tuple=True)[0]
292
+ for pos in mismatch_pos[:5]: # show first 5 mismatches
293
+ p = pos.item()
294
+ print(f" Mismatch at block pos {p}: "
295
+ f"train={train_block_preds[p].item()}, "
296
+ f"eval={eval_block_preds[p].item()}, "
297
+ f"label={block_labels[p].item()}")
298
+
299
+
300
+ if __name__ == "__main__":
301
+ main()
syxin_old/eval_alignment_diff.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash Eval 对齐分析:你的脚本 vs 官方 benchmark.py
2
+
3
+ ## 修改总览
4
+
5
+ | # | 问题 | 影响 | baseline | lora_inject |
6
+ |---|------|------|----------|-------------|
7
+ | 1 | Acceptance Length 计算方式 | 🔴 数值不同 | ✅ 修 | ✅ 修 |
8
+ | 2 | Multi-turn 对话支持 | 🔴 mt-bench 结果不同 | ✅ 修 | ✅ 修 |
9
+ | 3 | 样本选择:顺序 vs shuffle | 🔴 子集不同 | ✅ 修 | ✅ 修 |
10
+ | 4 | 数据集只有3个,官方10个 | 🔴 覆盖不全 | ✅ 修 | ✅ 修 |
11
+ | 5 | stop_token_ids 检查范围 | 🟡 可能提前/延迟停止 | ✅ 修 | ✅ 修 |
12
+ | 6 | 分布式聚合方式 | 🟡 丢失per-sample粒度 | ✅ 修 | ✅ 修 |
13
+ | 7 | AR baseline 含 draft forward | 🔴 speedup偏高(仅inject) | N/A | ✅ 修 |
14
+ | 8 | max_new_tokens 默认值 | 🟡 | ✅ 修 | ✅ 修 |
15
+
16
+ ---
17
+
18
+ ## 修改详情
19
+
20
+ ### 1. Acceptance Length 计算方式
21
+
22
+ **问题:** 官方是 per-sample mean 再取 mean,你的是全局池化。
23
+
24
+ ```python
25
+ # ❌ 你的(两个脚本都是)
26
+ avg_accept_length = total_accept_sum / total_count
27
+
28
+ # ✅ 官方
29
+ tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses])
30
+ ```
31
+
32
+ **修改:** 收集每个 sample 的 accept_lengths list,先算 per-sample mean,再取 mean。
33
+
34
+ ---
35
+
36
+ ### 2. Multi-turn 对话支持
37
+
38
+ **问题:** 官方对 mt-bench 等多轮数据集,会逐轮生成并将 assistant 回复加入 context。
39
+
40
+ ```python
41
+ # ❌ 你的:只取 turns[0],单轮
42
+ messages = [{"role": "user", "content": prompt}]
43
+
44
+ # ✅ 官方:逐轮生成
45
+ for turn_index, user_content in enumerate(instance["turns"]):
46
+ messages.append({"role": "user", "content": user_content})
47
+ # generate ...
48
+ messages.append({"role": "assistant", "content": output_text})
49
+ responses.append(response)
50
+ ```
51
+
52
+ **修改:** 数据加载改为返回 `{"turns": [...]}` 格式,生成循环改为逐轮。
53
+
54
+ ---
55
+
56
+ ### 3. 样本选择
57
+
58
+ **问题:** 官方 shuffle 后选取,你的是顺序取前N。
59
+
60
+ ```python
61
+ # ❌ 你的
62
+ prompts = prompts[:num_samples]
63
+
64
+ # ✅ 官方
65
+ dataset = dataset.shuffle(seed=0).select(range(max_samples))
66
+ ```
67
+
68
+ **修改:** 改用 HF dataset 的 shuffle + select。
69
+
70
+ ---
71
+
72
+ ### 4. 数据集补齐
73
+
74
+ **问题:** 缺少 math500, aime24, aime25, mbpp, livecodebench, swe-bench, alpaca 共 7 个。
75
+
76
+ **修改:** 直接复用官方 `load_and_process_dataset()` 函数。
77
+
78
+ ---
79
+
80
+ ### 5. stop_token_ids 检查范围
81
+
82
+ ```python
83
+ # ❌ 你的 baseline(检查到 start)
84
+ output_ids[:, num_input_tokens:start]
85
+
86
+ # ✅ 官方(检查所有已生成)
87
+ output_ids[:, num_input_tokens:]
88
+ ```
89
+
90
+ **修改:** 改为 `output_ids[:, num_input_tokens:]`。
91
+
92
+ ---
93
+
94
+ ### 6. 分布式聚合
95
+
96
+ **问题:** 你用 all_reduce 聚合标量,丢失 per-sample 粒度。
97
+
98
+ ```python
99
+ # ❌ 你的:all_reduce sum/count
100
+ dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
101
+
102
+ # ✅ 官方:gather 完整 response list
103
+ responses = dist.gather(responses, dst=0)
104
+ ```
105
+
106
+ **修改:** 改为 gather per-sample 的 acceptance_lengths + time metrics 到 rank 0 统一计算。
107
+
108
+ ---
109
+
110
+ ### 7. AR baseline 含 draft forward (仅 lora_inject)
111
+
112
+ **问题:** block_size=1 时仍跑完整的 inject pipeline(包含 draft model),导致 AR 时间偏大、speedup 偏高。
113
+
114
+ ```python
115
+ # ❌ 你的 lora_inject AR baseline
116
+ spec_generate_inject(..., block_size=1) # 仍会过 draft model layers
117
+
118
+ # ✅ 应该:纯 target autoregressive
119
+ ar_generate(target_model, input_ids, ...) # 只用 target
120
+ ```
121
+
122
+ **修改:** 新增纯 AR 生成函数,block_size=1 时不经过 draft。
123
+
124
+ ---
125
+
126
+ ### 8. max_new_tokens 默认值
127
+
128
+ ```python
129
+ # 官方 shell 脚本用 2048(Python 默认 16384)
130
+ # 你的默认 2048,和 shell 一致,保持不变
131
+ # 但增加提示,显式说明
132
+ ```
syxin_old/eval_dflash_b16_baseline.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Offline evaluation for DFlash-b16 baseline: measure accepted length.
4
+ 8 GPUs parallel, each GPU loads target + draft independently.
5
+
6
+ Usage:
7
+ # 8 GPUs
8
+ torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py
9
+
10
+ # quick test
11
+ torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py --num-samples 20
12
+
13
+ # single GPU
14
+ python3 eval_dflash_b16_baseline.py --benchmarks humaneval
15
+ """
16
+ import argparse
17
+ import json
18
+ import os
19
+ import sys
20
+ import time
21
+ from typing import List, Optional, Tuple
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.distributed as dist
26
+ from tqdm import tqdm
27
+ from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache
28
+
29
+ # Add DFlash model path so we can import utils
30
+ sys.path.insert(0, "/workspace/models/Qwen3-8B-DFlash-b16")
31
+ from utils import extract_context_feature, sample
32
+
33
+ # ──────────────────────────────────────────────────────────────────
34
+ BASE_MODEL = "/workspace/models/Qwen3-8B"
35
+ DRAFT_MODEL = "/workspace/models/Qwen3-8B-DFlash-b16"
36
+ RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results"
37
+
38
+
39
+ # ──────────────────────────────────────────────────────────────────
40
+ # Distributed helpers
41
+ # ──────────────────────────────────────────────────────────────────
42
+ def is_distributed():
43
+ return dist.is_available() and dist.is_initialized()
44
+
45
+ def get_rank():
46
+ return dist.get_rank() if is_distributed() else 0
47
+
48
+ def get_world_size():
49
+ return dist.get_world_size() if is_distributed() else 1
50
+
51
+ def is_main():
52
+ return get_rank() == 0
53
+
54
+ def print_rank0(*args, **kwargs):
55
+ if is_main():
56
+ print(*args, **kwargs)
57
+
58
+ def split_list(lst, rank, world_size):
59
+ return [x for i, x in enumerate(lst) if i % world_size == rank]
60
+
61
+
62
+ # ──────────────────────────────────────────────────────────────────
63
+ # Prompts
64
+ # ──────────────────────────────────────────────────────────────────
65
+ def load_prompts(bench_name: str, num_samples: Optional[int] = None) -> List[str]:
66
+ local_paths = {
67
+ "humaneval": "/workspace/hanrui/datasets/humaneval/test.jsonl",
68
+ "mtbench": "/workspace/hanrui/datasets/mtbench/question.jsonl",
69
+ "gsm8k": "/workspace/hanrui/datasets/gsm8k/test.jsonl",
70
+ }
71
+ prompts = []
72
+ path = local_paths.get(bench_name)
73
+ if path and os.path.exists(path):
74
+ with open(path) as f:
75
+ for line in f:
76
+ item = json.loads(line)
77
+ if bench_name == "humaneval":
78
+ p = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{item['prompt']}\n```"
79
+ elif bench_name == "mtbench":
80
+ p = item.get("turns", [item.get("prompt", "")])[0]
81
+ elif bench_name == "gsm8k":
82
+ p = item["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}."
83
+ else:
84
+ p = str(item)
85
+ prompts.append(p)
86
+ else:
87
+ from datasets import load_dataset
88
+ if bench_name == "humaneval":
89
+ ds = load_dataset("openai/openai_humaneval", split="test")
90
+ prompts = [f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{x['prompt']}\n```" for x in ds]
91
+ elif bench_name == "mtbench":
92
+ ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
93
+ prompts = [x["prompt"][0] for x in ds]
94
+ elif bench_name == "gsm8k":
95
+ ds = load_dataset("openai/gsm8k", "main", split="test")
96
+ prompts = [x["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." for x in ds]
97
+ if num_samples is not None:
98
+ prompts = prompts[:num_samples]
99
+ return prompts
100
+
101
+
102
+ # ──────────────────────────────────────────────────────────────────
103
+ # spec_generate with acceptance_lengths returned
104
+ # (Same logic as DFlashDraftModel.spec_generate but returns accept lens)
105
+ # ──────────────────────────────────────────────────────────────────
106
+ @torch.inference_mode()
107
+ def spec_generate_b16(
108
+ draft_model,
109
+ target_model: nn.Module,
110
+ input_ids: torch.LongTensor,
111
+ max_new_tokens: int = 512,
112
+ temperature: float = 0.0,
113
+ stop_token_ids: Optional[List[int]] = None,
114
+ ) -> Tuple[torch.Tensor, List[int]]:
115
+ """Same as DFlashDraftModel.spec_generate but also returns acceptance_lengths."""
116
+ draft_model.eval()
117
+ device = target_model.device if hasattr(target_model, 'device') else input_ids.device
118
+ num_input_tokens = input_ids.shape[1]
119
+ max_length = num_input_tokens + max_new_tokens
120
+ block_size = draft_model.block_size
121
+ mask_token_id = draft_model.mask_token_id
122
+
123
+ output_ids = torch.full(
124
+ (1, max_length + block_size), mask_token_id,
125
+ dtype=torch.long, device=device,
126
+ )
127
+ position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
128
+
129
+ past_key_values_target = DynamicCache()
130
+ past_key_values_draft = DynamicCache()
131
+
132
+ # Prefill
133
+ output = target_model(
134
+ input_ids,
135
+ position_ids=position_ids[:, :num_input_tokens],
136
+ past_key_values=past_key_values_target,
137
+ use_cache=True,
138
+ logits_to_keep=1,
139
+ output_hidden_states=True,
140
+ )
141
+ output_ids[:, :num_input_tokens] = input_ids
142
+ output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature)
143
+ target_hidden = extract_context_feature(output.hidden_states, draft_model.target_layer_ids)
144
+
145
+ # Decode
146
+ acceptance_lengths = []
147
+ start = num_input_tokens
148
+ while start < max_length:
149
+ block_output_ids = output_ids[:, start:start + block_size].clone()
150
+ block_position_ids = position_ids[:, start:start + block_size]
151
+ noise_embedding = target_model.model.embed_tokens(block_output_ids)
152
+
153
+ draft_logits = target_model.lm_head(
154
+ draft_model(
155
+ target_hidden=target_hidden,
156
+ noise_embedding=noise_embedding,
157
+ position_ids=position_ids[:, past_key_values_draft.get_seq_length():start + block_size],
158
+ past_key_values=past_key_values_draft,
159
+ use_cache=True,
160
+ is_causal=False,
161
+ )[:, -block_size + 1:, :]
162
+ )
163
+ past_key_values_draft.crop(start)
164
+ block_output_ids[:, 1:] = sample(draft_logits)
165
+
166
+ output = target_model(
167
+ block_output_ids,
168
+ position_ids=block_position_ids,
169
+ past_key_values=past_key_values_target,
170
+ use_cache=True,
171
+ output_hidden_states=True,
172
+ )
173
+
174
+ posterior = sample(output.logits, temperature)
175
+ acceptance_length = (
176
+ (block_output_ids[:, 1:] == posterior[:, :-1])
177
+ .cumprod(dim=1).sum(dim=1)[0].item()
178
+ )
179
+ output_ids[:, start:start + int(acceptance_length) + 1] = block_output_ids[:, :int(acceptance_length) + 1]
180
+ output_ids[:, start + int(acceptance_length) + 1] = posterior[:, int(acceptance_length)]
181
+ start += int(acceptance_length) + 1
182
+ past_key_values_target.crop(start)
183
+ target_hidden = extract_context_feature(
184
+ output.hidden_states, draft_model.target_layer_ids
185
+ )[:, :int(acceptance_length) + 1, :]
186
+ acceptance_lengths.append(int(acceptance_length) + 1)
187
+
188
+ if stop_token_ids is not None and any(
189
+ sid in output_ids[:, num_input_tokens:start] for sid in stop_token_ids
190
+ ):
191
+ break
192
+
193
+ output_ids = output_ids[:, :max_length]
194
+ output_ids = output_ids[:, output_ids[0] != mask_token_id]
195
+ if stop_token_ids is not None:
196
+ stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
197
+ stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
198
+ if stop_idx.numel() > 0:
199
+ output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
200
+
201
+ return output_ids, acceptance_lengths
202
+
203
+
204
+ # ──────────────────────────────────────────────────────────────────
205
+ def parse_args():
206
+ p = argparse.ArgumentParser()
207
+ p.add_argument("--base-model", default=BASE_MODEL)
208
+ p.add_argument("--draft-model", default=DRAFT_MODEL)
209
+ p.add_argument("--max-new-tokens", type=int, default=512)
210
+ p.add_argument("--temperature", type=float, default=0.0)
211
+ p.add_argument("--benchmarks", nargs="+", default=["humaneval", "mtbench", "gsm8k"])
212
+ p.add_argument("--num-samples", type=int, default=None)
213
+ p.add_argument("--output-dir", default=RESULT_DIR)
214
+ return p.parse_args()
215
+
216
+
217
+ def main():
218
+ args = parse_args()
219
+
220
+ local_rank = int(os.environ.get("LOCAL_RANK", 0))
221
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
222
+
223
+ if world_size > 1:
224
+ dist.init_process_group(backend="nccl")
225
+ torch.cuda.set_device(local_rank)
226
+
227
+ device = f"cuda:{local_rank}"
228
+ rank = get_rank()
229
+
230
+ print_rank0(f"Running DFlash-b16 baseline on {world_size} GPU(s)")
231
+
232
+ # ── Load models ──
233
+ print_rank0(f"Loading target: {args.base_model}")
234
+ target_model = AutoModelForCausalLM.from_pretrained(
235
+ args.base_model,
236
+ torch_dtype=torch.bfloat16,
237
+ device_map=device,
238
+ trust_remote_code=True,
239
+ )
240
+ target_model.eval()
241
+
242
+ print_rank0(f"Loading DFlash-b16 draft: {args.draft_model}")
243
+ draft_model = AutoModel.from_pretrained(
244
+ args.draft_model,
245
+ torch_dtype=torch.bfloat16,
246
+ trust_remote_code=True,
247
+ ).to(device)
248
+ draft_model.eval()
249
+
250
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
251
+ stop_token_ids = [tokenizer.eos_token_id]
252
+
253
+ print_rank0(f"DFlash-b16: block_size={draft_model.block_size}, "
254
+ f"target_layer_ids={draft_model.target_layer_ids}, "
255
+ f"num_layers={len(draft_model.layers)}")
256
+
257
+ # ── Run benchmarks ──
258
+ results = {"model": "Qwen3-8B-DFlash-b16", "type": "baseline",
259
+ "block_size": draft_model.block_size}
260
+
261
+ for bench_name in args.benchmarks:
262
+ print_rank0(f"\n{'='*60}")
263
+ print_rank0(f"Benchmark: {bench_name} ({world_size} GPUs)")
264
+ print_rank0(f"{'='*60}")
265
+
266
+ all_prompts = load_prompts(bench_name, args.num_samples)
267
+ my_prompts = split_list(all_prompts, rank, world_size)
268
+ print_rank0(f"Total {len(all_prompts)} prompts, ~{len(my_prompts)} per GPU")
269
+
270
+ local_accept_lengths = []
271
+ local_tokens = 0
272
+ t0 = time.time()
273
+
274
+ iterator = tqdm(my_prompts, desc=f"[GPU{rank}] {bench_name}", unit="sample",
275
+ disable=(rank != 0))
276
+ for prompt in iterator:
277
+ messages = [{"role": "user", "content": prompt}]
278
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
279
+ input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
280
+
281
+ output_ids, accept_lens = spec_generate_b16(
282
+ draft_model=draft_model,
283
+ target_model=target_model,
284
+ input_ids=input_ids,
285
+ max_new_tokens=args.max_new_tokens,
286
+ temperature=args.temperature,
287
+ stop_token_ids=stop_token_ids,
288
+ )
289
+
290
+ local_accept_lengths.extend(accept_lens)
291
+ num_gen = output_ids.shape[1] - input_ids.shape[1]
292
+ local_tokens += num_gen
293
+
294
+ if rank == 0 and len(local_accept_lengths) > 0:
295
+ avg = sum(local_accept_lengths) / len(local_accept_lengths)
296
+ iterator.set_postfix(accept_len=f"{avg:.2f}", tokens=local_tokens, gen=num_gen)
297
+
298
+ elapsed = time.time() - t0
299
+
300
+ # ── Gather ──
301
+ if world_size > 1:
302
+ local_sum = torch.tensor(sum(local_accept_lengths), dtype=torch.float64, device=device)
303
+ local_count = torch.tensor(len(local_accept_lengths), dtype=torch.long, device=device)
304
+ local_tok = torch.tensor(local_tokens, dtype=torch.long, device=device)
305
+ dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
306
+ dist.all_reduce(local_count, op=dist.ReduceOp.SUM)
307
+ dist.all_reduce(local_tok, op=dist.ReduceOp.SUM)
308
+ total_accept_sum = local_sum.item()
309
+ total_count = local_count.item()
310
+ total_tokens = local_tok.item()
311
+ else:
312
+ total_accept_sum = sum(local_accept_lengths)
313
+ total_count = len(local_accept_lengths)
314
+ total_tokens = local_tokens
315
+
316
+ avg_accept_length = total_accept_sum / max(total_count, 1)
317
+ throughput = total_tokens / elapsed if elapsed > 0 else 0
318
+
319
+ print_rank0(f"\n{bench_name} Results:")
320
+ print_rank0(f" Avg Accept Length: {avg_accept_length:.3f}")
321
+ print_rank0(f" Total tokens: {total_tokens}")
322
+ print_rank0(f" Latency: {elapsed:.1f}s")
323
+ print_rank0(f" Throughput: {throughput:.1f} tok/s (aggregate {world_size} GPUs)")
324
+ print_rank0(f" Num verify rounds: {total_count}")
325
+ print_rank0(f" Num samples: {len(all_prompts)}")
326
+
327
+ results[bench_name] = {
328
+ "avg_accept_length": avg_accept_length,
329
+ "total_tokens": total_tokens,
330
+ "latency": elapsed,
331
+ "throughput": throughput,
332
+ "num_samples": len(all_prompts),
333
+ "num_verify_rounds": total_count,
334
+ "num_gpus": world_size,
335
+ }
336
+
337
+ # ── Save ──
338
+ if is_main():
339
+ os.makedirs(args.output_dir, exist_ok=True)
340
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
341
+ result_file = os.path.join(
342
+ args.output_dir,
343
+ f"dflash_b16_baseline_offline_{timestamp}.json",
344
+ )
345
+ with open(result_file, "w") as f:
346
+ json.dump(results, f, indent=2)
347
+ print(f"\nResults saved to: {result_file}")
348
+
349
+ if world_size > 1:
350
+ dist.destroy_process_group()
351
+
352
+
353
+ if __name__ == "__main__":
354
+ main()
syxin_old/eval_dflash_b16_baseline_changelog.md ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # eval_dflash_b16_baseline.py 修改记录
2
+
3
+ 对照官方仓库 `/workspace/hanrui/dflash/benchmark.py` 和 `run_benchmark.sh`,修复了以下问题。
4
+
5
+ ---
6
+
7
+ ## 1. [Critical] 添加 `attn_implementation` 参数
8
+
9
+ **问题**: 模型加载时未指定 attention 实现,默认使用 eager attention,性能远低于论文。
10
+
11
+ **修改**: target_model 和 draft_model 均添加 `attn_implementation="flash_attention_2"`(flash_attn 未安装时自动降级为 `"sdpa"`)。
12
+
13
+ ```python
14
+ # Before
15
+ target_model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, ...)
16
+ draft_model = AutoModel.from_pretrained(args.draft_model, torch_dtype=torch.bfloat16, ...)
17
+
18
+ # After
19
+ attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa"
20
+ target_model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, attn_implementation=attn_impl, ...)
21
+ draft_model = AutoModel.from_pretrained(args.draft_model, torch_dtype=torch.bfloat16, attn_implementation=attn_impl, ...)
22
+ ```
23
+
24
+ ---
25
+
26
+ ## 2. [Critical] 添加 `enable_thinking=False`
27
+
28
+ **问题**: Qwen3 系列模型默认启用 thinking mode,会在输出中插入大量 `<think>...</think>` 内容,导致生成内容和长度与论文测试条件完全不同,acceptance length 指标不可比。
29
+
30
+ **修改**: `apply_chat_template` 调用中添加 `enable_thinking=False`。
31
+
32
+ ```python
33
+ # Before
34
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
35
+
36
+ # After
37
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
38
+ ```
39
+
40
+ ---
41
+
42
+ ## 3. [Critical] 添加 autoregressive baseline (block_size=1) 以计算 speedup
43
+
44
+ **问题**: 原脚本只跑了 speculative decoding(block_size=16),没有跑 autoregressive baseline(block_size=1),因此无法计算论文中报告的 decoding speedup 指标。
45
+
46
+ **修改**: 对每个 prompt 先用 `block_size=1` 跑一次 baseline,再用 `block_size=block_size` 跑 speculative decoding,计算 `speedup = t1 / tb`。`spec_generate_b16` 函数新增 `block_size` 参数以支持 block_size=1 模式。
47
+
48
+ ```python
49
+ # 对每个 prompt:
50
+ _, _, t1 = spec_generate_b16(..., block_size=1, ...) # autoregressive
51
+ output_ids, accept_lens, tb = spec_generate_b16(..., block_size=block_size, ...) # speculative
52
+ # speedup = mean(t1) / mean(tb)
53
+ ```
54
+
55
+ ---
56
+
57
+ ## 4. [Critical] 修复计时方式:`time.time()` → `cuda_time()`
58
+
59
+ **问题**: 原脚本使用 `time.time()` 测量 wall clock 时间,不含 CUDA synchronize,GPU 异步执行导致计时不准确。
60
+
61
+ **修改**: 新增 `cuda_time()` 函数(与官方一致),在计时点调用 `torch.cuda.synchronize()` + `time.perf_counter()`。`spec_generate_b16` 内部使用 `cuda_time()` 精确测量 prefill 和 decode 阶段耗时,并返回 `time_per_output_token`。
62
+
63
+ ```python
64
+ def cuda_time() -> float:
65
+ torch.cuda.synchronize()
66
+ return time.perf_counter()
67
+ ```
68
+
69
+ ---
70
+
71
+ ## 5. [Critical] 添加 draft prefill 计时修正
72
+
73
+ **问题**: 原脚本将 draft model 的首次 prefill 时间计入了 decode 阶段,导致 time_per_token 偏高、speedup 偏低。
74
+
75
+ **修改**: 在 spec_generate_b16 的 decode 循环中,第一次 draft forward 完成后重置 `decode_start`(与官方 `draft_prefill` flag 逻辑一致)。
76
+
77
+ ```python
78
+ if draft_prefill:
79
+ draft_prefill = False
80
+ decode_start = cuda_time() # 重置,排除 draft 首次 prefill
81
+ ```
82
+
83
+ ---
84
+
85
+ ## 6. [Important] `max_new_tokens` 默认值 512 → 2048
86
+
87
+ **问题**: 原脚本默认 `max_new_tokens=512`,官方 `run_benchmark.sh` 使用 `2048`。生成长度不足会导致 acceptance length 统计样本量不够,指标与论文不可比。
88
+
89
+ **修改**: 默认值改为 `2048`。
90
+
91
+ ```python
92
+ # Before
93
+ p.add_argument("--max-new-tokens", type=int, default=512)
94
+
95
+ # After
96
+ p.add_argument("--max-new-tokens", type=int, default=2048)
97
+ ```
98
+
99
+ ---
100
+
101
+ ## 7. [Important] 添加固定随机种子
102
+
103
+ **问题**: 原脚本未设置随机种子,多次运行结果不可复现。
104
+
105
+ **修改**: 在 `main()` 开头添加与官方一致的种子设置。
106
+
107
+ ```python
108
+ random.seed(0)
109
+ np.random.seed(0)
110
+ torch.manual_seed(0)
111
+ torch.cuda.manual_seed_all(0)
112
+ torch.backends.cudnn.deterministic = True
113
+ torch.backends.cudnn.benchmark = False
114
+ ```
115
+
116
+ ---
117
+
118
+ ## 8. [Important] spec_generate_b16 支持 block_size=1 条件分支
119
+
120
+ **问题**: 原函数硬编码使用 `draft_model.block_size`,且始终 `output_hidden_states=True`。当 block_size=1(autoregressive baseline)时,不应调用 draft model 和 extract hidden states。
121
+
122
+ **修改**:
123
+ - 新增 `block_size` 参数
124
+ - `output_hidden_states` 仅在 `block_size > 1` 时为 True
125
+ - draft model forward 和 hidden state 提取仅在 `block_size > 1` 时执行
126
+
127
+ ---
128
+
129
+ ## 9. [Minor] 输出增加 speedup 和 acceptance histogram
130
+
131
+ **修改**: 结果输出中新增:
132
+ - `Decoding speedup: X.XXx`(t1/tb 比值)
133
+ - `Acceptance length histogram`(各 acceptance length 的占比分布)
134
+
135
+ 与官方 benchmark.py 的输出格式对齐。
136
+
137
+ ---
138
+
139
+ ## 未修改项
140
+
141
+ - **模型选择**: 保留 Qwen3-8B(非官方默认的 Qwen3-4B),因为你本地模型就是 8B
142
+ - **Draft 模型加载方式**: 保留 `AutoModel.from_pretrained`(依赖 `trust_remote_code=True`),未改为官方的 `DFlashDraftModel`,因为需要模型目录下的 remote code 支持
143
+ - **数据集范围**: 保留原有的 3 个 benchmark(humaneval/mtbench/gsm8k),未扩展到官方的 10 个
syxin_old/eval_dflash_lora_inject.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Offline evaluation for DFlash-LoRA-Inject: measure accepted length & speedup.
4
+ Aligned with official DFlash benchmark.py methodology.
5
+
6
+ Unlike DFlash-b16 which uses a small 5-layer draft model with fc/hidden_norm,
7
+ LoRA-Inject uses a full Qwen3-8B with LoRA adapters that receives target hidden
8
+ states via layer-by-layer injection.
9
+
10
+ Usage:
11
+ conda activate spec
12
+
13
+ # 8 GPU parallel (default, all 10 benchmarks)
14
+ torchrun --nproc_per_node 8 eval_dflash_lora_inject.py
15
+
16
+ # single GPU
17
+ python3 eval_dflash_lora_inject.py
18
+
19
+ # specific checkpoint / benchmark
20
+ torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --ckpt epoch_0_step_1000 --datasets humaneval
21
+
22
+ # quick test
23
+ torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --max-samples 20
24
+ """
25
+ import argparse
26
+ import json
27
+ import os
28
+ import random
29
+ import sys
30
+ import time
31
+ import warnings
32
+ from itertools import chain
33
+ from types import SimpleNamespace
34
+ from typing import List, Optional, Tuple
35
+
36
+ import numpy as np
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.distributed as dist
40
+ from peft import PeftModel
41
+ from tqdm import tqdm
42
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
43
+
44
+ # Import official dataset loader
45
+ sys.path.insert(0, "/workspace/hanrui/dflash")
46
+ from model.utils import load_and_process_dataset
47
+
48
+ # ──────────────────────────────────────────────────────────────────
49
+ # Config defaults
50
+ # ──────────────────────────────────────────────────────────────────
51
+ BASE_MODEL = "/workspace/models/Qwen3-8B"
52
+ ADAPTER_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject"
53
+ DEFAULT_CKPT = "epoch_3_step_4644"
54
+ MASK_TOKEN_ID = 151669 # Qwen3 <|mask|>
55
+ BLOCK_SIZE = 16
56
+ RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results"
57
+
58
+ # Official benchmark tasks (from run_benchmark.sh)
59
+ OFFICIAL_TASKS = {
60
+ "gsm8k": 128,
61
+ "math500": 128,
62
+ "aime24": 30,
63
+ "aime25": 30,
64
+ "humaneval": 164,
65
+ "mbpp": 128,
66
+ "livecodebench": 128,
67
+ "swe-bench": 128,
68
+ "mt-bench": 80,
69
+ "alpaca": 128,
70
+ }
71
+
72
+
73
+ # ──────────────────────────────────────────────────────────────────
74
+ # CUDA-synchronised timer (matches official benchmark.py)
75
+ # ──────────────────────────────────────────────────────────────────
76
+ def cuda_time() -> float:
77
+ torch.cuda.synchronize()
78
+ return time.perf_counter()
79
+
80
+
81
+ def has_flash_attn() -> bool:
82
+ try:
83
+ import flash_attn # noqa: F401
84
+ return True
85
+ except ImportError:
86
+ print("[WARN] flash_attn not installed, falling back to sdpa.")
87
+ return False
88
+
89
+
90
+ # ──────────────────────────────────────────────────────────────────
91
+ # Distributed helpers (mirrors official distributed.py)
92
+ # ──────────────────────────────────────────────────────────────────
93
+ def dist_init():
94
+ if "RANK" not in os.environ:
95
+ warnings.warn("RANK not set. Skipping distributed init.")
96
+ return
97
+ dist.init_process_group(backend="nccl", init_method="env://")
98
+
99
+ def dist_rank():
100
+ return int(os.environ.get("RANK", 0))
101
+
102
+ def dist_size():
103
+ return int(os.environ.get("WORLD_SIZE", 1))
104
+
105
+ def dist_local_rank():
106
+ return int(os.environ.get("LOCAL_RANK", 0))
107
+
108
+ def dist_is_main():
109
+ return dist_rank() == 0
110
+
111
+ def dist_gather(obj, dst=0):
112
+ if not dist.is_initialized():
113
+ return [obj]
114
+ if dist_is_main():
115
+ objs = [None for _ in range(dist_size())]
116
+ dist.gather_object(obj, objs, dst=dst)
117
+ return objs
118
+ else:
119
+ dist.gather_object(obj, dst=dst)
120
+ return None
121
+
122
+ def print_rank0(*args, **kwargs):
123
+ if dist_is_main():
124
+ print(*args, **kwargs)
125
+
126
+
127
+ # ──────────────────────────────────────────────────────────────────
128
+ # Sampling (matches official model/utils.py::sample)
129
+ # ──────────────────────────────────────────────────────────────────
130
+ def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
131
+ if temperature < 1e-5:
132
+ return torch.argmax(logits, dim=-1)
133
+ bsz, seq_len, vocab_size = logits.shape
134
+ logits = logits.view(-1, vocab_size)
135
+ logits = logits / temperature
136
+ probs = torch.softmax(logits, dim=-1)
137
+ return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)
138
+
139
+
140
+ # ──────────────────────────────────────────────────────────────────
141
+ # Build DFlash attention mask (vectorized, no Python loops)
142
+ # ──────────────────────────────────────────────────────────────────
143
+ def build_dflash_mask(ctx_len: int, block_size: int, device, dtype=torch.bfloat16):
144
+ """
145
+ Build DFlash attention mask for [context | block] sequence.
146
+ - Context part: standard causal
147
+ - Block part: each token sees all context + all tokens in same block (bidirectional)
148
+ """
149
+ full_len = ctx_len + block_size
150
+ neg_inf = torch.finfo(dtype).min
151
+
152
+ mask = torch.full((1, 1, full_len, full_len), neg_inf, device=device, dtype=dtype)
153
+
154
+ if ctx_len > 0:
155
+ ctx_rows = torch.arange(ctx_len, device=device)
156
+ ctx_cols = torch.arange(ctx_len, device=device)
157
+ causal = ctx_cols.unsqueeze(0) <= ctx_rows.unsqueeze(1)
158
+ mask[0, 0, :ctx_len, :ctx_len].masked_fill_(causal, 0)
159
+
160
+ if ctx_len > 0:
161
+ mask[0, 0, ctx_len:, :ctx_len] = 0
162
+ mask[0, 0, ctx_len:, ctx_len:] = 0
163
+
164
+ return mask
165
+
166
+
167
+ # ──────────────────────────────────────────────────────────────────
168
+ # Pure autoregressive generation (target model only, no draft)
169
+ # Used for AR baseline timing -- avoids inflating AR time with draft overhead.
170
+ # ──────────────────────────────────────────────────────────────────
171
+ @torch.inference_mode()
172
+ def ar_generate(
173
+ target_model: nn.Module,
174
+ input_ids: torch.LongTensor,
175
+ max_new_tokens: int = 2048,
176
+ mask_token_id: int = MASK_TOKEN_ID,
177
+ temperature: float = 0.0,
178
+ stop_token_ids: Optional[List[int]] = None,
179
+ ) -> SimpleNamespace:
180
+ """
181
+ Pure autoregressive generation using only the target model.
182
+ Mirrors official benchmark.py with block_size=1 (no draft model involved).
183
+ Returns SimpleNamespace matching official dflash_generate output format.
184
+ """
185
+ device = input_ids.device
186
+ num_input_tokens = input_ids.shape[1]
187
+ max_length = num_input_tokens + max_new_tokens
188
+
189
+ output_ids = torch.full(
190
+ (1, max_length + 1), mask_token_id,
191
+ dtype=torch.long, device=device,
192
+ )
193
+ output_ids[:, :num_input_tokens] = input_ids
194
+ position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
195
+ past_key_values = DynamicCache()
196
+
197
+ # Prefill
198
+ prefill_start = cuda_time()
199
+ output = target_model(
200
+ input_ids,
201
+ position_ids=position_ids[:, :num_input_tokens],
202
+ past_key_values=past_key_values,
203
+ use_cache=True,
204
+ logits_to_keep=1,
205
+ output_hidden_states=False,
206
+ )
207
+ first_token = sample(output.logits, temperature)
208
+ output_ids[:, num_input_tokens:num_input_tokens + 1] = first_token
209
+ time_to_first_token = cuda_time() - prefill_start
210
+
211
+ # Decode (autoregressive, one token at a time)
212
+ decode_start = cuda_time()
213
+ start = num_input_tokens
214
+
215
+ while start < max_length:
216
+ cur_token = output_ids[:, start:start + 1]
217
+ cur_pos = position_ids[:, start:start + 1]
218
+
219
+ output = target_model(
220
+ cur_token,
221
+ position_ids=cur_pos,
222
+ past_key_values=past_key_values,
223
+ use_cache=True,
224
+ output_hidden_states=False,
225
+ )
226
+
227
+ next_token = sample(output.logits, temperature)
228
+ start += 1
229
+ output_ids[:, start:start + 1] = next_token
230
+ past_key_values.crop(start)
231
+
232
+ # Check stop tokens (matches official: check all generated)
233
+ if stop_token_ids is not None and any(
234
+ sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
235
+ ):
236
+ break
237
+
238
+ output_ids = output_ids[:, :max_length]
239
+ output_ids = output_ids[:, output_ids[0] != mask_token_id]
240
+ if stop_token_ids is not None:
241
+ stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
242
+ stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
243
+ if stop_idx.numel() > 0:
244
+ output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
245
+
246
+ num_output_tokens = output_ids.shape[1] - num_input_tokens
247
+ total_decode_time = cuda_time() - decode_start
248
+ time_per_output_token = total_decode_time / max(num_output_tokens, 1)
249
+
250
+ return SimpleNamespace(
251
+ output_ids=output_ids,
252
+ num_input_tokens=num_input_tokens,
253
+ num_output_tokens=num_output_tokens,
254
+ time_to_first_token=time_to_first_token,
255
+ time_per_output_token=time_per_output_token,
256
+ acceptance_lengths=[1] * max(num_output_tokens, 0), # AR: always 1
257
+ )
258
+
259
+
260
+ # ──────────────────────────────────────────────────────────────────
261
+ # Core: spec_generate with layer-by-layer injection (KV-cached)
262
+ # ──────────────────────────────────────────────────────────────────
263
+ @torch.inference_mode()
264
+ def spec_generate_inject(
265
+ target_model: nn.Module,
266
+ draft_model: nn.Module,
267
+ input_ids: torch.LongTensor,
268
+ max_new_tokens: int = 2048,
269
+ block_size: int = 16,
270
+ mask_token_id: int = MASK_TOKEN_ID,
271
+ temperature: float = 0.0,
272
+ stop_token_ids: Optional[List[int]] = None,
273
+ ) -> SimpleNamespace:
274
+ """
275
+ Speculative generation using DFlash-LoRA-Inject inference pattern.
276
+ Returns SimpleNamespace matching official dflash_generate output format.
277
+ """
278
+ device = input_ids.device
279
+ num_input_tokens = input_ids.shape[1]
280
+ max_length = num_input_tokens + max_new_tokens
281
+
282
+ draft_layers = draft_model.model.layers
283
+ draft_norm = draft_model.model.norm
284
+ draft_lm_head = draft_model.lm_head
285
+ rotary_emb = draft_model.model.rotary_emb
286
+ num_layers = len(draft_layers)
287
+
288
+ output_ids = torch.full(
289
+ (1, max_length + block_size), mask_token_id,
290
+ dtype=torch.long, device=device,
291
+ )
292
+ output_ids[:, :num_input_tokens] = input_ids
293
+
294
+ # ── Prefill: target with KV cache + hidden states ──
295
+ prefill_start = cuda_time()
296
+ target_kv = DynamicCache()
297
+ target_output = target_model(
298
+ input_ids,
299
+ past_key_values=target_kv,
300
+ use_cache=True,
301
+ output_hidden_states=True,
302
+ )
303
+ first_token = sample(target_output.logits[:, -1:, :], temperature)
304
+ output_ids[:, num_input_tokens] = first_token.squeeze()
305
+
306
+ ctx_hidden_per_layer = [
307
+ target_output.hidden_states[i]
308
+ for i in range(num_layers)
309
+ ]
310
+
311
+ time_to_first_token = cuda_time() - prefill_start
312
+
313
+ # Decode
314
+ decode_start = cuda_time()
315
+ acceptance_lengths = []
316
+ start = num_input_tokens
317
+ draft_prefill = True
318
+
319
+ while start < max_length:
320
+ end = min(start + block_size, max_length)
321
+ actual_block_size = end - start
322
+
323
+ block_ids = output_ids[:, start:end].clone()
324
+
325
+ # ── FIX: Get anchor's target hidden state before draft forward ──
326
+ # Training: block k's draft tokens see target hidden states at positions
327
+ # 0..k*block_size INCLUSIVE (the anchor position). But in eval, ctx_hidden
328
+ # only covers 0..start-1. We must process the anchor through the target
329
+ # model to get its hidden state, matching training's attention pattern.
330
+ anchor_token = output_ids[:, start:start + 1]
331
+ anchor_pos = torch.tensor([[start]], device=device)
332
+ anchor_output = target_model(
333
+ anchor_token,
334
+ position_ids=anchor_pos,
335
+ past_key_values=target_kv,
336
+ use_cache=True,
337
+ output_hidden_states=True,
338
+ )
339
+ # Save anchor hidden states (one per layer)
340
+ for i in range(num_layers):
341
+ anchor_hs = anchor_output.hidden_states[i] # [1, 1, hidden_dim]
342
+ ctx_hidden_per_layer[i] = torch.cat([ctx_hidden_per_layer[i], anchor_hs], dim=1)
343
+ # Roll back KV cache: verification will re-process from position start
344
+ target_kv.crop(start)
345
+
346
+ # ── Draft: forward with layer-by-layer injection ──
347
+ draft_hidden = draft_model.model.embed_tokens(block_ids)
348
+ ctx_len = ctx_hidden_per_layer[0].shape[1] # now includes anchor at position start
349
+
350
+ dflash_mask = build_dflash_mask(ctx_len, actual_block_size, device)
351
+
352
+ # Position IDs: context covers [0..start], block covers [start..start+bs-1]
353
+ # Position 'start' appears twice (in both context and block), matching
354
+ # training where target and draft share the same position IDs.
355
+ ctx_positions = torch.arange(ctx_len, device=device)
356
+ block_positions = torch.arange(start, start + actual_block_size, device=device)
357
+ combined_pos = torch.cat([ctx_positions, block_positions], dim=0).unsqueeze(0)
358
+
359
+ dummy_combined = torch.empty(1, ctx_len + actual_block_size, draft_hidden.shape[-1],
360
+ device=device, dtype=torch.bfloat16)
361
+ position_embeddings = rotary_emb(dummy_combined, combined_pos)
362
+
363
+ for layer_idx in range(num_layers):
364
+ target_ctx = ctx_hidden_per_layer[layer_idx]
365
+ combined = torch.cat([target_ctx, draft_hidden], dim=1)
366
+
367
+ layer_output = draft_layers[layer_idx](
368
+ combined,
369
+ attention_mask=dflash_mask,
370
+ position_ids=combined_pos,
371
+ position_embeddings=position_embeddings,
372
+ )
373
+ if isinstance(layer_output, tuple):
374
+ layer_output = layer_output[0]
375
+ draft_hidden = layer_output[:, ctx_len:, :]
376
+
377
+ draft_hidden = draft_norm(draft_hidden)
378
+ draft_logits = draft_lm_head(draft_hidden)
379
+
380
+ draft_predictions = sample(draft_logits[:, 1:, :], temperature)
381
+ block_ids[:, 1:actual_block_size] = draft_predictions[:, :actual_block_size - 1]
382
+
383
+ # Exclude draft's first prefill from decode timing (matches official pattern)
384
+ if draft_prefill:
385
+ draft_prefill = False
386
+ decode_start = cuda_time()
387
+
388
+ # ── Verify: target forward on block tokens (with KV cache) ──
389
+ position_ids_block = torch.arange(
390
+ start, start + actual_block_size, device=device
391
+ ).unsqueeze(0)
392
+
393
+ target_verify = target_model(
394
+ block_ids,
395
+ position_ids=position_ids_block,
396
+ past_key_values=target_kv,
397
+ use_cache=True,
398
+ output_hidden_states=True,
399
+ )
400
+ target_tokens = sample(target_verify.logits, temperature)
401
+
402
+ # Acceptance
403
+ matches = (block_ids[:, 1:actual_block_size] == target_tokens[:, :actual_block_size - 1])
404
+ acceptance_length = int(matches.cumprod(dim=1).sum(dim=1)[0].item())
405
+
406
+ output_ids[:, start:start + acceptance_length + 1] = block_ids[:, :acceptance_length + 1]
407
+ output_ids[:, start + acceptance_length + 1] = target_tokens[:, acceptance_length]
408
+
409
+ accepted_end = start + acceptance_length + 1
410
+ target_kv.crop(accepted_end)
411
+
412
+ # Remove the anchor hidden state we added above (it's position start);
413
+ # instead, save the verification's hidden states which include the anchor
414
+ # and accepted tokens computed with the correct full KV context.
415
+ for i in range(num_layers):
416
+ # Drop the anchor we appended earlier (last entry in ctx_hidden)
417
+ ctx_hidden_per_layer[i] = ctx_hidden_per_layer[i][:, :-1, :]
418
+ # Add verification hidden states for accepted positions
419
+ new_hidden = target_verify.hidden_states[i][:, :acceptance_length + 1, :]
420
+ ctx_hidden_per_layer[i] = torch.cat([ctx_hidden_per_layer[i], new_hidden], dim=1)
421
+
422
+ start += acceptance_length + 1
423
+ acceptance_lengths.append(acceptance_length + 1)
424
+
425
+ # Official: check ALL generated tokens
426
+ if stop_token_ids is not None and any(
427
+ sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
428
+ ):
429
+ break
430
+
431
+ output_ids = output_ids[:, :min(start, max_length)]
432
+ output_ids = output_ids[:, output_ids[0] != mask_token_id]
433
+ if stop_token_ids is not None:
434
+ stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
435
+ stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
436
+ if stop_idx.numel() > 0:
437
+ output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
438
+
439
+ num_output_tokens = output_ids.shape[1] - num_input_tokens
440
+ total_decode_time = cuda_time() - decode_start
441
+ time_per_output_token = total_decode_time / max(num_output_tokens, 1)
442
+
443
+ return SimpleNamespace(
444
+ output_ids=output_ids,
445
+ num_input_tokens=num_input_tokens,
446
+ num_output_tokens=num_output_tokens,
447
+ time_to_first_token=time_to_first_token,
448
+ time_per_output_token=time_per_output_token,
449
+ acceptance_lengths=acceptance_lengths,
450
+ )
451
+
452
+
453
+ # ──────────────────────────────────────────────────────────────────
454
+ # Main
455
+ # ──────────────────────────────────────────────────────────────────
456
+ def parse_args():
457
+ p = argparse.ArgumentParser(description="Offline eval for DFlash-LoRA-Inject (aligned with official)")
458
+ p.add_argument("--base-model", default=BASE_MODEL)
459
+ p.add_argument("--adapter-root", default=ADAPTER_ROOT)
460
+ p.add_argument("--ckpt", default=DEFAULT_CKPT, help="Checkpoint folder name")
461
+ p.add_argument("--merged-path",
462
+ default="/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged",
463
+ help="Path to pre-merged model. If None, will merge on the fly.")
464
+ p.add_argument("--block-size", type=int, default=BLOCK_SIZE)
465
+ p.add_argument("--max-new-tokens", type=int, default=2048,
466
+ help="Max new tokens per turn (official shell uses 2048)")
467
+ p.add_argument("--temperature", type=float, default=0.0)
468
+ p.add_argument("--datasets", nargs="+", default=list(OFFICIAL_TASKS.keys()),
469
+ help="Benchmarks to run (default: all 10 official tasks)")
470
+ p.add_argument("--max-samples", type=int, default=None,
471
+ help="Override max samples per dataset (None = use official per-task counts)")
472
+ p.add_argument("--output-dir", default=RESULT_DIR)
473
+ return p.parse_args()
474
+
475
+
476
+ def main():
477
+ args = parse_args()
478
+
479
+ # Fix random seeds (matches official)
480
+ random.seed(0)
481
+ np.random.seed(0)
482
+ torch.manual_seed(0)
483
+ torch.cuda.manual_seed_all(0)
484
+ torch.backends.cudnn.deterministic = True
485
+ torch.backends.cudnn.benchmark = False
486
+
487
+ # ── Init distributed ──
488
+ dist_init()
489
+ torch.cuda.set_device(dist_local_rank())
490
+ device = torch.device(f"cuda:{dist_local_rank()}")
491
+
492
+ print_rank0(f"Running on {dist_size()} GPU(s)")
493
+
494
+ # Detect flash_attn (only for target model; draft needs sdpa for custom DFlash mask)
495
+ installed_flash_attn = has_flash_attn()
496
+ target_attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa"
497
+ draft_attn_impl = "sdpa" # DFlash injection uses custom attention mask
498
+ print_rank0(f"Using attn_implementation: target={target_attn_impl}, draft={draft_attn_impl}")
499
+
500
+ # ── Load models ──
501
+ print_rank0(f"Loading target model: {args.base_model}")
502
+ target_model = AutoModelForCausalLM.from_pretrained(
503
+ args.base_model,
504
+ torch_dtype=torch.bfloat16,
505
+ attn_implementation=target_attn_impl,
506
+ device_map=device,
507
+ trust_remote_code=True,
508
+ )
509
+ target_model.eval()
510
+
511
+ if args.merged_path and os.path.isdir(args.merged_path):
512
+ print_rank0(f"Loading pre-merged draft model: {args.merged_path}")
513
+ draft_model = AutoModelForCausalLM.from_pretrained(
514
+ args.merged_path,
515
+ torch_dtype=torch.bfloat16,
516
+ attn_implementation=draft_attn_impl,
517
+ device_map=device,
518
+ trust_remote_code=True,
519
+ )
520
+ else:
521
+ adapter_path = os.path.join(args.adapter_root, args.ckpt)
522
+ print_rank0(f"Loading base + LoRA adapter: {adapter_path}")
523
+ draft_model = AutoModelForCausalLM.from_pretrained(
524
+ args.base_model,
525
+ torch_dtype=torch.bfloat16,
526
+ attn_implementation=draft_attn_impl,
527
+ device_map=device,
528
+ trust_remote_code=True,
529
+ )
530
+ draft_model = PeftModel.from_pretrained(draft_model, adapter_path)
531
+ draft_model = draft_model.merge_and_unload()
532
+ draft_model.eval()
533
+
534
+ tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
535
+ stop_token_ids = [tokenizer.eos_token_id]
536
+
537
+ block_size = args.block_size
538
+
539
+ # ── Run benchmarks ──
540
+ all_results = {"model": f"dflash-lora-inject/{args.ckpt}", "block_size": block_size}
541
+
542
+ for dataset_name in args.datasets:
543
+ print_rank0(f"\n{'=' * 60}")
544
+ print_rank0(f"Benchmark: {dataset_name} ({dist_size()} GPUs)")
545
+ print_rank0(f"{'=' * 60}")
546
+
547
+ # Load dataset using official loader
548
+ dataset = load_and_process_dataset(dataset_name)
549
+
550
+ # Sample selection: official uses shuffle(seed=0).select()
551
+ max_samples = args.max_samples if args.max_samples is not None else OFFICIAL_TASKS.get(dataset_name)
552
+ if max_samples is not None and len(dataset) > max_samples:
553
+ dataset = dataset.shuffle(seed=0).select(range(max_samples))
554
+
555
+ print_rank0(f"Total {len(dataset)} samples, distributed across {dist_size()} GPUs")
556
+
557
+ responses = []
558
+ indices = range(dist_rank(), len(dataset), dist_size())
559
+
560
+ iterator = tqdm(indices, desc=f"[GPU{dist_rank()}] {dataset_name}",
561
+ unit="sample", disable=not dist_is_main())
562
+
563
+ for idx in iterator:
564
+ instance = dataset[idx]
565
+
566
+ # Multi-turn support (matches official benchmark.py)
567
+ messages = []
568
+ for turn_index, user_content in enumerate(instance["turns"]):
569
+ messages.append({"role": "user", "content": user_content})
570
+ input_text = tokenizer.apply_chat_template(
571
+ messages, tokenize=False, add_generation_prompt=True,
572
+ enable_thinking=False,
573
+ )
574
+ input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
575
+
576
+ response = {}
577
+
578
+ # AR baseline: pure target-only autoregressive (no draft overhead)
579
+ response[1] = ar_generate(
580
+ target_model=target_model,
581
+ input_ids=input_ids,
582
+ max_new_tokens=args.max_new_tokens,
583
+ mask_token_id=MASK_TOKEN_ID,
584
+ temperature=args.temperature,
585
+ stop_token_ids=stop_token_ids,
586
+ )
587
+
588
+ # Speculative: DFlash-LoRA-Inject
589
+ response[block_size] = spec_generate_inject(
590
+ target_model=target_model,
591
+ draft_model=draft_model,
592
+ input_ids=input_ids,
593
+ max_new_tokens=args.max_new_tokens,
594
+ block_size=block_size,
595
+ mask_token_id=MASK_TOKEN_ID,
596
+ temperature=args.temperature,
597
+ stop_token_ids=stop_token_ids,
598
+ )
599
+
600
+ # Append assistant response for multi-turn context
601
+ spec_response = response[block_size]
602
+ generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens:]
603
+ output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
604
+ messages.append({"role": "assistant", "content": output_text})
605
+ responses.append(response)
606
+
607
+ if dist_is_main() and responses:
608
+ recent_tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses[-5:]])
609
+ iterator.set_postfix(accept_len=f"{recent_tau:.2f}")
610
+
611
+ # ── Gather to rank 0 (matches official) ──
612
+ if dist_size() > 1:
613
+ gathered = dist_gather(responses, dst=0)
614
+ if not dist_is_main():
615
+ continue
616
+ responses = list(chain(*gathered))
617
+ elif not dist_is_main():
618
+ continue
619
+
620
+ # ── Compute metrics (exact official formulas) ──
621
+ t1 = np.mean([r[1].time_per_output_token for r in responses])
622
+ tb = np.mean([r[block_size].time_per_output_token for r in responses])
623
+ speedup = t1 / tb if tb > 0 else 0
624
+
625
+ # Acceptance length: per-sample mean, then mean of means (official)
626
+ tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses])
627
+
628
+ # Histogram
629
+ acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses]))
630
+ histogram = [acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1)]
631
+
632
+ print_rank0(f"\n{dataset_name} Results:")
633
+ print_rank0(f" Decoding speedup: {speedup:.2f}x")
634
+ print_rank0(f" Average Acceptance length: {tau:.2f}")
635
+ print_rank0(f" Acceptance length histogram: {[f'{x * 100:.1f}%' for x in histogram]}")
636
+ print_rank0(f" Num responses: {len(responses)}")
637
+
638
+ all_results[dataset_name] = {
639
+ "decoding_speedup": speedup,
640
+ "avg_accept_length": tau,
641
+ "acceptance_histogram": histogram,
642
+ "num_responses": len(responses),
643
+ "num_gpus": dist_size(),
644
+ }
645
+
646
+ # ── Save results ──
647
+ if dist_is_main():
648
+ os.makedirs(args.output_dir, exist_ok=True)
649
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
650
+ result_file = os.path.join(
651
+ args.output_dir,
652
+ f"dflash_lora_inject_offline_{args.ckpt}_{timestamp}.json",
653
+ )
654
+ with open(result_file, "w") as f:
655
+ json.dump(all_results, f, indent=2)
656
+ print(f"\nResults saved to: {result_file}")
657
+
658
+
659
+ if __name__ == "__main__":
660
+ main()
syxin_old/eval_gsm8k_humaneval_mtbench.log ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ nohup: ignoring input
2
+ WARNING:__main__:
3
+ *****************************************
4
+ Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
5
+ *****************************************
6
+ [W324 11:41:43.200488949 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
7
+ [W324 11:41:43.200586722 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
8
+ [W324 11:41:43.267031138 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
9
+ [W324 11:41:43.267675225 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
10
+ [W324 11:41:43.279640318 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
11
+ [W324 11:41:43.291758156 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
12
+ [W324 11:41:43.328250126 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
13
+ [W324 11:41:43.335890706 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
14
+ Running on 8 GPU(s)
15
+ Using attn_implementation: target=flash_attention_2, draft=sdpa
16
+ Loading target model: /workspace/models/Qwen3-8B
17
+ `torch_dtype` is deprecated! Use `dtype` instead!
18
+ `torch_dtype` is deprecated! Use `dtype` instead!
19
+ `torch_dtype` is deprecated! Use `dtype` instead!
20
+ `torch_dtype` is deprecated! Use `dtype` instead!
21
+ `torch_dtype` is deprecated! Use `dtype` instead!
22
+ `torch_dtype` is deprecated! Use `dtype` instead!
23
+ `torch_dtype` is deprecated! Use `dtype` instead!
24
+ `torch_dtype` is deprecated! Use `dtype` instead!
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+ Loading base + LoRA adapter: /workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_4644
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+ ============================================================
44
+ Benchmark: gsm8k (8 GPUs)
45
+ ============================================================
46
+ Total 128 samples, distributed across 8 GPUs
47
+
48
+ /workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/storage.py:414: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
49
+ return torch.load(io.BytesIO(b))
50
+
51
+ gsm8k Results:
52
+ Decoding speedup: 1.01x
53
+ Average Acceptance length: 1.99
54
+ Acceptance length histogram: ['0.0%', '3.6%', '94.8%', '1.3%', '0.3%', '0.1%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%']
55
+ Num responses: 128
56
+
57
+ ============================================================
58
+ Benchmark: humaneval (8 GPUs)
59
+ ============================================================
60
+ Total 164 samples, distributed across 8 GPUs
61
+
62
+
63
+ humaneval Results:
64
+ Decoding speedup: 0.96x
65
+ Average Acceptance length: 1.97
66
+ Acceptance length histogram: ['0.0%', '4.6%', '94.6%', '0.6%', '0.1%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%']
67
+ Num responses: 164
68
+
69
+ ============================================================
70
+ Benchmark: mt-bench (8 GPUs)
71
+ ============================================================
72
+ Total 80 samples, distributed across 8 GPUs
73
+
74
+
75
+ mt-bench Results:
76
+ Decoding speedup: 0.84x
77
+ Average Acceptance length: 1.94
78
+ Acceptance length histogram: ['0.0%', '6.7%', '92.6%', '0.4%', '0.2%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%']
79
+ Num responses: 160
80
+
81
+ Results saved to: /workspace/hanrui/syxin_old/Specforge/benchmarks/results/dflash_lora_inject_offline_epoch_3_step_4644_20260324_121731.json
syxin_old/eval_run.log ADDED
The diff for this file is too large to render. See raw diff
 
syxin_old/launch_train.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ cd /workspace/hanrui/syxin_old/Specforge
5
+
6
+ export TORCHINDUCTOR_CACHE_DIR=/workspace/hanrui/cache/compiled_kernels
7
+ export SPECFORGE_DATA_NUM_PROC=16
8
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
9
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
10
+ export HF_DATASETS_CACHE=/workspace/hanrui/cache/hf_datasets
11
+ export HF_HOME=/workspace/hanrui/cache/hf_home
12
+
13
+ torchrun --nproc_per_node=8 \
14
+ scripts/train_dflash_lora_inject.py \
15
+ --target-model-path /workspace/models/Qwen3-8B \
16
+ --target-model-backend hf \
17
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
18
+ --output-dir outputs/qwen3-8b-sft-32gpu-v2 \
19
+ --block-size 16 \
20
+ --attention-backend additive \
21
+ --attn-implementation sdpa \
22
+ --max-length 2048 \
23
+ --batch-size 4 \
24
+ --accumulation-steps 8 \
25
+ --num-epochs 3 \
26
+ --learning-rate 5e-5 \
27
+ --loss-decay-gamma 7 \
28
+ --gradient-checkpointing \
29
+ --chat-template qwen \
30
+ --log-interval 50 \
31
+ --save-interval 500 \
32
+ --cache-dir /workspace/hanrui/cache \
33
+ --lora-rank 32 \
34
+ --lora-alpha 64 \
35
+ --lora-dropout 0.1 \
36
+ --trust-remote-code \
37
+ --dataloader-num-workers 0
syxin_old/launch_train_dflash_wrapper.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Python wrapper to launch dflash training script via northjob/torchrun
4
+ """
5
+ import subprocess
6
+ import sys
7
+ import os
8
+
9
+ if __name__ == "__main__":
10
+ bash_script = "/workspace/hanrui/syxin_old/run_train_multinode_dflash.sh"
11
+ args = sys.argv[1:]
12
+
13
+ cmd = ["bash", bash_script] + args
14
+
15
+ result = subprocess.run(cmd, env=os.environ.copy())
16
+
17
+ sys.exit(result.returncode)
syxin_old/launch_train_random_anchor.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Python wrapper to launch random anchor training script via northjob/torchrun
4
+ """
5
+ import subprocess
6
+ import sys
7
+ import os
8
+
9
+ if __name__ == "__main__":
10
+ bash_script = "/workspace/hanrui/syxin_old/run_train_multinode_random_anchor.sh"
11
+ args = sys.argv[1:]
12
+
13
+ cmd = ["bash", bash_script] + args
14
+ result = subprocess.run(cmd, env=os.environ.copy())
15
+ sys.exit(result.returncode)
syxin_old/launch_train_wrapper.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Python wrapper to launch bash training script via torchrun
4
+ """
5
+ import subprocess
6
+ import sys
7
+ import os
8
+
9
+ if __name__ == "__main__":
10
+ # Get the bash script path and arguments
11
+ bash_script = "/workspace/hanrui/syxin_old/run_train_multinode.sh"
12
+ args = sys.argv[1:] # Pass through all arguments
13
+
14
+ # Build the command
15
+ cmd = ["bash", bash_script] + args
16
+
17
+ # Execute the bash script
18
+ result = subprocess.run(cmd, env=os.environ.copy())
19
+
20
+ # Exit with the same code as the bash script
21
+ sys.exit(result.returncode)
syxin_old/list.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. `train_dflash_lora.py`
2
+ * 加了lora,原来是调用小模型,现在是hidden states+lora预测。
3
+ * `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
4
+
5
+ ### 2. OOM优化
6
+ * 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。
7
+ * `batch-size=1`,`accumulation-steps=8`。
8
+ * 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
9
+ * `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。
10
+
11
+ ### 运行
12
+ * bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2
syxin_old/merge_lora.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Step 1: Merge DFlash-LoRA adapter into base model.
3
+ Usage:
4
+ conda activate sglang
5
+ python3 merge_lora.py
6
+ python3 merge_lora.py --ckpt epoch_2_step_15000 # 测其他 checkpoint
7
+ """
8
+ import argparse
9
+ import os
10
+
11
+ import torch
12
+ from peft import PeftModel
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ BASE_MODEL = "/workspace/models/Qwen3-8B"
16
+ OUTPUT_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora"
17
+ MERGE_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged"
18
+
19
+ def parse_args():
20
+ p = argparse.ArgumentParser()
21
+ p.add_argument("--ckpt", default="epoch_3_step_18576",
22
+ help="Checkpoint folder name under OUTPUT_ROOT")
23
+ p.add_argument("--merged-path", default=MERGE_ROOT,
24
+ help="Where to save the merged model")
25
+ return p.parse_args()
26
+
27
+
28
+ def main():
29
+ args = parse_args()
30
+ adapter_path = os.path.join(OUTPUT_ROOT, args.ckpt)
31
+ merged_path = args.merged_path
32
+
33
+ if os.path.exists(merged_path):
34
+ print(f"[skip] Merged model already exists: {merged_path}")
35
+ return
36
+
37
+ assert os.path.isdir(adapter_path), f"Adapter not found: {adapter_path}"
38
+
39
+ print(f"Base model : {BASE_MODEL}")
40
+ print(f"Adapter : {adapter_path}")
41
+ print(f"Output : {merged_path}")
42
+ print()
43
+
44
+ print("[1/4] Loading base model to CPU ...")
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ BASE_MODEL,
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="cpu",
49
+ )
50
+
51
+ print("[2/4] Loading LoRA adapter ...")
52
+ model = PeftModel.from_pretrained(model, adapter_path)
53
+
54
+ print("[3/4] Merging weights ...")
55
+ model = model.merge_and_unload()
56
+
57
+ print("[4/4] Saving merged model ...")
58
+ os.makedirs(merged_path, exist_ok=True)
59
+ model.save_pretrained(merged_path, safe_serialization=True)
60
+ AutoTokenizer.from_pretrained(BASE_MODEL).save_pretrained(merged_path)
61
+
62
+ print(f"\nDone. Merged model saved to: {merged_path}")
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
syxin_old/oom_fix_progress.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DFlash LoRA OOM 修复记录
2
+
3
+ ## OOM 根因分析
4
+
5
+ 1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
6
+ 2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`)
7
+ 3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
8
+ 4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4`
9
+
10
+ ## 已完成的改动
11
+
12
+ ### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
13
+ - 文件: `SpecForge/scripts/train_dflash_lora.py:347`
14
+ - `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD`
15
+ - 效果: 参数跨卡分片,每卡省 ~8-12GB
16
+
17
+ ### 2. 降 batch-size,提高 accumulation-steps
18
+ - 文件: `SpecForge/scripts/run_train_dflash_lora.sh`
19
+ - `--batch-size 2` → `1`,`--accumulation-steps 4` → `8`
20
+ - 效果: 等效 global batch size 不变,峰值显存减半
21
+
22
+ ## 待验证 / 后续优化
23
+
24
+ - [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡
25
+ - [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
26
+ - [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
27
+
28
+ ### 3. flex_attention + BlockMask 替换 4D additive mask
29
+ - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
30
+ - 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
31
+ - LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
32
+ - 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask
33
+ - HuggingFace model 用 `attn_implementation="flex_attention"` 加载
34
+ - 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存
35
+
36
+ ### 4. chunked cross-entropy loss
37
+ - 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
38
+ - 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法
39
+ - 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits
40
+ - 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用)
41
+ - `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states
42
+ - 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)
syxin_old/random_anchor_plan.md ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Plan: Add Random Anchor to DFlash LoRA-Inject (syxin_old)
2
+
3
+ ## Context
4
+ LoRA-Inject τ≈3.9 vs 原始 DFlash τ≈6.5。代码无 bug,差距来自训练策略:原始 DFlash 用 `--random-anchor --num-anchors 512` 每步随机采样 block 起点,LoRA-Inject 只用固定边界。
5
+
6
+ ## 要改的文件(全在 syxin_old)
7
+
8
+ 1. `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora_inject.py` — training wrapper
9
+ 2. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/dflash_lora_inject.py` — draft model
10
+ 3. `/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash_lora_inject.py` — 训练脚本参数
11
+ 4. `/workspace/hanrui/syxin_old/run_train_dflash_lora_inject.sh` — 启动脚本
12
+
13
+ ## 参考代码(同在 syxin_old)
14
+ - `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 73-146: `_sample_anchor_positions`, `_build_blocks_from_anchors`
15
+ - `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 219-235: `_build_additive_mask_random_anchor`
16
+ - `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 384-408: `_compute_loss_weights_random_anchor`
17
+ - `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 529-578: random anchor forward 路径
18
+
19
+ ## Step 1: training wrapper `core/dflash_lora_inject.py`
20
+
21
+ 添加 4 个方法 + 修改 forward:
22
+
23
+ **1a. `_sample_anchor_positions()`** — 直接从 dflash_lora.py 复制,完全相同
24
+
25
+ **1b. `_build_blocks_from_anchors()`** — 从 dflash_lora.py 复制,**额外** gather target_layer_hidden_states(List[Tensor] 每层一个):
26
+ ```python
27
+ block_target_hidden = []
28
+ for layer_hs in target_layer_hidden_states:
29
+ gathered = torch.gather(layer_hs, 1,
30
+ gather_idx.unsqueeze(-1).expand(-1, -1, layer_hs.size(-1)))
31
+ block_target_hidden.append(gathered)
32
+ ```
33
+ 返回多一个 `block_target_hidden_states`
34
+
35
+ **1c. `_build_additive_mask_random_anchor()`** — 直接从 dflash_lora.py 复制(line 219-235),完全相同。这是 draft-to-draft 的 mask(同 block 双向),会在 draft model 的 `_forward_with_injection` 中被扩展为 extended mask。
36
+
37
+ **1d. `_compute_loss_weights_random_anchor()`** — 直接从 dflash_lora.py 复制(line 384-408),完全相同
38
+
39
+ **1e. 修改 `forward()`** — 在 target model forward 之后插入 random anchor 分支:
40
+ ```python
41
+ if self.random_anchor and self.training:
42
+ # 1. sample anchors
43
+ # 2. build blocks (input_ids, loss_mask, target hidden per layer)
44
+ # 3. prepare_noise_input(block_ids=block_ids)
45
+ # 4. build draft-draft mask
46
+ # 5. position_ids = gather_idx (原始序列位置!)
47
+ # 6. draft_model.forward(... block_ids=block_ids) # 新参数
48
+ # 7. loss + accuracy
49
+ return loss, accuracy
50
+ ```
51
+
52
+ ## Step 2: draft model `modeling/draft/dflash_lora_inject.py`
53
+
54
+ **修改 `forward()` 和 `_forward_with_injection()`**,添加 `block_ids` 参数。
55
+
56
+ 当前 `_forward_with_injection` 的问题(random anchor 模式下):
57
+ 1. **Position IDs** (line 237-238): 用 `[0..seq_len-1, 0..seq_len-1]`,但 random anchor 需要用 `[gather_idx, gather_idx]`(原始序列位置)
58
+ 2. **Extended mask** (line 269-274): 用固定 `context_len` 和 `block_size` 算 leakage prevention,但 random anchor 的 block 边界由 `block_ids` 决定
59
+
60
+ 修改方案:
61
+ - `forward()` 加 `block_ids=None` 参数,传递给 `_forward_with_injection`
62
+ - `_forward_with_injection` 加 `block_ids=None` 参数
63
+ - 当 `block_ids is not None`:
64
+ - position 使用调用方传入的 `position_ids`,构造 `extended_pos = cat([position_ids, position_ids])`(gather_idx 的位置)
65
+ - extended mask 的 draft-to-target 部分:block k 的 draft token 只能看 block < k 的 target token(用 block_ids 判断:`block_ids[target_pos] < block_ids[draft_pos]`,per-sample)
66
+ - draft-to-draft 部分:使用传入的 `attention_mask`(已由 wrapper 构建好的同 block 双向 mask)
67
+
68
+ ## Step 3: 训练脚本
69
+
70
+ **3a. `scripts/train_dflash_lora_inject.py`**
71
+ - 添加 argparse 参数 `--random-anchor` (store_true) 和 `--num-anchors` (default=512)
72
+ - line 204: `random_anchor=False` → `random_anchor=args.random_anchor`
73
+ - line 205: `num_anchors=512` → `num_anchors=args.num_anchors`
74
+
75
+ **3b. `run_train_dflash_lora_inject.sh`**
76
+ - 添加 `--random-anchor --num-anchors 512`
77
+ - 建议提高 lr 到 6e-4,epoch 到 6(对齐原始 DFlash)
78
+
79
+ ## 验证
80
+ 1. 跑几步训练确认 loss 正常下降
81
+ 2. 对比 dflash_lora.py random anchor 路径确认 mask/loss/position 逻辑一致
82
+ 3. 完整训练后重新 eval
syxin_old/requirements.txt ADDED
File without changes
syxin_old/run_bench_dflash.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Evaluate DFlash-LoRA-Inject accepted length (offline, 8 GPUs parallel).
3
+ # No sglang server needed. Each GPU loads its own target+draft and processes a shard.
4
+ #
5
+ # Usage:
6
+ # bash run_bench_dflash.sh # 8 GPUs, all 3 benches
7
+ # bash run_bench_dflash.sh humaneval # only humaneval
8
+ # bash run_bench_dflash.sh mtbench gsm8k # pick any subset
9
+ # bash run_bench_dflash.sh --quick # quick test (20 samples)
10
+ # bash run_bench_dflash.sh --ckpt epoch_0_step_500 # specific checkpoint
11
+ # NUM_GPUS=4 bash run_bench_dflash.sh # use 4 GPUs
12
+
13
+ set -e
14
+
15
+ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
16
+ PYTHON=/workspace/miniconda3/envs/spec/bin/python3
17
+ RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
18
+ NUM_GPUS=${NUM_GPUS:-8}
19
+
20
+ # ---- parse args ----
21
+ BENCHMARKS=()
22
+ EXTRA_ARGS=()
23
+ QUICK=false
24
+
25
+ for arg in "$@"; do
26
+ case $arg in
27
+ humaneval|mtbench|gsm8k)
28
+ BENCHMARKS+=("$arg")
29
+ ;;
30
+ --quick)
31
+ QUICK=true
32
+ ;;
33
+ *)
34
+ EXTRA_ARGS+=("$arg")
35
+ ;;
36
+ esac
37
+ done
38
+
39
+ if [ ${#BENCHMARKS[@]} -eq 0 ]; then
40
+ BENCHMARKS=(humaneval mtbench gsm8k)
41
+ fi
42
+
43
+ if [ "$QUICK" = true ]; then
44
+ EXTRA_ARGS+=(--num-samples 20)
45
+ fi
46
+
47
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
48
+
49
+ echo "============================================"
50
+ echo " DFlash-LoRA-Inject Offline Eval"
51
+ echo " GPUs : $NUM_GPUS"
52
+ echo " benchmarks : ${BENCHMARKS[*]}"
53
+ echo " extra args : ${EXTRA_ARGS[*]}"
54
+ echo " results : $RESULT_DIR"
55
+ echo "============================================"
56
+ echo ""
57
+
58
+ mkdir -p $RESULT_DIR
59
+
60
+ $PYTHON -m torch.distributed.run \
61
+ --standalone \
62
+ --nproc_per_node $NUM_GPUS \
63
+ $SCRIPT_DIR/eval_dflash_lora_inject.py \
64
+ --benchmarks ${BENCHMARKS[@]} \
65
+ --output-dir $RESULT_DIR \
66
+ "${EXTRA_ARGS[@]}" \
67
+ 2>&1 | tee $RESULT_DIR/bench_dflash_lora_inject_offline_${TIMESTAMP}.log
68
+
69
+ echo ""
70
+ echo "Done. Latest result files:"
71
+ ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
syxin_old/run_bench_dflash_b16_baseline.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # DFlash-b16 baseline: measure accepted length offline, 8 GPUs parallel.
3
+ # Usage:
4
+ # bash run_bench_dflash_b16_baseline.sh # 8 GPUs, all 3 benches
5
+ # bash run_bench_dflash_b16_baseline.sh humaneval # only humaneval
6
+ # bash run_bench_dflash_b16_baseline.sh --quick # 20 samples per bench
7
+ # NUM_GPUS=4 bash run_bench_dflash_b16_baseline.sh # 4 GPUs
8
+
9
+ set -e
10
+
11
+ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
12
+ PYTHON=/workspace/miniconda3/envs/spec/bin/python3
13
+ RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
14
+ NUM_GPUS=${NUM_GPUS:-8}
15
+
16
+ BENCHMARKS=()
17
+ EXTRA_ARGS=()
18
+ QUICK=false
19
+
20
+ for arg in "$@"; do
21
+ case $arg in
22
+ humaneval|mtbench|gsm8k) BENCHMARKS+=("$arg") ;;
23
+ --quick) QUICK=true ;;
24
+ *) EXTRA_ARGS+=("$arg") ;;
25
+ esac
26
+ done
27
+
28
+ if [ ${#BENCHMARKS[@]} -eq 0 ]; then
29
+ BENCHMARKS=(humaneval mtbench gsm8k)
30
+ fi
31
+
32
+ if [ "$QUICK" = true ]; then
33
+ EXTRA_ARGS+=(--num-samples 20)
34
+ fi
35
+
36
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
37
+
38
+ echo "============================================"
39
+ echo " DFlash-b16 Baseline Offline Eval"
40
+ echo " GPUs : $NUM_GPUS"
41
+ echo " draft : /workspace/models/Qwen3-8B-DFlash-b16"
42
+ echo " benchmarks : ${BENCHMARKS[*]}"
43
+ echo " extra args : ${EXTRA_ARGS[*]}"
44
+ echo "============================================"
45
+ echo ""
46
+
47
+ mkdir -p $RESULT_DIR
48
+
49
+ $PYTHON -m torch.distributed.run \
50
+ --standalone \
51
+ --nproc_per_node $NUM_GPUS \
52
+ $SCRIPT_DIR/eval_dflash_b16_baseline.py \
53
+ --benchmarks ${BENCHMARKS[@]} \
54
+ --output-dir $RESULT_DIR \
55
+ "${EXTRA_ARGS[@]}" \
56
+ 2>&1 | tee $RESULT_DIR/bench_dflash_b16_baseline_${TIMESTAMP}.log
57
+
58
+ echo ""
59
+ echo "Done. Latest result files:"
60
+ ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
syxin_old/run_bench_dflash_lora_inject.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # DFlash-LoRA-Inject: measure accepted length offline, 8 GPUs parallel.
3
+ # Usage:
4
+ # bash run_bench_dflash_lora_inject.sh # 8 GPUs, all 3 benches
5
+ # bash run_bench_dflash_lora_inject.sh humaneval # only humaneval
6
+ # bash run_bench_dflash_lora_inject.sh --quick # 20 samples per bench
7
+ # NUM_GPUS=4 bash run_bench_dflash_lora_inject.sh # 4 GPUs
8
+
9
+ set -e
10
+
11
+ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
12
+ PYTHON=/workspace/miniconda3/envs/dflash/bin/python3
13
+ RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
14
+ NUM_GPUS=${NUM_GPUS:-8}
15
+
16
+ BENCHMARKS=()
17
+ EXTRA_ARGS=()
18
+ QUICK=false
19
+
20
+ for arg in "$@"; do
21
+ case $arg in
22
+ humaneval|mtbench|gsm8k) BENCHMARKS+=("$arg") ;;
23
+ --quick) QUICK=true ;;
24
+ *) EXTRA_ARGS+=("$arg") ;;
25
+ esac
26
+ done
27
+
28
+ if [ ${#BENCHMARKS[@]} -eq 0 ]; then
29
+ BENCHMARKS=(humaneval mtbench gsm8k)
30
+ fi
31
+
32
+ if [ "$QUICK" = true ]; then
33
+ EXTRA_ARGS+=(--num-samples 20)
34
+ fi
35
+
36
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
37
+
38
+ echo "============================================"
39
+ echo " DFlash-LoRA-Inject Offline Eval"
40
+ echo " GPUs : $NUM_GPUS"
41
+ echo " draft : LoRA-inject (merged)"
42
+ echo " benchmarks : ${BENCHMARKS[*]}"
43
+ echo " extra args : ${EXTRA_ARGS[*]}"
44
+ echo "============================================"
45
+ echo ""
46
+
47
+ mkdir -p $RESULT_DIR
48
+
49
+ $PYTHON -m torch.distributed.run \
50
+ --standalone \
51
+ --nproc_per_node $NUM_GPUS \
52
+ $SCRIPT_DIR/eval_dflash_lora_inject.py \
53
+ --benchmarks ${BENCHMARKS[@]} \
54
+ --output-dir $RESULT_DIR \
55
+ "${EXTRA_ARGS[@]}" \
56
+ 2>&1 | tee $RESULT_DIR/bench_dflash_lora_inject_${TIMESTAMP}.log
57
+
58
+ echo ""
59
+ echo "Done. Latest result files:"
60
+ ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
syxin_old/run_qwen3_8b_sft_64gpu.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export JOB_NAME='qwen3-32b-sft'
3
+ export GPU_NUMS=32
4
+ export TRAIN_SCRIPT='/workspace/hanrui/syxin_old/launch_train_dflash_wrapper.py'
5
+ export WORK_DIR='/workspace/hanrui/syxin_old/Specforge'
6
+
7
+ if [ $GPU_NUMS -lt 8 ]; then
8
+ export NNODES=1
9
+ export GPU_NUMS_PER_NODE=$GPU_NUMS
10
+ else
11
+ export NNODES=$((GPU_NUMS/8))
12
+ export GPU_NUMS_PER_NODE=8
13
+ fi
14
+
15
+ # 使用 spec 环境的 northjob
16
+ /workspace/miniconda3/envs/spec/bin/northjob \
17
+ create \
18
+ --job-type train \
19
+ --nproc-per-node $GPU_NUMS_PER_NODE \
20
+ --gpu-per-node $GPU_NUMS_PER_NODE \
21
+ --nnodes $NNODES \
22
+ --k8s-priority 3 \
23
+ --k8s-queue bg-agentic-coding \
24
+ --k8s-namespace bg-agentic-coding \
25
+ --k8s-pvc-name i-xinsiyang-y4zy0sik0a \
26
+ --k8s-pvc-mount-path /workspace \
27
+ --k8s-no-reclaim \
28
+ --k8s-images harbor.local.clusters/bp/megatron-bplm:25.03_fp8.ibgda.qwen3.next.fix_triton.fix_te.hf457.qwen3_vl \
29
+ --job-name $JOB_NAME \
30
+ --workspace $WORK_DIR \
31
+ $TRAIN_SCRIPT $GPU_NUMS_PER_NODE
syxin_old/run_train_dflash_lora_inject.sh ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
5
+ NUM_GPUS=8
6
+ OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject-random-anchor
7
+ CACHE_DIR=/tmp/specforge_cache
8
+
9
+ # Parse arguments
10
+ if [[ $# -ge 1 ]]; then
11
+ NUM_GPUS=$1
12
+ shift
13
+ fi
14
+ if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
15
+ OUTPUT_DIR=$1
16
+ shift
17
+ fi
18
+ EXTRA_ARGS=("$@")
19
+
20
+ # Environment variables
21
+ export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
22
+ export SPECFORGE_DATA_NUM_PROC=16
23
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
24
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
25
+ export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
26
+ export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
27
+ export HF_HOME=/tmp/specforge_cache/hf_home
28
+
29
+ # Python binary
30
+ DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
31
+ if [[ -z "${PYTHON_BIN:-}" ]]; then
32
+ if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
33
+ PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
34
+ else
35
+ PYTHON_BIN=python3
36
+ fi
37
+ fi
38
+
39
+ cd $ROOT_DIR
40
+
41
+ $PYTHON_BIN -m torch.distributed.run \
42
+ --standalone \
43
+ --nproc_per_node $NUM_GPUS \
44
+ scripts/train_dflash_lora_inject.py \
45
+ --target-model-path /workspace/models/Qwen3-8B \
46
+ --target-model-backend hf \
47
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
48
+ --output-dir $OUTPUT_DIR \
49
+ --block-size 16 \
50
+ --attention-backend additive \
51
+ --attn-implementation sdpa \
52
+ --random-anchor \
53
+ --num-anchors 64 \
54
+ --max-length 2048 \
55
+ --batch-size 1 \
56
+ --accumulation-steps 64 \
57
+ --num-epochs 6 \
58
+ --learning-rate 6e-4 \
59
+ --loss-decay-gamma 7 \
60
+ --gradient-checkpointing \
61
+ --chat-template qwen \
62
+ --log-interval 50 \
63
+ --save-interval 500 \
64
+ --cache-dir $CACHE_DIR \
65
+ --lora-rank 32 \
66
+ --lora-alpha 64 \
67
+ --lora-dropout 0.1 \
68
+ --trust-remote-code \
69
+ --dataloader-num-workers 0 \
70
+ --early-stop \
71
+ --early-stop-patience 5 \
72
+ --early-stop-min-delta 0.005 \
73
+ "${EXTRA_ARGS[@]}"
syxin_old/run_train_multinode.sh ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
5
+ NUM_GPUS=8
6
+ OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject
7
+ CACHE_DIR=/tmp/specforge_cache
8
+
9
+ # Parse arguments
10
+ if [[ $# -ge 1 ]]; then
11
+ NUM_GPUS=$1
12
+ shift
13
+ fi
14
+ if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
15
+ OUTPUT_DIR=$1
16
+ shift
17
+ fi
18
+ EXTRA_ARGS=("$@")
19
+
20
+ # Environment variables
21
+ export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
22
+ export SPECFORGE_DATA_NUM_PROC=16
23
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
24
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
25
+ export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
26
+ export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
27
+ export HF_HOME=/tmp/specforge_cache/hf_home
28
+
29
+ # Python binary
30
+ DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
31
+ if [[ -z "${PYTHON_BIN:-}" ]]; then
32
+ if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
33
+ PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
34
+ else
35
+ PYTHON_BIN=python3
36
+ fi
37
+ fi
38
+
39
+ cd $ROOT_DIR
40
+
41
+ # northjob 已经通过 torchrun 设置了分布式环境变量
42
+ # 直接运行训练脚本,不要再启动 torch.distributed.run
43
+ $PYTHON_BIN scripts/train_dflash_lora_inject.py \
44
+ --target-model-path /workspace/models/Qwen3-8B \
45
+ --target-model-backend hf \
46
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
47
+ --output-dir $OUTPUT_DIR \
48
+ --block-size 16 \
49
+ --attention-backend additive \
50
+ --attn-implementation sdpa \
51
+ --max-length 2048 \
52
+ --batch-size 8 \
53
+ --accumulation-steps 8 \
54
+ --num-epochs 3 \
55
+ --learning-rate 5e-5 \
56
+ --loss-decay-gamma 7 \
57
+ --gradient-checkpointing \
58
+ --chat-template qwen \
59
+ --log-interval 50 \
60
+ --save-interval 500 \
61
+ --cache-dir $CACHE_DIR \
62
+ --lora-rank 32 \
63
+ --lora-alpha 64 \
64
+ --lora-dropout 0.1 \
65
+ --trust-remote-code \
66
+ --dataloader-num-workers 0 \
67
+ "${EXTRA_ARGS[@]}"
syxin_old/run_train_multinode_random_anchor.sh ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
5
+ NUM_GPUS=8
6
+ OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject-random-anchor
7
+ CACHE_DIR=/tmp/specforge_cache
8
+
9
+ # Parse arguments
10
+ if [[ $# -ge 1 ]]; then
11
+ NUM_GPUS=$1
12
+ shift
13
+ fi
14
+ if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
15
+ OUTPUT_DIR=$1
16
+ shift
17
+ fi
18
+ EXTRA_ARGS=("$@")
19
+
20
+ # Environment variables
21
+ export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
22
+ export SPECFORGE_DATA_NUM_PROC=16
23
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
24
+ export PYTORCH_ALLOC_CONF=expandable_segments:True
25
+ export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
26
+ export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
27
+ export HF_HOME=/tmp/specforge_cache/hf_home
28
+
29
+ # Python binary
30
+ DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
31
+ if [[ -z "${PYTHON_BIN:-}" ]]; then
32
+ if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
33
+ PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
34
+ else
35
+ PYTHON_BIN=python3
36
+ fi
37
+ fi
38
+
39
+ cd $ROOT_DIR
40
+
41
+ # northjob 已经通过 torchrun 设置了分布式环境变量
42
+ # 直接运行训练脚本,不要再启动 torch.distributed.run
43
+ $PYTHON_BIN scripts/train_dflash_lora_inject.py \
44
+ --target-model-path /workspace/models/Qwen3-8B \
45
+ --target-model-backend hf \
46
+ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
47
+ --output-dir $OUTPUT_DIR \
48
+ --block-size 16 \
49
+ --attention-backend additive \
50
+ --attn-implementation sdpa \
51
+ --random-anchor \
52
+ --num-anchors 64 \
53
+ --max-length 2048 \
54
+ --batch-size 1 \
55
+ --accumulation-steps 64 \
56
+ --num-epochs 6 \
57
+ --learning-rate 6e-4 \
58
+ --loss-decay-gamma 7 \
59
+ --gradient-checkpointing \
60
+ --chat-template qwen \
61
+ --log-interval 50 \
62
+ --save-interval 500 \
63
+ --cache-dir $CACHE_DIR \
64
+ --lora-rank 32 \
65
+ --lora-alpha 64 \
66
+ --lora-dropout 0.1 \
67
+ --trust-remote-code \
68
+ --dataloader-num-workers 0 \
69
+ --early-stop \
70
+ --early-stop-patience 5 \
71
+ --early-stop-min-delta 0.005 \
72
+ "${EXTRA_ARGS[@]}"
syxin_old/start_server.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Step 2: Launch SGLang server with STANDALONE speculative decoding.
3
+ # Usage:
4
+ # bash start_server.sh
5
+ # bash start_server.sh 8 # use tp=8
6
+
7
+ set -e
8
+
9
+ TP=${1:-2}
10
+
11
+ BASE_MODEL=/workspace/models/Qwen3-8B
12
+ MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged
13
+ INTRANET_IP=10.1.1.131
14
+ PORT=30000
15
+
16
+ if [ ! -d "$MERGED" ]; then
17
+ echo "[ERROR] Merged model not found: $MERGED"
18
+ echo " Run: conda activate sglang && python3 merge_lora.py"
19
+ exit 1
20
+ fi
21
+
22
+ echo "============================================"
23
+ echo " SGLang STANDALONE Speculative Decoding"
24
+ echo " target : $BASE_MODEL"
25
+ echo " draft : $MERGED"
26
+ echo " host : $INTRANET_IP:$PORT"
27
+ echo " tp : $TP"
28
+ echo "============================================"
29
+
30
+ /workspace/miniconda3/envs/sglang/bin/python3 -m sglang.launch_server \
31
+ --model-path $BASE_MODEL \
32
+ --speculative-algorithm STANDALONE \
33
+ --speculative-draft-model-path $MERGED \
34
+ --speculative-num-steps 4 \
35
+ --speculative-eagle-topk 1 \
36
+ --speculative-num-draft-tokens 4 \
37
+ --tp-size $TP \
38
+ --mem-fraction-static 0.30 \
39
+ --trust-remote-code \
40
+ --host $INTRANET_IP \
41
+ --port $PORT \
42
+ --dtype bfloat16
syxin_old/start_server_dflash.sh ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Evaluate DFlash-LoRA-Inject: measure accepted length OFFLINE.
3
+ # 8 GPUs parallel by default, each GPU runs a shard of prompts independently.
4
+ #
5
+ # WHY offline?
6
+ # sglang STANDALONE treats draft as an independent autoregressive model,
7
+ # completely ignoring the layer-by-layer injection that LoRA-Inject was
8
+ # trained with. Result: accept_length ≈ 4.7 for ALL models (no signal).
9
+ #
10
+ # sglang DFLASH expects the DFlash-b16 architecture (5-layer, fc+hidden_norm),
11
+ # which is structurally different from LoRA-Inject (full 36-layer + LoRA).
12
+ #
13
+ # So we run offline spec-generate with the correct injection pattern.
14
+ #
15
+ # Usage:
16
+ # bash start_server_dflash.sh # 8 GPUs, all benchmarks
17
+ # bash start_server_dflash.sh 4 # 4 GPUs
18
+ # bash start_server_dflash.sh 8 humaneval # specific benchmark
19
+ # bash start_server_dflash.sh 8 --num-samples 20 # quick test
20
+
21
+ set -e
22
+
23
+ SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
24
+
25
+ NUM_GPUS=${1:-8}
26
+ shift 2>/dev/null || true
27
+
28
+ # ---- defaults ----
29
+ BASE_MODEL=/workspace/models/Qwen3-8B
30
+ ADAPTER_ROOT=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject
31
+ CKPT=epoch_3_step_1400
32
+ MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged
33
+ RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
34
+ PYTHON=/workspace/miniconda3/envs/spec/bin/python3
35
+
36
+ echo "============================================"
37
+ echo " DFlash-LoRA-Inject Offline Evaluation"
38
+ echo " target : $BASE_MODEL"
39
+ echo " ckpt : $CKPT"
40
+ echo " merged : $MERGED"
41
+ echo " GPUs : $NUM_GPUS"
42
+ echo "============================================"
43
+
44
+ $PYTHON -m torch.distributed.run \
45
+ --standalone \
46
+ --nproc_per_node $NUM_GPUS \
47
+ $SCRIPT_DIR/eval_dflash_lora_inject.py \
48
+ --base-model $BASE_MODEL \
49
+ --adapter-root $ADAPTER_ROOT \
50
+ --ckpt $CKPT \
51
+ --merged-path $MERGED \
52
+ --block-size 16 \
53
+ --output-dir $RESULT_DIR \
54
+ "$@"