datasysdev commited on
Commit
01d6e75
·
verified ·
1 Parent(s): f013d0b

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +90 -0
README.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ base_model: Qwen/Qwen3-4B-Instruct-2507
6
+ tags:
7
+ - sparse-attention
8
+ - ann-attention
9
+ - distillation
10
+ - search-projection
11
+ - inference-optimization
12
+ library_name: pytorch
13
+ ---
14
+
15
+ # ann-sparseattention
16
+
17
+ Search projections for ANN-substituted attention on
18
+ [`Qwen/Qwen3-4B-Instruct-2507`](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507).
19
+
20
+ Code: [github.com/unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention)
21
+
22
+ ## What's in this repo
23
+
24
+ Per-layer linear search projections `(W_Qs, W_Ks)` of shape `[2560, 64]`,
25
+ trained against the frozen base model's attention via contrastive +
26
+ distillation losses. At inference these produce 64-d "search vectors" that
27
+ let an off-the-shelf FAISS HNSW index pick the top-K keys to attend to,
28
+ replacing dense `O(L²)` attention with `O(L·K)` ANN-substituted attention.
29
+
30
+ Layers covered (pilot): `[4, 8, 12, 16, 20, 24]` — 6 of 36 layers, ~2M trainable params.
31
+
32
+ ## Pilot results (intermediate, step 1000 / 2000)
33
+
34
+ | Step | Recall@K=128 | PPL gap (full vs ANN) |
35
+ |---|---|---|
36
+ | 500 | 47.4% | 1.21% |
37
+ | 1000 | 50.7% | 0.68% |
38
+
39
+ PPL gap is the primary signal — at <1% relative gap, the model's output
40
+ quality is preserved under ANN substitution. Final-checkpoint numbers and
41
+ the full recall@K curve over `K ∈ {64, 128, 256, 512}` will be added when
42
+ the 2K-step pilot completes.
43
+
44
+ ## Files
45
+
46
+ | File | What |
47
+ |---|---|
48
+ | `search_step_1000.pt` | Search-projection state-dict + optimizer + scheduler at step 1000 (`~11 MB`) |
49
+ | `config.json` | Pilot hyperparams used for this checkpoint |
50
+
51
+ ## Loading
52
+
53
+ ```python
54
+ import torch
55
+ from transformers import AutoModelForCausalLM
56
+ from search_module import SearchProjectionModule # from the GitHub repo
57
+
58
+ base = AutoModelForCausalLM.from_pretrained(
59
+ "Qwen/Qwen3-4B-Instruct-2507",
60
+ dtype=torch.bfloat16,
61
+ device_map="auto",
62
+ attn_implementation="sdpa",
63
+ )
64
+
65
+ search = SearchProjectionModule(
66
+ d_model=2560, d_search=64,
67
+ layer_indices=[4, 8, 12, 16, 20, 24],
68
+ use_mlp=False,
69
+ ).to(base.device).to(torch.bfloat16)
70
+
71
+ ckpt = torch.load("search_step_1000.pt", map_location="cpu")
72
+ search.load_state_dict(ckpt["search_module"])
73
+ ```
74
+
75
+ Use `inference.install_ann_attention(...)` (in the GitHub repo) to monkey-patch
76
+ the trained layers and run with FAISS HNSW retrieval at inference time.
77
+
78
+ ## Training recipe
79
+
80
+ - Frozen base: Qwen3-4B-Instruct-2507 (36 layers, hidden 2560, GQA 32:8).
81
+ - Data: WikiText-103 raw, packed to 4K-token sequences.
82
+ - 2000 steps, batch 8, lr 1e-4 (cosine, 100-step warmup), AdamW.
83
+ - `α=β=1` (contrastive + KL distillation, both layers averaged).
84
+ - bf16 weights, fp32 loss math.
85
+ - SDPA attention (B200, no flash-attn package needed).
86
+ - Liger fused RMSNorm/SwiGLU/RoPE on the frozen base.
87
+
88
+ ## License
89
+
90
+ The search projections are released under Apache-2.0 (matching the base model).