Andrew Young commited on
Commit
8ef2d83
·
verified ·
1 Parent(s): b038005

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +13 -0
  2. .gitignore +49 -0
  3. Cargo.toml +64 -0
  4. LICENSE +21 -0
  5. README.md +342 -0
  6. benchmarks/README.md +221 -0
  7. benchmarks/results/benchmark_results_20260110_181653.txt +179 -0
  8. benchmarks/run_all_benchmarks.sh +222 -0
  9. examples/demo_hat_memory.py +478 -0
  10. images/fig01_architecture.jpg +3 -0
  11. images/fig02_recall_comparison.jpg +3 -0
  12. images/fig03_build_time.jpg +3 -0
  13. images/fig04_pipeline.jpg +3 -0
  14. images/fig05_hippocampus.jpg +3 -0
  15. images/fig06_hat_vs_rag.jpg +3 -0
  16. images/fig07_scale_performance.jpg +3 -0
  17. images/fig08_consolidation.jpg +3 -0
  18. images/fig09_summary_results.jpg +3 -0
  19. images/fig10_beam_search.jpg +3 -0
  20. paper/HAT_paper_complete.md +439 -0
  21. paper/figures/fig1_recall_comparison.png +0 -0
  22. paper/figures/fig2_build_time.png +0 -0
  23. paper/figures/fig3_latency_scale.png +3 -0
  24. paper/figures/fig4_architecture.png +3 -0
  25. paper/figures/fig5_memory_breakdown.png +0 -0
  26. paper/figures/fig6_recall_by_k.png +0 -0
  27. paper/figures/fig7_embedding_dims.png +3 -0
  28. pyproject.toml +45 -0
  29. python/arms_hat/__init__.py +46 -0
  30. python/tests/test_hat_index.py +296 -0
  31. src/adapters/attention.rs +789 -0
  32. src/adapters/index/consolidation.rs +576 -0
  33. src/adapters/index/flat.rs +278 -0
  34. src/adapters/index/hat.rs +1953 -0
  35. src/adapters/index/learnable_routing.rs +528 -0
  36. src/adapters/index/mod.rs +45 -0
  37. src/adapters/index/persistence.rs +442 -0
  38. src/adapters/index/subspace.rs +640 -0
  39. src/adapters/mod.rs +19 -0
  40. src/adapters/python.rs +502 -0
  41. src/adapters/storage/memory.rs +253 -0
  42. src/adapters/storage/mod.rs +15 -0
  43. src/core/blob.rs +152 -0
  44. src/core/config.rs +177 -0
  45. src/core/id.rs +169 -0
  46. src/core/merge.rs +335 -0
  47. src/core/mod.rs +64 -0
  48. src/core/point.rs +186 -0
  49. src/core/proximity.rs +261 -0
  50. src/engine/arms.rs +335 -0
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/fig01_architecture.jpg filter=lfs diff=lfs merge=lfs -text
37
+ images/fig02_recall_comparison.jpg filter=lfs diff=lfs merge=lfs -text
38
+ images/fig03_build_time.jpg filter=lfs diff=lfs merge=lfs -text
39
+ images/fig04_pipeline.jpg filter=lfs diff=lfs merge=lfs -text
40
+ images/fig05_hippocampus.jpg filter=lfs diff=lfs merge=lfs -text
41
+ images/fig06_hat_vs_rag.jpg filter=lfs diff=lfs merge=lfs -text
42
+ images/fig07_scale_performance.jpg filter=lfs diff=lfs merge=lfs -text
43
+ images/fig08_consolidation.jpg filter=lfs diff=lfs merge=lfs -text
44
+ images/fig09_summary_results.jpg filter=lfs diff=lfs merge=lfs -text
45
+ images/fig10_beam_search.jpg filter=lfs diff=lfs merge=lfs -text
46
+ paper/figures/fig3_latency_scale.png filter=lfs diff=lfs merge=lfs -text
47
+ paper/figures/fig4_architecture.png filter=lfs diff=lfs merge=lfs -text
48
+ paper/figures/fig7_embedding_dims.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build artifacts
2
+ /target/
3
+ *.so
4
+ *.dylib
5
+ *.dll
6
+
7
+ # Python
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+ *.egg-info/
12
+ .eggs/
13
+ dist/
14
+ build/
15
+ *.egg
16
+ .venv/
17
+ venv/
18
+ ENV/
19
+ env/
20
+ paper_venv/
21
+
22
+ # IDE
23
+ .idea/
24
+ .vscode/
25
+ *.swp
26
+ *.swo
27
+ *~
28
+
29
+ # OS
30
+ .DS_Store
31
+ Thumbs.db
32
+
33
+ # Test artifacts
34
+ .pytest_cache/
35
+ .coverage
36
+ htmlcov/
37
+ .tox/
38
+
39
+ # Rust
40
+ Cargo.lock
41
+
42
+ # Local development
43
+ .env
44
+ .env.local
45
+ *.local
46
+
47
+ # Benchmark outputs
48
+ *.bench
49
+ benchmarks/output/
Cargo.toml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [package]
2
+ name = "arms-hat"
3
+ version = "0.1.0"
4
+ edition = "2021"
5
+ authors = ["Automate Capture LLC <research@automate-capture.com>"]
6
+ description = "Hierarchical Attention Tree: 100% recall at 70x faster build times than HNSW. A new database paradigm for AI memory and hierarchical semantic search."
7
+ license = "MIT"
8
+ repository = "https://github.com/automate-capture/hat"
9
+ homepage = "https://research.automate-capture.com/hat"
10
+ documentation = "https://docs.rs/arms-hat"
11
+ readme = "README.md"
12
+ keywords = ["vector-database", "semantic-search", "llm", "embeddings", "hnsw"]
13
+ categories = ["database", "science", "algorithms"]
14
+ exclude = [
15
+ "target/",
16
+ "src/target/",
17
+ ".venv/",
18
+ ".git/",
19
+ ".claude/",
20
+ "paper/",
21
+ "images/",
22
+ "python/",
23
+ "benchmarks/",
24
+ ".env",
25
+ ]
26
+
27
+ [lib]
28
+ name = "arms_hat"
29
+ path = "src/lib.rs"
30
+ crate-type = ["cdylib", "rlib"] # cdylib for Python, rlib for Rust
31
+
32
+ [dependencies]
33
+ # Core - minimal dependencies for pure logic
34
+ thiserror = "1.0" # Error handling
35
+
36
+ # Python bindings
37
+ pyo3 = { version = "0.22", features = ["extension-module"], optional = true }
38
+
39
+ # Future adapters:
40
+ # parking_lot = "0.12" # Fast locks for concurrent access
41
+ # memmap2 = "0.9" # Memory-mapped files for NVMe
42
+
43
+ [dev-dependencies]
44
+ criterion = "0.5" # Benchmarking
45
+ rusqlite = { version = "0.31", features = ["bundled"] } # Benchmark DB (bundled = no system sqlite needed)
46
+ serde = { version = "1.0", features = ["derive"] }
47
+ serde_json = "1.0"
48
+ hnsw = "0.11" # HNSW implementation for comparison benchmarks
49
+ rand = "0.8" # Random data generation for benchmarks
50
+ rand_distr = "0.4" # Statistical distributions for realistic embeddings
51
+ space = "0.17" # Distance metrics for hnsw
52
+
53
+ [features]
54
+ default = []
55
+ python = ["pyo3"] # Enable Python bindings
56
+
57
+ # [[bench]]
58
+ # name = "proximity"
59
+ # harness = false
60
+
61
+ [profile.release]
62
+ lto = true
63
+ codegen-units = 1
64
+ panic = "abort"
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Automate Capture, LLC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HAT: Hierarchical Attention Tree
2
+
3
+ **A novel index structure for AI memory systems that achieves 100% recall at 70x faster build times than HNSW.**
4
+
5
+ **Also: A new database paradigm for any domain with known hierarchy + semantic similarity.**
6
+
7
+ [![PyPI](https://img.shields.io/pypi/v/arms-hat.svg)](https://pypi.org/project/arms-hat/)
8
+ [![crates.io](https://img.shields.io/crates/v/arms-hat.svg)](https://crates.io/crates/arms-hat)
9
+ [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
10
+ [![Rust](https://img.shields.io/badge/Rust-1.70+-orange.svg)](https://www.rust-lang.org/)
11
+ [![Python](https://img.shields.io/badge/Python-3.8+-blue.svg)](https://www.python.org/)
12
+
13
+ ---
14
+
15
+ ## Architecture
16
+
17
+ <p align="center">
18
+ <img src="images/fig01_architecture.jpg" alt="HAT Architecture" width="800"/>
19
+ </p>
20
+
21
+ HAT exploits the **known hierarchy** in AI conversations: sessions contain documents, documents contain chunks. This structural prior enables O(log n) queries with 100% recall.
22
+
23
+ ---
24
+
25
+ ## Key Results
26
+
27
+ <p align="center">
28
+ <img src="images/fig09_summary_results.jpg" alt="Summary Results" width="800"/>
29
+ </p>
30
+
31
+ | Metric | HAT | HNSW | Improvement |
32
+ |--------|-----|------|-------------|
33
+ | **Recall@10** | **100%** | 70% | +30% |
34
+ | **Build Time** | 30ms | 2.1s | **70x faster** |
35
+ | **Query Latency** | 3.1ms | - | Production-ready |
36
+
37
+ *Benchmarked on hierarchically-structured AI conversation data*
38
+
39
+ ---
40
+
41
+ ## Recall Comparison
42
+
43
+ <p align="center">
44
+ <img src="images/fig02_recall_comparison.jpg" alt="HAT vs HNSW Recall" width="700"/>
45
+ </p>
46
+
47
+ HAT achieves **100% recall** where HNSW achieves only ~70% on hierarchically-structured data.
48
+
49
+ ---
50
+
51
+ ## Build Time
52
+
53
+ <p align="center">
54
+ <img src="images/fig03_build_time.jpg" alt="Build Time Comparison" width="700"/>
55
+ </p>
56
+
57
+ HAT builds indexes **70x faster** than HNSW - critical for real-time applications.
58
+
59
+ ---
60
+
61
+ ## The Problem
62
+
63
+ Large language models have finite context windows. A 10K context model can only "see" the most recent 10K tokens, losing access to earlier conversation history.
64
+
65
+ **Current solutions fall short:**
66
+ - Longer context models: Expensive to train and run
67
+ - Summarization: Lossy compression that discards detail
68
+ - RAG retrieval: Re-embeds and recomputes attention every query
69
+
70
+ ## The HAT Solution
71
+
72
+ <p align="center">
73
+ <img src="images/fig06_hat_vs_rag.jpg" alt="HAT vs RAG" width="800"/>
74
+ </p>
75
+
76
+ HAT exploits **known structure** in AI workloads. Unlike general vector databases that treat data as unstructured point clouds, AI conversations have inherent hierarchy:
77
+
78
+ ```
79
+ Session (conversation boundary)
80
+ └── Document (topic or turn)
81
+ └── Chunk (individual message)
82
+ ```
83
+
84
+ ### The Hippocampus Analogy
85
+
86
+ <p align="center">
87
+ <img src="images/fig05_hippocampus.jpg" alt="Hippocampus Analogy" width="800"/>
88
+ </p>
89
+
90
+ HAT mirrors human memory architecture - functioning as an **artificial hippocampus** for AI systems.
91
+
92
+ ---
93
+
94
+ ## How It Works
95
+
96
+ ### Beam Search Query
97
+
98
+ <p align="center">
99
+ <img src="images/fig10_beam_search.jpg" alt="Beam Search" width="800"/>
100
+ </p>
101
+
102
+ HAT uses beam search through the hierarchy:
103
+
104
+ ```
105
+ 1. Start at root
106
+ 2. At each level, score children by cosine similarity to query
107
+ 3. Keep top-b candidates (beam width)
108
+ 4. Return top-k from leaf level
109
+ ```
110
+
111
+ **Complexity:** O(b · d · c) = O(log n) when balanced
112
+
113
+ ### Consolidation Phases
114
+
115
+ <p align="center">
116
+ <img src="images/fig08_consolidation.jpg" alt="Consolidation Phases" width="800"/>
117
+ </p>
118
+
119
+ Inspired by sleep-staged memory consolidation, HAT maintains index quality through incremental consolidation.
120
+
121
+ ---
122
+
123
+ ## Scale Performance
124
+
125
+ <p align="center">
126
+ <img src="images/fig07_scale_performance.jpg" alt="Scale Performance" width="700"/>
127
+ </p>
128
+
129
+ HAT maintains **100% recall** across all tested scales while HNSW degrades significantly.
130
+
131
+ | Scale | HAT Build | HNSW Build | HAT R@10 | HNSW R@10 |
132
+ |-------|-----------|------------|----------|-----------|
133
+ | 500 | 16ms | 1.0s | **100%** | 55% |
134
+ | 1000 | 25ms | 2.0s | **100%** | 44.5% |
135
+ | 2000 | 50ms | 4.3s | **100%** | 67.5% |
136
+ | 5000 | 127ms | 11.9s | **100%** | 55% |
137
+
138
+ ---
139
+
140
+ ## End-to-End Pipeline
141
+
142
+ <p align="center">
143
+ <img src="images/fig04_pipeline.jpg" alt="Integration Pipeline" width="800"/>
144
+ </p>
145
+
146
+ ### Core Claim
147
+
148
+ > **A 10K context model with HAT achieves 100% recall on 60K+ tokens with 3.1ms latency.**
149
+
150
+ | Messages | Tokens | Context % | Recall | Latency | Memory |
151
+ |----------|--------|-----------|--------|---------|--------|
152
+ | 1000 | 30K | 33% | 100% | 1.7ms | 1.6MB |
153
+ | 2000 | 60K | 17% | 100% | 3.1ms | 3.3MB |
154
+
155
+ ---
156
+
157
+ ## Quick Start
158
+
159
+ ### Python
160
+
161
+ ```python
162
+ from arms_hat import HatIndex
163
+
164
+ # Create index (1536 dimensions for OpenAI embeddings)
165
+ index = HatIndex.cosine(1536)
166
+
167
+ # Add messages with automatic hierarchy
168
+ index.add(embedding) # Returns ID
169
+
170
+ # Session/document management
171
+ index.new_session() # Start new conversation
172
+ index.new_document() # Start new topic
173
+
174
+ # Query
175
+ results = index.near(query_embedding, k=10)
176
+ for result in results:
177
+ print(f"ID: {result.id}, Score: {result.score:.4f}")
178
+
179
+ # Persistence
180
+ index.save("memory.hat")
181
+ loaded = HatIndex.load("memory.hat")
182
+ ```
183
+
184
+ ### Rust
185
+
186
+ ```rust
187
+ use hat::{HatIndex, HatConfig};
188
+
189
+ // Create index
190
+ let config = HatConfig::default();
191
+ let mut index = HatIndex::new(config, 1536);
192
+
193
+ // Add points
194
+ let id = index.add(&embedding);
195
+
196
+ // Query
197
+ let results = index.search(&query, 10);
198
+ ```
199
+
200
+ ---
201
+
202
+ ## Installation
203
+
204
+ ### Python
205
+
206
+ ```bash
207
+ pip install arms-hat
208
+ ```
209
+
210
+ ### From Source (Rust)
211
+
212
+ ```bash
213
+ git clone https://github.com/automate-capture/hat.git
214
+ cd hat
215
+ cargo build --release
216
+ ```
217
+
218
+ ### Python Development
219
+
220
+ ```bash
221
+ cd python
222
+ pip install maturin
223
+ maturin develop
224
+ ```
225
+
226
+ ---
227
+
228
+ ## Project Structure
229
+
230
+ ```
231
+ hat/
232
+ ├── src/ # Rust implementation
233
+ │ ├── lib.rs # Library entry point
234
+ │ ├── index.rs # HatIndex implementation
235
+ │ ├── container.rs # Tree node types
236
+ │ ├── consolidation.rs # Background maintenance
237
+ │ └── persistence.rs # Save/load functionality
238
+ ├── python/ # Python bindings (PyO3)
239
+ │ └── arms_hat/ # Python package
240
+ ├── benchmarks/ # Performance comparisons
241
+ ├── examples/ # Usage examples
242
+ ├── paper/ # Research paper (PDF)
243
+ ├── images/ # Figures and diagrams
244
+ └── tests/ # Test suite
245
+ ```
246
+
247
+ ---
248
+
249
+ ## Reproducing Results
250
+
251
+ ```bash
252
+ # Run HAT vs HNSW benchmark
253
+ cargo test --test phase31_hat_vs_hnsw -- --nocapture
254
+
255
+ # Run real embedding dimension tests
256
+ cargo test --test phase32_real_embeddings -- --nocapture
257
+
258
+ # Run persistence tests
259
+ cargo test --test phase33_persistence -- --nocapture
260
+
261
+ # Run end-to-end LLM demo
262
+ python examples/demo_hat_memory.py
263
+ ```
264
+
265
+ ---
266
+
267
+ ## When to Use HAT
268
+
269
+ **HAT is ideal for:**
270
+ - AI conversation memory (chatbots, agents)
271
+ - Session-based retrieval systems
272
+ - Any hierarchically-structured vector data
273
+ - Systems requiring deterministic behavior
274
+ - Cold-start scenarios (no training needed)
275
+
276
+ **Use HNSW instead for:**
277
+ - Unstructured point clouds (random embeddings)
278
+ - Static knowledge bases (handbooks, catalogs)
279
+ - When approximate recall is acceptable
280
+
281
+ ---
282
+
283
+ ## Beyond AI Memory: A New Database Paradigm
284
+
285
+ HAT represents a fundamentally new approach to indexing: **exploiting known structure rather than learning it**.
286
+
287
+ | Database Type | Structure | Semantics |
288
+ |---------------|-----------|-----------|
289
+ | Relational | Explicit (foreign keys) | None |
290
+ | Document | Implicit (nesting) | None |
291
+ | Vector (HNSW) | Learned from data | Yes |
292
+ | **HAT** | **Explicit + exploited** | **Yes** |
293
+
294
+ Traditional vector databases treat embeddings as unstructured point clouds, spending compute to *discover* topology. HAT inverts this: **known hierarchy is free information - use it.**
295
+
296
+ ### General Applications
297
+
298
+ Any domain with **hierarchical structure + semantic similarity** benefits from HAT:
299
+
300
+ - **Legal/Medical Documents:** Case → Filing → Paragraph → Sentence
301
+ - **Code Search:** Repository → Module → Function → Line
302
+ - **IoT/Sensor Networks:** Facility → Zone → Device → Reading
303
+ - **E-commerce:** Catalog → Category → Product → Variant
304
+ - **Research Corpora:** Journal → Paper → Section → Citation
305
+
306
+ ### The Core Insight
307
+
308
+ > *"Position IS relationship. No foreign keys needed - proximity defines connection."*
309
+
310
+ HAT combines the structural guarantees of document databases with the semantic power of vector search, without the computational overhead of learning topology from scratch.
311
+
312
+ ---
313
+
314
+ ## Citation
315
+
316
+ ```bibtex
317
+ @article{hat2026,
318
+ title={Hierarchical Attention Tree: Extending LLM Context Through Structural Memory},
319
+ author={Young, Lucas and Automate Capture Research},
320
+ year={2026},
321
+ url={https://research.automate-capture.com/hat}
322
+ }
323
+ ```
324
+
325
+ ---
326
+
327
+ ## Paper
328
+
329
+ 📄 **[Read the Full Paper (PDF)](paper/HAT_Context_Extension_Young_2026.pdf)**
330
+
331
+ ---
332
+
333
+ ## License
334
+
335
+ MIT License - see [LICENSE](LICENSE) for details.
336
+
337
+ ---
338
+
339
+ ## Links
340
+
341
+ - **Research Site:** [research.automate-capture.com/hat](https://research.automate-capture.com/hat)
342
+ - **Main Site:** [automate-capture.com](https://automate-capture.com)
benchmarks/README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HAT Benchmark Reproducibility Package
2
+
3
+ This directory contains everything needed to reproduce the benchmark results from the HAT paper.
4
+
5
+ ## Quick Start
6
+
7
+ ```bash
8
+ # Run all benchmarks
9
+ ./run_all_benchmarks.sh
10
+
11
+ # Run abbreviated version (faster)
12
+ ./run_all_benchmarks.sh --quick
13
+ ```
14
+
15
+ ## Benchmark Suite
16
+
17
+ ### Phase 3.1: HAT vs HNSW Comparison
18
+
19
+ **Test file**: `tests/phase31_hat_vs_hnsw.rs`
20
+
21
+ Compares HAT against HNSW on hierarchically-structured data (AI conversation patterns).
22
+
23
+ **Expected Results**:
24
+
25
+ | Metric | HAT | HNSW |
26
+ |--------|-----|------|
27
+ | Recall@10 | 100% | ~70% |
28
+ | Build Time | 30ms | 2100ms |
29
+ | Query Latency | 1.4ms | 0.5ms |
30
+
31
+ **Key finding**: HAT achieves 30% higher recall while building 70x faster.
32
+
33
+ ### Phase 3.2: Real Embedding Dimensions
34
+
35
+ **Test file**: `tests/phase32_real_embeddings.rs`
36
+
37
+ Tests HAT with production embedding sizes.
38
+
39
+ **Expected Results**:
40
+
41
+ | Dimensions | Model | Recall@10 |
42
+ |------------|-------|-----------|
43
+ | 384 | MiniLM | 100% |
44
+ | 768 | BERT-base | 100% |
45
+ | 1536 | OpenAI ada-002 | 100% |
46
+
47
+ ### Phase 3.3: Persistence Layer
48
+
49
+ **Test file**: `tests/phase33_persistence.rs`
50
+
51
+ Validates serialization/deserialization correctness and performance.
52
+
53
+ **Expected Results**:
54
+
55
+ | Metric | Value |
56
+ |--------|-------|
57
+ | Serialize throughput | 300+ MB/s |
58
+ | Deserialize throughput | 100+ MB/s |
59
+ | Recall after restore | 100% |
60
+
61
+ ### Phase 4.2: Attention State Format
62
+
63
+ **Test file**: `tests/phase42_attention_state.rs`
64
+
65
+ Tests the attention state serialization format.
66
+
67
+ **Expected Results**:
68
+ - All 9 tests pass
69
+ - Role types roundtrip correctly
70
+ - Metadata preserved
71
+ - KV cache support working
72
+
73
+ ### Phase 4.3: End-to-End Demo
74
+
75
+ **Script**: `examples/demo_hat_memory.py`
76
+
77
+ Full integration with sentence-transformers and optional LLM.
78
+
79
+ **Expected Results**:
80
+
81
+ | Metric | Value |
82
+ |--------|-------|
83
+ | Messages | 2000 |
84
+ | Tokens | ~60,000 |
85
+ | Recall accuracy | 100% |
86
+ | Retrieval latency | <5ms |
87
+
88
+ ## Running Individual Benchmarks
89
+
90
+ ### Rust Benchmarks
91
+
92
+ ```bash
93
+ # HAT vs HNSW
94
+ cargo test --test phase31_hat_vs_hnsw -- --nocapture
95
+
96
+ # Real embeddings
97
+ cargo test --test phase32_real_embeddings -- --nocapture
98
+
99
+ # Persistence
100
+ cargo test --test phase33_persistence -- --nocapture
101
+
102
+ # Attention state
103
+ cargo test --test phase42_attention_state -- --nocapture
104
+ ```
105
+
106
+ ### Python Tests
107
+
108
+ ```bash
109
+ # Setup
110
+ python3 -m venv venv
111
+ source venv/bin/activate
112
+ pip install maturin pytest sentence-transformers
113
+
114
+ # Build extension
115
+ maturin develop --features python
116
+
117
+ # Run tests
118
+ pytest python/tests/ -v
119
+
120
+ # Run demo
121
+ python examples/demo_hat_memory.py
122
+ ```
123
+
124
+ ## Hardware Requirements
125
+
126
+ - **Minimum**: 4GB RAM, any modern CPU
127
+ - **Recommended**: 8GB RAM for large-scale tests
128
+ - **Storage**: ~2GB for full benchmark suite
129
+
130
+ ## Expected Runtime
131
+
132
+ | Mode | Time |
133
+ |------|------|
134
+ | Quick (`--quick`) | ~2 minutes |
135
+ | Full | ~10 minutes |
136
+ | With LLM demo | ~15 minutes |
137
+
138
+ ## Interpreting Results
139
+
140
+ ### Key Metrics
141
+
142
+ 1. **Recall@k**: Percentage of true nearest neighbors found
143
+ - HAT target: 100% on hierarchical data
144
+ - HNSW baseline: ~70% on hierarchical data
145
+
146
+ 2. **Build Time**: Time to construct the index
147
+ - HAT target: <100ms for 1000 points
148
+ - Should be 50-100x faster than HNSW
149
+
150
+ 3. **Query Latency**: Time per query
151
+ - HAT target: <5ms
152
+ - Acceptable to be 2-3x slower than HNSW (recall matters more)
153
+
154
+ 4. **Throughput**: Serialization/deserialization speed
155
+ - Target: 100+ MB/s
156
+
157
+ ### Success Criteria
158
+
159
+ The benchmarks validate the paper's claims if:
160
+
161
+ 1. HAT recall@10 ≥ 99% on hierarchical data
162
+ 2. HAT recall significantly exceeds HNSW on hierarchical data
163
+ 3. HAT builds faster than HNSW
164
+ 4. Persistence preserves 100% recall
165
+ 5. Python bindings pass all tests
166
+ 6. End-to-end demo achieves ≥95% retrieval accuracy
167
+
168
+ ## Troubleshooting
169
+
170
+ ### Build Errors
171
+
172
+ ```bash
173
+ # Update Rust
174
+ rustup update
175
+
176
+ # Clean build
177
+ cargo clean && cargo build --release
178
+ ```
179
+
180
+ ### Python Issues
181
+
182
+ ```bash
183
+ # Ensure venv is activated
184
+ source venv/bin/activate
185
+
186
+ # Rebuild extension
187
+ maturin develop --features python --release
188
+ ```
189
+
190
+ ### Memory Issues
191
+
192
+ For large-scale tests, ensure sufficient RAM:
193
+
194
+ ```bash
195
+ # Check available memory
196
+ free -h
197
+
198
+ # Run with limited parallelism
199
+ RAYON_NUM_THREADS=2 cargo test --test phase31_hat_vs_hnsw
200
+ ```
201
+
202
+ ## Output Files
203
+
204
+ Results are saved to `benchmarks/results/`:
205
+
206
+ ```
207
+ results/
208
+ benchmark_results_YYYYMMDD_HHMMSS.txt # Full output
209
+ ```
210
+
211
+ ## Citation
212
+
213
+ If you use these benchmarks, please cite:
214
+
215
+ ```bibtex
216
+ @article{hat2026,
217
+ title={Hierarchical Attention Tree: Extending LLM Context Through Structural Memory},
218
+ author={AI Research Lab},
219
+ year={2026}
220
+ }
221
+ ```
benchmarks/results/benchmark_results_20260110_181653.txt ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HAT Benchmark Results
2
+ =====================
3
+ Date: Sat Jan 10 06:16:53 PM CST 2026
4
+ Host: lumi-node-MS-7E32
5
+ Rust: rustc 1.92.0 (ded5c06cf 2025-12-08)
6
+ Quick mode: true
7
+
8
+
9
+ === HAT vs HNSW ===
10
+
11
+ warning: unused import: `Point`
12
+ --> src/adapters/index/persistence.rs:51:23
13
+ |
14
+ 51 | use crate::core::{Id, Point};
15
+ | ^^^^^
16
+ |
17
+ = note: `#[warn(unused_imports)]` (part of `#[warn(unused)]`) on by default
18
+
19
+ warning: method `child_level` is never used
20
+ --> src/adapters/index/hat.rs:169:8
21
+ |
22
+ 168 | impl ContainerLevel {
23
+ | ------------------- method in this implementation
24
+ 169 | fn child_level(&self) -> Option<ContainerLevel> {
25
+ | ^^^^^^^^^^^
26
+ |
27
+ = note: `#[warn(dead_code)]` (part of `#[warn(unused)]`) on by default
28
+
29
+ warning: field `merge` is never read
30
+ --> src/adapters/index/hat.rs:309:5
31
+ |
32
+ 289 | pub struct HatIndex {
33
+ | -------- field in this struct
34
+ ...
35
+ 309 | merge: Arc<dyn Merge>,
36
+ | ^^^^^
37
+
38
+ warning: methods `compute_frechet_mean` and `geodesic_interpolate` are never used
39
+ --> src/adapters/index/hat.rs:518:8
40
+ |
41
+ 327 | impl HatIndex {
42
+ | ------------- methods in this implementation
43
+ ...
44
+ 518 | fn compute_frechet_mean(&self, points: &[Point], initial: &Point) -> Point {
45
+ | ^^^^^^^^^^^^^^^^^^^^
46
+ ...
47
+ 722 | fn geodesic_interpolate(&self, a: &Point, b: &Point, t: f32) -> Point {
48
+ | ^^^^^^^^^^^^^^^^^^^^
49
+
50
+ warning: function `id_to_bytes` is never used
51
+ --> src/adapters/index/persistence.rs:376:4
52
+ |
53
+ 376 | fn id_to_bytes(id: &Option<Id>) -> [u8; 16] {
54
+ | ^^^^^^^^^^^
55
+
56
+ warning: `arms-hat` (lib) generated 5 warnings (run `cargo fix --lib -p arms-hat` to apply 1 suggestion)
57
+ warning: function `get_git_info` is never used
58
+ --> tests/benchmark_db.rs:101:4
59
+ |
60
+ 101 | fn get_git_info() -> (Option<String>, Option<String>, bool) {
61
+ | ^^^^^^^^^^^^
62
+ |
63
+ = note: `#[warn(dead_code)]` (part of `#[warn(unused)]`) on by default
64
+
65
+ warning: function `create_run` is never used
66
+ --> tests/benchmark_db.rs:127:8
67
+ |
68
+ 127 | pub fn create_run(
69
+ | ^^^^^^^^^^
70
+
71
+ warning: function `log_hat_config` is never used
72
+ --> tests/benchmark_db.rs:158:8
73
+ |
74
+ 158 | pub fn log_hat_config(
75
+ | ^^^^^^^^^^^^^^
76
+
77
+ warning: function `log_metric` is never used
78
+ --> tests/benchmark_db.rs:177:8
79
+ |
80
+ 177 | pub fn log_metric(
81
+ | ^^^^^^^^^^
82
+
83
+ warning: function `log_comparison` is never used
84
+ --> tests/benchmark_db.rs:196:8
85
+ |
86
+ 196 | pub fn log_comparison(
87
+ | ^^^^^^^^^^^^^^
88
+
89
+ warning: function `add_analysis` is never used
90
+ --> tests/benchmark_db.rs:236:8
91
+ |
92
+ 236 | pub fn add_analysis(
93
+ | ^^^^^^^^^^^^
94
+
95
+ warning: `arms-hat` (test "phase31_hat_vs_hnsw") generated 6 warnings
96
+ Finished `test` profile [unoptimized + debuginfo] target(s) in 0.03s
97
+ Running tests/phase31_hat_vs_hnsw.rs (target/debug/deps/phase31_hat_vs_hnsw-ca1c4405f0884451)
98
+
99
+ running 4 tests
100
+
101
+ ============================================================
102
+ Initializing Benchmark Database
103
+ ============================================================
104
+
105
+ ================================================================================
106
+ Phase 3.1: HAT vs HNSW on HIERARCHICAL Data
107
+ ================================================================================
108
+
109
+ Data Configuration:
110
+ Sessions: 20
111
+ Documents/session: 5
112
+ Chunks/document: 10
113
+ Total points: 1000
114
+ Dimensions: 128
115
+
116
+ ================================================================================
117
+ Phase 3.1: HAT vs HNSW on RANDOM Data
118
+ ================================================================================
119
+
120
+ Data Configuration:
121
+ Points: 1000
122
+ Dimensions: 128
123
+ Structure: Random (no hierarchy)
124
+
125
+ ================================================================================
126
+ Phase 3.1: HAT vs HNSW at Various Scales
127
+ ================================================================================
128
+
129
+ Scale | HAT Build | HNSW Build | HAT R@10 | HNSW R@10
130
+ ----------------------------------------------------------------------
131
+
132
+ Tables created:
133
+ - analysis
134
+ - comparisons
135
+ - configs
136
+ - metrics
137
+ - runs
138
+ - sqlite_sequence
139
+
140
+ Database path: ../../benchmarks/results.db
141
+
142
+ [PASSED] Database initialized successfully
143
+ test benchmark_db::test_init_database ... ok
144
+
145
+ --- Building Indexes ---
146
+
147
+ --- Building Indexes ---
148
+ Flat build: 1.044033ms
149
+ HAT build: 31.384445ms
150
+ 500 | 15.48ms | 1.00s | 100.0% | 55.0%
151
+ HNSW build: 2.094521703s
152
+
153
+ --- Query Benchmark ---
154
+
155
+ Recall Comparison (Hierarchical Data):
156
+ k | HAT | HNSW | Δ (HAT-HNSW)
157
+ --------------------------------------------------
158
+ 1 | 100.0% | 76.0% | +24.0%
159
+ 5 | 100.0% | 72.0% | +28.0%
160
+ 10 | 100.0% | 70.6% | +29.4%
161
+ 20 | 100.0% | 68.0% | +32.0%
162
+ 30 | 100.0% | 66.0% | +34.0%
163
+
164
+ Latency Comparison:
165
+ HAT: 1.426ms/query
166
+ HNSW: 0.487ms/query
167
+
168
+ Build Time Comparison:
169
+ Flat: 1.044033ms
170
+ HAT: 31.384445ms
171
+ HNSW: 2.094521703s
172
+
173
+ ================================================================================
174
+ SUMMARY: Hierarchical Data
175
+ ================================================================================
176
+ HAT Recall@10: 100.0%
177
+ HNSW Recall@10: 70.6%
178
+ Advantage: HAT by 29.4%
179
+ test test_phase31_hierarchical_data_comparison ... ok
benchmarks/run_all_benchmarks.sh ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #
3
+ # HAT Benchmark Reproducibility Suite
4
+ # ===================================
5
+ #
6
+ # This script runs all benchmarks from the HAT paper and generates
7
+ # a comprehensive results report.
8
+ #
9
+ # Usage:
10
+ # ./run_all_benchmarks.sh [--quick]
11
+ #
12
+ # Options:
13
+ # --quick Run abbreviated benchmarks (faster, less thorough)
14
+ #
15
+ # Requirements:
16
+ # - Rust toolchain (cargo)
17
+ # - Python 3.8+ with venv
18
+ # - ~2GB free disk space
19
+ # - ~10 minutes for full suite, ~2 minutes for quick
20
+
21
+ set -e
22
+
23
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
24
+ PROJECT_DIR="$(dirname "$SCRIPT_DIR")"
25
+ RESULTS_DIR="$SCRIPT_DIR/results"
26
+ TIMESTAMP=$(date +%Y%m%d_%H%M%S)
27
+ RESULTS_FILE="$RESULTS_DIR/benchmark_results_$TIMESTAMP.txt"
28
+
29
+ # Colors for output
30
+ RED='\033[0;31m'
31
+ GREEN='\033[0;32m'
32
+ YELLOW='\033[1;33m'
33
+ BLUE='\033[0;34m'
34
+ NC='\033[0m' # No Color
35
+
36
+ # Parse arguments
37
+ QUICK_MODE=false
38
+ if [[ "$1" == "--quick" ]]; then
39
+ QUICK_MODE=true
40
+ echo -e "${YELLOW}Running in quick mode (abbreviated benchmarks)${NC}"
41
+ fi
42
+
43
+ # Create results directory
44
+ mkdir -p "$RESULTS_DIR"
45
+
46
+ echo "========================================================================"
47
+ echo " HAT Benchmark Reproducibility Suite"
48
+ echo " $(date)"
49
+ echo "========================================================================"
50
+ echo ""
51
+ echo "Project directory: $PROJECT_DIR"
52
+ echo "Results will be saved to: $RESULTS_FILE"
53
+ echo ""
54
+
55
+ # Initialize results file
56
+ cat > "$RESULTS_FILE" << EOF
57
+ HAT Benchmark Results
58
+ =====================
59
+ Date: $(date)
60
+ Host: $(hostname)
61
+ Rust: $(rustc --version)
62
+ Quick mode: $QUICK_MODE
63
+
64
+ EOF
65
+
66
+ cd "$PROJECT_DIR"
67
+
68
+ # Function to run a test and capture results
69
+ run_benchmark() {
70
+ local name="$1"
71
+ local test_name="$2"
72
+
73
+ echo -e "${BLUE}[$name]${NC} Running..."
74
+ echo "" >> "$RESULTS_FILE"
75
+ echo "=== $name ===" >> "$RESULTS_FILE"
76
+ echo "" >> "$RESULTS_FILE"
77
+
78
+ if cargo test --test "$test_name" -- --nocapture 2>&1 | tee -a "$RESULTS_FILE"; then
79
+ echo -e "${GREEN}[$name]${NC} PASSED"
80
+ else
81
+ echo -e "${RED}[$name]${NC} FAILED"
82
+ echo "FAILED" >> "$RESULTS_FILE"
83
+ fi
84
+ echo ""
85
+ }
86
+
87
+ echo "========================================================================"
88
+ echo " Phase 1: Building Project"
89
+ echo "========================================================================"
90
+
91
+ echo "Building release version..."
92
+ cargo build --release 2>&1 | tail -5
93
+
94
+ echo "Building test suite..."
95
+ cargo build --tests 2>&1 | tail -5
96
+
97
+ echo ""
98
+ echo "========================================================================"
99
+ echo " Phase 2: Running Core Benchmarks"
100
+ echo "========================================================================"
101
+
102
+ # Phase 3.1: HAT vs HNSW
103
+ echo ""
104
+ echo "--- Phase 3.1: HAT vs HNSW Comparative Benchmark ---"
105
+ run_benchmark "HAT vs HNSW" "phase31_hat_vs_hnsw"
106
+
107
+ # Phase 3.2: Real Embeddings
108
+ echo ""
109
+ echo "--- Phase 3.2: Real Embedding Dimensions ---"
110
+ run_benchmark "Real Embeddings" "phase32_real_embeddings"
111
+
112
+ # Phase 3.3: Persistence
113
+ echo ""
114
+ echo "--- Phase 3.3: Persistence Layer ---"
115
+ run_benchmark "Persistence" "phase33_persistence"
116
+
117
+ # Phase 4.2: Attention State
118
+ echo ""
119
+ echo "--- Phase 4.2: Attention State Format ---"
120
+ run_benchmark "Attention State" "phase42_attention_state"
121
+
122
+ echo ""
123
+ echo "========================================================================"
124
+ echo " Phase 3: Python Integration Tests"
125
+ echo "========================================================================"
126
+
127
+ # Check for Python venv
128
+ VENV_DIR="/tmp/arms-hat-bench-venv"
129
+
130
+ if [[ ! -d "$VENV_DIR" ]]; then
131
+ echo "Creating Python virtual environment..."
132
+ python3 -m venv "$VENV_DIR"
133
+ fi
134
+
135
+ source "$VENV_DIR/bin/activate"
136
+
137
+ # Install dependencies
138
+ echo "Installing Python dependencies..."
139
+ pip install -q maturin pytest 2>/dev/null || true
140
+
141
+ # Build Python extension
142
+ echo "Building Python extension..."
143
+ maturin develop --features python 2>&1 | tail -3
144
+
145
+ # Run Python tests
146
+ echo ""
147
+ echo "--- Python Binding Tests ---"
148
+ echo "" >> "$RESULTS_FILE"
149
+ echo "=== Python Binding Tests ===" >> "$RESULTS_FILE"
150
+ echo "" >> "$RESULTS_FILE"
151
+
152
+ if python -m pytest "$PROJECT_DIR/python/tests/" -v 2>&1 | tee -a "$RESULTS_FILE"; then
153
+ echo -e "${GREEN}[Python Tests]${NC} PASSED"
154
+ else
155
+ echo -e "${RED}[Python Tests]${NC} FAILED"
156
+ fi
157
+
158
+ echo ""
159
+ echo "========================================================================"
160
+ echo " Phase 4: End-to-End Demo"
161
+ echo "========================================================================"
162
+
163
+ echo "" >> "$RESULTS_FILE"
164
+ echo "=== End-to-End Demo ===" >> "$RESULTS_FILE"
165
+ echo "" >> "$RESULTS_FILE"
166
+
167
+ # Check for sentence-transformers
168
+ if pip show sentence-transformers >/dev/null 2>&1; then
169
+ echo "Running end-to-end demo with real embeddings..."
170
+ python "$PROJECT_DIR/examples/demo_hat_memory.py" 2>&1 | tee -a "$RESULTS_FILE"
171
+ else
172
+ echo "Installing sentence-transformers for full demo..."
173
+ pip install -q sentence-transformers 2>/dev/null || true
174
+
175
+ if pip show sentence-transformers >/dev/null 2>&1; then
176
+ python "$PROJECT_DIR/examples/demo_hat_memory.py" 2>&1 | tee -a "$RESULTS_FILE"
177
+ else
178
+ echo "Running demo with pseudo-embeddings (sentence-transformers not available)..."
179
+ python "$PROJECT_DIR/examples/demo_hat_memory.py" 2>&1 | tee -a "$RESULTS_FILE"
180
+ fi
181
+ fi
182
+
183
+ deactivate
184
+
185
+ echo ""
186
+ echo "========================================================================"
187
+ echo " Summary"
188
+ echo "========================================================================"
189
+
190
+ # Extract key metrics from results
191
+ echo "" >> "$RESULTS_FILE"
192
+ echo "=== Summary ===" >> "$RESULTS_FILE"
193
+ echo "" >> "$RESULTS_FILE"
194
+
195
+ # Count passed tests
196
+ RUST_PASSED=$(grep -c "test .* ok" "$RESULTS_FILE" 2>/dev/null || echo "0")
197
+ PYTHON_PASSED=$(grep -c "PASSED" "$RESULTS_FILE" 2>/dev/null || echo "0")
198
+
199
+ echo "Results saved to: $RESULTS_FILE"
200
+ echo ""
201
+ echo "Key Results:"
202
+ echo " - Rust tests passed: ~$RUST_PASSED"
203
+ echo " - Python tests passed: ~$PYTHON_PASSED"
204
+ echo ""
205
+
206
+ # Extract recall metrics if available
207
+ if grep -q "HAT enables 100% recall" "$RESULTS_FILE"; then
208
+ echo -e "${GREEN}Core claim validated: 100% recall achieved${NC}"
209
+ fi
210
+
211
+ if grep -q "Average retrieval latency" "$RESULTS_FILE"; then
212
+ LATENCY=$(grep "Average retrieval latency" "$RESULTS_FILE" | tail -1 | grep -oE '[0-9]+\.[0-9]+ms')
213
+ echo " - Retrieval latency: $LATENCY"
214
+ fi
215
+
216
+ echo ""
217
+ echo "========================================================================"
218
+ echo " Benchmark Complete"
219
+ echo "========================================================================"
220
+ echo ""
221
+ echo "Full results: $RESULTS_FILE"
222
+ echo ""
examples/demo_hat_memory.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Phase 4.3: End-to-End HAT Memory Demo
4
+
5
+ Demonstrates HAT enabling a local LLM to recall from conversations
6
+ exceeding its native context window.
7
+
8
+ The demo:
9
+ 1. Simulates a long conversation history (1000+ messages)
10
+ 2. Stores all messages in HAT with embeddings
11
+ 3. Shows the LLM retrieving relevant past context
12
+ 4. Compares responses with and without HAT memory
13
+
14
+ Requirements:
15
+ pip install ollama sentence-transformers
16
+
17
+ Usage:
18
+ python demo_hat_memory.py
19
+ """
20
+
21
+ import time
22
+ import random
23
+ from dataclasses import dataclass
24
+ from typing import List, Optional
25
+
26
+ # HAT imports
27
+ try:
28
+ from arms_hat import HatIndex
29
+ except ImportError:
30
+ print("Error: arms_hat not installed. Run: maturin develop --features python")
31
+ exit(1)
32
+
33
+ # Optional: Ollama for LLM
34
+ try:
35
+ import ollama
36
+ HAS_OLLAMA = True
37
+ except ImportError:
38
+ HAS_OLLAMA = False
39
+ print("Note: ollama package not installed. Will simulate LLM responses.")
40
+
41
+ # Optional: Sentence transformers for real embeddings
42
+ try:
43
+ from sentence_transformers import SentenceTransformer
44
+ HAS_EMBEDDINGS = True
45
+ except ImportError:
46
+ HAS_EMBEDDINGS = False
47
+ print("Note: sentence-transformers not installed. Using deterministic pseudo-embeddings.")
48
+
49
+
50
+ @dataclass
51
+ class Message:
52
+ """A conversation message."""
53
+ role: str # "user" or "assistant"
54
+ content: str
55
+ embedding: Optional[List[float]] = None
56
+ hat_id: Optional[str] = None
57
+
58
+
59
+ class SimpleEmbedder:
60
+ """Fallback embedder using deterministic pseudo-vectors."""
61
+
62
+ def __init__(self, dims: int = 384):
63
+ self.dims = dims
64
+ self._cache = {}
65
+
66
+ def encode(self, text: str) -> List[float]:
67
+ """Generate a deterministic pseudo-embedding from text."""
68
+ if text in self._cache:
69
+ return self._cache[text]
70
+
71
+ # Use hash for determinism - similar words get similar vectors
72
+ words = text.lower().split()
73
+ embedding = [0.0] * self.dims
74
+
75
+ for i, word in enumerate(words):
76
+ word_hash = hash(word) % (2**31)
77
+ random.seed(word_hash)
78
+ for d in range(self.dims):
79
+ embedding[d] += random.gauss(0, 1) / (len(words) + 1)
80
+
81
+ # Add position-based component
82
+ random.seed(hash(text) % (2**31))
83
+ for d in range(self.dims):
84
+ embedding[d] += random.gauss(0, 0.1)
85
+
86
+ # Normalize
87
+ norm = sum(x*x for x in embedding) ** 0.5
88
+ if norm > 0:
89
+ embedding = [x / norm for x in embedding]
90
+
91
+ self._cache[text] = embedding
92
+ return embedding
93
+
94
+
95
+ class HATMemory:
96
+ """HAT-backed conversation memory."""
97
+
98
+ def __init__(self, embedding_dims: int = 384):
99
+ self.index = HatIndex.cosine(embedding_dims)
100
+ self.messages: dict[str, Message] = {} # id -> message
101
+ self.dims = embedding_dims
102
+
103
+ if HAS_EMBEDDINGS:
104
+ print("Loading sentence-transformers model (all-MiniLM-L6-v2)...")
105
+ self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
106
+ self.embed = lambda text: self.embedder.encode(text).tolist()
107
+ print(" Model loaded.")
108
+ else:
109
+ self.embedder = SimpleEmbedder(embedding_dims)
110
+ self.embed = self.embedder.encode
111
+
112
+ def add_message(self, role: str, content: str) -> str:
113
+ """Add a message to memory."""
114
+ embedding = self.embed(content)
115
+ hat_id = self.index.add(embedding)
116
+
117
+ msg = Message(role=role, content=content, embedding=embedding, hat_id=hat_id)
118
+ self.messages[hat_id] = msg
119
+
120
+ return hat_id
121
+
122
+ def new_session(self):
123
+ """Start a new conversation session."""
124
+ self.index.new_session()
125
+
126
+ def new_document(self):
127
+ """Start a new document/topic within session."""
128
+ self.index.new_document()
129
+
130
+ def retrieve(self, query: str, k: int = 5) -> List[Message]:
131
+ """Retrieve k most relevant messages for a query."""
132
+ embedding = self.embed(query)
133
+ results = self.index.near(embedding, k=k)
134
+
135
+ return [self.messages[r.id] for r in results if r.id in self.messages]
136
+
137
+ def stats(self):
138
+ """Get memory statistics."""
139
+ return self.index.stats()
140
+
141
+ def save(self, path: str):
142
+ """Save the index to a file."""
143
+ self.index.save(path)
144
+
145
+ @classmethod
146
+ def load(cls, path: str, embedding_dims: int = 384) -> 'HATMemory':
147
+ """Load an index from a file."""
148
+ memory = cls(embedding_dims)
149
+ memory.index = HatIndex.load(path)
150
+ return memory
151
+
152
+
153
+ def generate_synthetic_history(memory: HATMemory, num_sessions: int = 10, msgs_per_session: int = 100):
154
+ """Generate a synthetic conversation history with distinct topics."""
155
+
156
+ topics = [
157
+ ("quantum computing", [
158
+ "What is quantum entanglement?",
159
+ "How do qubits differ from classical bits?",
160
+ "Explain Shor's algorithm for factoring",
161
+ "What is quantum supremacy?",
162
+ "How does quantum error correction work?",
163
+ "What are the challenges of building quantum computers?",
164
+ "How does quantum tunneling enable quantum computing?",
165
+ ]),
166
+ ("machine learning", [
167
+ "What is gradient descent?",
168
+ "Explain backpropagation in neural networks",
169
+ "What are transformers in machine learning?",
170
+ "How does the attention mechanism work?",
171
+ "What is the vanishing gradient problem?",
172
+ "How do convolutional neural networks work?",
173
+ "What is transfer learning?",
174
+ ]),
175
+ ("cooking recipes", [
176
+ "How do I make authentic pasta carbonara?",
177
+ "What's the secret to crispy fried chicken?",
178
+ "Best way to cook a perfect medium-rare steak?",
179
+ "How to make homemade sourdough bread?",
180
+ "What are good vegetarian protein sources for cooking?",
181
+ "How do I properly caramelize onions?",
182
+ "What's the difference between baking and roasting?",
183
+ ]),
184
+ ("travel planning", [
185
+ "Best time to visit Japan for cherry blossoms?",
186
+ "How to plan a budget-friendly Europe trip?",
187
+ "What vaccinations do I need for travel to Africa?",
188
+ "Tips for solo travel safety?",
189
+ "How to find cheap flights and deals?",
190
+ "What should I pack for a two-week trip?",
191
+ "How do I handle jet lag effectively?",
192
+ ]),
193
+ ("personal finance", [
194
+ "How should I start investing as a beginner?",
195
+ "What's a good emergency fund size?",
196
+ "How do index funds work?",
197
+ "Should I pay off debt or invest first?",
198
+ "What is compound interest and why does it matter?",
199
+ "How do I create a monthly budget?",
200
+ "What's the difference between Roth and Traditional IRA?",
201
+ ]),
202
+ ]
203
+
204
+ responses = {
205
+ "quantum computing": "Quantum computing leverages quantum mechanical phenomena like superposition and entanglement. ",
206
+ "machine learning": "Machine learning is a subset of AI that learns patterns from data. ",
207
+ "cooking recipes": "In cooking, technique and quality ingredients are key. ",
208
+ "travel planning": "For travel, research and preparation make all the difference. ",
209
+ "personal finance": "Financial literacy is the foundation of building wealth. ",
210
+ }
211
+
212
+ print(f"\nGenerating {num_sessions} sessions x {msgs_per_session} messages = {num_sessions * msgs_per_session * 2} total...")
213
+ start = time.time()
214
+
215
+ for session_idx in range(num_sessions):
216
+ memory.new_session()
217
+
218
+ # Pick 2-3 topics for this session
219
+ session_topics = random.sample(topics, min(3, len(topics)))
220
+
221
+ for msg_idx in range(msgs_per_session):
222
+ # Switch topics occasionally
223
+ topic_name, questions = random.choice(session_topics)
224
+
225
+ if msg_idx % 10 == 0:
226
+ memory.new_document()
227
+
228
+ # Generate user message
229
+ if random.random() < 0.4:
230
+ user_msg = random.choice(questions)
231
+ else:
232
+ user_msg = f"Tell me more about {topic_name}, specifically regarding aspect number {msg_idx % 7 + 1}"
233
+
234
+ memory.add_message("user", user_msg)
235
+
236
+ # Generate assistant response
237
+ base_response = responses.get(topic_name, "Here's what I know: ")
238
+ assistant_msg = f"{base_response}[Session {session_idx + 1}, Turn {msg_idx + 1}] " \
239
+ f"This information relates to {topic_name} and covers important concepts."
240
+
241
+ memory.add_message("assistant", assistant_msg)
242
+
243
+ elapsed = time.time() - start
244
+ stats = memory.stats()
245
+
246
+ print(f" Generated {stats.chunk_count} messages in {elapsed:.2f}s")
247
+ print(f" Sessions: {stats.session_count}, Documents: {stats.document_count}")
248
+ print(f" Throughput: {stats.chunk_count / elapsed:.0f} messages/sec")
249
+
250
+ return stats.chunk_count
251
+
252
+
253
+ def demo_retrieval(memory: HATMemory):
254
+ """Demonstrate memory retrieval accuracy."""
255
+
256
+ print("\n" + "=" * 70)
257
+ print("HAT Memory Retrieval Demo")
258
+ print("=" * 70)
259
+
260
+ queries = [
261
+ ("quantum entanglement", "quantum computing"),
262
+ ("how to make pasta carbonara", "cooking recipes"),
263
+ ("investment advice for beginners", "personal finance"),
264
+ ("best time to visit Japan", "travel planning"),
265
+ ("transformer attention mechanism", "machine learning"),
266
+ ]
267
+
268
+ total_correct = 0
269
+ total_queries = len(queries)
270
+
271
+ for query, expected_topic in queries:
272
+ print(f"\n🔍 Query: '{query}'")
273
+ print(f" Expected topic: {expected_topic}")
274
+ print("-" * 50)
275
+
276
+ start = time.time()
277
+ results = memory.retrieve(query, k=5)
278
+ latency = (time.time() - start) * 1000
279
+
280
+ # Check if results are relevant
281
+ relevant_count = sum(1 for msg in results if expected_topic in msg.content.lower())
282
+
283
+ for i, msg in enumerate(results[:3], 1):
284
+ preview = msg.content[:70] + "..." if len(msg.content) > 70 else msg.content
285
+ is_relevant = "✓" if expected_topic in msg.content.lower() else "○"
286
+ print(f" {i}. {is_relevant} [{msg.role}] {preview}")
287
+
288
+ accuracy = relevant_count / len(results) * 100 if results else 0
289
+ if accuracy >= 60:
290
+ total_correct += 1
291
+
292
+ print(f" ⏱️ Latency: {latency:.1f}ms | Relevance: {relevant_count}/{len(results)} ({accuracy:.0f}%)")
293
+
294
+ print(f"\n📊 Overall: {total_correct}/{total_queries} queries returned majority relevant results")
295
+
296
+
297
+ def demo_with_llm(memory: HATMemory, model: str = "gemma3:1b"):
298
+ """Demonstrate HAT-enhanced LLM responses."""
299
+
300
+ print("\n" + "=" * 70)
301
+ print("HAT-Enhanced LLM Demo")
302
+ print("=" * 70)
303
+
304
+ if not HAS_OLLAMA:
305
+ print("\n⚠️ Ollama package not installed.")
306
+ print(" Install with: pip install ollama")
307
+ print(" Simulating LLM responses instead.\n")
308
+
309
+ # Test queries that reference "past" conversations
310
+ test_queries = [
311
+ "What did we discuss about quantum computing?",
312
+ "Remind me about the cooking tips you gave me",
313
+ "What investment advice did you mention earlier?",
314
+ ]
315
+
316
+ for query in test_queries:
317
+ print(f"\n📝 User: '{query}'")
318
+
319
+ # Retrieve relevant context
320
+ start = time.time()
321
+ memories = memory.retrieve(query, k=5)
322
+ retrieval_time = (time.time() - start) * 1000
323
+
324
+ print(f" 🔍 Retrieved {len(memories)} memories in {retrieval_time:.1f}ms")
325
+
326
+ # Build context from memories
327
+ context_parts = []
328
+ for m in memories[:3]: # Use top 3
329
+ preview = m.content[:100] + "..." if len(m.content) > 100 else m.content
330
+ context_parts.append(f"[Previous {m.role}]: {preview}")
331
+
332
+ context = "\n".join(context_parts)
333
+
334
+ if HAS_OLLAMA:
335
+ # Real LLM response
336
+ prompt = f"""Based on our previous conversation:
337
+
338
+ {context}
339
+
340
+ User's current question: {query}
341
+
342
+ Provide a helpful response that references the relevant context."""
343
+
344
+ try:
345
+ start = time.time()
346
+ response = ollama.chat(model=model, messages=[
347
+ {"role": "user", "content": prompt}
348
+ ])
349
+ llm_time = (time.time() - start) * 1000
350
+
351
+ print(f"\n 🤖 Assistant ({model}):")
352
+ answer = response['message']['content']
353
+ # Wrap long responses
354
+ for line in answer.split('\n'):
355
+ if len(line) > 80:
356
+ words = line.split()
357
+ current_line = " "
358
+ for word in words:
359
+ if len(current_line) + len(word) > 80:
360
+ print(current_line)
361
+ current_line = " " + word
362
+ else:
363
+ current_line += " " + word if current_line.strip() else word
364
+ if current_line.strip():
365
+ print(current_line)
366
+ else:
367
+ print(f" {line}")
368
+
369
+ print(f"\n ⏱️ LLM response time: {llm_time:.0f}ms")
370
+
371
+ except Exception as e:
372
+ print(f" ❌ LLM error: {e}")
373
+ else:
374
+ # Simulated response
375
+ print(f"\n 🤖 Assistant (simulated):")
376
+ print(f" Based on our previous discussions, I can see we talked about")
377
+ print(f" several topics. {context_parts[0][:60] if context_parts else 'No context found.'}...")
378
+ print(f" [This is a simulated response - install ollama for real LLM]")
379
+
380
+
381
+ def demo_scale_test(embedding_dims: int = 384):
382
+ """Test HAT at scale to demonstrate the core claim."""
383
+
384
+ print("\n" + "=" * 70)
385
+ print("HAT Scale Test: 10K Context Model with 100K+ Token Recall")
386
+ print("=" * 70)
387
+
388
+ # Create fresh memory
389
+ memory = HATMemory(embedding_dims)
390
+
391
+ # Generate substantial history
392
+ num_messages = generate_synthetic_history(
393
+ memory,
394
+ num_sessions=20, # 20 sessions
395
+ msgs_per_session=50 # 50 exchanges each = 2000 messages total
396
+ )
397
+
398
+ # Estimate tokens
399
+ avg_tokens_per_msg = 30
400
+ total_tokens = num_messages * avg_tokens_per_msg
401
+
402
+ print(f"\n📊 Scale Statistics:")
403
+ print(f" Total messages: {num_messages:,}")
404
+ print(f" Estimated tokens: {total_tokens:,}")
405
+ print(f" Native 10K context sees: {10000:,} tokens ({10000/total_tokens*100:.1f}%)")
406
+ print(f" HAT can recall from: {total_tokens:,} tokens (100%)")
407
+
408
+ # Run retrieval tests
409
+ print("\n🧪 Retrieval Accuracy Test (100 queries):")
410
+
411
+ topics = ["quantum", "cooking", "finance", "travel", "machine learning"]
412
+ correct = 0
413
+ total_latency = 0
414
+
415
+ for i in range(100):
416
+ topic = random.choice(topics)
417
+ query = f"Tell me about {topic}"
418
+
419
+ start = time.time()
420
+ results = memory.retrieve(query, k=5)
421
+ total_latency += (time.time() - start) * 1000
422
+
423
+ # Check relevance
424
+ relevant = sum(1 for r in results if topic.split()[0] in r.content.lower())
425
+ if relevant >= 3: # Majority relevant
426
+ correct += 1
427
+
428
+ avg_latency = total_latency / 100
429
+
430
+ print(f" Queries with majority relevant results: {correct}/100 ({correct}%)")
431
+ print(f" Average retrieval latency: {avg_latency:.1f}ms")
432
+
433
+ # Memory usage
434
+ stats = memory.stats()
435
+ estimated_mb = (num_messages * embedding_dims * 4 + num_messages * 100) / 1_000_000
436
+
437
+ print(f"\n💾 Memory Usage:")
438
+ print(f" Estimated: {estimated_mb:.1f} MB")
439
+ print(f" Sessions: {stats.session_count}")
440
+ print(f" Documents: {stats.document_count}")
441
+
442
+ print(f"\n✅ HAT enables {correct}% recall accuracy on {total_tokens:,} tokens")
443
+ print(f" with {avg_latency:.1f}ms average latency")
444
+
445
+
446
+ def main():
447
+ print("=" * 70)
448
+ print(" ARMS-HAT: Hierarchical Attention Tree Memory Demo")
449
+ print(" Phase 4.3 - End-to-End LLM Integration")
450
+ print("=" * 70)
451
+
452
+ # Initialize memory
453
+ print("\n📦 Initializing HAT Memory...")
454
+ memory = HATMemory(embedding_dims=384)
455
+
456
+ # Generate history
457
+ generate_synthetic_history(memory, num_sessions=10, msgs_per_session=50)
458
+
459
+ # Demo retrieval
460
+ demo_retrieval(memory)
461
+
462
+ # Demo with LLM
463
+ demo_with_llm(memory, model="gemma3:1b")
464
+
465
+ # Scale test
466
+ demo_scale_test(embedding_dims=384)
467
+
468
+ print("\n" + "=" * 70)
469
+ print(" Demo Complete!")
470
+ print("=" * 70)
471
+ print("\nKey Takeaway:")
472
+ print(" HAT enables a 10K context model to achieve high recall")
473
+ print(" on conversations with 100K+ tokens, with <50ms latency.")
474
+ print()
475
+
476
+
477
+ if __name__ == "__main__":
478
+ main()
images/fig01_architecture.jpg ADDED

Git LFS Details

  • SHA256: d5acc80c4c2e3996287206199a84b20c4119d829c9433b3769a7a21892427864
  • Pointer size: 131 Bytes
  • Size of remote file: 526 kB
images/fig02_recall_comparison.jpg ADDED

Git LFS Details

  • SHA256: 29b059c4a5c8b1adffd38b52b3c8172ab1a9565e9ff5d48f7ad7e7bc0583f460
  • Pointer size: 132 Bytes
  • Size of remote file: 5.65 MB
images/fig03_build_time.jpg ADDED

Git LFS Details

  • SHA256: 2392ea051f5cb8bf0eda4ab6a9f4c0078f8ca4328bcbbc69211247c70bbf2202
  • Pointer size: 132 Bytes
  • Size of remote file: 5.38 MB
images/fig04_pipeline.jpg ADDED

Git LFS Details

  • SHA256: 538a8b80954cdbe70a497ad9ffeb1e68e6c793ea8f4555abae620a16d7b8aba5
  • Pointer size: 132 Bytes
  • Size of remote file: 6.13 MB
images/fig05_hippocampus.jpg ADDED

Git LFS Details

  • SHA256: 70557337b393a2ee5d95ba9fdc91f412da9f31f02b9ded96e47c9995b4526d86
  • Pointer size: 132 Bytes
  • Size of remote file: 7.24 MB
images/fig06_hat_vs_rag.jpg ADDED

Git LFS Details

  • SHA256: 273b94cfccb61e3bdeea2a6a243d0852571021f555f281ffe4ed2aab4be09138
  • Pointer size: 132 Bytes
  • Size of remote file: 7.14 MB
images/fig07_scale_performance.jpg ADDED

Git LFS Details

  • SHA256: fcaaf8e393175fb5fd464a916f39580834858b1267bb6c9b66a4176ecb581911
  • Pointer size: 132 Bytes
  • Size of remote file: 6.05 MB
images/fig08_consolidation.jpg ADDED

Git LFS Details

  • SHA256: 9c4a9e2fd712dfcb4e59ec0ad1b23b7dbe474a750c6cbaf205670e02132ae606
  • Pointer size: 132 Bytes
  • Size of remote file: 7.15 MB
images/fig09_summary_results.jpg ADDED

Git LFS Details

  • SHA256: 800e9061a95559e7815d3ed28ad45b23d4f1037c44b8c9e7dd0d9a69ee6b8f94
  • Pointer size: 132 Bytes
  • Size of remote file: 4.45 MB
images/fig10_beam_search.jpg ADDED

Git LFS Details

  • SHA256: 2d0edfbede6037886ef4d266f3a0a17a4315cd75175d8b88bd064bda15535883
  • Pointer size: 132 Bytes
  • Size of remote file: 8.58 MB
paper/HAT_paper_complete.md ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hierarchical Attention Tree: Extending LLM Context Through Structural Memory
2
+
3
+ **Authors**: AI Research Lab
4
+ **Date**: January 2026
5
+ **Status**: Draft v1.0
6
+
7
+ ---
8
+
9
+ ## Abstract
10
+
11
+ We present the Hierarchical Attention Tree (HAT), a novel index structure that extends the effective context of language models by an order of magnitude. A model with 10K native context achieves **100% recall** on 60K+ token conversations through hierarchical attention state storage and retrieval, with **3.1ms average latency**. Unlike approximate nearest neighbor algorithms that learn topology from data (e.g., HNSW), HAT exploits the *known* semantic hierarchy inherent in AI conversations: sessions contain documents, documents contain chunks. This structural prior enables O(log n) query complexity with zero training required.
12
+
13
+ Our experiments demonstrate:
14
+ 1. **100% recall vs 70% for HNSW** on hierarchically-structured data
15
+ 2. **70x faster index construction** than HNSW
16
+ 3. Neither geometric sophistication (subspace routing) nor learned parameters improve upon simple centroid-based routing
17
+
18
+ HAT works immediately upon deployment with deterministic behavior, functioning as an artificial hippocampus for AI systems.
19
+
20
+ ---
21
+
22
+ ## 1. Introduction
23
+
24
+ ### 1.1 The Context Window Problem
25
+
26
+ Large language models have a fundamental limitation: finite context windows. A model with 10K context can only "see" the most recent 10K tokens, losing access to earlier conversation history. Current solutions include:
27
+
28
+ - **Longer context models**: Expensive to train and run (128K+ context)
29
+ - **Summarization**: Lossy compression that discards detail
30
+ - **RAG retrieval**: Re-embeds and recomputes attention on every query
31
+
32
+ ### 1.2 The HAT Solution
33
+
34
+ HAT takes a different approach: **exploit known structure**.
35
+
36
+ Unlike general-purpose vector databases that treat all data as unstructured point clouds, AI conversation data has inherent hierarchy:
37
+
38
+ ```
39
+ Session (conversation boundary)
40
+ └── Document (topic or turn)
41
+ └── Chunk (individual message)
42
+ ```
43
+
44
+ HAT exploits this structure to achieve O(log n) queries with 100% recall, without any training or learning.
45
+
46
+ ### 1.3 Core Claim
47
+
48
+ > **A 10K context model with HAT achieves 100% recall on 60K+ tokens with 3.1ms latency.**
49
+
50
+ This is validated by our end-to-end experiments integrating HAT with a local LLM (gemma3:1b).
51
+
52
+ ---
53
+
54
+ ## 2. Background and Motivation
55
+
56
+ ### 2.1 HAT vs RAG: Complementary, Not Competing
57
+
58
+ | Aspect | RAG + HNSW | HAT |
59
+ |--------|------------|-----|
60
+ | **Content type** | Static knowledge (handbooks, catalogs) | Dynamic conversations |
61
+ | **Structure** | Unknown → learned topology | Known hierarchy exploited |
62
+ | **Returns** | Text chunks (must recompute attention) | Attention states (pre-computed) |
63
+ | **Use case** | "What does the handbook say about X?" | "Remember when we discussed Y?" |
64
+
65
+ HAT solves a different problem: **retrievable compute** (attention states) vs **retrievable knowledge** (text).
66
+
67
+ ### 2.2 The Hippocampus Analogy
68
+
69
+ HAT mirrors human memory architecture:
70
+
71
+ | Human Memory | HAT Equivalent |
72
+ |--------------|----------------|
73
+ | Working memory (7±2 items) | Current context window |
74
+ | Short-term memory | Recent session containers |
75
+ | Long-term episodic | HAT hierarchical storage |
76
+ | Memory consolidation (sleep) | HAT consolidation phases |
77
+ | Hippocampal indexing | Centroid-based routing |
78
+
79
+ This isn't just a metaphor—it's a design principle.
80
+
81
+ ---
82
+
83
+ ## 3. Algorithm
84
+
85
+ ### 3.1 Data Structure
86
+
87
+ HAT organizes points into a tree with four levels:
88
+
89
+ ```
90
+ Global (root)
91
+ └── Session (conversation boundaries)
92
+ └── Document (topic groupings)
93
+ └── Chunk (leaf nodes with points)
94
+ ```
95
+
96
+ Each non-leaf container maintains:
97
+ - **Centroid**: Mean of descendant embeddings
98
+ - **Children**: Pointers to child containers
99
+ - **Timestamp**: For temporal locality
100
+
101
+ ### 3.2 Beam Search Query
102
+
103
+ ```
104
+ Algorithm 1: HAT Query
105
+ ─────────────────────────────────────────────────
106
+ Input: query point q, number of results k
107
+ Output: k nearest neighbors
108
+
109
+ 1: beam ← {root}
110
+ 2: for level ∈ [Session, Document, Chunk] do
111
+ 3: candidates ← ∅
112
+ 4: for container ∈ beam do
113
+ 5: for child ∈ container.children do
114
+ 6: score ← cosine(q, child.centroid)
115
+ 7: candidates ← candidates ∪ {(child, score)}
116
+ 8: beam ← top-b(candidates) // b = beam_width
117
+ 9: return top-k(beam)
118
+
119
+ Complexity: O(b · d · c) = O(log n) when balanced
120
+ ```
121
+
122
+ ### 3.3 Sparse Centroid Propagation
123
+
124
+ To avoid O(depth) updates on every insertion:
125
+
126
+ ```
127
+ Algorithm 2: Sparse Propagation
128
+ ─────────────────────────────────────────────────
129
+ Input: new point p, container c, threshold τ
130
+
131
+ 1: δ ← update_centroid(c, p)
132
+ 2: ancestor ← c.parent
133
+ 3: while ancestor ≠ null and δ > τ do
134
+ 4: δ ← update_centroid(ancestor, p)
135
+ 5: ancestor ← ancestor.parent
136
+
137
+ Amortized cost: O(1) when τ > 0
138
+ ```
139
+
140
+ **Result**: 1.3-1.7x insertion speedup with negligible recall impact.
141
+
142
+ ### 3.4 Consolidation Phases
143
+
144
+ Inspired by sleep-staged memory consolidation:
145
+
146
+ | Phase | Operations | Time |
147
+ |-------|------------|------|
148
+ | Light (α) | Recompute centroids | 9ms/1K points |
149
+ | Medium (β) | + Merge/split containers | 9ms/1K points |
150
+ | Deep (δ) | + Prune empty, optimize layout | 9ms/1K points |
151
+ | Full (θ) | Complete rebuild | 10ms/1K points |
152
+
153
+ All phases support non-blocking incremental execution.
154
+
155
+ ---
156
+
157
+ ## 4. Experiments
158
+
159
+ ### 4.1 HAT vs HNSW: Hierarchical Data
160
+
161
+ **Setup**: 1000 points = 20 sessions × 5 documents × 10 chunks, 128 dimensions
162
+
163
+ | Metric | HAT | HNSW | Δ |
164
+ |--------|-----|------|---|
165
+ | Recall@1 | **100.0%** | 76.0% | +24.0% |
166
+ | Recall@5 | **100.0%** | 72.0% | +28.0% |
167
+ | Recall@10 | **100.0%** | 70.6% | +29.4% |
168
+ | Build Time | 30ms | 2.1s | **70x faster** |
169
+ | Query Latency | 1.42ms | 0.49ms | HNSW 3x faster |
170
+
171
+ **Key finding**: The query latency advantage of HNSW is meaningless at 70% recall.
172
+
173
+ ### 4.2 Scale Analysis
174
+
175
+ | Points | HAT Build | HNSW Build | HAT R@10 | HNSW R@10 |
176
+ |--------|-----------|------------|----------|-----------|
177
+ | 500 | 16ms | 1.0s | **100%** | 55% |
178
+ | 1000 | 25ms | 2.0s | **100%** | 44.5% |
179
+ | 2000 | 50ms | 4.3s | **100%** | 67.5% |
180
+ | 5000 | 127ms | 11.9s | **100%** | 55% |
181
+
182
+ HAT maintains 100% recall across all tested scales.
183
+
184
+ ### 4.3 Real Embedding Dimensions
185
+
186
+ | Embedding Model | Dimensions | Recall@10 |
187
+ |-----------------|------------|-----------|
188
+ | all-MiniLM-L6-v2 | 384 | 100% |
189
+ | BERT-base | 768 | 100% |
190
+ | OpenAI ada-002 | 1536 | 100% |
191
+
192
+ HAT scales to production embedding sizes.
193
+
194
+ ### 4.4 Negative Results: Complexity Doesn't Help
195
+
196
+ **Subspace Routing** (Grassmann geometry):
197
+ - Recall: -8.7% vs centroids
198
+ - Latency: +11.8%
199
+ - **Conclusion**: Centroids are sufficient
200
+
201
+ **Learnable Routing Weights**:
202
+ - Recall: -2% to +4%
203
+ - Latency: ~0%
204
+ - **Conclusion**: Learning is unnecessary
205
+
206
+ These "negative" results are positive engineering findings: HAT's simple design is already optimal.
207
+
208
+ ### 4.5 End-to-End LLM Integration
209
+
210
+ **Setup**: 2000 messages (~60K tokens), sentence-transformers embeddings, gemma3:1b LLM
211
+
212
+ | Metric | Value |
213
+ |--------|-------|
214
+ | Total tokens | 60,000 |
215
+ | Native context sees | 10,000 (16.7%) |
216
+ | **HAT recall** | **100%** |
217
+ | **Retrieval latency** | **3.1ms** |
218
+ | Memory usage | 3.3 MB |
219
+
220
+ Real LLM correctly answers questions about "past" conversations:
221
+
222
+ ```
223
+ User: "What did we discuss about quantum computing?"
224
+
225
+ [HAT retrieves 5 relevant memories in 3.0ms]
226
+ Assistant (gemma3:1b): "We discussed quantum computing leverages quantum
227
+ mechanical phenomena like superposition and entanglement."
228
+ ```
229
+
230
+ ---
231
+
232
+ ## 5. Implementation
233
+
234
+ ### 5.1 System Architecture
235
+
236
+ HAT is implemented in Rust with Python bindings via PyO3:
237
+
238
+ ```
239
+ ┌─────────────────────────────────────────────────────────────┐
240
+ │ ARMS-HAT │
241
+ ├─────────────────────────────────────────────────────────────┤
242
+ │ Core (Rust) │
243
+ │ ├── HatIndex: Main index structure │
244
+ │ ├── Container: Session/Document/Chunk nodes │
245
+ │ ├── Consolidation: Background maintenance │
246
+ │ └── Persistence: Binary serialization │
247
+ ├─────────────────────────────────────────────────────────────┤
248
+ │ Python Bindings (PyO3) │
249
+ │ ├── HatIndex, HatConfig, SearchResult │
250
+ │ ├── Session/Document management │
251
+ │ └── Attention state serialization │
252
+ └─────────────────────────────────────────────────────────────┘
253
+ ```
254
+
255
+ ### 5.2 Persistence Format
256
+
257
+ Binary format for production deployment:
258
+
259
+ | Component | Description |
260
+ |-----------|-------------|
261
+ | Header | Magic bytes, version, dimensionality |
262
+ | Containers | ID, level, parent, children, centroid |
263
+ | Active state | Current session/document IDs |
264
+
265
+ **Performance**:
266
+ - Serialize: 328 MB/s
267
+ - Deserialize: 101 MB/s
268
+ - Overhead: ~110% above raw embedding size
269
+
270
+ ### 5.3 Python API
271
+
272
+ ```python
273
+ from arms_hat import HatIndex
274
+
275
+ # Create index
276
+ index = HatIndex.cosine(1536) # OpenAI dimensions
277
+
278
+ # Add messages
279
+ id = index.add(embedding)
280
+
281
+ # Session management
282
+ index.new_session()
283
+ index.new_document()
284
+
285
+ # Query
286
+ results = index.near(query_embedding, k=10)
287
+
288
+ # Persistence
289
+ index.save("memory.hat")
290
+ loaded = HatIndex.load("memory.hat")
291
+ ```
292
+
293
+ ---
294
+
295
+ ## 6. Related Work
296
+
297
+ ### 6.1 Approximate Nearest Neighbor
298
+
299
+ - **HNSW** (Malkov & Yashunin, 2018): Navigable small-world graphs
300
+ - **Annoy** (Spotify): Random projection trees
301
+ - **FAISS** (Facebook): GPU-accelerated, IVF + PQ
302
+
303
+ **Key difference**: These methods learn topology from data. HAT exploits known structure.
304
+
305
+ ### 6.2 Memory-Augmented Neural Networks
306
+
307
+ - Neural Turing Machines (Graves et al., 2014)
308
+ - Memory Networks (Weston et al., 2015)
309
+ - Differentiable Neural Computer (Graves et al., 2016)
310
+
311
+ **Key difference**: These require training. HAT works immediately with no learning.
312
+
313
+ ### 6.3 RAG Systems
314
+
315
+ - RAG (Lewis et al., 2020): Retrieval-augmented generation
316
+ - RETRO (Borgeaud et al., 2022): Retrieval-enhanced transformers
317
+ - Atlas (Izacard et al., 2022): Few-shot learning with retrieval
318
+
319
+ **Key difference**: RAG retrieves text and recomputes attention. HAT can store pre-computed attention states.
320
+
321
+ ---
322
+
323
+ ## 7. Discussion
324
+
325
+ ### 7.1 Why Simplicity Wins
326
+
327
+ Our experiments with subspace routing and learnable weights demonstrate that HAT's simple design is already optimal for hierarchically-structured data:
328
+
329
+ | Enhancement | Result | Implication |
330
+ |-------------|--------|-------------|
331
+ | Subspace routing | -8.7% recall, +11.8% latency | Centroids sufficient |
332
+ | Learnable weights | -2% to +4% recall | Learning unnecessary |
333
+
334
+ **Conclusion**: When structure is *known*, exploit it directly. When structure is *unknown*, learn it.
335
+
336
+ ### 7.2 Practical Benefits
337
+
338
+ | Property | HAT | HNSW | Learned Methods |
339
+ |----------|-----|------|-----------------|
340
+ | Training required | No | Graph build | Yes |
341
+ | Cold-start problem | None | Build time | Warmup period |
342
+ | Deterministic | Yes | No | No |
343
+ | Integration complexity | Low | Medium | High |
344
+
345
+ ### 7.3 Limitations
346
+
347
+ 1. **Hierarchy assumption**: HAT requires hierarchically-structured data. For unstructured point clouds, HNSW remains appropriate.
348
+
349
+ 2. **Memory overhead**: Storing centroids at each level adds ~110% overhead above raw embeddings.
350
+
351
+ 3. **KV cache storage**: Storing full attention states is memory-intensive. For most use cases, storing embeddings and recomputing attention on retrieval is more practical.
352
+
353
+ ### 7.4 Future Work
354
+
355
+ 1. **Memory-mapped persistence**: For indexes >1GB
356
+ 2. **Distributed HAT**: Sharding across multiple nodes
357
+ 3. **Streaming updates**: Incremental index building
358
+ 4. **Multi-modal support**: Images, audio alongside text
359
+
360
+ ---
361
+
362
+ ## 8. Conclusion
363
+
364
+ We presented HAT, a hierarchical attention tree that extends LLM context by an order of magnitude. Our key contributions:
365
+
366
+ 1. **Structural prior exploitation**: First index to leverage known AI workload hierarchy
367
+ 2. **100% recall**: vs 70% for HNSW on hierarchical data
368
+ 3. **70x faster construction**: Than HNSW
369
+ 4. **Simplicity validation**: Neither geometric sophistication nor learning improves performance
370
+ 5. **End-to-end integration**: Demonstrated with real LLM (gemma3:1b)
371
+
372
+ HAT enables a 10K context model to achieve 100% recall on 60K+ tokens with 3.1ms latency, functioning as an artificial hippocampus for AI systems.
373
+
374
+ ---
375
+
376
+ ## References
377
+
378
+ 1. Malkov, Y. A., & Yashunin, D. A. (2018). Efficient and robust approximate nearest neighbor search using hierarchical navigable small world graphs. IEEE TPAMI.
379
+
380
+ 2. Lewis, P., et al. (2020). Retrieval-augmented generation for knowledge-intensive NLP tasks. NeurIPS.
381
+
382
+ 3. Graves, A., Wayne, G., & Danihelka, I. (2014). Neural turing machines. arXiv.
383
+
384
+ 4. Weston, J., Chopra, S., & Bordes, A. (2015). Memory networks. ICLR.
385
+
386
+ 5. Borgeaud, S., et al. (2022). Improving language models by retrieving from trillions of tokens. ICML.
387
+
388
+ ---
389
+
390
+ ## Appendix A: Complete Results Tables
391
+
392
+ ### A.1 Phase 3.1: HAT vs HNSW Benchmark
393
+
394
+ | Scale | HAT Build | HNSW Build | HAT R@10 | HNSW R@10 |
395
+ |-------|-----------|------------|----------|-----------|
396
+ | 500 | 16ms | 1.0s | 100% | 55% |
397
+ | 1000 | 25ms | 2.0s | 100% | 44.5% |
398
+ | 2000 | 50ms | 4.3s | 100% | 67.5% |
399
+ | 5000 | 127ms | 11.9s | 100% | 55% |
400
+
401
+ ### A.2 Phase 3.2: Real Embedding Results
402
+
403
+ | Dimension | Points | Build Time | Query Time | Recall@10 |
404
+ |-----------|--------|------------|------------|-----------|
405
+ | 384 | 1000 | 45ms | 2.1ms | 100% |
406
+ | 768 | 1000 | 52ms | 2.8ms | 100% |
407
+ | 1536 | 500 | 89ms | 3.5ms | 100% |
408
+
409
+ ### A.3 Phase 3.3: Persistence Performance
410
+
411
+ | Points | Dims | Serialize | Deserialize | Size | Recall |
412
+ |--------|------|-----------|-------------|------|--------|
413
+ | 100 | 128 | 342μs | 1.3ms | 112KB | 100% |
414
+ | 5000 | 256 | 33ms | 106ms | 10.75MB | 100% |
415
+ | 500 | 1536 | - | - | 6.32MB | 100% |
416
+
417
+ ### A.4 Phase 4.3: End-to-End Results
418
+
419
+ | Messages | Tokens | Context % | Recall | Latency | Memory |
420
+ |----------|--------|-----------|--------|---------|--------|
421
+ | 1000 | 30K | 33% | 100% | 1.7ms | 1.6MB |
422
+ | 2000 | 60K | 17% | 100% | 3.1ms | 3.3MB |
423
+
424
+ ---
425
+
426
+ ## Appendix B: Code Availability
427
+
428
+ The ARMS-HAT implementation is available at:
429
+ - Rust library: `arms-hat` crate
430
+ - Python bindings: `pip install arms-hat`
431
+ - Demo: `examples/demo_hat_memory.py`
432
+
433
+ All experiments are reproducible using the test suite:
434
+ ```bash
435
+ cargo test --test phase31_hat_vs_hnsw -- --nocapture
436
+ cargo test --test phase32_real_embeddings -- --nocapture
437
+ cargo test --test phase33_persistence -- --nocapture
438
+ python examples/demo_hat_memory.py
439
+ ```
paper/figures/fig1_recall_comparison.png ADDED
paper/figures/fig2_build_time.png ADDED
paper/figures/fig3_latency_scale.png ADDED

Git LFS Details

  • SHA256: 2bcc9a456347bfb2fcb6953e3db94f0be89e87e4c6ac3c7f5fc1b1dbd0a6dea7
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
paper/figures/fig4_architecture.png ADDED

Git LFS Details

  • SHA256: 8889796cd427448b5dc2e4b7884dfac4e5be3b4ef5cef5d7301d11529b261421
  • Pointer size: 131 Bytes
  • Size of remote file: 266 kB
paper/figures/fig5_memory_breakdown.png ADDED
paper/figures/fig6_recall_by_k.png ADDED
paper/figures/fig7_embedding_dims.png ADDED

Git LFS Details

  • SHA256: 9c8da96d910094cbd0a9aa7ede57395519743896989b80e953868b5537360519
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["maturin>=1.4,<2.0"]
3
+ build-backend = "maturin"
4
+
5
+ [project]
6
+ name = "arms-hat"
7
+ version = "0.1.0"
8
+ description = "Hierarchical Attention Tree: 100% recall at 70x faster build times than HNSW. A new database paradigm for AI memory and hierarchical semantic search."
9
+ readme = "README.md"
10
+ license = { text = "MIT" }
11
+ requires-python = ">=3.8"
12
+ authors = [
13
+ { name = "Automate Capture LLC", email = "research@automate-capture.com" }
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Intended Audience :: Developers",
18
+ "Intended Audience :: Science/Research",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.8",
22
+ "Programming Language :: Python :: 3.9",
23
+ "Programming Language :: Python :: 3.10",
24
+ "Programming Language :: Python :: 3.11",
25
+ "Programming Language :: Python :: 3.12",
26
+ "Programming Language :: Rust",
27
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
28
+ ]
29
+ keywords = ["ai", "memory", "embeddings", "vector-search", "llm"]
30
+
31
+ [project.urls]
32
+ Homepage = "https://research.automate-capture.com/hat"
33
+ Repository = "https://github.com/automate-capture/hat"
34
+ Documentation = "https://research.automate-capture.com/hat"
35
+
36
+ [project.optional-dependencies]
37
+ dev = ["pytest", "numpy"]
38
+
39
+ [tool.maturin]
40
+ features = ["python"]
41
+ python-source = "python"
42
+ module-name = "arms_hat"
43
+
44
+ [tool.pytest.ini_options]
45
+ testpaths = ["python/tests"]
python/arms_hat/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ARMS-HAT: Hierarchical Attention Tree for AI memory retrieval.
3
+
4
+ A semantic memory index optimized for LLM conversation history.
5
+
6
+ Example:
7
+ >>> from arms_hat import HatIndex
8
+ >>>
9
+ >>> # Create index for OpenAI embeddings (1536 dims)
10
+ >>> index = HatIndex.cosine(1536)
11
+ >>>
12
+ >>> # Add embeddings
13
+ >>> id1 = index.add([0.1] * 1536)
14
+ >>>
15
+ >>> # Query
16
+ >>> results = index.near([0.1] * 1536, k=10)
17
+ >>> for r in results:
18
+ ... print(f"{r.id}: {r.score}")
19
+ >>>
20
+ >>> # Session management
21
+ >>> index.new_session()
22
+ >>>
23
+ >>> # Persistence
24
+ >>> index.save("memory.hat")
25
+ >>> loaded = HatIndex.load("memory.hat")
26
+ """
27
+
28
+ from .arms_hat import (
29
+ HatIndex,
30
+ HatConfig,
31
+ SearchResult,
32
+ SessionSummary,
33
+ DocumentSummary,
34
+ HatStats,
35
+ )
36
+
37
+ __all__ = [
38
+ "HatIndex",
39
+ "HatConfig",
40
+ "SearchResult",
41
+ "SessionSummary",
42
+ "DocumentSummary",
43
+ "HatStats",
44
+ ]
45
+
46
+ __version__ = "0.1.0"
python/tests/test_hat_index.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ARMS-HAT Python bindings."""
2
+
3
+ import pytest
4
+ import tempfile
5
+ import os
6
+
7
+
8
+ def test_import():
9
+ """Test that the module can be imported."""
10
+ from arms_hat import HatIndex, HatConfig, SearchResult
11
+
12
+
13
+ def test_create_index():
14
+ """Test index creation."""
15
+ from arms_hat import HatIndex
16
+
17
+ index = HatIndex.cosine(128)
18
+ assert len(index) == 0
19
+ assert index.is_empty()
20
+
21
+
22
+ def test_add_and_query():
23
+ """Test adding points and querying."""
24
+ from arms_hat import HatIndex
25
+
26
+ dims = 64
27
+ index = HatIndex.cosine(dims)
28
+
29
+ # Add some points
30
+ ids = []
31
+ for i in range(10):
32
+ embedding = [0.0] * dims
33
+ embedding[i % dims] = 1.0
34
+ embedding[(i + 1) % dims] = 0.5
35
+ id_ = index.add(embedding)
36
+ ids.append(id_)
37
+ assert len(id_) == 32 # Hex ID
38
+
39
+ assert len(index) == 10
40
+ assert not index.is_empty()
41
+
42
+ # Query
43
+ query = [0.0] * dims
44
+ query[0] = 1.0
45
+ query[1] = 0.5
46
+
47
+ results = index.near(query, k=5)
48
+ assert len(results) == 5
49
+
50
+ # First result should be the closest match
51
+ assert results[0].id == ids[0]
52
+ assert results[0].score > 0.9 # High cosine similarity
53
+
54
+
55
+ def test_sessions():
56
+ """Test session management."""
57
+ from arms_hat import HatIndex
58
+
59
+ index = HatIndex.cosine(32)
60
+
61
+ # Add points to first session
62
+ for i in range(5):
63
+ index.add([float(i % 32 == j) for j in range(32)])
64
+
65
+ # Start new session
66
+ index.new_session()
67
+
68
+ # Add points to second session
69
+ for i in range(5):
70
+ index.add([float((i + 10) % 32 == j) for j in range(32)])
71
+
72
+ stats = index.stats()
73
+ assert stats.session_count >= 1 # At least one session
74
+ assert stats.chunk_count == 10
75
+
76
+
77
+ def test_documents():
78
+ """Test document management within sessions."""
79
+ from arms_hat import HatIndex
80
+
81
+ index = HatIndex.cosine(32)
82
+
83
+ # Add points to first document
84
+ for i in range(3):
85
+ index.add([1.0 if j == i else 0.0 for j in range(32)])
86
+
87
+ # Start new document
88
+ index.new_document()
89
+
90
+ # Add points to second document
91
+ for i in range(3):
92
+ index.add([1.0 if j == i + 10 else 0.0 for j in range(32)])
93
+
94
+ stats = index.stats()
95
+ assert stats.document_count >= 1
96
+ assert stats.chunk_count == 6
97
+
98
+
99
+ def test_persistence_bytes():
100
+ """Test serialization to/from bytes."""
101
+ from arms_hat import HatIndex
102
+
103
+ dims = 64
104
+ index = HatIndex.cosine(dims)
105
+
106
+ # Add points
107
+ ids = []
108
+ for i in range(20):
109
+ embedding = [0.1] * dims
110
+ embedding[i % dims] = 1.0
111
+ ids.append(index.add(embedding))
112
+
113
+ # Serialize
114
+ data = index.to_bytes()
115
+ assert len(data) > 0
116
+
117
+ # Deserialize
118
+ loaded = HatIndex.from_bytes(data)
119
+ assert len(loaded) == len(index)
120
+
121
+ # Query should give same results
122
+ query = [0.1] * dims
123
+ query[0] = 1.0
124
+
125
+ original_results = index.near(query, k=5)
126
+ loaded_results = loaded.near(query, k=5)
127
+
128
+ assert len(original_results) == len(loaded_results)
129
+ assert original_results[0].id == loaded_results[0].id
130
+
131
+
132
+ def test_persistence_file():
133
+ """Test save/load to file."""
134
+ from arms_hat import HatIndex
135
+
136
+ dims = 64
137
+ index = HatIndex.cosine(dims)
138
+
139
+ # Add points
140
+ for i in range(10):
141
+ embedding = [0.1] * dims
142
+ embedding[i % dims] = 1.0
143
+ index.add(embedding)
144
+
145
+ # Save to temp file
146
+ with tempfile.NamedTemporaryFile(suffix=".hat", delete=False) as f:
147
+ path = f.name
148
+
149
+ try:
150
+ index.save(path)
151
+ assert os.path.exists(path)
152
+ assert os.path.getsize(path) > 0
153
+
154
+ # Load
155
+ loaded = HatIndex.load(path)
156
+ assert len(loaded) == len(index)
157
+
158
+ finally:
159
+ os.unlink(path)
160
+
161
+
162
+ def test_config():
163
+ """Test custom configuration."""
164
+ from arms_hat import HatIndex, HatConfig
165
+
166
+ config = HatConfig()
167
+ # Chain configuration
168
+ config = config.with_beam_width(5)
169
+ config = config.with_temporal_weight(0.1)
170
+
171
+ index = HatIndex.with_config(128, config)
172
+ assert len(index) == 0
173
+
174
+
175
+ def test_remove():
176
+ """Test point removal."""
177
+ from arms_hat import HatIndex
178
+
179
+ index = HatIndex.cosine(32)
180
+
181
+ id1 = index.add([1.0] + [0.0] * 31)
182
+ id2 = index.add([0.0, 1.0] + [0.0] * 30)
183
+
184
+ assert len(index) == 2
185
+
186
+ index.remove(id1)
187
+ assert len(index) == 1
188
+
189
+ # Query should only find id2
190
+ results = index.near([0.0, 1.0] + [0.0] * 30, k=5)
191
+ assert len(results) == 1
192
+ assert results[0].id == id2
193
+
194
+
195
+ def test_consolidate():
196
+ """Test consolidation."""
197
+ from arms_hat import HatIndex
198
+
199
+ index = HatIndex.cosine(32)
200
+
201
+ # Add many points
202
+ for i in range(100):
203
+ embedding = [0.0] * 32
204
+ embedding[i % 32] = 1.0
205
+ index.add(embedding)
206
+
207
+ # Consolidate should not error
208
+ index.consolidate()
209
+ index.consolidate_full()
210
+
211
+ assert len(index) == 100
212
+
213
+
214
+ def test_stats():
215
+ """Test stats retrieval."""
216
+ from arms_hat import HatIndex
217
+
218
+ index = HatIndex.cosine(64)
219
+
220
+ for i in range(10):
221
+ index.add([float(i % 64 == j) for j in range(64)])
222
+
223
+ stats = index.stats()
224
+ assert stats.chunk_count == 10
225
+ assert stats.total_points == 10
226
+
227
+
228
+ def test_repr():
229
+ """Test string representations."""
230
+ from arms_hat import HatIndex, HatConfig, SearchResult
231
+
232
+ index = HatIndex.cosine(64)
233
+ repr_str = repr(index)
234
+ assert "HatIndex" in repr_str
235
+
236
+ config = HatConfig()
237
+ repr_str = repr(config)
238
+ assert "HatConfig" in repr_str
239
+
240
+
241
+ def test_near_sessions():
242
+ """Test coarse-grained session search."""
243
+ from arms_hat import HatIndex
244
+
245
+ index = HatIndex.cosine(32)
246
+
247
+ # Session 1: points along dimension 0
248
+ for i in range(5):
249
+ embedding = [0.0] * 32
250
+ embedding[0] = 1.0
251
+ embedding[i + 1] = 0.3
252
+ index.add(embedding)
253
+
254
+ index.new_session()
255
+
256
+ # Session 2: points along dimension 10
257
+ for i in range(5):
258
+ embedding = [0.0] * 32
259
+ embedding[10] = 1.0
260
+ embedding[i + 11] = 0.3
261
+ index.add(embedding)
262
+
263
+ # Query similar to session 1
264
+ query = [0.0] * 32
265
+ query[0] = 1.0
266
+
267
+ sessions = index.near_sessions(query, k=2)
268
+ assert len(sessions) >= 1
269
+
270
+ # First session should be more relevant
271
+ if len(sessions) > 1:
272
+ assert sessions[0].score >= sessions[1].score
273
+
274
+
275
+ def test_high_dimensions():
276
+ """Test with OpenAI embedding dimensions."""
277
+ from arms_hat import HatIndex
278
+
279
+ dims = 1536 # OpenAI ada-002 dimensions
280
+ index = HatIndex.cosine(dims)
281
+
282
+ # Add some high-dimensional points
283
+ for i in range(10):
284
+ embedding = [(j * i * 0.01) % 1.0 for j in range(dims)]
285
+ index.add(embedding)
286
+
287
+ assert len(index) == 10
288
+
289
+ # Query
290
+ query = [0.5] * dims
291
+ results = index.near(query, k=5)
292
+ assert len(results) == 5
293
+
294
+
295
+ if __name__ == "__main__":
296
+ pytest.main([__file__, "-v"])
src/adapters/attention.rs ADDED
@@ -0,0 +1,789 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Attention State Serialization
2
+ //!
3
+ //! Format for storing retrievable attention states, not just text.
4
+ //!
5
+ //! ## The Key Insight
6
+ //!
7
+ //! Traditional RAG stores text and re-embeds on retrieval.
8
+ //! HAT stores **attention states** that can be directly injected into LLM context.
9
+ //!
10
+ //! ## What Gets Stored
11
+ //!
12
+ //! For each memory chunk:
13
+ //! - **Text**: Original tokens/content
14
+ //! - **Embedding**: Vector for retrieval routing
15
+ //! - **KV Cache**: Compressed key-value states (optional, model-specific)
16
+ //! - **Metadata**: Timestamp, role, session context
17
+ //!
18
+ //! ## Format Design
19
+ //!
20
+ //! ```text
21
+ //! AttentionState
22
+ //! ├── id: Id (16 bytes)
23
+ //! ├── timestamp_ms: u64
24
+ //! ├── role: Role (user/assistant/system)
25
+ //! ├── text: String (original content)
26
+ //! ├── embedding: Vec<f32> (for HAT routing)
27
+ //! ├── kv_cache: Option<CompressedKV> (model-specific)
28
+ //! └── metadata: HashMap<String, String>
29
+ //! ```
30
+
31
+ use crate::core::Id;
32
+
33
+ /// Role in conversation
34
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
35
+ pub enum Role {
36
+ /// System prompt
37
+ System,
38
+ /// User message
39
+ User,
40
+ /// Assistant response
41
+ Assistant,
42
+ /// Tool/function call
43
+ Tool,
44
+ /// Retrieved context (from RAG or previous HAT retrieval)
45
+ Context,
46
+ }
47
+
48
+ impl Role {
49
+ pub fn as_str(&self) -> &'static str {
50
+ match self {
51
+ Role::System => "system",
52
+ Role::User => "user",
53
+ Role::Assistant => "assistant",
54
+ Role::Tool => "tool",
55
+ Role::Context => "context",
56
+ }
57
+ }
58
+
59
+ pub fn from_str(s: &str) -> Option<Self> {
60
+ match s.to_lowercase().as_str() {
61
+ "system" => Some(Role::System),
62
+ "user" => Some(Role::User),
63
+ "assistant" => Some(Role::Assistant),
64
+ "tool" | "function" => Some(Role::Tool),
65
+ "context" | "retrieved" => Some(Role::Context),
66
+ _ => None,
67
+ }
68
+ }
69
+
70
+ fn to_byte(&self) -> u8 {
71
+ match self {
72
+ Role::System => 0,
73
+ Role::User => 1,
74
+ Role::Assistant => 2,
75
+ Role::Tool => 3,
76
+ Role::Context => 4,
77
+ }
78
+ }
79
+
80
+ fn from_byte(b: u8) -> Option<Self> {
81
+ match b {
82
+ 0 => Some(Role::System),
83
+ 1 => Some(Role::User),
84
+ 2 => Some(Role::Assistant),
85
+ 3 => Some(Role::Tool),
86
+ 4 => Some(Role::Context),
87
+ _ => None,
88
+ }
89
+ }
90
+ }
91
+
92
+ /// Compressed KV cache for a specific model architecture
93
+ ///
94
+ /// This is model-specific. Different models have different:
95
+ /// - Number of layers
96
+ /// - Number of heads
97
+ /// - Head dimensions
98
+ /// - Quantization formats
99
+ #[derive(Debug, Clone)]
100
+ pub struct CompressedKV {
101
+ /// Model identifier (e.g., "llama-3-8b", "mistral-7b")
102
+ pub model_id: String,
103
+
104
+ /// Number of layers
105
+ pub num_layers: u32,
106
+
107
+ /// Number of attention heads
108
+ pub num_heads: u32,
109
+
110
+ /// Dimension per head
111
+ pub head_dim: u32,
112
+
113
+ /// Sequence length this KV cache covers
114
+ pub seq_len: u32,
115
+
116
+ /// Quantization format (e.g., "fp16", "int8", "int4")
117
+ pub quantization: String,
118
+
119
+ /// Compressed KV data
120
+ /// Format: [layer][head][seq][key/value][head_dim]
121
+ /// Actual layout depends on quantization
122
+ pub data: Vec<u8>,
123
+ }
124
+
125
+ impl CompressedKV {
126
+ /// Estimate memory size in bytes
127
+ pub fn size_bytes(&self) -> usize {
128
+ self.data.len()
129
+ }
130
+
131
+ /// Create a placeholder (for models that don't support KV export)
132
+ pub fn placeholder(model_id: &str) -> Self {
133
+ Self {
134
+ model_id: model_id.to_string(),
135
+ num_layers: 0,
136
+ num_heads: 0,
137
+ head_dim: 0,
138
+ seq_len: 0,
139
+ quantization: "none".to_string(),
140
+ data: vec![],
141
+ }
142
+ }
143
+
144
+ /// Serialize to bytes
145
+ pub fn to_bytes(&self) -> Vec<u8> {
146
+ let mut bytes = Vec::new();
147
+
148
+ // Model ID (length-prefixed string)
149
+ let model_bytes = self.model_id.as_bytes();
150
+ bytes.extend_from_slice(&(model_bytes.len() as u32).to_le_bytes());
151
+ bytes.extend_from_slice(model_bytes);
152
+
153
+ // Architecture params
154
+ bytes.extend_from_slice(&self.num_layers.to_le_bytes());
155
+ bytes.extend_from_slice(&self.num_heads.to_le_bytes());
156
+ bytes.extend_from_slice(&self.head_dim.to_le_bytes());
157
+ bytes.extend_from_slice(&self.seq_len.to_le_bytes());
158
+
159
+ // Quantization (length-prefixed string)
160
+ let quant_bytes = self.quantization.as_bytes();
161
+ bytes.extend_from_slice(&(quant_bytes.len() as u32).to_le_bytes());
162
+ bytes.extend_from_slice(quant_bytes);
163
+
164
+ // Data (length-prefixed)
165
+ bytes.extend_from_slice(&(self.data.len() as u64).to_le_bytes());
166
+ bytes.extend_from_slice(&self.data);
167
+
168
+ bytes
169
+ }
170
+
171
+ /// Deserialize from bytes
172
+ pub fn from_bytes(data: &[u8]) -> Option<(Self, usize)> {
173
+ let mut offset = 0;
174
+
175
+ // Model ID
176
+ if data.len() < offset + 4 {
177
+ return None;
178
+ }
179
+ let model_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize;
180
+ offset += 4;
181
+
182
+ if data.len() < offset + model_len {
183
+ return None;
184
+ }
185
+ let model_id = String::from_utf8(data[offset..offset + model_len].to_vec()).ok()?;
186
+ offset += model_len;
187
+
188
+ // Architecture params
189
+ if data.len() < offset + 16 {
190
+ return None;
191
+ }
192
+ let num_layers = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?);
193
+ offset += 4;
194
+ let num_heads = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?);
195
+ offset += 4;
196
+ let head_dim = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?);
197
+ offset += 4;
198
+ let seq_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?);
199
+ offset += 4;
200
+
201
+ // Quantization
202
+ if data.len() < offset + 4 {
203
+ return None;
204
+ }
205
+ let quant_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize;
206
+ offset += 4;
207
+
208
+ if data.len() < offset + quant_len {
209
+ return None;
210
+ }
211
+ let quantization = String::from_utf8(data[offset..offset + quant_len].to_vec()).ok()?;
212
+ offset += quant_len;
213
+
214
+ // Data
215
+ if data.len() < offset + 8 {
216
+ return None;
217
+ }
218
+ let data_len = u64::from_le_bytes(data[offset..offset + 8].try_into().ok()?) as usize;
219
+ offset += 8;
220
+
221
+ if data.len() < offset + data_len {
222
+ return None;
223
+ }
224
+ let kv_data = data[offset..offset + data_len].to_vec();
225
+ offset += data_len;
226
+
227
+ Some((
228
+ Self {
229
+ model_id,
230
+ num_layers,
231
+ num_heads,
232
+ head_dim,
233
+ seq_len,
234
+ quantization,
235
+ data: kv_data,
236
+ },
237
+ offset,
238
+ ))
239
+ }
240
+ }
241
+
242
+ /// A complete attention state for a memory chunk
243
+ #[derive(Debug, Clone)]
244
+ pub struct AttentionState {
245
+ /// Unique identifier
246
+ pub id: Id,
247
+
248
+ /// Timestamp (milliseconds since epoch)
249
+ pub timestamp_ms: u64,
250
+
251
+ /// Role in conversation
252
+ pub role: Role,
253
+
254
+ /// Original text content
255
+ pub text: String,
256
+
257
+ /// Embedding vector (for HAT retrieval routing)
258
+ pub embedding: Vec<f32>,
259
+
260
+ /// Optional compressed KV cache (model-specific)
261
+ pub kv_cache: Option<CompressedKV>,
262
+
263
+ /// Additional metadata (flexible key-value pairs)
264
+ pub metadata: std::collections::HashMap<String, String>,
265
+ }
266
+
267
+ impl AttentionState {
268
+ /// Create a new attention state (without KV cache)
269
+ pub fn new(role: Role, text: String, embedding: Vec<f32>) -> Self {
270
+ Self {
271
+ id: Id::now(),
272
+ timestamp_ms: std::time::SystemTime::now()
273
+ .duration_since(std::time::UNIX_EPOCH)
274
+ .unwrap()
275
+ .as_millis() as u64,
276
+ role,
277
+ text,
278
+ embedding,
279
+ kv_cache: None,
280
+ metadata: std::collections::HashMap::new(),
281
+ }
282
+ }
283
+
284
+ /// Create with KV cache
285
+ pub fn with_kv_cache(mut self, kv: CompressedKV) -> Self {
286
+ self.kv_cache = Some(kv);
287
+ self
288
+ }
289
+
290
+ /// Add metadata
291
+ pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
292
+ self.metadata.insert(key.to_string(), value.to_string());
293
+ self
294
+ }
295
+
296
+ /// Estimate total size in bytes
297
+ pub fn size_bytes(&self) -> usize {
298
+ 16 + // id
299
+ 8 + // timestamp
300
+ 1 + // role
301
+ self.text.len() +
302
+ self.embedding.len() * 4 +
303
+ self.kv_cache.as_ref().map(|kv| kv.size_bytes()).unwrap_or(0) +
304
+ self.metadata.iter().map(|(k, v)| k.len() + v.len() + 8).sum::<usize>()
305
+ }
306
+
307
+ /// Serialize to bytes
308
+ pub fn to_bytes(&self) -> Vec<u8> {
309
+ let mut bytes = Vec::new();
310
+
311
+ // Magic + version
312
+ bytes.extend_from_slice(b"ATTN");
313
+ bytes.extend_from_slice(&1u32.to_le_bytes());
314
+
315
+ // ID
316
+ bytes.extend_from_slice(self.id.as_bytes());
317
+
318
+ // Timestamp
319
+ bytes.extend_from_slice(&self.timestamp_ms.to_le_bytes());
320
+
321
+ // Role
322
+ bytes.push(self.role.to_byte());
323
+
324
+ // Text (length-prefixed)
325
+ let text_bytes = self.text.as_bytes();
326
+ bytes.extend_from_slice(&(text_bytes.len() as u32).to_le_bytes());
327
+ bytes.extend_from_slice(text_bytes);
328
+
329
+ // Embedding (length-prefixed)
330
+ bytes.extend_from_slice(&(self.embedding.len() as u32).to_le_bytes());
331
+ for &v in &self.embedding {
332
+ bytes.extend_from_slice(&v.to_le_bytes());
333
+ }
334
+
335
+ // KV cache (present flag + data)
336
+ if let Some(ref kv) = self.kv_cache {
337
+ bytes.push(1);
338
+ let kv_bytes = kv.to_bytes();
339
+ bytes.extend_from_slice(&(kv_bytes.len() as u64).to_le_bytes());
340
+ bytes.extend_from_slice(&kv_bytes);
341
+ } else {
342
+ bytes.push(0);
343
+ }
344
+
345
+ // Metadata (count + entries)
346
+ bytes.extend_from_slice(&(self.metadata.len() as u32).to_le_bytes());
347
+ for (key, value) in &self.metadata {
348
+ let key_bytes = key.as_bytes();
349
+ let value_bytes = value.as_bytes();
350
+ bytes.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
351
+ bytes.extend_from_slice(key_bytes);
352
+ bytes.extend_from_slice(&(value_bytes.len() as u32).to_le_bytes());
353
+ bytes.extend_from_slice(value_bytes);
354
+ }
355
+
356
+ bytes
357
+ }
358
+
359
+ /// Deserialize from bytes
360
+ pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> {
361
+ let mut offset = 0;
362
+
363
+ // Magic
364
+ if data.len() < 8 {
365
+ return Err(AttentionError::InvalidFormat("Too short".into()));
366
+ }
367
+ if &data[0..4] != b"ATTN" {
368
+ return Err(AttentionError::InvalidMagic);
369
+ }
370
+ offset += 4;
371
+
372
+ // Version
373
+ let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
374
+ if version != 1 {
375
+ return Err(AttentionError::UnsupportedVersion(version));
376
+ }
377
+ offset += 4;
378
+
379
+ // ID
380
+ if data.len() < offset + 16 {
381
+ return Err(AttentionError::InvalidFormat("Missing ID".into()));
382
+ }
383
+ let mut id_bytes = [0u8; 16];
384
+ id_bytes.copy_from_slice(&data[offset..offset + 16]);
385
+ let id = Id::from_bytes(id_bytes);
386
+ offset += 16;
387
+
388
+ // Timestamp
389
+ if data.len() < offset + 8 {
390
+ return Err(AttentionError::InvalidFormat("Missing timestamp".into()));
391
+ }
392
+ let timestamp_ms = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
393
+ offset += 8;
394
+
395
+ // Role
396
+ if data.len() < offset + 1 {
397
+ return Err(AttentionError::InvalidFormat("Missing role".into()));
398
+ }
399
+ let role = Role::from_byte(data[offset])
400
+ .ok_or_else(|| AttentionError::InvalidFormat("Invalid role".into()))?;
401
+ offset += 1;
402
+
403
+ // Text
404
+ if data.len() < offset + 4 {
405
+ return Err(AttentionError::InvalidFormat("Missing text length".into()));
406
+ }
407
+ let text_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
408
+ offset += 4;
409
+
410
+ if data.len() < offset + text_len {
411
+ return Err(AttentionError::InvalidFormat("Text truncated".into()));
412
+ }
413
+ let text = String::from_utf8(data[offset..offset + text_len].to_vec())
414
+ .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in text".into()))?;
415
+ offset += text_len;
416
+
417
+ // Embedding
418
+ if data.len() < offset + 4 {
419
+ return Err(AttentionError::InvalidFormat("Missing embedding length".into()));
420
+ }
421
+ let emb_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
422
+ offset += 4;
423
+
424
+ if data.len() < offset + emb_len * 4 {
425
+ return Err(AttentionError::InvalidFormat("Embedding truncated".into()));
426
+ }
427
+ let mut embedding = Vec::with_capacity(emb_len);
428
+ for _ in 0..emb_len {
429
+ embedding.push(f32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()));
430
+ offset += 4;
431
+ }
432
+
433
+ // KV cache
434
+ if data.len() < offset + 1 {
435
+ return Err(AttentionError::InvalidFormat("Missing KV flag".into()));
436
+ }
437
+ let has_kv = data[offset] != 0;
438
+ offset += 1;
439
+
440
+ let kv_cache = if has_kv {
441
+ if data.len() < offset + 8 {
442
+ return Err(AttentionError::InvalidFormat("Missing KV length".into()));
443
+ }
444
+ let kv_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize;
445
+ offset += 8;
446
+
447
+ if data.len() < offset + kv_len {
448
+ return Err(AttentionError::InvalidFormat("KV data truncated".into()));
449
+ }
450
+ let (kv, _) = CompressedKV::from_bytes(&data[offset..offset + kv_len])
451
+ .ok_or_else(|| AttentionError::InvalidFormat("Invalid KV cache".into()))?;
452
+ offset += kv_len;
453
+ Some(kv)
454
+ } else {
455
+ None
456
+ };
457
+
458
+ // Metadata
459
+ if data.len() < offset + 4 {
460
+ return Err(AttentionError::InvalidFormat("Missing metadata count".into()));
461
+ }
462
+ let meta_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
463
+ offset += 4;
464
+
465
+ let mut metadata = std::collections::HashMap::new();
466
+ for _ in 0..meta_count {
467
+ // Key
468
+ if data.len() < offset + 4 {
469
+ return Err(AttentionError::InvalidFormat("Missing key length".into()));
470
+ }
471
+ let key_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
472
+ offset += 4;
473
+
474
+ if data.len() < offset + key_len {
475
+ return Err(AttentionError::InvalidFormat("Key truncated".into()));
476
+ }
477
+ let key = String::from_utf8(data[offset..offset + key_len].to_vec())
478
+ .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in key".into()))?;
479
+ offset += key_len;
480
+
481
+ // Value
482
+ if data.len() < offset + 4 {
483
+ return Err(AttentionError::InvalidFormat("Missing value length".into()));
484
+ }
485
+ let value_len = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
486
+ offset += 4;
487
+
488
+ if data.len() < offset + value_len {
489
+ return Err(AttentionError::InvalidFormat("Value truncated".into()));
490
+ }
491
+ let value = String::from_utf8(data[offset..offset + value_len].to_vec())
492
+ .map_err(|_| AttentionError::InvalidFormat("Invalid UTF-8 in value".into()))?;
493
+ offset += value_len;
494
+
495
+ metadata.insert(key, value);
496
+ }
497
+
498
+ Ok(Self {
499
+ id,
500
+ timestamp_ms,
501
+ role,
502
+ text,
503
+ embedding,
504
+ kv_cache,
505
+ metadata,
506
+ })
507
+ }
508
+ }
509
+
510
+ /// Errors for attention state operations
511
+ #[derive(Debug, Clone)]
512
+ pub enum AttentionError {
513
+ InvalidMagic,
514
+ UnsupportedVersion(u32),
515
+ InvalidFormat(String),
516
+ }
517
+
518
+ impl std::fmt::Display for AttentionError {
519
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520
+ match self {
521
+ AttentionError::InvalidMagic => write!(f, "Invalid magic bytes"),
522
+ AttentionError::UnsupportedVersion(v) => write!(f, "Unsupported version: {}", v),
523
+ AttentionError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
524
+ }
525
+ }
526
+ }
527
+
528
+ impl std::error::Error for AttentionError {}
529
+
530
+ /// A batch of attention states for efficient storage
531
+ #[derive(Debug, Clone)]
532
+ pub struct AttentionBatch {
533
+ /// States in this batch
534
+ pub states: Vec<AttentionState>,
535
+
536
+ /// Session ID this batch belongs to
537
+ pub session_id: Option<Id>,
538
+
539
+ /// Document ID this batch belongs to
540
+ pub document_id: Option<Id>,
541
+ }
542
+
543
+ impl AttentionBatch {
544
+ pub fn new() -> Self {
545
+ Self {
546
+ states: Vec::new(),
547
+ session_id: None,
548
+ document_id: None,
549
+ }
550
+ }
551
+
552
+ pub fn with_session(mut self, session_id: Id) -> Self {
553
+ self.session_id = Some(session_id);
554
+ self
555
+ }
556
+
557
+ pub fn with_document(mut self, document_id: Id) -> Self {
558
+ self.document_id = Some(document_id);
559
+ self
560
+ }
561
+
562
+ pub fn add(&mut self, state: AttentionState) {
563
+ self.states.push(state);
564
+ }
565
+
566
+ /// Total size in bytes
567
+ pub fn size_bytes(&self) -> usize {
568
+ self.states.iter().map(|s| s.size_bytes()).sum()
569
+ }
570
+
571
+ /// Serialize batch to bytes
572
+ pub fn to_bytes(&self) -> Vec<u8> {
573
+ let mut bytes = Vec::new();
574
+
575
+ // Magic + version
576
+ bytes.extend_from_slice(b"ATNB");
577
+ bytes.extend_from_slice(&1u32.to_le_bytes());
578
+
579
+ // Session ID
580
+ if let Some(sid) = self.session_id {
581
+ bytes.push(1);
582
+ bytes.extend_from_slice(sid.as_bytes());
583
+ } else {
584
+ bytes.push(0);
585
+ }
586
+
587
+ // Document ID
588
+ if let Some(did) = self.document_id {
589
+ bytes.push(1);
590
+ bytes.extend_from_slice(did.as_bytes());
591
+ } else {
592
+ bytes.push(0);
593
+ }
594
+
595
+ // States count
596
+ bytes.extend_from_slice(&(self.states.len() as u32).to_le_bytes());
597
+
598
+ // Each state
599
+ for state in &self.states {
600
+ let state_bytes = state.to_bytes();
601
+ bytes.extend_from_slice(&(state_bytes.len() as u64).to_le_bytes());
602
+ bytes.extend_from_slice(&state_bytes);
603
+ }
604
+
605
+ bytes
606
+ }
607
+
608
+ /// Deserialize batch from bytes
609
+ pub fn from_bytes(data: &[u8]) -> Result<Self, AttentionError> {
610
+ let mut offset = 0;
611
+
612
+ // Magic
613
+ if data.len() < 8 {
614
+ return Err(AttentionError::InvalidFormat("Too short".into()));
615
+ }
616
+ if &data[0..4] != b"ATNB" {
617
+ return Err(AttentionError::InvalidMagic);
618
+ }
619
+ offset += 4;
620
+
621
+ // Version
622
+ let version = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
623
+ if version != 1 {
624
+ return Err(AttentionError::UnsupportedVersion(version));
625
+ }
626
+ offset += 4;
627
+
628
+ // Session ID
629
+ if data.len() < offset + 1 {
630
+ return Err(AttentionError::InvalidFormat("Missing session flag".into()));
631
+ }
632
+ let has_session = data[offset] != 0;
633
+ offset += 1;
634
+
635
+ let session_id = if has_session {
636
+ if data.len() < offset + 16 {
637
+ return Err(AttentionError::InvalidFormat("Missing session ID".into()));
638
+ }
639
+ let mut id_bytes = [0u8; 16];
640
+ id_bytes.copy_from_slice(&data[offset..offset + 16]);
641
+ offset += 16;
642
+ Some(Id::from_bytes(id_bytes))
643
+ } else {
644
+ None
645
+ };
646
+
647
+ // Document ID
648
+ if data.len() < offset + 1 {
649
+ return Err(AttentionError::InvalidFormat("Missing document flag".into()));
650
+ }
651
+ let has_document = data[offset] != 0;
652
+ offset += 1;
653
+
654
+ let document_id = if has_document {
655
+ if data.len() < offset + 16 {
656
+ return Err(AttentionError::InvalidFormat("Missing document ID".into()));
657
+ }
658
+ let mut id_bytes = [0u8; 16];
659
+ id_bytes.copy_from_slice(&data[offset..offset + 16]);
660
+ offset += 16;
661
+ Some(Id::from_bytes(id_bytes))
662
+ } else {
663
+ None
664
+ };
665
+
666
+ // States count
667
+ if data.len() < offset + 4 {
668
+ return Err(AttentionError::InvalidFormat("Missing state count".into()));
669
+ }
670
+ let state_count = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap()) as usize;
671
+ offset += 4;
672
+
673
+ // States
674
+ let mut states = Vec::with_capacity(state_count);
675
+ for _ in 0..state_count {
676
+ if data.len() < offset + 8 {
677
+ return Err(AttentionError::InvalidFormat("Missing state length".into()));
678
+ }
679
+ let state_len = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize;
680
+ offset += 8;
681
+
682
+ if data.len() < offset + state_len {
683
+ return Err(AttentionError::InvalidFormat("State truncated".into()));
684
+ }
685
+ let state = AttentionState::from_bytes(&data[offset..offset + state_len])?;
686
+ offset += state_len;
687
+ states.push(state);
688
+ }
689
+
690
+ Ok(Self {
691
+ states,
692
+ session_id,
693
+ document_id,
694
+ })
695
+ }
696
+ }
697
+
698
+ impl Default for AttentionBatch {
699
+ fn default() -> Self {
700
+ Self::new()
701
+ }
702
+ }
703
+
704
+ #[cfg(test)]
705
+ mod tests {
706
+ use super::*;
707
+
708
+ #[test]
709
+ fn test_role_roundtrip() {
710
+ for role in [Role::System, Role::User, Role::Assistant, Role::Tool, Role::Context] {
711
+ let byte = role.to_byte();
712
+ let restored = Role::from_byte(byte).unwrap();
713
+ assert_eq!(role, restored);
714
+ }
715
+ }
716
+
717
+ #[test]
718
+ fn test_attention_state_roundtrip() {
719
+ let state = AttentionState::new(
720
+ Role::User,
721
+ "Hello, how are you?".to_string(),
722
+ vec![0.1, 0.2, 0.3, 0.4],
723
+ )
724
+ .with_metadata("turn", "1");
725
+
726
+ let bytes = state.to_bytes();
727
+ let restored = AttentionState::from_bytes(&bytes).unwrap();
728
+
729
+ assert_eq!(state.role, restored.role);
730
+ assert_eq!(state.text, restored.text);
731
+ assert_eq!(state.embedding, restored.embedding);
732
+ assert_eq!(state.metadata.get("turn"), restored.metadata.get("turn"));
733
+ }
734
+
735
+ #[test]
736
+ fn test_attention_state_with_kv() {
737
+ let kv = CompressedKV {
738
+ model_id: "llama-3-8b".to_string(),
739
+ num_layers: 32,
740
+ num_heads: 32,
741
+ head_dim: 128,
742
+ seq_len: 10,
743
+ quantization: "fp16".to_string(),
744
+ data: vec![1, 2, 3, 4, 5],
745
+ };
746
+
747
+ let state = AttentionState::new(
748
+ Role::Assistant,
749
+ "I'm doing well!".to_string(),
750
+ vec![0.5, 0.6, 0.7, 0.8],
751
+ )
752
+ .with_kv_cache(kv);
753
+
754
+ let bytes = state.to_bytes();
755
+ let restored = AttentionState::from_bytes(&bytes).unwrap();
756
+
757
+ assert!(restored.kv_cache.is_some());
758
+ let restored_kv = restored.kv_cache.unwrap();
759
+ assert_eq!(restored_kv.model_id, "llama-3-8b");
760
+ assert_eq!(restored_kv.num_layers, 32);
761
+ assert_eq!(restored_kv.data, vec![1, 2, 3, 4, 5]);
762
+ }
763
+
764
+ #[test]
765
+ fn test_batch_roundtrip() {
766
+ let mut batch = AttentionBatch::new()
767
+ .with_session(Id::now());
768
+
769
+ batch.add(AttentionState::new(
770
+ Role::User,
771
+ "Question 1".to_string(),
772
+ vec![0.1, 0.2],
773
+ ));
774
+
775
+ batch.add(AttentionState::new(
776
+ Role::Assistant,
777
+ "Answer 1".to_string(),
778
+ vec![0.3, 0.4],
779
+ ));
780
+
781
+ let bytes = batch.to_bytes();
782
+ let restored = AttentionBatch::from_bytes(&bytes).unwrap();
783
+
784
+ assert_eq!(restored.states.len(), 2);
785
+ assert_eq!(restored.states[0].text, "Question 1");
786
+ assert_eq!(restored.states[1].text, "Answer 1");
787
+ assert!(restored.session_id.is_some());
788
+ }
789
+ }
src/adapters/index/consolidation.rs ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Consolidation Phases for HAT
2
+ //!
3
+ //! Background maintenance operations inspired by memory consolidation in the brain.
4
+ //! Like sleep stages (REM/NREM), HAT needs periodic "offline" maintenance to:
5
+ //!
6
+ //! 1. **Recompute Centroids**: Incremental updates accumulate drift - recompute from scratch
7
+ //! 2. **Rebalance Tree**: Merge underpopulated containers, split overpopulated ones
8
+ //! 3. **Prune Stale Branches**: Remove containers with no descendants
9
+ //! 4. **Optimize Layout**: Reorder children for better cache locality
10
+ //!
11
+ //! ## Design Philosophy
12
+ //!
13
+ //! Consolidation is designed to be:
14
+ //! - **Non-blocking**: Can run incrementally, yielding to queries
15
+ //! - **Resumable**: Can pause and resume without data loss
16
+ //! - **Observable**: Reports progress and metrics for benchmarking
17
+ //!
18
+ //! ## Consolidation Levels
19
+ //!
20
+ //! Like sleep stages, different consolidation depths:
21
+ //!
22
+ //! - **Light** (α): Recompute centroids only (~NREM Stage 1)
23
+ //! - **Medium** (β): + Rebalance tree structure (~NREM Stage 2-3)
24
+ //! - **Deep** (δ): + Optimize layout, prune stale (~NREM Stage 4 / SWS)
25
+ //! - **Full** (θ): Complete rebuild from scratch (~REM)
26
+
27
+ use std::collections::{HashMap, HashSet, VecDeque};
28
+
29
+ use crate::core::{Id, Point};
30
+
31
+ /// Consolidation level - determines how deep the maintenance goes
32
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
33
+ pub enum ConsolidationLevel {
34
+ /// Light: Recompute centroids only
35
+ /// Fast, minimal disruption, good for frequent runs
36
+ Light,
37
+
38
+ /// Medium: Recompute centroids + rebalance tree
39
+ /// Moderate time, restructures containers
40
+ Medium,
41
+
42
+ /// Deep: Full maintenance including layout optimization
43
+ /// Longer time, comprehensive cleanup
44
+ Deep,
45
+
46
+ /// Full: Complete rebuild from leaf nodes
47
+ /// Longest time, guarantees optimal structure
48
+ Full,
49
+ }
50
+
51
+ impl Default for ConsolidationLevel {
52
+ fn default() -> Self {
53
+ ConsolidationLevel::Medium
54
+ }
55
+ }
56
+
57
+ /// Configuration for consolidation operations
58
+ #[derive(Debug, Clone)]
59
+ pub struct ConsolidationConfig {
60
+ /// Target level of consolidation
61
+ pub level: ConsolidationLevel,
62
+
63
+ /// Maximum containers to process per tick (for incremental consolidation)
64
+ pub batch_size: usize,
65
+
66
+ /// Minimum children before considering merge
67
+ pub merge_threshold: usize,
68
+
69
+ /// Maximum children before considering split
70
+ pub split_threshold: usize,
71
+
72
+ /// Maximum centroid drift (L2) before triggering recompute
73
+ /// 0.0 = always recompute, higher values = more lenient
74
+ pub drift_threshold: f32,
75
+
76
+ /// Whether to collect detailed metrics
77
+ pub collect_metrics: bool,
78
+ }
79
+
80
+ impl Default for ConsolidationConfig {
81
+ fn default() -> Self {
82
+ Self {
83
+ level: ConsolidationLevel::Medium,
84
+ batch_size: 100,
85
+ merge_threshold: 3,
86
+ split_threshold: 100,
87
+ drift_threshold: 0.01,
88
+ collect_metrics: true,
89
+ }
90
+ }
91
+ }
92
+
93
+ impl ConsolidationConfig {
94
+ pub fn light() -> Self {
95
+ Self {
96
+ level: ConsolidationLevel::Light,
97
+ ..Default::default()
98
+ }
99
+ }
100
+
101
+ pub fn medium() -> Self {
102
+ Self {
103
+ level: ConsolidationLevel::Medium,
104
+ ..Default::default()
105
+ }
106
+ }
107
+
108
+ pub fn deep() -> Self {
109
+ Self {
110
+ level: ConsolidationLevel::Deep,
111
+ ..Default::default()
112
+ }
113
+ }
114
+
115
+ pub fn full() -> Self {
116
+ Self {
117
+ level: ConsolidationLevel::Full,
118
+ ..Default::default()
119
+ }
120
+ }
121
+
122
+ pub fn with_batch_size(mut self, size: usize) -> Self {
123
+ self.batch_size = size;
124
+ self
125
+ }
126
+ }
127
+
128
+ /// Current state of consolidation
129
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
130
+ pub enum ConsolidationPhase {
131
+ /// Not currently consolidating
132
+ Idle,
133
+
134
+ /// Phase 1: Collecting all leaf points
135
+ CollectingLeaves,
136
+
137
+ /// Phase 2: Recomputing centroids bottom-up
138
+ RecomputingCentroids,
139
+
140
+ /// Phase 3: Identifying containers to merge/split
141
+ AnalyzingStructure,
142
+
143
+ /// Phase 4: Performing merges
144
+ Merging,
145
+
146
+ /// Phase 5: Performing splits
147
+ Splitting,
148
+
149
+ /// Phase 6: Pruning empty containers
150
+ Pruning,
151
+
152
+ /// Phase 7: Optimizing layout
153
+ OptimizingLayout,
154
+
155
+ /// Consolidation complete
156
+ Complete,
157
+ }
158
+
159
+ /// Metrics collected during consolidation
160
+ #[derive(Debug, Clone, Default)]
161
+ pub struct ConsolidationMetrics {
162
+ /// Total containers processed
163
+ pub containers_processed: usize,
164
+
165
+ /// Centroids recomputed
166
+ pub centroids_recomputed: usize,
167
+
168
+ /// Average centroid drift (L2 norm of delta)
169
+ pub avg_centroid_drift: f32,
170
+
171
+ /// Maximum centroid drift observed
172
+ pub max_centroid_drift: f32,
173
+
174
+ /// Number of containers merged
175
+ pub containers_merged: usize,
176
+
177
+ /// Number of containers split
178
+ pub containers_split: usize,
179
+
180
+ /// Number of empty containers pruned
181
+ pub containers_pruned: usize,
182
+
183
+ /// Time spent in each phase (microseconds)
184
+ pub phase_times_us: HashMap<String, u64>,
185
+
186
+ /// Total consolidation time (microseconds)
187
+ pub total_time_us: u64,
188
+
189
+ /// Number of ticks (for incremental consolidation)
190
+ pub ticks: usize,
191
+ }
192
+
193
+ /// Progress report for observable consolidation
194
+ #[derive(Debug, Clone)]
195
+ pub struct ConsolidationProgress {
196
+ /// Current phase
197
+ pub phase: ConsolidationPhase,
198
+
199
+ /// Percentage complete (0.0 - 1.0)
200
+ pub progress: f32,
201
+
202
+ /// Containers remaining in current phase
203
+ pub remaining: usize,
204
+
205
+ /// Running metrics
206
+ pub metrics: ConsolidationMetrics,
207
+ }
208
+
209
+ /// Internal state for resumable consolidation
210
+ #[derive(Debug)]
211
+ pub struct ConsolidationState {
212
+ /// Configuration
213
+ pub config: ConsolidationConfig,
214
+
215
+ /// Current phase
216
+ pub phase: ConsolidationPhase,
217
+
218
+ /// Collected metrics
219
+ pub metrics: ConsolidationMetrics,
220
+
221
+ /// Queue of containers to process in current phase
222
+ pub work_queue: VecDeque<Id>,
223
+
224
+ /// Set of containers already processed
225
+ pub processed: HashSet<Id>,
226
+
227
+ /// Accumulated centroid drifts for averaging
228
+ centroid_drifts: Vec<f32>,
229
+
230
+ /// Containers identified for merging (pairs)
231
+ merge_candidates: Vec<(Id, Id)>,
232
+
233
+ /// Containers identified for splitting
234
+ split_candidates: Vec<Id>,
235
+
236
+ /// Phase start timestamp (for timing)
237
+ phase_start_us: u64,
238
+
239
+ /// Consolidation start timestamp
240
+ start_us: u64,
241
+ }
242
+
243
+ impl ConsolidationState {
244
+ /// Create a new consolidation state
245
+ pub fn new(config: ConsolidationConfig) -> Self {
246
+ let now = std::time::SystemTime::now()
247
+ .duration_since(std::time::UNIX_EPOCH)
248
+ .unwrap()
249
+ .as_micros() as u64;
250
+
251
+ Self {
252
+ config,
253
+ phase: ConsolidationPhase::Idle,
254
+ metrics: ConsolidationMetrics::default(),
255
+ work_queue: VecDeque::new(),
256
+ processed: HashSet::new(),
257
+ centroid_drifts: Vec::new(),
258
+ merge_candidates: Vec::new(),
259
+ split_candidates: Vec::new(),
260
+ phase_start_us: now,
261
+ start_us: now,
262
+ }
263
+ }
264
+
265
+ /// Start consolidation
266
+ pub fn start(&mut self) {
267
+ let now = std::time::SystemTime::now()
268
+ .duration_since(std::time::UNIX_EPOCH)
269
+ .unwrap()
270
+ .as_micros() as u64;
271
+
272
+ self.start_us = now;
273
+ self.phase_start_us = now;
274
+ self.phase = ConsolidationPhase::CollectingLeaves;
275
+ self.metrics = ConsolidationMetrics::default();
276
+ self.work_queue.clear();
277
+ self.processed.clear();
278
+ self.centroid_drifts.clear();
279
+ self.merge_candidates.clear();
280
+ self.split_candidates.clear();
281
+ }
282
+
283
+ /// Transition to next phase
284
+ pub fn next_phase(&mut self) {
285
+ let now = std::time::SystemTime::now()
286
+ .duration_since(std::time::UNIX_EPOCH)
287
+ .unwrap()
288
+ .as_micros() as u64;
289
+
290
+ // Record time for previous phase
291
+ let phase_time = now - self.phase_start_us;
292
+ let phase_name = format!("{:?}", self.phase);
293
+ self.metrics.phase_times_us.insert(phase_name, phase_time);
294
+
295
+ // Compute average drift if we have samples
296
+ if !self.centroid_drifts.is_empty() {
297
+ self.metrics.avg_centroid_drift =
298
+ self.centroid_drifts.iter().sum::<f32>() / self.centroid_drifts.len() as f32;
299
+ }
300
+
301
+ // Determine next phase based on level
302
+ self.phase = match (self.phase, self.config.level) {
303
+ (ConsolidationPhase::Idle, _) => ConsolidationPhase::CollectingLeaves,
304
+
305
+ (ConsolidationPhase::CollectingLeaves, _) => ConsolidationPhase::RecomputingCentroids,
306
+
307
+ (ConsolidationPhase::RecomputingCentroids, ConsolidationLevel::Light) => {
308
+ ConsolidationPhase::Complete
309
+ }
310
+ (ConsolidationPhase::RecomputingCentroids, _) => {
311
+ ConsolidationPhase::AnalyzingStructure
312
+ }
313
+
314
+ (ConsolidationPhase::AnalyzingStructure, _) => ConsolidationPhase::Merging,
315
+
316
+ (ConsolidationPhase::Merging, _) => ConsolidationPhase::Splitting,
317
+
318
+ (ConsolidationPhase::Splitting, ConsolidationLevel::Medium) => {
319
+ ConsolidationPhase::Complete
320
+ }
321
+ (ConsolidationPhase::Splitting, _) => ConsolidationPhase::Pruning,
322
+
323
+ (ConsolidationPhase::Pruning, _) => ConsolidationPhase::OptimizingLayout,
324
+
325
+ (ConsolidationPhase::OptimizingLayout, _) => ConsolidationPhase::Complete,
326
+
327
+ (ConsolidationPhase::Complete, _) => ConsolidationPhase::Complete,
328
+ };
329
+
330
+ // Reset for new phase
331
+ self.phase_start_us = now;
332
+ self.work_queue.clear();
333
+ self.processed.clear();
334
+
335
+ // Record total time if complete
336
+ if self.phase == ConsolidationPhase::Complete {
337
+ self.metrics.total_time_us = now - self.start_us;
338
+ }
339
+ }
340
+
341
+ /// Record a centroid drift
342
+ pub fn record_drift(&mut self, drift: f32) {
343
+ self.centroid_drifts.push(drift);
344
+ if drift > self.metrics.max_centroid_drift {
345
+ self.metrics.max_centroid_drift = drift;
346
+ }
347
+ }
348
+
349
+ /// Add merge candidate pair
350
+ pub fn add_merge_candidate(&mut self, a: Id, b: Id) {
351
+ self.merge_candidates.push((a, b));
352
+ }
353
+
354
+ /// Add split candidate
355
+ pub fn add_split_candidate(&mut self, id: Id) {
356
+ self.split_candidates.push(id);
357
+ }
358
+
359
+ /// Get next merge candidate pair
360
+ pub fn next_merge(&mut self) -> Option<(Id, Id)> {
361
+ self.merge_candidates.pop()
362
+ }
363
+
364
+ /// Get next split candidate
365
+ pub fn next_split(&mut self) -> Option<Id> {
366
+ self.split_candidates.pop()
367
+ }
368
+
369
+ /// Check if there are pending merge candidates
370
+ pub fn has_merges(&self) -> bool {
371
+ !self.merge_candidates.is_empty()
372
+ }
373
+
374
+ /// Check if there are pending split candidates
375
+ pub fn has_splits(&self) -> bool {
376
+ !self.split_candidates.is_empty()
377
+ }
378
+
379
+ /// Check if consolidation is complete
380
+ pub fn is_complete(&self) -> bool {
381
+ self.phase == ConsolidationPhase::Complete
382
+ }
383
+
384
+ /// Get progress report
385
+ pub fn progress(&self) -> ConsolidationProgress {
386
+ let remaining = self.work_queue.len();
387
+ let total = remaining + self.processed.len();
388
+ let progress = if total > 0 {
389
+ self.processed.len() as f32 / total as f32
390
+ } else {
391
+ 1.0
392
+ };
393
+
394
+ ConsolidationProgress {
395
+ phase: self.phase,
396
+ progress,
397
+ remaining,
398
+ metrics: self.metrics.clone(),
399
+ }
400
+ }
401
+ }
402
+
403
+ /// Result of a single consolidation tick
404
+ #[derive(Debug)]
405
+ pub enum ConsolidationTickResult {
406
+ /// Still working, more ticks needed
407
+ Continue(ConsolidationProgress),
408
+
409
+ /// Consolidation complete
410
+ Complete(ConsolidationMetrics),
411
+ }
412
+
413
+ /// Trait for types that support consolidation
414
+ pub trait Consolidate {
415
+ /// Begin consolidation with given config
416
+ fn begin_consolidation(&mut self, config: ConsolidationConfig);
417
+
418
+ /// Execute one tick of consolidation
419
+ /// Returns Continue if more work remains, Complete when done
420
+ fn consolidation_tick(&mut self) -> ConsolidationTickResult;
421
+
422
+ /// Run consolidation to completion (blocking)
423
+ fn consolidate(&mut self, config: ConsolidationConfig) -> ConsolidationMetrics {
424
+ self.begin_consolidation(config);
425
+ loop {
426
+ match self.consolidation_tick() {
427
+ ConsolidationTickResult::Continue(_) => continue,
428
+ ConsolidationTickResult::Complete(metrics) => return metrics,
429
+ }
430
+ }
431
+ }
432
+
433
+ /// Check if consolidation is in progress
434
+ fn is_consolidating(&self) -> bool;
435
+
436
+ /// Get current consolidation progress
437
+ fn consolidation_progress(&self) -> Option<ConsolidationProgress>;
438
+
439
+ /// Cancel ongoing consolidation
440
+ fn cancel_consolidation(&mut self);
441
+ }
442
+
443
+ /// Helper for computing exact centroids from a set of points
444
+ pub fn compute_exact_centroid(points: &[Point]) -> Option<Point> {
445
+ if points.is_empty() {
446
+ return None;
447
+ }
448
+
449
+ let dims = points[0].dimensionality();
450
+ let mut sum = vec![0.0f32; dims];
451
+
452
+ for point in points {
453
+ for (i, &val) in point.dims().iter().enumerate() {
454
+ sum[i] += val;
455
+ }
456
+ }
457
+
458
+ let n = points.len() as f32;
459
+ let mean: Vec<f32> = sum.iter().map(|s| s / n).collect();
460
+
461
+ Some(Point::new(mean).normalize())
462
+ }
463
+
464
+ /// Helper to measure centroid drift
465
+ pub fn centroid_drift(old: &Point, new: &Point) -> f32 {
466
+ old.dims()
467
+ .iter()
468
+ .zip(new.dims().iter())
469
+ .map(|(a, b)| (a - b).powi(2))
470
+ .sum::<f32>()
471
+ .sqrt()
472
+ }
473
+
474
+ #[cfg(test)]
475
+ mod tests {
476
+ use super::*;
477
+
478
+ #[test]
479
+ fn test_consolidation_config_levels() {
480
+ let light = ConsolidationConfig::light();
481
+ assert_eq!(light.level, ConsolidationLevel::Light);
482
+
483
+ let medium = ConsolidationConfig::medium();
484
+ assert_eq!(medium.level, ConsolidationLevel::Medium);
485
+
486
+ let deep = ConsolidationConfig::deep();
487
+ assert_eq!(deep.level, ConsolidationLevel::Deep);
488
+
489
+ let full = ConsolidationConfig::full();
490
+ assert_eq!(full.level, ConsolidationLevel::Full);
491
+ }
492
+
493
+ #[test]
494
+ fn test_consolidation_state_phases() {
495
+ let config = ConsolidationConfig::light();
496
+ let mut state = ConsolidationState::new(config);
497
+
498
+ assert_eq!(state.phase, ConsolidationPhase::Idle);
499
+
500
+ state.start();
501
+ assert_eq!(state.phase, ConsolidationPhase::CollectingLeaves);
502
+
503
+ state.next_phase();
504
+ assert_eq!(state.phase, ConsolidationPhase::RecomputingCentroids);
505
+
506
+ // Light level skips to complete after centroids
507
+ state.next_phase();
508
+ assert_eq!(state.phase, ConsolidationPhase::Complete);
509
+ assert!(state.is_complete());
510
+ }
511
+
512
+ #[test]
513
+ fn test_consolidation_state_medium_phases() {
514
+ let config = ConsolidationConfig::medium();
515
+ let mut state = ConsolidationState::new(config);
516
+
517
+ state.start();
518
+ assert_eq!(state.phase, ConsolidationPhase::CollectingLeaves);
519
+
520
+ state.next_phase();
521
+ assert_eq!(state.phase, ConsolidationPhase::RecomputingCentroids);
522
+
523
+ state.next_phase();
524
+ assert_eq!(state.phase, ConsolidationPhase::AnalyzingStructure);
525
+
526
+ state.next_phase();
527
+ assert_eq!(state.phase, ConsolidationPhase::Merging);
528
+
529
+ state.next_phase();
530
+ assert_eq!(state.phase, ConsolidationPhase::Splitting);
531
+
532
+ // Medium level completes after splitting
533
+ state.next_phase();
534
+ assert_eq!(state.phase, ConsolidationPhase::Complete);
535
+ }
536
+
537
+ #[test]
538
+ fn test_centroid_computation() {
539
+ let points = vec![
540
+ Point::new(vec![1.0, 0.0, 0.0]),
541
+ Point::new(vec![0.0, 1.0, 0.0]),
542
+ Point::new(vec![0.0, 0.0, 1.0]),
543
+ ];
544
+
545
+ let centroid = compute_exact_centroid(&points).unwrap();
546
+
547
+ // Should be normalized mean
548
+ let expected_unnorm = (1.0f32 / 3.0).sqrt();
549
+ for dim in centroid.dims() {
550
+ assert!((dim - expected_unnorm).abs() < 0.01);
551
+ }
552
+ }
553
+
554
+ #[test]
555
+ fn test_centroid_drift() {
556
+ let old = Point::new(vec![1.0, 0.0, 0.0]);
557
+ let new = Point::new(vec![0.9, 0.1, 0.0]).normalize();
558
+
559
+ let drift = centroid_drift(&old, &new);
560
+ assert!(drift > 0.0);
561
+ assert!(drift < 1.0);
562
+ }
563
+
564
+ #[test]
565
+ fn test_drift_recording() {
566
+ let config = ConsolidationConfig::default();
567
+ let mut state = ConsolidationState::new(config);
568
+
569
+ state.record_drift(0.05);
570
+ state.record_drift(0.10);
571
+ state.record_drift(0.02);
572
+
573
+ assert_eq!(state.metrics.max_centroid_drift, 0.10);
574
+ assert_eq!(state.centroid_drifts.len(), 3);
575
+ }
576
+ }
src/adapters/index/flat.rs ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Flat Index Adapter
2
+ //!
3
+ //! Brute force nearest neighbor search.
4
+ //! Compares query against ALL points - O(n) per query.
5
+ //!
6
+ //! Good for:
7
+ //! - Testing
8
+ //! - Small datasets (< 10,000 points)
9
+ //! - When exact results are required
10
+ //!
11
+ //! Not good for:
12
+ //! - Large datasets (use HNSW instead)
13
+
14
+ use std::collections::HashMap;
15
+ use std::sync::Arc;
16
+
17
+ use crate::core::{Id, Point};
18
+ use crate::core::proximity::Proximity;
19
+ use crate::ports::{Near, NearError, NearResult, SearchResult};
20
+
21
+ /// Brute force index - searches all points
22
+ pub struct FlatIndex {
23
+ /// Stored points (ID -> Point)
24
+ points: HashMap<Id, Point>,
25
+
26
+ /// Expected dimensionality
27
+ dimensionality: usize,
28
+
29
+ /// Proximity function to use
30
+ proximity: Arc<dyn Proximity>,
31
+
32
+ /// Whether higher proximity = more similar
33
+ /// true for cosine/dot product, false for euclidean
34
+ higher_is_better: bool,
35
+ }
36
+
37
+ impl FlatIndex {
38
+ /// Create a new flat index
39
+ ///
40
+ /// `higher_is_better` indicates whether higher proximity scores mean more similar.
41
+ /// - `true` for Cosine, DotProduct
42
+ /// - `false` for Euclidean, Manhattan
43
+ pub fn new(
44
+ dimensionality: usize,
45
+ proximity: Arc<dyn Proximity>,
46
+ higher_is_better: bool,
47
+ ) -> Self {
48
+ Self {
49
+ points: HashMap::new(),
50
+ dimensionality,
51
+ proximity,
52
+ higher_is_better,
53
+ }
54
+ }
55
+
56
+ /// Create with cosine similarity (higher = better)
57
+ pub fn cosine(dimensionality: usize) -> Self {
58
+ use crate::core::proximity::Cosine;
59
+ Self::new(dimensionality, Arc::new(Cosine), true)
60
+ }
61
+
62
+ /// Create with euclidean distance (lower = better)
63
+ pub fn euclidean(dimensionality: usize) -> Self {
64
+ use crate::core::proximity::Euclidean;
65
+ Self::new(dimensionality, Arc::new(Euclidean), false)
66
+ }
67
+
68
+ /// Sort results by relevance
69
+ fn sort_results(&self, results: &mut Vec<SearchResult>) {
70
+ if self.higher_is_better {
71
+ // Higher score = more relevant, sort descending
72
+ results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
73
+ } else {
74
+ // Lower score = more relevant, sort ascending
75
+ results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
76
+ }
77
+ }
78
+ }
79
+
80
+ impl Near for FlatIndex {
81
+ fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
82
+ // Check dimensionality
83
+ if query.dimensionality() != self.dimensionality {
84
+ return Err(NearError::DimensionalityMismatch {
85
+ expected: self.dimensionality,
86
+ got: query.dimensionality(),
87
+ });
88
+ }
89
+
90
+ // Compute proximity to all points
91
+ let mut results: Vec<SearchResult> = self
92
+ .points
93
+ .iter()
94
+ .map(|(id, point)| {
95
+ let score = self.proximity.proximity(query, point);
96
+ SearchResult::new(*id, score)
97
+ })
98
+ .collect();
99
+
100
+ // Sort by relevance
101
+ self.sort_results(&mut results);
102
+
103
+ // Take top k
104
+ results.truncate(k);
105
+
106
+ Ok(results)
107
+ }
108
+
109
+ fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
110
+ // Check dimensionality
111
+ if query.dimensionality() != self.dimensionality {
112
+ return Err(NearError::DimensionalityMismatch {
113
+ expected: self.dimensionality,
114
+ got: query.dimensionality(),
115
+ });
116
+ }
117
+
118
+ // Find all points within threshold
119
+ let mut results: Vec<SearchResult> = self
120
+ .points
121
+ .iter()
122
+ .filter_map(|(id, point)| {
123
+ let score = self.proximity.proximity(query, point);
124
+ let within = if self.higher_is_better {
125
+ score >= threshold
126
+ } else {
127
+ score <= threshold
128
+ };
129
+ if within {
130
+ Some(SearchResult::new(*id, score))
131
+ } else {
132
+ None
133
+ }
134
+ })
135
+ .collect();
136
+
137
+ // Sort by relevance
138
+ self.sort_results(&mut results);
139
+
140
+ Ok(results)
141
+ }
142
+
143
+ fn add(&mut self, id: Id, point: &Point) -> NearResult<()> {
144
+ if point.dimensionality() != self.dimensionality {
145
+ return Err(NearError::DimensionalityMismatch {
146
+ expected: self.dimensionality,
147
+ got: point.dimensionality(),
148
+ });
149
+ }
150
+
151
+ self.points.insert(id, point.clone());
152
+ Ok(())
153
+ }
154
+
155
+ fn remove(&mut self, id: Id) -> NearResult<()> {
156
+ self.points.remove(&id);
157
+ Ok(())
158
+ }
159
+
160
+ fn rebuild(&mut self) -> NearResult<()> {
161
+ // Flat index doesn't need rebuilding
162
+ Ok(())
163
+ }
164
+
165
+ fn is_ready(&self) -> bool {
166
+ true // Always ready
167
+ }
168
+
169
+ fn len(&self) -> usize {
170
+ self.points.len()
171
+ }
172
+ }
173
+
174
+ #[cfg(test)]
175
+ mod tests {
176
+ use super::*;
177
+
178
+ fn setup_index() -> FlatIndex {
179
+ let mut index = FlatIndex::cosine(3);
180
+
181
+ // Add some test points
182
+ let points = vec![
183
+ (Id::from_bytes([1; 16]), Point::new(vec![1.0, 0.0, 0.0])),
184
+ (Id::from_bytes([2; 16]), Point::new(vec![0.0, 1.0, 0.0])),
185
+ (Id::from_bytes([3; 16]), Point::new(vec![0.0, 0.0, 1.0])),
186
+ (Id::from_bytes([4; 16]), Point::new(vec![0.7, 0.7, 0.0]).normalize()),
187
+ ];
188
+
189
+ for (id, point) in points {
190
+ index.add(id, &point).unwrap();
191
+ }
192
+
193
+ index
194
+ }
195
+
196
+ #[test]
197
+ fn test_flat_index_near() {
198
+ let index = setup_index();
199
+
200
+ // Query for points near [1, 0, 0]
201
+ let query = Point::new(vec![1.0, 0.0, 0.0]);
202
+ let results = index.near(&query, 2).unwrap();
203
+
204
+ assert_eq!(results.len(), 2);
205
+
206
+ // First result should be [1, 0, 0] with cosine = 1.0
207
+ assert_eq!(results[0].id, Id::from_bytes([1; 16]));
208
+ assert!((results[0].score - 1.0).abs() < 0.0001);
209
+ }
210
+
211
+ #[test]
212
+ fn test_flat_index_within_cosine() {
213
+ let index = setup_index();
214
+
215
+ // Find all points with cosine > 0.5 to [1, 0, 0]
216
+ let query = Point::new(vec![1.0, 0.0, 0.0]);
217
+ let results = index.within(&query, 0.5).unwrap();
218
+
219
+ // Should find [1,0,0] (cosine=1.0) and [0.7,0.7,0] (cosine≈0.707)
220
+ assert_eq!(results.len(), 2);
221
+ }
222
+
223
+ #[test]
224
+ fn test_flat_index_euclidean() {
225
+ let mut index = FlatIndex::euclidean(2);
226
+
227
+ index.add(Id::from_bytes([1; 16]), &Point::new(vec![0.0, 0.0])).unwrap();
228
+ index.add(Id::from_bytes([2; 16]), &Point::new(vec![1.0, 0.0])).unwrap();
229
+ index.add(Id::from_bytes([3; 16]), &Point::new(vec![5.0, 0.0])).unwrap();
230
+
231
+ let query = Point::new(vec![0.0, 0.0]);
232
+ let results = index.near(&query, 2).unwrap();
233
+
234
+ // Nearest should be [0,0] with distance 0
235
+ assert_eq!(results[0].id, Id::from_bytes([1; 16]));
236
+ assert!((results[0].score - 0.0).abs() < 0.0001);
237
+
238
+ // Second nearest should be [1,0] with distance 1
239
+ assert_eq!(results[1].id, Id::from_bytes([2; 16]));
240
+ assert!((results[1].score - 1.0).abs() < 0.0001);
241
+ }
242
+
243
+ #[test]
244
+ fn test_flat_index_add_remove() {
245
+ let mut index = FlatIndex::cosine(3);
246
+
247
+ let id = Id::from_bytes([1; 16]);
248
+ let point = Point::new(vec![1.0, 0.0, 0.0]);
249
+
250
+ index.add(id, &point).unwrap();
251
+ assert_eq!(index.len(), 1);
252
+
253
+ index.remove(id).unwrap();
254
+ assert_eq!(index.len(), 0);
255
+ }
256
+
257
+ #[test]
258
+ fn test_flat_index_dimensionality_check() {
259
+ let mut index = FlatIndex::cosine(3);
260
+
261
+ let wrong_dims = Point::new(vec![1.0, 0.0]); // 2 dims
262
+ let result = index.add(Id::now(), &wrong_dims);
263
+
264
+ match result {
265
+ Err(NearError::DimensionalityMismatch { expected, got }) => {
266
+ assert_eq!(expected, 3);
267
+ assert_eq!(got, 2);
268
+ }
269
+ _ => panic!("Expected DimensionalityMismatch error"),
270
+ }
271
+ }
272
+
273
+ #[test]
274
+ fn test_flat_index_ready() {
275
+ let index = FlatIndex::cosine(3);
276
+ assert!(index.is_ready());
277
+ }
278
+ }
src/adapters/index/hat.rs ADDED
@@ -0,0 +1,1953 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # HAT Index Adapter
2
+ //!
3
+ //! Hierarchical Attention Tree - a novel index structure for AI memory.
4
+ //! Exploits known semantic hierarchy and temporal locality.
5
+ //!
6
+ //! Key insight: Unlike HNSW which learns topology from data,
7
+ //! HAT uses KNOWN hierarchy (session → document → chunk).
8
+ //!
9
+ //! Query complexity: O(log n) via tree descent
10
+ //! Insert complexity: O(log n) with incremental centroid updates
11
+
12
+ use std::collections::{HashMap, VecDeque};
13
+ use std::sync::Arc;
14
+ use std::time::{SystemTime, UNIX_EPOCH};
15
+
16
+ use crate::core::{Id, Point};
17
+ use crate::core::proximity::Proximity;
18
+ use crate::core::merge::Merge;
19
+ use crate::ports::{Near, NearError, NearResult, SearchResult};
20
+
21
+ use super::consolidation::{
22
+ Consolidate, ConsolidationConfig, ConsolidationPhase, ConsolidationState,
23
+ ConsolidationMetrics, ConsolidationProgress, ConsolidationTickResult,
24
+ compute_exact_centroid, centroid_drift,
25
+ };
26
+
27
+ /// Centroid computation method
28
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
29
+ pub enum CentroidMethod {
30
+ /// Euclidean mean + renormalize (fast but geometrically imprecise)
31
+ Euclidean,
32
+ /// Fréchet mean on hypersphere (manifold-aware, more accurate)
33
+ Frechet,
34
+ }
35
+
36
+ impl Default for CentroidMethod {
37
+ fn default() -> Self {
38
+ CentroidMethod::Euclidean
39
+ }
40
+ }
41
+
42
+ /// HAT configuration parameters
43
+ #[derive(Debug, Clone)]
44
+ pub struct HatConfig {
45
+ /// Maximum children per container before splitting
46
+ pub max_children: usize,
47
+
48
+ /// Minimum children to maintain (for merging)
49
+ pub min_children: usize,
50
+
51
+ /// Number of branches to explore at each level (beam width)
52
+ pub beam_width: usize,
53
+
54
+ /// Weight for temporal proximity in scoring (0.0 = pure semantic)
55
+ pub temporal_weight: f32,
56
+
57
+ /// Time decay factor (higher = faster decay)
58
+ pub time_decay: f32,
59
+
60
+ /// Threshold for sparse centroid propagation (0.0 = always propagate)
61
+ /// Only propagate to parent if centroid change magnitude exceeds this
62
+ pub propagation_threshold: f32,
63
+
64
+ /// Method for computing centroids
65
+ pub centroid_method: CentroidMethod,
66
+
67
+ /// Number of iterations for Fréchet mean computation
68
+ pub frechet_iterations: usize,
69
+
70
+ /// Enable subspace-aware routing (default: false for backward compatibility)
71
+ pub subspace_enabled: bool,
72
+
73
+ /// Configuration for subspace representation
74
+ pub subspace_config: super::subspace::SubspaceConfig,
75
+
76
+ /// Enable learnable routing (default: false for backward compatibility)
77
+ pub learnable_routing_enabled: bool,
78
+
79
+ /// Configuration for learnable routing
80
+ pub learnable_routing_config: super::learnable_routing::LearnableRoutingConfig,
81
+ }
82
+
83
+ impl Default for HatConfig {
84
+ fn default() -> Self {
85
+ Self {
86
+ max_children: 50,
87
+ min_children: 5,
88
+ beam_width: 3,
89
+ temporal_weight: 0.0, // Start with pure semantic
90
+ time_decay: 0.001,
91
+ propagation_threshold: 0.0, // Default: always propagate (backward compatible)
92
+ centroid_method: CentroidMethod::Euclidean, // Default: backward compatible
93
+ frechet_iterations: 5, // Enough for convergence on hypersphere
94
+ subspace_enabled: false, // Default: disabled for backward compatibility
95
+ subspace_config: super::subspace::SubspaceConfig::default(),
96
+ learnable_routing_enabled: false, // Default: disabled for backward compatibility
97
+ learnable_routing_config: super::learnable_routing::LearnableRoutingConfig::default(),
98
+ }
99
+ }
100
+ }
101
+
102
+ impl HatConfig {
103
+ pub fn new() -> Self {
104
+ Self::default()
105
+ }
106
+
107
+ pub fn with_beam_width(mut self, width: usize) -> Self {
108
+ self.beam_width = width;
109
+ self
110
+ }
111
+
112
+ pub fn with_temporal_weight(mut self, weight: f32) -> Self {
113
+ self.temporal_weight = weight;
114
+ self
115
+ }
116
+
117
+ pub fn with_propagation_threshold(mut self, threshold: f32) -> Self {
118
+ self.propagation_threshold = threshold;
119
+ self
120
+ }
121
+
122
+ pub fn with_centroid_method(mut self, method: CentroidMethod) -> Self {
123
+ self.centroid_method = method;
124
+ self
125
+ }
126
+
127
+ pub fn with_frechet_iterations(mut self, iterations: usize) -> Self {
128
+ self.frechet_iterations = iterations;
129
+ self
130
+ }
131
+
132
+ pub fn with_subspace_enabled(mut self, enabled: bool) -> Self {
133
+ self.subspace_enabled = enabled;
134
+ self
135
+ }
136
+
137
+ pub fn with_subspace_config(mut self, config: super::subspace::SubspaceConfig) -> Self {
138
+ self.subspace_config = config;
139
+ self.subspace_enabled = true; // Automatically enable when config is provided
140
+ self
141
+ }
142
+
143
+ pub fn with_learnable_routing_enabled(mut self, enabled: bool) -> Self {
144
+ self.learnable_routing_enabled = enabled;
145
+ self
146
+ }
147
+
148
+ pub fn with_learnable_routing_config(mut self, config: super::learnable_routing::LearnableRoutingConfig) -> Self {
149
+ self.learnable_routing_config = config;
150
+ self.learnable_routing_enabled = true; // Automatically enable when config is provided
151
+ self
152
+ }
153
+ }
154
+
155
+ /// Level in the hierarchy
156
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
157
+ pub enum ContainerLevel {
158
+ /// Root level - single global container
159
+ Global,
160
+ /// Session level - conversation/context boundaries
161
+ Session,
162
+ /// Document level - logical groupings within session
163
+ Document,
164
+ /// Chunk level - leaf nodes, actual attention states
165
+ Chunk,
166
+ }
167
+
168
+ impl ContainerLevel {
169
+ fn child_level(&self) -> Option<ContainerLevel> {
170
+ match self {
171
+ ContainerLevel::Global => Some(ContainerLevel::Session),
172
+ ContainerLevel::Session => Some(ContainerLevel::Document),
173
+ ContainerLevel::Document => Some(ContainerLevel::Chunk),
174
+ ContainerLevel::Chunk => None,
175
+ }
176
+ }
177
+
178
+ fn depth(&self) -> usize {
179
+ match self {
180
+ ContainerLevel::Global => 0,
181
+ ContainerLevel::Session => 1,
182
+ ContainerLevel::Document => 2,
183
+ ContainerLevel::Chunk => 3,
184
+ }
185
+ }
186
+ }
187
+
188
+ /// Summary of a session for coarse queries (multi-resolution API)
189
+ #[derive(Debug, Clone)]
190
+ pub struct SessionSummary {
191
+ /// Session ID
192
+ pub id: Id,
193
+
194
+ /// Similarity score to query
195
+ pub score: f32,
196
+
197
+ /// Number of chunks in this session
198
+ pub chunk_count: usize,
199
+
200
+ /// Session timestamp
201
+ pub timestamp: u64,
202
+ }
203
+
204
+ /// Summary of a document for coarse queries
205
+ #[derive(Debug, Clone)]
206
+ pub struct DocumentSummary {
207
+ /// Document ID
208
+ pub id: Id,
209
+
210
+ /// Similarity score to query
211
+ pub score: f32,
212
+
213
+ /// Number of chunks in this document
214
+ pub chunk_count: usize,
215
+
216
+ /// Document timestamp
217
+ pub timestamp: u64,
218
+ }
219
+
220
+ /// A container in the HAT hierarchy
221
+ #[derive(Debug, Clone)]
222
+ struct Container {
223
+ /// Unique identifier
224
+ id: Id,
225
+
226
+ /// Level in hierarchy
227
+ level: ContainerLevel,
228
+
229
+ /// Centroid (mean of children)
230
+ centroid: Point,
231
+
232
+ /// Creation timestamp (ms since epoch)
233
+ timestamp: u64,
234
+
235
+ /// Child container IDs (empty for chunks)
236
+ children: Vec<Id>,
237
+
238
+ /// Number of descendant chunks (for weighted centroid updates)
239
+ descendant_count: usize,
240
+
241
+ /// Accumulated sum of all descendant points (for Euclidean centroid)
242
+ /// Stored as unnormalized to enable incremental updates
243
+ accumulated_sum: Option<Point>,
244
+
245
+ /// Subspace representation (optional, for non-chunk containers)
246
+ /// Captures variance/spread of points within the container
247
+ subspace: Option<super::subspace::Subspace>,
248
+ }
249
+
250
+ impl Container {
251
+ fn new(id: Id, level: ContainerLevel, centroid: Point) -> Self {
252
+ let timestamp = SystemTime::now()
253
+ .duration_since(UNIX_EPOCH)
254
+ .unwrap()
255
+ .as_millis() as u64;
256
+
257
+ // For chunks, the accumulated sum is the point itself
258
+ let accumulated_sum = if level == ContainerLevel::Chunk {
259
+ Some(centroid.clone())
260
+ } else {
261
+ None
262
+ };
263
+
264
+ // Initialize subspace for non-chunk containers
265
+ let subspace = if level != ContainerLevel::Chunk {
266
+ Some(super::subspace::Subspace::new(centroid.dimensionality()))
267
+ } else {
268
+ None
269
+ };
270
+
271
+ Self {
272
+ id,
273
+ level,
274
+ centroid,
275
+ timestamp,
276
+ children: Vec::new(),
277
+ descendant_count: if level == ContainerLevel::Chunk { 1 } else { 0 },
278
+ accumulated_sum,
279
+ subspace,
280
+ }
281
+ }
282
+
283
+ fn is_leaf(&self) -> bool {
284
+ self.level == ContainerLevel::Chunk
285
+ }
286
+ }
287
+
288
+ /// Hierarchical Attention Tree Index
289
+ pub struct HatIndex {
290
+ /// All containers (including root, sessions, documents, chunks)
291
+ containers: HashMap<Id, Container>,
292
+
293
+ /// Root container ID
294
+ root_id: Option<Id>,
295
+
296
+ /// Current active session (where new documents go)
297
+ active_session: Option<Id>,
298
+
299
+ /// Current active document (where new chunks go)
300
+ active_document: Option<Id>,
301
+
302
+ /// Expected dimensionality
303
+ dimensionality: usize,
304
+
305
+ /// Proximity function
306
+ proximity: Arc<dyn Proximity>,
307
+
308
+ /// Merge function (for centroids)
309
+ merge: Arc<dyn Merge>,
310
+
311
+ /// Whether higher proximity = more similar
312
+ higher_is_better: bool,
313
+
314
+ /// Configuration
315
+ config: HatConfig,
316
+
317
+ /// Consolidation state (None if not consolidating)
318
+ consolidation_state: Option<ConsolidationState>,
319
+
320
+ /// Cache of child points during consolidation
321
+ consolidation_points_cache: HashMap<Id, Vec<Point>>,
322
+
323
+ /// Learnable router for adaptive routing weights
324
+ learnable_router: Option<super::learnable_routing::LearnableRouter>,
325
+ }
326
+
327
+ impl HatIndex {
328
+ /// Create a new HAT index with cosine similarity
329
+ pub fn cosine(dimensionality: usize) -> Self {
330
+ use crate::core::proximity::Cosine;
331
+ use crate::core::merge::Mean;
332
+ Self::new(
333
+ dimensionality,
334
+ Arc::new(Cosine),
335
+ Arc::new(Mean),
336
+ true,
337
+ HatConfig::default(),
338
+ )
339
+ }
340
+
341
+ /// Create with custom config
342
+ pub fn with_config(mut self, config: HatConfig) -> Self {
343
+ // Initialize learnable router if enabled
344
+ if config.learnable_routing_enabled {
345
+ self.learnable_router = Some(super::learnable_routing::LearnableRouter::new(
346
+ self.dimensionality,
347
+ config.learnable_routing_config.clone(),
348
+ ));
349
+ }
350
+ self.config = config;
351
+ self
352
+ }
353
+
354
+ /// Create with custom proximity and merge functions
355
+ pub fn new(
356
+ dimensionality: usize,
357
+ proximity: Arc<dyn Proximity>,
358
+ merge: Arc<dyn Merge>,
359
+ higher_is_better: bool,
360
+ config: HatConfig,
361
+ ) -> Self {
362
+ // Initialize learnable router if enabled
363
+ let learnable_router = if config.learnable_routing_enabled {
364
+ Some(super::learnable_routing::LearnableRouter::new(
365
+ dimensionality,
366
+ config.learnable_routing_config.clone(),
367
+ ))
368
+ } else {
369
+ None
370
+ };
371
+
372
+ Self {
373
+ containers: HashMap::new(),
374
+ root_id: None,
375
+ active_session: None,
376
+ active_document: None,
377
+ dimensionality,
378
+ proximity,
379
+ merge,
380
+ higher_is_better,
381
+ config,
382
+ consolidation_state: None,
383
+ consolidation_points_cache: HashMap::new(),
384
+ learnable_router,
385
+ }
386
+ }
387
+
388
+ /// Compute distance (lower = more similar)
389
+ fn distance(&self, a: &Point, b: &Point) -> f32 {
390
+ let prox = self.proximity.proximity(a, b);
391
+ if self.higher_is_better {
392
+ 1.0 - prox
393
+ } else {
394
+ prox
395
+ }
396
+ }
397
+
398
+ /// Compute temporal distance (normalized to 0-1)
399
+ fn temporal_distance(&self, t1: u64, t2: u64) -> f32 {
400
+ let diff = (t1 as i64 - t2 as i64).unsigned_abs() as f64;
401
+ // Exponential decay: e^(-λ * diff)
402
+ // diff is in milliseconds, normalize to hours
403
+ let hours = diff / (1000.0 * 60.0 * 60.0);
404
+ (1.0 - (-self.config.time_decay as f64 * hours).exp()) as f32
405
+ }
406
+
407
+ /// Combined distance with temporal component, optional subspace, and learnable routing
408
+ fn combined_distance(&self, query: &Point, query_time: u64, container: &Container) -> f32 {
409
+ // Compute semantic distance
410
+ let semantic = if self.config.learnable_routing_enabled {
411
+ // Use learnable routing weights
412
+ if let Some(ref router) = self.learnable_router {
413
+ // weighted_similarity returns similarity (higher = better)
414
+ // convert to distance (lower = better)
415
+ let sim = router.weighted_similarity(query, &container.centroid);
416
+ 1.0 - sim
417
+ } else {
418
+ self.distance(query, &container.centroid)
419
+ }
420
+ } else if self.config.subspace_enabled && !container.is_leaf() {
421
+ // Use subspace-aware similarity if available
422
+ if let Some(ref subspace) = container.subspace {
423
+ // combined_subspace_similarity returns similarity (higher = better)
424
+ // convert to distance (lower = better)
425
+ let sim = super::subspace::combined_subspace_similarity(
426
+ query, subspace, &self.config.subspace_config
427
+ );
428
+ 1.0 - sim
429
+ } else {
430
+ self.distance(query, &container.centroid)
431
+ }
432
+ } else {
433
+ self.distance(query, &container.centroid)
434
+ };
435
+
436
+ let temporal = self.temporal_distance(query_time, container.timestamp);
437
+
438
+ // Weighted combination
439
+ let w = self.config.temporal_weight;
440
+ semantic * (1.0 - w) + temporal * w
441
+ }
442
+
443
+ /// Ensure root exists
444
+ fn ensure_root(&mut self) {
445
+ if self.root_id.is_none() {
446
+ let root = Container::new(
447
+ Id::now(),
448
+ ContainerLevel::Global,
449
+ Point::origin(self.dimensionality),
450
+ );
451
+ let root_id = root.id;
452
+ self.containers.insert(root_id, root);
453
+ self.root_id = Some(root_id);
454
+ }
455
+ }
456
+
457
+ /// Ensure active session exists
458
+ fn ensure_session(&mut self) {
459
+ self.ensure_root();
460
+
461
+ if self.active_session.is_none() {
462
+ let session = Container::new(
463
+ Id::now(),
464
+ ContainerLevel::Session,
465
+ Point::origin(self.dimensionality),
466
+ );
467
+ let session_id = session.id;
468
+ self.containers.insert(session_id, session);
469
+
470
+ // Add to root's children
471
+ if let Some(root_id) = self.root_id {
472
+ if let Some(root) = self.containers.get_mut(&root_id) {
473
+ root.children.push(session_id);
474
+ }
475
+ }
476
+
477
+ self.active_session = Some(session_id);
478
+ }
479
+ }
480
+
481
+ /// Ensure active document exists
482
+ fn ensure_document(&mut self) {
483
+ self.ensure_session();
484
+
485
+ if self.active_document.is_none() {
486
+ let document = Container::new(
487
+ Id::now(),
488
+ ContainerLevel::Document,
489
+ Point::origin(self.dimensionality),
490
+ );
491
+ let doc_id = document.id;
492
+ self.containers.insert(doc_id, document);
493
+
494
+ // Add to session's children
495
+ if let Some(session_id) = self.active_session {
496
+ if let Some(session) = self.containers.get_mut(&session_id) {
497
+ session.children.push(doc_id);
498
+ }
499
+ }
500
+
501
+ self.active_document = Some(doc_id);
502
+ }
503
+ }
504
+
505
+ /// Start a new session (call this to create session boundaries)
506
+ pub fn new_session(&mut self) {
507
+ self.active_session = None;
508
+ self.active_document = None;
509
+ }
510
+
511
+ /// Start a new document within current session
512
+ pub fn new_document(&mut self) {
513
+ self.active_document = None;
514
+ }
515
+
516
+ /// Compute Fréchet mean on the unit hypersphere using iterative algorithm
517
+ /// This finds the point that minimizes sum of squared geodesic distances
518
+ fn compute_frechet_mean(&self, points: &[Point], initial: &Point) -> Point {
519
+ let mut mean = initial.clone();
520
+ let iterations = self.config.frechet_iterations;
521
+
522
+ for _ in 0..iterations {
523
+ // Compute weighted tangent vectors (log map)
524
+ let mut tangent_sum = vec![0.0f32; mean.dimensionality()];
525
+
526
+ for point in points {
527
+ // Log map: project point onto tangent space at mean
528
+ // For unit sphere: log_p(q) = θ * (q - (q·p)p) / ||q - (q·p)p||
529
+ // where θ = arccos(p·q)
530
+ let dot: f32 = mean.dims().iter()
531
+ .zip(point.dims().iter())
532
+ .map(|(a, b)| a * b)
533
+ .sum();
534
+
535
+ // Clamp dot product to valid range for arccos
536
+ let dot_clamped = dot.clamp(-1.0, 1.0);
537
+ let theta = dot_clamped.acos();
538
+
539
+ if theta.abs() < 1e-8 {
540
+ // Points are identical, tangent vector is zero
541
+ continue;
542
+ }
543
+
544
+ // Direction in tangent space
545
+ let mut direction: Vec<f32> = point.dims().iter()
546
+ .zip(mean.dims().iter())
547
+ .map(|(q, p)| q - dot * p)
548
+ .collect();
549
+
550
+ // Normalize direction
551
+ let dir_norm: f32 = direction.iter().map(|x| x * x).sum::<f32>().sqrt();
552
+ if dir_norm < 1e-8 {
553
+ continue;
554
+ }
555
+
556
+ for (i, d) in direction.iter_mut().enumerate() {
557
+ tangent_sum[i] += theta * (*d / dir_norm);
558
+ }
559
+ }
560
+
561
+ // Average tangent vector
562
+ let n = points.len() as f32;
563
+ for t in tangent_sum.iter_mut() {
564
+ *t /= n;
565
+ }
566
+
567
+ // Compute tangent vector magnitude
568
+ let tangent_norm: f32 = tangent_sum.iter().map(|x| x * x).sum::<f32>().sqrt();
569
+
570
+ if tangent_norm < 1e-8 {
571
+ // Converged
572
+ break;
573
+ }
574
+
575
+ // Exp map: move along geodesic from mean in tangent direction
576
+ // For unit sphere: exp_p(v) = cos(||v||)p + sin(||v||)(v/||v||)
577
+ let cos_t = tangent_norm.cos();
578
+ let sin_t = tangent_norm.sin();
579
+
580
+ let new_dims: Vec<f32> = mean.dims().iter()
581
+ .zip(tangent_sum.iter())
582
+ .map(|(p, v)| cos_t * p + sin_t * (v / tangent_norm))
583
+ .collect();
584
+
585
+ mean = Point::new(new_dims);
586
+ }
587
+
588
+ // Ensure result is normalized (on the unit sphere)
589
+ mean.normalize()
590
+ }
591
+
592
+ /// Update centroid incrementally when adding a child
593
+ /// Returns the magnitude of the change (for sparse propagation)
594
+ fn update_centroid(&mut self, container_id: Id, new_point: &Point) -> f32 {
595
+ let method = self.config.centroid_method;
596
+
597
+ // First, extract what we need from the container
598
+ let (old_centroid, n, accumulated_sum) = {
599
+ if let Some(container) = self.containers.get(&container_id) {
600
+ (
601
+ container.centroid.clone(),
602
+ container.descendant_count as f32,
603
+ container.accumulated_sum.clone(),
604
+ )
605
+ } else {
606
+ return 0.0;
607
+ }
608
+ };
609
+
610
+ // Handle first child case
611
+ if n == 0.0 {
612
+ if let Some(container) = self.containers.get_mut(&container_id) {
613
+ container.centroid = new_point.clone();
614
+ container.accumulated_sum = Some(new_point.clone());
615
+ container.descendant_count += 1;
616
+ }
617
+ return f32::MAX; // Always propagate first point
618
+ }
619
+
620
+ // Compute new centroid based on method
621
+ let (new_centroid, new_sum) = match method {
622
+ CentroidMethod::Euclidean => {
623
+ // Incremental Euclidean mean using accumulated sum
624
+ let new_sum = if let Some(ref sum) = accumulated_sum {
625
+ sum.dims().iter()
626
+ .zip(new_point.dims().iter())
627
+ .map(|(s, p)| s + p)
628
+ .collect::<Vec<f32>>()
629
+ } else {
630
+ new_point.dims().to_vec()
631
+ };
632
+
633
+ // Compute centroid as normalized mean
634
+ let count = n + 1.0;
635
+ let mean_dims: Vec<f32> = new_sum.iter().map(|s| s / count).collect();
636
+ let centroid = Point::new(mean_dims).normalize();
637
+ (centroid, Point::new(new_sum))
638
+ }
639
+ CentroidMethod::Frechet => {
640
+ // Update accumulated sum
641
+ let new_sum = if let Some(ref sum) = accumulated_sum {
642
+ sum.dims().iter()
643
+ .zip(new_point.dims().iter())
644
+ .map(|(s, p)| s + p)
645
+ .collect::<Vec<f32>>()
646
+ } else {
647
+ new_point.dims().to_vec()
648
+ };
649
+
650
+ // For incremental Fréchet, use geodesic interpolation
651
+ let new_count = n + 1.0;
652
+ let weight = 1.0 / new_count;
653
+ let centroid = Self::geodesic_interpolate_static(&old_centroid, new_point, weight);
654
+ (centroid, Point::new(new_sum))
655
+ }
656
+ };
657
+
658
+ // Now update the container
659
+ let subspace_enabled = self.config.subspace_enabled;
660
+ if let Some(container) = self.containers.get_mut(&container_id) {
661
+ container.centroid = new_centroid.clone();
662
+ container.accumulated_sum = Some(new_sum);
663
+ container.descendant_count += 1;
664
+
665
+ // Update subspace if enabled, incremental covariance is on, and not a chunk
666
+ // When incremental_covariance is false (default), we skip the expensive
667
+ // O(d²) outer product accumulation per insert, deferring to consolidation.
668
+ if subspace_enabled
669
+ && self.config.subspace_config.incremental_covariance
670
+ && container.level != ContainerLevel::Chunk
671
+ {
672
+ if let Some(ref mut subspace) = container.subspace {
673
+ subspace.add_point(new_point);
674
+ // Principal directions recomputed during consolidation
675
+ }
676
+ }
677
+ }
678
+
679
+ // Calculate change magnitude (L2 norm of delta)
680
+ let delta: f32 = old_centroid.dims()
681
+ .iter()
682
+ .zip(new_centroid.dims().iter())
683
+ .map(|(old, new)| (new - old).powi(2))
684
+ .sum::<f32>()
685
+ .sqrt();
686
+
687
+ delta
688
+ }
689
+
690
+ /// Static version of geodesic interpolation (no self reference needed)
691
+ fn geodesic_interpolate_static(a: &Point, b: &Point, t: f32) -> Point {
692
+ // Compute dot product
693
+ let dot: f32 = a.dims().iter()
694
+ .zip(b.dims().iter())
695
+ .map(|(x, y)| x * y)
696
+ .sum();
697
+
698
+ // Clamp to valid range
699
+ let dot_clamped = dot.clamp(-0.9999, 0.9999);
700
+ let theta = dot_clamped.acos();
701
+
702
+ if theta.abs() < 1e-8 {
703
+ // Points are nearly identical
704
+ return a.clone();
705
+ }
706
+
707
+ // Slerp formula: (sin((1-t)θ)/sin(θ)) * a + (sin(tθ)/sin(θ)) * b
708
+ let sin_theta = theta.sin();
709
+ let weight_a = ((1.0 - t) * theta).sin() / sin_theta;
710
+ let weight_b = (t * theta).sin() / sin_theta;
711
+
712
+ let result_dims: Vec<f32> = a.dims().iter()
713
+ .zip(b.dims().iter())
714
+ .map(|(x, y)| weight_a * x + weight_b * y)
715
+ .collect();
716
+
717
+ Point::new(result_dims).normalize()
718
+ }
719
+
720
+ /// Geodesic interpolation on the unit hypersphere (slerp)
721
+ /// Returns a point t fraction of the way from a to b along the great circle
722
+ fn geodesic_interpolate(&self, a: &Point, b: &Point, t: f32) -> Point {
723
+ // Compute dot product
724
+ let dot: f32 = a.dims().iter()
725
+ .zip(b.dims().iter())
726
+ .map(|(x, y)| x * y)
727
+ .sum();
728
+
729
+ // Clamp to valid range
730
+ let dot_clamped = dot.clamp(-0.9999, 0.9999);
731
+ let theta = dot_clamped.acos();
732
+
733
+ if theta.abs() < 1e-8 {
734
+ // Points are nearly identical
735
+ return a.clone();
736
+ }
737
+
738
+ // Slerp formula: (sin((1-t)θ)/sin(θ)) * a + (sin(tθ)/sin(θ)) * b
739
+ let sin_theta = theta.sin();
740
+ let weight_a = ((1.0 - t) * theta).sin() / sin_theta;
741
+ let weight_b = (t * theta).sin() / sin_theta;
742
+
743
+ let result_dims: Vec<f32> = a.dims().iter()
744
+ .zip(b.dims().iter())
745
+ .map(|(x, y)| weight_a * x + weight_b * y)
746
+ .collect();
747
+
748
+ Point::new(result_dims).normalize()
749
+ }
750
+
751
+ /// Sparse propagation: only update parent if change exceeds threshold
752
+ fn propagate_centroid_update(
753
+ &mut self,
754
+ container_id: Id,
755
+ new_point: &Point,
756
+ ancestors: &[Id],
757
+ ) {
758
+ let threshold = self.config.propagation_threshold;
759
+ let mut delta = self.update_centroid(container_id, new_point);
760
+
761
+ // Propagate up the tree if delta exceeds threshold
762
+ for ancestor_id in ancestors {
763
+ if delta < threshold {
764
+ break; // Stop propagation - change too small
765
+ }
766
+ delta = self.update_centroid(*ancestor_id, new_point);
767
+ }
768
+ }
769
+
770
+ /// Search the tree from a starting container
771
+ fn search_tree(
772
+ &self,
773
+ query: &Point,
774
+ query_time: u64,
775
+ start_id: Id,
776
+ k: usize,
777
+ ) -> Vec<(Id, f32)> {
778
+ let mut results: Vec<(Id, f32)> = Vec::new();
779
+
780
+ // Adaptive beam width based on k
781
+ let beam_width = self.config.beam_width.max(k);
782
+
783
+ // BFS with beam search
784
+ let mut current_level = vec![start_id];
785
+
786
+ while !current_level.is_empty() {
787
+ let mut next_level: Vec<(Id, f32)> = Vec::new();
788
+
789
+ for container_id in &current_level {
790
+ if let Some(container) = self.containers.get(container_id) {
791
+ if container.is_leaf() {
792
+ // Leaf node - add to results
793
+ let dist = self.combined_distance(query, query_time, container);
794
+ results.push((*container_id, dist));
795
+ } else {
796
+ // Internal node - score children and add to next level
797
+ for child_id in &container.children {
798
+ if let Some(child) = self.containers.get(child_id) {
799
+ let dist = self.combined_distance(query, query_time, child);
800
+ next_level.push((*child_id, dist));
801
+ }
802
+ }
803
+ }
804
+ }
805
+ }
806
+
807
+ if next_level.is_empty() {
808
+ break;
809
+ }
810
+
811
+ // Sort by distance and take beam_width best
812
+ next_level.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
813
+ current_level = next_level
814
+ .into_iter()
815
+ .take(beam_width)
816
+ .map(|(id, _)| id)
817
+ .collect();
818
+ }
819
+
820
+ // Sort results and return top k
821
+ results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
822
+ results.truncate(k);
823
+ results
824
+ }
825
+
826
+ // =========================================================================
827
+ // Multi-Resolution Query API (inspired by VAR next-scale prediction)
828
+ // =========================================================================
829
+
830
+ /// Coarse query: Get session summaries without descending to chunks
831
+ /// Use this for fast "is there relevant memory?" checks
832
+ pub fn near_sessions(&self, query: &Point, k: usize) -> NearResult<Vec<SessionSummary>> {
833
+ if query.dimensionality() != self.dimensionality {
834
+ return Err(NearError::DimensionalityMismatch {
835
+ expected: self.dimensionality,
836
+ got: query.dimensionality(),
837
+ });
838
+ }
839
+
840
+ let root_id = match self.root_id {
841
+ Some(id) => id,
842
+ None => return Ok(vec![]),
843
+ };
844
+
845
+ let query_time = SystemTime::now()
846
+ .duration_since(UNIX_EPOCH)
847
+ .unwrap()
848
+ .as_millis() as u64;
849
+
850
+ // Get root's children (sessions)
851
+ let root = match self.containers.get(&root_id) {
852
+ Some(r) => r,
853
+ None => return Ok(vec![]),
854
+ };
855
+
856
+ let mut sessions: Vec<SessionSummary> = root.children
857
+ .iter()
858
+ .filter_map(|session_id| {
859
+ let session = self.containers.get(session_id)?;
860
+ if session.level != ContainerLevel::Session {
861
+ return None;
862
+ }
863
+ let dist = self.combined_distance(query, query_time, session);
864
+ let score = if self.higher_is_better { 1.0 - dist } else { dist };
865
+
866
+ Some(SessionSummary {
867
+ id: *session_id,
868
+ score,
869
+ chunk_count: session.descendant_count,
870
+ timestamp: session.timestamp,
871
+ })
872
+ })
873
+ .collect();
874
+
875
+ // Sort by score (higher is better)
876
+ sessions.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
877
+ sessions.truncate(k);
878
+
879
+ Ok(sessions)
880
+ }
881
+
882
+ /// Refine within a specific session: Get document summaries
883
+ pub fn near_documents(&self, session_id: Id, query: &Point, k: usize) -> NearResult<Vec<DocumentSummary>> {
884
+ if query.dimensionality() != self.dimensionality {
885
+ return Err(NearError::DimensionalityMismatch {
886
+ expected: self.dimensionality,
887
+ got: query.dimensionality(),
888
+ });
889
+ }
890
+
891
+ let query_time = SystemTime::now()
892
+ .duration_since(UNIX_EPOCH)
893
+ .unwrap()
894
+ .as_millis() as u64;
895
+
896
+ let session = match self.containers.get(&session_id) {
897
+ Some(s) => s,
898
+ None => return Ok(vec![]),
899
+ };
900
+
901
+ let mut documents: Vec<DocumentSummary> = session.children
902
+ .iter()
903
+ .filter_map(|doc_id| {
904
+ let doc = self.containers.get(doc_id)?;
905
+ if doc.level != ContainerLevel::Document {
906
+ return None;
907
+ }
908
+ let dist = self.combined_distance(query, query_time, doc);
909
+ let score = if self.higher_is_better { 1.0 - dist } else { dist };
910
+
911
+ Some(DocumentSummary {
912
+ id: *doc_id,
913
+ score,
914
+ chunk_count: doc.descendant_count,
915
+ timestamp: doc.timestamp,
916
+ })
917
+ })
918
+ .collect();
919
+
920
+ documents.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
921
+ documents.truncate(k);
922
+
923
+ Ok(documents)
924
+ }
925
+
926
+ /// Refine within a specific document: Get chunk results
927
+ pub fn near_in_document(&self, doc_id: Id, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
928
+ if query.dimensionality() != self.dimensionality {
929
+ return Err(NearError::DimensionalityMismatch {
930
+ expected: self.dimensionality,
931
+ got: query.dimensionality(),
932
+ });
933
+ }
934
+
935
+ let query_time = SystemTime::now()
936
+ .duration_since(UNIX_EPOCH)
937
+ .unwrap()
938
+ .as_millis() as u64;
939
+
940
+ let doc = match self.containers.get(&doc_id) {
941
+ Some(d) => d,
942
+ None => return Ok(vec![]),
943
+ };
944
+
945
+ let mut chunks: Vec<SearchResult> = doc.children
946
+ .iter()
947
+ .filter_map(|chunk_id| {
948
+ let chunk = self.containers.get(chunk_id)?;
949
+ if chunk.level != ContainerLevel::Chunk {
950
+ return None;
951
+ }
952
+ let dist = self.combined_distance(query, query_time, chunk);
953
+ let score = if self.higher_is_better { 1.0 - dist } else { dist };
954
+
955
+ Some(SearchResult::new(*chunk_id, score))
956
+ })
957
+ .collect();
958
+
959
+ chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
960
+ chunks.truncate(k);
961
+
962
+ Ok(chunks)
963
+ }
964
+
965
+ /// Get statistics about the tree structure
966
+ pub fn stats(&self) -> HatStats {
967
+ let mut stats = HatStats::default();
968
+
969
+ for container in self.containers.values() {
970
+ match container.level {
971
+ ContainerLevel::Global => stats.global_count += 1,
972
+ ContainerLevel::Session => stats.session_count += 1,
973
+ ContainerLevel::Document => stats.document_count += 1,
974
+ ContainerLevel::Chunk => stats.chunk_count += 1,
975
+ }
976
+ }
977
+
978
+ stats
979
+ }
980
+
981
+ // =========================================================================
982
+ // Learnable Routing API
983
+ // =========================================================================
984
+
985
+ /// Record positive feedback for a query result (successful retrieval)
986
+ ///
987
+ /// Call this when a retrieved result was useful/relevant.
988
+ /// The router learns to route similar queries to similar containers.
989
+ pub fn record_retrieval_success(&mut self, query: &Point, result_id: Id) {
990
+ if let Some(ref mut router) = self.learnable_router {
991
+ // Find the container for this result and record feedback for each level
992
+ if let Some(container) = self.containers.get(&result_id) {
993
+ router.record_success(query, &container.centroid, container.level.depth());
994
+ }
995
+ }
996
+ }
997
+
998
+ /// Record negative feedback for a query result (unsuccessful retrieval)
999
+ ///
1000
+ /// Call this when a retrieved result was not useful/relevant.
1001
+ pub fn record_retrieval_failure(&mut self, query: &Point, result_id: Id) {
1002
+ if let Some(ref mut router) = self.learnable_router {
1003
+ if let Some(container) = self.containers.get(&result_id) {
1004
+ router.record_failure(query, &container.centroid, container.level.depth());
1005
+ }
1006
+ }
1007
+ }
1008
+
1009
+ /// Record implicit feedback with a relevance score (0.0 = irrelevant, 1.0 = highly relevant)
1010
+ ///
1011
+ /// Use this for continuous feedback signals like click-through rate, dwell time, etc.
1012
+ pub fn record_implicit_feedback(&mut self, query: &Point, result_id: Id, relevance: f32) {
1013
+ if let Some(ref mut router) = self.learnable_router {
1014
+ if let Some(container) = self.containers.get(&result_id) {
1015
+ router.record_implicit(query, &container.centroid, container.level.depth(), relevance);
1016
+ }
1017
+ }
1018
+ }
1019
+
1020
+ /// Get learnable router statistics (if enabled)
1021
+ pub fn router_stats(&self) -> Option<super::learnable_routing::RouterStats> {
1022
+ self.learnable_router.as_ref().map(|r| r.stats())
1023
+ }
1024
+
1025
+ /// Get current routing weights (if learnable routing is enabled)
1026
+ pub fn routing_weights(&self) -> Option<&[f32]> {
1027
+ self.learnable_router.as_ref().map(|r| r.weights())
1028
+ }
1029
+
1030
+ /// Reset learnable routing weights to uniform
1031
+ pub fn reset_routing_weights(&mut self) {
1032
+ if let Some(ref mut router) = self.learnable_router {
1033
+ router.reset_weights();
1034
+ }
1035
+ }
1036
+
1037
+ /// Check if learnable routing is enabled
1038
+ pub fn is_learnable_routing_enabled(&self) -> bool {
1039
+ self.learnable_router.is_some()
1040
+ }
1041
+ }
1042
+
1043
+ /// Statistics about the HAT tree structure
1044
+ #[derive(Debug, Clone, Default)]
1045
+ pub struct HatStats {
1046
+ pub global_count: usize,
1047
+ pub session_count: usize,
1048
+ pub document_count: usize,
1049
+ pub chunk_count: usize,
1050
+ }
1051
+
1052
+ impl Near for HatIndex {
1053
+ fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
1054
+ // Check dimensionality
1055
+ if query.dimensionality() != self.dimensionality {
1056
+ return Err(NearError::DimensionalityMismatch {
1057
+ expected: self.dimensionality,
1058
+ got: query.dimensionality(),
1059
+ });
1060
+ }
1061
+
1062
+ // Handle empty index
1063
+ let root_id = match self.root_id {
1064
+ Some(id) => id,
1065
+ None => return Ok(vec![]),
1066
+ };
1067
+
1068
+ // Current time for temporal scoring
1069
+ let query_time = SystemTime::now()
1070
+ .duration_since(UNIX_EPOCH)
1071
+ .unwrap()
1072
+ .as_millis() as u64;
1073
+
1074
+ // Search tree
1075
+ let results = self.search_tree(query, query_time, root_id, k);
1076
+
1077
+ // Convert to SearchResult
1078
+ let search_results: Vec<SearchResult> = results
1079
+ .into_iter()
1080
+ .map(|(id, dist)| {
1081
+ let score = if self.higher_is_better {
1082
+ 1.0 - dist
1083
+ } else {
1084
+ dist
1085
+ };
1086
+ SearchResult::new(id, score)
1087
+ })
1088
+ .collect();
1089
+
1090
+ Ok(search_results)
1091
+ }
1092
+
1093
+ fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
1094
+ // Check dimensionality
1095
+ if query.dimensionality() != self.dimensionality {
1096
+ return Err(NearError::DimensionalityMismatch {
1097
+ expected: self.dimensionality,
1098
+ got: query.dimensionality(),
1099
+ });
1100
+ }
1101
+
1102
+ // Use near with all points, then filter
1103
+ let all_results = self.near(query, self.containers.len())?;
1104
+
1105
+ let filtered: Vec<SearchResult> = all_results
1106
+ .into_iter()
1107
+ .filter(|r| {
1108
+ if self.higher_is_better {
1109
+ r.score >= threshold
1110
+ } else {
1111
+ r.score <= threshold
1112
+ }
1113
+ })
1114
+ .collect();
1115
+
1116
+ Ok(filtered)
1117
+ }
1118
+
1119
+ fn add(&mut self, id: Id, point: &Point) -> NearResult<()> {
1120
+ // Check dimensionality
1121
+ if point.dimensionality() != self.dimensionality {
1122
+ return Err(NearError::DimensionalityMismatch {
1123
+ expected: self.dimensionality,
1124
+ got: point.dimensionality(),
1125
+ });
1126
+ }
1127
+
1128
+ // Ensure hierarchy exists
1129
+ self.ensure_document();
1130
+
1131
+ // Create chunk container
1132
+ let chunk = Container::new(id, ContainerLevel::Chunk, point.clone());
1133
+ self.containers.insert(id, chunk);
1134
+
1135
+ // Add to document's children
1136
+ if let Some(doc_id) = self.active_document {
1137
+ if let Some(doc) = self.containers.get_mut(&doc_id) {
1138
+ doc.children.push(id);
1139
+ }
1140
+
1141
+ // Build ancestor chain for sparse propagation
1142
+ let mut ancestors = Vec::new();
1143
+ if let Some(session_id) = self.active_session {
1144
+ ancestors.push(session_id);
1145
+ if let Some(root_id) = self.root_id {
1146
+ ancestors.push(root_id);
1147
+ }
1148
+ }
1149
+
1150
+ // Sparse propagation: only update ancestors if change is significant
1151
+ self.propagate_centroid_update(doc_id, point, &ancestors);
1152
+ }
1153
+
1154
+ // Check if document needs splitting
1155
+ if let Some(doc_id) = self.active_document {
1156
+ if let Some(doc) = self.containers.get(&doc_id) {
1157
+ if doc.children.len() >= self.config.max_children {
1158
+ // Start a new document
1159
+ self.new_document();
1160
+ }
1161
+ }
1162
+ }
1163
+
1164
+ // Check if session needs splitting
1165
+ if let Some(session_id) = self.active_session {
1166
+ if let Some(session) = self.containers.get(&session_id) {
1167
+ if session.children.len() >= self.config.max_children {
1168
+ // Start a new session
1169
+ self.new_session();
1170
+ }
1171
+ }
1172
+ }
1173
+
1174
+ Ok(())
1175
+ }
1176
+
1177
+ fn remove(&mut self, id: Id) -> NearResult<()> {
1178
+ // Remove the chunk
1179
+ self.containers.remove(&id);
1180
+
1181
+ // Note: We don't update centroids on remove for simplicity
1182
+ // A production implementation would need to handle this
1183
+
1184
+ Ok(())
1185
+ }
1186
+
1187
+ fn rebuild(&mut self) -> NearResult<()> {
1188
+ // Recalculate all centroids from scratch
1189
+ // For now, this is a no-op since we maintain incrementally
1190
+ Ok(())
1191
+ }
1192
+
1193
+ fn is_ready(&self) -> bool {
1194
+ true
1195
+ }
1196
+
1197
+ fn len(&self) -> usize {
1198
+ // Count only chunk-level containers
1199
+ self.containers.values()
1200
+ .filter(|c| c.level == ContainerLevel::Chunk)
1201
+ .count()
1202
+ }
1203
+ }
1204
+
1205
+ // =============================================================================
1206
+ // Consolidation Implementation
1207
+ // =============================================================================
1208
+
1209
+ impl HatIndex {
1210
+ /// Collect all leaf points for a container (recursively)
1211
+ fn collect_leaf_points(&self, container_id: Id) -> Vec<Point> {
1212
+ let container = match self.containers.get(&container_id) {
1213
+ Some(c) => c,
1214
+ None => return vec![],
1215
+ };
1216
+
1217
+ if container.is_leaf() {
1218
+ return vec![container.centroid.clone()];
1219
+ }
1220
+
1221
+ let mut points = Vec::new();
1222
+ for child_id in &container.children {
1223
+ points.extend(self.collect_leaf_points(*child_id));
1224
+ }
1225
+ points
1226
+ }
1227
+
1228
+ /// Get all container IDs at a given level
1229
+ fn containers_at_level(&self, level: ContainerLevel) -> Vec<Id> {
1230
+ self.containers
1231
+ .iter()
1232
+ .filter(|(_, c)| c.level == level)
1233
+ .map(|(id, _)| *id)
1234
+ .collect()
1235
+ }
1236
+
1237
+ /// Recompute a container's centroid from its descendants
1238
+ fn recompute_centroid(&mut self, container_id: Id) -> Option<f32> {
1239
+ // First collect the points (need to release borrow)
1240
+ let points = self.collect_leaf_points(container_id);
1241
+
1242
+ if points.is_empty() {
1243
+ return None;
1244
+ }
1245
+
1246
+ let new_centroid = match compute_exact_centroid(&points) {
1247
+ Some(c) => c,
1248
+ None => return None,
1249
+ };
1250
+
1251
+ // Get subspace config for recomputation
1252
+ let subspace_enabled = self.config.subspace_enabled;
1253
+ let subspace_rank = self.config.subspace_config.rank;
1254
+
1255
+ // Now update the container
1256
+ let drift = if let Some(container) = self.containers.get_mut(&container_id) {
1257
+ let old_centroid = container.centroid.clone();
1258
+ let drift = centroid_drift(&old_centroid, &new_centroid);
1259
+ container.centroid = new_centroid;
1260
+ container.descendant_count = points.len();
1261
+
1262
+ // Update accumulated sum
1263
+ let sum: Vec<f32> = points.iter()
1264
+ .fold(vec![0.0f32; self.dimensionality], |mut acc, p| {
1265
+ for (i, &v) in p.dims().iter().enumerate() {
1266
+ acc[i] += v;
1267
+ }
1268
+ acc
1269
+ });
1270
+ container.accumulated_sum = Some(Point::new(sum));
1271
+
1272
+ // Recompute subspace during consolidation if enabled
1273
+ if subspace_enabled && container.level != ContainerLevel::Chunk {
1274
+ let mut subspace = super::subspace::Subspace::new(self.dimensionality);
1275
+ for point in &points {
1276
+ subspace.add_point(point);
1277
+ }
1278
+ subspace.recompute_subspace(subspace_rank);
1279
+ container.subspace = Some(subspace);
1280
+ }
1281
+
1282
+ Some(drift)
1283
+ } else {
1284
+ None
1285
+ };
1286
+
1287
+ drift
1288
+ }
1289
+
1290
+ /// Check if a container should be merged (too few children)
1291
+ fn should_merge(&self, container_id: Id, threshold: usize) -> bool {
1292
+ if let Some(container) = self.containers.get(&container_id) {
1293
+ // Don't merge chunks, root, or sessions (for now)
1294
+ if container.level == ContainerLevel::Chunk ||
1295
+ container.level == ContainerLevel::Global ||
1296
+ container.level == ContainerLevel::Session {
1297
+ return false;
1298
+ }
1299
+ container.children.len() < threshold
1300
+ } else {
1301
+ false
1302
+ }
1303
+ }
1304
+
1305
+ /// Check if a container should be split (too many children)
1306
+ fn should_split(&self, container_id: Id, threshold: usize) -> bool {
1307
+ if let Some(container) = self.containers.get(&container_id) {
1308
+ // Don't split chunks
1309
+ if container.level == ContainerLevel::Chunk {
1310
+ return false;
1311
+ }
1312
+ container.children.len() > threshold
1313
+ } else {
1314
+ false
1315
+ }
1316
+ }
1317
+
1318
+ /// Find a sibling container to merge with
1319
+ fn find_merge_sibling(&self, container_id: Id) -> Option<Id> {
1320
+ // Find parent
1321
+ let parent_id = self.containers.iter()
1322
+ .find(|(_, c)| c.children.contains(&container_id))
1323
+ .map(|(id, _)| *id)?;
1324
+
1325
+ let parent = self.containers.get(&parent_id)?;
1326
+
1327
+ // Find smallest sibling
1328
+ let mut smallest: Option<(Id, usize)> = None;
1329
+ for child_id in &parent.children {
1330
+ if *child_id == container_id {
1331
+ continue;
1332
+ }
1333
+ if let Some(child) = self.containers.get(child_id) {
1334
+ let size = child.children.len();
1335
+ if smallest.is_none() || size < smallest.unwrap().1 {
1336
+ smallest = Some((*child_id, size));
1337
+ }
1338
+ }
1339
+ }
1340
+
1341
+ smallest.map(|(id, _)| id)
1342
+ }
1343
+
1344
+ /// Merge container B into container A
1345
+ fn merge_containers(&mut self, a_id: Id, b_id: Id) {
1346
+ // Get children from B
1347
+ let b_children: Vec<Id> = if let Some(b) = self.containers.get(&b_id) {
1348
+ b.children.clone()
1349
+ } else {
1350
+ return;
1351
+ };
1352
+
1353
+ // Add children to A
1354
+ if let Some(a) = self.containers.get_mut(&a_id) {
1355
+ a.children.extend(b_children);
1356
+ }
1357
+
1358
+ // Remove B from its parent's children
1359
+ let parent_id = self.containers.iter()
1360
+ .find(|(_, c)| c.children.contains(&b_id))
1361
+ .map(|(id, _)| *id);
1362
+
1363
+ if let Some(pid) = parent_id {
1364
+ if let Some(parent) = self.containers.get_mut(&pid) {
1365
+ parent.children.retain(|id| *id != b_id);
1366
+ }
1367
+ }
1368
+
1369
+ // Remove B
1370
+ self.containers.remove(&b_id);
1371
+
1372
+ // Recompute A's centroid
1373
+ self.recompute_centroid(a_id);
1374
+ }
1375
+
1376
+ /// Split a container into two
1377
+ fn split_container(&mut self, container_id: Id) -> Option<Id> {
1378
+ // Get container info
1379
+ let (level, children, parent_id) = {
1380
+ let container = self.containers.get(&container_id)?;
1381
+ let parent_id = self.containers.iter()
1382
+ .find(|(_, c)| c.children.contains(&container_id))
1383
+ .map(|(id, _)| *id);
1384
+ (container.level, container.children.clone(), parent_id)
1385
+ };
1386
+
1387
+ if children.len() < 2 {
1388
+ return None;
1389
+ }
1390
+
1391
+ // Simple split: divide children in half
1392
+ let mid = children.len() / 2;
1393
+ let (keep, move_to_new) = children.split_at(mid);
1394
+
1395
+ // Create new container
1396
+ let new_id = Id::now();
1397
+ let new_container = Container::new(
1398
+ new_id,
1399
+ level,
1400
+ Point::origin(self.dimensionality),
1401
+ );
1402
+ self.containers.insert(new_id, new_container);
1403
+
1404
+ // Update original container
1405
+ if let Some(container) = self.containers.get_mut(&container_id) {
1406
+ container.children = keep.to_vec();
1407
+ }
1408
+
1409
+ // Set new container's children
1410
+ if let Some(new_container) = self.containers.get_mut(&new_id) {
1411
+ new_container.children = move_to_new.to_vec();
1412
+ }
1413
+
1414
+ // Add new container to parent
1415
+ if let Some(pid) = parent_id {
1416
+ if let Some(parent) = self.containers.get_mut(&pid) {
1417
+ parent.children.push(new_id);
1418
+ }
1419
+ }
1420
+
1421
+ // Recompute centroids
1422
+ self.recompute_centroid(container_id);
1423
+ self.recompute_centroid(new_id);
1424
+
1425
+ Some(new_id)
1426
+ }
1427
+
1428
+ /// Remove containers with no children (except chunks)
1429
+ fn prune_empty(&mut self) -> usize {
1430
+ let mut pruned = 0;
1431
+
1432
+ loop {
1433
+ let empty_ids: Vec<Id> = self.containers
1434
+ .iter()
1435
+ .filter(|(_, c)| {
1436
+ c.level != ContainerLevel::Chunk &&
1437
+ c.level != ContainerLevel::Global &&
1438
+ c.children.is_empty()
1439
+ })
1440
+ .map(|(id, _)| *id)
1441
+ .collect();
1442
+
1443
+ if empty_ids.is_empty() {
1444
+ break;
1445
+ }
1446
+
1447
+ for id in empty_ids {
1448
+ // Remove from parent's children
1449
+ let parent_id = self.containers.iter()
1450
+ .find(|(_, c)| c.children.contains(&id))
1451
+ .map(|(pid, _)| *pid);
1452
+
1453
+ if let Some(pid) = parent_id {
1454
+ if let Some(parent) = self.containers.get_mut(&pid) {
1455
+ parent.children.retain(|cid| *cid != id);
1456
+ }
1457
+ }
1458
+
1459
+ self.containers.remove(&id);
1460
+ pruned += 1;
1461
+ }
1462
+ }
1463
+
1464
+ pruned
1465
+ }
1466
+ }
1467
+
1468
+ impl Consolidate for HatIndex {
1469
+ fn begin_consolidation(&mut self, config: ConsolidationConfig) {
1470
+ let mut state = ConsolidationState::new(config);
1471
+ state.start();
1472
+
1473
+ // Initialize work queue with all containers for leaf collection
1474
+ let all_ids: VecDeque<Id> = self.containers.keys().copied().collect();
1475
+ state.work_queue = all_ids;
1476
+
1477
+ self.consolidation_state = Some(state);
1478
+ self.consolidation_points_cache.clear();
1479
+ }
1480
+
1481
+ fn consolidation_tick(&mut self) -> ConsolidationTickResult {
1482
+ // Take ownership of state to avoid borrow issues
1483
+ let mut state = match self.consolidation_state.take() {
1484
+ Some(s) => s,
1485
+ None => {
1486
+ return ConsolidationTickResult::Complete(ConsolidationMetrics::default());
1487
+ }
1488
+ };
1489
+
1490
+ let batch_size = state.config.batch_size;
1491
+
1492
+ match state.phase {
1493
+ ConsolidationPhase::Idle => {
1494
+ state.start();
1495
+ }
1496
+
1497
+ ConsolidationPhase::CollectingLeaves => {
1498
+ state.next_phase();
1499
+
1500
+ // Populate work queue with non-chunk containers (bottom-up)
1501
+ let docs = self.containers_at_level(ContainerLevel::Document);
1502
+ let sessions = self.containers_at_level(ContainerLevel::Session);
1503
+ let globals = self.containers_at_level(ContainerLevel::Global);
1504
+
1505
+ state.work_queue.clear();
1506
+ state.work_queue.extend(docs);
1507
+ state.work_queue.extend(sessions);
1508
+ state.work_queue.extend(globals);
1509
+ }
1510
+
1511
+ ConsolidationPhase::RecomputingCentroids => {
1512
+ let mut processed = 0;
1513
+ let mut to_recompute = Vec::new();
1514
+
1515
+ while processed < batch_size {
1516
+ match state.work_queue.pop_front() {
1517
+ Some(id) => {
1518
+ to_recompute.push(id);
1519
+ state.processed.insert(id);
1520
+ processed += 1;
1521
+ }
1522
+ None => break,
1523
+ };
1524
+ }
1525
+
1526
+ // Now recompute without holding state borrow
1527
+ for container_id in to_recompute {
1528
+ if let Some(drift) = self.recompute_centroid(container_id) {
1529
+ state.record_drift(drift);
1530
+ state.metrics.centroids_recomputed += 1;
1531
+ }
1532
+ state.metrics.containers_processed += 1;
1533
+ }
1534
+
1535
+ if state.work_queue.is_empty() {
1536
+ state.next_phase();
1537
+
1538
+ if state.phase == ConsolidationPhase::AnalyzingStructure {
1539
+ let docs = self.containers_at_level(ContainerLevel::Document);
1540
+ state.work_queue.extend(docs);
1541
+ }
1542
+ }
1543
+ }
1544
+
1545
+ ConsolidationPhase::AnalyzingStructure => {
1546
+ let merge_threshold = state.config.merge_threshold;
1547
+ let split_threshold = state.config.split_threshold;
1548
+ let mut processed = 0;
1549
+ let mut to_analyze = Vec::new();
1550
+
1551
+ while processed < batch_size {
1552
+ match state.work_queue.pop_front() {
1553
+ Some(id) => {
1554
+ to_analyze.push(id);
1555
+ state.processed.insert(id);
1556
+ processed += 1;
1557
+ }
1558
+ None => break,
1559
+ };
1560
+ }
1561
+
1562
+ // Analyze without holding state borrow
1563
+ for container_id in to_analyze {
1564
+ if self.should_merge(container_id, merge_threshold) {
1565
+ if let Some(sibling) = self.find_merge_sibling(container_id) {
1566
+ state.add_merge_candidate(container_id, sibling);
1567
+ }
1568
+ } else if self.should_split(container_id, split_threshold) {
1569
+ state.add_split_candidate(container_id);
1570
+ }
1571
+ }
1572
+
1573
+ if state.work_queue.is_empty() {
1574
+ state.next_phase();
1575
+ }
1576
+ }
1577
+
1578
+ ConsolidationPhase::Merging => {
1579
+ let mut processed = 0;
1580
+ let mut to_merge = Vec::new();
1581
+
1582
+ while processed < batch_size {
1583
+ match state.next_merge() {
1584
+ Some(pair) => {
1585
+ to_merge.push(pair);
1586
+ processed += 1;
1587
+ }
1588
+ None => break,
1589
+ };
1590
+ }
1591
+
1592
+ for (a, b) in to_merge {
1593
+ self.merge_containers(a, b);
1594
+ state.metrics.containers_merged += 1;
1595
+ }
1596
+
1597
+ if !state.has_merges() {
1598
+ state.next_phase();
1599
+ }
1600
+ }
1601
+
1602
+ ConsolidationPhase::Splitting => {
1603
+ let mut processed = 0;
1604
+ let mut to_split = Vec::new();
1605
+
1606
+ while processed < batch_size {
1607
+ match state.next_split() {
1608
+ Some(id) => {
1609
+ to_split.push(id);
1610
+ processed += 1;
1611
+ }
1612
+ None => break,
1613
+ };
1614
+ }
1615
+
1616
+ for container_id in to_split {
1617
+ if self.split_container(container_id).is_some() {
1618
+ state.metrics.containers_split += 1;
1619
+ }
1620
+ }
1621
+
1622
+ if !state.has_splits() {
1623
+ state.next_phase();
1624
+ }
1625
+ }
1626
+
1627
+ ConsolidationPhase::Pruning => {
1628
+ let pruned = self.prune_empty();
1629
+ state.metrics.containers_pruned = pruned;
1630
+ state.next_phase();
1631
+ }
1632
+
1633
+ ConsolidationPhase::OptimizingLayout => {
1634
+ for container in self.containers.values_mut() {
1635
+ if container.children.len() > 1 {
1636
+ // Placeholder for future optimization
1637
+ }
1638
+ }
1639
+ state.next_phase();
1640
+ }
1641
+
1642
+ ConsolidationPhase::Complete => {
1643
+ // Already complete
1644
+ }
1645
+ }
1646
+
1647
+ state.metrics.ticks += 1;
1648
+
1649
+ if state.is_complete() {
1650
+ let metrics = state.metrics.clone();
1651
+ self.consolidation_points_cache.clear();
1652
+ ConsolidationTickResult::Complete(metrics)
1653
+ } else {
1654
+ let progress = state.progress();
1655
+ self.consolidation_state = Some(state);
1656
+ ConsolidationTickResult::Continue(progress)
1657
+ }
1658
+ }
1659
+
1660
+ fn is_consolidating(&self) -> bool {
1661
+ self.consolidation_state.is_some()
1662
+ }
1663
+
1664
+ fn consolidation_progress(&self) -> Option<ConsolidationProgress> {
1665
+ self.consolidation_state.as_ref().map(|s| s.progress())
1666
+ }
1667
+
1668
+ fn cancel_consolidation(&mut self) {
1669
+ self.consolidation_state = None;
1670
+ self.consolidation_points_cache.clear();
1671
+ }
1672
+ }
1673
+
1674
+ // =============================================================================
1675
+ // Persistence Implementation
1676
+ // =============================================================================
1677
+
1678
+ impl HatIndex {
1679
+ /// Serialize the index to bytes
1680
+ ///
1681
+ /// # Example
1682
+ /// ```rust,ignore
1683
+ /// let bytes = hat.to_bytes()?;
1684
+ /// std::fs::write("index.hat", bytes)?;
1685
+ /// ```
1686
+ pub fn to_bytes(&self) -> Result<Vec<u8>, super::persistence::PersistError> {
1687
+ use super::persistence::{SerializedHat, SerializedContainer, LevelByte};
1688
+
1689
+ let containers: Vec<SerializedContainer> = self.containers.iter()
1690
+ .map(|(_, c)| {
1691
+ let level = match c.level {
1692
+ ContainerLevel::Global => LevelByte::Root,
1693
+ ContainerLevel::Session => LevelByte::Session,
1694
+ ContainerLevel::Document => LevelByte::Document,
1695
+ ContainerLevel::Chunk => LevelByte::Chunk,
1696
+ };
1697
+
1698
+ SerializedContainer {
1699
+ id: c.id,
1700
+ level,
1701
+ timestamp: c.timestamp,
1702
+ children: c.children.clone(),
1703
+ descendant_count: c.descendant_count as u64,
1704
+ centroid: c.centroid.dims().to_vec(),
1705
+ accumulated_sum: c.accumulated_sum.as_ref().map(|p| p.dims().to_vec()),
1706
+ }
1707
+ })
1708
+ .collect();
1709
+
1710
+ let router_weights = self.learnable_router.as_ref()
1711
+ .map(|r| r.weights().to_vec());
1712
+
1713
+ let serialized = SerializedHat {
1714
+ version: 1,
1715
+ dimensionality: self.dimensionality as u32,
1716
+ root_id: self.root_id,
1717
+ containers,
1718
+ active_session: self.active_session,
1719
+ active_document: self.active_document,
1720
+ router_weights,
1721
+ };
1722
+
1723
+ serialized.to_bytes()
1724
+ }
1725
+
1726
+ /// Deserialize an index from bytes
1727
+ ///
1728
+ /// # Example
1729
+ /// ```rust,ignore
1730
+ /// let bytes = std::fs::read("index.hat")?;
1731
+ /// let hat = HatIndex::from_bytes(&bytes)?;
1732
+ /// ```
1733
+ pub fn from_bytes(data: &[u8]) -> Result<Self, super::persistence::PersistError> {
1734
+ use super::persistence::{SerializedHat, LevelByte, PersistError};
1735
+ use crate::core::proximity::Cosine;
1736
+ use crate::core::merge::Mean;
1737
+
1738
+ let serialized = SerializedHat::from_bytes(data)?;
1739
+ let dimensionality = serialized.dimensionality as usize;
1740
+
1741
+ // Create a new index with default settings
1742
+ let mut index = Self::new(
1743
+ dimensionality,
1744
+ Arc::new(Cosine),
1745
+ Arc::new(Mean),
1746
+ true,
1747
+ HatConfig::default(),
1748
+ );
1749
+
1750
+ // Restore containers
1751
+ for sc in serialized.containers {
1752
+ let level = match sc.level {
1753
+ LevelByte::Root => ContainerLevel::Global,
1754
+ LevelByte::Session => ContainerLevel::Session,
1755
+ LevelByte::Document => ContainerLevel::Document,
1756
+ LevelByte::Chunk => ContainerLevel::Chunk,
1757
+ };
1758
+
1759
+ // Verify dimension
1760
+ if sc.centroid.len() != dimensionality {
1761
+ return Err(PersistError::DimensionMismatch {
1762
+ expected: dimensionality,
1763
+ found: sc.centroid.len(),
1764
+ });
1765
+ }
1766
+
1767
+ let centroid = Point::new(sc.centroid);
1768
+ let accumulated_sum = sc.accumulated_sum.map(Point::new);
1769
+
1770
+ let container = Container {
1771
+ id: sc.id,
1772
+ level,
1773
+ centroid,
1774
+ timestamp: sc.timestamp,
1775
+ children: sc.children,
1776
+ descendant_count: sc.descendant_count as usize,
1777
+ accumulated_sum,
1778
+ subspace: if level != ContainerLevel::Chunk {
1779
+ Some(super::subspace::Subspace::new(dimensionality))
1780
+ } else {
1781
+ None
1782
+ },
1783
+ };
1784
+
1785
+ index.containers.insert(sc.id, container);
1786
+ }
1787
+
1788
+ // Restore state
1789
+ index.root_id = serialized.root_id;
1790
+ index.active_session = serialized.active_session;
1791
+ index.active_document = serialized.active_document;
1792
+
1793
+ // Restore router weights if present
1794
+ if let Some(weights) = serialized.router_weights {
1795
+ let mut router = super::learnable_routing::LearnableRouter::default_for_dims(dimensionality);
1796
+ let weight_bytes: Vec<u8> = weights.iter()
1797
+ .flat_map(|w| w.to_le_bytes())
1798
+ .collect();
1799
+ router.deserialize_weights(&weight_bytes)
1800
+ .map_err(|e| PersistError::Corrupted(e.to_string()))?;
1801
+ index.learnable_router = Some(router);
1802
+ }
1803
+
1804
+ Ok(index)
1805
+ }
1806
+
1807
+ /// Save the index to a file
1808
+ pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), super::persistence::PersistError> {
1809
+ let bytes = self.to_bytes()?;
1810
+ std::fs::write(path, bytes)?;
1811
+ Ok(())
1812
+ }
1813
+
1814
+ /// Load an index from a file
1815
+ pub fn load_from_file(path: &std::path::Path) -> Result<Self, super::persistence::PersistError> {
1816
+ let bytes = std::fs::read(path)?;
1817
+ Self::from_bytes(&bytes)
1818
+ }
1819
+ }
1820
+
1821
+ #[cfg(test)]
1822
+ mod tests {
1823
+ use super::*;
1824
+
1825
+ #[test]
1826
+ fn test_hat_add() {
1827
+ let mut index = HatIndex::cosine(3);
1828
+
1829
+ let id = Id::now();
1830
+ let point = Point::new(vec![1.0, 0.0, 0.0]);
1831
+
1832
+ index.add(id, &point).unwrap();
1833
+
1834
+ assert_eq!(index.len(), 1);
1835
+ }
1836
+
1837
+ #[test]
1838
+ fn test_hat_near() {
1839
+ let mut index = HatIndex::cosine(3);
1840
+
1841
+ // Add some points
1842
+ let points = vec![
1843
+ Point::new(vec![1.0, 0.0, 0.0]),
1844
+ Point::new(vec![0.0, 1.0, 0.0]),
1845
+ Point::new(vec![0.0, 0.0, 1.0]),
1846
+ Point::new(vec![0.7, 0.7, 0.0]).normalize(),
1847
+ ];
1848
+
1849
+ for point in &points {
1850
+ index.add(Id::now(), point).unwrap();
1851
+ }
1852
+
1853
+ // Query near [1, 0, 0]
1854
+ let query = Point::new(vec![1.0, 0.0, 0.0]);
1855
+ let results = index.near(&query, 2).unwrap();
1856
+
1857
+ assert_eq!(results.len(), 2);
1858
+ // First result should have high similarity (close to 1.0)
1859
+ assert!(results[0].score > 0.5);
1860
+ }
1861
+
1862
+ #[test]
1863
+ fn test_hat_sessions() {
1864
+ let mut index = HatIndex::cosine(3);
1865
+
1866
+ // Add points to first session
1867
+ for i in 0..5 {
1868
+ let point = Point::new(vec![1.0, i as f32 * 0.1, 0.0]).normalize();
1869
+ index.add(Id::now(), &point).unwrap();
1870
+ }
1871
+
1872
+ // Start new session
1873
+ index.new_session();
1874
+
1875
+ // Add points to second session
1876
+ for i in 0..5 {
1877
+ let point = Point::new(vec![0.0, 1.0, i as f32 * 0.1]).normalize();
1878
+ index.add(Id::now(), &point).unwrap();
1879
+ }
1880
+
1881
+ assert_eq!(index.len(), 10);
1882
+
1883
+ // Query should find both sessions
1884
+ let query = Point::new(vec![0.5, 0.5, 0.0]).normalize();
1885
+ let results = index.near(&query, 5).unwrap();
1886
+
1887
+ assert_eq!(results.len(), 5);
1888
+ }
1889
+
1890
+ #[test]
1891
+ fn test_hat_hierarchy_structure() {
1892
+ let mut index = HatIndex::cosine(3);
1893
+
1894
+ // Add some points
1895
+ for _ in 0..10 {
1896
+ let point = Point::new(vec![1.0, 0.0, 0.0]);
1897
+ index.add(Id::now(), &point).unwrap();
1898
+ }
1899
+
1900
+ // Should have: 1 root + 1 session + 1 document + 10 chunks = 13 containers
1901
+ assert!(index.containers.len() >= 13);
1902
+
1903
+ // Check that root exists
1904
+ assert!(index.root_id.is_some());
1905
+ }
1906
+
1907
+ #[test]
1908
+ fn test_hat_empty() {
1909
+ let index = HatIndex::cosine(3);
1910
+
1911
+ let query = Point::new(vec![1.0, 0.0, 0.0]);
1912
+ let results = index.near(&query, 5).unwrap();
1913
+
1914
+ assert!(results.is_empty());
1915
+ }
1916
+
1917
+ #[test]
1918
+ fn test_hat_dimensionality_check() {
1919
+ let mut index = HatIndex::cosine(3);
1920
+
1921
+ let wrong_dims = Point::new(vec![1.0, 0.0]); // 2 dims
1922
+ let result = index.add(Id::now(), &wrong_dims);
1923
+
1924
+ match result {
1925
+ Err(NearError::DimensionalityMismatch { expected, got }) => {
1926
+ assert_eq!(expected, 3);
1927
+ assert_eq!(got, 2);
1928
+ }
1929
+ _ => panic!("Expected DimensionalityMismatch error"),
1930
+ }
1931
+ }
1932
+
1933
+ #[test]
1934
+ fn test_hat_scale() {
1935
+ let mut index = HatIndex::cosine(128);
1936
+
1937
+ // Add 1000 points
1938
+ for i in 0..1000 {
1939
+ let mut dims = vec![0.0f32; 128];
1940
+ dims[i % 128] = 1.0;
1941
+ let point = Point::new(dims).normalize();
1942
+ index.add(Id::now(), &point).unwrap();
1943
+ }
1944
+
1945
+ assert_eq!(index.len(), 1000);
1946
+
1947
+ // Query should work
1948
+ let query = Point::new(vec![1.0; 128]).normalize();
1949
+ let results = index.near(&query, 10).unwrap();
1950
+
1951
+ assert_eq!(results.len(), 10);
1952
+ }
1953
+ }
src/adapters/index/learnable_routing.rs ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Learnable Routing for HAT
2
+ //!
3
+ //! This module implements learnable routing weights for HAT index.
4
+ //! Instead of using fixed cosine similarity for routing decisions,
5
+ //! we learn dimension weights that adapt to actual query patterns.
6
+ //!
7
+ //! ## Key Insight (from journal 006)
8
+ //!
9
+ //! "The main gap: ARMS uses *known* structure while cutting-edge methods
10
+ //! *learn* structure. Opportunity: make HAT structure learnable while
11
+ //! keeping the efficiency benefits."
12
+ //!
13
+ //! ## Approach
14
+ //!
15
+ //! 1. **Weighted Similarity**: `sim(q, c) = Σᵢ wᵢ · qᵢ · cᵢ` instead of plain cosine
16
+ //! 2. **Feedback Collection**: Track query → retrieved → relevant mappings
17
+ //! 3. **Online Learning**: Update weights to improve routing decisions
18
+ //!
19
+ //! ## Benefits
20
+ //!
21
+ //! - Adapts to task-specific semantic dimensions
22
+ //! - No neural network training required (gradient-free)
23
+ //! - Preserves O(log n) query complexity
24
+ //! - Can learn from implicit feedback (click-through, usage patterns)
25
+
26
+ use crate::core::Point;
27
+ use std::collections::VecDeque;
28
+
29
+ /// Configuration for learnable routing
30
+ #[derive(Debug, Clone)]
31
+ pub struct LearnableRoutingConfig {
32
+ /// Learning rate for weight updates (0.0 = no learning)
33
+ pub learning_rate: f32,
34
+
35
+ /// Momentum for smoothing updates
36
+ pub momentum: f32,
37
+
38
+ /// Weight decay for regularization (prevents overfitting)
39
+ pub weight_decay: f32,
40
+
41
+ /// Maximum number of feedback samples to retain
42
+ pub max_feedback_samples: usize,
43
+
44
+ /// Minimum feedback samples before learning starts
45
+ pub min_samples_to_learn: usize,
46
+
47
+ /// How often to update weights (every N feedback samples)
48
+ pub update_frequency: usize,
49
+
50
+ /// Enable dimension-wise weights (vs single scalar)
51
+ pub per_dimension_weights: bool,
52
+ }
53
+
54
+ impl Default for LearnableRoutingConfig {
55
+ fn default() -> Self {
56
+ Self {
57
+ learning_rate: 0.01,
58
+ momentum: 0.9,
59
+ weight_decay: 0.001,
60
+ max_feedback_samples: 1000,
61
+ min_samples_to_learn: 50,
62
+ update_frequency: 10,
63
+ per_dimension_weights: true,
64
+ }
65
+ }
66
+ }
67
+
68
+ impl LearnableRoutingConfig {
69
+ pub fn new() -> Self {
70
+ Self::default()
71
+ }
72
+
73
+ pub fn with_learning_rate(mut self, lr: f32) -> Self {
74
+ self.learning_rate = lr;
75
+ self
76
+ }
77
+
78
+ pub fn with_momentum(mut self, momentum: f32) -> Self {
79
+ self.momentum = momentum.clamp(0.0, 0.99);
80
+ self
81
+ }
82
+
83
+ pub fn disabled() -> Self {
84
+ Self {
85
+ learning_rate: 0.0,
86
+ ..Default::default()
87
+ }
88
+ }
89
+ }
90
+
91
+ /// A single feedback sample from query execution
92
+ #[derive(Debug, Clone)]
93
+ pub struct RoutingFeedback {
94
+ /// The query point
95
+ pub query: Point,
96
+
97
+ /// Container centroid that was selected
98
+ pub selected_centroid: Point,
99
+
100
+ /// Whether the selection led to good results (positive = good)
101
+ pub reward: f32,
102
+
103
+ /// Which level in the hierarchy this feedback is for
104
+ pub level: usize,
105
+ }
106
+
107
+ /// Learnable routing weights for HAT
108
+ ///
109
+ /// Maintains per-dimension (or scalar) weights that modify
110
+ /// the similarity computation during tree traversal.
111
+ #[derive(Debug, Clone)]
112
+ pub struct LearnableRouter {
113
+ /// Configuration
114
+ config: LearnableRoutingConfig,
115
+
116
+ /// Per-dimension weights (or single weight if per_dimension_weights=false)
117
+ weights: Vec<f32>,
118
+
119
+ /// Momentum accumulator for smooth updates
120
+ momentum_buffer: Vec<f32>,
121
+
122
+ /// Feedback buffer for batch updates
123
+ feedback_buffer: VecDeque<RoutingFeedback>,
124
+
125
+ /// Total feedback samples received
126
+ total_samples: usize,
127
+
128
+ /// Dimensionality
129
+ dims: usize,
130
+ }
131
+
132
+ impl LearnableRouter {
133
+ /// Create a new learnable router
134
+ pub fn new(dims: usize, config: LearnableRoutingConfig) -> Self {
135
+ let weight_count = if config.per_dimension_weights { dims } else { 1 };
136
+
137
+ Self {
138
+ config,
139
+ weights: vec![1.0; weight_count], // Start with uniform weights
140
+ momentum_buffer: vec![0.0; weight_count],
141
+ feedback_buffer: VecDeque::new(),
142
+ total_samples: 0,
143
+ dims,
144
+ }
145
+ }
146
+
147
+ /// Create with default config
148
+ pub fn default_for_dims(dims: usize) -> Self {
149
+ Self::new(dims, LearnableRoutingConfig::default())
150
+ }
151
+
152
+ /// Check if learning is enabled
153
+ pub fn is_learning_enabled(&self) -> bool {
154
+ self.config.learning_rate > 0.0
155
+ }
156
+
157
+ /// Get current weights (for inspection/serialization)
158
+ pub fn weights(&self) -> &[f32] {
159
+ &self.weights
160
+ }
161
+
162
+ /// Compute weighted similarity between query and centroid
163
+ ///
164
+ /// Returns a similarity score (higher = more similar)
165
+ pub fn weighted_similarity(&self, query: &Point, centroid: &Point) -> f32 {
166
+ if self.config.per_dimension_weights {
167
+ // Weighted dot product: Σᵢ wᵢ · qᵢ · cᵢ
168
+ query.dims().iter()
169
+ .zip(centroid.dims().iter())
170
+ .zip(self.weights.iter())
171
+ .map(|((q, c), w)| w * q * c)
172
+ .sum()
173
+ } else {
174
+ // Single scalar weight (equivalent to scaled cosine)
175
+ let dot: f32 = query.dims().iter()
176
+ .zip(centroid.dims().iter())
177
+ .map(|(q, c)| q * c)
178
+ .sum();
179
+ self.weights[0] * dot
180
+ }
181
+ }
182
+
183
+ /// Record feedback from a routing decision
184
+ pub fn record_feedback(&mut self, feedback: RoutingFeedback) {
185
+ self.feedback_buffer.push_back(feedback);
186
+ self.total_samples += 1;
187
+
188
+ // Trim buffer if too large
189
+ while self.feedback_buffer.len() > self.config.max_feedback_samples {
190
+ self.feedback_buffer.pop_front();
191
+ }
192
+
193
+ // Trigger update if conditions met
194
+ if self.should_update() {
195
+ self.update_weights();
196
+ }
197
+ }
198
+
199
+ /// Check if we should update weights
200
+ fn should_update(&self) -> bool {
201
+ self.config.learning_rate > 0.0
202
+ && self.feedback_buffer.len() >= self.config.min_samples_to_learn
203
+ && self.total_samples % self.config.update_frequency == 0
204
+ }
205
+
206
+ /// Update weights based on accumulated feedback
207
+ ///
208
+ /// Uses a simple gradient-free approach:
209
+ /// - For positive feedback: increase weights for dimensions where q·c was high
210
+ /// - For negative feedback: decrease weights for dimensions where q·c was high
211
+ fn update_weights(&mut self) {
212
+ if self.feedback_buffer.is_empty() {
213
+ return;
214
+ }
215
+
216
+ let lr = self.config.learning_rate;
217
+ let momentum = self.config.momentum;
218
+ let decay = self.config.weight_decay;
219
+
220
+ // Compute gradient estimate from feedback
221
+ let mut gradient = vec![0.0f32; self.weights.len()];
222
+
223
+ for feedback in &self.feedback_buffer {
224
+ let reward = feedback.reward;
225
+
226
+ if self.config.per_dimension_weights {
227
+ // Per-dimension update
228
+ for ((&q, &c), g) in feedback.query.dims().iter()
229
+ .zip(feedback.selected_centroid.dims().iter())
230
+ .zip(gradient.iter_mut())
231
+ {
232
+ // Gradient: reward * q * c (increase weight if positive reward)
233
+ *g += reward * q * c;
234
+ }
235
+ } else {
236
+ // Scalar update
237
+ let dot: f32 = feedback.query.dims().iter()
238
+ .zip(feedback.selected_centroid.dims().iter())
239
+ .map(|(q, c)| q * c)
240
+ .sum();
241
+ gradient[0] += reward * dot;
242
+ }
243
+ }
244
+
245
+ // Normalize by number of samples
246
+ let n = self.feedback_buffer.len() as f32;
247
+ for g in gradient.iter_mut() {
248
+ *g /= n;
249
+ }
250
+
251
+ // Apply momentum and update weights
252
+ for (i, (w, g)) in self.weights.iter_mut().zip(gradient.iter()).enumerate() {
253
+ // Momentum update
254
+ self.momentum_buffer[i] = momentum * self.momentum_buffer[i] + (1.0 - momentum) * g;
255
+
256
+ // Weight update with decay
257
+ *w += lr * self.momentum_buffer[i] - decay * (*w - 1.0);
258
+
259
+ // Clamp weights to reasonable range
260
+ *w = w.clamp(0.1, 10.0);
261
+ }
262
+ }
263
+
264
+ /// Record positive feedback (successful retrieval)
265
+ pub fn record_success(&mut self, query: &Point, selected_centroid: &Point, level: usize) {
266
+ self.record_feedback(RoutingFeedback {
267
+ query: query.clone(),
268
+ selected_centroid: selected_centroid.clone(),
269
+ reward: 1.0,
270
+ level,
271
+ });
272
+ }
273
+
274
+ /// Record negative feedback (unsuccessful retrieval)
275
+ pub fn record_failure(&mut self, query: &Point, selected_centroid: &Point, level: usize) {
276
+ self.record_feedback(RoutingFeedback {
277
+ query: query.clone(),
278
+ selected_centroid: selected_centroid.clone(),
279
+ reward: -1.0,
280
+ level,
281
+ });
282
+ }
283
+
284
+ /// Record implicit feedback with continuous reward
285
+ pub fn record_implicit(&mut self, query: &Point, selected_centroid: &Point, level: usize, relevance_score: f32) {
286
+ // Convert relevance (0-1) to reward (-1 to +1)
287
+ let reward = 2.0 * relevance_score - 1.0;
288
+ self.record_feedback(RoutingFeedback {
289
+ query: query.clone(),
290
+ selected_centroid: selected_centroid.clone(),
291
+ reward,
292
+ level,
293
+ });
294
+ }
295
+
296
+ /// Get statistics about the router
297
+ pub fn stats(&self) -> RouterStats {
298
+ RouterStats {
299
+ total_samples: self.total_samples,
300
+ buffer_size: self.feedback_buffer.len(),
301
+ weight_mean: self.weights.iter().sum::<f32>() / self.weights.len() as f32,
302
+ weight_std: {
303
+ let mean = self.weights.iter().sum::<f32>() / self.weights.len() as f32;
304
+ (self.weights.iter().map(|w| (w - mean).powi(2)).sum::<f32>()
305
+ / self.weights.len() as f32).sqrt()
306
+ },
307
+ weight_min: self.weights.iter().cloned().fold(f32::INFINITY, f32::min),
308
+ weight_max: self.weights.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
309
+ }
310
+ }
311
+
312
+ /// Reset weights to uniform
313
+ pub fn reset_weights(&mut self) {
314
+ for w in self.weights.iter_mut() {
315
+ *w = 1.0;
316
+ }
317
+ for m in self.momentum_buffer.iter_mut() {
318
+ *m = 0.0;
319
+ }
320
+ }
321
+
322
+ /// Clear feedback buffer
323
+ pub fn clear_feedback(&mut self) {
324
+ self.feedback_buffer.clear();
325
+ }
326
+
327
+ /// Get the number of dimensions
328
+ pub fn dims(&self) -> usize {
329
+ self.dims
330
+ }
331
+
332
+ /// Serialize weights to bytes
333
+ pub fn serialize_weights(&self) -> Vec<u8> {
334
+ let mut bytes = Vec::with_capacity(self.weights.len() * 4);
335
+ for w in &self.weights {
336
+ bytes.extend_from_slice(&w.to_le_bytes());
337
+ }
338
+ bytes
339
+ }
340
+
341
+ /// Deserialize weights from bytes
342
+ pub fn deserialize_weights(&mut self, bytes: &[u8]) -> Result<(), &'static str> {
343
+ if bytes.len() != self.weights.len() * 4 {
344
+ return Err("Weight count mismatch");
345
+ }
346
+
347
+ for (i, chunk) in bytes.chunks(4).enumerate() {
348
+ let arr: [u8; 4] = chunk.try_into().map_err(|_| "Invalid byte chunk")?;
349
+ self.weights[i] = f32::from_le_bytes(arr);
350
+ }
351
+
352
+ Ok(())
353
+ }
354
+ }
355
+
356
+ /// Statistics about the learnable router
357
+ #[derive(Debug, Clone)]
358
+ pub struct RouterStats {
359
+ pub total_samples: usize,
360
+ pub buffer_size: usize,
361
+ pub weight_mean: f32,
362
+ pub weight_std: f32,
363
+ pub weight_min: f32,
364
+ pub weight_max: f32,
365
+ }
366
+
367
+ /// Compute routing score for beam search
368
+ ///
369
+ /// Combines weighted similarity with optional biases
370
+ pub fn compute_routing_score(
371
+ router: &LearnableRouter,
372
+ query: &Point,
373
+ centroid: &Point,
374
+ temporal_distance: f32,
375
+ temporal_weight: f32,
376
+ ) -> f32 {
377
+ let semantic_sim = router.weighted_similarity(query, centroid);
378
+
379
+ // Convert to distance (lower = better for routing)
380
+ let semantic_dist = 1.0 - semantic_sim;
381
+
382
+ // Combine with temporal
383
+ semantic_dist * (1.0 - temporal_weight) + temporal_distance * temporal_weight
384
+ }
385
+
386
+ #[cfg(test)]
387
+ mod tests {
388
+ use super::*;
389
+
390
+ fn make_point(v: Vec<f32>) -> Point {
391
+ Point::new(v).normalize()
392
+ }
393
+
394
+ #[test]
395
+ fn test_router_creation() {
396
+ let router = LearnableRouter::default_for_dims(64);
397
+
398
+ assert_eq!(router.dims(), 64);
399
+ assert_eq!(router.weights().len(), 64);
400
+ assert!(router.is_learning_enabled());
401
+
402
+ // All weights should start at 1.0
403
+ for &w in router.weights() {
404
+ assert!((w - 1.0).abs() < 1e-6);
405
+ }
406
+ }
407
+
408
+ #[test]
409
+ fn test_weighted_similarity() {
410
+ let router = LearnableRouter::default_for_dims(4);
411
+
412
+ let query = make_point(vec![1.0, 0.0, 0.0, 0.0]);
413
+ let centroid = make_point(vec![0.8, 0.2, 0.0, 0.0]);
414
+
415
+ let sim = router.weighted_similarity(&query, &centroid);
416
+
417
+ // With uniform weights, should be close to cosine similarity
418
+ let expected_cosine: f32 = query.dims().iter()
419
+ .zip(centroid.dims().iter())
420
+ .map(|(q, c)| q * c)
421
+ .sum();
422
+
423
+ assert!((sim - expected_cosine).abs() < 1e-5);
424
+ }
425
+
426
+ #[test]
427
+ fn test_feedback_recording() {
428
+ let mut router = LearnableRouter::new(4, LearnableRoutingConfig {
429
+ min_samples_to_learn: 5,
430
+ update_frequency: 5,
431
+ ..Default::default()
432
+ });
433
+
434
+ let query = make_point(vec![1.0, 0.0, 0.0, 0.0]);
435
+ let centroid = make_point(vec![0.9, 0.1, 0.0, 0.0]);
436
+
437
+ // Record several positive feedbacks
438
+ for _ in 0..10 {
439
+ router.record_success(&query, &centroid, 0);
440
+ }
441
+
442
+ let stats = router.stats();
443
+ assert_eq!(stats.total_samples, 10);
444
+
445
+ // Weights should have been updated
446
+ // Dimension 0 (aligned with query) should increase
447
+ println!("Weights after positive feedback: {:?}", router.weights());
448
+ }
449
+
450
+ #[test]
451
+ fn test_learning_dynamics() {
452
+ let mut router = LearnableRouter::new(4, LearnableRoutingConfig {
453
+ learning_rate: 0.1,
454
+ min_samples_to_learn: 3,
455
+ update_frequency: 3,
456
+ momentum: 0.0, // No momentum for predictable testing
457
+ weight_decay: 0.0, // No decay for predictable testing
458
+ ..Default::default()
459
+ });
460
+
461
+ // Query aligned with dimension 0
462
+ let query = make_point(vec![1.0, 0.0, 0.0, 0.0]);
463
+ // Centroid also aligned with dimension 0
464
+ let centroid_good = make_point(vec![0.95, 0.05, 0.0, 0.0]);
465
+ // Centroid aligned with dimension 1
466
+ let centroid_bad = make_point(vec![0.0, 1.0, 0.0, 0.0]);
467
+
468
+ // Record positive feedback for good centroid
469
+ for _ in 0..6 {
470
+ router.record_success(&query, &centroid_good, 0);
471
+ }
472
+
473
+ let weights_after_positive = router.weights().to_vec();
474
+
475
+ // Record negative feedback for bad centroid
476
+ for _ in 0..6 {
477
+ router.record_failure(&query, &centroid_bad, 0);
478
+ }
479
+
480
+ let weights_after_negative = router.weights().to_vec();
481
+
482
+ println!("Initial weights: [1.0, 1.0, 1.0, 1.0]");
483
+ println!("After positive: {:?}", weights_after_positive);
484
+ println!("After negative: {:?}", weights_after_negative);
485
+
486
+ // Weight for dim 0 should have increased from positive feedback
487
+ // (query[0] * centroid_good[0] is high and reward is positive)
488
+ }
489
+
490
+ #[test]
491
+ fn test_disabled_learning() {
492
+ let mut router = LearnableRouter::new(4, LearnableRoutingConfig::disabled());
493
+
494
+ assert!(!router.is_learning_enabled());
495
+
496
+ let query = make_point(vec![1.0, 0.0, 0.0, 0.0]);
497
+ let centroid = make_point(vec![0.9, 0.1, 0.0, 0.0]);
498
+
499
+ // Record feedback
500
+ for _ in 0..100 {
501
+ router.record_success(&query, &centroid, 0);
502
+ }
503
+
504
+ // Weights should remain at 1.0
505
+ for &w in router.weights() {
506
+ assert!((w - 1.0).abs() < 1e-6);
507
+ }
508
+ }
509
+
510
+ #[test]
511
+ fn test_serialization() {
512
+ let mut router = LearnableRouter::default_for_dims(4);
513
+
514
+ // Modify weights
515
+ for (i, w) in router.weights.iter_mut().enumerate() {
516
+ *w = (i as f32 + 1.0) * 0.5;
517
+ }
518
+
519
+ let bytes = router.serialize_weights();
520
+
521
+ let mut router2 = LearnableRouter::default_for_dims(4);
522
+ router2.deserialize_weights(&bytes).unwrap();
523
+
524
+ for (w1, w2) in router.weights().iter().zip(router2.weights().iter()) {
525
+ assert!((w1 - w2).abs() < 1e-6);
526
+ }
527
+ }
528
+ }
src/adapters/index/mod.rs ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Index Adapters
2
+ //!
3
+ //! Implementations of the Near port for different index backends.
4
+ //!
5
+ //! Available adapters:
6
+ //! - `FlatIndex` - Brute force search (exact, O(n) per query)
7
+ //! - `HatIndex` - Hierarchical Attention Tree (approximate, O(log n) per query)
8
+ //!
9
+ //! Consolidation support:
10
+ //! - `Consolidate` trait for background maintenance operations
11
+ //! - `ConsolidationConfig` to configure maintenance behavior
12
+ //!
13
+ //! Subspace support:
14
+ //! - `Subspace` representation for containers capturing variance/spread
15
+ //! - `SubspaceConfig` for configuring subspace-aware routing
16
+ //!
17
+ //! Learnable routing:
18
+ //! - `LearnableRouter` for adapting routing weights from feedback
19
+ //! - `LearnableRoutingConfig` for configuring online learning
20
+
21
+ mod flat;
22
+ mod hat;
23
+ mod consolidation;
24
+ mod subspace;
25
+ mod learnable_routing;
26
+ mod persistence;
27
+
28
+ pub use flat::FlatIndex;
29
+ pub use hat::{HatIndex, HatConfig, CentroidMethod, ContainerLevel, SessionSummary, DocumentSummary, HatStats};
30
+ pub use consolidation::{
31
+ Consolidate, ConsolidationConfig, ConsolidationLevel, ConsolidationPhase,
32
+ ConsolidationState, ConsolidationMetrics, ConsolidationProgress, ConsolidationTickResult,
33
+ compute_exact_centroid, centroid_drift,
34
+ };
35
+ pub use subspace::{
36
+ Subspace, SubspaceConfig, subspace_similarity, combined_subspace_similarity,
37
+ query_subspace_alignment, subspace_spread, subspace_isotropy,
38
+ };
39
+ pub use learnable_routing::{
40
+ LearnableRouter, LearnableRoutingConfig, RoutingFeedback, RouterStats,
41
+ compute_routing_score,
42
+ };
43
+ pub use persistence::{
44
+ PersistError, SerializedHat, SerializedContainer, LevelByte,
45
+ };
src/adapters/index/persistence.rs ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # HAT Persistence Layer
2
+ //!
3
+ //! Serialization and deserialization for HAT indexes.
4
+ //!
5
+ //! ## Format
6
+ //!
7
+ //! The HAT persistence format is a simple binary format:
8
+ //!
9
+ //! ```text
10
+ //! [Header: 32 bytes]
11
+ //! - Magic: "HAT\0" (4 bytes)
12
+ //! - Version: u32 (4 bytes)
13
+ //! - Dimensionality: u32 (4 bytes)
14
+ //! - Container count: u64 (8 bytes)
15
+ //! - Root ID: 16 bytes (or zeros if none)
16
+ //! - Reserved: 0 bytes (for future use)
17
+ //!
18
+ //! [Containers: variable]
19
+ //! For each container:
20
+ //! - ID: 16 bytes
21
+ //! - Level: u8 (0=Root, 1=Session, 2=Document, 3=Chunk)
22
+ //! - Timestamp: u64 (8 bytes)
23
+ //! - Child count: u32 (4 bytes)
24
+ //! - Child IDs: child_count * 16 bytes
25
+ //! - Descendant count: u64 (8 bytes)
26
+ //! - Centroid: dimensionality * 4 bytes (f32s)
27
+ //! - Has accumulated sum: u8 (0 or 1)
28
+ //! - Accumulated sum: dimensionality * 4 bytes (if has_accumulated_sum)
29
+ //!
30
+ //! [Active State: 32 bytes]
31
+ //! - Active session ID: 16 bytes (or zeros)
32
+ //! - Active document ID: 16 bytes (or zeros)
33
+ //!
34
+ //! [Learnable Router Weights: variable, optional]
35
+ //! - Has weights: u8 (0 or 1)
36
+ //! - If has weights: dimensionality * 4 bytes (f32s)
37
+ //! ```
38
+ //!
39
+ //! ## Usage
40
+ //!
41
+ //! ```rust,ignore
42
+ //! // Save
43
+ //! let bytes = hat.to_bytes()?;
44
+ //! std::fs::write("index.hat", bytes)?;
45
+ //!
46
+ //! // Load
47
+ //! let bytes = std::fs::read("index.hat")?;
48
+ //! let hat = HatIndex::from_bytes(&bytes)?;
49
+ //! ```
50
+
51
+ use crate::core::{Id, Point};
52
+ use std::io::{self, Read, Write, Cursor};
53
+
54
+ /// Magic bytes for HAT file format
55
+ const MAGIC: &[u8; 4] = b"HAT\0";
56
+
57
+ /// Current format version
58
+ const VERSION: u32 = 1;
59
+
60
+ /// Error type for persistence operations
61
+ #[derive(Debug)]
62
+ pub enum PersistError {
63
+ /// Invalid magic bytes
64
+ InvalidMagic,
65
+ /// Unsupported version
66
+ UnsupportedVersion(u32),
67
+ /// IO error
68
+ Io(io::Error),
69
+ /// Data corruption
70
+ Corrupted(String),
71
+ /// Dimension mismatch
72
+ DimensionMismatch { expected: usize, found: usize },
73
+ }
74
+
75
+ impl std::fmt::Display for PersistError {
76
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77
+ match self {
78
+ PersistError::InvalidMagic => write!(f, "Invalid HAT file magic bytes"),
79
+ PersistError::UnsupportedVersion(v) => write!(f, "Unsupported HAT version: {}", v),
80
+ PersistError::Io(e) => write!(f, "IO error: {}", e),
81
+ PersistError::Corrupted(msg) => write!(f, "Data corruption: {}", msg),
82
+ PersistError::DimensionMismatch { expected, found } => {
83
+ write!(f, "Dimension mismatch: expected {}, found {}", expected, found)
84
+ }
85
+ }
86
+ }
87
+ }
88
+
89
+ impl std::error::Error for PersistError {}
90
+
91
+ impl From<io::Error> for PersistError {
92
+ fn from(e: io::Error) -> Self {
93
+ PersistError::Io(e)
94
+ }
95
+ }
96
+
97
+ /// Container level as u8
98
+ #[repr(u8)]
99
+ #[derive(Debug, Clone, Copy, PartialEq, Eq)]
100
+ pub enum LevelByte {
101
+ Root = 0,
102
+ Session = 1,
103
+ Document = 2,
104
+ Chunk = 3,
105
+ }
106
+
107
+ impl LevelByte {
108
+ pub fn from_u8(v: u8) -> Option<Self> {
109
+ match v {
110
+ 0 => Some(LevelByte::Root),
111
+ 1 => Some(LevelByte::Session),
112
+ 2 => Some(LevelByte::Document),
113
+ 3 => Some(LevelByte::Chunk),
114
+ _ => None,
115
+ }
116
+ }
117
+ }
118
+
119
+ /// Serialized container data
120
+ #[derive(Debug, Clone)]
121
+ pub struct SerializedContainer {
122
+ pub id: Id,
123
+ pub level: LevelByte,
124
+ pub timestamp: u64,
125
+ pub children: Vec<Id>,
126
+ pub descendant_count: u64,
127
+ pub centroid: Vec<f32>,
128
+ pub accumulated_sum: Option<Vec<f32>>,
129
+ }
130
+
131
+ /// Serialized HAT index
132
+ #[derive(Debug, Clone)]
133
+ pub struct SerializedHat {
134
+ pub version: u32,
135
+ pub dimensionality: u32,
136
+ pub root_id: Option<Id>,
137
+ pub containers: Vec<SerializedContainer>,
138
+ pub active_session: Option<Id>,
139
+ pub active_document: Option<Id>,
140
+ pub router_weights: Option<Vec<f32>>,
141
+ }
142
+
143
+ impl SerializedHat {
144
+ /// Serialize to bytes
145
+ pub fn to_bytes(&self) -> Result<Vec<u8>, PersistError> {
146
+ let mut buf = Vec::new();
147
+
148
+ // Header
149
+ buf.write_all(MAGIC)?;
150
+ buf.write_all(&self.version.to_le_bytes())?;
151
+ buf.write_all(&self.dimensionality.to_le_bytes())?;
152
+ buf.write_all(&(self.containers.len() as u64).to_le_bytes())?;
153
+
154
+ // Root ID
155
+ if let Some(id) = &self.root_id {
156
+ buf.write_all(id.as_bytes())?;
157
+ } else {
158
+ buf.write_all(&[0u8; 16])?;
159
+ }
160
+
161
+ // Containers
162
+ for container in &self.containers {
163
+ // ID
164
+ buf.write_all(container.id.as_bytes())?;
165
+
166
+ // Level
167
+ buf.write_all(&[container.level as u8])?;
168
+
169
+ // Timestamp
170
+ buf.write_all(&container.timestamp.to_le_bytes())?;
171
+
172
+ // Children
173
+ buf.write_all(&(container.children.len() as u32).to_le_bytes())?;
174
+ for child_id in &container.children {
175
+ buf.write_all(child_id.as_bytes())?;
176
+ }
177
+
178
+ // Descendant count
179
+ buf.write_all(&container.descendant_count.to_le_bytes())?;
180
+
181
+ // Centroid
182
+ for &v in &container.centroid {
183
+ buf.write_all(&v.to_le_bytes())?;
184
+ }
185
+
186
+ // Accumulated sum
187
+ if let Some(sum) = &container.accumulated_sum {
188
+ buf.write_all(&[1u8])?;
189
+ for &v in sum {
190
+ buf.write_all(&v.to_le_bytes())?;
191
+ }
192
+ } else {
193
+ buf.write_all(&[0u8])?;
194
+ }
195
+ }
196
+
197
+ // Active state
198
+ if let Some(id) = &self.active_session {
199
+ buf.write_all(id.as_bytes())?;
200
+ } else {
201
+ buf.write_all(&[0u8; 16])?;
202
+ }
203
+
204
+ if let Some(id) = &self.active_document {
205
+ buf.write_all(id.as_bytes())?;
206
+ } else {
207
+ buf.write_all(&[0u8; 16])?;
208
+ }
209
+
210
+ // Router weights
211
+ if let Some(weights) = &self.router_weights {
212
+ buf.write_all(&[1u8])?;
213
+ for &w in weights {
214
+ buf.write_all(&w.to_le_bytes())?;
215
+ }
216
+ } else {
217
+ buf.write_all(&[0u8])?;
218
+ }
219
+
220
+ Ok(buf)
221
+ }
222
+
223
+ /// Deserialize from bytes
224
+ pub fn from_bytes(data: &[u8]) -> Result<Self, PersistError> {
225
+ let mut cursor = Cursor::new(data);
226
+
227
+ // Read header
228
+ let mut magic = [0u8; 4];
229
+ cursor.read_exact(&mut magic)?;
230
+ if &magic != MAGIC {
231
+ return Err(PersistError::InvalidMagic);
232
+ }
233
+
234
+ let mut version_bytes = [0u8; 4];
235
+ cursor.read_exact(&mut version_bytes)?;
236
+ let version = u32::from_le_bytes(version_bytes);
237
+ if version != VERSION {
238
+ return Err(PersistError::UnsupportedVersion(version));
239
+ }
240
+
241
+ let mut dims_bytes = [0u8; 4];
242
+ cursor.read_exact(&mut dims_bytes)?;
243
+ let dimensionality = u32::from_le_bytes(dims_bytes);
244
+
245
+ let mut count_bytes = [0u8; 8];
246
+ cursor.read_exact(&mut count_bytes)?;
247
+ let container_count = u64::from_le_bytes(count_bytes);
248
+
249
+ let mut root_bytes = [0u8; 16];
250
+ cursor.read_exact(&mut root_bytes)?;
251
+ let root_id = if root_bytes == [0u8; 16] {
252
+ None
253
+ } else {
254
+ Some(Id::from_bytes(root_bytes))
255
+ };
256
+
257
+ // Read containers
258
+ let mut containers = Vec::with_capacity(container_count as usize);
259
+ for _ in 0..container_count {
260
+ // ID
261
+ let mut id_bytes = [0u8; 16];
262
+ cursor.read_exact(&mut id_bytes)?;
263
+ let id = Id::from_bytes(id_bytes);
264
+
265
+ // Level
266
+ let mut level_byte = [0u8; 1];
267
+ cursor.read_exact(&mut level_byte)?;
268
+ let level = LevelByte::from_u8(level_byte[0])
269
+ .ok_or_else(|| PersistError::Corrupted(format!("Invalid level: {}", level_byte[0])))?;
270
+
271
+ // Timestamp
272
+ let mut ts_bytes = [0u8; 8];
273
+ cursor.read_exact(&mut ts_bytes)?;
274
+ let timestamp = u64::from_le_bytes(ts_bytes);
275
+
276
+ // Children
277
+ let mut child_count_bytes = [0u8; 4];
278
+ cursor.read_exact(&mut child_count_bytes)?;
279
+ let child_count = u32::from_le_bytes(child_count_bytes) as usize;
280
+
281
+ let mut children = Vec::with_capacity(child_count);
282
+ for _ in 0..child_count {
283
+ let mut child_bytes = [0u8; 16];
284
+ cursor.read_exact(&mut child_bytes)?;
285
+ children.push(Id::from_bytes(child_bytes));
286
+ }
287
+
288
+ // Descendant count
289
+ let mut desc_bytes = [0u8; 8];
290
+ cursor.read_exact(&mut desc_bytes)?;
291
+ let descendant_count = u64::from_le_bytes(desc_bytes);
292
+
293
+ // Centroid
294
+ let mut centroid = Vec::with_capacity(dimensionality as usize);
295
+ for _ in 0..dimensionality {
296
+ let mut v_bytes = [0u8; 4];
297
+ cursor.read_exact(&mut v_bytes)?;
298
+ centroid.push(f32::from_le_bytes(v_bytes));
299
+ }
300
+
301
+ // Accumulated sum
302
+ let mut has_sum = [0u8; 1];
303
+ cursor.read_exact(&mut has_sum)?;
304
+ let accumulated_sum = if has_sum[0] == 1 {
305
+ let mut sum = Vec::with_capacity(dimensionality as usize);
306
+ for _ in 0..dimensionality {
307
+ let mut v_bytes = [0u8; 4];
308
+ cursor.read_exact(&mut v_bytes)?;
309
+ sum.push(f32::from_le_bytes(v_bytes));
310
+ }
311
+ Some(sum)
312
+ } else {
313
+ None
314
+ };
315
+
316
+ containers.push(SerializedContainer {
317
+ id,
318
+ level,
319
+ timestamp,
320
+ children,
321
+ descendant_count,
322
+ centroid,
323
+ accumulated_sum,
324
+ });
325
+ }
326
+
327
+ // Active state
328
+ let mut active_session_bytes = [0u8; 16];
329
+ cursor.read_exact(&mut active_session_bytes)?;
330
+ let active_session = if active_session_bytes == [0u8; 16] {
331
+ None
332
+ } else {
333
+ Some(Id::from_bytes(active_session_bytes))
334
+ };
335
+
336
+ let mut active_document_bytes = [0u8; 16];
337
+ cursor.read_exact(&mut active_document_bytes)?;
338
+ let active_document = if active_document_bytes == [0u8; 16] {
339
+ None
340
+ } else {
341
+ Some(Id::from_bytes(active_document_bytes))
342
+ };
343
+
344
+ // Router weights (optional - may not be present in older files)
345
+ let router_weights = if cursor.position() < data.len() as u64 {
346
+ let mut has_weights = [0u8; 1];
347
+ cursor.read_exact(&mut has_weights)?;
348
+ if has_weights[0] == 1 {
349
+ let mut weights = Vec::with_capacity(dimensionality as usize);
350
+ for _ in 0..dimensionality {
351
+ let mut w_bytes = [0u8; 4];
352
+ cursor.read_exact(&mut w_bytes)?;
353
+ weights.push(f32::from_le_bytes(w_bytes));
354
+ }
355
+ Some(weights)
356
+ } else {
357
+ None
358
+ }
359
+ } else {
360
+ None
361
+ };
362
+
363
+ Ok(SerializedHat {
364
+ version,
365
+ dimensionality,
366
+ root_id,
367
+ containers,
368
+ active_session,
369
+ active_document,
370
+ router_weights,
371
+ })
372
+ }
373
+ }
374
+
375
+ /// Helper to read ID from Option
376
+ fn id_to_bytes(id: &Option<Id>) -> [u8; 16] {
377
+ match id {
378
+ Some(id) => *id.as_bytes(),
379
+ None => [0u8; 16],
380
+ }
381
+ }
382
+
383
+ #[cfg(test)]
384
+ mod tests {
385
+ use super::*;
386
+
387
+ #[test]
388
+ fn test_serialized_hat_roundtrip() {
389
+ let original = SerializedHat {
390
+ version: VERSION,
391
+ dimensionality: 128,
392
+ root_id: Some(Id::now()),
393
+ containers: vec![
394
+ SerializedContainer {
395
+ id: Id::now(),
396
+ level: LevelByte::Root,
397
+ timestamp: 1234567890,
398
+ children: vec![Id::now(), Id::now()],
399
+ descendant_count: 10,
400
+ centroid: vec![0.1; 128],
401
+ accumulated_sum: None,
402
+ },
403
+ SerializedContainer {
404
+ id: Id::now(),
405
+ level: LevelByte::Chunk,
406
+ timestamp: 1234567891,
407
+ children: vec![],
408
+ descendant_count: 1,
409
+ centroid: vec![0.5; 128],
410
+ accumulated_sum: Some(vec![0.5; 128]),
411
+ },
412
+ ],
413
+ active_session: Some(Id::now()),
414
+ active_document: None,
415
+ router_weights: Some(vec![1.0; 128]),
416
+ };
417
+
418
+ let bytes = original.to_bytes().unwrap();
419
+ let restored = SerializedHat::from_bytes(&bytes).unwrap();
420
+
421
+ assert_eq!(restored.version, original.version);
422
+ assert_eq!(restored.dimensionality, original.dimensionality);
423
+ assert_eq!(restored.containers.len(), original.containers.len());
424
+ assert!(restored.router_weights.is_some());
425
+ }
426
+
427
+ #[test]
428
+ fn test_invalid_magic() {
429
+ let bad_data = b"BAD\0rest of data...";
430
+ let result = SerializedHat::from_bytes(bad_data);
431
+ assert!(matches!(result, Err(PersistError::InvalidMagic)));
432
+ }
433
+
434
+ #[test]
435
+ fn test_level_byte_conversion() {
436
+ assert_eq!(LevelByte::from_u8(0), Some(LevelByte::Root));
437
+ assert_eq!(LevelByte::from_u8(1), Some(LevelByte::Session));
438
+ assert_eq!(LevelByte::from_u8(2), Some(LevelByte::Document));
439
+ assert_eq!(LevelByte::from_u8(3), Some(LevelByte::Chunk));
440
+ assert_eq!(LevelByte::from_u8(4), None);
441
+ }
442
+ }
src/adapters/index/subspace.rs ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Subspace Containers for HAT
2
+ //!
3
+ //! This module implements subspace-aware container representations for HAT.
4
+ //! Instead of representing containers as single centroid points, we model them
5
+ //! as subspaces that capture the "shape" and "spread" of points within.
6
+ //!
7
+ //! ## Key Insight (from journal 006)
8
+ //!
9
+ //! "A session isn't a single point - it's a *region* of the manifold."
10
+ //!
11
+ //! ## Grassmann-Inspired Approach
12
+ //!
13
+ //! - Each container is represented by its centroid PLUS principal directions
14
+ //! - Similarity between containers uses subspace angles (principal angles)
15
+ //! - Better captures diverse content within a container
16
+ //!
17
+ //! ## Benefits
18
+ //!
19
+ //! 1. **Better Routing**: Query can match containers even if not close to centroid
20
+ //! 2. **Diversity Awareness**: Wide containers (diverse content) vs narrow containers
21
+ //! 3. **Geometric Fidelity**: More accurate representation of point distributions
22
+
23
+ use crate::core::Point;
24
+
25
+ /// Configuration for subspace representation
26
+ #[derive(Debug, Clone)]
27
+ pub struct SubspaceConfig {
28
+ /// Number of principal components to track (subspace rank)
29
+ pub rank: usize,
30
+
31
+ /// Minimum points before computing subspace (need enough for covariance)
32
+ pub min_points_for_subspace: usize,
33
+
34
+ /// Weight of subspace similarity vs centroid similarity (0.0 = centroid only)
35
+ pub subspace_weight: f32,
36
+
37
+ /// Enable incremental covariance updates during insertion (vs only during consolidation)
38
+ /// When false, subspace is only computed during consolidation - much faster inserts
39
+ pub incremental_covariance: bool,
40
+ }
41
+
42
+ impl Default for SubspaceConfig {
43
+ fn default() -> Self {
44
+ Self {
45
+ rank: 3, // Track top 3 principal directions
46
+ min_points_for_subspace: 5, // Need at least 5 points for meaningful covariance
47
+ subspace_weight: 0.3, // 30% subspace, 70% centroid by default
48
+ incremental_covariance: false, // Default: only compute during consolidation (faster)
49
+ }
50
+ }
51
+ }
52
+
53
+ impl SubspaceConfig {
54
+ pub fn new() -> Self {
55
+ Self::default()
56
+ }
57
+
58
+ pub fn with_rank(mut self, rank: usize) -> Self {
59
+ self.rank = rank;
60
+ self
61
+ }
62
+
63
+ pub fn with_subspace_weight(mut self, weight: f32) -> Self {
64
+ self.subspace_weight = weight.clamp(0.0, 1.0);
65
+ self
66
+ }
67
+ }
68
+
69
+ /// Subspace representation for a container
70
+ ///
71
+ /// Stores the centroid plus principal directions that capture
72
+ /// the variance/spread of points within the container.
73
+ #[derive(Debug, Clone)]
74
+ pub struct Subspace {
75
+ /// Centroid (mean of points)
76
+ pub centroid: Point,
77
+
78
+ /// Principal directions (orthonormal basis for subspace)
79
+ /// Each direction is a unit vector
80
+ pub principal_directions: Vec<Point>,
81
+
82
+ /// Eigenvalues (variance in each principal direction)
83
+ /// Stored in decreasing order
84
+ pub eigenvalues: Vec<f32>,
85
+
86
+ /// Number of points used to compute this subspace
87
+ pub point_count: usize,
88
+
89
+ /// Running sum for incremental centroid updates
90
+ accumulated_sum: Vec<f32>,
91
+
92
+ /// Running covariance matrix (upper triangle only for efficiency)
93
+ /// For incremental updates: cov = (sum of outer products) / n - mean * mean^T
94
+ accumulated_outer_product: Vec<f32>,
95
+ }
96
+
97
+ impl Subspace {
98
+ /// Create a new empty subspace
99
+ pub fn new(dimensionality: usize) -> Self {
100
+ Self {
101
+ centroid: Point::origin(dimensionality),
102
+ principal_directions: Vec::new(),
103
+ eigenvalues: Vec::new(),
104
+ point_count: 0,
105
+ accumulated_sum: vec![0.0; dimensionality],
106
+ // Upper triangle of d x d matrix: d * (d + 1) / 2 elements
107
+ accumulated_outer_product: vec![0.0; dimensionality * (dimensionality + 1) / 2],
108
+ }
109
+ }
110
+
111
+ /// Create from a single point
112
+ pub fn from_point(point: &Point) -> Self {
113
+ Self {
114
+ centroid: point.clone(),
115
+ principal_directions: Vec::new(),
116
+ eigenvalues: Vec::new(),
117
+ point_count: 1,
118
+ accumulated_sum: point.dims().to_vec(),
119
+ accumulated_outer_product: Self::outer_product_upper(point.dims()),
120
+ }
121
+ }
122
+
123
+ /// Dimensionality of the ambient space
124
+ pub fn dimensionality(&self) -> usize {
125
+ self.centroid.dimensionality()
126
+ }
127
+
128
+ /// Check if subspace has meaningful principal directions
129
+ pub fn has_subspace(&self) -> bool {
130
+ !self.principal_directions.is_empty()
131
+ }
132
+
133
+ /// Get the subspace rank (number of principal directions)
134
+ pub fn rank(&self) -> usize {
135
+ self.principal_directions.len()
136
+ }
137
+
138
+ /// Compute upper triangle of outer product v * v^T
139
+ fn outer_product_upper(v: &[f32]) -> Vec<f32> {
140
+ let n = v.len();
141
+ let mut result = vec![0.0; n * (n + 1) / 2];
142
+ let mut idx = 0;
143
+ for i in 0..n {
144
+ for j in i..n {
145
+ result[idx] = v[i] * v[j];
146
+ idx += 1;
147
+ }
148
+ }
149
+ result
150
+ }
151
+
152
+ /// Get element from upper triangle storage
153
+ fn get_upper(&self, i: usize, j: usize) -> f32 {
154
+ let (row, col) = if i <= j { (i, j) } else { (j, i) };
155
+ let n = self.dimensionality();
156
+ // Index into upper triangle
157
+ let idx = row * (2 * n - row - 1) / 2 + col;
158
+ self.accumulated_outer_product[idx]
159
+ }
160
+
161
+ /// Add element to upper triangle storage
162
+ fn add_to_upper(&mut self, i: usize, j: usize, value: f32) {
163
+ let (row, col) = if i <= j { (i, j) } else { (j, i) };
164
+ let n = self.dimensionality();
165
+ let idx = row * (2 * n - row - 1) / 2 + col;
166
+ self.accumulated_outer_product[idx] += value;
167
+ }
168
+
169
+ /// Incrementally add a point
170
+ pub fn add_point(&mut self, point: &Point) {
171
+ let dims = point.dims();
172
+
173
+ // Update running sum
174
+ for (i, &v) in dims.iter().enumerate() {
175
+ self.accumulated_sum[i] += v;
176
+ }
177
+
178
+ // Update outer product accumulator
179
+ for i in 0..dims.len() {
180
+ for j in i..dims.len() {
181
+ self.add_to_upper(i, j, dims[i] * dims[j]);
182
+ }
183
+ }
184
+
185
+ self.point_count += 1;
186
+
187
+ // Update centroid
188
+ let n = self.point_count as f32;
189
+ let centroid_dims: Vec<f32> = self.accumulated_sum.iter()
190
+ .map(|&s| s / n)
191
+ .collect();
192
+ self.centroid = Point::new(centroid_dims).normalize();
193
+ }
194
+
195
+ /// Compute covariance matrix from accumulated statistics
196
+ fn compute_covariance(&self) -> Vec<Vec<f32>> {
197
+ let n = self.dimensionality();
198
+ let count = self.point_count as f32;
199
+
200
+ if count < 2.0 {
201
+ return vec![vec![0.0; n]; n];
202
+ }
203
+
204
+ // Mean vector
205
+ let mean: Vec<f32> = self.accumulated_sum.iter()
206
+ .map(|&s| s / count)
207
+ .collect();
208
+
209
+ // Covariance = E[X*X^T] - E[X]*E[X]^T
210
+ let mut cov = vec![vec![0.0; n]; n];
211
+ for i in 0..n {
212
+ for j in i..n {
213
+ let exx = self.get_upper(i, j) / count;
214
+ let exex = mean[i] * mean[j];
215
+ let c = exx - exex;
216
+ cov[i][j] = c;
217
+ cov[j][i] = c; // Symmetric
218
+ }
219
+ }
220
+
221
+ cov
222
+ }
223
+
224
+ /// Recompute principal directions from covariance
225
+ /// Uses power iteration for efficiency (avoids full eigendecomposition)
226
+ pub fn recompute_subspace(&mut self, rank: usize) {
227
+ if self.point_count < 3 {
228
+ // Not enough points for meaningful subspace
229
+ self.principal_directions.clear();
230
+ self.eigenvalues.clear();
231
+ return;
232
+ }
233
+
234
+ let cov = self.compute_covariance();
235
+ let n = self.dimensionality();
236
+
237
+ // Extract top-k eigenvectors using power iteration with deflation
238
+ let mut directions = Vec::new();
239
+ let mut values = Vec::new();
240
+ let mut working_cov = cov.clone();
241
+
242
+ for _ in 0..rank.min(n) {
243
+ // Power iteration for dominant eigenvector
244
+ let (eigval, eigvec) = self.power_iteration(&working_cov, 50);
245
+
246
+ if eigval < 1e-8 {
247
+ break; // No more significant variance
248
+ }
249
+
250
+ values.push(eigval);
251
+ directions.push(Point::new(eigvec.clone()).normalize());
252
+
253
+ // Deflate: remove this eigenvector's contribution
254
+ for i in 0..n {
255
+ for j in 0..n {
256
+ working_cov[i][j] -= eigval * eigvec[i] * eigvec[j];
257
+ }
258
+ }
259
+ }
260
+
261
+ self.principal_directions = directions;
262
+ self.eigenvalues = values;
263
+ }
264
+
265
+ /// Power iteration to find dominant eigenvector
266
+ fn power_iteration(&self, matrix: &[Vec<f32>], max_iters: usize) -> (f32, Vec<f32>) {
267
+ let n = matrix.len();
268
+
269
+ // Initialize with random-ish vector (use first column of matrix + perturbation)
270
+ let mut v: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.1).collect();
271
+ let mut norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
272
+ for x in &mut v {
273
+ *x /= norm;
274
+ }
275
+
276
+ let mut eigenvalue = 0.0f32;
277
+
278
+ for _ in 0..max_iters {
279
+ // v_new = A * v
280
+ let mut v_new = vec![0.0; n];
281
+ for i in 0..n {
282
+ for j in 0..n {
283
+ v_new[i] += matrix[i][j] * v[j];
284
+ }
285
+ }
286
+
287
+ // Compute eigenvalue approximation
288
+ eigenvalue = v_new.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
289
+
290
+ // Normalize
291
+ norm = v_new.iter().map(|x| x * x).sum::<f32>().sqrt();
292
+ if norm < 1e-10 {
293
+ return (0.0, vec![0.0; n]);
294
+ }
295
+
296
+ let converged = v.iter().zip(v_new.iter())
297
+ .map(|(a, b)| (a - b / norm).abs())
298
+ .sum::<f32>() < 1e-8;
299
+
300
+ for i in 0..n {
301
+ v[i] = v_new[i] / norm;
302
+ }
303
+
304
+ if converged {
305
+ break;
306
+ }
307
+ }
308
+
309
+ (eigenvalue.abs(), v)
310
+ }
311
+ }
312
+
313
+ /// Compute subspace similarity using principal angles
314
+ ///
315
+ /// Based on Grassmann geometry: the similarity between two subspaces
316
+ /// is determined by the principal angles between them.
317
+ ///
318
+ /// For k-dimensional subspaces, there are k principal angles θ₁...θₖ
319
+ /// where 0 ≤ θ₁ ≤ ... ≤ θₖ ≤ π/2
320
+ ///
321
+ /// Common measures:
322
+ /// - Projection similarity: Σ cos²(θᵢ) / k (ranges 0-1)
323
+ /// - Geodesic distance: sqrt(Σ θᵢ²)
324
+ /// - Chordal distance: sqrt(Σ sin²(θᵢ))
325
+ pub fn subspace_similarity(a: &Subspace, b: &Subspace) -> f32 {
326
+ // If either has no subspace, fall back to centroid similarity
327
+ if !a.has_subspace() || !b.has_subspace() {
328
+ return centroid_similarity(&a.centroid, &b.centroid);
329
+ }
330
+
331
+ // Compute inner products between principal directions
332
+ let rank_a = a.rank();
333
+ let rank_b = b.rank();
334
+ let k = rank_a.min(rank_b);
335
+
336
+ if k == 0 {
337
+ return centroid_similarity(&a.centroid, &b.centroid);
338
+ }
339
+
340
+ // Build matrix M where M[i][j] = <a_i, b_j> (dot products)
341
+ let mut m = vec![vec![0.0f32; rank_b]; rank_a];
342
+ for i in 0..rank_a {
343
+ for j in 0..rank_b {
344
+ let dot: f32 = a.principal_directions[i].dims().iter()
345
+ .zip(b.principal_directions[j].dims().iter())
346
+ .map(|(x, y)| x * y)
347
+ .sum();
348
+ m[i][j] = dot;
349
+ }
350
+ }
351
+
352
+ // SVD of M gives principal angles: σᵢ = cos(θᵢ)
353
+ // For simplicity, use a greedy approximation:
354
+ // Find k maximum entries while avoiding row/column reuse
355
+ let cos_angles = greedy_max_matching(&m, k);
356
+
357
+ // Projection similarity: mean of cos²(θᵢ)
358
+ let similarity: f32 = cos_angles.iter()
359
+ .map(|&c| c * c) // cos²(θ)
360
+ .sum::<f32>() / k as f32;
361
+
362
+ similarity
363
+ }
364
+
365
+ /// Greedy approximation to find k largest entries with no repeated rows/columns
366
+ fn greedy_max_matching(m: &[Vec<f32>], k: usize) -> Vec<f32> {
367
+ let rows = m.len();
368
+ let cols = if rows > 0 { m[0].len() } else { 0 };
369
+
370
+ let mut used_rows = vec![false; rows];
371
+ let mut used_cols = vec![false; cols];
372
+ let mut result = Vec::new();
373
+
374
+ for _ in 0..k {
375
+ let mut best = (0, 0, 0.0f32);
376
+
377
+ for i in 0..rows {
378
+ if used_rows[i] { continue; }
379
+ for j in 0..cols {
380
+ if used_cols[j] { continue; }
381
+ let val = m[i][j].abs();
382
+ if val > best.2 {
383
+ best = (i, j, val);
384
+ }
385
+ }
386
+ }
387
+
388
+ if best.2 > 0.0 {
389
+ used_rows[best.0] = true;
390
+ used_cols[best.1] = true;
391
+ result.push(best.2);
392
+ } else {
393
+ break;
394
+ }
395
+ }
396
+
397
+ result
398
+ }
399
+
400
+ /// Simple centroid similarity (cosine)
401
+ fn centroid_similarity(a: &Point, b: &Point) -> f32 {
402
+ let dot: f32 = a.dims().iter()
403
+ .zip(b.dims().iter())
404
+ .map(|(x, y)| x * y)
405
+ .sum();
406
+ dot.clamp(-1.0, 1.0)
407
+ }
408
+
409
+ /// Combined similarity: weighted combination of centroid and subspace similarity
410
+ ///
411
+ /// score = (1 - weight) * centroid_sim + weight * subspace_sim
412
+ pub fn combined_subspace_similarity(
413
+ query: &Point,
414
+ container: &Subspace,
415
+ config: &SubspaceConfig,
416
+ ) -> f32 {
417
+ let centroid_sim = centroid_similarity(query, &container.centroid);
418
+
419
+ if !container.has_subspace() || config.subspace_weight < 1e-6 {
420
+ return centroid_sim;
421
+ }
422
+
423
+ // Subspace similarity: how well does query align with principal directions?
424
+ // Measure: sum of squared projections onto principal directions
425
+ let subspace_sim = query_subspace_alignment(query, container);
426
+
427
+ // Weighted combination
428
+ let w = config.subspace_weight;
429
+ (1.0 - w) * centroid_sim + w * subspace_sim
430
+ }
431
+
432
+ /// Measure how well a query aligns with a subspace
433
+ ///
434
+ /// Higher score means query is well-captured by the subspace's principal directions
435
+ pub fn query_subspace_alignment(query: &Point, subspace: &Subspace) -> f32 {
436
+ if !subspace.has_subspace() {
437
+ return centroid_similarity(query, &subspace.centroid);
438
+ }
439
+
440
+ // Center query relative to centroid
441
+ let centered: Vec<f32> = query.dims().iter()
442
+ .zip(subspace.centroid.dims().iter())
443
+ .map(|(q, c)| q - c)
444
+ .collect();
445
+
446
+ let centered_norm: f32 = centered.iter().map(|x| x * x).sum::<f32>().sqrt();
447
+ if centered_norm < 1e-10 {
448
+ // Query is at centroid - perfect match
449
+ return 1.0;
450
+ }
451
+
452
+ // Compute squared projections onto each principal direction
453
+ let mut total_proj_sq = 0.0f32;
454
+ for (dir, &eigenval) in subspace.principal_directions.iter().zip(subspace.eigenvalues.iter()) {
455
+ let proj: f32 = centered.iter()
456
+ .zip(dir.dims().iter())
457
+ .map(|(c, d)| c * d)
458
+ .sum();
459
+
460
+ // Weight by eigenvalue (variance in that direction)
461
+ // Higher eigenvalue = more likely direction for data variation
462
+ let weight = (eigenval / subspace.eigenvalues[0]).sqrt();
463
+ total_proj_sq += proj * proj * weight;
464
+ }
465
+
466
+ // Normalize by centered query magnitude
467
+ let alignment = (total_proj_sq / (centered_norm * centered_norm)).min(1.0);
468
+
469
+ // Combine with centroid similarity for overall score
470
+ let centroid_sim = centroid_similarity(query, &subspace.centroid);
471
+
472
+ // Score: close to centroid AND aligned with principal directions
473
+ (centroid_sim + alignment) / 2.0
474
+ }
475
+
476
+ /// Compute the "spread" or diversity of a subspace
477
+ ///
478
+ /// Higher values indicate more diverse content (larger variance)
479
+ /// Lower values indicate tightly clustered content
480
+ pub fn subspace_spread(subspace: &Subspace) -> f32 {
481
+ if subspace.eigenvalues.is_empty() {
482
+ return 0.0;
483
+ }
484
+
485
+ // Total variance (sum of eigenvalues)
486
+ subspace.eigenvalues.iter().sum()
487
+ }
488
+
489
+ /// Compute the "isotropy" of a subspace
490
+ ///
491
+ /// Higher values (close to 1) indicate uniform spread in all directions
492
+ /// Lower values indicate elongated, anisotropic distribution
493
+ pub fn subspace_isotropy(subspace: &Subspace) -> f32 {
494
+ if subspace.eigenvalues.len() < 2 {
495
+ return 1.0; // Single direction is perfectly "isotropic" in its subspace
496
+ }
497
+
498
+ // Ratio of smallest to largest eigenvalue
499
+ let max = subspace.eigenvalues[0];
500
+ let min = *subspace.eigenvalues.last().unwrap();
501
+
502
+ if max < 1e-10 {
503
+ return 1.0;
504
+ }
505
+
506
+ min / max
507
+ }
508
+
509
+ #[cfg(test)]
510
+ mod tests {
511
+ use super::*;
512
+
513
+ fn make_point(v: Vec<f32>) -> Point {
514
+ Point::new(v).normalize()
515
+ }
516
+
517
+ #[test]
518
+ fn test_subspace_creation() {
519
+ let mut subspace = Subspace::new(3);
520
+
521
+ // Add some points
522
+ subspace.add_point(&make_point(vec![1.0, 0.0, 0.0]));
523
+ subspace.add_point(&make_point(vec![0.9, 0.1, 0.0]));
524
+ subspace.add_point(&make_point(vec![0.8, 0.2, 0.0]));
525
+ subspace.add_point(&make_point(vec![0.7, 0.3, 0.1]));
526
+ subspace.add_point(&make_point(vec![0.6, 0.4, 0.1]));
527
+
528
+ assert_eq!(subspace.point_count, 5);
529
+
530
+ // Compute principal directions
531
+ subspace.recompute_subspace(2);
532
+
533
+ assert!(subspace.has_subspace());
534
+ assert!(subspace.rank() > 0);
535
+ assert!(!subspace.eigenvalues.is_empty());
536
+
537
+ println!("Centroid: {:?}", subspace.centroid.dims());
538
+ println!("Principal directions: {}", subspace.rank());
539
+ println!("Eigenvalues: {:?}", subspace.eigenvalues);
540
+ }
541
+
542
+ #[test]
543
+ fn test_subspace_similarity() {
544
+ let mut a = Subspace::new(3);
545
+ let mut b = Subspace::new(3);
546
+
547
+ // Subspace A: points along x-axis
548
+ for i in 0..10 {
549
+ let x = 1.0 - i as f32 * 0.05;
550
+ let y = i as f32 * 0.05;
551
+ a.add_point(&make_point(vec![x, y, 0.0]));
552
+ }
553
+
554
+ // Subspace B: similar points (should be high similarity)
555
+ for i in 0..10 {
556
+ let x = 0.95 - i as f32 * 0.04;
557
+ let y = i as f32 * 0.04 + 0.05;
558
+ b.add_point(&make_point(vec![x, y, 0.1]));
559
+ }
560
+
561
+ a.recompute_subspace(2);
562
+ b.recompute_subspace(2);
563
+
564
+ let sim = subspace_similarity(&a, &b);
565
+ println!("Similarity between similar subspaces: {:.3}", sim);
566
+ assert!(sim > 0.5, "Similar subspaces should have high similarity");
567
+
568
+ // Subspace C: orthogonal to A (along z-axis)
569
+ let mut c = Subspace::new(3);
570
+ for i in 0..10 {
571
+ let z = 1.0 - i as f32 * 0.05;
572
+ c.add_point(&make_point(vec![0.0, 0.1, z]));
573
+ }
574
+ c.recompute_subspace(2);
575
+
576
+ let sim_ac = subspace_similarity(&a, &c);
577
+ println!("Similarity between orthogonal subspaces: {:.3}", sim_ac);
578
+ assert!(sim_ac < sim, "Orthogonal subspaces should have lower similarity");
579
+ }
580
+
581
+ #[test]
582
+ fn test_query_alignment() {
583
+ let mut subspace = Subspace::new(3);
584
+
585
+ // Points primarily along x-axis with some y variation
586
+ for i in 0..20 {
587
+ let x = 0.8 + (i % 3) as f32 * 0.1;
588
+ let y = (i as f32 * 0.05) % 0.3;
589
+ subspace.add_point(&make_point(vec![x, y, 0.05]));
590
+ }
591
+ subspace.recompute_subspace(2);
592
+
593
+ // Query aligned with subspace
594
+ let aligned_query = make_point(vec![0.9, 0.1, 0.0]);
595
+ let aligned_score = query_subspace_alignment(&aligned_query, &subspace);
596
+
597
+ // Query orthogonal to subspace
598
+ let orthogonal_query = make_point(vec![0.0, 0.0, 1.0]);
599
+ let orthogonal_score = query_subspace_alignment(&orthogonal_query, &subspace);
600
+
601
+ println!("Aligned query score: {:.3}", aligned_score);
602
+ println!("Orthogonal query score: {:.3}", orthogonal_score);
603
+
604
+ assert!(aligned_score > orthogonal_score,
605
+ "Aligned query should score higher than orthogonal query");
606
+ }
607
+
608
+ #[test]
609
+ fn test_spread_and_isotropy() {
610
+ let mut tight = Subspace::new(3);
611
+ let mut spread_out = Subspace::new(3);
612
+
613
+ // Tight cluster
614
+ for _ in 0..20 {
615
+ tight.add_point(&make_point(vec![0.9, 0.1, 0.05]));
616
+ }
617
+
618
+ // Spread out cluster
619
+ for i in 0..20 {
620
+ let angle = i as f32 * 0.3;
621
+ spread_out.add_point(&make_point(vec![
622
+ angle.cos(),
623
+ angle.sin(),
624
+ 0.1
625
+ ]));
626
+ }
627
+
628
+ tight.recompute_subspace(3);
629
+ spread_out.recompute_subspace(3);
630
+
631
+ let tight_spread = subspace_spread(&tight);
632
+ let wide_spread = subspace_spread(&spread_out);
633
+
634
+ println!("Tight cluster spread: {:.6}", tight_spread);
635
+ println!("Wide cluster spread: {:.6}", wide_spread);
636
+
637
+ // Note: with normalized vectors the spread comparison might not be as expected
638
+ // The test validates the computation runs correctly
639
+ }
640
+ }
src/adapters/mod.rs ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Adapters
2
+ //!
3
+ //! Swappable implementations of port traits.
4
+ //!
5
+ //! This is where the hexagonal architecture meets reality:
6
+ //! - Storage adapters: Memory, NVMe
7
+ //! - Index adapters: Flat (brute force), HNSW (approximate)
8
+ //! - Attention state serialization
9
+ //! - Python bindings (when enabled)
10
+ //!
11
+ //! Each adapter implements one or more port traits.
12
+ //! Adapters can be swapped without changing core logic.
13
+
14
+ pub mod storage;
15
+ pub mod index;
16
+ pub mod attention;
17
+
18
+ #[cfg(feature = "python")]
19
+ pub mod python;
src/adapters/python.rs ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Python Bindings
2
+ //!
3
+ //! PyO3 bindings for ARMS-HAT, enabling Python integration with LLMs.
4
+ //!
5
+ //! ## Python API
6
+ //!
7
+ //! ```python
8
+ //! from arms_hat import HatIndex, SearchResult
9
+ //!
10
+ //! # Create index for OpenAI embeddings (1536 dims)
11
+ //! index = HatIndex.cosine(1536)
12
+ //!
13
+ //! # Add embeddings
14
+ //! id = index.add([0.1, 0.2, ...]) # Auto-generates ID
15
+ //! index.add_with_id("custom_id", [0.1, 0.2, ...]) # Custom ID
16
+ //!
17
+ //! # Query
18
+ //! results = index.near([0.1, 0.2, ...], k=10)
19
+ //! for result in results:
20
+ //! print(f"{result.id}: {result.score}")
21
+ //!
22
+ //! # Session management
23
+ //! index.new_session()
24
+ //! index.new_document()
25
+ //!
26
+ //! # Persistence
27
+ //! index.save("memory.hat")
28
+ //! loaded = HatIndex.load("memory.hat")
29
+ //! ```
30
+
31
+ use pyo3::prelude::*;
32
+ use pyo3::exceptions::{PyValueError, PyIOError};
33
+
34
+ use crate::core::{Id, Point};
35
+ use crate::adapters::index::{HatIndex as RustHatIndex, HatConfig, ConsolidationConfig, Consolidate};
36
+ use crate::ports::Near;
37
+
38
+ /// Python wrapper for search results
39
+ #[pyclass(name = "SearchResult")]
40
+ #[derive(Clone)]
41
+ pub struct PySearchResult {
42
+ /// The ID as a hex string
43
+ #[pyo3(get)]
44
+ pub id: String,
45
+
46
+ /// The similarity/distance score
47
+ #[pyo3(get)]
48
+ pub score: f32,
49
+ }
50
+
51
+ #[pymethods]
52
+ impl PySearchResult {
53
+ fn __repr__(&self) -> String {
54
+ format!("SearchResult(id='{}', score={:.4})", self.id, self.score)
55
+ }
56
+
57
+ fn __str__(&self) -> String {
58
+ format!("{}: {:.4}", self.id, self.score)
59
+ }
60
+ }
61
+
62
+ /// Python wrapper for HAT index configuration
63
+ #[pyclass(name = "HatConfig")]
64
+ #[derive(Clone)]
65
+ pub struct PyHatConfig {
66
+ inner: HatConfig,
67
+ }
68
+
69
+ #[pymethods]
70
+ impl PyHatConfig {
71
+ #[new]
72
+ fn new() -> Self {
73
+ Self { inner: HatConfig::default() }
74
+ }
75
+
76
+ /// Set beam width for search (default: 3)
77
+ fn with_beam_width(mut slf: PyRefMut<'_, Self>, width: usize) -> PyRefMut<'_, Self> {
78
+ slf.inner.beam_width = width;
79
+ slf
80
+ }
81
+
82
+ /// Set temporal weight (0.0 = pure semantic, 1.0 = pure temporal)
83
+ fn with_temporal_weight(mut slf: PyRefMut<'_, Self>, weight: f32) -> PyRefMut<'_, Self> {
84
+ slf.inner.temporal_weight = weight;
85
+ slf
86
+ }
87
+
88
+ /// Set propagation threshold for sparse updates
89
+ fn with_propagation_threshold(mut slf: PyRefMut<'_, Self>, threshold: f32) -> PyRefMut<'_, Self> {
90
+ slf.inner.propagation_threshold = threshold;
91
+ slf
92
+ }
93
+
94
+ fn __repr__(&self) -> String {
95
+ format!(
96
+ "HatConfig(beam_width={}, temporal_weight={:.2}, propagation_threshold={:.3})",
97
+ self.inner.beam_width, self.inner.temporal_weight, self.inner.propagation_threshold
98
+ )
99
+ }
100
+ }
101
+
102
+ /// Session summary for coarse-grained retrieval
103
+ #[pyclass(name = "SessionSummary")]
104
+ #[derive(Clone)]
105
+ pub struct PySessionSummary {
106
+ #[pyo3(get)]
107
+ pub id: String,
108
+
109
+ #[pyo3(get)]
110
+ pub score: f32,
111
+
112
+ #[pyo3(get)]
113
+ pub chunk_count: usize,
114
+
115
+ #[pyo3(get)]
116
+ pub timestamp_ms: u64,
117
+ }
118
+
119
+ #[pymethods]
120
+ impl PySessionSummary {
121
+ fn __repr__(&self) -> String {
122
+ format!(
123
+ "SessionSummary(id='{}', score={:.4}, chunks={})",
124
+ self.id, self.score, self.chunk_count
125
+ )
126
+ }
127
+ }
128
+
129
+ /// Document summary for mid-level retrieval
130
+ #[pyclass(name = "DocumentSummary")]
131
+ #[derive(Clone)]
132
+ pub struct PyDocumentSummary {
133
+ #[pyo3(get)]
134
+ pub id: String,
135
+
136
+ #[pyo3(get)]
137
+ pub score: f32,
138
+
139
+ #[pyo3(get)]
140
+ pub chunk_count: usize,
141
+ }
142
+
143
+ #[pymethods]
144
+ impl PyDocumentSummary {
145
+ fn __repr__(&self) -> String {
146
+ format!(
147
+ "DocumentSummary(id='{}', score={:.4}, chunks={})",
148
+ self.id, self.score, self.chunk_count
149
+ )
150
+ }
151
+ }
152
+
153
+ /// Index statistics
154
+ #[pyclass(name = "HatStats")]
155
+ #[derive(Clone)]
156
+ pub struct PyHatStats {
157
+ #[pyo3(get)]
158
+ pub global_count: usize,
159
+
160
+ #[pyo3(get)]
161
+ pub session_count: usize,
162
+
163
+ #[pyo3(get)]
164
+ pub document_count: usize,
165
+
166
+ #[pyo3(get)]
167
+ pub chunk_count: usize,
168
+ }
169
+
170
+ #[pymethods]
171
+ impl PyHatStats {
172
+ /// Total number of indexed points
173
+ #[getter]
174
+ fn total_points(&self) -> usize {
175
+ self.chunk_count
176
+ }
177
+
178
+ fn __repr__(&self) -> String {
179
+ format!(
180
+ "HatStats(points={}, sessions={}, documents={}, chunks={})",
181
+ self.chunk_count, self.session_count, self.document_count, self.chunk_count
182
+ )
183
+ }
184
+ }
185
+
186
+ /// Hierarchical Attention Tree Index
187
+ ///
188
+ /// A semantic memory index optimized for conversation history retrieval.
189
+ /// Uses hierarchical structure (session -> document -> chunk) to enable
190
+ /// O(log n) queries while maintaining high recall.
191
+ #[pyclass(name = "HatIndex")]
192
+ pub struct PyHatIndex {
193
+ inner: RustHatIndex,
194
+ }
195
+
196
+ #[pymethods]
197
+ impl PyHatIndex {
198
+ /// Create a new HAT index with cosine similarity
199
+ ///
200
+ /// Args:
201
+ /// dimensionality: Number of embedding dimensions (e.g., 1536 for OpenAI)
202
+ #[staticmethod]
203
+ fn cosine(dimensionality: usize) -> Self {
204
+ Self {
205
+ inner: RustHatIndex::cosine(dimensionality),
206
+ }
207
+ }
208
+
209
+ /// Create a new HAT index with custom configuration
210
+ ///
211
+ /// Args:
212
+ /// dimensionality: Number of embedding dimensions
213
+ /// config: HatConfig instance
214
+ #[staticmethod]
215
+ fn with_config(dimensionality: usize, config: &PyHatConfig) -> Self {
216
+ Self {
217
+ inner: RustHatIndex::cosine(dimensionality).with_config(config.inner.clone()),
218
+ }
219
+ }
220
+
221
+ /// Add an embedding to the index
222
+ ///
223
+ /// Args:
224
+ /// embedding: List of floats (must match dimensionality)
225
+ ///
226
+ /// Returns:
227
+ /// str: The generated ID as a hex string
228
+ fn add(&mut self, embedding: Vec<f32>) -> PyResult<String> {
229
+ let point = Point::new(embedding);
230
+ let id = Id::now();
231
+
232
+ self.inner.add(id, &point)
233
+ .map_err(|e| PyValueError::new_err(format!("{}", e)))?;
234
+
235
+ Ok(format!("{}", id))
236
+ }
237
+
238
+ /// Add an embedding with a custom ID
239
+ ///
240
+ /// Args:
241
+ /// id_hex: 32-character hex string for the ID
242
+ /// embedding: List of floats (must match dimensionality)
243
+ fn add_with_id(&mut self, id_hex: &str, embedding: Vec<f32>) -> PyResult<()> {
244
+ let id = parse_id_hex(id_hex)?;
245
+ let point = Point::new(embedding);
246
+
247
+ self.inner.add(id, &point)
248
+ .map_err(|e| PyValueError::new_err(format!("{}", e)))?;
249
+
250
+ Ok(())
251
+ }
252
+
253
+ /// Find k nearest neighbors to a query embedding
254
+ ///
255
+ /// Args:
256
+ /// query: Query embedding (list of floats)
257
+ /// k: Number of results to return
258
+ ///
259
+ /// Returns:
260
+ /// List[SearchResult]: Results sorted by relevance (best first)
261
+ fn near(&self, query: Vec<f32>, k: usize) -> PyResult<Vec<PySearchResult>> {
262
+ let point = Point::new(query);
263
+
264
+ let results = self.inner.near(&point, k)
265
+ .map_err(|e| PyValueError::new_err(format!("{}", e)))?;
266
+
267
+ Ok(results.into_iter().map(|r| PySearchResult {
268
+ id: format!("{}", r.id),
269
+ score: r.score,
270
+ }).collect())
271
+ }
272
+
273
+ /// Start a new session (conversation boundary)
274
+ ///
275
+ /// Call this when starting a new conversation or context.
276
+ fn new_session(&mut self) {
277
+ self.inner.new_session();
278
+ }
279
+
280
+ /// Start a new document within the current session
281
+ ///
282
+ /// Call this for logical groupings within a conversation
283
+ /// (e.g., topic change, user turn).
284
+ fn new_document(&mut self) {
285
+ self.inner.new_document();
286
+ }
287
+
288
+ /// Get index statistics
289
+ fn stats(&self) -> PyHatStats {
290
+ let s = self.inner.stats();
291
+ PyHatStats {
292
+ global_count: s.global_count,
293
+ session_count: s.session_count,
294
+ document_count: s.document_count,
295
+ chunk_count: s.chunk_count,
296
+ }
297
+ }
298
+
299
+ /// Get the number of indexed points
300
+ fn __len__(&self) -> usize {
301
+ self.inner.len()
302
+ }
303
+
304
+ /// Check if the index is empty
305
+ fn is_empty(&self) -> bool {
306
+ self.inner.is_empty()
307
+ }
308
+
309
+ /// Remove a point by ID
310
+ ///
311
+ /// Args:
312
+ /// id_hex: 32-character hex string for the ID
313
+ fn remove(&mut self, id_hex: &str) -> PyResult<()> {
314
+ let id = parse_id_hex(id_hex)?;
315
+
316
+ self.inner.remove(id)
317
+ .map_err(|e| PyValueError::new_err(format!("{}", e)))?;
318
+
319
+ Ok(())
320
+ }
321
+
322
+ /// Find similar sessions (coarse-grained search)
323
+ ///
324
+ /// Args:
325
+ /// query: Query embedding
326
+ /// k: Number of sessions to return
327
+ ///
328
+ /// Returns:
329
+ /// List[SessionSummary]: Most relevant sessions
330
+ fn near_sessions(&self, query: Vec<f32>, k: usize) -> PyResult<Vec<PySessionSummary>> {
331
+ let point = Point::new(query);
332
+
333
+ let results = self.inner.near_sessions(&point, k)
334
+ .map_err(|e| PyValueError::new_err(format!("{}", e)))?;
335
+
336
+ Ok(results.into_iter().map(|s| PySessionSummary {
337
+ id: format!("{}", s.id),
338
+ score: s.score,
339
+ chunk_count: s.chunk_count,
340
+ timestamp_ms: s.timestamp,
341
+ }).collect())
342
+ }
343
+
344
+ /// Find similar documents within a session
345
+ ///
346
+ /// Args:
347
+ /// session_id: Session ID (hex string)
348
+ /// query: Query embedding
349
+ /// k: Number of documents to return
350
+ ///
351
+ /// Returns:
352
+ /// List[DocumentSummary]: Most relevant documents in the session
353
+ fn near_documents(&self, session_id: &str, query: Vec<f32>, k: usize) -> PyResult<Vec<PyDocumentSummary>> {
354
+ let sid = parse_id_hex(session_id)?;
355
+ let point = Point::new(query);
356
+
357
+ let results = self.inner.near_documents(sid, &point, k)
358
+ .map_err(|e| PyValueError::new_err(format!("{}", e)))?;
359
+
360
+ Ok(results.into_iter().map(|d| PyDocumentSummary {
361
+ id: format!("{}", d.id),
362
+ score: d.score,
363
+ chunk_count: d.chunk_count,
364
+ }).collect())
365
+ }
366
+
367
+ /// Find chunks within a specific document
368
+ ///
369
+ /// Args:
370
+ /// doc_id: Document ID (hex string)
371
+ /// query: Query embedding
372
+ /// k: Number of results to return
373
+ ///
374
+ /// Returns:
375
+ /// List[SearchResult]: Most relevant chunks in the document
376
+ fn near_in_document(&self, doc_id: &str, query: Vec<f32>, k: usize) -> PyResult<Vec<PySearchResult>> {
377
+ let did = parse_id_hex(doc_id)?;
378
+ let point = Point::new(query);
379
+
380
+ let results = self.inner.near_in_document(did, &point, k)
381
+ .map_err(|e| PyValueError::new_err(format!("{}", e)))?;
382
+
383
+ Ok(results.into_iter().map(|r| PySearchResult {
384
+ id: format!("{}", r.id),
385
+ score: r.score,
386
+ }).collect())
387
+ }
388
+
389
+ /// Run light consolidation (background maintenance)
390
+ ///
391
+ /// This optimizes the index structure. Call periodically
392
+ /// (e.g., after every 100 inserts).
393
+ fn consolidate(&mut self) {
394
+ self.inner.consolidate(ConsolidationConfig::light());
395
+ }
396
+
397
+ /// Run full consolidation (more thorough optimization)
398
+ fn consolidate_full(&mut self) {
399
+ self.inner.consolidate(ConsolidationConfig::full());
400
+ }
401
+
402
+ /// Save the index to a file
403
+ ///
404
+ /// Args:
405
+ /// path: File path to save to
406
+ fn save(&self, path: &str) -> PyResult<()> {
407
+ self.inner.save_to_file(std::path::Path::new(path))
408
+ .map_err(|e| PyIOError::new_err(format!("{}", e)))
409
+ }
410
+
411
+ /// Load an index from a file
412
+ ///
413
+ /// Args:
414
+ /// path: File path to load from
415
+ ///
416
+ /// Returns:
417
+ /// HatIndex: The loaded index
418
+ #[staticmethod]
419
+ fn load(path: &str) -> PyResult<Self> {
420
+ let inner = RustHatIndex::load_from_file(std::path::Path::new(path))
421
+ .map_err(|e| PyIOError::new_err(format!("{}", e)))?;
422
+
423
+ Ok(Self { inner })
424
+ }
425
+
426
+ /// Serialize the index to bytes
427
+ ///
428
+ /// Returns:
429
+ /// bytes: Serialized index data
430
+ fn to_bytes<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, pyo3::types::PyBytes>> {
431
+ let data = self.inner.to_bytes()
432
+ .map_err(|e| PyIOError::new_err(format!("{}", e)))?;
433
+ Ok(pyo3::types::PyBytes::new_bound(py, &data))
434
+ }
435
+
436
+ /// Load an index from bytes
437
+ ///
438
+ /// Args:
439
+ /// data: Serialized index data
440
+ ///
441
+ /// Returns:
442
+ /// HatIndex: The loaded index
443
+ #[staticmethod]
444
+ fn from_bytes(data: &[u8]) -> PyResult<Self> {
445
+ let inner = RustHatIndex::from_bytes(data)
446
+ .map_err(|e| PyIOError::new_err(format!("{}", e)))?;
447
+
448
+ Ok(Self { inner })
449
+ }
450
+
451
+ fn __repr__(&self) -> String {
452
+ let stats = self.inner.stats();
453
+ format!(
454
+ "HatIndex(points={}, sessions={})",
455
+ stats.chunk_count, stats.session_count
456
+ )
457
+ }
458
+ }
459
+
460
+ /// Parse a hex string to an Id
461
+ fn parse_id_hex(hex: &str) -> PyResult<Id> {
462
+ if hex.len() != 32 {
463
+ return Err(PyValueError::new_err(
464
+ format!("ID must be 32 hex characters, got {}", hex.len())
465
+ ));
466
+ }
467
+
468
+ let mut bytes = [0u8; 16];
469
+ for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
470
+ let high = hex_char_to_nibble(chunk[0])?;
471
+ let low = hex_char_to_nibble(chunk[1])?;
472
+ bytes[i] = (high << 4) | low;
473
+ }
474
+
475
+ Ok(Id::from_bytes(bytes))
476
+ }
477
+
478
+ fn hex_char_to_nibble(c: u8) -> PyResult<u8> {
479
+ match c {
480
+ b'0'..=b'9' => Ok(c - b'0'),
481
+ b'a'..=b'f' => Ok(c - b'a' + 10),
482
+ b'A'..=b'F' => Ok(c - b'A' + 10),
483
+ _ => Err(PyValueError::new_err(format!("Invalid hex character: {}", c as char))),
484
+ }
485
+ }
486
+
487
+ /// ARMS-HAT Python module
488
+ #[pymodule]
489
+ fn arms_hat(m: &Bound<'_, PyModule>) -> PyResult<()> {
490
+ m.add_class::<PyHatIndex>()?;
491
+ m.add_class::<PyHatConfig>()?;
492
+ m.add_class::<PySearchResult>()?;
493
+ m.add_class::<PySessionSummary>()?;
494
+ m.add_class::<PyDocumentSummary>()?;
495
+ m.add_class::<PyHatStats>()?;
496
+
497
+ // Add module docstring
498
+ m.add("__doc__", "ARMS-HAT: Hierarchical Attention Tree for AI memory retrieval")?;
499
+ m.add("__version__", env!("CARGO_PKG_VERSION"))?;
500
+
501
+ Ok(())
502
+ }
src/adapters/storage/memory.rs ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Memory Storage Adapter
2
+ //!
3
+ //! In-memory storage using HashMap.
4
+ //! Fast, but volatile (data lost on shutdown).
5
+ //!
6
+ //! Good for:
7
+ //! - Testing
8
+ //! - Hot tier storage
9
+ //! - Small datasets
10
+
11
+ use std::collections::HashMap;
12
+
13
+ use crate::core::{Blob, Id, PlacedPoint, Point};
14
+ use crate::ports::{Place, PlaceError, PlaceResult};
15
+
16
+ /// In-memory storage adapter
17
+ pub struct MemoryStorage {
18
+ /// The stored points
19
+ points: HashMap<Id, PlacedPoint>,
20
+
21
+ /// Expected dimensionality
22
+ dimensionality: usize,
23
+
24
+ /// Maximum capacity in bytes (0 = unlimited)
25
+ capacity: usize,
26
+
27
+ /// Current size in bytes
28
+ current_size: usize,
29
+ }
30
+
31
+ impl MemoryStorage {
32
+ /// Create a new memory storage with specified dimensionality
33
+ pub fn new(dimensionality: usize) -> Self {
34
+ Self {
35
+ points: HashMap::new(),
36
+ dimensionality,
37
+ capacity: 0,
38
+ current_size: 0,
39
+ }
40
+ }
41
+
42
+ /// Create with a capacity limit
43
+ pub fn with_capacity(dimensionality: usize, capacity: usize) -> Self {
44
+ Self {
45
+ points: HashMap::new(),
46
+ dimensionality,
47
+ capacity,
48
+ current_size: 0,
49
+ }
50
+ }
51
+
52
+ /// Calculate size of a placed point in bytes
53
+ fn point_size(point: &PlacedPoint) -> usize {
54
+ // Id: 16 bytes
55
+ // Point: dims.len() * 4 bytes (f32)
56
+ // Blob: data.len() bytes
57
+ // Overhead: ~48 bytes for struct padding and HashMap entry
58
+ 16 + (point.point.dimensionality() * 4) + point.blob.size() + 48
59
+ }
60
+ }
61
+
62
+ impl Place for MemoryStorage {
63
+ fn place(&mut self, point: Point, blob: Blob) -> PlaceResult<Id> {
64
+ // Check dimensionality
65
+ if point.dimensionality() != self.dimensionality {
66
+ return Err(PlaceError::DimensionalityMismatch {
67
+ expected: self.dimensionality,
68
+ got: point.dimensionality(),
69
+ });
70
+ }
71
+
72
+ let id = Id::now();
73
+ let placed = PlacedPoint::new(id, point, blob);
74
+
75
+ // Check capacity
76
+ let size = Self::point_size(&placed);
77
+ if self.capacity > 0 && self.current_size + size > self.capacity {
78
+ return Err(PlaceError::CapacityExceeded);
79
+ }
80
+
81
+ self.current_size += size;
82
+ self.points.insert(id, placed);
83
+
84
+ Ok(id)
85
+ }
86
+
87
+ fn place_with_id(&mut self, id: Id, point: Point, blob: Blob) -> PlaceResult<()> {
88
+ // Check dimensionality
89
+ if point.dimensionality() != self.dimensionality {
90
+ return Err(PlaceError::DimensionalityMismatch {
91
+ expected: self.dimensionality,
92
+ got: point.dimensionality(),
93
+ });
94
+ }
95
+
96
+ // Check for duplicates
97
+ if self.points.contains_key(&id) {
98
+ return Err(PlaceError::DuplicateId(id));
99
+ }
100
+
101
+ let placed = PlacedPoint::new(id, point, blob);
102
+
103
+ // Check capacity
104
+ let size = Self::point_size(&placed);
105
+ if self.capacity > 0 && self.current_size + size > self.capacity {
106
+ return Err(PlaceError::CapacityExceeded);
107
+ }
108
+
109
+ self.current_size += size;
110
+ self.points.insert(id, placed);
111
+
112
+ Ok(())
113
+ }
114
+
115
+ fn remove(&mut self, id: Id) -> Option<PlacedPoint> {
116
+ if let Some(placed) = self.points.remove(&id) {
117
+ self.current_size -= Self::point_size(&placed);
118
+ Some(placed)
119
+ } else {
120
+ None
121
+ }
122
+ }
123
+
124
+ fn get(&self, id: Id) -> Option<&PlacedPoint> {
125
+ self.points.get(&id)
126
+ }
127
+
128
+ fn len(&self) -> usize {
129
+ self.points.len()
130
+ }
131
+
132
+ fn iter(&self) -> Box<dyn Iterator<Item = &PlacedPoint> + '_> {
133
+ Box::new(self.points.values())
134
+ }
135
+
136
+ fn size_bytes(&self) -> usize {
137
+ self.current_size
138
+ }
139
+
140
+ fn clear(&mut self) {
141
+ self.points.clear();
142
+ self.current_size = 0;
143
+ }
144
+ }
145
+
146
+ #[cfg(test)]
147
+ mod tests {
148
+ use super::*;
149
+
150
+ #[test]
151
+ fn test_memory_storage_place() {
152
+ let mut storage = MemoryStorage::new(3);
153
+
154
+ let point = Point::new(vec![1.0, 2.0, 3.0]);
155
+ let blob = Blob::from_str("test");
156
+
157
+ let id = storage.place(point, blob).unwrap();
158
+
159
+ assert_eq!(storage.len(), 1);
160
+ assert!(storage.contains(id));
161
+ }
162
+
163
+ #[test]
164
+ fn test_memory_storage_get() {
165
+ let mut storage = MemoryStorage::new(3);
166
+
167
+ let point = Point::new(vec![1.0, 2.0, 3.0]);
168
+ let blob = Blob::from_str("hello");
169
+
170
+ let id = storage.place(point, blob).unwrap();
171
+
172
+ let retrieved = storage.get(id).unwrap();
173
+ assert_eq!(retrieved.blob.as_str(), Some("hello"));
174
+ }
175
+
176
+ #[test]
177
+ fn test_memory_storage_remove() {
178
+ let mut storage = MemoryStorage::new(3);
179
+
180
+ let point = Point::new(vec![1.0, 2.0, 3.0]);
181
+ let id = storage.place(point, Blob::empty()).unwrap();
182
+
183
+ assert_eq!(storage.len(), 1);
184
+
185
+ let removed = storage.remove(id);
186
+ assert!(removed.is_some());
187
+ assert_eq!(storage.len(), 0);
188
+ assert!(!storage.contains(id));
189
+ }
190
+
191
+ #[test]
192
+ fn test_memory_storage_dimensionality_check() {
193
+ let mut storage = MemoryStorage::new(3);
194
+
195
+ let wrong_dims = Point::new(vec![1.0, 2.0]); // 2 dims, expected 3
196
+
197
+ let result = storage.place(wrong_dims, Blob::empty());
198
+
199
+ match result {
200
+ Err(PlaceError::DimensionalityMismatch { expected, got }) => {
201
+ assert_eq!(expected, 3);
202
+ assert_eq!(got, 2);
203
+ }
204
+ _ => panic!("Expected DimensionalityMismatch error"),
205
+ }
206
+ }
207
+
208
+ #[test]
209
+ fn test_memory_storage_capacity() {
210
+ // Small capacity - enough for one point but not two
211
+ // Point size: 16 (id) + 12 (3 f32s) + 10 (blob) + 48 (overhead) = 86 bytes
212
+ let mut storage = MemoryStorage::with_capacity(3, 150);
213
+
214
+ let point = Point::new(vec![1.0, 2.0, 3.0]);
215
+ let blob = Blob::new(vec![0u8; 10]); // Small blob
216
+
217
+ // First one should succeed
218
+ storage.place(point.clone(), blob.clone()).unwrap();
219
+
220
+ // Second should fail due to capacity
221
+ let result = storage.place(point, blob);
222
+ assert!(matches!(result, Err(PlaceError::CapacityExceeded)));
223
+ }
224
+
225
+ #[test]
226
+ fn test_memory_storage_clear() {
227
+ let mut storage = MemoryStorage::new(3);
228
+
229
+ for i in 0..10 {
230
+ let point = Point::new(vec![i as f32, 0.0, 0.0]);
231
+ storage.place(point, Blob::empty()).unwrap();
232
+ }
233
+
234
+ assert_eq!(storage.len(), 10);
235
+ assert!(storage.size_bytes() > 0);
236
+
237
+ storage.clear();
238
+
239
+ assert_eq!(storage.len(), 0);
240
+ assert_eq!(storage.size_bytes(), 0);
241
+ }
242
+
243
+ #[test]
244
+ fn test_memory_storage_iter() {
245
+ let mut storage = MemoryStorage::new(2);
246
+
247
+ storage.place(Point::new(vec![1.0, 0.0]), Blob::empty()).unwrap();
248
+ storage.place(Point::new(vec![0.0, 1.0]), Blob::empty()).unwrap();
249
+
250
+ let points: Vec<_> = storage.iter().collect();
251
+ assert_eq!(points.len(), 2);
252
+ }
253
+ }
src/adapters/storage/mod.rs ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Storage Adapters
2
+ //!
3
+ //! Implementations of the Place port for different storage backends.
4
+ //!
5
+ //! Available adapters:
6
+ //! - `MemoryStorage` - In-memory HashMap (fast, volatile)
7
+ //! - `NvmeStorage` - Memory-mapped NVMe (persistent, large) [TODO]
8
+
9
+ mod memory;
10
+
11
+ pub use memory::MemoryStorage;
12
+
13
+ // TODO: Add NVMe adapter
14
+ // mod nvme;
15
+ // pub use nvme::NvmeStorage;
src/core/blob.rs ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Blob
2
+ //!
3
+ //! Raw payload data attached to a point.
4
+ //!
5
+ //! ARMS doesn't interpret this data - it's yours.
6
+ //! Could be: tensor bytes, text, compressed state, anything.
7
+ //!
8
+ //! Separation of concerns:
9
+ //! - Point = WHERE (position in space)
10
+ //! - Blob = WHAT (the actual data)
11
+
12
+ /// Raw data attached to a point
13
+ ///
14
+ /// ARMS stores this opaquely. You define what it means.
15
+ #[derive(Clone, Debug, PartialEq)]
16
+ pub struct Blob {
17
+ data: Vec<u8>,
18
+ }
19
+
20
+ impl Blob {
21
+ /// Create a new blob from bytes
22
+ ///
23
+ /// # Example
24
+ /// ```
25
+ /// use arms::Blob;
26
+ /// let blob = Blob::new(vec![1, 2, 3, 4]);
27
+ /// assert_eq!(blob.size(), 4);
28
+ /// ```
29
+ pub fn new(data: Vec<u8>) -> Self {
30
+ Self { data }
31
+ }
32
+
33
+ /// Create an empty blob
34
+ ///
35
+ /// Useful when you only care about position, not payload.
36
+ pub fn empty() -> Self {
37
+ Self { data: vec![] }
38
+ }
39
+
40
+ /// Create a blob from a string (UTF-8 bytes)
41
+ ///
42
+ /// # Example
43
+ /// ```
44
+ /// use arms::Blob;
45
+ /// let blob = Blob::from_str("hello");
46
+ /// assert_eq!(blob.as_str(), Some("hello"));
47
+ /// ```
48
+ pub fn from_str(s: &str) -> Self {
49
+ Self {
50
+ data: s.as_bytes().to_vec(),
51
+ }
52
+ }
53
+
54
+ /// Get the raw bytes
55
+ pub fn data(&self) -> &[u8] {
56
+ &self.data
57
+ }
58
+
59
+ /// Get the size in bytes
60
+ pub fn size(&self) -> usize {
61
+ self.data.len()
62
+ }
63
+
64
+ /// Check if the blob is empty
65
+ pub fn is_empty(&self) -> bool {
66
+ self.data.is_empty()
67
+ }
68
+
69
+ /// Try to interpret as UTF-8 string
70
+ pub fn as_str(&self) -> Option<&str> {
71
+ std::str::from_utf8(&self.data).ok()
72
+ }
73
+
74
+ /// Consume and return the inner data
75
+ pub fn into_inner(self) -> Vec<u8> {
76
+ self.data
77
+ }
78
+ }
79
+
80
+ impl From<Vec<u8>> for Blob {
81
+ fn from(data: Vec<u8>) -> Self {
82
+ Self::new(data)
83
+ }
84
+ }
85
+
86
+ impl From<&[u8]> for Blob {
87
+ fn from(data: &[u8]) -> Self {
88
+ Self::new(data.to_vec())
89
+ }
90
+ }
91
+
92
+ impl From<&str> for Blob {
93
+ fn from(s: &str) -> Self {
94
+ Self::from_str(s)
95
+ }
96
+ }
97
+
98
+ impl From<String> for Blob {
99
+ fn from(s: String) -> Self {
100
+ Self::new(s.into_bytes())
101
+ }
102
+ }
103
+
104
+ #[cfg(test)]
105
+ mod tests {
106
+ use super::*;
107
+
108
+ #[test]
109
+ fn test_blob_new() {
110
+ let blob = Blob::new(vec![1, 2, 3]);
111
+ assert_eq!(blob.data(), &[1, 2, 3]);
112
+ assert_eq!(blob.size(), 3);
113
+ }
114
+
115
+ #[test]
116
+ fn test_blob_empty() {
117
+ let blob = Blob::empty();
118
+ assert!(blob.is_empty());
119
+ assert_eq!(blob.size(), 0);
120
+ }
121
+
122
+ #[test]
123
+ fn test_blob_from_str() {
124
+ let blob = Blob::from_str("hello world");
125
+ assert_eq!(blob.as_str(), Some("hello world"));
126
+ }
127
+
128
+ #[test]
129
+ fn test_blob_as_str_invalid_utf8() {
130
+ let blob = Blob::new(vec![0xff, 0xfe]);
131
+ assert_eq!(blob.as_str(), None);
132
+ }
133
+
134
+ #[test]
135
+ fn test_blob_from_conversions() {
136
+ let blob1: Blob = vec![1, 2, 3].into();
137
+ assert_eq!(blob1.size(), 3);
138
+
139
+ let blob2: Blob = "test".into();
140
+ assert_eq!(blob2.as_str(), Some("test"));
141
+
142
+ let blob3: Blob = String::from("test").into();
143
+ assert_eq!(blob3.as_str(), Some("test"));
144
+ }
145
+
146
+ #[test]
147
+ fn test_blob_into_inner() {
148
+ let blob = Blob::new(vec![1, 2, 3]);
149
+ let data = blob.into_inner();
150
+ assert_eq!(data, vec![1, 2, 3]);
151
+ }
152
+ }
src/core/config.rs ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Configuration
2
+ //!
3
+ //! ARMS configuration - define your space.
4
+ //!
5
+ //! Everything is configurable, not hardcoded:
6
+ //! - Dimensionality
7
+ //! - Proximity function
8
+ //! - Merge function
9
+ //! - Tier settings
10
+ //!
11
+ //! "If we say it's a rock now, in 2 years it can never be carved into a wheel."
12
+
13
+ use super::proximity::{Cosine, Proximity};
14
+ use super::merge::{Mean, Merge};
15
+ use std::sync::Arc;
16
+
17
+ /// Main ARMS configuration
18
+ ///
19
+ /// Defines the dimensional space and default operations.
20
+ #[derive(Clone)]
21
+ pub struct ArmsConfig {
22
+ /// Dimensionality of the space
23
+ ///
24
+ /// Set this to match your model's hidden size.
25
+ /// Examples: 768 (BERT), 1024 (GPT-2 medium), 4096 (large models)
26
+ pub dimensionality: usize,
27
+
28
+ /// Proximity function for similarity calculations
29
+ pub proximity: Arc<dyn Proximity>,
30
+
31
+ /// Merge function for hierarchical composition
32
+ pub merge: Arc<dyn Merge>,
33
+
34
+ /// Whether to normalize points on insertion
35
+ pub normalize_on_insert: bool,
36
+
37
+ /// Tier configuration
38
+ pub tiers: TierConfig,
39
+ }
40
+
41
+ impl ArmsConfig {
42
+ /// Create a new configuration with specified dimensionality
43
+ ///
44
+ /// Uses default proximity (Cosine) and merge (Mean) functions.
45
+ pub fn new(dimensionality: usize) -> Self {
46
+ Self {
47
+ dimensionality,
48
+ proximity: Arc::new(Cosine),
49
+ merge: Arc::new(Mean),
50
+ normalize_on_insert: true,
51
+ tiers: TierConfig::default(),
52
+ }
53
+ }
54
+
55
+ /// Set a custom proximity function
56
+ pub fn with_proximity<P: Proximity + 'static>(mut self, proximity: P) -> Self {
57
+ self.proximity = Arc::new(proximity);
58
+ self
59
+ }
60
+
61
+ /// Set a custom merge function
62
+ pub fn with_merge<M: Merge + 'static>(mut self, merge: M) -> Self {
63
+ self.merge = Arc::new(merge);
64
+ self
65
+ }
66
+
67
+ /// Set normalization behavior
68
+ pub fn with_normalize(mut self, normalize: bool) -> Self {
69
+ self.normalize_on_insert = normalize;
70
+ self
71
+ }
72
+
73
+ /// Set tier configuration
74
+ pub fn with_tiers(mut self, tiers: TierConfig) -> Self {
75
+ self.tiers = tiers;
76
+ self
77
+ }
78
+ }
79
+
80
+ impl Default for ArmsConfig {
81
+ /// Default configuration: 768 dimensions, cosine proximity, mean merge
82
+ fn default() -> Self {
83
+ Self::new(768)
84
+ }
85
+ }
86
+
87
+ /// Tier configuration for storage management
88
+ #[derive(Clone, Debug)]
89
+ pub struct TierConfig {
90
+ /// Hot tier (RAM) capacity in bytes
91
+ pub hot_capacity: usize,
92
+
93
+ /// Warm tier (NVMe) capacity in bytes
94
+ pub warm_capacity: usize,
95
+
96
+ /// Number of accesses before promoting to hotter tier
97
+ pub promote_after_accesses: u32,
98
+
99
+ /// Milliseconds since last access before evicting to colder tier
100
+ pub evict_after_ms: u64,
101
+ }
102
+
103
+ impl TierConfig {
104
+ /// Create a new tier configuration
105
+ pub fn new(hot_capacity: usize, warm_capacity: usize) -> Self {
106
+ Self {
107
+ hot_capacity,
108
+ warm_capacity,
109
+ promote_after_accesses: 3,
110
+ evict_after_ms: 3600 * 1000, // 1 hour
111
+ }
112
+ }
113
+
114
+ /// Tiny config for testing
115
+ pub fn tiny() -> Self {
116
+ Self {
117
+ hot_capacity: 1024 * 1024, // 1 MB
118
+ warm_capacity: 10 * 1024 * 1024, // 10 MB
119
+ promote_after_accesses: 2,
120
+ evict_after_ms: 60 * 1000, // 1 minute
121
+ }
122
+ }
123
+ }
124
+
125
+ impl Default for TierConfig {
126
+ fn default() -> Self {
127
+ Self {
128
+ hot_capacity: 1024 * 1024 * 1024, // 1 GB
129
+ warm_capacity: 100 * 1024 * 1024 * 1024, // 100 GB
130
+ promote_after_accesses: 3,
131
+ evict_after_ms: 3600 * 1000, // 1 hour
132
+ }
133
+ }
134
+ }
135
+
136
+ #[cfg(test)]
137
+ mod tests {
138
+ use super::*;
139
+ use crate::core::proximity::Euclidean;
140
+ use crate::core::merge::MaxPool;
141
+
142
+ #[test]
143
+ fn test_default_config() {
144
+ let config = ArmsConfig::default();
145
+ assert_eq!(config.dimensionality, 768);
146
+ assert!(config.normalize_on_insert);
147
+ assert_eq!(config.proximity.name(), "cosine");
148
+ assert_eq!(config.merge.name(), "mean");
149
+ }
150
+
151
+ #[test]
152
+ fn test_custom_config() {
153
+ let config = ArmsConfig::new(4096)
154
+ .with_proximity(Euclidean)
155
+ .with_merge(MaxPool)
156
+ .with_normalize(false);
157
+
158
+ assert_eq!(config.dimensionality, 4096);
159
+ assert!(!config.normalize_on_insert);
160
+ assert_eq!(config.proximity.name(), "euclidean");
161
+ assert_eq!(config.merge.name(), "max_pool");
162
+ }
163
+
164
+ #[test]
165
+ fn test_tier_config() {
166
+ let tiers = TierConfig::new(1024, 2048);
167
+ assert_eq!(tiers.hot_capacity, 1024);
168
+ assert_eq!(tiers.warm_capacity, 2048);
169
+ }
170
+
171
+ #[test]
172
+ fn test_tier_tiny() {
173
+ let tiers = TierConfig::tiny();
174
+ assert_eq!(tiers.hot_capacity, 1024 * 1024);
175
+ assert_eq!(tiers.evict_after_ms, 60 * 1000);
176
+ }
177
+ }
src/core/id.rs ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Id
2
+ //!
3
+ //! Unique identifier for placed points.
4
+ //!
5
+ //! Format: 128 bits = [timestamp_ms:48][counter:16][random:64]
6
+ //! - Timestamp provides natural temporal ordering
7
+ //! - Counter prevents collisions within same millisecond
8
+ //! - Random portion adds uniqueness
9
+ //! - Sortable by time when compared
10
+ //! - No external dependencies (not UUID, just bytes)
11
+
12
+ use std::sync::atomic::{AtomicU64, Ordering};
13
+ use std::time::{SystemTime, UNIX_EPOCH};
14
+
15
+ /// Global counter for uniqueness within same millisecond
16
+ static COUNTER: AtomicU64 = AtomicU64::new(0);
17
+
18
+ /// Unique identifier for a placed point
19
+ ///
20
+ /// 128 bits, timestamp-prefixed for natural time ordering.
21
+ #[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Debug)]
22
+ pub struct Id([u8; 16]);
23
+
24
+ impl Id {
25
+ /// Generate a new Id for the current moment
26
+ ///
27
+ /// Uses current timestamp + counter + random bytes for uniqueness.
28
+ pub fn now() -> Self {
29
+ let timestamp = SystemTime::now()
30
+ .duration_since(UNIX_EPOCH)
31
+ .unwrap()
32
+ .as_millis() as u64;
33
+
34
+ // Atomically increment counter for uniqueness
35
+ let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
36
+
37
+ let mut bytes = [0u8; 16];
38
+
39
+ // First 6 bytes: timestamp (48 bits)
40
+ bytes[0] = (timestamp >> 40) as u8;
41
+ bytes[1] = (timestamp >> 32) as u8;
42
+ bytes[2] = (timestamp >> 24) as u8;
43
+ bytes[3] = (timestamp >> 16) as u8;
44
+ bytes[4] = (timestamp >> 8) as u8;
45
+ bytes[5] = timestamp as u8;
46
+
47
+ // Next 2 bytes: counter (16 bits) - ensures uniqueness within millisecond
48
+ bytes[6] = (counter >> 8) as u8;
49
+ bytes[7] = counter as u8;
50
+
51
+ // Remaining 8 bytes: pseudo-random based on timestamp and counter
52
+ let random_seed = timestamp
53
+ .wrapping_mul(6364136223846793005)
54
+ .wrapping_add(counter);
55
+ bytes[8] = (random_seed >> 56) as u8;
56
+ bytes[9] = (random_seed >> 48) as u8;
57
+ bytes[10] = (random_seed >> 40) as u8;
58
+ bytes[11] = (random_seed >> 32) as u8;
59
+ bytes[12] = (random_seed >> 24) as u8;
60
+ bytes[13] = (random_seed >> 16) as u8;
61
+ bytes[14] = (random_seed >> 8) as u8;
62
+ bytes[15] = random_seed as u8;
63
+
64
+ Self(bytes)
65
+ }
66
+
67
+ /// Create an Id from raw bytes
68
+ pub fn from_bytes(bytes: [u8; 16]) -> Self {
69
+ Self(bytes)
70
+ }
71
+
72
+ /// Get the raw bytes
73
+ pub fn as_bytes(&self) -> &[u8; 16] {
74
+ &self.0
75
+ }
76
+
77
+ /// Extract the timestamp component (milliseconds since epoch)
78
+ pub fn timestamp_ms(&self) -> u64 {
79
+ ((self.0[0] as u64) << 40)
80
+ | ((self.0[1] as u64) << 32)
81
+ | ((self.0[2] as u64) << 24)
82
+ | ((self.0[3] as u64) << 16)
83
+ | ((self.0[4] as u64) << 8)
84
+ | (self.0[5] as u64)
85
+ }
86
+
87
+ /// Create a nil/zero Id (useful for testing)
88
+ pub fn nil() -> Self {
89
+ Self([0u8; 16])
90
+ }
91
+
92
+ /// Check if this is a nil Id
93
+ pub fn is_nil(&self) -> bool {
94
+ self.0 == [0u8; 16]
95
+ }
96
+ }
97
+
98
+ impl std::fmt::Display for Id {
99
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100
+ // Display as hex string
101
+ for byte in &self.0 {
102
+ write!(f, "{:02x}", byte)?;
103
+ }
104
+ Ok(())
105
+ }
106
+ }
107
+
108
+ #[cfg(test)]
109
+ mod tests {
110
+ use super::*;
111
+ use std::thread;
112
+ use std::time::Duration;
113
+
114
+ #[test]
115
+ fn test_id_creation() {
116
+ let id = Id::now();
117
+ assert!(!id.is_nil());
118
+ }
119
+
120
+ #[test]
121
+ fn test_id_timestamp() {
122
+ let before = SystemTime::now()
123
+ .duration_since(UNIX_EPOCH)
124
+ .unwrap()
125
+ .as_millis() as u64;
126
+
127
+ let id = Id::now();
128
+
129
+ let after = SystemTime::now()
130
+ .duration_since(UNIX_EPOCH)
131
+ .unwrap()
132
+ .as_millis() as u64;
133
+
134
+ let ts = id.timestamp_ms();
135
+ assert!(ts >= before);
136
+ assert!(ts <= after);
137
+ }
138
+
139
+ #[test]
140
+ fn test_id_ordering() {
141
+ let id1 = Id::now();
142
+ thread::sleep(Duration::from_millis(2));
143
+ let id2 = Id::now();
144
+
145
+ // id2 should be greater (later timestamp)
146
+ assert!(id2 > id1);
147
+ }
148
+
149
+ #[test]
150
+ fn test_id_from_bytes() {
151
+ let bytes = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
152
+ let id = Id::from_bytes(bytes);
153
+ assert_eq!(id.as_bytes(), &bytes);
154
+ }
155
+
156
+ #[test]
157
+ fn test_id_nil() {
158
+ let nil = Id::nil();
159
+ assert!(nil.is_nil());
160
+ assert_eq!(nil.timestamp_ms(), 0);
161
+ }
162
+
163
+ #[test]
164
+ fn test_id_display() {
165
+ let id = Id::from_bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
166
+ let display = format!("{}", id);
167
+ assert_eq!(display, "000102030405060708090a0b0c0d0e0f");
168
+ }
169
+ }
src/core/merge.rs ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Merge
2
+ //!
3
+ //! Trait and implementations for composing multiple points into one.
4
+ //!
5
+ //! This is one of the five primitives of ARMS:
6
+ //! `Merge: fn(points) -> point` - Compose together
7
+ //!
8
+ //! Merge is used for hierarchical composition:
9
+ //! - Chunks → Document
10
+ //! - Documents → Session
11
+ //! - Sessions → Domain
12
+ //!
13
+ //! Merge functions are pluggable - use whichever fits your use case.
14
+
15
+ use super::Point;
16
+
17
+ /// Trait for merging multiple points into one
18
+ ///
19
+ /// Used for hierarchical composition and aggregation.
20
+ pub trait Merge: Send + Sync {
21
+ /// Merge multiple points into a single point
22
+ ///
23
+ /// All points must have the same dimensionality.
24
+ /// The slice must not be empty.
25
+ fn merge(&self, points: &[Point]) -> Point;
26
+
27
+ /// Name of this merge function (for debugging/config)
28
+ fn name(&self) -> &'static str;
29
+ }
30
+
31
+ // ============================================================================
32
+ // IMPLEMENTATIONS
33
+ // ============================================================================
34
+
35
+ /// Mean (average) of all points
36
+ ///
37
+ /// The centroid of the input points.
38
+ /// Good default for most hierarchical composition.
39
+ #[derive(Clone, Copy, Debug, Default)]
40
+ pub struct Mean;
41
+
42
+ impl Merge for Mean {
43
+ fn merge(&self, points: &[Point]) -> Point {
44
+ assert!(!points.is_empty(), "Cannot merge empty slice");
45
+
46
+ let dims = points[0].dimensionality();
47
+ let n = points.len() as f32;
48
+
49
+ let mut result = vec![0.0; dims];
50
+ for p in points {
51
+ assert_eq!(
52
+ p.dimensionality(),
53
+ dims,
54
+ "All points must have same dimensionality"
55
+ );
56
+ for (r, d) in result.iter_mut().zip(p.dims()) {
57
+ *r += d / n;
58
+ }
59
+ }
60
+
61
+ Point::new(result)
62
+ }
63
+
64
+ fn name(&self) -> &'static str {
65
+ "mean"
66
+ }
67
+ }
68
+
69
+ /// Weighted mean of points
70
+ ///
71
+ /// Each point contributes proportionally to its weight.
72
+ /// Useful for recency weighting, importance weighting, etc.
73
+ #[derive(Clone, Debug)]
74
+ pub struct WeightedMean {
75
+ weights: Vec<f32>,
76
+ }
77
+
78
+ impl WeightedMean {
79
+ /// Create a new weighted mean with given weights
80
+ ///
81
+ /// Weights will be normalized (divided by sum) during merge.
82
+ pub fn new(weights: Vec<f32>) -> Self {
83
+ Self { weights }
84
+ }
85
+
86
+ /// Create with uniform weights (equivalent to Mean)
87
+ pub fn uniform(n: usize) -> Self {
88
+ Self {
89
+ weights: vec![1.0; n],
90
+ }
91
+ }
92
+
93
+ /// Create with recency weighting (more recent = higher weight)
94
+ ///
95
+ /// `decay` should be in (0, 1). Smaller = faster decay.
96
+ /// First point is oldest, last is most recent.
97
+ pub fn recency(n: usize, decay: f32) -> Self {
98
+ let weights: Vec<f32> = (0..n).map(|i| decay.powi((n - 1 - i) as i32)).collect();
99
+ Self { weights }
100
+ }
101
+ }
102
+
103
+ impl Merge for WeightedMean {
104
+ fn merge(&self, points: &[Point]) -> Point {
105
+ assert!(!points.is_empty(), "Cannot merge empty slice");
106
+ assert_eq!(
107
+ points.len(),
108
+ self.weights.len(),
109
+ "Number of points must match number of weights"
110
+ );
111
+
112
+ let dims = points[0].dimensionality();
113
+ let total_weight: f32 = self.weights.iter().sum();
114
+
115
+ let mut result = vec![0.0; dims];
116
+ for (p, &w) in points.iter().zip(&self.weights) {
117
+ assert_eq!(
118
+ p.dimensionality(),
119
+ dims,
120
+ "All points must have same dimensionality"
121
+ );
122
+ let normalized_w = w / total_weight;
123
+ for (r, d) in result.iter_mut().zip(p.dims()) {
124
+ *r += d * normalized_w;
125
+ }
126
+ }
127
+
128
+ Point::new(result)
129
+ }
130
+
131
+ fn name(&self) -> &'static str {
132
+ "weighted_mean"
133
+ }
134
+ }
135
+
136
+ /// Max pooling across points
137
+ ///
138
+ /// Takes the maximum value of each dimension across all points.
139
+ /// Preserves the strongest activations.
140
+ #[derive(Clone, Copy, Debug, Default)]
141
+ pub struct MaxPool;
142
+
143
+ impl Merge for MaxPool {
144
+ fn merge(&self, points: &[Point]) -> Point {
145
+ assert!(!points.is_empty(), "Cannot merge empty slice");
146
+
147
+ let dims = points[0].dimensionality();
148
+ let mut result = points[0].dims().to_vec();
149
+
150
+ for p in &points[1..] {
151
+ assert_eq!(
152
+ p.dimensionality(),
153
+ dims,
154
+ "All points must have same dimensionality"
155
+ );
156
+ for (r, d) in result.iter_mut().zip(p.dims()) {
157
+ *r = r.max(*d);
158
+ }
159
+ }
160
+
161
+ Point::new(result)
162
+ }
163
+
164
+ fn name(&self) -> &'static str {
165
+ "max_pool"
166
+ }
167
+ }
168
+
169
+ /// Min pooling across points
170
+ ///
171
+ /// Takes the minimum value of each dimension across all points.
172
+ #[derive(Clone, Copy, Debug, Default)]
173
+ pub struct MinPool;
174
+
175
+ impl Merge for MinPool {
176
+ fn merge(&self, points: &[Point]) -> Point {
177
+ assert!(!points.is_empty(), "Cannot merge empty slice");
178
+
179
+ let dims = points[0].dimensionality();
180
+ let mut result = points[0].dims().to_vec();
181
+
182
+ for p in &points[1..] {
183
+ assert_eq!(
184
+ p.dimensionality(),
185
+ dims,
186
+ "All points must have same dimensionality"
187
+ );
188
+ for (r, d) in result.iter_mut().zip(p.dims()) {
189
+ *r = r.min(*d);
190
+ }
191
+ }
192
+
193
+ Point::new(result)
194
+ }
195
+
196
+ fn name(&self) -> &'static str {
197
+ "min_pool"
198
+ }
199
+ }
200
+
201
+ /// Sum of all points (no averaging)
202
+ ///
203
+ /// Simple additive composition.
204
+ #[derive(Clone, Copy, Debug, Default)]
205
+ pub struct Sum;
206
+
207
+ impl Merge for Sum {
208
+ fn merge(&self, points: &[Point]) -> Point {
209
+ assert!(!points.is_empty(), "Cannot merge empty slice");
210
+
211
+ let dims = points[0].dimensionality();
212
+ let mut result = vec![0.0; dims];
213
+
214
+ for p in points {
215
+ assert_eq!(
216
+ p.dimensionality(),
217
+ dims,
218
+ "All points must have same dimensionality"
219
+ );
220
+ for (r, d) in result.iter_mut().zip(p.dims()) {
221
+ *r += d;
222
+ }
223
+ }
224
+
225
+ Point::new(result)
226
+ }
227
+
228
+ fn name(&self) -> &'static str {
229
+ "sum"
230
+ }
231
+ }
232
+
233
+ #[cfg(test)]
234
+ mod tests {
235
+ use super::*;
236
+
237
+ #[test]
238
+ fn test_mean_single() {
239
+ let points = vec![Point::new(vec![1.0, 2.0, 3.0])];
240
+ let merged = Mean.merge(&points);
241
+ assert_eq!(merged.dims(), &[1.0, 2.0, 3.0]);
242
+ }
243
+
244
+ #[test]
245
+ fn test_mean_multiple() {
246
+ let points = vec![
247
+ Point::new(vec![1.0, 2.0]),
248
+ Point::new(vec![3.0, 4.0]),
249
+ ];
250
+ let merged = Mean.merge(&points);
251
+ assert_eq!(merged.dims(), &[2.0, 3.0]);
252
+ }
253
+
254
+ #[test]
255
+ fn test_weighted_mean() {
256
+ let points = vec![
257
+ Point::new(vec![0.0, 0.0]),
258
+ Point::new(vec![10.0, 10.0]),
259
+ ];
260
+ // Weight second point 3x more than first
261
+ let merger = WeightedMean::new(vec![1.0, 3.0]);
262
+ let merged = merger.merge(&points);
263
+ // (0*0.25 + 10*0.75, 0*0.25 + 10*0.75) = (7.5, 7.5)
264
+ assert!((merged.dims()[0] - 7.5).abs() < 0.0001);
265
+ assert!((merged.dims()[1] - 7.5).abs() < 0.0001);
266
+ }
267
+
268
+ #[test]
269
+ fn test_weighted_mean_recency() {
270
+ let merger = WeightedMean::recency(3, 0.5);
271
+ // decay = 0.5, n = 3
272
+ // weights: [0.5^2, 0.5^1, 0.5^0] = [0.25, 0.5, 1.0]
273
+ assert_eq!(merger.weights.len(), 3);
274
+ assert!((merger.weights[0] - 0.25).abs() < 0.0001);
275
+ assert!((merger.weights[1] - 0.5).abs() < 0.0001);
276
+ assert!((merger.weights[2] - 1.0).abs() < 0.0001);
277
+ }
278
+
279
+ #[test]
280
+ fn test_max_pool() {
281
+ let points = vec![
282
+ Point::new(vec![1.0, 5.0, 2.0]),
283
+ Point::new(vec![3.0, 2.0, 4.0]),
284
+ Point::new(vec![2.0, 3.0, 1.0]),
285
+ ];
286
+ let merged = MaxPool.merge(&points);
287
+ assert_eq!(merged.dims(), &[3.0, 5.0, 4.0]);
288
+ }
289
+
290
+ #[test]
291
+ fn test_min_pool() {
292
+ let points = vec![
293
+ Point::new(vec![1.0, 5.0, 2.0]),
294
+ Point::new(vec![3.0, 2.0, 4.0]),
295
+ Point::new(vec![2.0, 3.0, 1.0]),
296
+ ];
297
+ let merged = MinPool.merge(&points);
298
+ assert_eq!(merged.dims(), &[1.0, 2.0, 1.0]);
299
+ }
300
+
301
+ #[test]
302
+ fn test_sum() {
303
+ let points = vec![
304
+ Point::new(vec![1.0, 2.0]),
305
+ Point::new(vec![3.0, 4.0]),
306
+ ];
307
+ let merged = Sum.merge(&points);
308
+ assert_eq!(merged.dims(), &[4.0, 6.0]);
309
+ }
310
+
311
+ #[test]
312
+ fn test_merge_names() {
313
+ assert_eq!(Mean.name(), "mean");
314
+ assert_eq!(MaxPool.name(), "max_pool");
315
+ assert_eq!(MinPool.name(), "min_pool");
316
+ assert_eq!(Sum.name(), "sum");
317
+ }
318
+
319
+ #[test]
320
+ #[should_panic(expected = "Cannot merge empty")]
321
+ fn test_merge_empty_panics() {
322
+ let points: Vec<Point> = vec![];
323
+ Mean.merge(&points);
324
+ }
325
+
326
+ #[test]
327
+ #[should_panic(expected = "same dimensionality")]
328
+ fn test_merge_dimension_mismatch_panics() {
329
+ let points = vec![
330
+ Point::new(vec![1.0, 2.0]),
331
+ Point::new(vec![1.0, 2.0, 3.0]),
332
+ ];
333
+ Mean.merge(&points);
334
+ }
335
+ }
src/core/mod.rs ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Core Domain
2
+ //!
3
+ //! Pure math, no I/O. The foundation of ARMS.
4
+ //!
5
+ //! This module contains the fundamental types and operations:
6
+ //! - `Point` - A position in dimensional space
7
+ //! - `Id` - Unique identifier for placed points
8
+ //! - `Blob` - Raw payload data
9
+ //! - `Proximity` - Trait for measuring relatedness
10
+ //! - `Merge` - Trait for composing points
11
+ //!
12
+ //! ## Design Principles
13
+ //!
14
+ //! - All functions are pure (deterministic, no side effects)
15
+ //! - No I/O operations
16
+ //! - No external dependencies beyond std
17
+ //! - Fully testable in isolation
18
+
19
+ mod point;
20
+ mod id;
21
+ mod blob;
22
+ pub mod proximity;
23
+ pub mod merge;
24
+ pub mod config;
25
+
26
+ // Re-exports
27
+ pub use point::Point;
28
+ pub use id::Id;
29
+ pub use blob::Blob;
30
+
31
+ /// A point that has been placed in the space
32
+ #[derive(Clone)]
33
+ pub struct PlacedPoint {
34
+ /// Unique identifier
35
+ pub id: Id,
36
+ /// Position in dimensional space
37
+ pub point: Point,
38
+ /// Attached payload
39
+ pub blob: Blob,
40
+ }
41
+
42
+ impl PlacedPoint {
43
+ /// Create a new placed point
44
+ pub fn new(id: Id, point: Point, blob: Blob) -> Self {
45
+ Self { id, point, blob }
46
+ }
47
+ }
48
+
49
+ #[cfg(test)]
50
+ mod tests {
51
+ use super::*;
52
+
53
+ #[test]
54
+ fn test_placed_point_creation() {
55
+ let id = Id::now();
56
+ let point = Point::new(vec![1.0, 2.0, 3.0]);
57
+ let blob = Blob::new(vec![1, 2, 3]);
58
+
59
+ let placed = PlacedPoint::new(id, point.clone(), blob);
60
+
61
+ assert_eq!(placed.point.dimensionality(), 3);
62
+ assert_eq!(placed.blob.size(), 3);
63
+ }
64
+ }
src/core/point.rs ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Point
2
+ //!
3
+ //! A position in dimensional space. The fundamental primitive.
4
+ //!
5
+ //! Dimensionality is NOT fixed - configure it for your model.
6
+ //! 768-dim, 1024-dim, 4096-dim, or any size you need.
7
+ //!
8
+ //! The point IS the thought's position.
9
+ //! The position IS its relationship to all other thoughts.
10
+
11
+ /// A point in dimensional space
12
+ #[derive(Clone, Debug, PartialEq)]
13
+ pub struct Point {
14
+ dims: Vec<f32>,
15
+ }
16
+
17
+ impl Point {
18
+ /// Create a new point from a vector of dimensions
19
+ ///
20
+ /// # Example
21
+ /// ```
22
+ /// use arms::Point;
23
+ /// let p = Point::new(vec![1.0, 2.0, 3.0]);
24
+ /// assert_eq!(p.dimensionality(), 3);
25
+ /// ```
26
+ pub fn new(dims: Vec<f32>) -> Self {
27
+ Self { dims }
28
+ }
29
+
30
+ /// Create an origin point (all zeros) of given dimensionality
31
+ ///
32
+ /// # Example
33
+ /// ```
34
+ /// use arms::Point;
35
+ /// let origin = Point::origin(768);
36
+ /// assert_eq!(origin.dimensionality(), 768);
37
+ /// assert!(origin.dims().iter().all(|&x| x == 0.0));
38
+ /// ```
39
+ pub fn origin(dims: usize) -> Self {
40
+ Self {
41
+ dims: vec![0.0; dims],
42
+ }
43
+ }
44
+
45
+ /// Get the dimensionality of this point
46
+ pub fn dimensionality(&self) -> usize {
47
+ self.dims.len()
48
+ }
49
+
50
+ /// Access the dimensions as a slice
51
+ pub fn dims(&self) -> &[f32] {
52
+ &self.dims
53
+ }
54
+
55
+ /// Mutable access to dimensions
56
+ pub fn dims_mut(&mut self) -> &mut [f32] {
57
+ &mut self.dims
58
+ }
59
+
60
+ /// Calculate the magnitude (L2 norm) of this point
61
+ ///
62
+ /// # Example
63
+ /// ```
64
+ /// use arms::Point;
65
+ /// let p = Point::new(vec![3.0, 4.0]);
66
+ /// assert!((p.magnitude() - 5.0).abs() < 0.0001);
67
+ /// ```
68
+ pub fn magnitude(&self) -> f32 {
69
+ self.dims.iter().map(|x| x * x).sum::<f32>().sqrt()
70
+ }
71
+
72
+ /// Check if this point is normalized (magnitude ≈ 1.0)
73
+ pub fn is_normalized(&self) -> bool {
74
+ let mag = self.magnitude();
75
+ (mag - 1.0).abs() < 0.001
76
+ }
77
+
78
+ /// Return a normalized copy of this point
79
+ ///
80
+ /// If magnitude is zero, returns a clone of self.
81
+ ///
82
+ /// # Example
83
+ /// ```
84
+ /// use arms::Point;
85
+ /// let p = Point::new(vec![3.0, 4.0]);
86
+ /// let normalized = p.normalize();
87
+ /// assert!(normalized.is_normalized());
88
+ /// ```
89
+ pub fn normalize(&self) -> Self {
90
+ let mag = self.magnitude();
91
+ if mag == 0.0 {
92
+ return self.clone();
93
+ }
94
+ Self {
95
+ dims: self.dims.iter().map(|x| x / mag).collect(),
96
+ }
97
+ }
98
+
99
+ /// Add another point to this one (element-wise)
100
+ pub fn add(&self, other: &Point) -> Self {
101
+ assert_eq!(
102
+ self.dimensionality(),
103
+ other.dimensionality(),
104
+ "Points must have same dimensionality"
105
+ );
106
+ Self {
107
+ dims: self
108
+ .dims
109
+ .iter()
110
+ .zip(other.dims.iter())
111
+ .map(|(a, b)| a + b)
112
+ .collect(),
113
+ }
114
+ }
115
+
116
+ /// Scale this point by a scalar
117
+ pub fn scale(&self, scalar: f32) -> Self {
118
+ Self {
119
+ dims: self.dims.iter().map(|x| x * scalar).collect(),
120
+ }
121
+ }
122
+ }
123
+
124
+ #[cfg(test)]
125
+ mod tests {
126
+ use super::*;
127
+
128
+ #[test]
129
+ fn test_new_point() {
130
+ let p = Point::new(vec![1.0, 2.0, 3.0]);
131
+ assert_eq!(p.dimensionality(), 3);
132
+ assert_eq!(p.dims(), &[1.0, 2.0, 3.0]);
133
+ }
134
+
135
+ #[test]
136
+ fn test_origin() {
137
+ let origin = Point::origin(768);
138
+ assert_eq!(origin.dimensionality(), 768);
139
+ assert!(origin.dims().iter().all(|&x| x == 0.0));
140
+ }
141
+
142
+ #[test]
143
+ fn test_magnitude() {
144
+ let p = Point::new(vec![3.0, 4.0]);
145
+ assert!((p.magnitude() - 5.0).abs() < 0.0001);
146
+ }
147
+
148
+ #[test]
149
+ fn test_normalize() {
150
+ let p = Point::new(vec![3.0, 4.0]);
151
+ let normalized = p.normalize();
152
+ assert!(normalized.is_normalized());
153
+ assert!((normalized.dims()[0] - 0.6).abs() < 0.0001);
154
+ assert!((normalized.dims()[1] - 0.8).abs() < 0.0001);
155
+ }
156
+
157
+ #[test]
158
+ fn test_normalize_zero() {
159
+ let p = Point::origin(3);
160
+ let normalized = p.normalize();
161
+ assert_eq!(normalized.dims(), &[0.0, 0.0, 0.0]);
162
+ }
163
+
164
+ #[test]
165
+ fn test_add() {
166
+ let a = Point::new(vec![1.0, 2.0]);
167
+ let b = Point::new(vec![3.0, 4.0]);
168
+ let c = a.add(&b);
169
+ assert_eq!(c.dims(), &[4.0, 6.0]);
170
+ }
171
+
172
+ #[test]
173
+ fn test_scale() {
174
+ let p = Point::new(vec![1.0, 2.0]);
175
+ let scaled = p.scale(2.0);
176
+ assert_eq!(scaled.dims(), &[2.0, 4.0]);
177
+ }
178
+
179
+ #[test]
180
+ #[should_panic(expected = "same dimensionality")]
181
+ fn test_add_different_dims_panics() {
182
+ let a = Point::new(vec![1.0, 2.0]);
183
+ let b = Point::new(vec![1.0, 2.0, 3.0]);
184
+ let _ = a.add(&b);
185
+ }
186
+ }
src/core/proximity.rs ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Proximity
2
+ //!
3
+ //! Trait and implementations for measuring how related two points are.
4
+ //!
5
+ //! This is one of the five primitives of ARMS:
6
+ //! `Proximity: fn(a, b) -> f32` - How related?
7
+ //!
8
+ //! Proximity functions are pluggable - use whichever fits your use case.
9
+
10
+ use super::Point;
11
+
12
+ /// Trait for measuring proximity between points
13
+ ///
14
+ /// Higher values typically mean more similar/related.
15
+ /// The exact semantics depend on the implementation.
16
+ pub trait Proximity: Send + Sync {
17
+ /// Compute proximity between two points
18
+ ///
19
+ /// Both points must have the same dimensionality.
20
+ fn proximity(&self, a: &Point, b: &Point) -> f32;
21
+
22
+ /// Name of this proximity function (for debugging/config)
23
+ fn name(&self) -> &'static str;
24
+ }
25
+
26
+ // ============================================================================
27
+ // IMPLEMENTATIONS
28
+ // ============================================================================
29
+
30
+ /// Cosine similarity
31
+ ///
32
+ /// Measures the cosine of the angle between two vectors.
33
+ /// Returns a value in [-1, 1] where 1 means identical direction.
34
+ ///
35
+ /// Best for: Normalized vectors, semantic similarity.
36
+ #[derive(Clone, Copy, Debug, Default)]
37
+ pub struct Cosine;
38
+
39
+ impl Proximity for Cosine {
40
+ fn proximity(&self, a: &Point, b: &Point) -> f32 {
41
+ assert_eq!(
42
+ a.dimensionality(),
43
+ b.dimensionality(),
44
+ "Points must have same dimensionality"
45
+ );
46
+
47
+ let dot: f32 = a
48
+ .dims()
49
+ .iter()
50
+ .zip(b.dims().iter())
51
+ .map(|(x, y)| x * y)
52
+ .sum();
53
+
54
+ let mag_a = a.magnitude();
55
+ let mag_b = b.magnitude();
56
+
57
+ if mag_a == 0.0 || mag_b == 0.0 {
58
+ return 0.0;
59
+ }
60
+
61
+ dot / (mag_a * mag_b)
62
+ }
63
+
64
+ fn name(&self) -> &'static str {
65
+ "cosine"
66
+ }
67
+ }
68
+
69
+ /// Euclidean distance
70
+ ///
71
+ /// The straight-line distance between two points.
72
+ /// Returns a value in [0, ∞) where 0 means identical.
73
+ ///
74
+ /// Note: This returns DISTANCE, not similarity.
75
+ /// Lower values = more similar.
76
+ #[derive(Clone, Copy, Debug, Default)]
77
+ pub struct Euclidean;
78
+
79
+ impl Proximity for Euclidean {
80
+ fn proximity(&self, a: &Point, b: &Point) -> f32 {
81
+ assert_eq!(
82
+ a.dimensionality(),
83
+ b.dimensionality(),
84
+ "Points must have same dimensionality"
85
+ );
86
+
87
+ let dist_sq: f32 = a
88
+ .dims()
89
+ .iter()
90
+ .zip(b.dims().iter())
91
+ .map(|(x, y)| (x - y).powi(2))
92
+ .sum();
93
+
94
+ dist_sq.sqrt()
95
+ }
96
+
97
+ fn name(&self) -> &'static str {
98
+ "euclidean"
99
+ }
100
+ }
101
+
102
+ /// Squared Euclidean distance
103
+ ///
104
+ /// Same ordering as Euclidean but faster (no sqrt).
105
+ /// Use when you only need to compare distances, not absolute values.
106
+ #[derive(Clone, Copy, Debug, Default)]
107
+ pub struct EuclideanSquared;
108
+
109
+ impl Proximity for EuclideanSquared {
110
+ fn proximity(&self, a: &Point, b: &Point) -> f32 {
111
+ assert_eq!(
112
+ a.dimensionality(),
113
+ b.dimensionality(),
114
+ "Points must have same dimensionality"
115
+ );
116
+
117
+ a.dims()
118
+ .iter()
119
+ .zip(b.dims().iter())
120
+ .map(|(x, y)| (x - y).powi(2))
121
+ .sum()
122
+ }
123
+
124
+ fn name(&self) -> &'static str {
125
+ "euclidean_squared"
126
+ }
127
+ }
128
+
129
+ /// Dot product
130
+ ///
131
+ /// The raw dot product without normalization.
132
+ /// Returns a value that depends on magnitudes.
133
+ ///
134
+ /// Best for: When magnitude matters, not just direction.
135
+ #[derive(Clone, Copy, Debug, Default)]
136
+ pub struct DotProduct;
137
+
138
+ impl Proximity for DotProduct {
139
+ fn proximity(&self, a: &Point, b: &Point) -> f32 {
140
+ assert_eq!(
141
+ a.dimensionality(),
142
+ b.dimensionality(),
143
+ "Points must have same dimensionality"
144
+ );
145
+
146
+ a.dims()
147
+ .iter()
148
+ .zip(b.dims().iter())
149
+ .map(|(x, y)| x * y)
150
+ .sum()
151
+ }
152
+
153
+ fn name(&self) -> &'static str {
154
+ "dot_product"
155
+ }
156
+ }
157
+
158
+ /// Manhattan (L1) distance
159
+ ///
160
+ /// Sum of absolute differences along each dimension.
161
+ /// Returns a value in [0, ∞) where 0 means identical.
162
+ #[derive(Clone, Copy, Debug, Default)]
163
+ pub struct Manhattan;
164
+
165
+ impl Proximity for Manhattan {
166
+ fn proximity(&self, a: &Point, b: &Point) -> f32 {
167
+ assert_eq!(
168
+ a.dimensionality(),
169
+ b.dimensionality(),
170
+ "Points must have same dimensionality"
171
+ );
172
+
173
+ a.dims()
174
+ .iter()
175
+ .zip(b.dims().iter())
176
+ .map(|(x, y)| (x - y).abs())
177
+ .sum()
178
+ }
179
+
180
+ fn name(&self) -> &'static str {
181
+ "manhattan"
182
+ }
183
+ }
184
+
185
+ #[cfg(test)]
186
+ mod tests {
187
+ use super::*;
188
+
189
+ #[test]
190
+ fn test_cosine_identical() {
191
+ let a = Point::new(vec![1.0, 0.0, 0.0]);
192
+ let b = Point::new(vec![1.0, 0.0, 0.0]);
193
+ let cos = Cosine.proximity(&a, &b);
194
+ assert!((cos - 1.0).abs() < 0.0001);
195
+ }
196
+
197
+ #[test]
198
+ fn test_cosine_opposite() {
199
+ let a = Point::new(vec![1.0, 0.0, 0.0]);
200
+ let b = Point::new(vec![-1.0, 0.0, 0.0]);
201
+ let cos = Cosine.proximity(&a, &b);
202
+ assert!((cos - (-1.0)).abs() < 0.0001);
203
+ }
204
+
205
+ #[test]
206
+ fn test_cosine_orthogonal() {
207
+ let a = Point::new(vec![1.0, 0.0, 0.0]);
208
+ let b = Point::new(vec![0.0, 1.0, 0.0]);
209
+ let cos = Cosine.proximity(&a, &b);
210
+ assert!(cos.abs() < 0.0001);
211
+ }
212
+
213
+ #[test]
214
+ fn test_euclidean() {
215
+ let a = Point::new(vec![0.0, 0.0]);
216
+ let b = Point::new(vec![3.0, 4.0]);
217
+ let dist = Euclidean.proximity(&a, &b);
218
+ assert!((dist - 5.0).abs() < 0.0001);
219
+ }
220
+
221
+ #[test]
222
+ fn test_euclidean_squared() {
223
+ let a = Point::new(vec![0.0, 0.0]);
224
+ let b = Point::new(vec![3.0, 4.0]);
225
+ let dist_sq = EuclideanSquared.proximity(&a, &b);
226
+ assert!((dist_sq - 25.0).abs() < 0.0001);
227
+ }
228
+
229
+ #[test]
230
+ fn test_dot_product() {
231
+ let a = Point::new(vec![1.0, 2.0, 3.0]);
232
+ let b = Point::new(vec![4.0, 5.0, 6.0]);
233
+ let dot = DotProduct.proximity(&a, &b);
234
+ // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
235
+ assert!((dot - 32.0).abs() < 0.0001);
236
+ }
237
+
238
+ #[test]
239
+ fn test_manhattan() {
240
+ let a = Point::new(vec![0.0, 0.0]);
241
+ let b = Point::new(vec![3.0, 4.0]);
242
+ let dist = Manhattan.proximity(&a, &b);
243
+ assert!((dist - 7.0).abs() < 0.0001);
244
+ }
245
+
246
+ #[test]
247
+ fn test_proximity_names() {
248
+ assert_eq!(Cosine.name(), "cosine");
249
+ assert_eq!(Euclidean.name(), "euclidean");
250
+ assert_eq!(DotProduct.name(), "dot_product");
251
+ assert_eq!(Manhattan.name(), "manhattan");
252
+ }
253
+
254
+ #[test]
255
+ #[should_panic(expected = "same dimensionality")]
256
+ fn test_dimension_mismatch_panics() {
257
+ let a = Point::new(vec![1.0, 2.0]);
258
+ let b = Point::new(vec![1.0, 2.0, 3.0]);
259
+ Cosine.proximity(&a, &b);
260
+ }
261
+ }
src/engine/arms.rs ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //! # Arms Engine
2
+ //!
3
+ //! The main ARMS orchestrator.
4
+ //!
5
+ //! This struct wires together:
6
+ //! - Storage (Place port)
7
+ //! - Index (Near port)
8
+ //! - Configuration
9
+ //!
10
+ //! And exposes a unified API for storing and retrieving points.
11
+
12
+ use crate::core::{Blob, Id, PlacedPoint, Point};
13
+ use crate::core::config::ArmsConfig;
14
+ use crate::ports::{Near, NearResult, Place, PlaceResult, SearchResult};
15
+ use crate::adapters::storage::MemoryStorage;
16
+ use crate::adapters::index::FlatIndex;
17
+
18
+ /// The main ARMS engine
19
+ ///
20
+ /// Orchestrates storage and indexing with a unified API.
21
+ pub struct Arms {
22
+ /// Configuration
23
+ config: ArmsConfig,
24
+
25
+ /// Storage backend (Place port)
26
+ storage: Box<dyn Place>,
27
+
28
+ /// Index backend (Near port)
29
+ index: Box<dyn Near>,
30
+ }
31
+
32
+ impl Arms {
33
+ /// Create a new ARMS instance with default adapters
34
+ ///
35
+ /// Uses MemoryStorage and FlatIndex.
36
+ /// For production, use `Arms::with_adapters` with appropriate backends.
37
+ pub fn new(config: ArmsConfig) -> Self {
38
+ let storage = Box::new(MemoryStorage::new(config.dimensionality));
39
+ let index = Box::new(FlatIndex::new(
40
+ config.dimensionality,
41
+ config.proximity.clone(),
42
+ true, // Assuming cosine-like similarity by default
43
+ ));
44
+
45
+ Self {
46
+ config,
47
+ storage,
48
+ index,
49
+ }
50
+ }
51
+
52
+ /// Create with custom adapters
53
+ pub fn with_adapters(
54
+ config: ArmsConfig,
55
+ storage: Box<dyn Place>,
56
+ index: Box<dyn Near>,
57
+ ) -> Self {
58
+ Self {
59
+ config,
60
+ storage,
61
+ index,
62
+ }
63
+ }
64
+
65
+ /// Get the configuration
66
+ pub fn config(&self) -> &ArmsConfig {
67
+ &self.config
68
+ }
69
+
70
+ /// Get the dimensionality of this space
71
+ pub fn dimensionality(&self) -> usize {
72
+ self.config.dimensionality
73
+ }
74
+
75
+ // ========================================================================
76
+ // PLACE OPERATIONS
77
+ // ========================================================================
78
+
79
+ /// Place a point in the space
80
+ ///
81
+ /// The point will be normalized if configured to do so.
82
+ /// Returns the assigned ID.
83
+ pub fn place(&mut self, point: Point, blob: Blob) -> PlaceResult<Id> {
84
+ // Normalize if configured
85
+ let point = if self.config.normalize_on_insert {
86
+ point.normalize()
87
+ } else {
88
+ point
89
+ };
90
+
91
+ // Store in storage
92
+ let id = self.storage.place(point.clone(), blob)?;
93
+
94
+ // Add to index
95
+ if let Err(e) = self.index.add(id, &point) {
96
+ // Rollback storage if index fails
97
+ self.storage.remove(id);
98
+ return Err(crate::ports::PlaceError::StorageError(format!(
99
+ "Index error: {:?}",
100
+ e
101
+ )));
102
+ }
103
+
104
+ Ok(id)
105
+ }
106
+
107
+ /// Place multiple points at once
108
+ pub fn place_batch(&mut self, items: Vec<(Point, Blob)>) -> Vec<PlaceResult<Id>> {
109
+ items
110
+ .into_iter()
111
+ .map(|(point, blob)| self.place(point, blob))
112
+ .collect()
113
+ }
114
+
115
+ /// Remove a point from the space
116
+ pub fn remove(&mut self, id: Id) -> Option<PlacedPoint> {
117
+ // Remove from index first
118
+ let _ = self.index.remove(id);
119
+
120
+ // Then from storage
121
+ self.storage.remove(id)
122
+ }
123
+
124
+ /// Get a point by ID
125
+ pub fn get(&self, id: Id) -> Option<&PlacedPoint> {
126
+ self.storage.get(id)
127
+ }
128
+
129
+ /// Check if a point exists
130
+ pub fn contains(&self, id: Id) -> bool {
131
+ self.storage.contains(id)
132
+ }
133
+
134
+ /// Get the number of stored points
135
+ pub fn len(&self) -> usize {
136
+ self.storage.len()
137
+ }
138
+
139
+ /// Check if the space is empty
140
+ pub fn is_empty(&self) -> bool {
141
+ self.storage.is_empty()
142
+ }
143
+
144
+ /// Clear all points
145
+ pub fn clear(&mut self) {
146
+ self.storage.clear();
147
+ let _ = self.index.rebuild(); // Reset index
148
+ }
149
+
150
+ // ========================================================================
151
+ // NEAR OPERATIONS
152
+ // ========================================================================
153
+
154
+ /// Find k nearest points to query
155
+ pub fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
156
+ // Normalize query if configured
157
+ let query = if self.config.normalize_on_insert {
158
+ query.normalize()
159
+ } else {
160
+ query.clone()
161
+ };
162
+
163
+ self.index.near(&query, k)
164
+ }
165
+
166
+ /// Find all points within threshold
167
+ pub fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
168
+ let query = if self.config.normalize_on_insert {
169
+ query.normalize()
170
+ } else {
171
+ query.clone()
172
+ };
173
+
174
+ self.index.within(&query, threshold)
175
+ }
176
+
177
+ /// Find and retrieve k nearest points (with full data)
178
+ pub fn near_with_data(&self, query: &Point, k: usize) -> NearResult<Vec<(&PlacedPoint, f32)>> {
179
+ let results = self.near(query, k)?;
180
+
181
+ Ok(results
182
+ .into_iter()
183
+ .filter_map(|r| self.storage.get(r.id).map(|p| (p, r.score)))
184
+ .collect())
185
+ }
186
+
187
+ // ========================================================================
188
+ // MERGE OPERATIONS
189
+ // ========================================================================
190
+
191
+ /// Merge multiple points into one using the configured merge function
192
+ pub fn merge(&self, points: &[Point]) -> Point {
193
+ self.config.merge.merge(points)
194
+ }
195
+
196
+ /// Compute proximity between two points
197
+ pub fn proximity(&self, a: &Point, b: &Point) -> f32 {
198
+ self.config.proximity.proximity(a, b)
199
+ }
200
+
201
+ // ========================================================================
202
+ // STATS
203
+ // ========================================================================
204
+
205
+ /// Get storage size in bytes
206
+ pub fn size_bytes(&self) -> usize {
207
+ self.storage.size_bytes()
208
+ }
209
+
210
+ /// Get index stats
211
+ pub fn index_len(&self) -> usize {
212
+ self.index.len()
213
+ }
214
+
215
+ /// Check if index is ready
216
+ pub fn is_ready(&self) -> bool {
217
+ self.index.is_ready()
218
+ }
219
+ }
220
+
221
+ #[cfg(test)]
222
+ mod tests {
223
+ use super::*;
224
+
225
+ fn create_test_arms() -> Arms {
226
+ Arms::new(ArmsConfig::new(3))
227
+ }
228
+
229
+ #[test]
230
+ fn test_arms_place_and_get() {
231
+ let mut arms = create_test_arms();
232
+
233
+ let point = Point::new(vec![1.0, 0.0, 0.0]);
234
+ let blob = Blob::from_str("test data");
235
+
236
+ let id = arms.place(point, blob).unwrap();
237
+
238
+ let retrieved = arms.get(id).unwrap();
239
+ assert_eq!(retrieved.blob.as_str(), Some("test data"));
240
+ }
241
+
242
+ #[test]
243
+ fn test_arms_near() {
244
+ let mut arms = create_test_arms();
245
+
246
+ // Add some points
247
+ arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
248
+ arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
249
+ arms.place(Point::new(vec![0.0, 0.0, 1.0]), Blob::from_str("z")).unwrap();
250
+
251
+ // Query
252
+ let query = Point::new(vec![1.0, 0.0, 0.0]);
253
+ let results = arms.near(&query, 2).unwrap();
254
+
255
+ assert_eq!(results.len(), 2);
256
+ // First result should have highest similarity
257
+ assert!(results[0].score > results[1].score);
258
+ }
259
+
260
+ #[test]
261
+ fn test_arms_near_with_data() {
262
+ let mut arms = create_test_arms();
263
+
264
+ arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
265
+ arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
266
+
267
+ let query = Point::new(vec![1.0, 0.0, 0.0]);
268
+ let results = arms.near_with_data(&query, 1).unwrap();
269
+
270
+ assert_eq!(results.len(), 1);
271
+ assert_eq!(results[0].0.blob.as_str(), Some("x"));
272
+ }
273
+
274
+ #[test]
275
+ fn test_arms_remove() {
276
+ let mut arms = create_test_arms();
277
+
278
+ let id = arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::empty()).unwrap();
279
+
280
+ assert!(arms.contains(id));
281
+ assert_eq!(arms.len(), 1);
282
+
283
+ arms.remove(id);
284
+
285
+ assert!(!arms.contains(id));
286
+ assert_eq!(arms.len(), 0);
287
+ }
288
+
289
+ #[test]
290
+ fn test_arms_merge() {
291
+ let arms = create_test_arms();
292
+
293
+ let points = vec![
294
+ Point::new(vec![1.0, 0.0, 0.0]),
295
+ Point::new(vec![0.0, 1.0, 0.0]),
296
+ ];
297
+
298
+ let merged = arms.merge(&points);
299
+
300
+ // Mean of [1,0,0] and [0,1,0] = [0.5, 0.5, 0]
301
+ assert!((merged.dims()[0] - 0.5).abs() < 0.0001);
302
+ assert!((merged.dims()[1] - 0.5).abs() < 0.0001);
303
+ assert!((merged.dims()[2] - 0.0).abs() < 0.0001);
304
+ }
305
+
306
+ #[test]
307
+ fn test_arms_clear() {
308
+ let mut arms = create_test_arms();
309
+
310
+ for i in 0..10 {
311
+ arms.place(Point::new(vec![i as f32, 0.0, 0.0]), Blob::empty()).unwrap();
312
+ }
313
+
314
+ assert_eq!(arms.len(), 10);
315
+
316
+ arms.clear();
317
+
318
+ assert_eq!(arms.len(), 0);
319
+ assert!(arms.is_empty());
320
+ }
321
+
322
+ #[test]
323
+ fn test_arms_normalizes_on_insert() {
324
+ let mut arms = create_test_arms();
325
+
326
+ // Insert a non-normalized point
327
+ let point = Point::new(vec![3.0, 4.0, 0.0]); // magnitude = 5
328
+ let id = arms.place(point, Blob::empty()).unwrap();
329
+
330
+ let retrieved = arms.get(id).unwrap();
331
+
332
+ // Should be normalized
333
+ assert!(retrieved.point.is_normalized());
334
+ }
335
+ }