Add files using upload-large-folder tool
Browse files- IDEA_REPORT.md +187 -0
- datasets/_workspace_hanrui_datasets_HuggingFaceH4___aime_2024_default_0.0.0_2fe88a2f1091d5048c0f36abc874fb997b3dd99a.lock +0 -0
- datasets/_workspace_hanrui_datasets_MathArena___aime_2025_default_0.0.0_beca2d7875cf92cdac07acefbccad3c4d16e2916.lock +0 -0
- datasets/_workspace_hanrui_datasets_google-research-datasets___mbpp_sanitized_0.0.0_4bb6404fdc6cacfda99d4ac4205087b89d32030c.lock +0 -0
- datasets/_workspace_hanrui_datasets_json_default-3ab01998402731b9_0.0.0_c181ad2be84b86e0b75142bbe88bda3f4906d051ee75b5ff536a5dba0ffbe8f2.lock +0 -0
- datasets/_workspace_hanrui_datasets_princeton-nlp___swe-bench_lite_default_0.0.0_6ec7bb89b9342f664a54a6e0a6ea6501d3437cc2.lock +0 -0
- datasets/_workspace_hanrui_datasets_tatsu-lab___alpaca_default_0.0.0_dce01c9b08f87459cf36a430d809084718273017.lock +0 -0
- datasets/download_nemotron_codealpha.sh +10 -0
- manage_subgits.sh +87 -0
- nohup.out +48 -0
- progress/dflash_lora_changelog.md +232 -0
- progress/list.md +12 -0
- progress/oom_fix_progress.md +42 -0
- progress/requirements.txt +20 -0
- progress/step1.md +139 -0
- sglang/.codespellrc +3 -0
- sglang/.editorconfig +25 -0
- sglang/.isort.cfg +3 -0
- sglang/.pre-commit-config.yaml +83 -0
- sglang/CODE_OF_CONDUCT.md +128 -0
- sglang/LICENSE +201 -0
- sglang/README.md +90 -0
- syxin_old/DFLASH_LORA_INJECT_FIXES.md +142 -0
- syxin_old/backup.log +0 -0
- syxin_old/dflash_8gpu_03-31-13:40.log +552 -0
- syxin_old/diagnostic_compare.py +301 -0
- syxin_old/eval_alignment_diff.md +132 -0
- syxin_old/eval_dflash_b16_baseline.py +354 -0
- syxin_old/eval_dflash_b16_baseline_changelog.md +143 -0
- syxin_old/eval_dflash_lora_inject.py +660 -0
- syxin_old/eval_gsm8k_humaneval_mtbench.log +81 -0
- syxin_old/eval_run.log +0 -0
- syxin_old/launch_train.sh +37 -0
- syxin_old/launch_train_dflash_wrapper.py +17 -0
- syxin_old/launch_train_random_anchor.py +15 -0
- syxin_old/launch_train_wrapper.py +21 -0
- syxin_old/list.md +12 -0
- syxin_old/merge_lora.py +66 -0
- syxin_old/oom_fix_progress.md +42 -0
- syxin_old/random_anchor_plan.md +82 -0
- syxin_old/requirements.txt +0 -0
- syxin_old/run_bench_dflash.sh +71 -0
- syxin_old/run_bench_dflash_b16_baseline.sh +60 -0
- syxin_old/run_bench_dflash_lora_inject.sh +60 -0
- syxin_old/run_qwen3_8b_sft_64gpu.sh +31 -0
- syxin_old/run_train_dflash_lora_inject.sh +73 -0
- syxin_old/run_train_multinode.sh +67 -0
- syxin_old/run_train_multinode_random_anchor.sh +72 -0
- syxin_old/start_server.sh +42 -0
- syxin_old/start_server_dflash.sh +54 -0
IDEA_REPORT.md
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash Improvement Ideas: Higher Acceptance Length Without Training
|
| 2 |
+
|
| 3 |
+
**Goal:** Improve DFlash's acceptance length (tau) and acceleration ratio using only inference-time modifications — no additional training.
|
| 4 |
+
|
| 5 |
+
**Baseline:** Qwen3-4B + z-lab/Qwen3-4B-DFlash-b16, block_size=16, math500 (10 samples, 512 tokens)
|
| 6 |
+
- **Baseline avg tau = 8.63**, median = 8.0
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## Idea 1: Iterative Block Refinement (Multi-Step Denoising)⭐⭐⭐⭐⭐
|
| 11 |
+
|
| 12 |
+
**Core Idea:** Run the DFlash draft model multiple times on the same block. After each pass, use the sampled tokens as updated noise embeddings for the next pass, mimicking multi-step diffusion denoising.
|
| 13 |
+
|
| 14 |
+
**Why it might work:** DFlash currently uses a single forward pass to predict all block tokens from mask tokens. The initial mask embeddings carry no information about what the draft should generate. By iterating, each pass conditions on an increasingly informed noise context — the first pass gives a rough draft, the second pass refines it with better token embeddings as context.
|
| 15 |
+
|
| 16 |
+
**Implementation complexity:** Low. Just loop the draft forward pass 2-3 times, feeding output back as input. No KV cache across steps.
|
| 17 |
+
|
| 18 |
+
**Expected improvement:** +0.5 to +2.0 tau (denoising is the core mechanism of diffusion models — more steps should help).
|
| 19 |
+
|
| 20 |
+
**Risk:** Extra draft compute may negate speedup gains. Must keep step count low (2-3) to maintain wall-clock advantage.
|
| 21 |
+
|
| 22 |
+
**Pilot result:** `[PENDING]`
|
| 23 |
+
|
| 24 |
+
## Idea 1 plus: Confidence-Gated Selective Redrafting
|
| 25 |
+
|
| 26 |
+
**Core Idea:** After the first draft pass, compute per-position entropy of the draft logits. If any position (especially early ones) has high entropy (>threshold), run a second draft pass with the partially-filled block as context. Only replace the high-entropy positions with the second pass's predictions.
|
| 27 |
+
|
| 28 |
+
**Why it might work:** High entropy at a position signals that the draft model is uncertain — these are the positions most likely to cause rejection. A second pass, now conditioned on a partially-correct draft, can refine exactly these problematic positions.
|
| 29 |
+
|
| 30 |
+
**Implementation complexity:** Medium. Two draft passes + entropy computation + selective replacement.
|
| 31 |
+
|
| 32 |
+
**Expected improvement:** +0.5 to +2.0 tau (targeted improvement where it matters most).
|
| 33 |
+
|
| 34 |
+
**Risk:** Extra compute for the second pass. Entropy threshold needs tuning per dataset/model.
|
| 35 |
+
|
| 36 |
+
**Pilot result:** `[PENDING]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## Idea 2: N-Best Draft Proposals (Multi-Candidate Selection)⭐
|
| 43 |
+
|
| 44 |
+
**Core Idea:** Generate K candidate draft blocks (K=2-4) using different sampling strategies (greedy + temperature-based), then select the candidate with the highest aggregate log-probability under the draft model's own distribution.
|
| 45 |
+
|
| 46 |
+
**Why it might work:** Exact-match acceptance is binary — a single wrong token kills the entire suffix. By generating multiple candidates and picking the most confident one, we increase the probability that at least one candidate matches the target's greedy output. The confidence score acts as a proxy for "likely to match target."
|
| 47 |
+
|
| 48 |
+
**Implementation complexity:** Low-Medium. K forward passes per block, simple confidence scoring.
|
| 49 |
+
|
| 50 |
+
**Expected improvement:** +0.5 to +2.5 tau (especially for "unlucky" blocks where the default greedy choice is wrong).
|
| 51 |
+
|
| 52 |
+
**Risk:** K times the draft compute cost. Must keep K small. Confidence score may not perfectly correlate with acceptance.
|
| 53 |
+
|
| 54 |
+
**Pilot result:** `[PENDING]`
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## Idea 6: Token Recycling / Warm-Start Drafting⭐⭐⭐
|
| 59 |
+
|
| 60 |
+
**Core Idea:** When rejection occurs at position j in a block of B tokens, the rejected tokens at positions j+1..B are discarded. Instead, save these tokens and use them to warm-start the noise embeddings of the next draft block. This gives the draft model a better starting point than random mask tokens.
|
| 61 |
+
|
| 62 |
+
**Why it might work:** Even though the prefix was wrong, later tokens in the rejected draft may still carry useful distributional information about the continuation. Using them as initial noise (instead of mask tokens) gives the draft model more context for its single-pass prediction.
|
| 63 |
+
|
| 64 |
+
**Implementation complexity:** Low. Save rejected suffix, inject into next block's initial embeddings.
|
| 65 |
+
|
| 66 |
+
**Expected improvement:** +0.3 to +1.0 tau (modest, since the recycled tokens are conditioned on a wrong prefix).
|
| 67 |
+
|
| 68 |
+
**Risk:** Recycled tokens may actually mislead the draft model if they were generated from a very different prefix. Net effect could be negative.
|
| 69 |
+
|
| 70 |
+
**Pilot result:** `[PENDING]`
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## Idea 9: Dynamic Target Layer Selection
|
| 75 |
+
|
| 76 |
+
**Core Idea:** Instead of always extracting features from the same 5 fixed target layers, try alternative layer selections (e.g., shifted by +2 or -2) and pick the one that produces the highest-confidence draft. Different parts of the sequence may benefit from different layers.
|
| 77 |
+
|
| 78 |
+
**Why it might work:** The paper's ablation (Table 5) shows that layer selection affects acceptance length. The optimal layers may vary by position in the sequence or by the type of content being generated. Late layers have more "final answer" information; early layers have more syntactic/structural information.
|
| 79 |
+
|
| 80 |
+
**Implementation complexity:** Medium. Multiple draft passes with different layer configs + scoring.
|
| 81 |
+
|
| 82 |
+
**Expected improvement:** +0.3 to +1.5 tau (if the fixed layers are suboptimal for certain content types).
|
| 83 |
+
|
| 84 |
+
**Risk:** The draft model's fc projection was trained on specific layer combinations. Using different layers degrades the learned alignment. Needs the fc layer to generalize.
|
| 85 |
+
|
| 86 |
+
**Pilot result:** `[PENDING]`
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Idea 11: Top-K Constrained Draft SamplingIdea 7: Confidence-Gated Selective Redrafting⭐
|
| 91 |
+
|
| 92 |
+
**Core Idea:** After the first draft pass, compute per-position entropy of the draft logits. If any position (especially early ones) has high entropy (>threshold), run a second draft pass with the partially-filled block as context. Only replace the high-entropy positions with the second pass's predictions.
|
| 93 |
+
|
| 94 |
+
**Why it might work:** High entropy at a position signals that the draft model is uncertain — these are the positions most likely to cause rejection. A second pass, now conditioned on a partially-correct draft, can refine exactly these problematic positions.
|
| 95 |
+
|
| 96 |
+
**Implementation complexity:** Medium. Two draft passes + entropy computation + selective replacement.
|
| 97 |
+
|
| 98 |
+
**Expected improvement:** +0.5 to +2.0 tau (targeted improvement where it matters most).
|
| 99 |
+
|
| 100 |
+
**Risk:** Extra compute for the second pass. Entropy threshold needs tuning per dataset/model.
|
| 101 |
+
|
| 102 |
+
**Pilot result:** `[PENDING]
|
| 103 |
+
|
| 104 |
+
**Core Idea:** Apply top-k filtering to draft logits before sampling, zeroing out all but the top-k tokens at each position. This forces the draft to choose among only the most probable tokens.
|
| 105 |
+
|
| 106 |
+
**Why it might work:** For exact-match acceptance under greedy target decoding, only the target's argmax token matters. By restricting the draft's vocabulary to its own top-k, we reduce the chance of sampling a low-probability token that definitely won't match the target.
|
| 107 |
+
|
| 108 |
+
**Implementation complexity:** Very low. Single top-k operation on logits.
|
| 109 |
+
|
| 110 |
+
**Expected improvement:** +0.1 to +0.5 tau (minor, since greedy draft already picks argmax; mainly helps with stochastic target).
|
| 111 |
+
|
| 112 |
+
**Risk:** Under greedy draft + greedy target, this is a no-op. Only helps when draft uses non-zero temperature.
|
| 113 |
+
|
| 114 |
+
**Pilot result:** `[PENDING]`
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## Idea 12: Position-Weighted Logit Scaling⭐⭐
|
| 119 |
+
|
| 120 |
+
**Core Idea:** Scale draft logits by a position-dependent factor: early positions get more aggressive scaling (sharper distribution = higher confidence), later positions get gentler scaling. Rationale: early positions matter most for prefix-based acceptance.
|
| 121 |
+
|
| 122 |
+
**Why it might work:** By sharpening early positions, we increase the probability that positions 1-3 are correct (the most critical for tau). Later positions can afford to be less sharp since they only matter if all earlier positions are accepted.
|
| 123 |
+
|
| 124 |
+
**Implementation complexity:** Very low. Multiply logits by a position-dependent vector.
|
| 125 |
+
|
| 126 |
+
**Expected improvement:** +0.2 to +1.0 tau.
|
| 127 |
+
|
| 128 |
+
**Risk:** Over-sharpening may concentrate probability on a wrong token. Needs careful calibration of the scaling schedule.
|
| 129 |
+
|
| 130 |
+
**Pilot result:** `[PENDING]`
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## Bonus Ideas (Not Yet Implemented)
|
| 135 |
+
|
| 136 |
+
### Idea 13: Tree-Structured Verification
|
| 137 |
+
Verify multiple candidate continuations in a single batched target forward pass using packed attention with tree causal masks. This doesn't improve tau per-candidate but amortizes the verification cost across candidates, enabling higher effective throughput. Very promising for combining with N-best or beam approaches.
|
| 138 |
+
|
| 139 |
+
### Idea 16: Draft-Target KL Alignment via Inference-Time Calibration⭐⭐⭐
|
| 140 |
+
Compute a lightweight calibration mapping (affine transform on draft logits) by running a small calibration set and measuring draft vs target token agreement. Apply this calibration at inference time without retraining.
|
| 141 |
+
|
| 142 |
+
### Idea 17: Multi-Block Pipelining
|
| 143 |
+
Overlap the draft and verification phases across blocks. While the target model verifies block k, the draft model starts working on block k+1 using a speculative target_hidden extrapolation. If the speculation was right, the pipeline stays full.
|
| 144 |
+
|
| 145 |
+
---
|
| 146 |
+
|
| 147 |
+
## Experiment Configuration
|
| 148 |
+
|
| 149 |
+
| Parameter | Value |
|
| 150 |
+
|-----------|-------|
|
| 151 |
+
| Target model | Qwen/Qwen3-4B |
|
| 152 |
+
| Draft model | z-lab/Qwen3-4B-DFlash-b16 |
|
| 153 |
+
| Block size | 16 |
|
| 154 |
+
| Dataset | math500 |
|
| 155 |
+
| Max samples | 10 |
|
| 156 |
+
| Max new tokens | 512 |
|
| 157 |
+
| Temperature | 0.0 (greedy) |
|
| 158 |
+
| GPU | NVIDIA H200 (single GPU) |
|
| 159 |
+
| Attention | SDPA |
|
| 160 |
+
|
| 161 |
+
## Results Summary
|
| 162 |
+
|
| 163 |
+
| # | Method | Avg tau | Delta | Pilot Signal |
|
| 164 |
+
|---|--------|---------|-------|--------------|
|
| 165 |
+
| 0 | **Baseline** | **8.63** | - | - |
|
| 166 |
+
| 1 | Iterative Refinement (2 steps) | `[PENDING]` | | |
|
| 167 |
+
| 2 | Iterative Refinement (3 steps) | `[PENDING]` | | |
|
| 168 |
+
| 3 | N-Best Draft (K=2) | `[PENDING]` | | |
|
| 169 |
+
| 4 | N-Best Draft (K=3) | `[PENDING]` | | |
|
| 170 |
+
| 5 | Adaptive Block Size (4-16) | `[PENDING]` | | |
|
| 171 |
+
| 6 | Early-Position Beam (width=3) | `[PENDING]` | | |
|
| 172 |
+
| 7 | Draft Temp t=0.3 | `[PENDING]` | | |
|
| 173 |
+
| 8 | Draft Temp t=0.1 | `[PENDING]` | | |
|
| 174 |
+
| 9 | Token Recycling | `[PENDING]` | | |
|
| 175 |
+
| 10 | Selective Redraft (ent>1.5) | `[PENDING]` | | |
|
| 176 |
+
| 11 | Selective Redraft (ent>1.0) | `[PENDING]` | | |
|
| 177 |
+
| 12 | Majority Vote (K=3) | `[PENDING]` | | |
|
| 178 |
+
| 13 | Majority Vote (K=5) | `[PENDING]` | | |
|
| 179 |
+
| 14 | Shifted Target Layers (+2) | `[PENDING]` | | |
|
| 180 |
+
| 15 | Logit Averaging (2 pass) | `[PENDING]` | | |
|
| 181 |
+
| 16 | Logit Averaging (3 pass) | `[PENDING]` | | |
|
| 182 |
+
| 17 | Top-K Constrained (k=10) | `[PENDING]` | | |
|
| 183 |
+
| 18 | Position-Weighted Temp | `[PENDING]` | | |
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
*Generated 2026-04-01. Experiments running on NVIDIA H200, dflash conda env.*
|
datasets/_workspace_hanrui_datasets_HuggingFaceH4___aime_2024_default_0.0.0_2fe88a2f1091d5048c0f36abc874fb997b3dd99a.lock
ADDED
|
File without changes
|
datasets/_workspace_hanrui_datasets_MathArena___aime_2025_default_0.0.0_beca2d7875cf92cdac07acefbccad3c4d16e2916.lock
ADDED
|
File without changes
|
datasets/_workspace_hanrui_datasets_google-research-datasets___mbpp_sanitized_0.0.0_4bb6404fdc6cacfda99d4ac4205087b89d32030c.lock
ADDED
|
File without changes
|
datasets/_workspace_hanrui_datasets_json_default-3ab01998402731b9_0.0.0_c181ad2be84b86e0b75142bbe88bda3f4906d051ee75b5ff536a5dba0ffbe8f2.lock
ADDED
|
File without changes
|
datasets/_workspace_hanrui_datasets_princeton-nlp___swe-bench_lite_default_0.0.0_6ec7bb89b9342f664a54a6e0a6ea6501d3437cc2.lock
ADDED
|
File without changes
|
datasets/_workspace_hanrui_datasets_tatsu-lab___alpaca_default_0.0.0_dce01c9b08f87459cf36a430d809084718273017.lock
ADDED
|
File without changes
|
datasets/download_nemotron_codealpha.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
export HF_TOKEN="YOUR_HF_TOKEN_HERE"
|
| 4 |
+
export HF_HUB_ENABLE_HF_TRANSFER=1
|
| 5 |
+
export HF_HUB_VERBOSITY=debug
|
| 6 |
+
|
| 7 |
+
hf download \
|
| 8 |
+
--repo-type dataset \
|
| 9 |
+
--local-dir /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 10 |
+
eigen-ai-labs/Nemotron-CodeAlpaca-qwen3-8b-800K
|
manage_subgits.sh
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# 管理子目录中的 .git 文件夹:备份、删除、恢复
|
| 3 |
+
# 用法:
|
| 4 |
+
# ./manage_subgits.sh backup - 备份并删除子目录中的 .git
|
| 5 |
+
# ./manage_subgits.sh restore - 从备份中恢复 .git
|
| 6 |
+
|
| 7 |
+
set -euo pipefail
|
| 8 |
+
|
| 9 |
+
cd "$(dirname "$0")"
|
| 10 |
+
|
| 11 |
+
BACKUP_DIR=".git_backups"
|
| 12 |
+
MANIFEST="$BACKUP_DIR/manifest.txt"
|
| 13 |
+
|
| 14 |
+
backup() {
|
| 15 |
+
if [ -d "$BACKUP_DIR" ]; then
|
| 16 |
+
echo "❌ 备份目录 $BACKUP_DIR 已存在,请先 restore 或手动删除"
|
| 17 |
+
exit 1
|
| 18 |
+
fi
|
| 19 |
+
|
| 20 |
+
mkdir -p "$BACKUP_DIR"
|
| 21 |
+
> "$MANIFEST"
|
| 22 |
+
|
| 23 |
+
count=0
|
| 24 |
+
while IFS= read -r gitdir; do
|
| 25 |
+
count=$((count + 1))
|
| 26 |
+
echo "$count|$gitdir" >> "$MANIFEST"
|
| 27 |
+
|
| 28 |
+
echo "📦 备份: $gitdir"
|
| 29 |
+
cp -a "$gitdir" "$BACKUP_DIR/$count"
|
| 30 |
+
|
| 31 |
+
echo "🗑️ 删除: $gitdir"
|
| 32 |
+
rm -rf "$gitdir"
|
| 33 |
+
done < <(find . -mindepth 2 -name ".git" -not -path "./$BACKUP_DIR/*" | sort)
|
| 34 |
+
|
| 35 |
+
if [ "$count" -eq 0 ]; then
|
| 36 |
+
rm -rf "$BACKUP_DIR"
|
| 37 |
+
echo "ℹ️ 没有找到子目录中的 .git,无需操作"
|
| 38 |
+
else
|
| 39 |
+
echo ""
|
| 40 |
+
echo "✅ 完成!共备份并删除了 $count 个 .git"
|
| 41 |
+
echo "📁 备份存放在: $BACKUP_DIR/"
|
| 42 |
+
echo "👉 上传完成后运行: $0 restore"
|
| 43 |
+
fi
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
restore() {
|
| 47 |
+
if [ ! -f "$MANIFEST" ]; then
|
| 48 |
+
echo "❌ 找不到备份清单 $MANIFEST,没有可恢复的内容"
|
| 49 |
+
exit 1
|
| 50 |
+
fi
|
| 51 |
+
|
| 52 |
+
count=0
|
| 53 |
+
while IFS='|' read -r id gitdir; do
|
| 54 |
+
if [ ! -d "$BACKUP_DIR/$id" ]; then
|
| 55 |
+
echo "⚠️ 跳过: 备份 #$id 不存在 ($gitdir)"
|
| 56 |
+
continue
|
| 57 |
+
fi
|
| 58 |
+
|
| 59 |
+
mkdir -p "$(dirname "$gitdir")"
|
| 60 |
+
|
| 61 |
+
echo "♻️ 恢复: $gitdir"
|
| 62 |
+
cp -a "$BACKUP_DIR/$id" "$gitdir"
|
| 63 |
+
count=$((count + 1))
|
| 64 |
+
done < "$MANIFEST"
|
| 65 |
+
|
| 66 |
+
rm -rf "$BACKUP_DIR"
|
| 67 |
+
|
| 68 |
+
echo ""
|
| 69 |
+
echo "✅ 完成!共恢复了 $count 个 .git"
|
| 70 |
+
echo "🧹 备份目录已清理"
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
case "${1:-}" in
|
| 74 |
+
backup)
|
| 75 |
+
backup
|
| 76 |
+
;;
|
| 77 |
+
restore)
|
| 78 |
+
restore
|
| 79 |
+
;;
|
| 80 |
+
*)
|
| 81 |
+
echo "用法: $0 {backup|restore}"
|
| 82 |
+
echo ""
|
| 83 |
+
echo " backup - 备份子目录中所有 .git 并删除它们"
|
| 84 |
+
echo " restore - 从备份中恢复所有 .git"
|
| 85 |
+
exit 1
|
| 86 |
+
;;
|
| 87 |
+
esac
|
nohup.out
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/workspace/miniconda3/envs/dflash/bin/python3: can't open file '/workspace/hanrui/ ': [Errno 2] No such file or directory
|
| 2 |
+
E0317 16:57:14.100000 140364991186752 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 2) local_rank: 0 (pid: 14058) of binary: /workspace/miniconda3/envs/dflash/bin/python3
|
| 3 |
+
Traceback (most recent call last):
|
| 4 |
+
File "<frozen runpy>", line 198, in _run_module_as_main
|
| 5 |
+
File "<frozen runpy>", line 88, in _run_code
|
| 6 |
+
File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/run.py", line 905, in <module>
|
| 7 |
+
main()
|
| 8 |
+
File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
|
| 9 |
+
return f(*args, **kwargs)
|
| 10 |
+
^^^^^^^^^^^^^^^^^^
|
| 11 |
+
File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/run.py", line 901, in main
|
| 12 |
+
run(args)
|
| 13 |
+
File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/run.py", line 892, in run
|
| 14 |
+
elastic_launch(
|
| 15 |
+
File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
|
| 16 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 17 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 18 |
+
File "/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
|
| 19 |
+
raise ChildFailedError(
|
| 20 |
+
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
| 21 |
+
============================================================
|
| 22 |
+
FAILED
|
| 23 |
+
------------------------------------------------------------
|
| 24 |
+
Failures:
|
| 25 |
+
<NO_OTHER_FAILURES>
|
| 26 |
+
------------------------------------------------------------
|
| 27 |
+
Root Cause (first observed failure):
|
| 28 |
+
[0]:
|
| 29 |
+
time : 2026-03-17_16:57:14
|
| 30 |
+
host : job-006ce80a7c47-20260302193512-5dcd4c9bbd-gfjsn
|
| 31 |
+
rank : 0 (local_rank: 0)
|
| 32 |
+
exitcode : 2 (pid: 14058)
|
| 33 |
+
error_file: <N/A>
|
| 34 |
+
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
|
| 35 |
+
============================================================
|
| 36 |
+
usage: run.py [-h] [--nnodes NNODES] [--nproc-per-node NPROC_PER_NODE]
|
| 37 |
+
[--rdzv-backend RDZV_BACKEND] [--rdzv-endpoint RDZV_ENDPOINT]
|
| 38 |
+
[--rdzv-id RDZV_ID] [--rdzv-conf RDZV_CONF] [--standalone]
|
| 39 |
+
[--max-restarts MAX_RESTARTS]
|
| 40 |
+
[--monitor-interval MONITOR_INTERVAL]
|
| 41 |
+
[--start-method {spawn,fork,forkserver}] [--role ROLE] [-m]
|
| 42 |
+
[--no-python] [--run-path] [--log-dir LOG_DIR] [-r REDIRECTS]
|
| 43 |
+
[-t TEE] [--local-ranks-filter LOCAL_RANKS_FILTER]
|
| 44 |
+
[--node-rank NODE_RANK] [--master-addr MASTER_ADDR]
|
| 45 |
+
[--master-port MASTER_PORT] [--local-addr LOCAL_ADDR]
|
| 46 |
+
[--logs-specs LOGS_SPECS]
|
| 47 |
+
training_script ...
|
| 48 |
+
run.py: error: the following arguments are required: training_script, training_script_args
|
progress/dflash_lora_changelog.md
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash LoRA 全部改动记录
|
| 2 |
+
|
| 3 |
+
## 概述
|
| 4 |
+
|
| 5 |
+
为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 新增文件清单
|
| 10 |
+
|
| 11 |
+
| 文件 | 行数 | 用途 |
|
| 12 |
+
|------|------|------|
|
| 13 |
+
| `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) |
|
| 14 |
+
| `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) |
|
| 15 |
+
| `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 |
|
| 16 |
+
| `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 |
|
| 17 |
+
| `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 |
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## Step 1 完成过程
|
| 22 |
+
|
| 23 |
+
### 1.1 分析现有代码
|
| 24 |
+
|
| 25 |
+
首先分析了非 LoRA 版 `train_dflash.py` 的完整流程:
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
input_ids → target_model.generate_dflash_data() → hidden_states
|
| 29 |
+
→ OnlineDFlashModel.forward():
|
| 30 |
+
1. 截断到 block 边界
|
| 31 |
+
2. prepare_noise_input(): anchor 保留,其余 → MASK
|
| 32 |
+
3. embed_tokens(noise_input_ids) → noise_embedding
|
| 33 |
+
4. 构建 DFlash attention mask
|
| 34 |
+
5. draft_model(noise_embedding, target_hidden, mask)
|
| 35 |
+
6. lm_head(hidden) → logits → CE loss
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。
|
| 39 |
+
|
| 40 |
+
### 1.2 确定 LoRA 版设计差异
|
| 41 |
+
|
| 42 |
+
| 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) |
|
| 43 |
+
|------|------|------|
|
| 44 |
+
| Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
|
| 45 |
+
| Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
|
| 46 |
+
| Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
|
| 47 |
+
| KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
|
| 48 |
+
| 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |
|
| 49 |
+
|
| 50 |
+
### 1.3 新建 LoRA 版三个核心文件
|
| 51 |
+
|
| 52 |
+
#### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel
|
| 53 |
+
|
| 54 |
+
- `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数
|
| 55 |
+
- `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要)
|
| 56 |
+
- `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用
|
| 57 |
+
- `gradient_checkpointing_enable()`: 代理到底层模型
|
| 58 |
+
- `save_pretrained()`: 仅保存 LoRA adapter 权重
|
| 59 |
+
|
| 60 |
+
#### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel
|
| 61 |
+
|
| 62 |
+
- `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
|
| 63 |
+
- `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]`
|
| 64 |
+
- `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
|
| 65 |
+
- `_full_lm_loss()`: 标准 CE loss 路径
|
| 66 |
+
- `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
|
| 67 |
+
- `forward()`: 完整训练 forward pass
|
| 68 |
+
|
| 69 |
+
LoRA 版 mask 规则:
|
| 70 |
+
- context token i → 因果注意力 (j ≤ i)
|
| 71 |
+
- block token i (属于 block b) → 所有 context + 同 block 内双向注意力
|
| 72 |
+
|
| 73 |
+
#### `scripts/train_dflash_lora.py` — 训练脚本
|
| 74 |
+
|
| 75 |
+
- 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
|
| 76 |
+
- `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
|
| 77 |
+
- `build_dataloader()`: 复用 `build_eagle3_dataset` 和 `prepare_dp_dataloaders`
|
| 78 |
+
- FSDP 包装 + BF16Optimizer
|
| 79 |
+
- 训练循环:forward → backward → accumulation → optimizer step
|
| 80 |
+
- checkpoint 保存/恢复
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## OOM 修复改动(4 项)
|
| 85 |
+
|
| 86 |
+
### 改动 1: FSDP FULL_SHARD (ZeRO-3)
|
| 87 |
+
|
| 88 |
+
**问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)
|
| 89 |
+
|
| 90 |
+
**修复**: `train_dflash_lora.py:362`
|
| 91 |
+
```python
|
| 92 |
+
# 之前
|
| 93 |
+
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
|
| 94 |
+
# 之后
|
| 95 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
**效果**: 参数跨卡分片,每卡省 ~8-12GB
|
| 99 |
+
|
| 100 |
+
### 改动 2: batch_size=1 + accumulation_steps=8
|
| 101 |
+
|
| 102 |
+
**问题**: `batch_size=2` 时峰值显存过高
|
| 103 |
+
|
| 104 |
+
**修复**: `run_train_dflash_lora.sh`
|
| 105 |
+
```bash
|
| 106 |
+
--batch-size 1 \
|
| 107 |
+
--accumulation-steps 8 \
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
**效果**: 等效 global batch size 不变,峰值显存减半
|
| 111 |
+
|
| 112 |
+
### 改动 3: flex_attention + BlockMask 替换 4D additive mask
|
| 113 |
+
|
| 114 |
+
**问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores
|
| 115 |
+
|
| 116 |
+
**修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景
|
| 117 |
+
|
| 118 |
+
涉及文件:
|
| 119 |
+
|
| 120 |
+
1. **`specforge/core/dflash_lora.py`**
|
| 121 |
+
- `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段
|
| 122 |
+
- 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask
|
| 123 |
+
- `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask
|
| 124 |
+
|
| 125 |
+
2. **`specforge/modeling/draft/dflash_lora.py`**
|
| 126 |
+
- `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace
|
| 127 |
+
|
| 128 |
+
3. **`scripts/train_dflash_lora.py`**
|
| 129 |
+
- `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`)
|
| 130 |
+
- `build_model()`: 根据 backend 选择 `attn_implementation`
|
| 131 |
+
|
| 132 |
+
BlockMask mask function(LoRA 版):
|
| 133 |
+
```python
|
| 134 |
+
def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
|
| 135 |
+
# Context query: 标准因果
|
| 136 |
+
is_q_ctx = q_idx < context_len
|
| 137 |
+
ctx_visible = is_q_ctx & (kv_idx <= q_idx)
|
| 138 |
+
|
| 139 |
+
# Block query: 全部 context + 同 block 双向
|
| 140 |
+
is_q_block = q_idx >= context_len
|
| 141 |
+
is_k_ctx = kv_idx < context_len
|
| 142 |
+
q_block_id = (q_idx - context_len) // block_size
|
| 143 |
+
k_block_id = (kv_idx - context_len) // block_size
|
| 144 |
+
block_attend_ctx = is_q_block & is_k_ctx
|
| 145 |
+
block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)
|
| 146 |
+
|
| 147 |
+
return ctx_visible | (block_attend_ctx | block_attend_same)
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
**验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。
|
| 151 |
+
|
| 152 |
+
**效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存
|
| 153 |
+
|
| 154 |
+
### 改动 4: chunked cross-entropy loss
|
| 155 |
+
|
| 156 |
+
**问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+
|
| 157 |
+
|
| 158 |
+
**修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss
|
| 159 |
+
|
| 160 |
+
涉及文件:
|
| 161 |
+
|
| 162 |
+
1. **`specforge/core/dflash_lora.py`**
|
| 163 |
+
- `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用)
|
| 164 |
+
- 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing
|
| 165 |
+
- 提取 `_full_lm_loss()`: 原始非 chunked 路径
|
| 166 |
+
- `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径
|
| 167 |
+
|
| 168 |
+
2. **`specforge/modeling/draft/dflash_lora.py`**
|
| 169 |
+
- `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits
|
| 170 |
+
- `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用
|
| 171 |
+
|
| 172 |
+
3. **`scripts/train_dflash_lora.py`**
|
| 173 |
+
- `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256)
|
| 174 |
+
- `build_model()`: 传递到 OnlineDFlashLoRAModel
|
| 175 |
+
|
| 176 |
+
Chunked loss 核心逻辑:
|
| 177 |
+
```python
|
| 178 |
+
# 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
|
| 179 |
+
for start in range(0, effective_len, chunk_size):
|
| 180 |
+
end = min(start + chunk_size, effective_len)
|
| 181 |
+
chunk_loss, chunk_weight = grad_checkpoint(
|
| 182 |
+
_chunk_ce, # lm_head + CE
|
| 183 |
+
hidden[:, start:end, :], # 只取当前 chunk
|
| 184 |
+
input_ids[:, start:end],
|
| 185 |
+
combined_mask[:, start:end],
|
| 186 |
+
use_reentrant=False,
|
| 187 |
+
)
|
| 188 |
+
total_loss += chunk_loss
|
| 189 |
+
total_weight += chunk_weight
|
| 190 |
+
loss = total_loss / total_weight
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
**效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
## 当前训练命令
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
bash run_train_dflash_lora.sh 2 # 2 = GPU 数量
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
对应完整参数:
|
| 204 |
+
```bash
|
| 205 |
+
torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
|
| 206 |
+
--model-path /workspace/Qwen3-8B \
|
| 207 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 208 |
+
--output-dir outputs/qwen3-8b-dflash-lora \
|
| 209 |
+
--lora-config configs/qwen3-8b-dflash-lora.json \
|
| 210 |
+
--block-size 16 \
|
| 211 |
+
--max-length 2048 \
|
| 212 |
+
--batch-size 1 \
|
| 213 |
+
--num-epochs 3 \
|
| 214 |
+
--learning-rate 2e-4 \
|
| 215 |
+
--accumulation-steps 8 \
|
| 216 |
+
--loss-decay-gamma 7 \
|
| 217 |
+
--attention-backend flex_attention \
|
| 218 |
+
--lm-head-chunk-size 256 \
|
| 219 |
+
--gradient-checkpointing \
|
| 220 |
+
--chat-template qwen \
|
| 221 |
+
--log-interval 50 \
|
| 222 |
+
--save-interval 500
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
---
|
| 226 |
+
|
| 227 |
+
## 待验证
|
| 228 |
+
|
| 229 |
+
- [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM
|
| 230 |
+
- [ ] 确认无 SDPA math fallback warning
|
| 231 |
+
- [ ] 观察 GPU 显存峰值
|
| 232 |
+
- [ ] 确认 loss 下降和 accuracy 上升趋势正常
|
progress/list.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### 1. `train_dflash_lora.py`
|
| 2 |
+
* 加了lora,原来是调用小模型,现在是hidden states+lora预测。
|
| 3 |
+
* `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
|
| 4 |
+
|
| 5 |
+
### 2. OOM优化
|
| 6 |
+
* 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。
|
| 7 |
+
* `batch-size=1`,`accumulation-steps=8`。
|
| 8 |
+
* 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
|
| 9 |
+
* `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。
|
| 10 |
+
|
| 11 |
+
### 运行
|
| 12 |
+
* bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2
|
progress/oom_fix_progress.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash LoRA OOM 修复记录
|
| 2 |
+
|
| 3 |
+
## OOM 根因分析
|
| 4 |
+
|
| 5 |
+
1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
|
| 6 |
+
2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`)
|
| 7 |
+
3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
|
| 8 |
+
4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4`
|
| 9 |
+
|
| 10 |
+
## 已完成的改动
|
| 11 |
+
|
| 12 |
+
### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
|
| 13 |
+
- 文件: `SpecForge/scripts/train_dflash_lora.py:347`
|
| 14 |
+
- `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD`
|
| 15 |
+
- 效果: 参数跨卡分片,每卡省 ~8-12GB
|
| 16 |
+
|
| 17 |
+
### 2. 降 batch-size,提高 accumulation-steps
|
| 18 |
+
- 文件: `SpecForge/scripts/run_train_dflash_lora.sh`
|
| 19 |
+
- `--batch-size 2` → `1`,`--accumulation-steps 4` → `8`
|
| 20 |
+
- 效果: 等效 global batch size 不变,峰值显存减半
|
| 21 |
+
|
| 22 |
+
## 待验证 / 后续优化
|
| 23 |
+
|
| 24 |
+
- [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡
|
| 25 |
+
- [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
|
| 26 |
+
- [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
|
| 27 |
+
|
| 28 |
+
### 3. flex_attention + BlockMask 替换 4D additive mask
|
| 29 |
+
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
|
| 30 |
+
- 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
|
| 31 |
+
- LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
|
| 32 |
+
- 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask
|
| 33 |
+
- HuggingFace model 用 `attn_implementation="flex_attention"` 加载
|
| 34 |
+
- 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存
|
| 35 |
+
|
| 36 |
+
### 4. chunked cross-entropy loss
|
| 37 |
+
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
|
| 38 |
+
- 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法
|
| 39 |
+
- 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits
|
| 40 |
+
- 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用)
|
| 41 |
+
- `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states
|
| 42 |
+
- 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)
|
progress/requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.4.0
|
| 2 |
+
torchaudio==2.4.0
|
| 3 |
+
torchvision==0.19.0
|
| 4 |
+
transformers==4.57.1
|
| 5 |
+
accelerate
|
| 6 |
+
datasets
|
| 7 |
+
peft
|
| 8 |
+
numpy
|
| 9 |
+
tqdm
|
| 10 |
+
wandb
|
| 11 |
+
psutil
|
| 12 |
+
pydantic
|
| 13 |
+
packaging
|
| 14 |
+
ninja
|
| 15 |
+
setuptools
|
| 16 |
+
yunchang
|
| 17 |
+
sglang<=0.4.10
|
| 18 |
+
openai-harmony
|
| 19 |
+
qwen-vl-utils==0.0.11
|
| 20 |
+
pre-commit
|
progress/step1.md
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Now I have a complete picture. Here's the full analysis and proposed plan.
|
| 2 |
+
|
| 3 |
+
---
|
| 4 |
+
Current Logic of train_dflash.py
|
| 5 |
+
|
| 6 |
+
This script trains a small, standalone DFlash draft model (1–10 transformer layers) that learns to predict token blocks in parallel,
|
| 7 |
+
guided by hidden states extracted from a frozen target model.
|
| 8 |
+
|
| 9 |
+
Architecture (two separate models)
|
| 10 |
+
|
| 11 |
+
1. Target model (frozen, e.g. Qwen3-8B/32B) — runs a full forward pass to produce hidden states at selected layers.
|
| 12 |
+
2. Draft model (DFlashDraftModel, ~1–10 layers) — a lightweight Qwen3-based decoder that takes noise embeddings + target hidden states
|
| 13 |
+
and predicts the block tokens.
|
| 14 |
+
3. Target embed_tokens + lm_head — loaded separately via TargetEmbeddingsAndHead to avoid duplicating the full target model in memory.
|
| 15 |
+
|
| 16 |
+
Key locations
|
| 17 |
+
|
| 18 |
+
┌──────────────────────────┬────────────────────────────────────┬───────────────────────────────────────────────────────┐
|
| 19 |
+
│ Component │ File │ Lines │
|
| 20 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 21 |
+
│ Model init │ scripts/train_dflash.py │ build_models() L254–311 │
|
| 22 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 23 |
+
│ Target hidden extraction │ scripts/train_dflash.py │ L644–647 (target_model.generate_dflash_data) │
|
| 24 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 25 |
+
│ Forward pass │ specforge/core/dflash.py │ OnlineDFlashModel.forward() L243–332 │
|
| 26 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 27 |
+
│ Loss calculation │ specforge/core/dflash.py │ _full_lm_loss() L382–417, _chunked_lm_loss() L419–478 │
|
| 28 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 29 |
+
│ Loss mask │ specforge/core/dflash.py │ create_dflash_loss_mask() L481–509 │
|
| 30 |
+
├──────────────────────────┼────────────────────────────────────┼────���──────────────────────────────────────────────────┤
|
| 31 |
+
│ Draft model architecture │ specforge/modeling/draft/dflash.py │ DFlashDraftModel L212–266 │
|
| 32 |
+
├──────────────────────────┼────────────────────────────────────┼───────────────────────────────────────────────────────┤
|
| 33 |
+
│ DFlash attention │ specforge/modeling/draft/dflash.py │ Qwen3DFlashAttention L42–134 │
|
| 34 |
+
└──────────────────────────┴────────────────────────────────────┴───────────────────────────────────────────────────────┘
|
| 35 |
+
|
| 36 |
+
Forward pass flow (per training step)
|
| 37 |
+
|
| 38 |
+
input_ids, attention_mask, loss_mask → target_model.generate_dflash_data()
|
| 39 |
+
↓
|
| 40 |
+
hidden_states (from target layers [1,9,17,25,33])
|
| 41 |
+
↓
|
| 42 |
+
OnlineDFlashModel.forward():
|
| 43 |
+
1. Truncate to block boundary
|
| 44 |
+
2. prepare_noise_input(): anchor tokens kept, rest → MASK
|
| 45 |
+
3. embed_tokens(noise_input_ids) → noise_embedding
|
| 46 |
+
4. Build DFlash attention mask (flex_attention or additive)
|
| 47 |
+
5. draft_model(noise_embedding, target_hidden, mask)
|
| 48 |
+
6. lm_head(hidden) → logits
|
| 49 |
+
7. CE loss on non-anchor positions (weighted by loss_mask × decay)
|
| 50 |
+
|
| 51 |
+
The draft model's custom Qwen3DFlashAttention concatenates [context_hidden, noise_hidden] as KV, with queries only from noise tokens. The
|
| 52 |
+
attention mask enforces: block tokens see all preceding blocks' context + bidirectional within their own block.
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
What already exists: train_dflash_lora.py
|
| 56 |
+
|
| 57 |
+
Interestingly, the repo already has a LoRA variant at scripts/train_dflash_lora.py with its own model (DFlashLoRADraftModel) and wrapper
|
| 58 |
+
(OnlineDFlashLoRAModel). This is exactly the approach you described — Qwen3-8B + LoRA, no separate target model, 1-step diffusion
|
| 59 |
+
training. The key differences from train_dflash.py:
|
| 60 |
+
|
| 61 |
+
┌─────────────────┬─────────────────────────────────────────────────────────────┬────────────────────────────────────────────────────┐
|
| 62 |
+
│ Aspect │ train_dflash.py │ train_dflash_lora.py │
|
| 63 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 64 |
+
│ Draft model │ Small custom DFlashDraftModel (1–10 layers) │ Full Qwen3-8B + LoRA adapters │
|
| 65 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 66 |
+
│ Target model │ Separate frozen model for hidden state extraction │ None — model uses its own representations │
|
| 67 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼───────────────────────────���────────────────────────┤
|
| 68 |
+
│ Attention │ Custom Qwen3DFlashAttention (Q from noise, KV from [ctx, │ Standard HF attention with 4D additive DFlash mask │
|
| 69 |
+
│ │ noise]) │ │
|
| 70 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 71 |
+
│ Forward │ draft_model(noise_emb, target_hidden, mask) │ model(noise_input_ids, 4d_mask, position_ids) → │
|
| 72 |
+
│ │ │ logits │
|
| 73 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 74 |
+
│ Trainable │ All draft model params │ Only LoRA (q/k/v/o_proj), base frozen │
|
| 75 |
+
│ params │ │ │
|
| 76 |
+
├─────────────────┼─────────────────────────────────────────────────────────────┼────────────────────────────────────────────────────┤
|
| 77 |
+
│ FSDP strategy │ SHARD_GRAD_OP │ FULL_SHARD │
|
| 78 |
+
└─────────────────┴─────────────────────────────────────────────────────────────┴────────────────────────────────────────────────────┘
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
Proposed Modification Plan
|
| 82 |
+
|
| 83 |
+
Since train_dflash_lora.py already implements the core idea, the plan focuses on what's missing or needs improvement to make it a proper
|
| 84 |
+
"1-step dLLM draft model" for your research:
|
| 85 |
+
|
| 86 |
+
Phase 1: Validate and extend the existing LoRA pipeline
|
| 87 |
+
|
| 88 |
+
1. Add MLP to LoRA targets — The current config only targets q_proj, k_proj, v_proj, o_proj. For stronger 1-step diffusion capability,
|
| 89 |
+
add gate_proj, up_proj, down_proj to lora_target_modules. This gives the model more capacity to learn the non-autoregressive distribution
|
| 90 |
+
shift.
|
| 91 |
+
2. Add multi-step noise schedule support — Currently the training is strictly 1-step (all non-anchors → MASK). For a proper diffusion/AR
|
| 92 |
+
fusion, add an option for a noise schedule where a fraction of block tokens are revealed (not just the anchor), controlled by a
|
| 93 |
+
noise_ratio parameter. This would modify prepare_noise_input() in OnlineDFlashLoRAModel:
|
| 94 |
+
# Instead of: all non-anchor → MASK
|
| 95 |
+
# Allow: randomly keep some non-anchor tokens with probability (1 - noise_ratio)
|
| 96 |
+
3. Add configurable context_len strategy — Currently context_len=0 treats the whole sequence as blocks. Add a --context-ratio arg that
|
| 97 |
+
dynamically sets context_len as a fraction of the sequence, so the model learns to condition on varying amounts of AR-decoded prefix.
|
| 98 |
+
|
| 99 |
+
Phase 2: Training logic improvements
|
| 100 |
+
|
| 101 |
+
4. Add KL divergence loss — In addition to CE loss against ground truth, add an optional KL loss against the base model's AR distribution
|
| 102 |
+
(teacher forcing). This regularizes the LoRA model to stay close to the original Qwen3-8B distribution. Modify
|
| 103 |
+
OnlineDFlashLoRAModel.forward():
|
| 104 |
+
# Compute base model logits (no_grad, no LoRA) as teacher
|
| 105 |
+
# KL(draft_logits || teacher_logits) on block positions
|
| 106 |
+
# total_loss = ce_loss + kl_weight * kl_loss
|
| 107 |
+
5. Add evaluation with speculative decoding metrics — The current accuracy metric is block-wise acceptance rate. Add an eval loop that
|
| 108 |
+
actually runs speculative decoding (draft → verify) to measure real speedup, using the LoRA model as the drafter and the base model (with
|
| 109 |
+
LoRA disabled) as the verifier.
|
| 110 |
+
|
| 111 |
+
Phase 3: Integration with train_dflash.py style features
|
| 112 |
+
|
| 113 |
+
6. Port random anchor sampling — train_dflash.py has --random-anchor (L147–156) which samples diverse anchor positions per sequence. This
|
| 114 |
+
is missing from the LoRA variant and would improve training diversity. Port _sample_anchor_positions and _build_blocks_from_anchors from
|
| 115 |
+
OnlineDFlashModel to OnlineDFlashLoRAModel.
|
| 116 |
+
7. Port chunked lm_head — The LoRA variant materializes full [bsz, seq_len, vocab_size] logits. For Qwen3-8B's 152K vocab, this is ~1.2GB
|
| 117 |
+
per sample at seq_len=2048. Port the _chunked_lm_loss from OnlineDFlashModel for memory efficiency.
|
| 118 |
+
8. Add tensor parallelism support — The LoRA script currently forces tp_size=1. For Qwen3-8B on multi-GPU setups, add TP support to shard
|
| 119 |
+
the base model across GPUs while keeping LoRA params on each rank.
|
| 120 |
+
|
| 121 |
+
Minimal starting point
|
| 122 |
+
|
| 123 |
+
If you want to just get training running immediately, the existing train_dflash_lora.py + configs/qwen3-8b-dflash-lora.json should work
|
| 124 |
+
out of the box:
|
| 125 |
+
|
| 126 |
+
torchrun --nproc_per_node=4 scripts/train_dflash_lora.py \
|
| 127 |
+
--model-path Qwen/Qwen3-8B \
|
| 128 |
+
--lora-config configs/qwen3-8b-dflash-lora.json \
|
| 129 |
+
--train-data-path /path/to/data.jsonl \
|
| 130 |
+
--output-dir ./outputs/qwen3-8b-dflash-lora \
|
| 131 |
+
--block-size 16 \
|
| 132 |
+
--gradient-checkpointing \
|
| 133 |
+
--learning-rate 2e-4 \
|
| 134 |
+
--batch-size 1 \
|
| 135 |
+
--max-length 2048 \
|
| 136 |
+
--num-epochs 3
|
| 137 |
+
|
| 138 |
+
Want me to start implementing any of these phases? I'd suggest starting with Phase 1 (noise schedule + MLP targets) since that's the most
|
| 139 |
+
impactful for the diffusion/AR fusion research angle.
|
sglang/.codespellrc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[codespell]
|
| 2 |
+
ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS
|
| 3 |
+
skip = *.json,*.jsonl,*.patch,*.txt
|
sglang/.editorconfig
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://editorconfig.org/
|
| 2 |
+
|
| 3 |
+
root = true
|
| 4 |
+
|
| 5 |
+
[*]
|
| 6 |
+
charset = utf-8
|
| 7 |
+
end_of_line = lf
|
| 8 |
+
indent_style = space
|
| 9 |
+
indent_size = 4
|
| 10 |
+
trim_trailing_whitespace = true
|
| 11 |
+
insert_final_newline = true
|
| 12 |
+
|
| 13 |
+
[*.{json,yaml,yml}]
|
| 14 |
+
indent_size = 2
|
| 15 |
+
|
| 16 |
+
[*.md]
|
| 17 |
+
indent_size = 2
|
| 18 |
+
x-soft-wrap-text = true
|
| 19 |
+
|
| 20 |
+
[*.rst]
|
| 21 |
+
indent_size = 4
|
| 22 |
+
x-soft-wrap-text = true
|
| 23 |
+
|
| 24 |
+
[Makefile]
|
| 25 |
+
indent_style = tab
|
sglang/.isort.cfg
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[settings]
|
| 2 |
+
profile=black
|
| 3 |
+
known_first_party=sglang
|
sglang/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stages: [pre-commit, pre-push, manual]
|
| 2 |
+
exclude: ^(python/sglang/multimodal_gen/csrc|python/sglang/jit_kernel/flash_attention/cute)
|
| 3 |
+
|
| 4 |
+
repos:
|
| 5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 6 |
+
rev: v6.0.0
|
| 7 |
+
hooks:
|
| 8 |
+
- id: check-symlinks
|
| 9 |
+
- id: destroyed-symlinks
|
| 10 |
+
- id: trailing-whitespace
|
| 11 |
+
- id: end-of-file-fixer
|
| 12 |
+
- id: check-yaml
|
| 13 |
+
args: [--allow-multiple-documents]
|
| 14 |
+
- id: check-toml
|
| 15 |
+
- id: check-ast
|
| 16 |
+
- id: check-added-large-files
|
| 17 |
+
- id: check-merge-conflict
|
| 18 |
+
- id: check-shebang-scripts-are-executable
|
| 19 |
+
- id: detect-private-key
|
| 20 |
+
exclude: ^sgl-model-gateway/tests/.*_test\.rs$
|
| 21 |
+
- id: debug-statements
|
| 22 |
+
- id: no-commit-to-branch
|
| 23 |
+
- repo: https://github.com/PyCQA/isort
|
| 24 |
+
rev: 7.0.0
|
| 25 |
+
hooks:
|
| 26 |
+
- id: isort
|
| 27 |
+
exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
|
| 28 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 29 |
+
rev: v0.15.1
|
| 30 |
+
hooks:
|
| 31 |
+
- id: ruff
|
| 32 |
+
args:
|
| 33 |
+
- --select=F401,F821
|
| 34 |
+
- --fix
|
| 35 |
+
files: ^(benchmark/|docs/|examples/|python/sglang/|sgl-model-gateway/py_*|test/)
|
| 36 |
+
exclude: |
|
| 37 |
+
(?x)^(
|
| 38 |
+
.*/__init__\.py$|
|
| 39 |
+
.*\.ipynb$|
|
| 40 |
+
python/sglang/srt/grpc/.*_pb2\.py$|
|
| 41 |
+
python/sglang/srt/grpc/.*_pb2_grpc\.py$|
|
| 42 |
+
python/sglang/srt/grpc/.*_pb2\.pyi$|
|
| 43 |
+
python/sglang/srt/grpc/.*_pb2_grpc\.pyi$|
|
| 44 |
+
)$
|
| 45 |
+
- repo: https://github.com/psf/black
|
| 46 |
+
rev: 26.1.0
|
| 47 |
+
hooks:
|
| 48 |
+
- id: black-jupyter
|
| 49 |
+
exclude: '^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$'
|
| 50 |
+
- repo: https://github.com/codespell-project/codespell
|
| 51 |
+
rev: v2.4.1
|
| 52 |
+
hooks:
|
| 53 |
+
- id: codespell
|
| 54 |
+
args: ['--config', '.codespellrc']
|
| 55 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
| 56 |
+
rev: v20.1.7
|
| 57 |
+
hooks:
|
| 58 |
+
- id: clang-format
|
| 59 |
+
types_or: [c++, cuda]
|
| 60 |
+
args: [--style=file, --verbose]
|
| 61 |
+
- repo: https://github.com/kynan/nbstripout
|
| 62 |
+
rev: 0.9.0
|
| 63 |
+
hooks:
|
| 64 |
+
- id: nbstripout
|
| 65 |
+
args:
|
| 66 |
+
- '--keep-output'
|
| 67 |
+
- '--extra-keys=metadata.kernelspec metadata.language_info.version'
|
| 68 |
+
- repo: local
|
| 69 |
+
hooks:
|
| 70 |
+
- id: check-chinese-characters
|
| 71 |
+
name: check chinese characters in multimodal_gen
|
| 72 |
+
entry: >-
|
| 73 |
+
python3 -c 'import sys, re; p=re.compile(r"[\u4e00-\u9fff]"); ec=0; [ ([(print(f"{f}:{i+1}: {l.strip()}") or (ec:=1)) for i,l in enumerate(open(f, "r", encoding="utf-8", errors="ignore")) if p.search(l)]) for f in sys.argv[1:] ]; sys.exit(ec)'
|
| 74 |
+
language: system
|
| 75 |
+
files: ^python/sglang/multimodal_gen/.*
|
| 76 |
+
exclude: ^(python/sglang/multimodal_gen/configs/sample|python/sglang/multimodal_gen/apps/ComfyUI_SGLDiffusion/workflows|python/sglang/multimodal_gen/runtime/pipelines_core/stages/model_specific_stages)(/|$)
|
| 77 |
+
types_or: [python, markdown, json, text]
|
| 78 |
+
- id: sort-ci-permissions
|
| 79 |
+
name: sort CI_PERMISSIONS.json
|
| 80 |
+
entry: python3 .github/update_ci_permission.py --sort-only
|
| 81 |
+
language: system
|
| 82 |
+
files: ^\.github/CI_PERMISSIONS\.json$
|
| 83 |
+
pass_filenames: false
|
sglang/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, religion, or sexual identity
|
| 10 |
+
and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the
|
| 26 |
+
overall community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or
|
| 31 |
+
advances of any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email
|
| 35 |
+
address, without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 |
+
.
|
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series
|
| 86 |
+
of actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or
|
| 93 |
+
permanent ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
| 113 |
+
the community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.0, available at
|
| 119 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
| 122 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
| 123 |
+
|
| 124 |
+
[homepage]: https://www.contributor-covenant.org
|
| 125 |
+
|
| 126 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 127 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
| 128 |
+
https://www.contributor-covenant.org/translations.
|
sglang/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2023-2024 SGLang Team
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
sglang/README.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
<div align="center" id="sglangtop">
|
| 4 |
+
<img src="https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" alt="logo" width="400" margin="10px"></img>
|
| 5 |
+
|
| 6 |
+
[](https://pypi.org/project/sglang)
|
| 7 |
+

|
| 8 |
+
[](https://github.com/sgl-project/sglang/tree/main/LICENSE)
|
| 9 |
+
[](https://github.com/sgl-project/sglang/issues)
|
| 10 |
+
[](https://github.com/sgl-project/sglang/issues)
|
| 11 |
+
[](https://deepwiki.com/sgl-project/sglang)
|
| 12 |
+
|
| 13 |
+
</div>
|
| 14 |
+
|
| 15 |
+
--------------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
<p align="center">
|
| 18 |
+
<a href="https://lmsys.org/blog/"><b>Blog</b></a> |
|
| 19 |
+
<a href="https://docs.sglang.io/"><b>Documentation</b></a> |
|
| 20 |
+
<a href="https://roadmap.sglang.io/"><b>Roadmap</b></a> |
|
| 21 |
+
<a href="https://slack.sglang.io/"><b>Join Slack</b></a> |
|
| 22 |
+
<a href="https://meet.sglang.io/"><b>Weekly Dev Meeting</b></a> |
|
| 23 |
+
<a href="https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides"><b>Slides</b></a>
|
| 24 |
+
</p>
|
| 25 |
+
|
| 26 |
+
## News
|
| 27 |
+
- [2026/01] 🔥 SGLang Diffusion accelerates video and image generation ([blog](https://lmsys.org/blog/2026-01-16-sglang-diffusion/)).
|
| 28 |
+
- [2025/12] SGLang provides day-0 support for latest open models ([MiMo-V2-Flash](https://lmsys.org/blog/2025-12-16-mimo-v2-flash/), [Nemotron 3 Nano](https://lmsys.org/blog/2025-12-15-run-nvidia-nemotron-3-nano/), [Mistral Large 3](https://github.com/sgl-project/sglang/pull/14213), [LLaDA 2.0 Diffusion LLM](https://lmsys.org/blog/2025-12-19-diffusion-llm/), [MiniMax M2](https://lmsys.org/blog/2025-11-04-miminmax-m2/)).
|
| 29 |
+
- [2025/10] 🔥 SGLang now runs natively on TPU with the SGLang-Jax backend ([blog](https://lmsys.org/blog/2025-10-29-sglang-jax/)).
|
| 30 |
+
- [2025/09] Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part II): 3.8x Prefill, 4.8x Decode Throughput ([blog](https://lmsys.org/blog/2025-09-25-gb200-part-2/)).
|
| 31 |
+
- [2025/09] SGLang Day 0 Support for DeepSeek-V3.2 with Sparse Attention ([blog](https://lmsys.org/blog/2025-09-29-deepseek-V32/)).
|
| 32 |
+
- [2025/08] SGLang x AMD SF Meetup on 8/22: Hands-on GPU workshop, tech talks by AMD/xAI/SGLang, and networking ([Roadmap](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_roadmap.pdf), [Large-scale EP](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_sglang_ep.pdf), [Highlights](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_highlights.pdf), [AITER/MoRI](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_aiter_mori.pdf), [Wave](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/amd_meetup_wave.pdf)).
|
| 33 |
+
|
| 34 |
+
<details>
|
| 35 |
+
<summary>More</summary>
|
| 36 |
+
|
| 37 |
+
- [2025/11] SGLang Diffusion accelerates video and image generation ([blog](https://lmsys.org/blog/2025-11-07-sglang-diffusion/)).
|
| 38 |
+
- [2025/10] PyTorch Conference 2025 SGLang Talk ([slide](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/sglang_pytorch_2025.pdf)).
|
| 39 |
+
- [2025/10] SGLang x Nvidia SF Meetup on 10/2 ([recap](https://x.com/lmsysorg/status/1975339501934510231)).
|
| 40 |
+
- [2025/08] SGLang provides day-0 support for OpenAI gpt-oss model ([instructions](https://github.com/sgl-project/sglang/issues/8833))
|
| 41 |
+
- [2025/06] SGLang, the high-performance serving infrastructure powering trillions of tokens daily, has been awarded the third batch of the Open Source AI Grant by a16z ([a16z blog](https://a16z.com/advancing-open-source-ai-through-benchmarks-and-bold-experimentation/)).
|
| 42 |
+
- [2025/05] Deploying DeepSeek with PD Disaggregation and Large-scale Expert Parallelism on 96 H100 GPUs ([blog](https://lmsys.org/blog/2025-05-05-large-scale-ep/)).
|
| 43 |
+
- [2025/06] Deploying DeepSeek on GB200 NVL72 with PD and Large Scale EP (Part I): 2.7x Higher Decoding Throughput ([blog](https://lmsys.org/blog/2025-06-16-gb200-part-1/)).
|
| 44 |
+
- [2025/03] Supercharge DeepSeek-R1 Inference on AMD Instinct MI300X ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1-Part2/README.html))
|
| 45 |
+
- [2025/03] SGLang Joins PyTorch Ecosystem: Efficient LLM Serving Engine ([PyTorch blog](https://pytorch.org/blog/sglang-joins-pytorch/))
|
| 46 |
+
- [2025/02] Unlock DeepSeek-R1 Inference Performance on AMD Instinct™ MI300X GPU ([AMD blog](https://rocm.blogs.amd.com/artificial-intelligence/DeepSeekR1_Perf/README.html))
|
| 47 |
+
- [2025/01] SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html), [10+ other companies](https://x.com/lmsysorg/status/1887262321636221412))
|
| 48 |
+
- [2024/12] v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
|
| 49 |
+
- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
|
| 50 |
+
- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
|
| 51 |
+
- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
|
| 52 |
+
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
|
| 53 |
+
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
|
| 54 |
+
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
|
| 55 |
+
|
| 56 |
+
</details>
|
| 57 |
+
|
| 58 |
+
## About
|
| 59 |
+
SGLang is a high-performance serving framework for large language models and multimodal models.
|
| 60 |
+
It is designed to deliver low-latency and high-throughput inference across a wide range of setups, from a single GPU to large distributed clusters.
|
| 61 |
+
Its core features include:
|
| 62 |
+
|
| 63 |
+
- **Fast Runtime**: Provides efficient serving with RadixAttention for prefix caching, a zero-overhead CPU scheduler, prefill-decode disaggregation, speculative decoding, continuous batching, paged attention, tensor/pipeline/expert/data parallelism, structured outputs, chunked prefill, quantization (FP4/FP8/INT4/AWQ/GPTQ), and multi-LoRA batching.
|
| 64 |
+
- **Broad Model Support**: Supports a wide range of language models (Llama, Qwen, DeepSeek, Kimi, GLM, GPT, Gemma, Mistral, etc.), embedding models (e5-mistral, gte, mcdse), reward models (Skywork), and diffusion models (WAN, Qwen-Image), with easy extensibility for adding new models. Compatible with most Hugging Face models and OpenAI APIs.
|
| 65 |
+
- **Extensive Hardware Support**: Runs on NVIDIA GPUs (GB200/B300/H100/A100/Spark), AMD GPUs (MI355/MI300), Intel Xeon CPUs, Google TPUs, Ascend NPUs, and more.
|
| 66 |
+
- **Active Community**: SGLang is open-source and supported by a vibrant community with widespread industry adoption, powering over 400,000 GPUs worldwide.
|
| 67 |
+
- **RL & Post-Training Backbone**: SGLang is a proven rollout backend across the world, with native RL integrations and adoption by well-known post-training frameworks such as [**AReaL**](https://github.com/inclusionAI/AReaL), [**Miles**](https://github.com/radixark/miles), [**slime**](https://github.com/THUDM/slime), [**Tunix**](https://github.com/google/tunix), [**verl**](https://github.com/volcengine/verl) and more.
|
| 68 |
+
|
| 69 |
+
## Getting Started
|
| 70 |
+
- [Install SGLang](https://docs.sglang.io/get_started/install.html)
|
| 71 |
+
- [Quick Start](https://docs.sglang.io/basic_usage/send_request.html)
|
| 72 |
+
- [Backend Tutorial](https://docs.sglang.io/basic_usage/openai_api_completions.html)
|
| 73 |
+
- [Frontend Tutorial](https://docs.sglang.io/references/frontend/frontend_tutorial.html)
|
| 74 |
+
- [Contribution Guide](https://docs.sglang.io/developer_guide/contribution_guide.html)
|
| 75 |
+
|
| 76 |
+
## Benchmark and Performance
|
| 77 |
+
Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/), [Large-scale expert parallelism](https://lmsys.org/blog/2025-05-05-large-scale-ep/), [GB200 rack-scale parallelism](https://lmsys.org/blog/2025-09-25-gb200-part-2/).
|
| 78 |
+
|
| 79 |
+
## Adoption and Sponsorship
|
| 80 |
+
SGLang has been deployed at large scale, generating trillions of tokens in production each day. It is trusted and adopted by a wide range of leading enterprises and institutions, including xAI, AMD, NVIDIA, Intel, LinkedIn, Cursor, Oracle Cloud, Google Cloud, Microsoft Azure, AWS, Atlas Cloud, Voltage Park, Nebius, DataCrunch, Novita, InnoMatrix, MIT, UCLA, the University of Washington, Stanford, UC Berkeley, Tsinghua University, Jam & Tea Studios, Baseten, and other major technology organizations across North America and Asia.
|
| 81 |
+
As an open-source LLM inference engine, SGLang has become the de facto industry standard, with deployments running on over 400,000 GPUs worldwide.
|
| 82 |
+
SGLang is currently hosted under the non-profit open-source organization [LMSYS](https://lmsys.org/about/).
|
| 83 |
+
|
| 84 |
+
<img src="https://raw.githubusercontent.com/sgl-project/sgl-learning-materials/refs/heads/main/slides/adoption.png" alt="logo" width="800" margin="10px"></img>
|
| 85 |
+
|
| 86 |
+
## Contact Us
|
| 87 |
+
For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at sglang@lmsys.org
|
| 88 |
+
|
| 89 |
+
## Acknowledgment
|
| 90 |
+
We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
|
syxin_old/DFLASH_LORA_INJECT_FIXES.md
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash LoRA Inject 修复方案
|
| 2 |
+
|
| 3 |
+
## 问题背景
|
| 4 |
+
|
| 5 |
+
DFlash LoRA inject 训练后评估结果极差:
|
| 6 |
+
- GSM8K: 5.05× → **1.04×**(baseline → LoRA inject)
|
| 7 |
+
- HumanEval: 5.06× → **0.98×**
|
| 8 |
+
- MT-Bench: 2.70× → **0.85×**
|
| 9 |
+
|
| 10 |
+
根本原因:LoRA 权重保存格式不正确 + draft model forward 效率问题。
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## Fix 1: LoRA 权重保存修复(最关键)
|
| 15 |
+
|
| 16 |
+
### 文件
|
| 17 |
+
`Specforge/scripts/train_dflash_lora_inject.py`
|
| 18 |
+
|
| 19 |
+
### 问题
|
| 20 |
+
`save_checkpoint()` 函数(L292-306)手动提取 state_dict 中含 `"lora_"` 的 key 保存为 `adapter_model.safetensors`,然后用 `peft_config["default"].save_pretrained()` 只保存了 LoraConfig 的 JSON。
|
| 21 |
+
|
| 22 |
+
但 PEFT 的 `PeftModel.from_pretrained()` 期望:
|
| 23 |
+
1. **标准的 `adapter_config.json`**(不是 LoraConfig 直接序列化的格式)
|
| 24 |
+
2. **key 命名规范不同**(PEFT 内部会自动处理 `base_model.model.` 前缀)
|
| 25 |
+
|
| 26 |
+
导致评估加载时出现大量 **"Found missing adapter keys"** 警告,LoRA 权重实际未加载。
|
| 27 |
+
|
| 28 |
+
### 修改
|
| 29 |
+
|
| 30 |
+
**删除 L295-306:**
|
| 31 |
+
```python
|
| 32 |
+
lora_state_dict = {
|
| 33 |
+
k: v for k, v in module.draft_model.model.state_dict().items()
|
| 34 |
+
if "lora_" in k
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from safetensors.torch import save_file as safetensors_save
|
| 39 |
+
safetensors_save(lora_state_dict, os.path.join(save_dir, "adapter_model.safetensors"))
|
| 40 |
+
except (ImportError, Exception):
|
| 41 |
+
torch.save(lora_state_dict, os.path.join(save_dir, "adapter_model.bin"))
|
| 42 |
+
|
| 43 |
+
draft_model.model.peft_config["default"].save_pretrained(save_dir)
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
**替换为:**
|
| 47 |
+
```python
|
| 48 |
+
# Use PEFT native save which handles key naming and adapter_config.json correctly
|
| 49 |
+
module.draft_model.model.save_pretrained(save_dir)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
`PeftModel.save_pretrained()` 会正确:
|
| 53 |
+
- 自动处理 key 前缀映射
|
| 54 |
+
- 生成标准 `adapter_config.json`
|
| 55 |
+
- 保存 `adapter_model.safetensors`
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## Fix 2: Draft Model Forward 优化(性能 + 代码清晰度)
|
| 60 |
+
|
| 61 |
+
### 文件
|
| 62 |
+
`Specforge/specforge/modeling/draft/dflash_lora_inject.py`
|
| 63 |
+
|
| 64 |
+
### 问题
|
| 65 |
+
`_forward_with_injection()` 当前在 **每一层的 for loop 内** 重复计算:
|
| 66 |
+
- `rotary_emb(layer_input, extended_pos)` — 每层一次,36层重复计算
|
| 67 |
+
- 扩展 attention mask — 每层重建 O(seq_len²) 的 mask
|
| 68 |
+
- `extended_pos = torch.cat([target_pos, position_ids])` — 每层重复拼接
|
| 69 |
+
|
| 70 |
+
### 修改
|
| 71 |
+
将 position_embeddings、extended_mask、extended_pos 的计算 **提取到 for loop 之前**(一次性预计算):
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
def _forward_with_injection(self, input_ids, attention_mask, target_hidden_states,
|
| 75 |
+
position_ids=None, output_hidden_states=False, context_len=0):
|
| 76 |
+
# ... (get base_model, embed_tokens, layers, norm, lm_head)
|
| 77 |
+
|
| 78 |
+
hidden_states = embed_tokens(input_ids)
|
| 79 |
+
bsz, seq_len, hidden_dim = hidden_states.shape
|
| 80 |
+
ctx_len = target_hidden_states[0].shape[1] if target_hidden_states else 0
|
| 81 |
+
full_seq_len = ctx_len + seq_len
|
| 82 |
+
|
| 83 |
+
# ── Pre-compute position embeddings ONCE ──
|
| 84 |
+
target_pos = torch.arange(ctx_len, device=hidden_states.device)
|
| 85 |
+
draft_pos_ids = position_ids if position_ids is not None else torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).expand(bsz, -1)
|
| 86 |
+
extended_pos = torch.cat([
|
| 87 |
+
target_pos.unsqueeze(0).expand(bsz, -1),
|
| 88 |
+
draft_pos_ids
|
| 89 |
+
], dim=1)
|
| 90 |
+
|
| 91 |
+
position_embeddings = None
|
| 92 |
+
if hasattr(base_model.model, 'rotary_emb'):
|
| 93 |
+
dummy = torch.empty(1, full_seq_len, hidden_dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
| 94 |
+
position_embeddings = base_model.model.rotary_emb(dummy, extended_pos)
|
| 95 |
+
|
| 96 |
+
# ── Pre-compute extended attention mask ONCE ──
|
| 97 |
+
extended_mask = attention_mask # fallback
|
| 98 |
+
if attention_mask is not None and attention_mask.dim() == 4:
|
| 99 |
+
# ... (build ctx_mask_full + draft_mask_full, same logic as before)
|
| 100 |
+
# Key: use block_start (NOT block_start + 1) to prevent leakage
|
| 101 |
+
extended_mask = torch.cat([ctx_mask_full, draft_mask_full], dim=2)
|
| 102 |
+
|
| 103 |
+
# ── Layer-by-layer forward ──
|
| 104 |
+
for layer_idx, layer in enumerate(layers):
|
| 105 |
+
if target_hidden_states and layer_idx < len(target_hidden_states):
|
| 106 |
+
target_ctx = target_hidden_states[layer_idx]
|
| 107 |
+
layer_input = torch.cat([target_ctx, hidden_states], dim=1)
|
| 108 |
+
|
| 109 |
+
layer_output = layer(
|
| 110 |
+
layer_input,
|
| 111 |
+
attention_mask=extended_mask,
|
| 112 |
+
position_ids=extended_pos,
|
| 113 |
+
position_embeddings=position_embeddings,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
hidden_states = layer_output[0][:, ctx_len:, :] if isinstance(layer_output, tuple) else layer_output[:, ctx_len:, :]
|
| 117 |
+
else:
|
| 118 |
+
layer_output = layer(hidden_states, attention_mask=attention_mask, position_ids=position_ids)
|
| 119 |
+
hidden_states = layer_output[0] if isinstance(layer_output, tuple) else layer_output
|
| 120 |
+
|
| 121 |
+
hidden_states = norm(hidden_states)
|
| 122 |
+
if output_hidden_states:
|
| 123 |
+
return hidden_states
|
| 124 |
+
return lm_head(hidden_states)
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## 不需要改动的文件
|
| 130 |
+
|
| 131 |
+
| 文件 | 原因 |
|
| 132 |
+
|------|------|
|
| 133 |
+
| `eval_dflash_lora_inject.py` | 推理逻辑正确,position 已对齐 |
|
| 134 |
+
| `specforge/core/dflash_lora_inject.py` | 训练 wrapper 的 mask (`block_start` 不含 +1) 已正确 |
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## 验证方案
|
| 139 |
+
|
| 140 |
+
1. **LoRA roundtrip**: 保存后 `PeftModel.from_pretrained()` 加载无 warning
|
| 141 |
+
2. **Forward 一致性**: 预计算 vs 每层重算输出相同
|
| 142 |
+
3. **端到端评估**: 重新训练 + `eval_dflash_lora_inject.py` 验证 acceptance length 提升
|
syxin_old/backup.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
syxin_old/dflash_8gpu_03-31-13:40.log
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
|
| 3 |
+
*****************************************
|
| 4 |
+
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 5 |
+
*****************************************
|
| 6 |
+
Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0Set TORCH_CUDA_ARCH_LIST to 9.0
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
Set TORCH_CUDA_ARCH_LIST to 9.0
|
| 12 |
+
Set TORCH_CUDA_ARCH_LIST to 9.0
|
| 13 |
+
Set TORCH_CUDA_ARCH_LIST to 9.0
|
| 14 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 15 |
+
warnings.warn(
|
| 16 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 17 |
+
warnings.warn(
|
| 18 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 19 |
+
warnings.warn(
|
| 20 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 21 |
+
warnings.warn(
|
| 22 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 23 |
+
warnings.warn(
|
| 24 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 25 |
+
warnings.warn(
|
| 26 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 27 |
+
warnings.warn(
|
| 28 |
+
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/llama3_eagle.py:29: UserWarning: flash_attn is not found, falling back to flex_attention. Please install flash_attn if you want to use the flash attention backend.
|
| 29 |
+
warnings.warn(
|
| 30 |
+
INFO:specforge.utils:rank 0: bind to device 0
|
| 31 |
+
INFO:specforge.utils:rank 7: bind to device 7
|
| 32 |
+
INFO:specforge.utils:rank 0: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 33 |
+
INFO:specforge.utils:rank 2: bind to device 2
|
| 34 |
+
INFO:specforge.utils:rank 7: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 35 |
+
INFO:specforge.utils:rank 0: Initialized distributed
|
| 36 |
+
INFO:specforge.utils:Loading target model from /workspace/models/Qwen3-8B using hf backend
|
| 37 |
+
INFO:specforge.utils:rank 7: Initialized distributed
|
| 38 |
+
INFO:specforge.utils:rank 1: bind to device 1
|
| 39 |
+
INFO:specforge.utils:rank 2: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 40 |
+
INFO:specforge.utils:rank 2: Initialized distributed
|
| 41 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 42 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 43 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 44 |
+
INFO:specforge.utils:rank 6: bind to device 6
|
| 45 |
+
INFO:specforge.utils:rank 5: bind to device 5
|
| 46 |
+
INFO:specforge.utils:rank 1: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 47 |
+
INFO:specforge.utils:rank 1: Initialized distributed
|
| 48 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 49 |
+
INFO:specforge.utils:rank 6: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 50 |
+
INFO:specforge.utils:rank 4: bind to device 4
|
| 51 |
+
INFO:specforge.utils:rank 5: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 52 |
+
INFO:specforge.utils:rank 6: Initialized distributed
|
| 53 |
+
INFO:specforge.utils:rank 5: Initialized distributed
|
| 54 |
+
INFO:specforge.utils:rank 4: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 55 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 56 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 57 |
+
INFO:specforge.utils:rank 4: Initialized distributed
|
| 58 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 59 |
+
INFO:specforge.utils:rank 3: bind to device 3
|
| 60 |
+
[rank2]: Traceback (most recent call last):
|
| 61 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 62 |
+
[rank2]: return next(cls.discover(name=name))
|
| 63 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 64 |
+
[rank2]: StopIteration
|
| 65 |
+
|
| 66 |
+
[rank2]: During handling of the above exception, another exception occurred:
|
| 67 |
+
|
| 68 |
+
[rank2]: Traceback (most recent call last):
|
| 69 |
+
[rank2]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 70 |
+
[rank2]: main()
|
| 71 |
+
[rank2]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 72 |
+
[rank2]: target_model, draft_model = build_models(args)
|
| 73 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^
|
| 74 |
+
[rank2]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 75 |
+
[rank2]: target_model = get_dflash_target_model(
|
| 76 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 77 |
+
[rank2]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 78 |
+
[rank2]: return HFDFlashTargetModel.from_pretrained(
|
| 79 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 80 |
+
[rank2]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 81 |
+
[rank2]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 82 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 83 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 84 |
+
[rank2]: return model_class.from_pretrained(
|
| 85 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 86 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 87 |
+
[rank2]: return func(*args, **kwargs)
|
| 88 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
|
| 89 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 90 |
+
[rank2]: model = cls(config, *model_args, **model_kwargs)
|
| 91 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 92 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 93 |
+
[rank2]: super().__init__(config)
|
| 94 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 95 |
+
[rank2]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 96 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 97 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 98 |
+
[rank2]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 99 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 100 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 101 |
+
[rank2]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 102 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 103 |
+
[rank2]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 104 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 105 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 106 |
+
[rank2]: return distribution(distribution_name).version
|
| 107 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 108 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 109 |
+
[rank2]: return Distribution.from_name(distribution_name)
|
| 110 |
+
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 111 |
+
[rank2]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 112 |
+
[rank2]: raise PackageNotFoundError(name)
|
| 113 |
+
[rank2]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 114 |
+
[rank0]: Traceback (most recent call last):
|
| 115 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 116 |
+
[rank0]: return next(cls.discover(name=name))
|
| 117 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 118 |
+
[rank0]: StopIteration
|
| 119 |
+
|
| 120 |
+
[rank0]: During handling of the above exception, another exception occurred:
|
| 121 |
+
|
| 122 |
+
[rank0]: Traceback (most recent call last):
|
| 123 |
+
[rank0]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 124 |
+
[rank0]: main()
|
| 125 |
+
[rank0]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 126 |
+
[rank0]: target_model, draft_model = build_models(args)
|
| 127 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^
|
| 128 |
+
[rank0]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 129 |
+
[rank0]: target_model = get_dflash_target_model(
|
| 130 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 131 |
+
[rank0]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 132 |
+
[rank0]: return HFDFlashTargetModel.from_pretrained(
|
| 133 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 134 |
+
[rank0]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 135 |
+
[rank0]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 136 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 137 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 138 |
+
[rank0]: return model_class.from_pretrained(
|
| 139 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 140 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 141 |
+
[rank0]: return func(*args, **kwargs)
|
| 142 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
|
| 143 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 144 |
+
[rank0]: model = cls(config, *model_args, **model_kwargs)
|
| 145 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 146 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 147 |
+
[rank0]: super().__init__(config)
|
| 148 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 149 |
+
[rank0]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 150 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 151 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 152 |
+
[rank0]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 153 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 154 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 155 |
+
[rank0]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 156 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 157 |
+
[rank0]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 158 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 159 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 160 |
+
[rank0]: return distribution(distribution_name).version
|
| 161 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 162 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 163 |
+
[rank0]: return Distribution.from_name(distribution_name)
|
| 164 |
+
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 165 |
+
[rank0]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 166 |
+
[rank0]: raise PackageNotFoundError(name)
|
| 167 |
+
[rank0]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 168 |
+
[rank7]: Traceback (most recent call last):
|
| 169 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 170 |
+
[rank7]: return next(cls.discover(name=name))
|
| 171 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 172 |
+
[rank7]: StopIteration
|
| 173 |
+
|
| 174 |
+
[rank7]: During handling of the above exception, another exception occurred:
|
| 175 |
+
|
| 176 |
+
[rank7]: Traceback (most recent call last):
|
| 177 |
+
[rank7]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 178 |
+
[rank7]: main()
|
| 179 |
+
[rank7]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 180 |
+
[rank7]: target_model, draft_model = build_models(args)
|
| 181 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^
|
| 182 |
+
[rank7]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 183 |
+
[rank7]: target_model = get_dflash_target_model(
|
| 184 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 185 |
+
[rank7]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 186 |
+
[rank7]: return HFDFlashTargetModel.from_pretrained(
|
| 187 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 188 |
+
[rank7]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 189 |
+
[rank7]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 190 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 191 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 192 |
+
[rank7]: return model_class.from_pretrained(
|
| 193 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 194 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 195 |
+
[rank7]: return func(*args, **kwargs)
|
| 196 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^
|
| 197 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 198 |
+
[rank7]: model = cls(config, *model_args, **model_kwargs)
|
| 199 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 200 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 201 |
+
[rank7]: super().__init__(config)
|
| 202 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 203 |
+
[rank7]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 204 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 205 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 206 |
+
[rank7]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 207 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 208 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 209 |
+
[rank7]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 210 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 211 |
+
[rank7]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 212 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 213 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 214 |
+
[rank7]: return distribution(distribution_name).version
|
| 215 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 216 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 217 |
+
[rank7]: return Distribution.from_name(distribution_name)
|
| 218 |
+
[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 219 |
+
[rank7]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 220 |
+
[rank7]: raise PackageNotFoundError(name)
|
| 221 |
+
[rank7]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 222 |
+
[rank1]: Traceback (most recent call last):
|
| 223 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 224 |
+
[rank1]: return next(cls.discover(name=name))
|
| 225 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 226 |
+
[rank1]: StopIteration
|
| 227 |
+
|
| 228 |
+
[rank1]: During handling of the above exception, another exception occurred:
|
| 229 |
+
|
| 230 |
+
[rank1]: Traceback (most recent call last):
|
| 231 |
+
[rank1]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 232 |
+
[rank1]: main()
|
| 233 |
+
[rank1]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 234 |
+
[rank1]: target_model, draft_model = build_models(args)
|
| 235 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^
|
| 236 |
+
[rank1]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 237 |
+
[rank1]: target_model = get_dflash_target_model(
|
| 238 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 239 |
+
[rank1]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 240 |
+
[rank1]: return HFDFlashTargetModel.from_pretrained(
|
| 241 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 242 |
+
[rank1]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 243 |
+
[rank1]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 244 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 245 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 246 |
+
[rank1]: return model_class.from_pretrained(
|
| 247 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 248 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 249 |
+
[rank1]: return func(*args, **kwargs)
|
| 250 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
|
| 251 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 252 |
+
[rank1]: model = cls(config, *model_args, **model_kwargs)
|
| 253 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 254 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 255 |
+
[rank1]: super().__init__(config)
|
| 256 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 257 |
+
[rank1]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 258 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 259 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 260 |
+
[rank1]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 261 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 262 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 263 |
+
[rank1]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 264 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 265 |
+
[rank1]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 266 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 267 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 268 |
+
[rank1]: return distribution(distribution_name).version
|
| 269 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 270 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 271 |
+
[rank1]: return Distribution.from_name(distribution_name)
|
| 272 |
+
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 273 |
+
[rank1]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 274 |
+
[rank1]: raise PackageNotFoundError(name)
|
| 275 |
+
[rank1]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 276 |
+
INFO:specforge.utils:rank 3: device mesh: DeviceMesh((dp=8, tp=1), device: 'cuda', stride: (1, 1))
|
| 277 |
+
INFO:specforge.utils:rank 3: Initialized distributed
|
| 278 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 279 |
+
[rank5]: Traceback (most recent call last):
|
| 280 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 281 |
+
[rank5]: return next(cls.discover(name=name))
|
| 282 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 283 |
+
[rank5]: StopIteration
|
| 284 |
+
|
| 285 |
+
[rank5]: During handling of the above exception, another exception occurred:
|
| 286 |
+
|
| 287 |
+
[rank5]: Traceback (most recent call last):
|
| 288 |
+
[rank5]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 289 |
+
[rank5]: main()
|
| 290 |
+
[rank5]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 291 |
+
[rank5]: target_model, draft_model = build_models(args)
|
| 292 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^
|
| 293 |
+
[rank5]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 294 |
+
[rank5]: target_model = get_dflash_target_model(
|
| 295 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 296 |
+
[rank5]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 297 |
+
[rank5]: return HFDFlashTargetModel.from_pretrained(
|
| 298 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 299 |
+
[rank5]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 300 |
+
[rank5]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 301 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 302 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 303 |
+
[rank5]: return model_class.from_pretrained(
|
| 304 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 305 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 306 |
+
[rank5]: return func(*args, **kwargs)
|
| 307 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^
|
| 308 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 309 |
+
[rank5]: model = cls(config, *model_args, **model_kwargs)
|
| 310 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 311 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 312 |
+
[rank5]: super().__init__(config)
|
| 313 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 314 |
+
[rank5]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 315 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 316 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 317 |
+
[rank5]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 318 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 319 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 320 |
+
[rank5]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 321 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 322 |
+
[rank5]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 323 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 324 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 325 |
+
[rank5]: return distribution(distribution_name).version
|
| 326 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 327 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 328 |
+
[rank5]: return Distribution.from_name(distribution_name)
|
| 329 |
+
[rank5]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 330 |
+
[rank5]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 331 |
+
[rank5]: raise PackageNotFoundError(name)
|
| 332 |
+
[rank5]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 333 |
+
[rank4]: Traceback (most recent call last):
|
| 334 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 335 |
+
[rank4]: return next(cls.discover(name=name))
|
| 336 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 337 |
+
[rank4]: StopIteration
|
| 338 |
+
|
| 339 |
+
[rank4]: During handling of the above exception, another exception occurred:
|
| 340 |
+
|
| 341 |
+
[rank4]: Traceback (most recent call last):
|
| 342 |
+
[rank4]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 343 |
+
[rank4]: main()
|
| 344 |
+
[rank4]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 345 |
+
[rank4]: target_model, draft_model = build_models(args)
|
| 346 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^
|
| 347 |
+
[rank4]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 348 |
+
[rank4]: target_model = get_dflash_target_model(
|
| 349 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 350 |
+
[rank4]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 351 |
+
[rank4]: return HFDFlashTargetModel.from_pretrained(
|
| 352 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 353 |
+
[rank4]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 354 |
+
[rank4]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 355 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 356 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 357 |
+
[rank4]: return model_class.from_pretrained(
|
| 358 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 359 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 360 |
+
[rank4]: return func(*args, **kwargs)
|
| 361 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^
|
| 362 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 363 |
+
[rank4]: model = cls(config, *model_args, **model_kwargs)
|
| 364 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 365 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 366 |
+
[rank4]: super().__init__(config)
|
| 367 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 368 |
+
[rank4]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 369 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 370 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 371 |
+
[rank4]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 372 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 373 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 374 |
+
[rank4]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 375 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 376 |
+
[rank4]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 377 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 378 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 379 |
+
[rank4]: return distribution(distribution_name).version
|
| 380 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 381 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 382 |
+
[rank4]: return Distribution.from_name(distribution_name)
|
| 383 |
+
[rank4]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 384 |
+
[rank4]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 385 |
+
[rank4]: raise PackageNotFoundError(name)
|
| 386 |
+
[rank4]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 387 |
+
[rank6]: Traceback (most recent call last):
|
| 388 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 389 |
+
[rank6]: return next(cls.discover(name=name))
|
| 390 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 391 |
+
[rank6]: StopIteration
|
| 392 |
+
|
| 393 |
+
[rank6]: During handling of the above exception, another exception occurred:
|
| 394 |
+
|
| 395 |
+
[rank6]: Traceback (most recent call last):
|
| 396 |
+
[rank6]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 397 |
+
[rank6]: main()
|
| 398 |
+
[rank6]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 399 |
+
[rank6]: target_model, draft_model = build_models(args)
|
| 400 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^
|
| 401 |
+
[rank6]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 402 |
+
[rank6]: target_model = get_dflash_target_model(
|
| 403 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 404 |
+
[rank6]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 405 |
+
[rank6]: return HFDFlashTargetModel.from_pretrained(
|
| 406 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 407 |
+
[rank6]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 408 |
+
[rank6]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 409 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 410 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 411 |
+
[rank6]: return model_class.from_pretrained(
|
| 412 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 413 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 414 |
+
[rank6]: return func(*args, **kwargs)
|
| 415 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^
|
| 416 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 417 |
+
[rank6]: model = cls(config, *model_args, **model_kwargs)
|
| 418 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 419 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 420 |
+
[rank6]: super().__init__(config)
|
| 421 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 422 |
+
[rank6]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 423 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 424 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 425 |
+
[rank6]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 426 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 427 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 428 |
+
[rank6]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 429 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 430 |
+
[rank6]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 431 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 432 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 433 |
+
[rank6]: return distribution(distribution_name).version
|
| 434 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 435 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 436 |
+
[rank6]: return Distribution.from_name(distribution_name)
|
| 437 |
+
[rank6]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 438 |
+
[rank6]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 439 |
+
[rank6]: raise PackageNotFoundError(name)
|
| 440 |
+
[rank6]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 441 |
+
[rank3]: Traceback (most recent call last):
|
| 442 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 563, in from_name
|
| 443 |
+
[rank3]: return next(cls.discover(name=name))
|
| 444 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 445 |
+
[rank3]: StopIteration
|
| 446 |
+
|
| 447 |
+
[rank3]: During handling of the above exception, another exception occurred:
|
| 448 |
+
|
| 449 |
+
[rank3]: Traceback (most recent call last):
|
| 450 |
+
[rank3]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 723, in <module>
|
| 451 |
+
[rank3]: main()
|
| 452 |
+
[rank3]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 475, in main
|
| 453 |
+
[rank3]: target_model, draft_model = build_models(args)
|
| 454 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^
|
| 455 |
+
[rank3]: File "/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py", line 265, in build_models
|
| 456 |
+
[rank3]: target_model = get_dflash_target_model(
|
| 457 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^
|
| 458 |
+
[rank3]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 341, in get_dflash_target_model
|
| 459 |
+
[rank3]: return HFDFlashTargetModel.from_pretrained(
|
| 460 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 461 |
+
[rank3]: File "/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py", line 278, in from_pretrained
|
| 462 |
+
[rank3]: target_model = AutoModelForCausalLM.from_pretrained(
|
| 463 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 464 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
|
| 465 |
+
[rank3]: return model_class.from_pretrained(
|
| 466 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 467 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
|
| 468 |
+
[rank3]: return func(*args, **kwargs)
|
| 469 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
|
| 470 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
|
| 471 |
+
[rank3]: model = cls(config, *model_args, **model_kwargs)
|
| 472 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 473 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/models/qwen3/modeling_qwen3.py", line 435, in __init__
|
| 474 |
+
[rank3]: super().__init__(config)
|
| 475 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2076, in __init__
|
| 476 |
+
[rank3]: self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
|
| 477 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 478 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2686, in _check_and_adjust_attn_implementation
|
| 479 |
+
[rank3]: applicable_attn_implementation = self.get_correct_attn_implementation(
|
| 480 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 481 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2714, in get_correct_attn_implementation
|
| 482 |
+
[rank3]: self._flash_attn_2_can_dispatch(is_init_check)
|
| 483 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2425, in _flash_attn_2_can_dispatch
|
| 484 |
+
[rank3]: flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
|
| 485 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 486 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 1009, in version
|
| 487 |
+
[rank3]: return distribution(distribution_name).version
|
| 488 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 489 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 982, in distribution
|
| 490 |
+
[rank3]: return Distribution.from_name(distribution_name)
|
| 491 |
+
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 492 |
+
[rank3]: File "/workspace/miniconda3/envs/spec/lib/python3.11/importlib/metadata/__init__.py", line 565, in from_name
|
| 493 |
+
[rank3]: raise PackageNotFoundError(name)
|
| 494 |
+
[rank3]: importlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn
|
| 495 |
+
[rank0]:[W331 13:41:10.473818504 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 496 |
+
[rank7]:[W331 13:41:10.548010235 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 497 |
+
[rank7]:[W331 13:41:10.659783753 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 498 |
+
[rank4]:[W331 13:41:11.950068591 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 499 |
+
[rank2]:[W331 13:41:11.951701730 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 500 |
+
[rank6]:[W331 13:41:11.974675832 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 501 |
+
[rank5]:[W331 13:41:11.997313679 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 502 |
+
[rank3]:[W331 13:41:11.024650758 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 503 |
+
[rank1]:[W331 13:41:11.024685351 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
|
| 504 |
+
[rank2]:[W331 13:41:11.101274402 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 505 |
+
[rank4]:[W331 13:41:11.102711684 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 506 |
+
[rank6]:[W331 13:41:11.121351120 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 507 |
+
[rank0]:[W331 13:41:11.122852367 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 508 |
+
[rank1]:[W331 13:41:11.167109415 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 509 |
+
[rank3]:[W331 13:41:11.170910568 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 510 |
+
[rank5]:[W331 13:41:11.173578451 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
| 511 |
+
W0331 13:41:11.393000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 641 closing signal SIGTERM
|
| 512 |
+
W0331 13:41:11.393000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 642 closing signal SIGTERM
|
| 513 |
+
W0331 13:41:11.394000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 643 closing signal SIGTERM
|
| 514 |
+
W0331 13:41:11.394000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 644 closing signal SIGTERM
|
| 515 |
+
W0331 13:41:11.394000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 645 closing signal SIGTERM
|
| 516 |
+
W0331 13:41:11.395000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 646 closing signal SIGTERM
|
| 517 |
+
W0331 13:41:11.395000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 647 closing signal SIGTERM
|
| 518 |
+
E0331 13:41:12.401000 540 site-packages/torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 7 (pid: 648) of binary: /workspace/miniconda3/envs/spec/bin/python3
|
| 519 |
+
Traceback (most recent call last):
|
| 520 |
+
File "<frozen runpy>", line 198, in _run_module_as_main
|
| 521 |
+
File "<frozen runpy>", line 88, in _run_code
|
| 522 |
+
File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/run.py", line 940, in <module>
|
| 523 |
+
main()
|
| 524 |
+
File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
|
| 525 |
+
return f(*args, **kwargs)
|
| 526 |
+
^^^^^^^^^^^^^^^^^^
|
| 527 |
+
File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/run.py", line 936, in main
|
| 528 |
+
run(args)
|
| 529 |
+
File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/run.py", line 927, in run
|
| 530 |
+
elastic_launch(
|
| 531 |
+
File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 156, in __call__
|
| 532 |
+
return launch_agent(self._config, self._entrypoint, list(args))
|
| 533 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 534 |
+
File "/workspace/miniconda3/envs/spec/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
|
| 535 |
+
raise ChildFailedError(
|
| 536 |
+
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
|
| 537 |
+
============================================================
|
| 538 |
+
/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash.py FAILED
|
| 539 |
+
------------------------------------------------------------
|
| 540 |
+
Failures:
|
| 541 |
+
<NO_OTHER_FAILURES>
|
| 542 |
+
------------------------------------------------------------
|
| 543 |
+
Root Cause (first observed failure):
|
| 544 |
+
[0]:
|
| 545 |
+
time : 2026-03-31_13:41:11
|
| 546 |
+
host : job-006ce80a7c47-20260302193512-5cd88f7cfc-mlbh9
|
| 547 |
+
rank : 7 (local_rank: 7)
|
| 548 |
+
exitcode : 1 (pid: 648)
|
| 549 |
+
error_file: <N/A>
|
| 550 |
+
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
|
| 551 |
+
============================================================
|
| 552 |
+
[W331 13:41:12.379198750 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
|
syxin_old/diagnostic_compare.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Diagnostic: compare training forward vs eval forward on the same block.
|
| 3 |
+
|
| 4 |
+
Goal: find where the train-eval mismatch is.
|
| 5 |
+
Runs on a single GPU (no distributed).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
/workspace/miniconda3/envs/dflash/bin/python3 /workspace/hanrui/syxin_old/diagnostic_compare.py
|
| 9 |
+
"""
|
| 10 |
+
import sys, os, warnings
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
sys.path.insert(0, "/workspace/hanrui/syxin_old")
|
| 16 |
+
sys.path.insert(0, "/workspace/hanrui/syxin_old/Specforge")
|
| 17 |
+
|
| 18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 19 |
+
from peft import PeftModel
|
| 20 |
+
|
| 21 |
+
BASE_MODEL = "/workspace/models/Qwen3-8B"
|
| 22 |
+
ADAPTER_PATH = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_4644"
|
| 23 |
+
BLOCK_SIZE = 16
|
| 24 |
+
MASK_TOKEN_ID = 151666
|
| 25 |
+
|
| 26 |
+
device = torch.device("cuda:0")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def main():
|
| 30 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
|
| 31 |
+
|
| 32 |
+
print("Loading target model...")
|
| 33 |
+
target_model = AutoModelForCausalLM.from_pretrained(
|
| 34 |
+
BASE_MODEL, torch_dtype=torch.bfloat16,
|
| 35 |
+
attn_implementation="sdpa", device_map=device, trust_remote_code=True,
|
| 36 |
+
)
|
| 37 |
+
target_model.eval()
|
| 38 |
+
|
| 39 |
+
print("Loading draft model (base + LoRA adapter)...")
|
| 40 |
+
draft_model = AutoModelForCausalLM.from_pretrained(
|
| 41 |
+
BASE_MODEL, torch_dtype=torch.bfloat16,
|
| 42 |
+
attn_implementation="sdpa", device_map=device, trust_remote_code=True,
|
| 43 |
+
)
|
| 44 |
+
draft_model = PeftModel.from_pretrained(draft_model, ADAPTER_PATH)
|
| 45 |
+
draft_model = draft_model.merge_and_unload()
|
| 46 |
+
draft_model.eval()
|
| 47 |
+
|
| 48 |
+
num_layers = len(draft_model.model.layers)
|
| 49 |
+
draft_layers = draft_model.model.layers
|
| 50 |
+
draft_norm = draft_model.model.norm
|
| 51 |
+
draft_lm_head = draft_model.lm_head
|
| 52 |
+
rotary_emb = draft_model.model.rotary_emb
|
| 53 |
+
|
| 54 |
+
# Create a test sequence
|
| 55 |
+
text = "The quick brown fox jumps over the lazy dog. " * 10
|
| 56 |
+
messages = [{"role": "user", "content": text}]
|
| 57 |
+
input_text = tokenizer.apply_chat_template(
|
| 58 |
+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False,
|
| 59 |
+
)
|
| 60 |
+
full_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
| 61 |
+
|
| 62 |
+
# Extend to get a sequence long enough for multiple blocks
|
| 63 |
+
# Use the target model to generate some tokens
|
| 64 |
+
print(f"Input length: {full_ids.shape[1]}")
|
| 65 |
+
|
| 66 |
+
# We'll create a fixed sequence by repeating the prompt
|
| 67 |
+
# Actually, let's use the full_ids as is (it should be ~200 tokens)
|
| 68 |
+
seq_len = full_ids.shape[1]
|
| 69 |
+
|
| 70 |
+
# Make it align to block boundaries
|
| 71 |
+
n_blocks = seq_len // BLOCK_SIZE
|
| 72 |
+
effective_len = n_blocks * BLOCK_SIZE
|
| 73 |
+
input_ids = full_ids[:, :effective_len]
|
| 74 |
+
seq_len = effective_len
|
| 75 |
+
|
| 76 |
+
print(f"Using {n_blocks} blocks, seq_len = {seq_len}")
|
| 77 |
+
|
| 78 |
+
# ═══════════════════════════════════════════════════════
|
| 79 |
+
# TRAINING-STYLE FORWARD
|
| 80 |
+
# ═══════════════════════════════════════════════════════
|
| 81 |
+
print("\n" + "="*60)
|
| 82 |
+
print("TRAINING-STYLE FORWARD")
|
| 83 |
+
print("="*60)
|
| 84 |
+
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
# Step 1: Get target hidden states (full sequence)
|
| 87 |
+
target_output = target_model(
|
| 88 |
+
input_ids,
|
| 89 |
+
output_hidden_states=True,
|
| 90 |
+
)
|
| 91 |
+
# target hidden states: [hidden_states[0], ..., hidden_states[L-1]]
|
| 92 |
+
# hidden_states[k] = input to layer k
|
| 93 |
+
target_hidden_states = [target_output.hidden_states[i] for i in range(num_layers)]
|
| 94 |
+
|
| 95 |
+
# Step 2: Prepare noise input (mask non-anchors)
|
| 96 |
+
noise_input = input_ids.clone()
|
| 97 |
+
positions = torch.arange(seq_len, device=device)
|
| 98 |
+
is_anchor = (positions % BLOCK_SIZE) == 0
|
| 99 |
+
noise_input[:, ~is_anchor] = MASK_TOKEN_ID
|
| 100 |
+
|
| 101 |
+
# Step 3: Build DFlash mask (draft-to-draft)
|
| 102 |
+
NEG_INF = torch.finfo(torch.bfloat16).min
|
| 103 |
+
block_ids_mask = positions // BLOCK_SIZE
|
| 104 |
+
q_ids = block_ids_mask.unsqueeze(1)
|
| 105 |
+
k_ids = block_ids_mask.unsqueeze(0)
|
| 106 |
+
same_block = (q_ids == k_ids)
|
| 107 |
+
dflash_mask = torch.full((seq_len, seq_len), NEG_INF, device=device, dtype=torch.bfloat16)
|
| 108 |
+
dflash_mask[same_block] = 0.0
|
| 109 |
+
dflash_mask = dflash_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
|
| 110 |
+
|
| 111 |
+
# Step 4: Forward through draft model with injection
|
| 112 |
+
# (Mimicking _forward_with_injection with context_len=0)
|
| 113 |
+
ctx_len = seq_len # target hidden states span full sequence
|
| 114 |
+
full_seq_len = ctx_len + seq_len
|
| 115 |
+
|
| 116 |
+
# Position IDs: [0..N-1, 0..N-1]
|
| 117 |
+
orig_pos = torch.arange(seq_len, device=device)
|
| 118 |
+
extended_pos = torch.cat([orig_pos, orig_pos], dim=0).unsqueeze(0)
|
| 119 |
+
|
| 120 |
+
# Extended mask
|
| 121 |
+
# Context-to-context: causal
|
| 122 |
+
ctx_ctx_mask = torch.full((ctx_len, ctx_len), NEG_INF, device=device, dtype=torch.bfloat16)
|
| 123 |
+
ctx_ctx_mask = torch.triu(ctx_ctx_mask, diagonal=1)
|
| 124 |
+
ctx_draft_mask = torch.full((ctx_len, seq_len), NEG_INF, device=device, dtype=torch.bfloat16)
|
| 125 |
+
ctx_mask_full = torch.cat([ctx_ctx_mask, ctx_draft_mask], dim=-1)
|
| 126 |
+
ctx_mask_full = ctx_mask_full.unsqueeze(0).unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
draft_mask_full = torch.full((1, 1, seq_len, full_seq_len), NEG_INF, device=device, dtype=torch.bfloat16)
|
| 129 |
+
|
| 130 |
+
# Draft-to-target visibility
|
| 131 |
+
draft_pos = torch.arange(seq_len, device=device)
|
| 132 |
+
target_pos = torch.arange(ctx_len, device=device)
|
| 133 |
+
context_len = 0 # training uses context_len=0
|
| 134 |
+
|
| 135 |
+
is_ctx = draft_pos < context_len # all False
|
| 136 |
+
block_id = (draft_pos - context_len).clamp(min=0) // BLOCK_SIZE
|
| 137 |
+
block_start = context_len + block_id * BLOCK_SIZE
|
| 138 |
+
max_visible = torch.where(is_ctx, draft_pos + 1, block_start + 1)
|
| 139 |
+
visible = target_pos.unsqueeze(0) < max_visible.unsqueeze(1)
|
| 140 |
+
draft_mask_full[:, :, :, :ctx_len].masked_fill_(
|
| 141 |
+
visible.unsqueeze(0).unsqueeze(0), 0
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Draft-to-draft
|
| 145 |
+
draft_mask_full[:, :, :, ctx_len:] = dflash_mask
|
| 146 |
+
|
| 147 |
+
extended_mask = torch.cat([ctx_mask_full, draft_mask_full], dim=2)
|
| 148 |
+
|
| 149 |
+
# Position embeddings
|
| 150 |
+
dummy = torch.empty(1, full_seq_len, target_hidden_states[0].shape[-1],
|
| 151 |
+
device=device, dtype=torch.bfloat16)
|
| 152 |
+
position_embeddings = rotary_emb(dummy, extended_pos)
|
| 153 |
+
|
| 154 |
+
# Layer-by-layer forward
|
| 155 |
+
hidden_states = draft_model.model.embed_tokens(noise_input)
|
| 156 |
+
|
| 157 |
+
for layer_idx in range(num_layers):
|
| 158 |
+
target_ctx = target_hidden_states[layer_idx]
|
| 159 |
+
layer_input = torch.cat([target_ctx, hidden_states], dim=1)
|
| 160 |
+
|
| 161 |
+
layer_output = draft_layers[layer_idx](
|
| 162 |
+
layer_input,
|
| 163 |
+
attention_mask=extended_mask,
|
| 164 |
+
position_ids=extended_pos,
|
| 165 |
+
position_embeddings=position_embeddings,
|
| 166 |
+
)
|
| 167 |
+
if isinstance(layer_output, tuple):
|
| 168 |
+
layer_output = layer_output[0]
|
| 169 |
+
hidden_states = layer_output[:, ctx_len:, :]
|
| 170 |
+
|
| 171 |
+
hidden_states = draft_norm(hidden_states)
|
| 172 |
+
train_logits = draft_lm_head(hidden_states) # [1, seq_len, vocab_size]
|
| 173 |
+
train_preds = train_logits.argmax(dim=-1) # [1, seq_len]
|
| 174 |
+
|
| 175 |
+
# Compute training accuracy per block
|
| 176 |
+
print("\nTraining-style per-block accuracy (consecutive correct from pos 1):")
|
| 177 |
+
for b in range(n_blocks):
|
| 178 |
+
start = b * BLOCK_SIZE
|
| 179 |
+
block_preds = train_preds[0, start:start + BLOCK_SIZE]
|
| 180 |
+
block_labels = input_ids[0, start:start + BLOCK_SIZE]
|
| 181 |
+
correct = (block_preds[1:] == block_labels[1:]) # skip anchor
|
| 182 |
+
cumprod = correct.cumprod(dim=0)
|
| 183 |
+
accept_len = cumprod.sum().item()
|
| 184 |
+
print(f" Block {b} (pos {start}-{start+15}): accept_len={accept_len}, "
|
| 185 |
+
f"token_acc={correct.float().mean():.3f}")
|
| 186 |
+
|
| 187 |
+
# ═══════════════════════════════════════════════════════
|
| 188 |
+
# EVAL-STYLE FORWARD (block by block)
|
| 189 |
+
# ═══════════════════════════════════════════════════════
|
| 190 |
+
print("\n" + "="*60)
|
| 191 |
+
print("EVAL-STYLE FORWARD (block by block)")
|
| 192 |
+
print("="*60)
|
| 193 |
+
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
# Get target hidden states for the full sequence (to use as context)
|
| 196 |
+
# In real eval, these would come from incremental target forwards
|
| 197 |
+
# Here we use the same full-sequence target hidden states
|
| 198 |
+
|
| 199 |
+
for b in range(n_blocks):
|
| 200 |
+
start = b * BLOCK_SIZE
|
| 201 |
+
end = start + BLOCK_SIZE
|
| 202 |
+
|
| 203 |
+
# Block input: anchor + MASK
|
| 204 |
+
block_ids = input_ids[:, start:end].clone()
|
| 205 |
+
block_ids[:, 1:] = MASK_TOKEN_ID # mask non-anchors
|
| 206 |
+
|
| 207 |
+
# Context: target hidden states for positions 0..start (inclusive)
|
| 208 |
+
# This matches the training visibility: block k sees target 0..k*16
|
| 209 |
+
ctx_end = start + 1 # include anchor position
|
| 210 |
+
|
| 211 |
+
if ctx_end == 0:
|
| 212 |
+
# Block 0 with no context — skip (can't have empty context)
|
| 213 |
+
# Actually block 0 has anchor at position 0, so ctx_end = 1
|
| 214 |
+
ctx_end = 1
|
| 215 |
+
|
| 216 |
+
ctx_len_eval = ctx_end
|
| 217 |
+
actual_bs = BLOCK_SIZE
|
| 218 |
+
|
| 219 |
+
# Build eval mask
|
| 220 |
+
full_len_eval = ctx_len_eval + actual_bs
|
| 221 |
+
eval_mask = torch.full((1, 1, full_len_eval, full_len_eval), NEG_INF,
|
| 222 |
+
device=device, dtype=torch.bfloat16)
|
| 223 |
+
|
| 224 |
+
# Context-to-context: causal
|
| 225 |
+
if ctx_len_eval > 0:
|
| 226 |
+
ctx_rows = torch.arange(ctx_len_eval, device=device)
|
| 227 |
+
ctx_cols = torch.arange(ctx_len_eval, device=device)
|
| 228 |
+
causal = ctx_cols.unsqueeze(0) <= ctx_rows.unsqueeze(1)
|
| 229 |
+
eval_mask[0, 0, :ctx_len_eval, :ctx_len_eval].masked_fill_(causal, 0)
|
| 230 |
+
|
| 231 |
+
# Block-to-context: all visible
|
| 232 |
+
eval_mask[0, 0, ctx_len_eval:, :ctx_len_eval] = 0
|
| 233 |
+
# Block-to-block: bidirectional
|
| 234 |
+
eval_mask[0, 0, ctx_len_eval:, ctx_len_eval:] = 0
|
| 235 |
+
|
| 236 |
+
# Position IDs
|
| 237 |
+
ctx_positions = torch.arange(ctx_len_eval, device=device)
|
| 238 |
+
block_positions = torch.arange(start, start + actual_bs, device=device)
|
| 239 |
+
combined_pos = torch.cat([ctx_positions, block_positions], dim=0).unsqueeze(0)
|
| 240 |
+
|
| 241 |
+
# Position embeddings
|
| 242 |
+
hidden_dim = target_hidden_states[0].shape[-1]
|
| 243 |
+
dummy_eval = torch.empty(1, full_len_eval, hidden_dim, device=device, dtype=torch.bfloat16)
|
| 244 |
+
pos_emb_eval = rotary_emb(dummy_eval, combined_pos)
|
| 245 |
+
|
| 246 |
+
# Draft forward
|
| 247 |
+
draft_hidden = draft_model.model.embed_tokens(block_ids)
|
| 248 |
+
|
| 249 |
+
for layer_idx in range(num_layers):
|
| 250 |
+
target_ctx = target_hidden_states[layer_idx][:, :ctx_end, :]
|
| 251 |
+
combined = torch.cat([target_ctx, draft_hidden], dim=1)
|
| 252 |
+
|
| 253 |
+
layer_output = draft_layers[layer_idx](
|
| 254 |
+
combined,
|
| 255 |
+
attention_mask=eval_mask,
|
| 256 |
+
position_ids=combined_pos,
|
| 257 |
+
position_embeddings=pos_emb_eval,
|
| 258 |
+
)
|
| 259 |
+
if isinstance(layer_output, tuple):
|
| 260 |
+
layer_output = layer_output[0]
|
| 261 |
+
draft_hidden = layer_output[:, ctx_len_eval:, :]
|
| 262 |
+
|
| 263 |
+
draft_hidden = draft_norm(draft_hidden)
|
| 264 |
+
eval_logits = draft_lm_head(draft_hidden) # [1, 16, vocab_size]
|
| 265 |
+
eval_preds = eval_logits.argmax(dim=-1) # [1, 16]
|
| 266 |
+
|
| 267 |
+
# Compare with training
|
| 268 |
+
train_block_preds = train_preds[0, start:end]
|
| 269 |
+
eval_block_preds = eval_preds[0]
|
| 270 |
+
block_labels = input_ids[0, start:end]
|
| 271 |
+
|
| 272 |
+
train_correct = (train_block_preds[1:] == block_labels[1:])
|
| 273 |
+
eval_correct = (eval_block_preds[1:] == block_labels[1:])
|
| 274 |
+
preds_match = (train_block_preds == eval_block_preds)
|
| 275 |
+
|
| 276 |
+
train_accept = train_correct.cumprod(dim=0).sum().item()
|
| 277 |
+
eval_accept = eval_correct.cumprod(dim=0).sum().item()
|
| 278 |
+
|
| 279 |
+
# Check if logits are close
|
| 280 |
+
train_block_logits = train_logits[0, start:end, :]
|
| 281 |
+
eval_block_logits = eval_logits[0, :, :]
|
| 282 |
+
logit_diff = (train_block_logits - eval_block_logits).abs().max().item()
|
| 283 |
+
logit_rmse = ((train_block_logits - eval_block_logits)**2).mean().sqrt().item()
|
| 284 |
+
|
| 285 |
+
print(f" Block {b} (pos {start}-{start+15}):")
|
| 286 |
+
print(f" Train accept_len={train_accept}, Eval accept_len={eval_accept}")
|
| 287 |
+
print(f" Predictions match: {preds_match.sum().item()}/{BLOCK_SIZE}")
|
| 288 |
+
print(f" Logit max_diff={logit_diff:.4f}, rmse={logit_rmse:.6f}")
|
| 289 |
+
|
| 290 |
+
if not preds_match.all():
|
| 291 |
+
mismatch_pos = (~preds_match).nonzero(as_tuple=True)[0]
|
| 292 |
+
for pos in mismatch_pos[:5]: # show first 5 mismatches
|
| 293 |
+
p = pos.item()
|
| 294 |
+
print(f" Mismatch at block pos {p}: "
|
| 295 |
+
f"train={train_block_preds[p].item()}, "
|
| 296 |
+
f"eval={eval_block_preds[p].item()}, "
|
| 297 |
+
f"label={block_labels[p].item()}")
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
main()
|
syxin_old/eval_alignment_diff.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash Eval 对齐分析:你的脚本 vs 官方 benchmark.py
|
| 2 |
+
|
| 3 |
+
## 修改总览
|
| 4 |
+
|
| 5 |
+
| # | 问题 | 影响 | baseline | lora_inject |
|
| 6 |
+
|---|------|------|----------|-------------|
|
| 7 |
+
| 1 | Acceptance Length 计算方式 | 🔴 数值不同 | ✅ 修 | ✅ 修 |
|
| 8 |
+
| 2 | Multi-turn 对话支持 | 🔴 mt-bench 结果不同 | ✅ 修 | ✅ 修 |
|
| 9 |
+
| 3 | 样本选择:顺序 vs shuffle | 🔴 子集不同 | ✅ 修 | ✅ 修 |
|
| 10 |
+
| 4 | 数据集只有3个,官方10个 | 🔴 覆盖不全 | ✅ 修 | ✅ 修 |
|
| 11 |
+
| 5 | stop_token_ids 检查范围 | 🟡 可能提前/延迟停止 | ✅ 修 | ✅ 修 |
|
| 12 |
+
| 6 | 分布式聚合方式 | 🟡 丢失per-sample粒度 | ✅ 修 | ✅ 修 |
|
| 13 |
+
| 7 | AR baseline 含 draft forward | 🔴 speedup偏高(仅inject) | N/A | ✅ 修 |
|
| 14 |
+
| 8 | max_new_tokens 默认值 | 🟡 | ✅ 修 | ✅ 修 |
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## 修改详情
|
| 19 |
+
|
| 20 |
+
### 1. Acceptance Length 计算方式
|
| 21 |
+
|
| 22 |
+
**问题:** 官方是 per-sample mean 再取 mean,你的是全局池化。
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
# ❌ 你的(两个脚本都是)
|
| 26 |
+
avg_accept_length = total_accept_sum / total_count
|
| 27 |
+
|
| 28 |
+
# ✅ 官方
|
| 29 |
+
tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses])
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
**修改:** 收集每个 sample 的 accept_lengths list,先算 per-sample mean,再取 mean。
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
### 2. Multi-turn 对话支持
|
| 37 |
+
|
| 38 |
+
**问题:** 官方对 mt-bench 等多轮数据集,会逐轮生成并将 assistant 回复加入 context。
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
# ❌ 你的:只取 turns[0],单轮
|
| 42 |
+
messages = [{"role": "user", "content": prompt}]
|
| 43 |
+
|
| 44 |
+
# ✅ 官方:逐轮生成
|
| 45 |
+
for turn_index, user_content in enumerate(instance["turns"]):
|
| 46 |
+
messages.append({"role": "user", "content": user_content})
|
| 47 |
+
# generate ...
|
| 48 |
+
messages.append({"role": "assistant", "content": output_text})
|
| 49 |
+
responses.append(response)
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
**修改:** 数据加载改为返回 `{"turns": [...]}` 格式,生成循环改为逐轮。
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
### 3. 样本选择
|
| 57 |
+
|
| 58 |
+
**问题:** 官方 shuffle 后选取,你的是顺序取前N。
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
# ❌ 你的
|
| 62 |
+
prompts = prompts[:num_samples]
|
| 63 |
+
|
| 64 |
+
# ✅ 官方
|
| 65 |
+
dataset = dataset.shuffle(seed=0).select(range(max_samples))
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
**修改:** 改用 HF dataset 的 shuffle + select。
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
### 4. 数据集补齐
|
| 73 |
+
|
| 74 |
+
**问题:** 缺少 math500, aime24, aime25, mbpp, livecodebench, swe-bench, alpaca 共 7 个。
|
| 75 |
+
|
| 76 |
+
**修改:** 直接复用官方 `load_and_process_dataset()` 函数。
|
| 77 |
+
|
| 78 |
+
---
|
| 79 |
+
|
| 80 |
+
### 5. stop_token_ids 检查范围
|
| 81 |
+
|
| 82 |
+
```python
|
| 83 |
+
# ❌ 你的 baseline(检查到 start)
|
| 84 |
+
output_ids[:, num_input_tokens:start]
|
| 85 |
+
|
| 86 |
+
# ✅ 官方(检查所有已生成)
|
| 87 |
+
output_ids[:, num_input_tokens:]
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
**修改:** 改为 `output_ids[:, num_input_tokens:]`。
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
### 6. 分布式聚合
|
| 95 |
+
|
| 96 |
+
**问题:** 你用 all_reduce 聚合标量,丢失 per-sample 粒度。
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
# ❌ 你的:all_reduce sum/count
|
| 100 |
+
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
|
| 101 |
+
|
| 102 |
+
# ✅ 官方:gather 完整 response list
|
| 103 |
+
responses = dist.gather(responses, dst=0)
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
**修改:** 改为 gather per-sample 的 acceptance_lengths + time metrics 到 rank 0 统一计算。
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
### 7. AR baseline 含 draft forward (仅 lora_inject)
|
| 111 |
+
|
| 112 |
+
**问题:** block_size=1 时仍跑完整的 inject pipeline(包含 draft model),导致 AR 时间偏大、speedup 偏高。
|
| 113 |
+
|
| 114 |
+
```python
|
| 115 |
+
# ❌ 你的 lora_inject AR baseline
|
| 116 |
+
spec_generate_inject(..., block_size=1) # 仍会过 draft model layers
|
| 117 |
+
|
| 118 |
+
# ✅ 应该:纯 target autoregressive
|
| 119 |
+
ar_generate(target_model, input_ids, ...) # 只用 target
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
**修改:** 新增纯 AR 生成函数,block_size=1 时不经过 draft。
|
| 123 |
+
|
| 124 |
+
---
|
| 125 |
+
|
| 126 |
+
### 8. max_new_tokens 默认值
|
| 127 |
+
|
| 128 |
+
```python
|
| 129 |
+
# 官方 shell 脚本用 2048(Python 默认 16384)
|
| 130 |
+
# 你的默认 2048,和 shell 一致,保持不变
|
| 131 |
+
# 但增加提示,显式说明
|
| 132 |
+
```
|
syxin_old/eval_dflash_b16_baseline.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Offline evaluation for DFlash-b16 baseline: measure accepted length.
|
| 4 |
+
8 GPUs parallel, each GPU loads target + draft independently.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
# 8 GPUs
|
| 8 |
+
torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py
|
| 9 |
+
|
| 10 |
+
# quick test
|
| 11 |
+
torchrun --nproc_per_node 8 eval_dflash_b16_baseline.py --num-samples 20
|
| 12 |
+
|
| 13 |
+
# single GPU
|
| 14 |
+
python3 eval_dflash_b16_baseline.py --benchmarks humaneval
|
| 15 |
+
"""
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from typing import List, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.distributed as dist
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
| 28 |
+
|
| 29 |
+
# Add DFlash model path so we can import utils
|
| 30 |
+
sys.path.insert(0, "/workspace/models/Qwen3-8B-DFlash-b16")
|
| 31 |
+
from utils import extract_context_feature, sample
|
| 32 |
+
|
| 33 |
+
# ──────────────────────────────────────────────────────────────────
|
| 34 |
+
BASE_MODEL = "/workspace/models/Qwen3-8B"
|
| 35 |
+
DRAFT_MODEL = "/workspace/models/Qwen3-8B-DFlash-b16"
|
| 36 |
+
RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ──────────────────────────────────────────────────────────────────
|
| 40 |
+
# Distributed helpers
|
| 41 |
+
# ──────────────────────────────────────────────────────────────────
|
| 42 |
+
def is_distributed():
|
| 43 |
+
return dist.is_available() and dist.is_initialized()
|
| 44 |
+
|
| 45 |
+
def get_rank():
|
| 46 |
+
return dist.get_rank() if is_distributed() else 0
|
| 47 |
+
|
| 48 |
+
def get_world_size():
|
| 49 |
+
return dist.get_world_size() if is_distributed() else 1
|
| 50 |
+
|
| 51 |
+
def is_main():
|
| 52 |
+
return get_rank() == 0
|
| 53 |
+
|
| 54 |
+
def print_rank0(*args, **kwargs):
|
| 55 |
+
if is_main():
|
| 56 |
+
print(*args, **kwargs)
|
| 57 |
+
|
| 58 |
+
def split_list(lst, rank, world_size):
|
| 59 |
+
return [x for i, x in enumerate(lst) if i % world_size == rank]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ──────────────────────────────────────────────────────────────────
|
| 63 |
+
# Prompts
|
| 64 |
+
# ──────────────────────────────────────────────────────────────────
|
| 65 |
+
def load_prompts(bench_name: str, num_samples: Optional[int] = None) -> List[str]:
|
| 66 |
+
local_paths = {
|
| 67 |
+
"humaneval": "/workspace/hanrui/datasets/humaneval/test.jsonl",
|
| 68 |
+
"mtbench": "/workspace/hanrui/datasets/mtbench/question.jsonl",
|
| 69 |
+
"gsm8k": "/workspace/hanrui/datasets/gsm8k/test.jsonl",
|
| 70 |
+
}
|
| 71 |
+
prompts = []
|
| 72 |
+
path = local_paths.get(bench_name)
|
| 73 |
+
if path and os.path.exists(path):
|
| 74 |
+
with open(path) as f:
|
| 75 |
+
for line in f:
|
| 76 |
+
item = json.loads(line)
|
| 77 |
+
if bench_name == "humaneval":
|
| 78 |
+
p = f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{item['prompt']}\n```"
|
| 79 |
+
elif bench_name == "mtbench":
|
| 80 |
+
p = item.get("turns", [item.get("prompt", "")])[0]
|
| 81 |
+
elif bench_name == "gsm8k":
|
| 82 |
+
p = item["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}."
|
| 83 |
+
else:
|
| 84 |
+
p = str(item)
|
| 85 |
+
prompts.append(p)
|
| 86 |
+
else:
|
| 87 |
+
from datasets import load_dataset
|
| 88 |
+
if bench_name == "humaneval":
|
| 89 |
+
ds = load_dataset("openai/openai_humaneval", split="test")
|
| 90 |
+
prompts = [f"Write a solution to the following problem and make sure that it passes the tests:\n```python\n{x['prompt']}\n```" for x in ds]
|
| 91 |
+
elif bench_name == "mtbench":
|
| 92 |
+
ds = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
|
| 93 |
+
prompts = [x["prompt"][0] for x in ds]
|
| 94 |
+
elif bench_name == "gsm8k":
|
| 95 |
+
ds = load_dataset("openai/gsm8k", "main", split="test")
|
| 96 |
+
prompts = [x["question"] + "\nPlease reason step by step, and put your final answer within \\boxed{}." for x in ds]
|
| 97 |
+
if num_samples is not None:
|
| 98 |
+
prompts = prompts[:num_samples]
|
| 99 |
+
return prompts
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ──────────────────────────────────────────────────────────────────
|
| 103 |
+
# spec_generate with acceptance_lengths returned
|
| 104 |
+
# (Same logic as DFlashDraftModel.spec_generate but returns accept lens)
|
| 105 |
+
# ──────────────────────────────────────────────────────────────────
|
| 106 |
+
@torch.inference_mode()
|
| 107 |
+
def spec_generate_b16(
|
| 108 |
+
draft_model,
|
| 109 |
+
target_model: nn.Module,
|
| 110 |
+
input_ids: torch.LongTensor,
|
| 111 |
+
max_new_tokens: int = 512,
|
| 112 |
+
temperature: float = 0.0,
|
| 113 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 114 |
+
) -> Tuple[torch.Tensor, List[int]]:
|
| 115 |
+
"""Same as DFlashDraftModel.spec_generate but also returns acceptance_lengths."""
|
| 116 |
+
draft_model.eval()
|
| 117 |
+
device = target_model.device if hasattr(target_model, 'device') else input_ids.device
|
| 118 |
+
num_input_tokens = input_ids.shape[1]
|
| 119 |
+
max_length = num_input_tokens + max_new_tokens
|
| 120 |
+
block_size = draft_model.block_size
|
| 121 |
+
mask_token_id = draft_model.mask_token_id
|
| 122 |
+
|
| 123 |
+
output_ids = torch.full(
|
| 124 |
+
(1, max_length + block_size), mask_token_id,
|
| 125 |
+
dtype=torch.long, device=device,
|
| 126 |
+
)
|
| 127 |
+
position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
|
| 128 |
+
|
| 129 |
+
past_key_values_target = DynamicCache()
|
| 130 |
+
past_key_values_draft = DynamicCache()
|
| 131 |
+
|
| 132 |
+
# Prefill
|
| 133 |
+
output = target_model(
|
| 134 |
+
input_ids,
|
| 135 |
+
position_ids=position_ids[:, :num_input_tokens],
|
| 136 |
+
past_key_values=past_key_values_target,
|
| 137 |
+
use_cache=True,
|
| 138 |
+
logits_to_keep=1,
|
| 139 |
+
output_hidden_states=True,
|
| 140 |
+
)
|
| 141 |
+
output_ids[:, :num_input_tokens] = input_ids
|
| 142 |
+
output_ids[:, num_input_tokens:num_input_tokens + 1] = sample(output.logits, temperature)
|
| 143 |
+
target_hidden = extract_context_feature(output.hidden_states, draft_model.target_layer_ids)
|
| 144 |
+
|
| 145 |
+
# Decode
|
| 146 |
+
acceptance_lengths = []
|
| 147 |
+
start = num_input_tokens
|
| 148 |
+
while start < max_length:
|
| 149 |
+
block_output_ids = output_ids[:, start:start + block_size].clone()
|
| 150 |
+
block_position_ids = position_ids[:, start:start + block_size]
|
| 151 |
+
noise_embedding = target_model.model.embed_tokens(block_output_ids)
|
| 152 |
+
|
| 153 |
+
draft_logits = target_model.lm_head(
|
| 154 |
+
draft_model(
|
| 155 |
+
target_hidden=target_hidden,
|
| 156 |
+
noise_embedding=noise_embedding,
|
| 157 |
+
position_ids=position_ids[:, past_key_values_draft.get_seq_length():start + block_size],
|
| 158 |
+
past_key_values=past_key_values_draft,
|
| 159 |
+
use_cache=True,
|
| 160 |
+
is_causal=False,
|
| 161 |
+
)[:, -block_size + 1:, :]
|
| 162 |
+
)
|
| 163 |
+
past_key_values_draft.crop(start)
|
| 164 |
+
block_output_ids[:, 1:] = sample(draft_logits)
|
| 165 |
+
|
| 166 |
+
output = target_model(
|
| 167 |
+
block_output_ids,
|
| 168 |
+
position_ids=block_position_ids,
|
| 169 |
+
past_key_values=past_key_values_target,
|
| 170 |
+
use_cache=True,
|
| 171 |
+
output_hidden_states=True,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
posterior = sample(output.logits, temperature)
|
| 175 |
+
acceptance_length = (
|
| 176 |
+
(block_output_ids[:, 1:] == posterior[:, :-1])
|
| 177 |
+
.cumprod(dim=1).sum(dim=1)[0].item()
|
| 178 |
+
)
|
| 179 |
+
output_ids[:, start:start + int(acceptance_length) + 1] = block_output_ids[:, :int(acceptance_length) + 1]
|
| 180 |
+
output_ids[:, start + int(acceptance_length) + 1] = posterior[:, int(acceptance_length)]
|
| 181 |
+
start += int(acceptance_length) + 1
|
| 182 |
+
past_key_values_target.crop(start)
|
| 183 |
+
target_hidden = extract_context_feature(
|
| 184 |
+
output.hidden_states, draft_model.target_layer_ids
|
| 185 |
+
)[:, :int(acceptance_length) + 1, :]
|
| 186 |
+
acceptance_lengths.append(int(acceptance_length) + 1)
|
| 187 |
+
|
| 188 |
+
if stop_token_ids is not None and any(
|
| 189 |
+
sid in output_ids[:, num_input_tokens:start] for sid in stop_token_ids
|
| 190 |
+
):
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
output_ids = output_ids[:, :max_length]
|
| 194 |
+
output_ids = output_ids[:, output_ids[0] != mask_token_id]
|
| 195 |
+
if stop_token_ids is not None:
|
| 196 |
+
stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
|
| 197 |
+
stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
|
| 198 |
+
if stop_idx.numel() > 0:
|
| 199 |
+
output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
|
| 200 |
+
|
| 201 |
+
return output_ids, acceptance_lengths
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ──────────────────────────────────────────────────────────────────
|
| 205 |
+
def parse_args():
|
| 206 |
+
p = argparse.ArgumentParser()
|
| 207 |
+
p.add_argument("--base-model", default=BASE_MODEL)
|
| 208 |
+
p.add_argument("--draft-model", default=DRAFT_MODEL)
|
| 209 |
+
p.add_argument("--max-new-tokens", type=int, default=512)
|
| 210 |
+
p.add_argument("--temperature", type=float, default=0.0)
|
| 211 |
+
p.add_argument("--benchmarks", nargs="+", default=["humaneval", "mtbench", "gsm8k"])
|
| 212 |
+
p.add_argument("--num-samples", type=int, default=None)
|
| 213 |
+
p.add_argument("--output-dir", default=RESULT_DIR)
|
| 214 |
+
return p.parse_args()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def main():
|
| 218 |
+
args = parse_args()
|
| 219 |
+
|
| 220 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 221 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 222 |
+
|
| 223 |
+
if world_size > 1:
|
| 224 |
+
dist.init_process_group(backend="nccl")
|
| 225 |
+
torch.cuda.set_device(local_rank)
|
| 226 |
+
|
| 227 |
+
device = f"cuda:{local_rank}"
|
| 228 |
+
rank = get_rank()
|
| 229 |
+
|
| 230 |
+
print_rank0(f"Running DFlash-b16 baseline on {world_size} GPU(s)")
|
| 231 |
+
|
| 232 |
+
# ── Load models ──
|
| 233 |
+
print_rank0(f"Loading target: {args.base_model}")
|
| 234 |
+
target_model = AutoModelForCausalLM.from_pretrained(
|
| 235 |
+
args.base_model,
|
| 236 |
+
torch_dtype=torch.bfloat16,
|
| 237 |
+
device_map=device,
|
| 238 |
+
trust_remote_code=True,
|
| 239 |
+
)
|
| 240 |
+
target_model.eval()
|
| 241 |
+
|
| 242 |
+
print_rank0(f"Loading DFlash-b16 draft: {args.draft_model}")
|
| 243 |
+
draft_model = AutoModel.from_pretrained(
|
| 244 |
+
args.draft_model,
|
| 245 |
+
torch_dtype=torch.bfloat16,
|
| 246 |
+
trust_remote_code=True,
|
| 247 |
+
).to(device)
|
| 248 |
+
draft_model.eval()
|
| 249 |
+
|
| 250 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
|
| 251 |
+
stop_token_ids = [tokenizer.eos_token_id]
|
| 252 |
+
|
| 253 |
+
print_rank0(f"DFlash-b16: block_size={draft_model.block_size}, "
|
| 254 |
+
f"target_layer_ids={draft_model.target_layer_ids}, "
|
| 255 |
+
f"num_layers={len(draft_model.layers)}")
|
| 256 |
+
|
| 257 |
+
# ── Run benchmarks ──
|
| 258 |
+
results = {"model": "Qwen3-8B-DFlash-b16", "type": "baseline",
|
| 259 |
+
"block_size": draft_model.block_size}
|
| 260 |
+
|
| 261 |
+
for bench_name in args.benchmarks:
|
| 262 |
+
print_rank0(f"\n{'='*60}")
|
| 263 |
+
print_rank0(f"Benchmark: {bench_name} ({world_size} GPUs)")
|
| 264 |
+
print_rank0(f"{'='*60}")
|
| 265 |
+
|
| 266 |
+
all_prompts = load_prompts(bench_name, args.num_samples)
|
| 267 |
+
my_prompts = split_list(all_prompts, rank, world_size)
|
| 268 |
+
print_rank0(f"Total {len(all_prompts)} prompts, ~{len(my_prompts)} per GPU")
|
| 269 |
+
|
| 270 |
+
local_accept_lengths = []
|
| 271 |
+
local_tokens = 0
|
| 272 |
+
t0 = time.time()
|
| 273 |
+
|
| 274 |
+
iterator = tqdm(my_prompts, desc=f"[GPU{rank}] {bench_name}", unit="sample",
|
| 275 |
+
disable=(rank != 0))
|
| 276 |
+
for prompt in iterator:
|
| 277 |
+
messages = [{"role": "user", "content": prompt}]
|
| 278 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 279 |
+
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
|
| 280 |
+
|
| 281 |
+
output_ids, accept_lens = spec_generate_b16(
|
| 282 |
+
draft_model=draft_model,
|
| 283 |
+
target_model=target_model,
|
| 284 |
+
input_ids=input_ids,
|
| 285 |
+
max_new_tokens=args.max_new_tokens,
|
| 286 |
+
temperature=args.temperature,
|
| 287 |
+
stop_token_ids=stop_token_ids,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
local_accept_lengths.extend(accept_lens)
|
| 291 |
+
num_gen = output_ids.shape[1] - input_ids.shape[1]
|
| 292 |
+
local_tokens += num_gen
|
| 293 |
+
|
| 294 |
+
if rank == 0 and len(local_accept_lengths) > 0:
|
| 295 |
+
avg = sum(local_accept_lengths) / len(local_accept_lengths)
|
| 296 |
+
iterator.set_postfix(accept_len=f"{avg:.2f}", tokens=local_tokens, gen=num_gen)
|
| 297 |
+
|
| 298 |
+
elapsed = time.time() - t0
|
| 299 |
+
|
| 300 |
+
# ── Gather ──
|
| 301 |
+
if world_size > 1:
|
| 302 |
+
local_sum = torch.tensor(sum(local_accept_lengths), dtype=torch.float64, device=device)
|
| 303 |
+
local_count = torch.tensor(len(local_accept_lengths), dtype=torch.long, device=device)
|
| 304 |
+
local_tok = torch.tensor(local_tokens, dtype=torch.long, device=device)
|
| 305 |
+
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM)
|
| 306 |
+
dist.all_reduce(local_count, op=dist.ReduceOp.SUM)
|
| 307 |
+
dist.all_reduce(local_tok, op=dist.ReduceOp.SUM)
|
| 308 |
+
total_accept_sum = local_sum.item()
|
| 309 |
+
total_count = local_count.item()
|
| 310 |
+
total_tokens = local_tok.item()
|
| 311 |
+
else:
|
| 312 |
+
total_accept_sum = sum(local_accept_lengths)
|
| 313 |
+
total_count = len(local_accept_lengths)
|
| 314 |
+
total_tokens = local_tokens
|
| 315 |
+
|
| 316 |
+
avg_accept_length = total_accept_sum / max(total_count, 1)
|
| 317 |
+
throughput = total_tokens / elapsed if elapsed > 0 else 0
|
| 318 |
+
|
| 319 |
+
print_rank0(f"\n{bench_name} Results:")
|
| 320 |
+
print_rank0(f" Avg Accept Length: {avg_accept_length:.3f}")
|
| 321 |
+
print_rank0(f" Total tokens: {total_tokens}")
|
| 322 |
+
print_rank0(f" Latency: {elapsed:.1f}s")
|
| 323 |
+
print_rank0(f" Throughput: {throughput:.1f} tok/s (aggregate {world_size} GPUs)")
|
| 324 |
+
print_rank0(f" Num verify rounds: {total_count}")
|
| 325 |
+
print_rank0(f" Num samples: {len(all_prompts)}")
|
| 326 |
+
|
| 327 |
+
results[bench_name] = {
|
| 328 |
+
"avg_accept_length": avg_accept_length,
|
| 329 |
+
"total_tokens": total_tokens,
|
| 330 |
+
"latency": elapsed,
|
| 331 |
+
"throughput": throughput,
|
| 332 |
+
"num_samples": len(all_prompts),
|
| 333 |
+
"num_verify_rounds": total_count,
|
| 334 |
+
"num_gpus": world_size,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
# ── Save ──
|
| 338 |
+
if is_main():
|
| 339 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 340 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 341 |
+
result_file = os.path.join(
|
| 342 |
+
args.output_dir,
|
| 343 |
+
f"dflash_b16_baseline_offline_{timestamp}.json",
|
| 344 |
+
)
|
| 345 |
+
with open(result_file, "w") as f:
|
| 346 |
+
json.dump(results, f, indent=2)
|
| 347 |
+
print(f"\nResults saved to: {result_file}")
|
| 348 |
+
|
| 349 |
+
if world_size > 1:
|
| 350 |
+
dist.destroy_process_group()
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
main()
|
syxin_old/eval_dflash_b16_baseline_changelog.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# eval_dflash_b16_baseline.py 修改记录
|
| 2 |
+
|
| 3 |
+
对照官方仓库 `/workspace/hanrui/dflash/benchmark.py` 和 `run_benchmark.sh`,修复了以下问题。
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 1. [Critical] 添加 `attn_implementation` 参数
|
| 8 |
+
|
| 9 |
+
**问题**: 模型加载时未指定 attention 实现,默认使用 eager attention,性能远低于论文。
|
| 10 |
+
|
| 11 |
+
**修改**: target_model 和 draft_model 均添加 `attn_implementation="flash_attention_2"`(flash_attn 未安装时自动降级为 `"sdpa"`)。
|
| 12 |
+
|
| 13 |
+
```python
|
| 14 |
+
# Before
|
| 15 |
+
target_model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, ...)
|
| 16 |
+
draft_model = AutoModel.from_pretrained(args.draft_model, torch_dtype=torch.bfloat16, ...)
|
| 17 |
+
|
| 18 |
+
# After
|
| 19 |
+
attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa"
|
| 20 |
+
target_model = AutoModelForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.bfloat16, attn_implementation=attn_impl, ...)
|
| 21 |
+
draft_model = AutoModel.from_pretrained(args.draft_model, torch_dtype=torch.bfloat16, attn_implementation=attn_impl, ...)
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## 2. [Critical] 添加 `enable_thinking=False`
|
| 27 |
+
|
| 28 |
+
**问题**: Qwen3 系列模型默认启用 thinking mode,会在输出中插入大量 `<think>...</think>` 内容,导致生成内容和长度与论文测试条件完全不同,acceptance length 指标不可比。
|
| 29 |
+
|
| 30 |
+
**修改**: `apply_chat_template` 调用中添加 `enable_thinking=False`。
|
| 31 |
+
|
| 32 |
+
```python
|
| 33 |
+
# Before
|
| 34 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 35 |
+
|
| 36 |
+
# After
|
| 37 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
---
|
| 41 |
+
|
| 42 |
+
## 3. [Critical] 添加 autoregressive baseline (block_size=1) 以计算 speedup
|
| 43 |
+
|
| 44 |
+
**问题**: 原脚本只跑了 speculative decoding(block_size=16),没有跑 autoregressive baseline(block_size=1),因此无法计算论文中报告的 decoding speedup 指标。
|
| 45 |
+
|
| 46 |
+
**修改**: 对每个 prompt 先用 `block_size=1` 跑一次 baseline,再用 `block_size=block_size` 跑 speculative decoding,计算 `speedup = t1 / tb`。`spec_generate_b16` 函数新增 `block_size` 参数以支持 block_size=1 模式。
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
# 对每个 prompt:
|
| 50 |
+
_, _, t1 = spec_generate_b16(..., block_size=1, ...) # autoregressive
|
| 51 |
+
output_ids, accept_lens, tb = spec_generate_b16(..., block_size=block_size, ...) # speculative
|
| 52 |
+
# speedup = mean(t1) / mean(tb)
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## 4. [Critical] 修复计时方式:`time.time()` → `cuda_time()`
|
| 58 |
+
|
| 59 |
+
**问题**: 原脚本使用 `time.time()` 测量 wall clock 时间,不含 CUDA synchronize,GPU 异步执行导致计时不准确。
|
| 60 |
+
|
| 61 |
+
**修改**: 新增 `cuda_time()` 函数(与官方一致),在计时点调用 `torch.cuda.synchronize()` + `time.perf_counter()`。`spec_generate_b16` 内部使用 `cuda_time()` 精确测量 prefill 和 decode 阶段耗时,并返回 `time_per_output_token`。
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
def cuda_time() -> float:
|
| 65 |
+
torch.cuda.synchronize()
|
| 66 |
+
return time.perf_counter()
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
---
|
| 70 |
+
|
| 71 |
+
## 5. [Critical] 添加 draft prefill 计时修正
|
| 72 |
+
|
| 73 |
+
**问题**: 原脚本将 draft model 的首次 prefill 时间计入了 decode 阶段,导致 time_per_token 偏高、speedup 偏低。
|
| 74 |
+
|
| 75 |
+
**修改**: 在 spec_generate_b16 的 decode 循环中,第一次 draft forward 完成后重置 `decode_start`(与官方 `draft_prefill` flag 逻辑一致)。
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
if draft_prefill:
|
| 79 |
+
draft_prefill = False
|
| 80 |
+
decode_start = cuda_time() # 重置,排除 draft 首次 prefill
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
## 6. [Important] `max_new_tokens` 默认值 512 → 2048
|
| 86 |
+
|
| 87 |
+
**问题**: 原脚本默认 `max_new_tokens=512`,官方 `run_benchmark.sh` 使用 `2048`。生成长度不足会导致 acceptance length 统计样本量不够,指标与论文不可比。
|
| 88 |
+
|
| 89 |
+
**修改**: 默认值改为 `2048`。
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
# Before
|
| 93 |
+
p.add_argument("--max-new-tokens", type=int, default=512)
|
| 94 |
+
|
| 95 |
+
# After
|
| 96 |
+
p.add_argument("--max-new-tokens", type=int, default=2048)
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
---
|
| 100 |
+
|
| 101 |
+
## 7. [Important] 添加固定随机种子
|
| 102 |
+
|
| 103 |
+
**问题**: 原脚本未设置随机种子,多次运行结果不可复现。
|
| 104 |
+
|
| 105 |
+
**修改**: 在 `main()` 开头添加与官方一致的种子设置。
|
| 106 |
+
|
| 107 |
+
```python
|
| 108 |
+
random.seed(0)
|
| 109 |
+
np.random.seed(0)
|
| 110 |
+
torch.manual_seed(0)
|
| 111 |
+
torch.cuda.manual_seed_all(0)
|
| 112 |
+
torch.backends.cudnn.deterministic = True
|
| 113 |
+
torch.backends.cudnn.benchmark = False
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## 8. [Important] spec_generate_b16 支持 block_size=1 条件分支
|
| 119 |
+
|
| 120 |
+
**问题**: 原函数硬编码使用 `draft_model.block_size`,且始终 `output_hidden_states=True`。当 block_size=1(autoregressive baseline)时,不应调用 draft model 和 extract hidden states。
|
| 121 |
+
|
| 122 |
+
**修改**:
|
| 123 |
+
- 新增 `block_size` 参数
|
| 124 |
+
- `output_hidden_states` 仅在 `block_size > 1` 时为 True
|
| 125 |
+
- draft model forward 和 hidden state 提取仅在 `block_size > 1` 时执行
|
| 126 |
+
|
| 127 |
+
---
|
| 128 |
+
|
| 129 |
+
## 9. [Minor] 输出增加 speedup 和 acceptance histogram
|
| 130 |
+
|
| 131 |
+
**修改**: 结果输出中新增:
|
| 132 |
+
- `Decoding speedup: X.XXx`(t1/tb 比值)
|
| 133 |
+
- `Acceptance length histogram`(各 acceptance length 的占比分布)
|
| 134 |
+
|
| 135 |
+
与官方 benchmark.py 的输出格式对齐。
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## 未修改项
|
| 140 |
+
|
| 141 |
+
- **模型选择**: 保留 Qwen3-8B(非官方默认的 Qwen3-4B),因为你本地模型就是 8B
|
| 142 |
+
- **Draft 模型加载方式**: 保留 `AutoModel.from_pretrained`(依赖 `trust_remote_code=True`),未改为官方的 `DFlashDraftModel`,因为需要模型目录下的 remote code 支持
|
| 143 |
+
- **数据集范围**: 保留原有的 3 个 benchmark(humaneval/mtbench/gsm8k),未扩展到官方的 10 个
|
syxin_old/eval_dflash_lora_inject.py
ADDED
|
@@ -0,0 +1,660 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Offline evaluation for DFlash-LoRA-Inject: measure accepted length & speedup.
|
| 4 |
+
Aligned with official DFlash benchmark.py methodology.
|
| 5 |
+
|
| 6 |
+
Unlike DFlash-b16 which uses a small 5-layer draft model with fc/hidden_norm,
|
| 7 |
+
LoRA-Inject uses a full Qwen3-8B with LoRA adapters that receives target hidden
|
| 8 |
+
states via layer-by-layer injection.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
conda activate spec
|
| 12 |
+
|
| 13 |
+
# 8 GPU parallel (default, all 10 benchmarks)
|
| 14 |
+
torchrun --nproc_per_node 8 eval_dflash_lora_inject.py
|
| 15 |
+
|
| 16 |
+
# single GPU
|
| 17 |
+
python3 eval_dflash_lora_inject.py
|
| 18 |
+
|
| 19 |
+
# specific checkpoint / benchmark
|
| 20 |
+
torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --ckpt epoch_0_step_1000 --datasets humaneval
|
| 21 |
+
|
| 22 |
+
# quick test
|
| 23 |
+
torchrun --nproc_per_node 8 eval_dflash_lora_inject.py --max-samples 20
|
| 24 |
+
"""
|
| 25 |
+
import argparse
|
| 26 |
+
import json
|
| 27 |
+
import os
|
| 28 |
+
import random
|
| 29 |
+
import sys
|
| 30 |
+
import time
|
| 31 |
+
import warnings
|
| 32 |
+
from itertools import chain
|
| 33 |
+
from types import SimpleNamespace
|
| 34 |
+
from typing import List, Optional, Tuple
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
import torch.distributed as dist
|
| 40 |
+
from peft import PeftModel
|
| 41 |
+
from tqdm import tqdm
|
| 42 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
| 43 |
+
|
| 44 |
+
# Import official dataset loader
|
| 45 |
+
sys.path.insert(0, "/workspace/hanrui/dflash")
|
| 46 |
+
from model.utils import load_and_process_dataset
|
| 47 |
+
|
| 48 |
+
# ──────────────────────────────────────────────────────────────────
|
| 49 |
+
# Config defaults
|
| 50 |
+
# ──────────────────────────────────────────────────────────────────
|
| 51 |
+
BASE_MODEL = "/workspace/models/Qwen3-8B"
|
| 52 |
+
ADAPTER_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject"
|
| 53 |
+
DEFAULT_CKPT = "epoch_3_step_4644"
|
| 54 |
+
MASK_TOKEN_ID = 151669 # Qwen3 <|mask|>
|
| 55 |
+
BLOCK_SIZE = 16
|
| 56 |
+
RESULT_DIR = "/workspace/hanrui/syxin_old/Specforge/benchmarks/results"
|
| 57 |
+
|
| 58 |
+
# Official benchmark tasks (from run_benchmark.sh)
|
| 59 |
+
OFFICIAL_TASKS = {
|
| 60 |
+
"gsm8k": 128,
|
| 61 |
+
"math500": 128,
|
| 62 |
+
"aime24": 30,
|
| 63 |
+
"aime25": 30,
|
| 64 |
+
"humaneval": 164,
|
| 65 |
+
"mbpp": 128,
|
| 66 |
+
"livecodebench": 128,
|
| 67 |
+
"swe-bench": 128,
|
| 68 |
+
"mt-bench": 80,
|
| 69 |
+
"alpaca": 128,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ──────────────────────────────────────────────────────────────────
|
| 74 |
+
# CUDA-synchronised timer (matches official benchmark.py)
|
| 75 |
+
# ──────────────────────────────────────────────────────────────────
|
| 76 |
+
def cuda_time() -> float:
|
| 77 |
+
torch.cuda.synchronize()
|
| 78 |
+
return time.perf_counter()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def has_flash_attn() -> bool:
|
| 82 |
+
try:
|
| 83 |
+
import flash_attn # noqa: F401
|
| 84 |
+
return True
|
| 85 |
+
except ImportError:
|
| 86 |
+
print("[WARN] flash_attn not installed, falling back to sdpa.")
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ──────────────────────────────────────────────────────────────────
|
| 91 |
+
# Distributed helpers (mirrors official distributed.py)
|
| 92 |
+
# ──────────────────────────────────────────────────────────────────
|
| 93 |
+
def dist_init():
|
| 94 |
+
if "RANK" not in os.environ:
|
| 95 |
+
warnings.warn("RANK not set. Skipping distributed init.")
|
| 96 |
+
return
|
| 97 |
+
dist.init_process_group(backend="nccl", init_method="env://")
|
| 98 |
+
|
| 99 |
+
def dist_rank():
|
| 100 |
+
return int(os.environ.get("RANK", 0))
|
| 101 |
+
|
| 102 |
+
def dist_size():
|
| 103 |
+
return int(os.environ.get("WORLD_SIZE", 1))
|
| 104 |
+
|
| 105 |
+
def dist_local_rank():
|
| 106 |
+
return int(os.environ.get("LOCAL_RANK", 0))
|
| 107 |
+
|
| 108 |
+
def dist_is_main():
|
| 109 |
+
return dist_rank() == 0
|
| 110 |
+
|
| 111 |
+
def dist_gather(obj, dst=0):
|
| 112 |
+
if not dist.is_initialized():
|
| 113 |
+
return [obj]
|
| 114 |
+
if dist_is_main():
|
| 115 |
+
objs = [None for _ in range(dist_size())]
|
| 116 |
+
dist.gather_object(obj, objs, dst=dst)
|
| 117 |
+
return objs
|
| 118 |
+
else:
|
| 119 |
+
dist.gather_object(obj, dst=dst)
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def print_rank0(*args, **kwargs):
|
| 123 |
+
if dist_is_main():
|
| 124 |
+
print(*args, **kwargs)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ──────────────────────────────────────────────────────────────────
|
| 128 |
+
# Sampling (matches official model/utils.py::sample)
|
| 129 |
+
# ──────────────────────────────────────────────────────────────────
|
| 130 |
+
def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor:
|
| 131 |
+
if temperature < 1e-5:
|
| 132 |
+
return torch.argmax(logits, dim=-1)
|
| 133 |
+
bsz, seq_len, vocab_size = logits.shape
|
| 134 |
+
logits = logits.view(-1, vocab_size)
|
| 135 |
+
logits = logits / temperature
|
| 136 |
+
probs = torch.softmax(logits, dim=-1)
|
| 137 |
+
return torch.multinomial(probs, num_samples=1).view(bsz, seq_len)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ──────────────────────────────────────────────────────────────────
|
| 141 |
+
# Build DFlash attention mask (vectorized, no Python loops)
|
| 142 |
+
# ──────────────────────────────────────────────────────────────────
|
| 143 |
+
def build_dflash_mask(ctx_len: int, block_size: int, device, dtype=torch.bfloat16):
|
| 144 |
+
"""
|
| 145 |
+
Build DFlash attention mask for [context | block] sequence.
|
| 146 |
+
- Context part: standard causal
|
| 147 |
+
- Block part: each token sees all context + all tokens in same block (bidirectional)
|
| 148 |
+
"""
|
| 149 |
+
full_len = ctx_len + block_size
|
| 150 |
+
neg_inf = torch.finfo(dtype).min
|
| 151 |
+
|
| 152 |
+
mask = torch.full((1, 1, full_len, full_len), neg_inf, device=device, dtype=dtype)
|
| 153 |
+
|
| 154 |
+
if ctx_len > 0:
|
| 155 |
+
ctx_rows = torch.arange(ctx_len, device=device)
|
| 156 |
+
ctx_cols = torch.arange(ctx_len, device=device)
|
| 157 |
+
causal = ctx_cols.unsqueeze(0) <= ctx_rows.unsqueeze(1)
|
| 158 |
+
mask[0, 0, :ctx_len, :ctx_len].masked_fill_(causal, 0)
|
| 159 |
+
|
| 160 |
+
if ctx_len > 0:
|
| 161 |
+
mask[0, 0, ctx_len:, :ctx_len] = 0
|
| 162 |
+
mask[0, 0, ctx_len:, ctx_len:] = 0
|
| 163 |
+
|
| 164 |
+
return mask
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ──────────────────────────────────────────────────────────────────
|
| 168 |
+
# Pure autoregressive generation (target model only, no draft)
|
| 169 |
+
# Used for AR baseline timing -- avoids inflating AR time with draft overhead.
|
| 170 |
+
# ──────────────────────────────────────────────────────────────────
|
| 171 |
+
@torch.inference_mode()
|
| 172 |
+
def ar_generate(
|
| 173 |
+
target_model: nn.Module,
|
| 174 |
+
input_ids: torch.LongTensor,
|
| 175 |
+
max_new_tokens: int = 2048,
|
| 176 |
+
mask_token_id: int = MASK_TOKEN_ID,
|
| 177 |
+
temperature: float = 0.0,
|
| 178 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 179 |
+
) -> SimpleNamespace:
|
| 180 |
+
"""
|
| 181 |
+
Pure autoregressive generation using only the target model.
|
| 182 |
+
Mirrors official benchmark.py with block_size=1 (no draft model involved).
|
| 183 |
+
Returns SimpleNamespace matching official dflash_generate output format.
|
| 184 |
+
"""
|
| 185 |
+
device = input_ids.device
|
| 186 |
+
num_input_tokens = input_ids.shape[1]
|
| 187 |
+
max_length = num_input_tokens + max_new_tokens
|
| 188 |
+
|
| 189 |
+
output_ids = torch.full(
|
| 190 |
+
(1, max_length + 1), mask_token_id,
|
| 191 |
+
dtype=torch.long, device=device,
|
| 192 |
+
)
|
| 193 |
+
output_ids[:, :num_input_tokens] = input_ids
|
| 194 |
+
position_ids = torch.arange(output_ids.shape[1], device=device).unsqueeze(0)
|
| 195 |
+
past_key_values = DynamicCache()
|
| 196 |
+
|
| 197 |
+
# Prefill
|
| 198 |
+
prefill_start = cuda_time()
|
| 199 |
+
output = target_model(
|
| 200 |
+
input_ids,
|
| 201 |
+
position_ids=position_ids[:, :num_input_tokens],
|
| 202 |
+
past_key_values=past_key_values,
|
| 203 |
+
use_cache=True,
|
| 204 |
+
logits_to_keep=1,
|
| 205 |
+
output_hidden_states=False,
|
| 206 |
+
)
|
| 207 |
+
first_token = sample(output.logits, temperature)
|
| 208 |
+
output_ids[:, num_input_tokens:num_input_tokens + 1] = first_token
|
| 209 |
+
time_to_first_token = cuda_time() - prefill_start
|
| 210 |
+
|
| 211 |
+
# Decode (autoregressive, one token at a time)
|
| 212 |
+
decode_start = cuda_time()
|
| 213 |
+
start = num_input_tokens
|
| 214 |
+
|
| 215 |
+
while start < max_length:
|
| 216 |
+
cur_token = output_ids[:, start:start + 1]
|
| 217 |
+
cur_pos = position_ids[:, start:start + 1]
|
| 218 |
+
|
| 219 |
+
output = target_model(
|
| 220 |
+
cur_token,
|
| 221 |
+
position_ids=cur_pos,
|
| 222 |
+
past_key_values=past_key_values,
|
| 223 |
+
use_cache=True,
|
| 224 |
+
output_hidden_states=False,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
next_token = sample(output.logits, temperature)
|
| 228 |
+
start += 1
|
| 229 |
+
output_ids[:, start:start + 1] = next_token
|
| 230 |
+
past_key_values.crop(start)
|
| 231 |
+
|
| 232 |
+
# Check stop tokens (matches official: check all generated)
|
| 233 |
+
if stop_token_ids is not None and any(
|
| 234 |
+
sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
|
| 235 |
+
):
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
output_ids = output_ids[:, :max_length]
|
| 239 |
+
output_ids = output_ids[:, output_ids[0] != mask_token_id]
|
| 240 |
+
if stop_token_ids is not None:
|
| 241 |
+
stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
|
| 242 |
+
stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
|
| 243 |
+
if stop_idx.numel() > 0:
|
| 244 |
+
output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
|
| 245 |
+
|
| 246 |
+
num_output_tokens = output_ids.shape[1] - num_input_tokens
|
| 247 |
+
total_decode_time = cuda_time() - decode_start
|
| 248 |
+
time_per_output_token = total_decode_time / max(num_output_tokens, 1)
|
| 249 |
+
|
| 250 |
+
return SimpleNamespace(
|
| 251 |
+
output_ids=output_ids,
|
| 252 |
+
num_input_tokens=num_input_tokens,
|
| 253 |
+
num_output_tokens=num_output_tokens,
|
| 254 |
+
time_to_first_token=time_to_first_token,
|
| 255 |
+
time_per_output_token=time_per_output_token,
|
| 256 |
+
acceptance_lengths=[1] * max(num_output_tokens, 0), # AR: always 1
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# ──────────────────────────────────────────────────────────────────
|
| 261 |
+
# Core: spec_generate with layer-by-layer injection (KV-cached)
|
| 262 |
+
# ──────────────────────────────────────────────────────────────────
|
| 263 |
+
@torch.inference_mode()
|
| 264 |
+
def spec_generate_inject(
|
| 265 |
+
target_model: nn.Module,
|
| 266 |
+
draft_model: nn.Module,
|
| 267 |
+
input_ids: torch.LongTensor,
|
| 268 |
+
max_new_tokens: int = 2048,
|
| 269 |
+
block_size: int = 16,
|
| 270 |
+
mask_token_id: int = MASK_TOKEN_ID,
|
| 271 |
+
temperature: float = 0.0,
|
| 272 |
+
stop_token_ids: Optional[List[int]] = None,
|
| 273 |
+
) -> SimpleNamespace:
|
| 274 |
+
"""
|
| 275 |
+
Speculative generation using DFlash-LoRA-Inject inference pattern.
|
| 276 |
+
Returns SimpleNamespace matching official dflash_generate output format.
|
| 277 |
+
"""
|
| 278 |
+
device = input_ids.device
|
| 279 |
+
num_input_tokens = input_ids.shape[1]
|
| 280 |
+
max_length = num_input_tokens + max_new_tokens
|
| 281 |
+
|
| 282 |
+
draft_layers = draft_model.model.layers
|
| 283 |
+
draft_norm = draft_model.model.norm
|
| 284 |
+
draft_lm_head = draft_model.lm_head
|
| 285 |
+
rotary_emb = draft_model.model.rotary_emb
|
| 286 |
+
num_layers = len(draft_layers)
|
| 287 |
+
|
| 288 |
+
output_ids = torch.full(
|
| 289 |
+
(1, max_length + block_size), mask_token_id,
|
| 290 |
+
dtype=torch.long, device=device,
|
| 291 |
+
)
|
| 292 |
+
output_ids[:, :num_input_tokens] = input_ids
|
| 293 |
+
|
| 294 |
+
# ── Prefill: target with KV cache + hidden states ──
|
| 295 |
+
prefill_start = cuda_time()
|
| 296 |
+
target_kv = DynamicCache()
|
| 297 |
+
target_output = target_model(
|
| 298 |
+
input_ids,
|
| 299 |
+
past_key_values=target_kv,
|
| 300 |
+
use_cache=True,
|
| 301 |
+
output_hidden_states=True,
|
| 302 |
+
)
|
| 303 |
+
first_token = sample(target_output.logits[:, -1:, :], temperature)
|
| 304 |
+
output_ids[:, num_input_tokens] = first_token.squeeze()
|
| 305 |
+
|
| 306 |
+
ctx_hidden_per_layer = [
|
| 307 |
+
target_output.hidden_states[i]
|
| 308 |
+
for i in range(num_layers)
|
| 309 |
+
]
|
| 310 |
+
|
| 311 |
+
time_to_first_token = cuda_time() - prefill_start
|
| 312 |
+
|
| 313 |
+
# Decode
|
| 314 |
+
decode_start = cuda_time()
|
| 315 |
+
acceptance_lengths = []
|
| 316 |
+
start = num_input_tokens
|
| 317 |
+
draft_prefill = True
|
| 318 |
+
|
| 319 |
+
while start < max_length:
|
| 320 |
+
end = min(start + block_size, max_length)
|
| 321 |
+
actual_block_size = end - start
|
| 322 |
+
|
| 323 |
+
block_ids = output_ids[:, start:end].clone()
|
| 324 |
+
|
| 325 |
+
# ── FIX: Get anchor's target hidden state before draft forward ──
|
| 326 |
+
# Training: block k's draft tokens see target hidden states at positions
|
| 327 |
+
# 0..k*block_size INCLUSIVE (the anchor position). But in eval, ctx_hidden
|
| 328 |
+
# only covers 0..start-1. We must process the anchor through the target
|
| 329 |
+
# model to get its hidden state, matching training's attention pattern.
|
| 330 |
+
anchor_token = output_ids[:, start:start + 1]
|
| 331 |
+
anchor_pos = torch.tensor([[start]], device=device)
|
| 332 |
+
anchor_output = target_model(
|
| 333 |
+
anchor_token,
|
| 334 |
+
position_ids=anchor_pos,
|
| 335 |
+
past_key_values=target_kv,
|
| 336 |
+
use_cache=True,
|
| 337 |
+
output_hidden_states=True,
|
| 338 |
+
)
|
| 339 |
+
# Save anchor hidden states (one per layer)
|
| 340 |
+
for i in range(num_layers):
|
| 341 |
+
anchor_hs = anchor_output.hidden_states[i] # [1, 1, hidden_dim]
|
| 342 |
+
ctx_hidden_per_layer[i] = torch.cat([ctx_hidden_per_layer[i], anchor_hs], dim=1)
|
| 343 |
+
# Roll back KV cache: verification will re-process from position start
|
| 344 |
+
target_kv.crop(start)
|
| 345 |
+
|
| 346 |
+
# ── Draft: forward with layer-by-layer injection ──
|
| 347 |
+
draft_hidden = draft_model.model.embed_tokens(block_ids)
|
| 348 |
+
ctx_len = ctx_hidden_per_layer[0].shape[1] # now includes anchor at position start
|
| 349 |
+
|
| 350 |
+
dflash_mask = build_dflash_mask(ctx_len, actual_block_size, device)
|
| 351 |
+
|
| 352 |
+
# Position IDs: context covers [0..start], block covers [start..start+bs-1]
|
| 353 |
+
# Position 'start' appears twice (in both context and block), matching
|
| 354 |
+
# training where target and draft share the same position IDs.
|
| 355 |
+
ctx_positions = torch.arange(ctx_len, device=device)
|
| 356 |
+
block_positions = torch.arange(start, start + actual_block_size, device=device)
|
| 357 |
+
combined_pos = torch.cat([ctx_positions, block_positions], dim=0).unsqueeze(0)
|
| 358 |
+
|
| 359 |
+
dummy_combined = torch.empty(1, ctx_len + actual_block_size, draft_hidden.shape[-1],
|
| 360 |
+
device=device, dtype=torch.bfloat16)
|
| 361 |
+
position_embeddings = rotary_emb(dummy_combined, combined_pos)
|
| 362 |
+
|
| 363 |
+
for layer_idx in range(num_layers):
|
| 364 |
+
target_ctx = ctx_hidden_per_layer[layer_idx]
|
| 365 |
+
combined = torch.cat([target_ctx, draft_hidden], dim=1)
|
| 366 |
+
|
| 367 |
+
layer_output = draft_layers[layer_idx](
|
| 368 |
+
combined,
|
| 369 |
+
attention_mask=dflash_mask,
|
| 370 |
+
position_ids=combined_pos,
|
| 371 |
+
position_embeddings=position_embeddings,
|
| 372 |
+
)
|
| 373 |
+
if isinstance(layer_output, tuple):
|
| 374 |
+
layer_output = layer_output[0]
|
| 375 |
+
draft_hidden = layer_output[:, ctx_len:, :]
|
| 376 |
+
|
| 377 |
+
draft_hidden = draft_norm(draft_hidden)
|
| 378 |
+
draft_logits = draft_lm_head(draft_hidden)
|
| 379 |
+
|
| 380 |
+
draft_predictions = sample(draft_logits[:, 1:, :], temperature)
|
| 381 |
+
block_ids[:, 1:actual_block_size] = draft_predictions[:, :actual_block_size - 1]
|
| 382 |
+
|
| 383 |
+
# Exclude draft's first prefill from decode timing (matches official pattern)
|
| 384 |
+
if draft_prefill:
|
| 385 |
+
draft_prefill = False
|
| 386 |
+
decode_start = cuda_time()
|
| 387 |
+
|
| 388 |
+
# ── Verify: target forward on block tokens (with KV cache) ──
|
| 389 |
+
position_ids_block = torch.arange(
|
| 390 |
+
start, start + actual_block_size, device=device
|
| 391 |
+
).unsqueeze(0)
|
| 392 |
+
|
| 393 |
+
target_verify = target_model(
|
| 394 |
+
block_ids,
|
| 395 |
+
position_ids=position_ids_block,
|
| 396 |
+
past_key_values=target_kv,
|
| 397 |
+
use_cache=True,
|
| 398 |
+
output_hidden_states=True,
|
| 399 |
+
)
|
| 400 |
+
target_tokens = sample(target_verify.logits, temperature)
|
| 401 |
+
|
| 402 |
+
# Acceptance
|
| 403 |
+
matches = (block_ids[:, 1:actual_block_size] == target_tokens[:, :actual_block_size - 1])
|
| 404 |
+
acceptance_length = int(matches.cumprod(dim=1).sum(dim=1)[0].item())
|
| 405 |
+
|
| 406 |
+
output_ids[:, start:start + acceptance_length + 1] = block_ids[:, :acceptance_length + 1]
|
| 407 |
+
output_ids[:, start + acceptance_length + 1] = target_tokens[:, acceptance_length]
|
| 408 |
+
|
| 409 |
+
accepted_end = start + acceptance_length + 1
|
| 410 |
+
target_kv.crop(accepted_end)
|
| 411 |
+
|
| 412 |
+
# Remove the anchor hidden state we added above (it's position start);
|
| 413 |
+
# instead, save the verification's hidden states which include the anchor
|
| 414 |
+
# and accepted tokens computed with the correct full KV context.
|
| 415 |
+
for i in range(num_layers):
|
| 416 |
+
# Drop the anchor we appended earlier (last entry in ctx_hidden)
|
| 417 |
+
ctx_hidden_per_layer[i] = ctx_hidden_per_layer[i][:, :-1, :]
|
| 418 |
+
# Add verification hidden states for accepted positions
|
| 419 |
+
new_hidden = target_verify.hidden_states[i][:, :acceptance_length + 1, :]
|
| 420 |
+
ctx_hidden_per_layer[i] = torch.cat([ctx_hidden_per_layer[i], new_hidden], dim=1)
|
| 421 |
+
|
| 422 |
+
start += acceptance_length + 1
|
| 423 |
+
acceptance_lengths.append(acceptance_length + 1)
|
| 424 |
+
|
| 425 |
+
# Official: check ALL generated tokens
|
| 426 |
+
if stop_token_ids is not None and any(
|
| 427 |
+
sid in output_ids[:, num_input_tokens:] for sid in stop_token_ids
|
| 428 |
+
):
|
| 429 |
+
break
|
| 430 |
+
|
| 431 |
+
output_ids = output_ids[:, :min(start, max_length)]
|
| 432 |
+
output_ids = output_ids[:, output_ids[0] != mask_token_id]
|
| 433 |
+
if stop_token_ids is not None:
|
| 434 |
+
stop_t = torch.tensor(stop_token_ids, device=output_ids.device)
|
| 435 |
+
stop_idx = torch.isin(output_ids[0][num_input_tokens:], stop_t).nonzero(as_tuple=True)[0]
|
| 436 |
+
if stop_idx.numel() > 0:
|
| 437 |
+
output_ids = output_ids[:, :num_input_tokens + stop_idx[0] + 1]
|
| 438 |
+
|
| 439 |
+
num_output_tokens = output_ids.shape[1] - num_input_tokens
|
| 440 |
+
total_decode_time = cuda_time() - decode_start
|
| 441 |
+
time_per_output_token = total_decode_time / max(num_output_tokens, 1)
|
| 442 |
+
|
| 443 |
+
return SimpleNamespace(
|
| 444 |
+
output_ids=output_ids,
|
| 445 |
+
num_input_tokens=num_input_tokens,
|
| 446 |
+
num_output_tokens=num_output_tokens,
|
| 447 |
+
time_to_first_token=time_to_first_token,
|
| 448 |
+
time_per_output_token=time_per_output_token,
|
| 449 |
+
acceptance_lengths=acceptance_lengths,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
# ──────────────────────────────────────────────────────────────────
|
| 454 |
+
# Main
|
| 455 |
+
# ──────────────────────────────────────────────────────────────────
|
| 456 |
+
def parse_args():
|
| 457 |
+
p = argparse.ArgumentParser(description="Offline eval for DFlash-LoRA-Inject (aligned with official)")
|
| 458 |
+
p.add_argument("--base-model", default=BASE_MODEL)
|
| 459 |
+
p.add_argument("--adapter-root", default=ADAPTER_ROOT)
|
| 460 |
+
p.add_argument("--ckpt", default=DEFAULT_CKPT, help="Checkpoint folder name")
|
| 461 |
+
p.add_argument("--merged-path",
|
| 462 |
+
default="/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged",
|
| 463 |
+
help="Path to pre-merged model. If None, will merge on the fly.")
|
| 464 |
+
p.add_argument("--block-size", type=int, default=BLOCK_SIZE)
|
| 465 |
+
p.add_argument("--max-new-tokens", type=int, default=2048,
|
| 466 |
+
help="Max new tokens per turn (official shell uses 2048)")
|
| 467 |
+
p.add_argument("--temperature", type=float, default=0.0)
|
| 468 |
+
p.add_argument("--datasets", nargs="+", default=list(OFFICIAL_TASKS.keys()),
|
| 469 |
+
help="Benchmarks to run (default: all 10 official tasks)")
|
| 470 |
+
p.add_argument("--max-samples", type=int, default=None,
|
| 471 |
+
help="Override max samples per dataset (None = use official per-task counts)")
|
| 472 |
+
p.add_argument("--output-dir", default=RESULT_DIR)
|
| 473 |
+
return p.parse_args()
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def main():
|
| 477 |
+
args = parse_args()
|
| 478 |
+
|
| 479 |
+
# Fix random seeds (matches official)
|
| 480 |
+
random.seed(0)
|
| 481 |
+
np.random.seed(0)
|
| 482 |
+
torch.manual_seed(0)
|
| 483 |
+
torch.cuda.manual_seed_all(0)
|
| 484 |
+
torch.backends.cudnn.deterministic = True
|
| 485 |
+
torch.backends.cudnn.benchmark = False
|
| 486 |
+
|
| 487 |
+
# ── Init distributed ──
|
| 488 |
+
dist_init()
|
| 489 |
+
torch.cuda.set_device(dist_local_rank())
|
| 490 |
+
device = torch.device(f"cuda:{dist_local_rank()}")
|
| 491 |
+
|
| 492 |
+
print_rank0(f"Running on {dist_size()} GPU(s)")
|
| 493 |
+
|
| 494 |
+
# Detect flash_attn (only for target model; draft needs sdpa for custom DFlash mask)
|
| 495 |
+
installed_flash_attn = has_flash_attn()
|
| 496 |
+
target_attn_impl = "flash_attention_2" if installed_flash_attn else "sdpa"
|
| 497 |
+
draft_attn_impl = "sdpa" # DFlash injection uses custom attention mask
|
| 498 |
+
print_rank0(f"Using attn_implementation: target={target_attn_impl}, draft={draft_attn_impl}")
|
| 499 |
+
|
| 500 |
+
# ── Load models ──
|
| 501 |
+
print_rank0(f"Loading target model: {args.base_model}")
|
| 502 |
+
target_model = AutoModelForCausalLM.from_pretrained(
|
| 503 |
+
args.base_model,
|
| 504 |
+
torch_dtype=torch.bfloat16,
|
| 505 |
+
attn_implementation=target_attn_impl,
|
| 506 |
+
device_map=device,
|
| 507 |
+
trust_remote_code=True,
|
| 508 |
+
)
|
| 509 |
+
target_model.eval()
|
| 510 |
+
|
| 511 |
+
if args.merged_path and os.path.isdir(args.merged_path):
|
| 512 |
+
print_rank0(f"Loading pre-merged draft model: {args.merged_path}")
|
| 513 |
+
draft_model = AutoModelForCausalLM.from_pretrained(
|
| 514 |
+
args.merged_path,
|
| 515 |
+
torch_dtype=torch.bfloat16,
|
| 516 |
+
attn_implementation=draft_attn_impl,
|
| 517 |
+
device_map=device,
|
| 518 |
+
trust_remote_code=True,
|
| 519 |
+
)
|
| 520 |
+
else:
|
| 521 |
+
adapter_path = os.path.join(args.adapter_root, args.ckpt)
|
| 522 |
+
print_rank0(f"Loading base + LoRA adapter: {adapter_path}")
|
| 523 |
+
draft_model = AutoModelForCausalLM.from_pretrained(
|
| 524 |
+
args.base_model,
|
| 525 |
+
torch_dtype=torch.bfloat16,
|
| 526 |
+
attn_implementation=draft_attn_impl,
|
| 527 |
+
device_map=device,
|
| 528 |
+
trust_remote_code=True,
|
| 529 |
+
)
|
| 530 |
+
draft_model = PeftModel.from_pretrained(draft_model, adapter_path)
|
| 531 |
+
draft_model = draft_model.merge_and_unload()
|
| 532 |
+
draft_model.eval()
|
| 533 |
+
|
| 534 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_model, trust_remote_code=True)
|
| 535 |
+
stop_token_ids = [tokenizer.eos_token_id]
|
| 536 |
+
|
| 537 |
+
block_size = args.block_size
|
| 538 |
+
|
| 539 |
+
# ── Run benchmarks ──
|
| 540 |
+
all_results = {"model": f"dflash-lora-inject/{args.ckpt}", "block_size": block_size}
|
| 541 |
+
|
| 542 |
+
for dataset_name in args.datasets:
|
| 543 |
+
print_rank0(f"\n{'=' * 60}")
|
| 544 |
+
print_rank0(f"Benchmark: {dataset_name} ({dist_size()} GPUs)")
|
| 545 |
+
print_rank0(f"{'=' * 60}")
|
| 546 |
+
|
| 547 |
+
# Load dataset using official loader
|
| 548 |
+
dataset = load_and_process_dataset(dataset_name)
|
| 549 |
+
|
| 550 |
+
# Sample selection: official uses shuffle(seed=0).select()
|
| 551 |
+
max_samples = args.max_samples if args.max_samples is not None else OFFICIAL_TASKS.get(dataset_name)
|
| 552 |
+
if max_samples is not None and len(dataset) > max_samples:
|
| 553 |
+
dataset = dataset.shuffle(seed=0).select(range(max_samples))
|
| 554 |
+
|
| 555 |
+
print_rank0(f"Total {len(dataset)} samples, distributed across {dist_size()} GPUs")
|
| 556 |
+
|
| 557 |
+
responses = []
|
| 558 |
+
indices = range(dist_rank(), len(dataset), dist_size())
|
| 559 |
+
|
| 560 |
+
iterator = tqdm(indices, desc=f"[GPU{dist_rank()}] {dataset_name}",
|
| 561 |
+
unit="sample", disable=not dist_is_main())
|
| 562 |
+
|
| 563 |
+
for idx in iterator:
|
| 564 |
+
instance = dataset[idx]
|
| 565 |
+
|
| 566 |
+
# Multi-turn support (matches official benchmark.py)
|
| 567 |
+
messages = []
|
| 568 |
+
for turn_index, user_content in enumerate(instance["turns"]):
|
| 569 |
+
messages.append({"role": "user", "content": user_content})
|
| 570 |
+
input_text = tokenizer.apply_chat_template(
|
| 571 |
+
messages, tokenize=False, add_generation_prompt=True,
|
| 572 |
+
enable_thinking=False,
|
| 573 |
+
)
|
| 574 |
+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
| 575 |
+
|
| 576 |
+
response = {}
|
| 577 |
+
|
| 578 |
+
# AR baseline: pure target-only autoregressive (no draft overhead)
|
| 579 |
+
response[1] = ar_generate(
|
| 580 |
+
target_model=target_model,
|
| 581 |
+
input_ids=input_ids,
|
| 582 |
+
max_new_tokens=args.max_new_tokens,
|
| 583 |
+
mask_token_id=MASK_TOKEN_ID,
|
| 584 |
+
temperature=args.temperature,
|
| 585 |
+
stop_token_ids=stop_token_ids,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
# Speculative: DFlash-LoRA-Inject
|
| 589 |
+
response[block_size] = spec_generate_inject(
|
| 590 |
+
target_model=target_model,
|
| 591 |
+
draft_model=draft_model,
|
| 592 |
+
input_ids=input_ids,
|
| 593 |
+
max_new_tokens=args.max_new_tokens,
|
| 594 |
+
block_size=block_size,
|
| 595 |
+
mask_token_id=MASK_TOKEN_ID,
|
| 596 |
+
temperature=args.temperature,
|
| 597 |
+
stop_token_ids=stop_token_ids,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# Append assistant response for multi-turn context
|
| 601 |
+
spec_response = response[block_size]
|
| 602 |
+
generated_ids = spec_response.output_ids[0, spec_response.num_input_tokens:]
|
| 603 |
+
output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 604 |
+
messages.append({"role": "assistant", "content": output_text})
|
| 605 |
+
responses.append(response)
|
| 606 |
+
|
| 607 |
+
if dist_is_main() and responses:
|
| 608 |
+
recent_tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses[-5:]])
|
| 609 |
+
iterator.set_postfix(accept_len=f"{recent_tau:.2f}")
|
| 610 |
+
|
| 611 |
+
# ── Gather to rank 0 (matches official) ──
|
| 612 |
+
if dist_size() > 1:
|
| 613 |
+
gathered = dist_gather(responses, dst=0)
|
| 614 |
+
if not dist_is_main():
|
| 615 |
+
continue
|
| 616 |
+
responses = list(chain(*gathered))
|
| 617 |
+
elif not dist_is_main():
|
| 618 |
+
continue
|
| 619 |
+
|
| 620 |
+
# ── Compute metrics (exact official formulas) ──
|
| 621 |
+
t1 = np.mean([r[1].time_per_output_token for r in responses])
|
| 622 |
+
tb = np.mean([r[block_size].time_per_output_token for r in responses])
|
| 623 |
+
speedup = t1 / tb if tb > 0 else 0
|
| 624 |
+
|
| 625 |
+
# Acceptance length: per-sample mean, then mean of means (official)
|
| 626 |
+
tau = np.mean([np.mean(r[block_size].acceptance_lengths) for r in responses])
|
| 627 |
+
|
| 628 |
+
# Histogram
|
| 629 |
+
acceptance_lengths = list(chain(*[r[block_size].acceptance_lengths for r in responses]))
|
| 630 |
+
histogram = [acceptance_lengths.count(b) / len(acceptance_lengths) for b in range(block_size + 1)]
|
| 631 |
+
|
| 632 |
+
print_rank0(f"\n{dataset_name} Results:")
|
| 633 |
+
print_rank0(f" Decoding speedup: {speedup:.2f}x")
|
| 634 |
+
print_rank0(f" Average Acceptance length: {tau:.2f}")
|
| 635 |
+
print_rank0(f" Acceptance length histogram: {[f'{x * 100:.1f}%' for x in histogram]}")
|
| 636 |
+
print_rank0(f" Num responses: {len(responses)}")
|
| 637 |
+
|
| 638 |
+
all_results[dataset_name] = {
|
| 639 |
+
"decoding_speedup": speedup,
|
| 640 |
+
"avg_accept_length": tau,
|
| 641 |
+
"acceptance_histogram": histogram,
|
| 642 |
+
"num_responses": len(responses),
|
| 643 |
+
"num_gpus": dist_size(),
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
# ── Save results ──
|
| 647 |
+
if dist_is_main():
|
| 648 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 649 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 650 |
+
result_file = os.path.join(
|
| 651 |
+
args.output_dir,
|
| 652 |
+
f"dflash_lora_inject_offline_{args.ckpt}_{timestamp}.json",
|
| 653 |
+
)
|
| 654 |
+
with open(result_file, "w") as f:
|
| 655 |
+
json.dump(all_results, f, indent=2)
|
| 656 |
+
print(f"\nResults saved to: {result_file}")
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
if __name__ == "__main__":
|
| 660 |
+
main()
|
syxin_old/eval_gsm8k_humaneval_mtbench.log
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
nohup: ignoring input
|
| 2 |
+
WARNING:__main__:
|
| 3 |
+
*****************************************
|
| 4 |
+
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
|
| 5 |
+
*****************************************
|
| 6 |
+
[W324 11:41:43.200488949 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 7 |
+
[W324 11:41:43.200586722 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 8 |
+
[W324 11:41:43.267031138 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 9 |
+
[W324 11:41:43.267675225 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 10 |
+
[W324 11:41:43.279640318 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 11 |
+
[W324 11:41:43.291758156 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 12 |
+
[W324 11:41:43.328250126 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 13 |
+
[W324 11:41:43.335890706 Utils.hpp:135] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
|
| 14 |
+
Running on 8 GPU(s)
|
| 15 |
+
Using attn_implementation: target=flash_attention_2, draft=sdpa
|
| 16 |
+
Loading target model: /workspace/models/Qwen3-8B
|
| 17 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 18 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 19 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 20 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 21 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 22 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 23 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 24 |
+
`torch_dtype` is deprecated! Use `dtype` instead!
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
Loading base + LoRA adapter: /workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject/epoch_3_step_4644
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
============================================================
|
| 44 |
+
Benchmark: gsm8k (8 GPUs)
|
| 45 |
+
============================================================
|
| 46 |
+
Total 128 samples, distributed across 8 GPUs
|
| 47 |
+
|
| 48 |
+
/workspace/miniconda3/envs/dflash/lib/python3.11/site-packages/torch/storage.py:414: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
| 49 |
+
return torch.load(io.BytesIO(b))
|
| 50 |
+
|
| 51 |
+
gsm8k Results:
|
| 52 |
+
Decoding speedup: 1.01x
|
| 53 |
+
Average Acceptance length: 1.99
|
| 54 |
+
Acceptance length histogram: ['0.0%', '3.6%', '94.8%', '1.3%', '0.3%', '0.1%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%']
|
| 55 |
+
Num responses: 128
|
| 56 |
+
|
| 57 |
+
============================================================
|
| 58 |
+
Benchmark: humaneval (8 GPUs)
|
| 59 |
+
============================================================
|
| 60 |
+
Total 164 samples, distributed across 8 GPUs
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
humaneval Results:
|
| 64 |
+
Decoding speedup: 0.96x
|
| 65 |
+
Average Acceptance length: 1.97
|
| 66 |
+
Acceptance length histogram: ['0.0%', '4.6%', '94.6%', '0.6%', '0.1%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%']
|
| 67 |
+
Num responses: 164
|
| 68 |
+
|
| 69 |
+
============================================================
|
| 70 |
+
Benchmark: mt-bench (8 GPUs)
|
| 71 |
+
============================================================
|
| 72 |
+
Total 80 samples, distributed across 8 GPUs
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
mt-bench Results:
|
| 76 |
+
Decoding speedup: 0.84x
|
| 77 |
+
Average Acceptance length: 1.94
|
| 78 |
+
Acceptance length histogram: ['0.0%', '6.7%', '92.6%', '0.4%', '0.2%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%', '0.0%']
|
| 79 |
+
Num responses: 160
|
| 80 |
+
|
| 81 |
+
Results saved to: /workspace/hanrui/syxin_old/Specforge/benchmarks/results/dflash_lora_inject_offline_epoch_3_step_4644_20260324_121731.json
|
syxin_old/eval_run.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
syxin_old/launch_train.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
cd /workspace/hanrui/syxin_old/Specforge
|
| 5 |
+
|
| 6 |
+
export TORCHINDUCTOR_CACHE_DIR=/workspace/hanrui/cache/compiled_kernels
|
| 7 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 8 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 9 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 10 |
+
export HF_DATASETS_CACHE=/workspace/hanrui/cache/hf_datasets
|
| 11 |
+
export HF_HOME=/workspace/hanrui/cache/hf_home
|
| 12 |
+
|
| 13 |
+
torchrun --nproc_per_node=8 \
|
| 14 |
+
scripts/train_dflash_lora_inject.py \
|
| 15 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 16 |
+
--target-model-backend hf \
|
| 17 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 18 |
+
--output-dir outputs/qwen3-8b-sft-32gpu-v2 \
|
| 19 |
+
--block-size 16 \
|
| 20 |
+
--attention-backend additive \
|
| 21 |
+
--attn-implementation sdpa \
|
| 22 |
+
--max-length 2048 \
|
| 23 |
+
--batch-size 4 \
|
| 24 |
+
--accumulation-steps 8 \
|
| 25 |
+
--num-epochs 3 \
|
| 26 |
+
--learning-rate 5e-5 \
|
| 27 |
+
--loss-decay-gamma 7 \
|
| 28 |
+
--gradient-checkpointing \
|
| 29 |
+
--chat-template qwen \
|
| 30 |
+
--log-interval 50 \
|
| 31 |
+
--save-interval 500 \
|
| 32 |
+
--cache-dir /workspace/hanrui/cache \
|
| 33 |
+
--lora-rank 32 \
|
| 34 |
+
--lora-alpha 64 \
|
| 35 |
+
--lora-dropout 0.1 \
|
| 36 |
+
--trust-remote-code \
|
| 37 |
+
--dataloader-num-workers 0
|
syxin_old/launch_train_dflash_wrapper.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Python wrapper to launch dflash training script via northjob/torchrun
|
| 4 |
+
"""
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
bash_script = "/workspace/hanrui/syxin_old/run_train_multinode_dflash.sh"
|
| 11 |
+
args = sys.argv[1:]
|
| 12 |
+
|
| 13 |
+
cmd = ["bash", bash_script] + args
|
| 14 |
+
|
| 15 |
+
result = subprocess.run(cmd, env=os.environ.copy())
|
| 16 |
+
|
| 17 |
+
sys.exit(result.returncode)
|
syxin_old/launch_train_random_anchor.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Python wrapper to launch random anchor training script via northjob/torchrun
|
| 4 |
+
"""
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
bash_script = "/workspace/hanrui/syxin_old/run_train_multinode_random_anchor.sh"
|
| 11 |
+
args = sys.argv[1:]
|
| 12 |
+
|
| 13 |
+
cmd = ["bash", bash_script] + args
|
| 14 |
+
result = subprocess.run(cmd, env=os.environ.copy())
|
| 15 |
+
sys.exit(result.returncode)
|
syxin_old/launch_train_wrapper.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Python wrapper to launch bash training script via torchrun
|
| 4 |
+
"""
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
# Get the bash script path and arguments
|
| 11 |
+
bash_script = "/workspace/hanrui/syxin_old/run_train_multinode.sh"
|
| 12 |
+
args = sys.argv[1:] # Pass through all arguments
|
| 13 |
+
|
| 14 |
+
# Build the command
|
| 15 |
+
cmd = ["bash", bash_script] + args
|
| 16 |
+
|
| 17 |
+
# Execute the bash script
|
| 18 |
+
result = subprocess.run(cmd, env=os.environ.copy())
|
| 19 |
+
|
| 20 |
+
# Exit with the same code as the bash script
|
| 21 |
+
sys.exit(result.returncode)
|
syxin_old/list.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### 1. `train_dflash_lora.py`
|
| 2 |
+
* 加了lora,原来是调用小模型,现在是hidden states+lora预测。
|
| 3 |
+
* `dflash_lora_mask_fn`函数是在处理预测的那一块草稿Block时,可以同时看到这一块里的所有词。
|
| 4 |
+
|
| 5 |
+
### 2. OOM优化
|
| 6 |
+
* 分片策略ZeRO-3,FSDP切分从`SHARD_GRAD_OP`升级到`FULL_SHARD`。
|
| 7 |
+
* `batch-size=1`,`accumulation-steps=8`。
|
| 8 |
+
* 参考之前的代码用了FlexAttention(`dflash_lora_mask_fn`)。
|
| 9 |
+
* `_chunked_lm_loss()`,把算loss切片成256块来算+梯度检查。
|
| 10 |
+
|
| 11 |
+
### 运行
|
| 12 |
+
* bash /workspace/hanrui/junquan/SpecForge/scripts/run_train_dflash_lora.sh 2
|
syxin_old/merge_lora.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Step 1: Merge DFlash-LoRA adapter into base model.
|
| 3 |
+
Usage:
|
| 4 |
+
conda activate sglang
|
| 5 |
+
python3 merge_lora.py
|
| 6 |
+
python3 merge_lora.py --ckpt epoch_2_step_15000 # 测其他 checkpoint
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from peft import PeftModel
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
BASE_MODEL = "/workspace/models/Qwen3-8B"
|
| 16 |
+
OUTPUT_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora"
|
| 17 |
+
MERGE_ROOT = "/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged"
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
p = argparse.ArgumentParser()
|
| 21 |
+
p.add_argument("--ckpt", default="epoch_3_step_18576",
|
| 22 |
+
help="Checkpoint folder name under OUTPUT_ROOT")
|
| 23 |
+
p.add_argument("--merged-path", default=MERGE_ROOT,
|
| 24 |
+
help="Where to save the merged model")
|
| 25 |
+
return p.parse_args()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def main():
|
| 29 |
+
args = parse_args()
|
| 30 |
+
adapter_path = os.path.join(OUTPUT_ROOT, args.ckpt)
|
| 31 |
+
merged_path = args.merged_path
|
| 32 |
+
|
| 33 |
+
if os.path.exists(merged_path):
|
| 34 |
+
print(f"[skip] Merged model already exists: {merged_path}")
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
assert os.path.isdir(adapter_path), f"Adapter not found: {adapter_path}"
|
| 38 |
+
|
| 39 |
+
print(f"Base model : {BASE_MODEL}")
|
| 40 |
+
print(f"Adapter : {adapter_path}")
|
| 41 |
+
print(f"Output : {merged_path}")
|
| 42 |
+
print()
|
| 43 |
+
|
| 44 |
+
print("[1/4] Loading base model to CPU ...")
|
| 45 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 46 |
+
BASE_MODEL,
|
| 47 |
+
torch_dtype=torch.bfloat16,
|
| 48 |
+
device_map="cpu",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
print("[2/4] Loading LoRA adapter ...")
|
| 52 |
+
model = PeftModel.from_pretrained(model, adapter_path)
|
| 53 |
+
|
| 54 |
+
print("[3/4] Merging weights ...")
|
| 55 |
+
model = model.merge_and_unload()
|
| 56 |
+
|
| 57 |
+
print("[4/4] Saving merged model ...")
|
| 58 |
+
os.makedirs(merged_path, exist_ok=True)
|
| 59 |
+
model.save_pretrained(merged_path, safe_serialization=True)
|
| 60 |
+
AutoTokenizer.from_pretrained(BASE_MODEL).save_pretrained(merged_path)
|
| 61 |
+
|
| 62 |
+
print(f"\nDone. Merged model saved to: {merged_path}")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
main()
|
syxin_old/oom_fix_progress.md
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DFlash LoRA OOM 修复记录
|
| 2 |
+
|
| 3 |
+
## OOM 根因分析
|
| 4 |
+
|
| 5 |
+
1. **SHARD_GRAD_OP (ZeRO-2)** — 每卡持有完整 Qwen3-8B 参数 (~16GB bf16),参数未分片
|
| 6 |
+
2. **SDPA + 4D additive mask** — FlashAttention 不支持 4D additive mask,fallback 到 math backend,每层 materialize 完整 attention scores (`bsz × 32heads × 2048 × 2048`)
|
| 7 |
+
3. **大 vocab logits** — `[bsz, 2048, 151936]` bf16 ≈ 1.18GB,加上梯度和 boolean indexing 拷贝,峰值 ~3-4GB
|
| 8 |
+
4. **机器只有 2 张 H100**,脚本默认 `NUM_GPUS=4`
|
| 9 |
+
|
| 10 |
+
## 已完成的改动
|
| 11 |
+
|
| 12 |
+
### 1. FSDP sharding 改为 FULL_SHARD (ZeRO-3)
|
| 13 |
+
- 文件: `SpecForge/scripts/train_dflash_lora.py:347`
|
| 14 |
+
- `ShardingStrategy.SHARD_GRAD_OP` → `ShardingStrategy.FULL_SHARD`
|
| 15 |
+
- 效果: 参数跨卡分片,每卡省 ~8-12GB
|
| 16 |
+
|
| 17 |
+
### 2. 降 batch-size,提高 accumulation-steps
|
| 18 |
+
- 文件: `SpecForge/scripts/run_train_dflash_lora.sh`
|
| 19 |
+
- `--batch-size 2` → `1`,`--accumulation-steps 4` → `8`
|
| 20 |
+
- 效果: 等效 global batch size 不变,峰值显存减半
|
| 21 |
+
|
| 22 |
+
## 待验证 / 后续优化
|
| 23 |
+
|
| 24 |
+
- [ ] 运行时传 `bash run_train_dflash_lora.sh 2` 确保用 2 卡
|
| 25 |
+
- [x] 如仍 OOM,考虑 chunked cross-entropy loss 避免大 vocab logits 全量 materialize
|
| 26 |
+
- [x] 长期可探索自定义 attention kernel 支持 block-sparse mask,绕过 SDPA math fallback
|
| 27 |
+
|
| 28 |
+
### 3. flex_attention + BlockMask 替换 4D additive mask
|
| 29 |
+
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
|
| 30 |
+
- 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 (Q_LEN == KV_LEN == seq_len)
|
| 31 |
+
- LoRA 版 mask: context causal + block bidirectional (非 LoRA 版是 [context, noise] concat KV)
|
| 32 |
+
- 用 `--attention-backend flex_attention` 启用 (默认),退回 `--attention-backend additive` 走原有 4D mask
|
| 33 |
+
- HuggingFace model 用 `attn_implementation="flex_attention"` 加载
|
| 34 |
+
- 效果: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 的显存
|
| 35 |
+
|
| 36 |
+
### 4. chunked cross-entropy loss
|
| 37 |
+
- 文件: `SpecForge/specforge/core/dflash_lora.py`, `specforge/modeling/draft/dflash_lora.py`, `scripts/train_dflash_lora.py`
|
| 38 |
+
- 从非 LoRA 版 `dflash.py` 移植 `_chunked_lm_loss()` 方法
|
| 39 |
+
- 分 chunk 过 lm_head + CE loss + gradient checkpointing,避免 materialize 完整 `[bsz, seq, vocab]` logits
|
| 40 |
+
- 用 `--lm-head-chunk-size 256` 启用 (默认 0 = 不启用)
|
| 41 |
+
- `DFlashLoRADraftModel.forward()` 新增 `output_hidden_states` 参数,chunked 时返回 hidden states
|
| 42 |
+
- 效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size)
|
syxin_old/random_anchor_plan.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Plan: Add Random Anchor to DFlash LoRA-Inject (syxin_old)
|
| 2 |
+
|
| 3 |
+
## Context
|
| 4 |
+
LoRA-Inject τ≈3.9 vs 原始 DFlash τ≈6.5。代码无 bug,差距来自训练策略:原始 DFlash 用 `--random-anchor --num-anchors 512` 每步随机采样 block 起点,LoRA-Inject 只用固定边界。
|
| 5 |
+
|
| 6 |
+
## 要改的文件(全在 syxin_old)
|
| 7 |
+
|
| 8 |
+
1. `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora_inject.py` — training wrapper
|
| 9 |
+
2. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/dflash_lora_inject.py` — draft model
|
| 10 |
+
3. `/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash_lora_inject.py` — 训练脚本参数
|
| 11 |
+
4. `/workspace/hanrui/syxin_old/run_train_dflash_lora_inject.sh` — 启动脚本
|
| 12 |
+
|
| 13 |
+
## 参考代码(同在 syxin_old)
|
| 14 |
+
- `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 73-146: `_sample_anchor_positions`, `_build_blocks_from_anchors`
|
| 15 |
+
- `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 219-235: `_build_additive_mask_random_anchor`
|
| 16 |
+
- `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 384-408: `_compute_loss_weights_random_anchor`
|
| 17 |
+
- `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora.py` line 529-578: random anchor forward 路径
|
| 18 |
+
|
| 19 |
+
## Step 1: training wrapper `core/dflash_lora_inject.py`
|
| 20 |
+
|
| 21 |
+
添加 4 个方法 + 修改 forward:
|
| 22 |
+
|
| 23 |
+
**1a. `_sample_anchor_positions()`** — 直接从 dflash_lora.py 复制,完全相同
|
| 24 |
+
|
| 25 |
+
**1b. `_build_blocks_from_anchors()`** — 从 dflash_lora.py 复制,**额外** gather target_layer_hidden_states(List[Tensor] 每层一个):
|
| 26 |
+
```python
|
| 27 |
+
block_target_hidden = []
|
| 28 |
+
for layer_hs in target_layer_hidden_states:
|
| 29 |
+
gathered = torch.gather(layer_hs, 1,
|
| 30 |
+
gather_idx.unsqueeze(-1).expand(-1, -1, layer_hs.size(-1)))
|
| 31 |
+
block_target_hidden.append(gathered)
|
| 32 |
+
```
|
| 33 |
+
返回多一个 `block_target_hidden_states`
|
| 34 |
+
|
| 35 |
+
**1c. `_build_additive_mask_random_anchor()`** — 直接从 dflash_lora.py 复制(line 219-235),完全相同。这是 draft-to-draft 的 mask(同 block 双向),会在 draft model 的 `_forward_with_injection` 中被扩展为 extended mask。
|
| 36 |
+
|
| 37 |
+
**1d. `_compute_loss_weights_random_anchor()`** — 直接从 dflash_lora.py 复制(line 384-408),完全相同
|
| 38 |
+
|
| 39 |
+
**1e. 修改 `forward()`** — 在 target model forward 之后插入 random anchor 分支:
|
| 40 |
+
```python
|
| 41 |
+
if self.random_anchor and self.training:
|
| 42 |
+
# 1. sample anchors
|
| 43 |
+
# 2. build blocks (input_ids, loss_mask, target hidden per layer)
|
| 44 |
+
# 3. prepare_noise_input(block_ids=block_ids)
|
| 45 |
+
# 4. build draft-draft mask
|
| 46 |
+
# 5. position_ids = gather_idx (原始序列位置!)
|
| 47 |
+
# 6. draft_model.forward(... block_ids=block_ids) # 新参数
|
| 48 |
+
# 7. loss + accuracy
|
| 49 |
+
return loss, accuracy
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Step 2: draft model `modeling/draft/dflash_lora_inject.py`
|
| 53 |
+
|
| 54 |
+
**修改 `forward()` 和 `_forward_with_injection()`**,添加 `block_ids` 参数。
|
| 55 |
+
|
| 56 |
+
当前 `_forward_with_injection` 的问题(random anchor 模式下):
|
| 57 |
+
1. **Position IDs** (line 237-238): 用 `[0..seq_len-1, 0..seq_len-1]`,但 random anchor 需要用 `[gather_idx, gather_idx]`(原始序列位置)
|
| 58 |
+
2. **Extended mask** (line 269-274): 用固定 `context_len` 和 `block_size` 算 leakage prevention,但 random anchor 的 block 边界由 `block_ids` 决定
|
| 59 |
+
|
| 60 |
+
修改方案:
|
| 61 |
+
- `forward()` 加 `block_ids=None` 参数,传递给 `_forward_with_injection`
|
| 62 |
+
- `_forward_with_injection` 加 `block_ids=None` 参数
|
| 63 |
+
- 当 `block_ids is not None`:
|
| 64 |
+
- position 使用调用方传入的 `position_ids`,构造 `extended_pos = cat([position_ids, position_ids])`(gather_idx 的位置)
|
| 65 |
+
- extended mask 的 draft-to-target 部分:block k 的 draft token 只能看 block < k 的 target token(用 block_ids 判断:`block_ids[target_pos] < block_ids[draft_pos]`,per-sample)
|
| 66 |
+
- draft-to-draft 部分:使用传入的 `attention_mask`(已由 wrapper 构建好的同 block 双向 mask)
|
| 67 |
+
|
| 68 |
+
## Step 3: 训练脚本
|
| 69 |
+
|
| 70 |
+
**3a. `scripts/train_dflash_lora_inject.py`**
|
| 71 |
+
- 添加 argparse 参数 `--random-anchor` (store_true) 和 `--num-anchors` (default=512)
|
| 72 |
+
- line 204: `random_anchor=False` → `random_anchor=args.random_anchor`
|
| 73 |
+
- line 205: `num_anchors=512` → `num_anchors=args.num_anchors`
|
| 74 |
+
|
| 75 |
+
**3b. `run_train_dflash_lora_inject.sh`**
|
| 76 |
+
- 添加 `--random-anchor --num-anchors 512`
|
| 77 |
+
- 建议提高 lr 到 6e-4,epoch 到 6(对齐原始 DFlash)
|
| 78 |
+
|
| 79 |
+
## 验证
|
| 80 |
+
1. 跑几步训练确认 loss 正常下降
|
| 81 |
+
2. 对比 dflash_lora.py random anchor 路径确认 mask/loss/position 逻辑一致
|
| 82 |
+
3. 完整训练后重新 eval
|
syxin_old/requirements.txt
ADDED
|
File without changes
|
syxin_old/run_bench_dflash.sh
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Evaluate DFlash-LoRA-Inject accepted length (offline, 8 GPUs parallel).
|
| 3 |
+
# No sglang server needed. Each GPU loads its own target+draft and processes a shard.
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash run_bench_dflash.sh # 8 GPUs, all 3 benches
|
| 7 |
+
# bash run_bench_dflash.sh humaneval # only humaneval
|
| 8 |
+
# bash run_bench_dflash.sh mtbench gsm8k # pick any subset
|
| 9 |
+
# bash run_bench_dflash.sh --quick # quick test (20 samples)
|
| 10 |
+
# bash run_bench_dflash.sh --ckpt epoch_0_step_500 # specific checkpoint
|
| 11 |
+
# NUM_GPUS=4 bash run_bench_dflash.sh # use 4 GPUs
|
| 12 |
+
|
| 13 |
+
set -e
|
| 14 |
+
|
| 15 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
| 16 |
+
PYTHON=/workspace/miniconda3/envs/spec/bin/python3
|
| 17 |
+
RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
|
| 18 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 19 |
+
|
| 20 |
+
# ---- parse args ----
|
| 21 |
+
BENCHMARKS=()
|
| 22 |
+
EXTRA_ARGS=()
|
| 23 |
+
QUICK=false
|
| 24 |
+
|
| 25 |
+
for arg in "$@"; do
|
| 26 |
+
case $arg in
|
| 27 |
+
humaneval|mtbench|gsm8k)
|
| 28 |
+
BENCHMARKS+=("$arg")
|
| 29 |
+
;;
|
| 30 |
+
--quick)
|
| 31 |
+
QUICK=true
|
| 32 |
+
;;
|
| 33 |
+
*)
|
| 34 |
+
EXTRA_ARGS+=("$arg")
|
| 35 |
+
;;
|
| 36 |
+
esac
|
| 37 |
+
done
|
| 38 |
+
|
| 39 |
+
if [ ${#BENCHMARKS[@]} -eq 0 ]; then
|
| 40 |
+
BENCHMARKS=(humaneval mtbench gsm8k)
|
| 41 |
+
fi
|
| 42 |
+
|
| 43 |
+
if [ "$QUICK" = true ]; then
|
| 44 |
+
EXTRA_ARGS+=(--num-samples 20)
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 48 |
+
|
| 49 |
+
echo "============================================"
|
| 50 |
+
echo " DFlash-LoRA-Inject Offline Eval"
|
| 51 |
+
echo " GPUs : $NUM_GPUS"
|
| 52 |
+
echo " benchmarks : ${BENCHMARKS[*]}"
|
| 53 |
+
echo " extra args : ${EXTRA_ARGS[*]}"
|
| 54 |
+
echo " results : $RESULT_DIR"
|
| 55 |
+
echo "============================================"
|
| 56 |
+
echo ""
|
| 57 |
+
|
| 58 |
+
mkdir -p $RESULT_DIR
|
| 59 |
+
|
| 60 |
+
$PYTHON -m torch.distributed.run \
|
| 61 |
+
--standalone \
|
| 62 |
+
--nproc_per_node $NUM_GPUS \
|
| 63 |
+
$SCRIPT_DIR/eval_dflash_lora_inject.py \
|
| 64 |
+
--benchmarks ${BENCHMARKS[@]} \
|
| 65 |
+
--output-dir $RESULT_DIR \
|
| 66 |
+
"${EXTRA_ARGS[@]}" \
|
| 67 |
+
2>&1 | tee $RESULT_DIR/bench_dflash_lora_inject_offline_${TIMESTAMP}.log
|
| 68 |
+
|
| 69 |
+
echo ""
|
| 70 |
+
echo "Done. Latest result files:"
|
| 71 |
+
ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
|
syxin_old/run_bench_dflash_b16_baseline.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# DFlash-b16 baseline: measure accepted length offline, 8 GPUs parallel.
|
| 3 |
+
# Usage:
|
| 4 |
+
# bash run_bench_dflash_b16_baseline.sh # 8 GPUs, all 3 benches
|
| 5 |
+
# bash run_bench_dflash_b16_baseline.sh humaneval # only humaneval
|
| 6 |
+
# bash run_bench_dflash_b16_baseline.sh --quick # 20 samples per bench
|
| 7 |
+
# NUM_GPUS=4 bash run_bench_dflash_b16_baseline.sh # 4 GPUs
|
| 8 |
+
|
| 9 |
+
set -e
|
| 10 |
+
|
| 11 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
| 12 |
+
PYTHON=/workspace/miniconda3/envs/spec/bin/python3
|
| 13 |
+
RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
|
| 14 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 15 |
+
|
| 16 |
+
BENCHMARKS=()
|
| 17 |
+
EXTRA_ARGS=()
|
| 18 |
+
QUICK=false
|
| 19 |
+
|
| 20 |
+
for arg in "$@"; do
|
| 21 |
+
case $arg in
|
| 22 |
+
humaneval|mtbench|gsm8k) BENCHMARKS+=("$arg") ;;
|
| 23 |
+
--quick) QUICK=true ;;
|
| 24 |
+
*) EXTRA_ARGS+=("$arg") ;;
|
| 25 |
+
esac
|
| 26 |
+
done
|
| 27 |
+
|
| 28 |
+
if [ ${#BENCHMARKS[@]} -eq 0 ]; then
|
| 29 |
+
BENCHMARKS=(humaneval mtbench gsm8k)
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
if [ "$QUICK" = true ]; then
|
| 33 |
+
EXTRA_ARGS+=(--num-samples 20)
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 37 |
+
|
| 38 |
+
echo "============================================"
|
| 39 |
+
echo " DFlash-b16 Baseline Offline Eval"
|
| 40 |
+
echo " GPUs : $NUM_GPUS"
|
| 41 |
+
echo " draft : /workspace/models/Qwen3-8B-DFlash-b16"
|
| 42 |
+
echo " benchmarks : ${BENCHMARKS[*]}"
|
| 43 |
+
echo " extra args : ${EXTRA_ARGS[*]}"
|
| 44 |
+
echo "============================================"
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
mkdir -p $RESULT_DIR
|
| 48 |
+
|
| 49 |
+
$PYTHON -m torch.distributed.run \
|
| 50 |
+
--standalone \
|
| 51 |
+
--nproc_per_node $NUM_GPUS \
|
| 52 |
+
$SCRIPT_DIR/eval_dflash_b16_baseline.py \
|
| 53 |
+
--benchmarks ${BENCHMARKS[@]} \
|
| 54 |
+
--output-dir $RESULT_DIR \
|
| 55 |
+
"${EXTRA_ARGS[@]}" \
|
| 56 |
+
2>&1 | tee $RESULT_DIR/bench_dflash_b16_baseline_${TIMESTAMP}.log
|
| 57 |
+
|
| 58 |
+
echo ""
|
| 59 |
+
echo "Done. Latest result files:"
|
| 60 |
+
ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
|
syxin_old/run_bench_dflash_lora_inject.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# DFlash-LoRA-Inject: measure accepted length offline, 8 GPUs parallel.
|
| 3 |
+
# Usage:
|
| 4 |
+
# bash run_bench_dflash_lora_inject.sh # 8 GPUs, all 3 benches
|
| 5 |
+
# bash run_bench_dflash_lora_inject.sh humaneval # only humaneval
|
| 6 |
+
# bash run_bench_dflash_lora_inject.sh --quick # 20 samples per bench
|
| 7 |
+
# NUM_GPUS=4 bash run_bench_dflash_lora_inject.sh # 4 GPUs
|
| 8 |
+
|
| 9 |
+
set -e
|
| 10 |
+
|
| 11 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
| 12 |
+
PYTHON=/workspace/miniconda3/envs/dflash/bin/python3
|
| 13 |
+
RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
|
| 14 |
+
NUM_GPUS=${NUM_GPUS:-8}
|
| 15 |
+
|
| 16 |
+
BENCHMARKS=()
|
| 17 |
+
EXTRA_ARGS=()
|
| 18 |
+
QUICK=false
|
| 19 |
+
|
| 20 |
+
for arg in "$@"; do
|
| 21 |
+
case $arg in
|
| 22 |
+
humaneval|mtbench|gsm8k) BENCHMARKS+=("$arg") ;;
|
| 23 |
+
--quick) QUICK=true ;;
|
| 24 |
+
*) EXTRA_ARGS+=("$arg") ;;
|
| 25 |
+
esac
|
| 26 |
+
done
|
| 27 |
+
|
| 28 |
+
if [ ${#BENCHMARKS[@]} -eq 0 ]; then
|
| 29 |
+
BENCHMARKS=(humaneval mtbench gsm8k)
|
| 30 |
+
fi
|
| 31 |
+
|
| 32 |
+
if [ "$QUICK" = true ]; then
|
| 33 |
+
EXTRA_ARGS+=(--num-samples 20)
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
| 37 |
+
|
| 38 |
+
echo "============================================"
|
| 39 |
+
echo " DFlash-LoRA-Inject Offline Eval"
|
| 40 |
+
echo " GPUs : $NUM_GPUS"
|
| 41 |
+
echo " draft : LoRA-inject (merged)"
|
| 42 |
+
echo " benchmarks : ${BENCHMARKS[*]}"
|
| 43 |
+
echo " extra args : ${EXTRA_ARGS[*]}"
|
| 44 |
+
echo "============================================"
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
mkdir -p $RESULT_DIR
|
| 48 |
+
|
| 49 |
+
$PYTHON -m torch.distributed.run \
|
| 50 |
+
--standalone \
|
| 51 |
+
--nproc_per_node $NUM_GPUS \
|
| 52 |
+
$SCRIPT_DIR/eval_dflash_lora_inject.py \
|
| 53 |
+
--benchmarks ${BENCHMARKS[@]} \
|
| 54 |
+
--output-dir $RESULT_DIR \
|
| 55 |
+
"${EXTRA_ARGS[@]}" \
|
| 56 |
+
2>&1 | tee $RESULT_DIR/bench_dflash_lora_inject_${TIMESTAMP}.log
|
| 57 |
+
|
| 58 |
+
echo ""
|
| 59 |
+
echo "Done. Latest result files:"
|
| 60 |
+
ls -lht $RESULT_DIR/*.json 2>/dev/null | head -5
|
syxin_old/run_qwen3_8b_sft_64gpu.sh
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export JOB_NAME='qwen3-32b-sft'
|
| 3 |
+
export GPU_NUMS=32
|
| 4 |
+
export TRAIN_SCRIPT='/workspace/hanrui/syxin_old/launch_train_dflash_wrapper.py'
|
| 5 |
+
export WORK_DIR='/workspace/hanrui/syxin_old/Specforge'
|
| 6 |
+
|
| 7 |
+
if [ $GPU_NUMS -lt 8 ]; then
|
| 8 |
+
export NNODES=1
|
| 9 |
+
export GPU_NUMS_PER_NODE=$GPU_NUMS
|
| 10 |
+
else
|
| 11 |
+
export NNODES=$((GPU_NUMS/8))
|
| 12 |
+
export GPU_NUMS_PER_NODE=8
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# 使用 spec 环境的 northjob
|
| 16 |
+
/workspace/miniconda3/envs/spec/bin/northjob \
|
| 17 |
+
create \
|
| 18 |
+
--job-type train \
|
| 19 |
+
--nproc-per-node $GPU_NUMS_PER_NODE \
|
| 20 |
+
--gpu-per-node $GPU_NUMS_PER_NODE \
|
| 21 |
+
--nnodes $NNODES \
|
| 22 |
+
--k8s-priority 3 \
|
| 23 |
+
--k8s-queue bg-agentic-coding \
|
| 24 |
+
--k8s-namespace bg-agentic-coding \
|
| 25 |
+
--k8s-pvc-name i-xinsiyang-y4zy0sik0a \
|
| 26 |
+
--k8s-pvc-mount-path /workspace \
|
| 27 |
+
--k8s-no-reclaim \
|
| 28 |
+
--k8s-images harbor.local.clusters/bp/megatron-bplm:25.03_fp8.ibgda.qwen3.next.fix_triton.fix_te.hf457.qwen3_vl \
|
| 29 |
+
--job-name $JOB_NAME \
|
| 30 |
+
--workspace $WORK_DIR \
|
| 31 |
+
$TRAIN_SCRIPT $GPU_NUMS_PER_NODE
|
syxin_old/run_train_dflash_lora_inject.sh
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
|
| 5 |
+
NUM_GPUS=8
|
| 6 |
+
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject-random-anchor
|
| 7 |
+
CACHE_DIR=/tmp/specforge_cache
|
| 8 |
+
|
| 9 |
+
# Parse arguments
|
| 10 |
+
if [[ $# -ge 1 ]]; then
|
| 11 |
+
NUM_GPUS=$1
|
| 12 |
+
shift
|
| 13 |
+
fi
|
| 14 |
+
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
|
| 15 |
+
OUTPUT_DIR=$1
|
| 16 |
+
shift
|
| 17 |
+
fi
|
| 18 |
+
EXTRA_ARGS=("$@")
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
|
| 22 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 23 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 24 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 25 |
+
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
|
| 26 |
+
export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
|
| 27 |
+
export HF_HOME=/tmp/specforge_cache/hf_home
|
| 28 |
+
|
| 29 |
+
# Python binary
|
| 30 |
+
DEFAULT_SPECFORGE_PY=/workspace/hanrui/specforge/bin/python3
|
| 31 |
+
if [[ -z "${PYTHON_BIN:-}" ]]; then
|
| 32 |
+
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
|
| 33 |
+
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
|
| 34 |
+
else
|
| 35 |
+
PYTHON_BIN=python3
|
| 36 |
+
fi
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
cd $ROOT_DIR
|
| 40 |
+
|
| 41 |
+
$PYTHON_BIN -m torch.distributed.run \
|
| 42 |
+
--standalone \
|
| 43 |
+
--nproc_per_node $NUM_GPUS \
|
| 44 |
+
scripts/train_dflash_lora_inject.py \
|
| 45 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 46 |
+
--target-model-backend hf \
|
| 47 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 48 |
+
--output-dir $OUTPUT_DIR \
|
| 49 |
+
--block-size 16 \
|
| 50 |
+
--attention-backend additive \
|
| 51 |
+
--attn-implementation sdpa \
|
| 52 |
+
--random-anchor \
|
| 53 |
+
--num-anchors 64 \
|
| 54 |
+
--max-length 2048 \
|
| 55 |
+
--batch-size 1 \
|
| 56 |
+
--accumulation-steps 64 \
|
| 57 |
+
--num-epochs 6 \
|
| 58 |
+
--learning-rate 6e-4 \
|
| 59 |
+
--loss-decay-gamma 7 \
|
| 60 |
+
--gradient-checkpointing \
|
| 61 |
+
--chat-template qwen \
|
| 62 |
+
--log-interval 50 \
|
| 63 |
+
--save-interval 500 \
|
| 64 |
+
--cache-dir $CACHE_DIR \
|
| 65 |
+
--lora-rank 32 \
|
| 66 |
+
--lora-alpha 64 \
|
| 67 |
+
--lora-dropout 0.1 \
|
| 68 |
+
--trust-remote-code \
|
| 69 |
+
--dataloader-num-workers 0 \
|
| 70 |
+
--early-stop \
|
| 71 |
+
--early-stop-patience 5 \
|
| 72 |
+
--early-stop-min-delta 0.005 \
|
| 73 |
+
"${EXTRA_ARGS[@]}"
|
syxin_old/run_train_multinode.sh
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
|
| 5 |
+
NUM_GPUS=8
|
| 6 |
+
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject
|
| 7 |
+
CACHE_DIR=/tmp/specforge_cache
|
| 8 |
+
|
| 9 |
+
# Parse arguments
|
| 10 |
+
if [[ $# -ge 1 ]]; then
|
| 11 |
+
NUM_GPUS=$1
|
| 12 |
+
shift
|
| 13 |
+
fi
|
| 14 |
+
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
|
| 15 |
+
OUTPUT_DIR=$1
|
| 16 |
+
shift
|
| 17 |
+
fi
|
| 18 |
+
EXTRA_ARGS=("$@")
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
|
| 22 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 23 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 24 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 25 |
+
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
|
| 26 |
+
export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
|
| 27 |
+
export HF_HOME=/tmp/specforge_cache/hf_home
|
| 28 |
+
|
| 29 |
+
# Python binary
|
| 30 |
+
DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
|
| 31 |
+
if [[ -z "${PYTHON_BIN:-}" ]]; then
|
| 32 |
+
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
|
| 33 |
+
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
|
| 34 |
+
else
|
| 35 |
+
PYTHON_BIN=python3
|
| 36 |
+
fi
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
cd $ROOT_DIR
|
| 40 |
+
|
| 41 |
+
# northjob 已经通过 torchrun 设置了分布式环境变量
|
| 42 |
+
# 直接运行训练脚本,不要再启动 torch.distributed.run
|
| 43 |
+
$PYTHON_BIN scripts/train_dflash_lora_inject.py \
|
| 44 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 45 |
+
--target-model-backend hf \
|
| 46 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 47 |
+
--output-dir $OUTPUT_DIR \
|
| 48 |
+
--block-size 16 \
|
| 49 |
+
--attention-backend additive \
|
| 50 |
+
--attn-implementation sdpa \
|
| 51 |
+
--max-length 2048 \
|
| 52 |
+
--batch-size 8 \
|
| 53 |
+
--accumulation-steps 8 \
|
| 54 |
+
--num-epochs 3 \
|
| 55 |
+
--learning-rate 5e-5 \
|
| 56 |
+
--loss-decay-gamma 7 \
|
| 57 |
+
--gradient-checkpointing \
|
| 58 |
+
--chat-template qwen \
|
| 59 |
+
--log-interval 50 \
|
| 60 |
+
--save-interval 500 \
|
| 61 |
+
--cache-dir $CACHE_DIR \
|
| 62 |
+
--lora-rank 32 \
|
| 63 |
+
--lora-alpha 64 \
|
| 64 |
+
--lora-dropout 0.1 \
|
| 65 |
+
--trust-remote-code \
|
| 66 |
+
--dataloader-num-workers 0 \
|
| 67 |
+
"${EXTRA_ARGS[@]}"
|
syxin_old/run_train_multinode_random_anchor.sh
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
|
| 5 |
+
NUM_GPUS=8
|
| 6 |
+
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-lora-inject-random-anchor
|
| 7 |
+
CACHE_DIR=/tmp/specforge_cache
|
| 8 |
+
|
| 9 |
+
# Parse arguments
|
| 10 |
+
if [[ $# -ge 1 ]]; then
|
| 11 |
+
NUM_GPUS=$1
|
| 12 |
+
shift
|
| 13 |
+
fi
|
| 14 |
+
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
|
| 15 |
+
OUTPUT_DIR=$1
|
| 16 |
+
shift
|
| 17 |
+
fi
|
| 18 |
+
EXTRA_ARGS=("$@")
|
| 19 |
+
|
| 20 |
+
# Environment variables
|
| 21 |
+
export TORCHINDUCTOR_CACHE_DIR=/tmp/specforge_cache/compiled_kernels
|
| 22 |
+
export SPECFORGE_DATA_NUM_PROC=16
|
| 23 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
| 24 |
+
export PYTORCH_ALLOC_CONF=expandable_segments:True
|
| 25 |
+
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
|
| 26 |
+
export HF_DATASETS_CACHE=/tmp/specforge_cache/hf_datasets
|
| 27 |
+
export HF_HOME=/tmp/specforge_cache/hf_home
|
| 28 |
+
|
| 29 |
+
# Python binary
|
| 30 |
+
DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
|
| 31 |
+
if [[ -z "${PYTHON_BIN:-}" ]]; then
|
| 32 |
+
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
|
| 33 |
+
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
|
| 34 |
+
else
|
| 35 |
+
PYTHON_BIN=python3
|
| 36 |
+
fi
|
| 37 |
+
fi
|
| 38 |
+
|
| 39 |
+
cd $ROOT_DIR
|
| 40 |
+
|
| 41 |
+
# northjob 已经通过 torchrun 设置了分布式环境变量
|
| 42 |
+
# 直接运行训练脚本,不要再启动 torch.distributed.run
|
| 43 |
+
$PYTHON_BIN scripts/train_dflash_lora_inject.py \
|
| 44 |
+
--target-model-path /workspace/models/Qwen3-8B \
|
| 45 |
+
--target-model-backend hf \
|
| 46 |
+
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
|
| 47 |
+
--output-dir $OUTPUT_DIR \
|
| 48 |
+
--block-size 16 \
|
| 49 |
+
--attention-backend additive \
|
| 50 |
+
--attn-implementation sdpa \
|
| 51 |
+
--random-anchor \
|
| 52 |
+
--num-anchors 64 \
|
| 53 |
+
--max-length 2048 \
|
| 54 |
+
--batch-size 1 \
|
| 55 |
+
--accumulation-steps 64 \
|
| 56 |
+
--num-epochs 6 \
|
| 57 |
+
--learning-rate 6e-4 \
|
| 58 |
+
--loss-decay-gamma 7 \
|
| 59 |
+
--gradient-checkpointing \
|
| 60 |
+
--chat-template qwen \
|
| 61 |
+
--log-interval 50 \
|
| 62 |
+
--save-interval 500 \
|
| 63 |
+
--cache-dir $CACHE_DIR \
|
| 64 |
+
--lora-rank 32 \
|
| 65 |
+
--lora-alpha 64 \
|
| 66 |
+
--lora-dropout 0.1 \
|
| 67 |
+
--trust-remote-code \
|
| 68 |
+
--dataloader-num-workers 0 \
|
| 69 |
+
--early-stop \
|
| 70 |
+
--early-stop-patience 5 \
|
| 71 |
+
--early-stop-min-delta 0.005 \
|
| 72 |
+
"${EXTRA_ARGS[@]}"
|
syxin_old/start_server.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Step 2: Launch SGLang server with STANDALONE speculative decoding.
|
| 3 |
+
# Usage:
|
| 4 |
+
# bash start_server.sh
|
| 5 |
+
# bash start_server.sh 8 # use tp=8
|
| 6 |
+
|
| 7 |
+
set -e
|
| 8 |
+
|
| 9 |
+
TP=${1:-2}
|
| 10 |
+
|
| 11 |
+
BASE_MODEL=/workspace/models/Qwen3-8B
|
| 12 |
+
MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-merged
|
| 13 |
+
INTRANET_IP=10.1.1.131
|
| 14 |
+
PORT=30000
|
| 15 |
+
|
| 16 |
+
if [ ! -d "$MERGED" ]; then
|
| 17 |
+
echo "[ERROR] Merged model not found: $MERGED"
|
| 18 |
+
echo " Run: conda activate sglang && python3 merge_lora.py"
|
| 19 |
+
exit 1
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
echo "============================================"
|
| 23 |
+
echo " SGLang STANDALONE Speculative Decoding"
|
| 24 |
+
echo " target : $BASE_MODEL"
|
| 25 |
+
echo " draft : $MERGED"
|
| 26 |
+
echo " host : $INTRANET_IP:$PORT"
|
| 27 |
+
echo " tp : $TP"
|
| 28 |
+
echo "============================================"
|
| 29 |
+
|
| 30 |
+
/workspace/miniconda3/envs/sglang/bin/python3 -m sglang.launch_server \
|
| 31 |
+
--model-path $BASE_MODEL \
|
| 32 |
+
--speculative-algorithm STANDALONE \
|
| 33 |
+
--speculative-draft-model-path $MERGED \
|
| 34 |
+
--speculative-num-steps 4 \
|
| 35 |
+
--speculative-eagle-topk 1 \
|
| 36 |
+
--speculative-num-draft-tokens 4 \
|
| 37 |
+
--tp-size $TP \
|
| 38 |
+
--mem-fraction-static 0.30 \
|
| 39 |
+
--trust-remote-code \
|
| 40 |
+
--host $INTRANET_IP \
|
| 41 |
+
--port $PORT \
|
| 42 |
+
--dtype bfloat16
|
syxin_old/start_server_dflash.sh
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Evaluate DFlash-LoRA-Inject: measure accepted length OFFLINE.
|
| 3 |
+
# 8 GPUs parallel by default, each GPU runs a shard of prompts independently.
|
| 4 |
+
#
|
| 5 |
+
# WHY offline?
|
| 6 |
+
# sglang STANDALONE treats draft as an independent autoregressive model,
|
| 7 |
+
# completely ignoring the layer-by-layer injection that LoRA-Inject was
|
| 8 |
+
# trained with. Result: accept_length ≈ 4.7 for ALL models (no signal).
|
| 9 |
+
#
|
| 10 |
+
# sglang DFLASH expects the DFlash-b16 architecture (5-layer, fc+hidden_norm),
|
| 11 |
+
# which is structurally different from LoRA-Inject (full 36-layer + LoRA).
|
| 12 |
+
#
|
| 13 |
+
# So we run offline spec-generate with the correct injection pattern.
|
| 14 |
+
#
|
| 15 |
+
# Usage:
|
| 16 |
+
# bash start_server_dflash.sh # 8 GPUs, all benchmarks
|
| 17 |
+
# bash start_server_dflash.sh 4 # 4 GPUs
|
| 18 |
+
# bash start_server_dflash.sh 8 humaneval # specific benchmark
|
| 19 |
+
# bash start_server_dflash.sh 8 --num-samples 20 # quick test
|
| 20 |
+
|
| 21 |
+
set -e
|
| 22 |
+
|
| 23 |
+
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
| 24 |
+
|
| 25 |
+
NUM_GPUS=${1:-8}
|
| 26 |
+
shift 2>/dev/null || true
|
| 27 |
+
|
| 28 |
+
# ---- defaults ----
|
| 29 |
+
BASE_MODEL=/workspace/models/Qwen3-8B
|
| 30 |
+
ADAPTER_ROOT=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject
|
| 31 |
+
CKPT=epoch_3_step_1400
|
| 32 |
+
MERGED=/workspace/hanrui/syxin_old/Specforge/outputs/qwen3-8b-dflash-lora-inject-merged
|
| 33 |
+
RESULT_DIR=/workspace/hanrui/syxin_old/Specforge/benchmarks/results
|
| 34 |
+
PYTHON=/workspace/miniconda3/envs/spec/bin/python3
|
| 35 |
+
|
| 36 |
+
echo "============================================"
|
| 37 |
+
echo " DFlash-LoRA-Inject Offline Evaluation"
|
| 38 |
+
echo " target : $BASE_MODEL"
|
| 39 |
+
echo " ckpt : $CKPT"
|
| 40 |
+
echo " merged : $MERGED"
|
| 41 |
+
echo " GPUs : $NUM_GPUS"
|
| 42 |
+
echo "============================================"
|
| 43 |
+
|
| 44 |
+
$PYTHON -m torch.distributed.run \
|
| 45 |
+
--standalone \
|
| 46 |
+
--nproc_per_node $NUM_GPUS \
|
| 47 |
+
$SCRIPT_DIR/eval_dflash_lora_inject.py \
|
| 48 |
+
--base-model $BASE_MODEL \
|
| 49 |
+
--adapter-root $ADAPTER_ROOT \
|
| 50 |
+
--ckpt $CKPT \
|
| 51 |
+
--merged-path $MERGED \
|
| 52 |
+
--block-size 16 \
|
| 53 |
+
--output-dir $RESULT_DIR \
|
| 54 |
+
"$@"
|