datasysdev commited on
Commit
50122dd
·
verified ·
1 Parent(s): f3d67d5

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +17 -8
README.md CHANGED
@@ -29,31 +29,39 @@ replacing dense `O(L²)` attention with `O(L·K)` ANN-substituted attention.
29
 
30
  Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` — 6 of 36 layers, ~2M trainable params.
31
 
32
- ## Pilot results (intermediate, step 1000 / 2000)
33
 
34
  | Step | Recall@K=128 | PPL gap (full vs ANN) |
35
  |---|---|---|
36
  | 500 | 47.4% | 1.21% |
37
  | 1000 | 50.7% | 0.68% |
 
 
38
 
39
  PPL gap is the primary signal — at <1% relative gap, the model's output
40
- quality is preserved under ANN substitution. Final-checkpoint numbers and
41
- the full recall@K curve over `K {64, 128, 256, 512}` will be added when
42
- the 2K-step pilot completes.
 
 
 
43
 
44
  ## Files
45
 
46
  | File | What |
47
  |---|---|
48
- | `search_step_1000.pt` | Search-projection state-dict + optimizer + scheduler at step 1000 (`~11 MB`) |
49
- | `config.json` | Pilot hyperparams used for this checkpoint |
 
 
50
 
51
  ## Loading
52
 
53
  ```python
54
  import torch
55
  from transformers import AutoModelForCausalLM
56
- from search_module import SearchProjectionModule # from the GitHub repo
 
57
 
58
  base = AutoModelForCausalLM.from_pretrained(
59
  "Qwen/Qwen3-4B-Instruct-2507",
@@ -68,7 +76,7 @@ search = SearchProjectionModule(
68
  use_mlp=False,
69
  ).to(base.device).to(torch.bfloat16)
70
 
71
- ckpt = torch.load("search_step_1000.pt", map_location="cpu")
72
  search.load_state_dict(ckpt["search_module"])
73
  ```
74
 
@@ -84,6 +92,7 @@ the trained layers and run with FAISS HNSW retrieval at inference time.
84
  - bf16 weights, fp32 loss math.
85
  - SDPA attention (B200, no flash-attn package needed).
86
  - Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
 
87
 
88
  ## License
89
 
 
29
 
30
  Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` — 6 of 36 layers, ~2M trainable params.
31
 
32
+ ## Pilot results (final, 2K steps on WikiText-103)
33
 
34
  | Step | Recall@K=128 | PPL gap (full vs ANN) |
35
  |---|---|---|
36
  | 500 | 47.4% | 1.21% |
37
  | 1000 | 50.7% | 0.68% |
38
+ | 1500 | 50.9% | 0.68% |
39
+ | **2000 (final)** | **50.9%** | **0.71%** |
40
 
41
  PPL gap is the primary signal — at <1% relative gap, the model's output
42
+ quality is preserved under ANN substitution. Recall plateaus around step 1000
43
+ because the softmax-relevant keys concentrate in the top ~30; disagreement
44
+ on positions 30-128 is on near-zero-weight tail and doesn't affect output.
45
+
46
+ A K-retrieve Pareto sweep follows below; a 34-layer headline run on 8K context
47
+ extends the deployment story.
48
 
49
  ## Files
50
 
51
  | File | What |
52
  |---|---|
53
+ | `search_step_1000.pt` | Mid-training checkpoint (step 1000, 0.68% PPL gap) |
54
+ | `search_step_2000.pt` | Final pilot checkpoint (step 2000, 0.71% PPL gap) |
55
+
56
+ Each contains `{step, search_module: state_dict, optimizer, scheduler, config}`.
57
 
58
  ## Loading
59
 
60
  ```python
61
  import torch
62
  from transformers import AutoModelForCausalLM
63
+ # Search module class is in the GitHub repo (model.py)
64
+ from model import SearchProjectionModule
65
 
66
  base = AutoModelForCausalLM.from_pretrained(
67
  "Qwen/Qwen3-4B-Instruct-2507",
 
76
  use_mlp=False,
77
  ).to(base.device).to(torch.bfloat16)
78
 
79
+ ckpt = torch.load("search_step_2000.pt", map_location="cpu", weights_only=False)
80
  search.load_state_dict(ckpt["search_module"])
81
  ```
82
 
 
92
  - bf16 weights, fp32 loss math.
93
  - SDPA attention (B200, no flash-attn package needed).
94
  - Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
95
+ - Total wall-clock: ~25 min on a single B200.
96
 
97
  ## License
98