datasysdev commited on
Commit
1f13233
·
verified ·
1 Parent(s): dbb833b

Correct asymptotic scoring analysis

Browse files
Files changed (1) hide show
  1. README.md +389 -166
README.md CHANGED
@@ -1,72 +1,130 @@
1
- ---
2
- license: mit
3
- base_model: Qwen/Qwen3-4B-Instruct-2507
4
- tags:
5
- - sparse-attention
6
- - approximate-nearest-neighbor
7
- - ann
8
- - qwen3
9
- - retrieval
10
- - attention
11
- - research-artifact
12
- library_name: pytorch
13
- ---
14
-
15
- # ANN Sparse Attention Checkpoints
16
-
17
- This repository contains checkpoint artifacts for a research prototype that trains tiny per-layer search projections on a frozen LLM, so dense attention can be approximated by retrieving a small causal key set in a learned low-dimensional space.
18
-
19
- The associated source repo is [unixsysdev/ann-sparseattention](https://github.com/unixsysdev/ann-sparseattention). The GitHub repo contains the training/eval code; this Hugging Face repo stores the checkpoints and JSON result artifacts.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- ## Status
 
 
22
 
23
- Research artifact, not a deployable inference package.
 
 
 
24
 
25
- The clean result is narrow but real: on a 6-layer Qwen3-4B pilot with packed block-causal WikiText evaluation, the learned d128 search projections preserve full-attention perplexity under exact sparse substitution.
26
 
27
- What survives clean methodology:
 
 
 
 
 
 
 
28
 
29
- - Full-attention parity on the block-causal d128 pilot.
30
- - Strong teacher-attention mass recovery with learned projections.
31
- - Learned search projections recover more teacher attention mass than the Quest-style page heuristic at the same token budget on this slice.
32
- - The earlier negative PPL gaps from packed-with-leakage runs do **not** survive as a clean denoising headline.
33
 
34
- What is not established yet:
35
 
36
- - Wall-clock speedup. The current runtime is a correctness prototype.
37
- - Confidence intervals across seeds.
38
- - LongBench/RULER/needle downstream task quality.
39
- - Dynamic decode-mode index insertion.
40
- - Whole-model / all-layer substitution.
41
- - GPU-resident ANN or fused sparse-attention kernels.
42
 
43
- ## Base Model
 
 
 
 
 
44
 
45
- - Base model: `Qwen/Qwen3-4B-Instruct-2507`
46
- - Layers trained in pilot: `[4, 8, 12, 16, 20, 24]`
47
- - Clean recommended checkpoint: `checkpoints_block_d128/search_step_1000.pt`
48
- - Search dimension: `d_search=128`
49
- - Trainable parameters: 3.93M total, about 0.1% of the base model
50
- - Base model weights are **not** included here. These checkpoints contain only the learned search projection module and training metadata.
51
 
52
- ## Folder Guide
 
 
 
 
53
 
54
- Use `checkpoints_block_d128/` for current clean claims.
 
 
 
 
55
 
56
- | Folder | Meaning | Use for claims? |
57
- |---|---|---|
58
- | `checkpoints_block_d128/` | Clean packed block-causal d128 run and eval artifacts | Yes |
59
- | `checkpoints_packed_d64/` | Packed d64 leakage-confounded capacity run | Capacity history only |
60
- | `checkpoints_packed_d128/` | Packed d128 leakage-confounded capacity run | Capacity history only |
61
- | `checkpoints_packed_d256/` | Packed d256 leakage-confounded capacity run | Capacity history only |
62
- | `checkpoints_d64/` | Earlier unpacked d64 checkpoints | Debug/history |
63
- | `checkpoints/` | Original pilot checkpoint and compare JSON | Debug/history |
64
 
65
- The clean block-causal run fixed the core packing issue by assigning each packed document a `segment_id`, resetting `position_ids`, and supplying a 4D block-causal attention mask so tokens can only attend causally within their own document.
66
 
67
- ## Clean Block-Causal Result
 
 
 
68
 
69
- Command used for the clean d128 checkpoint:
70
 
71
  ```bash
72
  python train.py --config pilot_d128_block
@@ -76,9 +134,7 @@ python k_sweep.py \
76
  --no-use-faiss
77
  ```
78
 
79
- Evaluation slice: 16 packed block-causal WikiText batches at 4K context.
80
-
81
- `PPL_full = 30.44`
82
 
83
  | K | Recall@K | mass@K | PPL_ANN | PPL gap |
84
  |---|---:|---:|---:|---:|
@@ -86,13 +142,14 @@ Evaluation slice: 16 packed block-causal WikiText batches at 4K context.
86
  | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
87
  | 512 | n/a | n/a | 30.45 | +0.01% |
88
 
89
- K=512 has no meaningful mass/recall average on this WikiText slice because almost no same-segment queries have 512 valid causal keys. The PPL value is still shown, but K=512 should not be used as a retrieval-quality point for this dataset slice.
90
-
91
- Interpretation: the clean result supports **quality-preserving sparse substitution**, not a claim that sparse attention improves over full attention.
92
-
93
- ## Clean Per-layer Retrieval at K=128
 
94
 
95
- From `checkpoints_block_d128/search_step_1000.compare_retrieval.json`:
96
 
97
  | Layer | raw-QK oracle mass | learned d128 mass |
98
  |---|---:|---:|
@@ -104,22 +161,19 @@ From `checkpoints_block_d128/search_step_1000.compare_retrieval.json`:
104
  | 24 | 0.978 | 0.984 |
105
  | avg | 0.969 | 0.973 |
106
 
107
- This changes the interpretation from the earlier leakage-confounded pilot. With segment isolation, early trained layers are not diffuse or uniquely hard. All six trained layers have high raw-QK oracle mass, and learned projections match or slightly exceed raw-QK retrieval across the tested set.
 
 
 
 
 
108
 
109
- The next deployment hypothesis is therefore: substitute all tested layers, then validate on a broader all-layer run.
110
-
111
- ## Quest-style Page Baseline
112
 
113
  `quest_sweep.py` implements a Quest-style min/max page selector for comparison:
114
-
115
- - Page size: 16
116
- - Native post-RoPE Q/K min/max metadata
117
- - Same block-causal token eligibility mask
118
- - Same sparse-attention gather path
119
-
120
- This is a correctness baseline, not an optimized Quest runtime.
121
-
122
- Command:
123
 
124
  ```bash
125
  python quest_sweep.py \
@@ -128,7 +182,7 @@ python quest_sweep.py \
128
  --page-size 16
129
  ```
130
 
131
- Same 16-batch clean block-causal eval slice:
132
 
133
  | Method | K | Recall@K | mass@K | PPL | PPL gap |
134
  |---|---:|---:|---:|---:|---:|
@@ -137,8 +191,12 @@ Same 16-batch clean block-causal eval slice:
137
  | learned search exact | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
138
  | Quest-style page | 256 | 0.838 | 0.909 | 30.45 | +0.03% |
139
 
140
- Both methods are effectively full-attention parity on PPL. Learned projections recover more teacher attention mass at the same token budget, especially at K=128, but do not yet show a clean PPL advantage over Quest on this slice.
141
-
 
 
 
 
142
 
143
  Paired 32-batch NLL evaluation gives a sharper comparison:
144
 
@@ -147,11 +205,18 @@ Paired 32-batch NLL evaluation gives a sharper comparison:
147
  | 128 | 28.03 | 28.07 | 28.01 | +0.00205 `[+0.00160, +0.00251]` | Quest slightly better |
148
  | 256 | 28.03 | 28.04 | 28.04 | -0.00005 `[-0.00029, +0.00018]` | statistical tie |
149
 
150
- So the current clean result is: learned search has higher teacher-attention mass, but PPL is either tied with Quest (K=256) or slightly worse (K=128) on this paired WikiText slice.
 
 
 
151
 
152
- ## Clean FAISS-vs-exact Check
153
 
154
- The first block-causal FAISS prototype used one global index followed by segment filtering, which produced pathological filler rates after filtering. The current FAISS path builds per-segment indexes when a 4D block-causal mask is present. With that fix, CPU FAISS/HNSW tracks exact learned search on the same 16-batch clean eval slice:
 
 
 
 
155
 
156
  | Method | K | PPL | PPL gap | FAISS filler rate |
157
  |---|---:|---:|---:|---:|
@@ -160,110 +225,268 @@ The first block-causal FAISS prototype used one global index followed by segment
160
  | learned exact | 256 | 30.45 | +0.01% | n/a |
161
  | learned FAISS/HNSW | 256 | 30.46 | +0.04% | 0.683 |
162
 
163
- The remaining filler rate is expected for short same-segment prefixes where fewer than K valid causal keys exist; filler slots are masked out of the sparse-attention softmax. This demonstrates off-the-shelf ANN compatibility in the clean block-causal setting, but not production wall-clock speedup.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
- ## Packed Leakage-confounded Ablations
 
 
 
 
166
 
167
- The packed d64/d128/d256 runs are included because they are useful for understanding capacity scaling, but they should not be used for clean quality claims. Those runs allowed cross-document attention inside packed examples.
168
 
169
- Packed d_search ablation at K=128:
 
 
170
 
171
- | d_search | Params | learned mass@K=128 | raw-QK oracle | learned/oracle | final PPL gap |
172
- |---|---:|---:|---:|---:|---:|
173
- | 64 | 1.97M | 0.492 | 0.488 | 1.01x | +2.39% |
174
- | 128 | 3.93M | 0.503 | 0.488 | 1.03x | -1.81% |
175
- | 256 | 7.86M | 0.509 | 0.488 | 1.04x | -1.85% |
176
 
177
- The packed leakage-confounded K-sweep showed large negative PPL gaps:
 
 
 
178
 
179
- | K | Recall@K | mass@K | PPL_ANN | PPL gap |
180
- |---|---:|---:|---:|---:|
181
- | 128 | 0.166 | 0.256 | 203.63 | -9.36% |
182
- | 256 | 0.233 | 0.318 | 207.06 | -7.83% |
183
- | 512 | 0.339 | 0.409 | 211.93 | -5.66% |
184
 
185
- A second leaked packed slice preserved the shape: K=128 `-8.78%`, K=256 `-7.59%`, K=512 `-6.21%`. These numbers are retained for transparency and debugging history. They should not be reported as the headline because the clean block-causal rerun shows parity, not denoising.
186
-
187
- ## What the Checkpoints Contain
188
-
189
- Each `.pt` file is a PyTorch checkpoint with the learned search projection module and config metadata. The base LLM is loaded separately from Hugging Face.
190
-
191
- Example loading pattern from the source repo:
192
-
193
- ```python
194
- import torch
195
- from transformers import AutoModelForCausalLM, AutoTokenizer
196
- from config import Config
197
- from model import SearchProjectionModule
198
-
199
- ckpt = torch.load("checkpoints_block_d128/search_step_1000.pt", map_location="cpu", weights_only=False)
200
- ckpt_cfg = ckpt["config"]
201
-
202
- cfg = Config()
203
- for key, value in ckpt_cfg.items():
204
- if hasattr(cfg, key):
205
- setattr(cfg, key, value)
206
-
207
- base = AutoModelForCausalLM.from_pretrained(
208
- cfg.base_model_name,
209
- dtype=torch.bfloat16,
210
- device_map="auto",
211
- attn_implementation="sdpa",
212
- )
213
-
214
- layers = [
215
- i for i in cfg.full_attention_layer_indices
216
- if i not in cfg.reserved_full_attention_indices
217
- ]
218
- search = SearchProjectionModule(
219
- d_model=base.config.hidden_size,
220
- d_search=cfg.d_search,
221
- layer_indices=layers,
222
- use_mlp=cfg.use_mlp_proj,
223
- ).to(base.device).to(torch.bfloat16)
224
- search.load_state_dict(ckpt["search_module"])
225
- search.eval()
226
- ```
227
 
228
- See the GitHub repo for full eval scripts and monkey-patched sparse-attention wrappers.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
- ## Runtime Caveat
231
 
232
- The current `inference.py` path is a correctness prototype:
 
 
 
233
 
234
- - Exact top-K path materializes dense `[B, L, L]` similarity and is for analysis.
235
- - FAISS/HNSW path builds a CPU index per forward pass and transfers data across CPU/GPU.
236
- - Gathered sparse attention still uses dense-style tensor expansion internally.
 
 
237
 
238
- Therefore, any FLOP/scoring reductions are algorithmic estimates, not measured wall-clock speedups. A deployable runtime needs GPU-resident retrieval and a fused sparse/paged attention kernel.
 
239
 
240
- ## Recommended Use
 
 
241
 
242
- Use this repo for:
243
 
244
- - Reproducing the clean d128 block-causal result.
245
- - Inspecting search projection checkpoints.
246
- - Comparing learned search retrieval against raw-QK and Quest-style page retrieval.
247
- - Building follow-up experiments such as dynamic-index insertion or all-layer substitution.
248
 
249
- Do not use this repo as:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
- - A drop-in accelerated inference engine.
252
- - Evidence that sparse attention beats full attention on clean methodology.
253
- - A complete comparison against all sparse-attention baselines.
 
 
 
 
254
 
255
- ## Next Experiments
 
256
 
257
- The most important follow-ups are:
258
 
259
- 1. Dynamic-index demonstration during long generation.
260
- 2. Multi-seed confidence intervals for block-causal d128.
261
- 3. LongBench/RULER/needle task evaluation.
262
- 4. All-layer substitution run.
263
- 5. GPU-resident retrieval and decode-mode KV-cache integration.
 
264
 
265
- ## Citation / Attribution
 
266
 
267
- This is an in-progress research artifact. If you use it, cite the GitHub repo and this Hugging Face checkpoint repository.
268
 
269
- Source: https://github.com/unixsysdev/ann-sparseattention
 
1
+ # ann-sparseattention
2
+
3
+ Train tiny per-layer "search projections" on a frozen LLM that replicate the
4
+ attention's top-K preferences in a low-dimensional space, so we can swap dense
5
+ quadratic attention for an off-the-shelf ANN index (FAISS HNSW) at inference
6
+ and lose almost no model quality.
7
+
8
+ ## Current status
9
+
10
+ Research prototype. The trained projections work in a narrow 6-layer packed
11
+ WikiText-103 pilot on `Qwen/Qwen3-4B-Instruct-2507`, but the runtime is still
12
+ a correctness prototype. Treat reported numbers as preliminary until confidence
13
+ intervals, downstream long-context tasks, and real baselines are run.
14
+
15
+ Checkpoint artifacts and JSON eval outputs are mirrored on Hugging Face:
16
+ [`datasysdev/ann-sparseattention`](https://huggingface.co/datasysdev/ann-sparseattention).
17
+ Use `checkpoints_block_d128/search_step_1000.pt` there for the current clean
18
+ block-causal result.
19
+
20
+ **What's validated:**
21
+ - 6-layer packed pilot on Qwen3-4B-Instruct-2507, layers
22
+ `[4, 8, 12, 16, 20, 24]`, 4K context, 1K training steps.
23
+ - `d_search=128` is the current recommended capacity from the packed capacity
24
+ ablation: 3.93M trainable
25
+ parameters, mass@K=128 of 0.503 vs 0.488 for the raw-QK exact-topK oracle,
26
+ and -1.81% relative PPL gap at K=128 on the packed eval slice.
27
+ - Block-causal packed masking is implemented. On the clean block-causal d128
28
+ rerun, exact sparse attention is near parity with full attention
29
+ (K=128: +0.07% PPL gap; K=256: +0.01%). The large negative PPL gaps from
30
+ packed-with-leakage do not survive as a clean-methodology headline.
31
+ - Capacity scaling is monotonic but saturating: d64 < d128 < d256 on mass@K,
32
+ while d128 and d256 are effectively tied on final PPL.
33
+ - Learned projections outperform raw-QK oracle mass in mid/late trained layers
34
+ (L12-L24), while early layers remain harder.
35
+
36
+ **Not yet validated (next iteration):**
37
+ - Confidence intervals for the block-causal result over multiple seeds and
38
+ larger eval slices.
39
+ - Quest / RetrievalAttention baselines.
40
+ - Long-context task quality (LongBench, RULER, needle-in-haystack).
41
+ - 34-layer / whole-model substitution.
42
+ - Wall-clock speedup vs. FlashAttention/SDPA — not measured.
43
+ - KV-cache decode-mode integration.
44
+ - GPU-resident ANN or fused gather-attention kernel.
45
+
46
+ **Runtime caveat.** The current FAISS path is a correctness prototype: it
47
+ builds a CPU index per forward pass and uses dense-style tensor expansion
48
+ internally for the gather step. The compute-reduction numbers below are
49
+ **algorithmic scoring reductions, not measured wall-clock speedups.** A
50
+ production runtime requires a GPU-resident topk kernel or integration with
51
+ paged/block-sparse attention kernels.
52
+
53
+ ### d_search ablation (packed WikiText-103, K=128)
54
+
55
+ The packed ablation trains the same 6 layers for 1K steps and evaluates all
56
+ variants with the same packed eval pipeline. `raw_qk` is exact top-K over
57
+ head-mean-aggregated native post-RoPE Q/K vectors; `learned` is exact top-K
58
+ over trained search projections. mass@K is teacher-attention probability
59
+ captured by the retrieved set.
60
+
61
+ | d_search | Params | learned mass@K=128 | raw-QK oracle | learned / oracle | Final PPL gap |
62
+ |---|---:|---:|---:|---:|---:|
63
+ | 64 | 1.97M | 0.492 | 0.488 | 1.01x | +2.39% |
64
+ | **128** | **3.93M** | **0.503** | **0.488** | **1.03x** | **-1.81%** |
65
+ | 256 | 7.86M | 0.509 | 0.488 | 1.04x | -1.85% |
66
 
67
+ d128 is the recommended default for this pilot: it captures almost all of the
68
+ d256 quality with half the trainable parameters. d256 improves mass@K slightly
69
+ but does not materially improve final PPL.
70
 
71
+ PPL gap is the primary model-quality signal; mass@K is the more direct
72
+ retrieval-quality signal when teacher attention is sharp. Recall@K is logged,
73
+ but it is a weaker proxy because disagreement on near-zero-probability tail
74
+ positions can look like low recall while preserving model output.
75
 
76
+ Per-layer mass@K=128 for d128:
77
 
78
+ | Layer | raw-QK oracle | learned d128 |
79
+ |---|---:|---:|
80
+ | 4 | 0.422 | 0.382 |
81
+ | 8 | 0.518 | 0.421 |
82
+ | 12 | 0.404 | 0.533 |
83
+ | 16 | 0.475 | 0.481 |
84
+ | 20 | 0.499 | 0.551 |
85
+ | 24 | 0.614 | 0.648 |
86
 
87
+ Early layers remain harder for learned retrieval; mid/late trained layers
88
+ exceed raw-QK oracle mass.
 
 
89
 
90
+ ### K-retrieve Pareto (packed d128, leakage-confounded)
91
 
92
+ Exact top-K sweep for the recommended packed d128 checkpoint:
 
 
 
 
 
93
 
94
+ ```bash
95
+ python k_sweep.py \
96
+ --ckpt /tmp/checkpoints_packed_d128/search_step_1000.pt \
97
+ --K 128,256,512 \
98
+ --no-use-faiss
99
+ ```
100
 
101
+ `PPL_full = 224.64` on this packed eval slice.
 
 
 
 
 
102
 
103
+ | K | Recall@K | mass@K | PPL_ANN | PPL gap |
104
+ |---|---:|---:|---:|---:|
105
+ | 128 | 0.166 | 0.256 | 203.63 | -9.36% |
106
+ | 256 | 0.233 | 0.318 | 207.06 | -7.83% |
107
+ | 512 | 0.339 | 0.409 | 211.93 | -5.66% |
108
 
109
+ This disambiguates the earlier FAISS high-K failure on the leaked packed
110
+ pipeline: exact retrieval remains
111
+ strongly negative at K=256/512, so the denoising pattern is present on this
112
+ packed eval slice. This should not be used as a publication-strength denoising
113
+ claim because packed examples can attend across document boundaries.
114
 
115
+ A second exact sweep on the next 16 packed eval batches (`--skip-batches 16`)
116
+ preserved the shape: K=128 -8.78%, K=256 -7.59%, K=512 -6.21%. This is still
117
+ not a substitute for confidence intervals, but it reduces the chance that the
118
+ large negative gap is a single-slice accident.
 
 
 
 
119
 
120
+ ### Block-causal packed d128 (clean masking)
121
 
122
+ Packed block-causal masking assigns each packed document a `segment_id`, resets
123
+ `position_ids` at segment boundaries, and supplies a 4D additive mask so tokens
124
+ can only attend causally within their own document. Retrieval, loss masking,
125
+ mass@K, and recall@K use the same segment-causal eligibility mask.
126
 
127
+ Clean d128 block-causal run:
128
 
129
  ```bash
130
  python train.py --config pilot_d128_block
 
134
  --no-use-faiss
135
  ```
136
 
137
+ `PPL_full = 30.44` on the 16-batch clean eval slice.
 
 
138
 
139
  | K | Recall@K | mass@K | PPL_ANN | PPL gap |
140
  |---|---:|---:|---:|---:|
 
142
  | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
143
  | 512 | n/a | n/a | 30.45 | +0.01% |
144
 
145
+ K=512 has no meaningful mass/recall average on this WikiText slice because
146
+ almost no same-segment queries have 512 valid causal keys. The quality result
147
+ is still useful: with filler slots masked out of the sparse-attention softmax,
148
+ the block-causal exact path is effectively at full-attention parity. The clean
149
+ result supports "quality-preserving sparse substitution" rather than the leaked
150
+ pipeline's stronger denoising claim.
151
 
152
+ Clean block-causal per-layer `compare_retrieval` at K=128:
153
 
154
  | Layer | raw-QK oracle mass | learned d128 mass |
155
  |---|---:|---:|
 
161
  | 24 | 0.978 | 0.984 |
162
  | avg | 0.969 | 0.973 |
163
 
164
+ This changes the per-layer interpretation from the leakage-confounded pilot:
165
+ with segment isolation, early trained layers are not diffuse or uniquely hard.
166
+ All six trained layers have high oracle mass, and learned projections match or
167
+ slightly exceed raw-QK retrieval across the set. The deployment hypothesis for
168
+ the next run is therefore "substitute all tested layers" rather than "keep early
169
+ layers as full attention," pending a broader all-layer run.
170
 
171
+ ### Quest-style page baseline (clean block-causal)
 
 
172
 
173
  `quest_sweep.py` implements a Quest-style min/max page selector for comparison:
174
+ page size 16, native post-RoPE Q/K, same block-causal token eligibility mask,
175
+ and the same sparse-attention gather path. This is a correctness baseline, not
176
+ an optimized Quest runtime.
 
 
 
 
 
 
177
 
178
  ```bash
179
  python quest_sweep.py \
 
182
  --page-size 16
183
  ```
184
 
185
+ On the same 16-batch block-causal eval slice:
186
 
187
  | Method | K | Recall@K | mass@K | PPL | PPL gap |
188
  |---|---:|---:|---:|---:|---:|
 
191
  | learned search exact | 256 | 0.879 | 0.953 | 30.45 | +0.01% |
192
  | Quest-style page | 256 | 0.838 | 0.909 | 30.45 | +0.03% |
193
 
194
+ Both methods are effectively full-attention parity on PPL. The learned search
195
+ space recovers more teacher attention mass at the same token budget, especially
196
+ at K=128, while Quest remains a strong non-trained heuristic baseline. This
197
+ keeps the contribution narrow: learned projections improve retrieval fidelity
198
+ and support standard ANN indexing; they do not yet show a clean PPL advantage
199
+ over Quest on this slice.
200
 
201
  Paired 32-batch NLL evaluation gives a sharper comparison:
202
 
 
205
  | 128 | 28.03 | 28.07 | 28.01 | +0.00205 `[+0.00160, +0.00251]` | Quest slightly better |
206
  | 256 | 28.03 | 28.04 | 28.04 | -0.00005 `[-0.00029, +0.00018]` | statistical tie |
207
 
208
+ So the current clean result is: learned search has higher teacher-attention
209
+ mass, but PPL is either tied with Quest (K=256) or slightly worse (K=128) on
210
+ this paired WikiText slice. The paper claim should be "retrieval-fidelity and
211
+ ANN-compatibility advantages," not "PPL advantage over Quest."
212
 
213
+ ### Clean FAISS-vs-exact check
214
 
215
+ The first block-causal FAISS prototype used one global index followed by
216
+ segment filtering, which produced pathological filler rates after filtering.
217
+ The current FAISS path builds per-segment indexes when a 4D block-causal mask
218
+ is present. With that fix, CPU FAISS/HNSW tracks exact learned search on the
219
+ same 16-batch clean eval slice:
220
 
221
  | Method | K | PPL | PPL gap | FAISS filler rate |
222
  |---|---:|---:|---:|---:|
 
225
  | learned exact | 256 | 30.45 | +0.01% | n/a |
226
  | learned FAISS/HNSW | 256 | 30.46 | +0.04% | 0.683 |
227
 
228
+ The remaining filler rate is expected for short same-segment prefixes where
229
+ fewer than K valid causal keys exist; filler slots are masked out of the sparse
230
+ attention softmax. This demonstrates off-the-shelf ANN compatibility in the
231
+ clean block-causal setting, but not production wall-clock speedup.
232
+
233
+ ### Asymptotic scoring analysis
234
+
235
+ `artifacts/scaling_analysis.md` gives a deterministic operation-count proxy
236
+ for the per-query candidate scoring step. This is the cost of identifying
237
+ which keys to attend to, before the sparse attention softmax and value
238
+ multiply over the selected keys.
239
+
240
+ Assumptions:
241
+
242
+ - Full attention scoring: `N * d_head = N * 128`.
243
+ - Quest-style page scoring: `(N / page_size) * 2 * d_head = N * 16`
244
+ with `page_size=16`.
245
+ - Learned HNSW scoring: `M * ef_search * log2(N) * d_search`
246
+ with `M=32`, `ef_search=64`, and `d_search=128`.
247
+
248
+ ![Candidate-scoring operations per query](artifacts/scaling_plot.svg)
249
+
250
+ | Context | Full ops/query | Quest ops/query | Learned HNSW ops/query | Quest / learned |
251
+ |---:|---:|---:|---:|---:|
252
+ | 4K | 512,000 | 64,000 | 3,136,759 | 0.02x |
253
+ | 8K | 1,024,000 | 128,000 | 3,398,903 | 0.04x |
254
+ | 16K | 2,048,000 | 256,000 | 3,661,047 | 0.07x |
255
+ | 32K | 4,096,000 | 512,000 | 3,923,191 | 0.13x |
256
+ | 64K | 8,192,000 | 1,024,000 | 4,185,335 | 0.24x |
257
+ | 128K | 16,384,000 | 2,048,000 | 4,447,479 | 0.46x |
258
+ | 256K | 32,768,000 | 4,096,000 | 4,709,623 | 0.87x |
259
+ | 512K | 65,536,000 | 8,192,000 | 4,971,767 | 1.65x |
260
+ | 1M | 128,000,000 | 16,000,000 | 5,224,942 | 3.06x |
261
+ | 2M | 256,000,000 | 32,000,000 | 5,487,086 | 5.83x |
262
+ | 4M | 512,000,000 | 64,000,000 | 5,749,230 | 11.13x |
263
+
264
+ Under these conservative HNSW constants, Quest is cheaper below the
265
+ few-hundred-thousand-token regime and learned-projection scoring becomes
266
+ cheaper beyond roughly 300K tokens. At 1M context, the operation-count proxy is
267
+ about 3x in favor of learned projections. This supports the theoretical
268
+ scaling claim only; production speed claims still require GPU-resident
269
+ retrieval and KV-cache/decode integration.
270
+
271
+ ### Compute / quality knobs (FLOP-counted)
272
+
273
+ `L = 4096`. Compute reduction is the attention scoring step, `≈ L / K`.
274
+ These are FLOP estimates, not measured wall-clock — the FAISS path in this
275
+ repo is a research prototype that does CPU index builds and GPU↔CPU
276
+ transfers, so it is not the right thing to time. A GPU-resident topk
277
+ kernel is the natural next step.
278
+
279
+ | K | PPL gap | Attention scoring reduction |
280
+ |---|---|---|
281
+ | 512 | -5.66% (exact top-K over learned search space) | ~8x |
282
+ | 256 | -7.83% (exact top-K over learned search space) | ~16x |
283
+ | 128 | -9.36% exact; -1.81% FAISS/training eval | ~32x |
284
+ | 64 | +0.46% | ~64x |
285
+ | 32 | +0.03% | ~128x |
286
+ | 16 | +5.63% | ~256x |
287
+
288
+ Eval scope for the d_search table: 16 packed validation batches at 4K context
289
+ for PPL/recall during training, and 12 packed batches for `compare_retrieval`
290
+ mass@K. Numbers should be read as "what we observed on this slice", not
291
+ population-level estimates.
292
+
293
+ ### Caveats / what's next
294
+
295
+ A few things the pilot does not yet establish, and that the next iteration
296
+ will:
297
+
298
+ - **Packing**: the d_search ablation table is still from the packed
299
+ leakage-confounded run and is best read as a capacity comparison. The clean
300
+ block-causal d128 rerun removes cross-document leakage and should be used for
301
+ quality claims.
302
+ - **Exact-topK oracle**: the obvious follow-up is a four-way Pareto —
303
+ full attention vs. exact top-K (true `QK^T` argmax-K, then attention) vs.
304
+ search-topK (our projections, exact distance) vs. search-ANN (FAISS HNSW).
305
+ That separates "denoising from any sparsity" from "denoising from learned
306
+ projections."
307
+ - **Wall-clock**: the compute-reduction table above is FLOP-counted. The
308
+ FAISS path here is a research prototype (CPU index per forward, GPU↔CPU
309
+ transfer) and is the wrong thing to time. A GPU-resident topk kernel is
310
+ the next-step engineering.
311
+ - **34-layer headline**: was queued and the VM was reclaimed before launch.
312
+ Config is wired (`make_headline_config()`); rerun is a single command on
313
+ any B200/H100/H200.
314
+
315
+ The recall@K and mass@K reported here come from a 12-batch eval slice, not
316
+ a population-level estimate. Confidence intervals and downstream tasks
317
+ (LongBench / RULER / needle-in-haystack) are the natural next evals.
318
+
319
+ ### Headline run (queued)
320
+
321
+ 34 layers (every layer except 0 and 35), 8K context, 6K steps,
322
+ ~4-5h on a single B200. Tests whether the technique generalizes from a
323
+ 6-layer subset to broad layer coverage. Checkpoints will be mirrored at
324
+ [`datasysdev/ann-sparseattention`](https://huggingface.co/datasysdev/ann-sparseattention).
325
+
326
+ ## Relation to RetrievalAttention
327
+
328
+ The closest prior work is RetrievalAttention (Liu et al., 2024). They show
329
+ that **vanilla ANN over the model's native Q and K vectors fails** because
330
+ Q and K live in mismatched distributions — they were never trained to be
331
+ each other's nearest neighbors, only to score correctly via the dot
332
+ product. Their fix is at *index time*: an attention-aware graph
333
+ construction (RoarGraph-style) that compensates for the Q/K out-of-
334
+ distribution problem at search time.
335
+
336
+ This work attacks the same problem from the opposite direction. Instead of
337
+ patching the index over hostile vectors, we **train a tiny shared
338
+ low-dimensional projection** (`W_Qs, W_Ks → R^128` in the recommended pilot)
339
+ so that `q_search` and `k_search` *do* live in the same distribution by construction. Off-the-
340
+ shelf FAISS HNSW with default parameters is then sufficient — there is no
341
+ attention-aware index trick.
342
+
343
+ | | Search space | Index | Trainable |
344
+ |---|---|---|---|
345
+ | Raw Q/K + vanilla ANN | original Q/K | off-the-shelf | no — fails (Q/K OOD) |
346
+ | RetrievalAttention | original Q/K | attention-aware graph | no |
347
+ | **This work** | **learned Q\_s / K\_s** | **off-the-shelf** | **yes (~2-11M params)** |
348
+
349
+ The contribution claim: *eliminate the Q/K mismatch at index-build time
350
+ via distillation, instead of patching it at search time.* The clean
351
+ experiment to validate this — vanilla FAISS over raw Q/K vs. vanilla
352
+ FAISS over learned Q\_s/K\_s vs. exact teacher top-K, all at the same K —
353
+ is the next planned run. The current pilot establishes that the learned
354
+ projections retrieve attention-relevant keys; the comparison run isolates
355
+ how much of that came from the projection vs. the ANN approximation.
356
+
357
+ ## How it works
358
+
359
+ For each full-attention layer `i` we train two linear projections
360
+ `W_Qs^i, W_Ks^i ∈ R^{d_model × d_search}` (recommended pilot: d_search=128),
361
+ so that for any
362
+ hidden state `h`,
363
 
364
+ ```
365
+ q_search = W_Qs^i h k_search = W_Ks^i h
366
+ softmax(q_search · k_search^T) ranks the same keys as
367
+ softmax(QK^T / √d_head) (the teacher's attention)
368
+ ```
369
 
370
+ Two losses, summed across layers:
371
 
372
+ - **InfoNCE** with teacher-derived positives (top-`K_pos` keys from the
373
+ teacher's attention serve as positives for each query).
374
+ - **KL(teacher ‖ student)** on the full attention distribution.
375
 
376
+ At inference, we monkey-patch each trained layer's attention forward to:
 
 
 
 
377
 
378
+ 1. Compute `q_search`, `k_search` from the same hidden state.
379
+ 2. Build a per-batch FAISS HNSW index over `k_search` (default params).
380
+ 3. Retrieve top-`K_retrieve` positions (causal-respecting) per query.
381
+ 4. Run standard attention restricted to those `K_retrieve` keys.
382
 
383
+ The base model's parameters are never touched. The recommended d128 pilot
384
+ trains 3.93M parameters total.
 
 
 
385
 
386
+ ## Repo layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
+ ```
389
+ config.py Run config (pilot defaults; make_headline_config() for follow-up)
390
+ model.py SearchProjection, FrozenForwardCapture (with QK reconstruction
391
+ trick: capture (Q, K) post-RoPE while the forward stays in FA),
392
+ contrastive + KL distillation losses
393
+ data.py Long-context dataloader (packing off by default to avoid
394
+ cross-segment attention leakage; pin_memory, prefetch)
395
+ inference.py ANN-substituted attention (exact top-K for analysis;
396
+ CPU-FAISS HNSW prototype path — not a deployable kernel)
397
+ eval.py recall@K curve, mass@K curve, full-vs-ANN PPL,
398
+ MoE router stability
399
+ train.py Training loop, Liger setup, FA-3→FA-2→SDPA→eager fallback,
400
+ base-model freeze + drift check, auto-resume from latest ckpt
401
+ tests/ QK reconstruction verification + 50-step smoke test
402
+ ```
403
 
404
+ ## Quick start
405
 
406
+ ```bash
407
+ pip install -r requirements.txt
408
+ export WANDB_API_KEY=<key> # only — never check it in
409
+ export HF_TOKEN=<token> # for faster Hub downloads
410
 
411
+ # Pre-launch checks
412
+ python -c "from transformers import AutoConfig; \
413
+ print(AutoConfig.from_pretrained('Qwen/Qwen3-4B-Instruct-2507'))"
414
+ python tests/test_qk_reconstruction.py
415
+ python tests/smoke_test.py
416
 
417
+ # Packed d_search ablation
418
+ bash scripts/run_packed_ablation.sh
419
 
420
+ # Default clean pilot (packing off; data-sparse on WikiText articles)
421
+ python train.py --config pilot_d64_clean
422
+ ```
423
 
424
+ ## Configuration
425
 
426
+ The default `Config` is the 1-day pilot:
 
 
 
427
 
428
+ | Knob | Pilot | Headline |
429
+ |---|---|---|
430
+ | `seq_len` | 4096 | 8192 |
431
+ | `batch_size` | 8 | 8 |
432
+ | `total_steps` | 1000 | 6000 |
433
+ | layers trained | 6 (`[4,8,12,16,20,24]`) | 34 (`range(36)` minus reserved `[0, 35]`) |
434
+ | trainable params | 1.97M at d64; 3.93M at d128 | 11.1M at d64 |
435
+ | `d_search` | 64 default; d128 recommended from ablation | 64 default |
436
+ | `K_retrieve_eval` | 128 | 128 |
437
+
438
+ Pilot is the proof-of-concept; headline trains every attention layer except
439
+ the first (raw-embedding-adjacent) and last (output-logits-adjacent), which is
440
+ the deployment-relevant claim that the technique scales to dense application.
441
+
442
+ Use `make_pilot_d128_packed_config()` to reproduce the current recommended
443
+ pilot, or `make_headline_config()` for the broader 34-layer run.
444
+
445
+ ## Performance choices
446
+
447
+ - `attn_implementation` resolves at load time as
448
+ `flash_attention_3 → flash_attention_2 → sdpa → eager`. On B200 with no
449
+ flash-attn package installed, SDPA wins — its built-in flash backend is
450
+ ~80-90% of FA-2's throughput with zero build dependency.
451
+ - Liger kernels applied via `apply_liger_kernel_to_qwen3` (RMSNorm, SwiGLU,
452
+ RoPE fused — typically 30-50% faster forward).
453
+ - The QK-reconstruction trick keeps SDPA/FA fast on the trained layers:
454
+ we monkey-patch them to capture `(Q, K)` post-RoPE, then reconstruct
455
+ `softmax(QK^T/√d)` ourselves *after* the forward returns. The forward
456
+ never sets `output_attentions=True` (which would force eager).
457
+ - `torch.compile(search_module, mode="max-autotune")` on the search
458
+ projections; base model uncompiled (works but flaky for novel architectures).
459
+ - bf16 throughout; loss math cast to fp32 for numerical stability of softmax.
460
+
461
+ ## Verifying the QK reconstruction
462
+
463
+ The post-RoPE Q/K capture must match what the model's eager attention computes
464
+ or distillation supervision is wrong. The test asserts top-32 agreement
465
+ > 99% per layer:
466
 
467
+ ```bash
468
+ python tests/test_qk_reconstruction.py --model Qwen/Qwen3-4B-Instruct-2507
469
+ # layer 0: PASS max|Δ|=2.54e-02 top-32 agree=0.9963
470
+ # layer 1: PASS max|Δ|=5.27e-02 top-32 agree=0.9941
471
+ # ...
472
+ # QK reconstruction verified.
473
+ ```
474
 
475
+ The bf16 max-abs differences (~0.05) are just numerical noise; the
476
+ *ranking* of attention positions matches.
477
 
478
+ ## Reproducing the pilot
479
 
480
+ ```bash
481
+ git clone git@github.com:unixsysdev/ann-sparseattention.git
482
+ cd ann-sparseattention
483
+ pip install -r requirements.txt
484
+ python train.py --config pilot_d128_packed
485
+ ```
486
 
487
+ A single H100/H200/B200 + 8GB GPU RAM for the 4B model + ~10GB for activations
488
+ at 4K context, batch 8.
489
 
490
+ ## License
491
 
492
+ MIT.