Jackoatmon commited on
Commit
c2bf4b6
Β·
verified Β·
1 Parent(s): e317e25

Update Feather H200 runtime: Nemotron streaming and HTM force-CPU canary fixes

Browse files
Files changed (41) hide show
  1. overlay/htm_rust/bench_gpu.py +81 -81
  2. overlay/htm_rust/build.rs +6 -12
  3. overlay/htm_rust/docs/GPU_HTM.md +302 -302
  4. overlay/htm_rust/src/gpu/fused.rs +50 -33
  5. overlay/htm_rust/src/gpu/kernels/sp_boost_fused.cu +59 -59
  6. overlay/htm_rust/src/gpu/kernels/sp_duty.cu +45 -45
  7. overlay/htm_rust/src/gpu/kernels/sp_learn.cu +45 -45
  8. overlay/htm_rust/src/gpu/kernels/sp_overlap.cu +78 -78
  9. overlay/htm_rust/src/gpu/kernels/sp_topk.cu +117 -117
  10. overlay/htm_rust/src/gpu/kernels/tm_activate.cu +66 -66
  11. overlay/htm_rust/src/gpu/kernels/tm_anomaly.cu +43 -43
  12. overlay/htm_rust/src/gpu/kernels/tm_grow.cu +155 -155
  13. overlay/htm_rust/src/gpu/kernels/tm_learn.cu +75 -75
  14. overlay/htm_rust/src/gpu/kernels/tm_predict.cu +102 -102
  15. overlay/htm_rust/src/gpu/kernels/tm_punish.cu +64 -64
  16. overlay/htm_rust/src/gpu/kernels/tm_reset.cu +36 -36
  17. overlay/htm_rust/src/gpu/mod.rs +549 -549
  18. overlay/htm_rust/src/gpu/sp_gpu.rs +796 -796
  19. overlay/htm_rust/src/gpu/tm_gpu.rs +460 -460
  20. overlay/htm_rust/uv.lock +8 -8
  21. overlay/hydra/config.py +2 -2
  22. overlay/hydra/engram.py +121 -104
  23. overlay/hydra/model.py +1 -0
  24. overlay/scripts/autoresearch.py +517 -517
  25. overlay/scripts/chat.py +458 -458
  26. overlay/scripts/chat_eval.py +300 -300
  27. overlay/scripts/compile_debug.py +213 -213
  28. overlay/scripts/dataset_audit.py +241 -241
  29. overlay/scripts/download_sft_data.py +457 -457
  30. overlay/scripts/eval_quality.py +525 -525
  31. overlay/scripts/fetch_corpus.py +211 -211
  32. overlay/scripts/grad_probe.py +196 -196
  33. overlay/scripts/launch_feather_hf_job.py +8 -2
  34. overlay/scripts/profile_forward.py +87 -87
  35. overlay/scripts/run_domain_expanded_pretrain.sh +1 -5
  36. overlay/scripts/sample_utils.py +107 -107
  37. overlay/scripts/sft.py +559 -559
  38. overlay/scripts/sft_orchestrator.sh +165 -165
  39. overlay/subsystems/fused_sdr_project.py +7 -0
  40. overlay/subsystems/htm.py +7 -1
  41. overlay/subsystems/sdr_semantic.py +5 -27
overlay/htm_rust/bench_gpu.py CHANGED
@@ -1,81 +1,81 @@
1
- """Microbenchmark: CPU vs GPU HTMLayer forward at HYDRA training sizes.
2
-
3
- Usage:
4
- source .venv/bin/activate
5
- export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
6
- python htm_rust/bench_gpu.py
7
- """
8
- import os
9
- import sys
10
- import time
11
-
12
- # Ensure /home/mikeb/work/feather is on sys.path so `subsystems` imports.
13
- _FEATHER = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
- if _FEATHER not in sys.path:
15
- sys.path.insert(0, _FEATHER)
16
-
17
- import numpy as np
18
- import torch
19
-
20
- from subsystems.htm import HTMLayer
21
-
22
-
23
- def bench(layer: HTMLayer, sdr: torch.Tensor, warmup: int = 1, iters: int = 3) -> float:
24
- """Return mean ms/forward."""
25
- for _ in range(warmup):
26
- _ = layer(sdr)
27
- if torch.cuda.is_available():
28
- torch.cuda.synchronize()
29
- t0 = time.perf_counter()
30
- for _ in range(iters):
31
- _ = layer(sdr)
32
- if torch.cuda.is_available():
33
- torch.cuda.synchronize()
34
- dt = time.perf_counter() - t0
35
- return dt * 1000 / iters
36
-
37
-
38
- def main() -> None:
39
- # HYDRA training config: B=8, T=2048, bits=16384, cols=2048.
40
- B, T, D = int(os.environ.get("B", 8)), int(os.environ.get("T", 2048)), 16384
41
- n_cols = 2048
42
-
43
- print(f"config: B={B} T={T} D={D} n_cols={n_cols}")
44
- print(f"torch: {torch.__version__} cuda={torch.cuda.is_available()}")
45
-
46
- # Build a fixed sparse SDR once.
47
- rng = np.random.default_rng(0)
48
- sdr = np.zeros((B, T, D), dtype=bool)
49
- on = int(D * 0.02)
50
- for b in range(B):
51
- for t in range(T):
52
- idx = rng.choice(D, size=on, replace=False)
53
- sdr[b, t, idx] = True
54
- sdr_t = torch.from_numpy(sdr)
55
-
56
- # CPU baseline.
57
- print("\n--- CPU ---")
58
- cpu_layer = HTMLayer(
59
- input_bits=D, n_columns=n_cols, cells_per_column=32,
60
- batch_size=B, seed=42, use_gpu=False,
61
- )
62
- cpu_layer.train()
63
- cpu_ms = bench(cpu_layer, sdr_t, warmup=1, iters=2)
64
- print(f"CPU: {cpu_ms:.1f} ms/forward ({cpu_ms/T:.2f} ms/step Γ— T={T})")
65
-
66
- # GPU.
67
- print("\n--- GPU ---")
68
- gpu_layer = HTMLayer(
69
- input_bits=D, n_columns=n_cols, cells_per_column=32,
70
- batch_size=B, seed=42, use_gpu=True,
71
- )
72
- gpu_layer.train()
73
- sdr_cuda = sdr_t.cuda()
74
- gpu_ms = bench(gpu_layer, sdr_cuda, warmup=1, iters=2)
75
- print(f"GPU: {gpu_ms:.1f} ms/forward ({gpu_ms/T:.2f} ms/step Γ— T={T})")
76
-
77
- print(f"\nSpeedup: {cpu_ms / gpu_ms:.2f}x")
78
-
79
-
80
- if __name__ == "__main__":
81
- main()
 
1
+ """Microbenchmark: CPU vs GPU HTMLayer forward at HYDRA training sizes.
2
+
3
+ Usage:
4
+ source .venv/bin/activate
5
+ export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
6
+ python htm_rust/bench_gpu.py
7
+ """
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ # Ensure /home/mikeb/work/feather is on sys.path so `subsystems` imports.
13
+ _FEATHER = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ if _FEATHER not in sys.path:
15
+ sys.path.insert(0, _FEATHER)
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from subsystems.htm import HTMLayer
21
+
22
+
23
+ def bench(layer: HTMLayer, sdr: torch.Tensor, warmup: int = 1, iters: int = 3) -> float:
24
+ """Return mean ms/forward."""
25
+ for _ in range(warmup):
26
+ _ = layer(sdr)
27
+ if torch.cuda.is_available():
28
+ torch.cuda.synchronize()
29
+ t0 = time.perf_counter()
30
+ for _ in range(iters):
31
+ _ = layer(sdr)
32
+ if torch.cuda.is_available():
33
+ torch.cuda.synchronize()
34
+ dt = time.perf_counter() - t0
35
+ return dt * 1000 / iters
36
+
37
+
38
+ def main() -> None:
39
+ # HYDRA training config: B=8, T=2048, bits=16384, cols=2048.
40
+ B, T, D = int(os.environ.get("B", 8)), int(os.environ.get("T", 2048)), 16384
41
+ n_cols = 2048
42
+
43
+ print(f"config: B={B} T={T} D={D} n_cols={n_cols}")
44
+ print(f"torch: {torch.__version__} cuda={torch.cuda.is_available()}")
45
+
46
+ # Build a fixed sparse SDR once.
47
+ rng = np.random.default_rng(0)
48
+ sdr = np.zeros((B, T, D), dtype=bool)
49
+ on = int(D * 0.02)
50
+ for b in range(B):
51
+ for t in range(T):
52
+ idx = rng.choice(D, size=on, replace=False)
53
+ sdr[b, t, idx] = True
54
+ sdr_t = torch.from_numpy(sdr)
55
+
56
+ # CPU baseline.
57
+ print("\n--- CPU ---")
58
+ cpu_layer = HTMLayer(
59
+ input_bits=D, n_columns=n_cols, cells_per_column=32,
60
+ batch_size=B, seed=42, use_gpu=False,
61
+ )
62
+ cpu_layer.train()
63
+ cpu_ms = bench(cpu_layer, sdr_t, warmup=1, iters=2)
64
+ print(f"CPU: {cpu_ms:.1f} ms/forward ({cpu_ms/T:.2f} ms/step Γ— T={T})")
65
+
66
+ # GPU.
67
+ print("\n--- GPU ---")
68
+ gpu_layer = HTMLayer(
69
+ input_bits=D, n_columns=n_cols, cells_per_column=32,
70
+ batch_size=B, seed=42, use_gpu=True,
71
+ )
72
+ gpu_layer.train()
73
+ sdr_cuda = sdr_t.cuda()
74
+ gpu_ms = bench(gpu_layer, sdr_cuda, warmup=1, iters=2)
75
+ print(f"GPU: {gpu_ms:.1f} ms/forward ({gpu_ms/T:.2f} ms/step Γ— T={T})")
76
+
77
+ print(f"\nSpeedup: {cpu_ms / gpu_ms:.2f}x")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main()
overlay/htm_rust/build.rs CHANGED
@@ -26,11 +26,8 @@ fn main() {
26
  return;
27
  }
28
 
29
- let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
30
- let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
31
-
32
- // Base kernels β€” compile for any sm_80+ GPU. Each .cu file β†’ one .ptx file.
33
- let base_kernels: &[&str] = &[
34
  "sp_overlap",
35
  "sp_topk",
36
  "sp_learn",
@@ -43,20 +40,17 @@ fn main() {
43
  "tm_grow",
44
  "tm_anomaly",
45
  "tm_reset",
 
46
  ];
47
 
48
- // htm_fused_step now compiles for ALL architectures (sm_80+).
49
- // On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
50
- // On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
51
- // with grid.sync() for cross-block synchronization (cooperative launch).
52
- let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
53
-
54
  let kernels_dir = PathBuf::from("src/gpu/kernels");
55
- for k in &kernels {
56
  let src = kernels_dir.join(format!("{k}.cu"));
57
  println!("cargo:rerun-if-changed={}", src.display());
58
  }
59
 
 
 
60
 
61
  let nvcc = find_nvcc();
62
  println!("cargo:warning=htm_rust: nvcc = {nvcc}");
 
26
  return;
27
  }
28
 
29
+ // Kernels to compile. Each .cu file β†’ one .ptx file, embedded by name.
30
+ let kernels: &[&str] = &[
 
 
 
31
  "sp_overlap",
32
  "sp_topk",
33
  "sp_learn",
 
40
  "tm_grow",
41
  "tm_anomaly",
42
  "tm_reset",
43
+ "htm_fused_step",
44
  ];
45
 
 
 
 
 
 
 
46
  let kernels_dir = PathBuf::from("src/gpu/kernels");
47
+ for k in kernels {
48
  let src = kernels_dir.join(format!("{k}.cu"));
49
  println!("cargo:rerun-if-changed={}", src.display());
50
  }
51
 
52
+ let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
53
+ let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
54
 
55
  let nvcc = find_nvcc();
56
  println!("cargo:warning=htm_rust: nvcc = {nvcc}");
overlay/htm_rust/docs/GPU_HTM.md CHANGED
@@ -1,302 +1,302 @@
1
- # GPU HTM Backend
2
-
3
- ## Status
4
-
5
- **FUSED MEGAKERNEL: entire T-timestep SP+TM forward collapsed into a single
6
- CUDA launch per forward pass.**
7
-
8
- * Legacy path: 12 kernels Γ— T=2048 timesteps = 24K launches per forward.
9
- * Fused path: **1 launch per forward** (24000Γ— launch-overhead reduction).
10
- * End-to-end training throughput: **~2.7k β†’ ~60k tok/sec** (~22x speedup).
11
- * Fused path uses per-column threshold inhibition instead of global top-K
12
- (see Β§Fused Kernel below β€” this is a real architectural change).
13
-
14
- ## Fused Kernel
15
-
16
- ### Why
17
-
18
- Global top-K column selection requires cross-block synchronization at every
19
- timestep. On WSL2/sm_86 without `-rdc=true`, `cooperative_groups::grid_sync()`
20
- is unreliable. Without a grid sync, collapsing the T-loop into one kernel is
21
- impossible, so every forward pays 12Γ—T kernel launches and 90%+ of runtime is
22
- CUDA launch overhead + small-kernel tails.
23
-
24
- ### How
25
-
26
- Replace global top-K with **per-column threshold activation**:
27
-
28
- is_active[c] = (overlap[c] * boost[c]) > inhibition_threshold[c]
29
-
30
- `inhibition_threshold[c]` is a per-column scalar, learned via EMA update:
31
-
32
- err = active_duty[c] - sparsity_target
33
- new_thr = clamp(thr + thr_adapt_rate * err * 100, 0.1, 1000)
34
-
35
- This is biologically grounded (GABAergic local lateral inhibition in
36
- neocortical columns) and supported by HTM theory. The duty-cycle-driven
37
- feedback loop was already present; we simply redirect its output to drive
38
- activation threshold instead of multiplicative boost. The global top-K,
39
- which had no biological basis, is removed.
40
-
41
- ### Cross-block coherence
42
-
43
- - **Ping-pong bitsets** for `cell_active_bits` and `cell_winner_bits`: at
44
- even t write to `_a`, read from `_b`; at odd t reversed. This eliminates
45
- the need for an in-place snapshot kernel between timesteps.
46
- - **Primary path: cooperative launch + hardware grid sync**. Host code probes
47
- `CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH`, computes the cooperative whole-grid
48
- residency limit from occupancy, and launches the fused megakernel with
49
- `cuLaunchCooperativeKernel`. In-kernel barriers use
50
- `cooperative_groups::this_grid().sync()`.
51
- - **Fallback path: software grid barrier** via a 3-slot atomic counter array
52
- (`barrier_counters`). This remains as a compatibility fallback when
53
- cooperative launch is unavailable.
54
- - **Launch invariant**: cooperative launch is capped to the hardware residency
55
- limit for `blockDim.x = 1024`; software fallback remains capped conservatively
56
- (`HTM_FUSED_GRID_CAP`, default 8) to avoid whole-grid spin deadlock.
57
-
58
- ### Kernel structure
59
-
60
- ```
61
- for t in 0..T:
62
- # Phase 0: clear curr_active/curr_winner for my column range
63
- grid_barrier()
64
- # Phase A: SP overlap β†’ boost β†’ threshold β†’ SP learn β†’ duty + threshold EMA
65
- grid_barrier()
66
- # Phase B: TM predict (per cell, per seg) β†’ TM learn (reinforce on match)
67
- # β†’ burst if none predicted β†’ segment grow/reinforce
68
- grid_barrier()
69
- # Phase C: block 0 writes anomaly[t]
70
- ```
71
-
72
- Each warp owns a contiguous slice of columns. At grid=24 blocks Γ— 32 warps =
73
- 768 warps, n_columns=2048 β†’ 2-3 columns per warp.
74
-
75
- ### Parity with legacy GPU path
76
-
77
- **Semantics diverge**. Legacy: exactly `k = round(sparsity * n_cols)` columns
78
- active per step. Fused: variable, converging to `sparsity * n_cols` on
79
- average via the per-column EMA. Anomaly decay on repeating sequences is
80
- preserved (see `gpu_fused_tm_anomaly_decays_on_repeating_sequence` test).
81
-
82
- This is an intentional architectural change committed under
83
- `no-bypass/full-architecture` per program.md rules. The legacy top-K path
84
- (`step_many_cuda`) remains available for reference and can be re-enabled via
85
- `HYDRA_HTM_FUSED=0`.
86
-
87
- ### Tests
88
-
89
- - `gpu_threshold_converges_to_sparsity` (tests.rs): 1000-step warmup on
90
- random SDRs, then measure mean active cols/step on next 200 steps. Must
91
- land within [0.25Γ—, 4Γ—] of `sparsity_target * n_cols`.
92
- - `gpu_fused_tm_anomaly_decays_on_repeating_sequence`: feed A,B,C repeating
93
- for 300 steps. Late anomaly must be < early anomaly AND < 0.5.
94
-
95
- ## Legacy Pipeline (kept for fallback)
96
-
97
- * SP: 5 kernels, bit-identical parity with CPU under strict-parity mode.
98
- * TM: 7 kernels, relaxed-parity with CPU.
99
- * Speedup at training size (B=8, T=2048, bits=16384): **3.83x** vs CPU.
100
-
101
- ## Building
102
-
103
- CPU-only (default, zero CUDA dep):
104
- ```bash
105
- cargo build --release
106
- ```
107
-
108
- GPU-enabled:
109
- ```bash
110
- export PATH=/usr/local/cuda-12.1/bin:$PATH
111
- export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
112
- export HTM_PTX_VERSION=7.8 # lower if driver older than nvcc
113
- cargo build --release --features gpu
114
- cargo test --release --features gpu --lib # fused path includes cooperative launch + grid-sync tests
115
-
116
- # Python wheel:
117
- maturin develop --release --features gpu --manifest-path htm_rust/Cargo.toml
118
- ```
119
-
120
- ## Architecture
121
-
122
- ### Module layout
123
- ```
124
- src/gpu/
125
- mod.rs # HTMRegionGpu pyclass + step_many_gpu (full pipeline)
126
- sp_gpu.rs # Persistent SP device buffers + step_batch_with_tm
127
- tm_gpu.rs # Persistent TM device buffers + step (predict→activate→learn)
128
- tests.rs # CPU-vs-GPU SP parity + end-to-end TM anomaly decay
129
- kernels/
130
- sp_overlap.cu # per-column overlap reduction
131
- sp_topk.cu # k-WTA top-K winner selection
132
- sp_learn.cu # Hebbian +inc/-dec on proximal synapses
133
- sp_duty.cu # EMA duty-cycle update
134
- sp_boost_fused.cu # fused mean + exp boost (GPU-side)
135
- tm_reset.cu # per-step: snapshot active→prev, clear buffers
136
- tm_predict.cu # per-cell: score owned segments vs prev_active_bits
137
- tm_activate.cu # per-col: activate predicted cells OR burst
138
- tm_learn.cu # per-cell: reinforce correctly-predicted segments
139
- tm_punish.cu # per-cell: decay matching segs on inactive cols
140
- tm_grow.cu # per-bursting-col: reuse matching seg OR create new,
141
- # grow synapses to prev_winners
142
- tm_anomaly.cu # per-step: unpredicted/active ratio
143
- ```
144
-
145
- ### Persistent SP state (per region, unchanged from Phase 1)
146
- At n_cols=2048, S=40, bits=16384: ~355 KB persistent + ~90 KB transient.
147
-
148
- ### Persistent TM state (per region)
149
-
150
- Capacity knobs (configured in `tm_gpu.rs`):
151
- - `MAX_SEGMENTS_PER_CELL = 4`
152
- - `MAX_SYN_PER_SEGMENT = 20`
153
-
154
- At cells_per_col=32, n_cols=2048:
155
- - `n_cells = 65_536`
156
- - `n_segments_max = 262_144` (~262K)
157
- - `n_synapses_max = 5_242_880` (~5.2M)
158
-
159
- | Buffer | Shape / type | Notes |
160
- |-----------------------|----------------------|----------------------------------------|
161
- | `seg_cell_id` | (n_segs,) u32 | owning cell; U32_MAX = unused |
162
- | `seg_syn_count` | (n_segs,) u32 | #active synapses in slot |
163
- | `syn_presyn` | (n_segs Γ— S,) u32 | presynaptic cell indices |
164
- | `syn_perm` | (n_segs Γ— S,) i16 | permanence scaled 0..32767 (0.0..1.0) |
165
- | `cell_seg_count` | (n_cells,) u32 | segments allocated on each cell |
166
- | `cell_active_bits` | (n_cells/32,) u32 | packed bitset, current step |
167
- | `cell_winner_bits` | (n_cells/32,) u32 | packed bitset, current step |
168
- | `cell_predictive_bits`| (n_cells/32,) u32 | set by predict, read by activate |
169
- | `prev_active_bits` | (n_cells/32,) u32 | snapshot at step start |
170
- | `prev_winner_bits` | (n_cells/32,) u32 | snapshot at step start |
171
- | `col_predicted` | (n_cols,) u8 | set if any cell in col is predictive |
172
- | `col_best_match` | (n_cols,) u32 | packed (pot<<21 | seg_id), atomicMax |
173
- | `seg_num_active_conn` | (n_segs,) u32 | output of predict |
174
- | `seg_num_active_pot` | (n_segs,) u32 | output of predict |
175
- | `unpredicted_count` | (1,) u32 | atomic counter for anomaly |
176
- | `burst_cols_flat` | (n_cols,) u32 | list of bursting cols |
177
- | `burst_cols_count` | (1,) u32 | length of above list |
178
-
179
- **Total per TM region: ~42 MB.** Batch of 8 regions: ~340 MB. Fits 6 GB RTX 3060.
180
-
181
- ### Per-step pipeline (single iteration of `step_batch_with_tm`)
182
-
183
- ```
184
- SP side TM side
185
- --------- ---------
186
- 1. D2D input slice β†’ inp_dev
187
- 2. sp_overlap (n_cols blocks)
188
- 3. sp_topk (1 block)
189
- 4. sp_learn (n_cols blocks)
190
- 5. sp_duty (n_cols/256 blocks)
191
- 6. sp_boost_fused (1 block)
192
- 7. D2D active_mask β†’ cols_dev[ti]
193
- 8. tm_reset_step (ceil(n_cells/32/256))
194
- 9. tm_predict (n_cells blocks Γ— 32 thr)
195
- 10. tm_activate (n_cols/256 blocks)
196
- 11. tm_anomaly (1 block)
197
- if learn:
198
- 12. tm_learn (n_cells blocks)
199
- 13. tm_punish (n_cells blocks)
200
- 14. tm_grow (n_cols blocks β€” early-exits)
201
- ```
202
-
203
- No host sync in the T-step loop. At the end one `dtoh_sync_copy` each for
204
- `cols_dev` (T Γ— n_cols bytes) and `anom_dev` (T Γ— f32).
205
-
206
- ## Parity
207
-
208
- ### SP: strict bit-identical
209
- See Phase 1 docs β€” `gpu_sp_matches_cpu_with_learn` over 50 steps passes exact.
210
-
211
- ### TM: relaxed-parity
212
- The GPU TM has known, deliberate deviations from CPU to admit massive parallelism:
213
-
214
- 1. **Bursting winner cell**: CPU picks the least-used cell (fewest segments) with
215
- random tiebreak. GPU picks cell 0 of the column (deterministic, branch-free).
216
- Learning dynamics are preserved because segment creation/reinforcement is
217
- the dominant effect, not which specific cell in a bursting column wins.
218
-
219
- 2. **Permanence storage**: i16 fixed-point (scale 32767) vs f32. Rounding
220
- differs by <=1 ULP of the scale (~3.0e-5), below any meaningful learning
221
- quantum (inc=0.10, dec=0.10, predicted_segment_dec=0.10).
222
-
223
- 3. **Grown synapse candidate order**: CPU randomly samples from prev_winner_cells.
224
- GPU iterates prev_winner_bits words in a pseudo-random rotated order keyed
225
- by (bursting_col_idx, iter_seed). Output is a different subset but same size.
226
-
227
- 4. **Segment LRU eviction**: CPU tracks `last_used_iteration` per segment.
228
- GPU wraps around (slot = count % max_segments_per_cell). In the autoresearch
229
- loop where TM resets every forward, eviction rarely triggers.
230
-
231
- The GPU parity test (`gpu_tm_anomaly_decays_on_repeating_sequence`) feeds a
232
- repeating A,B,C sequence and asserts anomaly decays: **1.000 early β†’ 0.000 late**.
233
-
234
- ## Bottleneck Analysis
235
-
236
- | Source | Cost/step (B=8 T=2048) |
237
- |----------------------------------|-------------------------:|
238
- | 14 kernel launches | ~70 ΞΌs |
239
- | ~262K predict/learn/punish blocks| ~2.5 ms |
240
- | No D2H until end-of-batch | 0 ΞΌs |
241
- | Final D2H (T Γ— n_cols + T Γ— f32) | ~200 ΞΌs per region |
242
-
243
- Per-step wall time at B=8 T=2048:
244
- - CPU (reference): **~11.4 ms / step**
245
- - GPU (current): **~2.98 ms / step**
246
- - **Speedup: 3.83x**
247
-
248
- ## End-to-End Training Benchmark
249
-
250
- **Config**: B=8, T=2048, vocab=8192, 60-second time budget, full HYDRA stack
251
- (SDR Semantic + HTM + Mamba-3 + Engram + mHC + Hestia QAT).
252
-
253
- **Results**:
254
- - GPU util: **97-98% sustained**
255
- - VRAM: **5.4 GB / 6.0 GB** (90% utilisation)
256
- - Steps completed: 16
257
- - tok/sec: **~2,200-2,500** (stable post-warmup)
258
- - Final val_bpb: **2.249** (from ~3.1 initial)
259
- - Factual eval: 1/9 hits
260
-
261
- Compared to previous CPU-HTM baseline (~100 tok/s), the full-GPU HTM delivers
262
- **~22x end-to-end throughput** β€” far above the 3-10x target.
263
-
264
- ## Bench Commands
265
-
266
- ```bash
267
- source .venv/bin/activate
268
- export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
269
-
270
- # Microbench
271
- B=8 T=2048 python htm_rust/bench_gpu.py
272
-
273
- # Full training
274
- HYDRA_TIME_BUDGET=60 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=32768 python -u train.py
275
- ```
276
-
277
- ## Known Limitations / Future Work
278
-
279
- - **Segment-compacted launches**: predict/learn/punish iterate all n_cells
280
- blocks, using `cell_seg_count` to skip empty cells. A compacted live-cell
281
- list would shave another ~40% of launch overhead.
282
- - **Winner selection**: currently cell 0 of bursting col. Proper least-used
283
- selection would help stability of cross-column patterns.
284
- - **Single CUDA stream per region**: with B=8 regions we serialise on stream 0.
285
- Multi-stream would lift the ~20% launch overhead at small batch sizes.
286
- - **Permanence bump on chronically under-stimulated columns**: SP's strict-parity
287
- bump is not mirrored on GPU fast path. Effect on long runs needs measurement.
288
- - **`seg_num_active_conn` output is reused across reinforce + punish**: the two
289
- kernels each launch n_cells blocks. They could be fused into one for one fewer
290
- kernel launch per step.
291
-
292
- ## Files
293
-
294
- - `htm_rust/build.rs` β€” nvcc-driven PTX compilation, 12 kernels.
295
- - `htm_rust/Cargo.toml` β€” `gpu` feature flag, cudarc dep.
296
- - `htm_rust/src/gpu/mod.rs` β€” `HTMRegionGpu` pyclass + `step_many_gpu`.
297
- - `htm_rust/src/gpu/sp_gpu.rs` β€” SP state + `step_batch_with_tm`.
298
- - `htm_rust/src/gpu/tm_gpu.rs` β€” TM state + `step`.
299
- - `htm_rust/src/gpu/tests.rs` β€” parity + correctness tests.
300
- - `htm_rust/src/gpu/kernels/*.cu` β€” 5 SP + 7 TM kernels.
301
- - `htm_rust/bench_gpu.py` β€” CPU-vs-GPU microbench.
302
- - `subsystems/htm.py` β€” transparent GPU/CPU backend selection in `HTMLayer`.
 
1
+ # GPU HTM Backend
2
+
3
+ ## Status
4
+
5
+ **FUSED MEGAKERNEL: entire T-timestep SP+TM forward collapsed into a single
6
+ CUDA launch per forward pass.**
7
+
8
+ * Legacy path: 12 kernels Γ— T=2048 timesteps = 24K launches per forward.
9
+ * Fused path: **1 launch per forward** (24000Γ— launch-overhead reduction).
10
+ * End-to-end training throughput: **~2.7k β†’ ~60k tok/sec** (~22x speedup).
11
+ * Fused path uses per-column threshold inhibition instead of global top-K
12
+ (see Β§Fused Kernel below β€” this is a real architectural change).
13
+
14
+ ## Fused Kernel
15
+
16
+ ### Why
17
+
18
+ Global top-K column selection requires cross-block synchronization at every
19
+ timestep. On WSL2/sm_86 without `-rdc=true`, `cooperative_groups::grid_sync()`
20
+ is unreliable. Without a grid sync, collapsing the T-loop into one kernel is
21
+ impossible, so every forward pays 12Γ—T kernel launches and 90%+ of runtime is
22
+ CUDA launch overhead + small-kernel tails.
23
+
24
+ ### How
25
+
26
+ Replace global top-K with **per-column threshold activation**:
27
+
28
+ is_active[c] = (overlap[c] * boost[c]) > inhibition_threshold[c]
29
+
30
+ `inhibition_threshold[c]` is a per-column scalar, learned via EMA update:
31
+
32
+ err = active_duty[c] - sparsity_target
33
+ new_thr = clamp(thr + thr_adapt_rate * err * 100, 0.1, 1000)
34
+
35
+ This is biologically grounded (GABAergic local lateral inhibition in
36
+ neocortical columns) and supported by HTM theory. The duty-cycle-driven
37
+ feedback loop was already present; we simply redirect its output to drive
38
+ activation threshold instead of multiplicative boost. The global top-K,
39
+ which had no biological basis, is removed.
40
+
41
+ ### Cross-block coherence
42
+
43
+ - **Ping-pong bitsets** for `cell_active_bits` and `cell_winner_bits`: at
44
+ even t write to `_a`, read from `_b`; at odd t reversed. This eliminates
45
+ the need for an in-place snapshot kernel between timesteps.
46
+ - **Primary path: cooperative launch + hardware grid sync**. Host code probes
47
+ `CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH`, computes the cooperative whole-grid
48
+ residency limit from occupancy, and launches the fused megakernel with
49
+ `cuLaunchCooperativeKernel`. In-kernel barriers use
50
+ `cooperative_groups::this_grid().sync()`.
51
+ - **Fallback path: software grid barrier** via a 3-slot atomic counter array
52
+ (`barrier_counters`). This remains as a compatibility fallback when
53
+ cooperative launch is unavailable.
54
+ - **Launch invariant**: cooperative launch is capped to the hardware residency
55
+ limit for `blockDim.x = 1024`; software fallback remains capped conservatively
56
+ (`HTM_FUSED_GRID_CAP`, default 8) to avoid whole-grid spin deadlock.
57
+
58
+ ### Kernel structure
59
+
60
+ ```
61
+ for t in 0..T:
62
+ # Phase 0: clear curr_active/curr_winner for my column range
63
+ grid_barrier()
64
+ # Phase A: SP overlap β†’ boost β†’ threshold β†’ SP learn β†’ duty + threshold EMA
65
+ grid_barrier()
66
+ # Phase B: TM predict (per cell, per seg) β†’ TM learn (reinforce on match)
67
+ # β†’ burst if none predicted β†’ segment grow/reinforce
68
+ grid_barrier()
69
+ # Phase C: block 0 writes anomaly[t]
70
+ ```
71
+
72
+ Each warp owns a contiguous slice of columns. At grid=24 blocks Γ— 32 warps =
73
+ 768 warps, n_columns=2048 β†’ 2-3 columns per warp.
74
+
75
+ ### Parity with legacy GPU path
76
+
77
+ **Semantics diverge**. Legacy: exactly `k = round(sparsity * n_cols)` columns
78
+ active per step. Fused: variable, converging to `sparsity * n_cols` on
79
+ average via the per-column EMA. Anomaly decay on repeating sequences is
80
+ preserved (see `gpu_fused_tm_anomaly_decays_on_repeating_sequence` test).
81
+
82
+ This is an intentional architectural change committed under
83
+ `no-bypass/full-architecture` per program.md rules. The legacy top-K path
84
+ (`step_many_cuda`) remains available for reference and can be re-enabled via
85
+ `HYDRA_HTM_FUSED=0`.
86
+
87
+ ### Tests
88
+
89
+ - `gpu_threshold_converges_to_sparsity` (tests.rs): 1000-step warmup on
90
+ random SDRs, then measure mean active cols/step on next 200 steps. Must
91
+ land within [0.25Γ—, 4Γ—] of `sparsity_target * n_cols`.
92
+ - `gpu_fused_tm_anomaly_decays_on_repeating_sequence`: feed A,B,C repeating
93
+ for 300 steps. Late anomaly must be < early anomaly AND < 0.5.
94
+
95
+ ## Legacy Pipeline (kept for fallback)
96
+
97
+ * SP: 5 kernels, bit-identical parity with CPU under strict-parity mode.
98
+ * TM: 7 kernels, relaxed-parity with CPU.
99
+ * Speedup at training size (B=8, T=2048, bits=16384): **3.83x** vs CPU.
100
+
101
+ ## Building
102
+
103
+ CPU-only (default, zero CUDA dep):
104
+ ```bash
105
+ cargo build --release
106
+ ```
107
+
108
+ GPU-enabled:
109
+ ```bash
110
+ export PATH=/usr/local/cuda-12.1/bin:$PATH
111
+ export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
112
+ export HTM_PTX_VERSION=7.8 # lower if driver older than nvcc
113
+ cargo build --release --features gpu
114
+ cargo test --release --features gpu --lib # fused path includes cooperative launch + grid-sync tests
115
+
116
+ # Python wheel:
117
+ maturin develop --release --features gpu --manifest-path htm_rust/Cargo.toml
118
+ ```
119
+
120
+ ## Architecture
121
+
122
+ ### Module layout
123
+ ```
124
+ src/gpu/
125
+ mod.rs # HTMRegionGpu pyclass + step_many_gpu (full pipeline)
126
+ sp_gpu.rs # Persistent SP device buffers + step_batch_with_tm
127
+ tm_gpu.rs # Persistent TM device buffers + step (predict→activate→learn)
128
+ tests.rs # CPU-vs-GPU SP parity + end-to-end TM anomaly decay
129
+ kernels/
130
+ sp_overlap.cu # per-column overlap reduction
131
+ sp_topk.cu # k-WTA top-K winner selection
132
+ sp_learn.cu # Hebbian +inc/-dec on proximal synapses
133
+ sp_duty.cu # EMA duty-cycle update
134
+ sp_boost_fused.cu # fused mean + exp boost (GPU-side)
135
+ tm_reset.cu # per-step: snapshot active→prev, clear buffers
136
+ tm_predict.cu # per-cell: score owned segments vs prev_active_bits
137
+ tm_activate.cu # per-col: activate predicted cells OR burst
138
+ tm_learn.cu # per-cell: reinforce correctly-predicted segments
139
+ tm_punish.cu # per-cell: decay matching segs on inactive cols
140
+ tm_grow.cu # per-bursting-col: reuse matching seg OR create new,
141
+ # grow synapses to prev_winners
142
+ tm_anomaly.cu # per-step: unpredicted/active ratio
143
+ ```
144
+
145
+ ### Persistent SP state (per region, unchanged from Phase 1)
146
+ At n_cols=2048, S=40, bits=16384: ~355 KB persistent + ~90 KB transient.
147
+
148
+ ### Persistent TM state (per region)
149
+
150
+ Capacity knobs (configured in `tm_gpu.rs`):
151
+ - `MAX_SEGMENTS_PER_CELL = 4`
152
+ - `MAX_SYN_PER_SEGMENT = 20`
153
+
154
+ At cells_per_col=32, n_cols=2048:
155
+ - `n_cells = 65_536`
156
+ - `n_segments_max = 262_144` (~262K)
157
+ - `n_synapses_max = 5_242_880` (~5.2M)
158
+
159
+ | Buffer | Shape / type | Notes |
160
+ |-----------------------|----------------------|----------------------------------------|
161
+ | `seg_cell_id` | (n_segs,) u32 | owning cell; U32_MAX = unused |
162
+ | `seg_syn_count` | (n_segs,) u32 | #active synapses in slot |
163
+ | `syn_presyn` | (n_segs Γ— S,) u32 | presynaptic cell indices |
164
+ | `syn_perm` | (n_segs Γ— S,) i16 | permanence scaled 0..32767 (0.0..1.0) |
165
+ | `cell_seg_count` | (n_cells,) u32 | segments allocated on each cell |
166
+ | `cell_active_bits` | (n_cells/32,) u32 | packed bitset, current step |
167
+ | `cell_winner_bits` | (n_cells/32,) u32 | packed bitset, current step |
168
+ | `cell_predictive_bits`| (n_cells/32,) u32 | set by predict, read by activate |
169
+ | `prev_active_bits` | (n_cells/32,) u32 | snapshot at step start |
170
+ | `prev_winner_bits` | (n_cells/32,) u32 | snapshot at step start |
171
+ | `col_predicted` | (n_cols,) u8 | set if any cell in col is predictive |
172
+ | `col_best_match` | (n_cols,) u32 | packed (pot<<21 | seg_id), atomicMax |
173
+ | `seg_num_active_conn` | (n_segs,) u32 | output of predict |
174
+ | `seg_num_active_pot` | (n_segs,) u32 | output of predict |
175
+ | `unpredicted_count` | (1,) u32 | atomic counter for anomaly |
176
+ | `burst_cols_flat` | (n_cols,) u32 | list of bursting cols |
177
+ | `burst_cols_count` | (1,) u32 | length of above list |
178
+
179
+ **Total per TM region: ~42 MB.** Batch of 8 regions: ~340 MB. Fits 6 GB RTX 3060.
180
+
181
+ ### Per-step pipeline (single iteration of `step_batch_with_tm`)
182
+
183
+ ```
184
+ SP side TM side
185
+ --------- ---------
186
+ 1. D2D input slice β†’ inp_dev
187
+ 2. sp_overlap (n_cols blocks)
188
+ 3. sp_topk (1 block)
189
+ 4. sp_learn (n_cols blocks)
190
+ 5. sp_duty (n_cols/256 blocks)
191
+ 6. sp_boost_fused (1 block)
192
+ 7. D2D active_mask β†’ cols_dev[ti]
193
+ 8. tm_reset_step (ceil(n_cells/32/256))
194
+ 9. tm_predict (n_cells blocks Γ— 32 thr)
195
+ 10. tm_activate (n_cols/256 blocks)
196
+ 11. tm_anomaly (1 block)
197
+ if learn:
198
+ 12. tm_learn (n_cells blocks)
199
+ 13. tm_punish (n_cells blocks)
200
+ 14. tm_grow (n_cols blocks β€” early-exits)
201
+ ```
202
+
203
+ No host sync in the T-step loop. At the end one `dtoh_sync_copy` each for
204
+ `cols_dev` (T Γ— n_cols bytes) and `anom_dev` (T Γ— f32).
205
+
206
+ ## Parity
207
+
208
+ ### SP: strict bit-identical
209
+ See Phase 1 docs β€” `gpu_sp_matches_cpu_with_learn` over 50 steps passes exact.
210
+
211
+ ### TM: relaxed-parity
212
+ The GPU TM has known, deliberate deviations from CPU to admit massive parallelism:
213
+
214
+ 1. **Bursting winner cell**: CPU picks the least-used cell (fewest segments) with
215
+ random tiebreak. GPU picks cell 0 of the column (deterministic, branch-free).
216
+ Learning dynamics are preserved because segment creation/reinforcement is
217
+ the dominant effect, not which specific cell in a bursting column wins.
218
+
219
+ 2. **Permanence storage**: i16 fixed-point (scale 32767) vs f32. Rounding
220
+ differs by <=1 ULP of the scale (~3.0e-5), below any meaningful learning
221
+ quantum (inc=0.10, dec=0.10, predicted_segment_dec=0.10).
222
+
223
+ 3. **Grown synapse candidate order**: CPU randomly samples from prev_winner_cells.
224
+ GPU iterates prev_winner_bits words in a pseudo-random rotated order keyed
225
+ by (bursting_col_idx, iter_seed). Output is a different subset but same size.
226
+
227
+ 4. **Segment LRU eviction**: CPU tracks `last_used_iteration` per segment.
228
+ GPU wraps around (slot = count % max_segments_per_cell). In the autoresearch
229
+ loop where TM resets every forward, eviction rarely triggers.
230
+
231
+ The GPU parity test (`gpu_tm_anomaly_decays_on_repeating_sequence`) feeds a
232
+ repeating A,B,C sequence and asserts anomaly decays: **1.000 early β†’ 0.000 late**.
233
+
234
+ ## Bottleneck Analysis
235
+
236
+ | Source | Cost/step (B=8 T=2048) |
237
+ |----------------------------------|-------------------------:|
238
+ | 14 kernel launches | ~70 ΞΌs |
239
+ | ~262K predict/learn/punish blocks| ~2.5 ms |
240
+ | No D2H until end-of-batch | 0 ΞΌs |
241
+ | Final D2H (T Γ— n_cols + T Γ— f32) | ~200 ΞΌs per region |
242
+
243
+ Per-step wall time at B=8 T=2048:
244
+ - CPU (reference): **~11.4 ms / step**
245
+ - GPU (current): **~2.98 ms / step**
246
+ - **Speedup: 3.83x**
247
+
248
+ ## End-to-End Training Benchmark
249
+
250
+ **Config**: B=8, T=2048, vocab=8192, 60-second time budget, full HYDRA stack
251
+ (SDR Semantic + HTM + Mamba-3 + Engram + mHC + Hestia QAT).
252
+
253
+ **Results**:
254
+ - GPU util: **97-98% sustained**
255
+ - VRAM: **5.4 GB / 6.0 GB** (90% utilisation)
256
+ - Steps completed: 16
257
+ - tok/sec: **~2,200-2,500** (stable post-warmup)
258
+ - Final val_bpb: **2.249** (from ~3.1 initial)
259
+ - Factual eval: 1/9 hits
260
+
261
+ Compared to previous CPU-HTM baseline (~100 tok/s), the full-GPU HTM delivers
262
+ **~22x end-to-end throughput** β€” far above the 3-10x target.
263
+
264
+ ## Bench Commands
265
+
266
+ ```bash
267
+ source .venv/bin/activate
268
+ export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
269
+
270
+ # Microbench
271
+ B=8 T=2048 python htm_rust/bench_gpu.py
272
+
273
+ # Full training
274
+ HYDRA_TIME_BUDGET=60 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=32768 python -u train.py
275
+ ```
276
+
277
+ ## Known Limitations / Future Work
278
+
279
+ - **Segment-compacted launches**: predict/learn/punish iterate all n_cells
280
+ blocks, using `cell_seg_count` to skip empty cells. A compacted live-cell
281
+ list would shave another ~40% of launch overhead.
282
+ - **Winner selection**: currently cell 0 of bursting col. Proper least-used
283
+ selection would help stability of cross-column patterns.
284
+ - **Single CUDA stream per region**: with B=8 regions we serialise on stream 0.
285
+ Multi-stream would lift the ~20% launch overhead at small batch sizes.
286
+ - **Permanence bump on chronically under-stimulated columns**: SP's strict-parity
287
+ bump is not mirrored on GPU fast path. Effect on long runs needs measurement.
288
+ - **`seg_num_active_conn` output is reused across reinforce + punish**: the two
289
+ kernels each launch n_cells blocks. They could be fused into one for one fewer
290
+ kernel launch per step.
291
+
292
+ ## Files
293
+
294
+ - `htm_rust/build.rs` β€” nvcc-driven PTX compilation, 12 kernels.
295
+ - `htm_rust/Cargo.toml` β€” `gpu` feature flag, cudarc dep.
296
+ - `htm_rust/src/gpu/mod.rs` β€” `HTMRegionGpu` pyclass + `step_many_gpu`.
297
+ - `htm_rust/src/gpu/sp_gpu.rs` β€” SP state + `step_batch_with_tm`.
298
+ - `htm_rust/src/gpu/tm_gpu.rs` β€” TM state + `step`.
299
+ - `htm_rust/src/gpu/tests.rs` β€” parity + correctness tests.
300
+ - `htm_rust/src/gpu/kernels/*.cu` β€” 5 SP + 7 TM kernels.
301
+ - `htm_rust/bench_gpu.py` β€” CPU-vs-GPU microbench.
302
+ - `subsystems/htm.py` β€” transparent GPU/CPU backend selection in `HTMLayer`.
overlay/htm_rust/src/gpu/fused.rs CHANGED
@@ -20,15 +20,15 @@
20
  use std::ffi::CString;
21
  use std::sync::Arc;
22
 
23
- use cudarc::driver::{result, sys, CudaDevice, CudaSlice, DeviceRepr, DevicePtr, DriverError,
24
- LaunchConfig};
 
25
  use cudarc::nvrtc::Ptx;
26
 
27
  use super::sp_gpu::SpatialPoolerGpu;
28
  use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
29
 
30
- const PTX_HTM_FUSED: &str =
31
- include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
32
 
33
  /// Struct-by-value pointer pack β€” matches C-side `FusedPtrs`.
34
  ///
@@ -132,11 +132,9 @@ pub(crate) fn plan_fused_launch(
132
  grid_cap_override: Option<u32>,
133
  ) -> Result<FusedLaunchPlan, String> {
134
  let sm_count = sm_count.max(1);
135
- // 1024 threads/block exceeds the register file on Ampere (sm_86: 65536
136
- // regs/SM Γ· 1024 = 64 regs/thread; fused kernel needs ~80+). 256 gives
137
- // 256 regs/thread which is ample. Compensate with more blocks via
138
- // cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
139
- // 1024 works fine, but 256 is safe everywhere.
140
  let block_dim_x = 256u32;
141
 
142
  // Cluster launch path: cooperative launch is not required. Keep the probe
@@ -145,10 +143,11 @@ pub(crate) fn plan_fused_launch(
145
  eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
146
  }
147
 
148
- // Tested grid_cap: 4 blocks = 30ms (too serial), 16 blocks = 10.8ms (parallel wins).
149
- // Parallelism in SP overlap + TM predict stages outweighs grid.sync() cost.
 
150
  let default_grid_cap = 16u32;
151
- let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
152
  let resident_bound = if cooperative_grid_limit > 0 {
153
  cooperative_grid_limit.max(sm_count * 2)
154
  } else {
@@ -218,7 +217,7 @@ pub struct FusedState {
218
  pub cell_active_bits_b: CudaSlice<u32>,
219
  pub cell_winner_bits_a: CudaSlice<u32>,
220
  pub cell_winner_bits_b: CudaSlice<u32>,
221
- pub step_scratch: CudaSlice<u32>, // length 6
222
 
223
  pub grid_dim_x: u32,
224
  pub block_dim_x: u32,
@@ -241,7 +240,10 @@ impl FusedState {
241
  initial_threshold: f32,
242
  ) -> Result<Self, DriverError> {
243
  let n_cells = n_columns * cells_per_column;
244
- assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
 
 
 
245
  let bits_words = n_cells / 32;
246
 
247
  let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
@@ -278,7 +280,8 @@ impl FusedState {
278
  // every launched kernel function, otherwise cuLaunchKernelEx rejects
279
  // the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
280
  unsafe {
281
- let attr = sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
 
282
  // Ignore errors: older CUDA may lack the attribute, in which case
283
  // only portable sizes (<= 8) work β€” plan_fused_launch caps at 8.
284
  let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
@@ -294,9 +297,9 @@ impl FusedState {
294
  };
295
 
296
  // T1: Probe Hopper cluster launch capability.
297
- let max_cluster_size = match dev.attribute(
298
- cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH,
299
- ) {
300
  Ok(v) if v > 0 => {
301
  // H200/sm_90a supports up to 16 blocks per cluster.
302
  // There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
@@ -346,7 +349,11 @@ impl FusedState {
346
 
347
  Ok(Self {
348
  dev,
349
- raw_kernel: RawFusedKernel { module, function, function_batched },
 
 
 
 
350
  inhibition_threshold,
351
  cell_active_bits_a,
352
  cell_active_bits_b,
@@ -445,7 +452,7 @@ pub fn launch_fused(
445
  inputs: *inputs_flat.device_ptr(),
446
  cols_out: *cols_out.device_ptr(),
447
  anom_out: *anom_out.device_ptr(),
448
- barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
449
  step_scratch: *fused.step_scratch.device_ptr(),
450
  };
451
 
@@ -493,14 +500,17 @@ pub fn launch_fused(
493
  }
494
  } else {
495
  // Pre-Hopper: cooperative kernel launch. The fused kernel uses
496
- // grid.sync() for cross-block synchronization which REQUIRES
497
- // cuLaunchCooperativeKernel (normal launch silently crashes on
498
- // the first grid.sync() call).
499
  let ret = sys::lib().cuLaunchCooperativeKernel(
500
  fused.raw_kernel.function,
501
- grid_x, 1, 1,
502
- block_x, 1, 1,
503
- 0, // sharedMemBytes
 
 
 
 
504
  cu_stream,
505
  kernel_params.as_mut_ptr(),
506
  );
@@ -616,7 +626,7 @@ pub(super) fn launch_fused_batched_raw(
616
  inputs: inputs_per_region[i],
617
  cols_out: cols_per_region[i],
618
  anom_out: anom_per_region[i],
619
- barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
620
  step_scratch: *r.fused_state.step_scratch.device_ptr(),
621
  }
622
  })
@@ -636,8 +646,8 @@ pub(super) fn launch_fused_batched_raw(
636
  let r0 = unsafe { &*region_ptrs[0] };
637
  r0.fused_state.cluster_info.max_cluster_size > 0
638
  };
639
- let grid_x = plan_batched_grid_dim(grid_x, cooperative_grid_limit, b, use_cluster)
640
- .map_err(|msg| {
641
  eprintln!("[htm_rust] FATAL: {msg}");
642
  DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
643
  })?;
@@ -678,12 +688,19 @@ pub(super) fn launch_fused_batched_raw(
678
  return Err(DriverError(ret));
679
  }
680
  } else {
681
- // Pre-Hopper: cooperative kernel launch (grid.sync() requires it).
 
 
 
682
  let ret = sys::lib().cuLaunchCooperativeKernel(
683
  function_batched,
684
- grid_x, b as u32, 1,
685
- block_x, 1, 1,
686
- 0, // sharedMemBytes
 
 
 
 
687
  cu_stream,
688
  kernel_params.as_mut_ptr(),
689
  );
 
20
  use std::ffi::CString;
21
  use std::sync::Arc;
22
 
23
+ use cudarc::driver::{
24
+ result, sys, CudaDevice, CudaSlice, DevicePtr, DeviceRepr, DriverError, LaunchConfig,
25
+ };
26
  use cudarc::nvrtc::Ptx;
27
 
28
  use super::sp_gpu::SpatialPoolerGpu;
29
  use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
30
 
31
+ const PTX_HTM_FUSED: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
 
32
 
33
  /// Struct-by-value pointer pack β€” matches C-side `FusedPtrs`.
34
  ///
 
132
  grid_cap_override: Option<u32>,
133
  ) -> Result<FusedLaunchPlan, String> {
134
  let sm_count = sm_count.max(1);
135
+ // 1024 threads/block exceeds the register file on Ampere and makes the
136
+ // cooperative-grid residency probe lie when the launch uses a different
137
+ // block size. Keep the planned block size identical to the occupancy probe.
 
 
138
  let block_dim_x = 256u32;
139
 
140
  // Cluster launch path: cooperative launch is not required. Keep the probe
 
143
  eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
144
  }
145
 
146
+ // Cluster constraint: grid_dim_x must equal the cluster size (16) so that
147
+ // each region maps to exactly one cluster. `HTM_FUSED_GRID_CAP` can lower
148
+ // this for debugging but should not exceed 16 for cluster correctness.
149
  let default_grid_cap = 16u32;
150
+ let grid_cap = grid_cap_override.unwrap_or(default_grid_cap).min(16);
151
  let resident_bound = if cooperative_grid_limit > 0 {
152
  cooperative_grid_limit.max(sm_count * 2)
153
  } else {
 
217
  pub cell_active_bits_b: CudaSlice<u32>,
218
  pub cell_winner_bits_a: CudaSlice<u32>,
219
  pub cell_winner_bits_b: CudaSlice<u32>,
220
+ pub step_scratch: CudaSlice<u32>, // length 6
221
 
222
  pub grid_dim_x: u32,
223
  pub block_dim_x: u32,
 
240
  initial_threshold: f32,
241
  ) -> Result<Self, DriverError> {
242
  let n_cells = n_columns * cells_per_column;
243
+ assert!(
244
+ n_cells % 32 == 0,
245
+ "n_cells must be divisible by 32 for bitsets"
246
+ );
247
  let bits_words = n_cells / 32;
248
 
249
  let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
 
280
  // every launched kernel function, otherwise cuLaunchKernelEx rejects
281
  // the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
282
  unsafe {
283
+ let attr =
284
+ sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
285
  // Ignore errors: older CUDA may lack the attribute, in which case
286
  // only portable sizes (<= 8) work β€” plan_fused_launch caps at 8.
287
  let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
 
297
  };
298
 
299
  // T1: Probe Hopper cluster launch capability.
300
+ let max_cluster_size = match dev
301
+ .attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH)
302
+ {
303
  Ok(v) if v > 0 => {
304
  // H200/sm_90a supports up to 16 blocks per cluster.
305
  // There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
 
349
 
350
  Ok(Self {
351
  dev,
352
+ raw_kernel: RawFusedKernel {
353
+ module,
354
+ function,
355
+ function_batched,
356
+ },
357
  inhibition_threshold,
358
  cell_active_bits_a,
359
  cell_active_bits_b,
 
452
  inputs: *inputs_flat.device_ptr(),
453
  cols_out: *cols_out.device_ptr(),
454
  anom_out: *anom_out.device_ptr(),
455
+ barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
456
  step_scratch: *fused.step_scratch.device_ptr(),
457
  };
458
 
 
500
  }
501
  } else {
502
  // Pre-Hopper: cooperative kernel launch. The fused kernel uses
503
+ // cg::this_grid().sync(); normal launches poison the CUDA context
504
+ // with an asynchronous unspecified launch failure.
 
505
  let ret = sys::lib().cuLaunchCooperativeKernel(
506
  fused.raw_kernel.function,
507
+ grid_x,
508
+ 1,
509
+ 1,
510
+ block_x,
511
+ 1,
512
+ 1,
513
+ 0,
514
  cu_stream,
515
  kernel_params.as_mut_ptr(),
516
  );
 
626
  inputs: inputs_per_region[i],
627
  cols_out: cols_per_region[i],
628
  anom_out: anom_per_region[i],
629
+ barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
630
  step_scratch: *r.fused_state.step_scratch.device_ptr(),
631
  }
632
  })
 
646
  let r0 = unsafe { &*region_ptrs[0] };
647
  r0.fused_state.cluster_info.max_cluster_size > 0
648
  };
649
+ let grid_x =
650
+ plan_batched_grid_dim(grid_x, cooperative_grid_limit, b, use_cluster).map_err(|msg| {
651
  eprintln!("[htm_rust] FATAL: {msg}");
652
  DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
653
  })?;
 
688
  return Err(DriverError(ret));
689
  }
690
  } else {
691
+ // Pre-Hopper: cooperative kernel launch. The fused kernel uses
692
+ // cg::this_grid().sync(), which is only valid under cooperative
693
+ // launch. A normal launch can run until the first grid.sync() and
694
+ // then poison the CUDA context with an unspecified launch failure.
695
  let ret = sys::lib().cuLaunchCooperativeKernel(
696
  function_batched,
697
+ grid_x,
698
+ b as u32,
699
+ 1,
700
+ block_x,
701
+ 1,
702
+ 1,
703
+ 0,
704
  cu_stream,
705
  kernel_params.as_mut_ptr(),
706
  );
overlay/htm_rust/src/gpu/kernels/sp_boost_fused.cu CHANGED
@@ -1,59 +1,59 @@
1
- // Fused mean-reduction + boost-update kernel.
2
- //
3
- // Inputs:
4
- // active_duty[n] (f32)
5
- // boost_strength (f32)
6
- //
7
- // Output:
8
- // boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
9
- //
10
- // Launch: single block (1024 threads), shared mem for reduction. At n=2048
11
- // each thread handles 2 elements.
12
-
13
- extern "C" __global__
14
- void sp_boost_from_duty(
15
- const float * __restrict__ active_duty, // (n,)
16
- float * __restrict__ boost, // (n,) in-place out
17
- float boost_strength,
18
- unsigned int n
19
- ) {
20
- extern __shared__ float smem_raw[];
21
- float * smem = smem_raw;
22
- const unsigned int tid = threadIdx.x;
23
- const unsigned int bsz = blockDim.x;
24
-
25
- // Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
26
- float local_sum = 0.0f;
27
- for (unsigned int i = tid; i < n; i += bsz) {
28
- local_sum += active_duty[i];
29
- }
30
- // Warp reduction.
31
- for (int off = 16; off > 0; off >>= 1) {
32
- local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
33
- }
34
- unsigned int lane = tid & 31;
35
- unsigned int warp = tid >> 5;
36
- if (lane == 0) smem[warp] = local_sum;
37
- __syncthreads();
38
-
39
- // Warp 0 reduces warp-sums.
40
- __shared__ float mean_s;
41
- if (warp == 0) {
42
- unsigned int nwarps = (bsz + 31) / 32;
43
- float v = (lane < nwarps) ? smem[lane] : 0.0f;
44
- for (int off = 16; off > 0; off >>= 1) {
45
- v += __shfl_down_sync(0xffffffff, v, off);
46
- }
47
- if (tid == 0) {
48
- mean_s = v / (float)n;
49
- }
50
- }
51
- __syncthreads();
52
-
53
- // Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
54
- float mean = mean_s;
55
- for (unsigned int i = tid; i < n; i += bsz) {
56
- float d = active_duty[i] - mean;
57
- boost[i] = expf(-boost_strength * d);
58
- }
59
- }
 
1
+ // Fused mean-reduction + boost-update kernel.
2
+ //
3
+ // Inputs:
4
+ // active_duty[n] (f32)
5
+ // boost_strength (f32)
6
+ //
7
+ // Output:
8
+ // boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
9
+ //
10
+ // Launch: single block (1024 threads), shared mem for reduction. At n=2048
11
+ // each thread handles 2 elements.
12
+
13
+ extern "C" __global__
14
+ void sp_boost_from_duty(
15
+ const float * __restrict__ active_duty, // (n,)
16
+ float * __restrict__ boost, // (n,) in-place out
17
+ float boost_strength,
18
+ unsigned int n
19
+ ) {
20
+ extern __shared__ float smem_raw[];
21
+ float * smem = smem_raw;
22
+ const unsigned int tid = threadIdx.x;
23
+ const unsigned int bsz = blockDim.x;
24
+
25
+ // Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
26
+ float local_sum = 0.0f;
27
+ for (unsigned int i = tid; i < n; i += bsz) {
28
+ local_sum += active_duty[i];
29
+ }
30
+ // Warp reduction.
31
+ for (int off = 16; off > 0; off >>= 1) {
32
+ local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
33
+ }
34
+ unsigned int lane = tid & 31;
35
+ unsigned int warp = tid >> 5;
36
+ if (lane == 0) smem[warp] = local_sum;
37
+ __syncthreads();
38
+
39
+ // Warp 0 reduces warp-sums.
40
+ __shared__ float mean_s;
41
+ if (warp == 0) {
42
+ unsigned int nwarps = (bsz + 31) / 32;
43
+ float v = (lane < nwarps) ? smem[lane] : 0.0f;
44
+ for (int off = 16; off > 0; off >>= 1) {
45
+ v += __shfl_down_sync(0xffffffff, v, off);
46
+ }
47
+ if (tid == 0) {
48
+ mean_s = v / (float)n;
49
+ }
50
+ }
51
+ __syncthreads();
52
+
53
+ // Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
54
+ float mean = mean_s;
55
+ for (unsigned int i = tid; i < n; i += bsz) {
56
+ float d = active_duty[i] - mean;
57
+ boost[i] = expf(-boost_strength * d);
58
+ }
59
+ }
overlay/htm_rust/src/gpu/kernels/sp_duty.cu CHANGED
@@ -1,45 +1,45 @@
1
- // Duty cycle + boost update kernel.
2
- //
3
- // For each column c (one thread each):
4
- // active_sample = active_mask[c] ? 1 : 0
5
- // overlap_sample = raw_overlap[c] >= stim_thr ? 1 : 0
6
- // active_duty[c] = (1-alpha) * active_duty[c] + alpha * active_sample
7
- // overlap_duty[c] = (1-alpha) * overlap_duty[c] + alpha * overlap_sample
8
- //
9
- // Then, if learn:
10
- // boost[c] = exp(-boost_strength * (active_duty[c] - mean_duty))
11
- // mean_duty is computed on the host (one reduction) and passed in.
12
-
13
- extern "C" __global__
14
- void sp_duty_update(
15
- const unsigned char * __restrict__ active_mask, // (n_columns,)
16
- const unsigned int * __restrict__ raw_overlap, // (n_columns,)
17
- float * __restrict__ active_duty, // (n_columns,) in-place
18
- float * __restrict__ overlap_duty, // (n_columns,) in-place
19
- float * __restrict__ boost, // (n_columns,) in-place
20
- float alpha,
21
- float stim_thr,
22
- float boost_strength, // 0 to skip boost
23
- float mean_duty,
24
- unsigned int learn_flag, // 0 or 1
25
- unsigned int n_columns
26
- ) {
27
- unsigned int c = blockIdx.x * blockDim.x + threadIdx.x;
28
- if (c >= n_columns) return;
29
-
30
- float ad = active_duty[c];
31
- float od = overlap_duty[c];
32
-
33
- float a_sample = (active_mask[c] != 0) ? 1.0f : 0.0f;
34
- float o_sample = ((float)raw_overlap[c] >= stim_thr) ? 1.0f : 0.0f;
35
-
36
- ad = (1.0f - alpha) * ad + alpha * a_sample;
37
- od = (1.0f - alpha) * od + alpha * o_sample;
38
-
39
- active_duty[c] = ad;
40
- overlap_duty[c] = od;
41
-
42
- if (learn_flag && boost_strength > 0.0f) {
43
- boost[c] = expf(-boost_strength * (ad - mean_duty));
44
- }
45
- }
 
1
+ // Duty cycle + boost update kernel.
2
+ //
3
+ // For each column c (one thread each):
4
+ // active_sample = active_mask[c] ? 1 : 0
5
+ // overlap_sample = raw_overlap[c] >= stim_thr ? 1 : 0
6
+ // active_duty[c] = (1-alpha) * active_duty[c] + alpha * active_sample
7
+ // overlap_duty[c] = (1-alpha) * overlap_duty[c] + alpha * overlap_sample
8
+ //
9
+ // Then, if learn:
10
+ // boost[c] = exp(-boost_strength * (active_duty[c] - mean_duty))
11
+ // mean_duty is computed on the host (one reduction) and passed in.
12
+
13
+ extern "C" __global__
14
+ void sp_duty_update(
15
+ const unsigned char * __restrict__ active_mask, // (n_columns,)
16
+ const unsigned int * __restrict__ raw_overlap, // (n_columns,)
17
+ float * __restrict__ active_duty, // (n_columns,) in-place
18
+ float * __restrict__ overlap_duty, // (n_columns,) in-place
19
+ float * __restrict__ boost, // (n_columns,) in-place
20
+ float alpha,
21
+ float stim_thr,
22
+ float boost_strength, // 0 to skip boost
23
+ float mean_duty,
24
+ unsigned int learn_flag, // 0 or 1
25
+ unsigned int n_columns
26
+ ) {
27
+ unsigned int c = blockIdx.x * blockDim.x + threadIdx.x;
28
+ if (c >= n_columns) return;
29
+
30
+ float ad = active_duty[c];
31
+ float od = overlap_duty[c];
32
+
33
+ float a_sample = (active_mask[c] != 0) ? 1.0f : 0.0f;
34
+ float o_sample = ((float)raw_overlap[c] >= stim_thr) ? 1.0f : 0.0f;
35
+
36
+ ad = (1.0f - alpha) * ad + alpha * a_sample;
37
+ od = (1.0f - alpha) * od + alpha * o_sample;
38
+
39
+ active_duty[c] = ad;
40
+ overlap_duty[c] = od;
41
+
42
+ if (learn_flag && boost_strength > 0.0f) {
43
+ boost[c] = expf(-boost_strength * (ad - mean_duty));
44
+ }
45
+ }
overlay/htm_rust/src/gpu/kernels/sp_learn.cu CHANGED
@@ -1,45 +1,45 @@
1
- // SP Hebbian learning kernel.
2
- //
3
- // For each active (winner) column c, for each of its synapses s:
4
- // if input[bit[c][s]] active: perm += inc
5
- // else: perm -= dec
6
- // Clamp to [0, 1].
7
- //
8
- // Launch: one block per column (2048 blocks), but we predicate on
9
- // active_mask[c] to avoid launching k-specific blocks.
10
- //
11
- // This matches the CPU reference line-for-line:
12
- // src/sp.rs lines 157-169.
13
-
14
- extern "C" __global__
15
- void sp_learn(
16
- const unsigned char * __restrict__ active_mask, // (n_columns,) 0/1
17
- const unsigned char * __restrict__ inp, // (input_bits,)
18
- const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
19
- float * __restrict__ syn_perm, // (n_columns * S,) in-place
20
- float inc,
21
- float dec,
22
- unsigned int synapses_per_col,
23
- unsigned int n_columns
24
- ) {
25
- const unsigned int c = blockIdx.x;
26
- if (c >= n_columns) return;
27
- if (active_mask[c] == 0) return;
28
-
29
- const unsigned int base = c * synapses_per_col;
30
- const unsigned int tid = threadIdx.x;
31
- const unsigned int bsz = blockDim.x;
32
-
33
- for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
34
- unsigned int b = syn_bit[base + s];
35
- float p = syn_perm[base + s];
36
- if (inp[b] != 0) {
37
- p += inc;
38
- if (p > 1.0f) p = 1.0f;
39
- } else {
40
- p -= dec;
41
- if (p < 0.0f) p = 0.0f;
42
- }
43
- syn_perm[base + s] = p;
44
- }
45
- }
 
1
+ // SP Hebbian learning kernel.
2
+ //
3
+ // For each active (winner) column c, for each of its synapses s:
4
+ // if input[bit[c][s]] active: perm += inc
5
+ // else: perm -= dec
6
+ // Clamp to [0, 1].
7
+ //
8
+ // Launch: one block per column (2048 blocks), but we predicate on
9
+ // active_mask[c] to avoid launching k-specific blocks.
10
+ //
11
+ // This matches the CPU reference line-for-line:
12
+ // src/sp.rs lines 157-169.
13
+
14
+ extern "C" __global__
15
+ void sp_learn(
16
+ const unsigned char * __restrict__ active_mask, // (n_columns,) 0/1
17
+ const unsigned char * __restrict__ inp, // (input_bits,)
18
+ const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
19
+ float * __restrict__ syn_perm, // (n_columns * S,) in-place
20
+ float inc,
21
+ float dec,
22
+ unsigned int synapses_per_col,
23
+ unsigned int n_columns
24
+ ) {
25
+ const unsigned int c = blockIdx.x;
26
+ if (c >= n_columns) return;
27
+ if (active_mask[c] == 0) return;
28
+
29
+ const unsigned int base = c * synapses_per_col;
30
+ const unsigned int tid = threadIdx.x;
31
+ const unsigned int bsz = blockDim.x;
32
+
33
+ for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
34
+ unsigned int b = syn_bit[base + s];
35
+ float p = syn_perm[base + s];
36
+ if (inp[b] != 0) {
37
+ p += inc;
38
+ if (p > 1.0f) p = 1.0f;
39
+ } else {
40
+ p -= dec;
41
+ if (p < 0.0f) p = 0.0f;
42
+ }
43
+ syn_perm[base + s] = p;
44
+ }
45
+ }
overlay/htm_rust/src/gpu/kernels/sp_overlap.cu CHANGED
@@ -1,78 +1,78 @@
1
- // SP overlap kernel.
2
- //
3
- // For each column c (one CUDA block), compute:
4
- // overlap[c] = sum over its synapse list of {inp[bit[c][s]] && perm[c][s] >= conn_thr}
5
- // boosted[c] = overlap[c] * boost[c]
6
- // raw_overlap[c] = overlap[c] (also returned so host can drive duty cycle)
7
- //
8
- // Memory layout (flat, column-major with per-column stride = synapses_per_col):
9
- // syn_bit[c * S + s] : u32 index into input SDR
10
- // syn_perm[c * S + s] : f32 permanence in [0, 1]
11
- // boost[c] : f32
12
- // inp[b] : u8 0/1
13
- // Output:
14
- // raw[c] : u32
15
- // boosted[c] : f32
16
- //
17
- // Launch:
18
- // grid = n_columns
19
- // block = 128 (or 256) β€” one warp-sweep across synapses; many warps give
20
- // parallel reduction across S (typically S=40).
21
- //
22
- // At S=40 this is completely latency-bound; we coalesce reads and do a
23
- // warp-shuffle reduction. For clarity we use a simple block-wide shared-mem
24
- // reduction which is sufficient for S <= 1024 and has zero correctness risk.
25
-
26
- extern "C" __global__
27
- void sp_overlap(
28
- const unsigned char * __restrict__ inp, // (input_bits,)
29
- const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
30
- const float * __restrict__ syn_perm,// (n_columns * S,)
31
- const float * __restrict__ boost, // (n_columns,)
32
- float conn_thr,
33
- unsigned int synapses_per_col, // S
34
- unsigned int n_columns,
35
- unsigned int * __restrict__ raw_out, // (n_columns,)
36
- float * __restrict__ boosted_out // (n_columns,)
37
- ) {
38
- const unsigned int c = blockIdx.x;
39
- if (c >= n_columns) return;
40
-
41
- const unsigned int base = c * synapses_per_col;
42
- const unsigned int tid = threadIdx.x;
43
- const unsigned int bsz = blockDim.x;
44
-
45
- // Per-thread partial count.
46
- unsigned int local = 0;
47
- for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
48
- unsigned int b = syn_bit[base + s];
49
- float p = syn_perm[base + s];
50
- // Branchless: only counts when input active AND perm connected.
51
- // Using (inp != 0) to tolerate u8 layout.
52
- unsigned int hit = ((inp[b] != 0) && (p >= conn_thr)) ? 1u : 0u;
53
- local += hit;
54
- }
55
-
56
- // Block-wide reduction in shared memory.
57
- __shared__ unsigned int smem[32];
58
-
59
- // Warp-level reduction via shuffle.
60
- unsigned int lane = tid & 31;
61
- unsigned int warp = tid >> 5;
62
- for (int off = 16; off > 0; off >>= 1) {
63
- local += __shfl_down_sync(0xffffffff, local, off);
64
- }
65
- if (lane == 0) smem[warp] = local;
66
- __syncthreads();
67
-
68
- if (warp == 0) {
69
- unsigned int v = (tid < (bsz + 31) / 32) ? smem[lane] : 0;
70
- for (int off = 16; off > 0; off >>= 1) {
71
- v += __shfl_down_sync(0xffffffff, v, off);
72
- }
73
- if (tid == 0) {
74
- raw_out[c] = v;
75
- boosted_out[c] = (float)v * boost[c];
76
- }
77
- }
78
- }
 
1
+ // SP overlap kernel.
2
+ //
3
+ // For each column c (one CUDA block), compute:
4
+ // overlap[c] = sum over its synapse list of {inp[bit[c][s]] && perm[c][s] >= conn_thr}
5
+ // boosted[c] = overlap[c] * boost[c]
6
+ // raw_overlap[c] = overlap[c] (also returned so host can drive duty cycle)
7
+ //
8
+ // Memory layout (flat, column-major with per-column stride = synapses_per_col):
9
+ // syn_bit[c * S + s] : u32 index into input SDR
10
+ // syn_perm[c * S + s] : f32 permanence in [0, 1]
11
+ // boost[c] : f32
12
+ // inp[b] : u8 0/1
13
+ // Output:
14
+ // raw[c] : u32
15
+ // boosted[c] : f32
16
+ //
17
+ // Launch:
18
+ // grid = n_columns
19
+ // block = 128 (or 256) β€” one warp-sweep across synapses; many warps give
20
+ // parallel reduction across S (typically S=40).
21
+ //
22
+ // At S=40 this is completely latency-bound; we coalesce reads and do a
23
+ // warp-shuffle reduction. For clarity we use a simple block-wide shared-mem
24
+ // reduction which is sufficient for S <= 1024 and has zero correctness risk.
25
+
26
+ extern "C" __global__
27
+ void sp_overlap(
28
+ const unsigned char * __restrict__ inp, // (input_bits,)
29
+ const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
30
+ const float * __restrict__ syn_perm,// (n_columns * S,)
31
+ const float * __restrict__ boost, // (n_columns,)
32
+ float conn_thr,
33
+ unsigned int synapses_per_col, // S
34
+ unsigned int n_columns,
35
+ unsigned int * __restrict__ raw_out, // (n_columns,)
36
+ float * __restrict__ boosted_out // (n_columns,)
37
+ ) {
38
+ const unsigned int c = blockIdx.x;
39
+ if (c >= n_columns) return;
40
+
41
+ const unsigned int base = c * synapses_per_col;
42
+ const unsigned int tid = threadIdx.x;
43
+ const unsigned int bsz = blockDim.x;
44
+
45
+ // Per-thread partial count.
46
+ unsigned int local = 0;
47
+ for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
48
+ unsigned int b = syn_bit[base + s];
49
+ float p = syn_perm[base + s];
50
+ // Branchless: only counts when input active AND perm connected.
51
+ // Using (inp != 0) to tolerate u8 layout.
52
+ unsigned int hit = ((inp[b] != 0) && (p >= conn_thr)) ? 1u : 0u;
53
+ local += hit;
54
+ }
55
+
56
+ // Block-wide reduction in shared memory.
57
+ __shared__ unsigned int smem[32];
58
+
59
+ // Warp-level reduction via shuffle.
60
+ unsigned int lane = tid & 31;
61
+ unsigned int warp = tid >> 5;
62
+ for (int off = 16; off > 0; off >>= 1) {
63
+ local += __shfl_down_sync(0xffffffff, local, off);
64
+ }
65
+ if (lane == 0) smem[warp] = local;
66
+ __syncthreads();
67
+
68
+ if (warp == 0) {
69
+ unsigned int v = (tid < (bsz + 31) / 32) ? smem[lane] : 0;
70
+ for (int off = 16; off > 0; off >>= 1) {
71
+ v += __shfl_down_sync(0xffffffff, v, off);
72
+ }
73
+ if (tid == 0) {
74
+ raw_out[c] = v;
75
+ boosted_out[c] = (float)v * boost[c];
76
+ }
77
+ }
78
+ }
overlay/htm_rust/src/gpu/kernels/sp_topk.cu CHANGED
@@ -1,117 +1,117 @@
1
- // Top-K column selection.
2
- //
3
- // Inputs:
4
- // boosted[n_columns] : f32 score
5
- // Output:
6
- // active_mask[n_columns] : u8 0/1, exactly k ones
7
- //
8
- // Tie-breaking: when scores are equal, the LOWER column index wins (matches
9
- // CPU reference `select_nth_unstable_by` with secondary index comparator).
10
- //
11
- // Strategy: a single-block implementation. n_columns is typically 2048, which
12
- // fits comfortably in shared memory. We use a bitonic top-k via per-thread
13
- // radix-select of the (score, -index) key. At kβ‰ˆ41 of n=2048 the simplest
14
- // correct approach is a thresholding pass:
15
- //
16
- // 1. Radix-like bucket pass to find the k-th largest score.
17
- // 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
18
- //
19
- // For strict index-ordered tie-break we materialise a 64-bit key:
20
- // key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
21
- // Larger key = (higher score) OR (same score, smaller index).
22
- //
23
- // Then we find the k-th largest 64-bit key via radix-select and mark all
24
- // columns whose key >= threshold. This is O(n_cols * log k) and well under
25
- // 100 ΞΌs for n=2048, k=41 on sm_86.
26
- //
27
- // For simplicity and correctness this kernel uses a single-block parallel
28
- // selection sort variant (find max β†’ mark β†’ zero β†’ repeat, k iterations).
29
- // At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
30
- // fast.
31
-
32
- extern "C" __global__
33
- void sp_topk_select(
34
- const float * __restrict__ scores, // (n_columns,)
35
- unsigned int n_columns,
36
- unsigned int k,
37
- unsigned char * __restrict__ active_out // (n_columns,)
38
- ) {
39
- extern __shared__ float smem[];
40
- // Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
41
- // smem[n..n+32*2] = reduction scratch (score + index, per warp)
42
- float * work = smem;
43
- const unsigned int tid = threadIdx.x;
44
- const unsigned int bsz = blockDim.x;
45
-
46
- // Load scores into shared; also init active_out = 0.
47
- for (unsigned int i = tid; i < n_columns; i += bsz) {
48
- work[i] = scores[i];
49
- active_out[i] = 0;
50
- }
51
- __syncthreads();
52
-
53
- __shared__ int winner_idx;
54
- __shared__ float winner_score;
55
-
56
- for (unsigned int iter = 0; iter < k; ++iter) {
57
- // Find (argmax score, lowest index for ties).
58
- float best_s = -INFINITY;
59
- int best_i = n_columns; // sentinel larger than any index
60
-
61
- for (unsigned int i = tid; i < n_columns; i += bsz) {
62
- float s = work[i];
63
- if (s > best_s || (s == best_s && (int)i < best_i)) {
64
- best_s = s;
65
- best_i = (int)i;
66
- }
67
- }
68
-
69
- // Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
70
- unsigned int mask = 0xffffffff;
71
- for (int off = 16; off > 0; off >>= 1) {
72
- float os = __shfl_down_sync(mask, best_s, off);
73
- int oi = __shfl_down_sync(mask, best_i, off);
74
- if (os > best_s || (os == best_s && oi < best_i)) {
75
- best_s = os;
76
- best_i = oi;
77
- }
78
- }
79
- // Warp 0 collects lane 0 values from other warps via shared mem.
80
- __shared__ float warp_s[32];
81
- __shared__ int warp_i[32];
82
- unsigned int lane = tid & 31;
83
- unsigned int warp = tid >> 5;
84
- if (lane == 0) {
85
- warp_s[warp] = best_s;
86
- warp_i[warp] = best_i;
87
- }
88
- __syncthreads();
89
-
90
- if (warp == 0) {
91
- unsigned int nwarps = (bsz + 31) / 32;
92
- float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
93
- int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
94
- for (int off = 16; off > 0; off >>= 1) {
95
- float os = __shfl_down_sync(mask, s, off);
96
- int oi = __shfl_down_sync(mask, i, off);
97
- if (os > s || (os == s && oi < i)) {
98
- s = os;
99
- i = oi;
100
- }
101
- }
102
- if (tid == 0) {
103
- winner_score = s;
104
- winner_idx = i;
105
- }
106
- }
107
- __syncthreads();
108
-
109
- if (tid == 0) {
110
- if (winner_idx < (int)n_columns) {
111
- active_out[winner_idx] = 1;
112
- work[winner_idx] = -INFINITY;
113
- }
114
- }
115
- __syncthreads();
116
- }
117
- }
 
1
+ // Top-K column selection.
2
+ //
3
+ // Inputs:
4
+ // boosted[n_columns] : f32 score
5
+ // Output:
6
+ // active_mask[n_columns] : u8 0/1, exactly k ones
7
+ //
8
+ // Tie-breaking: when scores are equal, the LOWER column index wins (matches
9
+ // CPU reference `select_nth_unstable_by` with secondary index comparator).
10
+ //
11
+ // Strategy: a single-block implementation. n_columns is typically 2048, which
12
+ // fits comfortably in shared memory. We use a bitonic top-k via per-thread
13
+ // radix-select of the (score, -index) key. At kβ‰ˆ41 of n=2048 the simplest
14
+ // correct approach is a thresholding pass:
15
+ //
16
+ // 1. Radix-like bucket pass to find the k-th largest score.
17
+ // 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
18
+ //
19
+ // For strict index-ordered tie-break we materialise a 64-bit key:
20
+ // key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
21
+ // Larger key = (higher score) OR (same score, smaller index).
22
+ //
23
+ // Then we find the k-th largest 64-bit key via radix-select and mark all
24
+ // columns whose key >= threshold. This is O(n_cols * log k) and well under
25
+ // 100 ΞΌs for n=2048, k=41 on sm_86.
26
+ //
27
+ // For simplicity and correctness this kernel uses a single-block parallel
28
+ // selection sort variant (find max β†’ mark β†’ zero β†’ repeat, k iterations).
29
+ // At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
30
+ // fast.
31
+
32
+ extern "C" __global__
33
+ void sp_topk_select(
34
+ const float * __restrict__ scores, // (n_columns,)
35
+ unsigned int n_columns,
36
+ unsigned int k,
37
+ unsigned char * __restrict__ active_out // (n_columns,)
38
+ ) {
39
+ extern __shared__ float smem[];
40
+ // Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
41
+ // smem[n..n+32*2] = reduction scratch (score + index, per warp)
42
+ float * work = smem;
43
+ const unsigned int tid = threadIdx.x;
44
+ const unsigned int bsz = blockDim.x;
45
+
46
+ // Load scores into shared; also init active_out = 0.
47
+ for (unsigned int i = tid; i < n_columns; i += bsz) {
48
+ work[i] = scores[i];
49
+ active_out[i] = 0;
50
+ }
51
+ __syncthreads();
52
+
53
+ __shared__ int winner_idx;
54
+ __shared__ float winner_score;
55
+
56
+ for (unsigned int iter = 0; iter < k; ++iter) {
57
+ // Find (argmax score, lowest index for ties).
58
+ float best_s = -INFINITY;
59
+ int best_i = n_columns; // sentinel larger than any index
60
+
61
+ for (unsigned int i = tid; i < n_columns; i += bsz) {
62
+ float s = work[i];
63
+ if (s > best_s || (s == best_s && (int)i < best_i)) {
64
+ best_s = s;
65
+ best_i = (int)i;
66
+ }
67
+ }
68
+
69
+ // Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
70
+ unsigned int mask = 0xffffffff;
71
+ for (int off = 16; off > 0; off >>= 1) {
72
+ float os = __shfl_down_sync(mask, best_s, off);
73
+ int oi = __shfl_down_sync(mask, best_i, off);
74
+ if (os > best_s || (os == best_s && oi < best_i)) {
75
+ best_s = os;
76
+ best_i = oi;
77
+ }
78
+ }
79
+ // Warp 0 collects lane 0 values from other warps via shared mem.
80
+ __shared__ float warp_s[32];
81
+ __shared__ int warp_i[32];
82
+ unsigned int lane = tid & 31;
83
+ unsigned int warp = tid >> 5;
84
+ if (lane == 0) {
85
+ warp_s[warp] = best_s;
86
+ warp_i[warp] = best_i;
87
+ }
88
+ __syncthreads();
89
+
90
+ if (warp == 0) {
91
+ unsigned int nwarps = (bsz + 31) / 32;
92
+ float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
93
+ int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
94
+ for (int off = 16; off > 0; off >>= 1) {
95
+ float os = __shfl_down_sync(mask, s, off);
96
+ int oi = __shfl_down_sync(mask, i, off);
97
+ if (os > s || (os == s && oi < i)) {
98
+ s = os;
99
+ i = oi;
100
+ }
101
+ }
102
+ if (tid == 0) {
103
+ winner_score = s;
104
+ winner_idx = i;
105
+ }
106
+ }
107
+ __syncthreads();
108
+
109
+ if (tid == 0) {
110
+ if (winner_idx < (int)n_columns) {
111
+ active_out[winner_idx] = 1;
112
+ work[winner_idx] = -INFINITY;
113
+ }
114
+ }
115
+ __syncthreads();
116
+ }
117
+ }
overlay/htm_rust/src/gpu/kernels/tm_activate.cu CHANGED
@@ -1,66 +1,66 @@
1
- // TM activate kernel. See tm_predict.cu for TmConfig.
2
-
3
- struct TmConfig {
4
- unsigned int activation_threshold;
5
- unsigned int learning_threshold;
6
- unsigned int cells_per_column;
7
- unsigned int synapses_per_segment;
8
- unsigned int n_segments;
9
- unsigned int n_cells;
10
- unsigned int max_segments_per_cell;
11
- unsigned int max_new_synapses;
12
- int conn_thr_i16;
13
- int perm_inc_i16;
14
- int perm_dec_i16;
15
- int predicted_seg_dec_i16;
16
- int initial_perm_i16;
17
- unsigned int iter_seed;
18
- unsigned int n_cols;
19
- unsigned int bits_words;
20
- };
21
-
22
- extern "C" __global__
23
- void tm_activate(
24
- const unsigned char * __restrict__ sp_active_mask,
25
- const unsigned char * __restrict__ col_predicted,
26
- const unsigned int * __restrict__ cell_predictive_bits,
27
- unsigned int * __restrict__ cell_active_bits,
28
- unsigned int * __restrict__ cell_winner_bits,
29
- unsigned int * __restrict__ unpredicted_count,
30
- unsigned int * __restrict__ burst_cols_flat,
31
- unsigned int * __restrict__ burst_cols_count,
32
- TmConfig cfg
33
- ) {
34
- unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
35
- if (col >= cfg.n_cols) return;
36
- if (sp_active_mask[col] == 0) return;
37
-
38
- unsigned int base_cell = col * cfg.cells_per_column;
39
-
40
- if (col_predicted[col]) {
41
- for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
42
- unsigned int cell = base_cell + k;
43
- unsigned int word_idx = cell >> 5;
44
- unsigned int bit_mask = 1u << (cell & 31u);
45
- unsigned int pred_word = cell_predictive_bits[word_idx];
46
- if (pred_word & bit_mask) {
47
- atomicOr(&cell_active_bits[word_idx], bit_mask);
48
- atomicOr(&cell_winner_bits[word_idx], bit_mask);
49
- }
50
- }
51
- } else {
52
- atomicAdd(unpredicted_count, 1u);
53
- for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
54
- unsigned int cell = base_cell + k;
55
- unsigned int word_idx = cell >> 5;
56
- unsigned int bit_mask = 1u << (cell & 31u);
57
- atomicOr(&cell_active_bits[word_idx], bit_mask);
58
- }
59
- unsigned int winner = base_cell;
60
- unsigned int word_idx = winner >> 5;
61
- unsigned int bit_mask = 1u << (winner & 31u);
62
- atomicOr(&cell_winner_bits[word_idx], bit_mask);
63
- unsigned int slot = atomicAdd(burst_cols_count, 1u);
64
- burst_cols_flat[slot] = col;
65
- }
66
- }
 
1
+ // TM activate kernel. See tm_predict.cu for TmConfig.
2
+
3
+ struct TmConfig {
4
+ unsigned int activation_threshold;
5
+ unsigned int learning_threshold;
6
+ unsigned int cells_per_column;
7
+ unsigned int synapses_per_segment;
8
+ unsigned int n_segments;
9
+ unsigned int n_cells;
10
+ unsigned int max_segments_per_cell;
11
+ unsigned int max_new_synapses;
12
+ int conn_thr_i16;
13
+ int perm_inc_i16;
14
+ int perm_dec_i16;
15
+ int predicted_seg_dec_i16;
16
+ int initial_perm_i16;
17
+ unsigned int iter_seed;
18
+ unsigned int n_cols;
19
+ unsigned int bits_words;
20
+ };
21
+
22
+ extern "C" __global__
23
+ void tm_activate(
24
+ const unsigned char * __restrict__ sp_active_mask,
25
+ const unsigned char * __restrict__ col_predicted,
26
+ const unsigned int * __restrict__ cell_predictive_bits,
27
+ unsigned int * __restrict__ cell_active_bits,
28
+ unsigned int * __restrict__ cell_winner_bits,
29
+ unsigned int * __restrict__ unpredicted_count,
30
+ unsigned int * __restrict__ burst_cols_flat,
31
+ unsigned int * __restrict__ burst_cols_count,
32
+ TmConfig cfg
33
+ ) {
34
+ unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
35
+ if (col >= cfg.n_cols) return;
36
+ if (sp_active_mask[col] == 0) return;
37
+
38
+ unsigned int base_cell = col * cfg.cells_per_column;
39
+
40
+ if (col_predicted[col]) {
41
+ for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
42
+ unsigned int cell = base_cell + k;
43
+ unsigned int word_idx = cell >> 5;
44
+ unsigned int bit_mask = 1u << (cell & 31u);
45
+ unsigned int pred_word = cell_predictive_bits[word_idx];
46
+ if (pred_word & bit_mask) {
47
+ atomicOr(&cell_active_bits[word_idx], bit_mask);
48
+ atomicOr(&cell_winner_bits[word_idx], bit_mask);
49
+ }
50
+ }
51
+ } else {
52
+ atomicAdd(unpredicted_count, 1u);
53
+ for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
54
+ unsigned int cell = base_cell + k;
55
+ unsigned int word_idx = cell >> 5;
56
+ unsigned int bit_mask = 1u << (cell & 31u);
57
+ atomicOr(&cell_active_bits[word_idx], bit_mask);
58
+ }
59
+ unsigned int winner = base_cell;
60
+ unsigned int word_idx = winner >> 5;
61
+ unsigned int bit_mask = 1u << (winner & 31u);
62
+ atomicOr(&cell_winner_bits[word_idx], bit_mask);
63
+ unsigned int slot = atomicAdd(burst_cols_count, 1u);
64
+ burst_cols_flat[slot] = col;
65
+ }
66
+ }
overlay/htm_rust/src/gpu/kernels/tm_anomaly.cu CHANGED
@@ -1,43 +1,43 @@
1
- // TM anomaly kernel.
2
- //
3
- // Computes:
4
- // n_active = sum of sp_active_mask
5
- // anomaly = unpredicted_count / n_active (if n_active > 0)
6
- // = 0 (else)
7
- //
8
- // Launch: single block, 256 threads.
9
-
10
- extern "C" __global__
11
- void tm_anomaly(
12
- const unsigned char * __restrict__ sp_active_mask,
13
- const unsigned int * __restrict__ unpredicted_count,
14
- float * __restrict__ anomaly_out, // (1,) or (t_slot,)
15
- unsigned int t_slot,
16
- unsigned int n_cols
17
- ) {
18
- const unsigned int tid = threadIdx.x;
19
- __shared__ unsigned int n_active_s;
20
-
21
- if (tid == 0) n_active_s = 0u;
22
- __syncthreads();
23
-
24
- unsigned int local = 0u;
25
- for (unsigned int i = tid; i < n_cols; i += blockDim.x) {
26
- if (sp_active_mask[i]) local += 1u;
27
- }
28
- // Warp reduce.
29
- for (int off = 16; off > 0; off >>= 1) {
30
- local += __shfl_down_sync(0xffffffffu, local, off);
31
- }
32
- if ((tid & 31u) == 0) {
33
- atomicAdd(&n_active_s, local);
34
- }
35
- __syncthreads();
36
-
37
- if (tid == 0) {
38
- unsigned int total = n_active_s;
39
- unsigned int bad = unpredicted_count[0];
40
- float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
41
- anomaly_out[t_slot] = anom;
42
- }
43
- }
 
1
+ // TM anomaly kernel.
2
+ //
3
+ // Computes:
4
+ // n_active = sum of sp_active_mask
5
+ // anomaly = unpredicted_count / n_active (if n_active > 0)
6
+ // = 0 (else)
7
+ //
8
+ // Launch: single block, 256 threads.
9
+
10
+ extern "C" __global__
11
+ void tm_anomaly(
12
+ const unsigned char * __restrict__ sp_active_mask,
13
+ const unsigned int * __restrict__ unpredicted_count,
14
+ float * __restrict__ anomaly_out, // (1,) or (t_slot,)
15
+ unsigned int t_slot,
16
+ unsigned int n_cols
17
+ ) {
18
+ const unsigned int tid = threadIdx.x;
19
+ __shared__ unsigned int n_active_s;
20
+
21
+ if (tid == 0) n_active_s = 0u;
22
+ __syncthreads();
23
+
24
+ unsigned int local = 0u;
25
+ for (unsigned int i = tid; i < n_cols; i += blockDim.x) {
26
+ if (sp_active_mask[i]) local += 1u;
27
+ }
28
+ // Warp reduce.
29
+ for (int off = 16; off > 0; off >>= 1) {
30
+ local += __shfl_down_sync(0xffffffffu, local, off);
31
+ }
32
+ if ((tid & 31u) == 0) {
33
+ atomicAdd(&n_active_s, local);
34
+ }
35
+ __syncthreads();
36
+
37
+ if (tid == 0) {
38
+ unsigned int total = n_active_s;
39
+ unsigned int bad = unpredicted_count[0];
40
+ float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
41
+ anomaly_out[t_slot] = anom;
42
+ }
43
+ }
overlay/htm_rust/src/gpu/kernels/tm_grow.cu CHANGED
@@ -1,155 +1,155 @@
1
- // TM grow+reinforce kernel.
2
- //
3
- // For each bursting column:
4
- // If col_best_match[col] is non-zero (i.e. at least one matching segment
5
- // with num_active_potential >= learning_threshold exists on cells in this col):
6
- // Target = that matching segment.
7
- // Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
8
- // Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
9
- // Else:
10
- // Allocate a fresh segment slot on winner cell (cell 0 of col).
11
- // Grow up to max_new synapses to prev_winners (no reinforce needed β€” new seg).
12
- //
13
- // This mirrors the CPU TM burst logic.
14
-
15
- struct TmConfig {
16
- unsigned int activation_threshold;
17
- unsigned int learning_threshold;
18
- unsigned int cells_per_column;
19
- unsigned int synapses_per_segment;
20
- unsigned int n_segments;
21
- unsigned int n_cells;
22
- unsigned int max_segments_per_cell;
23
- unsigned int max_new_synapses;
24
- int conn_thr_i16;
25
- int perm_inc_i16;
26
- int perm_dec_i16;
27
- int predicted_seg_dec_i16;
28
- int initial_perm_i16;
29
- unsigned int iter_seed;
30
- unsigned int n_cols;
31
- unsigned int bits_words;
32
- };
33
-
34
- extern "C" __global__
35
- void tm_grow(
36
- unsigned int * __restrict__ seg_cell_id,
37
- unsigned int * __restrict__ seg_syn_count,
38
- unsigned int * __restrict__ syn_presyn,
39
- short * __restrict__ syn_perm,
40
- unsigned int * __restrict__ cell_seg_count,
41
- const unsigned int * __restrict__ burst_cols_flat,
42
- const unsigned int * __restrict__ burst_cols_count,
43
- const unsigned int * __restrict__ prev_winner_bits,
44
- const unsigned int * __restrict__ prev_active_bits,
45
- const unsigned int * __restrict__ col_best_match,
46
- TmConfig cfg
47
- ) {
48
- const unsigned int b = blockIdx.x;
49
- const unsigned int n_burst_cols = burst_cols_count[0];
50
- if (b >= n_burst_cols) return;
51
- const unsigned int tid = threadIdx.x;
52
-
53
- const unsigned int col = burst_cols_flat[b];
54
-
55
- __shared__ unsigned int shared_seg_id;
56
- __shared__ unsigned int shared_existing_syn_count;
57
- __shared__ unsigned int shared_grown;
58
- __shared__ unsigned int shared_is_new;
59
- __shared__ unsigned int shared_start_offset;
60
-
61
- if (tid == 0) {
62
- unsigned int match_key = col_best_match[col];
63
- if (match_key != 0u) {
64
- // Reuse matching segment.
65
- unsigned int seg_id = match_key & 0x1FFFFFu;
66
- shared_seg_id = seg_id;
67
- shared_existing_syn_count = seg_syn_count[seg_id];
68
- shared_is_new = 0u;
69
- } else {
70
- // Allocate new segment on winner cell (cell 0 of col).
71
- unsigned int winner_cell = col * cfg.cells_per_column;
72
- unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
73
- if (slot >= cfg.max_segments_per_cell) {
74
- slot = slot % cfg.max_segments_per_cell;
75
- }
76
- unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
77
- seg_cell_id[seg_id] = winner_cell;
78
- seg_syn_count[seg_id] = 0;
79
- shared_seg_id = seg_id;
80
- shared_existing_syn_count = 0u;
81
- shared_is_new = 1u;
82
- }
83
- shared_grown = 0u;
84
- shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
85
- }
86
- __syncthreads();
87
-
88
- const unsigned int seg_id = shared_seg_id;
89
- const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
90
- const unsigned int existing_syn = shared_existing_syn_count;
91
- const unsigned int is_new = shared_is_new;
92
- const unsigned int start = shared_start_offset;
93
-
94
- // PHASE 1: If reusing, reinforce existing synapses.
95
- if (!is_new) {
96
- for (unsigned int s = tid; s < existing_syn; s += 32u) {
97
- unsigned int presyn = syn_presyn[seg_base + s];
98
- unsigned int word = prev_active_bits[presyn >> 5];
99
- unsigned int bit = (word >> (presyn & 31u)) & 1u;
100
- int p = (int)syn_perm[seg_base + s];
101
- if (bit) {
102
- int np = p + cfg.perm_inc_i16;
103
- if (np > 32767) np = 32767;
104
- syn_perm[seg_base + s] = (short)np;
105
- } else {
106
- int np = p - cfg.perm_dec_i16;
107
- if (np < 0) np = 0;
108
- syn_perm[seg_base + s] = (short)np;
109
- }
110
- }
111
- __syncthreads();
112
- }
113
-
114
- // PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
115
- // that aren't already presynaptic to this segment.
116
- const unsigned int room = (cfg.synapses_per_segment > existing_syn)
117
- ? (cfg.synapses_per_segment - existing_syn) : 0u;
118
- const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
119
-
120
- for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
121
- if (shared_grown >= max_grow) break;
122
- unsigned int widx = (start + w_off + tid) % cfg.bits_words;
123
- unsigned int word = prev_winner_bits[widx];
124
- while (word != 0u) {
125
- if (shared_grown >= max_grow) break;
126
- unsigned int bit_pos = __ffs(word) - 1u;
127
- word &= ~(1u << bit_pos);
128
- unsigned int cell = widx * 32u + bit_pos;
129
- if (cell >= cfg.n_cells) continue;
130
-
131
- // Skip if already presynaptic (O(existing_syn) scan; usually small).
132
- bool exists = false;
133
- for (unsigned int s = 0; s < existing_syn; s++) {
134
- if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
135
- }
136
- if (exists) continue;
137
-
138
- unsigned int slot = atomicAdd(&shared_grown, 1u);
139
- if (slot >= max_grow) break;
140
- unsigned int write_idx = existing_syn + slot;
141
- if (write_idx >= cfg.synapses_per_segment) break;
142
- syn_presyn[seg_base + write_idx] = cell;
143
- syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
144
- }
145
- }
146
- __syncthreads();
147
-
148
- if (tid == 0) {
149
- unsigned int grown = shared_grown;
150
- if (grown > max_grow) grown = max_grow;
151
- unsigned int new_count = existing_syn + grown;
152
- if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
153
- seg_syn_count[seg_id] = new_count;
154
- }
155
- }
 
1
+ // TM grow+reinforce kernel.
2
+ //
3
+ // For each bursting column:
4
+ // If col_best_match[col] is non-zero (i.e. at least one matching segment
5
+ // with num_active_potential >= learning_threshold exists on cells in this col):
6
+ // Target = that matching segment.
7
+ // Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
8
+ // Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
9
+ // Else:
10
+ // Allocate a fresh segment slot on winner cell (cell 0 of col).
11
+ // Grow up to max_new synapses to prev_winners (no reinforce needed β€” new seg).
12
+ //
13
+ // This mirrors the CPU TM burst logic.
14
+
15
+ struct TmConfig {
16
+ unsigned int activation_threshold;
17
+ unsigned int learning_threshold;
18
+ unsigned int cells_per_column;
19
+ unsigned int synapses_per_segment;
20
+ unsigned int n_segments;
21
+ unsigned int n_cells;
22
+ unsigned int max_segments_per_cell;
23
+ unsigned int max_new_synapses;
24
+ int conn_thr_i16;
25
+ int perm_inc_i16;
26
+ int perm_dec_i16;
27
+ int predicted_seg_dec_i16;
28
+ int initial_perm_i16;
29
+ unsigned int iter_seed;
30
+ unsigned int n_cols;
31
+ unsigned int bits_words;
32
+ };
33
+
34
+ extern "C" __global__
35
+ void tm_grow(
36
+ unsigned int * __restrict__ seg_cell_id,
37
+ unsigned int * __restrict__ seg_syn_count,
38
+ unsigned int * __restrict__ syn_presyn,
39
+ short * __restrict__ syn_perm,
40
+ unsigned int * __restrict__ cell_seg_count,
41
+ const unsigned int * __restrict__ burst_cols_flat,
42
+ const unsigned int * __restrict__ burst_cols_count,
43
+ const unsigned int * __restrict__ prev_winner_bits,
44
+ const unsigned int * __restrict__ prev_active_bits,
45
+ const unsigned int * __restrict__ col_best_match,
46
+ TmConfig cfg
47
+ ) {
48
+ const unsigned int b = blockIdx.x;
49
+ const unsigned int n_burst_cols = burst_cols_count[0];
50
+ if (b >= n_burst_cols) return;
51
+ const unsigned int tid = threadIdx.x;
52
+
53
+ const unsigned int col = burst_cols_flat[b];
54
+
55
+ __shared__ unsigned int shared_seg_id;
56
+ __shared__ unsigned int shared_existing_syn_count;
57
+ __shared__ unsigned int shared_grown;
58
+ __shared__ unsigned int shared_is_new;
59
+ __shared__ unsigned int shared_start_offset;
60
+
61
+ if (tid == 0) {
62
+ unsigned int match_key = col_best_match[col];
63
+ if (match_key != 0u) {
64
+ // Reuse matching segment.
65
+ unsigned int seg_id = match_key & 0x1FFFFFu;
66
+ shared_seg_id = seg_id;
67
+ shared_existing_syn_count = seg_syn_count[seg_id];
68
+ shared_is_new = 0u;
69
+ } else {
70
+ // Allocate new segment on winner cell (cell 0 of col).
71
+ unsigned int winner_cell = col * cfg.cells_per_column;
72
+ unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
73
+ if (slot >= cfg.max_segments_per_cell) {
74
+ slot = slot % cfg.max_segments_per_cell;
75
+ }
76
+ unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
77
+ seg_cell_id[seg_id] = winner_cell;
78
+ seg_syn_count[seg_id] = 0;
79
+ shared_seg_id = seg_id;
80
+ shared_existing_syn_count = 0u;
81
+ shared_is_new = 1u;
82
+ }
83
+ shared_grown = 0u;
84
+ shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
85
+ }
86
+ __syncthreads();
87
+
88
+ const unsigned int seg_id = shared_seg_id;
89
+ const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
90
+ const unsigned int existing_syn = shared_existing_syn_count;
91
+ const unsigned int is_new = shared_is_new;
92
+ const unsigned int start = shared_start_offset;
93
+
94
+ // PHASE 1: If reusing, reinforce existing synapses.
95
+ if (!is_new) {
96
+ for (unsigned int s = tid; s < existing_syn; s += 32u) {
97
+ unsigned int presyn = syn_presyn[seg_base + s];
98
+ unsigned int word = prev_active_bits[presyn >> 5];
99
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
100
+ int p = (int)syn_perm[seg_base + s];
101
+ if (bit) {
102
+ int np = p + cfg.perm_inc_i16;
103
+ if (np > 32767) np = 32767;
104
+ syn_perm[seg_base + s] = (short)np;
105
+ } else {
106
+ int np = p - cfg.perm_dec_i16;
107
+ if (np < 0) np = 0;
108
+ syn_perm[seg_base + s] = (short)np;
109
+ }
110
+ }
111
+ __syncthreads();
112
+ }
113
+
114
+ // PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
115
+ // that aren't already presynaptic to this segment.
116
+ const unsigned int room = (cfg.synapses_per_segment > existing_syn)
117
+ ? (cfg.synapses_per_segment - existing_syn) : 0u;
118
+ const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
119
+
120
+ for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
121
+ if (shared_grown >= max_grow) break;
122
+ unsigned int widx = (start + w_off + tid) % cfg.bits_words;
123
+ unsigned int word = prev_winner_bits[widx];
124
+ while (word != 0u) {
125
+ if (shared_grown >= max_grow) break;
126
+ unsigned int bit_pos = __ffs(word) - 1u;
127
+ word &= ~(1u << bit_pos);
128
+ unsigned int cell = widx * 32u + bit_pos;
129
+ if (cell >= cfg.n_cells) continue;
130
+
131
+ // Skip if already presynaptic (O(existing_syn) scan; usually small).
132
+ bool exists = false;
133
+ for (unsigned int s = 0; s < existing_syn; s++) {
134
+ if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
135
+ }
136
+ if (exists) continue;
137
+
138
+ unsigned int slot = atomicAdd(&shared_grown, 1u);
139
+ if (slot >= max_grow) break;
140
+ unsigned int write_idx = existing_syn + slot;
141
+ if (write_idx >= cfg.synapses_per_segment) break;
142
+ syn_presyn[seg_base + write_idx] = cell;
143
+ syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
144
+ }
145
+ }
146
+ __syncthreads();
147
+
148
+ if (tid == 0) {
149
+ unsigned int grown = shared_grown;
150
+ if (grown > max_grow) grown = max_grow;
151
+ unsigned int new_count = existing_syn + grown;
152
+ if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
153
+ seg_syn_count[seg_id] = new_count;
154
+ }
155
+ }
overlay/htm_rust/src/gpu/kernels/tm_learn.cu CHANGED
@@ -1,75 +1,75 @@
1
- // TM learn (reinforce correctly predicted segments) β€” cell-grouped launch.
2
- //
3
- // Grid: n_cells.
4
- // For each cell in a predicted, SP-active column: iterate its segments.
5
- // For each segment with num_active_connected >= activation_threshold,
6
- // reinforce its synapses against prev_active_bits.
7
-
8
- struct TmConfig {
9
- unsigned int activation_threshold;
10
- unsigned int learning_threshold;
11
- unsigned int cells_per_column;
12
- unsigned int synapses_per_segment;
13
- unsigned int n_segments;
14
- unsigned int n_cells;
15
- unsigned int max_segments_per_cell;
16
- unsigned int max_new_synapses;
17
- int conn_thr_i16;
18
- int perm_inc_i16;
19
- int perm_dec_i16;
20
- int predicted_seg_dec_i16;
21
- int initial_perm_i16;
22
- unsigned int iter_seed;
23
- unsigned int n_cols;
24
- unsigned int bits_words;
25
- };
26
-
27
- extern "C" __global__
28
- void tm_learn_reinforce(
29
- const unsigned int * __restrict__ seg_cell_id,
30
- const unsigned int * __restrict__ seg_syn_count,
31
- const unsigned int * __restrict__ syn_presyn,
32
- short * __restrict__ syn_perm,
33
- const unsigned int * __restrict__ seg_num_active_connected,
34
- const unsigned int * __restrict__ prev_active_bits,
35
- const unsigned char * __restrict__ sp_active_mask,
36
- const unsigned char * __restrict__ col_predicted,
37
- const unsigned int * __restrict__ cell_seg_count,
38
- TmConfig cfg
39
- ) {
40
- const unsigned int cell = blockIdx.x;
41
- if (cell >= cfg.n_cells) return;
42
- const unsigned int col = cell / cfg.cells_per_column;
43
- if (sp_active_mask[col] == 0) return;
44
- if (col_predicted[col] == 0) return;
45
-
46
- const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
47
- if (n_segs_here == 0) return;
48
-
49
- const unsigned int tid = threadIdx.x;
50
- const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
51
-
52
- for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
53
- const unsigned int seg = seg_base_id + local_seg;
54
- if (seg_num_active_connected[seg] < cfg.activation_threshold) continue;
55
- const unsigned int n_syn = seg_syn_count[seg];
56
- if (n_syn == 0) continue;
57
- const unsigned int syn_base = seg * cfg.synapses_per_segment;
58
-
59
- for (unsigned int s = tid; s < n_syn; s += 32u) {
60
- unsigned int presyn = syn_presyn[syn_base + s];
61
- unsigned int word = prev_active_bits[presyn >> 5];
62
- unsigned int bit = (word >> (presyn & 31u)) & 1u;
63
- int p = (int)syn_perm[syn_base + s];
64
- if (bit) {
65
- int np = p + cfg.perm_inc_i16;
66
- if (np > 32767) np = 32767;
67
- syn_perm[syn_base + s] = (short)np;
68
- } else {
69
- int np = p - cfg.perm_dec_i16;
70
- if (np < 0) np = 0;
71
- syn_perm[syn_base + s] = (short)np;
72
- }
73
- }
74
- }
75
- }
 
1
+ // TM learn (reinforce correctly predicted segments) β€” cell-grouped launch.
2
+ //
3
+ // Grid: n_cells.
4
+ // For each cell in a predicted, SP-active column: iterate its segments.
5
+ // For each segment with num_active_connected >= activation_threshold,
6
+ // reinforce its synapses against prev_active_bits.
7
+
8
+ struct TmConfig {
9
+ unsigned int activation_threshold;
10
+ unsigned int learning_threshold;
11
+ unsigned int cells_per_column;
12
+ unsigned int synapses_per_segment;
13
+ unsigned int n_segments;
14
+ unsigned int n_cells;
15
+ unsigned int max_segments_per_cell;
16
+ unsigned int max_new_synapses;
17
+ int conn_thr_i16;
18
+ int perm_inc_i16;
19
+ int perm_dec_i16;
20
+ int predicted_seg_dec_i16;
21
+ int initial_perm_i16;
22
+ unsigned int iter_seed;
23
+ unsigned int n_cols;
24
+ unsigned int bits_words;
25
+ };
26
+
27
+ extern "C" __global__
28
+ void tm_learn_reinforce(
29
+ const unsigned int * __restrict__ seg_cell_id,
30
+ const unsigned int * __restrict__ seg_syn_count,
31
+ const unsigned int * __restrict__ syn_presyn,
32
+ short * __restrict__ syn_perm,
33
+ const unsigned int * __restrict__ seg_num_active_connected,
34
+ const unsigned int * __restrict__ prev_active_bits,
35
+ const unsigned char * __restrict__ sp_active_mask,
36
+ const unsigned char * __restrict__ col_predicted,
37
+ const unsigned int * __restrict__ cell_seg_count,
38
+ TmConfig cfg
39
+ ) {
40
+ const unsigned int cell = blockIdx.x;
41
+ if (cell >= cfg.n_cells) return;
42
+ const unsigned int col = cell / cfg.cells_per_column;
43
+ if (sp_active_mask[col] == 0) return;
44
+ if (col_predicted[col] == 0) return;
45
+
46
+ const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
47
+ if (n_segs_here == 0) return;
48
+
49
+ const unsigned int tid = threadIdx.x;
50
+ const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
51
+
52
+ for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
53
+ const unsigned int seg = seg_base_id + local_seg;
54
+ if (seg_num_active_connected[seg] < cfg.activation_threshold) continue;
55
+ const unsigned int n_syn = seg_syn_count[seg];
56
+ if (n_syn == 0) continue;
57
+ const unsigned int syn_base = seg * cfg.synapses_per_segment;
58
+
59
+ for (unsigned int s = tid; s < n_syn; s += 32u) {
60
+ unsigned int presyn = syn_presyn[syn_base + s];
61
+ unsigned int word = prev_active_bits[presyn >> 5];
62
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
63
+ int p = (int)syn_perm[syn_base + s];
64
+ if (bit) {
65
+ int np = p + cfg.perm_inc_i16;
66
+ if (np > 32767) np = 32767;
67
+ syn_perm[syn_base + s] = (short)np;
68
+ } else {
69
+ int np = p - cfg.perm_dec_i16;
70
+ if (np < 0) np = 0;
71
+ syn_perm[syn_base + s] = (short)np;
72
+ }
73
+ }
74
+ }
75
+ }
overlay/htm_rust/src/gpu/kernels/tm_predict.cu CHANGED
@@ -1,102 +1,102 @@
1
- // TM predict kernel β€” cell-grouped launch.
2
- //
3
- // Grid: n_cells blocks (one per cell).
4
- // Block: 32 threads (one warp).
5
- //
6
- // Each block iterates the segments owned by its cell (count in cell_seg_count[cell]).
7
- // For each live segment, counts active connected/potential synapses against
8
- // prev_active_bits. Updates per-segment counters, cell_predictive bit, and
9
- // col_predicted flag.
10
-
11
- struct TmConfig {
12
- unsigned int activation_threshold;
13
- unsigned int learning_threshold;
14
- unsigned int cells_per_column;
15
- unsigned int synapses_per_segment;
16
- unsigned int n_segments;
17
- unsigned int n_cells;
18
- unsigned int max_segments_per_cell;
19
- unsigned int max_new_synapses;
20
- int conn_thr_i16;
21
- int perm_inc_i16;
22
- int perm_dec_i16;
23
- int predicted_seg_dec_i16;
24
- int initial_perm_i16;
25
- unsigned int iter_seed;
26
- unsigned int n_cols;
27
- unsigned int bits_words;
28
- };
29
-
30
- extern "C" __global__
31
- void tm_predict(
32
- const unsigned int * __restrict__ seg_cell_id,
33
- const unsigned int * __restrict__ seg_syn_count,
34
- const unsigned int * __restrict__ syn_presyn,
35
- const short * __restrict__ syn_perm,
36
- const unsigned int * __restrict__ cell_active_bits,
37
- unsigned int * __restrict__ cell_predictive_bits,
38
- unsigned char * __restrict__ col_predicted,
39
- unsigned int * __restrict__ seg_num_active_connected,
40
- unsigned int * __restrict__ seg_num_active_potential,
41
- unsigned int * __restrict__ col_best_match,
42
- const unsigned int * __restrict__ cell_seg_count,
43
- TmConfig cfg
44
- ) {
45
- const unsigned int cell = blockIdx.x;
46
- if (cell >= cfg.n_cells) return;
47
-
48
- const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
49
- if (n_segs_here == 0) return;
50
-
51
- const unsigned int tid = threadIdx.x;
52
- const unsigned int col = cell / cfg.cells_per_column;
53
- const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
54
-
55
- for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
56
- const unsigned int seg = seg_base_id + local_seg;
57
- const unsigned int n_syn = seg_syn_count[seg];
58
- if (n_syn == 0) {
59
- if (tid == 0) {
60
- seg_num_active_connected[seg] = 0;
61
- seg_num_active_potential[seg] = 0;
62
- }
63
- continue;
64
- }
65
- const unsigned int syn_base = seg * cfg.synapses_per_segment;
66
-
67
- unsigned int local_conn = 0;
68
- unsigned int local_pot = 0;
69
- for (unsigned int s = tid; s < n_syn; s += 32u) {
70
- unsigned int presyn = syn_presyn[syn_base + s];
71
- unsigned int word = cell_active_bits[presyn >> 5];
72
- unsigned int bit = (word >> (presyn & 31u)) & 1u;
73
- if (bit) {
74
- local_pot += 1u;
75
- int p = (int)syn_perm[syn_base + s];
76
- if (p >= cfg.conn_thr_i16) {
77
- local_conn += 1u;
78
- }
79
- }
80
- }
81
- for (int off = 16; off > 0; off >>= 1) {
82
- local_conn += __shfl_down_sync(0xffffffffu, local_conn, off);
83
- local_pot += __shfl_down_sync(0xffffffffu, local_pot, off);
84
- }
85
-
86
- if (tid == 0) {
87
- seg_num_active_connected[seg] = local_conn;
88
- seg_num_active_potential[seg] = local_pot;
89
- if (local_conn >= cfg.activation_threshold) {
90
- unsigned int word_idx = cell >> 5;
91
- unsigned int bit_mask = 1u << (cell & 31u);
92
- atomicOr(&cell_predictive_bits[word_idx], bit_mask);
93
- col_predicted[col] = 1;
94
- }
95
- if (local_pot >= cfg.learning_threshold) {
96
- unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot;
97
- unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu);
98
- atomicMax(&col_best_match[col], key);
99
- }
100
- }
101
- }
102
- }
 
1
+ // TM predict kernel β€” cell-grouped launch.
2
+ //
3
+ // Grid: n_cells blocks (one per cell).
4
+ // Block: 32 threads (one warp).
5
+ //
6
+ // Each block iterates the segments owned by its cell (count in cell_seg_count[cell]).
7
+ // For each live segment, counts active connected/potential synapses against
8
+ // prev_active_bits. Updates per-segment counters, cell_predictive bit, and
9
+ // col_predicted flag.
10
+
11
+ struct TmConfig {
12
+ unsigned int activation_threshold;
13
+ unsigned int learning_threshold;
14
+ unsigned int cells_per_column;
15
+ unsigned int synapses_per_segment;
16
+ unsigned int n_segments;
17
+ unsigned int n_cells;
18
+ unsigned int max_segments_per_cell;
19
+ unsigned int max_new_synapses;
20
+ int conn_thr_i16;
21
+ int perm_inc_i16;
22
+ int perm_dec_i16;
23
+ int predicted_seg_dec_i16;
24
+ int initial_perm_i16;
25
+ unsigned int iter_seed;
26
+ unsigned int n_cols;
27
+ unsigned int bits_words;
28
+ };
29
+
30
+ extern "C" __global__
31
+ void tm_predict(
32
+ const unsigned int * __restrict__ seg_cell_id,
33
+ const unsigned int * __restrict__ seg_syn_count,
34
+ const unsigned int * __restrict__ syn_presyn,
35
+ const short * __restrict__ syn_perm,
36
+ const unsigned int * __restrict__ cell_active_bits,
37
+ unsigned int * __restrict__ cell_predictive_bits,
38
+ unsigned char * __restrict__ col_predicted,
39
+ unsigned int * __restrict__ seg_num_active_connected,
40
+ unsigned int * __restrict__ seg_num_active_potential,
41
+ unsigned int * __restrict__ col_best_match,
42
+ const unsigned int * __restrict__ cell_seg_count,
43
+ TmConfig cfg
44
+ ) {
45
+ const unsigned int cell = blockIdx.x;
46
+ if (cell >= cfg.n_cells) return;
47
+
48
+ const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
49
+ if (n_segs_here == 0) return;
50
+
51
+ const unsigned int tid = threadIdx.x;
52
+ const unsigned int col = cell / cfg.cells_per_column;
53
+ const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
54
+
55
+ for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
56
+ const unsigned int seg = seg_base_id + local_seg;
57
+ const unsigned int n_syn = seg_syn_count[seg];
58
+ if (n_syn == 0) {
59
+ if (tid == 0) {
60
+ seg_num_active_connected[seg] = 0;
61
+ seg_num_active_potential[seg] = 0;
62
+ }
63
+ continue;
64
+ }
65
+ const unsigned int syn_base = seg * cfg.synapses_per_segment;
66
+
67
+ unsigned int local_conn = 0;
68
+ unsigned int local_pot = 0;
69
+ for (unsigned int s = tid; s < n_syn; s += 32u) {
70
+ unsigned int presyn = syn_presyn[syn_base + s];
71
+ unsigned int word = cell_active_bits[presyn >> 5];
72
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
73
+ if (bit) {
74
+ local_pot += 1u;
75
+ int p = (int)syn_perm[syn_base + s];
76
+ if (p >= cfg.conn_thr_i16) {
77
+ local_conn += 1u;
78
+ }
79
+ }
80
+ }
81
+ for (int off = 16; off > 0; off >>= 1) {
82
+ local_conn += __shfl_down_sync(0xffffffffu, local_conn, off);
83
+ local_pot += __shfl_down_sync(0xffffffffu, local_pot, off);
84
+ }
85
+
86
+ if (tid == 0) {
87
+ seg_num_active_connected[seg] = local_conn;
88
+ seg_num_active_potential[seg] = local_pot;
89
+ if (local_conn >= cfg.activation_threshold) {
90
+ unsigned int word_idx = cell >> 5;
91
+ unsigned int bit_mask = 1u << (cell & 31u);
92
+ atomicOr(&cell_predictive_bits[word_idx], bit_mask);
93
+ col_predicted[col] = 1;
94
+ }
95
+ if (local_pot >= cfg.learning_threshold) {
96
+ unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot;
97
+ unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu);
98
+ atomicMax(&col_best_match[col], key);
99
+ }
100
+ }
101
+ }
102
+ }
overlay/htm_rust/src/gpu/kernels/tm_punish.cu CHANGED
@@ -1,64 +1,64 @@
1
- // TM punish β€” cell-grouped launch.
2
-
3
- struct TmConfig {
4
- unsigned int activation_threshold;
5
- unsigned int learning_threshold;
6
- unsigned int cells_per_column;
7
- unsigned int synapses_per_segment;
8
- unsigned int n_segments;
9
- unsigned int n_cells;
10
- unsigned int max_segments_per_cell;
11
- unsigned int max_new_synapses;
12
- int conn_thr_i16;
13
- int perm_inc_i16;
14
- int perm_dec_i16;
15
- int predicted_seg_dec_i16;
16
- int initial_perm_i16;
17
- unsigned int iter_seed;
18
- unsigned int n_cols;
19
- unsigned int bits_words;
20
- };
21
-
22
- extern "C" __global__
23
- void tm_punish(
24
- const unsigned int * __restrict__ seg_cell_id,
25
- const unsigned int * __restrict__ seg_syn_count,
26
- const unsigned int * __restrict__ syn_presyn,
27
- short * __restrict__ syn_perm,
28
- const unsigned int * __restrict__ seg_num_active_potential,
29
- const unsigned int * __restrict__ prev_active_bits,
30
- const unsigned char * __restrict__ sp_active_mask,
31
- const unsigned int * __restrict__ cell_seg_count,
32
- TmConfig cfg
33
- ) {
34
- const unsigned int cell = blockIdx.x;
35
- if (cell >= cfg.n_cells) return;
36
- const unsigned int col = cell / cfg.cells_per_column;
37
- if (sp_active_mask[col] != 0) return; // skip: col became active
38
-
39
- const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
40
- if (n_segs_here == 0) return;
41
-
42
- const unsigned int tid = threadIdx.x;
43
- const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
44
-
45
- for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
46
- const unsigned int seg = seg_base_id + local_seg;
47
- if (seg_num_active_potential[seg] < cfg.learning_threshold) continue;
48
- const unsigned int n_syn = seg_syn_count[seg];
49
- if (n_syn == 0) continue;
50
- const unsigned int syn_base = seg * cfg.synapses_per_segment;
51
-
52
- for (unsigned int s = tid; s < n_syn; s += 32u) {
53
- unsigned int presyn = syn_presyn[syn_base + s];
54
- unsigned int word = prev_active_bits[presyn >> 5];
55
- unsigned int bit = (word >> (presyn & 31u)) & 1u;
56
- if (bit) {
57
- int p = (int)syn_perm[syn_base + s];
58
- int np = p - cfg.predicted_seg_dec_i16;
59
- if (np < 0) np = 0;
60
- syn_perm[syn_base + s] = (short)np;
61
- }
62
- }
63
- }
64
- }
 
1
+ // TM punish β€” cell-grouped launch.
2
+
3
+ struct TmConfig {
4
+ unsigned int activation_threshold;
5
+ unsigned int learning_threshold;
6
+ unsigned int cells_per_column;
7
+ unsigned int synapses_per_segment;
8
+ unsigned int n_segments;
9
+ unsigned int n_cells;
10
+ unsigned int max_segments_per_cell;
11
+ unsigned int max_new_synapses;
12
+ int conn_thr_i16;
13
+ int perm_inc_i16;
14
+ int perm_dec_i16;
15
+ int predicted_seg_dec_i16;
16
+ int initial_perm_i16;
17
+ unsigned int iter_seed;
18
+ unsigned int n_cols;
19
+ unsigned int bits_words;
20
+ };
21
+
22
+ extern "C" __global__
23
+ void tm_punish(
24
+ const unsigned int * __restrict__ seg_cell_id,
25
+ const unsigned int * __restrict__ seg_syn_count,
26
+ const unsigned int * __restrict__ syn_presyn,
27
+ short * __restrict__ syn_perm,
28
+ const unsigned int * __restrict__ seg_num_active_potential,
29
+ const unsigned int * __restrict__ prev_active_bits,
30
+ const unsigned char * __restrict__ sp_active_mask,
31
+ const unsigned int * __restrict__ cell_seg_count,
32
+ TmConfig cfg
33
+ ) {
34
+ const unsigned int cell = blockIdx.x;
35
+ if (cell >= cfg.n_cells) return;
36
+ const unsigned int col = cell / cfg.cells_per_column;
37
+ if (sp_active_mask[col] != 0) return; // skip: col became active
38
+
39
+ const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
40
+ if (n_segs_here == 0) return;
41
+
42
+ const unsigned int tid = threadIdx.x;
43
+ const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
44
+
45
+ for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
46
+ const unsigned int seg = seg_base_id + local_seg;
47
+ if (seg_num_active_potential[seg] < cfg.learning_threshold) continue;
48
+ const unsigned int n_syn = seg_syn_count[seg];
49
+ if (n_syn == 0) continue;
50
+ const unsigned int syn_base = seg * cfg.synapses_per_segment;
51
+
52
+ for (unsigned int s = tid; s < n_syn; s += 32u) {
53
+ unsigned int presyn = syn_presyn[syn_base + s];
54
+ unsigned int word = prev_active_bits[presyn >> 5];
55
+ unsigned int bit = (word >> (presyn & 31u)) & 1u;
56
+ if (bit) {
57
+ int p = (int)syn_perm[syn_base + s];
58
+ int np = p - cfg.predicted_seg_dec_i16;
59
+ if (np < 0) np = 0;
60
+ syn_perm[syn_base + s] = (short)np;
61
+ }
62
+ }
63
+ }
64
+ }
overlay/htm_rust/src/gpu/kernels/tm_reset.cu CHANGED
@@ -1,36 +1,36 @@
1
- // TM reset-per-step kernel.
2
-
3
- extern "C" __global__
4
- void tm_reset_step(
5
- unsigned int * __restrict__ cell_active_bits,
6
- unsigned int * __restrict__ cell_winner_bits,
7
- unsigned int * __restrict__ cell_predictive_bits,
8
- unsigned int * __restrict__ prev_active_bits,
9
- unsigned int * __restrict__ prev_winner_bits,
10
- unsigned char * __restrict__ col_predicted,
11
- unsigned int * __restrict__ unpredicted_count,
12
- unsigned int * __restrict__ burst_cols_count,
13
- unsigned int * __restrict__ col_best_match,
14
- unsigned int bits_words,
15
- unsigned int n_cols
16
- ) {
17
- unsigned int tid_global = blockIdx.x * blockDim.x + threadIdx.x;
18
-
19
- if (tid_global < bits_words) {
20
- prev_active_bits[tid_global] = cell_active_bits[tid_global];
21
- prev_winner_bits[tid_global] = cell_winner_bits[tid_global];
22
- cell_active_bits[tid_global] = 0u;
23
- cell_winner_bits[tid_global] = 0u;
24
- cell_predictive_bits[tid_global] = 0u;
25
- }
26
-
27
- if (tid_global < n_cols) {
28
- col_predicted[tid_global] = 0;
29
- col_best_match[tid_global] = 0u;
30
- }
31
-
32
- if (tid_global == 0) {
33
- unpredicted_count[0] = 0u;
34
- burst_cols_count[0] = 0u;
35
- }
36
- }
 
1
+ // TM reset-per-step kernel.
2
+
3
+ extern "C" __global__
4
+ void tm_reset_step(
5
+ unsigned int * __restrict__ cell_active_bits,
6
+ unsigned int * __restrict__ cell_winner_bits,
7
+ unsigned int * __restrict__ cell_predictive_bits,
8
+ unsigned int * __restrict__ prev_active_bits,
9
+ unsigned int * __restrict__ prev_winner_bits,
10
+ unsigned char * __restrict__ col_predicted,
11
+ unsigned int * __restrict__ unpredicted_count,
12
+ unsigned int * __restrict__ burst_cols_count,
13
+ unsigned int * __restrict__ col_best_match,
14
+ unsigned int bits_words,
15
+ unsigned int n_cols
16
+ ) {
17
+ unsigned int tid_global = blockIdx.x * blockDim.x + threadIdx.x;
18
+
19
+ if (tid_global < bits_words) {
20
+ prev_active_bits[tid_global] = cell_active_bits[tid_global];
21
+ prev_winner_bits[tid_global] = cell_winner_bits[tid_global];
22
+ cell_active_bits[tid_global] = 0u;
23
+ cell_winner_bits[tid_global] = 0u;
24
+ cell_predictive_bits[tid_global] = 0u;
25
+ }
26
+
27
+ if (tid_global < n_cols) {
28
+ col_predicted[tid_global] = 0;
29
+ col_best_match[tid_global] = 0u;
30
+ }
31
+
32
+ if (tid_global == 0) {
33
+ unpredicted_count[0] = 0u;
34
+ burst_cols_count[0] = 0u;
35
+ }
36
+ }
overlay/htm_rust/src/gpu/mod.rs CHANGED
@@ -1,549 +1,549 @@
1
- //! GPU backend for HTM.
2
- //!
3
- //! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
4
- //! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
5
- //! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
6
- //! to the host in one shot.
7
- //!
8
- //! TM parity with the CPU reference is approximate:
9
- //! - Segment growth: winner = cell 0 of bursting column (CPU picks
10
- //! least-used-cell with RNG tiebreak). This is a pragmatic simplification
11
- //! for GPU atomicity; learning dynamics are preserved.
12
- //! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
13
- //! f32 by <= 1 ULP of the scale factor (β‰ˆ 3e-5) β€” inside any meaningful
14
- //! HTM learning quantum.
15
-
16
- #![cfg(feature = "gpu")]
17
-
18
- pub mod sp_gpu;
19
- pub mod tm_gpu;
20
- pub mod fused;
21
-
22
- #[cfg(test)]
23
- mod tests;
24
-
25
- use std::mem::ManuallyDrop;
26
-
27
- use pyo3::prelude::*;
28
- use pyo3::types::{PyDict, PyTuple};
29
- use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
30
-
31
- use crate::region::HTMRegionCore;
32
- use crate::sp::SpatialPoolerConfig;
33
- use sp_gpu::SpatialPoolerGpu;
34
- use tm_gpu::TemporalMemoryGpu;
35
- use fused::FusedState;
36
-
37
- /// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
38
- /// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
39
- /// torch-owned CUDA allocations zero-copy.
40
- fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
41
- // `data` is a (ptr: int, readonly: bool) tuple.
42
- let data_obj = cai.get_item("data")?
43
- .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
44
- let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
45
- .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
46
- let ptr: u64 = data_tup.get_item(0)?.extract()?;
47
-
48
- // `shape` is a tuple of ints.
49
- let shape_obj = cai.get_item("shape")?
50
- .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
51
- let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
52
- .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
53
- let shape: Vec<usize> = (0..shape_tup.len())
54
- .map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
55
- .collect::<PyResult<Vec<_>>>()?;
56
-
57
- // `typestr` (e.g. "|u1", "<f4").
58
- let typestr_obj = cai.get_item("typestr")?
59
- .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
60
- let typestr: String = typestr_obj.extract()?;
61
-
62
- // Reject non-contiguous tensors β€” we don't handle strides.
63
- if let Some(strides) = cai.get_item("strides")? {
64
- if !strides.is_none() {
65
- return Err(pyo3::exceptions::PyValueError::new_err(
66
- "CAI 'strides' must be None (tensor must be contiguous)",
67
- ));
68
- }
69
- }
70
-
71
- Ok((ptr, shape, typestr))
72
- }
73
-
74
- /// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
75
- #[pyclass(module = "htm_rust")]
76
- pub struct HTMRegionGpu {
77
- pub(super) sp_gpu: SpatialPoolerGpu,
78
- pub(super) tm_gpu: TemporalMemoryGpu,
79
- pub(super) fused_state: FusedState,
80
- pub(super) n_columns: usize,
81
- pub(super) input_bits: usize,
82
- pub(super) cells_per_column: usize,
83
- }
84
-
85
- #[pymethods]
86
- impl HTMRegionGpu {
87
- #[new]
88
- #[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
89
- fn new(
90
- input_bits: usize,
91
- n_columns: usize,
92
- cells_per_column: usize,
93
- seed: u64,
94
- ) -> PyResult<Self> {
95
- if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
96
- return Err(pyo3::exceptions::PyValueError::new_err(
97
- "input_bits, n_columns, cells_per_column must all be > 0",
98
- ));
99
- }
100
- // CPU reference for deterministic SP init.
101
- let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
102
- let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
103
- let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
104
- pyo3::exceptions::PyRuntimeError::new_err(format!(
105
- "GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
106
- sp_cfg.input_bits, sp_cfg.n_columns,
107
- ))
108
- })?;
109
- let dev = sp_gpu.dev_ref().clone();
110
- let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
111
- pyo3::exceptions::PyRuntimeError::new_err(format!(
112
- "GPU TM init failed: {e:?}",
113
- ))
114
- })?;
115
- let initial_threshold = sp_gpu.initial_threshold_estimate();
116
- let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
117
- .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
118
- "GPU fused state init failed: {e:?}",
119
- )))?;
120
- Ok(Self {
121
- sp_gpu,
122
- tm_gpu,
123
- fused_state,
124
- n_columns,
125
- input_bits,
126
- cells_per_column,
127
- })
128
- }
129
-
130
- #[getter] fn input_bits(&self) -> usize { self.input_bits }
131
- #[getter] fn n_columns(&self) -> usize { self.n_columns }
132
- #[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
133
-
134
- /// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
135
- /// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
136
- /// to the host at the end.
137
- #[pyo3(signature = (inputs, learn=true))]
138
- fn step_many_gpu<'py>(
139
- &mut self,
140
- py: Python<'py>,
141
- inputs: PyReadonlyArray2<'py, bool>,
142
- learn: bool,
143
- ) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
144
- let shape = inputs.shape();
145
- if shape.len() != 2 {
146
- return Err(pyo3::exceptions::PyValueError::new_err(
147
- "inputs must be 2-D (T, input_bits)",
148
- ));
149
- }
150
- let t = shape[0];
151
- let bits = shape[1];
152
- if bits != self.input_bits {
153
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
154
- "inputs last dim {bits} != expected input_bits {}",
155
- self.input_bits,
156
- )));
157
- }
158
- let slice = inputs.as_slice()?;
159
- let n_cols = self.n_columns;
160
- let input_vec: Vec<bool> = slice.to_vec();
161
-
162
- let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
163
- // 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
164
- let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
165
- let inputs_dev = self
166
- .sp_gpu
167
- .dev_ref()
168
- .htod_sync_copy(&sdr_u8_all)
169
- .map_err(|e| format!("H2D inputs: {e:?}"))?;
170
-
171
- // 2. Allocate output buffers on device.
172
- let mut cols_dev = self.sp_gpu.dev_ref()
173
- .alloc_zeros::<u8>(t * n_cols)
174
- .map_err(|e| format!("alloc cols: {e:?}"))?;
175
- let mut anom_dev = self.sp_gpu.dev_ref()
176
- .alloc_zeros::<f32>(t)
177
- .map_err(|e| format!("alloc anom: {e:?}"))?;
178
-
179
- // 3. Run T steps of SP + TM on GPU with NO per-step host sync.
180
- self.sp_gpu.step_batch_with_tm(
181
- &inputs_dev,
182
- t,
183
- self.input_bits,
184
- learn,
185
- &mut cols_dev,
186
- &mut anom_dev,
187
- &mut self.tm_gpu,
188
- ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
189
-
190
- // 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
191
- let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
192
- .dtoh_sync_copy(&cols_dev)
193
- .map_err(|e| format!("D2H cols: {e:?}"))?;
194
- let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
195
- .dtoh_sync_copy(&anom_dev)
196
- .map_err(|e| format!("D2H anom: {e:?}"))?;
197
-
198
- Ok((cols_host, anom_host))
199
- });
200
-
201
- let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
202
-
203
- let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
204
- let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
205
- .reshape([t, n_cols])
206
- .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
207
- let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
208
- Ok((cols_arr, anom_arr))
209
- }
210
-
211
- /// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
212
- /// write outputs directly into caller-allocated torch tensors. Skips the
213
- /// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
214
- /// two D2H copies at the end). This is the hot path for `train.py`.
215
- ///
216
- /// Contract:
217
- /// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
218
- /// cols_cai.shape == (T, n_columns), dtype u8 (written)
219
- /// anom_cai.shape == (T,), dtype f32 (written)
220
- /// All three tensors must live on the SAME CUDA device as this region.
221
- ///
222
- /// The torch tensors still own their memory β€” this method only wraps
223
- /// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
224
- /// impl can't free pytorch's allocator.
225
- #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
226
- fn step_many_cuda(
227
- &mut self,
228
- py: Python<'_>,
229
- sdr_cai: &Bound<'_, PyDict>,
230
- cols_cai: &Bound<'_, PyDict>,
231
- anom_cai: &Bound<'_, PyDict>,
232
- learn: bool,
233
- ) -> PyResult<()> {
234
- let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
235
- let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
236
- let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
237
-
238
- // typestr sanity. numpy u1 is what torch.uint8 exports.
239
- if sdr_type != "|u1" {
240
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
241
- "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
242
- )));
243
- }
244
- if cols_type != "|u1" {
245
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
246
- "cols_cai typestr must be '|u1' (uint8), got {cols_type}",
247
- )));
248
- }
249
- if anom_type != "<f4" && anom_type != "=f4" {
250
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
251
- "anom_cai typestr must be '<f4' (float32), got {anom_type}",
252
- )));
253
- }
254
-
255
- // Shape validation.
256
- if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
257
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
258
- "sdr_cai shape {sdr_shape:?} != (T, {})",
259
- self.input_bits,
260
- )));
261
- }
262
- let t = sdr_shape[0];
263
- if cols_shape != [t, self.n_columns] {
264
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
265
- "cols_cai shape {cols_shape:?} != ({t}, {})",
266
- self.n_columns,
267
- )));
268
- }
269
- if anom_shape != [t] {
270
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
271
- "anom_cai shape {anom_shape:?} != ({t},)",
272
- )));
273
- }
274
-
275
- let dev = self.sp_gpu.dev_ref().clone();
276
- let n_cols = self.n_columns;
277
- let input_bits = self.input_bits;
278
-
279
- let result = py.allow_threads(|| -> Result<(), String> {
280
- // SAFETY:
281
- // - ptrs came from torch CUDA tensors validated non-null by the
282
- // __cuda_array_interface__ contract.
283
- // - lens computed from validated shapes.
284
- // - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
285
- // Drop (which calls cuMemFree) never runs against torch memory.
286
- // The underlying allocation is owned+freed by torch.
287
- // - The slices are used only for the duration of this call;
288
- // torch guarantees the backing tensors are live across it
289
- // (Python holds refs on the wrapping tensors).
290
- let inputs_dev = ManuallyDrop::new(unsafe {
291
- dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
292
- });
293
- let mut cols_dev = ManuallyDrop::new(unsafe {
294
- dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
295
- });
296
- let mut anom_dev = ManuallyDrop::new(unsafe {
297
- dev.upgrade_device_ptr::<f32>(anom_ptr, t)
298
- });
299
-
300
- self.sp_gpu.step_batch_with_tm(
301
- &inputs_dev,
302
- t,
303
- input_bits,
304
- learn,
305
- &mut cols_dev,
306
- &mut anom_dev,
307
- &mut self.tm_gpu,
308
- ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
309
-
310
- // Synchronize: kernel writes must be visible to the next torch
311
- // op that reads cols/anom. Pytorch's default stream is stream 0,
312
- // and cudarc launches on its own stream β€” a full device sync
313
- // is the simplest correct barrier. (Could narrow to a stream
314
- // wait event in PR 2.)
315
- // No dev.synchronize() here: caller must explicitly sync via the
316
- // `device_sync()` method (or PyTorch auto-syncs when the output
317
- // tensor is next consumed). Removing the per-launch barrier lets
318
- // subsequent GPU work (mamba3 fwd, etc.) overlap in time.
319
- Ok(())
320
- });
321
-
322
- result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
323
- Ok(())
324
- }
325
-
326
- /// Clear TM state on the GPU.
327
- fn reset(&mut self) -> PyResult<()> {
328
- self.tm_gpu.reset().map_err(|e| {
329
- pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
330
- })?;
331
- self.fused_state.reset().map_err(|e| {
332
- pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
333
- })
334
- }
335
-
336
- /// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
337
- /// forward (SP + TM all in one). Accepts torch CUDA tensors via
338
- /// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
339
- /// anomaly directly into caller-allocated torch tensors.
340
- ///
341
- /// Semantics diverge from `step_many_cuda` in one important way: column
342
- /// activation uses per-column threshold inhibition instead of global
343
- /// top-K. The threshold is EMA-adapted per column toward the sparsity
344
- /// target. See `docs/GPU_HTM.md` Β§Fused Kernel.
345
- #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
346
- fn step_many_fused_cuda(
347
- &mut self,
348
- py: Python<'_>,
349
- sdr_cai: &Bound<'_, PyDict>,
350
- cols_cai: &Bound<'_, PyDict>,
351
- anom_cai: &Bound<'_, PyDict>,
352
- learn: bool,
353
- ) -> PyResult<()> {
354
- let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
355
- let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
356
- let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
357
-
358
- if sdr_type != "|u1" {
359
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
360
- "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
361
- )));
362
- }
363
- if cols_type != "|u1" {
364
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
365
- "cols_cai typestr must be '|u1' (uint8), got {cols_type}",
366
- )));
367
- }
368
- if anom_type != "<f4" && anom_type != "=f4" {
369
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
370
- "anom_cai typestr must be '<f4' (float32), got {anom_type}",
371
- )));
372
- }
373
-
374
- if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
375
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
376
- "sdr_cai shape {sdr_shape:?} != (T, {})",
377
- self.input_bits,
378
- )));
379
- }
380
- let t = sdr_shape[0];
381
- if cols_shape != [t, self.n_columns] {
382
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
383
- "cols_cai shape {cols_shape:?} != ({t}, {})",
384
- self.n_columns,
385
- )));
386
- }
387
- if anom_shape != [t] {
388
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
389
- "anom_cai shape {anom_shape:?} != ({t},)",
390
- )));
391
- }
392
-
393
- let dev = self.sp_gpu.dev_ref().clone();
394
- let n_cols = self.n_columns;
395
- let input_bits = self.input_bits;
396
-
397
- let result = py.allow_threads(|| -> Result<(), String> {
398
- let inputs_dev = ManuallyDrop::new(unsafe {
399
- dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
400
- });
401
- let mut cols_dev = ManuallyDrop::new(unsafe {
402
- dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
403
- });
404
- let mut anom_dev = ManuallyDrop::new(unsafe {
405
- dev.upgrade_device_ptr::<f32>(anom_ptr, t)
406
- });
407
-
408
- fused::launch_fused(
409
- &mut self.sp_gpu,
410
- &mut self.tm_gpu,
411
- &mut self.fused_state,
412
- &inputs_dev,
413
- &mut cols_dev,
414
- &mut anom_dev,
415
- t,
416
- input_bits,
417
- learn,
418
- ).map_err(|e| format!("launch_fused: {e:?}"))?;
419
-
420
- // No dev.synchronize() here: caller must explicitly sync via the
421
- // `device_sync()` method (or PyTorch auto-syncs when the output
422
- // tensor is next consumed). Removing the per-launch barrier lets
423
- // subsequent GPU work (mamba3 fwd, etc.) overlap in time.
424
- Ok(())
425
- });
426
-
427
- result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
428
- Ok(())
429
- }
430
-
431
- /// Explicit device synchronization β€” the caller must invoke this after
432
- /// all batched `step_many_*_cuda` calls complete, before reading the
433
- /// output tensors from a different CUDA stream. Equivalent to the old
434
- /// per-call `dev.synchronize()` that was removed for overlap.
435
- fn device_sync(&self) -> PyResult<()> {
436
- let dev = self.sp_gpu.dev_ref();
437
- dev.synchronize()
438
- .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?;
439
- Ok(())
440
- }
441
- }
442
-
443
- /// Batch B regions into ONE cooperative kernel launch. Breaks through the
444
- /// CUDA cooperative-kernel device-level serialization: a single cooperative
445
- /// launch with grid.y=B processes all regions concurrently β€” ~BΓ— speedup
446
- /// over B sequential launches.
447
- ///
448
- /// All regions must have the same config (input_bits, n_columns,
449
- /// cells_per_column). Each region keeps its independent GPU state.
450
- /// Does NOT sync; caller must invoke `device_sync()` on any region
451
- /// afterwards (or rely on a downstream torch op to auto-sync).
452
- #[pyfunction]
453
- #[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))]
454
- fn step_batch_fused_cuda(
455
- py: Python<'_>,
456
- regions: Vec<Py<HTMRegionGpu>>,
457
- sdr_cais: Vec<Bound<'_, PyDict>>,
458
- cols_cais: Vec<Bound<'_, PyDict>>,
459
- anom_cais: Vec<Bound<'_, PyDict>>,
460
- learn: bool,
461
- ) -> PyResult<()> {
462
- let b = regions.len();
463
- if b == 0 {
464
- return Err(pyo3::exceptions::PyValueError::new_err("regions is empty"));
465
- }
466
- if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b {
467
- return Err(pyo3::exceptions::PyValueError::new_err(
468
- "sdr_cais / cols_cais / anom_cais length must match regions",
469
- ));
470
- }
471
-
472
- // Parse all CAI dicts; collect device pointers. Validate shapes/dtypes.
473
- let mut sdr_ptrs = Vec::with_capacity(b);
474
- let mut cols_ptrs = Vec::with_capacity(b);
475
- let mut anom_ptrs = Vec::with_capacity(b);
476
- let (input_bits, n_columns, t) = {
477
- let r0 = regions[0].bind(py).borrow();
478
- (r0.input_bits, r0.n_columns, {
479
- let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?;
480
- if sh.len() != 2 {
481
- return Err(pyo3::exceptions::PyValueError::new_err(
482
- format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"),
483
- ));
484
- }
485
- sh[0]
486
- })
487
- };
488
-
489
- for i in 0..b {
490
- let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?;
491
- let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?;
492
- let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?;
493
- if sdr_type != "|u1" || cols_type != "|u1" {
494
- return Err(pyo3::exceptions::PyValueError::new_err(
495
- "sdr/cols typestr must be '|u1' (uint8)",
496
- ));
497
- }
498
- if anom_type != "<f4" && anom_type != "=f4" {
499
- return Err(pyo3::exceptions::PyValueError::new_err(
500
- "anom typestr must be '<f4' (float32)",
501
- ));
502
- }
503
- if sdr_shape != [t, input_bits] {
504
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
505
- "sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})"
506
- )));
507
- }
508
- if cols_shape != [t, n_columns] {
509
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
510
- "cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})"
511
- )));
512
- }
513
- if anom_shape != [t] {
514
- return Err(pyo3::exceptions::PyValueError::new_err(format!(
515
- "anom[{i}] shape {anom_shape:?} != ({t},)"
516
- )));
517
- }
518
- sdr_ptrs.push(sdr_ptr);
519
- cols_ptrs.push(cols_ptr);
520
- anom_ptrs.push(anom_ptr);
521
- }
522
-
523
- // Exclusively borrow each region. PyRefMut guarantees uniqueness.
524
- let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> =
525
- regions.iter().map(|p| p.bind(py).borrow_mut()).collect();
526
- // Collect raw mutable pointers β€” each PyRefMut exclusively borrows its
527
- // region for the lifetime of this call, so pointers stay valid and
528
- // unique. launch_fused_batched_raw only dereferences one region at a
529
- // time, not constructing an aliased slice.
530
- let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs
531
- .iter_mut()
532
- .map(|r| &mut **r as *mut HTMRegionGpu)
533
- .collect();
534
-
535
- // No allow_threads: raw pointers aren't Send. The launch is GPU-queued
536
- // and sync'd downstream; holding the GIL for the duration is cheap.
537
- fused::launch_fused_batched_raw(
538
- &raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs,
539
- t, input_bits, learn,
540
- )
541
- .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?;
542
- Ok(())
543
- }
544
-
545
- pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
546
- m.add_class::<HTMRegionGpu>()?;
547
- m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
548
- Ok(())
549
- }
 
1
+ //! GPU backend for HTM.
2
+ //!
3
+ //! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
4
+ //! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
5
+ //! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
6
+ //! to the host in one shot.
7
+ //!
8
+ //! TM parity with the CPU reference is approximate:
9
+ //! - Segment growth: winner = cell 0 of bursting column (CPU picks
10
+ //! least-used-cell with RNG tiebreak). This is a pragmatic simplification
11
+ //! for GPU atomicity; learning dynamics are preserved.
12
+ //! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
13
+ //! f32 by <= 1 ULP of the scale factor (β‰ˆ 3e-5) β€” inside any meaningful
14
+ //! HTM learning quantum.
15
+
16
+ #![cfg(feature = "gpu")]
17
+
18
+ pub mod sp_gpu;
19
+ pub mod tm_gpu;
20
+ pub mod fused;
21
+
22
+ #[cfg(test)]
23
+ mod tests;
24
+
25
+ use std::mem::ManuallyDrop;
26
+
27
+ use pyo3::prelude::*;
28
+ use pyo3::types::{PyDict, PyTuple};
29
+ use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
30
+
31
+ use crate::region::HTMRegionCore;
32
+ use crate::sp::SpatialPoolerConfig;
33
+ use sp_gpu::SpatialPoolerGpu;
34
+ use tm_gpu::TemporalMemoryGpu;
35
+ use fused::FusedState;
36
+
37
+ /// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
38
+ /// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
39
+ /// torch-owned CUDA allocations zero-copy.
40
+ fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
41
+ // `data` is a (ptr: int, readonly: bool) tuple.
42
+ let data_obj = cai.get_item("data")?
43
+ .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
44
+ let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
45
+ .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
46
+ let ptr: u64 = data_tup.get_item(0)?.extract()?;
47
+
48
+ // `shape` is a tuple of ints.
49
+ let shape_obj = cai.get_item("shape")?
50
+ .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
51
+ let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
52
+ .map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
53
+ let shape: Vec<usize> = (0..shape_tup.len())
54
+ .map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
55
+ .collect::<PyResult<Vec<_>>>()?;
56
+
57
+ // `typestr` (e.g. "|u1", "<f4").
58
+ let typestr_obj = cai.get_item("typestr")?
59
+ .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
60
+ let typestr: String = typestr_obj.extract()?;
61
+
62
+ // Reject non-contiguous tensors β€” we don't handle strides.
63
+ if let Some(strides) = cai.get_item("strides")? {
64
+ if !strides.is_none() {
65
+ return Err(pyo3::exceptions::PyValueError::new_err(
66
+ "CAI 'strides' must be None (tensor must be contiguous)",
67
+ ));
68
+ }
69
+ }
70
+
71
+ Ok((ptr, shape, typestr))
72
+ }
73
+
74
+ /// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
75
+ #[pyclass(module = "htm_rust")]
76
+ pub struct HTMRegionGpu {
77
+ pub(super) sp_gpu: SpatialPoolerGpu,
78
+ pub(super) tm_gpu: TemporalMemoryGpu,
79
+ pub(super) fused_state: FusedState,
80
+ pub(super) n_columns: usize,
81
+ pub(super) input_bits: usize,
82
+ pub(super) cells_per_column: usize,
83
+ }
84
+
85
+ #[pymethods]
86
+ impl HTMRegionGpu {
87
+ #[new]
88
+ #[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
89
+ fn new(
90
+ input_bits: usize,
91
+ n_columns: usize,
92
+ cells_per_column: usize,
93
+ seed: u64,
94
+ ) -> PyResult<Self> {
95
+ if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
96
+ return Err(pyo3::exceptions::PyValueError::new_err(
97
+ "input_bits, n_columns, cells_per_column must all be > 0",
98
+ ));
99
+ }
100
+ // CPU reference for deterministic SP init.
101
+ let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
102
+ let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
103
+ let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
104
+ pyo3::exceptions::PyRuntimeError::new_err(format!(
105
+ "GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
106
+ sp_cfg.input_bits, sp_cfg.n_columns,
107
+ ))
108
+ })?;
109
+ let dev = sp_gpu.dev_ref().clone();
110
+ let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
111
+ pyo3::exceptions::PyRuntimeError::new_err(format!(
112
+ "GPU TM init failed: {e:?}",
113
+ ))
114
+ })?;
115
+ let initial_threshold = sp_gpu.initial_threshold_estimate();
116
+ let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
117
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
118
+ "GPU fused state init failed: {e:?}",
119
+ )))?;
120
+ Ok(Self {
121
+ sp_gpu,
122
+ tm_gpu,
123
+ fused_state,
124
+ n_columns,
125
+ input_bits,
126
+ cells_per_column,
127
+ })
128
+ }
129
+
130
+ #[getter] fn input_bits(&self) -> usize { self.input_bits }
131
+ #[getter] fn n_columns(&self) -> usize { self.n_columns }
132
+ #[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
133
+
134
+ /// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
135
+ /// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
136
+ /// to the host at the end.
137
+ #[pyo3(signature = (inputs, learn=true))]
138
+ fn step_many_gpu<'py>(
139
+ &mut self,
140
+ py: Python<'py>,
141
+ inputs: PyReadonlyArray2<'py, bool>,
142
+ learn: bool,
143
+ ) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
144
+ let shape = inputs.shape();
145
+ if shape.len() != 2 {
146
+ return Err(pyo3::exceptions::PyValueError::new_err(
147
+ "inputs must be 2-D (T, input_bits)",
148
+ ));
149
+ }
150
+ let t = shape[0];
151
+ let bits = shape[1];
152
+ if bits != self.input_bits {
153
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
154
+ "inputs last dim {bits} != expected input_bits {}",
155
+ self.input_bits,
156
+ )));
157
+ }
158
+ let slice = inputs.as_slice()?;
159
+ let n_cols = self.n_columns;
160
+ let input_vec: Vec<bool> = slice.to_vec();
161
+
162
+ let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
163
+ // 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
164
+ let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
165
+ let inputs_dev = self
166
+ .sp_gpu
167
+ .dev_ref()
168
+ .htod_sync_copy(&sdr_u8_all)
169
+ .map_err(|e| format!("H2D inputs: {e:?}"))?;
170
+
171
+ // 2. Allocate output buffers on device.
172
+ let mut cols_dev = self.sp_gpu.dev_ref()
173
+ .alloc_zeros::<u8>(t * n_cols)
174
+ .map_err(|e| format!("alloc cols: {e:?}"))?;
175
+ let mut anom_dev = self.sp_gpu.dev_ref()
176
+ .alloc_zeros::<f32>(t)
177
+ .map_err(|e| format!("alloc anom: {e:?}"))?;
178
+
179
+ // 3. Run T steps of SP + TM on GPU with NO per-step host sync.
180
+ self.sp_gpu.step_batch_with_tm(
181
+ &inputs_dev,
182
+ t,
183
+ self.input_bits,
184
+ learn,
185
+ &mut cols_dev,
186
+ &mut anom_dev,
187
+ &mut self.tm_gpu,
188
+ ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
189
+
190
+ // 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
191
+ let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
192
+ .dtoh_sync_copy(&cols_dev)
193
+ .map_err(|e| format!("D2H cols: {e:?}"))?;
194
+ let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
195
+ .dtoh_sync_copy(&anom_dev)
196
+ .map_err(|e| format!("D2H anom: {e:?}"))?;
197
+
198
+ Ok((cols_host, anom_host))
199
+ });
200
+
201
+ let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
202
+
203
+ let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
204
+ let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
205
+ .reshape([t, n_cols])
206
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
207
+ let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
208
+ Ok((cols_arr, anom_arr))
209
+ }
210
+
211
+ /// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
212
+ /// write outputs directly into caller-allocated torch tensors. Skips the
213
+ /// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
214
+ /// two D2H copies at the end). This is the hot path for `train.py`.
215
+ ///
216
+ /// Contract:
217
+ /// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
218
+ /// cols_cai.shape == (T, n_columns), dtype u8 (written)
219
+ /// anom_cai.shape == (T,), dtype f32 (written)
220
+ /// All three tensors must live on the SAME CUDA device as this region.
221
+ ///
222
+ /// The torch tensors still own their memory β€” this method only wraps
223
+ /// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
224
+ /// impl can't free pytorch's allocator.
225
+ #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
226
+ fn step_many_cuda(
227
+ &mut self,
228
+ py: Python<'_>,
229
+ sdr_cai: &Bound<'_, PyDict>,
230
+ cols_cai: &Bound<'_, PyDict>,
231
+ anom_cai: &Bound<'_, PyDict>,
232
+ learn: bool,
233
+ ) -> PyResult<()> {
234
+ let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
235
+ let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
236
+ let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
237
+
238
+ // typestr sanity. numpy u1 is what torch.uint8 exports.
239
+ if sdr_type != "|u1" {
240
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
241
+ "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
242
+ )));
243
+ }
244
+ if cols_type != "|u1" {
245
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
246
+ "cols_cai typestr must be '|u1' (uint8), got {cols_type}",
247
+ )));
248
+ }
249
+ if anom_type != "<f4" && anom_type != "=f4" {
250
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
251
+ "anom_cai typestr must be '<f4' (float32), got {anom_type}",
252
+ )));
253
+ }
254
+
255
+ // Shape validation.
256
+ if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
257
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
258
+ "sdr_cai shape {sdr_shape:?} != (T, {})",
259
+ self.input_bits,
260
+ )));
261
+ }
262
+ let t = sdr_shape[0];
263
+ if cols_shape != [t, self.n_columns] {
264
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
265
+ "cols_cai shape {cols_shape:?} != ({t}, {})",
266
+ self.n_columns,
267
+ )));
268
+ }
269
+ if anom_shape != [t] {
270
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
271
+ "anom_cai shape {anom_shape:?} != ({t},)",
272
+ )));
273
+ }
274
+
275
+ let dev = self.sp_gpu.dev_ref().clone();
276
+ let n_cols = self.n_columns;
277
+ let input_bits = self.input_bits;
278
+
279
+ let result = py.allow_threads(|| -> Result<(), String> {
280
+ // SAFETY:
281
+ // - ptrs came from torch CUDA tensors validated non-null by the
282
+ // __cuda_array_interface__ contract.
283
+ // - lens computed from validated shapes.
284
+ // - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
285
+ // Drop (which calls cuMemFree) never runs against torch memory.
286
+ // The underlying allocation is owned+freed by torch.
287
+ // - The slices are used only for the duration of this call;
288
+ // torch guarantees the backing tensors are live across it
289
+ // (Python holds refs on the wrapping tensors).
290
+ let inputs_dev = ManuallyDrop::new(unsafe {
291
+ dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
292
+ });
293
+ let mut cols_dev = ManuallyDrop::new(unsafe {
294
+ dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
295
+ });
296
+ let mut anom_dev = ManuallyDrop::new(unsafe {
297
+ dev.upgrade_device_ptr::<f32>(anom_ptr, t)
298
+ });
299
+
300
+ self.sp_gpu.step_batch_with_tm(
301
+ &inputs_dev,
302
+ t,
303
+ input_bits,
304
+ learn,
305
+ &mut cols_dev,
306
+ &mut anom_dev,
307
+ &mut self.tm_gpu,
308
+ ).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
309
+
310
+ // Synchronize: kernel writes must be visible to the next torch
311
+ // op that reads cols/anom. Pytorch's default stream is stream 0,
312
+ // and cudarc launches on its own stream β€” a full device sync
313
+ // is the simplest correct barrier. (Could narrow to a stream
314
+ // wait event in PR 2.)
315
+ // No dev.synchronize() here: caller must explicitly sync via the
316
+ // `device_sync()` method (or PyTorch auto-syncs when the output
317
+ // tensor is next consumed). Removing the per-launch barrier lets
318
+ // subsequent GPU work (mamba3 fwd, etc.) overlap in time.
319
+ Ok(())
320
+ });
321
+
322
+ result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
323
+ Ok(())
324
+ }
325
+
326
+ /// Clear TM state on the GPU.
327
+ fn reset(&mut self) -> PyResult<()> {
328
+ self.tm_gpu.reset().map_err(|e| {
329
+ pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
330
+ })?;
331
+ self.fused_state.reset().map_err(|e| {
332
+ pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
333
+ })
334
+ }
335
+
336
+ /// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
337
+ /// forward (SP + TM all in one). Accepts torch CUDA tensors via
338
+ /// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
339
+ /// anomaly directly into caller-allocated torch tensors.
340
+ ///
341
+ /// Semantics diverge from `step_many_cuda` in one important way: column
342
+ /// activation uses per-column threshold inhibition instead of global
343
+ /// top-K. The threshold is EMA-adapted per column toward the sparsity
344
+ /// target. See `docs/GPU_HTM.md` Β§Fused Kernel.
345
+ #[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
346
+ fn step_many_fused_cuda(
347
+ &mut self,
348
+ py: Python<'_>,
349
+ sdr_cai: &Bound<'_, PyDict>,
350
+ cols_cai: &Bound<'_, PyDict>,
351
+ anom_cai: &Bound<'_, PyDict>,
352
+ learn: bool,
353
+ ) -> PyResult<()> {
354
+ let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
355
+ let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
356
+ let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
357
+
358
+ if sdr_type != "|u1" {
359
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
360
+ "sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
361
+ )));
362
+ }
363
+ if cols_type != "|u1" {
364
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
365
+ "cols_cai typestr must be '|u1' (uint8), got {cols_type}",
366
+ )));
367
+ }
368
+ if anom_type != "<f4" && anom_type != "=f4" {
369
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
370
+ "anom_cai typestr must be '<f4' (float32), got {anom_type}",
371
+ )));
372
+ }
373
+
374
+ if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
375
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
376
+ "sdr_cai shape {sdr_shape:?} != (T, {})",
377
+ self.input_bits,
378
+ )));
379
+ }
380
+ let t = sdr_shape[0];
381
+ if cols_shape != [t, self.n_columns] {
382
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
383
+ "cols_cai shape {cols_shape:?} != ({t}, {})",
384
+ self.n_columns,
385
+ )));
386
+ }
387
+ if anom_shape != [t] {
388
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
389
+ "anom_cai shape {anom_shape:?} != ({t},)",
390
+ )));
391
+ }
392
+
393
+ let dev = self.sp_gpu.dev_ref().clone();
394
+ let n_cols = self.n_columns;
395
+ let input_bits = self.input_bits;
396
+
397
+ let result = py.allow_threads(|| -> Result<(), String> {
398
+ let inputs_dev = ManuallyDrop::new(unsafe {
399
+ dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
400
+ });
401
+ let mut cols_dev = ManuallyDrop::new(unsafe {
402
+ dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
403
+ });
404
+ let mut anom_dev = ManuallyDrop::new(unsafe {
405
+ dev.upgrade_device_ptr::<f32>(anom_ptr, t)
406
+ });
407
+
408
+ fused::launch_fused(
409
+ &mut self.sp_gpu,
410
+ &mut self.tm_gpu,
411
+ &mut self.fused_state,
412
+ &inputs_dev,
413
+ &mut cols_dev,
414
+ &mut anom_dev,
415
+ t,
416
+ input_bits,
417
+ learn,
418
+ ).map_err(|e| format!("launch_fused: {e:?}"))?;
419
+
420
+ // No dev.synchronize() here: caller must explicitly sync via the
421
+ // `device_sync()` method (or PyTorch auto-syncs when the output
422
+ // tensor is next consumed). Removing the per-launch barrier lets
423
+ // subsequent GPU work (mamba3 fwd, etc.) overlap in time.
424
+ Ok(())
425
+ });
426
+
427
+ result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
428
+ Ok(())
429
+ }
430
+
431
+ /// Explicit device synchronization β€” the caller must invoke this after
432
+ /// all batched `step_many_*_cuda` calls complete, before reading the
433
+ /// output tensors from a different CUDA stream. Equivalent to the old
434
+ /// per-call `dev.synchronize()` that was removed for overlap.
435
+ fn device_sync(&self) -> PyResult<()> {
436
+ let dev = self.sp_gpu.dev_ref();
437
+ dev.synchronize()
438
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?;
439
+ Ok(())
440
+ }
441
+ }
442
+
443
+ /// Batch B regions into ONE cooperative kernel launch. Breaks through the
444
+ /// CUDA cooperative-kernel device-level serialization: a single cooperative
445
+ /// launch with grid.y=B processes all regions concurrently β€” ~BΓ— speedup
446
+ /// over B sequential launches.
447
+ ///
448
+ /// All regions must have the same config (input_bits, n_columns,
449
+ /// cells_per_column). Each region keeps its independent GPU state.
450
+ /// Does NOT sync; caller must invoke `device_sync()` on any region
451
+ /// afterwards (or rely on a downstream torch op to auto-sync).
452
+ #[pyfunction]
453
+ #[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))]
454
+ fn step_batch_fused_cuda(
455
+ py: Python<'_>,
456
+ regions: Vec<Py<HTMRegionGpu>>,
457
+ sdr_cais: Vec<Bound<'_, PyDict>>,
458
+ cols_cais: Vec<Bound<'_, PyDict>>,
459
+ anom_cais: Vec<Bound<'_, PyDict>>,
460
+ learn: bool,
461
+ ) -> PyResult<()> {
462
+ let b = regions.len();
463
+ if b == 0 {
464
+ return Err(pyo3::exceptions::PyValueError::new_err("regions is empty"));
465
+ }
466
+ if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b {
467
+ return Err(pyo3::exceptions::PyValueError::new_err(
468
+ "sdr_cais / cols_cais / anom_cais length must match regions",
469
+ ));
470
+ }
471
+
472
+ // Parse all CAI dicts; collect device pointers. Validate shapes/dtypes.
473
+ let mut sdr_ptrs = Vec::with_capacity(b);
474
+ let mut cols_ptrs = Vec::with_capacity(b);
475
+ let mut anom_ptrs = Vec::with_capacity(b);
476
+ let (input_bits, n_columns, t) = {
477
+ let r0 = regions[0].bind(py).borrow();
478
+ (r0.input_bits, r0.n_columns, {
479
+ let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?;
480
+ if sh.len() != 2 {
481
+ return Err(pyo3::exceptions::PyValueError::new_err(
482
+ format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"),
483
+ ));
484
+ }
485
+ sh[0]
486
+ })
487
+ };
488
+
489
+ for i in 0..b {
490
+ let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?;
491
+ let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?;
492
+ let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?;
493
+ if sdr_type != "|u1" || cols_type != "|u1" {
494
+ return Err(pyo3::exceptions::PyValueError::new_err(
495
+ "sdr/cols typestr must be '|u1' (uint8)",
496
+ ));
497
+ }
498
+ if anom_type != "<f4" && anom_type != "=f4" {
499
+ return Err(pyo3::exceptions::PyValueError::new_err(
500
+ "anom typestr must be '<f4' (float32)",
501
+ ));
502
+ }
503
+ if sdr_shape != [t, input_bits] {
504
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
505
+ "sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})"
506
+ )));
507
+ }
508
+ if cols_shape != [t, n_columns] {
509
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
510
+ "cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})"
511
+ )));
512
+ }
513
+ if anom_shape != [t] {
514
+ return Err(pyo3::exceptions::PyValueError::new_err(format!(
515
+ "anom[{i}] shape {anom_shape:?} != ({t},)"
516
+ )));
517
+ }
518
+ sdr_ptrs.push(sdr_ptr);
519
+ cols_ptrs.push(cols_ptr);
520
+ anom_ptrs.push(anom_ptr);
521
+ }
522
+
523
+ // Exclusively borrow each region. PyRefMut guarantees uniqueness.
524
+ let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> =
525
+ regions.iter().map(|p| p.bind(py).borrow_mut()).collect();
526
+ // Collect raw mutable pointers β€” each PyRefMut exclusively borrows its
527
+ // region for the lifetime of this call, so pointers stay valid and
528
+ // unique. launch_fused_batched_raw only dereferences one region at a
529
+ // time, not constructing an aliased slice.
530
+ let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs
531
+ .iter_mut()
532
+ .map(|r| &mut **r as *mut HTMRegionGpu)
533
+ .collect();
534
+
535
+ // No allow_threads: raw pointers aren't Send. The launch is GPU-queued
536
+ // and sync'd downstream; holding the GIL for the duration is cheap.
537
+ fused::launch_fused_batched_raw(
538
+ &raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs,
539
+ t, input_bits, learn,
540
+ )
541
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?;
542
+ Ok(())
543
+ }
544
+
545
+ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
546
+ m.add_class::<HTMRegionGpu>()?;
547
+ m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
548
+ Ok(())
549
+ }
overlay/htm_rust/src/gpu/sp_gpu.rs CHANGED
@@ -1,796 +1,796 @@
1
- //! GPU implementation of the Spatial Pooler.
2
- //!
3
- //! One `SpatialPoolerGpu` owns a set of persistent device buffers + 4 PTX
4
- //! kernels. `compute(input, learn)` performs one SP step and returns the
5
- //! sorted active-column indices (host `Vec<u32>`) β€” this is what the CPU
6
- //! TemporalMemory consumes.
7
- //!
8
- //! Persistent state on device (per region):
9
- //! syn_bit : u32 [n_columns Γ— S] (constant after init)
10
- //! syn_perm : f32 [n_columns Γ— S] (updated by sp_learn)
11
- //! boost : f32 [n_columns]
12
- //! active_duty : f32 [n_columns]
13
- //! overlap_duty: f32 [n_columns]
14
- //!
15
- //! Per-step transient state:
16
- //! inp_dev : u8 [input_bits] (H2D copy each step)
17
- //! raw : u32 [n_columns]
18
- //! boosted : f32 [n_columns]
19
- //! active_mask : u8 [n_columns] (topk output, D2H at the end)
20
-
21
- use std::sync::Arc;
22
-
23
- use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig};
24
- use cudarc::nvrtc::Ptx;
25
-
26
- use crate::sp::SpatialPooler;
27
-
28
- // Embed PTX at compile time. OUT_DIR is set by build.rs.
29
- const PTX_SP_OVERLAP: &str =
30
- include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx"));
31
- const PTX_SP_TOPK: &str =
32
- include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx"));
33
- const PTX_SP_LEARN: &str =
34
- include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx"));
35
- const PTX_SP_DUTY: &str =
36
- include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx"));
37
- const PTX_SP_BOOST_FUSED: &str =
38
- include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx"));
39
-
40
- pub struct SpatialPoolerGpu {
41
- dev: Arc<CudaDevice>,
42
-
43
- // Config mirror (we don't touch CPU SpatialPooler after init).
44
- input_bits: usize,
45
- n_columns: usize,
46
- synapses_per_col: usize,
47
- conn_thr: f32,
48
- inc: f32,
49
- dec: f32,
50
- sparsity: f32,
51
- duty_period: f32,
52
- boost_strength: f32,
53
-
54
- // Persistent device state.
55
- syn_bit: CudaSlice<u32>,
56
- syn_perm: CudaSlice<f32>,
57
- boost: CudaSlice<f32>,
58
- active_duty: CudaSlice<f32>,
59
- overlap_duty: CudaSlice<f32>,
60
-
61
- // Transient scratch (reused each step).
62
- inp_dev: CudaSlice<u8>,
63
- raw: CudaSlice<u32>,
64
- boosted: CudaSlice<f32>,
65
- active_mask: CudaSlice<u8>,
66
-
67
- // Reusable host buffer for D2H of active_mask.
68
- host_mask: Vec<u8>,
69
-
70
- /// Strict bit-parity with CPU reference. Enabled for tests.
71
- /// Forces host-side boost/exp computation and the overlap-duty bump check
72
- /// every step. Default false for max throughput.
73
- strict_parity: bool,
74
- }
75
-
76
- impl SpatialPoolerGpu {
77
- /// Copy CPU SpatialPooler state onto the device. This preserves the
78
- /// exact seeded proximal synapse layout + initial permanences, so the
79
- /// GPU SP is a bit-identical parallel implementation of the CPU SP.
80
- pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> {
81
- let dev = CudaDevice::new(0)?;
82
- let cfg = &cpu.cfg;
83
- let n = cfg.n_columns;
84
- let s = cfg.potential_synapses;
85
-
86
- // Flatten proximal dendrites into column-major arrays.
87
- let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s);
88
- let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s);
89
- for col in &cpu.columns {
90
- debug_assert_eq!(col.inputs.len(), s);
91
- debug_assert_eq!(col.perms.len(), s);
92
- syn_bit_h.extend_from_slice(&col.inputs);
93
- syn_perm_h.extend_from_slice(&col.perms);
94
- }
95
-
96
- let syn_bit = dev.htod_sync_copy(&syn_bit_h)?;
97
- let syn_perm = dev.htod_sync_copy(&syn_perm_h)?;
98
- let boost = dev.htod_sync_copy(&cpu.boost)?;
99
- let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?;
100
- let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?;
101
-
102
- let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?;
103
- let raw: CudaSlice<u32> = dev.alloc_zeros(n)?;
104
- let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?;
105
- let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?;
106
-
107
- // Load PTX modules. Each .ptx is a module containing one `extern "C"`
108
- // function; we tag them by unique module names so multiple SP instances
109
- // don't collide (cudarc uses the (module, func) pair).
110
- // Actually: CudaDevice::load_ptx stores under the given module name
111
- // globally on the device, so we use a deterministic naming scheme.
112
- let modules = [
113
- ("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"),
114
- ("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"),
115
- ("htm_sp_learn", PTX_SP_LEARN, "sp_learn"),
116
- ("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"),
117
- ("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"),
118
- ];
119
- for (modname, ptx, fnname) in modules {
120
- // load_ptx is NOT idempotent β€” calling twice errors. For multi-region
121
- // support we check-then-load.
122
- if dev.get_func(modname, fnname).is_none() {
123
- dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
124
- }
125
- }
126
-
127
- Ok(Self {
128
- dev,
129
- input_bits: cfg.input_bits,
130
- n_columns: n,
131
- synapses_per_col: s,
132
- conn_thr: cfg.connected_threshold,
133
- inc: cfg.syn_perm_active_inc,
134
- dec: cfg.syn_perm_inactive_dec,
135
- sparsity: cfg.sparsity,
136
- duty_period: cfg.duty_cycle_period,
137
- boost_strength: cfg.boost_strength,
138
- syn_bit,
139
- syn_perm,
140
- boost,
141
- active_duty,
142
- overlap_duty,
143
- inp_dev,
144
- raw,
145
- boosted,
146
- active_mask,
147
- host_mask: vec![0u8; n],
148
- strict_parity: false,
149
- })
150
- }
151
-
152
- /// Enable strict bit-parity mode. Parity tests use this.
153
- pub fn set_strict_parity(&mut self, strict: bool) {
154
- self.strict_parity = strict;
155
- }
156
-
157
- /// Access to the underlying CudaDevice for host-side orchestration.
158
- pub fn dev_ref(&self) -> &Arc<CudaDevice> {
159
- &self.dev
160
- }
161
-
162
- // --- Fused-path accessors (immutable state reads + pointer-grabs). ---
163
- pub fn n_columns_accessor(&self) -> usize { self.n_columns }
164
- #[allow(dead_code)]
165
- pub fn input_bits_accessor(&self) -> usize { self.input_bits }
166
- pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col }
167
- pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr }
168
- pub fn inc_accessor(&self) -> f32 { self.inc }
169
- pub fn dec_accessor(&self) -> f32 { self.dec }
170
- pub fn sparsity_accessor(&self) -> f32 { self.sparsity }
171
- pub fn duty_period_accessor(&self) -> f32 { self.duty_period }
172
- #[allow(dead_code)]
173
- pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength }
174
-
175
- pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit }
176
- pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm }
177
- pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost }
178
- pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty }
179
-
180
- /// Compute the 95th-percentile-like initial threshold from raw overlaps
181
- /// after a short warmup pass. Used to seed `inhibition_threshold` such
182
- /// that activation rate starts near the sparsity target.
183
- /// Placeholder (returns a conservative constant); real warmup pass
184
- /// happens on the Rust orchestrator side.
185
- pub fn initial_threshold_estimate(&self) -> f32 {
186
- // With conn_thr=0.5, init_perm around 0.5Β±0.1, S=40, sparse SDR at 2%:
187
- // expected overlap ~ 40 * 0.02 = 0.8 connected hits β†’ boosted ~ 0.8.
188
- // Top-K selects top 2%, so threshold for top 2% is roughly the
189
- // 98th-percentile of boosted. Conservative start: 2.0.
190
- // The per-column adaptation will quickly steer each column's thr.
191
- 2.0f32
192
- }
193
-
194
- /// Batched multi-step SP on the GPU. Processes T timesteps from a
195
- /// pre-uploaded device input buffer. Emits `(T, n_cols)` u8 active-column
196
- /// mask to `cols_dev_out` and `(T,)` active column index list (in a
197
- /// per-step window of size k, padded with u32::MAX).
198
- ///
199
- /// For each step, this runs the same 5-kernel pipeline as `compute`, but
200
- /// skips the per-step boost/duty D2H→exp→H2D round-trip: instead it
201
- /// accumulates to a host scratch once every `boost_interval` steps.
202
- ///
203
- /// This is the fast path used by `HTMRegionGpu.step_many_gpu`.
204
- #[allow(clippy::too_many_arguments)]
205
- pub fn step_batch(
206
- &mut self,
207
- inputs_flat_dev: &CudaSlice<u8>,
208
- t: usize,
209
- input_bits: usize,
210
- learn: bool,
211
- cols_out: &mut [u8],
212
- active_indices_host: &mut Vec<u32>,
213
- ) -> Result<(), DriverError> {
214
- let n = self.n_columns;
215
- let k = ((self.sparsity * n as f32).round() as usize).max(1);
216
- debug_assert_eq!(cols_out.len(), t * n);
217
-
218
- let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
219
- let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
220
- let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
221
- let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
222
-
223
- let overlap_cfg = LaunchConfig {
224
- grid_dim: (n as u32, 1, 1),
225
- block_dim: (128, 1, 1),
226
- shared_mem_bytes: 0,
227
- };
228
- let topk_cfg = LaunchConfig {
229
- grid_dim: (1, 1, 1),
230
- block_dim: (256, 1, 1),
231
- shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
232
- };
233
- let learn_cfg = overlap_cfg;
234
- let duty_cfg = LaunchConfig {
235
- grid_dim: ((n as u32 + 255) / 256, 1, 1),
236
- block_dim: (256, 1, 1),
237
- shared_mem_bytes: 0,
238
- };
239
- let alpha = 1.0f32 / self.duty_period.max(1.0);
240
-
241
- // Reusable host buffer for the per-step active_mask D2H.
242
- self.host_mask.resize(n, 0);
243
-
244
- active_indices_host.clear();
245
-
246
- for ti in 0..t {
247
- // Point overlap kernel at the ti-th slice of the pre-uploaded input.
248
- // cudarc CudaSlice doesn't have a "view" per se, so we must copy the
249
- // slice into the reusable inp_dev buffer. This is a D2D copy β€” much
250
- // faster than H2D.
251
- // (Alternative: rewrite kernel to accept an offset; deferred.)
252
- let in_off = ti * input_bits;
253
- // Use dtod_copy via raw slice indexing: cudarc exposes slice() for this.
254
- let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
255
- self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
256
-
257
- // 1. sp_overlap
258
- unsafe {
259
- overlap_fn.clone().launch(
260
- overlap_cfg,
261
- (
262
- &self.inp_dev,
263
- &self.syn_bit,
264
- &self.syn_perm,
265
- &self.boost,
266
- self.conn_thr,
267
- self.synapses_per_col as u32,
268
- n as u32,
269
- &mut self.raw,
270
- &mut self.boosted,
271
- ),
272
- )?;
273
- }
274
-
275
- // 2. Clear active_mask, then sp_topk
276
- self.dev.memset_zeros(&mut self.active_mask)?;
277
- unsafe {
278
- topk_fn.clone().launch(
279
- topk_cfg,
280
- (&self.boosted, n as u32, k as u32, &mut self.active_mask),
281
- )?;
282
- }
283
-
284
- // 3. sp_learn
285
- if learn {
286
- unsafe {
287
- learn_fn.clone().launch(
288
- learn_cfg,
289
- (
290
- &self.active_mask,
291
- &self.inp_dev,
292
- &self.syn_bit,
293
- &mut self.syn_perm,
294
- self.inc,
295
- self.dec,
296
- self.synapses_per_col as u32,
297
- n as u32,
298
- ),
299
- )?;
300
- }
301
- }
302
-
303
- // 4. duty update (device)
304
- unsafe {
305
- duty_fn.clone().launch(
306
- duty_cfg,
307
- (
308
- &self.active_mask,
309
- &self.raw,
310
- &mut self.active_duty,
311
- &mut self.overlap_duty,
312
- &mut self.boost,
313
- alpha,
314
- 1.0f32,
315
- 0.0f32,
316
- 0.0f32,
317
- 0u32,
318
- n as u32,
319
- ),
320
- )?;
321
- }
322
-
323
- // 5. Boost update. Two modes:
324
- // * strict_parity (tests): host-side exp for bit-exact match.
325
- // * default (production): GPU expf is close enough and ~10x faster
326
- // since we skip the D2H/H2D round-trip.
327
- if learn && self.boost_strength > 0.0 {
328
- if self.strict_parity {
329
- let mut duty_host = vec![0f32; n];
330
- self.dev
331
- .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
332
- let sum: f32 = duty_host.iter().sum();
333
- let mean = sum / (n as f32);
334
- let mut boost_host = vec![0f32; n];
335
- for i in 0..n {
336
- boost_host[i] =
337
- (-self.boost_strength * (duty_host[i] - mean)).exp();
338
- }
339
- self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
340
-
341
- // Permanence bump (rare). Only evaluated in strict mode.
342
- let mut ov_host = vec![0f32; n];
343
- self.dev
344
- .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
345
- let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
346
- if max_ov > 0.0 {
347
- let thr = 0.001f32 * max_ov;
348
- let bump = self.inc * 0.1f32;
349
- let bump_cols: Vec<u32> = ov_host
350
- .iter()
351
- .enumerate()
352
- .filter_map(|(i, &o)| {
353
- if o < thr { Some(i as u32) } else { None }
354
- })
355
- .collect();
356
- if !bump_cols.is_empty() {
357
- let s = self.synapses_per_col;
358
- let mut perm_host = vec![0f32; n * s];
359
- self.dev
360
- .dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
361
- for &c in &bump_cols {
362
- let base = (c as usize) * s;
363
- for p in &mut perm_host[base..base + s] {
364
- *p = (*p + bump).min(1.0);
365
- }
366
- }
367
- self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
368
- }
369
- }
370
- } else {
371
- // Fast path: fused mean + boost = expf(-strength*(ad-mean))
372
- // in a single GPU block. Zero D2H, zero H2D β€” fully async.
373
- let boost_fn = self
374
- .dev
375
- .get_func("htm_sp_boost_fused", "sp_boost_from_duty")
376
- .expect("sp_boost_fused not loaded");
377
- let boost_cfg = LaunchConfig {
378
- grid_dim: (1, 1, 1),
379
- block_dim: (1024, 1, 1),
380
- shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
381
- };
382
- unsafe {
383
- boost_fn.launch(
384
- boost_cfg,
385
- (
386
- &self.active_duty,
387
- &mut self.boost,
388
- self.boost_strength,
389
- n as u32,
390
- ),
391
- )?;
392
- }
393
- }
394
- }
395
-
396
- // D2H the active_mask for this step. This is the single
397
- // unavoidable sync point per step β€” CPU TM needs the active
398
- // indices for its next state update. At 2048 bytes / step this
399
- // is tiny in bandwidth but costs a full syncronize (~5-10ΞΌs).
400
- self.dev
401
- .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
402
- let co = ti * n;
403
- cols_out[co..co + n].copy_from_slice(&self.host_mask);
404
- // Extract active indices.
405
- for (i, &b) in self.host_mask.iter().enumerate() {
406
- if b != 0 {
407
- active_indices_host.push(i as u32);
408
- }
409
- }
410
- // Insert separator (u32::MAX) between steps to demarcate step boundaries.
411
- active_indices_host.push(u32::MAX);
412
- }
413
-
414
- Ok(())
415
- }
416
-
417
- /// Fully-on-GPU batched SP + TM. Zero per-step host sync.
418
- ///
419
- /// Inputs:
420
- /// inputs_flat_dev : (T * input_bits) u8 already uploaded
421
- /// cols_dev : (T * n_cols) u8 output β€” active-column mask per step
422
- /// anom_dev : (T,) f32 output β€” anomaly score per step
423
- /// tm : persistent GPU TemporalMemory for this region
424
- #[allow(clippy::too_many_arguments)]
425
- pub fn step_batch_with_tm(
426
- &mut self,
427
- inputs_flat_dev: &CudaSlice<u8>,
428
- t: usize,
429
- input_bits: usize,
430
- learn: bool,
431
- cols_dev: &mut CudaSlice<u8>,
432
- anom_dev: &mut CudaSlice<f32>,
433
- tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu,
434
- ) -> Result<(), DriverError> {
435
- let n = self.n_columns;
436
- let k = ((self.sparsity * n as f32).round() as usize).max(1);
437
- debug_assert_eq!(cols_dev.len(), t * n);
438
- debug_assert_eq!(anom_dev.len(), t);
439
-
440
- let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
441
- let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
442
- let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
443
- let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
444
-
445
- let overlap_cfg = LaunchConfig {
446
- grid_dim: (n as u32, 1, 1),
447
- block_dim: (128, 1, 1),
448
- shared_mem_bytes: 0,
449
- };
450
- let topk_cfg = LaunchConfig {
451
- grid_dim: (1, 1, 1),
452
- block_dim: (256, 1, 1),
453
- shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
454
- };
455
- let learn_cfg = overlap_cfg;
456
- let duty_cfg = LaunchConfig {
457
- grid_dim: ((n as u32 + 255) / 256, 1, 1),
458
- block_dim: (256, 1, 1),
459
- shared_mem_bytes: 0,
460
- };
461
- let alpha = 1.0f32 / self.duty_period.max(1.0);
462
-
463
- for ti in 0..t {
464
- let in_off = ti * input_bits;
465
- let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
466
- self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
467
-
468
- // 1. sp_overlap
469
- unsafe {
470
- overlap_fn.clone().launch(
471
- overlap_cfg,
472
- (
473
- &self.inp_dev,
474
- &self.syn_bit,
475
- &self.syn_perm,
476
- &self.boost,
477
- self.conn_thr,
478
- self.synapses_per_col as u32,
479
- n as u32,
480
- &mut self.raw,
481
- &mut self.boosted,
482
- ),
483
- )?;
484
- }
485
-
486
- // 2. clear + sp_topk
487
- self.dev.memset_zeros(&mut self.active_mask)?;
488
- unsafe {
489
- topk_fn.clone().launch(
490
- topk_cfg,
491
- (&self.boosted, n as u32, k as u32, &mut self.active_mask),
492
- )?;
493
- }
494
-
495
- // 3. sp_learn
496
- if learn {
497
- unsafe {
498
- learn_fn.clone().launch(
499
- learn_cfg,
500
- (
501
- &self.active_mask,
502
- &self.inp_dev,
503
- &self.syn_bit,
504
- &mut self.syn_perm,
505
- self.inc,
506
- self.dec,
507
- self.synapses_per_col as u32,
508
- n as u32,
509
- ),
510
- )?;
511
- }
512
- }
513
-
514
- // 4. duty update (stage 1: no-boost write)
515
- unsafe {
516
- duty_fn.clone().launch(
517
- duty_cfg,
518
- (
519
- &self.active_mask,
520
- &self.raw,
521
- &mut self.active_duty,
522
- &mut self.overlap_duty,
523
- &mut self.boost,
524
- alpha,
525
- 1.0f32,
526
- 0.0f32,
527
- 0.0f32,
528
- 0u32,
529
- n as u32,
530
- ),
531
- )?;
532
- }
533
-
534
- // 5. Boost update: fused GPU kernel (no D2H).
535
- if learn && self.boost_strength > 0.0 {
536
- let boost_fn = self.dev
537
- .get_func("htm_sp_boost_fused", "sp_boost_from_duty")
538
- .expect("sp_boost_fused not loaded");
539
- let boost_cfg = LaunchConfig {
540
- grid_dim: (1, 1, 1),
541
- block_dim: (1024, 1, 1),
542
- shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
543
- };
544
- unsafe {
545
- boost_fn.launch(
546
- boost_cfg,
547
- (
548
- &self.active_duty,
549
- &mut self.boost,
550
- self.boost_strength,
551
- n as u32,
552
- ),
553
- )?;
554
- }
555
- }
556
-
557
- // 6. Copy active_mask slice into cols_dev[ti*n .. (ti+1)*n].
558
- let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n);
559
- self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?;
560
-
561
- // 7. GPU TM step: predict + activate + anomaly + learn, all on device.
562
- tm.step(&self.active_mask, anom_dev, ti as u32, learn)?;
563
- }
564
-
565
- Ok(())
566
- }
567
-
568
- /// One SP step on the GPU. Returns sorted active-column indices.
569
- pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> {
570
- debug_assert_eq!(input.len(), self.input_bits);
571
- let n = self.n_columns;
572
- let k = ((self.sparsity * n as f32).round() as usize).max(1);
573
-
574
- // 1. H2D input SDR.
575
- self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?;
576
-
577
- // 2. Launch sp_overlap: grid=n_columns, block=128.
578
- let overlap_fn = self
579
- .dev
580
- .get_func("htm_sp_overlap", "sp_overlap")
581
- .expect("sp_overlap not loaded");
582
- let overlap_cfg = LaunchConfig {
583
- grid_dim: (n as u32, 1, 1),
584
- block_dim: (128, 1, 1),
585
- shared_mem_bytes: 0,
586
- };
587
- unsafe {
588
- overlap_fn.launch(
589
- overlap_cfg,
590
- (
591
- &self.inp_dev,
592
- &self.syn_bit,
593
- &self.syn_perm,
594
- &self.boost,
595
- self.conn_thr,
596
- self.synapses_per_col as u32,
597
- n as u32,
598
- &mut self.raw,
599
- &mut self.boosted,
600
- ),
601
- )?;
602
- }
603
-
604
- // 3. Launch sp_topk: single block, shared mem = n_columns * f32.
605
- let topk_fn = self
606
- .dev
607
- .get_func("htm_sp_topk", "sp_topk_select")
608
- .expect("sp_topk not loaded");
609
- let topk_cfg = LaunchConfig {
610
- grid_dim: (1, 1, 1),
611
- block_dim: (256, 1, 1),
612
- shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
613
- };
614
- // Clear active_mask first. memset_zeros avoids an H2D of a host
615
- // zeroes vector every step.
616
- self.dev.memset_zeros(&mut self.active_mask)?;
617
- unsafe {
618
- topk_fn.launch(
619
- topk_cfg,
620
- (
621
- &self.boosted,
622
- n as u32,
623
- k as u32,
624
- &mut self.active_mask,
625
- ),
626
- )?;
627
- }
628
-
629
- // 4. Optional: sp_learn on active columns.
630
- if learn {
631
- let learn_fn = self
632
- .dev
633
- .get_func("htm_sp_learn", "sp_learn")
634
- .expect("sp_learn not loaded");
635
- let learn_cfg = LaunchConfig {
636
- grid_dim: (n as u32, 1, 1),
637
- block_dim: (128, 1, 1),
638
- shared_mem_bytes: 0,
639
- };
640
- unsafe {
641
- learn_fn.launch(
642
- learn_cfg,
643
- (
644
- &self.active_mask,
645
- &self.inp_dev,
646
- &self.syn_bit,
647
- &mut self.syn_perm,
648
- self.inc,
649
- self.dec,
650
- self.synapses_per_col as u32,
651
- n as u32,
652
- ),
653
- )?;
654
- }
655
- }
656
-
657
- // 5. Duty cycle + boost update. Always runs (matches CPU).
658
- // We need mean_duty on the host β€” compute BEFORE the update (matches
659
- // CPU sp.rs line 200-205 where mean is computed then written).
660
- // Actually CPU computes mean of the PRE-update duty cycles too? Re-read:
661
- // sp.rs lines 186-196 update duty cycles (pre-mean).
662
- // Line 202: mean = sum(active_duty_cycle) / n ← after update.
663
- // Line 204: boost[i] = exp(-strength*(active_duty[i] - mean)).
664
- // So mean is on POST-update values.
665
- // Easiest: 1) run duty update with boost_strength=0 (skip boost calc),
666
- // 2) D2H active_duty, compute mean, 3) run a boost-only kernel
667
- // OR inline the exp() in a second launch with mean passed.
668
- //
669
- // For simplicity and correctness we fuse: run the duty kernel with
670
- // mean=0 and boost_strength=0 (disables boost write), then D2H to
671
- // compute mean, then re-launch with the true mean. Two launches, one
672
- // tiny D2H (n Γ— f32). At n=2048 this is 8KB per step β€” negligible.
673
- let alpha = 1.0f32 / self.duty_period.max(1.0);
674
- let duty_fn = self
675
- .dev
676
- .get_func("htm_sp_duty", "sp_duty_update")
677
- .expect("sp_duty not loaded");
678
- let duty_cfg = LaunchConfig {
679
- grid_dim: ((n as u32 + 255) / 256, 1, 1),
680
- block_dim: (256, 1, 1),
681
- shared_mem_bytes: 0,
682
- };
683
- // Stage 1: update duty cycles (boost_strength=0 -> no write).
684
- unsafe {
685
- duty_fn.launch(
686
- duty_cfg,
687
- (
688
- &self.active_mask,
689
- &self.raw,
690
- &mut self.active_duty,
691
- &mut self.overlap_duty,
692
- &mut self.boost,
693
- alpha,
694
- 1.0f32, // stim_thr
695
- 0.0f32, // boost_strength = 0 -> skip write
696
- 0.0f32, // mean_duty (unused)
697
- 0u32, // learn_flag = 0
698
- n as u32,
699
- ),
700
- )?;
701
- }
702
-
703
- if learn && self.boost_strength > 0.0 && self.strict_parity {
704
- // Boost update must bit-match CPU `f32::exp`, so we compute it on
705
- // the host and copy back. Cost per step: 8KB D2H + 8KB H2D at n=2048.
706
- // Critical for learning parity β€” CUDA expf (even without fast-math)
707
- // uses different rounding for some inputs than host libm.
708
- let mut duty_host = vec![0f32; n];
709
- self.dev
710
- .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
711
- let sum: f32 = duty_host.iter().sum();
712
- let mean = sum / (n as f32);
713
- let mut boost_host = vec![0f32; n];
714
- for i in 0..n {
715
- boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp();
716
- }
717
- self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
718
-
719
- // CPU sp.rs 210-226: permanence bump for chronically under-stimulated
720
- // columns. If overlap_duty_cycle[i] < 0.001 * max(overlap_duty_cycle),
721
- // add inc*0.1 to every synapse of column i (clamped to 1.0).
722
- // This runs only once per step and only for the rare cases, but we
723
- // need it for bit-exact parity with CPU learn.
724
- let mut ov_host = vec![0f32; n];
725
- self.dev
726
- .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
727
- let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
728
- if max_ov > 0.0 {
729
- let thr = 0.001f32 * max_ov;
730
- let bump = self.inc * 0.1f32;
731
- // Find columns needing a bump. Usually empty. Rare β†’ D2H/H2D
732
- // of syn_perm is cheap (n*S*4 = 320KB at n=2048,S=40).
733
- let bump_cols: Vec<u32> = ov_host
734
- .iter()
735
- .enumerate()
736
- .filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None })
737
- .collect();
738
- if !bump_cols.is_empty() {
739
- // Download, bump, upload. (Keeps implementation simple and
740
- // bit-exact. Could kernelize later.)
741
- let s = self.synapses_per_col;
742
- let mut perm_host = vec![0f32; n * s];
743
- self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
744
- for &c in &bump_cols {
745
- let base = (c as usize) * s;
746
- for p in &mut perm_host[base..base + s] {
747
- *p = (*p + bump).min(1.0);
748
- }
749
- }
750
- self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
751
- }
752
- }
753
- } else if learn && self.boost_strength > 0.0 {
754
- // Fast path: GPU-side boost using the already-loaded duty kernel.
755
- let mut duty_host = vec![0f32; n];
756
- self.dev
757
- .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
758
- let sum: f32 = duty_host.iter().sum();
759
- let mean = sum / (n as f32);
760
- let boost_fn = self
761
- .dev
762
- .get_func("htm_sp_duty", "sp_duty_update")
763
- .expect("sp_duty not loaded");
764
- unsafe {
765
- boost_fn.launch(
766
- duty_cfg,
767
- (
768
- &self.active_mask,
769
- &self.raw,
770
- &mut self.active_duty,
771
- &mut self.overlap_duty,
772
- &mut self.boost,
773
- 0.0f32,
774
- 1.0f32,
775
- self.boost_strength,
776
- mean,
777
- 1u32,
778
- n as u32,
779
- ),
780
- )?;
781
- }
782
- }
783
-
784
- // 6. D2H active_mask and convert to sorted index list.
785
- self.dev
786
- .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
787
- let mut active: Vec<u32> = Vec::with_capacity(k);
788
- for (i, &b) in self.host_mask.iter().enumerate() {
789
- if b != 0 {
790
- active.push(i as u32);
791
- }
792
- }
793
- debug_assert_eq!(active.len(), k, "SP must emit exactly k winners");
794
- Ok(active)
795
- }
796
- }
 
1
+ //! GPU implementation of the Spatial Pooler.
2
+ //!
3
+ //! One `SpatialPoolerGpu` owns a set of persistent device buffers + 4 PTX
4
+ //! kernels. `compute(input, learn)` performs one SP step and returns the
5
+ //! sorted active-column indices (host `Vec<u32>`) β€” this is what the CPU
6
+ //! TemporalMemory consumes.
7
+ //!
8
+ //! Persistent state on device (per region):
9
+ //! syn_bit : u32 [n_columns Γ— S] (constant after init)
10
+ //! syn_perm : f32 [n_columns Γ— S] (updated by sp_learn)
11
+ //! boost : f32 [n_columns]
12
+ //! active_duty : f32 [n_columns]
13
+ //! overlap_duty: f32 [n_columns]
14
+ //!
15
+ //! Per-step transient state:
16
+ //! inp_dev : u8 [input_bits] (H2D copy each step)
17
+ //! raw : u32 [n_columns]
18
+ //! boosted : f32 [n_columns]
19
+ //! active_mask : u8 [n_columns] (topk output, D2H at the end)
20
+
21
+ use std::sync::Arc;
22
+
23
+ use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig};
24
+ use cudarc::nvrtc::Ptx;
25
+
26
+ use crate::sp::SpatialPooler;
27
+
28
+ // Embed PTX at compile time. OUT_DIR is set by build.rs.
29
+ const PTX_SP_OVERLAP: &str =
30
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx"));
31
+ const PTX_SP_TOPK: &str =
32
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx"));
33
+ const PTX_SP_LEARN: &str =
34
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx"));
35
+ const PTX_SP_DUTY: &str =
36
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx"));
37
+ const PTX_SP_BOOST_FUSED: &str =
38
+ include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx"));
39
+
40
+ pub struct SpatialPoolerGpu {
41
+ dev: Arc<CudaDevice>,
42
+
43
+ // Config mirror (we don't touch CPU SpatialPooler after init).
44
+ input_bits: usize,
45
+ n_columns: usize,
46
+ synapses_per_col: usize,
47
+ conn_thr: f32,
48
+ inc: f32,
49
+ dec: f32,
50
+ sparsity: f32,
51
+ duty_period: f32,
52
+ boost_strength: f32,
53
+
54
+ // Persistent device state.
55
+ syn_bit: CudaSlice<u32>,
56
+ syn_perm: CudaSlice<f32>,
57
+ boost: CudaSlice<f32>,
58
+ active_duty: CudaSlice<f32>,
59
+ overlap_duty: CudaSlice<f32>,
60
+
61
+ // Transient scratch (reused each step).
62
+ inp_dev: CudaSlice<u8>,
63
+ raw: CudaSlice<u32>,
64
+ boosted: CudaSlice<f32>,
65
+ active_mask: CudaSlice<u8>,
66
+
67
+ // Reusable host buffer for D2H of active_mask.
68
+ host_mask: Vec<u8>,
69
+
70
+ /// Strict bit-parity with CPU reference. Enabled for tests.
71
+ /// Forces host-side boost/exp computation and the overlap-duty bump check
72
+ /// every step. Default false for max throughput.
73
+ strict_parity: bool,
74
+ }
75
+
76
+ impl SpatialPoolerGpu {
77
+ /// Copy CPU SpatialPooler state onto the device. This preserves the
78
+ /// exact seeded proximal synapse layout + initial permanences, so the
79
+ /// GPU SP is a bit-identical parallel implementation of the CPU SP.
80
+ pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> {
81
+ let dev = CudaDevice::new(0)?;
82
+ let cfg = &cpu.cfg;
83
+ let n = cfg.n_columns;
84
+ let s = cfg.potential_synapses;
85
+
86
+ // Flatten proximal dendrites into column-major arrays.
87
+ let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s);
88
+ let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s);
89
+ for col in &cpu.columns {
90
+ debug_assert_eq!(col.inputs.len(), s);
91
+ debug_assert_eq!(col.perms.len(), s);
92
+ syn_bit_h.extend_from_slice(&col.inputs);
93
+ syn_perm_h.extend_from_slice(&col.perms);
94
+ }
95
+
96
+ let syn_bit = dev.htod_sync_copy(&syn_bit_h)?;
97
+ let syn_perm = dev.htod_sync_copy(&syn_perm_h)?;
98
+ let boost = dev.htod_sync_copy(&cpu.boost)?;
99
+ let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?;
100
+ let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?;
101
+
102
+ let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?;
103
+ let raw: CudaSlice<u32> = dev.alloc_zeros(n)?;
104
+ let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?;
105
+ let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?;
106
+
107
+ // Load PTX modules. Each .ptx is a module containing one `extern "C"`
108
+ // function; we tag them by unique module names so multiple SP instances
109
+ // don't collide (cudarc uses the (module, func) pair).
110
+ // Actually: CudaDevice::load_ptx stores under the given module name
111
+ // globally on the device, so we use a deterministic naming scheme.
112
+ let modules = [
113
+ ("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"),
114
+ ("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"),
115
+ ("htm_sp_learn", PTX_SP_LEARN, "sp_learn"),
116
+ ("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"),
117
+ ("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"),
118
+ ];
119
+ for (modname, ptx, fnname) in modules {
120
+ // load_ptx is NOT idempotent β€” calling twice errors. For multi-region
121
+ // support we check-then-load.
122
+ if dev.get_func(modname, fnname).is_none() {
123
+ dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
124
+ }
125
+ }
126
+
127
+ Ok(Self {
128
+ dev,
129
+ input_bits: cfg.input_bits,
130
+ n_columns: n,
131
+ synapses_per_col: s,
132
+ conn_thr: cfg.connected_threshold,
133
+ inc: cfg.syn_perm_active_inc,
134
+ dec: cfg.syn_perm_inactive_dec,
135
+ sparsity: cfg.sparsity,
136
+ duty_period: cfg.duty_cycle_period,
137
+ boost_strength: cfg.boost_strength,
138
+ syn_bit,
139
+ syn_perm,
140
+ boost,
141
+ active_duty,
142
+ overlap_duty,
143
+ inp_dev,
144
+ raw,
145
+ boosted,
146
+ active_mask,
147
+ host_mask: vec![0u8; n],
148
+ strict_parity: false,
149
+ })
150
+ }
151
+
152
+ /// Enable strict bit-parity mode. Parity tests use this.
153
+ pub fn set_strict_parity(&mut self, strict: bool) {
154
+ self.strict_parity = strict;
155
+ }
156
+
157
+ /// Access to the underlying CudaDevice for host-side orchestration.
158
+ pub fn dev_ref(&self) -> &Arc<CudaDevice> {
159
+ &self.dev
160
+ }
161
+
162
+ // --- Fused-path accessors (immutable state reads + pointer-grabs). ---
163
+ pub fn n_columns_accessor(&self) -> usize { self.n_columns }
164
+ #[allow(dead_code)]
165
+ pub fn input_bits_accessor(&self) -> usize { self.input_bits }
166
+ pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col }
167
+ pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr }
168
+ pub fn inc_accessor(&self) -> f32 { self.inc }
169
+ pub fn dec_accessor(&self) -> f32 { self.dec }
170
+ pub fn sparsity_accessor(&self) -> f32 { self.sparsity }
171
+ pub fn duty_period_accessor(&self) -> f32 { self.duty_period }
172
+ #[allow(dead_code)]
173
+ pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength }
174
+
175
+ pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit }
176
+ pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm }
177
+ pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost }
178
+ pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty }
179
+
180
+ /// Compute the 95th-percentile-like initial threshold from raw overlaps
181
+ /// after a short warmup pass. Used to seed `inhibition_threshold` such
182
+ /// that activation rate starts near the sparsity target.
183
+ /// Placeholder (returns a conservative constant); real warmup pass
184
+ /// happens on the Rust orchestrator side.
185
+ pub fn initial_threshold_estimate(&self) -> f32 {
186
+ // With conn_thr=0.5, init_perm around 0.5Β±0.1, S=40, sparse SDR at 2%:
187
+ // expected overlap ~ 40 * 0.02 = 0.8 connected hits β†’ boosted ~ 0.8.
188
+ // Top-K selects top 2%, so threshold for top 2% is roughly the
189
+ // 98th-percentile of boosted. Conservative start: 2.0.
190
+ // The per-column adaptation will quickly steer each column's thr.
191
+ 2.0f32
192
+ }
193
+
194
+ /// Batched multi-step SP on the GPU. Processes T timesteps from a
195
+ /// pre-uploaded device input buffer. Emits `(T, n_cols)` u8 active-column
196
+ /// mask to `cols_dev_out` and `(T,)` active column index list (in a
197
+ /// per-step window of size k, padded with u32::MAX).
198
+ ///
199
+ /// For each step, this runs the same 5-kernel pipeline as `compute`, but
200
+ /// skips the per-step boost/duty D2H→exp→H2D round-trip: instead it
201
+ /// accumulates to a host scratch once every `boost_interval` steps.
202
+ ///
203
+ /// This is the fast path used by `HTMRegionGpu.step_many_gpu`.
204
+ #[allow(clippy::too_many_arguments)]
205
+ pub fn step_batch(
206
+ &mut self,
207
+ inputs_flat_dev: &CudaSlice<u8>,
208
+ t: usize,
209
+ input_bits: usize,
210
+ learn: bool,
211
+ cols_out: &mut [u8],
212
+ active_indices_host: &mut Vec<u32>,
213
+ ) -> Result<(), DriverError> {
214
+ let n = self.n_columns;
215
+ let k = ((self.sparsity * n as f32).round() as usize).max(1);
216
+ debug_assert_eq!(cols_out.len(), t * n);
217
+
218
+ let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
219
+ let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
220
+ let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
221
+ let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
222
+
223
+ let overlap_cfg = LaunchConfig {
224
+ grid_dim: (n as u32, 1, 1),
225
+ block_dim: (128, 1, 1),
226
+ shared_mem_bytes: 0,
227
+ };
228
+ let topk_cfg = LaunchConfig {
229
+ grid_dim: (1, 1, 1),
230
+ block_dim: (256, 1, 1),
231
+ shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
232
+ };
233
+ let learn_cfg = overlap_cfg;
234
+ let duty_cfg = LaunchConfig {
235
+ grid_dim: ((n as u32 + 255) / 256, 1, 1),
236
+ block_dim: (256, 1, 1),
237
+ shared_mem_bytes: 0,
238
+ };
239
+ let alpha = 1.0f32 / self.duty_period.max(1.0);
240
+
241
+ // Reusable host buffer for the per-step active_mask D2H.
242
+ self.host_mask.resize(n, 0);
243
+
244
+ active_indices_host.clear();
245
+
246
+ for ti in 0..t {
247
+ // Point overlap kernel at the ti-th slice of the pre-uploaded input.
248
+ // cudarc CudaSlice doesn't have a "view" per se, so we must copy the
249
+ // slice into the reusable inp_dev buffer. This is a D2D copy β€” much
250
+ // faster than H2D.
251
+ // (Alternative: rewrite kernel to accept an offset; deferred.)
252
+ let in_off = ti * input_bits;
253
+ // Use dtod_copy via raw slice indexing: cudarc exposes slice() for this.
254
+ let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
255
+ self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
256
+
257
+ // 1. sp_overlap
258
+ unsafe {
259
+ overlap_fn.clone().launch(
260
+ overlap_cfg,
261
+ (
262
+ &self.inp_dev,
263
+ &self.syn_bit,
264
+ &self.syn_perm,
265
+ &self.boost,
266
+ self.conn_thr,
267
+ self.synapses_per_col as u32,
268
+ n as u32,
269
+ &mut self.raw,
270
+ &mut self.boosted,
271
+ ),
272
+ )?;
273
+ }
274
+
275
+ // 2. Clear active_mask, then sp_topk
276
+ self.dev.memset_zeros(&mut self.active_mask)?;
277
+ unsafe {
278
+ topk_fn.clone().launch(
279
+ topk_cfg,
280
+ (&self.boosted, n as u32, k as u32, &mut self.active_mask),
281
+ )?;
282
+ }
283
+
284
+ // 3. sp_learn
285
+ if learn {
286
+ unsafe {
287
+ learn_fn.clone().launch(
288
+ learn_cfg,
289
+ (
290
+ &self.active_mask,
291
+ &self.inp_dev,
292
+ &self.syn_bit,
293
+ &mut self.syn_perm,
294
+ self.inc,
295
+ self.dec,
296
+ self.synapses_per_col as u32,
297
+ n as u32,
298
+ ),
299
+ )?;
300
+ }
301
+ }
302
+
303
+ // 4. duty update (device)
304
+ unsafe {
305
+ duty_fn.clone().launch(
306
+ duty_cfg,
307
+ (
308
+ &self.active_mask,
309
+ &self.raw,
310
+ &mut self.active_duty,
311
+ &mut self.overlap_duty,
312
+ &mut self.boost,
313
+ alpha,
314
+ 1.0f32,
315
+ 0.0f32,
316
+ 0.0f32,
317
+ 0u32,
318
+ n as u32,
319
+ ),
320
+ )?;
321
+ }
322
+
323
+ // 5. Boost update. Two modes:
324
+ // * strict_parity (tests): host-side exp for bit-exact match.
325
+ // * default (production): GPU expf is close enough and ~10x faster
326
+ // since we skip the D2H/H2D round-trip.
327
+ if learn && self.boost_strength > 0.0 {
328
+ if self.strict_parity {
329
+ let mut duty_host = vec![0f32; n];
330
+ self.dev
331
+ .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
332
+ let sum: f32 = duty_host.iter().sum();
333
+ let mean = sum / (n as f32);
334
+ let mut boost_host = vec![0f32; n];
335
+ for i in 0..n {
336
+ boost_host[i] =
337
+ (-self.boost_strength * (duty_host[i] - mean)).exp();
338
+ }
339
+ self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
340
+
341
+ // Permanence bump (rare). Only evaluated in strict mode.
342
+ let mut ov_host = vec![0f32; n];
343
+ self.dev
344
+ .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
345
+ let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
346
+ if max_ov > 0.0 {
347
+ let thr = 0.001f32 * max_ov;
348
+ let bump = self.inc * 0.1f32;
349
+ let bump_cols: Vec<u32> = ov_host
350
+ .iter()
351
+ .enumerate()
352
+ .filter_map(|(i, &o)| {
353
+ if o < thr { Some(i as u32) } else { None }
354
+ })
355
+ .collect();
356
+ if !bump_cols.is_empty() {
357
+ let s = self.synapses_per_col;
358
+ let mut perm_host = vec![0f32; n * s];
359
+ self.dev
360
+ .dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
361
+ for &c in &bump_cols {
362
+ let base = (c as usize) * s;
363
+ for p in &mut perm_host[base..base + s] {
364
+ *p = (*p + bump).min(1.0);
365
+ }
366
+ }
367
+ self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
368
+ }
369
+ }
370
+ } else {
371
+ // Fast path: fused mean + boost = expf(-strength*(ad-mean))
372
+ // in a single GPU block. Zero D2H, zero H2D β€” fully async.
373
+ let boost_fn = self
374
+ .dev
375
+ .get_func("htm_sp_boost_fused", "sp_boost_from_duty")
376
+ .expect("sp_boost_fused not loaded");
377
+ let boost_cfg = LaunchConfig {
378
+ grid_dim: (1, 1, 1),
379
+ block_dim: (1024, 1, 1),
380
+ shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
381
+ };
382
+ unsafe {
383
+ boost_fn.launch(
384
+ boost_cfg,
385
+ (
386
+ &self.active_duty,
387
+ &mut self.boost,
388
+ self.boost_strength,
389
+ n as u32,
390
+ ),
391
+ )?;
392
+ }
393
+ }
394
+ }
395
+
396
+ // D2H the active_mask for this step. This is the single
397
+ // unavoidable sync point per step β€” CPU TM needs the active
398
+ // indices for its next state update. At 2048 bytes / step this
399
+ // is tiny in bandwidth but costs a full syncronize (~5-10ΞΌs).
400
+ self.dev
401
+ .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
402
+ let co = ti * n;
403
+ cols_out[co..co + n].copy_from_slice(&self.host_mask);
404
+ // Extract active indices.
405
+ for (i, &b) in self.host_mask.iter().enumerate() {
406
+ if b != 0 {
407
+ active_indices_host.push(i as u32);
408
+ }
409
+ }
410
+ // Insert separator (u32::MAX) between steps to demarcate step boundaries.
411
+ active_indices_host.push(u32::MAX);
412
+ }
413
+
414
+ Ok(())
415
+ }
416
+
417
+ /// Fully-on-GPU batched SP + TM. Zero per-step host sync.
418
+ ///
419
+ /// Inputs:
420
+ /// inputs_flat_dev : (T * input_bits) u8 already uploaded
421
+ /// cols_dev : (T * n_cols) u8 output β€” active-column mask per step
422
+ /// anom_dev : (T,) f32 output β€” anomaly score per step
423
+ /// tm : persistent GPU TemporalMemory for this region
424
+ #[allow(clippy::too_many_arguments)]
425
+ pub fn step_batch_with_tm(
426
+ &mut self,
427
+ inputs_flat_dev: &CudaSlice<u8>,
428
+ t: usize,
429
+ input_bits: usize,
430
+ learn: bool,
431
+ cols_dev: &mut CudaSlice<u8>,
432
+ anom_dev: &mut CudaSlice<f32>,
433
+ tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu,
434
+ ) -> Result<(), DriverError> {
435
+ let n = self.n_columns;
436
+ let k = ((self.sparsity * n as f32).round() as usize).max(1);
437
+ debug_assert_eq!(cols_dev.len(), t * n);
438
+ debug_assert_eq!(anom_dev.len(), t);
439
+
440
+ let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
441
+ let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
442
+ let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
443
+ let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
444
+
445
+ let overlap_cfg = LaunchConfig {
446
+ grid_dim: (n as u32, 1, 1),
447
+ block_dim: (128, 1, 1),
448
+ shared_mem_bytes: 0,
449
+ };
450
+ let topk_cfg = LaunchConfig {
451
+ grid_dim: (1, 1, 1),
452
+ block_dim: (256, 1, 1),
453
+ shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
454
+ };
455
+ let learn_cfg = overlap_cfg;
456
+ let duty_cfg = LaunchConfig {
457
+ grid_dim: ((n as u32 + 255) / 256, 1, 1),
458
+ block_dim: (256, 1, 1),
459
+ shared_mem_bytes: 0,
460
+ };
461
+ let alpha = 1.0f32 / self.duty_period.max(1.0);
462
+
463
+ for ti in 0..t {
464
+ let in_off = ti * input_bits;
465
+ let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
466
+ self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
467
+
468
+ // 1. sp_overlap
469
+ unsafe {
470
+ overlap_fn.clone().launch(
471
+ overlap_cfg,
472
+ (
473
+ &self.inp_dev,
474
+ &self.syn_bit,
475
+ &self.syn_perm,
476
+ &self.boost,
477
+ self.conn_thr,
478
+ self.synapses_per_col as u32,
479
+ n as u32,
480
+ &mut self.raw,
481
+ &mut self.boosted,
482
+ ),
483
+ )?;
484
+ }
485
+
486
+ // 2. clear + sp_topk
487
+ self.dev.memset_zeros(&mut self.active_mask)?;
488
+ unsafe {
489
+ topk_fn.clone().launch(
490
+ topk_cfg,
491
+ (&self.boosted, n as u32, k as u32, &mut self.active_mask),
492
+ )?;
493
+ }
494
+
495
+ // 3. sp_learn
496
+ if learn {
497
+ unsafe {
498
+ learn_fn.clone().launch(
499
+ learn_cfg,
500
+ (
501
+ &self.active_mask,
502
+ &self.inp_dev,
503
+ &self.syn_bit,
504
+ &mut self.syn_perm,
505
+ self.inc,
506
+ self.dec,
507
+ self.synapses_per_col as u32,
508
+ n as u32,
509
+ ),
510
+ )?;
511
+ }
512
+ }
513
+
514
+ // 4. duty update (stage 1: no-boost write)
515
+ unsafe {
516
+ duty_fn.clone().launch(
517
+ duty_cfg,
518
+ (
519
+ &self.active_mask,
520
+ &self.raw,
521
+ &mut self.active_duty,
522
+ &mut self.overlap_duty,
523
+ &mut self.boost,
524
+ alpha,
525
+ 1.0f32,
526
+ 0.0f32,
527
+ 0.0f32,
528
+ 0u32,
529
+ n as u32,
530
+ ),
531
+ )?;
532
+ }
533
+
534
+ // 5. Boost update: fused GPU kernel (no D2H).
535
+ if learn && self.boost_strength > 0.0 {
536
+ let boost_fn = self.dev
537
+ .get_func("htm_sp_boost_fused", "sp_boost_from_duty")
538
+ .expect("sp_boost_fused not loaded");
539
+ let boost_cfg = LaunchConfig {
540
+ grid_dim: (1, 1, 1),
541
+ block_dim: (1024, 1, 1),
542
+ shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
543
+ };
544
+ unsafe {
545
+ boost_fn.launch(
546
+ boost_cfg,
547
+ (
548
+ &self.active_duty,
549
+ &mut self.boost,
550
+ self.boost_strength,
551
+ n as u32,
552
+ ),
553
+ )?;
554
+ }
555
+ }
556
+
557
+ // 6. Copy active_mask slice into cols_dev[ti*n .. (ti+1)*n].
558
+ let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n);
559
+ self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?;
560
+
561
+ // 7. GPU TM step: predict + activate + anomaly + learn, all on device.
562
+ tm.step(&self.active_mask, anom_dev, ti as u32, learn)?;
563
+ }
564
+
565
+ Ok(())
566
+ }
567
+
568
+ /// One SP step on the GPU. Returns sorted active-column indices.
569
+ pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> {
570
+ debug_assert_eq!(input.len(), self.input_bits);
571
+ let n = self.n_columns;
572
+ let k = ((self.sparsity * n as f32).round() as usize).max(1);
573
+
574
+ // 1. H2D input SDR.
575
+ self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?;
576
+
577
+ // 2. Launch sp_overlap: grid=n_columns, block=128.
578
+ let overlap_fn = self
579
+ .dev
580
+ .get_func("htm_sp_overlap", "sp_overlap")
581
+ .expect("sp_overlap not loaded");
582
+ let overlap_cfg = LaunchConfig {
583
+ grid_dim: (n as u32, 1, 1),
584
+ block_dim: (128, 1, 1),
585
+ shared_mem_bytes: 0,
586
+ };
587
+ unsafe {
588
+ overlap_fn.launch(
589
+ overlap_cfg,
590
+ (
591
+ &self.inp_dev,
592
+ &self.syn_bit,
593
+ &self.syn_perm,
594
+ &self.boost,
595
+ self.conn_thr,
596
+ self.synapses_per_col as u32,
597
+ n as u32,
598
+ &mut self.raw,
599
+ &mut self.boosted,
600
+ ),
601
+ )?;
602
+ }
603
+
604
+ // 3. Launch sp_topk: single block, shared mem = n_columns * f32.
605
+ let topk_fn = self
606
+ .dev
607
+ .get_func("htm_sp_topk", "sp_topk_select")
608
+ .expect("sp_topk not loaded");
609
+ let topk_cfg = LaunchConfig {
610
+ grid_dim: (1, 1, 1),
611
+ block_dim: (256, 1, 1),
612
+ shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
613
+ };
614
+ // Clear active_mask first. memset_zeros avoids an H2D of a host
615
+ // zeroes vector every step.
616
+ self.dev.memset_zeros(&mut self.active_mask)?;
617
+ unsafe {
618
+ topk_fn.launch(
619
+ topk_cfg,
620
+ (
621
+ &self.boosted,
622
+ n as u32,
623
+ k as u32,
624
+ &mut self.active_mask,
625
+ ),
626
+ )?;
627
+ }
628
+
629
+ // 4. Optional: sp_learn on active columns.
630
+ if learn {
631
+ let learn_fn = self
632
+ .dev
633
+ .get_func("htm_sp_learn", "sp_learn")
634
+ .expect("sp_learn not loaded");
635
+ let learn_cfg = LaunchConfig {
636
+ grid_dim: (n as u32, 1, 1),
637
+ block_dim: (128, 1, 1),
638
+ shared_mem_bytes: 0,
639
+ };
640
+ unsafe {
641
+ learn_fn.launch(
642
+ learn_cfg,
643
+ (
644
+ &self.active_mask,
645
+ &self.inp_dev,
646
+ &self.syn_bit,
647
+ &mut self.syn_perm,
648
+ self.inc,
649
+ self.dec,
650
+ self.synapses_per_col as u32,
651
+ n as u32,
652
+ ),
653
+ )?;
654
+ }
655
+ }
656
+
657
+ // 5. Duty cycle + boost update. Always runs (matches CPU).
658
+ // We need mean_duty on the host β€” compute BEFORE the update (matches
659
+ // CPU sp.rs line 200-205 where mean is computed then written).
660
+ // Actually CPU computes mean of the PRE-update duty cycles too? Re-read:
661
+ // sp.rs lines 186-196 update duty cycles (pre-mean).
662
+ // Line 202: mean = sum(active_duty_cycle) / n ← after update.
663
+ // Line 204: boost[i] = exp(-strength*(active_duty[i] - mean)).
664
+ // So mean is on POST-update values.
665
+ // Easiest: 1) run duty update with boost_strength=0 (skip boost calc),
666
+ // 2) D2H active_duty, compute mean, 3) run a boost-only kernel
667
+ // OR inline the exp() in a second launch with mean passed.
668
+ //
669
+ // For simplicity and correctness we fuse: run the duty kernel with
670
+ // mean=0 and boost_strength=0 (disables boost write), then D2H to
671
+ // compute mean, then re-launch with the true mean. Two launches, one
672
+ // tiny D2H (n Γ— f32). At n=2048 this is 8KB per step β€” negligible.
673
+ let alpha = 1.0f32 / self.duty_period.max(1.0);
674
+ let duty_fn = self
675
+ .dev
676
+ .get_func("htm_sp_duty", "sp_duty_update")
677
+ .expect("sp_duty not loaded");
678
+ let duty_cfg = LaunchConfig {
679
+ grid_dim: ((n as u32 + 255) / 256, 1, 1),
680
+ block_dim: (256, 1, 1),
681
+ shared_mem_bytes: 0,
682
+ };
683
+ // Stage 1: update duty cycles (boost_strength=0 -> no write).
684
+ unsafe {
685
+ duty_fn.launch(
686
+ duty_cfg,
687
+ (
688
+ &self.active_mask,
689
+ &self.raw,
690
+ &mut self.active_duty,
691
+ &mut self.overlap_duty,
692
+ &mut self.boost,
693
+ alpha,
694
+ 1.0f32, // stim_thr
695
+ 0.0f32, // boost_strength = 0 -> skip write
696
+ 0.0f32, // mean_duty (unused)
697
+ 0u32, // learn_flag = 0
698
+ n as u32,
699
+ ),
700
+ )?;
701
+ }
702
+
703
+ if learn && self.boost_strength > 0.0 && self.strict_parity {
704
+ // Boost update must bit-match CPU `f32::exp`, so we compute it on
705
+ // the host and copy back. Cost per step: 8KB D2H + 8KB H2D at n=2048.
706
+ // Critical for learning parity β€” CUDA expf (even without fast-math)
707
+ // uses different rounding for some inputs than host libm.
708
+ let mut duty_host = vec![0f32; n];
709
+ self.dev
710
+ .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
711
+ let sum: f32 = duty_host.iter().sum();
712
+ let mean = sum / (n as f32);
713
+ let mut boost_host = vec![0f32; n];
714
+ for i in 0..n {
715
+ boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp();
716
+ }
717
+ self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
718
+
719
+ // CPU sp.rs 210-226: permanence bump for chronically under-stimulated
720
+ // columns. If overlap_duty_cycle[i] < 0.001 * max(overlap_duty_cycle),
721
+ // add inc*0.1 to every synapse of column i (clamped to 1.0).
722
+ // This runs only once per step and only for the rare cases, but we
723
+ // need it for bit-exact parity with CPU learn.
724
+ let mut ov_host = vec![0f32; n];
725
+ self.dev
726
+ .dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
727
+ let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
728
+ if max_ov > 0.0 {
729
+ let thr = 0.001f32 * max_ov;
730
+ let bump = self.inc * 0.1f32;
731
+ // Find columns needing a bump. Usually empty. Rare β†’ D2H/H2D
732
+ // of syn_perm is cheap (n*S*4 = 320KB at n=2048,S=40).
733
+ let bump_cols: Vec<u32> = ov_host
734
+ .iter()
735
+ .enumerate()
736
+ .filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None })
737
+ .collect();
738
+ if !bump_cols.is_empty() {
739
+ // Download, bump, upload. (Keeps implementation simple and
740
+ // bit-exact. Could kernelize later.)
741
+ let s = self.synapses_per_col;
742
+ let mut perm_host = vec![0f32; n * s];
743
+ self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
744
+ for &c in &bump_cols {
745
+ let base = (c as usize) * s;
746
+ for p in &mut perm_host[base..base + s] {
747
+ *p = (*p + bump).min(1.0);
748
+ }
749
+ }
750
+ self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
751
+ }
752
+ }
753
+ } else if learn && self.boost_strength > 0.0 {
754
+ // Fast path: GPU-side boost using the already-loaded duty kernel.
755
+ let mut duty_host = vec![0f32; n];
756
+ self.dev
757
+ .dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
758
+ let sum: f32 = duty_host.iter().sum();
759
+ let mean = sum / (n as f32);
760
+ let boost_fn = self
761
+ .dev
762
+ .get_func("htm_sp_duty", "sp_duty_update")
763
+ .expect("sp_duty not loaded");
764
+ unsafe {
765
+ boost_fn.launch(
766
+ duty_cfg,
767
+ (
768
+ &self.active_mask,
769
+ &self.raw,
770
+ &mut self.active_duty,
771
+ &mut self.overlap_duty,
772
+ &mut self.boost,
773
+ 0.0f32,
774
+ 1.0f32,
775
+ self.boost_strength,
776
+ mean,
777
+ 1u32,
778
+ n as u32,
779
+ ),
780
+ )?;
781
+ }
782
+ }
783
+
784
+ // 6. D2H active_mask and convert to sorted index list.
785
+ self.dev
786
+ .dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
787
+ let mut active: Vec<u32> = Vec::with_capacity(k);
788
+ for (i, &b) in self.host_mask.iter().enumerate() {
789
+ if b != 0 {
790
+ active.push(i as u32);
791
+ }
792
+ }
793
+ debug_assert_eq!(active.len(), k, "SP must emit exactly k winners");
794
+ Ok(active)
795
+ }
796
+ }
overlay/htm_rust/src/gpu/tm_gpu.rs CHANGED
@@ -1,460 +1,460 @@
1
- //! GPU Temporal Memory.
2
- //!
3
- //! Flat device storage. Pre-allocated segment slab:
4
- //! n_cells = n_columns * cells_per_column
5
- //! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL
6
- //! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT
7
- //!
8
- //! Defaults (CPU parity targets relaxed on GPU to keep memory tractable):
9
- //! MAX_SEGMENTS_PER_CELL = 16
10
- //! MAX_SYN_PER_SEGMENT = 32
11
- //!
12
- //! At n_cells = 65536:
13
- //! n_segments_max = 1_048_576 (~1M)
14
- //! n_synapses_max = 33_554_432 (~33M)
15
- //! Storage:
16
- //! syn_presyn : u32 Γ— 33M = 128 MB
17
- //! syn_perm : i16 Γ— 33M = 64 MB
18
- //! seg_cell : u32 Γ— 1M = 4 MB
19
- //! seg_syn_n : u32 Γ— 1M = 4 MB
20
- //! misc bitsets etc ~ <1 MB
21
- //! -------------------------------
22
- //! Total per region ~200 MB
23
- //!
24
- //! Permanences are stored as i16 scaled by 32767 (β†’ [0, 32767] represents
25
- //! [0.0, 1.0]). inc/dec are provided pre-scaled.
26
-
27
- use std::sync::Arc;
28
-
29
- use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig};
30
- use cudarc::nvrtc::Ptx;
31
-
32
- /// Packed config struct passed by value to TM kernels to stay under
33
- /// cudarc's 12-tuple launch limit. Layout must match the C-side
34
- /// `TmConfig` struct declared in each kernel.
35
- #[repr(C)]
36
- #[derive(Clone, Copy)]
37
- pub struct TmConfig {
38
- pub activation_threshold: u32,
39
- pub learning_threshold: u32,
40
- pub cells_per_column: u32,
41
- pub synapses_per_segment: u32,
42
- pub n_segments: u32,
43
- pub n_cells: u32,
44
- pub max_segments_per_cell: u32,
45
- pub max_new_synapses: u32,
46
- pub conn_thr_i16: i32, // i16 widened to i32 for alignment
47
- pub perm_inc_i16: i32,
48
- pub perm_dec_i16: i32,
49
- pub predicted_seg_dec_i16: i32,
50
- pub initial_perm_i16: i32,
51
- pub iter_seed: u32,
52
- pub n_cols: u32,
53
- pub bits_words: u32,
54
- }
55
-
56
- unsafe impl DeviceRepr for TmConfig {}
57
-
58
- // Embedded PTX.
59
- const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx"));
60
- const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx"));
61
- const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx"));
62
- const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx"));
63
- const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx"));
64
- const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx"));
65
- const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx"));
66
-
67
- /// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model:
68
- /// n_cells = 2048 Γ— 32 = 65_536
69
- /// n_segments_max = n_cells Γ— MAX_SEGMENTS_PER_CELL
70
- /// n_synapses_max = n_segments_max Γ— MAX_SYN_PER_SEGMENT
71
- ///
72
- /// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region).
73
- /// The training loop runs with `reset_each_forward=True`, so segment counts
74
- /// per window stay well below 32K (typical: ~n_cols new segs per step until
75
- /// the first matching segment is reused; in a 2048-step window that plateaus
76
- /// around ~5K total live segments). The 262K ceiling is generous headroom.
77
- pub const MAX_SEGMENTS_PER_CELL: usize = 4;
78
- pub const MAX_SYN_PER_SEGMENT: usize = 20;
79
-
80
- const PERM_SCALE: f32 = 32767.0;
81
-
82
- fn perm_f32_to_i16(x: f32) -> i16 {
83
- let clamped = x.clamp(0.0, 1.0);
84
- (clamped * PERM_SCALE).round() as i16
85
- }
86
-
87
- pub struct TemporalMemoryGpu {
88
- dev: Arc<CudaDevice>,
89
-
90
- // Config mirror
91
- pub n_columns: usize,
92
- pub cells_per_column: usize,
93
- pub activation_threshold: u32,
94
- pub learning_threshold: u32,
95
- pub initial_perm_i16: i16,
96
- pub conn_thr_i16: i16,
97
- pub perm_inc_i16: i16,
98
- pub perm_dec_i16: i16,
99
- pub predicted_seg_dec_i16: i16,
100
- pub max_new_synapse_count: u32,
101
-
102
- // Sizes
103
- pub n_cells: usize,
104
- pub n_segments_max: usize,
105
- pub bits_words: usize, // n_cells / 32
106
-
107
- // Persistent device buffers
108
- seg_cell_id: CudaSlice<u32>,
109
- seg_syn_count: CudaSlice<u32>,
110
- syn_presyn: CudaSlice<u32>,
111
- syn_perm: CudaSlice<i16>,
112
- cell_seg_count: CudaSlice<u32>,
113
-
114
- cell_active_bits: CudaSlice<u32>,
115
- cell_winner_bits: CudaSlice<u32>,
116
- cell_predictive_bits: CudaSlice<u32>,
117
- prev_active_bits: CudaSlice<u32>,
118
- prev_winner_bits: CudaSlice<u32>,
119
-
120
- col_predicted: CudaSlice<u8>,
121
- seg_num_active_conn: CudaSlice<u32>,
122
- seg_num_active_pot: CudaSlice<u32>,
123
- unpredicted_count: CudaSlice<u32>,
124
- burst_cols_flat: CudaSlice<u32>,
125
- burst_cols_count: CudaSlice<u32>,
126
- col_best_match: CudaSlice<u32>,
127
-
128
- iter_counter: u32,
129
- }
130
-
131
- impl TemporalMemoryGpu {
132
- pub fn new(
133
- dev: Arc<CudaDevice>,
134
- n_columns: usize,
135
- cells_per_column: usize,
136
- ) -> Result<Self, DriverError> {
137
- let n_cells = n_columns * cells_per_column;
138
- assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
139
- let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL;
140
- let bits_words = n_cells / 32;
141
-
142
- // Numenta defaults.
143
- let activation_threshold = 15u32;
144
- let learning_threshold = 13u32;
145
- let initial_perm_i16 = perm_f32_to_i16(0.21);
146
- let conn_thr_i16 = perm_f32_to_i16(0.50);
147
- let perm_inc_i16 = perm_f32_to_i16(0.10);
148
- let perm_dec_i16 = perm_f32_to_i16(0.10);
149
- let predicted_seg_dec_i16 = perm_f32_to_i16(0.10);
150
- let max_new_synapse_count = 20u32;
151
-
152
- // Allocate buffers.
153
- let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max];
154
- let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?;
155
- let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?;
156
- let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
157
- let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
158
- let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?;
159
-
160
- let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
161
- let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
162
- let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?;
163
- let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
164
- let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
165
-
166
- let col_predicted = dev.alloc_zeros::<u8>(n_columns)?;
167
- let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?;
168
- let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?;
169
- let unpredicted_count = dev.alloc_zeros::<u32>(1)?;
170
- // Bursting columns for one step bounded by n_columns.
171
- let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?;
172
- let burst_cols_count = dev.alloc_zeros::<u32>(1)?;
173
- let col_best_match = dev.alloc_zeros::<u32>(n_columns)?;
174
-
175
- // Load PTX modules.
176
- let modules = [
177
- ("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"),
178
- ("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"),
179
- ("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"),
180
- ("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"),
181
- ("htm_tm_grow", PTX_TM_GROW, "tm_grow"),
182
- ("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"),
183
- ("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"),
184
- ];
185
- for (modname, ptx, fnname) in modules {
186
- if dev.get_func(modname, fnname).is_none() {
187
- dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
188
- }
189
- }
190
-
191
- Ok(Self {
192
- dev,
193
- n_columns,
194
- cells_per_column,
195
- activation_threshold,
196
- learning_threshold,
197
- initial_perm_i16,
198
- conn_thr_i16,
199
- perm_inc_i16,
200
- perm_dec_i16,
201
- predicted_seg_dec_i16,
202
- max_new_synapse_count,
203
- n_cells,
204
- n_segments_max,
205
- bits_words,
206
- seg_cell_id,
207
- seg_syn_count,
208
- syn_presyn,
209
- syn_perm,
210
- cell_seg_count,
211
- cell_active_bits,
212
- cell_winner_bits,
213
- cell_predictive_bits,
214
- prev_active_bits,
215
- prev_winner_bits,
216
- col_predicted,
217
- seg_num_active_conn,
218
- seg_num_active_pot,
219
- unpredicted_count,
220
- burst_cols_flat,
221
- burst_cols_count,
222
- col_best_match,
223
- iter_counter: 0,
224
- })
225
- }
226
-
227
- // --- Fused-path accessors ---
228
- pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id }
229
- pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count }
230
- pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn }
231
- pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm }
232
- pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count }
233
-
234
- /// Hard reset β€” clear everything (predictive + active + segments).
235
- pub fn reset(&mut self) -> Result<(), DriverError> {
236
- // Restore "unused" sentinel in seg_cell_id.
237
- let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max];
238
- self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?;
239
- self.dev.memset_zeros(&mut self.seg_syn_count)?;
240
- self.dev.memset_zeros(&mut self.cell_seg_count)?;
241
- self.dev.memset_zeros(&mut self.cell_active_bits)?;
242
- self.dev.memset_zeros(&mut self.cell_winner_bits)?;
243
- self.dev.memset_zeros(&mut self.cell_predictive_bits)?;
244
- self.dev.memset_zeros(&mut self.prev_active_bits)?;
245
- self.dev.memset_zeros(&mut self.prev_winner_bits)?;
246
- self.dev.memset_zeros(&mut self.col_best_match)?;
247
- self.iter_counter = 0;
248
- Ok(())
249
- }
250
-
251
- fn build_cfg(&self) -> TmConfig {
252
- TmConfig {
253
- activation_threshold: self.activation_threshold,
254
- learning_threshold: self.learning_threshold,
255
- cells_per_column: self.cells_per_column as u32,
256
- synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
257
- n_segments: self.n_segments_max as u32,
258
- n_cells: self.n_cells as u32,
259
- max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
260
- max_new_synapses: self.max_new_synapse_count,
261
- conn_thr_i16: self.conn_thr_i16 as i32,
262
- perm_inc_i16: self.perm_inc_i16 as i32,
263
- perm_dec_i16: self.perm_dec_i16 as i32,
264
- predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32,
265
- initial_perm_i16: self.initial_perm_i16 as i32,
266
- iter_seed: self.iter_counter,
267
- n_cols: self.n_columns as u32,
268
- bits_words: self.bits_words as u32,
269
- }
270
- }
271
-
272
- /// Run one TM step on the GPU. Takes the SP active-column mask (u8, already
273
- /// on device) and writes `anomaly_out[t_slot]`.
274
- pub fn step(
275
- &mut self,
276
- sp_active_mask: &CudaSlice<u8>,
277
- anomaly_out: &mut CudaSlice<f32>,
278
- t_slot: u32,
279
- learn: bool,
280
- ) -> Result<(), DriverError> {
281
- let n_cells = self.n_cells;
282
- let n_cols = self.n_columns;
283
-
284
- let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap();
285
- let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap();
286
- let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap();
287
- let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap();
288
- let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap();
289
- let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap();
290
- let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap();
291
-
292
- self.iter_counter = self.iter_counter.wrapping_add(1);
293
- let cfg_val = self.build_cfg();
294
-
295
- // 0. Per-step reset.
296
- let reset_words = self.bits_words.max(n_cols);
297
- let reset_cfg = LaunchConfig {
298
- grid_dim: (((reset_words + 255) / 256) as u32, 1, 1),
299
- block_dim: (256, 1, 1),
300
- shared_mem_bytes: 0,
301
- };
302
- unsafe {
303
- reset_fn.clone().launch(
304
- reset_cfg,
305
- (
306
- &mut self.cell_active_bits,
307
- &mut self.cell_winner_bits,
308
- &mut self.cell_predictive_bits,
309
- &mut self.prev_active_bits,
310
- &mut self.prev_winner_bits,
311
- &mut self.col_predicted,
312
- &mut self.unpredicted_count,
313
- &mut self.burst_cols_count,
314
- &mut self.col_best_match,
315
- self.bits_words as u32,
316
- n_cols as u32,
317
- ),
318
- )?;
319
- }
320
-
321
- // 1. Predict (grid = n_cells; each block iterates its cell's segments).
322
- let predict_cfg = LaunchConfig {
323
- grid_dim: (n_cells as u32, 1, 1),
324
- block_dim: (32, 1, 1),
325
- shared_mem_bytes: 0,
326
- };
327
- unsafe {
328
- predict_fn.clone().launch(
329
- predict_cfg,
330
- (
331
- &self.seg_cell_id,
332
- &self.seg_syn_count,
333
- &self.syn_presyn,
334
- &self.syn_perm,
335
- &self.prev_active_bits,
336
- &mut self.cell_predictive_bits,
337
- &mut self.col_predicted,
338
- &mut self.seg_num_active_conn,
339
- &mut self.seg_num_active_pot,
340
- &mut self.col_best_match,
341
- &self.cell_seg_count,
342
- cfg_val,
343
- ),
344
- )?;
345
- }
346
-
347
- // 2. Activate.
348
- let activate_cfg = LaunchConfig {
349
- grid_dim: (((n_cols + 255) / 256) as u32, 1, 1),
350
- block_dim: (256, 1, 1),
351
- shared_mem_bytes: 0,
352
- };
353
- unsafe {
354
- activate_fn.clone().launch(
355
- activate_cfg,
356
- (
357
- sp_active_mask,
358
- &self.col_predicted,
359
- &self.cell_predictive_bits,
360
- &mut self.cell_active_bits,
361
- &mut self.cell_winner_bits,
362
- &mut self.unpredicted_count,
363
- &mut self.burst_cols_flat,
364
- &mut self.burst_cols_count,
365
- cfg_val,
366
- ),
367
- )?;
368
- }
369
-
370
- // 3. Anomaly.
371
- let anom_cfg = LaunchConfig {
372
- grid_dim: (1, 1, 1),
373
- block_dim: (256, 1, 1),
374
- shared_mem_bytes: 0,
375
- };
376
- unsafe {
377
- anom_fn.clone().launch(
378
- anom_cfg,
379
- (
380
- sp_active_mask,
381
- &self.unpredicted_count,
382
- anomaly_out,
383
- t_slot,
384
- n_cols as u32,
385
- ),
386
- )?;
387
- }
388
-
389
- if learn {
390
- // 4. Reinforce (grid = n_cells).
391
- let learn_cfg = LaunchConfig {
392
- grid_dim: (n_cells as u32, 1, 1),
393
- block_dim: (32, 1, 1),
394
- shared_mem_bytes: 0,
395
- };
396
- unsafe {
397
- learn_fn.clone().launch(
398
- learn_cfg,
399
- (
400
- &self.seg_cell_id,
401
- &self.seg_syn_count,
402
- &self.syn_presyn,
403
- &mut self.syn_perm,
404
- &self.seg_num_active_conn,
405
- &self.prev_active_bits,
406
- sp_active_mask,
407
- &self.col_predicted,
408
- &self.cell_seg_count,
409
- cfg_val,
410
- ),
411
- )?;
412
- }
413
-
414
- // 5. Punish.
415
- unsafe {
416
- punish_fn.clone().launch(
417
- learn_cfg,
418
- (
419
- &self.seg_cell_id,
420
- &self.seg_syn_count,
421
- &self.syn_presyn,
422
- &mut self.syn_perm,
423
- &self.seg_num_active_pot,
424
- &self.prev_active_bits,
425
- sp_active_mask,
426
- &self.cell_seg_count,
427
- cfg_val,
428
- ),
429
- )?;
430
- }
431
-
432
- // 6. Grow.
433
- let grow_cfg = LaunchConfig {
434
- grid_dim: (n_cols as u32, 1, 1),
435
- block_dim: (32, 1, 1),
436
- shared_mem_bytes: 0,
437
- };
438
- unsafe {
439
- grow_fn.clone().launch(
440
- grow_cfg,
441
- (
442
- &mut self.seg_cell_id,
443
- &mut self.seg_syn_count,
444
- &mut self.syn_presyn,
445
- &mut self.syn_perm,
446
- &mut self.cell_seg_count,
447
- &self.burst_cols_flat,
448
- &self.burst_cols_count,
449
- &self.prev_winner_bits,
450
- &self.prev_active_bits,
451
- &self.col_best_match,
452
- cfg_val,
453
- ),
454
- )?;
455
- }
456
- }
457
-
458
- Ok(())
459
- }
460
- }
 
1
+ //! GPU Temporal Memory.
2
+ //!
3
+ //! Flat device storage. Pre-allocated segment slab:
4
+ //! n_cells = n_columns * cells_per_column
5
+ //! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL
6
+ //! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT
7
+ //!
8
+ //! Defaults (CPU parity targets relaxed on GPU to keep memory tractable):
9
+ //! MAX_SEGMENTS_PER_CELL = 16
10
+ //! MAX_SYN_PER_SEGMENT = 32
11
+ //!
12
+ //! At n_cells = 65536:
13
+ //! n_segments_max = 1_048_576 (~1M)
14
+ //! n_synapses_max = 33_554_432 (~33M)
15
+ //! Storage:
16
+ //! syn_presyn : u32 Γ— 33M = 128 MB
17
+ //! syn_perm : i16 Γ— 33M = 64 MB
18
+ //! seg_cell : u32 Γ— 1M = 4 MB
19
+ //! seg_syn_n : u32 Γ— 1M = 4 MB
20
+ //! misc bitsets etc ~ <1 MB
21
+ //! -------------------------------
22
+ //! Total per region ~200 MB
23
+ //!
24
+ //! Permanences are stored as i16 scaled by 32767 (β†’ [0, 32767] represents
25
+ //! [0.0, 1.0]). inc/dec are provided pre-scaled.
26
+
27
+ use std::sync::Arc;
28
+
29
+ use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig};
30
+ use cudarc::nvrtc::Ptx;
31
+
32
+ /// Packed config struct passed by value to TM kernels to stay under
33
+ /// cudarc's 12-tuple launch limit. Layout must match the C-side
34
+ /// `TmConfig` struct declared in each kernel.
35
+ #[repr(C)]
36
+ #[derive(Clone, Copy)]
37
+ pub struct TmConfig {
38
+ pub activation_threshold: u32,
39
+ pub learning_threshold: u32,
40
+ pub cells_per_column: u32,
41
+ pub synapses_per_segment: u32,
42
+ pub n_segments: u32,
43
+ pub n_cells: u32,
44
+ pub max_segments_per_cell: u32,
45
+ pub max_new_synapses: u32,
46
+ pub conn_thr_i16: i32, // i16 widened to i32 for alignment
47
+ pub perm_inc_i16: i32,
48
+ pub perm_dec_i16: i32,
49
+ pub predicted_seg_dec_i16: i32,
50
+ pub initial_perm_i16: i32,
51
+ pub iter_seed: u32,
52
+ pub n_cols: u32,
53
+ pub bits_words: u32,
54
+ }
55
+
56
+ unsafe impl DeviceRepr for TmConfig {}
57
+
58
+ // Embedded PTX.
59
+ const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx"));
60
+ const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx"));
61
+ const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx"));
62
+ const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx"));
63
+ const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx"));
64
+ const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx"));
65
+ const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx"));
66
+
67
+ /// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model:
68
+ /// n_cells = 2048 Γ— 32 = 65_536
69
+ /// n_segments_max = n_cells Γ— MAX_SEGMENTS_PER_CELL
70
+ /// n_synapses_max = n_segments_max Γ— MAX_SYN_PER_SEGMENT
71
+ ///
72
+ /// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region).
73
+ /// The training loop runs with `reset_each_forward=True`, so segment counts
74
+ /// per window stay well below 32K (typical: ~n_cols new segs per step until
75
+ /// the first matching segment is reused; in a 2048-step window that plateaus
76
+ /// around ~5K total live segments). The 262K ceiling is generous headroom.
77
+ pub const MAX_SEGMENTS_PER_CELL: usize = 4;
78
+ pub const MAX_SYN_PER_SEGMENT: usize = 20;
79
+
80
+ const PERM_SCALE: f32 = 32767.0;
81
+
82
+ fn perm_f32_to_i16(x: f32) -> i16 {
83
+ let clamped = x.clamp(0.0, 1.0);
84
+ (clamped * PERM_SCALE).round() as i16
85
+ }
86
+
87
+ pub struct TemporalMemoryGpu {
88
+ dev: Arc<CudaDevice>,
89
+
90
+ // Config mirror
91
+ pub n_columns: usize,
92
+ pub cells_per_column: usize,
93
+ pub activation_threshold: u32,
94
+ pub learning_threshold: u32,
95
+ pub initial_perm_i16: i16,
96
+ pub conn_thr_i16: i16,
97
+ pub perm_inc_i16: i16,
98
+ pub perm_dec_i16: i16,
99
+ pub predicted_seg_dec_i16: i16,
100
+ pub max_new_synapse_count: u32,
101
+
102
+ // Sizes
103
+ pub n_cells: usize,
104
+ pub n_segments_max: usize,
105
+ pub bits_words: usize, // n_cells / 32
106
+
107
+ // Persistent device buffers
108
+ seg_cell_id: CudaSlice<u32>,
109
+ seg_syn_count: CudaSlice<u32>,
110
+ syn_presyn: CudaSlice<u32>,
111
+ syn_perm: CudaSlice<i16>,
112
+ cell_seg_count: CudaSlice<u32>,
113
+
114
+ cell_active_bits: CudaSlice<u32>,
115
+ cell_winner_bits: CudaSlice<u32>,
116
+ cell_predictive_bits: CudaSlice<u32>,
117
+ prev_active_bits: CudaSlice<u32>,
118
+ prev_winner_bits: CudaSlice<u32>,
119
+
120
+ col_predicted: CudaSlice<u8>,
121
+ seg_num_active_conn: CudaSlice<u32>,
122
+ seg_num_active_pot: CudaSlice<u32>,
123
+ unpredicted_count: CudaSlice<u32>,
124
+ burst_cols_flat: CudaSlice<u32>,
125
+ burst_cols_count: CudaSlice<u32>,
126
+ col_best_match: CudaSlice<u32>,
127
+
128
+ iter_counter: u32,
129
+ }
130
+
131
+ impl TemporalMemoryGpu {
132
+ pub fn new(
133
+ dev: Arc<CudaDevice>,
134
+ n_columns: usize,
135
+ cells_per_column: usize,
136
+ ) -> Result<Self, DriverError> {
137
+ let n_cells = n_columns * cells_per_column;
138
+ assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
139
+ let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL;
140
+ let bits_words = n_cells / 32;
141
+
142
+ // Numenta defaults.
143
+ let activation_threshold = 15u32;
144
+ let learning_threshold = 13u32;
145
+ let initial_perm_i16 = perm_f32_to_i16(0.21);
146
+ let conn_thr_i16 = perm_f32_to_i16(0.50);
147
+ let perm_inc_i16 = perm_f32_to_i16(0.10);
148
+ let perm_dec_i16 = perm_f32_to_i16(0.10);
149
+ let predicted_seg_dec_i16 = perm_f32_to_i16(0.10);
150
+ let max_new_synapse_count = 20u32;
151
+
152
+ // Allocate buffers.
153
+ let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max];
154
+ let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?;
155
+ let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?;
156
+ let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
157
+ let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
158
+ let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?;
159
+
160
+ let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
161
+ let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
162
+ let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?;
163
+ let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
164
+ let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
165
+
166
+ let col_predicted = dev.alloc_zeros::<u8>(n_columns)?;
167
+ let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?;
168
+ let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?;
169
+ let unpredicted_count = dev.alloc_zeros::<u32>(1)?;
170
+ // Bursting columns for one step bounded by n_columns.
171
+ let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?;
172
+ let burst_cols_count = dev.alloc_zeros::<u32>(1)?;
173
+ let col_best_match = dev.alloc_zeros::<u32>(n_columns)?;
174
+
175
+ // Load PTX modules.
176
+ let modules = [
177
+ ("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"),
178
+ ("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"),
179
+ ("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"),
180
+ ("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"),
181
+ ("htm_tm_grow", PTX_TM_GROW, "tm_grow"),
182
+ ("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"),
183
+ ("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"),
184
+ ];
185
+ for (modname, ptx, fnname) in modules {
186
+ if dev.get_func(modname, fnname).is_none() {
187
+ dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
188
+ }
189
+ }
190
+
191
+ Ok(Self {
192
+ dev,
193
+ n_columns,
194
+ cells_per_column,
195
+ activation_threshold,
196
+ learning_threshold,
197
+ initial_perm_i16,
198
+ conn_thr_i16,
199
+ perm_inc_i16,
200
+ perm_dec_i16,
201
+ predicted_seg_dec_i16,
202
+ max_new_synapse_count,
203
+ n_cells,
204
+ n_segments_max,
205
+ bits_words,
206
+ seg_cell_id,
207
+ seg_syn_count,
208
+ syn_presyn,
209
+ syn_perm,
210
+ cell_seg_count,
211
+ cell_active_bits,
212
+ cell_winner_bits,
213
+ cell_predictive_bits,
214
+ prev_active_bits,
215
+ prev_winner_bits,
216
+ col_predicted,
217
+ seg_num_active_conn,
218
+ seg_num_active_pot,
219
+ unpredicted_count,
220
+ burst_cols_flat,
221
+ burst_cols_count,
222
+ col_best_match,
223
+ iter_counter: 0,
224
+ })
225
+ }
226
+
227
+ // --- Fused-path accessors ---
228
+ pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id }
229
+ pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count }
230
+ pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn }
231
+ pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm }
232
+ pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count }
233
+
234
+ /// Hard reset β€” clear everything (predictive + active + segments).
235
+ pub fn reset(&mut self) -> Result<(), DriverError> {
236
+ // Restore "unused" sentinel in seg_cell_id.
237
+ let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max];
238
+ self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?;
239
+ self.dev.memset_zeros(&mut self.seg_syn_count)?;
240
+ self.dev.memset_zeros(&mut self.cell_seg_count)?;
241
+ self.dev.memset_zeros(&mut self.cell_active_bits)?;
242
+ self.dev.memset_zeros(&mut self.cell_winner_bits)?;
243
+ self.dev.memset_zeros(&mut self.cell_predictive_bits)?;
244
+ self.dev.memset_zeros(&mut self.prev_active_bits)?;
245
+ self.dev.memset_zeros(&mut self.prev_winner_bits)?;
246
+ self.dev.memset_zeros(&mut self.col_best_match)?;
247
+ self.iter_counter = 0;
248
+ Ok(())
249
+ }
250
+
251
+ fn build_cfg(&self) -> TmConfig {
252
+ TmConfig {
253
+ activation_threshold: self.activation_threshold,
254
+ learning_threshold: self.learning_threshold,
255
+ cells_per_column: self.cells_per_column as u32,
256
+ synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
257
+ n_segments: self.n_segments_max as u32,
258
+ n_cells: self.n_cells as u32,
259
+ max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
260
+ max_new_synapses: self.max_new_synapse_count,
261
+ conn_thr_i16: self.conn_thr_i16 as i32,
262
+ perm_inc_i16: self.perm_inc_i16 as i32,
263
+ perm_dec_i16: self.perm_dec_i16 as i32,
264
+ predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32,
265
+ initial_perm_i16: self.initial_perm_i16 as i32,
266
+ iter_seed: self.iter_counter,
267
+ n_cols: self.n_columns as u32,
268
+ bits_words: self.bits_words as u32,
269
+ }
270
+ }
271
+
272
+ /// Run one TM step on the GPU. Takes the SP active-column mask (u8, already
273
+ /// on device) and writes `anomaly_out[t_slot]`.
274
+ pub fn step(
275
+ &mut self,
276
+ sp_active_mask: &CudaSlice<u8>,
277
+ anomaly_out: &mut CudaSlice<f32>,
278
+ t_slot: u32,
279
+ learn: bool,
280
+ ) -> Result<(), DriverError> {
281
+ let n_cells = self.n_cells;
282
+ let n_cols = self.n_columns;
283
+
284
+ let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap();
285
+ let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap();
286
+ let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap();
287
+ let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap();
288
+ let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap();
289
+ let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap();
290
+ let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap();
291
+
292
+ self.iter_counter = self.iter_counter.wrapping_add(1);
293
+ let cfg_val = self.build_cfg();
294
+
295
+ // 0. Per-step reset.
296
+ let reset_words = self.bits_words.max(n_cols);
297
+ let reset_cfg = LaunchConfig {
298
+ grid_dim: (((reset_words + 255) / 256) as u32, 1, 1),
299
+ block_dim: (256, 1, 1),
300
+ shared_mem_bytes: 0,
301
+ };
302
+ unsafe {
303
+ reset_fn.clone().launch(
304
+ reset_cfg,
305
+ (
306
+ &mut self.cell_active_bits,
307
+ &mut self.cell_winner_bits,
308
+ &mut self.cell_predictive_bits,
309
+ &mut self.prev_active_bits,
310
+ &mut self.prev_winner_bits,
311
+ &mut self.col_predicted,
312
+ &mut self.unpredicted_count,
313
+ &mut self.burst_cols_count,
314
+ &mut self.col_best_match,
315
+ self.bits_words as u32,
316
+ n_cols as u32,
317
+ ),
318
+ )?;
319
+ }
320
+
321
+ // 1. Predict (grid = n_cells; each block iterates its cell's segments).
322
+ let predict_cfg = LaunchConfig {
323
+ grid_dim: (n_cells as u32, 1, 1),
324
+ block_dim: (32, 1, 1),
325
+ shared_mem_bytes: 0,
326
+ };
327
+ unsafe {
328
+ predict_fn.clone().launch(
329
+ predict_cfg,
330
+ (
331
+ &self.seg_cell_id,
332
+ &self.seg_syn_count,
333
+ &self.syn_presyn,
334
+ &self.syn_perm,
335
+ &self.prev_active_bits,
336
+ &mut self.cell_predictive_bits,
337
+ &mut self.col_predicted,
338
+ &mut self.seg_num_active_conn,
339
+ &mut self.seg_num_active_pot,
340
+ &mut self.col_best_match,
341
+ &self.cell_seg_count,
342
+ cfg_val,
343
+ ),
344
+ )?;
345
+ }
346
+
347
+ // 2. Activate.
348
+ let activate_cfg = LaunchConfig {
349
+ grid_dim: (((n_cols + 255) / 256) as u32, 1, 1),
350
+ block_dim: (256, 1, 1),
351
+ shared_mem_bytes: 0,
352
+ };
353
+ unsafe {
354
+ activate_fn.clone().launch(
355
+ activate_cfg,
356
+ (
357
+ sp_active_mask,
358
+ &self.col_predicted,
359
+ &self.cell_predictive_bits,
360
+ &mut self.cell_active_bits,
361
+ &mut self.cell_winner_bits,
362
+ &mut self.unpredicted_count,
363
+ &mut self.burst_cols_flat,
364
+ &mut self.burst_cols_count,
365
+ cfg_val,
366
+ ),
367
+ )?;
368
+ }
369
+
370
+ // 3. Anomaly.
371
+ let anom_cfg = LaunchConfig {
372
+ grid_dim: (1, 1, 1),
373
+ block_dim: (256, 1, 1),
374
+ shared_mem_bytes: 0,
375
+ };
376
+ unsafe {
377
+ anom_fn.clone().launch(
378
+ anom_cfg,
379
+ (
380
+ sp_active_mask,
381
+ &self.unpredicted_count,
382
+ anomaly_out,
383
+ t_slot,
384
+ n_cols as u32,
385
+ ),
386
+ )?;
387
+ }
388
+
389
+ if learn {
390
+ // 4. Reinforce (grid = n_cells).
391
+ let learn_cfg = LaunchConfig {
392
+ grid_dim: (n_cells as u32, 1, 1),
393
+ block_dim: (32, 1, 1),
394
+ shared_mem_bytes: 0,
395
+ };
396
+ unsafe {
397
+ learn_fn.clone().launch(
398
+ learn_cfg,
399
+ (
400
+ &self.seg_cell_id,
401
+ &self.seg_syn_count,
402
+ &self.syn_presyn,
403
+ &mut self.syn_perm,
404
+ &self.seg_num_active_conn,
405
+ &self.prev_active_bits,
406
+ sp_active_mask,
407
+ &self.col_predicted,
408
+ &self.cell_seg_count,
409
+ cfg_val,
410
+ ),
411
+ )?;
412
+ }
413
+
414
+ // 5. Punish.
415
+ unsafe {
416
+ punish_fn.clone().launch(
417
+ learn_cfg,
418
+ (
419
+ &self.seg_cell_id,
420
+ &self.seg_syn_count,
421
+ &self.syn_presyn,
422
+ &mut self.syn_perm,
423
+ &self.seg_num_active_pot,
424
+ &self.prev_active_bits,
425
+ sp_active_mask,
426
+ &self.cell_seg_count,
427
+ cfg_val,
428
+ ),
429
+ )?;
430
+ }
431
+
432
+ // 6. Grow.
433
+ let grow_cfg = LaunchConfig {
434
+ grid_dim: (n_cols as u32, 1, 1),
435
+ block_dim: (32, 1, 1),
436
+ shared_mem_bytes: 0,
437
+ };
438
+ unsafe {
439
+ grow_fn.clone().launch(
440
+ grow_cfg,
441
+ (
442
+ &mut self.seg_cell_id,
443
+ &mut self.seg_syn_count,
444
+ &mut self.syn_presyn,
445
+ &mut self.syn_perm,
446
+ &mut self.cell_seg_count,
447
+ &self.burst_cols_flat,
448
+ &self.burst_cols_count,
449
+ &self.prev_winner_bits,
450
+ &self.prev_active_bits,
451
+ &self.col_best_match,
452
+ cfg_val,
453
+ ),
454
+ )?;
455
+ }
456
+ }
457
+
458
+ Ok(())
459
+ }
460
+ }
overlay/htm_rust/uv.lock CHANGED
@@ -1,8 +1,8 @@
1
- version = 1
2
- revision = 3
3
- requires-python = ">=3.11"
4
-
5
- [[package]]
6
- name = "htm-rust"
7
- version = "0.1.0"
8
- source = { editable = "." }
 
1
+ version = 1
2
+ revision = 3
3
+ requires-python = ">=3.11"
4
+
5
+ [[package]]
6
+ name = "htm-rust"
7
+ version = "0.1.0"
8
+ source = { editable = "." }
overlay/hydra/config.py CHANGED
@@ -110,8 +110,8 @@ class PostSemClawConfig:
110
  gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
111
 
112
  # Label smoothing + Z-loss
113
- label_smoothing: float = field(default_factory=lambda: float(os.environ.get("HYDRA_LABEL_SMOOTHING", "0.0")))
114
- z_loss_weight: float = field(default_factory=lambda: float(os.environ.get("HYDRA_Z_LOSS_WEIGHT", "1e-4")))
115
 
116
 
117
  # ---------------------------------------------------------------------------
 
110
  gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
111
 
112
  # Label smoothing + Z-loss
113
+ label_smoothing: float = 0.0 # disabled: any smoothing hurts in 5-min budget
114
+ z_loss_weight: float = 1e-4
115
 
116
 
117
  # ---------------------------------------------------------------------------
overlay/hydra/engram.py CHANGED
@@ -1,23 +1,93 @@
1
- """GPU Engram β€” Top-k Sparse Hopfield retrieval with optional Cantor/SDR nerve constraint."""
2
 
3
- from __future__ import annotations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  import torch
8
  import torch.nn as nn
9
 
 
 
 
10
 
11
- _ENGRAM_TOPK = int(os.environ.get("HYDRA_ENGRAM_TOPK", "64"))
 
12
 
 
 
 
13
 
14
- class GPUEngram(nn.Module):
15
- """GPU Engram: Top-k Sparse Hopfield retrieval.
 
 
16
 
17
- Default `routing_mode=flat` preserves the existing full-memory top-k path.
18
- `cantor_sdr` constrains candidates to the current Cantor leaf shard and SDR
19
- active offsets. `auto` only uses that local path when it is cheaper than the
20
- full score matrix (`K * d_model < n_columns`).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  """
22
 
23
  def __init__(
@@ -31,15 +101,20 @@ class GPUEngram(nn.Module):
31
  self.n_columns = n_columns
32
  self.max_ngram = max_ngram
33
  self.hebbian_boost = hebbian_boost
 
34
  self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
35
  self.gate = nn.Linear(d_model, 1, bias=True)
36
- nn.init.constant_(self.gate.bias, 0.0)
37
- self.topk_k = min(_ENGRAM_TOPK, n_columns)
38
  self.primes = [2654435761, 2246822519, 3266489917]
39
  self.hebbian_lr = 0.01
40
- self.routing_mode = os.environ.get("HYDRA_ENGRAM_ROUTING", "auto").lower()
 
 
 
41
 
42
  def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
 
43
  B, T = token_ids.shape
44
  h = token_ids * self.primes[0]
45
  if T > 1:
@@ -52,103 +127,44 @@ class GPUEngram(nn.Module):
52
  h = h ^ (shifted2 * self.primes[2])
53
  return h % self.n_columns
54
 
55
- def _validate_active_indices(self, sdr_active_indices: torch.Tensor, x: torch.Tensor) -> None:
56
- if not torch.is_floating_point(sdr_active_indices) and sdr_active_indices.dtype != torch.bool:
57
- pass
58
- else:
59
- raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
60
- if sdr_active_indices.dim() not in (2, 3):
61
- raise ValueError("compact active indices must have shape (B,T,K) or (B*T,K)")
62
- # Dense SDR masks arrive with K ~= n_bits; compact buffers are small
63
- # (retina target_active or RealityBridge l0_k). Refuse obviously dense
64
- # masks so forced cantor_sdr cannot silently route 0/1 values as offsets.
65
- if sdr_active_indices.shape[-1] > 1024 or sdr_active_indices.shape[-1] > self.n_columns:
66
- raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
67
-
68
- def _cantor_sdr_candidates(
69
- self,
70
- sdr_active_indices: torch.Tensor,
71
- cantor_leaf_ids: torch.Tensor,
72
- n_leaves: int,
73
- ) -> torch.Tensor:
74
- """Map SDR active offsets into each Cantor leaf's Engram column shard."""
75
- self._validate_active_indices(sdr_active_indices, cantor_leaf_ids)
76
- if sdr_active_indices.dim() == 2:
77
- B, T = cantor_leaf_ids.shape
78
- sdr_active_indices = sdr_active_indices.view(B, T, -1)
79
- sdr = sdr_active_indices.to(device=cantor_leaf_ids.device, dtype=torch.long)
80
- leaves = cantor_leaf_ids.to(dtype=torch.long).clamp(min=0, max=max(0, n_leaves - 1))
81
- cols_per_leaf = max(1, self.n_columns // max(1, n_leaves))
82
- offsets = sdr.remainder(cols_per_leaf)
83
- base = leaves.unsqueeze(-1) * cols_per_leaf
84
- return (base + offsets).clamp(max=self.n_columns - 1)
85
-
86
- def _flat_retrieve(self, x: torch.Tensor) -> torch.Tensor:
87
- scores = x @ self.memory.T
88
- topk_vals, topk_idx = scores.topk(self.topk_k, dim=-1)
89
- topk_w = torch.softmax(topk_vals, dim=-1)
90
- selected_mem = self.memory[topk_idx]
91
- return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
92
 
93
- def _cantor_sdr_retrieve(
94
- self,
95
- x: torch.Tensor,
96
- sdr_active_indices: torch.Tensor,
97
- cantor_leaf_ids: torch.Tensor,
98
- cantor_n_leaves: int,
99
- ) -> torch.Tensor:
100
- candidates = self._cantor_sdr_candidates(
101
- sdr_active_indices,
102
- cantor_leaf_ids,
103
- n_leaves=cantor_n_leaves,
104
- )
105
- cand_mem = self.memory[candidates]
106
- scores = torch.einsum('btd,btkd->btk', x, cand_mem)
107
- k = min(self.topk_k, scores.shape[-1])
108
- topk_vals, local_idx = scores.topk(k, dim=-1)
109
- topk_w = torch.softmax(topk_vals, dim=-1)
110
- global_idx = candidates.gather(-1, local_idx)
111
- selected_mem = self.memory[global_idx]
112
- return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
113
 
114
- def forward(
115
- self,
116
- x: torch.Tensor,
117
- token_ids: torch.Tensor,
118
- sdr_active_indices: torch.Tensor | None = None,
119
- cantor_leaf_ids: torch.Tensor | None = None,
120
- cantor_n_leaves: int | None = None,
121
- ):
122
- B, T, D = x.shape
123
- mode = self.routing_mode
124
- use_cantor = (
125
- mode in {"cantor_sdr", "auto"}
126
- and sdr_active_indices is not None
127
- and cantor_leaf_ids is not None
128
- and cantor_n_leaves is not None
129
- )
130
- if mode == "auto" and use_cantor:
131
- k_active = sdr_active_indices.shape[-1]
132
- # Compare actual retrieval candidates against the full-memory scan.
133
- # The previous `(k_active * D) < n_columns` check mixed candidate
134
- # count with feature dimension, so d256/k64 fell back to flat
135
- # retrieval even though Cantor/SDR scores only 64 candidates vs
136
- # 8k-16k memory columns. That kept required subsystems active but
137
- # spent tens of billions of extra MACs per forward.
138
- use_cantor = k_active < self.n_columns
139
-
140
- if use_cantor and mode in {"cantor_sdr", "auto"}:
141
- retrieved = self._cantor_sdr_retrieve(x, sdr_active_indices, cantor_leaf_ids, cantor_n_leaves)
142
- else:
143
- retrieved = self._flat_retrieve(x)
144
-
145
- alpha = torch.sigmoid(self.gate(x))
146
 
 
147
  if self.training and self.hebbian_boost:
148
  with torch.no_grad():
 
149
  indices = self._hash(token_ids)
150
- flat_idx = indices.reshape(-1)
151
- flat_x = x.detach().reshape(-1, D)
152
  mem_dtype = self.memory.data.dtype
153
  updates = (
154
  self.hebbian_lr * flat_x
@@ -156,5 +172,6 @@ class GPUEngram(nn.Module):
156
  ).to(mem_dtype)
157
  self.memory.data.index_add_(0, flat_idx, updates)
158
 
 
159
  hit_rate = (alpha.detach() > 0.1).float().mean()
160
  return x + alpha * retrieved, hit_rate
 
1
+ """GPU Engram β€” Sparse Modern Hopfield retrieval path.
2
 
3
+ ## What changed (scatter-gather β†’ Hopfield matmul)
4
+
5
+ The original forward used `self.memory[indices]` (scatter-gather), which misses
6
+ L2 cache at n_columns > 4096 and creates a hard tps ceiling.
7
+
8
+ The replacement uses:
9
+ scores = x @ self.memory.T # (B, T, n_columns) β€” coalesced matmul
10
+ weights = entmax15(scores, dim=-1) # sparse attention; 95%+ exact zeros
11
+ retrieved = weights @ self.memory # (B, T, d_model) β€” coalesced matmul
12
+
13
+ Both matmuls are tile-friendly (cuBLAS GEMM), so L2 reuse is high regardless of
14
+ n_columns. Gradient flows through both matmuls so `self.memory` learns via
15
+ autograd in addition to (or instead of) the Hebbian EMA writes.
16
+
17
+ ## Sparsity mechanism
18
+
19
+ alpha-entmax with alpha=1.5 (entmax15) is a sparse attention operator that maps
20
+ logit vectors to distributions where many entries are *exactly* zero (not merely
21
+ small). It generalises softmax (alpha=1) and argmax (alphaβ†’βˆž). At n_columns=1024
22
+ with d_model=64 a random batch typically hits β‰₯95% zero entries β€” the key
23
+ property that keeps bandwidth proportional to *attended* columns, not all columns.
24
+
25
+ Fallback: if `entmax` is not pip-installed, top-k softmax (k=32) is used instead.
26
+ This is chosen at module-import time β€” NO runtime branching per forward call.
27
+
28
+ ## token_ids argument
29
 
30
+ token_ids is accepted for API compatibility with the rest of the hydra stack
31
+ (train.py, lightning_module.py call `engram(x, token_ids)`). It is NOT used in
32
+ the retrieval path β€” the Hopfield path computes dense similarity over the whole
33
+ memory bank, which subsumes any hash-based column selection. Documented here to
34
+ prevent confusion.
35
+
36
+ ## Hebbian writes (hebbian_boost=False by default)
37
+
38
+ With Hopfield retrieval, gradient signals reach self.memory through autograd, so
39
+ Hebbian EMA writes are no longer critical. They are preserved as an *optional*
40
+ boost (hebbian_boost=True) for experiments that want both signals. Default is off.
41
+
42
+ ## Checkpoint compatibility
43
+
44
+ `self.memory` shape (n_columns, d_model) is unchanged, so existing .pt / .ckpt
45
+ files load without modification.
46
+ """
47
+
48
+ from __future__ import annotations
49
 
50
  import torch
51
  import torch.nn as nn
52
 
53
+ # ---------------------------------------------------------------------------
54
+ # Sparse-attention backend β€” chosen ONCE at import time, no runtime branching.
55
+ # ---------------------------------------------------------------------------
56
 
57
+ try:
58
+ from entmax import entmax15 as _entmax15 # type: ignore[import]
59
 
60
+ def _sparse_attention(scores: torch.Tensor) -> torch.Tensor:
61
+ """alpha-entmax (alpha=1.5): truly sparse distribution over last dim."""
62
+ return _entmax15(scores, dim=-1).to(dtype=scores.dtype)
63
 
64
+ _BACKEND = "entmax15"
65
+
66
+ except ImportError: # pragma: no cover β€” entmax always installed in CI
67
+ _K = 32 # top-k for fallback
68
 
69
+ def _sparse_attention(scores: torch.Tensor) -> torch.Tensor: # type: ignore[misc]
70
+ """Top-k softmax fallback: zero outside the k highest-scoring columns."""
71
+ topk_vals, topk_idx = scores.topk(_K, dim=-1)
72
+ topk_w = torch.softmax(topk_vals, dim=-1)
73
+ weights = torch.zeros_like(scores)
74
+ weights.scatter_(-1, topk_idx, topk_w.to(dtype=weights.dtype))
75
+ return weights
76
+
77
+ _BACKEND = "topk32"
78
+
79
+
80
+ class GPUEngram(nn.Module):
81
+ """GPU Engram: Sparse Modern Hopfield retrieval.
82
+
83
+ Args:
84
+ d_model: Model dimension β€” must match the surrounding transformer.
85
+ n_columns: Number of memory columns (key-value pairs). Safe at 32 768
86
+ with the matmul path; the old scatter-gather had an L2
87
+ cliff above ~4 096.
88
+ max_ngram: Retained for API compatibility; unused in retrieval path.
89
+ hebbian_boost: If True, also run a Hebbian EMA write on the memory bank
90
+ during training (old behaviour, now optional). Default False.
91
  """
92
 
93
  def __init__(
 
101
  self.n_columns = n_columns
102
  self.max_ngram = max_ngram
103
  self.hebbian_boost = hebbian_boost
104
+ # Shape unchanged from original β€” existing checkpoints load cleanly.
105
  self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
106
  self.gate = nn.Linear(d_model, 1, bias=True)
107
+ nn.init.constant_(self.gate.bias, 0.0) # START OPEN
108
+ # Retained for any external code that reads these attrs.
109
  self.primes = [2654435761, 2246822519, 3266489917]
110
  self.hebbian_lr = 0.01
111
+
112
+ # ------------------------------------------------------------------
113
+ # _hash: retained for API/checkpoint compat; unused in forward below.
114
+ # ------------------------------------------------------------------
115
 
116
  def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
117
+ """N-gram hash β†’ column index (kept for backward-compat; not used in retrieval)."""
118
  B, T = token_ids.shape
119
  h = token_ids * self.primes[0]
120
  if T > 1:
 
127
  h = h ^ (shifted2 * self.primes[2])
128
  return h % self.n_columns
129
 
130
+ # ------------------------------------------------------------------
131
+ # forward
132
+ # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ def forward(self, x: torch.Tensor, token_ids: torch.Tensor):
135
+ """Hopfield retrieve + soft gate + residual.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ Args:
138
+ x: (B, T, d_model) β€” input activations.
139
+ token_ids: (B, T) β€” token indices. Accepted for API compatibility;
140
+ NOT used in the retrieval path (see module docstring).
141
+
142
+ Returns:
143
+ (x + alpha * retrieved, hit_rate)
144
+ - x + alpha * retrieved: (B, T, d_model)
145
+ - hit_rate: scalar tensor β€” fraction of gate values > 0.1
146
+ """
147
+ # ---- 1. Similarity scores (coalesced GEMM) ----------------------
148
+ # scores[b, t, c] = dot(x[b,t], memory[c])
149
+ scores = x @ self.memory.T # (B, T, n_columns)
150
+
151
+ # ---- 2. Sparse attention weights --------------------------------
152
+ # _sparse_attention is fixed at import time (entmax15 or top-k).
153
+ weights = _sparse_attention(scores) # (B, T, n_columns), many exact zeros
154
+
155
+ # ---- 3. Retrieved vector (coalesced GEMM) -----------------------
156
+ retrieved = weights @ self.memory # (B, T, d_model)
157
+
158
+ # ---- 4. Soft gate (unchanged) -----------------------------------
159
+ alpha = torch.sigmoid(self.gate(x)) # (B, T, 1)
 
 
 
 
 
 
 
 
 
160
 
161
+ # ---- 5. Optional Hebbian EMA write ------------------------------
162
  if self.training and self.hebbian_boost:
163
  with torch.no_grad():
164
+ # Reuse the hash-based indices for the write target (sparse update).
165
  indices = self._hash(token_ids)
166
+ flat_idx = indices.reshape(-1) # (B*T,)
167
+ flat_x = x.detach().reshape(-1, x.shape[-1]) # (B*T, d_model)
168
  mem_dtype = self.memory.data.dtype
169
  updates = (
170
  self.hebbian_lr * flat_x
 
172
  ).to(mem_dtype)
173
  self.memory.data.index_add_(0, flat_idx, updates)
174
 
175
+ # ---- 6. Residual + hit_rate -------------------------------------
176
  hit_rate = (alpha.detach() > 0.1).float().mean()
177
  return x + alpha * retrieved, hit_rate
overlay/hydra/model.py CHANGED
@@ -469,6 +469,7 @@ class PostSemClawModel(nn.Module):
469
  # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
470
  # dtypes in the same shape group would break lerp_ dtype checks.
471
  self.wte.to(dtype=torch.bfloat16)
 
472
  self.htm_proj.to(dtype=torch.bfloat16)
473
  self.sdr_proj.to(dtype=torch.bfloat16)
474
  self.engram.to(dtype=torch.bfloat16)
 
469
  # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
470
  # dtypes in the same shape group would break lerp_ dtype checks.
471
  self.wte.to(dtype=torch.bfloat16)
472
+ self.blocks.to(dtype=torch.bfloat16)
473
  self.htm_proj.to(dtype=torch.bfloat16)
474
  self.sdr_proj.to(dtype=torch.bfloat16)
475
  self.engram.to(dtype=torch.bfloat16)
overlay/scripts/autoresearch.py CHANGED
@@ -1,517 +1,517 @@
1
- #!/usr/bin/env python3
2
- """HYDRA Autoresearch Mutation Loop.
3
-
4
- Runs baseline training -> evaluates -> picks ONE mutation at a time ->
5
- trains -> evaluates -> keeps if quality improves AND tps >= floor.
6
- Repeats until all mutations exhausted or Ctrl+C.
7
-
8
- State persisted in .omc/autoresearch_config.json for resume support.
9
-
10
- Usage:
11
- python scripts/autoresearch.py # run full loop
12
- python scripts/autoresearch.py --dry-run # show plan, don't train
13
- python scripts/autoresearch.py --baseline # only run baseline eval
14
- """
15
-
16
- from __future__ import annotations
17
-
18
- import argparse
19
- import json
20
- import math
21
- import os
22
- import re
23
- import signal
24
- import subprocess
25
- import sys
26
- import time
27
- from pathlib import Path
28
-
29
- _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
30
- if _PROJECT_ROOT not in sys.path:
31
- sys.path.insert(0, _PROJECT_ROOT)
32
-
33
- # ---------------------------------------------------------------------------
34
- # Mutation catalog (ordered by expected impact)
35
- # ---------------------------------------------------------------------------
36
-
37
- MUTATIONS = [
38
- # Learning dynamics β€” env vars verified in hydra/config.py
39
- {"name": "lr_matrix_0.012", "env": "HYDRA_MATRIX_LR=0.012"}, # default 0.12
40
- {"name": "lr_matrix_0.06", "env": "HYDRA_MATRIX_LR=0.06"}, # half default
41
- {"name": "lr_matrix_0.24", "env": "HYDRA_MATRIX_LR=0.24"}, # double default
42
- {"name": "lr_floor_50pct", "env": "HYDRA_LR_MIN_MULT=0.5"}, # default 0.0
43
- {"name": "lr_floor_20pct", "env": "HYDRA_LR_MIN_MULT=0.2"}, # default 0.0
44
- {"name": "embed_lr_0.5", "env": "HYDRA_EMBED_LR=0.5"}, # default 1.0
45
- {"name": "embed_lr_2.0", "env": "HYDRA_EMBED_LR=2.0"}, # default 1.0
46
- {"name": "unembed_lr_0.01", "env": "HYDRA_UNEMBED_LR=0.01"}, # default 0.005
47
- # Architecture β€” env vars verified in hydra/config.py
48
- {"name": "d_model_384", "env": "HYDRA_D_MODEL=384"}, # default 256
49
- {"name": "d_model_192", "env": "HYDRA_D_MODEL=192"}, # smaller
50
- {"name": "d_state_128", "env": "HYDRA_D_STATE=128"}, # default 64
51
- {"name": "d_state_32", "env": "HYDRA_D_STATE=32"}, # smaller
52
- {"name": "n_layer_6", "env": "HYDRA_N_LAYER=6"}, # default 4
53
- {"name": "n_layer_3", "env": "HYDRA_N_LAYER=3"}, # fewer
54
- {"name": "headdim_16", "env": "HYDRA_HEADDIM=16"}, # default 32 -> more heads
55
- {"name": "headdim_64", "env": "HYDRA_HEADDIM=64"}, # default 32 -> fewer heads
56
- {"name": "expand_3", "env": "HYDRA_EXPAND=3"}, # default 2
57
- {"name": "engram_2048", "env": "HYDRA_ENGRAM_N_COLUMNS=2048"}, # default 1024
58
- {"name": "engram_4096", "env": "HYDRA_ENGRAM_N_COLUMNS=4096"}, # default 1024
59
- {"name": "engram_512", "env": "HYDRA_ENGRAM_N_COLUMNS=512"}, # smaller
60
- # Batch size
61
- {"name": "batch_32k", "env": "HYDRA_TOTAL_BATCH=32768"}, # default 32768 (verify)
62
- {"name": "batch_16k", "env": "HYDRA_TOTAL_BATCH=16384"}, # smaller batch
63
- {"name": "batch_65k", "env": "HYDRA_TOTAL_BATCH=65536"}, # larger batch
64
- # Regularization β€” env vars verified in hydra/model.py + hydra/config.py
65
- {"name": "dropout_0.05", "env": "HYDRA_DROPOUT=0.05"}, # default 0.2
66
- {"name": "dropout_0.1", "env": "HYDRA_DROPOUT=0.1"}, # default 0.2
67
- {"name": "dropout_0.3", "env": "HYDRA_DROPOUT=0.3"}, # higher
68
- ]
69
-
70
- # ---------------------------------------------------------------------------
71
- # State management
72
- # ---------------------------------------------------------------------------
73
-
74
- STATE_DIR = os.path.join(_PROJECT_ROOT, ".omc")
75
- STATE_FILE = os.path.join(STATE_DIR, "autoresearch_config.json")
76
-
77
- DEFAULT_STATE = {
78
- "baseline_quality": None,
79
- "baseline_tps": None,
80
- "current_gen": 0,
81
- "mutations_tested": [],
82
- "mutations_kept": [],
83
- "tps_floor": 62000,
84
- "time_budget": 600,
85
- "history": [],
86
- }
87
-
88
-
89
- def load_state() -> dict:
90
- """Load state from disk or return default."""
91
- if os.path.exists(STATE_FILE):
92
- with open(STATE_FILE, "r") as f:
93
- state = json.load(f)
94
- # Backfill missing keys from defaults
95
- for k, v in DEFAULT_STATE.items():
96
- if k not in state:
97
- state[k] = v
98
- return state
99
- return dict(DEFAULT_STATE)
100
-
101
-
102
- def save_state(state: dict) -> None:
103
- """Persist state to disk."""
104
- os.makedirs(STATE_DIR, exist_ok=True)
105
- with open(STATE_FILE, "w") as f:
106
- json.dump(state, f, indent=2)
107
-
108
-
109
- # ---------------------------------------------------------------------------
110
- # Training subprocess
111
- # ---------------------------------------------------------------------------
112
-
113
- def build_env(extra_env: str | None = None) -> dict[str, str]:
114
- """Build environment for training subprocess."""
115
- env = os.environ.copy()
116
- # Ensure CUDA paths
117
- ld_paths = ["/usr/lib/wsl/lib", "/usr/local/cuda/lib64"]
118
- existing = env.get("LD_LIBRARY_PATH", "")
119
- for p in ld_paths:
120
- if p not in existing:
121
- existing = p + ":" + existing
122
- env["LD_LIBRARY_PATH"] = existing
123
-
124
- # Apply mutation env var
125
- if extra_env:
126
- key, val = extra_env.split("=", 1)
127
- env[key] = val
128
-
129
- return env
130
-
131
-
132
- def run_training(time_budget: int, extra_env: str | None = None) -> dict | None:
133
- """Run train.py with given time budget and optional env override.
134
-
135
- Returns dict with parsed metrics, or None on failure.
136
- """
137
- env = build_env(extra_env)
138
- env["HYDRA_TIME_BUDGET"] = str(time_budget)
139
-
140
- cmd = [os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"), "-u", "train.py"]
141
-
142
- try:
143
- proc = subprocess.Popen(
144
- cmd,
145
- cwd=_PROJECT_ROOT,
146
- env=env,
147
- stdout=subprocess.PIPE,
148
- stderr=subprocess.STDOUT,
149
- text=True,
150
- bufsize=1,
151
- )
152
- except Exception as e:
153
- print(f" [ERROR] Failed to start training: {e}")
154
- return None
155
-
156
- output_lines: list[str] = []
157
- last_step_line = ""
158
-
159
- try:
160
- for line in proc.stdout:
161
- line = line.rstrip()
162
- output_lines.append(line)
163
- if line.startswith("step="):
164
- last_step_line = line
165
- # Print progress every 50 steps
166
- m = re.search(r"step=(\d+)", line)
167
- if m and int(m.group(1)) % 50 == 0:
168
- tps_m = re.search(r"tps=(\d+)", line)
169
- bpb_m = re.search(r"bpb=([\d.]+)", line)
170
- tps = tps_m.group(1) if tps_m else "?"
171
- bpb = bpb_m.group(1) if bpb_m else "?"
172
- print(f" step={m.group(1)} tps={tps} bpb={bpb}", flush=True)
173
- elif "val_bpb" in line or "factual_english_score" in line:
174
- print(f" {line}", flush=True)
175
- except KeyboardInterrupt:
176
- proc.terminate()
177
- proc.wait()
178
- raise
179
-
180
- proc.wait()
181
- if proc.returncode != 0:
182
- print(f" [ERROR] Training exited with code {proc.returncode}")
183
- # Print last 10 lines for debugging
184
- for line in output_lines[-10:]:
185
- print(f" {line}")
186
- return None
187
-
188
- return _parse_training_output(output_lines)
189
-
190
-
191
- def _parse_training_output(lines: list[str]) -> dict:
192
- """Extract metrics from training output lines."""
193
- metrics: dict[str, float] = {}
194
-
195
- for line in lines:
196
- # Key=value pairs from summary block
197
- for key in ["val_bpb", "training_seconds", "peak_vram_mb", "mfu_percent",
198
- "total_tokens_M", "num_steps", "factual_english_score",
199
- "factual_english_hits"]:
200
- m = re.match(rf"^{key}:\s+([\d.]+)", line.strip())
201
- if m:
202
- metrics[key] = float(m.group(1))
203
-
204
- # TPS from last step line
205
- if line.startswith("step="):
206
- tps_m = re.search(r"tps=(\d+)", line)
207
- if tps_m:
208
- metrics["tps"] = float(tps_m.group(1))
209
-
210
- return metrics
211
-
212
-
213
- # ---------------------------------------------------------------------------
214
- # Eval integration
215
- # ---------------------------------------------------------------------------
216
-
217
- def run_eval_after_training(extra_env: str | None = None) -> dict | None:
218
- """Run eval_quality.py after training. Returns metrics dict or None."""
219
- env = build_env(extra_env)
220
- cmd = [
221
- os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"),
222
- os.path.join(_PROJECT_ROOT, "scripts", "eval_quality.py"),
223
- ]
224
-
225
- try:
226
- result = subprocess.run(
227
- cmd,
228
- cwd=_PROJECT_ROOT,
229
- env=env,
230
- capture_output=True,
231
- text=True,
232
- timeout=120, # 2 min max for eval
233
- )
234
- except subprocess.TimeoutExpired:
235
- print(" [ERROR] Eval timed out (120s)")
236
- return None
237
- except Exception as e:
238
- print(f" [ERROR] Eval failed: {e}")
239
- return None
240
-
241
- if result.returncode != 0:
242
- print(f" [ERROR] Eval exited with code {result.returncode}")
243
- for line in result.stdout.split("\n")[-10:]:
244
- print(f" {line}")
245
- for line in result.stderr.split("\n")[-5:]:
246
- print(f" {line}")
247
- return None
248
-
249
- # Parse key=value output
250
- metrics = {}
251
- for line in result.stdout.split("\n"):
252
- line = line.strip()
253
- m = re.match(r"^([\w]+)=([\d.eE+-]+)$", line)
254
- if m:
255
- try:
256
- metrics[m.group(1)] = float(m.group(2))
257
- except ValueError:
258
- pass
259
-
260
- return metrics if metrics else None
261
-
262
-
263
- # ---------------------------------------------------------------------------
264
- # Git operations
265
- # ---------------------------------------------------------------------------
266
-
267
- def git_commit(message: str) -> bool:
268
- """Stage all changes and commit."""
269
- try:
270
- subprocess.run(["git", "add", "-A"], cwd=_PROJECT_ROOT, check=True,
271
- capture_output=True, timeout=30)
272
- subprocess.run(
273
- ["git", "commit", "-m", message],
274
- cwd=_PROJECT_ROOT, check=True, capture_output=True, timeout=30,
275
- )
276
- return True
277
- except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
278
- print(f" [WARN] Git commit failed: {e}")
279
- return False
280
-
281
-
282
- # ---------------------------------------------------------------------------
283
- # Main loop
284
- # ---------------------------------------------------------------------------
285
-
286
- _SHUTDOWN = False
287
-
288
-
289
- def _handle_sigint(signum, frame):
290
- global _SHUTDOWN
291
- if _SHUTDOWN:
292
- print("\n[AUTORESEARCH] Double Ctrl+C β€” force exit")
293
- sys.exit(1)
294
- _SHUTDOWN = True
295
- print("\n[AUTORESEARCH] Ctrl+C received β€” finishing current gen then saving state...")
296
-
297
-
298
- def main():
299
- global _SHUTDOWN
300
- signal.signal(signal.SIGINT, _handle_sigint)
301
-
302
- parser = argparse.ArgumentParser(description="HYDRA autoresearch mutation loop")
303
- parser.add_argument("--dry-run", action="store_true", help="Show plan, don't train")
304
- parser.add_argument("--baseline", action="store_true", help="Only run baseline")
305
- parser.add_argument("--time-budget", type=int, default=600, help="Time budget per run (s)")
306
- parser.add_argument("--tps-floor", type=int, default=62000, help="Minimum acceptable TPS")
307
- args = parser.parse_args()
308
-
309
- state = load_state()
310
- state["time_budget"] = args.time_budget
311
- state["tps_floor"] = args.tps_floor
312
-
313
- tested = set(state["mutations_tested"])
314
- remaining = [m for m in MUTATIONS if m["name"] not in tested]
315
-
316
- print("=" * 70)
317
- print("HYDRA AUTORESEARCH MUTATION LOOP")
318
- print("=" * 70)
319
- print(f"Time budget per run: {state['time_budget']}s")
320
- print(f"TPS floor: {state['tps_floor']}")
321
- print(f"Current gen: {state['current_gen']}")
322
- print(f"Mutations tested: {len(tested)}/{len(MUTATIONS)}")
323
- print(f"Mutations kept: {state['mutations_kept']}")
324
- print(f"Remaining: {[m['name'] for m in remaining]}")
325
- print()
326
-
327
- if args.dry_run:
328
- print("[DRY RUN] Would test these mutations in order:")
329
- for i, m in enumerate(remaining):
330
- print(f" {i + 1}. {m['name']} ({m['env']})")
331
- return
332
-
333
- # -----------------------------------------------------------------------
334
- # Baseline (Gen 0)
335
- # -----------------------------------------------------------------------
336
- if state["baseline_quality"] is None:
337
- print("[GEN 0] Running baseline training + evaluation...")
338
- train_metrics = run_training(state["time_budget"])
339
- if train_metrics is None:
340
- print("[FAIL] Baseline training failed")
341
- save_state(state)
342
- return
343
-
344
- print("[GEN 0] Running quality evaluation...")
345
- eval_metrics = run_eval_after_training()
346
- if eval_metrics is None:
347
- print("[FAIL] Baseline eval failed")
348
- save_state(state)
349
- return
350
-
351
- baseline_tps = train_metrics.get("tps", 0)
352
- baseline_quality = eval_metrics.get("quality_score", 0)
353
-
354
- state["baseline_quality"] = baseline_quality
355
- state["baseline_tps"] = baseline_tps
356
- state["current_gen"] = 0
357
- state["history"].append({
358
- "gen": 0,
359
- "mutation": "baseline",
360
- "quality_score": baseline_quality,
361
- "baseline_score": baseline_quality,
362
- "delta": "0.0%",
363
- "tps": baseline_tps,
364
- "ppl": eval_metrics.get("ppl", 0),
365
- "bleu4": eval_metrics.get("bleu4", 0),
366
- "rouge_l": eval_metrics.get("rouge_l", 0),
367
- "factual": eval_metrics.get("factual", 0),
368
- "bpb": eval_metrics.get("bpb", 0),
369
- "repetition_rate": eval_metrics.get("repetition_rate", 0),
370
- "kept": True,
371
- })
372
- save_state(state)
373
- print(f"[GEN 0] BASELINE: quality={baseline_quality:.4f} tps={baseline_tps:.0f}")
374
-
375
- if args.baseline:
376
- return
377
- else:
378
- print(f"[RESUME] Baseline quality={state['baseline_quality']:.4f} tps={state['baseline_tps']:.0f}")
379
- if args.baseline:
380
- return
381
-
382
- # -----------------------------------------------------------------------
383
- # Mutation loop
384
- # -----------------------------------------------------------------------
385
- current_quality = state["baseline_quality"]
386
- # Track best quality so far (from last kept mutation, not just baseline)
387
- if state["history"]:
388
- kept_entries = [h for h in state["history"] if h.get("kept")]
389
- if kept_entries:
390
- current_quality = kept_entries[-1]["quality_score"]
391
-
392
- for mutation in remaining:
393
- if _SHUTDOWN:
394
- print("[AUTORESEARCH] Shutdown requested β€” saving state")
395
- save_state(state)
396
- return
397
-
398
- gen = state["current_gen"] + 1
399
- name = mutation["name"]
400
- env_str = mutation["env"]
401
-
402
- print(f"\n[GEN {gen}] Testing {name} ({env_str})...")
403
- print(f" Current best quality: {current_quality:.4f}")
404
-
405
- # Train with mutation
406
- print(f" Training ({state['time_budget']}s)...", flush=True)
407
- train_metrics = run_training(state["time_budget"], extra_env=env_str)
408
- if train_metrics is None:
409
- print(f" [SKIP] Training failed for {name}")
410
- state["mutations_tested"].append(name)
411
- state["current_gen"] = gen
412
- state["history"].append({
413
- "gen": gen, "mutation": name,
414
- "quality_score": 0, "baseline_score": current_quality,
415
- "delta": "FAIL", "tps": 0, "ppl": 0, "bleu4": 0,
416
- "rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
417
- "kept": False,
418
- })
419
- save_state(state)
420
- continue
421
-
422
- tps = train_metrics.get("tps", 0)
423
-
424
- # TPS floor check
425
- if tps < state["tps_floor"]:
426
- print(f" [REJECT] TPS={tps:.0f} < floor={state['tps_floor']} β€” skipping eval")
427
- state["mutations_tested"].append(name)
428
- state["current_gen"] = gen
429
- state["history"].append({
430
- "gen": gen, "mutation": name,
431
- "quality_score": 0, "baseline_score": current_quality,
432
- "delta": f"TPS_FAIL({tps:.0f})", "tps": tps,
433
- "ppl": 0, "bleu4": 0, "rouge_l": 0, "factual": 0,
434
- "bpb": train_metrics.get("val_bpb", 0), "repetition_rate": 0,
435
- "kept": False,
436
- })
437
- save_state(state)
438
- continue
439
-
440
- # Evaluate
441
- print(f" Evaluating...", flush=True)
442
- eval_metrics = run_eval_after_training(extra_env=env_str)
443
- if eval_metrics is None:
444
- print(f" [SKIP] Eval failed for {name}")
445
- state["mutations_tested"].append(name)
446
- state["current_gen"] = gen
447
- state["history"].append({
448
- "gen": gen, "mutation": name,
449
- "quality_score": 0, "baseline_score": current_quality,
450
- "delta": "EVAL_FAIL", "tps": tps, "ppl": 0, "bleu4": 0,
451
- "rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
452
- "kept": False,
453
- })
454
- save_state(state)
455
- continue
456
-
457
- quality = eval_metrics.get("quality_score", 0)
458
- delta_pct = ((quality - current_quality) / max(abs(current_quality), 1e-6)) * 100
459
- delta_str = f"{delta_pct:+.1f}%"
460
-
461
- kept = quality > current_quality and tps >= state["tps_floor"]
462
- status = "KEEP" if kept else "DISCARD"
463
-
464
- entry = {
465
- "gen": gen,
466
- "mutation": name,
467
- "quality_score": quality,
468
- "baseline_score": current_quality,
469
- "delta": delta_str,
470
- "tps": tps,
471
- "ppl": eval_metrics.get("ppl", 0),
472
- "bleu4": eval_metrics.get("bleu4", 0),
473
- "rouge_l": eval_metrics.get("rouge_l", 0),
474
- "factual": eval_metrics.get("factual", 0),
475
- "bpb": eval_metrics.get("bpb", 0),
476
- "repetition_rate": eval_metrics.get("repetition_rate", 0),
477
- "kept": kept,
478
- }
479
-
480
- print(f"\n[GEN {gen}] {name}: quality={quality:.4f} ({delta_str}) tps={tps:.0f} -> {status}")
481
-
482
- if kept:
483
- current_quality = quality
484
- state["mutations_kept"].append(name)
485
- git_commit(f"autoresearch: gen {gen} β€” {name} quality {delta_str}")
486
-
487
- state["mutations_tested"].append(name)
488
- state["current_gen"] = gen
489
- state["history"].append(entry)
490
- save_state(state)
491
-
492
- # -----------------------------------------------------------------------
493
- # Summary
494
- # -----------------------------------------------------------------------
495
- print("\n" + "=" * 70)
496
- print("AUTORESEARCH COMPLETE")
497
- print("=" * 70)
498
- print(f"Total generations: {state['current_gen']}")
499
- print(f"Mutations kept: {state['mutations_kept']}")
500
- print(f"Final quality: {current_quality:.4f}")
501
- if state["baseline_quality"]:
502
- total_delta = ((current_quality - state["baseline_quality"]) /
503
- max(abs(state["baseline_quality"]), 1e-6)) * 100
504
- print(f"Total improvement: {total_delta:+.1f}%")
505
- print()
506
-
507
- # Print history table
508
- print(f"{'Gen':>4} {'Mutation':>20} {'Quality':>8} {'Delta':>8} {'TPS':>7} {'PPL':>8} {'BPB':>7} {'Kept':>5}")
509
- print("-" * 75)
510
- for h in state["history"]:
511
- print(f"{h['gen']:4d} {h['mutation']:>20s} {h['quality_score']:8.4f} "
512
- f"{h['delta']:>8s} {h['tps']:7.0f} {h['ppl']:8.2f} "
513
- f"{h.get('bpb', 0):7.4f} {' YES' if h['kept'] else ' NO'}")
514
-
515
-
516
- if __name__ == "__main__":
517
- main()
 
1
+ #!/usr/bin/env python3
2
+ """HYDRA Autoresearch Mutation Loop.
3
+
4
+ Runs baseline training -> evaluates -> picks ONE mutation at a time ->
5
+ trains -> evaluates -> keeps if quality improves AND tps >= floor.
6
+ Repeats until all mutations exhausted or Ctrl+C.
7
+
8
+ State persisted in .omc/autoresearch_config.json for resume support.
9
+
10
+ Usage:
11
+ python scripts/autoresearch.py # run full loop
12
+ python scripts/autoresearch.py --dry-run # show plan, don't train
13
+ python scripts/autoresearch.py --baseline # only run baseline eval
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import json
20
+ import math
21
+ import os
22
+ import re
23
+ import signal
24
+ import subprocess
25
+ import sys
26
+ import time
27
+ from pathlib import Path
28
+
29
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
30
+ if _PROJECT_ROOT not in sys.path:
31
+ sys.path.insert(0, _PROJECT_ROOT)
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Mutation catalog (ordered by expected impact)
35
+ # ---------------------------------------------------------------------------
36
+
37
+ MUTATIONS = [
38
+ # Learning dynamics β€” env vars verified in hydra/config.py
39
+ {"name": "lr_matrix_0.012", "env": "HYDRA_MATRIX_LR=0.012"}, # default 0.12
40
+ {"name": "lr_matrix_0.06", "env": "HYDRA_MATRIX_LR=0.06"}, # half default
41
+ {"name": "lr_matrix_0.24", "env": "HYDRA_MATRIX_LR=0.24"}, # double default
42
+ {"name": "lr_floor_50pct", "env": "HYDRA_LR_MIN_MULT=0.5"}, # default 0.0
43
+ {"name": "lr_floor_20pct", "env": "HYDRA_LR_MIN_MULT=0.2"}, # default 0.0
44
+ {"name": "embed_lr_0.5", "env": "HYDRA_EMBED_LR=0.5"}, # default 1.0
45
+ {"name": "embed_lr_2.0", "env": "HYDRA_EMBED_LR=2.0"}, # default 1.0
46
+ {"name": "unembed_lr_0.01", "env": "HYDRA_UNEMBED_LR=0.01"}, # default 0.005
47
+ # Architecture β€” env vars verified in hydra/config.py
48
+ {"name": "d_model_384", "env": "HYDRA_D_MODEL=384"}, # default 256
49
+ {"name": "d_model_192", "env": "HYDRA_D_MODEL=192"}, # smaller
50
+ {"name": "d_state_128", "env": "HYDRA_D_STATE=128"}, # default 64
51
+ {"name": "d_state_32", "env": "HYDRA_D_STATE=32"}, # smaller
52
+ {"name": "n_layer_6", "env": "HYDRA_N_LAYER=6"}, # default 4
53
+ {"name": "n_layer_3", "env": "HYDRA_N_LAYER=3"}, # fewer
54
+ {"name": "headdim_16", "env": "HYDRA_HEADDIM=16"}, # default 32 -> more heads
55
+ {"name": "headdim_64", "env": "HYDRA_HEADDIM=64"}, # default 32 -> fewer heads
56
+ {"name": "expand_3", "env": "HYDRA_EXPAND=3"}, # default 2
57
+ {"name": "engram_2048", "env": "HYDRA_ENGRAM_N_COLUMNS=2048"}, # default 1024
58
+ {"name": "engram_4096", "env": "HYDRA_ENGRAM_N_COLUMNS=4096"}, # default 1024
59
+ {"name": "engram_512", "env": "HYDRA_ENGRAM_N_COLUMNS=512"}, # smaller
60
+ # Batch size
61
+ {"name": "batch_32k", "env": "HYDRA_TOTAL_BATCH=32768"}, # default 32768 (verify)
62
+ {"name": "batch_16k", "env": "HYDRA_TOTAL_BATCH=16384"}, # smaller batch
63
+ {"name": "batch_65k", "env": "HYDRA_TOTAL_BATCH=65536"}, # larger batch
64
+ # Regularization β€” env vars verified in hydra/model.py + hydra/config.py
65
+ {"name": "dropout_0.05", "env": "HYDRA_DROPOUT=0.05"}, # default 0.2
66
+ {"name": "dropout_0.1", "env": "HYDRA_DROPOUT=0.1"}, # default 0.2
67
+ {"name": "dropout_0.3", "env": "HYDRA_DROPOUT=0.3"}, # higher
68
+ ]
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # State management
72
+ # ---------------------------------------------------------------------------
73
+
74
+ STATE_DIR = os.path.join(_PROJECT_ROOT, ".omc")
75
+ STATE_FILE = os.path.join(STATE_DIR, "autoresearch_config.json")
76
+
77
+ DEFAULT_STATE = {
78
+ "baseline_quality": None,
79
+ "baseline_tps": None,
80
+ "current_gen": 0,
81
+ "mutations_tested": [],
82
+ "mutations_kept": [],
83
+ "tps_floor": 62000,
84
+ "time_budget": 600,
85
+ "history": [],
86
+ }
87
+
88
+
89
+ def load_state() -> dict:
90
+ """Load state from disk or return default."""
91
+ if os.path.exists(STATE_FILE):
92
+ with open(STATE_FILE, "r") as f:
93
+ state = json.load(f)
94
+ # Backfill missing keys from defaults
95
+ for k, v in DEFAULT_STATE.items():
96
+ if k not in state:
97
+ state[k] = v
98
+ return state
99
+ return dict(DEFAULT_STATE)
100
+
101
+
102
+ def save_state(state: dict) -> None:
103
+ """Persist state to disk."""
104
+ os.makedirs(STATE_DIR, exist_ok=True)
105
+ with open(STATE_FILE, "w") as f:
106
+ json.dump(state, f, indent=2)
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Training subprocess
111
+ # ---------------------------------------------------------------------------
112
+
113
+ def build_env(extra_env: str | None = None) -> dict[str, str]:
114
+ """Build environment for training subprocess."""
115
+ env = os.environ.copy()
116
+ # Ensure CUDA paths
117
+ ld_paths = ["/usr/lib/wsl/lib", "/usr/local/cuda/lib64"]
118
+ existing = env.get("LD_LIBRARY_PATH", "")
119
+ for p in ld_paths:
120
+ if p not in existing:
121
+ existing = p + ":" + existing
122
+ env["LD_LIBRARY_PATH"] = existing
123
+
124
+ # Apply mutation env var
125
+ if extra_env:
126
+ key, val = extra_env.split("=", 1)
127
+ env[key] = val
128
+
129
+ return env
130
+
131
+
132
+ def run_training(time_budget: int, extra_env: str | None = None) -> dict | None:
133
+ """Run train.py with given time budget and optional env override.
134
+
135
+ Returns dict with parsed metrics, or None on failure.
136
+ """
137
+ env = build_env(extra_env)
138
+ env["HYDRA_TIME_BUDGET"] = str(time_budget)
139
+
140
+ cmd = [os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"), "-u", "train.py"]
141
+
142
+ try:
143
+ proc = subprocess.Popen(
144
+ cmd,
145
+ cwd=_PROJECT_ROOT,
146
+ env=env,
147
+ stdout=subprocess.PIPE,
148
+ stderr=subprocess.STDOUT,
149
+ text=True,
150
+ bufsize=1,
151
+ )
152
+ except Exception as e:
153
+ print(f" [ERROR] Failed to start training: {e}")
154
+ return None
155
+
156
+ output_lines: list[str] = []
157
+ last_step_line = ""
158
+
159
+ try:
160
+ for line in proc.stdout:
161
+ line = line.rstrip()
162
+ output_lines.append(line)
163
+ if line.startswith("step="):
164
+ last_step_line = line
165
+ # Print progress every 50 steps
166
+ m = re.search(r"step=(\d+)", line)
167
+ if m and int(m.group(1)) % 50 == 0:
168
+ tps_m = re.search(r"tps=(\d+)", line)
169
+ bpb_m = re.search(r"bpb=([\d.]+)", line)
170
+ tps = tps_m.group(1) if tps_m else "?"
171
+ bpb = bpb_m.group(1) if bpb_m else "?"
172
+ print(f" step={m.group(1)} tps={tps} bpb={bpb}", flush=True)
173
+ elif "val_bpb" in line or "factual_english_score" in line:
174
+ print(f" {line}", flush=True)
175
+ except KeyboardInterrupt:
176
+ proc.terminate()
177
+ proc.wait()
178
+ raise
179
+
180
+ proc.wait()
181
+ if proc.returncode != 0:
182
+ print(f" [ERROR] Training exited with code {proc.returncode}")
183
+ # Print last 10 lines for debugging
184
+ for line in output_lines[-10:]:
185
+ print(f" {line}")
186
+ return None
187
+
188
+ return _parse_training_output(output_lines)
189
+
190
+
191
+ def _parse_training_output(lines: list[str]) -> dict:
192
+ """Extract metrics from training output lines."""
193
+ metrics: dict[str, float] = {}
194
+
195
+ for line in lines:
196
+ # Key=value pairs from summary block
197
+ for key in ["val_bpb", "training_seconds", "peak_vram_mb", "mfu_percent",
198
+ "total_tokens_M", "num_steps", "factual_english_score",
199
+ "factual_english_hits"]:
200
+ m = re.match(rf"^{key}:\s+([\d.]+)", line.strip())
201
+ if m:
202
+ metrics[key] = float(m.group(1))
203
+
204
+ # TPS from last step line
205
+ if line.startswith("step="):
206
+ tps_m = re.search(r"tps=(\d+)", line)
207
+ if tps_m:
208
+ metrics["tps"] = float(tps_m.group(1))
209
+
210
+ return metrics
211
+
212
+
213
+ # ---------------------------------------------------------------------------
214
+ # Eval integration
215
+ # ---------------------------------------------------------------------------
216
+
217
+ def run_eval_after_training(extra_env: str | None = None) -> dict | None:
218
+ """Run eval_quality.py after training. Returns metrics dict or None."""
219
+ env = build_env(extra_env)
220
+ cmd = [
221
+ os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"),
222
+ os.path.join(_PROJECT_ROOT, "scripts", "eval_quality.py"),
223
+ ]
224
+
225
+ try:
226
+ result = subprocess.run(
227
+ cmd,
228
+ cwd=_PROJECT_ROOT,
229
+ env=env,
230
+ capture_output=True,
231
+ text=True,
232
+ timeout=120, # 2 min max for eval
233
+ )
234
+ except subprocess.TimeoutExpired:
235
+ print(" [ERROR] Eval timed out (120s)")
236
+ return None
237
+ except Exception as e:
238
+ print(f" [ERROR] Eval failed: {e}")
239
+ return None
240
+
241
+ if result.returncode != 0:
242
+ print(f" [ERROR] Eval exited with code {result.returncode}")
243
+ for line in result.stdout.split("\n")[-10:]:
244
+ print(f" {line}")
245
+ for line in result.stderr.split("\n")[-5:]:
246
+ print(f" {line}")
247
+ return None
248
+
249
+ # Parse key=value output
250
+ metrics = {}
251
+ for line in result.stdout.split("\n"):
252
+ line = line.strip()
253
+ m = re.match(r"^([\w]+)=([\d.eE+-]+)$", line)
254
+ if m:
255
+ try:
256
+ metrics[m.group(1)] = float(m.group(2))
257
+ except ValueError:
258
+ pass
259
+
260
+ return metrics if metrics else None
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # Git operations
265
+ # ---------------------------------------------------------------------------
266
+
267
+ def git_commit(message: str) -> bool:
268
+ """Stage all changes and commit."""
269
+ try:
270
+ subprocess.run(["git", "add", "-A"], cwd=_PROJECT_ROOT, check=True,
271
+ capture_output=True, timeout=30)
272
+ subprocess.run(
273
+ ["git", "commit", "-m", message],
274
+ cwd=_PROJECT_ROOT, check=True, capture_output=True, timeout=30,
275
+ )
276
+ return True
277
+ except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
278
+ print(f" [WARN] Git commit failed: {e}")
279
+ return False
280
+
281
+
282
+ # ---------------------------------------------------------------------------
283
+ # Main loop
284
+ # ---------------------------------------------------------------------------
285
+
286
+ _SHUTDOWN = False
287
+
288
+
289
+ def _handle_sigint(signum, frame):
290
+ global _SHUTDOWN
291
+ if _SHUTDOWN:
292
+ print("\n[AUTORESEARCH] Double Ctrl+C β€” force exit")
293
+ sys.exit(1)
294
+ _SHUTDOWN = True
295
+ print("\n[AUTORESEARCH] Ctrl+C received β€” finishing current gen then saving state...")
296
+
297
+
298
+ def main():
299
+ global _SHUTDOWN
300
+ signal.signal(signal.SIGINT, _handle_sigint)
301
+
302
+ parser = argparse.ArgumentParser(description="HYDRA autoresearch mutation loop")
303
+ parser.add_argument("--dry-run", action="store_true", help="Show plan, don't train")
304
+ parser.add_argument("--baseline", action="store_true", help="Only run baseline")
305
+ parser.add_argument("--time-budget", type=int, default=600, help="Time budget per run (s)")
306
+ parser.add_argument("--tps-floor", type=int, default=62000, help="Minimum acceptable TPS")
307
+ args = parser.parse_args()
308
+
309
+ state = load_state()
310
+ state["time_budget"] = args.time_budget
311
+ state["tps_floor"] = args.tps_floor
312
+
313
+ tested = set(state["mutations_tested"])
314
+ remaining = [m for m in MUTATIONS if m["name"] not in tested]
315
+
316
+ print("=" * 70)
317
+ print("HYDRA AUTORESEARCH MUTATION LOOP")
318
+ print("=" * 70)
319
+ print(f"Time budget per run: {state['time_budget']}s")
320
+ print(f"TPS floor: {state['tps_floor']}")
321
+ print(f"Current gen: {state['current_gen']}")
322
+ print(f"Mutations tested: {len(tested)}/{len(MUTATIONS)}")
323
+ print(f"Mutations kept: {state['mutations_kept']}")
324
+ print(f"Remaining: {[m['name'] for m in remaining]}")
325
+ print()
326
+
327
+ if args.dry_run:
328
+ print("[DRY RUN] Would test these mutations in order:")
329
+ for i, m in enumerate(remaining):
330
+ print(f" {i + 1}. {m['name']} ({m['env']})")
331
+ return
332
+
333
+ # -----------------------------------------------------------------------
334
+ # Baseline (Gen 0)
335
+ # -----------------------------------------------------------------------
336
+ if state["baseline_quality"] is None:
337
+ print("[GEN 0] Running baseline training + evaluation...")
338
+ train_metrics = run_training(state["time_budget"])
339
+ if train_metrics is None:
340
+ print("[FAIL] Baseline training failed")
341
+ save_state(state)
342
+ return
343
+
344
+ print("[GEN 0] Running quality evaluation...")
345
+ eval_metrics = run_eval_after_training()
346
+ if eval_metrics is None:
347
+ print("[FAIL] Baseline eval failed")
348
+ save_state(state)
349
+ return
350
+
351
+ baseline_tps = train_metrics.get("tps", 0)
352
+ baseline_quality = eval_metrics.get("quality_score", 0)
353
+
354
+ state["baseline_quality"] = baseline_quality
355
+ state["baseline_tps"] = baseline_tps
356
+ state["current_gen"] = 0
357
+ state["history"].append({
358
+ "gen": 0,
359
+ "mutation": "baseline",
360
+ "quality_score": baseline_quality,
361
+ "baseline_score": baseline_quality,
362
+ "delta": "0.0%",
363
+ "tps": baseline_tps,
364
+ "ppl": eval_metrics.get("ppl", 0),
365
+ "bleu4": eval_metrics.get("bleu4", 0),
366
+ "rouge_l": eval_metrics.get("rouge_l", 0),
367
+ "factual": eval_metrics.get("factual", 0),
368
+ "bpb": eval_metrics.get("bpb", 0),
369
+ "repetition_rate": eval_metrics.get("repetition_rate", 0),
370
+ "kept": True,
371
+ })
372
+ save_state(state)
373
+ print(f"[GEN 0] BASELINE: quality={baseline_quality:.4f} tps={baseline_tps:.0f}")
374
+
375
+ if args.baseline:
376
+ return
377
+ else:
378
+ print(f"[RESUME] Baseline quality={state['baseline_quality']:.4f} tps={state['baseline_tps']:.0f}")
379
+ if args.baseline:
380
+ return
381
+
382
+ # -----------------------------------------------------------------------
383
+ # Mutation loop
384
+ # -----------------------------------------------------------------------
385
+ current_quality = state["baseline_quality"]
386
+ # Track best quality so far (from last kept mutation, not just baseline)
387
+ if state["history"]:
388
+ kept_entries = [h for h in state["history"] if h.get("kept")]
389
+ if kept_entries:
390
+ current_quality = kept_entries[-1]["quality_score"]
391
+
392
+ for mutation in remaining:
393
+ if _SHUTDOWN:
394
+ print("[AUTORESEARCH] Shutdown requested β€” saving state")
395
+ save_state(state)
396
+ return
397
+
398
+ gen = state["current_gen"] + 1
399
+ name = mutation["name"]
400
+ env_str = mutation["env"]
401
+
402
+ print(f"\n[GEN {gen}] Testing {name} ({env_str})...")
403
+ print(f" Current best quality: {current_quality:.4f}")
404
+
405
+ # Train with mutation
406
+ print(f" Training ({state['time_budget']}s)...", flush=True)
407
+ train_metrics = run_training(state["time_budget"], extra_env=env_str)
408
+ if train_metrics is None:
409
+ print(f" [SKIP] Training failed for {name}")
410
+ state["mutations_tested"].append(name)
411
+ state["current_gen"] = gen
412
+ state["history"].append({
413
+ "gen": gen, "mutation": name,
414
+ "quality_score": 0, "baseline_score": current_quality,
415
+ "delta": "FAIL", "tps": 0, "ppl": 0, "bleu4": 0,
416
+ "rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
417
+ "kept": False,
418
+ })
419
+ save_state(state)
420
+ continue
421
+
422
+ tps = train_metrics.get("tps", 0)
423
+
424
+ # TPS floor check
425
+ if tps < state["tps_floor"]:
426
+ print(f" [REJECT] TPS={tps:.0f} < floor={state['tps_floor']} β€” skipping eval")
427
+ state["mutations_tested"].append(name)
428
+ state["current_gen"] = gen
429
+ state["history"].append({
430
+ "gen": gen, "mutation": name,
431
+ "quality_score": 0, "baseline_score": current_quality,
432
+ "delta": f"TPS_FAIL({tps:.0f})", "tps": tps,
433
+ "ppl": 0, "bleu4": 0, "rouge_l": 0, "factual": 0,
434
+ "bpb": train_metrics.get("val_bpb", 0), "repetition_rate": 0,
435
+ "kept": False,
436
+ })
437
+ save_state(state)
438
+ continue
439
+
440
+ # Evaluate
441
+ print(f" Evaluating...", flush=True)
442
+ eval_metrics = run_eval_after_training(extra_env=env_str)
443
+ if eval_metrics is None:
444
+ print(f" [SKIP] Eval failed for {name}")
445
+ state["mutations_tested"].append(name)
446
+ state["current_gen"] = gen
447
+ state["history"].append({
448
+ "gen": gen, "mutation": name,
449
+ "quality_score": 0, "baseline_score": current_quality,
450
+ "delta": "EVAL_FAIL", "tps": tps, "ppl": 0, "bleu4": 0,
451
+ "rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
452
+ "kept": False,
453
+ })
454
+ save_state(state)
455
+ continue
456
+
457
+ quality = eval_metrics.get("quality_score", 0)
458
+ delta_pct = ((quality - current_quality) / max(abs(current_quality), 1e-6)) * 100
459
+ delta_str = f"{delta_pct:+.1f}%"
460
+
461
+ kept = quality > current_quality and tps >= state["tps_floor"]
462
+ status = "KEEP" if kept else "DISCARD"
463
+
464
+ entry = {
465
+ "gen": gen,
466
+ "mutation": name,
467
+ "quality_score": quality,
468
+ "baseline_score": current_quality,
469
+ "delta": delta_str,
470
+ "tps": tps,
471
+ "ppl": eval_metrics.get("ppl", 0),
472
+ "bleu4": eval_metrics.get("bleu4", 0),
473
+ "rouge_l": eval_metrics.get("rouge_l", 0),
474
+ "factual": eval_metrics.get("factual", 0),
475
+ "bpb": eval_metrics.get("bpb", 0),
476
+ "repetition_rate": eval_metrics.get("repetition_rate", 0),
477
+ "kept": kept,
478
+ }
479
+
480
+ print(f"\n[GEN {gen}] {name}: quality={quality:.4f} ({delta_str}) tps={tps:.0f} -> {status}")
481
+
482
+ if kept:
483
+ current_quality = quality
484
+ state["mutations_kept"].append(name)
485
+ git_commit(f"autoresearch: gen {gen} β€” {name} quality {delta_str}")
486
+
487
+ state["mutations_tested"].append(name)
488
+ state["current_gen"] = gen
489
+ state["history"].append(entry)
490
+ save_state(state)
491
+
492
+ # -----------------------------------------------------------------------
493
+ # Summary
494
+ # -----------------------------------------------------------------------
495
+ print("\n" + "=" * 70)
496
+ print("AUTORESEARCH COMPLETE")
497
+ print("=" * 70)
498
+ print(f"Total generations: {state['current_gen']}")
499
+ print(f"Mutations kept: {state['mutations_kept']}")
500
+ print(f"Final quality: {current_quality:.4f}")
501
+ if state["baseline_quality"]:
502
+ total_delta = ((current_quality - state["baseline_quality"]) /
503
+ max(abs(state["baseline_quality"]), 1e-6)) * 100
504
+ print(f"Total improvement: {total_delta:+.1f}%")
505
+ print()
506
+
507
+ # Print history table
508
+ print(f"{'Gen':>4} {'Mutation':>20} {'Quality':>8} {'Delta':>8} {'TPS':>7} {'PPL':>8} {'BPB':>7} {'Kept':>5}")
509
+ print("-" * 75)
510
+ for h in state["history"]:
511
+ print(f"{h['gen']:4d} {h['mutation']:>20s} {h['quality_score']:8.4f} "
512
+ f"{h['delta']:>8s} {h['tps']:7.0f} {h['ppl']:8.2f} "
513
+ f"{h.get('bpb', 0):7.4f} {' YES' if h['kept'] else ' NO'}")
514
+
515
+
516
+ if __name__ == "__main__":
517
+ main()
overlay/scripts/chat.py CHANGED
@@ -1,458 +1,458 @@
1
- """Interactive chat REPL for HYDRA.
2
-
3
- Usage:
4
- python scripts/chat.py # auto-select best checkpoint
5
- python scripts/chat.py --ckpt PATH # explicit checkpoint
6
- python scripts/chat.py --sft # prefer sft_final.pt
7
- python scripts/chat.py --random # skip ckpt, use random weights
8
-
9
- HONESTY: model is ~7.5M params at d_model=256/n_layer=4. Expect incoherent
10
- output. This REPL validates the *interface* β€” tokenizer roundtrip, generation
11
- loop, stop-token handling, conversation history truncation. Coherent dialogue
12
- is not a goal at this scale.
13
-
14
- Slash commands:
15
- /reset clear conversation history
16
- /quit exit
17
- /temp X set temperature (default 0.8)
18
- /topk K set top-k (default 40)
19
- /topp P set top-p (default 0.9)
20
- /max N set max new tokens per turn (default 200)
21
- /rep R set repetition penalty (default 1.1)
22
- /sys S set a system prefix prepended to every turn
23
- /info print current settings + checkpoint path
24
- """
25
-
26
- from __future__ import annotations
27
-
28
- import argparse
29
- import os
30
- import sys
31
- import time
32
- from dataclasses import asdict
33
- from pathlib import Path
34
-
35
- # Make repo root importable when invoked as `python scripts/chat.py`.
36
- _REPO_ROOT = Path(__file__).resolve().parent.parent
37
- if str(_REPO_ROOT) not in sys.path:
38
- sys.path.insert(0, str(_REPO_ROOT))
39
-
40
- import torch # noqa: E402
41
-
42
- # Chat template β€” plain-text fallback (see .omc/chat_plan.md).
43
- # If the SFT agent later reserves special tokens, redefine USER_TAG /
44
- # ASSISTANT_TAG / END_TAG and the stop-string accordingly.
45
- USER_TAG = "User:"
46
- ASSISTANT_TAG = "Assistant:"
47
- END_TAG = "\nUser:" # stop-string matched on decoded output
48
-
49
- CKPT_DIR = Path(os.path.expanduser("~/.cache/autoresearch/ckpts"))
50
- CKPT_CANDIDATES_PRETRAIN = ["pretrain_final.pt", "latest.pt"]
51
- CKPT_CANDIDATES_SFT = ["sft_final.pt"]
52
-
53
-
54
- # ---------------------------------------------------------------------------
55
- # Checkpoint resolution
56
- # ---------------------------------------------------------------------------
57
-
58
- def resolve_checkpoint(explicit: str | None, prefer_sft: bool) -> Path | None:
59
- """Return Path to checkpoint file, or None if nothing found.
60
-
61
- Order:
62
- 1. `explicit` if provided and exists.
63
- 2. If prefer_sft: sft_final.pt -> pretrain_final.pt -> latest.pt.
64
- 3. Else: sft_final.pt (if exists) -> pretrain_final.pt -> latest.pt.
65
- """
66
- if explicit:
67
- p = Path(os.path.expanduser(explicit))
68
- if p.exists():
69
- return p
70
- print(f"[WARN] --ckpt {p} does not exist; falling through to auto-select.", file=sys.stderr)
71
-
72
- # Task spec: prefer sft_final.pt if it exists; otherwise pretrain_final.pt
73
- # then latest.pt. --sft just makes the preference explicit; it's already
74
- # the default behavior. We list SFT first in both orderings to honor the
75
- # spec, since the task description said "prefer sft if exists" by default.
76
- _ = prefer_sft # reserved for future "pretrain-only" vs "sft-only" modes
77
- order = CKPT_CANDIDATES_SFT + CKPT_CANDIDATES_PRETRAIN
78
- for name in order:
79
- cand = CKPT_DIR / name
80
- if cand.exists():
81
- return cand
82
- return None
83
-
84
-
85
- # ---------------------------------------------------------------------------
86
- # Model + tokenizer loading
87
- # ---------------------------------------------------------------------------
88
-
89
- def load_model_and_tokenizer(ckpt_path: Path | None, device: torch.device):
90
- """Build model + tokenizer. If ckpt_path is None, random weights are used.
91
-
92
- Returns (model, tokenizer, meta) where meta is a dict with 'ckpt',
93
- 'step', 'val_bpb' etc. for /info display.
94
- """
95
- from hydra.config import PostSemClawConfig
96
- from hydra.model import PostSemClawModel
97
- from prepare import Tokenizer
98
-
99
- tokenizer = Tokenizer.from_directory()
100
- vocab_size = tokenizer.get_vocab_size()
101
- print(f"[chat] Tokenizer loaded (vocab={vocab_size:,})")
102
-
103
- meta: dict = {"ckpt": str(ckpt_path) if ckpt_path else "<random>", "step": None, "val_bpb": None}
104
-
105
- # Build config. If checkpoint provides one, use it; else use env-var defaults.
106
- ckpt_state = None
107
- config_kwargs: dict = {}
108
- if ckpt_path is not None:
109
- print(f"[chat] Loading checkpoint: {ckpt_path}")
110
- ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=False)
111
- cfg_dict = ckpt_state.get("config")
112
- if isinstance(cfg_dict, dict):
113
- # Filter to kwargs PostSemClawConfig actually accepts.
114
- allowed = set(PostSemClawConfig.__dataclass_fields__.keys())
115
- config_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
116
- meta["step"] = ckpt_state.get("step")
117
- meta["val_bpb"] = ckpt_state.get("val_bpb") or ckpt_state.get("bpb")
118
-
119
- # Env-var defaults are applied by PostSemClawConfig field defaults; but the
120
- # training run builds the config explicitly from hydra.config module-level
121
- # constants. We mirror that here so the random-weights path aligns with
122
- # what train.py would instantiate for the same env.
123
- if not config_kwargs:
124
- from hydra.config import ( # noqa: E402
125
- D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX,
126
- ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER,
127
- )
128
- from prepare import MAX_SEQ_LEN # noqa: E402
129
- config_kwargs = dict(
130
- sequence_len=MAX_SEQ_LEN,
131
- vocab_size=vocab_size,
132
- n_layer=N_LAYER,
133
- d_model=D_MODEL,
134
- d_state=D_STATE,
135
- headdim=HEADDIM,
136
- n_heads=N_HEADS,
137
- expand=EXPAND,
138
- engram_n_columns=ENGRAM_N_COLUMNS,
139
- engram_key_dim=ENGRAM_KEY_DIM,
140
- engram_layer_idx=ENGRAM_LAYER_IDX,
141
- )
142
-
143
- # Build model on meta device then materialize β€” matches training.py path.
144
- with torch.device("meta"):
145
- model = PostSemClawModel(PostSemClawConfig(**config_kwargs))
146
- model.to_empty(device=device)
147
- model.init_weights()
148
-
149
- if ckpt_state is not None and "model_state_dict" in ckpt_state:
150
- # strict=False: the model has non-parameter buffers (SDR retina loaded
151
- # from npz, HTM Rust-side state, engram EMA stats) that may not be in
152
- # the state_dict. missing/unexpected-key warnings are expected and OK.
153
- missing, unexpected = model.load_state_dict(
154
- ckpt_state["model_state_dict"], strict=False
155
- )
156
- if missing:
157
- print(f"[chat] Note: {len(missing)} missing key(s) in state_dict (expected for HTM/SDR buffers).")
158
- if unexpected:
159
- print(f"[chat] Note: {len(unexpected)} unexpected key(s) in state_dict.")
160
- elif ckpt_path is None:
161
- print("[chat] [WARN] NO CHECKPOINT β€” using random weights. Output will be gibberish.", file=sys.stderr)
162
-
163
- model.eval()
164
- return model, tokenizer, meta
165
-
166
-
167
- # ---------------------------------------------------------------------------
168
- # Generation
169
- # ---------------------------------------------------------------------------
170
-
171
- def generate_stream(
172
- model,
173
- tokenizer,
174
- prompt_ids: list[int],
175
- *,
176
- max_new_tokens: int,
177
- temperature: float,
178
- top_k: int,
179
- top_p: float,
180
- repetition_penalty: float,
181
- stop_strings: tuple[str, ...],
182
- max_seq_len: int,
183
- device: torch.device,
184
- rep_window: int = 64,
185
- ):
186
- """Yield decoded-text chunks as tokens are generated.
187
-
188
- Truncates `prompt_ids` to the last `max_seq_len` tokens if needed. Stops
189
- early when any `stop_strings` substring appears in the newly-decoded
190
- continuation.
191
- """
192
- from scripts.sample_utils import sample_token
193
-
194
- # Truncate prompt to window.
195
- if len(prompt_ids) > max_seq_len:
196
- prompt_ids = prompt_ids[-max_seq_len:]
197
-
198
- ctx = torch.tensor([prompt_ids], device=device, dtype=torch.long)
199
- generated: list[int] = []
200
- # Track already-streamed byte length so we can detect when the decoded
201
- # string has grown (BPE tokens may decode to multi-char strings mid-merge).
202
- streamed_chars = 0
203
- accumulated_text = ""
204
-
205
- autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
206
-
207
- for _ in range(max_new_tokens):
208
- with torch.no_grad(), autocast_ctx:
209
- out = model(ctx, targets=None)
210
- # out shape: (1, T, vocab) or (1, vocab) depending on path.
211
- if out.dim() == 3:
212
- last_logits = out[0, -1, :]
213
- else:
214
- last_logits = out[0]
215
-
216
- recent = generated[-rep_window:] if generated else None
217
- next_id = sample_token(
218
- last_logits,
219
- temperature=temperature,
220
- top_k=top_k,
221
- top_p=top_p,
222
- repetition_penalty=repetition_penalty,
223
- recent_tokens=recent,
224
- )
225
- generated.append(next_id)
226
-
227
- # Decode everything so-far then diff β€” BPE decoding is not token-local,
228
- # so a per-token decode can drop bytes.
229
- new_text = tokenizer.decode(generated)
230
- delta = new_text[streamed_chars:]
231
- if delta:
232
- streamed_chars = len(new_text)
233
- accumulated_text = new_text
234
- yield delta
235
-
236
- # Stop-string check.
237
- hit_stop = any(s and s in accumulated_text for s in stop_strings)
238
- if hit_stop:
239
- break
240
-
241
- # Advance context. If we've filled the window, drop oldest token.
242
- ctx = torch.cat([ctx, torch.tensor([[next_id]], device=device, dtype=torch.long)], dim=1)
243
- if ctx.size(1) > max_seq_len:
244
- ctx = ctx[:, -max_seq_len:]
245
-
246
- # Final accumulated text is also returned for history tracking.
247
- return accumulated_text # noqa: B901 (generator return for history)
248
-
249
-
250
- def _consume_stream_with_print(stream_gen):
251
- """Iterate a generator, print each chunk, return the full text.
252
-
253
- Replacement for a naΓ―ve list(stream) since `generate_stream` is a generator
254
- that yields then returns the final text.
255
- """
256
- collected = []
257
- try:
258
- while True:
259
- chunk = next(stream_gen)
260
- collected.append(chunk)
261
- sys.stdout.write(chunk)
262
- sys.stdout.flush()
263
- except StopIteration as stop:
264
- # stop.value holds the return value of the generator.
265
- final = stop.value
266
- if final is not None:
267
- return final
268
- return "".join(collected)
269
-
270
-
271
- # ---------------------------------------------------------------------------
272
- # REPL
273
- # ---------------------------------------------------------------------------
274
-
275
- def build_prompt(system: str, history: list[tuple[str, str]], user_msg: str) -> str:
276
- """Assemble the text prompt fed to the tokenizer."""
277
- parts: list[str] = []
278
- if system:
279
- parts.append(system.rstrip() + "\n")
280
- for u, a in history:
281
- parts.append(f"{USER_TAG} {u}\n{ASSISTANT_TAG} {a}\n")
282
- parts.append(f"{USER_TAG} {user_msg}\n{ASSISTANT_TAG}")
283
- return "".join(parts)
284
-
285
-
286
- def run_repl(
287
- model,
288
- tokenizer,
289
- meta: dict,
290
- *,
291
- device: torch.device,
292
- max_seq_len: int,
293
- ) -> None:
294
- settings = {
295
- "temperature": float(os.environ.get("HYDRA_CHAT_TEMP", "0.8")),
296
- "top_k": int(os.environ.get("HYDRA_CHAT_TOPK", "40")),
297
- "top_p": float(os.environ.get("HYDRA_CHAT_TOPP", "0.9")),
298
- "max_new_tokens": int(os.environ.get("HYDRA_CHAT_MAX", "200")),
299
- "repetition_penalty": float(os.environ.get("HYDRA_CHAT_REP", "1.1")),
300
- "system": os.environ.get("HYDRA_CHAT_SYSTEM", ""),
301
- }
302
- history: list[tuple[str, str]] = []
303
-
304
- print()
305
- print("=" * 60)
306
- print("HYDRA chat REPL")
307
- print(f" checkpoint: {meta['ckpt']}")
308
- if meta.get("step") is not None:
309
- print(f" step: {meta['step']}")
310
- if meta.get("val_bpb") is not None:
311
- print(f" val_bpb: {meta['val_bpb']}")
312
- print(" type /info for settings, /quit to exit")
313
- print("=" * 60)
314
- print()
315
-
316
- while True:
317
- try:
318
- line = input(f"{USER_TAG} ")
319
- except (EOFError, KeyboardInterrupt):
320
- print()
321
- return
322
-
323
- line = line.rstrip()
324
- if not line:
325
- continue
326
-
327
- if line.startswith("/"):
328
- cmd, *rest = line.split(maxsplit=1)
329
- arg = rest[0] if rest else ""
330
- if cmd == "/quit" or cmd == "/exit":
331
- return
332
- elif cmd == "/reset":
333
- history = []
334
- print("[reset]")
335
- continue
336
- elif cmd == "/info":
337
- print(f"[info] ckpt={meta['ckpt']} settings={settings} history_turns={len(history)}")
338
- continue
339
- elif cmd == "/temp":
340
- try:
341
- settings["temperature"] = float(arg)
342
- print(f"[temp={settings['temperature']}]")
343
- except ValueError:
344
- print(f"[err] /temp needs a float, got {arg!r}")
345
- continue
346
- elif cmd == "/topk":
347
- try:
348
- settings["top_k"] = int(arg)
349
- print(f"[topk={settings['top_k']}]")
350
- except ValueError:
351
- print(f"[err] /topk needs an int, got {arg!r}")
352
- continue
353
- elif cmd == "/topp":
354
- try:
355
- settings["top_p"] = float(arg)
356
- print(f"[topp={settings['top_p']}]")
357
- except ValueError:
358
- print(f"[err] /topp needs a float, got {arg!r}")
359
- continue
360
- elif cmd == "/max":
361
- try:
362
- settings["max_new_tokens"] = int(arg)
363
- print(f"[max={settings['max_new_tokens']}]")
364
- except ValueError:
365
- print(f"[err] /max needs an int, got {arg!r}")
366
- continue
367
- elif cmd == "/rep":
368
- try:
369
- settings["repetition_penalty"] = float(arg)
370
- print(f"[rep={settings['repetition_penalty']}]")
371
- except ValueError:
372
- print(f"[err] /rep needs a float, got {arg!r}")
373
- continue
374
- elif cmd == "/sys":
375
- settings["system"] = arg
376
- print(f"[sys set, {len(arg)} chars]")
377
- continue
378
- else:
379
- print(f"[err] unknown command {cmd!r}. Try /info /reset /quit.")
380
- continue
381
-
382
- # Normal chat turn.
383
- prompt_text = build_prompt(settings["system"], history, line)
384
- prompt_ids = tokenizer.encode(prompt_text)
385
-
386
- sys.stdout.write(f"{ASSISTANT_TAG} ")
387
- sys.stdout.flush()
388
-
389
- stream = generate_stream(
390
- model, tokenizer, prompt_ids,
391
- max_new_tokens=settings["max_new_tokens"],
392
- temperature=settings["temperature"],
393
- top_k=settings["top_k"],
394
- top_p=settings["top_p"],
395
- repetition_penalty=settings["repetition_penalty"],
396
- stop_strings=(END_TAG,),
397
- max_seq_len=max_seq_len,
398
- device=device,
399
- )
400
- response_text = _consume_stream_with_print(stream)
401
- if not response_text.endswith("\n"):
402
- sys.stdout.write("\n")
403
- sys.stdout.flush()
404
-
405
- # Strip trailing stop marker from the remembered history.
406
- clean = response_text
407
- if END_TAG in clean:
408
- clean = clean.split(END_TAG, 1)[0]
409
- clean = clean.strip()
410
- history.append((line, clean))
411
-
412
-
413
- # ---------------------------------------------------------------------------
414
- # CLI
415
- # ---------------------------------------------------------------------------
416
-
417
- def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
418
- p = argparse.ArgumentParser(description="HYDRA chat REPL")
419
- p.add_argument("--ckpt", type=str, default=None,
420
- help="Path to checkpoint (.pt). If omitted, auto-select.")
421
- p.add_argument("--sft", action="store_true",
422
- help="Prefer an SFT checkpoint if available.")
423
- p.add_argument("--random", action="store_true",
424
- help="Skip checkpoint load; use random weights.")
425
- p.add_argument("--device", type=str, default=None,
426
- help="Torch device (default: cuda if available else cpu).")
427
- return p.parse_args(argv)
428
-
429
-
430
- def main(argv: list[str] | None = None) -> int:
431
- args = _parse_args(argv)
432
-
433
- if args.device:
434
- device = torch.device(args.device)
435
- elif torch.cuda.is_available():
436
- device = torch.device("cuda")
437
- else:
438
- device = torch.device("cpu")
439
- print("[chat] [WARN] CUDA not available; HYDRA's HTM/Mamba kernels may fail on CPU.", file=sys.stderr)
440
-
441
- ckpt_path: Path | None
442
- if args.random:
443
- ckpt_path = None
444
- else:
445
- ckpt_path = resolve_checkpoint(args.ckpt, args.sft)
446
-
447
- t0 = time.time()
448
- model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
449
- dt = time.time() - t0
450
- print(f"[chat] Model ready in {dt:.1f}s on {device}")
451
-
452
- from prepare import MAX_SEQ_LEN
453
- run_repl(model, tokenizer, meta, device=device, max_seq_len=MAX_SEQ_LEN)
454
- return 0
455
-
456
-
457
- if __name__ == "__main__":
458
- sys.exit(main())
 
1
+ """Interactive chat REPL for HYDRA.
2
+
3
+ Usage:
4
+ python scripts/chat.py # auto-select best checkpoint
5
+ python scripts/chat.py --ckpt PATH # explicit checkpoint
6
+ python scripts/chat.py --sft # prefer sft_final.pt
7
+ python scripts/chat.py --random # skip ckpt, use random weights
8
+
9
+ HONESTY: model is ~7.5M params at d_model=256/n_layer=4. Expect incoherent
10
+ output. This REPL validates the *interface* β€” tokenizer roundtrip, generation
11
+ loop, stop-token handling, conversation history truncation. Coherent dialogue
12
+ is not a goal at this scale.
13
+
14
+ Slash commands:
15
+ /reset clear conversation history
16
+ /quit exit
17
+ /temp X set temperature (default 0.8)
18
+ /topk K set top-k (default 40)
19
+ /topp P set top-p (default 0.9)
20
+ /max N set max new tokens per turn (default 200)
21
+ /rep R set repetition penalty (default 1.1)
22
+ /sys S set a system prefix prepended to every turn
23
+ /info print current settings + checkpoint path
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import os
30
+ import sys
31
+ import time
32
+ from dataclasses import asdict
33
+ from pathlib import Path
34
+
35
+ # Make repo root importable when invoked as `python scripts/chat.py`.
36
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
37
+ if str(_REPO_ROOT) not in sys.path:
38
+ sys.path.insert(0, str(_REPO_ROOT))
39
+
40
+ import torch # noqa: E402
41
+
42
+ # Chat template β€” plain-text fallback (see .omc/chat_plan.md).
43
+ # If the SFT agent later reserves special tokens, redefine USER_TAG /
44
+ # ASSISTANT_TAG / END_TAG and the stop-string accordingly.
45
+ USER_TAG = "User:"
46
+ ASSISTANT_TAG = "Assistant:"
47
+ END_TAG = "\nUser:" # stop-string matched on decoded output
48
+
49
+ CKPT_DIR = Path(os.path.expanduser("~/.cache/autoresearch/ckpts"))
50
+ CKPT_CANDIDATES_PRETRAIN = ["pretrain_final.pt", "latest.pt"]
51
+ CKPT_CANDIDATES_SFT = ["sft_final.pt"]
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Checkpoint resolution
56
+ # ---------------------------------------------------------------------------
57
+
58
+ def resolve_checkpoint(explicit: str | None, prefer_sft: bool) -> Path | None:
59
+ """Return Path to checkpoint file, or None if nothing found.
60
+
61
+ Order:
62
+ 1. `explicit` if provided and exists.
63
+ 2. If prefer_sft: sft_final.pt -> pretrain_final.pt -> latest.pt.
64
+ 3. Else: sft_final.pt (if exists) -> pretrain_final.pt -> latest.pt.
65
+ """
66
+ if explicit:
67
+ p = Path(os.path.expanduser(explicit))
68
+ if p.exists():
69
+ return p
70
+ print(f"[WARN] --ckpt {p} does not exist; falling through to auto-select.", file=sys.stderr)
71
+
72
+ # Task spec: prefer sft_final.pt if it exists; otherwise pretrain_final.pt
73
+ # then latest.pt. --sft just makes the preference explicit; it's already
74
+ # the default behavior. We list SFT first in both orderings to honor the
75
+ # spec, since the task description said "prefer sft if exists" by default.
76
+ _ = prefer_sft # reserved for future "pretrain-only" vs "sft-only" modes
77
+ order = CKPT_CANDIDATES_SFT + CKPT_CANDIDATES_PRETRAIN
78
+ for name in order:
79
+ cand = CKPT_DIR / name
80
+ if cand.exists():
81
+ return cand
82
+ return None
83
+
84
+
85
+ # ---------------------------------------------------------------------------
86
+ # Model + tokenizer loading
87
+ # ---------------------------------------------------------------------------
88
+
89
+ def load_model_and_tokenizer(ckpt_path: Path | None, device: torch.device):
90
+ """Build model + tokenizer. If ckpt_path is None, random weights are used.
91
+
92
+ Returns (model, tokenizer, meta) where meta is a dict with 'ckpt',
93
+ 'step', 'val_bpb' etc. for /info display.
94
+ """
95
+ from hydra.config import PostSemClawConfig
96
+ from hydra.model import PostSemClawModel
97
+ from prepare import Tokenizer
98
+
99
+ tokenizer = Tokenizer.from_directory()
100
+ vocab_size = tokenizer.get_vocab_size()
101
+ print(f"[chat] Tokenizer loaded (vocab={vocab_size:,})")
102
+
103
+ meta: dict = {"ckpt": str(ckpt_path) if ckpt_path else "<random>", "step": None, "val_bpb": None}
104
+
105
+ # Build config. If checkpoint provides one, use it; else use env-var defaults.
106
+ ckpt_state = None
107
+ config_kwargs: dict = {}
108
+ if ckpt_path is not None:
109
+ print(f"[chat] Loading checkpoint: {ckpt_path}")
110
+ ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=False)
111
+ cfg_dict = ckpt_state.get("config")
112
+ if isinstance(cfg_dict, dict):
113
+ # Filter to kwargs PostSemClawConfig actually accepts.
114
+ allowed = set(PostSemClawConfig.__dataclass_fields__.keys())
115
+ config_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
116
+ meta["step"] = ckpt_state.get("step")
117
+ meta["val_bpb"] = ckpt_state.get("val_bpb") or ckpt_state.get("bpb")
118
+
119
+ # Env-var defaults are applied by PostSemClawConfig field defaults; but the
120
+ # training run builds the config explicitly from hydra.config module-level
121
+ # constants. We mirror that here so the random-weights path aligns with
122
+ # what train.py would instantiate for the same env.
123
+ if not config_kwargs:
124
+ from hydra.config import ( # noqa: E402
125
+ D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX,
126
+ ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER,
127
+ )
128
+ from prepare import MAX_SEQ_LEN # noqa: E402
129
+ config_kwargs = dict(
130
+ sequence_len=MAX_SEQ_LEN,
131
+ vocab_size=vocab_size,
132
+ n_layer=N_LAYER,
133
+ d_model=D_MODEL,
134
+ d_state=D_STATE,
135
+ headdim=HEADDIM,
136
+ n_heads=N_HEADS,
137
+ expand=EXPAND,
138
+ engram_n_columns=ENGRAM_N_COLUMNS,
139
+ engram_key_dim=ENGRAM_KEY_DIM,
140
+ engram_layer_idx=ENGRAM_LAYER_IDX,
141
+ )
142
+
143
+ # Build model on meta device then materialize β€” matches training.py path.
144
+ with torch.device("meta"):
145
+ model = PostSemClawModel(PostSemClawConfig(**config_kwargs))
146
+ model.to_empty(device=device)
147
+ model.init_weights()
148
+
149
+ if ckpt_state is not None and "model_state_dict" in ckpt_state:
150
+ # strict=False: the model has non-parameter buffers (SDR retina loaded
151
+ # from npz, HTM Rust-side state, engram EMA stats) that may not be in
152
+ # the state_dict. missing/unexpected-key warnings are expected and OK.
153
+ missing, unexpected = model.load_state_dict(
154
+ ckpt_state["model_state_dict"], strict=False
155
+ )
156
+ if missing:
157
+ print(f"[chat] Note: {len(missing)} missing key(s) in state_dict (expected for HTM/SDR buffers).")
158
+ if unexpected:
159
+ print(f"[chat] Note: {len(unexpected)} unexpected key(s) in state_dict.")
160
+ elif ckpt_path is None:
161
+ print("[chat] [WARN] NO CHECKPOINT β€” using random weights. Output will be gibberish.", file=sys.stderr)
162
+
163
+ model.eval()
164
+ return model, tokenizer, meta
165
+
166
+
167
+ # ---------------------------------------------------------------------------
168
+ # Generation
169
+ # ---------------------------------------------------------------------------
170
+
171
+ def generate_stream(
172
+ model,
173
+ tokenizer,
174
+ prompt_ids: list[int],
175
+ *,
176
+ max_new_tokens: int,
177
+ temperature: float,
178
+ top_k: int,
179
+ top_p: float,
180
+ repetition_penalty: float,
181
+ stop_strings: tuple[str, ...],
182
+ max_seq_len: int,
183
+ device: torch.device,
184
+ rep_window: int = 64,
185
+ ):
186
+ """Yield decoded-text chunks as tokens are generated.
187
+
188
+ Truncates `prompt_ids` to the last `max_seq_len` tokens if needed. Stops
189
+ early when any `stop_strings` substring appears in the newly-decoded
190
+ continuation.
191
+ """
192
+ from scripts.sample_utils import sample_token
193
+
194
+ # Truncate prompt to window.
195
+ if len(prompt_ids) > max_seq_len:
196
+ prompt_ids = prompt_ids[-max_seq_len:]
197
+
198
+ ctx = torch.tensor([prompt_ids], device=device, dtype=torch.long)
199
+ generated: list[int] = []
200
+ # Track already-streamed byte length so we can detect when the decoded
201
+ # string has grown (BPE tokens may decode to multi-char strings mid-merge).
202
+ streamed_chars = 0
203
+ accumulated_text = ""
204
+
205
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
206
+
207
+ for _ in range(max_new_tokens):
208
+ with torch.no_grad(), autocast_ctx:
209
+ out = model(ctx, targets=None)
210
+ # out shape: (1, T, vocab) or (1, vocab) depending on path.
211
+ if out.dim() == 3:
212
+ last_logits = out[0, -1, :]
213
+ else:
214
+ last_logits = out[0]
215
+
216
+ recent = generated[-rep_window:] if generated else None
217
+ next_id = sample_token(
218
+ last_logits,
219
+ temperature=temperature,
220
+ top_k=top_k,
221
+ top_p=top_p,
222
+ repetition_penalty=repetition_penalty,
223
+ recent_tokens=recent,
224
+ )
225
+ generated.append(next_id)
226
+
227
+ # Decode everything so-far then diff β€” BPE decoding is not token-local,
228
+ # so a per-token decode can drop bytes.
229
+ new_text = tokenizer.decode(generated)
230
+ delta = new_text[streamed_chars:]
231
+ if delta:
232
+ streamed_chars = len(new_text)
233
+ accumulated_text = new_text
234
+ yield delta
235
+
236
+ # Stop-string check.
237
+ hit_stop = any(s and s in accumulated_text for s in stop_strings)
238
+ if hit_stop:
239
+ break
240
+
241
+ # Advance context. If we've filled the window, drop oldest token.
242
+ ctx = torch.cat([ctx, torch.tensor([[next_id]], device=device, dtype=torch.long)], dim=1)
243
+ if ctx.size(1) > max_seq_len:
244
+ ctx = ctx[:, -max_seq_len:]
245
+
246
+ # Final accumulated text is also returned for history tracking.
247
+ return accumulated_text # noqa: B901 (generator return for history)
248
+
249
+
250
+ def _consume_stream_with_print(stream_gen):
251
+ """Iterate a generator, print each chunk, return the full text.
252
+
253
+ Replacement for a naΓ―ve list(stream) since `generate_stream` is a generator
254
+ that yields then returns the final text.
255
+ """
256
+ collected = []
257
+ try:
258
+ while True:
259
+ chunk = next(stream_gen)
260
+ collected.append(chunk)
261
+ sys.stdout.write(chunk)
262
+ sys.stdout.flush()
263
+ except StopIteration as stop:
264
+ # stop.value holds the return value of the generator.
265
+ final = stop.value
266
+ if final is not None:
267
+ return final
268
+ return "".join(collected)
269
+
270
+
271
+ # ---------------------------------------------------------------------------
272
+ # REPL
273
+ # ---------------------------------------------------------------------------
274
+
275
+ def build_prompt(system: str, history: list[tuple[str, str]], user_msg: str) -> str:
276
+ """Assemble the text prompt fed to the tokenizer."""
277
+ parts: list[str] = []
278
+ if system:
279
+ parts.append(system.rstrip() + "\n")
280
+ for u, a in history:
281
+ parts.append(f"{USER_TAG} {u}\n{ASSISTANT_TAG} {a}\n")
282
+ parts.append(f"{USER_TAG} {user_msg}\n{ASSISTANT_TAG}")
283
+ return "".join(parts)
284
+
285
+
286
+ def run_repl(
287
+ model,
288
+ tokenizer,
289
+ meta: dict,
290
+ *,
291
+ device: torch.device,
292
+ max_seq_len: int,
293
+ ) -> None:
294
+ settings = {
295
+ "temperature": float(os.environ.get("HYDRA_CHAT_TEMP", "0.8")),
296
+ "top_k": int(os.environ.get("HYDRA_CHAT_TOPK", "40")),
297
+ "top_p": float(os.environ.get("HYDRA_CHAT_TOPP", "0.9")),
298
+ "max_new_tokens": int(os.environ.get("HYDRA_CHAT_MAX", "200")),
299
+ "repetition_penalty": float(os.environ.get("HYDRA_CHAT_REP", "1.1")),
300
+ "system": os.environ.get("HYDRA_CHAT_SYSTEM", ""),
301
+ }
302
+ history: list[tuple[str, str]] = []
303
+
304
+ print()
305
+ print("=" * 60)
306
+ print("HYDRA chat REPL")
307
+ print(f" checkpoint: {meta['ckpt']}")
308
+ if meta.get("step") is not None:
309
+ print(f" step: {meta['step']}")
310
+ if meta.get("val_bpb") is not None:
311
+ print(f" val_bpb: {meta['val_bpb']}")
312
+ print(" type /info for settings, /quit to exit")
313
+ print("=" * 60)
314
+ print()
315
+
316
+ while True:
317
+ try:
318
+ line = input(f"{USER_TAG} ")
319
+ except (EOFError, KeyboardInterrupt):
320
+ print()
321
+ return
322
+
323
+ line = line.rstrip()
324
+ if not line:
325
+ continue
326
+
327
+ if line.startswith("/"):
328
+ cmd, *rest = line.split(maxsplit=1)
329
+ arg = rest[0] if rest else ""
330
+ if cmd == "/quit" or cmd == "/exit":
331
+ return
332
+ elif cmd == "/reset":
333
+ history = []
334
+ print("[reset]")
335
+ continue
336
+ elif cmd == "/info":
337
+ print(f"[info] ckpt={meta['ckpt']} settings={settings} history_turns={len(history)}")
338
+ continue
339
+ elif cmd == "/temp":
340
+ try:
341
+ settings["temperature"] = float(arg)
342
+ print(f"[temp={settings['temperature']}]")
343
+ except ValueError:
344
+ print(f"[err] /temp needs a float, got {arg!r}")
345
+ continue
346
+ elif cmd == "/topk":
347
+ try:
348
+ settings["top_k"] = int(arg)
349
+ print(f"[topk={settings['top_k']}]")
350
+ except ValueError:
351
+ print(f"[err] /topk needs an int, got {arg!r}")
352
+ continue
353
+ elif cmd == "/topp":
354
+ try:
355
+ settings["top_p"] = float(arg)
356
+ print(f"[topp={settings['top_p']}]")
357
+ except ValueError:
358
+ print(f"[err] /topp needs a float, got {arg!r}")
359
+ continue
360
+ elif cmd == "/max":
361
+ try:
362
+ settings["max_new_tokens"] = int(arg)
363
+ print(f"[max={settings['max_new_tokens']}]")
364
+ except ValueError:
365
+ print(f"[err] /max needs an int, got {arg!r}")
366
+ continue
367
+ elif cmd == "/rep":
368
+ try:
369
+ settings["repetition_penalty"] = float(arg)
370
+ print(f"[rep={settings['repetition_penalty']}]")
371
+ except ValueError:
372
+ print(f"[err] /rep needs a float, got {arg!r}")
373
+ continue
374
+ elif cmd == "/sys":
375
+ settings["system"] = arg
376
+ print(f"[sys set, {len(arg)} chars]")
377
+ continue
378
+ else:
379
+ print(f"[err] unknown command {cmd!r}. Try /info /reset /quit.")
380
+ continue
381
+
382
+ # Normal chat turn.
383
+ prompt_text = build_prompt(settings["system"], history, line)
384
+ prompt_ids = tokenizer.encode(prompt_text)
385
+
386
+ sys.stdout.write(f"{ASSISTANT_TAG} ")
387
+ sys.stdout.flush()
388
+
389
+ stream = generate_stream(
390
+ model, tokenizer, prompt_ids,
391
+ max_new_tokens=settings["max_new_tokens"],
392
+ temperature=settings["temperature"],
393
+ top_k=settings["top_k"],
394
+ top_p=settings["top_p"],
395
+ repetition_penalty=settings["repetition_penalty"],
396
+ stop_strings=(END_TAG,),
397
+ max_seq_len=max_seq_len,
398
+ device=device,
399
+ )
400
+ response_text = _consume_stream_with_print(stream)
401
+ if not response_text.endswith("\n"):
402
+ sys.stdout.write("\n")
403
+ sys.stdout.flush()
404
+
405
+ # Strip trailing stop marker from the remembered history.
406
+ clean = response_text
407
+ if END_TAG in clean:
408
+ clean = clean.split(END_TAG, 1)[0]
409
+ clean = clean.strip()
410
+ history.append((line, clean))
411
+
412
+
413
+ # ---------------------------------------------------------------------------
414
+ # CLI
415
+ # ---------------------------------------------------------------------------
416
+
417
+ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
418
+ p = argparse.ArgumentParser(description="HYDRA chat REPL")
419
+ p.add_argument("--ckpt", type=str, default=None,
420
+ help="Path to checkpoint (.pt). If omitted, auto-select.")
421
+ p.add_argument("--sft", action="store_true",
422
+ help="Prefer an SFT checkpoint if available.")
423
+ p.add_argument("--random", action="store_true",
424
+ help="Skip checkpoint load; use random weights.")
425
+ p.add_argument("--device", type=str, default=None,
426
+ help="Torch device (default: cuda if available else cpu).")
427
+ return p.parse_args(argv)
428
+
429
+
430
+ def main(argv: list[str] | None = None) -> int:
431
+ args = _parse_args(argv)
432
+
433
+ if args.device:
434
+ device = torch.device(args.device)
435
+ elif torch.cuda.is_available():
436
+ device = torch.device("cuda")
437
+ else:
438
+ device = torch.device("cpu")
439
+ print("[chat] [WARN] CUDA not available; HYDRA's HTM/Mamba kernels may fail on CPU.", file=sys.stderr)
440
+
441
+ ckpt_path: Path | None
442
+ if args.random:
443
+ ckpt_path = None
444
+ else:
445
+ ckpt_path = resolve_checkpoint(args.ckpt, args.sft)
446
+
447
+ t0 = time.time()
448
+ model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
449
+ dt = time.time() - t0
450
+ print(f"[chat] Model ready in {dt:.1f}s on {device}")
451
+
452
+ from prepare import MAX_SEQ_LEN
453
+ run_repl(model, tokenizer, meta, device=device, max_seq_len=MAX_SEQ_LEN)
454
+ return 0
455
+
456
+
457
+ if __name__ == "__main__":
458
+ sys.exit(main())
overlay/scripts/chat_eval.py CHANGED
@@ -1,300 +1,300 @@
1
- """Non-interactive chat eval for HYDRA.
2
-
3
- Runs a fixed set of prompts through the same chat template that `chat.py`
4
- uses, prints a markdown table with the response and coherence heuristics.
5
-
6
- Usage:
7
- python scripts/chat_eval.py # auto-select checkpoint
8
- python scripts/chat_eval.py --ckpt PATH
9
- python scripts/chat_eval.py --random
10
- python scripts/chat_eval.py --json out.json # also dump raw results
11
- python scripts/chat_eval.py --max 80 # cap new tokens per prompt
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import argparse
17
- import json
18
- import os
19
- import re
20
- import sys
21
- import time
22
- from pathlib import Path
23
-
24
- _REPO_ROOT = Path(__file__).resolve().parent.parent
25
- if str(_REPO_ROOT) not in sys.path:
26
- sys.path.insert(0, str(_REPO_ROOT))
27
-
28
- import torch # noqa: E402
29
-
30
- from scripts.chat import ( # noqa: E402
31
- ASSISTANT_TAG, END_TAG, USER_TAG, build_prompt,
32
- generate_stream, load_model_and_tokenizer, resolve_checkpoint,
33
- )
34
-
35
-
36
- PROMPTS: list[str] = [
37
- # Factual
38
- "What is the capital of France?",
39
- "Who wrote Romeo and Juliet?",
40
- "What is 2 plus 2?",
41
- "What color is the sky on a clear day?",
42
- # Completion
43
- "Once upon a time",
44
- "The cat sat on the",
45
- "In a hole in the ground there lived",
46
- # Instruction
47
- "Write one short sentence about rain.",
48
- "List three animals.",
49
- "Define the word 'library'.",
50
- # Conversational
51
- "Hello, how are you?",
52
- "Tell me a joke.",
53
- # Creative
54
- "Describe a sunset in one line.",
55
- "Give me a name for a pet robot.",
56
- "What is the meaning of friendship?",
57
- ]
58
-
59
- # Heuristic thresholds (printed, not enforced as pass/fail).
60
- THRESH_DISTINCT_2 = 0.30
61
- THRESH_SENT_MIN = 5
62
- THRESH_SENT_MAX = 30
63
- THRESH_EN_RATIO = 0.95
64
-
65
-
66
- # ---------------------------------------------------------------------------
67
- # Coherence heuristics
68
- # ---------------------------------------------------------------------------
69
-
70
- def _tokens(text: str) -> list[str]:
71
- return re.findall(r"[A-Za-z0-9']+", text)
72
-
73
-
74
- def distinct_2(text: str) -> float:
75
- toks = _tokens(text)
76
- if len(toks) < 2:
77
- return 0.0
78
- bigrams = [(toks[i], toks[i + 1]) for i in range(len(toks) - 1)]
79
- return len(set(bigrams)) / max(1, len(bigrams))
80
-
81
-
82
- def avg_sentence_len(text: str) -> float:
83
- sents = re.split(r"[.!?]+", text)
84
- lens = [len(_tokens(s)) for s in sents if _tokens(s)]
85
- if not lens:
86
- return 0.0
87
- return sum(lens) / len(lens)
88
-
89
-
90
- def english_char_ratio(text: str) -> float:
91
- if not text:
92
- return 0.0
93
- allowed = 0
94
- for c in text:
95
- if c.isalnum() or c.isspace() or c in ".,!?;:'\"-()[]{}/\\*#@&%+=_<>|$":
96
- allowed += 1
97
- return allowed / len(text)
98
-
99
-
100
- # ---------------------------------------------------------------------------
101
- # Runner
102
- # ---------------------------------------------------------------------------
103
-
104
- def _run_one(model, tokenizer, prompt: str, *, max_new_tokens: int, device: torch.device,
105
- max_seq_len: int, temperature: float, top_k: int, top_p: float,
106
- repetition_penalty: float) -> str:
107
- prompt_text = build_prompt(system="", history=[], user_msg=prompt)
108
- prompt_ids = tokenizer.encode(prompt_text)
109
-
110
- stream = generate_stream(
111
- model, tokenizer, prompt_ids,
112
- max_new_tokens=max_new_tokens,
113
- temperature=temperature,
114
- top_k=top_k,
115
- top_p=top_p,
116
- repetition_penalty=repetition_penalty,
117
- stop_strings=(END_TAG,),
118
- max_seq_len=max_seq_len,
119
- device=device,
120
- )
121
- collected: list[str] = []
122
- try:
123
- while True:
124
- collected.append(next(stream))
125
- except StopIteration as stop:
126
- if stop.value is not None:
127
- text = stop.value
128
- else:
129
- text = "".join(collected)
130
-
131
- if END_TAG in text:
132
- text = text.split(END_TAG, 1)[0]
133
- return text.strip()
134
-
135
-
136
- def _render_markdown(rows: list[dict]) -> str:
137
- lines = [
138
- "| # | Prompt | Response | dist-2 | sent_len | en_ratio | flags |",
139
- "|---|--------|----------|--------|----------|----------|-------|",
140
- ]
141
-
142
- def _cell(s: str, n: int = 60) -> str:
143
- s = s.replace("|", "\\|").replace("\n", " ")
144
- if len(s) > n:
145
- s = s[: n - 1] + "…"
146
- return s
147
-
148
- for i, r in enumerate(rows, 1):
149
- flags = []
150
- if r["distinct_2"] < THRESH_DISTINCT_2:
151
- flags.append("repetitive")
152
- if not (THRESH_SENT_MIN <= r["avg_sentence_len"] <= THRESH_SENT_MAX):
153
- flags.append("sent_len")
154
- if r["en_ratio"] < THRESH_EN_RATIO:
155
- flags.append("non_en")
156
- flag_str = ",".join(flags) or "ok"
157
- lines.append(
158
- f"| {i} | {_cell(r['prompt'], 40)} | {_cell(r['response'], 60)} | "
159
- f"{r['distinct_2']:.2f} | {r['avg_sentence_len']:.1f} | "
160
- f"{r['en_ratio']:.2f} | {flag_str} |"
161
- )
162
- return "\n".join(lines)
163
-
164
-
165
- # ---------------------------------------------------------------------------
166
- # CLI
167
- # ---------------------------------------------------------------------------
168
-
169
- def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
170
- p = argparse.ArgumentParser(description="HYDRA chat eval")
171
- p.add_argument("--ckpt", type=str, default=None, help="Checkpoint path.")
172
- p.add_argument("--sft", action="store_true", help="Prefer SFT checkpoint.")
173
- p.add_argument("--random", action="store_true", help="Use random weights.")
174
- p.add_argument("--max", dest="max_new_tokens", type=int, default=80)
175
- p.add_argument("--temp", dest="temperature", type=float, default=0.8)
176
- p.add_argument("--topk", dest="top_k", type=int, default=40)
177
- p.add_argument("--topp", dest="top_p", type=float, default=0.9)
178
- p.add_argument("--rep", dest="repetition_penalty", type=float, default=1.1)
179
- p.add_argument("--json", dest="json_out", type=str, default=None,
180
- help="Optional: dump raw results to this JSON path.")
181
- p.add_argument("--device", type=str, default=None)
182
- return p.parse_args(argv)
183
-
184
-
185
- def main(argv: list[str] | None = None) -> int:
186
- args = _parse_args(argv)
187
-
188
- if args.device:
189
- device = torch.device(args.device)
190
- elif torch.cuda.is_available():
191
- device = torch.device("cuda")
192
- else:
193
- device = torch.device("cpu")
194
-
195
- ckpt_path = None if args.random else resolve_checkpoint(args.ckpt, args.sft)
196
-
197
- t0 = time.time()
198
- model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
199
- dt_load = time.time() - t0
200
- print(f"[chat_eval] Loaded in {dt_load:.1f}s ckpt={meta['ckpt']}")
201
-
202
- from prepare import MAX_SEQ_LEN
203
-
204
- rows: list[dict] = []
205
- t_gen = time.time()
206
- for i, prompt in enumerate(PROMPTS, 1):
207
- t_start = time.time()
208
- try:
209
- resp = _run_one(
210
- model, tokenizer, prompt,
211
- max_new_tokens=args.max_new_tokens,
212
- device=device,
213
- max_seq_len=MAX_SEQ_LEN,
214
- temperature=args.temperature,
215
- top_k=args.top_k,
216
- top_p=args.top_p,
217
- repetition_penalty=args.repetition_penalty,
218
- )
219
- err = None
220
- except Exception as e: # noqa: BLE001 β€” eval must not abort mid-prompt.
221
- resp = ""
222
- err = repr(e)
223
- print(f"[chat_eval] prompt {i} failed: {err}", file=sys.stderr)
224
-
225
- rows.append({
226
- "prompt": prompt,
227
- "response": resp,
228
- "distinct_2": distinct_2(resp),
229
- "avg_sentence_len": avg_sentence_len(resp),
230
- "en_ratio": english_char_ratio(resp),
231
- "latency_s": round(time.time() - t_start, 2),
232
- "error": err,
233
- })
234
- print(f"[chat_eval] {i:2d}/{len(PROMPTS)} {rows[-1]['latency_s']:.1f}s {resp!r}")
235
-
236
- dt_gen = time.time() - t_gen
237
-
238
- print()
239
- print("## HYDRA chat_eval results")
240
- print(f"- checkpoint: `{meta['ckpt']}`")
241
- if meta.get("step") is not None:
242
- print(f"- step: {meta['step']}")
243
- if meta.get("val_bpb") is not None:
244
- print(f"- val_bpb: {meta['val_bpb']}")
245
- print(f"- prompts: {len(PROMPTS)}")
246
- print(f"- load: {dt_load:.1f}s generation: {dt_gen:.1f}s")
247
- print()
248
- print(_render_markdown(rows))
249
- print()
250
-
251
- # Summary heuristics
252
- any_empty = sum(1 for r in rows if not r["response"])
253
- any_error = sum(1 for r in rows if r["error"])
254
- mean_d2 = sum(r["distinct_2"] for r in rows) / max(1, len(rows))
255
- mean_en = sum(r["en_ratio"] for r in rows) / max(1, len(rows))
256
-
257
- print("### Aggregates")
258
- print(f"- empty responses: {any_empty}/{len(rows)}")
259
- print(f"- generation errors: {any_error}/{len(rows)}")
260
- print(f"- mean distinct-2: {mean_d2:.3f} (target > {THRESH_DISTINCT_2})")
261
- print(f"- mean en_ratio: {mean_en:.3f} (target > {THRESH_EN_RATIO})")
262
- print()
263
- print("_Quality at this model scale (~7.5M params) is NOT expected to meet thresholds; "
264
- "this eval verifies the chat interface, not dialogue coherence._")
265
-
266
- if args.json_out:
267
- out = {
268
- "meta": meta,
269
- "settings": {
270
- "max_new_tokens": args.max_new_tokens,
271
- "temperature": args.temperature,
272
- "top_k": args.top_k,
273
- "top_p": args.top_p,
274
- "repetition_penalty": args.repetition_penalty,
275
- },
276
- "rows": rows,
277
- "aggregates": {
278
- "empty": any_empty,
279
- "errors": any_error,
280
- "mean_distinct_2": mean_d2,
281
- "mean_en_ratio": mean_en,
282
- "load_s": dt_load,
283
- "gen_s": dt_gen,
284
- },
285
- }
286
- Path(args.json_out).write_text(json.dumps(out, indent=2))
287
- print(f"[chat_eval] JSON written to {args.json_out}")
288
-
289
- # Exit 0 if we loaded and generated *something* for each prompt (even if
290
- # quality was poor). Exit 1 only on load failure (caught by main's exception
291
- # propagation) or if ALL prompts returned empty strings β€” that signals a
292
- # broken generation loop, not poor quality.
293
- if any_empty == len(rows):
294
- print("[chat_eval] ALL prompts returned empty β€” generation loop is broken.", file=sys.stderr)
295
- return 1
296
- return 0
297
-
298
-
299
- if __name__ == "__main__":
300
- sys.exit(main())
 
1
+ """Non-interactive chat eval for HYDRA.
2
+
3
+ Runs a fixed set of prompts through the same chat template that `chat.py`
4
+ uses, prints a markdown table with the response and coherence heuristics.
5
+
6
+ Usage:
7
+ python scripts/chat_eval.py # auto-select checkpoint
8
+ python scripts/chat_eval.py --ckpt PATH
9
+ python scripts/chat_eval.py --random
10
+ python scripts/chat_eval.py --json out.json # also dump raw results
11
+ python scripts/chat_eval.py --max 80 # cap new tokens per prompt
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ import json
18
+ import os
19
+ import re
20
+ import sys
21
+ import time
22
+ from pathlib import Path
23
+
24
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
25
+ if str(_REPO_ROOT) not in sys.path:
26
+ sys.path.insert(0, str(_REPO_ROOT))
27
+
28
+ import torch # noqa: E402
29
+
30
+ from scripts.chat import ( # noqa: E402
31
+ ASSISTANT_TAG, END_TAG, USER_TAG, build_prompt,
32
+ generate_stream, load_model_and_tokenizer, resolve_checkpoint,
33
+ )
34
+
35
+
36
+ PROMPTS: list[str] = [
37
+ # Factual
38
+ "What is the capital of France?",
39
+ "Who wrote Romeo and Juliet?",
40
+ "What is 2 plus 2?",
41
+ "What color is the sky on a clear day?",
42
+ # Completion
43
+ "Once upon a time",
44
+ "The cat sat on the",
45
+ "In a hole in the ground there lived",
46
+ # Instruction
47
+ "Write one short sentence about rain.",
48
+ "List three animals.",
49
+ "Define the word 'library'.",
50
+ # Conversational
51
+ "Hello, how are you?",
52
+ "Tell me a joke.",
53
+ # Creative
54
+ "Describe a sunset in one line.",
55
+ "Give me a name for a pet robot.",
56
+ "What is the meaning of friendship?",
57
+ ]
58
+
59
+ # Heuristic thresholds (printed, not enforced as pass/fail).
60
+ THRESH_DISTINCT_2 = 0.30
61
+ THRESH_SENT_MIN = 5
62
+ THRESH_SENT_MAX = 30
63
+ THRESH_EN_RATIO = 0.95
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # Coherence heuristics
68
+ # ---------------------------------------------------------------------------
69
+
70
+ def _tokens(text: str) -> list[str]:
71
+ return re.findall(r"[A-Za-z0-9']+", text)
72
+
73
+
74
+ def distinct_2(text: str) -> float:
75
+ toks = _tokens(text)
76
+ if len(toks) < 2:
77
+ return 0.0
78
+ bigrams = [(toks[i], toks[i + 1]) for i in range(len(toks) - 1)]
79
+ return len(set(bigrams)) / max(1, len(bigrams))
80
+
81
+
82
+ def avg_sentence_len(text: str) -> float:
83
+ sents = re.split(r"[.!?]+", text)
84
+ lens = [len(_tokens(s)) for s in sents if _tokens(s)]
85
+ if not lens:
86
+ return 0.0
87
+ return sum(lens) / len(lens)
88
+
89
+
90
+ def english_char_ratio(text: str) -> float:
91
+ if not text:
92
+ return 0.0
93
+ allowed = 0
94
+ for c in text:
95
+ if c.isalnum() or c.isspace() or c in ".,!?;:'\"-()[]{}/\\*#@&%+=_<>|$":
96
+ allowed += 1
97
+ return allowed / len(text)
98
+
99
+
100
+ # ---------------------------------------------------------------------------
101
+ # Runner
102
+ # ---------------------------------------------------------------------------
103
+
104
+ def _run_one(model, tokenizer, prompt: str, *, max_new_tokens: int, device: torch.device,
105
+ max_seq_len: int, temperature: float, top_k: int, top_p: float,
106
+ repetition_penalty: float) -> str:
107
+ prompt_text = build_prompt(system="", history=[], user_msg=prompt)
108
+ prompt_ids = tokenizer.encode(prompt_text)
109
+
110
+ stream = generate_stream(
111
+ model, tokenizer, prompt_ids,
112
+ max_new_tokens=max_new_tokens,
113
+ temperature=temperature,
114
+ top_k=top_k,
115
+ top_p=top_p,
116
+ repetition_penalty=repetition_penalty,
117
+ stop_strings=(END_TAG,),
118
+ max_seq_len=max_seq_len,
119
+ device=device,
120
+ )
121
+ collected: list[str] = []
122
+ try:
123
+ while True:
124
+ collected.append(next(stream))
125
+ except StopIteration as stop:
126
+ if stop.value is not None:
127
+ text = stop.value
128
+ else:
129
+ text = "".join(collected)
130
+
131
+ if END_TAG in text:
132
+ text = text.split(END_TAG, 1)[0]
133
+ return text.strip()
134
+
135
+
136
+ def _render_markdown(rows: list[dict]) -> str:
137
+ lines = [
138
+ "| # | Prompt | Response | dist-2 | sent_len | en_ratio | flags |",
139
+ "|---|--------|----------|--------|----------|----------|-------|",
140
+ ]
141
+
142
+ def _cell(s: str, n: int = 60) -> str:
143
+ s = s.replace("|", "\\|").replace("\n", " ")
144
+ if len(s) > n:
145
+ s = s[: n - 1] + "…"
146
+ return s
147
+
148
+ for i, r in enumerate(rows, 1):
149
+ flags = []
150
+ if r["distinct_2"] < THRESH_DISTINCT_2:
151
+ flags.append("repetitive")
152
+ if not (THRESH_SENT_MIN <= r["avg_sentence_len"] <= THRESH_SENT_MAX):
153
+ flags.append("sent_len")
154
+ if r["en_ratio"] < THRESH_EN_RATIO:
155
+ flags.append("non_en")
156
+ flag_str = ",".join(flags) or "ok"
157
+ lines.append(
158
+ f"| {i} | {_cell(r['prompt'], 40)} | {_cell(r['response'], 60)} | "
159
+ f"{r['distinct_2']:.2f} | {r['avg_sentence_len']:.1f} | "
160
+ f"{r['en_ratio']:.2f} | {flag_str} |"
161
+ )
162
+ return "\n".join(lines)
163
+
164
+
165
+ # ---------------------------------------------------------------------------
166
+ # CLI
167
+ # ---------------------------------------------------------------------------
168
+
169
+ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
170
+ p = argparse.ArgumentParser(description="HYDRA chat eval")
171
+ p.add_argument("--ckpt", type=str, default=None, help="Checkpoint path.")
172
+ p.add_argument("--sft", action="store_true", help="Prefer SFT checkpoint.")
173
+ p.add_argument("--random", action="store_true", help="Use random weights.")
174
+ p.add_argument("--max", dest="max_new_tokens", type=int, default=80)
175
+ p.add_argument("--temp", dest="temperature", type=float, default=0.8)
176
+ p.add_argument("--topk", dest="top_k", type=int, default=40)
177
+ p.add_argument("--topp", dest="top_p", type=float, default=0.9)
178
+ p.add_argument("--rep", dest="repetition_penalty", type=float, default=1.1)
179
+ p.add_argument("--json", dest="json_out", type=str, default=None,
180
+ help="Optional: dump raw results to this JSON path.")
181
+ p.add_argument("--device", type=str, default=None)
182
+ return p.parse_args(argv)
183
+
184
+
185
+ def main(argv: list[str] | None = None) -> int:
186
+ args = _parse_args(argv)
187
+
188
+ if args.device:
189
+ device = torch.device(args.device)
190
+ elif torch.cuda.is_available():
191
+ device = torch.device("cuda")
192
+ else:
193
+ device = torch.device("cpu")
194
+
195
+ ckpt_path = None if args.random else resolve_checkpoint(args.ckpt, args.sft)
196
+
197
+ t0 = time.time()
198
+ model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
199
+ dt_load = time.time() - t0
200
+ print(f"[chat_eval] Loaded in {dt_load:.1f}s ckpt={meta['ckpt']}")
201
+
202
+ from prepare import MAX_SEQ_LEN
203
+
204
+ rows: list[dict] = []
205
+ t_gen = time.time()
206
+ for i, prompt in enumerate(PROMPTS, 1):
207
+ t_start = time.time()
208
+ try:
209
+ resp = _run_one(
210
+ model, tokenizer, prompt,
211
+ max_new_tokens=args.max_new_tokens,
212
+ device=device,
213
+ max_seq_len=MAX_SEQ_LEN,
214
+ temperature=args.temperature,
215
+ top_k=args.top_k,
216
+ top_p=args.top_p,
217
+ repetition_penalty=args.repetition_penalty,
218
+ )
219
+ err = None
220
+ except Exception as e: # noqa: BLE001 β€” eval must not abort mid-prompt.
221
+ resp = ""
222
+ err = repr(e)
223
+ print(f"[chat_eval] prompt {i} failed: {err}", file=sys.stderr)
224
+
225
+ rows.append({
226
+ "prompt": prompt,
227
+ "response": resp,
228
+ "distinct_2": distinct_2(resp),
229
+ "avg_sentence_len": avg_sentence_len(resp),
230
+ "en_ratio": english_char_ratio(resp),
231
+ "latency_s": round(time.time() - t_start, 2),
232
+ "error": err,
233
+ })
234
+ print(f"[chat_eval] {i:2d}/{len(PROMPTS)} {rows[-1]['latency_s']:.1f}s {resp!r}")
235
+
236
+ dt_gen = time.time() - t_gen
237
+
238
+ print()
239
+ print("## HYDRA chat_eval results")
240
+ print(f"- checkpoint: `{meta['ckpt']}`")
241
+ if meta.get("step") is not None:
242
+ print(f"- step: {meta['step']}")
243
+ if meta.get("val_bpb") is not None:
244
+ print(f"- val_bpb: {meta['val_bpb']}")
245
+ print(f"- prompts: {len(PROMPTS)}")
246
+ print(f"- load: {dt_load:.1f}s generation: {dt_gen:.1f}s")
247
+ print()
248
+ print(_render_markdown(rows))
249
+ print()
250
+
251
+ # Summary heuristics
252
+ any_empty = sum(1 for r in rows if not r["response"])
253
+ any_error = sum(1 for r in rows if r["error"])
254
+ mean_d2 = sum(r["distinct_2"] for r in rows) / max(1, len(rows))
255
+ mean_en = sum(r["en_ratio"] for r in rows) / max(1, len(rows))
256
+
257
+ print("### Aggregates")
258
+ print(f"- empty responses: {any_empty}/{len(rows)}")
259
+ print(f"- generation errors: {any_error}/{len(rows)}")
260
+ print(f"- mean distinct-2: {mean_d2:.3f} (target > {THRESH_DISTINCT_2})")
261
+ print(f"- mean en_ratio: {mean_en:.3f} (target > {THRESH_EN_RATIO})")
262
+ print()
263
+ print("_Quality at this model scale (~7.5M params) is NOT expected to meet thresholds; "
264
+ "this eval verifies the chat interface, not dialogue coherence._")
265
+
266
+ if args.json_out:
267
+ out = {
268
+ "meta": meta,
269
+ "settings": {
270
+ "max_new_tokens": args.max_new_tokens,
271
+ "temperature": args.temperature,
272
+ "top_k": args.top_k,
273
+ "top_p": args.top_p,
274
+ "repetition_penalty": args.repetition_penalty,
275
+ },
276
+ "rows": rows,
277
+ "aggregates": {
278
+ "empty": any_empty,
279
+ "errors": any_error,
280
+ "mean_distinct_2": mean_d2,
281
+ "mean_en_ratio": mean_en,
282
+ "load_s": dt_load,
283
+ "gen_s": dt_gen,
284
+ },
285
+ }
286
+ Path(args.json_out).write_text(json.dumps(out, indent=2))
287
+ print(f"[chat_eval] JSON written to {args.json_out}")
288
+
289
+ # Exit 0 if we loaded and generated *something* for each prompt (even if
290
+ # quality was poor). Exit 1 only on load failure (caught by main's exception
291
+ # propagation) or if ALL prompts returned empty strings β€” that signals a
292
+ # broken generation loop, not poor quality.
293
+ if any_empty == len(rows):
294
+ print("[chat_eval] ALL prompts returned empty β€” generation loop is broken.", file=sys.stderr)
295
+ return 1
296
+ return 0
297
+
298
+
299
+ if __name__ == "__main__":
300
+ sys.exit(main())
overlay/scripts/compile_debug.py CHANGED
@@ -1,213 +1,213 @@
1
- """Diagnostic script for torch.compile deadlock after ~500 steps.
2
-
3
- F17 investigation: validates that the _compiled_core / forward split
4
- fixes the deadlock by running forward+backward loops with compile on.
5
-
6
- Usage:
7
- LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
8
- HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \
9
- HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \
10
- .venv/bin/python -u scripts/compile_debug.py [mode]
11
-
12
- Modes:
13
- eager - no compile (baseline)
14
- model_only - compile model _compiled_core only
15
- muon_only - compile muon step only
16
- both - compile both (default)
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- import gc
22
- import os
23
- import signal
24
- import sys
25
- import threading
26
- import time
27
-
28
- # Set CUDA env before torch import
29
- os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
30
- os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
31
-
32
- import torch
33
- import torch.nn as nn
34
- import torch.nn.functional as F
35
-
36
- # -------------------------------------------------------------------------
37
- # Config
38
- # -------------------------------------------------------------------------
39
- MAX_STEPS = 800
40
- WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds
41
- BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8"))
42
- SEQ_LEN = 2048
43
- VOCAB_SIZE = 8192
44
-
45
-
46
- # -------------------------------------------------------------------------
47
- # Watchdog thread: kills process if no progress
48
- # -------------------------------------------------------------------------
49
- _last_progress = time.time()
50
- _watchdog_armed = True
51
-
52
- def _watchdog_fn():
53
- global _last_progress, _watchdog_armed
54
- while _watchdog_armed:
55
- time.sleep(1.0)
56
- elapsed = time.time() - _last_progress
57
- if elapsed > WATCHDOG_TIMEOUT_S:
58
- print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s β€” DEADLOCK DETECTED ***",
59
- flush=True)
60
- _dump_diagnostics()
61
- os.kill(os.getpid(), signal.SIGTERM)
62
- return
63
-
64
- def _dump_diagnostics():
65
- """Dump CUDA/dynamo state at deadlock time."""
66
- try:
67
- stats = torch.cuda.memory_stats()
68
- print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}")
69
- print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB")
70
- print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB")
71
- print(f" num_ooms: {stats.get('num_ooms', 0)}")
72
- except Exception as e:
73
- print(f" (memory_stats failed: {e})")
74
-
75
- try:
76
- import torch._dynamo.utils as du
77
- print(f" dynamo counters: {dict(du.counters)}")
78
- except Exception as e:
79
- print(f" (dynamo counters failed: {e})")
80
-
81
-
82
- def tick():
83
- global _last_progress
84
- _last_progress = time.time()
85
-
86
-
87
- # -------------------------------------------------------------------------
88
- # Test
89
- # -------------------------------------------------------------------------
90
- def run_test(mode: str) -> dict:
91
- """Run forward+backward loop with specified compile config."""
92
- print(f"\n{'='*70}")
93
- print(f"TEST MODE: {mode}")
94
- print(f"{'='*70}", flush=True)
95
-
96
- compile_model = mode in ("model_only", "both")
97
- compile_muon = mode in ("muon_only", "both")
98
-
99
- os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0"
100
- os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0"
101
- os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0"
102
- os.environ["HYDRA_HESTIA_INTERVAL"] = "9999"
103
- os.environ["HYDRA_HTM_LEARN_EVERY"] = "4"
104
-
105
- # Clear cached modules for fresh env var reads
106
- for mod_name in list(sys.modules.keys()):
107
- if mod_name.startswith("hydra."):
108
- del sys.modules[mod_name]
109
-
110
- torch._dynamo.reset()
111
- torch.cuda.empty_cache()
112
- torch.cuda.reset_peak_memory_stats()
113
- gc.collect()
114
-
115
- from hydra.model import PostSemClawModel
116
- from hydra.config import PostSemClawConfig
117
-
118
- device = torch.device("cuda")
119
- config = PostSemClawConfig(
120
- d_model=256, n_layer=4, d_state=64, headdim=32, expand=2,
121
- vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN,
122
- )
123
-
124
- with torch.device("meta"):
125
- model = PostSemClawModel(config)
126
- model.to_empty(device=device)
127
- model.init_weights()
128
-
129
- optimizer = model.setup_optimizer()
130
- autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
131
-
132
- result = {"mode": mode, "max_step": 0, "tps_samples": []}
133
- alloc_retries_prev = 0
134
-
135
- tick()
136
-
137
- for step in range(MAX_STEPS):
138
- t0 = time.time()
139
-
140
- x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
141
- y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
142
-
143
- with autocast_ctx:
144
- loss = model(x, y)
145
- loss.backward()
146
-
147
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
148
- optimizer.step()
149
- model.zero_grad(set_to_none=True)
150
-
151
- torch.cuda.synchronize()
152
- dt = time.time() - t0
153
- tps = int(BATCH_SIZE * SEQ_LEN / dt)
154
-
155
- tick()
156
-
157
- stats = torch.cuda.memory_stats()
158
- retries = stats.get("num_alloc_retries", 0)
159
- retry_delta = retries - alloc_retries_prev
160
- alloc_retries_prev = retries
161
-
162
- result["max_step"] = step
163
-
164
- if step % 50 == 0 or retry_delta > 0 or step < 3:
165
- alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6
166
- print(
167
- f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms "
168
- f"alloc={alloc_mb:.0f}MB retries={retries}",
169
- flush=True,
170
- )
171
- result["tps_samples"].append((step, tps))
172
-
173
- result["completed"] = True
174
- print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True)
175
- return result
176
-
177
-
178
- def main():
179
- print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}")
180
- print(f"GPU: {torch.cuda.get_device_name()}")
181
- print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
182
- print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s")
183
-
184
- wd = threading.Thread(target=_watchdog_fn, daemon=True)
185
- wd.start()
186
-
187
- modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"]
188
- results = []
189
-
190
- for mode in modes:
191
- try:
192
- r = run_test(mode)
193
- except SystemExit:
194
- print(f"\n DEADLOCK/KILLED mode={mode}", flush=True)
195
- r = {"mode": mode, "completed": False, "max_step": "?"}
196
- except Exception as e:
197
- print(f"\n ERROR mode={mode}: {e}", flush=True)
198
- r = {"mode": mode, "completed": False, "error": str(e)}
199
- results.append(r)
200
-
201
- print(f"\n{'='*70}")
202
- print("SUMMARY")
203
- print(f"{'='*70}")
204
- for r in results:
205
- status = "PASS" if r.get("completed") else "FAIL"
206
- print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})")
207
-
208
- global _watchdog_armed
209
- _watchdog_armed = False
210
-
211
-
212
- if __name__ == "__main__":
213
- main()
 
1
+ """Diagnostic script for torch.compile deadlock after ~500 steps.
2
+
3
+ F17 investigation: validates that the _compiled_core / forward split
4
+ fixes the deadlock by running forward+backward loops with compile on.
5
+
6
+ Usage:
7
+ LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
8
+ HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \
9
+ HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \
10
+ .venv/bin/python -u scripts/compile_debug.py [mode]
11
+
12
+ Modes:
13
+ eager - no compile (baseline)
14
+ model_only - compile model _compiled_core only
15
+ muon_only - compile muon step only
16
+ both - compile both (default)
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import gc
22
+ import os
23
+ import signal
24
+ import sys
25
+ import threading
26
+ import time
27
+
28
+ # Set CUDA env before torch import
29
+ os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
30
+ os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+
36
+ # -------------------------------------------------------------------------
37
+ # Config
38
+ # -------------------------------------------------------------------------
39
+ MAX_STEPS = 800
40
+ WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds
41
+ BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8"))
42
+ SEQ_LEN = 2048
43
+ VOCAB_SIZE = 8192
44
+
45
+
46
+ # -------------------------------------------------------------------------
47
+ # Watchdog thread: kills process if no progress
48
+ # -------------------------------------------------------------------------
49
+ _last_progress = time.time()
50
+ _watchdog_armed = True
51
+
52
+ def _watchdog_fn():
53
+ global _last_progress, _watchdog_armed
54
+ while _watchdog_armed:
55
+ time.sleep(1.0)
56
+ elapsed = time.time() - _last_progress
57
+ if elapsed > WATCHDOG_TIMEOUT_S:
58
+ print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s β€” DEADLOCK DETECTED ***",
59
+ flush=True)
60
+ _dump_diagnostics()
61
+ os.kill(os.getpid(), signal.SIGTERM)
62
+ return
63
+
64
+ def _dump_diagnostics():
65
+ """Dump CUDA/dynamo state at deadlock time."""
66
+ try:
67
+ stats = torch.cuda.memory_stats()
68
+ print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}")
69
+ print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB")
70
+ print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB")
71
+ print(f" num_ooms: {stats.get('num_ooms', 0)}")
72
+ except Exception as e:
73
+ print(f" (memory_stats failed: {e})")
74
+
75
+ try:
76
+ import torch._dynamo.utils as du
77
+ print(f" dynamo counters: {dict(du.counters)}")
78
+ except Exception as e:
79
+ print(f" (dynamo counters failed: {e})")
80
+
81
+
82
+ def tick():
83
+ global _last_progress
84
+ _last_progress = time.time()
85
+
86
+
87
+ # -------------------------------------------------------------------------
88
+ # Test
89
+ # -------------------------------------------------------------------------
90
+ def run_test(mode: str) -> dict:
91
+ """Run forward+backward loop with specified compile config."""
92
+ print(f"\n{'='*70}")
93
+ print(f"TEST MODE: {mode}")
94
+ print(f"{'='*70}", flush=True)
95
+
96
+ compile_model = mode in ("model_only", "both")
97
+ compile_muon = mode in ("muon_only", "both")
98
+
99
+ os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0"
100
+ os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0"
101
+ os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0"
102
+ os.environ["HYDRA_HESTIA_INTERVAL"] = "9999"
103
+ os.environ["HYDRA_HTM_LEARN_EVERY"] = "4"
104
+
105
+ # Clear cached modules for fresh env var reads
106
+ for mod_name in list(sys.modules.keys()):
107
+ if mod_name.startswith("hydra."):
108
+ del sys.modules[mod_name]
109
+
110
+ torch._dynamo.reset()
111
+ torch.cuda.empty_cache()
112
+ torch.cuda.reset_peak_memory_stats()
113
+ gc.collect()
114
+
115
+ from hydra.model import PostSemClawModel
116
+ from hydra.config import PostSemClawConfig
117
+
118
+ device = torch.device("cuda")
119
+ config = PostSemClawConfig(
120
+ d_model=256, n_layer=4, d_state=64, headdim=32, expand=2,
121
+ vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN,
122
+ )
123
+
124
+ with torch.device("meta"):
125
+ model = PostSemClawModel(config)
126
+ model.to_empty(device=device)
127
+ model.init_weights()
128
+
129
+ optimizer = model.setup_optimizer()
130
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
131
+
132
+ result = {"mode": mode, "max_step": 0, "tps_samples": []}
133
+ alloc_retries_prev = 0
134
+
135
+ tick()
136
+
137
+ for step in range(MAX_STEPS):
138
+ t0 = time.time()
139
+
140
+ x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
141
+ y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
142
+
143
+ with autocast_ctx:
144
+ loss = model(x, y)
145
+ loss.backward()
146
+
147
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
148
+ optimizer.step()
149
+ model.zero_grad(set_to_none=True)
150
+
151
+ torch.cuda.synchronize()
152
+ dt = time.time() - t0
153
+ tps = int(BATCH_SIZE * SEQ_LEN / dt)
154
+
155
+ tick()
156
+
157
+ stats = torch.cuda.memory_stats()
158
+ retries = stats.get("num_alloc_retries", 0)
159
+ retry_delta = retries - alloc_retries_prev
160
+ alloc_retries_prev = retries
161
+
162
+ result["max_step"] = step
163
+
164
+ if step % 50 == 0 or retry_delta > 0 or step < 3:
165
+ alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6
166
+ print(
167
+ f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms "
168
+ f"alloc={alloc_mb:.0f}MB retries={retries}",
169
+ flush=True,
170
+ )
171
+ result["tps_samples"].append((step, tps))
172
+
173
+ result["completed"] = True
174
+ print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True)
175
+ return result
176
+
177
+
178
+ def main():
179
+ print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}")
180
+ print(f"GPU: {torch.cuda.get_device_name()}")
181
+ print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
182
+ print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s")
183
+
184
+ wd = threading.Thread(target=_watchdog_fn, daemon=True)
185
+ wd.start()
186
+
187
+ modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"]
188
+ results = []
189
+
190
+ for mode in modes:
191
+ try:
192
+ r = run_test(mode)
193
+ except SystemExit:
194
+ print(f"\n DEADLOCK/KILLED mode={mode}", flush=True)
195
+ r = {"mode": mode, "completed": False, "max_step": "?"}
196
+ except Exception as e:
197
+ print(f"\n ERROR mode={mode}: {e}", flush=True)
198
+ r = {"mode": mode, "completed": False, "error": str(e)}
199
+ results.append(r)
200
+
201
+ print(f"\n{'='*70}")
202
+ print("SUMMARY")
203
+ print(f"{'='*70}")
204
+ for r in results:
205
+ status = "PASS" if r.get("completed") else "FAIL"
206
+ print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})")
207
+
208
+ global _watchdog_armed
209
+ _watchdog_armed = False
210
+
211
+
212
+ if __name__ == "__main__":
213
+ main()
overlay/scripts/dataset_audit.py CHANGED
@@ -1,241 +1,241 @@
1
- """
2
- Dataset audit β€” diagnostic tool for HYDRA's pretraining corpus.
3
-
4
- Usage:
5
- python scripts/dataset_audit.py # Quick audit
6
- python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts
7
- python scripts/dataset_audit.py --full # Full tokenize of every shard (slow)
8
-
9
- Reports:
10
- - Shard count, total disk usage
11
- - Estimated total tokens (character-based + tokenized sample)
12
- - Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target
13
- - Document diversity sample
14
- - Warnings about shard ordering, shuffle, and streaming behavior
15
- """
16
- from __future__ import annotations
17
-
18
- import argparse
19
- import os
20
- import sys
21
- import time
22
- from pathlib import Path
23
-
24
- import pyarrow.parquet as pq
25
-
26
- # Resolve repo root so the script works regardless of CWD.
27
- REPO_ROOT = Path(__file__).resolve().parent.parent
28
- sys.path.insert(0, str(REPO_ROOT))
29
-
30
- from prepare import ( # noqa: E402
31
- DATA_DIR,
32
- MAX_SHARD,
33
- TOKENIZER_DIR,
34
- VAL_FILENAME,
35
- VAL_SHARD,
36
- )
37
-
38
- TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s
39
- CHARS_PER_TOKEN_HEURISTIC = 4.0
40
-
41
-
42
- def human_bytes(n: int) -> str:
43
- for unit in ("B", "KB", "MB", "GB", "TB"):
44
- if n < 1024:
45
- return f"{n:.1f}{unit}"
46
- n /= 1024
47
- return f"{n:.1f}PB"
48
-
49
-
50
- def human_tokens(n: int | float) -> str:
51
- if n >= 1e9:
52
- return f"{n / 1e9:.2f}B"
53
- if n >= 1e6:
54
- return f"{n / 1e6:.1f}M"
55
- if n >= 1e3:
56
- return f"{n / 1e3:.1f}K"
57
- return f"{n:.0f}"
58
-
59
-
60
- def list_shards() -> tuple[list[Path], Path | None]:
61
- """Return (train_shards_sorted, val_shard_or_none)."""
62
- if not os.path.isdir(DATA_DIR):
63
- return [], None
64
- all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet"))
65
- val_path = Path(DATA_DIR) / VAL_FILENAME
66
- train = [p for p in all_paths if p.name != VAL_FILENAME]
67
- val = val_path if val_path.exists() else None
68
- return train, val
69
-
70
-
71
- def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]:
72
- """Tokenize first N row groups of a shard. Returns (tokens, docs)."""
73
- pf = pq.ParquetFile(shard_path)
74
- tokens = 0
75
- docs = 0
76
- n = min(row_groups, pf.num_row_groups)
77
- for i in range(n):
78
- rg = pf.read_row_group(i)
79
- texts = rg.column("text").to_pylist()
80
- ids = enc.encode_ordinary_batch(texts, num_threads=8)
81
- tokens += sum(len(x) for x in ids)
82
- docs += len(texts)
83
- return tokens, docs, pf.num_row_groups
84
-
85
-
86
- def main() -> int:
87
- parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus")
88
- parser.add_argument(
89
- "--sample",
90
- type=int,
91
- default=3,
92
- help="Number of shards to tokenize for token-count estimate",
93
- )
94
- parser.add_argument(
95
- "--full",
96
- action="store_true",
97
- help="Tokenize every shard (slow; gives exact total)",
98
- )
99
- args = parser.parse_args()
100
-
101
- print("=" * 72)
102
- print("HYDRA corpus audit")
103
- print("=" * 72)
104
- print(f"DATA_DIR: {DATA_DIR}")
105
- print(f"TOKENIZER_DIR: {TOKENIZER_DIR}")
106
- print(f"Source dataset: karpathy/climbmix-400b-shuffle")
107
- print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})")
108
- print()
109
-
110
- train_shards, val_shard = list_shards()
111
- if not train_shards:
112
- print("ERROR: no parquet shards found. Run `python prepare.py` first.")
113
- return 1
114
-
115
- total_disk = sum(p.stat().st_size for p in train_shards)
116
- val_disk = val_shard.stat().st_size if val_shard else 0
117
-
118
- print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})")
119
- print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})")
120
- print(f"Disk (train): {human_bytes(total_disk)}")
121
- print(f"Disk (val): {human_bytes(val_disk)}")
122
- print()
123
-
124
- # Character-based pass (fast): count total chars in all shards.
125
- t0 = time.time()
126
- total_chars = 0
127
- total_docs = 0
128
- total_row_groups = 0
129
- for p in train_shards:
130
- pf = pq.ParquetFile(p)
131
- total_row_groups += pf.num_row_groups
132
- total_docs += pf.metadata.num_rows
133
- dt_meta = time.time() - t0
134
- print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s")
135
- print(f"Train documents: {total_docs:,}")
136
- print(f"Row groups: {total_row_groups:,}")
137
- print()
138
-
139
- # Tokenizer-based sampling.
140
- try:
141
- import pickle
142
-
143
- with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f:
144
- enc = pickle.load(f)
145
- print(f"Tokenizer vocab: {enc.n_vocab}")
146
- except FileNotFoundError:
147
- print("WARNING: tokenizer.pkl not found β€” skipping tokenized sample.")
148
- enc = None
149
-
150
- est_total_tokens = 0
151
- if enc is not None:
152
- if args.full:
153
- sample_shards = train_shards
154
- else:
155
- # Pick shards evenly across the range for a representative sample.
156
- n_sample = min(args.sample, len(train_shards))
157
- if n_sample == 1:
158
- sample_shards = [train_shards[0]]
159
- else:
160
- stride = max(1, len(train_shards) // n_sample)
161
- sample_shards = train_shards[::stride][:n_sample]
162
-
163
- t0 = time.time()
164
- sample_tokens = 0
165
- sample_docs = 0
166
- sample_row_groups = 0
167
- sample_shard_row_groups = 0
168
- print(f"Tokenizing sample: {len(sample_shards)} shards ...")
169
- for p in sample_shards:
170
- tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5)
171
- sample_tokens += tok
172
- sample_docs += docs
173
- sample_row_groups += min(5, n_rg)
174
- sample_shard_row_groups += n_rg
175
- dt_tok = time.time() - t0
176
-
177
- tokens_per_rg = sample_tokens / max(sample_row_groups, 1)
178
- per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards))
179
- est_total_tokens = per_shard * len(train_shards)
180
-
181
- print(
182
- f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, "
183
- f"{sample_tokens:,} tokens) in {dt_tok:.1f}s"
184
- )
185
- print(f" tokens/row_group: {tokens_per_rg:,.0f}")
186
- print(f" tokens/shard: {per_shard:,.0f}")
187
- print(f" tokens/shard: {human_tokens(per_shard)}")
188
- else:
189
- # Fall back to character heuristic.
190
- per_shard_chars = total_disk / max(len(train_shards), 1)
191
- # Parquet compression ratio ~3x for text; decompressed ~3 * file size.
192
- # Chars per token heuristic β‰ˆ 4.
193
- est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC
194
-
195
- print()
196
- print("-" * 72)
197
- print("Token budget analysis")
198
- print("-" * 72)
199
- print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} "
200
- f"({est_total_tokens:,.0f})")
201
- print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}")
202
- ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0
203
- if ratio >= 1.0:
204
- print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})")
205
- else:
206
- print(f" Ratio: {ratio:.2f}x INSUFFICIENT β€” need {1 - ratio:.0%} more")
207
- print()
208
-
209
- # Warnings about the dataloader behavior.
210
- print("-" * 72)
211
- print("Dataloader behavior (prepare.py::_document_batches)")
212
- print("-" * 72)
213
- print("+ Infinite streaming: while True around shard list (no StopIteration)")
214
- print("+ Streams per shard, never loads full corpus into RAM")
215
- print("+ BOS-aligned best-fit packing gives document-level buffer shuffling")
216
- print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch")
217
- print("- Row groups / rows WITHIN a shard are read in fixed order")
218
- print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)")
219
- print()
220
-
221
- # Quick content diversity peek.
222
- if train_shards:
223
- print("-" * 72)
224
- print("Content sample (shard 0, first 3 docs)")
225
- print("-" * 72)
226
- pf = pq.ParquetFile(train_shards[0])
227
- rg = pf.read_row_group(0)
228
- texts = rg.column("text").to_pylist()
229
- for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]):
230
- if idx < len(texts):
231
- snippet = texts[idx][:160].replace("\n", " ")
232
- print(f" [{i}] len={len(texts[idx])}: {snippet!r}")
233
- print()
234
-
235
- print("=" * 72)
236
- print("Done.")
237
- return 0
238
-
239
-
240
- if __name__ == "__main__":
241
- raise SystemExit(main())
 
1
+ """
2
+ Dataset audit β€” diagnostic tool for HYDRA's pretraining corpus.
3
+
4
+ Usage:
5
+ python scripts/dataset_audit.py # Quick audit
6
+ python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts
7
+ python scripts/dataset_audit.py --full # Full tokenize of every shard (slow)
8
+
9
+ Reports:
10
+ - Shard count, total disk usage
11
+ - Estimated total tokens (character-based + tokenized sample)
12
+ - Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target
13
+ - Document diversity sample
14
+ - Warnings about shard ordering, shuffle, and streaming behavior
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import os
20
+ import sys
21
+ import time
22
+ from pathlib import Path
23
+
24
+ import pyarrow.parquet as pq
25
+
26
+ # Resolve repo root so the script works regardless of CWD.
27
+ REPO_ROOT = Path(__file__).resolve().parent.parent
28
+ sys.path.insert(0, str(REPO_ROOT))
29
+
30
+ from prepare import ( # noqa: E402
31
+ DATA_DIR,
32
+ MAX_SHARD,
33
+ TOKENIZER_DIR,
34
+ VAL_FILENAME,
35
+ VAL_SHARD,
36
+ )
37
+
38
+ TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s
39
+ CHARS_PER_TOKEN_HEURISTIC = 4.0
40
+
41
+
42
+ def human_bytes(n: int) -> str:
43
+ for unit in ("B", "KB", "MB", "GB", "TB"):
44
+ if n < 1024:
45
+ return f"{n:.1f}{unit}"
46
+ n /= 1024
47
+ return f"{n:.1f}PB"
48
+
49
+
50
+ def human_tokens(n: int | float) -> str:
51
+ if n >= 1e9:
52
+ return f"{n / 1e9:.2f}B"
53
+ if n >= 1e6:
54
+ return f"{n / 1e6:.1f}M"
55
+ if n >= 1e3:
56
+ return f"{n / 1e3:.1f}K"
57
+ return f"{n:.0f}"
58
+
59
+
60
+ def list_shards() -> tuple[list[Path], Path | None]:
61
+ """Return (train_shards_sorted, val_shard_or_none)."""
62
+ if not os.path.isdir(DATA_DIR):
63
+ return [], None
64
+ all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet"))
65
+ val_path = Path(DATA_DIR) / VAL_FILENAME
66
+ train = [p for p in all_paths if p.name != VAL_FILENAME]
67
+ val = val_path if val_path.exists() else None
68
+ return train, val
69
+
70
+
71
+ def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]:
72
+ """Tokenize first N row groups of a shard. Returns (tokens, docs)."""
73
+ pf = pq.ParquetFile(shard_path)
74
+ tokens = 0
75
+ docs = 0
76
+ n = min(row_groups, pf.num_row_groups)
77
+ for i in range(n):
78
+ rg = pf.read_row_group(i)
79
+ texts = rg.column("text").to_pylist()
80
+ ids = enc.encode_ordinary_batch(texts, num_threads=8)
81
+ tokens += sum(len(x) for x in ids)
82
+ docs += len(texts)
83
+ return tokens, docs, pf.num_row_groups
84
+
85
+
86
+ def main() -> int:
87
+ parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus")
88
+ parser.add_argument(
89
+ "--sample",
90
+ type=int,
91
+ default=3,
92
+ help="Number of shards to tokenize for token-count estimate",
93
+ )
94
+ parser.add_argument(
95
+ "--full",
96
+ action="store_true",
97
+ help="Tokenize every shard (slow; gives exact total)",
98
+ )
99
+ args = parser.parse_args()
100
+
101
+ print("=" * 72)
102
+ print("HYDRA corpus audit")
103
+ print("=" * 72)
104
+ print(f"DATA_DIR: {DATA_DIR}")
105
+ print(f"TOKENIZER_DIR: {TOKENIZER_DIR}")
106
+ print(f"Source dataset: karpathy/climbmix-400b-shuffle")
107
+ print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})")
108
+ print()
109
+
110
+ train_shards, val_shard = list_shards()
111
+ if not train_shards:
112
+ print("ERROR: no parquet shards found. Run `python prepare.py` first.")
113
+ return 1
114
+
115
+ total_disk = sum(p.stat().st_size for p in train_shards)
116
+ val_disk = val_shard.stat().st_size if val_shard else 0
117
+
118
+ print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})")
119
+ print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})")
120
+ print(f"Disk (train): {human_bytes(total_disk)}")
121
+ print(f"Disk (val): {human_bytes(val_disk)}")
122
+ print()
123
+
124
+ # Character-based pass (fast): count total chars in all shards.
125
+ t0 = time.time()
126
+ total_chars = 0
127
+ total_docs = 0
128
+ total_row_groups = 0
129
+ for p in train_shards:
130
+ pf = pq.ParquetFile(p)
131
+ total_row_groups += pf.num_row_groups
132
+ total_docs += pf.metadata.num_rows
133
+ dt_meta = time.time() - t0
134
+ print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s")
135
+ print(f"Train documents: {total_docs:,}")
136
+ print(f"Row groups: {total_row_groups:,}")
137
+ print()
138
+
139
+ # Tokenizer-based sampling.
140
+ try:
141
+ import pickle
142
+
143
+ with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f:
144
+ enc = pickle.load(f)
145
+ print(f"Tokenizer vocab: {enc.n_vocab}")
146
+ except FileNotFoundError:
147
+ print("WARNING: tokenizer.pkl not found β€” skipping tokenized sample.")
148
+ enc = None
149
+
150
+ est_total_tokens = 0
151
+ if enc is not None:
152
+ if args.full:
153
+ sample_shards = train_shards
154
+ else:
155
+ # Pick shards evenly across the range for a representative sample.
156
+ n_sample = min(args.sample, len(train_shards))
157
+ if n_sample == 1:
158
+ sample_shards = [train_shards[0]]
159
+ else:
160
+ stride = max(1, len(train_shards) // n_sample)
161
+ sample_shards = train_shards[::stride][:n_sample]
162
+
163
+ t0 = time.time()
164
+ sample_tokens = 0
165
+ sample_docs = 0
166
+ sample_row_groups = 0
167
+ sample_shard_row_groups = 0
168
+ print(f"Tokenizing sample: {len(sample_shards)} shards ...")
169
+ for p in sample_shards:
170
+ tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5)
171
+ sample_tokens += tok
172
+ sample_docs += docs
173
+ sample_row_groups += min(5, n_rg)
174
+ sample_shard_row_groups += n_rg
175
+ dt_tok = time.time() - t0
176
+
177
+ tokens_per_rg = sample_tokens / max(sample_row_groups, 1)
178
+ per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards))
179
+ est_total_tokens = per_shard * len(train_shards)
180
+
181
+ print(
182
+ f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, "
183
+ f"{sample_tokens:,} tokens) in {dt_tok:.1f}s"
184
+ )
185
+ print(f" tokens/row_group: {tokens_per_rg:,.0f}")
186
+ print(f" tokens/shard: {per_shard:,.0f}")
187
+ print(f" tokens/shard: {human_tokens(per_shard)}")
188
+ else:
189
+ # Fall back to character heuristic.
190
+ per_shard_chars = total_disk / max(len(train_shards), 1)
191
+ # Parquet compression ratio ~3x for text; decompressed ~3 * file size.
192
+ # Chars per token heuristic β‰ˆ 4.
193
+ est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC
194
+
195
+ print()
196
+ print("-" * 72)
197
+ print("Token budget analysis")
198
+ print("-" * 72)
199
+ print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} "
200
+ f"({est_total_tokens:,.0f})")
201
+ print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}")
202
+ ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0
203
+ if ratio >= 1.0:
204
+ print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})")
205
+ else:
206
+ print(f" Ratio: {ratio:.2f}x INSUFFICIENT β€” need {1 - ratio:.0%} more")
207
+ print()
208
+
209
+ # Warnings about the dataloader behavior.
210
+ print("-" * 72)
211
+ print("Dataloader behavior (prepare.py::_document_batches)")
212
+ print("-" * 72)
213
+ print("+ Infinite streaming: while True around shard list (no StopIteration)")
214
+ print("+ Streams per shard, never loads full corpus into RAM")
215
+ print("+ BOS-aligned best-fit packing gives document-level buffer shuffling")
216
+ print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch")
217
+ print("- Row groups / rows WITHIN a shard are read in fixed order")
218
+ print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)")
219
+ print()
220
+
221
+ # Quick content diversity peek.
222
+ if train_shards:
223
+ print("-" * 72)
224
+ print("Content sample (shard 0, first 3 docs)")
225
+ print("-" * 72)
226
+ pf = pq.ParquetFile(train_shards[0])
227
+ rg = pf.read_row_group(0)
228
+ texts = rg.column("text").to_pylist()
229
+ for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]):
230
+ if idx < len(texts):
231
+ snippet = texts[idx][:160].replace("\n", " ")
232
+ print(f" [{i}] len={len(texts[idx])}: {snippet!r}")
233
+ print()
234
+
235
+ print("=" * 72)
236
+ print("Done.")
237
+ return 0
238
+
239
+
240
+ if __name__ == "__main__":
241
+ raise SystemExit(main())
overlay/scripts/download_sft_data.py CHANGED
@@ -1,457 +1,457 @@
1
- """Download + tokenize instruction data for HYDRA SFT.
2
-
3
- Writes int16 token shards to `data/sft/shard_XXX.bin` plus a
4
- `data/sft/meta.json` with counts + special-token mapping.
5
-
6
- Chat format (vocab's 4 reserved special tokens are repurposed):
7
- <BOS=8188> <|user|=8189>\n{instruction}\n{input?}\n <|assistant|=8190>\n
8
- {output}<|end|=8191>\n
9
-
10
- Special-token IDs are constants derived from the tokenizer (they are the
11
- last 4 IDs in an 8192-vocab). They are stored in meta.json for the SFT
12
- script to read.
13
-
14
- Sources (tried in order):
15
- 1. yahma/alpaca-cleaned (~52K pairs via HF parquet auto-convert)
16
- 2. databricks/databricks-dolly-15k (~15K pairs)
17
- 3. Hard-coded 200 simple Q&A pairs (offline backup)
18
-
19
- Usage:
20
- python scripts/download_sft_data.py # full download
21
- python scripts/download_sft_data.py --test # small smoke run
22
- python scripts/download_sft_data.py --offline # skip network; use backup
23
- """
24
-
25
- from __future__ import annotations
26
-
27
- import argparse
28
- import json
29
- import os
30
- import pickle
31
- import sys
32
- import time
33
- from pathlib import Path
34
-
35
- import numpy as np
36
- import requests
37
-
38
- # Make `prepare` and `hydra.*` importable when run as a script
39
- _REPO_ROOT = Path(__file__).resolve().parent.parent
40
- if str(_REPO_ROOT) not in sys.path:
41
- sys.path.insert(0, str(_REPO_ROOT))
42
-
43
-
44
- # ---------------------------------------------------------------------------
45
- # Constants
46
- # ---------------------------------------------------------------------------
47
-
48
- CACHE_DIR = Path.home() / ".cache" / "autoresearch"
49
- TOKENIZER_PKL = CACHE_DIR / "tokenizer" / "tokenizer.pkl"
50
-
51
- SFT_DIR = _REPO_ROOT / "data" / "sft"
52
- SFT_DIR.mkdir(parents=True, exist_ok=True)
53
-
54
- # Reserved token repurposing β€” must match prepare.py SPECIAL_TOKENS list
55
- # (indices 8188-8191 in the 8192-vocab BPE).
56
- BOS_ID = 8188 # <|reserved_0|>
57
- USER_ID = 8189 # <|reserved_1|>
58
- ASSISTANT_ID = 8190 # <|reserved_2|>
59
- END_ID = 8191 # <|reserved_3|>
60
-
61
- # Shards are int16 arrays of packed token IDs.
62
- TOKENS_PER_SHARD = 1_048_576 # ~2 MB per shard
63
- DTYPE = np.int16 # vocab_size=8192 fits in int16
64
-
65
- TARGET_TOKENS_DEFAULT = 15_000_000 # ~15M instruction tokens
66
- TARGET_TOKENS_TEST = 1_500_000 # smoke run
67
-
68
- # HuggingFace auto-parquet endpoint β€” one file for alpaca-cleaned
69
- ALPACA_URL = (
70
- "https://huggingface.co/api/datasets/yahma/alpaca-cleaned/parquet/"
71
- "default/train/0.parquet"
72
- )
73
- DOLLY_URL = (
74
- "https://huggingface.co/api/datasets/databricks/databricks-dolly-15k/"
75
- "parquet/default/train/0.parquet"
76
- )
77
-
78
-
79
- # ---------------------------------------------------------------------------
80
- # Offline backup Q&A pairs (used only if network unavailable)
81
- # ---------------------------------------------------------------------------
82
-
83
- _BACKUP_QA = [
84
- ("What is the capital of France?", "The capital of France is Paris."),
85
- ("What is the capital of Germany?", "The capital of Germany is Berlin."),
86
- ("What is the capital of Japan?", "The capital of Japan is Tokyo."),
87
- ("What is the capital of Italy?", "The capital of Italy is Rome."),
88
- ("What is the capital of Spain?", "The capital of Spain is Madrid."),
89
- ("What is the capital of England?", "The capital of England is London."),
90
- ("What is the capital of Canada?", "The capital of Canada is Ottawa."),
91
- ("What is the capital of Australia?", "The capital of Australia is Canberra."),
92
- ("What is 2 plus 2?", "Two plus two equals four."),
93
- ("What is 5 plus 5?", "Five plus five equals ten."),
94
- ("What is 3 times 4?", "Three times four equals twelve."),
95
- ("What is 10 minus 3?", "Ten minus three equals seven."),
96
- ("Name a color.", "Red is a color. Other colors include blue, green, and yellow."),
97
- ("Name a fruit.", "Apple is a fruit. Other fruits include banana, orange, and grape."),
98
- ("Name an animal.", "A dog is an animal. Other animals include cat, horse, and cow."),
99
- ("What do birds do?", "Birds fly through the air using their wings."),
100
- ("What do fish do?", "Fish swim in water using their fins and tails."),
101
- ("What is water made of?", "Water is made of hydrogen and oxygen."),
102
- ("What is the sun?", "The sun is a star at the center of our solar system."),
103
- ("What is the moon?", "The moon is Earth's natural satellite."),
104
- ("What is a cloud?", "A cloud is a collection of tiny water droplets in the sky."),
105
- ("What is rain?", "Rain is water that falls from clouds to the ground."),
106
- ("What is snow?", "Snow is frozen water that falls from clouds in cold weather."),
107
- ("What is wind?", "Wind is moving air."),
108
- ("What is fire?", "Fire is a chemical reaction that produces heat and light."),
109
- ("What is ice?", "Ice is frozen water."),
110
- ("Say hello.", "Hello! How can I help you today?"),
111
- ("Say goodbye.", "Goodbye! Have a great day."),
112
- ("Tell me a short story.", "Once upon a time, a brave rabbit lived in the forest. The end."),
113
- ("Tell me a joke.", "Why did the chicken cross the road? To get to the other side."),
114
- ("Who wrote Hamlet?", "William Shakespeare wrote the play Hamlet."),
115
- ("Who wrote Romeo and Juliet?", "William Shakespeare wrote Romeo and Juliet."),
116
- ("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa."),
117
- ("When did World War 2 end?", "World War 2 ended in 1945."),
118
- ("What is gravity?", "Gravity is the force that pulls objects toward the Earth."),
119
- ("What is the speed of light?", "The speed of light is approximately 300,000 kilometers per second."),
120
- ("What is the largest planet?", "Jupiter is the largest planet in our solar system."),
121
- ("What is the smallest planet?", "Mercury is the smallest planet in our solar system."),
122
- ("At what temperature does water boil?", "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit."),
123
- ("At what temperature does water freeze?", "Water freezes at 0 degrees Celsius or 32 degrees Fahrenheit."),
124
- ("How many legs does a spider have?", "A spider has eight legs."),
125
- ("How many legs does an insect have?", "An insect has six legs."),
126
- ("What do plants need to grow?", "Plants need sunlight, water, soil, and air to grow."),
127
- ("What do humans eat?", "Humans eat a variety of foods including fruits, vegetables, meat, and grains."),
128
- ("What is a book?", "A book is a collection of written or printed pages bound together."),
129
- ("What is a computer?", "A computer is an electronic device that processes information."),
130
- ("What is a phone?", "A phone is a device used to communicate with people at a distance."),
131
- ("What is music?", "Music is an arrangement of sounds that is pleasing to hear."),
132
- ("What is art?", "Art is the expression of human creativity and imagination."),
133
- ("What is a language?", "A language is a system of communication used by a group of people."),
134
- ]
135
-
136
- # Duplicate to reach ~200 samples (each pair appears ~4x)
137
- BACKUP_QA = (_BACKUP_QA * 4)[:200]
138
-
139
-
140
- # ---------------------------------------------------------------------------
141
- # Tokenizer loader
142
- # ---------------------------------------------------------------------------
143
-
144
- class _TokenizerWrapper:
145
- """Minimal wrapper around the pickled tiktoken.Encoding. We avoid
146
- importing `prepare.Tokenizer` to sidestep its side effects (which
147
- touch the running pretrain's cache files)."""
148
-
149
- def __init__(self, enc):
150
- self.enc = enc
151
-
152
- def encode(self, text: str) -> list[int]:
153
- return self.enc.encode_ordinary(text)
154
-
155
- @property
156
- def vocab_size(self) -> int:
157
- return self.enc.n_vocab
158
-
159
-
160
- def load_tokenizer() -> _TokenizerWrapper:
161
- if not TOKENIZER_PKL.exists():
162
- raise FileNotFoundError(
163
- f"Tokenizer not found at {TOKENIZER_PKL}. Run `python prepare.py` "
164
- f"first."
165
- )
166
- with open(TOKENIZER_PKL, "rb") as f:
167
- enc = pickle.load(f)
168
- tok = _TokenizerWrapper(enc)
169
- assert tok.vocab_size == 8192, f"Expected vocab=8192, got {tok.vocab_size}"
170
- return tok
171
-
172
-
173
- # ---------------------------------------------------------------------------
174
- # Source downloaders
175
- # ---------------------------------------------------------------------------
176
-
177
- def _download_parquet(url: str, local_path: Path, timeout: int = 60) -> bool:
178
- """Stream-download a parquet file with retry. Returns True on success."""
179
- local_path.parent.mkdir(parents=True, exist_ok=True)
180
- tmp = local_path.with_suffix(local_path.suffix + ".tmp")
181
- for attempt in range(1, 4):
182
- try:
183
- with requests.get(url, stream=True, timeout=timeout,
184
- allow_redirects=True) as r:
185
- r.raise_for_status()
186
- with open(tmp, "wb") as f:
187
- for chunk in r.iter_content(chunk_size=1 << 20):
188
- if chunk:
189
- f.write(chunk)
190
- tmp.replace(local_path)
191
- return True
192
- except Exception as e:
193
- print(f" [net] attempt {attempt} failed: {e}", flush=True)
194
- for p in (tmp, local_path):
195
- try:
196
- p.unlink()
197
- except FileNotFoundError:
198
- pass
199
- time.sleep(2 ** attempt)
200
- return False
201
-
202
-
203
- def _iter_alpaca(local_path: Path):
204
- """Yield (instruction, input, output) from alpaca-cleaned parquet."""
205
- import pyarrow.parquet as pq
206
- pf = pq.ParquetFile(str(local_path))
207
- for rg_idx in range(pf.num_row_groups):
208
- rg = pf.read_row_group(rg_idx)
209
- instr_col = rg.column("instruction").to_pylist()
210
- input_col = rg.column("input").to_pylist()
211
- output_col = rg.column("output").to_pylist()
212
- for instruction, input_text, output in zip(instr_col, input_col, output_col):
213
- if instruction and output:
214
- yield instruction, (input_text or ""), output
215
-
216
-
217
- def _iter_dolly(local_path: Path):
218
- """Yield (instruction, input, output) from dolly-15k parquet."""
219
- import pyarrow.parquet as pq
220
- pf = pq.ParquetFile(str(local_path))
221
- # Schema: instruction, context, response, category
222
- for rg_idx in range(pf.num_row_groups):
223
- rg = pf.read_row_group(rg_idx)
224
- cols = {n: rg.column(n).to_pylist() for n in rg.schema.names}
225
- instr_col = cols.get("instruction") or cols.get("Instruction")
226
- ctx_col = cols.get("context") or cols.get("Context") or [""] * len(instr_col)
227
- resp_col = cols.get("response") or cols.get("Response")
228
- for instruction, context, response in zip(instr_col, ctx_col, resp_col):
229
- if instruction and response:
230
- yield instruction, (context or ""), response
231
-
232
-
233
- def _iter_backup():
234
- for q, a in BACKUP_QA:
235
- yield q, "", a
236
-
237
-
238
- # ---------------------------------------------------------------------------
239
- # Encoding
240
- # ---------------------------------------------------------------------------
241
-
242
- def encode_example(tok: _TokenizerWrapper, instruction: str,
243
- input_text: str, output: str) -> list[int]:
244
- """Serialize one instruction/response pair into a flat token list.
245
-
246
- Format:
247
- <BOS> <|user|> \\n {instr}\\n[{input}\\n] <|assistant|> \\n {output} <|end|> \\n
248
- """
249
- ids: list[int] = [BOS_ID, USER_ID]
250
- ids += tok.encode("\n" + instruction.strip())
251
- if input_text and input_text.strip():
252
- ids += tok.encode("\n" + input_text.strip())
253
- ids += tok.encode("\n")
254
- ids.append(ASSISTANT_ID)
255
- ids += tok.encode("\n" + output.strip())
256
- ids.append(END_ID)
257
- ids += tok.encode("\n")
258
- return ids
259
-
260
-
261
- def encode_example_with_mask(tok: _TokenizerWrapper, instruction: str,
262
- input_text: str, output: str
263
- ) -> tuple[list[int], list[int]]:
264
- """Return (tokens, mask) where mask[i]=1 means 'compute loss on token i'
265
- and mask[i]=0 means 'prompt, ignore'. The boundary is the <|assistant|>
266
- token: the assistant response (and <|end|>) contribute to loss; the
267
- user prompt does not."""
268
- prompt_ids = [BOS_ID, USER_ID] + tok.encode("\n" + instruction.strip())
269
- if input_text and input_text.strip():
270
- prompt_ids += tok.encode("\n" + input_text.strip())
271
- prompt_ids += tok.encode("\n")
272
- prompt_ids.append(ASSISTANT_ID)
273
-
274
- response_ids = tok.encode("\n" + output.strip())
275
- response_ids.append(END_ID)
276
- response_ids += tok.encode("\n")
277
-
278
- ids = prompt_ids + response_ids
279
- mask = [0] * len(prompt_ids) + [1] * len(response_ids)
280
- return ids, mask
281
-
282
-
283
- # ---------------------------------------------------------------------------
284
- # Shard writer
285
- # ---------------------------------------------------------------------------
286
-
287
- class ShardWriter:
288
- """Writes two parallel int16 files per shard:
289
- data/sft/shard_XXX.bin β€” token IDs
290
- data/sft/mask_XXX.bin β€” 0/1 loss mask
291
-
292
- Packs one example after another with no padding. At runtime, SFT builds
293
- sequences of length MAX_SEQ_LEN by slicing across these flat arrays.
294
- """
295
-
296
- def __init__(self, out_dir: Path, tokens_per_shard: int = TOKENS_PER_SHARD):
297
- self.out_dir = out_dir
298
- self.tokens_per_shard = tokens_per_shard
299
- self.shard_idx = 0
300
- self._buf_tok: list[int] = []
301
- self._buf_mask: list[int] = []
302
- self.total_tokens = 0
303
-
304
- def add(self, tokens: list[int], mask: list[int]):
305
- assert len(tokens) == len(mask)
306
- self._buf_tok.extend(tokens)
307
- self._buf_mask.extend(mask)
308
- self.total_tokens += len(tokens)
309
- while len(self._buf_tok) >= self.tokens_per_shard:
310
- self._flush_one(self.tokens_per_shard)
311
-
312
- def _flush_one(self, n: int):
313
- tok_path = self.out_dir / f"shard_{self.shard_idx:04d}.bin"
314
- mask_path = self.out_dir / f"mask_{self.shard_idx:04d}.bin"
315
- arr_tok = np.array(self._buf_tok[:n], dtype=DTYPE)
316
- arr_mask = np.array(self._buf_mask[:n], dtype=np.uint8)
317
- arr_tok.tofile(tok_path)
318
- arr_mask.tofile(mask_path)
319
- self._buf_tok = self._buf_tok[n:]
320
- self._buf_mask = self._buf_mask[n:]
321
- print(f" wrote {tok_path.name} ({n:,} tokens)", flush=True)
322
- self.shard_idx += 1
323
-
324
- def finalize(self):
325
- if self._buf_tok:
326
- self._flush_one(len(self._buf_tok))
327
-
328
-
329
- # ---------------------------------------------------------------------------
330
- # Main
331
- # ---------------------------------------------------------------------------
332
-
333
- def main():
334
- ap = argparse.ArgumentParser()
335
- ap.add_argument("--test", action="store_true",
336
- help="Small smoke run: write ~1.5M tokens and exit.")
337
- ap.add_argument("--offline", action="store_true",
338
- help="Skip network, use hard-coded backup only.")
339
- ap.add_argument("--target-tokens", type=int, default=None,
340
- help="Override target token count.")
341
- args = ap.parse_args()
342
-
343
- target = args.target_tokens or (
344
- TARGET_TOKENS_TEST if args.test else TARGET_TOKENS_DEFAULT
345
- )
346
-
347
- print(f"SFT_DIR: {SFT_DIR}")
348
- print(f"Target tokens: {target:,}")
349
- print(f"Offline mode: {args.offline}")
350
-
351
- # Clear any prior shards
352
- for p in SFT_DIR.glob("shard_*.bin"):
353
- p.unlink()
354
- for p in SFT_DIR.glob("mask_*.bin"):
355
- p.unlink()
356
-
357
- tok = load_tokenizer()
358
- print(f"Tokenizer vocab: {tok.vocab_size}")
359
- print(f"Special tokens: BOS={BOS_ID} USER={USER_ID} "
360
- f"ASSISTANT={ASSISTANT_ID} END={END_ID}")
361
-
362
- sources = [] # list of (name, iterator_fn)
363
- if not args.offline:
364
- alpaca_path = SFT_DIR / "alpaca_raw.parquet"
365
- print(f"\n[src] downloading alpaca-cleaned -> {alpaca_path.name} ...")
366
- if _download_parquet(ALPACA_URL, alpaca_path):
367
- print(f" ok ({alpaca_path.stat().st_size // (1 << 20)} MiB)")
368
- sources.append(("alpaca-cleaned", lambda: _iter_alpaca(alpaca_path)))
369
- else:
370
- print(" alpaca download FAILED, trying dolly...")
371
- dolly_path = SFT_DIR / "dolly_raw.parquet"
372
- if _download_parquet(DOLLY_URL, dolly_path):
373
- print(f" ok ({dolly_path.stat().st_size // (1 << 20)} MiB)")
374
- sources.append(("dolly-15k", lambda: _iter_dolly(dolly_path)))
375
-
376
- # Always include backup β€” cheap, catches tail
377
- sources.append(("backup-200", _iter_backup))
378
-
379
- if not sources:
380
- print("FATAL: no data sources available.", file=sys.stderr)
381
- sys.exit(1)
382
-
383
- # Stream-encode
384
- writer = ShardWriter(SFT_DIR)
385
- n_examples = 0
386
- n_assistant_tokens = 0
387
- source_counts = {}
388
-
389
- for src_name, src_fn in sources:
390
- print(f"\n[src] encoding {src_name} ...")
391
- src_examples = 0
392
- src_tokens = 0
393
- for (instruction, input_text, output) in src_fn():
394
- # Skip overly long outputs β€” 7.5M model can't use them
395
- if len(output) > 2000:
396
- output = output[:2000]
397
- ids, mask = encode_example_with_mask(tok, instruction,
398
- input_text, output)
399
- if len(ids) < 4 or len(ids) > 512:
400
- # Skip degenerate / too-long examples
401
- continue
402
- writer.add(ids, mask)
403
- n_examples += 1
404
- src_examples += 1
405
- src_tokens += len(ids)
406
- n_assistant_tokens += sum(mask)
407
- if writer.total_tokens >= target:
408
- break
409
- source_counts[src_name] = {
410
- "examples": src_examples,
411
- "tokens": src_tokens,
412
- }
413
- print(f" {src_name}: {src_examples:,} examples, {src_tokens:,} tokens")
414
- if writer.total_tokens >= target:
415
- break
416
-
417
- writer.finalize()
418
-
419
- meta = {
420
- "total_tokens": writer.total_tokens,
421
- "total_examples": n_examples,
422
- "assistant_tokens_in_loss": n_assistant_tokens,
423
- "num_shards": writer.shard_idx,
424
- "tokens_per_shard": TOKENS_PER_SHARD,
425
- "dtype": "int16",
426
- "vocab_size": tok.vocab_size,
427
- "special_tokens": {
428
- "bos": BOS_ID,
429
- "user": USER_ID,
430
- "assistant": ASSISTANT_ID,
431
- "end": END_ID,
432
- },
433
- "sources": source_counts,
434
- "format_hint": (
435
- "<BOS><|user|>\\n{instr}\\n[{input}\\n]<|assistant|>\\n"
436
- "{output}<|end|>\\n"
437
- ),
438
- }
439
- meta_path = SFT_DIR / "meta.json"
440
- with open(meta_path, "w") as f:
441
- json.dump(meta, f, indent=2)
442
-
443
- print(f"\n===== SFT data ready =====")
444
- print(f" examples: {n_examples:,}")
445
- print(f" total tokens: {writer.total_tokens:,}")
446
- print(f" loss tokens: {n_assistant_tokens:,}")
447
- print(f" shards: {writer.shard_idx}")
448
- print(f" meta: {meta_path}")
449
-
450
- if args.test and writer.total_tokens < 1_000_000:
451
- print(f"\nWARN: test mode produced only {writer.total_tokens:,} "
452
- f"tokens β€” below 1M threshold.")
453
- sys.exit(2)
454
-
455
-
456
- if __name__ == "__main__":
457
- main()
 
1
+ """Download + tokenize instruction data for HYDRA SFT.
2
+
3
+ Writes int16 token shards to `data/sft/shard_XXX.bin` plus a
4
+ `data/sft/meta.json` with counts + special-token mapping.
5
+
6
+ Chat format (vocab's 4 reserved special tokens are repurposed):
7
+ <BOS=8188> <|user|=8189>\n{instruction}\n{input?}\n <|assistant|=8190>\n
8
+ {output}<|end|=8191>\n
9
+
10
+ Special-token IDs are constants derived from the tokenizer (they are the
11
+ last 4 IDs in an 8192-vocab). They are stored in meta.json for the SFT
12
+ script to read.
13
+
14
+ Sources (tried in order):
15
+ 1. yahma/alpaca-cleaned (~52K pairs via HF parquet auto-convert)
16
+ 2. databricks/databricks-dolly-15k (~15K pairs)
17
+ 3. Hard-coded 200 simple Q&A pairs (offline backup)
18
+
19
+ Usage:
20
+ python scripts/download_sft_data.py # full download
21
+ python scripts/download_sft_data.py --test # small smoke run
22
+ python scripts/download_sft_data.py --offline # skip network; use backup
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import argparse
28
+ import json
29
+ import os
30
+ import pickle
31
+ import sys
32
+ import time
33
+ from pathlib import Path
34
+
35
+ import numpy as np
36
+ import requests
37
+
38
+ # Make `prepare` and `hydra.*` importable when run as a script
39
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
40
+ if str(_REPO_ROOT) not in sys.path:
41
+ sys.path.insert(0, str(_REPO_ROOT))
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # Constants
46
+ # ---------------------------------------------------------------------------
47
+
48
+ CACHE_DIR = Path.home() / ".cache" / "autoresearch"
49
+ TOKENIZER_PKL = CACHE_DIR / "tokenizer" / "tokenizer.pkl"
50
+
51
+ SFT_DIR = _REPO_ROOT / "data" / "sft"
52
+ SFT_DIR.mkdir(parents=True, exist_ok=True)
53
+
54
+ # Reserved token repurposing β€” must match prepare.py SPECIAL_TOKENS list
55
+ # (indices 8188-8191 in the 8192-vocab BPE).
56
+ BOS_ID = 8188 # <|reserved_0|>
57
+ USER_ID = 8189 # <|reserved_1|>
58
+ ASSISTANT_ID = 8190 # <|reserved_2|>
59
+ END_ID = 8191 # <|reserved_3|>
60
+
61
+ # Shards are int16 arrays of packed token IDs.
62
+ TOKENS_PER_SHARD = 1_048_576 # ~2 MB per shard
63
+ DTYPE = np.int16 # vocab_size=8192 fits in int16
64
+
65
+ TARGET_TOKENS_DEFAULT = 15_000_000 # ~15M instruction tokens
66
+ TARGET_TOKENS_TEST = 1_500_000 # smoke run
67
+
68
+ # HuggingFace auto-parquet endpoint β€” one file for alpaca-cleaned
69
+ ALPACA_URL = (
70
+ "https://huggingface.co/api/datasets/yahma/alpaca-cleaned/parquet/"
71
+ "default/train/0.parquet"
72
+ )
73
+ DOLLY_URL = (
74
+ "https://huggingface.co/api/datasets/databricks/databricks-dolly-15k/"
75
+ "parquet/default/train/0.parquet"
76
+ )
77
+
78
+
79
+ # ---------------------------------------------------------------------------
80
+ # Offline backup Q&A pairs (used only if network unavailable)
81
+ # ---------------------------------------------------------------------------
82
+
83
+ _BACKUP_QA = [
84
+ ("What is the capital of France?", "The capital of France is Paris."),
85
+ ("What is the capital of Germany?", "The capital of Germany is Berlin."),
86
+ ("What is the capital of Japan?", "The capital of Japan is Tokyo."),
87
+ ("What is the capital of Italy?", "The capital of Italy is Rome."),
88
+ ("What is the capital of Spain?", "The capital of Spain is Madrid."),
89
+ ("What is the capital of England?", "The capital of England is London."),
90
+ ("What is the capital of Canada?", "The capital of Canada is Ottawa."),
91
+ ("What is the capital of Australia?", "The capital of Australia is Canberra."),
92
+ ("What is 2 plus 2?", "Two plus two equals four."),
93
+ ("What is 5 plus 5?", "Five plus five equals ten."),
94
+ ("What is 3 times 4?", "Three times four equals twelve."),
95
+ ("What is 10 minus 3?", "Ten minus three equals seven."),
96
+ ("Name a color.", "Red is a color. Other colors include blue, green, and yellow."),
97
+ ("Name a fruit.", "Apple is a fruit. Other fruits include banana, orange, and grape."),
98
+ ("Name an animal.", "A dog is an animal. Other animals include cat, horse, and cow."),
99
+ ("What do birds do?", "Birds fly through the air using their wings."),
100
+ ("What do fish do?", "Fish swim in water using their fins and tails."),
101
+ ("What is water made of?", "Water is made of hydrogen and oxygen."),
102
+ ("What is the sun?", "The sun is a star at the center of our solar system."),
103
+ ("What is the moon?", "The moon is Earth's natural satellite."),
104
+ ("What is a cloud?", "A cloud is a collection of tiny water droplets in the sky."),
105
+ ("What is rain?", "Rain is water that falls from clouds to the ground."),
106
+ ("What is snow?", "Snow is frozen water that falls from clouds in cold weather."),
107
+ ("What is wind?", "Wind is moving air."),
108
+ ("What is fire?", "Fire is a chemical reaction that produces heat and light."),
109
+ ("What is ice?", "Ice is frozen water."),
110
+ ("Say hello.", "Hello! How can I help you today?"),
111
+ ("Say goodbye.", "Goodbye! Have a great day."),
112
+ ("Tell me a short story.", "Once upon a time, a brave rabbit lived in the forest. The end."),
113
+ ("Tell me a joke.", "Why did the chicken cross the road? To get to the other side."),
114
+ ("Who wrote Hamlet?", "William Shakespeare wrote the play Hamlet."),
115
+ ("Who wrote Romeo and Juliet?", "William Shakespeare wrote Romeo and Juliet."),
116
+ ("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa."),
117
+ ("When did World War 2 end?", "World War 2 ended in 1945."),
118
+ ("What is gravity?", "Gravity is the force that pulls objects toward the Earth."),
119
+ ("What is the speed of light?", "The speed of light is approximately 300,000 kilometers per second."),
120
+ ("What is the largest planet?", "Jupiter is the largest planet in our solar system."),
121
+ ("What is the smallest planet?", "Mercury is the smallest planet in our solar system."),
122
+ ("At what temperature does water boil?", "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit."),
123
+ ("At what temperature does water freeze?", "Water freezes at 0 degrees Celsius or 32 degrees Fahrenheit."),
124
+ ("How many legs does a spider have?", "A spider has eight legs."),
125
+ ("How many legs does an insect have?", "An insect has six legs."),
126
+ ("What do plants need to grow?", "Plants need sunlight, water, soil, and air to grow."),
127
+ ("What do humans eat?", "Humans eat a variety of foods including fruits, vegetables, meat, and grains."),
128
+ ("What is a book?", "A book is a collection of written or printed pages bound together."),
129
+ ("What is a computer?", "A computer is an electronic device that processes information."),
130
+ ("What is a phone?", "A phone is a device used to communicate with people at a distance."),
131
+ ("What is music?", "Music is an arrangement of sounds that is pleasing to hear."),
132
+ ("What is art?", "Art is the expression of human creativity and imagination."),
133
+ ("What is a language?", "A language is a system of communication used by a group of people."),
134
+ ]
135
+
136
+ # Duplicate to reach ~200 samples (each pair appears ~4x)
137
+ BACKUP_QA = (_BACKUP_QA * 4)[:200]
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Tokenizer loader
142
+ # ---------------------------------------------------------------------------
143
+
144
+ class _TokenizerWrapper:
145
+ """Minimal wrapper around the pickled tiktoken.Encoding. We avoid
146
+ importing `prepare.Tokenizer` to sidestep its side effects (which
147
+ touch the running pretrain's cache files)."""
148
+
149
+ def __init__(self, enc):
150
+ self.enc = enc
151
+
152
+ def encode(self, text: str) -> list[int]:
153
+ return self.enc.encode_ordinary(text)
154
+
155
+ @property
156
+ def vocab_size(self) -> int:
157
+ return self.enc.n_vocab
158
+
159
+
160
+ def load_tokenizer() -> _TokenizerWrapper:
161
+ if not TOKENIZER_PKL.exists():
162
+ raise FileNotFoundError(
163
+ f"Tokenizer not found at {TOKENIZER_PKL}. Run `python prepare.py` "
164
+ f"first."
165
+ )
166
+ with open(TOKENIZER_PKL, "rb") as f:
167
+ enc = pickle.load(f)
168
+ tok = _TokenizerWrapper(enc)
169
+ assert tok.vocab_size == 8192, f"Expected vocab=8192, got {tok.vocab_size}"
170
+ return tok
171
+
172
+
173
+ # ---------------------------------------------------------------------------
174
+ # Source downloaders
175
+ # ---------------------------------------------------------------------------
176
+
177
+ def _download_parquet(url: str, local_path: Path, timeout: int = 60) -> bool:
178
+ """Stream-download a parquet file with retry. Returns True on success."""
179
+ local_path.parent.mkdir(parents=True, exist_ok=True)
180
+ tmp = local_path.with_suffix(local_path.suffix + ".tmp")
181
+ for attempt in range(1, 4):
182
+ try:
183
+ with requests.get(url, stream=True, timeout=timeout,
184
+ allow_redirects=True) as r:
185
+ r.raise_for_status()
186
+ with open(tmp, "wb") as f:
187
+ for chunk in r.iter_content(chunk_size=1 << 20):
188
+ if chunk:
189
+ f.write(chunk)
190
+ tmp.replace(local_path)
191
+ return True
192
+ except Exception as e:
193
+ print(f" [net] attempt {attempt} failed: {e}", flush=True)
194
+ for p in (tmp, local_path):
195
+ try:
196
+ p.unlink()
197
+ except FileNotFoundError:
198
+ pass
199
+ time.sleep(2 ** attempt)
200
+ return False
201
+
202
+
203
+ def _iter_alpaca(local_path: Path):
204
+ """Yield (instruction, input, output) from alpaca-cleaned parquet."""
205
+ import pyarrow.parquet as pq
206
+ pf = pq.ParquetFile(str(local_path))
207
+ for rg_idx in range(pf.num_row_groups):
208
+ rg = pf.read_row_group(rg_idx)
209
+ instr_col = rg.column("instruction").to_pylist()
210
+ input_col = rg.column("input").to_pylist()
211
+ output_col = rg.column("output").to_pylist()
212
+ for instruction, input_text, output in zip(instr_col, input_col, output_col):
213
+ if instruction and output:
214
+ yield instruction, (input_text or ""), output
215
+
216
+
217
+ def _iter_dolly(local_path: Path):
218
+ """Yield (instruction, input, output) from dolly-15k parquet."""
219
+ import pyarrow.parquet as pq
220
+ pf = pq.ParquetFile(str(local_path))
221
+ # Schema: instruction, context, response, category
222
+ for rg_idx in range(pf.num_row_groups):
223
+ rg = pf.read_row_group(rg_idx)
224
+ cols = {n: rg.column(n).to_pylist() for n in rg.schema.names}
225
+ instr_col = cols.get("instruction") or cols.get("Instruction")
226
+ ctx_col = cols.get("context") or cols.get("Context") or [""] * len(instr_col)
227
+ resp_col = cols.get("response") or cols.get("Response")
228
+ for instruction, context, response in zip(instr_col, ctx_col, resp_col):
229
+ if instruction and response:
230
+ yield instruction, (context or ""), response
231
+
232
+
233
+ def _iter_backup():
234
+ for q, a in BACKUP_QA:
235
+ yield q, "", a
236
+
237
+
238
+ # ---------------------------------------------------------------------------
239
+ # Encoding
240
+ # ---------------------------------------------------------------------------
241
+
242
+ def encode_example(tok: _TokenizerWrapper, instruction: str,
243
+ input_text: str, output: str) -> list[int]:
244
+ """Serialize one instruction/response pair into a flat token list.
245
+
246
+ Format:
247
+ <BOS> <|user|> \\n {instr}\\n[{input}\\n] <|assistant|> \\n {output} <|end|> \\n
248
+ """
249
+ ids: list[int] = [BOS_ID, USER_ID]
250
+ ids += tok.encode("\n" + instruction.strip())
251
+ if input_text and input_text.strip():
252
+ ids += tok.encode("\n" + input_text.strip())
253
+ ids += tok.encode("\n")
254
+ ids.append(ASSISTANT_ID)
255
+ ids += tok.encode("\n" + output.strip())
256
+ ids.append(END_ID)
257
+ ids += tok.encode("\n")
258
+ return ids
259
+
260
+
261
+ def encode_example_with_mask(tok: _TokenizerWrapper, instruction: str,
262
+ input_text: str, output: str
263
+ ) -> tuple[list[int], list[int]]:
264
+ """Return (tokens, mask) where mask[i]=1 means 'compute loss on token i'
265
+ and mask[i]=0 means 'prompt, ignore'. The boundary is the <|assistant|>
266
+ token: the assistant response (and <|end|>) contribute to loss; the
267
+ user prompt does not."""
268
+ prompt_ids = [BOS_ID, USER_ID] + tok.encode("\n" + instruction.strip())
269
+ if input_text and input_text.strip():
270
+ prompt_ids += tok.encode("\n" + input_text.strip())
271
+ prompt_ids += tok.encode("\n")
272
+ prompt_ids.append(ASSISTANT_ID)
273
+
274
+ response_ids = tok.encode("\n" + output.strip())
275
+ response_ids.append(END_ID)
276
+ response_ids += tok.encode("\n")
277
+
278
+ ids = prompt_ids + response_ids
279
+ mask = [0] * len(prompt_ids) + [1] * len(response_ids)
280
+ return ids, mask
281
+
282
+
283
+ # ---------------------------------------------------------------------------
284
+ # Shard writer
285
+ # ---------------------------------------------------------------------------
286
+
287
+ class ShardWriter:
288
+ """Writes two parallel int16 files per shard:
289
+ data/sft/shard_XXX.bin β€” token IDs
290
+ data/sft/mask_XXX.bin β€” 0/1 loss mask
291
+
292
+ Packs one example after another with no padding. At runtime, SFT builds
293
+ sequences of length MAX_SEQ_LEN by slicing across these flat arrays.
294
+ """
295
+
296
+ def __init__(self, out_dir: Path, tokens_per_shard: int = TOKENS_PER_SHARD):
297
+ self.out_dir = out_dir
298
+ self.tokens_per_shard = tokens_per_shard
299
+ self.shard_idx = 0
300
+ self._buf_tok: list[int] = []
301
+ self._buf_mask: list[int] = []
302
+ self.total_tokens = 0
303
+
304
+ def add(self, tokens: list[int], mask: list[int]):
305
+ assert len(tokens) == len(mask)
306
+ self._buf_tok.extend(tokens)
307
+ self._buf_mask.extend(mask)
308
+ self.total_tokens += len(tokens)
309
+ while len(self._buf_tok) >= self.tokens_per_shard:
310
+ self._flush_one(self.tokens_per_shard)
311
+
312
+ def _flush_one(self, n: int):
313
+ tok_path = self.out_dir / f"shard_{self.shard_idx:04d}.bin"
314
+ mask_path = self.out_dir / f"mask_{self.shard_idx:04d}.bin"
315
+ arr_tok = np.array(self._buf_tok[:n], dtype=DTYPE)
316
+ arr_mask = np.array(self._buf_mask[:n], dtype=np.uint8)
317
+ arr_tok.tofile(tok_path)
318
+ arr_mask.tofile(mask_path)
319
+ self._buf_tok = self._buf_tok[n:]
320
+ self._buf_mask = self._buf_mask[n:]
321
+ print(f" wrote {tok_path.name} ({n:,} tokens)", flush=True)
322
+ self.shard_idx += 1
323
+
324
+ def finalize(self):
325
+ if self._buf_tok:
326
+ self._flush_one(len(self._buf_tok))
327
+
328
+
329
+ # ---------------------------------------------------------------------------
330
+ # Main
331
+ # ---------------------------------------------------------------------------
332
+
333
+ def main():
334
+ ap = argparse.ArgumentParser()
335
+ ap.add_argument("--test", action="store_true",
336
+ help="Small smoke run: write ~1.5M tokens and exit.")
337
+ ap.add_argument("--offline", action="store_true",
338
+ help="Skip network, use hard-coded backup only.")
339
+ ap.add_argument("--target-tokens", type=int, default=None,
340
+ help="Override target token count.")
341
+ args = ap.parse_args()
342
+
343
+ target = args.target_tokens or (
344
+ TARGET_TOKENS_TEST if args.test else TARGET_TOKENS_DEFAULT
345
+ )
346
+
347
+ print(f"SFT_DIR: {SFT_DIR}")
348
+ print(f"Target tokens: {target:,}")
349
+ print(f"Offline mode: {args.offline}")
350
+
351
+ # Clear any prior shards
352
+ for p in SFT_DIR.glob("shard_*.bin"):
353
+ p.unlink()
354
+ for p in SFT_DIR.glob("mask_*.bin"):
355
+ p.unlink()
356
+
357
+ tok = load_tokenizer()
358
+ print(f"Tokenizer vocab: {tok.vocab_size}")
359
+ print(f"Special tokens: BOS={BOS_ID} USER={USER_ID} "
360
+ f"ASSISTANT={ASSISTANT_ID} END={END_ID}")
361
+
362
+ sources = [] # list of (name, iterator_fn)
363
+ if not args.offline:
364
+ alpaca_path = SFT_DIR / "alpaca_raw.parquet"
365
+ print(f"\n[src] downloading alpaca-cleaned -> {alpaca_path.name} ...")
366
+ if _download_parquet(ALPACA_URL, alpaca_path):
367
+ print(f" ok ({alpaca_path.stat().st_size // (1 << 20)} MiB)")
368
+ sources.append(("alpaca-cleaned", lambda: _iter_alpaca(alpaca_path)))
369
+ else:
370
+ print(" alpaca download FAILED, trying dolly...")
371
+ dolly_path = SFT_DIR / "dolly_raw.parquet"
372
+ if _download_parquet(DOLLY_URL, dolly_path):
373
+ print(f" ok ({dolly_path.stat().st_size // (1 << 20)} MiB)")
374
+ sources.append(("dolly-15k", lambda: _iter_dolly(dolly_path)))
375
+
376
+ # Always include backup β€” cheap, catches tail
377
+ sources.append(("backup-200", _iter_backup))
378
+
379
+ if not sources:
380
+ print("FATAL: no data sources available.", file=sys.stderr)
381
+ sys.exit(1)
382
+
383
+ # Stream-encode
384
+ writer = ShardWriter(SFT_DIR)
385
+ n_examples = 0
386
+ n_assistant_tokens = 0
387
+ source_counts = {}
388
+
389
+ for src_name, src_fn in sources:
390
+ print(f"\n[src] encoding {src_name} ...")
391
+ src_examples = 0
392
+ src_tokens = 0
393
+ for (instruction, input_text, output) in src_fn():
394
+ # Skip overly long outputs β€” 7.5M model can't use them
395
+ if len(output) > 2000:
396
+ output = output[:2000]
397
+ ids, mask = encode_example_with_mask(tok, instruction,
398
+ input_text, output)
399
+ if len(ids) < 4 or len(ids) > 512:
400
+ # Skip degenerate / too-long examples
401
+ continue
402
+ writer.add(ids, mask)
403
+ n_examples += 1
404
+ src_examples += 1
405
+ src_tokens += len(ids)
406
+ n_assistant_tokens += sum(mask)
407
+ if writer.total_tokens >= target:
408
+ break
409
+ source_counts[src_name] = {
410
+ "examples": src_examples,
411
+ "tokens": src_tokens,
412
+ }
413
+ print(f" {src_name}: {src_examples:,} examples, {src_tokens:,} tokens")
414
+ if writer.total_tokens >= target:
415
+ break
416
+
417
+ writer.finalize()
418
+
419
+ meta = {
420
+ "total_tokens": writer.total_tokens,
421
+ "total_examples": n_examples,
422
+ "assistant_tokens_in_loss": n_assistant_tokens,
423
+ "num_shards": writer.shard_idx,
424
+ "tokens_per_shard": TOKENS_PER_SHARD,
425
+ "dtype": "int16",
426
+ "vocab_size": tok.vocab_size,
427
+ "special_tokens": {
428
+ "bos": BOS_ID,
429
+ "user": USER_ID,
430
+ "assistant": ASSISTANT_ID,
431
+ "end": END_ID,
432
+ },
433
+ "sources": source_counts,
434
+ "format_hint": (
435
+ "<BOS><|user|>\\n{instr}\\n[{input}\\n]<|assistant|>\\n"
436
+ "{output}<|end|>\\n"
437
+ ),
438
+ }
439
+ meta_path = SFT_DIR / "meta.json"
440
+ with open(meta_path, "w") as f:
441
+ json.dump(meta, f, indent=2)
442
+
443
+ print(f"\n===== SFT data ready =====")
444
+ print(f" examples: {n_examples:,}")
445
+ print(f" total tokens: {writer.total_tokens:,}")
446
+ print(f" loss tokens: {n_assistant_tokens:,}")
447
+ print(f" shards: {writer.shard_idx}")
448
+ print(f" meta: {meta_path}")
449
+
450
+ if args.test and writer.total_tokens < 1_000_000:
451
+ print(f"\nWARN: test mode produced only {writer.total_tokens:,} "
452
+ f"tokens β€” below 1M threshold.")
453
+ sys.exit(2)
454
+
455
+
456
+ if __name__ == "__main__":
457
+ main()
overlay/scripts/eval_quality.py CHANGED
@@ -1,525 +1,525 @@
1
- #!/usr/bin/env python3
2
- """Comprehensive quality evaluation harness for HYDRA.
3
-
4
- Computes: PPL, BLEU-1, BLEU-4, ROUGE-1, ROUGE-L, factual accuracy,
5
- coherence metrics (distinct-2, repetition-rate, self-BLEU), and a
6
- composite quality_score.
7
-
8
- Usage:
9
- python scripts/eval_quality.py # eval latest model
10
- python scripts/eval_quality.py --checkpoint ckpt.pt # eval from checkpoint
11
-
12
- All metrics printed as key=value (grep-friendly). Runs in <30s on RTX 3060.
13
- """
14
-
15
- from __future__ import annotations
16
-
17
- import math
18
- import os
19
- import sys
20
- import time
21
- from collections import Counter
22
- from typing import Optional
23
-
24
- # Ensure project root is on path
25
- _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
26
- if _PROJECT_ROOT not in sys.path:
27
- sys.path.insert(0, _PROJECT_ROOT)
28
-
29
- import torch
30
- import torch.nn.functional as F
31
-
32
- from hydra.config import (
33
- D_MODEL, D_STATE, DEVICE_BATCH_SIZE, ENGRAM_KEY_DIM,
34
- ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, HEADDIM,
35
- N_HEADS, N_LAYER, PostSemClawConfig,
36
- )
37
- from hydra.eval import FACTUAL_EVAL
38
- from prepare import MAX_SEQ_LEN, Tokenizer, evaluate_bpb
39
-
40
- # ---------------------------------------------------------------------------
41
- # Eval prompts (hardcoded for reproducibility)
42
- # ---------------------------------------------------------------------------
43
-
44
- EVAL_PROMPTS = [
45
- "The capital of France is",
46
- "In 1969, humans first",
47
- "Water boils at a temperature of",
48
- "The theory of relativity was developed by",
49
- "The largest planet in our solar system is",
50
- "Photosynthesis is the process by which",
51
- "The stock market crashed in",
52
- "DNA stands for",
53
- "The speed of light is approximately",
54
- "Shakespeare wrote the play",
55
- "The mitochondria is often called the",
56
- "In computer science, an algorithm is",
57
- "The chemical symbol for gold is",
58
- "The Great Wall of China was built to",
59
- "Gravity is a force that",
60
- "The human heart pumps blood through",
61
- "The Amazon rainforest is located in",
62
- "Pi is approximately equal to",
63
- "The first President of the United States was",
64
- "Oxygen makes up approximately",
65
- ]
66
-
67
- # Reference continuations (approximate, for BLEU/ROUGE)
68
- EVAL_REFERENCES = [
69
- "Paris, which is also the largest city in France.",
70
- "landed on the Moon during the Apollo 11 mission.",
71
- "100 degrees Celsius or 212 degrees Fahrenheit at standard atmospheric pressure.",
72
- "Albert Einstein in the early twentieth century.",
73
- "Jupiter, which is a gas giant.",
74
- "plants convert sunlight into chemical energy and produce oxygen.",
75
- "1929, leading to the Great Depression.",
76
- "deoxyribonucleic acid, which carries genetic information.",
77
- "299,792 kilometers per second in a vacuum.",
78
- "Romeo and Juliet, one of the most famous tragedies.",
79
- "powerhouse of the cell because it produces energy.",
80
- "a step by step procedure for solving a problem.",
81
- "Au, from the Latin word aurum.",
82
- "protect against invasions from the north.",
83
- "attracts objects with mass toward each other.",
84
- "the circulatory system to deliver oxygen and nutrients.",
85
- "South America, primarily within Brazil.",
86
- "3.14159, and it represents the ratio of circumference to diameter.",
87
- "George Washington, who served from 1789 to 1797.",
88
- "21 percent of the Earth's atmosphere.",
89
- ]
90
-
91
- COHERENCE_PROMPTS = [
92
- "The history of science shows that",
93
- "In modern society, technology has",
94
- "The relationship between education and",
95
- "Climate change is affecting the world because",
96
- "The development of artificial intelligence has led to",
97
- "Throughout human history, art has been",
98
- "The economy of a nation depends on",
99
- "Medical research has shown that",
100
- "The role of government in society is",
101
- "The ocean covers more than",
102
- ]
103
-
104
-
105
- # ---------------------------------------------------------------------------
106
- # Manual BLEU implementation (no nltk dependency)
107
- # ---------------------------------------------------------------------------
108
-
109
- def _get_ngrams(tokens: list[str], n: int) -> Counter:
110
- """Extract n-gram counts from token list."""
111
- return Counter(tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1))
112
-
113
-
114
- def _modified_precision(reference_tokens: list[str], hypothesis_tokens: list[str], n: int) -> tuple[int, int]:
115
- """Compute modified precision for n-grams."""
116
- ref_ngrams = _get_ngrams(reference_tokens, n)
117
- hyp_ngrams = _get_ngrams(hypothesis_tokens, n)
118
- clipped_count = 0
119
- total_count = 0
120
- for ngram, count in hyp_ngrams.items():
121
- clipped_count += min(count, ref_ngrams.get(ngram, 0))
122
- total_count += count
123
- return clipped_count, max(total_count, 1)
124
-
125
-
126
- def compute_bleu(references: list[list[str]], hypotheses: list[list[str]], max_n: int = 4) -> dict[str, float]:
127
- """Corpus-level BLEU-1 through BLEU-max_n.
128
-
129
- Uses brevity penalty and geometric mean of modified precisions.
130
- """
131
- precisions = []
132
- for n in range(1, max_n + 1):
133
- total_clip = 0
134
- total_count = 0
135
- for ref, hyp in zip(references, hypotheses):
136
- clip, count = _modified_precision(ref, hyp, n)
137
- total_clip += clip
138
- total_count += count
139
- precisions.append(total_clip / max(total_count, 1))
140
-
141
- # Brevity penalty
142
- ref_len = sum(len(r) for r in references)
143
- hyp_len = sum(len(h) for h in hypotheses)
144
- if hyp_len == 0:
145
- return {f"bleu{n}": 0.0 for n in range(1, max_n + 1)}
146
- bp = math.exp(min(0, 1 - ref_len / hyp_len))
147
-
148
- result = {}
149
- for n in range(1, max_n + 1):
150
- # Geometric mean of precisions 1..n
151
- log_avg = sum(math.log(max(p, 1e-10)) for p in precisions[:n]) / n
152
- result[f"bleu{n}"] = bp * math.exp(log_avg)
153
- return result
154
-
155
-
156
- # ---------------------------------------------------------------------------
157
- # Manual ROUGE implementation (no rouge_score dependency)
158
- # ---------------------------------------------------------------------------
159
-
160
- def _lcs_length(x: list[str], y: list[str]) -> int:
161
- """Longest common subsequence length via DP."""
162
- m, n = len(x), len(y)
163
- if m == 0 or n == 0:
164
- return 0
165
- # Space-optimized: only keep current and previous row
166
- prev = [0] * (n + 1)
167
- curr = [0] * (n + 1)
168
- for i in range(1, m + 1):
169
- for j in range(1, n + 1):
170
- if x[i - 1] == y[j - 1]:
171
- curr[j] = prev[j - 1] + 1
172
- else:
173
- curr[j] = max(prev[j], curr[j - 1])
174
- prev, curr = curr, [0] * (n + 1)
175
- return prev[n]
176
-
177
-
178
- def compute_rouge(references: list[list[str]], hypotheses: list[list[str]]) -> dict[str, float]:
179
- """Compute ROUGE-1 (unigram F1) and ROUGE-L (LCS-based F1)."""
180
- rouge1_scores = []
181
- rougel_scores = []
182
-
183
- for ref, hyp in zip(references, hypotheses):
184
- if not ref or not hyp:
185
- rouge1_scores.append(0.0)
186
- rougel_scores.append(0.0)
187
- continue
188
-
189
- # ROUGE-1: unigram overlap
190
- ref_unigrams = Counter(ref)
191
- hyp_unigrams = Counter(hyp)
192
- overlap = sum((ref_unigrams & hyp_unigrams).values())
193
- r1_precision = overlap / max(len(hyp), 1)
194
- r1_recall = overlap / max(len(ref), 1)
195
- r1_f1 = 2 * r1_precision * r1_recall / max(r1_precision + r1_recall, 1e-10)
196
- rouge1_scores.append(r1_f1)
197
-
198
- # ROUGE-L: LCS-based
199
- lcs = _lcs_length(ref, hyp)
200
- rl_precision = lcs / max(len(hyp), 1)
201
- rl_recall = lcs / max(len(ref), 1)
202
- rl_f1 = 2 * rl_precision * rl_recall / max(rl_precision + rl_recall, 1e-10)
203
- rougel_scores.append(rl_f1)
204
-
205
- return {
206
- "rouge1": sum(rouge1_scores) / max(len(rouge1_scores), 1),
207
- "rouge_l": sum(rougel_scores) / max(len(rougel_scores), 1),
208
- }
209
-
210
-
211
- # ---------------------------------------------------------------------------
212
- # Greedy generation
213
- # ---------------------------------------------------------------------------
214
-
215
- @torch.no_grad()
216
- def greedy_generate(model, tokenizer, prompt: str, max_new_tokens: int = 32, device: str = "cuda") -> str:
217
- """Greedy (argmax) autoregressive generation. Deterministic."""
218
- ids = tokenizer.encode(prompt)
219
- x = torch.tensor([ids], device=device, dtype=torch.long)
220
-
221
- for _ in range(max_new_tokens):
222
- logits = model(x, targets=None)
223
- if logits.dim() == 3:
224
- next_logits = logits[0, -1, :]
225
- else:
226
- next_logits = logits[0]
227
- next_id = next_logits.argmax().unsqueeze(0).unsqueeze(0)
228
- x = torch.cat([x, next_id], dim=1)
229
- if x.size(1) >= MAX_SEQ_LEN:
230
- break
231
-
232
- all_ids = x[0].tolist()
233
- return tokenizer.decode(all_ids[len(ids):])
234
-
235
-
236
- # ---------------------------------------------------------------------------
237
- # Coherence metrics
238
- # ---------------------------------------------------------------------------
239
-
240
- def compute_coherence(generations: list[str]) -> dict[str, float]:
241
- """Compute distinct-2, repetition rate, and self-BLEU across generations."""
242
- all_bigrams = []
243
- all_fourgrams = []
244
- tokenized_gens = []
245
-
246
- for gen in generations:
247
- tokens = gen.lower().split()
248
- tokenized_gens.append(tokens)
249
- bigrams = [tuple(tokens[i:i + 2]) for i in range(len(tokens) - 1)]
250
- fourgrams = [tuple(tokens[i:i + 4]) for i in range(len(tokens) - 3)]
251
- all_bigrams.extend(bigrams)
252
- all_fourgrams.extend(fourgrams)
253
-
254
- # Distinct-2: fraction of unique bigrams
255
- distinct2 = len(set(all_bigrams)) / max(len(all_bigrams), 1)
256
-
257
- # Repetition rate: fraction of 4-grams that appear more than once
258
- fourgram_counts = Counter(all_fourgrams)
259
- repeated = sum(1 for c in fourgram_counts.values() if c > 1)
260
- repetition_rate = repeated / max(len(fourgram_counts), 1)
261
-
262
- # Self-BLEU: average BLEU of each generation against all others
263
- # Lower = more diverse
264
- self_bleu_scores = []
265
- for i, hyp in enumerate(tokenized_gens):
266
- if not hyp:
267
- continue
268
- others = [g for j, g in enumerate(tokenized_gens) if j != i and g]
269
- if not others:
270
- continue
271
- # Average BLEU against each other generation
272
- pair_scores = []
273
- for ref in others:
274
- result = compute_bleu([ref], [hyp], max_n=4)
275
- pair_scores.append(result.get("bleu4", 0.0))
276
- self_bleu_scores.append(sum(pair_scores) / len(pair_scores))
277
-
278
- self_bleu = sum(self_bleu_scores) / max(len(self_bleu_scores), 1)
279
-
280
- return {
281
- "distinct2": distinct2,
282
- "repetition_rate": repetition_rate,
283
- "self_bleu": self_bleu,
284
- }
285
-
286
-
287
- # ---------------------------------------------------------------------------
288
- # Factual accuracy (reuse existing probes)
289
- # ---------------------------------------------------------------------------
290
-
291
- def compute_factual(model, tokenizer, device: str = "cuda") -> float:
292
- """Run factual eval probes, return accuracy [0,1]."""
293
- model.eval()
294
- hits = 0
295
-
296
- with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
297
- for prompt, answers in FACTUAL_EVAL:
298
- ids = tokenizer.encode(prompt)
299
- x = torch.tensor([ids], device=device, dtype=torch.long)
300
- logits = model(x, targets=None)
301
- if logits.dim() == 3:
302
- last_logits = logits[0, -1, :]
303
- else:
304
- last_logits = logits[0]
305
-
306
- probs = torch.softmax(last_logits.float(), dim=-1)
307
- top_k = min(20, probs.shape[-1])
308
- top_ids = torch.topk(probs, top_k).indices.tolist()
309
- top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
310
- answers_lower = [a.lower() for a in answers]
311
- if any(any(a in tok for a in answers_lower) for tok in top_tokens):
312
- hits += 1
313
-
314
- return hits / max(len(FACTUAL_EVAL), 1)
315
-
316
-
317
- # ---------------------------------------------------------------------------
318
- # PPL (perplexity) via existing evaluate_bpb
319
- # ---------------------------------------------------------------------------
320
-
321
- def compute_ppl(model, tokenizer, batch_size: int = 8) -> tuple[float, float]:
322
- """Compute BPB and PPL. Returns (bpb, ppl)."""
323
- import prepare as _prepare_mod
324
- # Use smaller eval set for speed (<30s budget)
325
- orig_eval = _prepare_mod.EVAL_TOKENS
326
- _prepare_mod.EVAL_TOKENS = 2 * 524288 # ~1M tokens, fast
327
- try:
328
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
329
- bpb = evaluate_bpb(model, tokenizer, batch_size)
330
- finally:
331
- _prepare_mod.EVAL_TOKENS = orig_eval
332
- ppl = 2 ** bpb
333
- return bpb, ppl
334
-
335
-
336
- # ---------------------------------------------------------------------------
337
- # Composite quality score
338
- # ---------------------------------------------------------------------------
339
-
340
- def compute_quality_score(ppl: float, bleu4: float, rouge_l: float,
341
- factual: float, repetition_rate: float) -> float:
342
- """Single composite metric for autoresearch optimization.
343
-
344
- Formula rationale:
345
- - PPL (30%): Primary language modeling metric, capped at 100
346
- - BLEU-4 (20%): Generation quality vs references
347
- - ROUGE-L (20%): Recall of reference content
348
- - Factual (15%): Knowledge memorization
349
- - 1-repetition (15%): Diversity/coherence
350
- """
351
- return (
352
- 0.3 * (1 - min(ppl, 100) / 100) +
353
- 0.2 * bleu4 +
354
- 0.2 * rouge_l +
355
- 0.15 * factual +
356
- 0.15 * (1 - repetition_rate)
357
- )
358
-
359
-
360
- # ---------------------------------------------------------------------------
361
- # Main evaluation entry point
362
- # ---------------------------------------------------------------------------
363
-
364
- def run_quality_eval(
365
- model: torch.nn.Module,
366
- tokenizer,
367
- device: str = "cuda",
368
- batch_size: int = 8,
369
- verbose: bool = True,
370
- ) -> dict[str, float]:
371
- """Run full quality evaluation suite. Returns dict of all metrics."""
372
- model.eval()
373
- results: dict[str, float] = {}
374
-
375
- t0 = time.time()
376
-
377
- # 1. PPL / BPB
378
- if verbose:
379
- print("[eval] Computing PPL/BPB...", flush=True)
380
- bpb, ppl = compute_ppl(model, tokenizer, batch_size)
381
- results["bpb"] = bpb
382
- results["ppl"] = ppl
383
-
384
- # 2. Generate continuations for BLEU/ROUGE
385
- if verbose:
386
- print("[eval] Generating continuations (20 prompts, greedy)...", flush=True)
387
- hypotheses_text = []
388
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
389
- for prompt in EVAL_PROMPTS:
390
- gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=32, device=device)
391
- hypotheses_text.append(gen)
392
-
393
- # Tokenize for BLEU/ROUGE (simple whitespace split)
394
- ref_tokens = [ref.lower().split() for ref in EVAL_REFERENCES]
395
- hyp_tokens = [hyp.lower().split() for hyp in hypotheses_text]
396
-
397
- # 3. BLEU
398
- if verbose:
399
- print("[eval] Computing BLEU...", flush=True)
400
- bleu = compute_bleu(ref_tokens, hyp_tokens, max_n=4)
401
- results["bleu1"] = bleu["bleu1"]
402
- results["bleu4"] = bleu["bleu4"]
403
-
404
- # 4. ROUGE
405
- if verbose:
406
- print("[eval] Computing ROUGE...", flush=True)
407
- rouge = compute_rouge(ref_tokens, hyp_tokens)
408
- results["rouge1"] = rouge["rouge1"]
409
- results["rouge_l"] = rouge["rouge_l"]
410
-
411
- # 5. Factual accuracy
412
- if verbose:
413
- print("[eval] Computing factual accuracy...", flush=True)
414
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
415
- factual = compute_factual(model, tokenizer, device)
416
- results["factual"] = factual
417
-
418
- # 6. Coherence
419
- if verbose:
420
- print("[eval] Generating coherence passages (10 prompts, 64 tokens)...", flush=True)
421
- coherence_gens = []
422
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
423
- for prompt in COHERENCE_PROMPTS:
424
- gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=64, device=device)
425
- coherence_gens.append(gen)
426
-
427
- coherence = compute_coherence(coherence_gens)
428
- results["distinct2"] = coherence["distinct2"]
429
- results["repetition_rate"] = coherence["repetition_rate"]
430
- results["self_bleu"] = coherence["self_bleu"]
431
-
432
- # 7. Composite score
433
- results["quality_score"] = compute_quality_score(
434
- ppl=results["ppl"],
435
- bleu4=results["bleu4"],
436
- rouge_l=results["rouge_l"],
437
- factual=results["factual"],
438
- repetition_rate=results["repetition_rate"],
439
- )
440
-
441
- elapsed = time.time() - t0
442
- results["eval_time_s"] = elapsed
443
-
444
- # Print all metrics
445
- if verbose:
446
- print("\n--- Quality Evaluation Results ---")
447
- for k, v in sorted(results.items()):
448
- print(f"{k}={v:.6f}")
449
- print("--- End Quality Evaluation ---\n")
450
-
451
- # Print sample generations
452
- print("--- Sample Generations ---")
453
- for i, (prompt, gen) in enumerate(zip(EVAL_PROMPTS[:5], hypotheses_text[:5])):
454
- print(f' [{i}] "{prompt}" -> "{gen.strip()[:80]}"')
455
- print("--- End Sample Generations ---\n")
456
-
457
- print("--- Coherence Samples ---")
458
- for i, (prompt, gen) in enumerate(zip(COHERENCE_PROMPTS[:3], coherence_gens[:3])):
459
- print(f' [{i}] "{prompt}" -> "{gen.strip()[:100]}"')
460
- print("--- End Coherence Samples ---\n")
461
-
462
- return results
463
-
464
-
465
- # ---------------------------------------------------------------------------
466
- # Standalone CLI
467
- # ---------------------------------------------------------------------------
468
-
469
- def _build_model_and_tokenizer(checkpoint: Optional[str] = None):
470
- """Build model + tokenizer, optionally loading from checkpoint."""
471
- from hydra.model import PostSemClawModel
472
-
473
- device = torch.device("cuda")
474
- tokenizer = Tokenizer.from_directory()
475
- vocab_size = tokenizer.get_vocab_size()
476
-
477
- config = PostSemClawConfig(
478
- sequence_len=MAX_SEQ_LEN,
479
- vocab_size=vocab_size,
480
- n_layer=N_LAYER,
481
- d_model=D_MODEL,
482
- d_state=D_STATE,
483
- headdim=HEADDIM,
484
- n_heads=N_HEADS,
485
- expand=EXPAND,
486
- engram_n_columns=ENGRAM_N_COLUMNS,
487
- engram_key_dim=ENGRAM_KEY_DIM,
488
- engram_layer_idx=ENGRAM_LAYER_IDX,
489
- )
490
-
491
- with torch.device("meta"):
492
- model = PostSemClawModel(config)
493
- model.to_empty(device=device)
494
-
495
- if checkpoint and os.path.exists(checkpoint):
496
- print(f"[eval] Loading checkpoint: {checkpoint}")
497
- state = torch.load(checkpoint, map_location=device, weights_only=True)
498
- model.load_state_dict(state, strict=False)
499
- else:
500
- print("[eval] No checkpoint β€” using freshly initialized weights")
501
- model.init_weights()
502
-
503
- model.eval()
504
- return model, tokenizer, device
505
-
506
-
507
- def main():
508
- import argparse
509
- parser = argparse.ArgumentParser(description="HYDRA quality evaluation")
510
- parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint")
511
- parser.add_argument("--batch-size", type=int, default=DEVICE_BATCH_SIZE, help="Batch size for PPL eval")
512
- args = parser.parse_args()
513
-
514
- model, tokenizer, device = _build_model_and_tokenizer(args.checkpoint)
515
- results = run_quality_eval(model, tokenizer, str(device), args.batch_size, verbose=True)
516
-
517
- # Final summary line (grep-friendly)
518
- print(f"QUALITY_SCORE={results['quality_score']:.6f} PPL={results['ppl']:.3f} "
519
- f"BPB={results['bpb']:.4f} BLEU4={results['bleu4']:.4f} "
520
- f"ROUGE_L={results['rouge_l']:.4f} FACTUAL={results['factual']:.4f} "
521
- f"REP_RATE={results['repetition_rate']:.4f}")
522
-
523
-
524
- if __name__ == "__main__":
525
- main()
 
1
+ #!/usr/bin/env python3
2
+ """Comprehensive quality evaluation harness for HYDRA.
3
+
4
+ Computes: PPL, BLEU-1, BLEU-4, ROUGE-1, ROUGE-L, factual accuracy,
5
+ coherence metrics (distinct-2, repetition-rate, self-BLEU), and a
6
+ composite quality_score.
7
+
8
+ Usage:
9
+ python scripts/eval_quality.py # eval latest model
10
+ python scripts/eval_quality.py --checkpoint ckpt.pt # eval from checkpoint
11
+
12
+ All metrics printed as key=value (grep-friendly). Runs in <30s on RTX 3060.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import math
18
+ import os
19
+ import sys
20
+ import time
21
+ from collections import Counter
22
+ from typing import Optional
23
+
24
+ # Ensure project root is on path
25
+ _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
26
+ if _PROJECT_ROOT not in sys.path:
27
+ sys.path.insert(0, _PROJECT_ROOT)
28
+
29
+ import torch
30
+ import torch.nn.functional as F
31
+
32
+ from hydra.config import (
33
+ D_MODEL, D_STATE, DEVICE_BATCH_SIZE, ENGRAM_KEY_DIM,
34
+ ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, HEADDIM,
35
+ N_HEADS, N_LAYER, PostSemClawConfig,
36
+ )
37
+ from hydra.eval import FACTUAL_EVAL
38
+ from prepare import MAX_SEQ_LEN, Tokenizer, evaluate_bpb
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Eval prompts (hardcoded for reproducibility)
42
+ # ---------------------------------------------------------------------------
43
+
44
+ EVAL_PROMPTS = [
45
+ "The capital of France is",
46
+ "In 1969, humans first",
47
+ "Water boils at a temperature of",
48
+ "The theory of relativity was developed by",
49
+ "The largest planet in our solar system is",
50
+ "Photosynthesis is the process by which",
51
+ "The stock market crashed in",
52
+ "DNA stands for",
53
+ "The speed of light is approximately",
54
+ "Shakespeare wrote the play",
55
+ "The mitochondria is often called the",
56
+ "In computer science, an algorithm is",
57
+ "The chemical symbol for gold is",
58
+ "The Great Wall of China was built to",
59
+ "Gravity is a force that",
60
+ "The human heart pumps blood through",
61
+ "The Amazon rainforest is located in",
62
+ "Pi is approximately equal to",
63
+ "The first President of the United States was",
64
+ "Oxygen makes up approximately",
65
+ ]
66
+
67
+ # Reference continuations (approximate, for BLEU/ROUGE)
68
+ EVAL_REFERENCES = [
69
+ "Paris, which is also the largest city in France.",
70
+ "landed on the Moon during the Apollo 11 mission.",
71
+ "100 degrees Celsius or 212 degrees Fahrenheit at standard atmospheric pressure.",
72
+ "Albert Einstein in the early twentieth century.",
73
+ "Jupiter, which is a gas giant.",
74
+ "plants convert sunlight into chemical energy and produce oxygen.",
75
+ "1929, leading to the Great Depression.",
76
+ "deoxyribonucleic acid, which carries genetic information.",
77
+ "299,792 kilometers per second in a vacuum.",
78
+ "Romeo and Juliet, one of the most famous tragedies.",
79
+ "powerhouse of the cell because it produces energy.",
80
+ "a step by step procedure for solving a problem.",
81
+ "Au, from the Latin word aurum.",
82
+ "protect against invasions from the north.",
83
+ "attracts objects with mass toward each other.",
84
+ "the circulatory system to deliver oxygen and nutrients.",
85
+ "South America, primarily within Brazil.",
86
+ "3.14159, and it represents the ratio of circumference to diameter.",
87
+ "George Washington, who served from 1789 to 1797.",
88
+ "21 percent of the Earth's atmosphere.",
89
+ ]
90
+
91
+ COHERENCE_PROMPTS = [
92
+ "The history of science shows that",
93
+ "In modern society, technology has",
94
+ "The relationship between education and",
95
+ "Climate change is affecting the world because",
96
+ "The development of artificial intelligence has led to",
97
+ "Throughout human history, art has been",
98
+ "The economy of a nation depends on",
99
+ "Medical research has shown that",
100
+ "The role of government in society is",
101
+ "The ocean covers more than",
102
+ ]
103
+
104
+
105
+ # ---------------------------------------------------------------------------
106
+ # Manual BLEU implementation (no nltk dependency)
107
+ # ---------------------------------------------------------------------------
108
+
109
+ def _get_ngrams(tokens: list[str], n: int) -> Counter:
110
+ """Extract n-gram counts from token list."""
111
+ return Counter(tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1))
112
+
113
+
114
+ def _modified_precision(reference_tokens: list[str], hypothesis_tokens: list[str], n: int) -> tuple[int, int]:
115
+ """Compute modified precision for n-grams."""
116
+ ref_ngrams = _get_ngrams(reference_tokens, n)
117
+ hyp_ngrams = _get_ngrams(hypothesis_tokens, n)
118
+ clipped_count = 0
119
+ total_count = 0
120
+ for ngram, count in hyp_ngrams.items():
121
+ clipped_count += min(count, ref_ngrams.get(ngram, 0))
122
+ total_count += count
123
+ return clipped_count, max(total_count, 1)
124
+
125
+
126
+ def compute_bleu(references: list[list[str]], hypotheses: list[list[str]], max_n: int = 4) -> dict[str, float]:
127
+ """Corpus-level BLEU-1 through BLEU-max_n.
128
+
129
+ Uses brevity penalty and geometric mean of modified precisions.
130
+ """
131
+ precisions = []
132
+ for n in range(1, max_n + 1):
133
+ total_clip = 0
134
+ total_count = 0
135
+ for ref, hyp in zip(references, hypotheses):
136
+ clip, count = _modified_precision(ref, hyp, n)
137
+ total_clip += clip
138
+ total_count += count
139
+ precisions.append(total_clip / max(total_count, 1))
140
+
141
+ # Brevity penalty
142
+ ref_len = sum(len(r) for r in references)
143
+ hyp_len = sum(len(h) for h in hypotheses)
144
+ if hyp_len == 0:
145
+ return {f"bleu{n}": 0.0 for n in range(1, max_n + 1)}
146
+ bp = math.exp(min(0, 1 - ref_len / hyp_len))
147
+
148
+ result = {}
149
+ for n in range(1, max_n + 1):
150
+ # Geometric mean of precisions 1..n
151
+ log_avg = sum(math.log(max(p, 1e-10)) for p in precisions[:n]) / n
152
+ result[f"bleu{n}"] = bp * math.exp(log_avg)
153
+ return result
154
+
155
+
156
+ # ---------------------------------------------------------------------------
157
+ # Manual ROUGE implementation (no rouge_score dependency)
158
+ # ---------------------------------------------------------------------------
159
+
160
+ def _lcs_length(x: list[str], y: list[str]) -> int:
161
+ """Longest common subsequence length via DP."""
162
+ m, n = len(x), len(y)
163
+ if m == 0 or n == 0:
164
+ return 0
165
+ # Space-optimized: only keep current and previous row
166
+ prev = [0] * (n + 1)
167
+ curr = [0] * (n + 1)
168
+ for i in range(1, m + 1):
169
+ for j in range(1, n + 1):
170
+ if x[i - 1] == y[j - 1]:
171
+ curr[j] = prev[j - 1] + 1
172
+ else:
173
+ curr[j] = max(prev[j], curr[j - 1])
174
+ prev, curr = curr, [0] * (n + 1)
175
+ return prev[n]
176
+
177
+
178
+ def compute_rouge(references: list[list[str]], hypotheses: list[list[str]]) -> dict[str, float]:
179
+ """Compute ROUGE-1 (unigram F1) and ROUGE-L (LCS-based F1)."""
180
+ rouge1_scores = []
181
+ rougel_scores = []
182
+
183
+ for ref, hyp in zip(references, hypotheses):
184
+ if not ref or not hyp:
185
+ rouge1_scores.append(0.0)
186
+ rougel_scores.append(0.0)
187
+ continue
188
+
189
+ # ROUGE-1: unigram overlap
190
+ ref_unigrams = Counter(ref)
191
+ hyp_unigrams = Counter(hyp)
192
+ overlap = sum((ref_unigrams & hyp_unigrams).values())
193
+ r1_precision = overlap / max(len(hyp), 1)
194
+ r1_recall = overlap / max(len(ref), 1)
195
+ r1_f1 = 2 * r1_precision * r1_recall / max(r1_precision + r1_recall, 1e-10)
196
+ rouge1_scores.append(r1_f1)
197
+
198
+ # ROUGE-L: LCS-based
199
+ lcs = _lcs_length(ref, hyp)
200
+ rl_precision = lcs / max(len(hyp), 1)
201
+ rl_recall = lcs / max(len(ref), 1)
202
+ rl_f1 = 2 * rl_precision * rl_recall / max(rl_precision + rl_recall, 1e-10)
203
+ rougel_scores.append(rl_f1)
204
+
205
+ return {
206
+ "rouge1": sum(rouge1_scores) / max(len(rouge1_scores), 1),
207
+ "rouge_l": sum(rougel_scores) / max(len(rougel_scores), 1),
208
+ }
209
+
210
+
211
+ # ---------------------------------------------------------------------------
212
+ # Greedy generation
213
+ # ---------------------------------------------------------------------------
214
+
215
+ @torch.no_grad()
216
+ def greedy_generate(model, tokenizer, prompt: str, max_new_tokens: int = 32, device: str = "cuda") -> str:
217
+ """Greedy (argmax) autoregressive generation. Deterministic."""
218
+ ids = tokenizer.encode(prompt)
219
+ x = torch.tensor([ids], device=device, dtype=torch.long)
220
+
221
+ for _ in range(max_new_tokens):
222
+ logits = model(x, targets=None)
223
+ if logits.dim() == 3:
224
+ next_logits = logits[0, -1, :]
225
+ else:
226
+ next_logits = logits[0]
227
+ next_id = next_logits.argmax().unsqueeze(0).unsqueeze(0)
228
+ x = torch.cat([x, next_id], dim=1)
229
+ if x.size(1) >= MAX_SEQ_LEN:
230
+ break
231
+
232
+ all_ids = x[0].tolist()
233
+ return tokenizer.decode(all_ids[len(ids):])
234
+
235
+
236
+ # ---------------------------------------------------------------------------
237
+ # Coherence metrics
238
+ # ---------------------------------------------------------------------------
239
+
240
+ def compute_coherence(generations: list[str]) -> dict[str, float]:
241
+ """Compute distinct-2, repetition rate, and self-BLEU across generations."""
242
+ all_bigrams = []
243
+ all_fourgrams = []
244
+ tokenized_gens = []
245
+
246
+ for gen in generations:
247
+ tokens = gen.lower().split()
248
+ tokenized_gens.append(tokens)
249
+ bigrams = [tuple(tokens[i:i + 2]) for i in range(len(tokens) - 1)]
250
+ fourgrams = [tuple(tokens[i:i + 4]) for i in range(len(tokens) - 3)]
251
+ all_bigrams.extend(bigrams)
252
+ all_fourgrams.extend(fourgrams)
253
+
254
+ # Distinct-2: fraction of unique bigrams
255
+ distinct2 = len(set(all_bigrams)) / max(len(all_bigrams), 1)
256
+
257
+ # Repetition rate: fraction of 4-grams that appear more than once
258
+ fourgram_counts = Counter(all_fourgrams)
259
+ repeated = sum(1 for c in fourgram_counts.values() if c > 1)
260
+ repetition_rate = repeated / max(len(fourgram_counts), 1)
261
+
262
+ # Self-BLEU: average BLEU of each generation against all others
263
+ # Lower = more diverse
264
+ self_bleu_scores = []
265
+ for i, hyp in enumerate(tokenized_gens):
266
+ if not hyp:
267
+ continue
268
+ others = [g for j, g in enumerate(tokenized_gens) if j != i and g]
269
+ if not others:
270
+ continue
271
+ # Average BLEU against each other generation
272
+ pair_scores = []
273
+ for ref in others:
274
+ result = compute_bleu([ref], [hyp], max_n=4)
275
+ pair_scores.append(result.get("bleu4", 0.0))
276
+ self_bleu_scores.append(sum(pair_scores) / len(pair_scores))
277
+
278
+ self_bleu = sum(self_bleu_scores) / max(len(self_bleu_scores), 1)
279
+
280
+ return {
281
+ "distinct2": distinct2,
282
+ "repetition_rate": repetition_rate,
283
+ "self_bleu": self_bleu,
284
+ }
285
+
286
+
287
+ # ---------------------------------------------------------------------------
288
+ # Factual accuracy (reuse existing probes)
289
+ # ---------------------------------------------------------------------------
290
+
291
+ def compute_factual(model, tokenizer, device: str = "cuda") -> float:
292
+ """Run factual eval probes, return accuracy [0,1]."""
293
+ model.eval()
294
+ hits = 0
295
+
296
+ with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
297
+ for prompt, answers in FACTUAL_EVAL:
298
+ ids = tokenizer.encode(prompt)
299
+ x = torch.tensor([ids], device=device, dtype=torch.long)
300
+ logits = model(x, targets=None)
301
+ if logits.dim() == 3:
302
+ last_logits = logits[0, -1, :]
303
+ else:
304
+ last_logits = logits[0]
305
+
306
+ probs = torch.softmax(last_logits.float(), dim=-1)
307
+ top_k = min(20, probs.shape[-1])
308
+ top_ids = torch.topk(probs, top_k).indices.tolist()
309
+ top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
310
+ answers_lower = [a.lower() for a in answers]
311
+ if any(any(a in tok for a in answers_lower) for tok in top_tokens):
312
+ hits += 1
313
+
314
+ return hits / max(len(FACTUAL_EVAL), 1)
315
+
316
+
317
+ # ---------------------------------------------------------------------------
318
+ # PPL (perplexity) via existing evaluate_bpb
319
+ # ---------------------------------------------------------------------------
320
+
321
+ def compute_ppl(model, tokenizer, batch_size: int = 8) -> tuple[float, float]:
322
+ """Compute BPB and PPL. Returns (bpb, ppl)."""
323
+ import prepare as _prepare_mod
324
+ # Use smaller eval set for speed (<30s budget)
325
+ orig_eval = _prepare_mod.EVAL_TOKENS
326
+ _prepare_mod.EVAL_TOKENS = 2 * 524288 # ~1M tokens, fast
327
+ try:
328
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
329
+ bpb = evaluate_bpb(model, tokenizer, batch_size)
330
+ finally:
331
+ _prepare_mod.EVAL_TOKENS = orig_eval
332
+ ppl = 2 ** bpb
333
+ return bpb, ppl
334
+
335
+
336
+ # ---------------------------------------------------------------------------
337
+ # Composite quality score
338
+ # ---------------------------------------------------------------------------
339
+
340
+ def compute_quality_score(ppl: float, bleu4: float, rouge_l: float,
341
+ factual: float, repetition_rate: float) -> float:
342
+ """Single composite metric for autoresearch optimization.
343
+
344
+ Formula rationale:
345
+ - PPL (30%): Primary language modeling metric, capped at 100
346
+ - BLEU-4 (20%): Generation quality vs references
347
+ - ROUGE-L (20%): Recall of reference content
348
+ - Factual (15%): Knowledge memorization
349
+ - 1-repetition (15%): Diversity/coherence
350
+ """
351
+ return (
352
+ 0.3 * (1 - min(ppl, 100) / 100) +
353
+ 0.2 * bleu4 +
354
+ 0.2 * rouge_l +
355
+ 0.15 * factual +
356
+ 0.15 * (1 - repetition_rate)
357
+ )
358
+
359
+
360
+ # ---------------------------------------------------------------------------
361
+ # Main evaluation entry point
362
+ # ---------------------------------------------------------------------------
363
+
364
+ def run_quality_eval(
365
+ model: torch.nn.Module,
366
+ tokenizer,
367
+ device: str = "cuda",
368
+ batch_size: int = 8,
369
+ verbose: bool = True,
370
+ ) -> dict[str, float]:
371
+ """Run full quality evaluation suite. Returns dict of all metrics."""
372
+ model.eval()
373
+ results: dict[str, float] = {}
374
+
375
+ t0 = time.time()
376
+
377
+ # 1. PPL / BPB
378
+ if verbose:
379
+ print("[eval] Computing PPL/BPB...", flush=True)
380
+ bpb, ppl = compute_ppl(model, tokenizer, batch_size)
381
+ results["bpb"] = bpb
382
+ results["ppl"] = ppl
383
+
384
+ # 2. Generate continuations for BLEU/ROUGE
385
+ if verbose:
386
+ print("[eval] Generating continuations (20 prompts, greedy)...", flush=True)
387
+ hypotheses_text = []
388
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
389
+ for prompt in EVAL_PROMPTS:
390
+ gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=32, device=device)
391
+ hypotheses_text.append(gen)
392
+
393
+ # Tokenize for BLEU/ROUGE (simple whitespace split)
394
+ ref_tokens = [ref.lower().split() for ref in EVAL_REFERENCES]
395
+ hyp_tokens = [hyp.lower().split() for hyp in hypotheses_text]
396
+
397
+ # 3. BLEU
398
+ if verbose:
399
+ print("[eval] Computing BLEU...", flush=True)
400
+ bleu = compute_bleu(ref_tokens, hyp_tokens, max_n=4)
401
+ results["bleu1"] = bleu["bleu1"]
402
+ results["bleu4"] = bleu["bleu4"]
403
+
404
+ # 4. ROUGE
405
+ if verbose:
406
+ print("[eval] Computing ROUGE...", flush=True)
407
+ rouge = compute_rouge(ref_tokens, hyp_tokens)
408
+ results["rouge1"] = rouge["rouge1"]
409
+ results["rouge_l"] = rouge["rouge_l"]
410
+
411
+ # 5. Factual accuracy
412
+ if verbose:
413
+ print("[eval] Computing factual accuracy...", flush=True)
414
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
415
+ factual = compute_factual(model, tokenizer, device)
416
+ results["factual"] = factual
417
+
418
+ # 6. Coherence
419
+ if verbose:
420
+ print("[eval] Generating coherence passages (10 prompts, 64 tokens)...", flush=True)
421
+ coherence_gens = []
422
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
423
+ for prompt in COHERENCE_PROMPTS:
424
+ gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=64, device=device)
425
+ coherence_gens.append(gen)
426
+
427
+ coherence = compute_coherence(coherence_gens)
428
+ results["distinct2"] = coherence["distinct2"]
429
+ results["repetition_rate"] = coherence["repetition_rate"]
430
+ results["self_bleu"] = coherence["self_bleu"]
431
+
432
+ # 7. Composite score
433
+ results["quality_score"] = compute_quality_score(
434
+ ppl=results["ppl"],
435
+ bleu4=results["bleu4"],
436
+ rouge_l=results["rouge_l"],
437
+ factual=results["factual"],
438
+ repetition_rate=results["repetition_rate"],
439
+ )
440
+
441
+ elapsed = time.time() - t0
442
+ results["eval_time_s"] = elapsed
443
+
444
+ # Print all metrics
445
+ if verbose:
446
+ print("\n--- Quality Evaluation Results ---")
447
+ for k, v in sorted(results.items()):
448
+ print(f"{k}={v:.6f}")
449
+ print("--- End Quality Evaluation ---\n")
450
+
451
+ # Print sample generations
452
+ print("--- Sample Generations ---")
453
+ for i, (prompt, gen) in enumerate(zip(EVAL_PROMPTS[:5], hypotheses_text[:5])):
454
+ print(f' [{i}] "{prompt}" -> "{gen.strip()[:80]}"')
455
+ print("--- End Sample Generations ---\n")
456
+
457
+ print("--- Coherence Samples ---")
458
+ for i, (prompt, gen) in enumerate(zip(COHERENCE_PROMPTS[:3], coherence_gens[:3])):
459
+ print(f' [{i}] "{prompt}" -> "{gen.strip()[:100]}"')
460
+ print("--- End Coherence Samples ---\n")
461
+
462
+ return results
463
+
464
+
465
+ # ---------------------------------------------------------------------------
466
+ # Standalone CLI
467
+ # ---------------------------------------------------------------------------
468
+
469
+ def _build_model_and_tokenizer(checkpoint: Optional[str] = None):
470
+ """Build model + tokenizer, optionally loading from checkpoint."""
471
+ from hydra.model import PostSemClawModel
472
+
473
+ device = torch.device("cuda")
474
+ tokenizer = Tokenizer.from_directory()
475
+ vocab_size = tokenizer.get_vocab_size()
476
+
477
+ config = PostSemClawConfig(
478
+ sequence_len=MAX_SEQ_LEN,
479
+ vocab_size=vocab_size,
480
+ n_layer=N_LAYER,
481
+ d_model=D_MODEL,
482
+ d_state=D_STATE,
483
+ headdim=HEADDIM,
484
+ n_heads=N_HEADS,
485
+ expand=EXPAND,
486
+ engram_n_columns=ENGRAM_N_COLUMNS,
487
+ engram_key_dim=ENGRAM_KEY_DIM,
488
+ engram_layer_idx=ENGRAM_LAYER_IDX,
489
+ )
490
+
491
+ with torch.device("meta"):
492
+ model = PostSemClawModel(config)
493
+ model.to_empty(device=device)
494
+
495
+ if checkpoint and os.path.exists(checkpoint):
496
+ print(f"[eval] Loading checkpoint: {checkpoint}")
497
+ state = torch.load(checkpoint, map_location=device, weights_only=True)
498
+ model.load_state_dict(state, strict=False)
499
+ else:
500
+ print("[eval] No checkpoint β€” using freshly initialized weights")
501
+ model.init_weights()
502
+
503
+ model.eval()
504
+ return model, tokenizer, device
505
+
506
+
507
+ def main():
508
+ import argparse
509
+ parser = argparse.ArgumentParser(description="HYDRA quality evaluation")
510
+ parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint")
511
+ parser.add_argument("--batch-size", type=int, default=DEVICE_BATCH_SIZE, help="Batch size for PPL eval")
512
+ args = parser.parse_args()
513
+
514
+ model, tokenizer, device = _build_model_and_tokenizer(args.checkpoint)
515
+ results = run_quality_eval(model, tokenizer, str(device), args.batch_size, verbose=True)
516
+
517
+ # Final summary line (grep-friendly)
518
+ print(f"QUALITY_SCORE={results['quality_score']:.6f} PPL={results['ppl']:.3f} "
519
+ f"BPB={results['bpb']:.4f} BLEU4={results['bleu4']:.4f} "
520
+ f"ROUGE_L={results['rouge_l']:.4f} FACTUAL={results['factual']:.4f} "
521
+ f"REP_RATE={results['repetition_rate']:.4f}")
522
+
523
+
524
+ if __name__ == "__main__":
525
+ main()
overlay/scripts/fetch_corpus.py CHANGED
@@ -1,211 +1,211 @@
1
- """
2
- Fetch additional training shards from karpathy/climbmix-400b-shuffle.
3
-
4
- The repo already has ~500 shards (~31B tokens). This script is a
5
- resumable, parallel downloader for cases where more shards are needed
6
- (e.g., multi-day training, experiments requiring fresh-unseen data,
7
- or when we want to split the corpus across processes).
8
-
9
- Usage:
10
- # Fetch shards up to index 600 (total cap)
11
- python scripts/fetch_corpus.py --target-shards 600
12
-
13
- # Fetch a specific range
14
- python scripts/fetch_corpus.py --start 500 --end 800
15
-
16
- # Dry-run (list what would be downloaded)
17
- python scripts/fetch_corpus.py --target-shards 600 --dry-run
18
-
19
- Notes:
20
- - Safe to run while training is active; only writes files not touched
21
- by the training process.
22
- - Resumable: skips shards already on disk.
23
- - Downloads to the same DATA_DIR used by prepare.py so they're picked
24
- up on next training launch.
25
- """
26
- from __future__ import annotations
27
-
28
- import argparse
29
- import os
30
- import shutil
31
- import sys
32
- import time
33
- from concurrent.futures import ThreadPoolExecutor, as_completed
34
- from pathlib import Path
35
-
36
- import requests
37
-
38
- REPO_ROOT = Path(__file__).resolve().parent.parent
39
- sys.path.insert(0, str(REPO_ROOT))
40
-
41
- from prepare import BASE_URL, DATA_DIR, MAX_SHARD, VAL_SHARD # noqa: E402
42
-
43
-
44
- def human_bytes(n: int) -> str:
45
- for unit in ("B", "KB", "MB", "GB", "TB"):
46
- if n < 1024:
47
- return f"{n:.1f}{unit}"
48
- n /= 1024
49
- return f"{n:.1f}PB"
50
-
51
-
52
- def download_one(
53
- index: int, data_dir: str, timeout: int = 30, max_attempts: int = 5
54
- ) -> tuple[int, bool, int, str]:
55
- """
56
- Download a single parquet shard. Resumable + retry with exponential backoff.
57
- Returns (index, success, bytes_written, message).
58
- """
59
- filename = f"shard_{index:05d}.parquet"
60
- filepath = os.path.join(data_dir, filename)
61
- tmp_path = filepath + ".tmp"
62
-
63
- if os.path.exists(filepath):
64
- return index, True, 0, "already-present"
65
-
66
- url = f"{BASE_URL}/{filename}"
67
- for attempt in range(1, max_attempts + 1):
68
- try:
69
- with requests.get(url, stream=True, timeout=timeout) as r:
70
- r.raise_for_status()
71
- bytes_written = 0
72
- with open(tmp_path, "wb") as f:
73
- for chunk in r.iter_content(chunk_size=1 << 20):
74
- if chunk:
75
- f.write(chunk)
76
- bytes_written += len(chunk)
77
- os.rename(tmp_path, filepath)
78
- return index, True, bytes_written, f"ok (attempt {attempt})"
79
- except (requests.RequestException, OSError) as e:
80
- # Clean up partial file.
81
- for p in (tmp_path, filepath):
82
- if os.path.exists(p):
83
- try:
84
- os.remove(p)
85
- except OSError:
86
- pass
87
- if attempt < max_attempts:
88
- wait = 2 ** attempt
89
- time.sleep(wait)
90
- continue
91
- return index, False, 0, f"failed after {max_attempts} attempts: {e}"
92
-
93
- return index, False, 0, "unknown failure"
94
-
95
-
96
- def check_disk_space(required_bytes: int, data_dir: str) -> tuple[bool, int]:
97
- """Ensure we have at least required_bytes + 10% headroom free."""
98
- os.makedirs(data_dir, exist_ok=True)
99
- stats = shutil.disk_usage(data_dir)
100
- headroom = int(required_bytes * 1.1)
101
- return stats.free >= headroom, stats.free
102
-
103
-
104
- def main() -> int:
105
- parser = argparse.ArgumentParser(
106
- description="Fetch additional climbmix-400b-shuffle shards"
107
- )
108
- parser.add_argument(
109
- "--target-shards",
110
- type=int,
111
- default=None,
112
- help="Total train-shard count to reach (0..target-1). Mutually exclusive with --start/--end.",
113
- )
114
- parser.add_argument("--start", type=int, default=None, help="Starting shard index (inclusive)")
115
- parser.add_argument("--end", type=int, default=None, help="Ending shard index (exclusive)")
116
- parser.add_argument("--workers", type=int, default=8, help="Parallel download workers")
117
- parser.add_argument(
118
- "--include-val",
119
- action="store_true",
120
- help="Also fetch the pinned validation shard (normally present already)",
121
- )
122
- parser.add_argument(
123
- "--dry-run",
124
- action="store_true",
125
- help="List what would be downloaded without fetching",
126
- )
127
- args = parser.parse_args()
128
-
129
- # Resolve shard range.
130
- if args.target_shards is not None:
131
- if args.start is not None or args.end is not None:
132
- print("ERROR: --target-shards is exclusive with --start/--end")
133
- return 1
134
- ids = list(range(min(args.target_shards, MAX_SHARD)))
135
- else:
136
- start = args.start or 0
137
- end = args.end if args.end is not None else MAX_SHARD
138
- end = min(end, MAX_SHARD)
139
- ids = list(range(start, end))
140
-
141
- if args.include_val and VAL_SHARD not in ids:
142
- ids.append(VAL_SHARD)
143
-
144
- os.makedirs(DATA_DIR, exist_ok=True)
145
- present = set()
146
- for p in Path(DATA_DIR).glob("shard_*.parquet"):
147
- try:
148
- idx = int(p.stem.split("_")[1])
149
- present.add(idx)
150
- except (IndexError, ValueError):
151
- continue
152
-
153
- to_fetch = [i for i in ids if i not in present]
154
- if not to_fetch:
155
- print(f"All {len(ids)} shards already present at {DATA_DIR}")
156
- return 0
157
-
158
- # Estimate space: shards are ~88MB; leave 10% headroom.
159
- avg_shard_bytes = 90 * (1 << 20) # 90MB
160
- required = avg_shard_bytes * len(to_fetch)
161
- ok, free = check_disk_space(required, DATA_DIR)
162
- print(f"Plan: fetch {len(to_fetch)} shards (~{human_bytes(required)}); "
163
- f"disk free: {human_bytes(free)}")
164
- if not ok:
165
- print("ERROR: insufficient disk space (need 1.1x required)")
166
- return 2
167
-
168
- if args.dry_run:
169
- preview = to_fetch[:10]
170
- print(
171
- f"Dry-run β€” would fetch {len(to_fetch)} shards. First {len(preview)}: {preview}"
172
- )
173
- return 0
174
-
175
- print(f"Downloading {len(to_fetch)} shards with {args.workers} workers...")
176
- t_start = time.time()
177
- success = 0
178
- failed = 0
179
- total_bytes = 0
180
-
181
- with ThreadPoolExecutor(max_workers=args.workers) as ex:
182
- futs = {ex.submit(download_one, i, DATA_DIR): i for i in to_fetch}
183
- for fut in as_completed(futs):
184
- idx, ok, nbytes, msg = fut.result()
185
- if ok:
186
- success += 1
187
- total_bytes += nbytes
188
- if success % 10 == 0 or success == len(to_fetch):
189
- elapsed = time.time() - t_start
190
- rate = total_bytes / max(elapsed, 1)
191
- print(
192
- f" [{success}/{len(to_fetch)}] shard_{idx:05d} ok "
193
- f"({human_bytes(total_bytes)} @ {human_bytes(int(rate))}/s)"
194
- )
195
- else:
196
- failed += 1
197
- print(f" [FAIL] shard_{idx:05d}: {msg}")
198
-
199
- elapsed = time.time() - t_start
200
- print()
201
- print("=" * 60)
202
- print(f"Downloaded {success}/{len(to_fetch)} shards in {elapsed:.1f}s")
203
- print(f"Failed: {failed}")
204
- print(f"Total bytes: {human_bytes(total_bytes)}")
205
- print("=" * 60)
206
-
207
- return 0 if failed == 0 else 3
208
-
209
-
210
- if __name__ == "__main__":
211
- raise SystemExit(main())
 
1
+ """
2
+ Fetch additional training shards from karpathy/climbmix-400b-shuffle.
3
+
4
+ The repo already has ~500 shards (~31B tokens). This script is a
5
+ resumable, parallel downloader for cases where more shards are needed
6
+ (e.g., multi-day training, experiments requiring fresh-unseen data,
7
+ or when we want to split the corpus across processes).
8
+
9
+ Usage:
10
+ # Fetch shards up to index 600 (total cap)
11
+ python scripts/fetch_corpus.py --target-shards 600
12
+
13
+ # Fetch a specific range
14
+ python scripts/fetch_corpus.py --start 500 --end 800
15
+
16
+ # Dry-run (list what would be downloaded)
17
+ python scripts/fetch_corpus.py --target-shards 600 --dry-run
18
+
19
+ Notes:
20
+ - Safe to run while training is active; only writes files not touched
21
+ by the training process.
22
+ - Resumable: skips shards already on disk.
23
+ - Downloads to the same DATA_DIR used by prepare.py so they're picked
24
+ up on next training launch.
25
+ """
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import os
30
+ import shutil
31
+ import sys
32
+ import time
33
+ from concurrent.futures import ThreadPoolExecutor, as_completed
34
+ from pathlib import Path
35
+
36
+ import requests
37
+
38
+ REPO_ROOT = Path(__file__).resolve().parent.parent
39
+ sys.path.insert(0, str(REPO_ROOT))
40
+
41
+ from prepare import BASE_URL, DATA_DIR, MAX_SHARD, VAL_SHARD # noqa: E402
42
+
43
+
44
+ def human_bytes(n: int) -> str:
45
+ for unit in ("B", "KB", "MB", "GB", "TB"):
46
+ if n < 1024:
47
+ return f"{n:.1f}{unit}"
48
+ n /= 1024
49
+ return f"{n:.1f}PB"
50
+
51
+
52
+ def download_one(
53
+ index: int, data_dir: str, timeout: int = 30, max_attempts: int = 5
54
+ ) -> tuple[int, bool, int, str]:
55
+ """
56
+ Download a single parquet shard. Resumable + retry with exponential backoff.
57
+ Returns (index, success, bytes_written, message).
58
+ """
59
+ filename = f"shard_{index:05d}.parquet"
60
+ filepath = os.path.join(data_dir, filename)
61
+ tmp_path = filepath + ".tmp"
62
+
63
+ if os.path.exists(filepath):
64
+ return index, True, 0, "already-present"
65
+
66
+ url = f"{BASE_URL}/{filename}"
67
+ for attempt in range(1, max_attempts + 1):
68
+ try:
69
+ with requests.get(url, stream=True, timeout=timeout) as r:
70
+ r.raise_for_status()
71
+ bytes_written = 0
72
+ with open(tmp_path, "wb") as f:
73
+ for chunk in r.iter_content(chunk_size=1 << 20):
74
+ if chunk:
75
+ f.write(chunk)
76
+ bytes_written += len(chunk)
77
+ os.rename(tmp_path, filepath)
78
+ return index, True, bytes_written, f"ok (attempt {attempt})"
79
+ except (requests.RequestException, OSError) as e:
80
+ # Clean up partial file.
81
+ for p in (tmp_path, filepath):
82
+ if os.path.exists(p):
83
+ try:
84
+ os.remove(p)
85
+ except OSError:
86
+ pass
87
+ if attempt < max_attempts:
88
+ wait = 2 ** attempt
89
+ time.sleep(wait)
90
+ continue
91
+ return index, False, 0, f"failed after {max_attempts} attempts: {e}"
92
+
93
+ return index, False, 0, "unknown failure"
94
+
95
+
96
+ def check_disk_space(required_bytes: int, data_dir: str) -> tuple[bool, int]:
97
+ """Ensure we have at least required_bytes + 10% headroom free."""
98
+ os.makedirs(data_dir, exist_ok=True)
99
+ stats = shutil.disk_usage(data_dir)
100
+ headroom = int(required_bytes * 1.1)
101
+ return stats.free >= headroom, stats.free
102
+
103
+
104
+ def main() -> int:
105
+ parser = argparse.ArgumentParser(
106
+ description="Fetch additional climbmix-400b-shuffle shards"
107
+ )
108
+ parser.add_argument(
109
+ "--target-shards",
110
+ type=int,
111
+ default=None,
112
+ help="Total train-shard count to reach (0..target-1). Mutually exclusive with --start/--end.",
113
+ )
114
+ parser.add_argument("--start", type=int, default=None, help="Starting shard index (inclusive)")
115
+ parser.add_argument("--end", type=int, default=None, help="Ending shard index (exclusive)")
116
+ parser.add_argument("--workers", type=int, default=8, help="Parallel download workers")
117
+ parser.add_argument(
118
+ "--include-val",
119
+ action="store_true",
120
+ help="Also fetch the pinned validation shard (normally present already)",
121
+ )
122
+ parser.add_argument(
123
+ "--dry-run",
124
+ action="store_true",
125
+ help="List what would be downloaded without fetching",
126
+ )
127
+ args = parser.parse_args()
128
+
129
+ # Resolve shard range.
130
+ if args.target_shards is not None:
131
+ if args.start is not None or args.end is not None:
132
+ print("ERROR: --target-shards is exclusive with --start/--end")
133
+ return 1
134
+ ids = list(range(min(args.target_shards, MAX_SHARD)))
135
+ else:
136
+ start = args.start or 0
137
+ end = args.end if args.end is not None else MAX_SHARD
138
+ end = min(end, MAX_SHARD)
139
+ ids = list(range(start, end))
140
+
141
+ if args.include_val and VAL_SHARD not in ids:
142
+ ids.append(VAL_SHARD)
143
+
144
+ os.makedirs(DATA_DIR, exist_ok=True)
145
+ present = set()
146
+ for p in Path(DATA_DIR).glob("shard_*.parquet"):
147
+ try:
148
+ idx = int(p.stem.split("_")[1])
149
+ present.add(idx)
150
+ except (IndexError, ValueError):
151
+ continue
152
+
153
+ to_fetch = [i for i in ids if i not in present]
154
+ if not to_fetch:
155
+ print(f"All {len(ids)} shards already present at {DATA_DIR}")
156
+ return 0
157
+
158
+ # Estimate space: shards are ~88MB; leave 10% headroom.
159
+ avg_shard_bytes = 90 * (1 << 20) # 90MB
160
+ required = avg_shard_bytes * len(to_fetch)
161
+ ok, free = check_disk_space(required, DATA_DIR)
162
+ print(f"Plan: fetch {len(to_fetch)} shards (~{human_bytes(required)}); "
163
+ f"disk free: {human_bytes(free)}")
164
+ if not ok:
165
+ print("ERROR: insufficient disk space (need 1.1x required)")
166
+ return 2
167
+
168
+ if args.dry_run:
169
+ preview = to_fetch[:10]
170
+ print(
171
+ f"Dry-run β€” would fetch {len(to_fetch)} shards. First {len(preview)}: {preview}"
172
+ )
173
+ return 0
174
+
175
+ print(f"Downloading {len(to_fetch)} shards with {args.workers} workers...")
176
+ t_start = time.time()
177
+ success = 0
178
+ failed = 0
179
+ total_bytes = 0
180
+
181
+ with ThreadPoolExecutor(max_workers=args.workers) as ex:
182
+ futs = {ex.submit(download_one, i, DATA_DIR): i for i in to_fetch}
183
+ for fut in as_completed(futs):
184
+ idx, ok, nbytes, msg = fut.result()
185
+ if ok:
186
+ success += 1
187
+ total_bytes += nbytes
188
+ if success % 10 == 0 or success == len(to_fetch):
189
+ elapsed = time.time() - t_start
190
+ rate = total_bytes / max(elapsed, 1)
191
+ print(
192
+ f" [{success}/{len(to_fetch)}] shard_{idx:05d} ok "
193
+ f"({human_bytes(total_bytes)} @ {human_bytes(int(rate))}/s)"
194
+ )
195
+ else:
196
+ failed += 1
197
+ print(f" [FAIL] shard_{idx:05d}: {msg}")
198
+
199
+ elapsed = time.time() - t_start
200
+ print()
201
+ print("=" * 60)
202
+ print(f"Downloaded {success}/{len(to_fetch)} shards in {elapsed:.1f}s")
203
+ print(f"Failed: {failed}")
204
+ print(f"Total bytes: {human_bytes(total_bytes)}")
205
+ print("=" * 60)
206
+
207
+ return 0 if failed == 0 else 3
208
+
209
+
210
+ if __name__ == "__main__":
211
+ raise SystemExit(main())
overlay/scripts/grad_probe.py CHANGED
@@ -1,196 +1,196 @@
1
- """
2
- Gradient flow probe for PostSemClawModel.
3
-
4
- READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT
5
- step an optimizer. Runs one forward + backward and reports, per-parameter:
6
-
7
- name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm
8
-
9
- Severity classification at the bottom:
10
- BLOCKER β€” requires_grad=True but p.grad is None (disconnected from graph)
11
- WARNING β€” grad present but literally zero (ops cancel, wd_init, etc.)
12
- WARNING β€” requires_grad=True but param missing from every optimizer group
13
- OK β€” everything else
14
-
15
- Usage:
16
- .venv/bin/python -u scripts/grad_probe.py
17
- """
18
-
19
- from __future__ import annotations
20
-
21
- import os
22
- import sys
23
- from pathlib import Path
24
-
25
- # Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare`
26
- # resolve when we run from any cwd). Probe is intentionally a thin wrapper.
27
- HERE = Path(__file__).resolve().parent
28
- ROOT = HERE.parent
29
- sys.path.insert(0, str(ROOT))
30
-
31
- # Small model config to keep the probe fast (still exercises every component).
32
- # K=4 MTP (default), d_model=256 (default), n_layer=4 (default).
33
- os.environ.setdefault("HYDRA_D_MODEL", "256")
34
- os.environ.setdefault("HYDRA_N_LAYER", "4")
35
- os.environ.setdefault("HYDRA_MTP_K", "4")
36
-
37
- import torch # noqa: E402
38
-
39
- from train import PostSemClawModel, PostSemClawConfig # noqa: E402
40
-
41
-
42
- def main() -> int:
43
- device = "cuda" if torch.cuda.is_available() else "cpu"
44
- if device != "cuda":
45
- print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).")
46
- return 2
47
-
48
- cfg = PostSemClawConfig(
49
- sequence_len=64,
50
- vocab_size=8192,
51
- n_layer=int(os.environ["HYDRA_N_LAYER"]),
52
- d_model=int(os.environ["HYDRA_D_MODEL"]),
53
- d_state=64,
54
- headdim=32,
55
- n_heads=8,
56
- expand=2,
57
- engram_n_columns=1024,
58
- engram_key_dim=64,
59
- engram_layer_idx=1,
60
- sdr_n_bits=16384,
61
- sdr_target_active=327,
62
- sdr_delta_rank=32,
63
- sdr_som_warmup=500,
64
- sdr_som_interval=100,
65
- htm_n_columns=2048,
66
- htm_cells_per_column=32,
67
- mtp_k=int(os.environ["HYDRA_MTP_K"]),
68
- mtp_weight_decay=0.5,
69
- )
70
-
71
- print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} "
72
- f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}")
73
-
74
- torch.manual_seed(0)
75
- model = PostSemClawModel(cfg).to(device)
76
- model.init_weights()
77
- model.train()
78
-
79
- # ---- Enumerate params & optimizer group assignment ----
80
- all_params = list(model.named_parameters())
81
- print(f"[probe] total named parameters: {len(all_params)}")
82
-
83
- # Build optimizer to check group coverage (no step, no zero_grad).
84
- opt = model.setup_optimizer()
85
- grouped_ids: set[int] = set()
86
- for group in opt.param_groups:
87
- for p in group["params"]:
88
- grouped_ids.add(id(p))
89
- unique_param_ids = {id(p) for _, p in all_params}
90
- missing_from_opt = unique_param_ids - grouped_ids
91
- print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}")
92
- if missing_from_opt:
93
- print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups")
94
-
95
- # Tied weight check.
96
- tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr()
97
- print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}")
98
-
99
- # ---- One forward + backward under bf16 autocast ----
100
- B, T = 1, 64
101
- idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device)
102
- tgt = torch.roll(idx, -1, dims=1)
103
-
104
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
105
- loss = model(idx, targets=tgt)
106
- print(f"[probe] fwd loss = {float(loss.detach()):.4f}")
107
- loss.backward()
108
- torch.cuda.synchronize()
109
-
110
- # ---- Report ----
111
- blockers: list[str] = []
112
- zero_grads: list[str] = []
113
- unexpected_frozen: list[str] = []
114
- not_in_opt: list[str] = []
115
- rows: list[tuple[str, tuple, str, bool, bool, float, float]] = []
116
-
117
- for name, p in all_params:
118
- grad_is_none = p.grad is None
119
- if p.requires_grad and grad_is_none:
120
- blockers.append(name)
121
- rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
122
- p.requires_grad, True, float("nan"), float("nan")))
123
- continue
124
- if not p.requires_grad:
125
- unexpected_frozen.append(name)
126
- rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
127
- False, True, float("nan"), float("nan")))
128
- continue
129
- g = p.grad.detach().float()
130
- abs_mean = float(g.abs().mean().item())
131
- norm = float(g.norm().item())
132
- if abs_mean == 0.0 and norm == 0.0:
133
- zero_grads.append(name)
134
- if id(p) not in grouped_ids:
135
- not_in_opt.append(name)
136
- rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
137
- p.requires_grad, False, abs_mean, norm))
138
-
139
- # Pretty table
140
- print("\n[probe] per-parameter grad table:")
141
- print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}")
142
- for name, shape, dtype, rg, none, mean, norm in rows:
143
- shape_s = "x".join(str(s) for s in shape)
144
- rg_s = "Y" if rg else "N"
145
- none_s = "Y" if none else "N"
146
- if none:
147
- mean_s, norm_s = " nan ", " nan "
148
- else:
149
- mean_s = f"{mean:>10.3e}"
150
- norm_s = f"{norm:>10.3e}"
151
- print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}")
152
-
153
- # Identity checks
154
- print("\n[probe] identity checks:")
155
- print(f" id(wte.weight) = {id(model.wte.weight)}")
156
- print(f" id(lm_head.weight) = {id(model.lm_head.weight)}")
157
- print(f" same Python object = {model.wte.weight is model.lm_head.weight}")
158
- print(f" same storage ptr = {tied}")
159
-
160
- # Engram memory inspection
161
- print(f"\n[probe] engram.memory is nn.Parameter: "
162
- f"{isinstance(model.engram.memory, torch.nn.Parameter)}")
163
- print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}")
164
- if model.engram.memory.grad is None:
165
- print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())")
166
- else:
167
- g = model.engram.memory.grad.detach().float()
168
- print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}")
169
-
170
- # Stash flag sanity: _last_sdr should be uint8, no graph
171
- last = getattr(model, "_last_sdr", None)
172
- if last is not None:
173
- print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}")
174
- else:
175
- print("\n[probe] model._last_sdr is None (fwd didn't stash β€” ok if path changed)")
176
-
177
- # Summary
178
- print("\n[probe] ============ SUMMARY ============")
179
- print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}")
180
- for n in blockers:
181
- print(f" - {n}")
182
- print(f" WARNINGS (grad is literally zero): {len(zero_grads)}")
183
- for n in zero_grads:
184
- print(f" - {n}")
185
- print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}")
186
- for n in unexpected_frozen:
187
- print(f" - {n}")
188
- print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}")
189
- for n in not_in_opt:
190
- print(f" - {n}")
191
-
192
- return 0
193
-
194
-
195
- if __name__ == "__main__":
196
- sys.exit(main())
 
1
+ """
2
+ Gradient flow probe for PostSemClawModel.
3
+
4
+ READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT
5
+ step an optimizer. Runs one forward + backward and reports, per-parameter:
6
+
7
+ name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm
8
+
9
+ Severity classification at the bottom:
10
+ BLOCKER β€” requires_grad=True but p.grad is None (disconnected from graph)
11
+ WARNING β€” grad present but literally zero (ops cancel, wd_init, etc.)
12
+ WARNING β€” requires_grad=True but param missing from every optimizer group
13
+ OK β€” everything else
14
+
15
+ Usage:
16
+ .venv/bin/python -u scripts/grad_probe.py
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import os
22
+ import sys
23
+ from pathlib import Path
24
+
25
+ # Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare`
26
+ # resolve when we run from any cwd). Probe is intentionally a thin wrapper.
27
+ HERE = Path(__file__).resolve().parent
28
+ ROOT = HERE.parent
29
+ sys.path.insert(0, str(ROOT))
30
+
31
+ # Small model config to keep the probe fast (still exercises every component).
32
+ # K=4 MTP (default), d_model=256 (default), n_layer=4 (default).
33
+ os.environ.setdefault("HYDRA_D_MODEL", "256")
34
+ os.environ.setdefault("HYDRA_N_LAYER", "4")
35
+ os.environ.setdefault("HYDRA_MTP_K", "4")
36
+
37
+ import torch # noqa: E402
38
+
39
+ from train import PostSemClawModel, PostSemClawConfig # noqa: E402
40
+
41
+
42
+ def main() -> int:
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+ if device != "cuda":
45
+ print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).")
46
+ return 2
47
+
48
+ cfg = PostSemClawConfig(
49
+ sequence_len=64,
50
+ vocab_size=8192,
51
+ n_layer=int(os.environ["HYDRA_N_LAYER"]),
52
+ d_model=int(os.environ["HYDRA_D_MODEL"]),
53
+ d_state=64,
54
+ headdim=32,
55
+ n_heads=8,
56
+ expand=2,
57
+ engram_n_columns=1024,
58
+ engram_key_dim=64,
59
+ engram_layer_idx=1,
60
+ sdr_n_bits=16384,
61
+ sdr_target_active=327,
62
+ sdr_delta_rank=32,
63
+ sdr_som_warmup=500,
64
+ sdr_som_interval=100,
65
+ htm_n_columns=2048,
66
+ htm_cells_per_column=32,
67
+ mtp_k=int(os.environ["HYDRA_MTP_K"]),
68
+ mtp_weight_decay=0.5,
69
+ )
70
+
71
+ print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} "
72
+ f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}")
73
+
74
+ torch.manual_seed(0)
75
+ model = PostSemClawModel(cfg).to(device)
76
+ model.init_weights()
77
+ model.train()
78
+
79
+ # ---- Enumerate params & optimizer group assignment ----
80
+ all_params = list(model.named_parameters())
81
+ print(f"[probe] total named parameters: {len(all_params)}")
82
+
83
+ # Build optimizer to check group coverage (no step, no zero_grad).
84
+ opt = model.setup_optimizer()
85
+ grouped_ids: set[int] = set()
86
+ for group in opt.param_groups:
87
+ for p in group["params"]:
88
+ grouped_ids.add(id(p))
89
+ unique_param_ids = {id(p) for _, p in all_params}
90
+ missing_from_opt = unique_param_ids - grouped_ids
91
+ print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}")
92
+ if missing_from_opt:
93
+ print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups")
94
+
95
+ # Tied weight check.
96
+ tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr()
97
+ print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}")
98
+
99
+ # ---- One forward + backward under bf16 autocast ----
100
+ B, T = 1, 64
101
+ idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device)
102
+ tgt = torch.roll(idx, -1, dims=1)
103
+
104
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
105
+ loss = model(idx, targets=tgt)
106
+ print(f"[probe] fwd loss = {float(loss.detach()):.4f}")
107
+ loss.backward()
108
+ torch.cuda.synchronize()
109
+
110
+ # ---- Report ----
111
+ blockers: list[str] = []
112
+ zero_grads: list[str] = []
113
+ unexpected_frozen: list[str] = []
114
+ not_in_opt: list[str] = []
115
+ rows: list[tuple[str, tuple, str, bool, bool, float, float]] = []
116
+
117
+ for name, p in all_params:
118
+ grad_is_none = p.grad is None
119
+ if p.requires_grad and grad_is_none:
120
+ blockers.append(name)
121
+ rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
122
+ p.requires_grad, True, float("nan"), float("nan")))
123
+ continue
124
+ if not p.requires_grad:
125
+ unexpected_frozen.append(name)
126
+ rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
127
+ False, True, float("nan"), float("nan")))
128
+ continue
129
+ g = p.grad.detach().float()
130
+ abs_mean = float(g.abs().mean().item())
131
+ norm = float(g.norm().item())
132
+ if abs_mean == 0.0 and norm == 0.0:
133
+ zero_grads.append(name)
134
+ if id(p) not in grouped_ids:
135
+ not_in_opt.append(name)
136
+ rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
137
+ p.requires_grad, False, abs_mean, norm))
138
+
139
+ # Pretty table
140
+ print("\n[probe] per-parameter grad table:")
141
+ print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}")
142
+ for name, shape, dtype, rg, none, mean, norm in rows:
143
+ shape_s = "x".join(str(s) for s in shape)
144
+ rg_s = "Y" if rg else "N"
145
+ none_s = "Y" if none else "N"
146
+ if none:
147
+ mean_s, norm_s = " nan ", " nan "
148
+ else:
149
+ mean_s = f"{mean:>10.3e}"
150
+ norm_s = f"{norm:>10.3e}"
151
+ print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}")
152
+
153
+ # Identity checks
154
+ print("\n[probe] identity checks:")
155
+ print(f" id(wte.weight) = {id(model.wte.weight)}")
156
+ print(f" id(lm_head.weight) = {id(model.lm_head.weight)}")
157
+ print(f" same Python object = {model.wte.weight is model.lm_head.weight}")
158
+ print(f" same storage ptr = {tied}")
159
+
160
+ # Engram memory inspection
161
+ print(f"\n[probe] engram.memory is nn.Parameter: "
162
+ f"{isinstance(model.engram.memory, torch.nn.Parameter)}")
163
+ print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}")
164
+ if model.engram.memory.grad is None:
165
+ print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())")
166
+ else:
167
+ g = model.engram.memory.grad.detach().float()
168
+ print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}")
169
+
170
+ # Stash flag sanity: _last_sdr should be uint8, no graph
171
+ last = getattr(model, "_last_sdr", None)
172
+ if last is not None:
173
+ print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}")
174
+ else:
175
+ print("\n[probe] model._last_sdr is None (fwd didn't stash β€” ok if path changed)")
176
+
177
+ # Summary
178
+ print("\n[probe] ============ SUMMARY ============")
179
+ print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}")
180
+ for n in blockers:
181
+ print(f" - {n}")
182
+ print(f" WARNINGS (grad is literally zero): {len(zero_grads)}")
183
+ for n in zero_grads:
184
+ print(f" - {n}")
185
+ print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}")
186
+ for n in unexpected_frozen:
187
+ print(f" - {n}")
188
+ print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}")
189
+ for n in not_in_opt:
190
+ print(f" - {n}")
191
+
192
+ return 0
193
+
194
+
195
+ if __name__ == "__main__":
196
+ sys.exit(main())
overlay/scripts/launch_feather_hf_job.py CHANGED
@@ -211,9 +211,12 @@ def main() -> int:
211
  if not USE_SPACE_IMAGE:
212
  print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
213
 
 
214
  if DRY_RUN:
215
- if 'HYDRA_USE_NEMOTRON' not in os.environ and should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET):
216
  print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
 
 
217
  print('[launch] dry-run mode; skipping repo creation, upload, and job submission', flush=True)
218
  return 0
219
 
@@ -277,9 +280,12 @@ def main() -> int:
277
  'TRITON_CACHE_DIR': f'/workspace/triton_cache/{GPU_PROFILE}',
278
  'TRITON_CACHE_REPO': f'{routing.owner}/feather-triton-cache-{GPU_PROFILE}',
279
  }
280
- if 'HYDRA_USE_NEMOTRON' not in os.environ and should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET):
281
  env['HYDRA_USE_NEMOTRON'] = '1'
282
  print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
 
 
 
283
  # A10 compatibility profile: avoid known PTX/compile runtime pitfalls and
284
  # keep throughput path enabled. Caller can explicitly override each key by
285
  # setting it in the parent environment.
 
211
  if not USE_SPACE_IMAGE:
212
  print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
213
 
214
+ fast_start_streaming = should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET)
215
  if DRY_RUN:
216
+ if 'HYDRA_USE_NEMOTRON' not in os.environ and fast_start_streaming:
217
  print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
218
+ if 'HYDRA_LOCAL_SHARDS_ONLY' not in os.environ and fast_start_streaming:
219
+ print('[launch] auto-enabled HYDRA_LOCAL_SHARDS_ONLY=0 for Nemotron streaming fast-start profile', flush=True)
220
  print('[launch] dry-run mode; skipping repo creation, upload, and job submission', flush=True)
221
  return 0
222
 
 
280
  'TRITON_CACHE_DIR': f'/workspace/triton_cache/{GPU_PROFILE}',
281
  'TRITON_CACHE_REPO': f'{routing.owner}/feather-triton-cache-{GPU_PROFILE}',
282
  }
283
+ if 'HYDRA_USE_NEMOTRON' not in os.environ and fast_start_streaming:
284
  env['HYDRA_USE_NEMOTRON'] = '1'
285
  print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
286
+ if 'HYDRA_LOCAL_SHARDS_ONLY' not in os.environ and fast_start_streaming:
287
+ env['HYDRA_LOCAL_SHARDS_ONLY'] = '0'
288
+ print('[launch] auto-enabled HYDRA_LOCAL_SHARDS_ONLY=0 for Nemotron streaming fast-start profile', flush=True)
289
  # A10 compatibility profile: avoid known PTX/compile runtime pitfalls and
290
  # keep throughput path enabled. Caller can explicitly override each key by
291
  # setting it in the parent environment.
overlay/scripts/profile_forward.py CHANGED
@@ -1,87 +1,87 @@
1
- """Per-subsystem timing to find the tok/s bottleneck.
2
-
3
- Runs a single forward+backward at (B=8, T=2048) and times each stage via
4
- torch.cuda.Event. Reports ms/stage and derived tok/s budget.
5
- """
6
- import os, sys, time
7
- os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64")
8
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
- import torch
10
- from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN
11
-
12
- B, T = 8, MAX_SEQ_LEN
13
-
14
- def timeit(name, fn, warmup=1, n=3):
15
- for _ in range(warmup):
16
- fn(); torch.cuda.synchronize()
17
- s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
18
- times = []
19
- for _ in range(n):
20
- torch.cuda.synchronize()
21
- s.record(); fn(); e.record(); torch.cuda.synchronize()
22
- times.append(s.elapsed_time(e))
23
- avg = sum(times)/len(times)
24
- print(f" {name:30s} {avg:8.2f} ms (min {min(times):.2f} max {max(times):.2f})")
25
- return avg
26
-
27
- cfg = PostSemClawConfig()
28
- model = PostSemClawModel(cfg).cuda()
29
- model.init_weights()
30
- model.train()
31
- idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long)
32
- y = idx.clone()
33
-
34
- print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n")
35
-
36
- # Warmup full forward
37
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
38
- _ = model(idx, y)
39
- torch.cuda.synchronize()
40
-
41
- print("Stage times (3 iter avg):\n")
42
-
43
- # 1) wte
44
- timeit("wte embedding", lambda: model.wte(idx).sum().item())
45
-
46
- # 2) sdr_semantic (STE forward)
47
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
48
- timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item())
49
-
50
- # 3) sdr binary_only
51
- timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item())
52
-
53
- # 4) HTM full forward (with reset/learn)
54
- with torch.no_grad():
55
- timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item())
56
-
57
- # 5) Mamba block stack only
58
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
59
- def _blocks():
60
- x = model.wte(idx)
61
- from train import norm
62
- x = norm(x)
63
- streams = model.mhc[0].init_streams(x)
64
- for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)):
65
- def _bfn(h, _b=block): return _b(norm(h))
66
- streams = mhc_layer(streams, _bfn)
67
- x = model.mhc[-1].merge_streams(streams)
68
- return x.sum().item()
69
- timeit("Mamba+mHC blocks (n_layer=4)", _blocks)
70
-
71
- # 6) Full forward+loss
72
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
73
- timeit("FULL forward+loss", lambda: model(idx, y).item())
74
-
75
- # 7) Full forward+loss+backward
76
- def full_fwd_bwd():
77
- model.zero_grad(set_to_none=True)
78
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
79
- loss = model(idx, y)
80
- loss.backward()
81
- return loss.item()
82
- t_full = timeit("FULL forward+backward", full_fwd_bwd)
83
-
84
- print()
85
- print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens")
86
- print(f"tok/s per forward: {B*T / (t_full/1000):.0f}")
87
- print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s")
 
1
+ """Per-subsystem timing to find the tok/s bottleneck.
2
+
3
+ Runs a single forward+backward at (B=8, T=2048) and times each stage via
4
+ torch.cuda.Event. Reports ms/stage and derived tok/s budget.
5
+ """
6
+ import os, sys, time
7
+ os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64")
8
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+ import torch
10
+ from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN
11
+
12
+ B, T = 8, MAX_SEQ_LEN
13
+
14
+ def timeit(name, fn, warmup=1, n=3):
15
+ for _ in range(warmup):
16
+ fn(); torch.cuda.synchronize()
17
+ s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
18
+ times = []
19
+ for _ in range(n):
20
+ torch.cuda.synchronize()
21
+ s.record(); fn(); e.record(); torch.cuda.synchronize()
22
+ times.append(s.elapsed_time(e))
23
+ avg = sum(times)/len(times)
24
+ print(f" {name:30s} {avg:8.2f} ms (min {min(times):.2f} max {max(times):.2f})")
25
+ return avg
26
+
27
+ cfg = PostSemClawConfig()
28
+ model = PostSemClawModel(cfg).cuda()
29
+ model.init_weights()
30
+ model.train()
31
+ idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long)
32
+ y = idx.clone()
33
+
34
+ print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n")
35
+
36
+ # Warmup full forward
37
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
38
+ _ = model(idx, y)
39
+ torch.cuda.synchronize()
40
+
41
+ print("Stage times (3 iter avg):\n")
42
+
43
+ # 1) wte
44
+ timeit("wte embedding", lambda: model.wte(idx).sum().item())
45
+
46
+ # 2) sdr_semantic (STE forward)
47
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
48
+ timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item())
49
+
50
+ # 3) sdr binary_only
51
+ timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item())
52
+
53
+ # 4) HTM full forward (with reset/learn)
54
+ with torch.no_grad():
55
+ timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item())
56
+
57
+ # 5) Mamba block stack only
58
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
59
+ def _blocks():
60
+ x = model.wte(idx)
61
+ from train import norm
62
+ x = norm(x)
63
+ streams = model.mhc[0].init_streams(x)
64
+ for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)):
65
+ def _bfn(h, _b=block): return _b(norm(h))
66
+ streams = mhc_layer(streams, _bfn)
67
+ x = model.mhc[-1].merge_streams(streams)
68
+ return x.sum().item()
69
+ timeit("Mamba+mHC blocks (n_layer=4)", _blocks)
70
+
71
+ # 6) Full forward+loss
72
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
73
+ timeit("FULL forward+loss", lambda: model(idx, y).item())
74
+
75
+ # 7) Full forward+loss+backward
76
+ def full_fwd_bwd():
77
+ model.zero_grad(set_to_none=True)
78
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
79
+ loss = model(idx, y)
80
+ loss.backward()
81
+ return loss.item()
82
+ t_full = timeit("FULL forward+backward", full_fwd_bwd)
83
+
84
+ print()
85
+ print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens")
86
+ print(f"tok/s per forward: {B*T / (t_full/1000):.0f}")
87
+ print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s")
overlay/scripts/run_domain_expanded_pretrain.sh CHANGED
@@ -188,11 +188,7 @@ fi
188
 
189
  RESUME_PATH="$(resolve_resume_path || true)"
190
 
191
- # Only inject WSL library paths when running on WSL. Cloud containers
192
- # (H200/A10G HF Jobs) already have their driver paths set by entrypoint.py.
193
- if [[ -d /usr/lib/wsl/lib ]]; then
194
- export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
195
- fi
196
  export HYDRA_TIME_BUDGET="${HYDRA_TIME_BUDGET:-28800}"
197
  export HYDRA_TARGET_SHARDS="$TARGET_SHARDS"
198
  export HYDRA_DOWNLOAD_WORKERS="$DOWNLOAD_WORKERS"
 
188
 
189
  RESUME_PATH="$(resolve_resume_path || true)"
190
 
191
+ export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
 
 
 
 
192
  export HYDRA_TIME_BUDGET="${HYDRA_TIME_BUDGET:-28800}"
193
  export HYDRA_TARGET_SHARDS="$TARGET_SHARDS"
194
  export HYDRA_DOWNLOAD_WORKERS="$DOWNLOAD_WORKERS"
overlay/scripts/sample_utils.py CHANGED
@@ -1,107 +1,107 @@
1
- """Shared sampling utilities for chat.py / chat_eval.py.
2
-
3
- Pure functions: given a 1-D logits tensor (vocab_size,), return a single
4
- sampled token id. No model/tokenizer knowledge here.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- from typing import Iterable, Optional
10
-
11
- import torch
12
-
13
-
14
- def apply_repetition_penalty(
15
- logits: torch.Tensor,
16
- recent_tokens: Optional[Iterable[int]],
17
- penalty: float,
18
- ) -> torch.Tensor:
19
- """Divide logits of recent positive tokens by `penalty`, multiply negatives.
20
-
21
- Operates in-place on a *copy* (logits is cloned first by caller if needed).
22
- `recent_tokens` may be any iterable of ints; duplicates are deduped internally.
23
- """
24
- if penalty == 1.0 or not recent_tokens:
25
- return logits
26
- seen = set(int(t) for t in recent_tokens)
27
- if not seen:
28
- return logits
29
- idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
30
- vals = logits.index_select(0, idx)
31
- vals = torch.where(vals > 0, vals / penalty, vals * penalty)
32
- logits.index_copy_(0, idx, vals)
33
- return logits
34
-
35
-
36
- def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
37
- """Keep only the top-k logits; set the rest to -inf.
38
-
39
- top_k<=0 or top_k>=vocab disables the filter."""
40
- if top_k <= 0 or top_k >= logits.size(-1):
41
- return logits
42
- topk_vals, topk_idx = logits.topk(top_k)
43
- mask = torch.full_like(logits, float("-inf"))
44
- mask.scatter_(0, topk_idx, topk_vals)
45
- return mask
46
-
47
-
48
- def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
49
- """Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p."""
50
- if top_p >= 1.0 or top_p <= 0.0:
51
- return logits
52
- sorted_logits, sorted_idx = logits.sort(descending=True)
53
- cumulative_probs = sorted_logits.softmax(-1).cumsum(-1)
54
- mask = cumulative_probs > top_p
55
- # shift right so we always keep at least one token
56
- mask[1:] = mask[:-1].clone()
57
- mask[0] = False
58
- sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
59
- out = torch.full_like(logits, float("-inf"))
60
- out.scatter_(0, sorted_idx, sorted_logits)
61
- return out
62
-
63
-
64
- def sample_token(
65
- logits: torch.Tensor,
66
- temperature: float = 1.0,
67
- top_k: int = 0,
68
- top_p: float = 1.0,
69
- repetition_penalty: float = 1.0,
70
- recent_tokens: Optional[Iterable[int]] = None,
71
- ) -> int:
72
- """Return a single sampled token id (Python int).
73
-
74
- logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe.
75
- """
76
- if logits.dim() != 1:
77
- raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}")
78
-
79
- # Work in fp32 on a clone so the caller's tensor is unchanged.
80
- work = logits.detach().to(torch.float32).clone()
81
-
82
- if repetition_penalty != 1.0 and recent_tokens is not None:
83
- work = apply_repetition_penalty(work, recent_tokens, repetition_penalty)
84
-
85
- # Temperature. Greedy when temperature <= 0.
86
- if temperature <= 0.0:
87
- return int(work.argmax().item())
88
- work = work / max(temperature, 1e-6)
89
-
90
- work = apply_top_k(work, top_k)
91
- work = apply_top_p(work, top_p)
92
-
93
- # Guard against all-(-inf) (can happen if top_k/top_p filter everything out).
94
- if torch.isinf(work).all():
95
- return int(logits.argmax().item())
96
-
97
- probs = torch.softmax(work, dim=-1)
98
- # Numerical safety β€” replace any NaN with 0 and renormalize.
99
- if torch.isnan(probs).any():
100
- probs = torch.nan_to_num(probs, nan=0.0)
101
- s = probs.sum()
102
- if s <= 0:
103
- return int(logits.argmax().item())
104
- probs = probs / s
105
-
106
- tok = torch.multinomial(probs, num_samples=1)
107
- return int(tok.item())
 
1
+ """Shared sampling utilities for chat.py / chat_eval.py.
2
+
3
+ Pure functions: given a 1-D logits tensor (vocab_size,), return a single
4
+ sampled token id. No model/tokenizer knowledge here.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Iterable, Optional
10
+
11
+ import torch
12
+
13
+
14
+ def apply_repetition_penalty(
15
+ logits: torch.Tensor,
16
+ recent_tokens: Optional[Iterable[int]],
17
+ penalty: float,
18
+ ) -> torch.Tensor:
19
+ """Divide logits of recent positive tokens by `penalty`, multiply negatives.
20
+
21
+ Operates in-place on a *copy* (logits is cloned first by caller if needed).
22
+ `recent_tokens` may be any iterable of ints; duplicates are deduped internally.
23
+ """
24
+ if penalty == 1.0 or not recent_tokens:
25
+ return logits
26
+ seen = set(int(t) for t in recent_tokens)
27
+ if not seen:
28
+ return logits
29
+ idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
30
+ vals = logits.index_select(0, idx)
31
+ vals = torch.where(vals > 0, vals / penalty, vals * penalty)
32
+ logits.index_copy_(0, idx, vals)
33
+ return logits
34
+
35
+
36
+ def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
37
+ """Keep only the top-k logits; set the rest to -inf.
38
+
39
+ top_k<=0 or top_k>=vocab disables the filter."""
40
+ if top_k <= 0 or top_k >= logits.size(-1):
41
+ return logits
42
+ topk_vals, topk_idx = logits.topk(top_k)
43
+ mask = torch.full_like(logits, float("-inf"))
44
+ mask.scatter_(0, topk_idx, topk_vals)
45
+ return mask
46
+
47
+
48
+ def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
49
+ """Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p."""
50
+ if top_p >= 1.0 or top_p <= 0.0:
51
+ return logits
52
+ sorted_logits, sorted_idx = logits.sort(descending=True)
53
+ cumulative_probs = sorted_logits.softmax(-1).cumsum(-1)
54
+ mask = cumulative_probs > top_p
55
+ # shift right so we always keep at least one token
56
+ mask[1:] = mask[:-1].clone()
57
+ mask[0] = False
58
+ sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
59
+ out = torch.full_like(logits, float("-inf"))
60
+ out.scatter_(0, sorted_idx, sorted_logits)
61
+ return out
62
+
63
+
64
+ def sample_token(
65
+ logits: torch.Tensor,
66
+ temperature: float = 1.0,
67
+ top_k: int = 0,
68
+ top_p: float = 1.0,
69
+ repetition_penalty: float = 1.0,
70
+ recent_tokens: Optional[Iterable[int]] = None,
71
+ ) -> int:
72
+ """Return a single sampled token id (Python int).
73
+
74
+ logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe.
75
+ """
76
+ if logits.dim() != 1:
77
+ raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}")
78
+
79
+ # Work in fp32 on a clone so the caller's tensor is unchanged.
80
+ work = logits.detach().to(torch.float32).clone()
81
+
82
+ if repetition_penalty != 1.0 and recent_tokens is not None:
83
+ work = apply_repetition_penalty(work, recent_tokens, repetition_penalty)
84
+
85
+ # Temperature. Greedy when temperature <= 0.
86
+ if temperature <= 0.0:
87
+ return int(work.argmax().item())
88
+ work = work / max(temperature, 1e-6)
89
+
90
+ work = apply_top_k(work, top_k)
91
+ work = apply_top_p(work, top_p)
92
+
93
+ # Guard against all-(-inf) (can happen if top_k/top_p filter everything out).
94
+ if torch.isinf(work).all():
95
+ return int(logits.argmax().item())
96
+
97
+ probs = torch.softmax(work, dim=-1)
98
+ # Numerical safety β€” replace any NaN with 0 and renormalize.
99
+ if torch.isnan(probs).any():
100
+ probs = torch.nan_to_num(probs, nan=0.0)
101
+ s = probs.sum()
102
+ if s <= 0:
103
+ return int(logits.argmax().item())
104
+ probs = probs / s
105
+
106
+ tok = torch.multinomial(probs, num_samples=1)
107
+ return int(tok.item())
overlay/scripts/sft.py CHANGED
@@ -1,559 +1,559 @@
1
- """HYDRA SFT β€” instruction fine-tune the pretrained 7.5M-param base.
2
-
3
- Mode selection:
4
- MODE=resume_from_pretrain iff ~/.cache/autoresearch/pretrain_final.pt
5
- exists AND loads cleanly into a fresh model.
6
- MODE=from_scratch otherwise (degraded fallback).
7
-
8
- Data: int16 shards written by `scripts/download_sft_data.py`, paired with
9
- uint8 loss-mask shards (1 on assistant tokens, 0 on user-prompt tokens).
10
- At runtime we pack consecutive examples into fixed-length rows; prompt
11
- positions get target=-1 so CE's `ignore_index=-1` drops them.
12
-
13
- Env vars (with defaults tuned for RTX 3060 6GB, 7.5M params):
14
- HYDRA_SFT_TIME_BUDGET 10800 SFT wall-clock budget (3h)
15
- HYDRA_SFT_SEQ_LEN 512 sequence length during SFT
16
- HYDRA_BATCH_SIZE 4 per-step device batch
17
- HYDRA_TOTAL_BATCH 8192 effective batch (grad-accum derived)
18
- HYDRA_SFT_LR_MULT 0.10 multiply pretrain LRs by this
19
- HYDRA_SFT_EVAL_INTERVAL 500 steps between sample generations
20
- HYDRA_SFT_CKPT_INTERVAL 2000 steps between interim checkpoints
21
-
22
- CLI:
23
- --dry-run load model+data, run 1 step, exit (validation path)
24
- --eval-only load `sft_final.pt`, run sample gen, exit
25
- """
26
-
27
- from __future__ import annotations
28
-
29
- import argparse
30
- import json
31
- import math
32
- import os
33
- import sys
34
- import time
35
- from dataclasses import asdict
36
- from pathlib import Path
37
-
38
- import numpy as np
39
- import torch
40
-
41
- # Repo root on path
42
- _REPO_ROOT = Path(__file__).resolve().parent.parent
43
- if str(_REPO_ROOT) not in sys.path:
44
- sys.path.insert(0, str(_REPO_ROOT))
45
-
46
- # Must import hydra.config BEFORE touching torch.cuda for CUDA env setup
47
- from hydra.config import (
48
- ADAM_BETAS, D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMBEDDING_LR,
49
- ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
50
- FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
51
- N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
52
- UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY,
53
- )
54
- from hydra.model import PostSemClawModel
55
- from prepare import Tokenizer
56
-
57
- # Use line-buffered stdout
58
- try:
59
- sys.stdout.reconfigure(line_buffering=True)
60
- except Exception:
61
- pass
62
-
63
-
64
- # ---------------------------------------------------------------------------
65
- # Paths
66
- # ---------------------------------------------------------------------------
67
-
68
- CACHE_DIR = Path.home() / ".cache" / "autoresearch"
69
- PRETRAIN_CKPT = CACHE_DIR / "pretrain_final.pt"
70
- SFT_FINAL_CKPT = CACHE_DIR / "sft_final.pt"
71
- SFT_INTERIM_CKPT = CACHE_DIR / "sft_interim.pt"
72
- SFT_DATA_DIR = _REPO_ROOT / "data" / "sft"
73
-
74
-
75
- # ---------------------------------------------------------------------------
76
- # Env vars for SFT
77
- # ---------------------------------------------------------------------------
78
-
79
- SFT_TIME_BUDGET = int(os.environ.get("HYDRA_SFT_TIME_BUDGET", "10800"))
80
- SFT_SEQ_LEN = int(os.environ.get("HYDRA_SFT_SEQ_LEN", "512"))
81
- SFT_LR_MULT = float(os.environ.get("HYDRA_SFT_LR_MULT", "0.10"))
82
- SFT_EVAL_INTERVAL = int(os.environ.get("HYDRA_SFT_EVAL_INTERVAL", "500"))
83
- SFT_CKPT_INTERVAL = int(os.environ.get("HYDRA_SFT_CKPT_INTERVAL", "2000"))
84
-
85
-
86
- # ---------------------------------------------------------------------------
87
- # Data loading
88
- # ---------------------------------------------------------------------------
89
-
90
- def _load_meta() -> dict:
91
- meta_path = SFT_DATA_DIR / "meta.json"
92
- if not meta_path.exists():
93
- raise FileNotFoundError(
94
- f"SFT meta not found at {meta_path}. Run "
95
- f"`python scripts/download_sft_data.py` first."
96
- )
97
- with open(meta_path) as f:
98
- return json.load(f)
99
-
100
-
101
- def _load_shards():
102
- """Load all shard_XXX.bin + mask_XXX.bin as big flat arrays.
103
-
104
- Returns: (tokens: np.int64, mask: np.uint8)
105
- Both arrays are 1-D and the same length. Total len ~= target_tokens.
106
- """
107
- tok_shards = sorted(SFT_DATA_DIR.glob("shard_*.bin"))
108
- mask_shards = sorted(SFT_DATA_DIR.glob("mask_*.bin"))
109
- if not tok_shards:
110
- raise FileNotFoundError(f"No SFT shards in {SFT_DATA_DIR}")
111
- assert len(tok_shards) == len(mask_shards), (
112
- f"shard/mask count mismatch: {len(tok_shards)} vs {len(mask_shards)}"
113
- )
114
- tok_parts = []
115
- mask_parts = []
116
- for t, m in zip(tok_shards, mask_shards):
117
- tok_parts.append(np.fromfile(str(t), dtype=np.int16).astype(np.int64))
118
- mask_parts.append(np.fromfile(str(m), dtype=np.uint8))
119
- tokens = np.concatenate(tok_parts)
120
- mask = np.concatenate(mask_parts)
121
- assert tokens.shape == mask.shape
122
- # Guard against negative int16 values (unlikely with vocab=8192 but defensive)
123
- if tokens.min() < 0 or tokens.max() >= 8192:
124
- raise ValueError(
125
- f"Token IDs out of range: min={tokens.min()} max={tokens.max()}"
126
- )
127
- return tokens, mask
128
-
129
-
130
- def make_sft_dataloader(tokens: np.ndarray, mask: np.ndarray, B: int, T: int,
131
- device: torch.device, seed: int = 0):
132
- """Yield (x, y, epoch) forever.
133
-
134
- Each row is a slice of length T+1 sampled at a random start. We produce:
135
- x = slice[:-1] (B, T) int64 on device
136
- y = slice[1:] with mask=0 -> -1 (B, T) int64 on device
137
-
138
- The mask applies to target positions (y), not inputs. This way the
139
- chunked CE loss in model.forward sees ignore_index=-1 for prompt tokens.
140
- """
141
- N = tokens.shape[0]
142
- rng = np.random.default_rng(seed)
143
- # Pin CPU tensors; copy to GPU non-blocking.
144
- cpu_x = torch.empty(B, T, dtype=torch.long, pin_memory=True)
145
- cpu_y = torch.empty(B, T, dtype=torch.long, pin_memory=True)
146
- epoch = 1
147
- samples_drawn = 0
148
- samples_per_epoch = max(1, N // (T + 1))
149
-
150
- # Minimum loss-positions per window. If a sampled window has fewer than
151
- # this many assistant tokens, resample. Guards against all-prompt windows
152
- # producing NaN from 0/0 in the chunked CE loss.
153
- min_loss_positions = max(1, T // 32)
154
- max_resample = 8
155
-
156
- while True:
157
- for b in range(B):
158
- # Sample a starting index with a light rejection filter to ensure
159
- # the window contains enough assistant (mask=1) positions.
160
- if N <= T + 1:
161
- start = 0
162
- else:
163
- start = int(rng.integers(0, N - T - 1))
164
- for _ in range(max_resample):
165
- loss_in_window = int(mask[start + 1:start + T + 1].sum())
166
- if loss_in_window >= min_loss_positions:
167
- break
168
- start = int(rng.integers(0, N - T - 1))
169
- window_tok = tokens[start:start + T + 1]
170
- window_mask = mask[start:start + T + 1]
171
- # x = window[:-1], y = window[1:]
172
- cpu_x[b].copy_(torch.from_numpy(window_tok[:-1].astype(np.int64)))
173
- y_slice = window_tok[1:].astype(np.int64).copy()
174
- # Apply mask to targets: mask=0 (prompt) -> target=-1 (ignore)
175
- y_slice[window_mask[1:] == 0] = -1
176
- # Final guard: if no loss positions survived, force at least 1
177
- # valid target so the batch doesn't produce NaN (it's rare with
178
- # the rejection filter but defensive is cheap).
179
- if (y_slice != -1).sum() == 0:
180
- y_slice[-1] = int(window_tok[-1])
181
- cpu_y[b].copy_(torch.from_numpy(y_slice))
182
- x = cpu_x.to(device, non_blocking=True)
183
- y = cpu_y.to(device, non_blocking=True)
184
- samples_drawn += B
185
- if samples_drawn >= samples_per_epoch:
186
- epoch += 1
187
- samples_drawn = 0
188
- yield x, y, epoch
189
-
190
-
191
- # ---------------------------------------------------------------------------
192
- # Model init + checkpoint load
193
- # ---------------------------------------------------------------------------
194
-
195
- def _peek_pretrain_config(vocab_size: int) -> PostSemClawConfig | None:
196
- """If pretrain checkpoint exists, return its saved config so we build
197
- the SFT model with matching architecture. Returns None if unavailable."""
198
- if not PRETRAIN_CKPT.exists():
199
- return None
200
- try:
201
- ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cpu",
202
- weights_only=False)
203
- cfg_dict = ckpt.get("config")
204
- if cfg_dict is None:
205
- return None
206
- # Override sequence_len to SFT's (shorter context) β€” architecture
207
- # is independent of sequence_len since Mamba3 is recurrent.
208
- cfg_dict = dict(cfg_dict)
209
- cfg_dict["sequence_len"] = SFT_SEQ_LEN
210
- cfg_dict["vocab_size"] = vocab_size
211
- cfg = PostSemClawConfig(**cfg_dict)
212
- return cfg
213
- except Exception as e:
214
- print(f"[model] could not peek pretrain config: {type(e).__name__}: {e}",
215
- flush=True)
216
- return None
217
-
218
-
219
- def build_model(vocab_size: int, device: torch.device) -> PostSemClawModel:
220
- # Prefer checkpoint-derived config if available (guards against env-var drift)
221
- config = _peek_pretrain_config(vocab_size)
222
- if config is None:
223
- config = PostSemClawConfig(
224
- sequence_len=SFT_SEQ_LEN,
225
- vocab_size=vocab_size,
226
- n_layer=N_LAYER,
227
- d_model=D_MODEL,
228
- d_state=D_STATE,
229
- headdim=HEADDIM,
230
- n_heads=N_HEADS,
231
- expand=EXPAND,
232
- engram_n_columns=ENGRAM_N_COLUMNS,
233
- engram_key_dim=ENGRAM_KEY_DIM,
234
- engram_layer_idx=ENGRAM_LAYER_IDX,
235
- )
236
- print(f"[model] config (from env, no ckpt): {asdict(config)}", flush=True)
237
- else:
238
- print(f"[model] config (from pretrain ckpt): {asdict(config)}", flush=True)
239
- with torch.device("meta"):
240
- model = PostSemClawModel(config)
241
- model.to_empty(device=device)
242
- model.init_weights()
243
- return model
244
-
245
-
246
- def try_load_pretrain(model: PostSemClawModel) -> tuple[bool, str]:
247
- """Attempt to load pretrain checkpoint into model. Returns (loaded, msg)."""
248
- if not PRETRAIN_CKPT.exists():
249
- return False, f"no checkpoint at {PRETRAIN_CKPT}"
250
- try:
251
- ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cuda",
252
- weights_only=False)
253
- state = ckpt.get("model_state_dict", ckpt)
254
- # Use strict=False in case SDR/HTM params are excluded from state_dict
255
- # by torch.compile wrappers or similar.
256
- missing, unexpected = model.load_state_dict(state, strict=False)
257
- msg = (f"loaded {PRETRAIN_CKPT} β€” missing={len(missing)} "
258
- f"unexpected={len(unexpected)}")
259
- if missing:
260
- # Log first few missing keys to help diagnose architecture skew
261
- msg += f" first_missing={missing[:3]}"
262
- return True, msg
263
- except Exception as e:
264
- return False, f"load failed: {type(e).__name__}: {e}"
265
-
266
-
267
- # ---------------------------------------------------------------------------
268
- # Sample generation (for in-training eval prints)
269
- # ---------------------------------------------------------------------------
270
-
271
- _SAMPLE_PROMPTS = [
272
- "What is the capital of France?",
273
- "Write a haiku about winter.",
274
- "List three colors.",
275
- "How are you?",
276
- "Explain why the sky is blue in one sentence.",
277
- ]
278
-
279
-
280
- @torch.no_grad()
281
- def sample_once(model, tokenizer, meta: dict, prompt: str,
282
- max_new: int = 64, temperature: float = 0.8,
283
- top_k: int = 40) -> str:
284
- """Generate a chat-formatted reply. Stops on <|end|> or max_new tokens."""
285
- bos = meta["special_tokens"]["bos"]
286
- user = meta["special_tokens"]["user"]
287
- assistant = meta["special_tokens"]["assistant"]
288
- end = meta["special_tokens"]["end"]
289
-
290
- prompt_ids = [bos, user] + tokenizer.encode("\n" + prompt.strip())
291
- prompt_ids += tokenizer.encode("\n")
292
- prompt_ids.append(assistant)
293
- prompt_ids += tokenizer.encode("\n")
294
-
295
- ctx = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
296
- generated: list[int] = []
297
- for _ in range(max_new):
298
- with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
299
- logits = model(ctx, targets=None)
300
- last = logits[0, -1].float()
301
- if top_k and top_k < last.shape[-1]:
302
- kth = torch.topk(last, top_k).values[-1]
303
- last = torch.where(last < kth, torch.full_like(last, -1e9), last)
304
- probs = torch.softmax(last / max(temperature, 1e-6), dim=-1)
305
- next_id = int(torch.multinomial(probs, num_samples=1).item())
306
- generated.append(next_id)
307
- if next_id == end:
308
- break
309
- ctx = torch.cat(
310
- [ctx, torch.tensor([[next_id]], device="cuda", dtype=torch.long)],
311
- dim=1,
312
- )
313
- # Hard cap on ctx length (model was trained at 2048, SFT at 512,
314
- # but inference could theoretically go longer)
315
- if ctx.size(1) >= 2048:
316
- break
317
- try:
318
- text = tokenizer.decode(generated)
319
- except Exception:
320
- text = "<decode error>"
321
- return text
322
-
323
-
324
- def run_samples(model, tokenizer, meta: dict, step: int):
325
- model.eval()
326
- print(f"\n=== SFT samples @ step {step} ===", flush=True)
327
- for p in _SAMPLE_PROMPTS:
328
- try:
329
- resp = sample_once(model, tokenizer, meta, p)
330
- except Exception as e:
331
- resp = f"<sample failed: {type(e).__name__}: {e}>"
332
- # Sanitize newlines for log readability
333
- resp_clean = resp.replace("\n", " ⏎ ").replace("\r", " ")
334
- print(f" prompt: {p!r}")
335
- print(f" reply: {resp_clean!r}")
336
- print("=== end samples ===\n", flush=True)
337
- model.train()
338
-
339
-
340
- # ---------------------------------------------------------------------------
341
- # Checkpoint save
342
- # ---------------------------------------------------------------------------
343
-
344
- def save_ckpt(model, step: int, smoothed_loss: float, path: Path,
345
- mode: str, meta: dict):
346
- try:
347
- CACHE_DIR.mkdir(parents=True, exist_ok=True)
348
- payload = {
349
- "model_state_dict": model.state_dict(),
350
- "step": step,
351
- "smoothed_loss": smoothed_loss,
352
- "mode": mode,
353
- "sft_meta": meta,
354
- }
355
- torch.save(payload, str(path))
356
- print(f"[ckpt] saved {path} (step={step})", flush=True)
357
- except Exception as e:
358
- print(f"[ckpt] SAVE FAILED {path}: {type(e).__name__}: {e}", flush=True)
359
-
360
-
361
- # ---------------------------------------------------------------------------
362
- # Main
363
- # ---------------------------------------------------------------------------
364
-
365
- def main():
366
- ap = argparse.ArgumentParser()
367
- ap.add_argument("--dry-run", action="store_true",
368
- help="Load model+data, run 1 step, exit.")
369
- ap.add_argument("--eval-only", action="store_true",
370
- help="Load sft_final.pt and run sample gen.")
371
- args = ap.parse_args()
372
-
373
- t_start = time.time()
374
- torch.manual_seed(SEED + 1) # +1 so SFT draws different RNG than pretrain
375
- torch.cuda.manual_seed(SEED + 1)
376
- torch.set_float32_matmul_precision("high")
377
- device = torch.device("cuda")
378
- autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
379
-
380
- # --- Tokenizer ---
381
- tokenizer = Tokenizer.from_directory()
382
- vocab_size = tokenizer.get_vocab_size()
383
- print(f"[init] vocab: {vocab_size}", flush=True)
384
-
385
- # --- Data meta ---
386
- meta = _load_meta()
387
- print(f"[data] meta: {meta}", flush=True)
388
-
389
- # --- Model ---
390
- model = build_model(vocab_size, device)
391
- n_params = sum(p.numel() for p in model.parameters())
392
- print(f"[model] params: {n_params:,}", flush=True)
393
-
394
- loaded, msg = try_load_pretrain(model)
395
- mode = "resume_from_pretrain" if loaded else "from_scratch"
396
- print(f"[init] MODE={mode} :: {msg}", flush=True)
397
-
398
- # --- Eval-only path ---
399
- if args.eval_only:
400
- if SFT_FINAL_CKPT.exists():
401
- ckpt = torch.load(str(SFT_FINAL_CKPT), map_location=device,
402
- weights_only=False)
403
- state = ckpt.get("model_state_dict", ckpt)
404
- model.load_state_dict(state, strict=False)
405
- print(f"[eval-only] loaded {SFT_FINAL_CKPT}", flush=True)
406
- else:
407
- print(f"[eval-only] no SFT checkpoint β€” running on current weights",
408
- flush=True)
409
- run_samples(model, tokenizer, meta, step=-1)
410
- return
411
-
412
- # --- Dataloader ---
413
- print(f"[data] loading shards ...", flush=True)
414
- tokens, mask = _load_shards()
415
- print(f"[data] tokens: {len(tokens):,} loss-positions: {int(mask.sum()):,}",
416
- flush=True)
417
- B = DEVICE_BATCH_SIZE
418
- T = SFT_SEQ_LEN
419
- tokens_per_fwdbwd = B * T
420
- assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, (
421
- f"TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} not divisible by B*T={tokens_per_fwdbwd}"
422
- )
423
- grad_accum = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
424
- print(f"[train] B={B} T={T} accum={grad_accum} effective_batch={TOTAL_BATCH_SIZE}",
425
- flush=True)
426
- loader = make_sft_dataloader(tokens, mask, B, T, device, seed=SEED + 7)
427
- x, y, epoch = next(loader)
428
-
429
- # --- Optimizer (scaled LRs) ---
430
- matrix_lr = MATRIX_LR * SFT_LR_MULT
431
- embed_lr = EMBEDDING_LR * SFT_LR_MULT
432
- unembed_lr = UNEMBEDDING_LR * SFT_LR_MULT
433
- scalar_lr = SCALAR_LR * SFT_LR_MULT
434
- print(f"[opt] LRs scaled by {SFT_LR_MULT}: matrix={matrix_lr:.5f} "
435
- f"embed={embed_lr:.5f} unembed={unembed_lr:.6f}", flush=True)
436
- optimizer = model.setup_optimizer(
437
- unembedding_lr=unembed_lr,
438
- embedding_lr=embed_lr,
439
- scalar_lr=scalar_lr,
440
- adam_betas=ADAM_BETAS,
441
- matrix_lr=matrix_lr,
442
- weight_decay=WEIGHT_DECAY,
443
- )
444
-
445
- # --- Dry-run path (validation) ---
446
- if args.dry_run:
447
- print("[dry-run] running 1 step ...", flush=True)
448
- with autocast_ctx:
449
- loss = model(x, y)
450
- loss_f = float(loss.item())
451
- print(f"[dry-run] step0 loss={loss_f:.4f}", flush=True)
452
- loss.backward()
453
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
454
- optimizer.step()
455
- model.zero_grad(set_to_none=True)
456
- if math.isnan(loss_f) or loss_f > 100:
457
- print("[dry-run] FAILED (NaN / huge loss)", flush=True)
458
- sys.exit(1)
459
- print("[dry-run] OK", flush=True)
460
- return
461
-
462
- # --- Training loop ---
463
- print(f"[train] budget={SFT_TIME_BUDGET}s eval_every={SFT_EVAL_INTERVAL} "
464
- f"ckpt_every={SFT_CKPT_INTERVAL}", flush=True)
465
- t_loop_start = time.time()
466
- smooth_loss = 0.0
467
- step = 0
468
- total_train_secs = 0.0
469
-
470
- # Warmup schedule for SFT: linear 0->1 over first 5% of budget, then cosine.
471
- sft_warmup_frac = 0.05
472
-
473
- def lr_mult(progress: float) -> float:
474
- if progress < sft_warmup_frac:
475
- return progress / sft_warmup_frac if sft_warmup_frac > 0 else 1.0
476
- decay = (progress - sft_warmup_frac) / (1.0 - sft_warmup_frac)
477
- return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * \
478
- (1 + math.cos(math.pi * decay))
479
-
480
- while True:
481
- torch.cuda.synchronize()
482
- t0 = time.time()
483
- for _ in range(grad_accum):
484
- with autocast_ctx:
485
- loss = model(x, y)
486
- train_loss_val = loss.detach()
487
- (loss / grad_accum).backward()
488
- x, y, epoch = next(loader)
489
-
490
- progress = min(total_train_secs / SFT_TIME_BUDGET, 1.0)
491
- mult = lr_mult(progress)
492
- for group in optimizer.param_groups:
493
- group["lr"] = group["initial_lr"] * mult
494
-
495
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
496
- optimizer.step()
497
- model.zero_grad(set_to_none=True)
498
-
499
- loss_f = float(train_loss_val.item())
500
- if math.isnan(loss_f) or loss_f > 100:
501
- print(f"[FAIL] step={step} loss={loss_f} β€” aborting", flush=True)
502
- save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
503
- sys.exit(1)
504
-
505
- torch.cuda.synchronize()
506
- dt = time.time() - t0
507
- if step > 3:
508
- total_train_secs += dt
509
-
510
- # EMA loss (debiased)
511
- beta = 0.9
512
- smooth_loss = beta * smooth_loss + (1 - beta) * loss_f
513
- debiased = smooth_loss / (1 - beta ** (step + 1))
514
- bpt = debiased / math.log(2)
515
- tps = int(TOTAL_BATCH_SIZE / dt) if dt > 0 else 0
516
- vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
517
- lr_now = optimizer.param_groups[0]["lr"]
518
- remaining = max(0, SFT_TIME_BUDGET - total_train_secs)
519
-
520
- print(
521
- f"sft_step={step:05d} loss={debiased:.4f} bpt={bpt:.3f} "
522
- f"tps={tps} dt_ms={dt*1000:.0f} lr={lr_now:.2e} "
523
- f"vram={vram_mib:.0f}MiB pct={100*progress:.1f} "
524
- f"epoch={epoch} remaining={remaining:.0f}s",
525
- flush=True,
526
- )
527
-
528
- if step > 0 and step % SFT_EVAL_INTERVAL == 0:
529
- run_samples(model, tokenizer, meta, step)
530
-
531
- if step > 0 and step % SFT_CKPT_INTERVAL == 0:
532
- save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
533
-
534
- step += 1
535
-
536
- if step > 5 and total_train_secs >= SFT_TIME_BUDGET:
537
- break
538
-
539
- # Final samples + save
540
- run_samples(model, tokenizer, meta, step)
541
- save_ckpt(model, step, smooth_loss, SFT_FINAL_CKPT, mode, meta)
542
-
543
- total_secs = time.time() - t_start
544
- print("---", flush=True)
545
- print(f"SFT_COMPLETE mode={mode} step={step} "
546
- f"smoothed_loss={smooth_loss:.4f} total_seconds={total_secs:.0f} "
547
- f"train_seconds={total_train_secs:.0f}", flush=True)
548
-
549
-
550
- if __name__ == "__main__":
551
- try:
552
- main()
553
- except SystemExit:
554
- raise
555
- except Exception as e:
556
- import traceback
557
- print(f"SFT_FAILED {type(e).__name__}: {e}", flush=True)
558
- traceback.print_exc()
559
- sys.exit(1)
 
1
+ """HYDRA SFT β€” instruction fine-tune the pretrained 7.5M-param base.
2
+
3
+ Mode selection:
4
+ MODE=resume_from_pretrain iff ~/.cache/autoresearch/pretrain_final.pt
5
+ exists AND loads cleanly into a fresh model.
6
+ MODE=from_scratch otherwise (degraded fallback).
7
+
8
+ Data: int16 shards written by `scripts/download_sft_data.py`, paired with
9
+ uint8 loss-mask shards (1 on assistant tokens, 0 on user-prompt tokens).
10
+ At runtime we pack consecutive examples into fixed-length rows; prompt
11
+ positions get target=-1 so CE's `ignore_index=-1` drops them.
12
+
13
+ Env vars (with defaults tuned for RTX 3060 6GB, 7.5M params):
14
+ HYDRA_SFT_TIME_BUDGET 10800 SFT wall-clock budget (3h)
15
+ HYDRA_SFT_SEQ_LEN 512 sequence length during SFT
16
+ HYDRA_BATCH_SIZE 4 per-step device batch
17
+ HYDRA_TOTAL_BATCH 8192 effective batch (grad-accum derived)
18
+ HYDRA_SFT_LR_MULT 0.10 multiply pretrain LRs by this
19
+ HYDRA_SFT_EVAL_INTERVAL 500 steps between sample generations
20
+ HYDRA_SFT_CKPT_INTERVAL 2000 steps between interim checkpoints
21
+
22
+ CLI:
23
+ --dry-run load model+data, run 1 step, exit (validation path)
24
+ --eval-only load `sft_final.pt`, run sample gen, exit
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import argparse
30
+ import json
31
+ import math
32
+ import os
33
+ import sys
34
+ import time
35
+ from dataclasses import asdict
36
+ from pathlib import Path
37
+
38
+ import numpy as np
39
+ import torch
40
+
41
+ # Repo root on path
42
+ _REPO_ROOT = Path(__file__).resolve().parent.parent
43
+ if str(_REPO_ROOT) not in sys.path:
44
+ sys.path.insert(0, str(_REPO_ROOT))
45
+
46
+ # Must import hydra.config BEFORE touching torch.cuda for CUDA env setup
47
+ from hydra.config import (
48
+ ADAM_BETAS, D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMBEDDING_LR,
49
+ ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
50
+ FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
51
+ N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
52
+ UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY,
53
+ )
54
+ from hydra.model import PostSemClawModel
55
+ from prepare import Tokenizer
56
+
57
+ # Use line-buffered stdout
58
+ try:
59
+ sys.stdout.reconfigure(line_buffering=True)
60
+ except Exception:
61
+ pass
62
+
63
+
64
+ # ---------------------------------------------------------------------------
65
+ # Paths
66
+ # ---------------------------------------------------------------------------
67
+
68
+ CACHE_DIR = Path.home() / ".cache" / "autoresearch"
69
+ PRETRAIN_CKPT = CACHE_DIR / "pretrain_final.pt"
70
+ SFT_FINAL_CKPT = CACHE_DIR / "sft_final.pt"
71
+ SFT_INTERIM_CKPT = CACHE_DIR / "sft_interim.pt"
72
+ SFT_DATA_DIR = _REPO_ROOT / "data" / "sft"
73
+
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # Env vars for SFT
77
+ # ---------------------------------------------------------------------------
78
+
79
+ SFT_TIME_BUDGET = int(os.environ.get("HYDRA_SFT_TIME_BUDGET", "10800"))
80
+ SFT_SEQ_LEN = int(os.environ.get("HYDRA_SFT_SEQ_LEN", "512"))
81
+ SFT_LR_MULT = float(os.environ.get("HYDRA_SFT_LR_MULT", "0.10"))
82
+ SFT_EVAL_INTERVAL = int(os.environ.get("HYDRA_SFT_EVAL_INTERVAL", "500"))
83
+ SFT_CKPT_INTERVAL = int(os.environ.get("HYDRA_SFT_CKPT_INTERVAL", "2000"))
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Data loading
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def _load_meta() -> dict:
91
+ meta_path = SFT_DATA_DIR / "meta.json"
92
+ if not meta_path.exists():
93
+ raise FileNotFoundError(
94
+ f"SFT meta not found at {meta_path}. Run "
95
+ f"`python scripts/download_sft_data.py` first."
96
+ )
97
+ with open(meta_path) as f:
98
+ return json.load(f)
99
+
100
+
101
+ def _load_shards():
102
+ """Load all shard_XXX.bin + mask_XXX.bin as big flat arrays.
103
+
104
+ Returns: (tokens: np.int64, mask: np.uint8)
105
+ Both arrays are 1-D and the same length. Total len ~= target_tokens.
106
+ """
107
+ tok_shards = sorted(SFT_DATA_DIR.glob("shard_*.bin"))
108
+ mask_shards = sorted(SFT_DATA_DIR.glob("mask_*.bin"))
109
+ if not tok_shards:
110
+ raise FileNotFoundError(f"No SFT shards in {SFT_DATA_DIR}")
111
+ assert len(tok_shards) == len(mask_shards), (
112
+ f"shard/mask count mismatch: {len(tok_shards)} vs {len(mask_shards)}"
113
+ )
114
+ tok_parts = []
115
+ mask_parts = []
116
+ for t, m in zip(tok_shards, mask_shards):
117
+ tok_parts.append(np.fromfile(str(t), dtype=np.int16).astype(np.int64))
118
+ mask_parts.append(np.fromfile(str(m), dtype=np.uint8))
119
+ tokens = np.concatenate(tok_parts)
120
+ mask = np.concatenate(mask_parts)
121
+ assert tokens.shape == mask.shape
122
+ # Guard against negative int16 values (unlikely with vocab=8192 but defensive)
123
+ if tokens.min() < 0 or tokens.max() >= 8192:
124
+ raise ValueError(
125
+ f"Token IDs out of range: min={tokens.min()} max={tokens.max()}"
126
+ )
127
+ return tokens, mask
128
+
129
+
130
+ def make_sft_dataloader(tokens: np.ndarray, mask: np.ndarray, B: int, T: int,
131
+ device: torch.device, seed: int = 0):
132
+ """Yield (x, y, epoch) forever.
133
+
134
+ Each row is a slice of length T+1 sampled at a random start. We produce:
135
+ x = slice[:-1] (B, T) int64 on device
136
+ y = slice[1:] with mask=0 -> -1 (B, T) int64 on device
137
+
138
+ The mask applies to target positions (y), not inputs. This way the
139
+ chunked CE loss in model.forward sees ignore_index=-1 for prompt tokens.
140
+ """
141
+ N = tokens.shape[0]
142
+ rng = np.random.default_rng(seed)
143
+ # Pin CPU tensors; copy to GPU non-blocking.
144
+ cpu_x = torch.empty(B, T, dtype=torch.long, pin_memory=True)
145
+ cpu_y = torch.empty(B, T, dtype=torch.long, pin_memory=True)
146
+ epoch = 1
147
+ samples_drawn = 0
148
+ samples_per_epoch = max(1, N // (T + 1))
149
+
150
+ # Minimum loss-positions per window. If a sampled window has fewer than
151
+ # this many assistant tokens, resample. Guards against all-prompt windows
152
+ # producing NaN from 0/0 in the chunked CE loss.
153
+ min_loss_positions = max(1, T // 32)
154
+ max_resample = 8
155
+
156
+ while True:
157
+ for b in range(B):
158
+ # Sample a starting index with a light rejection filter to ensure
159
+ # the window contains enough assistant (mask=1) positions.
160
+ if N <= T + 1:
161
+ start = 0
162
+ else:
163
+ start = int(rng.integers(0, N - T - 1))
164
+ for _ in range(max_resample):
165
+ loss_in_window = int(mask[start + 1:start + T + 1].sum())
166
+ if loss_in_window >= min_loss_positions:
167
+ break
168
+ start = int(rng.integers(0, N - T - 1))
169
+ window_tok = tokens[start:start + T + 1]
170
+ window_mask = mask[start:start + T + 1]
171
+ # x = window[:-1], y = window[1:]
172
+ cpu_x[b].copy_(torch.from_numpy(window_tok[:-1].astype(np.int64)))
173
+ y_slice = window_tok[1:].astype(np.int64).copy()
174
+ # Apply mask to targets: mask=0 (prompt) -> target=-1 (ignore)
175
+ y_slice[window_mask[1:] == 0] = -1
176
+ # Final guard: if no loss positions survived, force at least 1
177
+ # valid target so the batch doesn't produce NaN (it's rare with
178
+ # the rejection filter but defensive is cheap).
179
+ if (y_slice != -1).sum() == 0:
180
+ y_slice[-1] = int(window_tok[-1])
181
+ cpu_y[b].copy_(torch.from_numpy(y_slice))
182
+ x = cpu_x.to(device, non_blocking=True)
183
+ y = cpu_y.to(device, non_blocking=True)
184
+ samples_drawn += B
185
+ if samples_drawn >= samples_per_epoch:
186
+ epoch += 1
187
+ samples_drawn = 0
188
+ yield x, y, epoch
189
+
190
+
191
+ # ---------------------------------------------------------------------------
192
+ # Model init + checkpoint load
193
+ # ---------------------------------------------------------------------------
194
+
195
+ def _peek_pretrain_config(vocab_size: int) -> PostSemClawConfig | None:
196
+ """If pretrain checkpoint exists, return its saved config so we build
197
+ the SFT model with matching architecture. Returns None if unavailable."""
198
+ if not PRETRAIN_CKPT.exists():
199
+ return None
200
+ try:
201
+ ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cpu",
202
+ weights_only=False)
203
+ cfg_dict = ckpt.get("config")
204
+ if cfg_dict is None:
205
+ return None
206
+ # Override sequence_len to SFT's (shorter context) β€” architecture
207
+ # is independent of sequence_len since Mamba3 is recurrent.
208
+ cfg_dict = dict(cfg_dict)
209
+ cfg_dict["sequence_len"] = SFT_SEQ_LEN
210
+ cfg_dict["vocab_size"] = vocab_size
211
+ cfg = PostSemClawConfig(**cfg_dict)
212
+ return cfg
213
+ except Exception as e:
214
+ print(f"[model] could not peek pretrain config: {type(e).__name__}: {e}",
215
+ flush=True)
216
+ return None
217
+
218
+
219
+ def build_model(vocab_size: int, device: torch.device) -> PostSemClawModel:
220
+ # Prefer checkpoint-derived config if available (guards against env-var drift)
221
+ config = _peek_pretrain_config(vocab_size)
222
+ if config is None:
223
+ config = PostSemClawConfig(
224
+ sequence_len=SFT_SEQ_LEN,
225
+ vocab_size=vocab_size,
226
+ n_layer=N_LAYER,
227
+ d_model=D_MODEL,
228
+ d_state=D_STATE,
229
+ headdim=HEADDIM,
230
+ n_heads=N_HEADS,
231
+ expand=EXPAND,
232
+ engram_n_columns=ENGRAM_N_COLUMNS,
233
+ engram_key_dim=ENGRAM_KEY_DIM,
234
+ engram_layer_idx=ENGRAM_LAYER_IDX,
235
+ )
236
+ print(f"[model] config (from env, no ckpt): {asdict(config)}", flush=True)
237
+ else:
238
+ print(f"[model] config (from pretrain ckpt): {asdict(config)}", flush=True)
239
+ with torch.device("meta"):
240
+ model = PostSemClawModel(config)
241
+ model.to_empty(device=device)
242
+ model.init_weights()
243
+ return model
244
+
245
+
246
+ def try_load_pretrain(model: PostSemClawModel) -> tuple[bool, str]:
247
+ """Attempt to load pretrain checkpoint into model. Returns (loaded, msg)."""
248
+ if not PRETRAIN_CKPT.exists():
249
+ return False, f"no checkpoint at {PRETRAIN_CKPT}"
250
+ try:
251
+ ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cuda",
252
+ weights_only=False)
253
+ state = ckpt.get("model_state_dict", ckpt)
254
+ # Use strict=False in case SDR/HTM params are excluded from state_dict
255
+ # by torch.compile wrappers or similar.
256
+ missing, unexpected = model.load_state_dict(state, strict=False)
257
+ msg = (f"loaded {PRETRAIN_CKPT} β€” missing={len(missing)} "
258
+ f"unexpected={len(unexpected)}")
259
+ if missing:
260
+ # Log first few missing keys to help diagnose architecture skew
261
+ msg += f" first_missing={missing[:3]}"
262
+ return True, msg
263
+ except Exception as e:
264
+ return False, f"load failed: {type(e).__name__}: {e}"
265
+
266
+
267
+ # ---------------------------------------------------------------------------
268
+ # Sample generation (for in-training eval prints)
269
+ # ---------------------------------------------------------------------------
270
+
271
+ _SAMPLE_PROMPTS = [
272
+ "What is the capital of France?",
273
+ "Write a haiku about winter.",
274
+ "List three colors.",
275
+ "How are you?",
276
+ "Explain why the sky is blue in one sentence.",
277
+ ]
278
+
279
+
280
+ @torch.no_grad()
281
+ def sample_once(model, tokenizer, meta: dict, prompt: str,
282
+ max_new: int = 64, temperature: float = 0.8,
283
+ top_k: int = 40) -> str:
284
+ """Generate a chat-formatted reply. Stops on <|end|> or max_new tokens."""
285
+ bos = meta["special_tokens"]["bos"]
286
+ user = meta["special_tokens"]["user"]
287
+ assistant = meta["special_tokens"]["assistant"]
288
+ end = meta["special_tokens"]["end"]
289
+
290
+ prompt_ids = [bos, user] + tokenizer.encode("\n" + prompt.strip())
291
+ prompt_ids += tokenizer.encode("\n")
292
+ prompt_ids.append(assistant)
293
+ prompt_ids += tokenizer.encode("\n")
294
+
295
+ ctx = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
296
+ generated: list[int] = []
297
+ for _ in range(max_new):
298
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
299
+ logits = model(ctx, targets=None)
300
+ last = logits[0, -1].float()
301
+ if top_k and top_k < last.shape[-1]:
302
+ kth = torch.topk(last, top_k).values[-1]
303
+ last = torch.where(last < kth, torch.full_like(last, -1e9), last)
304
+ probs = torch.softmax(last / max(temperature, 1e-6), dim=-1)
305
+ next_id = int(torch.multinomial(probs, num_samples=1).item())
306
+ generated.append(next_id)
307
+ if next_id == end:
308
+ break
309
+ ctx = torch.cat(
310
+ [ctx, torch.tensor([[next_id]], device="cuda", dtype=torch.long)],
311
+ dim=1,
312
+ )
313
+ # Hard cap on ctx length (model was trained at 2048, SFT at 512,
314
+ # but inference could theoretically go longer)
315
+ if ctx.size(1) >= 2048:
316
+ break
317
+ try:
318
+ text = tokenizer.decode(generated)
319
+ except Exception:
320
+ text = "<decode error>"
321
+ return text
322
+
323
+
324
+ def run_samples(model, tokenizer, meta: dict, step: int):
325
+ model.eval()
326
+ print(f"\n=== SFT samples @ step {step} ===", flush=True)
327
+ for p in _SAMPLE_PROMPTS:
328
+ try:
329
+ resp = sample_once(model, tokenizer, meta, p)
330
+ except Exception as e:
331
+ resp = f"<sample failed: {type(e).__name__}: {e}>"
332
+ # Sanitize newlines for log readability
333
+ resp_clean = resp.replace("\n", " ⏎ ").replace("\r", " ")
334
+ print(f" prompt: {p!r}")
335
+ print(f" reply: {resp_clean!r}")
336
+ print("=== end samples ===\n", flush=True)
337
+ model.train()
338
+
339
+
340
+ # ---------------------------------------------------------------------------
341
+ # Checkpoint save
342
+ # ---------------------------------------------------------------------------
343
+
344
+ def save_ckpt(model, step: int, smoothed_loss: float, path: Path,
345
+ mode: str, meta: dict):
346
+ try:
347
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
348
+ payload = {
349
+ "model_state_dict": model.state_dict(),
350
+ "step": step,
351
+ "smoothed_loss": smoothed_loss,
352
+ "mode": mode,
353
+ "sft_meta": meta,
354
+ }
355
+ torch.save(payload, str(path))
356
+ print(f"[ckpt] saved {path} (step={step})", flush=True)
357
+ except Exception as e:
358
+ print(f"[ckpt] SAVE FAILED {path}: {type(e).__name__}: {e}", flush=True)
359
+
360
+
361
+ # ---------------------------------------------------------------------------
362
+ # Main
363
+ # ---------------------------------------------------------------------------
364
+
365
+ def main():
366
+ ap = argparse.ArgumentParser()
367
+ ap.add_argument("--dry-run", action="store_true",
368
+ help="Load model+data, run 1 step, exit.")
369
+ ap.add_argument("--eval-only", action="store_true",
370
+ help="Load sft_final.pt and run sample gen.")
371
+ args = ap.parse_args()
372
+
373
+ t_start = time.time()
374
+ torch.manual_seed(SEED + 1) # +1 so SFT draws different RNG than pretrain
375
+ torch.cuda.manual_seed(SEED + 1)
376
+ torch.set_float32_matmul_precision("high")
377
+ device = torch.device("cuda")
378
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
379
+
380
+ # --- Tokenizer ---
381
+ tokenizer = Tokenizer.from_directory()
382
+ vocab_size = tokenizer.get_vocab_size()
383
+ print(f"[init] vocab: {vocab_size}", flush=True)
384
+
385
+ # --- Data meta ---
386
+ meta = _load_meta()
387
+ print(f"[data] meta: {meta}", flush=True)
388
+
389
+ # --- Model ---
390
+ model = build_model(vocab_size, device)
391
+ n_params = sum(p.numel() for p in model.parameters())
392
+ print(f"[model] params: {n_params:,}", flush=True)
393
+
394
+ loaded, msg = try_load_pretrain(model)
395
+ mode = "resume_from_pretrain" if loaded else "from_scratch"
396
+ print(f"[init] MODE={mode} :: {msg}", flush=True)
397
+
398
+ # --- Eval-only path ---
399
+ if args.eval_only:
400
+ if SFT_FINAL_CKPT.exists():
401
+ ckpt = torch.load(str(SFT_FINAL_CKPT), map_location=device,
402
+ weights_only=False)
403
+ state = ckpt.get("model_state_dict", ckpt)
404
+ model.load_state_dict(state, strict=False)
405
+ print(f"[eval-only] loaded {SFT_FINAL_CKPT}", flush=True)
406
+ else:
407
+ print(f"[eval-only] no SFT checkpoint β€” running on current weights",
408
+ flush=True)
409
+ run_samples(model, tokenizer, meta, step=-1)
410
+ return
411
+
412
+ # --- Dataloader ---
413
+ print(f"[data] loading shards ...", flush=True)
414
+ tokens, mask = _load_shards()
415
+ print(f"[data] tokens: {len(tokens):,} loss-positions: {int(mask.sum()):,}",
416
+ flush=True)
417
+ B = DEVICE_BATCH_SIZE
418
+ T = SFT_SEQ_LEN
419
+ tokens_per_fwdbwd = B * T
420
+ assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, (
421
+ f"TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} not divisible by B*T={tokens_per_fwdbwd}"
422
+ )
423
+ grad_accum = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
424
+ print(f"[train] B={B} T={T} accum={grad_accum} effective_batch={TOTAL_BATCH_SIZE}",
425
+ flush=True)
426
+ loader = make_sft_dataloader(tokens, mask, B, T, device, seed=SEED + 7)
427
+ x, y, epoch = next(loader)
428
+
429
+ # --- Optimizer (scaled LRs) ---
430
+ matrix_lr = MATRIX_LR * SFT_LR_MULT
431
+ embed_lr = EMBEDDING_LR * SFT_LR_MULT
432
+ unembed_lr = UNEMBEDDING_LR * SFT_LR_MULT
433
+ scalar_lr = SCALAR_LR * SFT_LR_MULT
434
+ print(f"[opt] LRs scaled by {SFT_LR_MULT}: matrix={matrix_lr:.5f} "
435
+ f"embed={embed_lr:.5f} unembed={unembed_lr:.6f}", flush=True)
436
+ optimizer = model.setup_optimizer(
437
+ unembedding_lr=unembed_lr,
438
+ embedding_lr=embed_lr,
439
+ scalar_lr=scalar_lr,
440
+ adam_betas=ADAM_BETAS,
441
+ matrix_lr=matrix_lr,
442
+ weight_decay=WEIGHT_DECAY,
443
+ )
444
+
445
+ # --- Dry-run path (validation) ---
446
+ if args.dry_run:
447
+ print("[dry-run] running 1 step ...", flush=True)
448
+ with autocast_ctx:
449
+ loss = model(x, y)
450
+ loss_f = float(loss.item())
451
+ print(f"[dry-run] step0 loss={loss_f:.4f}", flush=True)
452
+ loss.backward()
453
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
454
+ optimizer.step()
455
+ model.zero_grad(set_to_none=True)
456
+ if math.isnan(loss_f) or loss_f > 100:
457
+ print("[dry-run] FAILED (NaN / huge loss)", flush=True)
458
+ sys.exit(1)
459
+ print("[dry-run] OK", flush=True)
460
+ return
461
+
462
+ # --- Training loop ---
463
+ print(f"[train] budget={SFT_TIME_BUDGET}s eval_every={SFT_EVAL_INTERVAL} "
464
+ f"ckpt_every={SFT_CKPT_INTERVAL}", flush=True)
465
+ t_loop_start = time.time()
466
+ smooth_loss = 0.0
467
+ step = 0
468
+ total_train_secs = 0.0
469
+
470
+ # Warmup schedule for SFT: linear 0->1 over first 5% of budget, then cosine.
471
+ sft_warmup_frac = 0.05
472
+
473
+ def lr_mult(progress: float) -> float:
474
+ if progress < sft_warmup_frac:
475
+ return progress / sft_warmup_frac if sft_warmup_frac > 0 else 1.0
476
+ decay = (progress - sft_warmup_frac) / (1.0 - sft_warmup_frac)
477
+ return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * \
478
+ (1 + math.cos(math.pi * decay))
479
+
480
+ while True:
481
+ torch.cuda.synchronize()
482
+ t0 = time.time()
483
+ for _ in range(grad_accum):
484
+ with autocast_ctx:
485
+ loss = model(x, y)
486
+ train_loss_val = loss.detach()
487
+ (loss / grad_accum).backward()
488
+ x, y, epoch = next(loader)
489
+
490
+ progress = min(total_train_secs / SFT_TIME_BUDGET, 1.0)
491
+ mult = lr_mult(progress)
492
+ for group in optimizer.param_groups:
493
+ group["lr"] = group["initial_lr"] * mult
494
+
495
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
496
+ optimizer.step()
497
+ model.zero_grad(set_to_none=True)
498
+
499
+ loss_f = float(train_loss_val.item())
500
+ if math.isnan(loss_f) or loss_f > 100:
501
+ print(f"[FAIL] step={step} loss={loss_f} β€” aborting", flush=True)
502
+ save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
503
+ sys.exit(1)
504
+
505
+ torch.cuda.synchronize()
506
+ dt = time.time() - t0
507
+ if step > 3:
508
+ total_train_secs += dt
509
+
510
+ # EMA loss (debiased)
511
+ beta = 0.9
512
+ smooth_loss = beta * smooth_loss + (1 - beta) * loss_f
513
+ debiased = smooth_loss / (1 - beta ** (step + 1))
514
+ bpt = debiased / math.log(2)
515
+ tps = int(TOTAL_BATCH_SIZE / dt) if dt > 0 else 0
516
+ vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
517
+ lr_now = optimizer.param_groups[0]["lr"]
518
+ remaining = max(0, SFT_TIME_BUDGET - total_train_secs)
519
+
520
+ print(
521
+ f"sft_step={step:05d} loss={debiased:.4f} bpt={bpt:.3f} "
522
+ f"tps={tps} dt_ms={dt*1000:.0f} lr={lr_now:.2e} "
523
+ f"vram={vram_mib:.0f}MiB pct={100*progress:.1f} "
524
+ f"epoch={epoch} remaining={remaining:.0f}s",
525
+ flush=True,
526
+ )
527
+
528
+ if step > 0 and step % SFT_EVAL_INTERVAL == 0:
529
+ run_samples(model, tokenizer, meta, step)
530
+
531
+ if step > 0 and step % SFT_CKPT_INTERVAL == 0:
532
+ save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
533
+
534
+ step += 1
535
+
536
+ if step > 5 and total_train_secs >= SFT_TIME_BUDGET:
537
+ break
538
+
539
+ # Final samples + save
540
+ run_samples(model, tokenizer, meta, step)
541
+ save_ckpt(model, step, smooth_loss, SFT_FINAL_CKPT, mode, meta)
542
+
543
+ total_secs = time.time() - t_start
544
+ print("---", flush=True)
545
+ print(f"SFT_COMPLETE mode={mode} step={step} "
546
+ f"smoothed_loss={smooth_loss:.4f} total_seconds={total_secs:.0f} "
547
+ f"train_seconds={total_train_secs:.0f}", flush=True)
548
+
549
+
550
+ if __name__ == "__main__":
551
+ try:
552
+ main()
553
+ except SystemExit:
554
+ raise
555
+ except Exception as e:
556
+ import traceback
557
+ print(f"SFT_FAILED {type(e).__name__}: {e}", flush=True)
558
+ traceback.print_exc()
559
+ sys.exit(1)
overlay/scripts/sft_orchestrator.sh CHANGED
@@ -1,165 +1,165 @@
1
- #!/usr/bin/env bash
2
- #
3
- # SFT orchestrator: waits for pretrain (train.py) to either complete or
4
- # reach the 8h budget, then kicks off SFT.
5
- #
6
- # Behavior:
7
- # - Polls for `train.py` process every 60 s
8
- # - Exits the wait loop on either:
9
- # (a) no train.py process found (pretrain completed naturally), or
10
- # (b) 8h elapsed since this script started
11
- # - Sends SIGTERM first (graceful β€” triggers checkpoint-save patch if
12
- # applied), waits 30 s, then SIGKILL as fallback
13
- # - Invokes `scripts/download_sft_data.py` if shards don't exist
14
- # - Launches `scripts/sft.py` in the background with tuned env vars
15
- # - Redirects all output to `run_sft.log`
16
- #
17
- # Re-entrant: safe to invoke even if pretrain has already exited.
18
- # Does NOT re-launch if SFT is already running.
19
- #
20
- # Usage (typical):
21
- # cd /home/mikeb/work/feather
22
- # nohup bash scripts/sft_orchestrator.sh > orchestrator.log 2>&1 &
23
- # disown
24
-
25
- set -u # error on unset vars, but don't -e (we handle failures explicitly)
26
-
27
- REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
28
- cd "$REPO_ROOT" || { echo "cannot cd to $REPO_ROOT" >&2; exit 1; }
29
-
30
- PY="$REPO_ROOT/.venv/bin/python"
31
- if [ ! -x "$PY" ]; then
32
- echo "[orchestrator] ERROR: python not found at $PY" >&2
33
- exit 1
34
- fi
35
-
36
- LOG_FILE="$REPO_ROOT/run_sft.log"
37
- DATA_LOG="$REPO_ROOT/run_sft_download.log"
38
- MAX_WAIT_SECONDS=28800 # 8 hours
39
- POLL_INTERVAL=60
40
- GRACEFUL_SHUTDOWN_WAIT=30
41
-
42
- log() {
43
- echo "[orchestrator $(date -u '+%Y-%m-%dT%H:%M:%SZ')] $*"
44
- }
45
-
46
- # ---------------------------------------------------------------------------
47
- # Stage 1: wait for pretrain
48
- # ---------------------------------------------------------------------------
49
-
50
- log "starting; max wait = ${MAX_WAIT_SECONDS}s"
51
-
52
- # Guard against double-launch
53
- if pgrep -f "scripts/sft.py" > /dev/null; then
54
- log "SFT is already running β€” exiting orchestrator to avoid conflict"
55
- exit 0
56
- fi
57
-
58
- T_START=$(date +%s)
59
- while true; do
60
- NOW=$(date +%s)
61
- ELAPSED=$((NOW - T_START))
62
-
63
- if [ $ELAPSED -ge $MAX_WAIT_SECONDS ]; then
64
- log "reached 8h wait cap (${ELAPSED}s) β€” will kill pretrain"
65
- break
66
- fi
67
-
68
- # Count train.py processes owned by current user (not orchestrator/sft.py)
69
- PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
70
- # Strip pid of this script if pgrep matched something spurious
71
- PRETRAIN_PIDS=$(echo "$PRETRAIN_PIDS" | sed "s/\b$$\b//g" | xargs)
72
-
73
- if [ -z "$PRETRAIN_PIDS" ]; then
74
- log "no train.py process found β€” pretrain already exited"
75
- break
76
- fi
77
-
78
- # Log a status every 10 polls (~10 min)
79
- if [ $((ELAPSED / POLL_INTERVAL % 10)) -eq 0 ]; then
80
- log "waiting... elapsed=${ELAPSED}s pretrain PIDs: $PRETRAIN_PIDS"
81
- fi
82
-
83
- sleep $POLL_INTERVAL
84
- done
85
-
86
- # ---------------------------------------------------------------------------
87
- # Stage 2: kill any remaining pretrain processes
88
- # ---------------------------------------------------------------------------
89
-
90
- PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
91
- if [ -n "$PRETRAIN_PIDS" ]; then
92
- log "sending SIGTERM to pretrain PIDs: $PRETRAIN_PIDS"
93
- for pid in $PRETRAIN_PIDS; do
94
- kill -TERM "$pid" 2>/dev/null || true
95
- done
96
-
97
- # Wait for graceful shutdown (gives the checkpoint-save patch time to run)
98
- for _ in $(seq 1 $GRACEFUL_SHUTDOWN_WAIT); do
99
- REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
100
- if [ -z "$REMAINING" ]; then break; fi
101
- sleep 1
102
- done
103
-
104
- # Force-kill any stragglers
105
- REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
106
- if [ -n "$REMAINING" ]; then
107
- log "force-killing stragglers: $REMAINING"
108
- for pid in $REMAINING; do
109
- kill -9 "$pid" 2>/dev/null || true
110
- done
111
- sleep 5
112
- fi
113
- fi
114
-
115
- # ---------------------------------------------------------------------------
116
- # Stage 3: ensure SFT data exists
117
- # ---------------------------------------------------------------------------
118
-
119
- META_JSON="$REPO_ROOT/data/sft/meta.json"
120
- if [ ! -f "$META_JSON" ]; then
121
- log "no SFT data found β€” running download_sft_data.py"
122
- LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
123
- "$PY" -u "$REPO_ROOT/scripts/download_sft_data.py" \
124
- > "$DATA_LOG" 2>&1
125
- DL_RC=$?
126
- if [ $DL_RC -ne 0 ] || [ ! -f "$META_JSON" ]; then
127
- log "ERROR: SFT data download failed (rc=$DL_RC)"
128
- log " last 20 lines of $DATA_LOG:"
129
- tail -20 "$DATA_LOG" 2>/dev/null | sed 's/^/ /'
130
- exit 2
131
- fi
132
- log "SFT data ready"
133
- else
134
- log "SFT data already present at $META_JSON"
135
- fi
136
-
137
- # ---------------------------------------------------------------------------
138
- # Stage 4: launch SFT
139
- # ---------------------------------------------------------------------------
140
-
141
- # Guard: if we somehow got here and SFT is now running, don't double-launch.
142
- if pgrep -f "scripts/sft.py" > /dev/null; then
143
- log "SFT is already running β€” skipping launch"
144
- exit 0
145
- fi
146
-
147
- log "launching SFT (log -> $LOG_FILE)"
148
-
149
- export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
150
- export HYDRA_SFT_TIME_BUDGET="${HYDRA_SFT_TIME_BUDGET:-10800}"
151
- export HYDRA_BATCH_SIZE="${HYDRA_BATCH_SIZE:-4}"
152
- export HYDRA_TOTAL_BATCH="${HYDRA_TOTAL_BATCH:-8192}"
153
- export HYDRA_SFT_SEQ_LEN="${HYDRA_SFT_SEQ_LEN:-512}"
154
- export HYDRA_SFT_LR_MULT="${HYDRA_SFT_LR_MULT:-0.10}"
155
- export HYDRA_SFT_EVAL_INTERVAL="${HYDRA_SFT_EVAL_INTERVAL:-500}"
156
- export HYDRA_SFT_CKPT_INTERVAL="${HYDRA_SFT_CKPT_INTERVAL:-2000}"
157
- export HYDRA_DROPOUT="${HYDRA_DROPOUT:-0.1}"
158
-
159
- nohup "$PY" -u "$REPO_ROOT/scripts/sft.py" \
160
- > "$LOG_FILE" 2>&1 &
161
- SFT_PID=$!
162
- disown $SFT_PID 2>/dev/null || true
163
-
164
- log "SFT launched as PID $SFT_PID (budget=${HYDRA_SFT_TIME_BUDGET}s)"
165
- log "monitor with: tail -f $LOG_FILE"
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # SFT orchestrator: waits for pretrain (train.py) to either complete or
4
+ # reach the 8h budget, then kicks off SFT.
5
+ #
6
+ # Behavior:
7
+ # - Polls for `train.py` process every 60 s
8
+ # - Exits the wait loop on either:
9
+ # (a) no train.py process found (pretrain completed naturally), or
10
+ # (b) 8h elapsed since this script started
11
+ # - Sends SIGTERM first (graceful β€” triggers checkpoint-save patch if
12
+ # applied), waits 30 s, then SIGKILL as fallback
13
+ # - Invokes `scripts/download_sft_data.py` if shards don't exist
14
+ # - Launches `scripts/sft.py` in the background with tuned env vars
15
+ # - Redirects all output to `run_sft.log`
16
+ #
17
+ # Re-entrant: safe to invoke even if pretrain has already exited.
18
+ # Does NOT re-launch if SFT is already running.
19
+ #
20
+ # Usage (typical):
21
+ # cd /home/mikeb/work/feather
22
+ # nohup bash scripts/sft_orchestrator.sh > orchestrator.log 2>&1 &
23
+ # disown
24
+
25
+ set -u # error on unset vars, but don't -e (we handle failures explicitly)
26
+
27
+ REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
28
+ cd "$REPO_ROOT" || { echo "cannot cd to $REPO_ROOT" >&2; exit 1; }
29
+
30
+ PY="$REPO_ROOT/.venv/bin/python"
31
+ if [ ! -x "$PY" ]; then
32
+ echo "[orchestrator] ERROR: python not found at $PY" >&2
33
+ exit 1
34
+ fi
35
+
36
+ LOG_FILE="$REPO_ROOT/run_sft.log"
37
+ DATA_LOG="$REPO_ROOT/run_sft_download.log"
38
+ MAX_WAIT_SECONDS=28800 # 8 hours
39
+ POLL_INTERVAL=60
40
+ GRACEFUL_SHUTDOWN_WAIT=30
41
+
42
+ log() {
43
+ echo "[orchestrator $(date -u '+%Y-%m-%dT%H:%M:%SZ')] $*"
44
+ }
45
+
46
+ # ---------------------------------------------------------------------------
47
+ # Stage 1: wait for pretrain
48
+ # ---------------------------------------------------------------------------
49
+
50
+ log "starting; max wait = ${MAX_WAIT_SECONDS}s"
51
+
52
+ # Guard against double-launch
53
+ if pgrep -f "scripts/sft.py" > /dev/null; then
54
+ log "SFT is already running β€” exiting orchestrator to avoid conflict"
55
+ exit 0
56
+ fi
57
+
58
+ T_START=$(date +%s)
59
+ while true; do
60
+ NOW=$(date +%s)
61
+ ELAPSED=$((NOW - T_START))
62
+
63
+ if [ $ELAPSED -ge $MAX_WAIT_SECONDS ]; then
64
+ log "reached 8h wait cap (${ELAPSED}s) β€” will kill pretrain"
65
+ break
66
+ fi
67
+
68
+ # Count train.py processes owned by current user (not orchestrator/sft.py)
69
+ PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
70
+ # Strip pid of this script if pgrep matched something spurious
71
+ PRETRAIN_PIDS=$(echo "$PRETRAIN_PIDS" | sed "s/\b$$\b//g" | xargs)
72
+
73
+ if [ -z "$PRETRAIN_PIDS" ]; then
74
+ log "no train.py process found β€” pretrain already exited"
75
+ break
76
+ fi
77
+
78
+ # Log a status every 10 polls (~10 min)
79
+ if [ $((ELAPSED / POLL_INTERVAL % 10)) -eq 0 ]; then
80
+ log "waiting... elapsed=${ELAPSED}s pretrain PIDs: $PRETRAIN_PIDS"
81
+ fi
82
+
83
+ sleep $POLL_INTERVAL
84
+ done
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Stage 2: kill any remaining pretrain processes
88
+ # ---------------------------------------------------------------------------
89
+
90
+ PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
91
+ if [ -n "$PRETRAIN_PIDS" ]; then
92
+ log "sending SIGTERM to pretrain PIDs: $PRETRAIN_PIDS"
93
+ for pid in $PRETRAIN_PIDS; do
94
+ kill -TERM "$pid" 2>/dev/null || true
95
+ done
96
+
97
+ # Wait for graceful shutdown (gives the checkpoint-save patch time to run)
98
+ for _ in $(seq 1 $GRACEFUL_SHUTDOWN_WAIT); do
99
+ REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
100
+ if [ -z "$REMAINING" ]; then break; fi
101
+ sleep 1
102
+ done
103
+
104
+ # Force-kill any stragglers
105
+ REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
106
+ if [ -n "$REMAINING" ]; then
107
+ log "force-killing stragglers: $REMAINING"
108
+ for pid in $REMAINING; do
109
+ kill -9 "$pid" 2>/dev/null || true
110
+ done
111
+ sleep 5
112
+ fi
113
+ fi
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Stage 3: ensure SFT data exists
117
+ # ---------------------------------------------------------------------------
118
+
119
+ META_JSON="$REPO_ROOT/data/sft/meta.json"
120
+ if [ ! -f "$META_JSON" ]; then
121
+ log "no SFT data found β€” running download_sft_data.py"
122
+ LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
123
+ "$PY" -u "$REPO_ROOT/scripts/download_sft_data.py" \
124
+ > "$DATA_LOG" 2>&1
125
+ DL_RC=$?
126
+ if [ $DL_RC -ne 0 ] || [ ! -f "$META_JSON" ]; then
127
+ log "ERROR: SFT data download failed (rc=$DL_RC)"
128
+ log " last 20 lines of $DATA_LOG:"
129
+ tail -20 "$DATA_LOG" 2>/dev/null | sed 's/^/ /'
130
+ exit 2
131
+ fi
132
+ log "SFT data ready"
133
+ else
134
+ log "SFT data already present at $META_JSON"
135
+ fi
136
+
137
+ # ---------------------------------------------------------------------------
138
+ # Stage 4: launch SFT
139
+ # ---------------------------------------------------------------------------
140
+
141
+ # Guard: if we somehow got here and SFT is now running, don't double-launch.
142
+ if pgrep -f "scripts/sft.py" > /dev/null; then
143
+ log "SFT is already running β€” skipping launch"
144
+ exit 0
145
+ fi
146
+
147
+ log "launching SFT (log -> $LOG_FILE)"
148
+
149
+ export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
150
+ export HYDRA_SFT_TIME_BUDGET="${HYDRA_SFT_TIME_BUDGET:-10800}"
151
+ export HYDRA_BATCH_SIZE="${HYDRA_BATCH_SIZE:-4}"
152
+ export HYDRA_TOTAL_BATCH="${HYDRA_TOTAL_BATCH:-8192}"
153
+ export HYDRA_SFT_SEQ_LEN="${HYDRA_SFT_SEQ_LEN:-512}"
154
+ export HYDRA_SFT_LR_MULT="${HYDRA_SFT_LR_MULT:-0.10}"
155
+ export HYDRA_SFT_EVAL_INTERVAL="${HYDRA_SFT_EVAL_INTERVAL:-500}"
156
+ export HYDRA_SFT_CKPT_INTERVAL="${HYDRA_SFT_CKPT_INTERVAL:-2000}"
157
+ export HYDRA_DROPOUT="${HYDRA_DROPOUT:-0.1}"
158
+
159
+ nohup "$PY" -u "$REPO_ROOT/scripts/sft.py" \
160
+ > "$LOG_FILE" 2>&1 &
161
+ SFT_PID=$!
162
+ disown $SFT_PID 2>/dev/null || true
163
+
164
+ log "SFT launched as PID $SFT_PID (budget=${HYDRA_SFT_TIME_BUDGET}s)"
165
+ log "monitor with: tail -f $LOG_FILE"
overlay/subsystems/fused_sdr_project.py CHANGED
@@ -114,6 +114,13 @@ class FusedSDRProject(torch.autograd.Function):
114
 
115
  out = torch.empty(P, D, device=active.device, dtype=sdr_proj_weight.dtype)
116
 
 
 
 
 
 
 
 
117
  BLOCK_D = min(256, triton.next_power_of_2(D))
118
  grid = (P * triton.cdiv(D, BLOCK_D),)
119
 
 
114
 
115
  out = torch.empty(P, D, device=active.device, dtype=sdr_proj_weight.dtype)
116
 
117
+ if not active.is_cuda:
118
+ # Local CPU validation has no Triton driver. Keep the same custom
119
+ # autograd contract but use a deterministic gather+sum fallback.
120
+ out = wt[active].sum(dim=1).to(dtype=sdr_proj_weight.dtype)
121
+ ctx.save_for_backward(active, token_ids, sdr_proj_weight, delta_u, delta_v)
122
+ return out.view(B, T, D)
123
+
124
  BLOCK_D = min(256, triton.next_power_of_2(D))
125
  grid = (P * triton.cdiv(D, BLOCK_D),)
126
 
overlay/subsystems/htm.py CHANGED
@@ -99,13 +99,19 @@ class HTMLayer(nn.Module):
99
  self._forward_counter = 0
100
  # GPU backend gate. Default: auto-detect β€” use GPU when the pyo3
101
  # module was built with --features gpu AND CUDA is actually usable.
 
 
 
 
102
  if use_gpu is None:
103
- use_gpu = _HTM_HAS_GPU and torch.cuda.is_available()
104
  elif use_gpu and not _HTM_HAS_GPU:
105
  raise RuntimeError(
106
  "HTMLayer(use_gpu=True) but htm_rust was not built with "
107
  "--features gpu. Re-run `maturin develop --features gpu`."
108
  )
 
 
109
  self._use_gpu = bool(use_gpu)
110
  cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion
111
  self._region_cls = cls
 
99
  self._forward_counter = 0
100
  # GPU backend gate. Default: auto-detect β€” use GPU when the pyo3
101
  # module was built with --features gpu AND CUDA is actually usable.
102
+ # HYDRA_FORCE_HTM_CPU=1 is an operational safety valve for paid remote
103
+ # canaries when the compiled CUDA HTM backend is present but unstable on
104
+ # a specific hardware/runtime combination.
105
+ force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1"
106
  if use_gpu is None:
107
+ use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available()
108
  elif use_gpu and not _HTM_HAS_GPU:
109
  raise RuntimeError(
110
  "HTMLayer(use_gpu=True) but htm_rust was not built with "
111
  "--features gpu. Re-run `maturin develop --features gpu`."
112
  )
113
+ elif use_gpu and force_cpu:
114
+ use_gpu = False
115
  self._use_gpu = bool(use_gpu)
116
  cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion
117
  self._region_cls = cls
overlay/subsystems/sdr_semantic.py CHANGED
@@ -46,19 +46,10 @@ class _SDRSTE(torch.autograd.Function):
46
  flat_grad = grad_out.reshape(B * T, n_bits).to(delta_v.dtype)
47
  flat_ids = token_ids.reshape(B * T)
48
  V = delta_u.shape[0]
49
- R = delta_u.shape[1] # delta_rank β€” typically 32
50
- # OOM fix: old code allocated (V, n_bits) = 4GB buffer via index_add.
51
- # Instead, project to rank-R space first (small), then scatter.
52
- # grad_delta_u[t, r] = sum_{pos: id=flat_ids[pos]=t} (flat_grad[pos] @ delta_v[r])
53
- # = index_add(V, R, flat_ids, flat_grad @ delta_v.T)
54
- projected = flat_grad @ delta_v.t() # (B*T, R) β€” ~1MB at B=8,T=1024,R=32
55
- per_tok_u = torch.zeros(V, R, device=flat_grad.device, dtype=delta_v.dtype)
56
- per_tok_u.index_add_(0, flat_ids, projected)
57
- grad_delta_u = per_tok_u # (V, R) β€” ~8MB at V=65536
58
- # grad_delta_v = sum_{pos} delta_u[flat_ids[pos]]^T @ flat_grad[pos]
59
- # = delta_u[flat_ids].T @ flat_grad β€” no intermediate buffer
60
- gathered_u = delta_u[flat_ids] # (B*T, R) β€” ~1MB
61
- grad_delta_v = gathered_u.t() @ flat_grad # (R, n_bits) β€” ~2MB
62
  return None, grad_delta_u, grad_delta_v, None
63
 
64
 
@@ -249,25 +240,12 @@ class SemanticFoldingSDR(nn.Module):
249
  sdr_binary = sdr_binary.view(B, T, self.n_bits)
250
  return _SDRSTE.apply(sdr_binary, self.delta_u, self.delta_v, token_ids)
251
 
252
- @torch.no_grad()
253
- def active_indices(self, token_ids: torch.Tensor) -> torch.Tensor:
254
- """Compact int16 Reality Buffer view: (B,T,K) active retina offsets.
255
-
256
- This is the production discrete bridge for Cantor/Engram routing. It
257
- avoids reconstructing dense (B,T,n_bits) masks when consumers only need
258
- the L0 support set.
259
- """
260
- if token_ids.dim() != 2:
261
- raise ValueError(f"expected (B, T) token_ids, got shape {tuple(token_ids.shape)}")
262
- B, T = token_ids.shape
263
- return self._retina_indices[token_ids.reshape(-1)].view(B, T, self.target_active)
264
-
265
  @torch.no_grad()
266
  def binary_only(self, token_ids: torch.Tensor) -> torch.Tensor:
267
  """uint8 retina view β€” no STE, no autocast cost. For HTM/consumers that
268
  only need the binary pattern. Reconstructs dense from CSR indices."""
269
  B, T = token_ids.shape
270
- idx = self.active_indices(token_ids).reshape(B * T, self.target_active)
271
  sdr = torch.zeros(
272
  B * T, self.n_bits, dtype=torch.uint8, device=token_ids.device,
273
  )
 
46
  flat_grad = grad_out.reshape(B * T, n_bits).to(delta_v.dtype)
47
  flat_ids = token_ids.reshape(B * T)
48
  V = delta_u.shape[0]
49
+ per_tok = torch.zeros(V, n_bits, device=flat_grad.device, dtype=delta_v.dtype)
50
+ per_tok.index_add_(0, flat_ids, flat_grad)
51
+ grad_delta_u = per_tok @ delta_v.t()
52
+ grad_delta_v = delta_u.t() @ per_tok
 
 
 
 
 
 
 
 
 
53
  return None, grad_delta_u, grad_delta_v, None
54
 
55
 
 
240
  sdr_binary = sdr_binary.view(B, T, self.n_bits)
241
  return _SDRSTE.apply(sdr_binary, self.delta_u, self.delta_v, token_ids)
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  @torch.no_grad()
244
  def binary_only(self, token_ids: torch.Tensor) -> torch.Tensor:
245
  """uint8 retina view β€” no STE, no autocast cost. For HTM/consumers that
246
  only need the binary pattern. Reconstructs dense from CSR indices."""
247
  B, T = token_ids.shape
248
+ idx = self._retina_indices[token_ids.reshape(-1)] # (B*T, K) int16
249
  sdr = torch.zeros(
250
  B * T, self.n_bits, dtype=torch.uint8, device=token_ids.device,
251
  )