Spaces:
Runtime error
Runtime error
Update Feather H200 runtime: Nemotron streaming and HTM force-CPU canary fixes
Browse files- overlay/htm_rust/bench_gpu.py +81 -81
- overlay/htm_rust/build.rs +6 -12
- overlay/htm_rust/docs/GPU_HTM.md +302 -302
- overlay/htm_rust/src/gpu/fused.rs +50 -33
- overlay/htm_rust/src/gpu/kernels/sp_boost_fused.cu +59 -59
- overlay/htm_rust/src/gpu/kernels/sp_duty.cu +45 -45
- overlay/htm_rust/src/gpu/kernels/sp_learn.cu +45 -45
- overlay/htm_rust/src/gpu/kernels/sp_overlap.cu +78 -78
- overlay/htm_rust/src/gpu/kernels/sp_topk.cu +117 -117
- overlay/htm_rust/src/gpu/kernels/tm_activate.cu +66 -66
- overlay/htm_rust/src/gpu/kernels/tm_anomaly.cu +43 -43
- overlay/htm_rust/src/gpu/kernels/tm_grow.cu +155 -155
- overlay/htm_rust/src/gpu/kernels/tm_learn.cu +75 -75
- overlay/htm_rust/src/gpu/kernels/tm_predict.cu +102 -102
- overlay/htm_rust/src/gpu/kernels/tm_punish.cu +64 -64
- overlay/htm_rust/src/gpu/kernels/tm_reset.cu +36 -36
- overlay/htm_rust/src/gpu/mod.rs +549 -549
- overlay/htm_rust/src/gpu/sp_gpu.rs +796 -796
- overlay/htm_rust/src/gpu/tm_gpu.rs +460 -460
- overlay/htm_rust/uv.lock +8 -8
- overlay/hydra/config.py +2 -2
- overlay/hydra/engram.py +121 -104
- overlay/hydra/model.py +1 -0
- overlay/scripts/autoresearch.py +517 -517
- overlay/scripts/chat.py +458 -458
- overlay/scripts/chat_eval.py +300 -300
- overlay/scripts/compile_debug.py +213 -213
- overlay/scripts/dataset_audit.py +241 -241
- overlay/scripts/download_sft_data.py +457 -457
- overlay/scripts/eval_quality.py +525 -525
- overlay/scripts/fetch_corpus.py +211 -211
- overlay/scripts/grad_probe.py +196 -196
- overlay/scripts/launch_feather_hf_job.py +8 -2
- overlay/scripts/profile_forward.py +87 -87
- overlay/scripts/run_domain_expanded_pretrain.sh +1 -5
- overlay/scripts/sample_utils.py +107 -107
- overlay/scripts/sft.py +559 -559
- overlay/scripts/sft_orchestrator.sh +165 -165
- overlay/subsystems/fused_sdr_project.py +7 -0
- overlay/subsystems/htm.py +7 -1
- overlay/subsystems/sdr_semantic.py +5 -27
overlay/htm_rust/bench_gpu.py
CHANGED
|
@@ -1,81 +1,81 @@
|
|
| 1 |
-
"""Microbenchmark: CPU vs GPU HTMLayer forward at HYDRA training sizes.
|
| 2 |
-
|
| 3 |
-
Usage:
|
| 4 |
-
source .venv/bin/activate
|
| 5 |
-
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 6 |
-
python htm_rust/bench_gpu.py
|
| 7 |
-
"""
|
| 8 |
-
import os
|
| 9 |
-
import sys
|
| 10 |
-
import time
|
| 11 |
-
|
| 12 |
-
# Ensure /home/mikeb/work/feather is on sys.path so `subsystems` imports.
|
| 13 |
-
_FEATHER = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
-
if _FEATHER not in sys.path:
|
| 15 |
-
sys.path.insert(0, _FEATHER)
|
| 16 |
-
|
| 17 |
-
import numpy as np
|
| 18 |
-
import torch
|
| 19 |
-
|
| 20 |
-
from subsystems.htm import HTMLayer
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def bench(layer: HTMLayer, sdr: torch.Tensor, warmup: int = 1, iters: int = 3) -> float:
|
| 24 |
-
"""Return mean ms/forward."""
|
| 25 |
-
for _ in range(warmup):
|
| 26 |
-
_ = layer(sdr)
|
| 27 |
-
if torch.cuda.is_available():
|
| 28 |
-
torch.cuda.synchronize()
|
| 29 |
-
t0 = time.perf_counter()
|
| 30 |
-
for _ in range(iters):
|
| 31 |
-
_ = layer(sdr)
|
| 32 |
-
if torch.cuda.is_available():
|
| 33 |
-
torch.cuda.synchronize()
|
| 34 |
-
dt = time.perf_counter() - t0
|
| 35 |
-
return dt * 1000 / iters
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def main() -> None:
|
| 39 |
-
# HYDRA training config: B=8, T=2048, bits=16384, cols=2048.
|
| 40 |
-
B, T, D = int(os.environ.get("B", 8)), int(os.environ.get("T", 2048)), 16384
|
| 41 |
-
n_cols = 2048
|
| 42 |
-
|
| 43 |
-
print(f"config: B={B} T={T} D={D} n_cols={n_cols}")
|
| 44 |
-
print(f"torch: {torch.__version__} cuda={torch.cuda.is_available()}")
|
| 45 |
-
|
| 46 |
-
# Build a fixed sparse SDR once.
|
| 47 |
-
rng = np.random.default_rng(0)
|
| 48 |
-
sdr = np.zeros((B, T, D), dtype=bool)
|
| 49 |
-
on = int(D * 0.02)
|
| 50 |
-
for b in range(B):
|
| 51 |
-
for t in range(T):
|
| 52 |
-
idx = rng.choice(D, size=on, replace=False)
|
| 53 |
-
sdr[b, t, idx] = True
|
| 54 |
-
sdr_t = torch.from_numpy(sdr)
|
| 55 |
-
|
| 56 |
-
# CPU baseline.
|
| 57 |
-
print("\n--- CPU ---")
|
| 58 |
-
cpu_layer = HTMLayer(
|
| 59 |
-
input_bits=D, n_columns=n_cols, cells_per_column=32,
|
| 60 |
-
batch_size=B, seed=42, use_gpu=False,
|
| 61 |
-
)
|
| 62 |
-
cpu_layer.train()
|
| 63 |
-
cpu_ms = bench(cpu_layer, sdr_t, warmup=1, iters=2)
|
| 64 |
-
print(f"CPU: {cpu_ms:.1f} ms/forward ({cpu_ms/T:.2f} ms/step Γ T={T})")
|
| 65 |
-
|
| 66 |
-
# GPU.
|
| 67 |
-
print("\n--- GPU ---")
|
| 68 |
-
gpu_layer = HTMLayer(
|
| 69 |
-
input_bits=D, n_columns=n_cols, cells_per_column=32,
|
| 70 |
-
batch_size=B, seed=42, use_gpu=True,
|
| 71 |
-
)
|
| 72 |
-
gpu_layer.train()
|
| 73 |
-
sdr_cuda = sdr_t.cuda()
|
| 74 |
-
gpu_ms = bench(gpu_layer, sdr_cuda, warmup=1, iters=2)
|
| 75 |
-
print(f"GPU: {gpu_ms:.1f} ms/forward ({gpu_ms/T:.2f} ms/step Γ T={T})")
|
| 76 |
-
|
| 77 |
-
print(f"\nSpeedup: {cpu_ms / gpu_ms:.2f}x")
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
if __name__ == "__main__":
|
| 81 |
-
main()
|
|
|
|
| 1 |
+
"""Microbenchmark: CPU vs GPU HTMLayer forward at HYDRA training sizes.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
source .venv/bin/activate
|
| 5 |
+
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 6 |
+
python htm_rust/bench_gpu.py
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
# Ensure /home/mikeb/work/feather is on sys.path so `subsystems` imports.
|
| 13 |
+
_FEATHER = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 14 |
+
if _FEATHER not in sys.path:
|
| 15 |
+
sys.path.insert(0, _FEATHER)
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from subsystems.htm import HTMLayer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def bench(layer: HTMLayer, sdr: torch.Tensor, warmup: int = 1, iters: int = 3) -> float:
|
| 24 |
+
"""Return mean ms/forward."""
|
| 25 |
+
for _ in range(warmup):
|
| 26 |
+
_ = layer(sdr)
|
| 27 |
+
if torch.cuda.is_available():
|
| 28 |
+
torch.cuda.synchronize()
|
| 29 |
+
t0 = time.perf_counter()
|
| 30 |
+
for _ in range(iters):
|
| 31 |
+
_ = layer(sdr)
|
| 32 |
+
if torch.cuda.is_available():
|
| 33 |
+
torch.cuda.synchronize()
|
| 34 |
+
dt = time.perf_counter() - t0
|
| 35 |
+
return dt * 1000 / iters
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def main() -> None:
|
| 39 |
+
# HYDRA training config: B=8, T=2048, bits=16384, cols=2048.
|
| 40 |
+
B, T, D = int(os.environ.get("B", 8)), int(os.environ.get("T", 2048)), 16384
|
| 41 |
+
n_cols = 2048
|
| 42 |
+
|
| 43 |
+
print(f"config: B={B} T={T} D={D} n_cols={n_cols}")
|
| 44 |
+
print(f"torch: {torch.__version__} cuda={torch.cuda.is_available()}")
|
| 45 |
+
|
| 46 |
+
# Build a fixed sparse SDR once.
|
| 47 |
+
rng = np.random.default_rng(0)
|
| 48 |
+
sdr = np.zeros((B, T, D), dtype=bool)
|
| 49 |
+
on = int(D * 0.02)
|
| 50 |
+
for b in range(B):
|
| 51 |
+
for t in range(T):
|
| 52 |
+
idx = rng.choice(D, size=on, replace=False)
|
| 53 |
+
sdr[b, t, idx] = True
|
| 54 |
+
sdr_t = torch.from_numpy(sdr)
|
| 55 |
+
|
| 56 |
+
# CPU baseline.
|
| 57 |
+
print("\n--- CPU ---")
|
| 58 |
+
cpu_layer = HTMLayer(
|
| 59 |
+
input_bits=D, n_columns=n_cols, cells_per_column=32,
|
| 60 |
+
batch_size=B, seed=42, use_gpu=False,
|
| 61 |
+
)
|
| 62 |
+
cpu_layer.train()
|
| 63 |
+
cpu_ms = bench(cpu_layer, sdr_t, warmup=1, iters=2)
|
| 64 |
+
print(f"CPU: {cpu_ms:.1f} ms/forward ({cpu_ms/T:.2f} ms/step Γ T={T})")
|
| 65 |
+
|
| 66 |
+
# GPU.
|
| 67 |
+
print("\n--- GPU ---")
|
| 68 |
+
gpu_layer = HTMLayer(
|
| 69 |
+
input_bits=D, n_columns=n_cols, cells_per_column=32,
|
| 70 |
+
batch_size=B, seed=42, use_gpu=True,
|
| 71 |
+
)
|
| 72 |
+
gpu_layer.train()
|
| 73 |
+
sdr_cuda = sdr_t.cuda()
|
| 74 |
+
gpu_ms = bench(gpu_layer, sdr_cuda, warmup=1, iters=2)
|
| 75 |
+
print(f"GPU: {gpu_ms:.1f} ms/forward ({gpu_ms/T:.2f} ms/step Γ T={T})")
|
| 76 |
+
|
| 77 |
+
print(f"\nSpeedup: {cpu_ms / gpu_ms:.2f}x")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
main()
|
overlay/htm_rust/build.rs
CHANGED
|
@@ -26,11 +26,8 @@ fn main() {
|
|
| 26 |
return;
|
| 27 |
}
|
| 28 |
|
| 29 |
-
|
| 30 |
-
let
|
| 31 |
-
|
| 32 |
-
// Base kernels β compile for any sm_80+ GPU. Each .cu file β one .ptx file.
|
| 33 |
-
let base_kernels: &[&str] = &[
|
| 34 |
"sp_overlap",
|
| 35 |
"sp_topk",
|
| 36 |
"sp_learn",
|
|
@@ -43,20 +40,17 @@ fn main() {
|
|
| 43 |
"tm_grow",
|
| 44 |
"tm_anomaly",
|
| 45 |
"tm_reset",
|
|
|
|
| 46 |
];
|
| 47 |
|
| 48 |
-
// htm_fused_step now compiles for ALL architectures (sm_80+).
|
| 49 |
-
// On Hopper (sm_90+): uses cluster-distributed shared memory for hot state.
|
| 50 |
-
// On Ampere (sm_86) and other pre-Hopper: uses global memory reads/writes
|
| 51 |
-
// with grid.sync() for cross-block synchronization (cooperative launch).
|
| 52 |
-
let kernels: Vec<&str> = base_kernels.iter().chain(["htm_fused_step"].iter()).copied().collect();
|
| 53 |
-
|
| 54 |
let kernels_dir = PathBuf::from("src/gpu/kernels");
|
| 55 |
-
for k in
|
| 56 |
let src = kernels_dir.join(format!("{k}.cu"));
|
| 57 |
println!("cargo:rerun-if-changed={}", src.display());
|
| 58 |
}
|
| 59 |
|
|
|
|
|
|
|
| 60 |
|
| 61 |
let nvcc = find_nvcc();
|
| 62 |
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
|
|
|
|
| 26 |
return;
|
| 27 |
}
|
| 28 |
|
| 29 |
+
// Kernels to compile. Each .cu file β one .ptx file, embedded by name.
|
| 30 |
+
let kernels: &[&str] = &[
|
|
|
|
|
|
|
|
|
|
| 31 |
"sp_overlap",
|
| 32 |
"sp_topk",
|
| 33 |
"sp_learn",
|
|
|
|
| 40 |
"tm_grow",
|
| 41 |
"tm_anomaly",
|
| 42 |
"tm_reset",
|
| 43 |
+
"htm_fused_step",
|
| 44 |
];
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
let kernels_dir = PathBuf::from("src/gpu/kernels");
|
| 47 |
+
for k in kernels {
|
| 48 |
let src = kernels_dir.join(format!("{k}.cu"));
|
| 49 |
println!("cargo:rerun-if-changed={}", src.display());
|
| 50 |
}
|
| 51 |
|
| 52 |
+
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR"));
|
| 53 |
+
let arch = env::var("HTM_CUDA_ARCH").unwrap_or_else(|_| "sm_86".into());
|
| 54 |
|
| 55 |
let nvcc = find_nvcc();
|
| 56 |
println!("cargo:warning=htm_rust: nvcc = {nvcc}");
|
overlay/htm_rust/docs/GPU_HTM.md
CHANGED
|
@@ -1,302 +1,302 @@
|
|
| 1 |
-
# GPU HTM Backend
|
| 2 |
-
|
| 3 |
-
## Status
|
| 4 |
-
|
| 5 |
-
**FUSED MEGAKERNEL: entire T-timestep SP+TM forward collapsed into a single
|
| 6 |
-
CUDA launch per forward pass.**
|
| 7 |
-
|
| 8 |
-
* Legacy path: 12 kernels Γ T=2048 timesteps = 24K launches per forward.
|
| 9 |
-
* Fused path: **1 launch per forward** (24000Γ launch-overhead reduction).
|
| 10 |
-
* End-to-end training throughput: **~2.7k β ~60k tok/sec** (~22x speedup).
|
| 11 |
-
* Fused path uses per-column threshold inhibition instead of global top-K
|
| 12 |
-
(see Β§Fused Kernel below β this is a real architectural change).
|
| 13 |
-
|
| 14 |
-
## Fused Kernel
|
| 15 |
-
|
| 16 |
-
### Why
|
| 17 |
-
|
| 18 |
-
Global top-K column selection requires cross-block synchronization at every
|
| 19 |
-
timestep. On WSL2/sm_86 without `-rdc=true`, `cooperative_groups::grid_sync()`
|
| 20 |
-
is unreliable. Without a grid sync, collapsing the T-loop into one kernel is
|
| 21 |
-
impossible, so every forward pays 12ΓT kernel launches and 90%+ of runtime is
|
| 22 |
-
CUDA launch overhead + small-kernel tails.
|
| 23 |
-
|
| 24 |
-
### How
|
| 25 |
-
|
| 26 |
-
Replace global top-K with **per-column threshold activation**:
|
| 27 |
-
|
| 28 |
-
is_active[c] = (overlap[c] * boost[c]) > inhibition_threshold[c]
|
| 29 |
-
|
| 30 |
-
`inhibition_threshold[c]` is a per-column scalar, learned via EMA update:
|
| 31 |
-
|
| 32 |
-
err = active_duty[c] - sparsity_target
|
| 33 |
-
new_thr = clamp(thr + thr_adapt_rate * err * 100, 0.1, 1000)
|
| 34 |
-
|
| 35 |
-
This is biologically grounded (GABAergic local lateral inhibition in
|
| 36 |
-
neocortical columns) and supported by HTM theory. The duty-cycle-driven
|
| 37 |
-
feedback loop was already present; we simply redirect its output to drive
|
| 38 |
-
activation threshold instead of multiplicative boost. The global top-K,
|
| 39 |
-
which had no biological basis, is removed.
|
| 40 |
-
|
| 41 |
-
### Cross-block coherence
|
| 42 |
-
|
| 43 |
-
- **Ping-pong bitsets** for `cell_active_bits` and `cell_winner_bits`: at
|
| 44 |
-
even t write to `_a`, read from `_b`; at odd t reversed. This eliminates
|
| 45 |
-
the need for an in-place snapshot kernel between timesteps.
|
| 46 |
-
- **Primary path: cooperative launch + hardware grid sync**. Host code probes
|
| 47 |
-
`CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH`, computes the cooperative whole-grid
|
| 48 |
-
residency limit from occupancy, and launches the fused megakernel with
|
| 49 |
-
`cuLaunchCooperativeKernel`. In-kernel barriers use
|
| 50 |
-
`cooperative_groups::this_grid().sync()`.
|
| 51 |
-
- **Fallback path: software grid barrier** via a 3-slot atomic counter array
|
| 52 |
-
(`barrier_counters`). This remains as a compatibility fallback when
|
| 53 |
-
cooperative launch is unavailable.
|
| 54 |
-
- **Launch invariant**: cooperative launch is capped to the hardware residency
|
| 55 |
-
limit for `blockDim.x = 1024`; software fallback remains capped conservatively
|
| 56 |
-
(`HTM_FUSED_GRID_CAP`, default 8) to avoid whole-grid spin deadlock.
|
| 57 |
-
|
| 58 |
-
### Kernel structure
|
| 59 |
-
|
| 60 |
-
```
|
| 61 |
-
for t in 0..T:
|
| 62 |
-
# Phase 0: clear curr_active/curr_winner for my column range
|
| 63 |
-
grid_barrier()
|
| 64 |
-
# Phase A: SP overlap β boost β threshold β SP learn β duty + threshold EMA
|
| 65 |
-
grid_barrier()
|
| 66 |
-
# Phase B: TM predict (per cell, per seg) β TM learn (reinforce on match)
|
| 67 |
-
# β burst if none predicted β segment grow/reinforce
|
| 68 |
-
grid_barrier()
|
| 69 |
-
# Phase C: block 0 writes anomaly[t]
|
| 70 |
-
```
|
| 71 |
-
|
| 72 |
-
Each warp owns a contiguous slice of columns. At grid=24 blocks Γ 32 warps =
|
| 73 |
-
768 warps, n_columns=2048 β 2-3 columns per warp.
|
| 74 |
-
|
| 75 |
-
### Parity with legacy GPU path
|
| 76 |
-
|
| 77 |
-
**Semantics diverge**. Legacy: exactly `k = round(sparsity * n_cols)` columns
|
| 78 |
-
active per step. Fused: variable, converging to `sparsity * n_cols` on
|
| 79 |
-
average via the per-column EMA. Anomaly decay on repeating sequences is
|
| 80 |
-
preserved (see `gpu_fused_tm_anomaly_decays_on_repeating_sequence` test).
|
| 81 |
-
|
| 82 |
-
This is an intentional architectural change committed under
|
| 83 |
-
`no-bypass/full-architecture` per program.md rules. The legacy top-K path
|
| 84 |
-
(`step_many_cuda`) remains available for reference and can be re-enabled via
|
| 85 |
-
`HYDRA_HTM_FUSED=0`.
|
| 86 |
-
|
| 87 |
-
### Tests
|
| 88 |
-
|
| 89 |
-
- `gpu_threshold_converges_to_sparsity` (tests.rs): 1000-step warmup on
|
| 90 |
-
random SDRs, then measure mean active cols/step on next 200 steps. Must
|
| 91 |
-
land within [0.25Γ, 4Γ] of `sparsity_target * n_cols`.
|
| 92 |
-
- `gpu_fused_tm_anomaly_decays_on_repeating_sequence`: feed A,B,C repeating
|
| 93 |
-
for 300 steps. Late anomaly must be < early anomaly AND < 0.5.
|
| 94 |
-
|
| 95 |
-
## Legacy Pipeline (kept for fallback)
|
| 96 |
-
|
| 97 |
-
* SP: 5 kernels, bit-identical parity with CPU under strict-parity mode.
|
| 98 |
-
* TM: 7 kernels, relaxed-parity with CPU.
|
| 99 |
-
* Speedup at training size (B=8, T=2048, bits=16384): **3.83x** vs CPU.
|
| 100 |
-
|
| 101 |
-
## Building
|
| 102 |
-
|
| 103 |
-
CPU-only (default, zero CUDA dep):
|
| 104 |
-
```bash
|
| 105 |
-
cargo build --release
|
| 106 |
-
```
|
| 107 |
-
|
| 108 |
-
GPU-enabled:
|
| 109 |
-
```bash
|
| 110 |
-
export PATH=/usr/local/cuda-12.1/bin:$PATH
|
| 111 |
-
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 112 |
-
export HTM_PTX_VERSION=7.8 # lower if driver older than nvcc
|
| 113 |
-
cargo build --release --features gpu
|
| 114 |
-
cargo test --release --features gpu --lib # fused path includes cooperative launch + grid-sync tests
|
| 115 |
-
|
| 116 |
-
# Python wheel:
|
| 117 |
-
maturin develop --release --features gpu --manifest-path htm_rust/Cargo.toml
|
| 118 |
-
```
|
| 119 |
-
|
| 120 |
-
## Architecture
|
| 121 |
-
|
| 122 |
-
### Module layout
|
| 123 |
-
```
|
| 124 |
-
src/gpu/
|
| 125 |
-
mod.rs # HTMRegionGpu pyclass + step_many_gpu (full pipeline)
|
| 126 |
-
sp_gpu.rs # Persistent SP device buffers + step_batch_with_tm
|
| 127 |
-
tm_gpu.rs # Persistent TM device buffers + step (predictβactivateβlearn)
|
| 128 |
-
tests.rs # CPU-vs-GPU SP parity + end-to-end TM anomaly decay
|
| 129 |
-
kernels/
|
| 130 |
-
sp_overlap.cu # per-column overlap reduction
|
| 131 |
-
sp_topk.cu # k-WTA top-K winner selection
|
| 132 |
-
sp_learn.cu # Hebbian +inc/-dec on proximal synapses
|
| 133 |
-
sp_duty.cu # EMA duty-cycle update
|
| 134 |
-
sp_boost_fused.cu # fused mean + exp boost (GPU-side)
|
| 135 |
-
tm_reset.cu # per-step: snapshot activeβprev, clear buffers
|
| 136 |
-
tm_predict.cu # per-cell: score owned segments vs prev_active_bits
|
| 137 |
-
tm_activate.cu # per-col: activate predicted cells OR burst
|
| 138 |
-
tm_learn.cu # per-cell: reinforce correctly-predicted segments
|
| 139 |
-
tm_punish.cu # per-cell: decay matching segs on inactive cols
|
| 140 |
-
tm_grow.cu # per-bursting-col: reuse matching seg OR create new,
|
| 141 |
-
# grow synapses to prev_winners
|
| 142 |
-
tm_anomaly.cu # per-step: unpredicted/active ratio
|
| 143 |
-
```
|
| 144 |
-
|
| 145 |
-
### Persistent SP state (per region, unchanged from Phase 1)
|
| 146 |
-
At n_cols=2048, S=40, bits=16384: ~355 KB persistent + ~90 KB transient.
|
| 147 |
-
|
| 148 |
-
### Persistent TM state (per region)
|
| 149 |
-
|
| 150 |
-
Capacity knobs (configured in `tm_gpu.rs`):
|
| 151 |
-
- `MAX_SEGMENTS_PER_CELL = 4`
|
| 152 |
-
- `MAX_SYN_PER_SEGMENT = 20`
|
| 153 |
-
|
| 154 |
-
At cells_per_col=32, n_cols=2048:
|
| 155 |
-
- `n_cells = 65_536`
|
| 156 |
-
- `n_segments_max = 262_144` (~262K)
|
| 157 |
-
- `n_synapses_max = 5_242_880` (~5.2M)
|
| 158 |
-
|
| 159 |
-
| Buffer | Shape / type | Notes |
|
| 160 |
-
|-----------------------|----------------------|----------------------------------------|
|
| 161 |
-
| `seg_cell_id` | (n_segs,) u32 | owning cell; U32_MAX = unused |
|
| 162 |
-
| `seg_syn_count` | (n_segs,) u32 | #active synapses in slot |
|
| 163 |
-
| `syn_presyn` | (n_segs Γ S,) u32 | presynaptic cell indices |
|
| 164 |
-
| `syn_perm` | (n_segs Γ S,) i16 | permanence scaled 0..32767 (0.0..1.0) |
|
| 165 |
-
| `cell_seg_count` | (n_cells,) u32 | segments allocated on each cell |
|
| 166 |
-
| `cell_active_bits` | (n_cells/32,) u32 | packed bitset, current step |
|
| 167 |
-
| `cell_winner_bits` | (n_cells/32,) u32 | packed bitset, current step |
|
| 168 |
-
| `cell_predictive_bits`| (n_cells/32,) u32 | set by predict, read by activate |
|
| 169 |
-
| `prev_active_bits` | (n_cells/32,) u32 | snapshot at step start |
|
| 170 |
-
| `prev_winner_bits` | (n_cells/32,) u32 | snapshot at step start |
|
| 171 |
-
| `col_predicted` | (n_cols,) u8 | set if any cell in col is predictive |
|
| 172 |
-
| `col_best_match` | (n_cols,) u32 | packed (pot<<21 | seg_id), atomicMax |
|
| 173 |
-
| `seg_num_active_conn` | (n_segs,) u32 | output of predict |
|
| 174 |
-
| `seg_num_active_pot` | (n_segs,) u32 | output of predict |
|
| 175 |
-
| `unpredicted_count` | (1,) u32 | atomic counter for anomaly |
|
| 176 |
-
| `burst_cols_flat` | (n_cols,) u32 | list of bursting cols |
|
| 177 |
-
| `burst_cols_count` | (1,) u32 | length of above list |
|
| 178 |
-
|
| 179 |
-
**Total per TM region: ~42 MB.** Batch of 8 regions: ~340 MB. Fits 6 GB RTX 3060.
|
| 180 |
-
|
| 181 |
-
### Per-step pipeline (single iteration of `step_batch_with_tm`)
|
| 182 |
-
|
| 183 |
-
```
|
| 184 |
-
SP side TM side
|
| 185 |
-
--------- ---------
|
| 186 |
-
1. D2D input slice β inp_dev
|
| 187 |
-
2. sp_overlap (n_cols blocks)
|
| 188 |
-
3. sp_topk (1 block)
|
| 189 |
-
4. sp_learn (n_cols blocks)
|
| 190 |
-
5. sp_duty (n_cols/256 blocks)
|
| 191 |
-
6. sp_boost_fused (1 block)
|
| 192 |
-
7. D2D active_mask β cols_dev[ti]
|
| 193 |
-
8. tm_reset_step (ceil(n_cells/32/256))
|
| 194 |
-
9. tm_predict (n_cells blocks Γ 32 thr)
|
| 195 |
-
10. tm_activate (n_cols/256 blocks)
|
| 196 |
-
11. tm_anomaly (1 block)
|
| 197 |
-
if learn:
|
| 198 |
-
12. tm_learn (n_cells blocks)
|
| 199 |
-
13. tm_punish (n_cells blocks)
|
| 200 |
-
14. tm_grow (n_cols blocks β early-exits)
|
| 201 |
-
```
|
| 202 |
-
|
| 203 |
-
No host sync in the T-step loop. At the end one `dtoh_sync_copy` each for
|
| 204 |
-
`cols_dev` (T Γ n_cols bytes) and `anom_dev` (T Γ f32).
|
| 205 |
-
|
| 206 |
-
## Parity
|
| 207 |
-
|
| 208 |
-
### SP: strict bit-identical
|
| 209 |
-
See Phase 1 docs β `gpu_sp_matches_cpu_with_learn` over 50 steps passes exact.
|
| 210 |
-
|
| 211 |
-
### TM: relaxed-parity
|
| 212 |
-
The GPU TM has known, deliberate deviations from CPU to admit massive parallelism:
|
| 213 |
-
|
| 214 |
-
1. **Bursting winner cell**: CPU picks the least-used cell (fewest segments) with
|
| 215 |
-
random tiebreak. GPU picks cell 0 of the column (deterministic, branch-free).
|
| 216 |
-
Learning dynamics are preserved because segment creation/reinforcement is
|
| 217 |
-
the dominant effect, not which specific cell in a bursting column wins.
|
| 218 |
-
|
| 219 |
-
2. **Permanence storage**: i16 fixed-point (scale 32767) vs f32. Rounding
|
| 220 |
-
differs by <=1 ULP of the scale (~3.0e-5), below any meaningful learning
|
| 221 |
-
quantum (inc=0.10, dec=0.10, predicted_segment_dec=0.10).
|
| 222 |
-
|
| 223 |
-
3. **Grown synapse candidate order**: CPU randomly samples from prev_winner_cells.
|
| 224 |
-
GPU iterates prev_winner_bits words in a pseudo-random rotated order keyed
|
| 225 |
-
by (bursting_col_idx, iter_seed). Output is a different subset but same size.
|
| 226 |
-
|
| 227 |
-
4. **Segment LRU eviction**: CPU tracks `last_used_iteration` per segment.
|
| 228 |
-
GPU wraps around (slot = count % max_segments_per_cell). In the autoresearch
|
| 229 |
-
loop where TM resets every forward, eviction rarely triggers.
|
| 230 |
-
|
| 231 |
-
The GPU parity test (`gpu_tm_anomaly_decays_on_repeating_sequence`) feeds a
|
| 232 |
-
repeating A,B,C sequence and asserts anomaly decays: **1.000 early β 0.000 late**.
|
| 233 |
-
|
| 234 |
-
## Bottleneck Analysis
|
| 235 |
-
|
| 236 |
-
| Source | Cost/step (B=8 T=2048) |
|
| 237 |
-
|----------------------------------|-------------------------:|
|
| 238 |
-
| 14 kernel launches | ~70 ΞΌs |
|
| 239 |
-
| ~262K predict/learn/punish blocks| ~2.5 ms |
|
| 240 |
-
| No D2H until end-of-batch | 0 ΞΌs |
|
| 241 |
-
| Final D2H (T Γ n_cols + T Γ f32) | ~200 ΞΌs per region |
|
| 242 |
-
|
| 243 |
-
Per-step wall time at B=8 T=2048:
|
| 244 |
-
- CPU (reference): **~11.4 ms / step**
|
| 245 |
-
- GPU (current): **~2.98 ms / step**
|
| 246 |
-
- **Speedup: 3.83x**
|
| 247 |
-
|
| 248 |
-
## End-to-End Training Benchmark
|
| 249 |
-
|
| 250 |
-
**Config**: B=8, T=2048, vocab=8192, 60-second time budget, full HYDRA stack
|
| 251 |
-
(SDR Semantic + HTM + Mamba-3 + Engram + mHC + Hestia QAT).
|
| 252 |
-
|
| 253 |
-
**Results**:
|
| 254 |
-
- GPU util: **97-98% sustained**
|
| 255 |
-
- VRAM: **5.4 GB / 6.0 GB** (90% utilisation)
|
| 256 |
-
- Steps completed: 16
|
| 257 |
-
- tok/sec: **~2,200-2,500** (stable post-warmup)
|
| 258 |
-
- Final val_bpb: **2.249** (from ~3.1 initial)
|
| 259 |
-
- Factual eval: 1/9 hits
|
| 260 |
-
|
| 261 |
-
Compared to previous CPU-HTM baseline (~100 tok/s), the full-GPU HTM delivers
|
| 262 |
-
**~22x end-to-end throughput** β far above the 3-10x target.
|
| 263 |
-
|
| 264 |
-
## Bench Commands
|
| 265 |
-
|
| 266 |
-
```bash
|
| 267 |
-
source .venv/bin/activate
|
| 268 |
-
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 269 |
-
|
| 270 |
-
# Microbench
|
| 271 |
-
B=8 T=2048 python htm_rust/bench_gpu.py
|
| 272 |
-
|
| 273 |
-
# Full training
|
| 274 |
-
HYDRA_TIME_BUDGET=60 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=32768 python -u train.py
|
| 275 |
-
```
|
| 276 |
-
|
| 277 |
-
## Known Limitations / Future Work
|
| 278 |
-
|
| 279 |
-
- **Segment-compacted launches**: predict/learn/punish iterate all n_cells
|
| 280 |
-
blocks, using `cell_seg_count` to skip empty cells. A compacted live-cell
|
| 281 |
-
list would shave another ~40% of launch overhead.
|
| 282 |
-
- **Winner selection**: currently cell 0 of bursting col. Proper least-used
|
| 283 |
-
selection would help stability of cross-column patterns.
|
| 284 |
-
- **Single CUDA stream per region**: with B=8 regions we serialise on stream 0.
|
| 285 |
-
Multi-stream would lift the ~20% launch overhead at small batch sizes.
|
| 286 |
-
- **Permanence bump on chronically under-stimulated columns**: SP's strict-parity
|
| 287 |
-
bump is not mirrored on GPU fast path. Effect on long runs needs measurement.
|
| 288 |
-
- **`seg_num_active_conn` output is reused across reinforce + punish**: the two
|
| 289 |
-
kernels each launch n_cells blocks. They could be fused into one for one fewer
|
| 290 |
-
kernel launch per step.
|
| 291 |
-
|
| 292 |
-
## Files
|
| 293 |
-
|
| 294 |
-
- `htm_rust/build.rs` β nvcc-driven PTX compilation, 12 kernels.
|
| 295 |
-
- `htm_rust/Cargo.toml` β `gpu` feature flag, cudarc dep.
|
| 296 |
-
- `htm_rust/src/gpu/mod.rs` β `HTMRegionGpu` pyclass + `step_many_gpu`.
|
| 297 |
-
- `htm_rust/src/gpu/sp_gpu.rs` β SP state + `step_batch_with_tm`.
|
| 298 |
-
- `htm_rust/src/gpu/tm_gpu.rs` β TM state + `step`.
|
| 299 |
-
- `htm_rust/src/gpu/tests.rs` β parity + correctness tests.
|
| 300 |
-
- `htm_rust/src/gpu/kernels/*.cu` β 5 SP + 7 TM kernels.
|
| 301 |
-
- `htm_rust/bench_gpu.py` β CPU-vs-GPU microbench.
|
| 302 |
-
- `subsystems/htm.py` β transparent GPU/CPU backend selection in `HTMLayer`.
|
|
|
|
| 1 |
+
# GPU HTM Backend
|
| 2 |
+
|
| 3 |
+
## Status
|
| 4 |
+
|
| 5 |
+
**FUSED MEGAKERNEL: entire T-timestep SP+TM forward collapsed into a single
|
| 6 |
+
CUDA launch per forward pass.**
|
| 7 |
+
|
| 8 |
+
* Legacy path: 12 kernels Γ T=2048 timesteps = 24K launches per forward.
|
| 9 |
+
* Fused path: **1 launch per forward** (24000Γ launch-overhead reduction).
|
| 10 |
+
* End-to-end training throughput: **~2.7k β ~60k tok/sec** (~22x speedup).
|
| 11 |
+
* Fused path uses per-column threshold inhibition instead of global top-K
|
| 12 |
+
(see Β§Fused Kernel below β this is a real architectural change).
|
| 13 |
+
|
| 14 |
+
## Fused Kernel
|
| 15 |
+
|
| 16 |
+
### Why
|
| 17 |
+
|
| 18 |
+
Global top-K column selection requires cross-block synchronization at every
|
| 19 |
+
timestep. On WSL2/sm_86 without `-rdc=true`, `cooperative_groups::grid_sync()`
|
| 20 |
+
is unreliable. Without a grid sync, collapsing the T-loop into one kernel is
|
| 21 |
+
impossible, so every forward pays 12ΓT kernel launches and 90%+ of runtime is
|
| 22 |
+
CUDA launch overhead + small-kernel tails.
|
| 23 |
+
|
| 24 |
+
### How
|
| 25 |
+
|
| 26 |
+
Replace global top-K with **per-column threshold activation**:
|
| 27 |
+
|
| 28 |
+
is_active[c] = (overlap[c] * boost[c]) > inhibition_threshold[c]
|
| 29 |
+
|
| 30 |
+
`inhibition_threshold[c]` is a per-column scalar, learned via EMA update:
|
| 31 |
+
|
| 32 |
+
err = active_duty[c] - sparsity_target
|
| 33 |
+
new_thr = clamp(thr + thr_adapt_rate * err * 100, 0.1, 1000)
|
| 34 |
+
|
| 35 |
+
This is biologically grounded (GABAergic local lateral inhibition in
|
| 36 |
+
neocortical columns) and supported by HTM theory. The duty-cycle-driven
|
| 37 |
+
feedback loop was already present; we simply redirect its output to drive
|
| 38 |
+
activation threshold instead of multiplicative boost. The global top-K,
|
| 39 |
+
which had no biological basis, is removed.
|
| 40 |
+
|
| 41 |
+
### Cross-block coherence
|
| 42 |
+
|
| 43 |
+
- **Ping-pong bitsets** for `cell_active_bits` and `cell_winner_bits`: at
|
| 44 |
+
even t write to `_a`, read from `_b`; at odd t reversed. This eliminates
|
| 45 |
+
the need for an in-place snapshot kernel between timesteps.
|
| 46 |
+
- **Primary path: cooperative launch + hardware grid sync**. Host code probes
|
| 47 |
+
`CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH`, computes the cooperative whole-grid
|
| 48 |
+
residency limit from occupancy, and launches the fused megakernel with
|
| 49 |
+
`cuLaunchCooperativeKernel`. In-kernel barriers use
|
| 50 |
+
`cooperative_groups::this_grid().sync()`.
|
| 51 |
+
- **Fallback path: software grid barrier** via a 3-slot atomic counter array
|
| 52 |
+
(`barrier_counters`). This remains as a compatibility fallback when
|
| 53 |
+
cooperative launch is unavailable.
|
| 54 |
+
- **Launch invariant**: cooperative launch is capped to the hardware residency
|
| 55 |
+
limit for `blockDim.x = 1024`; software fallback remains capped conservatively
|
| 56 |
+
(`HTM_FUSED_GRID_CAP`, default 8) to avoid whole-grid spin deadlock.
|
| 57 |
+
|
| 58 |
+
### Kernel structure
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
for t in 0..T:
|
| 62 |
+
# Phase 0: clear curr_active/curr_winner for my column range
|
| 63 |
+
grid_barrier()
|
| 64 |
+
# Phase A: SP overlap β boost β threshold β SP learn β duty + threshold EMA
|
| 65 |
+
grid_barrier()
|
| 66 |
+
# Phase B: TM predict (per cell, per seg) β TM learn (reinforce on match)
|
| 67 |
+
# β burst if none predicted β segment grow/reinforce
|
| 68 |
+
grid_barrier()
|
| 69 |
+
# Phase C: block 0 writes anomaly[t]
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
Each warp owns a contiguous slice of columns. At grid=24 blocks Γ 32 warps =
|
| 73 |
+
768 warps, n_columns=2048 β 2-3 columns per warp.
|
| 74 |
+
|
| 75 |
+
### Parity with legacy GPU path
|
| 76 |
+
|
| 77 |
+
**Semantics diverge**. Legacy: exactly `k = round(sparsity * n_cols)` columns
|
| 78 |
+
active per step. Fused: variable, converging to `sparsity * n_cols` on
|
| 79 |
+
average via the per-column EMA. Anomaly decay on repeating sequences is
|
| 80 |
+
preserved (see `gpu_fused_tm_anomaly_decays_on_repeating_sequence` test).
|
| 81 |
+
|
| 82 |
+
This is an intentional architectural change committed under
|
| 83 |
+
`no-bypass/full-architecture` per program.md rules. The legacy top-K path
|
| 84 |
+
(`step_many_cuda`) remains available for reference and can be re-enabled via
|
| 85 |
+
`HYDRA_HTM_FUSED=0`.
|
| 86 |
+
|
| 87 |
+
### Tests
|
| 88 |
+
|
| 89 |
+
- `gpu_threshold_converges_to_sparsity` (tests.rs): 1000-step warmup on
|
| 90 |
+
random SDRs, then measure mean active cols/step on next 200 steps. Must
|
| 91 |
+
land within [0.25Γ, 4Γ] of `sparsity_target * n_cols`.
|
| 92 |
+
- `gpu_fused_tm_anomaly_decays_on_repeating_sequence`: feed A,B,C repeating
|
| 93 |
+
for 300 steps. Late anomaly must be < early anomaly AND < 0.5.
|
| 94 |
+
|
| 95 |
+
## Legacy Pipeline (kept for fallback)
|
| 96 |
+
|
| 97 |
+
* SP: 5 kernels, bit-identical parity with CPU under strict-parity mode.
|
| 98 |
+
* TM: 7 kernels, relaxed-parity with CPU.
|
| 99 |
+
* Speedup at training size (B=8, T=2048, bits=16384): **3.83x** vs CPU.
|
| 100 |
+
|
| 101 |
+
## Building
|
| 102 |
+
|
| 103 |
+
CPU-only (default, zero CUDA dep):
|
| 104 |
+
```bash
|
| 105 |
+
cargo build --release
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
GPU-enabled:
|
| 109 |
+
```bash
|
| 110 |
+
export PATH=/usr/local/cuda-12.1/bin:$PATH
|
| 111 |
+
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 112 |
+
export HTM_PTX_VERSION=7.8 # lower if driver older than nvcc
|
| 113 |
+
cargo build --release --features gpu
|
| 114 |
+
cargo test --release --features gpu --lib # fused path includes cooperative launch + grid-sync tests
|
| 115 |
+
|
| 116 |
+
# Python wheel:
|
| 117 |
+
maturin develop --release --features gpu --manifest-path htm_rust/Cargo.toml
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
## Architecture
|
| 121 |
+
|
| 122 |
+
### Module layout
|
| 123 |
+
```
|
| 124 |
+
src/gpu/
|
| 125 |
+
mod.rs # HTMRegionGpu pyclass + step_many_gpu (full pipeline)
|
| 126 |
+
sp_gpu.rs # Persistent SP device buffers + step_batch_with_tm
|
| 127 |
+
tm_gpu.rs # Persistent TM device buffers + step (predictβactivateβlearn)
|
| 128 |
+
tests.rs # CPU-vs-GPU SP parity + end-to-end TM anomaly decay
|
| 129 |
+
kernels/
|
| 130 |
+
sp_overlap.cu # per-column overlap reduction
|
| 131 |
+
sp_topk.cu # k-WTA top-K winner selection
|
| 132 |
+
sp_learn.cu # Hebbian +inc/-dec on proximal synapses
|
| 133 |
+
sp_duty.cu # EMA duty-cycle update
|
| 134 |
+
sp_boost_fused.cu # fused mean + exp boost (GPU-side)
|
| 135 |
+
tm_reset.cu # per-step: snapshot activeβprev, clear buffers
|
| 136 |
+
tm_predict.cu # per-cell: score owned segments vs prev_active_bits
|
| 137 |
+
tm_activate.cu # per-col: activate predicted cells OR burst
|
| 138 |
+
tm_learn.cu # per-cell: reinforce correctly-predicted segments
|
| 139 |
+
tm_punish.cu # per-cell: decay matching segs on inactive cols
|
| 140 |
+
tm_grow.cu # per-bursting-col: reuse matching seg OR create new,
|
| 141 |
+
# grow synapses to prev_winners
|
| 142 |
+
tm_anomaly.cu # per-step: unpredicted/active ratio
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Persistent SP state (per region, unchanged from Phase 1)
|
| 146 |
+
At n_cols=2048, S=40, bits=16384: ~355 KB persistent + ~90 KB transient.
|
| 147 |
+
|
| 148 |
+
### Persistent TM state (per region)
|
| 149 |
+
|
| 150 |
+
Capacity knobs (configured in `tm_gpu.rs`):
|
| 151 |
+
- `MAX_SEGMENTS_PER_CELL = 4`
|
| 152 |
+
- `MAX_SYN_PER_SEGMENT = 20`
|
| 153 |
+
|
| 154 |
+
At cells_per_col=32, n_cols=2048:
|
| 155 |
+
- `n_cells = 65_536`
|
| 156 |
+
- `n_segments_max = 262_144` (~262K)
|
| 157 |
+
- `n_synapses_max = 5_242_880` (~5.2M)
|
| 158 |
+
|
| 159 |
+
| Buffer | Shape / type | Notes |
|
| 160 |
+
|-----------------------|----------------------|----------------------------------------|
|
| 161 |
+
| `seg_cell_id` | (n_segs,) u32 | owning cell; U32_MAX = unused |
|
| 162 |
+
| `seg_syn_count` | (n_segs,) u32 | #active synapses in slot |
|
| 163 |
+
| `syn_presyn` | (n_segs Γ S,) u32 | presynaptic cell indices |
|
| 164 |
+
| `syn_perm` | (n_segs Γ S,) i16 | permanence scaled 0..32767 (0.0..1.0) |
|
| 165 |
+
| `cell_seg_count` | (n_cells,) u32 | segments allocated on each cell |
|
| 166 |
+
| `cell_active_bits` | (n_cells/32,) u32 | packed bitset, current step |
|
| 167 |
+
| `cell_winner_bits` | (n_cells/32,) u32 | packed bitset, current step |
|
| 168 |
+
| `cell_predictive_bits`| (n_cells/32,) u32 | set by predict, read by activate |
|
| 169 |
+
| `prev_active_bits` | (n_cells/32,) u32 | snapshot at step start |
|
| 170 |
+
| `prev_winner_bits` | (n_cells/32,) u32 | snapshot at step start |
|
| 171 |
+
| `col_predicted` | (n_cols,) u8 | set if any cell in col is predictive |
|
| 172 |
+
| `col_best_match` | (n_cols,) u32 | packed (pot<<21 | seg_id), atomicMax |
|
| 173 |
+
| `seg_num_active_conn` | (n_segs,) u32 | output of predict |
|
| 174 |
+
| `seg_num_active_pot` | (n_segs,) u32 | output of predict |
|
| 175 |
+
| `unpredicted_count` | (1,) u32 | atomic counter for anomaly |
|
| 176 |
+
| `burst_cols_flat` | (n_cols,) u32 | list of bursting cols |
|
| 177 |
+
| `burst_cols_count` | (1,) u32 | length of above list |
|
| 178 |
+
|
| 179 |
+
**Total per TM region: ~42 MB.** Batch of 8 regions: ~340 MB. Fits 6 GB RTX 3060.
|
| 180 |
+
|
| 181 |
+
### Per-step pipeline (single iteration of `step_batch_with_tm`)
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
SP side TM side
|
| 185 |
+
--------- ---------
|
| 186 |
+
1. D2D input slice β inp_dev
|
| 187 |
+
2. sp_overlap (n_cols blocks)
|
| 188 |
+
3. sp_topk (1 block)
|
| 189 |
+
4. sp_learn (n_cols blocks)
|
| 190 |
+
5. sp_duty (n_cols/256 blocks)
|
| 191 |
+
6. sp_boost_fused (1 block)
|
| 192 |
+
7. D2D active_mask β cols_dev[ti]
|
| 193 |
+
8. tm_reset_step (ceil(n_cells/32/256))
|
| 194 |
+
9. tm_predict (n_cells blocks Γ 32 thr)
|
| 195 |
+
10. tm_activate (n_cols/256 blocks)
|
| 196 |
+
11. tm_anomaly (1 block)
|
| 197 |
+
if learn:
|
| 198 |
+
12. tm_learn (n_cells blocks)
|
| 199 |
+
13. tm_punish (n_cells blocks)
|
| 200 |
+
14. tm_grow (n_cols blocks β early-exits)
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
No host sync in the T-step loop. At the end one `dtoh_sync_copy` each for
|
| 204 |
+
`cols_dev` (T Γ n_cols bytes) and `anom_dev` (T Γ f32).
|
| 205 |
+
|
| 206 |
+
## Parity
|
| 207 |
+
|
| 208 |
+
### SP: strict bit-identical
|
| 209 |
+
See Phase 1 docs β `gpu_sp_matches_cpu_with_learn` over 50 steps passes exact.
|
| 210 |
+
|
| 211 |
+
### TM: relaxed-parity
|
| 212 |
+
The GPU TM has known, deliberate deviations from CPU to admit massive parallelism:
|
| 213 |
+
|
| 214 |
+
1. **Bursting winner cell**: CPU picks the least-used cell (fewest segments) with
|
| 215 |
+
random tiebreak. GPU picks cell 0 of the column (deterministic, branch-free).
|
| 216 |
+
Learning dynamics are preserved because segment creation/reinforcement is
|
| 217 |
+
the dominant effect, not which specific cell in a bursting column wins.
|
| 218 |
+
|
| 219 |
+
2. **Permanence storage**: i16 fixed-point (scale 32767) vs f32. Rounding
|
| 220 |
+
differs by <=1 ULP of the scale (~3.0e-5), below any meaningful learning
|
| 221 |
+
quantum (inc=0.10, dec=0.10, predicted_segment_dec=0.10).
|
| 222 |
+
|
| 223 |
+
3. **Grown synapse candidate order**: CPU randomly samples from prev_winner_cells.
|
| 224 |
+
GPU iterates prev_winner_bits words in a pseudo-random rotated order keyed
|
| 225 |
+
by (bursting_col_idx, iter_seed). Output is a different subset but same size.
|
| 226 |
+
|
| 227 |
+
4. **Segment LRU eviction**: CPU tracks `last_used_iteration` per segment.
|
| 228 |
+
GPU wraps around (slot = count % max_segments_per_cell). In the autoresearch
|
| 229 |
+
loop where TM resets every forward, eviction rarely triggers.
|
| 230 |
+
|
| 231 |
+
The GPU parity test (`gpu_tm_anomaly_decays_on_repeating_sequence`) feeds a
|
| 232 |
+
repeating A,B,C sequence and asserts anomaly decays: **1.000 early β 0.000 late**.
|
| 233 |
+
|
| 234 |
+
## Bottleneck Analysis
|
| 235 |
+
|
| 236 |
+
| Source | Cost/step (B=8 T=2048) |
|
| 237 |
+
|----------------------------------|-------------------------:|
|
| 238 |
+
| 14 kernel launches | ~70 ΞΌs |
|
| 239 |
+
| ~262K predict/learn/punish blocks| ~2.5 ms |
|
| 240 |
+
| No D2H until end-of-batch | 0 ΞΌs |
|
| 241 |
+
| Final D2H (T Γ n_cols + T Γ f32) | ~200 ΞΌs per region |
|
| 242 |
+
|
| 243 |
+
Per-step wall time at B=8 T=2048:
|
| 244 |
+
- CPU (reference): **~11.4 ms / step**
|
| 245 |
+
- GPU (current): **~2.98 ms / step**
|
| 246 |
+
- **Speedup: 3.83x**
|
| 247 |
+
|
| 248 |
+
## End-to-End Training Benchmark
|
| 249 |
+
|
| 250 |
+
**Config**: B=8, T=2048, vocab=8192, 60-second time budget, full HYDRA stack
|
| 251 |
+
(SDR Semantic + HTM + Mamba-3 + Engram + mHC + Hestia QAT).
|
| 252 |
+
|
| 253 |
+
**Results**:
|
| 254 |
+
- GPU util: **97-98% sustained**
|
| 255 |
+
- VRAM: **5.4 GB / 6.0 GB** (90% utilisation)
|
| 256 |
+
- Steps completed: 16
|
| 257 |
+
- tok/sec: **~2,200-2,500** (stable post-warmup)
|
| 258 |
+
- Final val_bpb: **2.249** (from ~3.1 initial)
|
| 259 |
+
- Factual eval: 1/9 hits
|
| 260 |
+
|
| 261 |
+
Compared to previous CPU-HTM baseline (~100 tok/s), the full-GPU HTM delivers
|
| 262 |
+
**~22x end-to-end throughput** β far above the 3-10x target.
|
| 263 |
+
|
| 264 |
+
## Bench Commands
|
| 265 |
+
|
| 266 |
+
```bash
|
| 267 |
+
source .venv/bin/activate
|
| 268 |
+
export LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
|
| 269 |
+
|
| 270 |
+
# Microbench
|
| 271 |
+
B=8 T=2048 python htm_rust/bench_gpu.py
|
| 272 |
+
|
| 273 |
+
# Full training
|
| 274 |
+
HYDRA_TIME_BUDGET=60 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=32768 python -u train.py
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
## Known Limitations / Future Work
|
| 278 |
+
|
| 279 |
+
- **Segment-compacted launches**: predict/learn/punish iterate all n_cells
|
| 280 |
+
blocks, using `cell_seg_count` to skip empty cells. A compacted live-cell
|
| 281 |
+
list would shave another ~40% of launch overhead.
|
| 282 |
+
- **Winner selection**: currently cell 0 of bursting col. Proper least-used
|
| 283 |
+
selection would help stability of cross-column patterns.
|
| 284 |
+
- **Single CUDA stream per region**: with B=8 regions we serialise on stream 0.
|
| 285 |
+
Multi-stream would lift the ~20% launch overhead at small batch sizes.
|
| 286 |
+
- **Permanence bump on chronically under-stimulated columns**: SP's strict-parity
|
| 287 |
+
bump is not mirrored on GPU fast path. Effect on long runs needs measurement.
|
| 288 |
+
- **`seg_num_active_conn` output is reused across reinforce + punish**: the two
|
| 289 |
+
kernels each launch n_cells blocks. They could be fused into one for one fewer
|
| 290 |
+
kernel launch per step.
|
| 291 |
+
|
| 292 |
+
## Files
|
| 293 |
+
|
| 294 |
+
- `htm_rust/build.rs` β nvcc-driven PTX compilation, 12 kernels.
|
| 295 |
+
- `htm_rust/Cargo.toml` β `gpu` feature flag, cudarc dep.
|
| 296 |
+
- `htm_rust/src/gpu/mod.rs` β `HTMRegionGpu` pyclass + `step_many_gpu`.
|
| 297 |
+
- `htm_rust/src/gpu/sp_gpu.rs` β SP state + `step_batch_with_tm`.
|
| 298 |
+
- `htm_rust/src/gpu/tm_gpu.rs` β TM state + `step`.
|
| 299 |
+
- `htm_rust/src/gpu/tests.rs` β parity + correctness tests.
|
| 300 |
+
- `htm_rust/src/gpu/kernels/*.cu` β 5 SP + 7 TM kernels.
|
| 301 |
+
- `htm_rust/bench_gpu.py` β CPU-vs-GPU microbench.
|
| 302 |
+
- `subsystems/htm.py` β transparent GPU/CPU backend selection in `HTMLayer`.
|
overlay/htm_rust/src/gpu/fused.rs
CHANGED
|
@@ -20,15 +20,15 @@
|
|
| 20 |
use std::ffi::CString;
|
| 21 |
use std::sync::Arc;
|
| 22 |
|
| 23 |
-
use cudarc::driver::{
|
| 24 |
-
|
|
|
|
| 25 |
use cudarc::nvrtc::Ptx;
|
| 26 |
|
| 27 |
use super::sp_gpu::SpatialPoolerGpu;
|
| 28 |
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
|
| 29 |
|
| 30 |
-
const PTX_HTM_FUSED: &str =
|
| 31 |
-
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
|
| 32 |
|
| 33 |
/// Struct-by-value pointer pack β matches C-side `FusedPtrs`.
|
| 34 |
///
|
|
@@ -132,11 +132,9 @@ pub(crate) fn plan_fused_launch(
|
|
| 132 |
grid_cap_override: Option<u32>,
|
| 133 |
) -> Result<FusedLaunchPlan, String> {
|
| 134 |
let sm_count = sm_count.max(1);
|
| 135 |
-
// 1024 threads/block exceeds the register file on Ampere
|
| 136 |
-
//
|
| 137 |
-
//
|
| 138 |
-
// cooperative launch. On Hopper (228 KB smem, 255 regs/thread baseline),
|
| 139 |
-
// 1024 works fine, but 256 is safe everywhere.
|
| 140 |
let block_dim_x = 256u32;
|
| 141 |
|
| 142 |
// Cluster launch path: cooperative launch is not required. Keep the probe
|
|
@@ -145,10 +143,11 @@ pub(crate) fn plan_fused_launch(
|
|
| 145 |
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
|
| 146 |
}
|
| 147 |
|
| 148 |
-
//
|
| 149 |
-
//
|
|
|
|
| 150 |
let default_grid_cap = 16u32;
|
| 151 |
-
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap);
|
| 152 |
let resident_bound = if cooperative_grid_limit > 0 {
|
| 153 |
cooperative_grid_limit.max(sm_count * 2)
|
| 154 |
} else {
|
|
@@ -218,7 +217,7 @@ pub struct FusedState {
|
|
| 218 |
pub cell_active_bits_b: CudaSlice<u32>,
|
| 219 |
pub cell_winner_bits_a: CudaSlice<u32>,
|
| 220 |
pub cell_winner_bits_b: CudaSlice<u32>,
|
| 221 |
-
pub step_scratch: CudaSlice<u32>,
|
| 222 |
|
| 223 |
pub grid_dim_x: u32,
|
| 224 |
pub block_dim_x: u32,
|
|
@@ -241,7 +240,10 @@ impl FusedState {
|
|
| 241 |
initial_threshold: f32,
|
| 242 |
) -> Result<Self, DriverError> {
|
| 243 |
let n_cells = n_columns * cells_per_column;
|
| 244 |
-
assert!(
|
|
|
|
|
|
|
|
|
|
| 245 |
let bits_words = n_cells / 32;
|
| 246 |
|
| 247 |
let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
|
|
@@ -278,7 +280,8 @@ impl FusedState {
|
|
| 278 |
// every launched kernel function, otherwise cuLaunchKernelEx rejects
|
| 279 |
// the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
|
| 280 |
unsafe {
|
| 281 |
-
let attr =
|
|
|
|
| 282 |
// Ignore errors: older CUDA may lack the attribute, in which case
|
| 283 |
// only portable sizes (<= 8) work β plan_fused_launch caps at 8.
|
| 284 |
let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
|
|
@@ -294,9 +297,9 @@ impl FusedState {
|
|
| 294 |
};
|
| 295 |
|
| 296 |
// T1: Probe Hopper cluster launch capability.
|
| 297 |
-
let max_cluster_size = match dev
|
| 298 |
-
cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH
|
| 299 |
-
|
| 300 |
Ok(v) if v > 0 => {
|
| 301 |
// H200/sm_90a supports up to 16 blocks per cluster.
|
| 302 |
// There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
|
|
@@ -346,7 +349,11 @@ impl FusedState {
|
|
| 346 |
|
| 347 |
Ok(Self {
|
| 348 |
dev,
|
| 349 |
-
raw_kernel: RawFusedKernel {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
inhibition_threshold,
|
| 351 |
cell_active_bits_a,
|
| 352 |
cell_active_bits_b,
|
|
@@ -445,7 +452,7 @@ pub fn launch_fused(
|
|
| 445 |
inputs: *inputs_flat.device_ptr(),
|
| 446 |
cols_out: *cols_out.device_ptr(),
|
| 447 |
anom_out: *anom_out.device_ptr(),
|
| 448 |
-
barrier_counters: 0u64,
|
| 449 |
step_scratch: *fused.step_scratch.device_ptr(),
|
| 450 |
};
|
| 451 |
|
|
@@ -493,14 +500,17 @@ pub fn launch_fused(
|
|
| 493 |
}
|
| 494 |
} else {
|
| 495 |
// Pre-Hopper: cooperative kernel launch. The fused kernel uses
|
| 496 |
-
//
|
| 497 |
-
//
|
| 498 |
-
// the first grid.sync() call).
|
| 499 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 500 |
fused.raw_kernel.function,
|
| 501 |
-
grid_x,
|
| 502 |
-
|
| 503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
cu_stream,
|
| 505 |
kernel_params.as_mut_ptr(),
|
| 506 |
);
|
|
@@ -616,7 +626,7 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 616 |
inputs: inputs_per_region[i],
|
| 617 |
cols_out: cols_per_region[i],
|
| 618 |
anom_out: anom_per_region[i],
|
| 619 |
-
barrier_counters: 0u64,
|
| 620 |
step_scratch: *r.fused_state.step_scratch.device_ptr(),
|
| 621 |
}
|
| 622 |
})
|
|
@@ -636,8 +646,8 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 636 |
let r0 = unsafe { &*region_ptrs[0] };
|
| 637 |
r0.fused_state.cluster_info.max_cluster_size > 0
|
| 638 |
};
|
| 639 |
-
let grid_x =
|
| 640 |
-
.map_err(|msg| {
|
| 641 |
eprintln!("[htm_rust] FATAL: {msg}");
|
| 642 |
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
|
| 643 |
})?;
|
|
@@ -678,12 +688,19 @@ pub(super) fn launch_fused_batched_raw(
|
|
| 678 |
return Err(DriverError(ret));
|
| 679 |
}
|
| 680 |
} else {
|
| 681 |
-
// Pre-Hopper: cooperative kernel launch
|
|
|
|
|
|
|
|
|
|
| 682 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 683 |
function_batched,
|
| 684 |
-
grid_x,
|
| 685 |
-
|
| 686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
cu_stream,
|
| 688 |
kernel_params.as_mut_ptr(),
|
| 689 |
);
|
|
|
|
| 20 |
use std::ffi::CString;
|
| 21 |
use std::sync::Arc;
|
| 22 |
|
| 23 |
+
use cudarc::driver::{
|
| 24 |
+
result, sys, CudaDevice, CudaSlice, DevicePtr, DeviceRepr, DriverError, LaunchConfig,
|
| 25 |
+
};
|
| 26 |
use cudarc::nvrtc::Ptx;
|
| 27 |
|
| 28 |
use super::sp_gpu::SpatialPoolerGpu;
|
| 29 |
use super::tm_gpu::{TemporalMemoryGpu, MAX_SEGMENTS_PER_CELL, MAX_SYN_PER_SEGMENT};
|
| 30 |
|
| 31 |
+
const PTX_HTM_FUSED: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/htm_fused_step.ptx"));
|
|
|
|
| 32 |
|
| 33 |
/// Struct-by-value pointer pack β matches C-side `FusedPtrs`.
|
| 34 |
///
|
|
|
|
| 132 |
grid_cap_override: Option<u32>,
|
| 133 |
) -> Result<FusedLaunchPlan, String> {
|
| 134 |
let sm_count = sm_count.max(1);
|
| 135 |
+
// 1024 threads/block exceeds the register file on Ampere and makes the
|
| 136 |
+
// cooperative-grid residency probe lie when the launch uses a different
|
| 137 |
+
// block size. Keep the planned block size identical to the occupancy probe.
|
|
|
|
|
|
|
| 138 |
let block_dim_x = 256u32;
|
| 139 |
|
| 140 |
// Cluster launch path: cooperative launch is not required. Keep the probe
|
|
|
|
| 143 |
eprintln!("[htm_rust] INFO: cooperative launch unsupported; cluster path only.");
|
| 144 |
}
|
| 145 |
|
| 146 |
+
// Cluster constraint: grid_dim_x must equal the cluster size (16) so that
|
| 147 |
+
// each region maps to exactly one cluster. `HTM_FUSED_GRID_CAP` can lower
|
| 148 |
+
// this for debugging but should not exceed 16 for cluster correctness.
|
| 149 |
let default_grid_cap = 16u32;
|
| 150 |
+
let grid_cap = grid_cap_override.unwrap_or(default_grid_cap).min(16);
|
| 151 |
let resident_bound = if cooperative_grid_limit > 0 {
|
| 152 |
cooperative_grid_limit.max(sm_count * 2)
|
| 153 |
} else {
|
|
|
|
| 217 |
pub cell_active_bits_b: CudaSlice<u32>,
|
| 218 |
pub cell_winner_bits_a: CudaSlice<u32>,
|
| 219 |
pub cell_winner_bits_b: CudaSlice<u32>,
|
| 220 |
+
pub step_scratch: CudaSlice<u32>, // length 6
|
| 221 |
|
| 222 |
pub grid_dim_x: u32,
|
| 223 |
pub block_dim_x: u32,
|
|
|
|
| 240 |
initial_threshold: f32,
|
| 241 |
) -> Result<Self, DriverError> {
|
| 242 |
let n_cells = n_columns * cells_per_column;
|
| 243 |
+
assert!(
|
| 244 |
+
n_cells % 32 == 0,
|
| 245 |
+
"n_cells must be divisible by 32 for bitsets"
|
| 246 |
+
);
|
| 247 |
let bits_words = n_cells / 32;
|
| 248 |
|
| 249 |
let mut inhibition_threshold = dev.alloc_zeros::<f32>(n_columns)?;
|
|
|
|
| 280 |
// every launched kernel function, otherwise cuLaunchKernelEx rejects
|
| 281 |
// the cluster dim with CUDA_ERROR_INVALID_CLUSTER_SIZE.
|
| 282 |
unsafe {
|
| 283 |
+
let attr =
|
| 284 |
+
sys::CUfunction_attribute::CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED;
|
| 285 |
// Ignore errors: older CUDA may lack the attribute, in which case
|
| 286 |
// only portable sizes (<= 8) work β plan_fused_launch caps at 8.
|
| 287 |
let _ = sys::lib().cuFuncSetAttribute(function, attr, 1);
|
|
|
|
| 297 |
};
|
| 298 |
|
| 299 |
// T1: Probe Hopper cluster launch capability.
|
| 300 |
+
let max_cluster_size = match dev
|
| 301 |
+
.attribute(cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_CLUSTER_LAUNCH)
|
| 302 |
+
{
|
| 303 |
Ok(v) if v > 0 => {
|
| 304 |
// H200/sm_90a supports up to 16 blocks per cluster.
|
| 305 |
// There is no MAX_CLUSTER_SIZE attribute in CUDA 12.4; hard-code the
|
|
|
|
| 349 |
|
| 350 |
Ok(Self {
|
| 351 |
dev,
|
| 352 |
+
raw_kernel: RawFusedKernel {
|
| 353 |
+
module,
|
| 354 |
+
function,
|
| 355 |
+
function_batched,
|
| 356 |
+
},
|
| 357 |
inhibition_threshold,
|
| 358 |
cell_active_bits_a,
|
| 359 |
cell_active_bits_b,
|
|
|
|
| 452 |
inputs: *inputs_flat.device_ptr(),
|
| 453 |
cols_out: *cols_out.device_ptr(),
|
| 454 |
anom_out: *anom_out.device_ptr(),
|
| 455 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 456 |
step_scratch: *fused.step_scratch.device_ptr(),
|
| 457 |
};
|
| 458 |
|
|
|
|
| 500 |
}
|
| 501 |
} else {
|
| 502 |
// Pre-Hopper: cooperative kernel launch. The fused kernel uses
|
| 503 |
+
// cg::this_grid().sync(); normal launches poison the CUDA context
|
| 504 |
+
// with an asynchronous unspecified launch failure.
|
|
|
|
| 505 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 506 |
fused.raw_kernel.function,
|
| 507 |
+
grid_x,
|
| 508 |
+
1,
|
| 509 |
+
1,
|
| 510 |
+
block_x,
|
| 511 |
+
1,
|
| 512 |
+
1,
|
| 513 |
+
0,
|
| 514 |
cu_stream,
|
| 515 |
kernel_params.as_mut_ptr(),
|
| 516 |
);
|
|
|
|
| 626 |
inputs: inputs_per_region[i],
|
| 627 |
cols_out: cols_per_region[i],
|
| 628 |
anom_out: anom_per_region[i],
|
| 629 |
+
barrier_counters: 0u64, // ABI-compat dummy; cluster barrier replaces DLB.
|
| 630 |
step_scratch: *r.fused_state.step_scratch.device_ptr(),
|
| 631 |
}
|
| 632 |
})
|
|
|
|
| 646 |
let r0 = unsafe { &*region_ptrs[0] };
|
| 647 |
r0.fused_state.cluster_info.max_cluster_size > 0
|
| 648 |
};
|
| 649 |
+
let grid_x =
|
| 650 |
+
plan_batched_grid_dim(grid_x, cooperative_grid_limit, b, use_cluster).map_err(|msg| {
|
| 651 |
eprintln!("[htm_rust] FATAL: {msg}");
|
| 652 |
DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE)
|
| 653 |
})?;
|
|
|
|
| 688 |
return Err(DriverError(ret));
|
| 689 |
}
|
| 690 |
} else {
|
| 691 |
+
// Pre-Hopper: cooperative kernel launch. The fused kernel uses
|
| 692 |
+
// cg::this_grid().sync(), which is only valid under cooperative
|
| 693 |
+
// launch. A normal launch can run until the first grid.sync() and
|
| 694 |
+
// then poison the CUDA context with an unspecified launch failure.
|
| 695 |
let ret = sys::lib().cuLaunchCooperativeKernel(
|
| 696 |
function_batched,
|
| 697 |
+
grid_x,
|
| 698 |
+
b as u32,
|
| 699 |
+
1,
|
| 700 |
+
block_x,
|
| 701 |
+
1,
|
| 702 |
+
1,
|
| 703 |
+
0,
|
| 704 |
cu_stream,
|
| 705 |
kernel_params.as_mut_ptr(),
|
| 706 |
);
|
overlay/htm_rust/src/gpu/kernels/sp_boost_fused.cu
CHANGED
|
@@ -1,59 +1,59 @@
|
|
| 1 |
-
// Fused mean-reduction + boost-update kernel.
|
| 2 |
-
//
|
| 3 |
-
// Inputs:
|
| 4 |
-
// active_duty[n] (f32)
|
| 5 |
-
// boost_strength (f32)
|
| 6 |
-
//
|
| 7 |
-
// Output:
|
| 8 |
-
// boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
|
| 9 |
-
//
|
| 10 |
-
// Launch: single block (1024 threads), shared mem for reduction. At n=2048
|
| 11 |
-
// each thread handles 2 elements.
|
| 12 |
-
|
| 13 |
-
extern "C" __global__
|
| 14 |
-
void sp_boost_from_duty(
|
| 15 |
-
const float * __restrict__ active_duty, // (n,)
|
| 16 |
-
float * __restrict__ boost, // (n,) in-place out
|
| 17 |
-
float boost_strength,
|
| 18 |
-
unsigned int n
|
| 19 |
-
) {
|
| 20 |
-
extern __shared__ float smem_raw[];
|
| 21 |
-
float * smem = smem_raw;
|
| 22 |
-
const unsigned int tid = threadIdx.x;
|
| 23 |
-
const unsigned int bsz = blockDim.x;
|
| 24 |
-
|
| 25 |
-
// Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
|
| 26 |
-
float local_sum = 0.0f;
|
| 27 |
-
for (unsigned int i = tid; i < n; i += bsz) {
|
| 28 |
-
local_sum += active_duty[i];
|
| 29 |
-
}
|
| 30 |
-
// Warp reduction.
|
| 31 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 32 |
-
local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
|
| 33 |
-
}
|
| 34 |
-
unsigned int lane = tid & 31;
|
| 35 |
-
unsigned int warp = tid >> 5;
|
| 36 |
-
if (lane == 0) smem[warp] = local_sum;
|
| 37 |
-
__syncthreads();
|
| 38 |
-
|
| 39 |
-
// Warp 0 reduces warp-sums.
|
| 40 |
-
__shared__ float mean_s;
|
| 41 |
-
if (warp == 0) {
|
| 42 |
-
unsigned int nwarps = (bsz + 31) / 32;
|
| 43 |
-
float v = (lane < nwarps) ? smem[lane] : 0.0f;
|
| 44 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 45 |
-
v += __shfl_down_sync(0xffffffff, v, off);
|
| 46 |
-
}
|
| 47 |
-
if (tid == 0) {
|
| 48 |
-
mean_s = v / (float)n;
|
| 49 |
-
}
|
| 50 |
-
}
|
| 51 |
-
__syncthreads();
|
| 52 |
-
|
| 53 |
-
// Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
|
| 54 |
-
float mean = mean_s;
|
| 55 |
-
for (unsigned int i = tid; i < n; i += bsz) {
|
| 56 |
-
float d = active_duty[i] - mean;
|
| 57 |
-
boost[i] = expf(-boost_strength * d);
|
| 58 |
-
}
|
| 59 |
-
}
|
|
|
|
| 1 |
+
// Fused mean-reduction + boost-update kernel.
|
| 2 |
+
//
|
| 3 |
+
// Inputs:
|
| 4 |
+
// active_duty[n] (f32)
|
| 5 |
+
// boost_strength (f32)
|
| 6 |
+
//
|
| 7 |
+
// Output:
|
| 8 |
+
// boost[n] (f32) = expf(-boost_strength * (active_duty[c] - mean))
|
| 9 |
+
//
|
| 10 |
+
// Launch: single block (1024 threads), shared mem for reduction. At n=2048
|
| 11 |
+
// each thread handles 2 elements.
|
| 12 |
+
|
| 13 |
+
extern "C" __global__
|
| 14 |
+
void sp_boost_from_duty(
|
| 15 |
+
const float * __restrict__ active_duty, // (n,)
|
| 16 |
+
float * __restrict__ boost, // (n,) in-place out
|
| 17 |
+
float boost_strength,
|
| 18 |
+
unsigned int n
|
| 19 |
+
) {
|
| 20 |
+
extern __shared__ float smem_raw[];
|
| 21 |
+
float * smem = smem_raw;
|
| 22 |
+
const unsigned int tid = threadIdx.x;
|
| 23 |
+
const unsigned int bsz = blockDim.x;
|
| 24 |
+
|
| 25 |
+
// Phase 1: parallel sum of active_duty into smem[0..32] (warp-level).
|
| 26 |
+
float local_sum = 0.0f;
|
| 27 |
+
for (unsigned int i = tid; i < n; i += bsz) {
|
| 28 |
+
local_sum += active_duty[i];
|
| 29 |
+
}
|
| 30 |
+
// Warp reduction.
|
| 31 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 32 |
+
local_sum += __shfl_down_sync(0xffffffff, local_sum, off);
|
| 33 |
+
}
|
| 34 |
+
unsigned int lane = tid & 31;
|
| 35 |
+
unsigned int warp = tid >> 5;
|
| 36 |
+
if (lane == 0) smem[warp] = local_sum;
|
| 37 |
+
__syncthreads();
|
| 38 |
+
|
| 39 |
+
// Warp 0 reduces warp-sums.
|
| 40 |
+
__shared__ float mean_s;
|
| 41 |
+
if (warp == 0) {
|
| 42 |
+
unsigned int nwarps = (bsz + 31) / 32;
|
| 43 |
+
float v = (lane < nwarps) ? smem[lane] : 0.0f;
|
| 44 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 45 |
+
v += __shfl_down_sync(0xffffffff, v, off);
|
| 46 |
+
}
|
| 47 |
+
if (tid == 0) {
|
| 48 |
+
mean_s = v / (float)n;
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
__syncthreads();
|
| 52 |
+
|
| 53 |
+
// Phase 2: boost[c] = expf(-strength * (active_duty[c] - mean)).
|
| 54 |
+
float mean = mean_s;
|
| 55 |
+
for (unsigned int i = tid; i < n; i += bsz) {
|
| 56 |
+
float d = active_duty[i] - mean;
|
| 57 |
+
boost[i] = expf(-boost_strength * d);
|
| 58 |
+
}
|
| 59 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_duty.cu
CHANGED
|
@@ -1,45 +1,45 @@
|
|
| 1 |
-
// Duty cycle + boost update kernel.
|
| 2 |
-
//
|
| 3 |
-
// For each column c (one thread each):
|
| 4 |
-
// active_sample = active_mask[c] ? 1 : 0
|
| 5 |
-
// overlap_sample = raw_overlap[c] >= stim_thr ? 1 : 0
|
| 6 |
-
// active_duty[c] = (1-alpha) * active_duty[c] + alpha * active_sample
|
| 7 |
-
// overlap_duty[c] = (1-alpha) * overlap_duty[c] + alpha * overlap_sample
|
| 8 |
-
//
|
| 9 |
-
// Then, if learn:
|
| 10 |
-
// boost[c] = exp(-boost_strength * (active_duty[c] - mean_duty))
|
| 11 |
-
// mean_duty is computed on the host (one reduction) and passed in.
|
| 12 |
-
|
| 13 |
-
extern "C" __global__
|
| 14 |
-
void sp_duty_update(
|
| 15 |
-
const unsigned char * __restrict__ active_mask, // (n_columns,)
|
| 16 |
-
const unsigned int * __restrict__ raw_overlap, // (n_columns,)
|
| 17 |
-
float * __restrict__ active_duty, // (n_columns,) in-place
|
| 18 |
-
float * __restrict__ overlap_duty, // (n_columns,) in-place
|
| 19 |
-
float * __restrict__ boost, // (n_columns,) in-place
|
| 20 |
-
float alpha,
|
| 21 |
-
float stim_thr,
|
| 22 |
-
float boost_strength, // 0 to skip boost
|
| 23 |
-
float mean_duty,
|
| 24 |
-
unsigned int learn_flag, // 0 or 1
|
| 25 |
-
unsigned int n_columns
|
| 26 |
-
) {
|
| 27 |
-
unsigned int c = blockIdx.x * blockDim.x + threadIdx.x;
|
| 28 |
-
if (c >= n_columns) return;
|
| 29 |
-
|
| 30 |
-
float ad = active_duty[c];
|
| 31 |
-
float od = overlap_duty[c];
|
| 32 |
-
|
| 33 |
-
float a_sample = (active_mask[c] != 0) ? 1.0f : 0.0f;
|
| 34 |
-
float o_sample = ((float)raw_overlap[c] >= stim_thr) ? 1.0f : 0.0f;
|
| 35 |
-
|
| 36 |
-
ad = (1.0f - alpha) * ad + alpha * a_sample;
|
| 37 |
-
od = (1.0f - alpha) * od + alpha * o_sample;
|
| 38 |
-
|
| 39 |
-
active_duty[c] = ad;
|
| 40 |
-
overlap_duty[c] = od;
|
| 41 |
-
|
| 42 |
-
if (learn_flag && boost_strength > 0.0f) {
|
| 43 |
-
boost[c] = expf(-boost_strength * (ad - mean_duty));
|
| 44 |
-
}
|
| 45 |
-
}
|
|
|
|
| 1 |
+
// Duty cycle + boost update kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each column c (one thread each):
|
| 4 |
+
// active_sample = active_mask[c] ? 1 : 0
|
| 5 |
+
// overlap_sample = raw_overlap[c] >= stim_thr ? 1 : 0
|
| 6 |
+
// active_duty[c] = (1-alpha) * active_duty[c] + alpha * active_sample
|
| 7 |
+
// overlap_duty[c] = (1-alpha) * overlap_duty[c] + alpha * overlap_sample
|
| 8 |
+
//
|
| 9 |
+
// Then, if learn:
|
| 10 |
+
// boost[c] = exp(-boost_strength * (active_duty[c] - mean_duty))
|
| 11 |
+
// mean_duty is computed on the host (one reduction) and passed in.
|
| 12 |
+
|
| 13 |
+
extern "C" __global__
|
| 14 |
+
void sp_duty_update(
|
| 15 |
+
const unsigned char * __restrict__ active_mask, // (n_columns,)
|
| 16 |
+
const unsigned int * __restrict__ raw_overlap, // (n_columns,)
|
| 17 |
+
float * __restrict__ active_duty, // (n_columns,) in-place
|
| 18 |
+
float * __restrict__ overlap_duty, // (n_columns,) in-place
|
| 19 |
+
float * __restrict__ boost, // (n_columns,) in-place
|
| 20 |
+
float alpha,
|
| 21 |
+
float stim_thr,
|
| 22 |
+
float boost_strength, // 0 to skip boost
|
| 23 |
+
float mean_duty,
|
| 24 |
+
unsigned int learn_flag, // 0 or 1
|
| 25 |
+
unsigned int n_columns
|
| 26 |
+
) {
|
| 27 |
+
unsigned int c = blockIdx.x * blockDim.x + threadIdx.x;
|
| 28 |
+
if (c >= n_columns) return;
|
| 29 |
+
|
| 30 |
+
float ad = active_duty[c];
|
| 31 |
+
float od = overlap_duty[c];
|
| 32 |
+
|
| 33 |
+
float a_sample = (active_mask[c] != 0) ? 1.0f : 0.0f;
|
| 34 |
+
float o_sample = ((float)raw_overlap[c] >= stim_thr) ? 1.0f : 0.0f;
|
| 35 |
+
|
| 36 |
+
ad = (1.0f - alpha) * ad + alpha * a_sample;
|
| 37 |
+
od = (1.0f - alpha) * od + alpha * o_sample;
|
| 38 |
+
|
| 39 |
+
active_duty[c] = ad;
|
| 40 |
+
overlap_duty[c] = od;
|
| 41 |
+
|
| 42 |
+
if (learn_flag && boost_strength > 0.0f) {
|
| 43 |
+
boost[c] = expf(-boost_strength * (ad - mean_duty));
|
| 44 |
+
}
|
| 45 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_learn.cu
CHANGED
|
@@ -1,45 +1,45 @@
|
|
| 1 |
-
// SP Hebbian learning kernel.
|
| 2 |
-
//
|
| 3 |
-
// For each active (winner) column c, for each of its synapses s:
|
| 4 |
-
// if input[bit[c][s]] active: perm += inc
|
| 5 |
-
// else: perm -= dec
|
| 6 |
-
// Clamp to [0, 1].
|
| 7 |
-
//
|
| 8 |
-
// Launch: one block per column (2048 blocks), but we predicate on
|
| 9 |
-
// active_mask[c] to avoid launching k-specific blocks.
|
| 10 |
-
//
|
| 11 |
-
// This matches the CPU reference line-for-line:
|
| 12 |
-
// src/sp.rs lines 157-169.
|
| 13 |
-
|
| 14 |
-
extern "C" __global__
|
| 15 |
-
void sp_learn(
|
| 16 |
-
const unsigned char * __restrict__ active_mask, // (n_columns,) 0/1
|
| 17 |
-
const unsigned char * __restrict__ inp, // (input_bits,)
|
| 18 |
-
const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
|
| 19 |
-
float * __restrict__ syn_perm, // (n_columns * S,) in-place
|
| 20 |
-
float inc,
|
| 21 |
-
float dec,
|
| 22 |
-
unsigned int synapses_per_col,
|
| 23 |
-
unsigned int n_columns
|
| 24 |
-
) {
|
| 25 |
-
const unsigned int c = blockIdx.x;
|
| 26 |
-
if (c >= n_columns) return;
|
| 27 |
-
if (active_mask[c] == 0) return;
|
| 28 |
-
|
| 29 |
-
const unsigned int base = c * synapses_per_col;
|
| 30 |
-
const unsigned int tid = threadIdx.x;
|
| 31 |
-
const unsigned int bsz = blockDim.x;
|
| 32 |
-
|
| 33 |
-
for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
|
| 34 |
-
unsigned int b = syn_bit[base + s];
|
| 35 |
-
float p = syn_perm[base + s];
|
| 36 |
-
if (inp[b] != 0) {
|
| 37 |
-
p += inc;
|
| 38 |
-
if (p > 1.0f) p = 1.0f;
|
| 39 |
-
} else {
|
| 40 |
-
p -= dec;
|
| 41 |
-
if (p < 0.0f) p = 0.0f;
|
| 42 |
-
}
|
| 43 |
-
syn_perm[base + s] = p;
|
| 44 |
-
}
|
| 45 |
-
}
|
|
|
|
| 1 |
+
// SP Hebbian learning kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each active (winner) column c, for each of its synapses s:
|
| 4 |
+
// if input[bit[c][s]] active: perm += inc
|
| 5 |
+
// else: perm -= dec
|
| 6 |
+
// Clamp to [0, 1].
|
| 7 |
+
//
|
| 8 |
+
// Launch: one block per column (2048 blocks), but we predicate on
|
| 9 |
+
// active_mask[c] to avoid launching k-specific blocks.
|
| 10 |
+
//
|
| 11 |
+
// This matches the CPU reference line-for-line:
|
| 12 |
+
// src/sp.rs lines 157-169.
|
| 13 |
+
|
| 14 |
+
extern "C" __global__
|
| 15 |
+
void sp_learn(
|
| 16 |
+
const unsigned char * __restrict__ active_mask, // (n_columns,) 0/1
|
| 17 |
+
const unsigned char * __restrict__ inp, // (input_bits,)
|
| 18 |
+
const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
|
| 19 |
+
float * __restrict__ syn_perm, // (n_columns * S,) in-place
|
| 20 |
+
float inc,
|
| 21 |
+
float dec,
|
| 22 |
+
unsigned int synapses_per_col,
|
| 23 |
+
unsigned int n_columns
|
| 24 |
+
) {
|
| 25 |
+
const unsigned int c = blockIdx.x;
|
| 26 |
+
if (c >= n_columns) return;
|
| 27 |
+
if (active_mask[c] == 0) return;
|
| 28 |
+
|
| 29 |
+
const unsigned int base = c * synapses_per_col;
|
| 30 |
+
const unsigned int tid = threadIdx.x;
|
| 31 |
+
const unsigned int bsz = blockDim.x;
|
| 32 |
+
|
| 33 |
+
for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
|
| 34 |
+
unsigned int b = syn_bit[base + s];
|
| 35 |
+
float p = syn_perm[base + s];
|
| 36 |
+
if (inp[b] != 0) {
|
| 37 |
+
p += inc;
|
| 38 |
+
if (p > 1.0f) p = 1.0f;
|
| 39 |
+
} else {
|
| 40 |
+
p -= dec;
|
| 41 |
+
if (p < 0.0f) p = 0.0f;
|
| 42 |
+
}
|
| 43 |
+
syn_perm[base + s] = p;
|
| 44 |
+
}
|
| 45 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_overlap.cu
CHANGED
|
@@ -1,78 +1,78 @@
|
|
| 1 |
-
// SP overlap kernel.
|
| 2 |
-
//
|
| 3 |
-
// For each column c (one CUDA block), compute:
|
| 4 |
-
// overlap[c] = sum over its synapse list of {inp[bit[c][s]] && perm[c][s] >= conn_thr}
|
| 5 |
-
// boosted[c] = overlap[c] * boost[c]
|
| 6 |
-
// raw_overlap[c] = overlap[c] (also returned so host can drive duty cycle)
|
| 7 |
-
//
|
| 8 |
-
// Memory layout (flat, column-major with per-column stride = synapses_per_col):
|
| 9 |
-
// syn_bit[c * S + s] : u32 index into input SDR
|
| 10 |
-
// syn_perm[c * S + s] : f32 permanence in [0, 1]
|
| 11 |
-
// boost[c] : f32
|
| 12 |
-
// inp[b] : u8 0/1
|
| 13 |
-
// Output:
|
| 14 |
-
// raw[c] : u32
|
| 15 |
-
// boosted[c] : f32
|
| 16 |
-
//
|
| 17 |
-
// Launch:
|
| 18 |
-
// grid = n_columns
|
| 19 |
-
// block = 128 (or 256) β one warp-sweep across synapses; many warps give
|
| 20 |
-
// parallel reduction across S (typically S=40).
|
| 21 |
-
//
|
| 22 |
-
// At S=40 this is completely latency-bound; we coalesce reads and do a
|
| 23 |
-
// warp-shuffle reduction. For clarity we use a simple block-wide shared-mem
|
| 24 |
-
// reduction which is sufficient for S <= 1024 and has zero correctness risk.
|
| 25 |
-
|
| 26 |
-
extern "C" __global__
|
| 27 |
-
void sp_overlap(
|
| 28 |
-
const unsigned char * __restrict__ inp, // (input_bits,)
|
| 29 |
-
const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
|
| 30 |
-
const float * __restrict__ syn_perm,// (n_columns * S,)
|
| 31 |
-
const float * __restrict__ boost, // (n_columns,)
|
| 32 |
-
float conn_thr,
|
| 33 |
-
unsigned int synapses_per_col, // S
|
| 34 |
-
unsigned int n_columns,
|
| 35 |
-
unsigned int * __restrict__ raw_out, // (n_columns,)
|
| 36 |
-
float * __restrict__ boosted_out // (n_columns,)
|
| 37 |
-
) {
|
| 38 |
-
const unsigned int c = blockIdx.x;
|
| 39 |
-
if (c >= n_columns) return;
|
| 40 |
-
|
| 41 |
-
const unsigned int base = c * synapses_per_col;
|
| 42 |
-
const unsigned int tid = threadIdx.x;
|
| 43 |
-
const unsigned int bsz = blockDim.x;
|
| 44 |
-
|
| 45 |
-
// Per-thread partial count.
|
| 46 |
-
unsigned int local = 0;
|
| 47 |
-
for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
|
| 48 |
-
unsigned int b = syn_bit[base + s];
|
| 49 |
-
float p = syn_perm[base + s];
|
| 50 |
-
// Branchless: only counts when input active AND perm connected.
|
| 51 |
-
// Using (inp != 0) to tolerate u8 layout.
|
| 52 |
-
unsigned int hit = ((inp[b] != 0) && (p >= conn_thr)) ? 1u : 0u;
|
| 53 |
-
local += hit;
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
// Block-wide reduction in shared memory.
|
| 57 |
-
__shared__ unsigned int smem[32];
|
| 58 |
-
|
| 59 |
-
// Warp-level reduction via shuffle.
|
| 60 |
-
unsigned int lane = tid & 31;
|
| 61 |
-
unsigned int warp = tid >> 5;
|
| 62 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 63 |
-
local += __shfl_down_sync(0xffffffff, local, off);
|
| 64 |
-
}
|
| 65 |
-
if (lane == 0) smem[warp] = local;
|
| 66 |
-
__syncthreads();
|
| 67 |
-
|
| 68 |
-
if (warp == 0) {
|
| 69 |
-
unsigned int v = (tid < (bsz + 31) / 32) ? smem[lane] : 0;
|
| 70 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 71 |
-
v += __shfl_down_sync(0xffffffff, v, off);
|
| 72 |
-
}
|
| 73 |
-
if (tid == 0) {
|
| 74 |
-
raw_out[c] = v;
|
| 75 |
-
boosted_out[c] = (float)v * boost[c];
|
| 76 |
-
}
|
| 77 |
-
}
|
| 78 |
-
}
|
|
|
|
| 1 |
+
// SP overlap kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each column c (one CUDA block), compute:
|
| 4 |
+
// overlap[c] = sum over its synapse list of {inp[bit[c][s]] && perm[c][s] >= conn_thr}
|
| 5 |
+
// boosted[c] = overlap[c] * boost[c]
|
| 6 |
+
// raw_overlap[c] = overlap[c] (also returned so host can drive duty cycle)
|
| 7 |
+
//
|
| 8 |
+
// Memory layout (flat, column-major with per-column stride = synapses_per_col):
|
| 9 |
+
// syn_bit[c * S + s] : u32 index into input SDR
|
| 10 |
+
// syn_perm[c * S + s] : f32 permanence in [0, 1]
|
| 11 |
+
// boost[c] : f32
|
| 12 |
+
// inp[b] : u8 0/1
|
| 13 |
+
// Output:
|
| 14 |
+
// raw[c] : u32
|
| 15 |
+
// boosted[c] : f32
|
| 16 |
+
//
|
| 17 |
+
// Launch:
|
| 18 |
+
// grid = n_columns
|
| 19 |
+
// block = 128 (or 256) β one warp-sweep across synapses; many warps give
|
| 20 |
+
// parallel reduction across S (typically S=40).
|
| 21 |
+
//
|
| 22 |
+
// At S=40 this is completely latency-bound; we coalesce reads and do a
|
| 23 |
+
// warp-shuffle reduction. For clarity we use a simple block-wide shared-mem
|
| 24 |
+
// reduction which is sufficient for S <= 1024 and has zero correctness risk.
|
| 25 |
+
|
| 26 |
+
extern "C" __global__
|
| 27 |
+
void sp_overlap(
|
| 28 |
+
const unsigned char * __restrict__ inp, // (input_bits,)
|
| 29 |
+
const unsigned int * __restrict__ syn_bit, // (n_columns * S,)
|
| 30 |
+
const float * __restrict__ syn_perm,// (n_columns * S,)
|
| 31 |
+
const float * __restrict__ boost, // (n_columns,)
|
| 32 |
+
float conn_thr,
|
| 33 |
+
unsigned int synapses_per_col, // S
|
| 34 |
+
unsigned int n_columns,
|
| 35 |
+
unsigned int * __restrict__ raw_out, // (n_columns,)
|
| 36 |
+
float * __restrict__ boosted_out // (n_columns,)
|
| 37 |
+
) {
|
| 38 |
+
const unsigned int c = blockIdx.x;
|
| 39 |
+
if (c >= n_columns) return;
|
| 40 |
+
|
| 41 |
+
const unsigned int base = c * synapses_per_col;
|
| 42 |
+
const unsigned int tid = threadIdx.x;
|
| 43 |
+
const unsigned int bsz = blockDim.x;
|
| 44 |
+
|
| 45 |
+
// Per-thread partial count.
|
| 46 |
+
unsigned int local = 0;
|
| 47 |
+
for (unsigned int s = tid; s < synapses_per_col; s += bsz) {
|
| 48 |
+
unsigned int b = syn_bit[base + s];
|
| 49 |
+
float p = syn_perm[base + s];
|
| 50 |
+
// Branchless: only counts when input active AND perm connected.
|
| 51 |
+
// Using (inp != 0) to tolerate u8 layout.
|
| 52 |
+
unsigned int hit = ((inp[b] != 0) && (p >= conn_thr)) ? 1u : 0u;
|
| 53 |
+
local += hit;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
// Block-wide reduction in shared memory.
|
| 57 |
+
__shared__ unsigned int smem[32];
|
| 58 |
+
|
| 59 |
+
// Warp-level reduction via shuffle.
|
| 60 |
+
unsigned int lane = tid & 31;
|
| 61 |
+
unsigned int warp = tid >> 5;
|
| 62 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 63 |
+
local += __shfl_down_sync(0xffffffff, local, off);
|
| 64 |
+
}
|
| 65 |
+
if (lane == 0) smem[warp] = local;
|
| 66 |
+
__syncthreads();
|
| 67 |
+
|
| 68 |
+
if (warp == 0) {
|
| 69 |
+
unsigned int v = (tid < (bsz + 31) / 32) ? smem[lane] : 0;
|
| 70 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 71 |
+
v += __shfl_down_sync(0xffffffff, v, off);
|
| 72 |
+
}
|
| 73 |
+
if (tid == 0) {
|
| 74 |
+
raw_out[c] = v;
|
| 75 |
+
boosted_out[c] = (float)v * boost[c];
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
}
|
overlay/htm_rust/src/gpu/kernels/sp_topk.cu
CHANGED
|
@@ -1,117 +1,117 @@
|
|
| 1 |
-
// Top-K column selection.
|
| 2 |
-
//
|
| 3 |
-
// Inputs:
|
| 4 |
-
// boosted[n_columns] : f32 score
|
| 5 |
-
// Output:
|
| 6 |
-
// active_mask[n_columns] : u8 0/1, exactly k ones
|
| 7 |
-
//
|
| 8 |
-
// Tie-breaking: when scores are equal, the LOWER column index wins (matches
|
| 9 |
-
// CPU reference `select_nth_unstable_by` with secondary index comparator).
|
| 10 |
-
//
|
| 11 |
-
// Strategy: a single-block implementation. n_columns is typically 2048, which
|
| 12 |
-
// fits comfortably in shared memory. We use a bitonic top-k via per-thread
|
| 13 |
-
// radix-select of the (score, -index) key. At kβ41 of n=2048 the simplest
|
| 14 |
-
// correct approach is a thresholding pass:
|
| 15 |
-
//
|
| 16 |
-
// 1. Radix-like bucket pass to find the k-th largest score.
|
| 17 |
-
// 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
|
| 18 |
-
//
|
| 19 |
-
// For strict index-ordered tie-break we materialise a 64-bit key:
|
| 20 |
-
// key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
|
| 21 |
-
// Larger key = (higher score) OR (same score, smaller index).
|
| 22 |
-
//
|
| 23 |
-
// Then we find the k-th largest 64-bit key via radix-select and mark all
|
| 24 |
-
// columns whose key >= threshold. This is O(n_cols * log k) and well under
|
| 25 |
-
// 100 ΞΌs for n=2048, k=41 on sm_86.
|
| 26 |
-
//
|
| 27 |
-
// For simplicity and correctness this kernel uses a single-block parallel
|
| 28 |
-
// selection sort variant (find max β mark β zero β repeat, k iterations).
|
| 29 |
-
// At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
|
| 30 |
-
// fast.
|
| 31 |
-
|
| 32 |
-
extern "C" __global__
|
| 33 |
-
void sp_topk_select(
|
| 34 |
-
const float * __restrict__ scores, // (n_columns,)
|
| 35 |
-
unsigned int n_columns,
|
| 36 |
-
unsigned int k,
|
| 37 |
-
unsigned char * __restrict__ active_out // (n_columns,)
|
| 38 |
-
) {
|
| 39 |
-
extern __shared__ float smem[];
|
| 40 |
-
// Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
|
| 41 |
-
// smem[n..n+32*2] = reduction scratch (score + index, per warp)
|
| 42 |
-
float * work = smem;
|
| 43 |
-
const unsigned int tid = threadIdx.x;
|
| 44 |
-
const unsigned int bsz = blockDim.x;
|
| 45 |
-
|
| 46 |
-
// Load scores into shared; also init active_out = 0.
|
| 47 |
-
for (unsigned int i = tid; i < n_columns; i += bsz) {
|
| 48 |
-
work[i] = scores[i];
|
| 49 |
-
active_out[i] = 0;
|
| 50 |
-
}
|
| 51 |
-
__syncthreads();
|
| 52 |
-
|
| 53 |
-
__shared__ int winner_idx;
|
| 54 |
-
__shared__ float winner_score;
|
| 55 |
-
|
| 56 |
-
for (unsigned int iter = 0; iter < k; ++iter) {
|
| 57 |
-
// Find (argmax score, lowest index for ties).
|
| 58 |
-
float best_s = -INFINITY;
|
| 59 |
-
int best_i = n_columns; // sentinel larger than any index
|
| 60 |
-
|
| 61 |
-
for (unsigned int i = tid; i < n_columns; i += bsz) {
|
| 62 |
-
float s = work[i];
|
| 63 |
-
if (s > best_s || (s == best_s && (int)i < best_i)) {
|
| 64 |
-
best_s = s;
|
| 65 |
-
best_i = (int)i;
|
| 66 |
-
}
|
| 67 |
-
}
|
| 68 |
-
|
| 69 |
-
// Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
|
| 70 |
-
unsigned int mask = 0xffffffff;
|
| 71 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 72 |
-
float os = __shfl_down_sync(mask, best_s, off);
|
| 73 |
-
int oi = __shfl_down_sync(mask, best_i, off);
|
| 74 |
-
if (os > best_s || (os == best_s && oi < best_i)) {
|
| 75 |
-
best_s = os;
|
| 76 |
-
best_i = oi;
|
| 77 |
-
}
|
| 78 |
-
}
|
| 79 |
-
// Warp 0 collects lane 0 values from other warps via shared mem.
|
| 80 |
-
__shared__ float warp_s[32];
|
| 81 |
-
__shared__ int warp_i[32];
|
| 82 |
-
unsigned int lane = tid & 31;
|
| 83 |
-
unsigned int warp = tid >> 5;
|
| 84 |
-
if (lane == 0) {
|
| 85 |
-
warp_s[warp] = best_s;
|
| 86 |
-
warp_i[warp] = best_i;
|
| 87 |
-
}
|
| 88 |
-
__syncthreads();
|
| 89 |
-
|
| 90 |
-
if (warp == 0) {
|
| 91 |
-
unsigned int nwarps = (bsz + 31) / 32;
|
| 92 |
-
float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
|
| 93 |
-
int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
|
| 94 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 95 |
-
float os = __shfl_down_sync(mask, s, off);
|
| 96 |
-
int oi = __shfl_down_sync(mask, i, off);
|
| 97 |
-
if (os > s || (os == s && oi < i)) {
|
| 98 |
-
s = os;
|
| 99 |
-
i = oi;
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
-
if (tid == 0) {
|
| 103 |
-
winner_score = s;
|
| 104 |
-
winner_idx = i;
|
| 105 |
-
}
|
| 106 |
-
}
|
| 107 |
-
__syncthreads();
|
| 108 |
-
|
| 109 |
-
if (tid == 0) {
|
| 110 |
-
if (winner_idx < (int)n_columns) {
|
| 111 |
-
active_out[winner_idx] = 1;
|
| 112 |
-
work[winner_idx] = -INFINITY;
|
| 113 |
-
}
|
| 114 |
-
}
|
| 115 |
-
__syncthreads();
|
| 116 |
-
}
|
| 117 |
-
}
|
|
|
|
| 1 |
+
// Top-K column selection.
|
| 2 |
+
//
|
| 3 |
+
// Inputs:
|
| 4 |
+
// boosted[n_columns] : f32 score
|
| 5 |
+
// Output:
|
| 6 |
+
// active_mask[n_columns] : u8 0/1, exactly k ones
|
| 7 |
+
//
|
| 8 |
+
// Tie-breaking: when scores are equal, the LOWER column index wins (matches
|
| 9 |
+
// CPU reference `select_nth_unstable_by` with secondary index comparator).
|
| 10 |
+
//
|
| 11 |
+
// Strategy: a single-block implementation. n_columns is typically 2048, which
|
| 12 |
+
// fits comfortably in shared memory. We use a bitonic top-k via per-thread
|
| 13 |
+
// radix-select of the (score, -index) key. At kβ41 of n=2048 the simplest
|
| 14 |
+
// correct approach is a thresholding pass:
|
| 15 |
+
//
|
| 16 |
+
// 1. Radix-like bucket pass to find the k-th largest score.
|
| 17 |
+
// 2. Mark winners = strictly-greater-than-threshold AND ties until count hits k.
|
| 18 |
+
//
|
| 19 |
+
// For strict index-ordered tie-break we materialise a 64-bit key:
|
| 20 |
+
// key = (float_to_sortable_u32(score) << 32) | (0xffffffff - index)
|
| 21 |
+
// Larger key = (higher score) OR (same score, smaller index).
|
| 22 |
+
//
|
| 23 |
+
// Then we find the k-th largest 64-bit key via radix-select and mark all
|
| 24 |
+
// columns whose key >= threshold. This is O(n_cols * log k) and well under
|
| 25 |
+
// 100 ΞΌs for n=2048, k=41 on sm_86.
|
| 26 |
+
//
|
| 27 |
+
// For simplicity and correctness this kernel uses a single-block parallel
|
| 28 |
+
// selection sort variant (find max β mark β zero β repeat, k iterations).
|
| 29 |
+
// At k=41 this is 41 passes of 2048 threads = ~2048*41 = 84K ops, trivially
|
| 30 |
+
// fast.
|
| 31 |
+
|
| 32 |
+
extern "C" __global__
|
| 33 |
+
void sp_topk_select(
|
| 34 |
+
const float * __restrict__ scores, // (n_columns,)
|
| 35 |
+
unsigned int n_columns,
|
| 36 |
+
unsigned int k,
|
| 37 |
+
unsigned char * __restrict__ active_out // (n_columns,)
|
| 38 |
+
) {
|
| 39 |
+
extern __shared__ float smem[];
|
| 40 |
+
// Layout: smem[0..n] = working scores (we'll mark selected entries as -inf)
|
| 41 |
+
// smem[n..n+32*2] = reduction scratch (score + index, per warp)
|
| 42 |
+
float * work = smem;
|
| 43 |
+
const unsigned int tid = threadIdx.x;
|
| 44 |
+
const unsigned int bsz = blockDim.x;
|
| 45 |
+
|
| 46 |
+
// Load scores into shared; also init active_out = 0.
|
| 47 |
+
for (unsigned int i = tid; i < n_columns; i += bsz) {
|
| 48 |
+
work[i] = scores[i];
|
| 49 |
+
active_out[i] = 0;
|
| 50 |
+
}
|
| 51 |
+
__syncthreads();
|
| 52 |
+
|
| 53 |
+
__shared__ int winner_idx;
|
| 54 |
+
__shared__ float winner_score;
|
| 55 |
+
|
| 56 |
+
for (unsigned int iter = 0; iter < k; ++iter) {
|
| 57 |
+
// Find (argmax score, lowest index for ties).
|
| 58 |
+
float best_s = -INFINITY;
|
| 59 |
+
int best_i = n_columns; // sentinel larger than any index
|
| 60 |
+
|
| 61 |
+
for (unsigned int i = tid; i < n_columns; i += bsz) {
|
| 62 |
+
float s = work[i];
|
| 63 |
+
if (s > best_s || (s == best_s && (int)i < best_i)) {
|
| 64 |
+
best_s = s;
|
| 65 |
+
best_i = (int)i;
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// Warp reduction. We reduce pairs (score, idx) keeping (max score, min idx on tie).
|
| 70 |
+
unsigned int mask = 0xffffffff;
|
| 71 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 72 |
+
float os = __shfl_down_sync(mask, best_s, off);
|
| 73 |
+
int oi = __shfl_down_sync(mask, best_i, off);
|
| 74 |
+
if (os > best_s || (os == best_s && oi < best_i)) {
|
| 75 |
+
best_s = os;
|
| 76 |
+
best_i = oi;
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
// Warp 0 collects lane 0 values from other warps via shared mem.
|
| 80 |
+
__shared__ float warp_s[32];
|
| 81 |
+
__shared__ int warp_i[32];
|
| 82 |
+
unsigned int lane = tid & 31;
|
| 83 |
+
unsigned int warp = tid >> 5;
|
| 84 |
+
if (lane == 0) {
|
| 85 |
+
warp_s[warp] = best_s;
|
| 86 |
+
warp_i[warp] = best_i;
|
| 87 |
+
}
|
| 88 |
+
__syncthreads();
|
| 89 |
+
|
| 90 |
+
if (warp == 0) {
|
| 91 |
+
unsigned int nwarps = (bsz + 31) / 32;
|
| 92 |
+
float s = (lane < nwarps) ? warp_s[lane] : -INFINITY;
|
| 93 |
+
int i = (lane < nwarps) ? warp_i[lane] : (int)n_columns;
|
| 94 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 95 |
+
float os = __shfl_down_sync(mask, s, off);
|
| 96 |
+
int oi = __shfl_down_sync(mask, i, off);
|
| 97 |
+
if (os > s || (os == s && oi < i)) {
|
| 98 |
+
s = os;
|
| 99 |
+
i = oi;
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
if (tid == 0) {
|
| 103 |
+
winner_score = s;
|
| 104 |
+
winner_idx = i;
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
__syncthreads();
|
| 108 |
+
|
| 109 |
+
if (tid == 0) {
|
| 110 |
+
if (winner_idx < (int)n_columns) {
|
| 111 |
+
active_out[winner_idx] = 1;
|
| 112 |
+
work[winner_idx] = -INFINITY;
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
__syncthreads();
|
| 116 |
+
}
|
| 117 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_activate.cu
CHANGED
|
@@ -1,66 +1,66 @@
|
|
| 1 |
-
// TM activate kernel. See tm_predict.cu for TmConfig.
|
| 2 |
-
|
| 3 |
-
struct TmConfig {
|
| 4 |
-
unsigned int activation_threshold;
|
| 5 |
-
unsigned int learning_threshold;
|
| 6 |
-
unsigned int cells_per_column;
|
| 7 |
-
unsigned int synapses_per_segment;
|
| 8 |
-
unsigned int n_segments;
|
| 9 |
-
unsigned int n_cells;
|
| 10 |
-
unsigned int max_segments_per_cell;
|
| 11 |
-
unsigned int max_new_synapses;
|
| 12 |
-
int conn_thr_i16;
|
| 13 |
-
int perm_inc_i16;
|
| 14 |
-
int perm_dec_i16;
|
| 15 |
-
int predicted_seg_dec_i16;
|
| 16 |
-
int initial_perm_i16;
|
| 17 |
-
unsigned int iter_seed;
|
| 18 |
-
unsigned int n_cols;
|
| 19 |
-
unsigned int bits_words;
|
| 20 |
-
};
|
| 21 |
-
|
| 22 |
-
extern "C" __global__
|
| 23 |
-
void tm_activate(
|
| 24 |
-
const unsigned char * __restrict__ sp_active_mask,
|
| 25 |
-
const unsigned char * __restrict__ col_predicted,
|
| 26 |
-
const unsigned int * __restrict__ cell_predictive_bits,
|
| 27 |
-
unsigned int * __restrict__ cell_active_bits,
|
| 28 |
-
unsigned int * __restrict__ cell_winner_bits,
|
| 29 |
-
unsigned int * __restrict__ unpredicted_count,
|
| 30 |
-
unsigned int * __restrict__ burst_cols_flat,
|
| 31 |
-
unsigned int * __restrict__ burst_cols_count,
|
| 32 |
-
TmConfig cfg
|
| 33 |
-
) {
|
| 34 |
-
unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
|
| 35 |
-
if (col >= cfg.n_cols) return;
|
| 36 |
-
if (sp_active_mask[col] == 0) return;
|
| 37 |
-
|
| 38 |
-
unsigned int base_cell = col * cfg.cells_per_column;
|
| 39 |
-
|
| 40 |
-
if (col_predicted[col]) {
|
| 41 |
-
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
|
| 42 |
-
unsigned int cell = base_cell + k;
|
| 43 |
-
unsigned int word_idx = cell >> 5;
|
| 44 |
-
unsigned int bit_mask = 1u << (cell & 31u);
|
| 45 |
-
unsigned int pred_word = cell_predictive_bits[word_idx];
|
| 46 |
-
if (pred_word & bit_mask) {
|
| 47 |
-
atomicOr(&cell_active_bits[word_idx], bit_mask);
|
| 48 |
-
atomicOr(&cell_winner_bits[word_idx], bit_mask);
|
| 49 |
-
}
|
| 50 |
-
}
|
| 51 |
-
} else {
|
| 52 |
-
atomicAdd(unpredicted_count, 1u);
|
| 53 |
-
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
|
| 54 |
-
unsigned int cell = base_cell + k;
|
| 55 |
-
unsigned int word_idx = cell >> 5;
|
| 56 |
-
unsigned int bit_mask = 1u << (cell & 31u);
|
| 57 |
-
atomicOr(&cell_active_bits[word_idx], bit_mask);
|
| 58 |
-
}
|
| 59 |
-
unsigned int winner = base_cell;
|
| 60 |
-
unsigned int word_idx = winner >> 5;
|
| 61 |
-
unsigned int bit_mask = 1u << (winner & 31u);
|
| 62 |
-
atomicOr(&cell_winner_bits[word_idx], bit_mask);
|
| 63 |
-
unsigned int slot = atomicAdd(burst_cols_count, 1u);
|
| 64 |
-
burst_cols_flat[slot] = col;
|
| 65 |
-
}
|
| 66 |
-
}
|
|
|
|
| 1 |
+
// TM activate kernel. See tm_predict.cu for TmConfig.
|
| 2 |
+
|
| 3 |
+
struct TmConfig {
|
| 4 |
+
unsigned int activation_threshold;
|
| 5 |
+
unsigned int learning_threshold;
|
| 6 |
+
unsigned int cells_per_column;
|
| 7 |
+
unsigned int synapses_per_segment;
|
| 8 |
+
unsigned int n_segments;
|
| 9 |
+
unsigned int n_cells;
|
| 10 |
+
unsigned int max_segments_per_cell;
|
| 11 |
+
unsigned int max_new_synapses;
|
| 12 |
+
int conn_thr_i16;
|
| 13 |
+
int perm_inc_i16;
|
| 14 |
+
int perm_dec_i16;
|
| 15 |
+
int predicted_seg_dec_i16;
|
| 16 |
+
int initial_perm_i16;
|
| 17 |
+
unsigned int iter_seed;
|
| 18 |
+
unsigned int n_cols;
|
| 19 |
+
unsigned int bits_words;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
extern "C" __global__
|
| 23 |
+
void tm_activate(
|
| 24 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 25 |
+
const unsigned char * __restrict__ col_predicted,
|
| 26 |
+
const unsigned int * __restrict__ cell_predictive_bits,
|
| 27 |
+
unsigned int * __restrict__ cell_active_bits,
|
| 28 |
+
unsigned int * __restrict__ cell_winner_bits,
|
| 29 |
+
unsigned int * __restrict__ unpredicted_count,
|
| 30 |
+
unsigned int * __restrict__ burst_cols_flat,
|
| 31 |
+
unsigned int * __restrict__ burst_cols_count,
|
| 32 |
+
TmConfig cfg
|
| 33 |
+
) {
|
| 34 |
+
unsigned int col = blockIdx.x * blockDim.x + threadIdx.x;
|
| 35 |
+
if (col >= cfg.n_cols) return;
|
| 36 |
+
if (sp_active_mask[col] == 0) return;
|
| 37 |
+
|
| 38 |
+
unsigned int base_cell = col * cfg.cells_per_column;
|
| 39 |
+
|
| 40 |
+
if (col_predicted[col]) {
|
| 41 |
+
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
|
| 42 |
+
unsigned int cell = base_cell + k;
|
| 43 |
+
unsigned int word_idx = cell >> 5;
|
| 44 |
+
unsigned int bit_mask = 1u << (cell & 31u);
|
| 45 |
+
unsigned int pred_word = cell_predictive_bits[word_idx];
|
| 46 |
+
if (pred_word & bit_mask) {
|
| 47 |
+
atomicOr(&cell_active_bits[word_idx], bit_mask);
|
| 48 |
+
atomicOr(&cell_winner_bits[word_idx], bit_mask);
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
} else {
|
| 52 |
+
atomicAdd(unpredicted_count, 1u);
|
| 53 |
+
for (unsigned int k = 0; k < cfg.cells_per_column; k++) {
|
| 54 |
+
unsigned int cell = base_cell + k;
|
| 55 |
+
unsigned int word_idx = cell >> 5;
|
| 56 |
+
unsigned int bit_mask = 1u << (cell & 31u);
|
| 57 |
+
atomicOr(&cell_active_bits[word_idx], bit_mask);
|
| 58 |
+
}
|
| 59 |
+
unsigned int winner = base_cell;
|
| 60 |
+
unsigned int word_idx = winner >> 5;
|
| 61 |
+
unsigned int bit_mask = 1u << (winner & 31u);
|
| 62 |
+
atomicOr(&cell_winner_bits[word_idx], bit_mask);
|
| 63 |
+
unsigned int slot = atomicAdd(burst_cols_count, 1u);
|
| 64 |
+
burst_cols_flat[slot] = col;
|
| 65 |
+
}
|
| 66 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_anomaly.cu
CHANGED
|
@@ -1,43 +1,43 @@
|
|
| 1 |
-
// TM anomaly kernel.
|
| 2 |
-
//
|
| 3 |
-
// Computes:
|
| 4 |
-
// n_active = sum of sp_active_mask
|
| 5 |
-
// anomaly = unpredicted_count / n_active (if n_active > 0)
|
| 6 |
-
// = 0 (else)
|
| 7 |
-
//
|
| 8 |
-
// Launch: single block, 256 threads.
|
| 9 |
-
|
| 10 |
-
extern "C" __global__
|
| 11 |
-
void tm_anomaly(
|
| 12 |
-
const unsigned char * __restrict__ sp_active_mask,
|
| 13 |
-
const unsigned int * __restrict__ unpredicted_count,
|
| 14 |
-
float * __restrict__ anomaly_out, // (1,) or (t_slot,)
|
| 15 |
-
unsigned int t_slot,
|
| 16 |
-
unsigned int n_cols
|
| 17 |
-
) {
|
| 18 |
-
const unsigned int tid = threadIdx.x;
|
| 19 |
-
__shared__ unsigned int n_active_s;
|
| 20 |
-
|
| 21 |
-
if (tid == 0) n_active_s = 0u;
|
| 22 |
-
__syncthreads();
|
| 23 |
-
|
| 24 |
-
unsigned int local = 0u;
|
| 25 |
-
for (unsigned int i = tid; i < n_cols; i += blockDim.x) {
|
| 26 |
-
if (sp_active_mask[i]) local += 1u;
|
| 27 |
-
}
|
| 28 |
-
// Warp reduce.
|
| 29 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 30 |
-
local += __shfl_down_sync(0xffffffffu, local, off);
|
| 31 |
-
}
|
| 32 |
-
if ((tid & 31u) == 0) {
|
| 33 |
-
atomicAdd(&n_active_s, local);
|
| 34 |
-
}
|
| 35 |
-
__syncthreads();
|
| 36 |
-
|
| 37 |
-
if (tid == 0) {
|
| 38 |
-
unsigned int total = n_active_s;
|
| 39 |
-
unsigned int bad = unpredicted_count[0];
|
| 40 |
-
float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
|
| 41 |
-
anomaly_out[t_slot] = anom;
|
| 42 |
-
}
|
| 43 |
-
}
|
|
|
|
| 1 |
+
// TM anomaly kernel.
|
| 2 |
+
//
|
| 3 |
+
// Computes:
|
| 4 |
+
// n_active = sum of sp_active_mask
|
| 5 |
+
// anomaly = unpredicted_count / n_active (if n_active > 0)
|
| 6 |
+
// = 0 (else)
|
| 7 |
+
//
|
| 8 |
+
// Launch: single block, 256 threads.
|
| 9 |
+
|
| 10 |
+
extern "C" __global__
|
| 11 |
+
void tm_anomaly(
|
| 12 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 13 |
+
const unsigned int * __restrict__ unpredicted_count,
|
| 14 |
+
float * __restrict__ anomaly_out, // (1,) or (t_slot,)
|
| 15 |
+
unsigned int t_slot,
|
| 16 |
+
unsigned int n_cols
|
| 17 |
+
) {
|
| 18 |
+
const unsigned int tid = threadIdx.x;
|
| 19 |
+
__shared__ unsigned int n_active_s;
|
| 20 |
+
|
| 21 |
+
if (tid == 0) n_active_s = 0u;
|
| 22 |
+
__syncthreads();
|
| 23 |
+
|
| 24 |
+
unsigned int local = 0u;
|
| 25 |
+
for (unsigned int i = tid; i < n_cols; i += blockDim.x) {
|
| 26 |
+
if (sp_active_mask[i]) local += 1u;
|
| 27 |
+
}
|
| 28 |
+
// Warp reduce.
|
| 29 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 30 |
+
local += __shfl_down_sync(0xffffffffu, local, off);
|
| 31 |
+
}
|
| 32 |
+
if ((tid & 31u) == 0) {
|
| 33 |
+
atomicAdd(&n_active_s, local);
|
| 34 |
+
}
|
| 35 |
+
__syncthreads();
|
| 36 |
+
|
| 37 |
+
if (tid == 0) {
|
| 38 |
+
unsigned int total = n_active_s;
|
| 39 |
+
unsigned int bad = unpredicted_count[0];
|
| 40 |
+
float anom = (total > 0u) ? ((float)bad / (float)total) : 0.0f;
|
| 41 |
+
anomaly_out[t_slot] = anom;
|
| 42 |
+
}
|
| 43 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_grow.cu
CHANGED
|
@@ -1,155 +1,155 @@
|
|
| 1 |
-
// TM grow+reinforce kernel.
|
| 2 |
-
//
|
| 3 |
-
// For each bursting column:
|
| 4 |
-
// If col_best_match[col] is non-zero (i.e. at least one matching segment
|
| 5 |
-
// with num_active_potential >= learning_threshold exists on cells in this col):
|
| 6 |
-
// Target = that matching segment.
|
| 7 |
-
// Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
|
| 8 |
-
// Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
|
| 9 |
-
// Else:
|
| 10 |
-
// Allocate a fresh segment slot on winner cell (cell 0 of col).
|
| 11 |
-
// Grow up to max_new synapses to prev_winners (no reinforce needed β new seg).
|
| 12 |
-
//
|
| 13 |
-
// This mirrors the CPU TM burst logic.
|
| 14 |
-
|
| 15 |
-
struct TmConfig {
|
| 16 |
-
unsigned int activation_threshold;
|
| 17 |
-
unsigned int learning_threshold;
|
| 18 |
-
unsigned int cells_per_column;
|
| 19 |
-
unsigned int synapses_per_segment;
|
| 20 |
-
unsigned int n_segments;
|
| 21 |
-
unsigned int n_cells;
|
| 22 |
-
unsigned int max_segments_per_cell;
|
| 23 |
-
unsigned int max_new_synapses;
|
| 24 |
-
int conn_thr_i16;
|
| 25 |
-
int perm_inc_i16;
|
| 26 |
-
int perm_dec_i16;
|
| 27 |
-
int predicted_seg_dec_i16;
|
| 28 |
-
int initial_perm_i16;
|
| 29 |
-
unsigned int iter_seed;
|
| 30 |
-
unsigned int n_cols;
|
| 31 |
-
unsigned int bits_words;
|
| 32 |
-
};
|
| 33 |
-
|
| 34 |
-
extern "C" __global__
|
| 35 |
-
void tm_grow(
|
| 36 |
-
unsigned int * __restrict__ seg_cell_id,
|
| 37 |
-
unsigned int * __restrict__ seg_syn_count,
|
| 38 |
-
unsigned int * __restrict__ syn_presyn,
|
| 39 |
-
short * __restrict__ syn_perm,
|
| 40 |
-
unsigned int * __restrict__ cell_seg_count,
|
| 41 |
-
const unsigned int * __restrict__ burst_cols_flat,
|
| 42 |
-
const unsigned int * __restrict__ burst_cols_count,
|
| 43 |
-
const unsigned int * __restrict__ prev_winner_bits,
|
| 44 |
-
const unsigned int * __restrict__ prev_active_bits,
|
| 45 |
-
const unsigned int * __restrict__ col_best_match,
|
| 46 |
-
TmConfig cfg
|
| 47 |
-
) {
|
| 48 |
-
const unsigned int b = blockIdx.x;
|
| 49 |
-
const unsigned int n_burst_cols = burst_cols_count[0];
|
| 50 |
-
if (b >= n_burst_cols) return;
|
| 51 |
-
const unsigned int tid = threadIdx.x;
|
| 52 |
-
|
| 53 |
-
const unsigned int col = burst_cols_flat[b];
|
| 54 |
-
|
| 55 |
-
__shared__ unsigned int shared_seg_id;
|
| 56 |
-
__shared__ unsigned int shared_existing_syn_count;
|
| 57 |
-
__shared__ unsigned int shared_grown;
|
| 58 |
-
__shared__ unsigned int shared_is_new;
|
| 59 |
-
__shared__ unsigned int shared_start_offset;
|
| 60 |
-
|
| 61 |
-
if (tid == 0) {
|
| 62 |
-
unsigned int match_key = col_best_match[col];
|
| 63 |
-
if (match_key != 0u) {
|
| 64 |
-
// Reuse matching segment.
|
| 65 |
-
unsigned int seg_id = match_key & 0x1FFFFFu;
|
| 66 |
-
shared_seg_id = seg_id;
|
| 67 |
-
shared_existing_syn_count = seg_syn_count[seg_id];
|
| 68 |
-
shared_is_new = 0u;
|
| 69 |
-
} else {
|
| 70 |
-
// Allocate new segment on winner cell (cell 0 of col).
|
| 71 |
-
unsigned int winner_cell = col * cfg.cells_per_column;
|
| 72 |
-
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
|
| 73 |
-
if (slot >= cfg.max_segments_per_cell) {
|
| 74 |
-
slot = slot % cfg.max_segments_per_cell;
|
| 75 |
-
}
|
| 76 |
-
unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
|
| 77 |
-
seg_cell_id[seg_id] = winner_cell;
|
| 78 |
-
seg_syn_count[seg_id] = 0;
|
| 79 |
-
shared_seg_id = seg_id;
|
| 80 |
-
shared_existing_syn_count = 0u;
|
| 81 |
-
shared_is_new = 1u;
|
| 82 |
-
}
|
| 83 |
-
shared_grown = 0u;
|
| 84 |
-
shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
|
| 85 |
-
}
|
| 86 |
-
__syncthreads();
|
| 87 |
-
|
| 88 |
-
const unsigned int seg_id = shared_seg_id;
|
| 89 |
-
const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
|
| 90 |
-
const unsigned int existing_syn = shared_existing_syn_count;
|
| 91 |
-
const unsigned int is_new = shared_is_new;
|
| 92 |
-
const unsigned int start = shared_start_offset;
|
| 93 |
-
|
| 94 |
-
// PHASE 1: If reusing, reinforce existing synapses.
|
| 95 |
-
if (!is_new) {
|
| 96 |
-
for (unsigned int s = tid; s < existing_syn; s += 32u) {
|
| 97 |
-
unsigned int presyn = syn_presyn[seg_base + s];
|
| 98 |
-
unsigned int word = prev_active_bits[presyn >> 5];
|
| 99 |
-
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 100 |
-
int p = (int)syn_perm[seg_base + s];
|
| 101 |
-
if (bit) {
|
| 102 |
-
int np = p + cfg.perm_inc_i16;
|
| 103 |
-
if (np > 32767) np = 32767;
|
| 104 |
-
syn_perm[seg_base + s] = (short)np;
|
| 105 |
-
} else {
|
| 106 |
-
int np = p - cfg.perm_dec_i16;
|
| 107 |
-
if (np < 0) np = 0;
|
| 108 |
-
syn_perm[seg_base + s] = (short)np;
|
| 109 |
-
}
|
| 110 |
-
}
|
| 111 |
-
__syncthreads();
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
// PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
|
| 115 |
-
// that aren't already presynaptic to this segment.
|
| 116 |
-
const unsigned int room = (cfg.synapses_per_segment > existing_syn)
|
| 117 |
-
? (cfg.synapses_per_segment - existing_syn) : 0u;
|
| 118 |
-
const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
|
| 119 |
-
|
| 120 |
-
for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
|
| 121 |
-
if (shared_grown >= max_grow) break;
|
| 122 |
-
unsigned int widx = (start + w_off + tid) % cfg.bits_words;
|
| 123 |
-
unsigned int word = prev_winner_bits[widx];
|
| 124 |
-
while (word != 0u) {
|
| 125 |
-
if (shared_grown >= max_grow) break;
|
| 126 |
-
unsigned int bit_pos = __ffs(word) - 1u;
|
| 127 |
-
word &= ~(1u << bit_pos);
|
| 128 |
-
unsigned int cell = widx * 32u + bit_pos;
|
| 129 |
-
if (cell >= cfg.n_cells) continue;
|
| 130 |
-
|
| 131 |
-
// Skip if already presynaptic (O(existing_syn) scan; usually small).
|
| 132 |
-
bool exists = false;
|
| 133 |
-
for (unsigned int s = 0; s < existing_syn; s++) {
|
| 134 |
-
if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
|
| 135 |
-
}
|
| 136 |
-
if (exists) continue;
|
| 137 |
-
|
| 138 |
-
unsigned int slot = atomicAdd(&shared_grown, 1u);
|
| 139 |
-
if (slot >= max_grow) break;
|
| 140 |
-
unsigned int write_idx = existing_syn + slot;
|
| 141 |
-
if (write_idx >= cfg.synapses_per_segment) break;
|
| 142 |
-
syn_presyn[seg_base + write_idx] = cell;
|
| 143 |
-
syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
|
| 144 |
-
}
|
| 145 |
-
}
|
| 146 |
-
__syncthreads();
|
| 147 |
-
|
| 148 |
-
if (tid == 0) {
|
| 149 |
-
unsigned int grown = shared_grown;
|
| 150 |
-
if (grown > max_grow) grown = max_grow;
|
| 151 |
-
unsigned int new_count = existing_syn + grown;
|
| 152 |
-
if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
|
| 153 |
-
seg_syn_count[seg_id] = new_count;
|
| 154 |
-
}
|
| 155 |
-
}
|
|
|
|
| 1 |
+
// TM grow+reinforce kernel.
|
| 2 |
+
//
|
| 3 |
+
// For each bursting column:
|
| 4 |
+
// If col_best_match[col] is non-zero (i.e. at least one matching segment
|
| 5 |
+
// with num_active_potential >= learning_threshold exists on cells in this col):
|
| 6 |
+
// Target = that matching segment.
|
| 7 |
+
// Reinforce its existing synapses: +inc if presyn in prev_active, -dec otherwise.
|
| 8 |
+
// Grow up to (max_new - current_syn_count) additional synapses to prev_winners.
|
| 9 |
+
// Else:
|
| 10 |
+
// Allocate a fresh segment slot on winner cell (cell 0 of col).
|
| 11 |
+
// Grow up to max_new synapses to prev_winners (no reinforce needed β new seg).
|
| 12 |
+
//
|
| 13 |
+
// This mirrors the CPU TM burst logic.
|
| 14 |
+
|
| 15 |
+
struct TmConfig {
|
| 16 |
+
unsigned int activation_threshold;
|
| 17 |
+
unsigned int learning_threshold;
|
| 18 |
+
unsigned int cells_per_column;
|
| 19 |
+
unsigned int synapses_per_segment;
|
| 20 |
+
unsigned int n_segments;
|
| 21 |
+
unsigned int n_cells;
|
| 22 |
+
unsigned int max_segments_per_cell;
|
| 23 |
+
unsigned int max_new_synapses;
|
| 24 |
+
int conn_thr_i16;
|
| 25 |
+
int perm_inc_i16;
|
| 26 |
+
int perm_dec_i16;
|
| 27 |
+
int predicted_seg_dec_i16;
|
| 28 |
+
int initial_perm_i16;
|
| 29 |
+
unsigned int iter_seed;
|
| 30 |
+
unsigned int n_cols;
|
| 31 |
+
unsigned int bits_words;
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
extern "C" __global__
|
| 35 |
+
void tm_grow(
|
| 36 |
+
unsigned int * __restrict__ seg_cell_id,
|
| 37 |
+
unsigned int * __restrict__ seg_syn_count,
|
| 38 |
+
unsigned int * __restrict__ syn_presyn,
|
| 39 |
+
short * __restrict__ syn_perm,
|
| 40 |
+
unsigned int * __restrict__ cell_seg_count,
|
| 41 |
+
const unsigned int * __restrict__ burst_cols_flat,
|
| 42 |
+
const unsigned int * __restrict__ burst_cols_count,
|
| 43 |
+
const unsigned int * __restrict__ prev_winner_bits,
|
| 44 |
+
const unsigned int * __restrict__ prev_active_bits,
|
| 45 |
+
const unsigned int * __restrict__ col_best_match,
|
| 46 |
+
TmConfig cfg
|
| 47 |
+
) {
|
| 48 |
+
const unsigned int b = blockIdx.x;
|
| 49 |
+
const unsigned int n_burst_cols = burst_cols_count[0];
|
| 50 |
+
if (b >= n_burst_cols) return;
|
| 51 |
+
const unsigned int tid = threadIdx.x;
|
| 52 |
+
|
| 53 |
+
const unsigned int col = burst_cols_flat[b];
|
| 54 |
+
|
| 55 |
+
__shared__ unsigned int shared_seg_id;
|
| 56 |
+
__shared__ unsigned int shared_existing_syn_count;
|
| 57 |
+
__shared__ unsigned int shared_grown;
|
| 58 |
+
__shared__ unsigned int shared_is_new;
|
| 59 |
+
__shared__ unsigned int shared_start_offset;
|
| 60 |
+
|
| 61 |
+
if (tid == 0) {
|
| 62 |
+
unsigned int match_key = col_best_match[col];
|
| 63 |
+
if (match_key != 0u) {
|
| 64 |
+
// Reuse matching segment.
|
| 65 |
+
unsigned int seg_id = match_key & 0x1FFFFFu;
|
| 66 |
+
shared_seg_id = seg_id;
|
| 67 |
+
shared_existing_syn_count = seg_syn_count[seg_id];
|
| 68 |
+
shared_is_new = 0u;
|
| 69 |
+
} else {
|
| 70 |
+
// Allocate new segment on winner cell (cell 0 of col).
|
| 71 |
+
unsigned int winner_cell = col * cfg.cells_per_column;
|
| 72 |
+
unsigned int slot = atomicAdd(&cell_seg_count[winner_cell], 1u);
|
| 73 |
+
if (slot >= cfg.max_segments_per_cell) {
|
| 74 |
+
slot = slot % cfg.max_segments_per_cell;
|
| 75 |
+
}
|
| 76 |
+
unsigned int seg_id = winner_cell * cfg.max_segments_per_cell + slot;
|
| 77 |
+
seg_cell_id[seg_id] = winner_cell;
|
| 78 |
+
seg_syn_count[seg_id] = 0;
|
| 79 |
+
shared_seg_id = seg_id;
|
| 80 |
+
shared_existing_syn_count = 0u;
|
| 81 |
+
shared_is_new = 1u;
|
| 82 |
+
}
|
| 83 |
+
shared_grown = 0u;
|
| 84 |
+
shared_start_offset = (b * 2654435761u + cfg.iter_seed) % cfg.bits_words;
|
| 85 |
+
}
|
| 86 |
+
__syncthreads();
|
| 87 |
+
|
| 88 |
+
const unsigned int seg_id = shared_seg_id;
|
| 89 |
+
const unsigned int seg_base = seg_id * cfg.synapses_per_segment;
|
| 90 |
+
const unsigned int existing_syn = shared_existing_syn_count;
|
| 91 |
+
const unsigned int is_new = shared_is_new;
|
| 92 |
+
const unsigned int start = shared_start_offset;
|
| 93 |
+
|
| 94 |
+
// PHASE 1: If reusing, reinforce existing synapses.
|
| 95 |
+
if (!is_new) {
|
| 96 |
+
for (unsigned int s = tid; s < existing_syn; s += 32u) {
|
| 97 |
+
unsigned int presyn = syn_presyn[seg_base + s];
|
| 98 |
+
unsigned int word = prev_active_bits[presyn >> 5];
|
| 99 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 100 |
+
int p = (int)syn_perm[seg_base + s];
|
| 101 |
+
if (bit) {
|
| 102 |
+
int np = p + cfg.perm_inc_i16;
|
| 103 |
+
if (np > 32767) np = 32767;
|
| 104 |
+
syn_perm[seg_base + s] = (short)np;
|
| 105 |
+
} else {
|
| 106 |
+
int np = p - cfg.perm_dec_i16;
|
| 107 |
+
if (np < 0) np = 0;
|
| 108 |
+
syn_perm[seg_base + s] = (short)np;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
__syncthreads();
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// PHASE 2: Grow up to `max_new_synapses` (or room) synapses to prev_winners
|
| 115 |
+
// that aren't already presynaptic to this segment.
|
| 116 |
+
const unsigned int room = (cfg.synapses_per_segment > existing_syn)
|
| 117 |
+
? (cfg.synapses_per_segment - existing_syn) : 0u;
|
| 118 |
+
const unsigned int max_grow = (cfg.max_new_synapses < room) ? cfg.max_new_synapses : room;
|
| 119 |
+
|
| 120 |
+
for (unsigned int w_off = 0; w_off < cfg.bits_words; w_off += 32u) {
|
| 121 |
+
if (shared_grown >= max_grow) break;
|
| 122 |
+
unsigned int widx = (start + w_off + tid) % cfg.bits_words;
|
| 123 |
+
unsigned int word = prev_winner_bits[widx];
|
| 124 |
+
while (word != 0u) {
|
| 125 |
+
if (shared_grown >= max_grow) break;
|
| 126 |
+
unsigned int bit_pos = __ffs(word) - 1u;
|
| 127 |
+
word &= ~(1u << bit_pos);
|
| 128 |
+
unsigned int cell = widx * 32u + bit_pos;
|
| 129 |
+
if (cell >= cfg.n_cells) continue;
|
| 130 |
+
|
| 131 |
+
// Skip if already presynaptic (O(existing_syn) scan; usually small).
|
| 132 |
+
bool exists = false;
|
| 133 |
+
for (unsigned int s = 0; s < existing_syn; s++) {
|
| 134 |
+
if (syn_presyn[seg_base + s] == cell) { exists = true; break; }
|
| 135 |
+
}
|
| 136 |
+
if (exists) continue;
|
| 137 |
+
|
| 138 |
+
unsigned int slot = atomicAdd(&shared_grown, 1u);
|
| 139 |
+
if (slot >= max_grow) break;
|
| 140 |
+
unsigned int write_idx = existing_syn + slot;
|
| 141 |
+
if (write_idx >= cfg.synapses_per_segment) break;
|
| 142 |
+
syn_presyn[seg_base + write_idx] = cell;
|
| 143 |
+
syn_perm[seg_base + write_idx] = (short)cfg.initial_perm_i16;
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
__syncthreads();
|
| 147 |
+
|
| 148 |
+
if (tid == 0) {
|
| 149 |
+
unsigned int grown = shared_grown;
|
| 150 |
+
if (grown > max_grow) grown = max_grow;
|
| 151 |
+
unsigned int new_count = existing_syn + grown;
|
| 152 |
+
if (new_count > cfg.synapses_per_segment) new_count = cfg.synapses_per_segment;
|
| 153 |
+
seg_syn_count[seg_id] = new_count;
|
| 154 |
+
}
|
| 155 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_learn.cu
CHANGED
|
@@ -1,75 +1,75 @@
|
|
| 1 |
-
// TM learn (reinforce correctly predicted segments) β cell-grouped launch.
|
| 2 |
-
//
|
| 3 |
-
// Grid: n_cells.
|
| 4 |
-
// For each cell in a predicted, SP-active column: iterate its segments.
|
| 5 |
-
// For each segment with num_active_connected >= activation_threshold,
|
| 6 |
-
// reinforce its synapses against prev_active_bits.
|
| 7 |
-
|
| 8 |
-
struct TmConfig {
|
| 9 |
-
unsigned int activation_threshold;
|
| 10 |
-
unsigned int learning_threshold;
|
| 11 |
-
unsigned int cells_per_column;
|
| 12 |
-
unsigned int synapses_per_segment;
|
| 13 |
-
unsigned int n_segments;
|
| 14 |
-
unsigned int n_cells;
|
| 15 |
-
unsigned int max_segments_per_cell;
|
| 16 |
-
unsigned int max_new_synapses;
|
| 17 |
-
int conn_thr_i16;
|
| 18 |
-
int perm_inc_i16;
|
| 19 |
-
int perm_dec_i16;
|
| 20 |
-
int predicted_seg_dec_i16;
|
| 21 |
-
int initial_perm_i16;
|
| 22 |
-
unsigned int iter_seed;
|
| 23 |
-
unsigned int n_cols;
|
| 24 |
-
unsigned int bits_words;
|
| 25 |
-
};
|
| 26 |
-
|
| 27 |
-
extern "C" __global__
|
| 28 |
-
void tm_learn_reinforce(
|
| 29 |
-
const unsigned int * __restrict__ seg_cell_id,
|
| 30 |
-
const unsigned int * __restrict__ seg_syn_count,
|
| 31 |
-
const unsigned int * __restrict__ syn_presyn,
|
| 32 |
-
short * __restrict__ syn_perm,
|
| 33 |
-
const unsigned int * __restrict__ seg_num_active_connected,
|
| 34 |
-
const unsigned int * __restrict__ prev_active_bits,
|
| 35 |
-
const unsigned char * __restrict__ sp_active_mask,
|
| 36 |
-
const unsigned char * __restrict__ col_predicted,
|
| 37 |
-
const unsigned int * __restrict__ cell_seg_count,
|
| 38 |
-
TmConfig cfg
|
| 39 |
-
) {
|
| 40 |
-
const unsigned int cell = blockIdx.x;
|
| 41 |
-
if (cell >= cfg.n_cells) return;
|
| 42 |
-
const unsigned int col = cell / cfg.cells_per_column;
|
| 43 |
-
if (sp_active_mask[col] == 0) return;
|
| 44 |
-
if (col_predicted[col] == 0) return;
|
| 45 |
-
|
| 46 |
-
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 47 |
-
if (n_segs_here == 0) return;
|
| 48 |
-
|
| 49 |
-
const unsigned int tid = threadIdx.x;
|
| 50 |
-
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 51 |
-
|
| 52 |
-
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 53 |
-
const unsigned int seg = seg_base_id + local_seg;
|
| 54 |
-
if (seg_num_active_connected[seg] < cfg.activation_threshold) continue;
|
| 55 |
-
const unsigned int n_syn = seg_syn_count[seg];
|
| 56 |
-
if (n_syn == 0) continue;
|
| 57 |
-
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 58 |
-
|
| 59 |
-
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 60 |
-
unsigned int presyn = syn_presyn[syn_base + s];
|
| 61 |
-
unsigned int word = prev_active_bits[presyn >> 5];
|
| 62 |
-
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 63 |
-
int p = (int)syn_perm[syn_base + s];
|
| 64 |
-
if (bit) {
|
| 65 |
-
int np = p + cfg.perm_inc_i16;
|
| 66 |
-
if (np > 32767) np = 32767;
|
| 67 |
-
syn_perm[syn_base + s] = (short)np;
|
| 68 |
-
} else {
|
| 69 |
-
int np = p - cfg.perm_dec_i16;
|
| 70 |
-
if (np < 0) np = 0;
|
| 71 |
-
syn_perm[syn_base + s] = (short)np;
|
| 72 |
-
}
|
| 73 |
-
}
|
| 74 |
-
}
|
| 75 |
-
}
|
|
|
|
| 1 |
+
// TM learn (reinforce correctly predicted segments) β cell-grouped launch.
|
| 2 |
+
//
|
| 3 |
+
// Grid: n_cells.
|
| 4 |
+
// For each cell in a predicted, SP-active column: iterate its segments.
|
| 5 |
+
// For each segment with num_active_connected >= activation_threshold,
|
| 6 |
+
// reinforce its synapses against prev_active_bits.
|
| 7 |
+
|
| 8 |
+
struct TmConfig {
|
| 9 |
+
unsigned int activation_threshold;
|
| 10 |
+
unsigned int learning_threshold;
|
| 11 |
+
unsigned int cells_per_column;
|
| 12 |
+
unsigned int synapses_per_segment;
|
| 13 |
+
unsigned int n_segments;
|
| 14 |
+
unsigned int n_cells;
|
| 15 |
+
unsigned int max_segments_per_cell;
|
| 16 |
+
unsigned int max_new_synapses;
|
| 17 |
+
int conn_thr_i16;
|
| 18 |
+
int perm_inc_i16;
|
| 19 |
+
int perm_dec_i16;
|
| 20 |
+
int predicted_seg_dec_i16;
|
| 21 |
+
int initial_perm_i16;
|
| 22 |
+
unsigned int iter_seed;
|
| 23 |
+
unsigned int n_cols;
|
| 24 |
+
unsigned int bits_words;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
extern "C" __global__
|
| 28 |
+
void tm_learn_reinforce(
|
| 29 |
+
const unsigned int * __restrict__ seg_cell_id,
|
| 30 |
+
const unsigned int * __restrict__ seg_syn_count,
|
| 31 |
+
const unsigned int * __restrict__ syn_presyn,
|
| 32 |
+
short * __restrict__ syn_perm,
|
| 33 |
+
const unsigned int * __restrict__ seg_num_active_connected,
|
| 34 |
+
const unsigned int * __restrict__ prev_active_bits,
|
| 35 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 36 |
+
const unsigned char * __restrict__ col_predicted,
|
| 37 |
+
const unsigned int * __restrict__ cell_seg_count,
|
| 38 |
+
TmConfig cfg
|
| 39 |
+
) {
|
| 40 |
+
const unsigned int cell = blockIdx.x;
|
| 41 |
+
if (cell >= cfg.n_cells) return;
|
| 42 |
+
const unsigned int col = cell / cfg.cells_per_column;
|
| 43 |
+
if (sp_active_mask[col] == 0) return;
|
| 44 |
+
if (col_predicted[col] == 0) return;
|
| 45 |
+
|
| 46 |
+
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 47 |
+
if (n_segs_here == 0) return;
|
| 48 |
+
|
| 49 |
+
const unsigned int tid = threadIdx.x;
|
| 50 |
+
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 51 |
+
|
| 52 |
+
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 53 |
+
const unsigned int seg = seg_base_id + local_seg;
|
| 54 |
+
if (seg_num_active_connected[seg] < cfg.activation_threshold) continue;
|
| 55 |
+
const unsigned int n_syn = seg_syn_count[seg];
|
| 56 |
+
if (n_syn == 0) continue;
|
| 57 |
+
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 58 |
+
|
| 59 |
+
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 60 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 61 |
+
unsigned int word = prev_active_bits[presyn >> 5];
|
| 62 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 63 |
+
int p = (int)syn_perm[syn_base + s];
|
| 64 |
+
if (bit) {
|
| 65 |
+
int np = p + cfg.perm_inc_i16;
|
| 66 |
+
if (np > 32767) np = 32767;
|
| 67 |
+
syn_perm[syn_base + s] = (short)np;
|
| 68 |
+
} else {
|
| 69 |
+
int np = p - cfg.perm_dec_i16;
|
| 70 |
+
if (np < 0) np = 0;
|
| 71 |
+
syn_perm[syn_base + s] = (short)np;
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_predict.cu
CHANGED
|
@@ -1,102 +1,102 @@
|
|
| 1 |
-
// TM predict kernel β cell-grouped launch.
|
| 2 |
-
//
|
| 3 |
-
// Grid: n_cells blocks (one per cell).
|
| 4 |
-
// Block: 32 threads (one warp).
|
| 5 |
-
//
|
| 6 |
-
// Each block iterates the segments owned by its cell (count in cell_seg_count[cell]).
|
| 7 |
-
// For each live segment, counts active connected/potential synapses against
|
| 8 |
-
// prev_active_bits. Updates per-segment counters, cell_predictive bit, and
|
| 9 |
-
// col_predicted flag.
|
| 10 |
-
|
| 11 |
-
struct TmConfig {
|
| 12 |
-
unsigned int activation_threshold;
|
| 13 |
-
unsigned int learning_threshold;
|
| 14 |
-
unsigned int cells_per_column;
|
| 15 |
-
unsigned int synapses_per_segment;
|
| 16 |
-
unsigned int n_segments;
|
| 17 |
-
unsigned int n_cells;
|
| 18 |
-
unsigned int max_segments_per_cell;
|
| 19 |
-
unsigned int max_new_synapses;
|
| 20 |
-
int conn_thr_i16;
|
| 21 |
-
int perm_inc_i16;
|
| 22 |
-
int perm_dec_i16;
|
| 23 |
-
int predicted_seg_dec_i16;
|
| 24 |
-
int initial_perm_i16;
|
| 25 |
-
unsigned int iter_seed;
|
| 26 |
-
unsigned int n_cols;
|
| 27 |
-
unsigned int bits_words;
|
| 28 |
-
};
|
| 29 |
-
|
| 30 |
-
extern "C" __global__
|
| 31 |
-
void tm_predict(
|
| 32 |
-
const unsigned int * __restrict__ seg_cell_id,
|
| 33 |
-
const unsigned int * __restrict__ seg_syn_count,
|
| 34 |
-
const unsigned int * __restrict__ syn_presyn,
|
| 35 |
-
const short * __restrict__ syn_perm,
|
| 36 |
-
const unsigned int * __restrict__ cell_active_bits,
|
| 37 |
-
unsigned int * __restrict__ cell_predictive_bits,
|
| 38 |
-
unsigned char * __restrict__ col_predicted,
|
| 39 |
-
unsigned int * __restrict__ seg_num_active_connected,
|
| 40 |
-
unsigned int * __restrict__ seg_num_active_potential,
|
| 41 |
-
unsigned int * __restrict__ col_best_match,
|
| 42 |
-
const unsigned int * __restrict__ cell_seg_count,
|
| 43 |
-
TmConfig cfg
|
| 44 |
-
) {
|
| 45 |
-
const unsigned int cell = blockIdx.x;
|
| 46 |
-
if (cell >= cfg.n_cells) return;
|
| 47 |
-
|
| 48 |
-
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 49 |
-
if (n_segs_here == 0) return;
|
| 50 |
-
|
| 51 |
-
const unsigned int tid = threadIdx.x;
|
| 52 |
-
const unsigned int col = cell / cfg.cells_per_column;
|
| 53 |
-
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 54 |
-
|
| 55 |
-
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 56 |
-
const unsigned int seg = seg_base_id + local_seg;
|
| 57 |
-
const unsigned int n_syn = seg_syn_count[seg];
|
| 58 |
-
if (n_syn == 0) {
|
| 59 |
-
if (tid == 0) {
|
| 60 |
-
seg_num_active_connected[seg] = 0;
|
| 61 |
-
seg_num_active_potential[seg] = 0;
|
| 62 |
-
}
|
| 63 |
-
continue;
|
| 64 |
-
}
|
| 65 |
-
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 66 |
-
|
| 67 |
-
unsigned int local_conn = 0;
|
| 68 |
-
unsigned int local_pot = 0;
|
| 69 |
-
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 70 |
-
unsigned int presyn = syn_presyn[syn_base + s];
|
| 71 |
-
unsigned int word = cell_active_bits[presyn >> 5];
|
| 72 |
-
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 73 |
-
if (bit) {
|
| 74 |
-
local_pot += 1u;
|
| 75 |
-
int p = (int)syn_perm[syn_base + s];
|
| 76 |
-
if (p >= cfg.conn_thr_i16) {
|
| 77 |
-
local_conn += 1u;
|
| 78 |
-
}
|
| 79 |
-
}
|
| 80 |
-
}
|
| 81 |
-
for (int off = 16; off > 0; off >>= 1) {
|
| 82 |
-
local_conn += __shfl_down_sync(0xffffffffu, local_conn, off);
|
| 83 |
-
local_pot += __shfl_down_sync(0xffffffffu, local_pot, off);
|
| 84 |
-
}
|
| 85 |
-
|
| 86 |
-
if (tid == 0) {
|
| 87 |
-
seg_num_active_connected[seg] = local_conn;
|
| 88 |
-
seg_num_active_potential[seg] = local_pot;
|
| 89 |
-
if (local_conn >= cfg.activation_threshold) {
|
| 90 |
-
unsigned int word_idx = cell >> 5;
|
| 91 |
-
unsigned int bit_mask = 1u << (cell & 31u);
|
| 92 |
-
atomicOr(&cell_predictive_bits[word_idx], bit_mask);
|
| 93 |
-
col_predicted[col] = 1;
|
| 94 |
-
}
|
| 95 |
-
if (local_pot >= cfg.learning_threshold) {
|
| 96 |
-
unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot;
|
| 97 |
-
unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu);
|
| 98 |
-
atomicMax(&col_best_match[col], key);
|
| 99 |
-
}
|
| 100 |
-
}
|
| 101 |
-
}
|
| 102 |
-
}
|
|
|
|
| 1 |
+
// TM predict kernel β cell-grouped launch.
|
| 2 |
+
//
|
| 3 |
+
// Grid: n_cells blocks (one per cell).
|
| 4 |
+
// Block: 32 threads (one warp).
|
| 5 |
+
//
|
| 6 |
+
// Each block iterates the segments owned by its cell (count in cell_seg_count[cell]).
|
| 7 |
+
// For each live segment, counts active connected/potential synapses against
|
| 8 |
+
// prev_active_bits. Updates per-segment counters, cell_predictive bit, and
|
| 9 |
+
// col_predicted flag.
|
| 10 |
+
|
| 11 |
+
struct TmConfig {
|
| 12 |
+
unsigned int activation_threshold;
|
| 13 |
+
unsigned int learning_threshold;
|
| 14 |
+
unsigned int cells_per_column;
|
| 15 |
+
unsigned int synapses_per_segment;
|
| 16 |
+
unsigned int n_segments;
|
| 17 |
+
unsigned int n_cells;
|
| 18 |
+
unsigned int max_segments_per_cell;
|
| 19 |
+
unsigned int max_new_synapses;
|
| 20 |
+
int conn_thr_i16;
|
| 21 |
+
int perm_inc_i16;
|
| 22 |
+
int perm_dec_i16;
|
| 23 |
+
int predicted_seg_dec_i16;
|
| 24 |
+
int initial_perm_i16;
|
| 25 |
+
unsigned int iter_seed;
|
| 26 |
+
unsigned int n_cols;
|
| 27 |
+
unsigned int bits_words;
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
extern "C" __global__
|
| 31 |
+
void tm_predict(
|
| 32 |
+
const unsigned int * __restrict__ seg_cell_id,
|
| 33 |
+
const unsigned int * __restrict__ seg_syn_count,
|
| 34 |
+
const unsigned int * __restrict__ syn_presyn,
|
| 35 |
+
const short * __restrict__ syn_perm,
|
| 36 |
+
const unsigned int * __restrict__ cell_active_bits,
|
| 37 |
+
unsigned int * __restrict__ cell_predictive_bits,
|
| 38 |
+
unsigned char * __restrict__ col_predicted,
|
| 39 |
+
unsigned int * __restrict__ seg_num_active_connected,
|
| 40 |
+
unsigned int * __restrict__ seg_num_active_potential,
|
| 41 |
+
unsigned int * __restrict__ col_best_match,
|
| 42 |
+
const unsigned int * __restrict__ cell_seg_count,
|
| 43 |
+
TmConfig cfg
|
| 44 |
+
) {
|
| 45 |
+
const unsigned int cell = blockIdx.x;
|
| 46 |
+
if (cell >= cfg.n_cells) return;
|
| 47 |
+
|
| 48 |
+
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 49 |
+
if (n_segs_here == 0) return;
|
| 50 |
+
|
| 51 |
+
const unsigned int tid = threadIdx.x;
|
| 52 |
+
const unsigned int col = cell / cfg.cells_per_column;
|
| 53 |
+
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 54 |
+
|
| 55 |
+
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 56 |
+
const unsigned int seg = seg_base_id + local_seg;
|
| 57 |
+
const unsigned int n_syn = seg_syn_count[seg];
|
| 58 |
+
if (n_syn == 0) {
|
| 59 |
+
if (tid == 0) {
|
| 60 |
+
seg_num_active_connected[seg] = 0;
|
| 61 |
+
seg_num_active_potential[seg] = 0;
|
| 62 |
+
}
|
| 63 |
+
continue;
|
| 64 |
+
}
|
| 65 |
+
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 66 |
+
|
| 67 |
+
unsigned int local_conn = 0;
|
| 68 |
+
unsigned int local_pot = 0;
|
| 69 |
+
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 70 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 71 |
+
unsigned int word = cell_active_bits[presyn >> 5];
|
| 72 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 73 |
+
if (bit) {
|
| 74 |
+
local_pot += 1u;
|
| 75 |
+
int p = (int)syn_perm[syn_base + s];
|
| 76 |
+
if (p >= cfg.conn_thr_i16) {
|
| 77 |
+
local_conn += 1u;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
for (int off = 16; off > 0; off >>= 1) {
|
| 82 |
+
local_conn += __shfl_down_sync(0xffffffffu, local_conn, off);
|
| 83 |
+
local_pot += __shfl_down_sync(0xffffffffu, local_pot, off);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
if (tid == 0) {
|
| 87 |
+
seg_num_active_connected[seg] = local_conn;
|
| 88 |
+
seg_num_active_potential[seg] = local_pot;
|
| 89 |
+
if (local_conn >= cfg.activation_threshold) {
|
| 90 |
+
unsigned int word_idx = cell >> 5;
|
| 91 |
+
unsigned int bit_mask = 1u << (cell & 31u);
|
| 92 |
+
atomicOr(&cell_predictive_bits[word_idx], bit_mask);
|
| 93 |
+
col_predicted[col] = 1;
|
| 94 |
+
}
|
| 95 |
+
if (local_pot >= cfg.learning_threshold) {
|
| 96 |
+
unsigned int pot_c = local_pot > 2047u ? 2047u : local_pot;
|
| 97 |
+
unsigned int key = (pot_c << 21) | (seg & 0x1FFFFFu);
|
| 98 |
+
atomicMax(&col_best_match[col], key);
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_punish.cu
CHANGED
|
@@ -1,64 +1,64 @@
|
|
| 1 |
-
// TM punish β cell-grouped launch.
|
| 2 |
-
|
| 3 |
-
struct TmConfig {
|
| 4 |
-
unsigned int activation_threshold;
|
| 5 |
-
unsigned int learning_threshold;
|
| 6 |
-
unsigned int cells_per_column;
|
| 7 |
-
unsigned int synapses_per_segment;
|
| 8 |
-
unsigned int n_segments;
|
| 9 |
-
unsigned int n_cells;
|
| 10 |
-
unsigned int max_segments_per_cell;
|
| 11 |
-
unsigned int max_new_synapses;
|
| 12 |
-
int conn_thr_i16;
|
| 13 |
-
int perm_inc_i16;
|
| 14 |
-
int perm_dec_i16;
|
| 15 |
-
int predicted_seg_dec_i16;
|
| 16 |
-
int initial_perm_i16;
|
| 17 |
-
unsigned int iter_seed;
|
| 18 |
-
unsigned int n_cols;
|
| 19 |
-
unsigned int bits_words;
|
| 20 |
-
};
|
| 21 |
-
|
| 22 |
-
extern "C" __global__
|
| 23 |
-
void tm_punish(
|
| 24 |
-
const unsigned int * __restrict__ seg_cell_id,
|
| 25 |
-
const unsigned int * __restrict__ seg_syn_count,
|
| 26 |
-
const unsigned int * __restrict__ syn_presyn,
|
| 27 |
-
short * __restrict__ syn_perm,
|
| 28 |
-
const unsigned int * __restrict__ seg_num_active_potential,
|
| 29 |
-
const unsigned int * __restrict__ prev_active_bits,
|
| 30 |
-
const unsigned char * __restrict__ sp_active_mask,
|
| 31 |
-
const unsigned int * __restrict__ cell_seg_count,
|
| 32 |
-
TmConfig cfg
|
| 33 |
-
) {
|
| 34 |
-
const unsigned int cell = blockIdx.x;
|
| 35 |
-
if (cell >= cfg.n_cells) return;
|
| 36 |
-
const unsigned int col = cell / cfg.cells_per_column;
|
| 37 |
-
if (sp_active_mask[col] != 0) return; // skip: col became active
|
| 38 |
-
|
| 39 |
-
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 40 |
-
if (n_segs_here == 0) return;
|
| 41 |
-
|
| 42 |
-
const unsigned int tid = threadIdx.x;
|
| 43 |
-
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 44 |
-
|
| 45 |
-
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 46 |
-
const unsigned int seg = seg_base_id + local_seg;
|
| 47 |
-
if (seg_num_active_potential[seg] < cfg.learning_threshold) continue;
|
| 48 |
-
const unsigned int n_syn = seg_syn_count[seg];
|
| 49 |
-
if (n_syn == 0) continue;
|
| 50 |
-
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 51 |
-
|
| 52 |
-
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 53 |
-
unsigned int presyn = syn_presyn[syn_base + s];
|
| 54 |
-
unsigned int word = prev_active_bits[presyn >> 5];
|
| 55 |
-
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 56 |
-
if (bit) {
|
| 57 |
-
int p = (int)syn_perm[syn_base + s];
|
| 58 |
-
int np = p - cfg.predicted_seg_dec_i16;
|
| 59 |
-
if (np < 0) np = 0;
|
| 60 |
-
syn_perm[syn_base + s] = (short)np;
|
| 61 |
-
}
|
| 62 |
-
}
|
| 63 |
-
}
|
| 64 |
-
}
|
|
|
|
| 1 |
+
// TM punish β cell-grouped launch.
|
| 2 |
+
|
| 3 |
+
struct TmConfig {
|
| 4 |
+
unsigned int activation_threshold;
|
| 5 |
+
unsigned int learning_threshold;
|
| 6 |
+
unsigned int cells_per_column;
|
| 7 |
+
unsigned int synapses_per_segment;
|
| 8 |
+
unsigned int n_segments;
|
| 9 |
+
unsigned int n_cells;
|
| 10 |
+
unsigned int max_segments_per_cell;
|
| 11 |
+
unsigned int max_new_synapses;
|
| 12 |
+
int conn_thr_i16;
|
| 13 |
+
int perm_inc_i16;
|
| 14 |
+
int perm_dec_i16;
|
| 15 |
+
int predicted_seg_dec_i16;
|
| 16 |
+
int initial_perm_i16;
|
| 17 |
+
unsigned int iter_seed;
|
| 18 |
+
unsigned int n_cols;
|
| 19 |
+
unsigned int bits_words;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
extern "C" __global__
|
| 23 |
+
void tm_punish(
|
| 24 |
+
const unsigned int * __restrict__ seg_cell_id,
|
| 25 |
+
const unsigned int * __restrict__ seg_syn_count,
|
| 26 |
+
const unsigned int * __restrict__ syn_presyn,
|
| 27 |
+
short * __restrict__ syn_perm,
|
| 28 |
+
const unsigned int * __restrict__ seg_num_active_potential,
|
| 29 |
+
const unsigned int * __restrict__ prev_active_bits,
|
| 30 |
+
const unsigned char * __restrict__ sp_active_mask,
|
| 31 |
+
const unsigned int * __restrict__ cell_seg_count,
|
| 32 |
+
TmConfig cfg
|
| 33 |
+
) {
|
| 34 |
+
const unsigned int cell = blockIdx.x;
|
| 35 |
+
if (cell >= cfg.n_cells) return;
|
| 36 |
+
const unsigned int col = cell / cfg.cells_per_column;
|
| 37 |
+
if (sp_active_mask[col] != 0) return; // skip: col became active
|
| 38 |
+
|
| 39 |
+
const unsigned int n_segs_here = min(cell_seg_count[cell], cfg.max_segments_per_cell);
|
| 40 |
+
if (n_segs_here == 0) return;
|
| 41 |
+
|
| 42 |
+
const unsigned int tid = threadIdx.x;
|
| 43 |
+
const unsigned int seg_base_id = cell * cfg.max_segments_per_cell;
|
| 44 |
+
|
| 45 |
+
for (unsigned int local_seg = 0; local_seg < n_segs_here; local_seg++) {
|
| 46 |
+
const unsigned int seg = seg_base_id + local_seg;
|
| 47 |
+
if (seg_num_active_potential[seg] < cfg.learning_threshold) continue;
|
| 48 |
+
const unsigned int n_syn = seg_syn_count[seg];
|
| 49 |
+
if (n_syn == 0) continue;
|
| 50 |
+
const unsigned int syn_base = seg * cfg.synapses_per_segment;
|
| 51 |
+
|
| 52 |
+
for (unsigned int s = tid; s < n_syn; s += 32u) {
|
| 53 |
+
unsigned int presyn = syn_presyn[syn_base + s];
|
| 54 |
+
unsigned int word = prev_active_bits[presyn >> 5];
|
| 55 |
+
unsigned int bit = (word >> (presyn & 31u)) & 1u;
|
| 56 |
+
if (bit) {
|
| 57 |
+
int p = (int)syn_perm[syn_base + s];
|
| 58 |
+
int np = p - cfg.predicted_seg_dec_i16;
|
| 59 |
+
if (np < 0) np = 0;
|
| 60 |
+
syn_perm[syn_base + s] = (short)np;
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
overlay/htm_rust/src/gpu/kernels/tm_reset.cu
CHANGED
|
@@ -1,36 +1,36 @@
|
|
| 1 |
-
// TM reset-per-step kernel.
|
| 2 |
-
|
| 3 |
-
extern "C" __global__
|
| 4 |
-
void tm_reset_step(
|
| 5 |
-
unsigned int * __restrict__ cell_active_bits,
|
| 6 |
-
unsigned int * __restrict__ cell_winner_bits,
|
| 7 |
-
unsigned int * __restrict__ cell_predictive_bits,
|
| 8 |
-
unsigned int * __restrict__ prev_active_bits,
|
| 9 |
-
unsigned int * __restrict__ prev_winner_bits,
|
| 10 |
-
unsigned char * __restrict__ col_predicted,
|
| 11 |
-
unsigned int * __restrict__ unpredicted_count,
|
| 12 |
-
unsigned int * __restrict__ burst_cols_count,
|
| 13 |
-
unsigned int * __restrict__ col_best_match,
|
| 14 |
-
unsigned int bits_words,
|
| 15 |
-
unsigned int n_cols
|
| 16 |
-
) {
|
| 17 |
-
unsigned int tid_global = blockIdx.x * blockDim.x + threadIdx.x;
|
| 18 |
-
|
| 19 |
-
if (tid_global < bits_words) {
|
| 20 |
-
prev_active_bits[tid_global] = cell_active_bits[tid_global];
|
| 21 |
-
prev_winner_bits[tid_global] = cell_winner_bits[tid_global];
|
| 22 |
-
cell_active_bits[tid_global] = 0u;
|
| 23 |
-
cell_winner_bits[tid_global] = 0u;
|
| 24 |
-
cell_predictive_bits[tid_global] = 0u;
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
if (tid_global < n_cols) {
|
| 28 |
-
col_predicted[tid_global] = 0;
|
| 29 |
-
col_best_match[tid_global] = 0u;
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
if (tid_global == 0) {
|
| 33 |
-
unpredicted_count[0] = 0u;
|
| 34 |
-
burst_cols_count[0] = 0u;
|
| 35 |
-
}
|
| 36 |
-
}
|
|
|
|
| 1 |
+
// TM reset-per-step kernel.
|
| 2 |
+
|
| 3 |
+
extern "C" __global__
|
| 4 |
+
void tm_reset_step(
|
| 5 |
+
unsigned int * __restrict__ cell_active_bits,
|
| 6 |
+
unsigned int * __restrict__ cell_winner_bits,
|
| 7 |
+
unsigned int * __restrict__ cell_predictive_bits,
|
| 8 |
+
unsigned int * __restrict__ prev_active_bits,
|
| 9 |
+
unsigned int * __restrict__ prev_winner_bits,
|
| 10 |
+
unsigned char * __restrict__ col_predicted,
|
| 11 |
+
unsigned int * __restrict__ unpredicted_count,
|
| 12 |
+
unsigned int * __restrict__ burst_cols_count,
|
| 13 |
+
unsigned int * __restrict__ col_best_match,
|
| 14 |
+
unsigned int bits_words,
|
| 15 |
+
unsigned int n_cols
|
| 16 |
+
) {
|
| 17 |
+
unsigned int tid_global = blockIdx.x * blockDim.x + threadIdx.x;
|
| 18 |
+
|
| 19 |
+
if (tid_global < bits_words) {
|
| 20 |
+
prev_active_bits[tid_global] = cell_active_bits[tid_global];
|
| 21 |
+
prev_winner_bits[tid_global] = cell_winner_bits[tid_global];
|
| 22 |
+
cell_active_bits[tid_global] = 0u;
|
| 23 |
+
cell_winner_bits[tid_global] = 0u;
|
| 24 |
+
cell_predictive_bits[tid_global] = 0u;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
if (tid_global < n_cols) {
|
| 28 |
+
col_predicted[tid_global] = 0;
|
| 29 |
+
col_best_match[tid_global] = 0u;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
if (tid_global == 0) {
|
| 33 |
+
unpredicted_count[0] = 0u;
|
| 34 |
+
burst_cols_count[0] = 0u;
|
| 35 |
+
}
|
| 36 |
+
}
|
overlay/htm_rust/src/gpu/mod.rs
CHANGED
|
@@ -1,549 +1,549 @@
|
|
| 1 |
-
//! GPU backend for HTM.
|
| 2 |
-
//!
|
| 3 |
-
//! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
|
| 4 |
-
//! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
|
| 5 |
-
//! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
|
| 6 |
-
//! to the host in one shot.
|
| 7 |
-
//!
|
| 8 |
-
//! TM parity with the CPU reference is approximate:
|
| 9 |
-
//! - Segment growth: winner = cell 0 of bursting column (CPU picks
|
| 10 |
-
//! least-used-cell with RNG tiebreak). This is a pragmatic simplification
|
| 11 |
-
//! for GPU atomicity; learning dynamics are preserved.
|
| 12 |
-
//! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
|
| 13 |
-
//! f32 by <= 1 ULP of the scale factor (β 3e-5) β inside any meaningful
|
| 14 |
-
//! HTM learning quantum.
|
| 15 |
-
|
| 16 |
-
#![cfg(feature = "gpu")]
|
| 17 |
-
|
| 18 |
-
pub mod sp_gpu;
|
| 19 |
-
pub mod tm_gpu;
|
| 20 |
-
pub mod fused;
|
| 21 |
-
|
| 22 |
-
#[cfg(test)]
|
| 23 |
-
mod tests;
|
| 24 |
-
|
| 25 |
-
use std::mem::ManuallyDrop;
|
| 26 |
-
|
| 27 |
-
use pyo3::prelude::*;
|
| 28 |
-
use pyo3::types::{PyDict, PyTuple};
|
| 29 |
-
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
|
| 30 |
-
|
| 31 |
-
use crate::region::HTMRegionCore;
|
| 32 |
-
use crate::sp::SpatialPoolerConfig;
|
| 33 |
-
use sp_gpu::SpatialPoolerGpu;
|
| 34 |
-
use tm_gpu::TemporalMemoryGpu;
|
| 35 |
-
use fused::FusedState;
|
| 36 |
-
|
| 37 |
-
/// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
|
| 38 |
-
/// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
|
| 39 |
-
/// torch-owned CUDA allocations zero-copy.
|
| 40 |
-
fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
|
| 41 |
-
// `data` is a (ptr: int, readonly: bool) tuple.
|
| 42 |
-
let data_obj = cai.get_item("data")?
|
| 43 |
-
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
|
| 44 |
-
let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
|
| 45 |
-
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
|
| 46 |
-
let ptr: u64 = data_tup.get_item(0)?.extract()?;
|
| 47 |
-
|
| 48 |
-
// `shape` is a tuple of ints.
|
| 49 |
-
let shape_obj = cai.get_item("shape")?
|
| 50 |
-
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
|
| 51 |
-
let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
|
| 52 |
-
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
|
| 53 |
-
let shape: Vec<usize> = (0..shape_tup.len())
|
| 54 |
-
.map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
|
| 55 |
-
.collect::<PyResult<Vec<_>>>()?;
|
| 56 |
-
|
| 57 |
-
// `typestr` (e.g. "|u1", "<f4").
|
| 58 |
-
let typestr_obj = cai.get_item("typestr")?
|
| 59 |
-
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
|
| 60 |
-
let typestr: String = typestr_obj.extract()?;
|
| 61 |
-
|
| 62 |
-
// Reject non-contiguous tensors β we don't handle strides.
|
| 63 |
-
if let Some(strides) = cai.get_item("strides")? {
|
| 64 |
-
if !strides.is_none() {
|
| 65 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 66 |
-
"CAI 'strides' must be None (tensor must be contiguous)",
|
| 67 |
-
));
|
| 68 |
-
}
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
Ok((ptr, shape, typestr))
|
| 72 |
-
}
|
| 73 |
-
|
| 74 |
-
/// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
|
| 75 |
-
#[pyclass(module = "htm_rust")]
|
| 76 |
-
pub struct HTMRegionGpu {
|
| 77 |
-
pub(super) sp_gpu: SpatialPoolerGpu,
|
| 78 |
-
pub(super) tm_gpu: TemporalMemoryGpu,
|
| 79 |
-
pub(super) fused_state: FusedState,
|
| 80 |
-
pub(super) n_columns: usize,
|
| 81 |
-
pub(super) input_bits: usize,
|
| 82 |
-
pub(super) cells_per_column: usize,
|
| 83 |
-
}
|
| 84 |
-
|
| 85 |
-
#[pymethods]
|
| 86 |
-
impl HTMRegionGpu {
|
| 87 |
-
#[new]
|
| 88 |
-
#[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
|
| 89 |
-
fn new(
|
| 90 |
-
input_bits: usize,
|
| 91 |
-
n_columns: usize,
|
| 92 |
-
cells_per_column: usize,
|
| 93 |
-
seed: u64,
|
| 94 |
-
) -> PyResult<Self> {
|
| 95 |
-
if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
|
| 96 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 97 |
-
"input_bits, n_columns, cells_per_column must all be > 0",
|
| 98 |
-
));
|
| 99 |
-
}
|
| 100 |
-
// CPU reference for deterministic SP init.
|
| 101 |
-
let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
|
| 102 |
-
let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
|
| 103 |
-
let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
|
| 104 |
-
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 105 |
-
"GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
|
| 106 |
-
sp_cfg.input_bits, sp_cfg.n_columns,
|
| 107 |
-
))
|
| 108 |
-
})?;
|
| 109 |
-
let dev = sp_gpu.dev_ref().clone();
|
| 110 |
-
let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
|
| 111 |
-
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 112 |
-
"GPU TM init failed: {e:?}",
|
| 113 |
-
))
|
| 114 |
-
})?;
|
| 115 |
-
let initial_threshold = sp_gpu.initial_threshold_estimate();
|
| 116 |
-
let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
|
| 117 |
-
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 118 |
-
"GPU fused state init failed: {e:?}",
|
| 119 |
-
)))?;
|
| 120 |
-
Ok(Self {
|
| 121 |
-
sp_gpu,
|
| 122 |
-
tm_gpu,
|
| 123 |
-
fused_state,
|
| 124 |
-
n_columns,
|
| 125 |
-
input_bits,
|
| 126 |
-
cells_per_column,
|
| 127 |
-
})
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
#[getter] fn input_bits(&self) -> usize { self.input_bits }
|
| 131 |
-
#[getter] fn n_columns(&self) -> usize { self.n_columns }
|
| 132 |
-
#[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
|
| 133 |
-
|
| 134 |
-
/// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
|
| 135 |
-
/// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
|
| 136 |
-
/// to the host at the end.
|
| 137 |
-
#[pyo3(signature = (inputs, learn=true))]
|
| 138 |
-
fn step_many_gpu<'py>(
|
| 139 |
-
&mut self,
|
| 140 |
-
py: Python<'py>,
|
| 141 |
-
inputs: PyReadonlyArray2<'py, bool>,
|
| 142 |
-
learn: bool,
|
| 143 |
-
) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
|
| 144 |
-
let shape = inputs.shape();
|
| 145 |
-
if shape.len() != 2 {
|
| 146 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 147 |
-
"inputs must be 2-D (T, input_bits)",
|
| 148 |
-
));
|
| 149 |
-
}
|
| 150 |
-
let t = shape[0];
|
| 151 |
-
let bits = shape[1];
|
| 152 |
-
if bits != self.input_bits {
|
| 153 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 154 |
-
"inputs last dim {bits} != expected input_bits {}",
|
| 155 |
-
self.input_bits,
|
| 156 |
-
)));
|
| 157 |
-
}
|
| 158 |
-
let slice = inputs.as_slice()?;
|
| 159 |
-
let n_cols = self.n_columns;
|
| 160 |
-
let input_vec: Vec<bool> = slice.to_vec();
|
| 161 |
-
|
| 162 |
-
let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
|
| 163 |
-
// 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
|
| 164 |
-
let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
|
| 165 |
-
let inputs_dev = self
|
| 166 |
-
.sp_gpu
|
| 167 |
-
.dev_ref()
|
| 168 |
-
.htod_sync_copy(&sdr_u8_all)
|
| 169 |
-
.map_err(|e| format!("H2D inputs: {e:?}"))?;
|
| 170 |
-
|
| 171 |
-
// 2. Allocate output buffers on device.
|
| 172 |
-
let mut cols_dev = self.sp_gpu.dev_ref()
|
| 173 |
-
.alloc_zeros::<u8>(t * n_cols)
|
| 174 |
-
.map_err(|e| format!("alloc cols: {e:?}"))?;
|
| 175 |
-
let mut anom_dev = self.sp_gpu.dev_ref()
|
| 176 |
-
.alloc_zeros::<f32>(t)
|
| 177 |
-
.map_err(|e| format!("alloc anom: {e:?}"))?;
|
| 178 |
-
|
| 179 |
-
// 3. Run T steps of SP + TM on GPU with NO per-step host sync.
|
| 180 |
-
self.sp_gpu.step_batch_with_tm(
|
| 181 |
-
&inputs_dev,
|
| 182 |
-
t,
|
| 183 |
-
self.input_bits,
|
| 184 |
-
learn,
|
| 185 |
-
&mut cols_dev,
|
| 186 |
-
&mut anom_dev,
|
| 187 |
-
&mut self.tm_gpu,
|
| 188 |
-
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
|
| 189 |
-
|
| 190 |
-
// 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
|
| 191 |
-
let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
|
| 192 |
-
.dtoh_sync_copy(&cols_dev)
|
| 193 |
-
.map_err(|e| format!("D2H cols: {e:?}"))?;
|
| 194 |
-
let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
|
| 195 |
-
.dtoh_sync_copy(&anom_dev)
|
| 196 |
-
.map_err(|e| format!("D2H anom: {e:?}"))?;
|
| 197 |
-
|
| 198 |
-
Ok((cols_host, anom_host))
|
| 199 |
-
});
|
| 200 |
-
|
| 201 |
-
let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 202 |
-
|
| 203 |
-
let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
|
| 204 |
-
let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
|
| 205 |
-
.reshape([t, n_cols])
|
| 206 |
-
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
|
| 207 |
-
let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
|
| 208 |
-
Ok((cols_arr, anom_arr))
|
| 209 |
-
}
|
| 210 |
-
|
| 211 |
-
/// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
|
| 212 |
-
/// write outputs directly into caller-allocated torch tensors. Skips the
|
| 213 |
-
/// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
|
| 214 |
-
/// two D2H copies at the end). This is the hot path for `train.py`.
|
| 215 |
-
///
|
| 216 |
-
/// Contract:
|
| 217 |
-
/// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
|
| 218 |
-
/// cols_cai.shape == (T, n_columns), dtype u8 (written)
|
| 219 |
-
/// anom_cai.shape == (T,), dtype f32 (written)
|
| 220 |
-
/// All three tensors must live on the SAME CUDA device as this region.
|
| 221 |
-
///
|
| 222 |
-
/// The torch tensors still own their memory β this method only wraps
|
| 223 |
-
/// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
|
| 224 |
-
/// impl can't free pytorch's allocator.
|
| 225 |
-
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
|
| 226 |
-
fn step_many_cuda(
|
| 227 |
-
&mut self,
|
| 228 |
-
py: Python<'_>,
|
| 229 |
-
sdr_cai: &Bound<'_, PyDict>,
|
| 230 |
-
cols_cai: &Bound<'_, PyDict>,
|
| 231 |
-
anom_cai: &Bound<'_, PyDict>,
|
| 232 |
-
learn: bool,
|
| 233 |
-
) -> PyResult<()> {
|
| 234 |
-
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
|
| 235 |
-
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
|
| 236 |
-
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
|
| 237 |
-
|
| 238 |
-
// typestr sanity. numpy u1 is what torch.uint8 exports.
|
| 239 |
-
if sdr_type != "|u1" {
|
| 240 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 241 |
-
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
|
| 242 |
-
)));
|
| 243 |
-
}
|
| 244 |
-
if cols_type != "|u1" {
|
| 245 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 246 |
-
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
|
| 247 |
-
)));
|
| 248 |
-
}
|
| 249 |
-
if anom_type != "<f4" && anom_type != "=f4" {
|
| 250 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 251 |
-
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
|
| 252 |
-
)));
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
-
// Shape validation.
|
| 256 |
-
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
|
| 257 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 258 |
-
"sdr_cai shape {sdr_shape:?} != (T, {})",
|
| 259 |
-
self.input_bits,
|
| 260 |
-
)));
|
| 261 |
-
}
|
| 262 |
-
let t = sdr_shape[0];
|
| 263 |
-
if cols_shape != [t, self.n_columns] {
|
| 264 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 265 |
-
"cols_cai shape {cols_shape:?} != ({t}, {})",
|
| 266 |
-
self.n_columns,
|
| 267 |
-
)));
|
| 268 |
-
}
|
| 269 |
-
if anom_shape != [t] {
|
| 270 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 271 |
-
"anom_cai shape {anom_shape:?} != ({t},)",
|
| 272 |
-
)));
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
let dev = self.sp_gpu.dev_ref().clone();
|
| 276 |
-
let n_cols = self.n_columns;
|
| 277 |
-
let input_bits = self.input_bits;
|
| 278 |
-
|
| 279 |
-
let result = py.allow_threads(|| -> Result<(), String> {
|
| 280 |
-
// SAFETY:
|
| 281 |
-
// - ptrs came from torch CUDA tensors validated non-null by the
|
| 282 |
-
// __cuda_array_interface__ contract.
|
| 283 |
-
// - lens computed from validated shapes.
|
| 284 |
-
// - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
|
| 285 |
-
// Drop (which calls cuMemFree) never runs against torch memory.
|
| 286 |
-
// The underlying allocation is owned+freed by torch.
|
| 287 |
-
// - The slices are used only for the duration of this call;
|
| 288 |
-
// torch guarantees the backing tensors are live across it
|
| 289 |
-
// (Python holds refs on the wrapping tensors).
|
| 290 |
-
let inputs_dev = ManuallyDrop::new(unsafe {
|
| 291 |
-
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
|
| 292 |
-
});
|
| 293 |
-
let mut cols_dev = ManuallyDrop::new(unsafe {
|
| 294 |
-
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
|
| 295 |
-
});
|
| 296 |
-
let mut anom_dev = ManuallyDrop::new(unsafe {
|
| 297 |
-
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
|
| 298 |
-
});
|
| 299 |
-
|
| 300 |
-
self.sp_gpu.step_batch_with_tm(
|
| 301 |
-
&inputs_dev,
|
| 302 |
-
t,
|
| 303 |
-
input_bits,
|
| 304 |
-
learn,
|
| 305 |
-
&mut cols_dev,
|
| 306 |
-
&mut anom_dev,
|
| 307 |
-
&mut self.tm_gpu,
|
| 308 |
-
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
|
| 309 |
-
|
| 310 |
-
// Synchronize: kernel writes must be visible to the next torch
|
| 311 |
-
// op that reads cols/anom. Pytorch's default stream is stream 0,
|
| 312 |
-
// and cudarc launches on its own stream β a full device sync
|
| 313 |
-
// is the simplest correct barrier. (Could narrow to a stream
|
| 314 |
-
// wait event in PR 2.)
|
| 315 |
-
// No dev.synchronize() here: caller must explicitly sync via the
|
| 316 |
-
// `device_sync()` method (or PyTorch auto-syncs when the output
|
| 317 |
-
// tensor is next consumed). Removing the per-launch barrier lets
|
| 318 |
-
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
|
| 319 |
-
Ok(())
|
| 320 |
-
});
|
| 321 |
-
|
| 322 |
-
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 323 |
-
Ok(())
|
| 324 |
-
}
|
| 325 |
-
|
| 326 |
-
/// Clear TM state on the GPU.
|
| 327 |
-
fn reset(&mut self) -> PyResult<()> {
|
| 328 |
-
self.tm_gpu.reset().map_err(|e| {
|
| 329 |
-
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
|
| 330 |
-
})?;
|
| 331 |
-
self.fused_state.reset().map_err(|e| {
|
| 332 |
-
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
|
| 333 |
-
})
|
| 334 |
-
}
|
| 335 |
-
|
| 336 |
-
/// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
|
| 337 |
-
/// forward (SP + TM all in one). Accepts torch CUDA tensors via
|
| 338 |
-
/// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
|
| 339 |
-
/// anomaly directly into caller-allocated torch tensors.
|
| 340 |
-
///
|
| 341 |
-
/// Semantics diverge from `step_many_cuda` in one important way: column
|
| 342 |
-
/// activation uses per-column threshold inhibition instead of global
|
| 343 |
-
/// top-K. The threshold is EMA-adapted per column toward the sparsity
|
| 344 |
-
/// target. See `docs/GPU_HTM.md` Β§Fused Kernel.
|
| 345 |
-
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
|
| 346 |
-
fn step_many_fused_cuda(
|
| 347 |
-
&mut self,
|
| 348 |
-
py: Python<'_>,
|
| 349 |
-
sdr_cai: &Bound<'_, PyDict>,
|
| 350 |
-
cols_cai: &Bound<'_, PyDict>,
|
| 351 |
-
anom_cai: &Bound<'_, PyDict>,
|
| 352 |
-
learn: bool,
|
| 353 |
-
) -> PyResult<()> {
|
| 354 |
-
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
|
| 355 |
-
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
|
| 356 |
-
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
|
| 357 |
-
|
| 358 |
-
if sdr_type != "|u1" {
|
| 359 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 360 |
-
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
|
| 361 |
-
)));
|
| 362 |
-
}
|
| 363 |
-
if cols_type != "|u1" {
|
| 364 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 365 |
-
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
|
| 366 |
-
)));
|
| 367 |
-
}
|
| 368 |
-
if anom_type != "<f4" && anom_type != "=f4" {
|
| 369 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 370 |
-
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
|
| 371 |
-
)));
|
| 372 |
-
}
|
| 373 |
-
|
| 374 |
-
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
|
| 375 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 376 |
-
"sdr_cai shape {sdr_shape:?} != (T, {})",
|
| 377 |
-
self.input_bits,
|
| 378 |
-
)));
|
| 379 |
-
}
|
| 380 |
-
let t = sdr_shape[0];
|
| 381 |
-
if cols_shape != [t, self.n_columns] {
|
| 382 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 383 |
-
"cols_cai shape {cols_shape:?} != ({t}, {})",
|
| 384 |
-
self.n_columns,
|
| 385 |
-
)));
|
| 386 |
-
}
|
| 387 |
-
if anom_shape != [t] {
|
| 388 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 389 |
-
"anom_cai shape {anom_shape:?} != ({t},)",
|
| 390 |
-
)));
|
| 391 |
-
}
|
| 392 |
-
|
| 393 |
-
let dev = self.sp_gpu.dev_ref().clone();
|
| 394 |
-
let n_cols = self.n_columns;
|
| 395 |
-
let input_bits = self.input_bits;
|
| 396 |
-
|
| 397 |
-
let result = py.allow_threads(|| -> Result<(), String> {
|
| 398 |
-
let inputs_dev = ManuallyDrop::new(unsafe {
|
| 399 |
-
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
|
| 400 |
-
});
|
| 401 |
-
let mut cols_dev = ManuallyDrop::new(unsafe {
|
| 402 |
-
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
|
| 403 |
-
});
|
| 404 |
-
let mut anom_dev = ManuallyDrop::new(unsafe {
|
| 405 |
-
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
|
| 406 |
-
});
|
| 407 |
-
|
| 408 |
-
fused::launch_fused(
|
| 409 |
-
&mut self.sp_gpu,
|
| 410 |
-
&mut self.tm_gpu,
|
| 411 |
-
&mut self.fused_state,
|
| 412 |
-
&inputs_dev,
|
| 413 |
-
&mut cols_dev,
|
| 414 |
-
&mut anom_dev,
|
| 415 |
-
t,
|
| 416 |
-
input_bits,
|
| 417 |
-
learn,
|
| 418 |
-
).map_err(|e| format!("launch_fused: {e:?}"))?;
|
| 419 |
-
|
| 420 |
-
// No dev.synchronize() here: caller must explicitly sync via the
|
| 421 |
-
// `device_sync()` method (or PyTorch auto-syncs when the output
|
| 422 |
-
// tensor is next consumed). Removing the per-launch barrier lets
|
| 423 |
-
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
|
| 424 |
-
Ok(())
|
| 425 |
-
});
|
| 426 |
-
|
| 427 |
-
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 428 |
-
Ok(())
|
| 429 |
-
}
|
| 430 |
-
|
| 431 |
-
/// Explicit device synchronization β the caller must invoke this after
|
| 432 |
-
/// all batched `step_many_*_cuda` calls complete, before reading the
|
| 433 |
-
/// output tensors from a different CUDA stream. Equivalent to the old
|
| 434 |
-
/// per-call `dev.synchronize()` that was removed for overlap.
|
| 435 |
-
fn device_sync(&self) -> PyResult<()> {
|
| 436 |
-
let dev = self.sp_gpu.dev_ref();
|
| 437 |
-
dev.synchronize()
|
| 438 |
-
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?;
|
| 439 |
-
Ok(())
|
| 440 |
-
}
|
| 441 |
-
}
|
| 442 |
-
|
| 443 |
-
/// Batch B regions into ONE cooperative kernel launch. Breaks through the
|
| 444 |
-
/// CUDA cooperative-kernel device-level serialization: a single cooperative
|
| 445 |
-
/// launch with grid.y=B processes all regions concurrently β ~BΓ speedup
|
| 446 |
-
/// over B sequential launches.
|
| 447 |
-
///
|
| 448 |
-
/// All regions must have the same config (input_bits, n_columns,
|
| 449 |
-
/// cells_per_column). Each region keeps its independent GPU state.
|
| 450 |
-
/// Does NOT sync; caller must invoke `device_sync()` on any region
|
| 451 |
-
/// afterwards (or rely on a downstream torch op to auto-sync).
|
| 452 |
-
#[pyfunction]
|
| 453 |
-
#[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))]
|
| 454 |
-
fn step_batch_fused_cuda(
|
| 455 |
-
py: Python<'_>,
|
| 456 |
-
regions: Vec<Py<HTMRegionGpu>>,
|
| 457 |
-
sdr_cais: Vec<Bound<'_, PyDict>>,
|
| 458 |
-
cols_cais: Vec<Bound<'_, PyDict>>,
|
| 459 |
-
anom_cais: Vec<Bound<'_, PyDict>>,
|
| 460 |
-
learn: bool,
|
| 461 |
-
) -> PyResult<()> {
|
| 462 |
-
let b = regions.len();
|
| 463 |
-
if b == 0 {
|
| 464 |
-
return Err(pyo3::exceptions::PyValueError::new_err("regions is empty"));
|
| 465 |
-
}
|
| 466 |
-
if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b {
|
| 467 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 468 |
-
"sdr_cais / cols_cais / anom_cais length must match regions",
|
| 469 |
-
));
|
| 470 |
-
}
|
| 471 |
-
|
| 472 |
-
// Parse all CAI dicts; collect device pointers. Validate shapes/dtypes.
|
| 473 |
-
let mut sdr_ptrs = Vec::with_capacity(b);
|
| 474 |
-
let mut cols_ptrs = Vec::with_capacity(b);
|
| 475 |
-
let mut anom_ptrs = Vec::with_capacity(b);
|
| 476 |
-
let (input_bits, n_columns, t) = {
|
| 477 |
-
let r0 = regions[0].bind(py).borrow();
|
| 478 |
-
(r0.input_bits, r0.n_columns, {
|
| 479 |
-
let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?;
|
| 480 |
-
if sh.len() != 2 {
|
| 481 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 482 |
-
format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"),
|
| 483 |
-
));
|
| 484 |
-
}
|
| 485 |
-
sh[0]
|
| 486 |
-
})
|
| 487 |
-
};
|
| 488 |
-
|
| 489 |
-
for i in 0..b {
|
| 490 |
-
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?;
|
| 491 |
-
let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?;
|
| 492 |
-
let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?;
|
| 493 |
-
if sdr_type != "|u1" || cols_type != "|u1" {
|
| 494 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 495 |
-
"sdr/cols typestr must be '|u1' (uint8)",
|
| 496 |
-
));
|
| 497 |
-
}
|
| 498 |
-
if anom_type != "<f4" && anom_type != "=f4" {
|
| 499 |
-
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 500 |
-
"anom typestr must be '<f4' (float32)",
|
| 501 |
-
));
|
| 502 |
-
}
|
| 503 |
-
if sdr_shape != [t, input_bits] {
|
| 504 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 505 |
-
"sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})"
|
| 506 |
-
)));
|
| 507 |
-
}
|
| 508 |
-
if cols_shape != [t, n_columns] {
|
| 509 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 510 |
-
"cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})"
|
| 511 |
-
)));
|
| 512 |
-
}
|
| 513 |
-
if anom_shape != [t] {
|
| 514 |
-
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 515 |
-
"anom[{i}] shape {anom_shape:?} != ({t},)"
|
| 516 |
-
)));
|
| 517 |
-
}
|
| 518 |
-
sdr_ptrs.push(sdr_ptr);
|
| 519 |
-
cols_ptrs.push(cols_ptr);
|
| 520 |
-
anom_ptrs.push(anom_ptr);
|
| 521 |
-
}
|
| 522 |
-
|
| 523 |
-
// Exclusively borrow each region. PyRefMut guarantees uniqueness.
|
| 524 |
-
let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> =
|
| 525 |
-
regions.iter().map(|p| p.bind(py).borrow_mut()).collect();
|
| 526 |
-
// Collect raw mutable pointers β each PyRefMut exclusively borrows its
|
| 527 |
-
// region for the lifetime of this call, so pointers stay valid and
|
| 528 |
-
// unique. launch_fused_batched_raw only dereferences one region at a
|
| 529 |
-
// time, not constructing an aliased slice.
|
| 530 |
-
let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs
|
| 531 |
-
.iter_mut()
|
| 532 |
-
.map(|r| &mut **r as *mut HTMRegionGpu)
|
| 533 |
-
.collect();
|
| 534 |
-
|
| 535 |
-
// No allow_threads: raw pointers aren't Send. The launch is GPU-queued
|
| 536 |
-
// and sync'd downstream; holding the GIL for the duration is cheap.
|
| 537 |
-
fused::launch_fused_batched_raw(
|
| 538 |
-
&raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs,
|
| 539 |
-
t, input_bits, learn,
|
| 540 |
-
)
|
| 541 |
-
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?;
|
| 542 |
-
Ok(())
|
| 543 |
-
}
|
| 544 |
-
|
| 545 |
-
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 546 |
-
m.add_class::<HTMRegionGpu>()?;
|
| 547 |
-
m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
|
| 548 |
-
Ok(())
|
| 549 |
-
}
|
|
|
|
| 1 |
+
//! GPU backend for HTM.
|
| 2 |
+
//!
|
| 3 |
+
//! Full-GPU pipeline (SP + TM). Per-step state lives entirely on device; the
|
| 4 |
+
//! batch API (`step_many_gpu`) uploads T steps of input once, runs T iterations
|
| 5 |
+
//! of the full HTM pipeline on GPU, and copies (T, n_cols) u8 + (T,) f32 back
|
| 6 |
+
//! to the host in one shot.
|
| 7 |
+
//!
|
| 8 |
+
//! TM parity with the CPU reference is approximate:
|
| 9 |
+
//! - Segment growth: winner = cell 0 of bursting column (CPU picks
|
| 10 |
+
//! least-used-cell with RNG tiebreak). This is a pragmatic simplification
|
| 11 |
+
//! for GPU atomicity; learning dynamics are preserved.
|
| 12 |
+
//! - Permanences stored as i16 (scaled 0..32767). Rounding differs from
|
| 13 |
+
//! f32 by <= 1 ULP of the scale factor (β 3e-5) β inside any meaningful
|
| 14 |
+
//! HTM learning quantum.
|
| 15 |
+
|
| 16 |
+
#![cfg(feature = "gpu")]
|
| 17 |
+
|
| 18 |
+
pub mod sp_gpu;
|
| 19 |
+
pub mod tm_gpu;
|
| 20 |
+
pub mod fused;
|
| 21 |
+
|
| 22 |
+
#[cfg(test)]
|
| 23 |
+
mod tests;
|
| 24 |
+
|
| 25 |
+
use std::mem::ManuallyDrop;
|
| 26 |
+
|
| 27 |
+
use pyo3::prelude::*;
|
| 28 |
+
use pyo3::types::{PyDict, PyTuple};
|
| 29 |
+
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods};
|
| 30 |
+
|
| 31 |
+
use crate::region::HTMRegionCore;
|
| 32 |
+
use crate::sp::SpatialPoolerConfig;
|
| 33 |
+
use sp_gpu::SpatialPoolerGpu;
|
| 34 |
+
use tm_gpu::TemporalMemoryGpu;
|
| 35 |
+
use fused::FusedState;
|
| 36 |
+
|
| 37 |
+
/// Extract (device_ptr, shape, typestr) from a `__cuda_array_interface__` dict.
|
| 38 |
+
/// Returns Err if the dict is malformed. Used by `step_many_cuda` to wrap
|
| 39 |
+
/// torch-owned CUDA allocations zero-copy.
|
| 40 |
+
fn cai_parse(cai: &Bound<'_, PyDict>) -> PyResult<(u64, Vec<usize>, String)> {
|
| 41 |
+
// `data` is a (ptr: int, readonly: bool) tuple.
|
| 42 |
+
let data_obj = cai.get_item("data")?
|
| 43 |
+
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'data'"))?;
|
| 44 |
+
let data_tup: Bound<'_, PyTuple> = data_obj.downcast_into()
|
| 45 |
+
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'data' must be a tuple"))?;
|
| 46 |
+
let ptr: u64 = data_tup.get_item(0)?.extract()?;
|
| 47 |
+
|
| 48 |
+
// `shape` is a tuple of ints.
|
| 49 |
+
let shape_obj = cai.get_item("shape")?
|
| 50 |
+
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'shape'"))?;
|
| 51 |
+
let shape_tup: Bound<'_, PyTuple> = shape_obj.downcast_into()
|
| 52 |
+
.map_err(|_| pyo3::exceptions::PyValueError::new_err("CAI 'shape' must be a tuple"))?;
|
| 53 |
+
let shape: Vec<usize> = (0..shape_tup.len())
|
| 54 |
+
.map(|i| shape_tup.get_item(i).and_then(|v| v.extract::<usize>()))
|
| 55 |
+
.collect::<PyResult<Vec<_>>>()?;
|
| 56 |
+
|
| 57 |
+
// `typestr` (e.g. "|u1", "<f4").
|
| 58 |
+
let typestr_obj = cai.get_item("typestr")?
|
| 59 |
+
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("CAI missing 'typestr'"))?;
|
| 60 |
+
let typestr: String = typestr_obj.extract()?;
|
| 61 |
+
|
| 62 |
+
// Reject non-contiguous tensors β we don't handle strides.
|
| 63 |
+
if let Some(strides) = cai.get_item("strides")? {
|
| 64 |
+
if !strides.is_none() {
|
| 65 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 66 |
+
"CAI 'strides' must be None (tensor must be contiguous)",
|
| 67 |
+
));
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
Ok((ptr, shape, typestr))
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
/// Python-exposed GPU HTM region. Drop-in replacement for `HTMRegion`.
|
| 75 |
+
#[pyclass(module = "htm_rust")]
|
| 76 |
+
pub struct HTMRegionGpu {
|
| 77 |
+
pub(super) sp_gpu: SpatialPoolerGpu,
|
| 78 |
+
pub(super) tm_gpu: TemporalMemoryGpu,
|
| 79 |
+
pub(super) fused_state: FusedState,
|
| 80 |
+
pub(super) n_columns: usize,
|
| 81 |
+
pub(super) input_bits: usize,
|
| 82 |
+
pub(super) cells_per_column: usize,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
#[pymethods]
|
| 86 |
+
impl HTMRegionGpu {
|
| 87 |
+
#[new]
|
| 88 |
+
#[pyo3(signature = (input_bits, n_columns, cells_per_column, seed=42))]
|
| 89 |
+
fn new(
|
| 90 |
+
input_bits: usize,
|
| 91 |
+
n_columns: usize,
|
| 92 |
+
cells_per_column: usize,
|
| 93 |
+
seed: u64,
|
| 94 |
+
) -> PyResult<Self> {
|
| 95 |
+
if input_bits == 0 || n_columns == 0 || cells_per_column == 0 {
|
| 96 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 97 |
+
"input_bits, n_columns, cells_per_column must all be > 0",
|
| 98 |
+
));
|
| 99 |
+
}
|
| 100 |
+
// CPU reference for deterministic SP init.
|
| 101 |
+
let cpu_ref = HTMRegionCore::new(input_bits, n_columns, cells_per_column, seed);
|
| 102 |
+
let sp_cfg: &SpatialPoolerConfig = &cpu_ref.sp.cfg;
|
| 103 |
+
let sp_gpu = SpatialPoolerGpu::from_cpu(&cpu_ref.sp).map_err(|e| {
|
| 104 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 105 |
+
"GPU SP init failed: {e:?}. Config: input_bits={}, n_columns={}",
|
| 106 |
+
sp_cfg.input_bits, sp_cfg.n_columns,
|
| 107 |
+
))
|
| 108 |
+
})?;
|
| 109 |
+
let dev = sp_gpu.dev_ref().clone();
|
| 110 |
+
let tm_gpu = TemporalMemoryGpu::new(dev.clone(), n_columns, cells_per_column).map_err(|e| {
|
| 111 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 112 |
+
"GPU TM init failed: {e:?}",
|
| 113 |
+
))
|
| 114 |
+
})?;
|
| 115 |
+
let initial_threshold = sp_gpu.initial_threshold_estimate();
|
| 116 |
+
let fused_state = FusedState::new(dev, n_columns, cells_per_column, initial_threshold)
|
| 117 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!(
|
| 118 |
+
"GPU fused state init failed: {e:?}",
|
| 119 |
+
)))?;
|
| 120 |
+
Ok(Self {
|
| 121 |
+
sp_gpu,
|
| 122 |
+
tm_gpu,
|
| 123 |
+
fused_state,
|
| 124 |
+
n_columns,
|
| 125 |
+
input_bits,
|
| 126 |
+
cells_per_column,
|
| 127 |
+
})
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
#[getter] fn input_bits(&self) -> usize { self.input_bits }
|
| 131 |
+
#[getter] fn n_columns(&self) -> usize { self.n_columns }
|
| 132 |
+
#[getter] fn cells_per_column(&self) -> usize { self.cells_per_column }
|
| 133 |
+
|
| 134 |
+
/// Process T timesteps in one call on GPU. Per-step state (SP + TM) stays
|
| 135 |
+
/// on device; only the final (T, n_cols) mask and (T,) anomaly are copied
|
| 136 |
+
/// to the host at the end.
|
| 137 |
+
#[pyo3(signature = (inputs, learn=true))]
|
| 138 |
+
fn step_many_gpu<'py>(
|
| 139 |
+
&mut self,
|
| 140 |
+
py: Python<'py>,
|
| 141 |
+
inputs: PyReadonlyArray2<'py, bool>,
|
| 142 |
+
learn: bool,
|
| 143 |
+
) -> PyResult<(Bound<'py, PyArray2<f32>>, Bound<'py, PyArray1<f32>>)> {
|
| 144 |
+
let shape = inputs.shape();
|
| 145 |
+
if shape.len() != 2 {
|
| 146 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 147 |
+
"inputs must be 2-D (T, input_bits)",
|
| 148 |
+
));
|
| 149 |
+
}
|
| 150 |
+
let t = shape[0];
|
| 151 |
+
let bits = shape[1];
|
| 152 |
+
if bits != self.input_bits {
|
| 153 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 154 |
+
"inputs last dim {bits} != expected input_bits {}",
|
| 155 |
+
self.input_bits,
|
| 156 |
+
)));
|
| 157 |
+
}
|
| 158 |
+
let slice = inputs.as_slice()?;
|
| 159 |
+
let n_cols = self.n_columns;
|
| 160 |
+
let input_vec: Vec<bool> = slice.to_vec();
|
| 161 |
+
|
| 162 |
+
let result = py.allow_threads(|| -> Result<(Vec<u8>, Vec<f32>), String> {
|
| 163 |
+
// 1. Upload T*input_bits bytes (32 MB at T=2048, bits=16384).
|
| 164 |
+
let sdr_u8_all: Vec<u8> = input_vec.iter().map(|&b| b as u8).collect();
|
| 165 |
+
let inputs_dev = self
|
| 166 |
+
.sp_gpu
|
| 167 |
+
.dev_ref()
|
| 168 |
+
.htod_sync_copy(&sdr_u8_all)
|
| 169 |
+
.map_err(|e| format!("H2D inputs: {e:?}"))?;
|
| 170 |
+
|
| 171 |
+
// 2. Allocate output buffers on device.
|
| 172 |
+
let mut cols_dev = self.sp_gpu.dev_ref()
|
| 173 |
+
.alloc_zeros::<u8>(t * n_cols)
|
| 174 |
+
.map_err(|e| format!("alloc cols: {e:?}"))?;
|
| 175 |
+
let mut anom_dev = self.sp_gpu.dev_ref()
|
| 176 |
+
.alloc_zeros::<f32>(t)
|
| 177 |
+
.map_err(|e| format!("alloc anom: {e:?}"))?;
|
| 178 |
+
|
| 179 |
+
// 3. Run T steps of SP + TM on GPU with NO per-step host sync.
|
| 180 |
+
self.sp_gpu.step_batch_with_tm(
|
| 181 |
+
&inputs_dev,
|
| 182 |
+
t,
|
| 183 |
+
self.input_bits,
|
| 184 |
+
learn,
|
| 185 |
+
&mut cols_dev,
|
| 186 |
+
&mut anom_dev,
|
| 187 |
+
&mut self.tm_gpu,
|
| 188 |
+
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
|
| 189 |
+
|
| 190 |
+
// 4. ONE D2H for the whole run (T * n_cols bytes + T floats).
|
| 191 |
+
let cols_host: Vec<u8> = self.sp_gpu.dev_ref()
|
| 192 |
+
.dtoh_sync_copy(&cols_dev)
|
| 193 |
+
.map_err(|e| format!("D2H cols: {e:?}"))?;
|
| 194 |
+
let anom_host: Vec<f32> = self.sp_gpu.dev_ref()
|
| 195 |
+
.dtoh_sync_copy(&anom_dev)
|
| 196 |
+
.map_err(|e| format!("D2H anom: {e:?}"))?;
|
| 197 |
+
|
| 198 |
+
Ok((cols_host, anom_host))
|
| 199 |
+
});
|
| 200 |
+
|
| 201 |
+
let (cols_u8, anom) = result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 202 |
+
|
| 203 |
+
let cols_f32: Vec<f32> = cols_u8.iter().map(|&b| b as f32).collect();
|
| 204 |
+
let cols_arr = numpy::PyArray1::from_vec_bound(py, cols_f32)
|
| 205 |
+
.reshape([t, n_cols])
|
| 206 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
|
| 207 |
+
let anom_arr = numpy::PyArray1::from_vec_bound(py, anom);
|
| 208 |
+
Ok((cols_arr, anom_arr))
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
/// Zero-copy CUDA path: accept torch tensors via __cuda_array_interface__,
|
| 212 |
+
/// write outputs directly into caller-allocated torch tensors. Skips the
|
| 213 |
+
/// host round-trip that `step_many_gpu` pays on every call (sdr.cpu() +
|
| 214 |
+
/// two D2H copies at the end). This is the hot path for `train.py`.
|
| 215 |
+
///
|
| 216 |
+
/// Contract:
|
| 217 |
+
/// sdr_cai.shape == (T, input_bits), dtype u8 (0/1 mask)
|
| 218 |
+
/// cols_cai.shape == (T, n_columns), dtype u8 (written)
|
| 219 |
+
/// anom_cai.shape == (T,), dtype f32 (written)
|
| 220 |
+
/// All three tensors must live on the SAME CUDA device as this region.
|
| 221 |
+
///
|
| 222 |
+
/// The torch tensors still own their memory β this method only wraps
|
| 223 |
+
/// them as borrowed CudaSlice views (via ManuallyDrop) so cudarc's Drop
|
| 224 |
+
/// impl can't free pytorch's allocator.
|
| 225 |
+
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
|
| 226 |
+
fn step_many_cuda(
|
| 227 |
+
&mut self,
|
| 228 |
+
py: Python<'_>,
|
| 229 |
+
sdr_cai: &Bound<'_, PyDict>,
|
| 230 |
+
cols_cai: &Bound<'_, PyDict>,
|
| 231 |
+
anom_cai: &Bound<'_, PyDict>,
|
| 232 |
+
learn: bool,
|
| 233 |
+
) -> PyResult<()> {
|
| 234 |
+
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
|
| 235 |
+
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
|
| 236 |
+
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
|
| 237 |
+
|
| 238 |
+
// typestr sanity. numpy u1 is what torch.uint8 exports.
|
| 239 |
+
if sdr_type != "|u1" {
|
| 240 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 241 |
+
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
|
| 242 |
+
)));
|
| 243 |
+
}
|
| 244 |
+
if cols_type != "|u1" {
|
| 245 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 246 |
+
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
|
| 247 |
+
)));
|
| 248 |
+
}
|
| 249 |
+
if anom_type != "<f4" && anom_type != "=f4" {
|
| 250 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 251 |
+
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
|
| 252 |
+
)));
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
// Shape validation.
|
| 256 |
+
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
|
| 257 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 258 |
+
"sdr_cai shape {sdr_shape:?} != (T, {})",
|
| 259 |
+
self.input_bits,
|
| 260 |
+
)));
|
| 261 |
+
}
|
| 262 |
+
let t = sdr_shape[0];
|
| 263 |
+
if cols_shape != [t, self.n_columns] {
|
| 264 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 265 |
+
"cols_cai shape {cols_shape:?} != ({t}, {})",
|
| 266 |
+
self.n_columns,
|
| 267 |
+
)));
|
| 268 |
+
}
|
| 269 |
+
if anom_shape != [t] {
|
| 270 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 271 |
+
"anom_cai shape {anom_shape:?} != ({t},)",
|
| 272 |
+
)));
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
let dev = self.sp_gpu.dev_ref().clone();
|
| 276 |
+
let n_cols = self.n_columns;
|
| 277 |
+
let input_bits = self.input_bits;
|
| 278 |
+
|
| 279 |
+
let result = py.allow_threads(|| -> Result<(), String> {
|
| 280 |
+
// SAFETY:
|
| 281 |
+
// - ptrs came from torch CUDA tensors validated non-null by the
|
| 282 |
+
// __cuda_array_interface__ contract.
|
| 283 |
+
// - lens computed from validated shapes.
|
| 284 |
+
// - We wrap the returned CudaSlice in ManuallyDrop so cudarc's
|
| 285 |
+
// Drop (which calls cuMemFree) never runs against torch memory.
|
| 286 |
+
// The underlying allocation is owned+freed by torch.
|
| 287 |
+
// - The slices are used only for the duration of this call;
|
| 288 |
+
// torch guarantees the backing tensors are live across it
|
| 289 |
+
// (Python holds refs on the wrapping tensors).
|
| 290 |
+
let inputs_dev = ManuallyDrop::new(unsafe {
|
| 291 |
+
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
|
| 292 |
+
});
|
| 293 |
+
let mut cols_dev = ManuallyDrop::new(unsafe {
|
| 294 |
+
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
|
| 295 |
+
});
|
| 296 |
+
let mut anom_dev = ManuallyDrop::new(unsafe {
|
| 297 |
+
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
|
| 298 |
+
});
|
| 299 |
+
|
| 300 |
+
self.sp_gpu.step_batch_with_tm(
|
| 301 |
+
&inputs_dev,
|
| 302 |
+
t,
|
| 303 |
+
input_bits,
|
| 304 |
+
learn,
|
| 305 |
+
&mut cols_dev,
|
| 306 |
+
&mut anom_dev,
|
| 307 |
+
&mut self.tm_gpu,
|
| 308 |
+
).map_err(|e| format!("step_batch_with_tm: {e:?}"))?;
|
| 309 |
+
|
| 310 |
+
// Synchronize: kernel writes must be visible to the next torch
|
| 311 |
+
// op that reads cols/anom. Pytorch's default stream is stream 0,
|
| 312 |
+
// and cudarc launches on its own stream β a full device sync
|
| 313 |
+
// is the simplest correct barrier. (Could narrow to a stream
|
| 314 |
+
// wait event in PR 2.)
|
| 315 |
+
// No dev.synchronize() here: caller must explicitly sync via the
|
| 316 |
+
// `device_sync()` method (or PyTorch auto-syncs when the output
|
| 317 |
+
// tensor is next consumed). Removing the per-launch barrier lets
|
| 318 |
+
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
|
| 319 |
+
Ok(())
|
| 320 |
+
});
|
| 321 |
+
|
| 322 |
+
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 323 |
+
Ok(())
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
/// Clear TM state on the GPU.
|
| 327 |
+
fn reset(&mut self) -> PyResult<()> {
|
| 328 |
+
self.tm_gpu.reset().map_err(|e| {
|
| 329 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU TM reset: {e:?}"))
|
| 330 |
+
})?;
|
| 331 |
+
self.fused_state.reset().map_err(|e| {
|
| 332 |
+
pyo3::exceptions::PyRuntimeError::new_err(format!("GPU fused reset: {e:?}"))
|
| 333 |
+
})
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
/// FUSED MEGAKERNEL PATH: single CUDA launch for the entire T-step
|
| 337 |
+
/// forward (SP + TM all in one). Accepts torch CUDA tensors via
|
| 338 |
+
/// `__cuda_array_interface__` (zero-copy). Writes active-column mask +
|
| 339 |
+
/// anomaly directly into caller-allocated torch tensors.
|
| 340 |
+
///
|
| 341 |
+
/// Semantics diverge from `step_many_cuda` in one important way: column
|
| 342 |
+
/// activation uses per-column threshold inhibition instead of global
|
| 343 |
+
/// top-K. The threshold is EMA-adapted per column toward the sparsity
|
| 344 |
+
/// target. See `docs/GPU_HTM.md` Β§Fused Kernel.
|
| 345 |
+
#[pyo3(signature = (sdr_cai, cols_cai, anom_cai, learn=true))]
|
| 346 |
+
fn step_many_fused_cuda(
|
| 347 |
+
&mut self,
|
| 348 |
+
py: Python<'_>,
|
| 349 |
+
sdr_cai: &Bound<'_, PyDict>,
|
| 350 |
+
cols_cai: &Bound<'_, PyDict>,
|
| 351 |
+
anom_cai: &Bound<'_, PyDict>,
|
| 352 |
+
learn: bool,
|
| 353 |
+
) -> PyResult<()> {
|
| 354 |
+
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(sdr_cai)?;
|
| 355 |
+
let (cols_ptr, cols_shape, cols_type) = cai_parse(cols_cai)?;
|
| 356 |
+
let (anom_ptr, anom_shape, anom_type) = cai_parse(anom_cai)?;
|
| 357 |
+
|
| 358 |
+
if sdr_type != "|u1" {
|
| 359 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 360 |
+
"sdr_cai typestr must be '|u1' (uint8), got {sdr_type}",
|
| 361 |
+
)));
|
| 362 |
+
}
|
| 363 |
+
if cols_type != "|u1" {
|
| 364 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 365 |
+
"cols_cai typestr must be '|u1' (uint8), got {cols_type}",
|
| 366 |
+
)));
|
| 367 |
+
}
|
| 368 |
+
if anom_type != "<f4" && anom_type != "=f4" {
|
| 369 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 370 |
+
"anom_cai typestr must be '<f4' (float32), got {anom_type}",
|
| 371 |
+
)));
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
if sdr_shape.len() != 2 || sdr_shape[1] != self.input_bits {
|
| 375 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 376 |
+
"sdr_cai shape {sdr_shape:?} != (T, {})",
|
| 377 |
+
self.input_bits,
|
| 378 |
+
)));
|
| 379 |
+
}
|
| 380 |
+
let t = sdr_shape[0];
|
| 381 |
+
if cols_shape != [t, self.n_columns] {
|
| 382 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 383 |
+
"cols_cai shape {cols_shape:?} != ({t}, {})",
|
| 384 |
+
self.n_columns,
|
| 385 |
+
)));
|
| 386 |
+
}
|
| 387 |
+
if anom_shape != [t] {
|
| 388 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 389 |
+
"anom_cai shape {anom_shape:?} != ({t},)",
|
| 390 |
+
)));
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
let dev = self.sp_gpu.dev_ref().clone();
|
| 394 |
+
let n_cols = self.n_columns;
|
| 395 |
+
let input_bits = self.input_bits;
|
| 396 |
+
|
| 397 |
+
let result = py.allow_threads(|| -> Result<(), String> {
|
| 398 |
+
let inputs_dev = ManuallyDrop::new(unsafe {
|
| 399 |
+
dev.upgrade_device_ptr::<u8>(sdr_ptr, t * input_bits)
|
| 400 |
+
});
|
| 401 |
+
let mut cols_dev = ManuallyDrop::new(unsafe {
|
| 402 |
+
dev.upgrade_device_ptr::<u8>(cols_ptr, t * n_cols)
|
| 403 |
+
});
|
| 404 |
+
let mut anom_dev = ManuallyDrop::new(unsafe {
|
| 405 |
+
dev.upgrade_device_ptr::<f32>(anom_ptr, t)
|
| 406 |
+
});
|
| 407 |
+
|
| 408 |
+
fused::launch_fused(
|
| 409 |
+
&mut self.sp_gpu,
|
| 410 |
+
&mut self.tm_gpu,
|
| 411 |
+
&mut self.fused_state,
|
| 412 |
+
&inputs_dev,
|
| 413 |
+
&mut cols_dev,
|
| 414 |
+
&mut anom_dev,
|
| 415 |
+
t,
|
| 416 |
+
input_bits,
|
| 417 |
+
learn,
|
| 418 |
+
).map_err(|e| format!("launch_fused: {e:?}"))?;
|
| 419 |
+
|
| 420 |
+
// No dev.synchronize() here: caller must explicitly sync via the
|
| 421 |
+
// `device_sync()` method (or PyTorch auto-syncs when the output
|
| 422 |
+
// tensor is next consumed). Removing the per-launch barrier lets
|
| 423 |
+
// subsequent GPU work (mamba3 fwd, etc.) overlap in time.
|
| 424 |
+
Ok(())
|
| 425 |
+
});
|
| 426 |
+
|
| 427 |
+
result.map_err(pyo3::exceptions::PyRuntimeError::new_err)?;
|
| 428 |
+
Ok(())
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/// Explicit device synchronization β the caller must invoke this after
|
| 432 |
+
/// all batched `step_many_*_cuda` calls complete, before reading the
|
| 433 |
+
/// output tensors from a different CUDA stream. Equivalent to the old
|
| 434 |
+
/// per-call `dev.synchronize()` that was removed for overlap.
|
| 435 |
+
fn device_sync(&self) -> PyResult<()> {
|
| 436 |
+
let dev = self.sp_gpu.dev_ref();
|
| 437 |
+
dev.synchronize()
|
| 438 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("sync: {e:?}")))?;
|
| 439 |
+
Ok(())
|
| 440 |
+
}
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
/// Batch B regions into ONE cooperative kernel launch. Breaks through the
|
| 444 |
+
/// CUDA cooperative-kernel device-level serialization: a single cooperative
|
| 445 |
+
/// launch with grid.y=B processes all regions concurrently β ~BΓ speedup
|
| 446 |
+
/// over B sequential launches.
|
| 447 |
+
///
|
| 448 |
+
/// All regions must have the same config (input_bits, n_columns,
|
| 449 |
+
/// cells_per_column). Each region keeps its independent GPU state.
|
| 450 |
+
/// Does NOT sync; caller must invoke `device_sync()` on any region
|
| 451 |
+
/// afterwards (or rely on a downstream torch op to auto-sync).
|
| 452 |
+
#[pyfunction]
|
| 453 |
+
#[pyo3(signature = (regions, sdr_cais, cols_cais, anom_cais, learn=true))]
|
| 454 |
+
fn step_batch_fused_cuda(
|
| 455 |
+
py: Python<'_>,
|
| 456 |
+
regions: Vec<Py<HTMRegionGpu>>,
|
| 457 |
+
sdr_cais: Vec<Bound<'_, PyDict>>,
|
| 458 |
+
cols_cais: Vec<Bound<'_, PyDict>>,
|
| 459 |
+
anom_cais: Vec<Bound<'_, PyDict>>,
|
| 460 |
+
learn: bool,
|
| 461 |
+
) -> PyResult<()> {
|
| 462 |
+
let b = regions.len();
|
| 463 |
+
if b == 0 {
|
| 464 |
+
return Err(pyo3::exceptions::PyValueError::new_err("regions is empty"));
|
| 465 |
+
}
|
| 466 |
+
if sdr_cais.len() != b || cols_cais.len() != b || anom_cais.len() != b {
|
| 467 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 468 |
+
"sdr_cais / cols_cais / anom_cais length must match regions",
|
| 469 |
+
));
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
// Parse all CAI dicts; collect device pointers. Validate shapes/dtypes.
|
| 473 |
+
let mut sdr_ptrs = Vec::with_capacity(b);
|
| 474 |
+
let mut cols_ptrs = Vec::with_capacity(b);
|
| 475 |
+
let mut anom_ptrs = Vec::with_capacity(b);
|
| 476 |
+
let (input_bits, n_columns, t) = {
|
| 477 |
+
let r0 = regions[0].bind(py).borrow();
|
| 478 |
+
(r0.input_bits, r0.n_columns, {
|
| 479 |
+
let (_p, sh, _ty) = cai_parse(&sdr_cais[0])?;
|
| 480 |
+
if sh.len() != 2 {
|
| 481 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 482 |
+
format!("sdr_cai must be 2-D (T, input_bits), got {sh:?}"),
|
| 483 |
+
));
|
| 484 |
+
}
|
| 485 |
+
sh[0]
|
| 486 |
+
})
|
| 487 |
+
};
|
| 488 |
+
|
| 489 |
+
for i in 0..b {
|
| 490 |
+
let (sdr_ptr, sdr_shape, sdr_type) = cai_parse(&sdr_cais[i])?;
|
| 491 |
+
let (cols_ptr, cols_shape, cols_type) = cai_parse(&cols_cais[i])?;
|
| 492 |
+
let (anom_ptr, anom_shape, anom_type) = cai_parse(&anom_cais[i])?;
|
| 493 |
+
if sdr_type != "|u1" || cols_type != "|u1" {
|
| 494 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 495 |
+
"sdr/cols typestr must be '|u1' (uint8)",
|
| 496 |
+
));
|
| 497 |
+
}
|
| 498 |
+
if anom_type != "<f4" && anom_type != "=f4" {
|
| 499 |
+
return Err(pyo3::exceptions::PyValueError::new_err(
|
| 500 |
+
"anom typestr must be '<f4' (float32)",
|
| 501 |
+
));
|
| 502 |
+
}
|
| 503 |
+
if sdr_shape != [t, input_bits] {
|
| 504 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 505 |
+
"sdr[{i}] shape {sdr_shape:?} != ({t}, {input_bits})"
|
| 506 |
+
)));
|
| 507 |
+
}
|
| 508 |
+
if cols_shape != [t, n_columns] {
|
| 509 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 510 |
+
"cols[{i}] shape {cols_shape:?} != ({t}, {n_columns})"
|
| 511 |
+
)));
|
| 512 |
+
}
|
| 513 |
+
if anom_shape != [t] {
|
| 514 |
+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
|
| 515 |
+
"anom[{i}] shape {anom_shape:?} != ({t},)"
|
| 516 |
+
)));
|
| 517 |
+
}
|
| 518 |
+
sdr_ptrs.push(sdr_ptr);
|
| 519 |
+
cols_ptrs.push(cols_ptr);
|
| 520 |
+
anom_ptrs.push(anom_ptr);
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
// Exclusively borrow each region. PyRefMut guarantees uniqueness.
|
| 524 |
+
let mut region_refs: Vec<pyo3::PyRefMut<HTMRegionGpu>> =
|
| 525 |
+
regions.iter().map(|p| p.bind(py).borrow_mut()).collect();
|
| 526 |
+
// Collect raw mutable pointers β each PyRefMut exclusively borrows its
|
| 527 |
+
// region for the lifetime of this call, so pointers stay valid and
|
| 528 |
+
// unique. launch_fused_batched_raw only dereferences one region at a
|
| 529 |
+
// time, not constructing an aliased slice.
|
| 530 |
+
let raw_ptrs: Vec<*mut HTMRegionGpu> = region_refs
|
| 531 |
+
.iter_mut()
|
| 532 |
+
.map(|r| &mut **r as *mut HTMRegionGpu)
|
| 533 |
+
.collect();
|
| 534 |
+
|
| 535 |
+
// No allow_threads: raw pointers aren't Send. The launch is GPU-queued
|
| 536 |
+
// and sync'd downstream; holding the GIL for the duration is cheap.
|
| 537 |
+
fused::launch_fused_batched_raw(
|
| 538 |
+
&raw_ptrs, &sdr_ptrs, &cols_ptrs, &anom_ptrs,
|
| 539 |
+
t, input_bits, learn,
|
| 540 |
+
)
|
| 541 |
+
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("launch_fused_batched: {e:?}")))?;
|
| 542 |
+
Ok(())
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
| 546 |
+
m.add_class::<HTMRegionGpu>()?;
|
| 547 |
+
m.add_function(pyo3::wrap_pyfunction!(step_batch_fused_cuda, m)?)?;
|
| 548 |
+
Ok(())
|
| 549 |
+
}
|
overlay/htm_rust/src/gpu/sp_gpu.rs
CHANGED
|
@@ -1,796 +1,796 @@
|
|
| 1 |
-
//! GPU implementation of the Spatial Pooler.
|
| 2 |
-
//!
|
| 3 |
-
//! One `SpatialPoolerGpu` owns a set of persistent device buffers + 4 PTX
|
| 4 |
-
//! kernels. `compute(input, learn)` performs one SP step and returns the
|
| 5 |
-
//! sorted active-column indices (host `Vec<u32>`) β this is what the CPU
|
| 6 |
-
//! TemporalMemory consumes.
|
| 7 |
-
//!
|
| 8 |
-
//! Persistent state on device (per region):
|
| 9 |
-
//! syn_bit : u32 [n_columns Γ S] (constant after init)
|
| 10 |
-
//! syn_perm : f32 [n_columns Γ S] (updated by sp_learn)
|
| 11 |
-
//! boost : f32 [n_columns]
|
| 12 |
-
//! active_duty : f32 [n_columns]
|
| 13 |
-
//! overlap_duty: f32 [n_columns]
|
| 14 |
-
//!
|
| 15 |
-
//! Per-step transient state:
|
| 16 |
-
//! inp_dev : u8 [input_bits] (H2D copy each step)
|
| 17 |
-
//! raw : u32 [n_columns]
|
| 18 |
-
//! boosted : f32 [n_columns]
|
| 19 |
-
//! active_mask : u8 [n_columns] (topk output, D2H at the end)
|
| 20 |
-
|
| 21 |
-
use std::sync::Arc;
|
| 22 |
-
|
| 23 |
-
use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig};
|
| 24 |
-
use cudarc::nvrtc::Ptx;
|
| 25 |
-
|
| 26 |
-
use crate::sp::SpatialPooler;
|
| 27 |
-
|
| 28 |
-
// Embed PTX at compile time. OUT_DIR is set by build.rs.
|
| 29 |
-
const PTX_SP_OVERLAP: &str =
|
| 30 |
-
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx"));
|
| 31 |
-
const PTX_SP_TOPK: &str =
|
| 32 |
-
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx"));
|
| 33 |
-
const PTX_SP_LEARN: &str =
|
| 34 |
-
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx"));
|
| 35 |
-
const PTX_SP_DUTY: &str =
|
| 36 |
-
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx"));
|
| 37 |
-
const PTX_SP_BOOST_FUSED: &str =
|
| 38 |
-
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx"));
|
| 39 |
-
|
| 40 |
-
pub struct SpatialPoolerGpu {
|
| 41 |
-
dev: Arc<CudaDevice>,
|
| 42 |
-
|
| 43 |
-
// Config mirror (we don't touch CPU SpatialPooler after init).
|
| 44 |
-
input_bits: usize,
|
| 45 |
-
n_columns: usize,
|
| 46 |
-
synapses_per_col: usize,
|
| 47 |
-
conn_thr: f32,
|
| 48 |
-
inc: f32,
|
| 49 |
-
dec: f32,
|
| 50 |
-
sparsity: f32,
|
| 51 |
-
duty_period: f32,
|
| 52 |
-
boost_strength: f32,
|
| 53 |
-
|
| 54 |
-
// Persistent device state.
|
| 55 |
-
syn_bit: CudaSlice<u32>,
|
| 56 |
-
syn_perm: CudaSlice<f32>,
|
| 57 |
-
boost: CudaSlice<f32>,
|
| 58 |
-
active_duty: CudaSlice<f32>,
|
| 59 |
-
overlap_duty: CudaSlice<f32>,
|
| 60 |
-
|
| 61 |
-
// Transient scratch (reused each step).
|
| 62 |
-
inp_dev: CudaSlice<u8>,
|
| 63 |
-
raw: CudaSlice<u32>,
|
| 64 |
-
boosted: CudaSlice<f32>,
|
| 65 |
-
active_mask: CudaSlice<u8>,
|
| 66 |
-
|
| 67 |
-
// Reusable host buffer for D2H of active_mask.
|
| 68 |
-
host_mask: Vec<u8>,
|
| 69 |
-
|
| 70 |
-
/// Strict bit-parity with CPU reference. Enabled for tests.
|
| 71 |
-
/// Forces host-side boost/exp computation and the overlap-duty bump check
|
| 72 |
-
/// every step. Default false for max throughput.
|
| 73 |
-
strict_parity: bool,
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
impl SpatialPoolerGpu {
|
| 77 |
-
/// Copy CPU SpatialPooler state onto the device. This preserves the
|
| 78 |
-
/// exact seeded proximal synapse layout + initial permanences, so the
|
| 79 |
-
/// GPU SP is a bit-identical parallel implementation of the CPU SP.
|
| 80 |
-
pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> {
|
| 81 |
-
let dev = CudaDevice::new(0)?;
|
| 82 |
-
let cfg = &cpu.cfg;
|
| 83 |
-
let n = cfg.n_columns;
|
| 84 |
-
let s = cfg.potential_synapses;
|
| 85 |
-
|
| 86 |
-
// Flatten proximal dendrites into column-major arrays.
|
| 87 |
-
let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s);
|
| 88 |
-
let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s);
|
| 89 |
-
for col in &cpu.columns {
|
| 90 |
-
debug_assert_eq!(col.inputs.len(), s);
|
| 91 |
-
debug_assert_eq!(col.perms.len(), s);
|
| 92 |
-
syn_bit_h.extend_from_slice(&col.inputs);
|
| 93 |
-
syn_perm_h.extend_from_slice(&col.perms);
|
| 94 |
-
}
|
| 95 |
-
|
| 96 |
-
let syn_bit = dev.htod_sync_copy(&syn_bit_h)?;
|
| 97 |
-
let syn_perm = dev.htod_sync_copy(&syn_perm_h)?;
|
| 98 |
-
let boost = dev.htod_sync_copy(&cpu.boost)?;
|
| 99 |
-
let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?;
|
| 100 |
-
let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?;
|
| 101 |
-
|
| 102 |
-
let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?;
|
| 103 |
-
let raw: CudaSlice<u32> = dev.alloc_zeros(n)?;
|
| 104 |
-
let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?;
|
| 105 |
-
let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?;
|
| 106 |
-
|
| 107 |
-
// Load PTX modules. Each .ptx is a module containing one `extern "C"`
|
| 108 |
-
// function; we tag them by unique module names so multiple SP instances
|
| 109 |
-
// don't collide (cudarc uses the (module, func) pair).
|
| 110 |
-
// Actually: CudaDevice::load_ptx stores under the given module name
|
| 111 |
-
// globally on the device, so we use a deterministic naming scheme.
|
| 112 |
-
let modules = [
|
| 113 |
-
("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"),
|
| 114 |
-
("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"),
|
| 115 |
-
("htm_sp_learn", PTX_SP_LEARN, "sp_learn"),
|
| 116 |
-
("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"),
|
| 117 |
-
("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"),
|
| 118 |
-
];
|
| 119 |
-
for (modname, ptx, fnname) in modules {
|
| 120 |
-
// load_ptx is NOT idempotent β calling twice errors. For multi-region
|
| 121 |
-
// support we check-then-load.
|
| 122 |
-
if dev.get_func(modname, fnname).is_none() {
|
| 123 |
-
dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
|
| 124 |
-
}
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
Ok(Self {
|
| 128 |
-
dev,
|
| 129 |
-
input_bits: cfg.input_bits,
|
| 130 |
-
n_columns: n,
|
| 131 |
-
synapses_per_col: s,
|
| 132 |
-
conn_thr: cfg.connected_threshold,
|
| 133 |
-
inc: cfg.syn_perm_active_inc,
|
| 134 |
-
dec: cfg.syn_perm_inactive_dec,
|
| 135 |
-
sparsity: cfg.sparsity,
|
| 136 |
-
duty_period: cfg.duty_cycle_period,
|
| 137 |
-
boost_strength: cfg.boost_strength,
|
| 138 |
-
syn_bit,
|
| 139 |
-
syn_perm,
|
| 140 |
-
boost,
|
| 141 |
-
active_duty,
|
| 142 |
-
overlap_duty,
|
| 143 |
-
inp_dev,
|
| 144 |
-
raw,
|
| 145 |
-
boosted,
|
| 146 |
-
active_mask,
|
| 147 |
-
host_mask: vec![0u8; n],
|
| 148 |
-
strict_parity: false,
|
| 149 |
-
})
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
/// Enable strict bit-parity mode. Parity tests use this.
|
| 153 |
-
pub fn set_strict_parity(&mut self, strict: bool) {
|
| 154 |
-
self.strict_parity = strict;
|
| 155 |
-
}
|
| 156 |
-
|
| 157 |
-
/// Access to the underlying CudaDevice for host-side orchestration.
|
| 158 |
-
pub fn dev_ref(&self) -> &Arc<CudaDevice> {
|
| 159 |
-
&self.dev
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
// --- Fused-path accessors (immutable state reads + pointer-grabs). ---
|
| 163 |
-
pub fn n_columns_accessor(&self) -> usize { self.n_columns }
|
| 164 |
-
#[allow(dead_code)]
|
| 165 |
-
pub fn input_bits_accessor(&self) -> usize { self.input_bits }
|
| 166 |
-
pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col }
|
| 167 |
-
pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr }
|
| 168 |
-
pub fn inc_accessor(&self) -> f32 { self.inc }
|
| 169 |
-
pub fn dec_accessor(&self) -> f32 { self.dec }
|
| 170 |
-
pub fn sparsity_accessor(&self) -> f32 { self.sparsity }
|
| 171 |
-
pub fn duty_period_accessor(&self) -> f32 { self.duty_period }
|
| 172 |
-
#[allow(dead_code)]
|
| 173 |
-
pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength }
|
| 174 |
-
|
| 175 |
-
pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit }
|
| 176 |
-
pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm }
|
| 177 |
-
pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost }
|
| 178 |
-
pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty }
|
| 179 |
-
|
| 180 |
-
/// Compute the 95th-percentile-like initial threshold from raw overlaps
|
| 181 |
-
/// after a short warmup pass. Used to seed `inhibition_threshold` such
|
| 182 |
-
/// that activation rate starts near the sparsity target.
|
| 183 |
-
/// Placeholder (returns a conservative constant); real warmup pass
|
| 184 |
-
/// happens on the Rust orchestrator side.
|
| 185 |
-
pub fn initial_threshold_estimate(&self) -> f32 {
|
| 186 |
-
// With conn_thr=0.5, init_perm around 0.5Β±0.1, S=40, sparse SDR at 2%:
|
| 187 |
-
// expected overlap ~ 40 * 0.02 = 0.8 connected hits β boosted ~ 0.8.
|
| 188 |
-
// Top-K selects top 2%, so threshold for top 2% is roughly the
|
| 189 |
-
// 98th-percentile of boosted. Conservative start: 2.0.
|
| 190 |
-
// The per-column adaptation will quickly steer each column's thr.
|
| 191 |
-
2.0f32
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
/// Batched multi-step SP on the GPU. Processes T timesteps from a
|
| 195 |
-
/// pre-uploaded device input buffer. Emits `(T, n_cols)` u8 active-column
|
| 196 |
-
/// mask to `cols_dev_out` and `(T,)` active column index list (in a
|
| 197 |
-
/// per-step window of size k, padded with u32::MAX).
|
| 198 |
-
///
|
| 199 |
-
/// For each step, this runs the same 5-kernel pipeline as `compute`, but
|
| 200 |
-
/// skips the per-step boost/duty D2HβexpβH2D round-trip: instead it
|
| 201 |
-
/// accumulates to a host scratch once every `boost_interval` steps.
|
| 202 |
-
///
|
| 203 |
-
/// This is the fast path used by `HTMRegionGpu.step_many_gpu`.
|
| 204 |
-
#[allow(clippy::too_many_arguments)]
|
| 205 |
-
pub fn step_batch(
|
| 206 |
-
&mut self,
|
| 207 |
-
inputs_flat_dev: &CudaSlice<u8>,
|
| 208 |
-
t: usize,
|
| 209 |
-
input_bits: usize,
|
| 210 |
-
learn: bool,
|
| 211 |
-
cols_out: &mut [u8],
|
| 212 |
-
active_indices_host: &mut Vec<u32>,
|
| 213 |
-
) -> Result<(), DriverError> {
|
| 214 |
-
let n = self.n_columns;
|
| 215 |
-
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 216 |
-
debug_assert_eq!(cols_out.len(), t * n);
|
| 217 |
-
|
| 218 |
-
let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
|
| 219 |
-
let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
|
| 220 |
-
let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
|
| 221 |
-
let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
|
| 222 |
-
|
| 223 |
-
let overlap_cfg = LaunchConfig {
|
| 224 |
-
grid_dim: (n as u32, 1, 1),
|
| 225 |
-
block_dim: (128, 1, 1),
|
| 226 |
-
shared_mem_bytes: 0,
|
| 227 |
-
};
|
| 228 |
-
let topk_cfg = LaunchConfig {
|
| 229 |
-
grid_dim: (1, 1, 1),
|
| 230 |
-
block_dim: (256, 1, 1),
|
| 231 |
-
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 232 |
-
};
|
| 233 |
-
let learn_cfg = overlap_cfg;
|
| 234 |
-
let duty_cfg = LaunchConfig {
|
| 235 |
-
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 236 |
-
block_dim: (256, 1, 1),
|
| 237 |
-
shared_mem_bytes: 0,
|
| 238 |
-
};
|
| 239 |
-
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 240 |
-
|
| 241 |
-
// Reusable host buffer for the per-step active_mask D2H.
|
| 242 |
-
self.host_mask.resize(n, 0);
|
| 243 |
-
|
| 244 |
-
active_indices_host.clear();
|
| 245 |
-
|
| 246 |
-
for ti in 0..t {
|
| 247 |
-
// Point overlap kernel at the ti-th slice of the pre-uploaded input.
|
| 248 |
-
// cudarc CudaSlice doesn't have a "view" per se, so we must copy the
|
| 249 |
-
// slice into the reusable inp_dev buffer. This is a D2D copy β much
|
| 250 |
-
// faster than H2D.
|
| 251 |
-
// (Alternative: rewrite kernel to accept an offset; deferred.)
|
| 252 |
-
let in_off = ti * input_bits;
|
| 253 |
-
// Use dtod_copy via raw slice indexing: cudarc exposes slice() for this.
|
| 254 |
-
let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
|
| 255 |
-
self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
|
| 256 |
-
|
| 257 |
-
// 1. sp_overlap
|
| 258 |
-
unsafe {
|
| 259 |
-
overlap_fn.clone().launch(
|
| 260 |
-
overlap_cfg,
|
| 261 |
-
(
|
| 262 |
-
&self.inp_dev,
|
| 263 |
-
&self.syn_bit,
|
| 264 |
-
&self.syn_perm,
|
| 265 |
-
&self.boost,
|
| 266 |
-
self.conn_thr,
|
| 267 |
-
self.synapses_per_col as u32,
|
| 268 |
-
n as u32,
|
| 269 |
-
&mut self.raw,
|
| 270 |
-
&mut self.boosted,
|
| 271 |
-
),
|
| 272 |
-
)?;
|
| 273 |
-
}
|
| 274 |
-
|
| 275 |
-
// 2. Clear active_mask, then sp_topk
|
| 276 |
-
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 277 |
-
unsafe {
|
| 278 |
-
topk_fn.clone().launch(
|
| 279 |
-
topk_cfg,
|
| 280 |
-
(&self.boosted, n as u32, k as u32, &mut self.active_mask),
|
| 281 |
-
)?;
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
-
// 3. sp_learn
|
| 285 |
-
if learn {
|
| 286 |
-
unsafe {
|
| 287 |
-
learn_fn.clone().launch(
|
| 288 |
-
learn_cfg,
|
| 289 |
-
(
|
| 290 |
-
&self.active_mask,
|
| 291 |
-
&self.inp_dev,
|
| 292 |
-
&self.syn_bit,
|
| 293 |
-
&mut self.syn_perm,
|
| 294 |
-
self.inc,
|
| 295 |
-
self.dec,
|
| 296 |
-
self.synapses_per_col as u32,
|
| 297 |
-
n as u32,
|
| 298 |
-
),
|
| 299 |
-
)?;
|
| 300 |
-
}
|
| 301 |
-
}
|
| 302 |
-
|
| 303 |
-
// 4. duty update (device)
|
| 304 |
-
unsafe {
|
| 305 |
-
duty_fn.clone().launch(
|
| 306 |
-
duty_cfg,
|
| 307 |
-
(
|
| 308 |
-
&self.active_mask,
|
| 309 |
-
&self.raw,
|
| 310 |
-
&mut self.active_duty,
|
| 311 |
-
&mut self.overlap_duty,
|
| 312 |
-
&mut self.boost,
|
| 313 |
-
alpha,
|
| 314 |
-
1.0f32,
|
| 315 |
-
0.0f32,
|
| 316 |
-
0.0f32,
|
| 317 |
-
0u32,
|
| 318 |
-
n as u32,
|
| 319 |
-
),
|
| 320 |
-
)?;
|
| 321 |
-
}
|
| 322 |
-
|
| 323 |
-
// 5. Boost update. Two modes:
|
| 324 |
-
// * strict_parity (tests): host-side exp for bit-exact match.
|
| 325 |
-
// * default (production): GPU expf is close enough and ~10x faster
|
| 326 |
-
// since we skip the D2H/H2D round-trip.
|
| 327 |
-
if learn && self.boost_strength > 0.0 {
|
| 328 |
-
if self.strict_parity {
|
| 329 |
-
let mut duty_host = vec![0f32; n];
|
| 330 |
-
self.dev
|
| 331 |
-
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 332 |
-
let sum: f32 = duty_host.iter().sum();
|
| 333 |
-
let mean = sum / (n as f32);
|
| 334 |
-
let mut boost_host = vec![0f32; n];
|
| 335 |
-
for i in 0..n {
|
| 336 |
-
boost_host[i] =
|
| 337 |
-
(-self.boost_strength * (duty_host[i] - mean)).exp();
|
| 338 |
-
}
|
| 339 |
-
self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
|
| 340 |
-
|
| 341 |
-
// Permanence bump (rare). Only evaluated in strict mode.
|
| 342 |
-
let mut ov_host = vec![0f32; n];
|
| 343 |
-
self.dev
|
| 344 |
-
.dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
|
| 345 |
-
let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
|
| 346 |
-
if max_ov > 0.0 {
|
| 347 |
-
let thr = 0.001f32 * max_ov;
|
| 348 |
-
let bump = self.inc * 0.1f32;
|
| 349 |
-
let bump_cols: Vec<u32> = ov_host
|
| 350 |
-
.iter()
|
| 351 |
-
.enumerate()
|
| 352 |
-
.filter_map(|(i, &o)| {
|
| 353 |
-
if o < thr { Some(i as u32) } else { None }
|
| 354 |
-
})
|
| 355 |
-
.collect();
|
| 356 |
-
if !bump_cols.is_empty() {
|
| 357 |
-
let s = self.synapses_per_col;
|
| 358 |
-
let mut perm_host = vec![0f32; n * s];
|
| 359 |
-
self.dev
|
| 360 |
-
.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
|
| 361 |
-
for &c in &bump_cols {
|
| 362 |
-
let base = (c as usize) * s;
|
| 363 |
-
for p in &mut perm_host[base..base + s] {
|
| 364 |
-
*p = (*p + bump).min(1.0);
|
| 365 |
-
}
|
| 366 |
-
}
|
| 367 |
-
self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
|
| 368 |
-
}
|
| 369 |
-
}
|
| 370 |
-
} else {
|
| 371 |
-
// Fast path: fused mean + boost = expf(-strength*(ad-mean))
|
| 372 |
-
// in a single GPU block. Zero D2H, zero H2D β fully async.
|
| 373 |
-
let boost_fn = self
|
| 374 |
-
.dev
|
| 375 |
-
.get_func("htm_sp_boost_fused", "sp_boost_from_duty")
|
| 376 |
-
.expect("sp_boost_fused not loaded");
|
| 377 |
-
let boost_cfg = LaunchConfig {
|
| 378 |
-
grid_dim: (1, 1, 1),
|
| 379 |
-
block_dim: (1024, 1, 1),
|
| 380 |
-
shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
|
| 381 |
-
};
|
| 382 |
-
unsafe {
|
| 383 |
-
boost_fn.launch(
|
| 384 |
-
boost_cfg,
|
| 385 |
-
(
|
| 386 |
-
&self.active_duty,
|
| 387 |
-
&mut self.boost,
|
| 388 |
-
self.boost_strength,
|
| 389 |
-
n as u32,
|
| 390 |
-
),
|
| 391 |
-
)?;
|
| 392 |
-
}
|
| 393 |
-
}
|
| 394 |
-
}
|
| 395 |
-
|
| 396 |
-
// D2H the active_mask for this step. This is the single
|
| 397 |
-
// unavoidable sync point per step β CPU TM needs the active
|
| 398 |
-
// indices for its next state update. At 2048 bytes / step this
|
| 399 |
-
// is tiny in bandwidth but costs a full syncronize (~5-10ΞΌs).
|
| 400 |
-
self.dev
|
| 401 |
-
.dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
|
| 402 |
-
let co = ti * n;
|
| 403 |
-
cols_out[co..co + n].copy_from_slice(&self.host_mask);
|
| 404 |
-
// Extract active indices.
|
| 405 |
-
for (i, &b) in self.host_mask.iter().enumerate() {
|
| 406 |
-
if b != 0 {
|
| 407 |
-
active_indices_host.push(i as u32);
|
| 408 |
-
}
|
| 409 |
-
}
|
| 410 |
-
// Insert separator (u32::MAX) between steps to demarcate step boundaries.
|
| 411 |
-
active_indices_host.push(u32::MAX);
|
| 412 |
-
}
|
| 413 |
-
|
| 414 |
-
Ok(())
|
| 415 |
-
}
|
| 416 |
-
|
| 417 |
-
/// Fully-on-GPU batched SP + TM. Zero per-step host sync.
|
| 418 |
-
///
|
| 419 |
-
/// Inputs:
|
| 420 |
-
/// inputs_flat_dev : (T * input_bits) u8 already uploaded
|
| 421 |
-
/// cols_dev : (T * n_cols) u8 output β active-column mask per step
|
| 422 |
-
/// anom_dev : (T,) f32 output β anomaly score per step
|
| 423 |
-
/// tm : persistent GPU TemporalMemory for this region
|
| 424 |
-
#[allow(clippy::too_many_arguments)]
|
| 425 |
-
pub fn step_batch_with_tm(
|
| 426 |
-
&mut self,
|
| 427 |
-
inputs_flat_dev: &CudaSlice<u8>,
|
| 428 |
-
t: usize,
|
| 429 |
-
input_bits: usize,
|
| 430 |
-
learn: bool,
|
| 431 |
-
cols_dev: &mut CudaSlice<u8>,
|
| 432 |
-
anom_dev: &mut CudaSlice<f32>,
|
| 433 |
-
tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu,
|
| 434 |
-
) -> Result<(), DriverError> {
|
| 435 |
-
let n = self.n_columns;
|
| 436 |
-
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 437 |
-
debug_assert_eq!(cols_dev.len(), t * n);
|
| 438 |
-
debug_assert_eq!(anom_dev.len(), t);
|
| 439 |
-
|
| 440 |
-
let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
|
| 441 |
-
let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
|
| 442 |
-
let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
|
| 443 |
-
let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
|
| 444 |
-
|
| 445 |
-
let overlap_cfg = LaunchConfig {
|
| 446 |
-
grid_dim: (n as u32, 1, 1),
|
| 447 |
-
block_dim: (128, 1, 1),
|
| 448 |
-
shared_mem_bytes: 0,
|
| 449 |
-
};
|
| 450 |
-
let topk_cfg = LaunchConfig {
|
| 451 |
-
grid_dim: (1, 1, 1),
|
| 452 |
-
block_dim: (256, 1, 1),
|
| 453 |
-
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 454 |
-
};
|
| 455 |
-
let learn_cfg = overlap_cfg;
|
| 456 |
-
let duty_cfg = LaunchConfig {
|
| 457 |
-
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 458 |
-
block_dim: (256, 1, 1),
|
| 459 |
-
shared_mem_bytes: 0,
|
| 460 |
-
};
|
| 461 |
-
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 462 |
-
|
| 463 |
-
for ti in 0..t {
|
| 464 |
-
let in_off = ti * input_bits;
|
| 465 |
-
let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
|
| 466 |
-
self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
|
| 467 |
-
|
| 468 |
-
// 1. sp_overlap
|
| 469 |
-
unsafe {
|
| 470 |
-
overlap_fn.clone().launch(
|
| 471 |
-
overlap_cfg,
|
| 472 |
-
(
|
| 473 |
-
&self.inp_dev,
|
| 474 |
-
&self.syn_bit,
|
| 475 |
-
&self.syn_perm,
|
| 476 |
-
&self.boost,
|
| 477 |
-
self.conn_thr,
|
| 478 |
-
self.synapses_per_col as u32,
|
| 479 |
-
n as u32,
|
| 480 |
-
&mut self.raw,
|
| 481 |
-
&mut self.boosted,
|
| 482 |
-
),
|
| 483 |
-
)?;
|
| 484 |
-
}
|
| 485 |
-
|
| 486 |
-
// 2. clear + sp_topk
|
| 487 |
-
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 488 |
-
unsafe {
|
| 489 |
-
topk_fn.clone().launch(
|
| 490 |
-
topk_cfg,
|
| 491 |
-
(&self.boosted, n as u32, k as u32, &mut self.active_mask),
|
| 492 |
-
)?;
|
| 493 |
-
}
|
| 494 |
-
|
| 495 |
-
// 3. sp_learn
|
| 496 |
-
if learn {
|
| 497 |
-
unsafe {
|
| 498 |
-
learn_fn.clone().launch(
|
| 499 |
-
learn_cfg,
|
| 500 |
-
(
|
| 501 |
-
&self.active_mask,
|
| 502 |
-
&self.inp_dev,
|
| 503 |
-
&self.syn_bit,
|
| 504 |
-
&mut self.syn_perm,
|
| 505 |
-
self.inc,
|
| 506 |
-
self.dec,
|
| 507 |
-
self.synapses_per_col as u32,
|
| 508 |
-
n as u32,
|
| 509 |
-
),
|
| 510 |
-
)?;
|
| 511 |
-
}
|
| 512 |
-
}
|
| 513 |
-
|
| 514 |
-
// 4. duty update (stage 1: no-boost write)
|
| 515 |
-
unsafe {
|
| 516 |
-
duty_fn.clone().launch(
|
| 517 |
-
duty_cfg,
|
| 518 |
-
(
|
| 519 |
-
&self.active_mask,
|
| 520 |
-
&self.raw,
|
| 521 |
-
&mut self.active_duty,
|
| 522 |
-
&mut self.overlap_duty,
|
| 523 |
-
&mut self.boost,
|
| 524 |
-
alpha,
|
| 525 |
-
1.0f32,
|
| 526 |
-
0.0f32,
|
| 527 |
-
0.0f32,
|
| 528 |
-
0u32,
|
| 529 |
-
n as u32,
|
| 530 |
-
),
|
| 531 |
-
)?;
|
| 532 |
-
}
|
| 533 |
-
|
| 534 |
-
// 5. Boost update: fused GPU kernel (no D2H).
|
| 535 |
-
if learn && self.boost_strength > 0.0 {
|
| 536 |
-
let boost_fn = self.dev
|
| 537 |
-
.get_func("htm_sp_boost_fused", "sp_boost_from_duty")
|
| 538 |
-
.expect("sp_boost_fused not loaded");
|
| 539 |
-
let boost_cfg = LaunchConfig {
|
| 540 |
-
grid_dim: (1, 1, 1),
|
| 541 |
-
block_dim: (1024, 1, 1),
|
| 542 |
-
shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
|
| 543 |
-
};
|
| 544 |
-
unsafe {
|
| 545 |
-
boost_fn.launch(
|
| 546 |
-
boost_cfg,
|
| 547 |
-
(
|
| 548 |
-
&self.active_duty,
|
| 549 |
-
&mut self.boost,
|
| 550 |
-
self.boost_strength,
|
| 551 |
-
n as u32,
|
| 552 |
-
),
|
| 553 |
-
)?;
|
| 554 |
-
}
|
| 555 |
-
}
|
| 556 |
-
|
| 557 |
-
// 6. Copy active_mask slice into cols_dev[ti*n .. (ti+1)*n].
|
| 558 |
-
let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n);
|
| 559 |
-
self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?;
|
| 560 |
-
|
| 561 |
-
// 7. GPU TM step: predict + activate + anomaly + learn, all on device.
|
| 562 |
-
tm.step(&self.active_mask, anom_dev, ti as u32, learn)?;
|
| 563 |
-
}
|
| 564 |
-
|
| 565 |
-
Ok(())
|
| 566 |
-
}
|
| 567 |
-
|
| 568 |
-
/// One SP step on the GPU. Returns sorted active-column indices.
|
| 569 |
-
pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> {
|
| 570 |
-
debug_assert_eq!(input.len(), self.input_bits);
|
| 571 |
-
let n = self.n_columns;
|
| 572 |
-
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 573 |
-
|
| 574 |
-
// 1. H2D input SDR.
|
| 575 |
-
self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?;
|
| 576 |
-
|
| 577 |
-
// 2. Launch sp_overlap: grid=n_columns, block=128.
|
| 578 |
-
let overlap_fn = self
|
| 579 |
-
.dev
|
| 580 |
-
.get_func("htm_sp_overlap", "sp_overlap")
|
| 581 |
-
.expect("sp_overlap not loaded");
|
| 582 |
-
let overlap_cfg = LaunchConfig {
|
| 583 |
-
grid_dim: (n as u32, 1, 1),
|
| 584 |
-
block_dim: (128, 1, 1),
|
| 585 |
-
shared_mem_bytes: 0,
|
| 586 |
-
};
|
| 587 |
-
unsafe {
|
| 588 |
-
overlap_fn.launch(
|
| 589 |
-
overlap_cfg,
|
| 590 |
-
(
|
| 591 |
-
&self.inp_dev,
|
| 592 |
-
&self.syn_bit,
|
| 593 |
-
&self.syn_perm,
|
| 594 |
-
&self.boost,
|
| 595 |
-
self.conn_thr,
|
| 596 |
-
self.synapses_per_col as u32,
|
| 597 |
-
n as u32,
|
| 598 |
-
&mut self.raw,
|
| 599 |
-
&mut self.boosted,
|
| 600 |
-
),
|
| 601 |
-
)?;
|
| 602 |
-
}
|
| 603 |
-
|
| 604 |
-
// 3. Launch sp_topk: single block, shared mem = n_columns * f32.
|
| 605 |
-
let topk_fn = self
|
| 606 |
-
.dev
|
| 607 |
-
.get_func("htm_sp_topk", "sp_topk_select")
|
| 608 |
-
.expect("sp_topk not loaded");
|
| 609 |
-
let topk_cfg = LaunchConfig {
|
| 610 |
-
grid_dim: (1, 1, 1),
|
| 611 |
-
block_dim: (256, 1, 1),
|
| 612 |
-
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 613 |
-
};
|
| 614 |
-
// Clear active_mask first. memset_zeros avoids an H2D of a host
|
| 615 |
-
// zeroes vector every step.
|
| 616 |
-
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 617 |
-
unsafe {
|
| 618 |
-
topk_fn.launch(
|
| 619 |
-
topk_cfg,
|
| 620 |
-
(
|
| 621 |
-
&self.boosted,
|
| 622 |
-
n as u32,
|
| 623 |
-
k as u32,
|
| 624 |
-
&mut self.active_mask,
|
| 625 |
-
),
|
| 626 |
-
)?;
|
| 627 |
-
}
|
| 628 |
-
|
| 629 |
-
// 4. Optional: sp_learn on active columns.
|
| 630 |
-
if learn {
|
| 631 |
-
let learn_fn = self
|
| 632 |
-
.dev
|
| 633 |
-
.get_func("htm_sp_learn", "sp_learn")
|
| 634 |
-
.expect("sp_learn not loaded");
|
| 635 |
-
let learn_cfg = LaunchConfig {
|
| 636 |
-
grid_dim: (n as u32, 1, 1),
|
| 637 |
-
block_dim: (128, 1, 1),
|
| 638 |
-
shared_mem_bytes: 0,
|
| 639 |
-
};
|
| 640 |
-
unsafe {
|
| 641 |
-
learn_fn.launch(
|
| 642 |
-
learn_cfg,
|
| 643 |
-
(
|
| 644 |
-
&self.active_mask,
|
| 645 |
-
&self.inp_dev,
|
| 646 |
-
&self.syn_bit,
|
| 647 |
-
&mut self.syn_perm,
|
| 648 |
-
self.inc,
|
| 649 |
-
self.dec,
|
| 650 |
-
self.synapses_per_col as u32,
|
| 651 |
-
n as u32,
|
| 652 |
-
),
|
| 653 |
-
)?;
|
| 654 |
-
}
|
| 655 |
-
}
|
| 656 |
-
|
| 657 |
-
// 5. Duty cycle + boost update. Always runs (matches CPU).
|
| 658 |
-
// We need mean_duty on the host β compute BEFORE the update (matches
|
| 659 |
-
// CPU sp.rs line 200-205 where mean is computed then written).
|
| 660 |
-
// Actually CPU computes mean of the PRE-update duty cycles too? Re-read:
|
| 661 |
-
// sp.rs lines 186-196 update duty cycles (pre-mean).
|
| 662 |
-
// Line 202: mean = sum(active_duty_cycle) / n β after update.
|
| 663 |
-
// Line 204: boost[i] = exp(-strength*(active_duty[i] - mean)).
|
| 664 |
-
// So mean is on POST-update values.
|
| 665 |
-
// Easiest: 1) run duty update with boost_strength=0 (skip boost calc),
|
| 666 |
-
// 2) D2H active_duty, compute mean, 3) run a boost-only kernel
|
| 667 |
-
// OR inline the exp() in a second launch with mean passed.
|
| 668 |
-
//
|
| 669 |
-
// For simplicity and correctness we fuse: run the duty kernel with
|
| 670 |
-
// mean=0 and boost_strength=0 (disables boost write), then D2H to
|
| 671 |
-
// compute mean, then re-launch with the true mean. Two launches, one
|
| 672 |
-
// tiny D2H (n Γ f32). At n=2048 this is 8KB per step β negligible.
|
| 673 |
-
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 674 |
-
let duty_fn = self
|
| 675 |
-
.dev
|
| 676 |
-
.get_func("htm_sp_duty", "sp_duty_update")
|
| 677 |
-
.expect("sp_duty not loaded");
|
| 678 |
-
let duty_cfg = LaunchConfig {
|
| 679 |
-
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 680 |
-
block_dim: (256, 1, 1),
|
| 681 |
-
shared_mem_bytes: 0,
|
| 682 |
-
};
|
| 683 |
-
// Stage 1: update duty cycles (boost_strength=0 -> no write).
|
| 684 |
-
unsafe {
|
| 685 |
-
duty_fn.launch(
|
| 686 |
-
duty_cfg,
|
| 687 |
-
(
|
| 688 |
-
&self.active_mask,
|
| 689 |
-
&self.raw,
|
| 690 |
-
&mut self.active_duty,
|
| 691 |
-
&mut self.overlap_duty,
|
| 692 |
-
&mut self.boost,
|
| 693 |
-
alpha,
|
| 694 |
-
1.0f32, // stim_thr
|
| 695 |
-
0.0f32, // boost_strength = 0 -> skip write
|
| 696 |
-
0.0f32, // mean_duty (unused)
|
| 697 |
-
0u32, // learn_flag = 0
|
| 698 |
-
n as u32,
|
| 699 |
-
),
|
| 700 |
-
)?;
|
| 701 |
-
}
|
| 702 |
-
|
| 703 |
-
if learn && self.boost_strength > 0.0 && self.strict_parity {
|
| 704 |
-
// Boost update must bit-match CPU `f32::exp`, so we compute it on
|
| 705 |
-
// the host and copy back. Cost per step: 8KB D2H + 8KB H2D at n=2048.
|
| 706 |
-
// Critical for learning parity β CUDA expf (even without fast-math)
|
| 707 |
-
// uses different rounding for some inputs than host libm.
|
| 708 |
-
let mut duty_host = vec![0f32; n];
|
| 709 |
-
self.dev
|
| 710 |
-
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 711 |
-
let sum: f32 = duty_host.iter().sum();
|
| 712 |
-
let mean = sum / (n as f32);
|
| 713 |
-
let mut boost_host = vec![0f32; n];
|
| 714 |
-
for i in 0..n {
|
| 715 |
-
boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp();
|
| 716 |
-
}
|
| 717 |
-
self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
|
| 718 |
-
|
| 719 |
-
// CPU sp.rs 210-226: permanence bump for chronically under-stimulated
|
| 720 |
-
// columns. If overlap_duty_cycle[i] < 0.001 * max(overlap_duty_cycle),
|
| 721 |
-
// add inc*0.1 to every synapse of column i (clamped to 1.0).
|
| 722 |
-
// This runs only once per step and only for the rare cases, but we
|
| 723 |
-
// need it for bit-exact parity with CPU learn.
|
| 724 |
-
let mut ov_host = vec![0f32; n];
|
| 725 |
-
self.dev
|
| 726 |
-
.dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
|
| 727 |
-
let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
|
| 728 |
-
if max_ov > 0.0 {
|
| 729 |
-
let thr = 0.001f32 * max_ov;
|
| 730 |
-
let bump = self.inc * 0.1f32;
|
| 731 |
-
// Find columns needing a bump. Usually empty. Rare β D2H/H2D
|
| 732 |
-
// of syn_perm is cheap (n*S*4 = 320KB at n=2048,S=40).
|
| 733 |
-
let bump_cols: Vec<u32> = ov_host
|
| 734 |
-
.iter()
|
| 735 |
-
.enumerate()
|
| 736 |
-
.filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None })
|
| 737 |
-
.collect();
|
| 738 |
-
if !bump_cols.is_empty() {
|
| 739 |
-
// Download, bump, upload. (Keeps implementation simple and
|
| 740 |
-
// bit-exact. Could kernelize later.)
|
| 741 |
-
let s = self.synapses_per_col;
|
| 742 |
-
let mut perm_host = vec![0f32; n * s];
|
| 743 |
-
self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
|
| 744 |
-
for &c in &bump_cols {
|
| 745 |
-
let base = (c as usize) * s;
|
| 746 |
-
for p in &mut perm_host[base..base + s] {
|
| 747 |
-
*p = (*p + bump).min(1.0);
|
| 748 |
-
}
|
| 749 |
-
}
|
| 750 |
-
self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
|
| 751 |
-
}
|
| 752 |
-
}
|
| 753 |
-
} else if learn && self.boost_strength > 0.0 {
|
| 754 |
-
// Fast path: GPU-side boost using the already-loaded duty kernel.
|
| 755 |
-
let mut duty_host = vec![0f32; n];
|
| 756 |
-
self.dev
|
| 757 |
-
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 758 |
-
let sum: f32 = duty_host.iter().sum();
|
| 759 |
-
let mean = sum / (n as f32);
|
| 760 |
-
let boost_fn = self
|
| 761 |
-
.dev
|
| 762 |
-
.get_func("htm_sp_duty", "sp_duty_update")
|
| 763 |
-
.expect("sp_duty not loaded");
|
| 764 |
-
unsafe {
|
| 765 |
-
boost_fn.launch(
|
| 766 |
-
duty_cfg,
|
| 767 |
-
(
|
| 768 |
-
&self.active_mask,
|
| 769 |
-
&self.raw,
|
| 770 |
-
&mut self.active_duty,
|
| 771 |
-
&mut self.overlap_duty,
|
| 772 |
-
&mut self.boost,
|
| 773 |
-
0.0f32,
|
| 774 |
-
1.0f32,
|
| 775 |
-
self.boost_strength,
|
| 776 |
-
mean,
|
| 777 |
-
1u32,
|
| 778 |
-
n as u32,
|
| 779 |
-
),
|
| 780 |
-
)?;
|
| 781 |
-
}
|
| 782 |
-
}
|
| 783 |
-
|
| 784 |
-
// 6. D2H active_mask and convert to sorted index list.
|
| 785 |
-
self.dev
|
| 786 |
-
.dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
|
| 787 |
-
let mut active: Vec<u32> = Vec::with_capacity(k);
|
| 788 |
-
for (i, &b) in self.host_mask.iter().enumerate() {
|
| 789 |
-
if b != 0 {
|
| 790 |
-
active.push(i as u32);
|
| 791 |
-
}
|
| 792 |
-
}
|
| 793 |
-
debug_assert_eq!(active.len(), k, "SP must emit exactly k winners");
|
| 794 |
-
Ok(active)
|
| 795 |
-
}
|
| 796 |
-
}
|
|
|
|
| 1 |
+
//! GPU implementation of the Spatial Pooler.
|
| 2 |
+
//!
|
| 3 |
+
//! One `SpatialPoolerGpu` owns a set of persistent device buffers + 4 PTX
|
| 4 |
+
//! kernels. `compute(input, learn)` performs one SP step and returns the
|
| 5 |
+
//! sorted active-column indices (host `Vec<u32>`) β this is what the CPU
|
| 6 |
+
//! TemporalMemory consumes.
|
| 7 |
+
//!
|
| 8 |
+
//! Persistent state on device (per region):
|
| 9 |
+
//! syn_bit : u32 [n_columns Γ S] (constant after init)
|
| 10 |
+
//! syn_perm : f32 [n_columns Γ S] (updated by sp_learn)
|
| 11 |
+
//! boost : f32 [n_columns]
|
| 12 |
+
//! active_duty : f32 [n_columns]
|
| 13 |
+
//! overlap_duty: f32 [n_columns]
|
| 14 |
+
//!
|
| 15 |
+
//! Per-step transient state:
|
| 16 |
+
//! inp_dev : u8 [input_bits] (H2D copy each step)
|
| 17 |
+
//! raw : u32 [n_columns]
|
| 18 |
+
//! boosted : f32 [n_columns]
|
| 19 |
+
//! active_mask : u8 [n_columns] (topk output, D2H at the end)
|
| 20 |
+
|
| 21 |
+
use std::sync::Arc;
|
| 22 |
+
|
| 23 |
+
use cudarc::driver::{CudaDevice, CudaSlice, DeviceSlice, DriverError, LaunchAsync, LaunchConfig};
|
| 24 |
+
use cudarc::nvrtc::Ptx;
|
| 25 |
+
|
| 26 |
+
use crate::sp::SpatialPooler;
|
| 27 |
+
|
| 28 |
+
// Embed PTX at compile time. OUT_DIR is set by build.rs.
|
| 29 |
+
const PTX_SP_OVERLAP: &str =
|
| 30 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_overlap.ptx"));
|
| 31 |
+
const PTX_SP_TOPK: &str =
|
| 32 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_topk.ptx"));
|
| 33 |
+
const PTX_SP_LEARN: &str =
|
| 34 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_learn.ptx"));
|
| 35 |
+
const PTX_SP_DUTY: &str =
|
| 36 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_duty.ptx"));
|
| 37 |
+
const PTX_SP_BOOST_FUSED: &str =
|
| 38 |
+
include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/sp_boost_fused.ptx"));
|
| 39 |
+
|
| 40 |
+
pub struct SpatialPoolerGpu {
|
| 41 |
+
dev: Arc<CudaDevice>,
|
| 42 |
+
|
| 43 |
+
// Config mirror (we don't touch CPU SpatialPooler after init).
|
| 44 |
+
input_bits: usize,
|
| 45 |
+
n_columns: usize,
|
| 46 |
+
synapses_per_col: usize,
|
| 47 |
+
conn_thr: f32,
|
| 48 |
+
inc: f32,
|
| 49 |
+
dec: f32,
|
| 50 |
+
sparsity: f32,
|
| 51 |
+
duty_period: f32,
|
| 52 |
+
boost_strength: f32,
|
| 53 |
+
|
| 54 |
+
// Persistent device state.
|
| 55 |
+
syn_bit: CudaSlice<u32>,
|
| 56 |
+
syn_perm: CudaSlice<f32>,
|
| 57 |
+
boost: CudaSlice<f32>,
|
| 58 |
+
active_duty: CudaSlice<f32>,
|
| 59 |
+
overlap_duty: CudaSlice<f32>,
|
| 60 |
+
|
| 61 |
+
// Transient scratch (reused each step).
|
| 62 |
+
inp_dev: CudaSlice<u8>,
|
| 63 |
+
raw: CudaSlice<u32>,
|
| 64 |
+
boosted: CudaSlice<f32>,
|
| 65 |
+
active_mask: CudaSlice<u8>,
|
| 66 |
+
|
| 67 |
+
// Reusable host buffer for D2H of active_mask.
|
| 68 |
+
host_mask: Vec<u8>,
|
| 69 |
+
|
| 70 |
+
/// Strict bit-parity with CPU reference. Enabled for tests.
|
| 71 |
+
/// Forces host-side boost/exp computation and the overlap-duty bump check
|
| 72 |
+
/// every step. Default false for max throughput.
|
| 73 |
+
strict_parity: bool,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
impl SpatialPoolerGpu {
|
| 77 |
+
/// Copy CPU SpatialPooler state onto the device. This preserves the
|
| 78 |
+
/// exact seeded proximal synapse layout + initial permanences, so the
|
| 79 |
+
/// GPU SP is a bit-identical parallel implementation of the CPU SP.
|
| 80 |
+
pub fn from_cpu(cpu: &SpatialPooler) -> Result<Self, DriverError> {
|
| 81 |
+
let dev = CudaDevice::new(0)?;
|
| 82 |
+
let cfg = &cpu.cfg;
|
| 83 |
+
let n = cfg.n_columns;
|
| 84 |
+
let s = cfg.potential_synapses;
|
| 85 |
+
|
| 86 |
+
// Flatten proximal dendrites into column-major arrays.
|
| 87 |
+
let mut syn_bit_h: Vec<u32> = Vec::with_capacity(n * s);
|
| 88 |
+
let mut syn_perm_h: Vec<f32> = Vec::with_capacity(n * s);
|
| 89 |
+
for col in &cpu.columns {
|
| 90 |
+
debug_assert_eq!(col.inputs.len(), s);
|
| 91 |
+
debug_assert_eq!(col.perms.len(), s);
|
| 92 |
+
syn_bit_h.extend_from_slice(&col.inputs);
|
| 93 |
+
syn_perm_h.extend_from_slice(&col.perms);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
let syn_bit = dev.htod_sync_copy(&syn_bit_h)?;
|
| 97 |
+
let syn_perm = dev.htod_sync_copy(&syn_perm_h)?;
|
| 98 |
+
let boost = dev.htod_sync_copy(&cpu.boost)?;
|
| 99 |
+
let active_duty = dev.htod_sync_copy(&cpu.active_duty_cycle)?;
|
| 100 |
+
let overlap_duty = dev.htod_sync_copy(&cpu.overlap_duty_cycle)?;
|
| 101 |
+
|
| 102 |
+
let inp_dev: CudaSlice<u8> = dev.alloc_zeros(cfg.input_bits)?;
|
| 103 |
+
let raw: CudaSlice<u32> = dev.alloc_zeros(n)?;
|
| 104 |
+
let boosted: CudaSlice<f32> = dev.alloc_zeros(n)?;
|
| 105 |
+
let active_mask: CudaSlice<u8> = dev.alloc_zeros(n)?;
|
| 106 |
+
|
| 107 |
+
// Load PTX modules. Each .ptx is a module containing one `extern "C"`
|
| 108 |
+
// function; we tag them by unique module names so multiple SP instances
|
| 109 |
+
// don't collide (cudarc uses the (module, func) pair).
|
| 110 |
+
// Actually: CudaDevice::load_ptx stores under the given module name
|
| 111 |
+
// globally on the device, so we use a deterministic naming scheme.
|
| 112 |
+
let modules = [
|
| 113 |
+
("htm_sp_overlap", PTX_SP_OVERLAP, "sp_overlap"),
|
| 114 |
+
("htm_sp_topk", PTX_SP_TOPK, "sp_topk_select"),
|
| 115 |
+
("htm_sp_learn", PTX_SP_LEARN, "sp_learn"),
|
| 116 |
+
("htm_sp_duty", PTX_SP_DUTY, "sp_duty_update"),
|
| 117 |
+
("htm_sp_boost_fused", PTX_SP_BOOST_FUSED, "sp_boost_from_duty"),
|
| 118 |
+
];
|
| 119 |
+
for (modname, ptx, fnname) in modules {
|
| 120 |
+
// load_ptx is NOT idempotent β calling twice errors. For multi-region
|
| 121 |
+
// support we check-then-load.
|
| 122 |
+
if dev.get_func(modname, fnname).is_none() {
|
| 123 |
+
dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
Ok(Self {
|
| 128 |
+
dev,
|
| 129 |
+
input_bits: cfg.input_bits,
|
| 130 |
+
n_columns: n,
|
| 131 |
+
synapses_per_col: s,
|
| 132 |
+
conn_thr: cfg.connected_threshold,
|
| 133 |
+
inc: cfg.syn_perm_active_inc,
|
| 134 |
+
dec: cfg.syn_perm_inactive_dec,
|
| 135 |
+
sparsity: cfg.sparsity,
|
| 136 |
+
duty_period: cfg.duty_cycle_period,
|
| 137 |
+
boost_strength: cfg.boost_strength,
|
| 138 |
+
syn_bit,
|
| 139 |
+
syn_perm,
|
| 140 |
+
boost,
|
| 141 |
+
active_duty,
|
| 142 |
+
overlap_duty,
|
| 143 |
+
inp_dev,
|
| 144 |
+
raw,
|
| 145 |
+
boosted,
|
| 146 |
+
active_mask,
|
| 147 |
+
host_mask: vec![0u8; n],
|
| 148 |
+
strict_parity: false,
|
| 149 |
+
})
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/// Enable strict bit-parity mode. Parity tests use this.
|
| 153 |
+
pub fn set_strict_parity(&mut self, strict: bool) {
|
| 154 |
+
self.strict_parity = strict;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
/// Access to the underlying CudaDevice for host-side orchestration.
|
| 158 |
+
pub fn dev_ref(&self) -> &Arc<CudaDevice> {
|
| 159 |
+
&self.dev
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// --- Fused-path accessors (immutable state reads + pointer-grabs). ---
|
| 163 |
+
pub fn n_columns_accessor(&self) -> usize { self.n_columns }
|
| 164 |
+
#[allow(dead_code)]
|
| 165 |
+
pub fn input_bits_accessor(&self) -> usize { self.input_bits }
|
| 166 |
+
pub fn synapses_per_col_accessor(&self) -> usize { self.synapses_per_col }
|
| 167 |
+
pub fn conn_thr_accessor(&self) -> f32 { self.conn_thr }
|
| 168 |
+
pub fn inc_accessor(&self) -> f32 { self.inc }
|
| 169 |
+
pub fn dec_accessor(&self) -> f32 { self.dec }
|
| 170 |
+
pub fn sparsity_accessor(&self) -> f32 { self.sparsity }
|
| 171 |
+
pub fn duty_period_accessor(&self) -> f32 { self.duty_period }
|
| 172 |
+
#[allow(dead_code)]
|
| 173 |
+
pub fn boost_strength_accessor(&self) -> f32 { self.boost_strength }
|
| 174 |
+
|
| 175 |
+
pub fn syn_bit_accessor(&self) -> &CudaSlice<u32> { &self.syn_bit }
|
| 176 |
+
pub fn syn_perm_accessor(&self) -> &CudaSlice<f32> { &self.syn_perm }
|
| 177 |
+
pub fn boost_accessor(&self) -> &CudaSlice<f32> { &self.boost }
|
| 178 |
+
pub fn active_duty_accessor(&self) -> &CudaSlice<f32> { &self.active_duty }
|
| 179 |
+
|
| 180 |
+
/// Compute the 95th-percentile-like initial threshold from raw overlaps
|
| 181 |
+
/// after a short warmup pass. Used to seed `inhibition_threshold` such
|
| 182 |
+
/// that activation rate starts near the sparsity target.
|
| 183 |
+
/// Placeholder (returns a conservative constant); real warmup pass
|
| 184 |
+
/// happens on the Rust orchestrator side.
|
| 185 |
+
pub fn initial_threshold_estimate(&self) -> f32 {
|
| 186 |
+
// With conn_thr=0.5, init_perm around 0.5Β±0.1, S=40, sparse SDR at 2%:
|
| 187 |
+
// expected overlap ~ 40 * 0.02 = 0.8 connected hits β boosted ~ 0.8.
|
| 188 |
+
// Top-K selects top 2%, so threshold for top 2% is roughly the
|
| 189 |
+
// 98th-percentile of boosted. Conservative start: 2.0.
|
| 190 |
+
// The per-column adaptation will quickly steer each column's thr.
|
| 191 |
+
2.0f32
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
/// Batched multi-step SP on the GPU. Processes T timesteps from a
|
| 195 |
+
/// pre-uploaded device input buffer. Emits `(T, n_cols)` u8 active-column
|
| 196 |
+
/// mask to `cols_dev_out` and `(T,)` active column index list (in a
|
| 197 |
+
/// per-step window of size k, padded with u32::MAX).
|
| 198 |
+
///
|
| 199 |
+
/// For each step, this runs the same 5-kernel pipeline as `compute`, but
|
| 200 |
+
/// skips the per-step boost/duty D2HβexpβH2D round-trip: instead it
|
| 201 |
+
/// accumulates to a host scratch once every `boost_interval` steps.
|
| 202 |
+
///
|
| 203 |
+
/// This is the fast path used by `HTMRegionGpu.step_many_gpu`.
|
| 204 |
+
#[allow(clippy::too_many_arguments)]
|
| 205 |
+
pub fn step_batch(
|
| 206 |
+
&mut self,
|
| 207 |
+
inputs_flat_dev: &CudaSlice<u8>,
|
| 208 |
+
t: usize,
|
| 209 |
+
input_bits: usize,
|
| 210 |
+
learn: bool,
|
| 211 |
+
cols_out: &mut [u8],
|
| 212 |
+
active_indices_host: &mut Vec<u32>,
|
| 213 |
+
) -> Result<(), DriverError> {
|
| 214 |
+
let n = self.n_columns;
|
| 215 |
+
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 216 |
+
debug_assert_eq!(cols_out.len(), t * n);
|
| 217 |
+
|
| 218 |
+
let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
|
| 219 |
+
let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
|
| 220 |
+
let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
|
| 221 |
+
let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
|
| 222 |
+
|
| 223 |
+
let overlap_cfg = LaunchConfig {
|
| 224 |
+
grid_dim: (n as u32, 1, 1),
|
| 225 |
+
block_dim: (128, 1, 1),
|
| 226 |
+
shared_mem_bytes: 0,
|
| 227 |
+
};
|
| 228 |
+
let topk_cfg = LaunchConfig {
|
| 229 |
+
grid_dim: (1, 1, 1),
|
| 230 |
+
block_dim: (256, 1, 1),
|
| 231 |
+
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 232 |
+
};
|
| 233 |
+
let learn_cfg = overlap_cfg;
|
| 234 |
+
let duty_cfg = LaunchConfig {
|
| 235 |
+
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 236 |
+
block_dim: (256, 1, 1),
|
| 237 |
+
shared_mem_bytes: 0,
|
| 238 |
+
};
|
| 239 |
+
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 240 |
+
|
| 241 |
+
// Reusable host buffer for the per-step active_mask D2H.
|
| 242 |
+
self.host_mask.resize(n, 0);
|
| 243 |
+
|
| 244 |
+
active_indices_host.clear();
|
| 245 |
+
|
| 246 |
+
for ti in 0..t {
|
| 247 |
+
// Point overlap kernel at the ti-th slice of the pre-uploaded input.
|
| 248 |
+
// cudarc CudaSlice doesn't have a "view" per se, so we must copy the
|
| 249 |
+
// slice into the reusable inp_dev buffer. This is a D2D copy β much
|
| 250 |
+
// faster than H2D.
|
| 251 |
+
// (Alternative: rewrite kernel to accept an offset; deferred.)
|
| 252 |
+
let in_off = ti * input_bits;
|
| 253 |
+
// Use dtod_copy via raw slice indexing: cudarc exposes slice() for this.
|
| 254 |
+
let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
|
| 255 |
+
self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
|
| 256 |
+
|
| 257 |
+
// 1. sp_overlap
|
| 258 |
+
unsafe {
|
| 259 |
+
overlap_fn.clone().launch(
|
| 260 |
+
overlap_cfg,
|
| 261 |
+
(
|
| 262 |
+
&self.inp_dev,
|
| 263 |
+
&self.syn_bit,
|
| 264 |
+
&self.syn_perm,
|
| 265 |
+
&self.boost,
|
| 266 |
+
self.conn_thr,
|
| 267 |
+
self.synapses_per_col as u32,
|
| 268 |
+
n as u32,
|
| 269 |
+
&mut self.raw,
|
| 270 |
+
&mut self.boosted,
|
| 271 |
+
),
|
| 272 |
+
)?;
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// 2. Clear active_mask, then sp_topk
|
| 276 |
+
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 277 |
+
unsafe {
|
| 278 |
+
topk_fn.clone().launch(
|
| 279 |
+
topk_cfg,
|
| 280 |
+
(&self.boosted, n as u32, k as u32, &mut self.active_mask),
|
| 281 |
+
)?;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// 3. sp_learn
|
| 285 |
+
if learn {
|
| 286 |
+
unsafe {
|
| 287 |
+
learn_fn.clone().launch(
|
| 288 |
+
learn_cfg,
|
| 289 |
+
(
|
| 290 |
+
&self.active_mask,
|
| 291 |
+
&self.inp_dev,
|
| 292 |
+
&self.syn_bit,
|
| 293 |
+
&mut self.syn_perm,
|
| 294 |
+
self.inc,
|
| 295 |
+
self.dec,
|
| 296 |
+
self.synapses_per_col as u32,
|
| 297 |
+
n as u32,
|
| 298 |
+
),
|
| 299 |
+
)?;
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
// 4. duty update (device)
|
| 304 |
+
unsafe {
|
| 305 |
+
duty_fn.clone().launch(
|
| 306 |
+
duty_cfg,
|
| 307 |
+
(
|
| 308 |
+
&self.active_mask,
|
| 309 |
+
&self.raw,
|
| 310 |
+
&mut self.active_duty,
|
| 311 |
+
&mut self.overlap_duty,
|
| 312 |
+
&mut self.boost,
|
| 313 |
+
alpha,
|
| 314 |
+
1.0f32,
|
| 315 |
+
0.0f32,
|
| 316 |
+
0.0f32,
|
| 317 |
+
0u32,
|
| 318 |
+
n as u32,
|
| 319 |
+
),
|
| 320 |
+
)?;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
// 5. Boost update. Two modes:
|
| 324 |
+
// * strict_parity (tests): host-side exp for bit-exact match.
|
| 325 |
+
// * default (production): GPU expf is close enough and ~10x faster
|
| 326 |
+
// since we skip the D2H/H2D round-trip.
|
| 327 |
+
if learn && self.boost_strength > 0.0 {
|
| 328 |
+
if self.strict_parity {
|
| 329 |
+
let mut duty_host = vec![0f32; n];
|
| 330 |
+
self.dev
|
| 331 |
+
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 332 |
+
let sum: f32 = duty_host.iter().sum();
|
| 333 |
+
let mean = sum / (n as f32);
|
| 334 |
+
let mut boost_host = vec![0f32; n];
|
| 335 |
+
for i in 0..n {
|
| 336 |
+
boost_host[i] =
|
| 337 |
+
(-self.boost_strength * (duty_host[i] - mean)).exp();
|
| 338 |
+
}
|
| 339 |
+
self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
|
| 340 |
+
|
| 341 |
+
// Permanence bump (rare). Only evaluated in strict mode.
|
| 342 |
+
let mut ov_host = vec![0f32; n];
|
| 343 |
+
self.dev
|
| 344 |
+
.dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
|
| 345 |
+
let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
|
| 346 |
+
if max_ov > 0.0 {
|
| 347 |
+
let thr = 0.001f32 * max_ov;
|
| 348 |
+
let bump = self.inc * 0.1f32;
|
| 349 |
+
let bump_cols: Vec<u32> = ov_host
|
| 350 |
+
.iter()
|
| 351 |
+
.enumerate()
|
| 352 |
+
.filter_map(|(i, &o)| {
|
| 353 |
+
if o < thr { Some(i as u32) } else { None }
|
| 354 |
+
})
|
| 355 |
+
.collect();
|
| 356 |
+
if !bump_cols.is_empty() {
|
| 357 |
+
let s = self.synapses_per_col;
|
| 358 |
+
let mut perm_host = vec![0f32; n * s];
|
| 359 |
+
self.dev
|
| 360 |
+
.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
|
| 361 |
+
for &c in &bump_cols {
|
| 362 |
+
let base = (c as usize) * s;
|
| 363 |
+
for p in &mut perm_host[base..base + s] {
|
| 364 |
+
*p = (*p + bump).min(1.0);
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
} else {
|
| 371 |
+
// Fast path: fused mean + boost = expf(-strength*(ad-mean))
|
| 372 |
+
// in a single GPU block. Zero D2H, zero H2D β fully async.
|
| 373 |
+
let boost_fn = self
|
| 374 |
+
.dev
|
| 375 |
+
.get_func("htm_sp_boost_fused", "sp_boost_from_duty")
|
| 376 |
+
.expect("sp_boost_fused not loaded");
|
| 377 |
+
let boost_cfg = LaunchConfig {
|
| 378 |
+
grid_dim: (1, 1, 1),
|
| 379 |
+
block_dim: (1024, 1, 1),
|
| 380 |
+
shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
|
| 381 |
+
};
|
| 382 |
+
unsafe {
|
| 383 |
+
boost_fn.launch(
|
| 384 |
+
boost_cfg,
|
| 385 |
+
(
|
| 386 |
+
&self.active_duty,
|
| 387 |
+
&mut self.boost,
|
| 388 |
+
self.boost_strength,
|
| 389 |
+
n as u32,
|
| 390 |
+
),
|
| 391 |
+
)?;
|
| 392 |
+
}
|
| 393 |
+
}
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
// D2H the active_mask for this step. This is the single
|
| 397 |
+
// unavoidable sync point per step β CPU TM needs the active
|
| 398 |
+
// indices for its next state update. At 2048 bytes / step this
|
| 399 |
+
// is tiny in bandwidth but costs a full syncronize (~5-10ΞΌs).
|
| 400 |
+
self.dev
|
| 401 |
+
.dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
|
| 402 |
+
let co = ti * n;
|
| 403 |
+
cols_out[co..co + n].copy_from_slice(&self.host_mask);
|
| 404 |
+
// Extract active indices.
|
| 405 |
+
for (i, &b) in self.host_mask.iter().enumerate() {
|
| 406 |
+
if b != 0 {
|
| 407 |
+
active_indices_host.push(i as u32);
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
// Insert separator (u32::MAX) between steps to demarcate step boundaries.
|
| 411 |
+
active_indices_host.push(u32::MAX);
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
Ok(())
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
/// Fully-on-GPU batched SP + TM. Zero per-step host sync.
|
| 418 |
+
///
|
| 419 |
+
/// Inputs:
|
| 420 |
+
/// inputs_flat_dev : (T * input_bits) u8 already uploaded
|
| 421 |
+
/// cols_dev : (T * n_cols) u8 output β active-column mask per step
|
| 422 |
+
/// anom_dev : (T,) f32 output β anomaly score per step
|
| 423 |
+
/// tm : persistent GPU TemporalMemory for this region
|
| 424 |
+
#[allow(clippy::too_many_arguments)]
|
| 425 |
+
pub fn step_batch_with_tm(
|
| 426 |
+
&mut self,
|
| 427 |
+
inputs_flat_dev: &CudaSlice<u8>,
|
| 428 |
+
t: usize,
|
| 429 |
+
input_bits: usize,
|
| 430 |
+
learn: bool,
|
| 431 |
+
cols_dev: &mut CudaSlice<u8>,
|
| 432 |
+
anom_dev: &mut CudaSlice<f32>,
|
| 433 |
+
tm: &mut crate::gpu::tm_gpu::TemporalMemoryGpu,
|
| 434 |
+
) -> Result<(), DriverError> {
|
| 435 |
+
let n = self.n_columns;
|
| 436 |
+
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 437 |
+
debug_assert_eq!(cols_dev.len(), t * n);
|
| 438 |
+
debug_assert_eq!(anom_dev.len(), t);
|
| 439 |
+
|
| 440 |
+
let overlap_fn = self.dev.get_func("htm_sp_overlap", "sp_overlap").unwrap();
|
| 441 |
+
let topk_fn = self.dev.get_func("htm_sp_topk", "sp_topk_select").unwrap();
|
| 442 |
+
let learn_fn = self.dev.get_func("htm_sp_learn", "sp_learn").unwrap();
|
| 443 |
+
let duty_fn = self.dev.get_func("htm_sp_duty", "sp_duty_update").unwrap();
|
| 444 |
+
|
| 445 |
+
let overlap_cfg = LaunchConfig {
|
| 446 |
+
grid_dim: (n as u32, 1, 1),
|
| 447 |
+
block_dim: (128, 1, 1),
|
| 448 |
+
shared_mem_bytes: 0,
|
| 449 |
+
};
|
| 450 |
+
let topk_cfg = LaunchConfig {
|
| 451 |
+
grid_dim: (1, 1, 1),
|
| 452 |
+
block_dim: (256, 1, 1),
|
| 453 |
+
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 454 |
+
};
|
| 455 |
+
let learn_cfg = overlap_cfg;
|
| 456 |
+
let duty_cfg = LaunchConfig {
|
| 457 |
+
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 458 |
+
block_dim: (256, 1, 1),
|
| 459 |
+
shared_mem_bytes: 0,
|
| 460 |
+
};
|
| 461 |
+
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 462 |
+
|
| 463 |
+
for ti in 0..t {
|
| 464 |
+
let in_off = ti * input_bits;
|
| 465 |
+
let sub = inputs_flat_dev.slice(in_off..in_off + input_bits);
|
| 466 |
+
self.dev.dtod_copy(&sub, &mut self.inp_dev)?;
|
| 467 |
+
|
| 468 |
+
// 1. sp_overlap
|
| 469 |
+
unsafe {
|
| 470 |
+
overlap_fn.clone().launch(
|
| 471 |
+
overlap_cfg,
|
| 472 |
+
(
|
| 473 |
+
&self.inp_dev,
|
| 474 |
+
&self.syn_bit,
|
| 475 |
+
&self.syn_perm,
|
| 476 |
+
&self.boost,
|
| 477 |
+
self.conn_thr,
|
| 478 |
+
self.synapses_per_col as u32,
|
| 479 |
+
n as u32,
|
| 480 |
+
&mut self.raw,
|
| 481 |
+
&mut self.boosted,
|
| 482 |
+
),
|
| 483 |
+
)?;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
// 2. clear + sp_topk
|
| 487 |
+
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 488 |
+
unsafe {
|
| 489 |
+
topk_fn.clone().launch(
|
| 490 |
+
topk_cfg,
|
| 491 |
+
(&self.boosted, n as u32, k as u32, &mut self.active_mask),
|
| 492 |
+
)?;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
// 3. sp_learn
|
| 496 |
+
if learn {
|
| 497 |
+
unsafe {
|
| 498 |
+
learn_fn.clone().launch(
|
| 499 |
+
learn_cfg,
|
| 500 |
+
(
|
| 501 |
+
&self.active_mask,
|
| 502 |
+
&self.inp_dev,
|
| 503 |
+
&self.syn_bit,
|
| 504 |
+
&mut self.syn_perm,
|
| 505 |
+
self.inc,
|
| 506 |
+
self.dec,
|
| 507 |
+
self.synapses_per_col as u32,
|
| 508 |
+
n as u32,
|
| 509 |
+
),
|
| 510 |
+
)?;
|
| 511 |
+
}
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
// 4. duty update (stage 1: no-boost write)
|
| 515 |
+
unsafe {
|
| 516 |
+
duty_fn.clone().launch(
|
| 517 |
+
duty_cfg,
|
| 518 |
+
(
|
| 519 |
+
&self.active_mask,
|
| 520 |
+
&self.raw,
|
| 521 |
+
&mut self.active_duty,
|
| 522 |
+
&mut self.overlap_duty,
|
| 523 |
+
&mut self.boost,
|
| 524 |
+
alpha,
|
| 525 |
+
1.0f32,
|
| 526 |
+
0.0f32,
|
| 527 |
+
0.0f32,
|
| 528 |
+
0u32,
|
| 529 |
+
n as u32,
|
| 530 |
+
),
|
| 531 |
+
)?;
|
| 532 |
+
}
|
| 533 |
+
|
| 534 |
+
// 5. Boost update: fused GPU kernel (no D2H).
|
| 535 |
+
if learn && self.boost_strength > 0.0 {
|
| 536 |
+
let boost_fn = self.dev
|
| 537 |
+
.get_func("htm_sp_boost_fused", "sp_boost_from_duty")
|
| 538 |
+
.expect("sp_boost_fused not loaded");
|
| 539 |
+
let boost_cfg = LaunchConfig {
|
| 540 |
+
grid_dim: (1, 1, 1),
|
| 541 |
+
block_dim: (1024, 1, 1),
|
| 542 |
+
shared_mem_bytes: 32 * std::mem::size_of::<f32>() as u32,
|
| 543 |
+
};
|
| 544 |
+
unsafe {
|
| 545 |
+
boost_fn.launch(
|
| 546 |
+
boost_cfg,
|
| 547 |
+
(
|
| 548 |
+
&self.active_duty,
|
| 549 |
+
&mut self.boost,
|
| 550 |
+
self.boost_strength,
|
| 551 |
+
n as u32,
|
| 552 |
+
),
|
| 553 |
+
)?;
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
// 6. Copy active_mask slice into cols_dev[ti*n .. (ti+1)*n].
|
| 558 |
+
let mut dst_slice = cols_dev.slice_mut(ti * n..(ti + 1) * n);
|
| 559 |
+
self.dev.dtod_copy(&self.active_mask, &mut dst_slice)?;
|
| 560 |
+
|
| 561 |
+
// 7. GPU TM step: predict + activate + anomaly + learn, all on device.
|
| 562 |
+
tm.step(&self.active_mask, anom_dev, ti as u32, learn)?;
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
Ok(())
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
/// One SP step on the GPU. Returns sorted active-column indices.
|
| 569 |
+
pub fn compute(&mut self, input: &[u8], learn: bool) -> Result<Vec<u32>, DriverError> {
|
| 570 |
+
debug_assert_eq!(input.len(), self.input_bits);
|
| 571 |
+
let n = self.n_columns;
|
| 572 |
+
let k = ((self.sparsity * n as f32).round() as usize).max(1);
|
| 573 |
+
|
| 574 |
+
// 1. H2D input SDR.
|
| 575 |
+
self.dev.htod_sync_copy_into(input, &mut self.inp_dev)?;
|
| 576 |
+
|
| 577 |
+
// 2. Launch sp_overlap: grid=n_columns, block=128.
|
| 578 |
+
let overlap_fn = self
|
| 579 |
+
.dev
|
| 580 |
+
.get_func("htm_sp_overlap", "sp_overlap")
|
| 581 |
+
.expect("sp_overlap not loaded");
|
| 582 |
+
let overlap_cfg = LaunchConfig {
|
| 583 |
+
grid_dim: (n as u32, 1, 1),
|
| 584 |
+
block_dim: (128, 1, 1),
|
| 585 |
+
shared_mem_bytes: 0,
|
| 586 |
+
};
|
| 587 |
+
unsafe {
|
| 588 |
+
overlap_fn.launch(
|
| 589 |
+
overlap_cfg,
|
| 590 |
+
(
|
| 591 |
+
&self.inp_dev,
|
| 592 |
+
&self.syn_bit,
|
| 593 |
+
&self.syn_perm,
|
| 594 |
+
&self.boost,
|
| 595 |
+
self.conn_thr,
|
| 596 |
+
self.synapses_per_col as u32,
|
| 597 |
+
n as u32,
|
| 598 |
+
&mut self.raw,
|
| 599 |
+
&mut self.boosted,
|
| 600 |
+
),
|
| 601 |
+
)?;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
// 3. Launch sp_topk: single block, shared mem = n_columns * f32.
|
| 605 |
+
let topk_fn = self
|
| 606 |
+
.dev
|
| 607 |
+
.get_func("htm_sp_topk", "sp_topk_select")
|
| 608 |
+
.expect("sp_topk not loaded");
|
| 609 |
+
let topk_cfg = LaunchConfig {
|
| 610 |
+
grid_dim: (1, 1, 1),
|
| 611 |
+
block_dim: (256, 1, 1),
|
| 612 |
+
shared_mem_bytes: (n * std::mem::size_of::<f32>()) as u32,
|
| 613 |
+
};
|
| 614 |
+
// Clear active_mask first. memset_zeros avoids an H2D of a host
|
| 615 |
+
// zeroes vector every step.
|
| 616 |
+
self.dev.memset_zeros(&mut self.active_mask)?;
|
| 617 |
+
unsafe {
|
| 618 |
+
topk_fn.launch(
|
| 619 |
+
topk_cfg,
|
| 620 |
+
(
|
| 621 |
+
&self.boosted,
|
| 622 |
+
n as u32,
|
| 623 |
+
k as u32,
|
| 624 |
+
&mut self.active_mask,
|
| 625 |
+
),
|
| 626 |
+
)?;
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
// 4. Optional: sp_learn on active columns.
|
| 630 |
+
if learn {
|
| 631 |
+
let learn_fn = self
|
| 632 |
+
.dev
|
| 633 |
+
.get_func("htm_sp_learn", "sp_learn")
|
| 634 |
+
.expect("sp_learn not loaded");
|
| 635 |
+
let learn_cfg = LaunchConfig {
|
| 636 |
+
grid_dim: (n as u32, 1, 1),
|
| 637 |
+
block_dim: (128, 1, 1),
|
| 638 |
+
shared_mem_bytes: 0,
|
| 639 |
+
};
|
| 640 |
+
unsafe {
|
| 641 |
+
learn_fn.launch(
|
| 642 |
+
learn_cfg,
|
| 643 |
+
(
|
| 644 |
+
&self.active_mask,
|
| 645 |
+
&self.inp_dev,
|
| 646 |
+
&self.syn_bit,
|
| 647 |
+
&mut self.syn_perm,
|
| 648 |
+
self.inc,
|
| 649 |
+
self.dec,
|
| 650 |
+
self.synapses_per_col as u32,
|
| 651 |
+
n as u32,
|
| 652 |
+
),
|
| 653 |
+
)?;
|
| 654 |
+
}
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
// 5. Duty cycle + boost update. Always runs (matches CPU).
|
| 658 |
+
// We need mean_duty on the host β compute BEFORE the update (matches
|
| 659 |
+
// CPU sp.rs line 200-205 where mean is computed then written).
|
| 660 |
+
// Actually CPU computes mean of the PRE-update duty cycles too? Re-read:
|
| 661 |
+
// sp.rs lines 186-196 update duty cycles (pre-mean).
|
| 662 |
+
// Line 202: mean = sum(active_duty_cycle) / n β after update.
|
| 663 |
+
// Line 204: boost[i] = exp(-strength*(active_duty[i] - mean)).
|
| 664 |
+
// So mean is on POST-update values.
|
| 665 |
+
// Easiest: 1) run duty update with boost_strength=0 (skip boost calc),
|
| 666 |
+
// 2) D2H active_duty, compute mean, 3) run a boost-only kernel
|
| 667 |
+
// OR inline the exp() in a second launch with mean passed.
|
| 668 |
+
//
|
| 669 |
+
// For simplicity and correctness we fuse: run the duty kernel with
|
| 670 |
+
// mean=0 and boost_strength=0 (disables boost write), then D2H to
|
| 671 |
+
// compute mean, then re-launch with the true mean. Two launches, one
|
| 672 |
+
// tiny D2H (n Γ f32). At n=2048 this is 8KB per step β negligible.
|
| 673 |
+
let alpha = 1.0f32 / self.duty_period.max(1.0);
|
| 674 |
+
let duty_fn = self
|
| 675 |
+
.dev
|
| 676 |
+
.get_func("htm_sp_duty", "sp_duty_update")
|
| 677 |
+
.expect("sp_duty not loaded");
|
| 678 |
+
let duty_cfg = LaunchConfig {
|
| 679 |
+
grid_dim: ((n as u32 + 255) / 256, 1, 1),
|
| 680 |
+
block_dim: (256, 1, 1),
|
| 681 |
+
shared_mem_bytes: 0,
|
| 682 |
+
};
|
| 683 |
+
// Stage 1: update duty cycles (boost_strength=0 -> no write).
|
| 684 |
+
unsafe {
|
| 685 |
+
duty_fn.launch(
|
| 686 |
+
duty_cfg,
|
| 687 |
+
(
|
| 688 |
+
&self.active_mask,
|
| 689 |
+
&self.raw,
|
| 690 |
+
&mut self.active_duty,
|
| 691 |
+
&mut self.overlap_duty,
|
| 692 |
+
&mut self.boost,
|
| 693 |
+
alpha,
|
| 694 |
+
1.0f32, // stim_thr
|
| 695 |
+
0.0f32, // boost_strength = 0 -> skip write
|
| 696 |
+
0.0f32, // mean_duty (unused)
|
| 697 |
+
0u32, // learn_flag = 0
|
| 698 |
+
n as u32,
|
| 699 |
+
),
|
| 700 |
+
)?;
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
if learn && self.boost_strength > 0.0 && self.strict_parity {
|
| 704 |
+
// Boost update must bit-match CPU `f32::exp`, so we compute it on
|
| 705 |
+
// the host and copy back. Cost per step: 8KB D2H + 8KB H2D at n=2048.
|
| 706 |
+
// Critical for learning parity β CUDA expf (even without fast-math)
|
| 707 |
+
// uses different rounding for some inputs than host libm.
|
| 708 |
+
let mut duty_host = vec![0f32; n];
|
| 709 |
+
self.dev
|
| 710 |
+
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 711 |
+
let sum: f32 = duty_host.iter().sum();
|
| 712 |
+
let mean = sum / (n as f32);
|
| 713 |
+
let mut boost_host = vec![0f32; n];
|
| 714 |
+
for i in 0..n {
|
| 715 |
+
boost_host[i] = (-self.boost_strength * (duty_host[i] - mean)).exp();
|
| 716 |
+
}
|
| 717 |
+
self.dev.htod_sync_copy_into(&boost_host, &mut self.boost)?;
|
| 718 |
+
|
| 719 |
+
// CPU sp.rs 210-226: permanence bump for chronically under-stimulated
|
| 720 |
+
// columns. If overlap_duty_cycle[i] < 0.001 * max(overlap_duty_cycle),
|
| 721 |
+
// add inc*0.1 to every synapse of column i (clamped to 1.0).
|
| 722 |
+
// This runs only once per step and only for the rare cases, but we
|
| 723 |
+
// need it for bit-exact parity with CPU learn.
|
| 724 |
+
let mut ov_host = vec![0f32; n];
|
| 725 |
+
self.dev
|
| 726 |
+
.dtoh_sync_copy_into(&self.overlap_duty, &mut ov_host)?;
|
| 727 |
+
let max_ov = ov_host.iter().cloned().fold(0f32, f32::max);
|
| 728 |
+
if max_ov > 0.0 {
|
| 729 |
+
let thr = 0.001f32 * max_ov;
|
| 730 |
+
let bump = self.inc * 0.1f32;
|
| 731 |
+
// Find columns needing a bump. Usually empty. Rare β D2H/H2D
|
| 732 |
+
// of syn_perm is cheap (n*S*4 = 320KB at n=2048,S=40).
|
| 733 |
+
let bump_cols: Vec<u32> = ov_host
|
| 734 |
+
.iter()
|
| 735 |
+
.enumerate()
|
| 736 |
+
.filter_map(|(i, &o)| if o < thr { Some(i as u32) } else { None })
|
| 737 |
+
.collect();
|
| 738 |
+
if !bump_cols.is_empty() {
|
| 739 |
+
// Download, bump, upload. (Keeps implementation simple and
|
| 740 |
+
// bit-exact. Could kernelize later.)
|
| 741 |
+
let s = self.synapses_per_col;
|
| 742 |
+
let mut perm_host = vec![0f32; n * s];
|
| 743 |
+
self.dev.dtoh_sync_copy_into(&self.syn_perm, &mut perm_host)?;
|
| 744 |
+
for &c in &bump_cols {
|
| 745 |
+
let base = (c as usize) * s;
|
| 746 |
+
for p in &mut perm_host[base..base + s] {
|
| 747 |
+
*p = (*p + bump).min(1.0);
|
| 748 |
+
}
|
| 749 |
+
}
|
| 750 |
+
self.dev.htod_sync_copy_into(&perm_host, &mut self.syn_perm)?;
|
| 751 |
+
}
|
| 752 |
+
}
|
| 753 |
+
} else if learn && self.boost_strength > 0.0 {
|
| 754 |
+
// Fast path: GPU-side boost using the already-loaded duty kernel.
|
| 755 |
+
let mut duty_host = vec![0f32; n];
|
| 756 |
+
self.dev
|
| 757 |
+
.dtoh_sync_copy_into(&self.active_duty, &mut duty_host)?;
|
| 758 |
+
let sum: f32 = duty_host.iter().sum();
|
| 759 |
+
let mean = sum / (n as f32);
|
| 760 |
+
let boost_fn = self
|
| 761 |
+
.dev
|
| 762 |
+
.get_func("htm_sp_duty", "sp_duty_update")
|
| 763 |
+
.expect("sp_duty not loaded");
|
| 764 |
+
unsafe {
|
| 765 |
+
boost_fn.launch(
|
| 766 |
+
duty_cfg,
|
| 767 |
+
(
|
| 768 |
+
&self.active_mask,
|
| 769 |
+
&self.raw,
|
| 770 |
+
&mut self.active_duty,
|
| 771 |
+
&mut self.overlap_duty,
|
| 772 |
+
&mut self.boost,
|
| 773 |
+
0.0f32,
|
| 774 |
+
1.0f32,
|
| 775 |
+
self.boost_strength,
|
| 776 |
+
mean,
|
| 777 |
+
1u32,
|
| 778 |
+
n as u32,
|
| 779 |
+
),
|
| 780 |
+
)?;
|
| 781 |
+
}
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
// 6. D2H active_mask and convert to sorted index list.
|
| 785 |
+
self.dev
|
| 786 |
+
.dtoh_sync_copy_into(&self.active_mask, &mut self.host_mask)?;
|
| 787 |
+
let mut active: Vec<u32> = Vec::with_capacity(k);
|
| 788 |
+
for (i, &b) in self.host_mask.iter().enumerate() {
|
| 789 |
+
if b != 0 {
|
| 790 |
+
active.push(i as u32);
|
| 791 |
+
}
|
| 792 |
+
}
|
| 793 |
+
debug_assert_eq!(active.len(), k, "SP must emit exactly k winners");
|
| 794 |
+
Ok(active)
|
| 795 |
+
}
|
| 796 |
+
}
|
overlay/htm_rust/src/gpu/tm_gpu.rs
CHANGED
|
@@ -1,460 +1,460 @@
|
|
| 1 |
-
//! GPU Temporal Memory.
|
| 2 |
-
//!
|
| 3 |
-
//! Flat device storage. Pre-allocated segment slab:
|
| 4 |
-
//! n_cells = n_columns * cells_per_column
|
| 5 |
-
//! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL
|
| 6 |
-
//! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT
|
| 7 |
-
//!
|
| 8 |
-
//! Defaults (CPU parity targets relaxed on GPU to keep memory tractable):
|
| 9 |
-
//! MAX_SEGMENTS_PER_CELL = 16
|
| 10 |
-
//! MAX_SYN_PER_SEGMENT = 32
|
| 11 |
-
//!
|
| 12 |
-
//! At n_cells = 65536:
|
| 13 |
-
//! n_segments_max = 1_048_576 (~1M)
|
| 14 |
-
//! n_synapses_max = 33_554_432 (~33M)
|
| 15 |
-
//! Storage:
|
| 16 |
-
//! syn_presyn : u32 Γ 33M = 128 MB
|
| 17 |
-
//! syn_perm : i16 Γ 33M = 64 MB
|
| 18 |
-
//! seg_cell : u32 Γ 1M = 4 MB
|
| 19 |
-
//! seg_syn_n : u32 Γ 1M = 4 MB
|
| 20 |
-
//! misc bitsets etc ~ <1 MB
|
| 21 |
-
//! -------------------------------
|
| 22 |
-
//! Total per region ~200 MB
|
| 23 |
-
//!
|
| 24 |
-
//! Permanences are stored as i16 scaled by 32767 (β [0, 32767] represents
|
| 25 |
-
//! [0.0, 1.0]). inc/dec are provided pre-scaled.
|
| 26 |
-
|
| 27 |
-
use std::sync::Arc;
|
| 28 |
-
|
| 29 |
-
use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig};
|
| 30 |
-
use cudarc::nvrtc::Ptx;
|
| 31 |
-
|
| 32 |
-
/// Packed config struct passed by value to TM kernels to stay under
|
| 33 |
-
/// cudarc's 12-tuple launch limit. Layout must match the C-side
|
| 34 |
-
/// `TmConfig` struct declared in each kernel.
|
| 35 |
-
#[repr(C)]
|
| 36 |
-
#[derive(Clone, Copy)]
|
| 37 |
-
pub struct TmConfig {
|
| 38 |
-
pub activation_threshold: u32,
|
| 39 |
-
pub learning_threshold: u32,
|
| 40 |
-
pub cells_per_column: u32,
|
| 41 |
-
pub synapses_per_segment: u32,
|
| 42 |
-
pub n_segments: u32,
|
| 43 |
-
pub n_cells: u32,
|
| 44 |
-
pub max_segments_per_cell: u32,
|
| 45 |
-
pub max_new_synapses: u32,
|
| 46 |
-
pub conn_thr_i16: i32, // i16 widened to i32 for alignment
|
| 47 |
-
pub perm_inc_i16: i32,
|
| 48 |
-
pub perm_dec_i16: i32,
|
| 49 |
-
pub predicted_seg_dec_i16: i32,
|
| 50 |
-
pub initial_perm_i16: i32,
|
| 51 |
-
pub iter_seed: u32,
|
| 52 |
-
pub n_cols: u32,
|
| 53 |
-
pub bits_words: u32,
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
unsafe impl DeviceRepr for TmConfig {}
|
| 57 |
-
|
| 58 |
-
// Embedded PTX.
|
| 59 |
-
const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx"));
|
| 60 |
-
const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx"));
|
| 61 |
-
const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx"));
|
| 62 |
-
const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx"));
|
| 63 |
-
const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx"));
|
| 64 |
-
const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx"));
|
| 65 |
-
const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx"));
|
| 66 |
-
|
| 67 |
-
/// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model:
|
| 68 |
-
/// n_cells = 2048 Γ 32 = 65_536
|
| 69 |
-
/// n_segments_max = n_cells Γ MAX_SEGMENTS_PER_CELL
|
| 70 |
-
/// n_synapses_max = n_segments_max Γ MAX_SYN_PER_SEGMENT
|
| 71 |
-
///
|
| 72 |
-
/// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region).
|
| 73 |
-
/// The training loop runs with `reset_each_forward=True`, so segment counts
|
| 74 |
-
/// per window stay well below 32K (typical: ~n_cols new segs per step until
|
| 75 |
-
/// the first matching segment is reused; in a 2048-step window that plateaus
|
| 76 |
-
/// around ~5K total live segments). The 262K ceiling is generous headroom.
|
| 77 |
-
pub const MAX_SEGMENTS_PER_CELL: usize = 4;
|
| 78 |
-
pub const MAX_SYN_PER_SEGMENT: usize = 20;
|
| 79 |
-
|
| 80 |
-
const PERM_SCALE: f32 = 32767.0;
|
| 81 |
-
|
| 82 |
-
fn perm_f32_to_i16(x: f32) -> i16 {
|
| 83 |
-
let clamped = x.clamp(0.0, 1.0);
|
| 84 |
-
(clamped * PERM_SCALE).round() as i16
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
pub struct TemporalMemoryGpu {
|
| 88 |
-
dev: Arc<CudaDevice>,
|
| 89 |
-
|
| 90 |
-
// Config mirror
|
| 91 |
-
pub n_columns: usize,
|
| 92 |
-
pub cells_per_column: usize,
|
| 93 |
-
pub activation_threshold: u32,
|
| 94 |
-
pub learning_threshold: u32,
|
| 95 |
-
pub initial_perm_i16: i16,
|
| 96 |
-
pub conn_thr_i16: i16,
|
| 97 |
-
pub perm_inc_i16: i16,
|
| 98 |
-
pub perm_dec_i16: i16,
|
| 99 |
-
pub predicted_seg_dec_i16: i16,
|
| 100 |
-
pub max_new_synapse_count: u32,
|
| 101 |
-
|
| 102 |
-
// Sizes
|
| 103 |
-
pub n_cells: usize,
|
| 104 |
-
pub n_segments_max: usize,
|
| 105 |
-
pub bits_words: usize, // n_cells / 32
|
| 106 |
-
|
| 107 |
-
// Persistent device buffers
|
| 108 |
-
seg_cell_id: CudaSlice<u32>,
|
| 109 |
-
seg_syn_count: CudaSlice<u32>,
|
| 110 |
-
syn_presyn: CudaSlice<u32>,
|
| 111 |
-
syn_perm: CudaSlice<i16>,
|
| 112 |
-
cell_seg_count: CudaSlice<u32>,
|
| 113 |
-
|
| 114 |
-
cell_active_bits: CudaSlice<u32>,
|
| 115 |
-
cell_winner_bits: CudaSlice<u32>,
|
| 116 |
-
cell_predictive_bits: CudaSlice<u32>,
|
| 117 |
-
prev_active_bits: CudaSlice<u32>,
|
| 118 |
-
prev_winner_bits: CudaSlice<u32>,
|
| 119 |
-
|
| 120 |
-
col_predicted: CudaSlice<u8>,
|
| 121 |
-
seg_num_active_conn: CudaSlice<u32>,
|
| 122 |
-
seg_num_active_pot: CudaSlice<u32>,
|
| 123 |
-
unpredicted_count: CudaSlice<u32>,
|
| 124 |
-
burst_cols_flat: CudaSlice<u32>,
|
| 125 |
-
burst_cols_count: CudaSlice<u32>,
|
| 126 |
-
col_best_match: CudaSlice<u32>,
|
| 127 |
-
|
| 128 |
-
iter_counter: u32,
|
| 129 |
-
}
|
| 130 |
-
|
| 131 |
-
impl TemporalMemoryGpu {
|
| 132 |
-
pub fn new(
|
| 133 |
-
dev: Arc<CudaDevice>,
|
| 134 |
-
n_columns: usize,
|
| 135 |
-
cells_per_column: usize,
|
| 136 |
-
) -> Result<Self, DriverError> {
|
| 137 |
-
let n_cells = n_columns * cells_per_column;
|
| 138 |
-
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
|
| 139 |
-
let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL;
|
| 140 |
-
let bits_words = n_cells / 32;
|
| 141 |
-
|
| 142 |
-
// Numenta defaults.
|
| 143 |
-
let activation_threshold = 15u32;
|
| 144 |
-
let learning_threshold = 13u32;
|
| 145 |
-
let initial_perm_i16 = perm_f32_to_i16(0.21);
|
| 146 |
-
let conn_thr_i16 = perm_f32_to_i16(0.50);
|
| 147 |
-
let perm_inc_i16 = perm_f32_to_i16(0.10);
|
| 148 |
-
let perm_dec_i16 = perm_f32_to_i16(0.10);
|
| 149 |
-
let predicted_seg_dec_i16 = perm_f32_to_i16(0.10);
|
| 150 |
-
let max_new_synapse_count = 20u32;
|
| 151 |
-
|
| 152 |
-
// Allocate buffers.
|
| 153 |
-
let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max];
|
| 154 |
-
let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?;
|
| 155 |
-
let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 156 |
-
let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
|
| 157 |
-
let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
|
| 158 |
-
let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?;
|
| 159 |
-
|
| 160 |
-
let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 161 |
-
let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 162 |
-
let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 163 |
-
let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 164 |
-
let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 165 |
-
|
| 166 |
-
let col_predicted = dev.alloc_zeros::<u8>(n_columns)?;
|
| 167 |
-
let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 168 |
-
let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 169 |
-
let unpredicted_count = dev.alloc_zeros::<u32>(1)?;
|
| 170 |
-
// Bursting columns for one step bounded by n_columns.
|
| 171 |
-
let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?;
|
| 172 |
-
let burst_cols_count = dev.alloc_zeros::<u32>(1)?;
|
| 173 |
-
let col_best_match = dev.alloc_zeros::<u32>(n_columns)?;
|
| 174 |
-
|
| 175 |
-
// Load PTX modules.
|
| 176 |
-
let modules = [
|
| 177 |
-
("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"),
|
| 178 |
-
("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"),
|
| 179 |
-
("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"),
|
| 180 |
-
("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"),
|
| 181 |
-
("htm_tm_grow", PTX_TM_GROW, "tm_grow"),
|
| 182 |
-
("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"),
|
| 183 |
-
("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"),
|
| 184 |
-
];
|
| 185 |
-
for (modname, ptx, fnname) in modules {
|
| 186 |
-
if dev.get_func(modname, fnname).is_none() {
|
| 187 |
-
dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
|
| 188 |
-
}
|
| 189 |
-
}
|
| 190 |
-
|
| 191 |
-
Ok(Self {
|
| 192 |
-
dev,
|
| 193 |
-
n_columns,
|
| 194 |
-
cells_per_column,
|
| 195 |
-
activation_threshold,
|
| 196 |
-
learning_threshold,
|
| 197 |
-
initial_perm_i16,
|
| 198 |
-
conn_thr_i16,
|
| 199 |
-
perm_inc_i16,
|
| 200 |
-
perm_dec_i16,
|
| 201 |
-
predicted_seg_dec_i16,
|
| 202 |
-
max_new_synapse_count,
|
| 203 |
-
n_cells,
|
| 204 |
-
n_segments_max,
|
| 205 |
-
bits_words,
|
| 206 |
-
seg_cell_id,
|
| 207 |
-
seg_syn_count,
|
| 208 |
-
syn_presyn,
|
| 209 |
-
syn_perm,
|
| 210 |
-
cell_seg_count,
|
| 211 |
-
cell_active_bits,
|
| 212 |
-
cell_winner_bits,
|
| 213 |
-
cell_predictive_bits,
|
| 214 |
-
prev_active_bits,
|
| 215 |
-
prev_winner_bits,
|
| 216 |
-
col_predicted,
|
| 217 |
-
seg_num_active_conn,
|
| 218 |
-
seg_num_active_pot,
|
| 219 |
-
unpredicted_count,
|
| 220 |
-
burst_cols_flat,
|
| 221 |
-
burst_cols_count,
|
| 222 |
-
col_best_match,
|
| 223 |
-
iter_counter: 0,
|
| 224 |
-
})
|
| 225 |
-
}
|
| 226 |
-
|
| 227 |
-
// --- Fused-path accessors ---
|
| 228 |
-
pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id }
|
| 229 |
-
pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count }
|
| 230 |
-
pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn }
|
| 231 |
-
pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm }
|
| 232 |
-
pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count }
|
| 233 |
-
|
| 234 |
-
/// Hard reset β clear everything (predictive + active + segments).
|
| 235 |
-
pub fn reset(&mut self) -> Result<(), DriverError> {
|
| 236 |
-
// Restore "unused" sentinel in seg_cell_id.
|
| 237 |
-
let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max];
|
| 238 |
-
self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?;
|
| 239 |
-
self.dev.memset_zeros(&mut self.seg_syn_count)?;
|
| 240 |
-
self.dev.memset_zeros(&mut self.cell_seg_count)?;
|
| 241 |
-
self.dev.memset_zeros(&mut self.cell_active_bits)?;
|
| 242 |
-
self.dev.memset_zeros(&mut self.cell_winner_bits)?;
|
| 243 |
-
self.dev.memset_zeros(&mut self.cell_predictive_bits)?;
|
| 244 |
-
self.dev.memset_zeros(&mut self.prev_active_bits)?;
|
| 245 |
-
self.dev.memset_zeros(&mut self.prev_winner_bits)?;
|
| 246 |
-
self.dev.memset_zeros(&mut self.col_best_match)?;
|
| 247 |
-
self.iter_counter = 0;
|
| 248 |
-
Ok(())
|
| 249 |
-
}
|
| 250 |
-
|
| 251 |
-
fn build_cfg(&self) -> TmConfig {
|
| 252 |
-
TmConfig {
|
| 253 |
-
activation_threshold: self.activation_threshold,
|
| 254 |
-
learning_threshold: self.learning_threshold,
|
| 255 |
-
cells_per_column: self.cells_per_column as u32,
|
| 256 |
-
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
|
| 257 |
-
n_segments: self.n_segments_max as u32,
|
| 258 |
-
n_cells: self.n_cells as u32,
|
| 259 |
-
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
|
| 260 |
-
max_new_synapses: self.max_new_synapse_count,
|
| 261 |
-
conn_thr_i16: self.conn_thr_i16 as i32,
|
| 262 |
-
perm_inc_i16: self.perm_inc_i16 as i32,
|
| 263 |
-
perm_dec_i16: self.perm_dec_i16 as i32,
|
| 264 |
-
predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32,
|
| 265 |
-
initial_perm_i16: self.initial_perm_i16 as i32,
|
| 266 |
-
iter_seed: self.iter_counter,
|
| 267 |
-
n_cols: self.n_columns as u32,
|
| 268 |
-
bits_words: self.bits_words as u32,
|
| 269 |
-
}
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
/// Run one TM step on the GPU. Takes the SP active-column mask (u8, already
|
| 273 |
-
/// on device) and writes `anomaly_out[t_slot]`.
|
| 274 |
-
pub fn step(
|
| 275 |
-
&mut self,
|
| 276 |
-
sp_active_mask: &CudaSlice<u8>,
|
| 277 |
-
anomaly_out: &mut CudaSlice<f32>,
|
| 278 |
-
t_slot: u32,
|
| 279 |
-
learn: bool,
|
| 280 |
-
) -> Result<(), DriverError> {
|
| 281 |
-
let n_cells = self.n_cells;
|
| 282 |
-
let n_cols = self.n_columns;
|
| 283 |
-
|
| 284 |
-
let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap();
|
| 285 |
-
let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap();
|
| 286 |
-
let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap();
|
| 287 |
-
let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap();
|
| 288 |
-
let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap();
|
| 289 |
-
let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap();
|
| 290 |
-
let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap();
|
| 291 |
-
|
| 292 |
-
self.iter_counter = self.iter_counter.wrapping_add(1);
|
| 293 |
-
let cfg_val = self.build_cfg();
|
| 294 |
-
|
| 295 |
-
// 0. Per-step reset.
|
| 296 |
-
let reset_words = self.bits_words.max(n_cols);
|
| 297 |
-
let reset_cfg = LaunchConfig {
|
| 298 |
-
grid_dim: (((reset_words + 255) / 256) as u32, 1, 1),
|
| 299 |
-
block_dim: (256, 1, 1),
|
| 300 |
-
shared_mem_bytes: 0,
|
| 301 |
-
};
|
| 302 |
-
unsafe {
|
| 303 |
-
reset_fn.clone().launch(
|
| 304 |
-
reset_cfg,
|
| 305 |
-
(
|
| 306 |
-
&mut self.cell_active_bits,
|
| 307 |
-
&mut self.cell_winner_bits,
|
| 308 |
-
&mut self.cell_predictive_bits,
|
| 309 |
-
&mut self.prev_active_bits,
|
| 310 |
-
&mut self.prev_winner_bits,
|
| 311 |
-
&mut self.col_predicted,
|
| 312 |
-
&mut self.unpredicted_count,
|
| 313 |
-
&mut self.burst_cols_count,
|
| 314 |
-
&mut self.col_best_match,
|
| 315 |
-
self.bits_words as u32,
|
| 316 |
-
n_cols as u32,
|
| 317 |
-
),
|
| 318 |
-
)?;
|
| 319 |
-
}
|
| 320 |
-
|
| 321 |
-
// 1. Predict (grid = n_cells; each block iterates its cell's segments).
|
| 322 |
-
let predict_cfg = LaunchConfig {
|
| 323 |
-
grid_dim: (n_cells as u32, 1, 1),
|
| 324 |
-
block_dim: (32, 1, 1),
|
| 325 |
-
shared_mem_bytes: 0,
|
| 326 |
-
};
|
| 327 |
-
unsafe {
|
| 328 |
-
predict_fn.clone().launch(
|
| 329 |
-
predict_cfg,
|
| 330 |
-
(
|
| 331 |
-
&self.seg_cell_id,
|
| 332 |
-
&self.seg_syn_count,
|
| 333 |
-
&self.syn_presyn,
|
| 334 |
-
&self.syn_perm,
|
| 335 |
-
&self.prev_active_bits,
|
| 336 |
-
&mut self.cell_predictive_bits,
|
| 337 |
-
&mut self.col_predicted,
|
| 338 |
-
&mut self.seg_num_active_conn,
|
| 339 |
-
&mut self.seg_num_active_pot,
|
| 340 |
-
&mut self.col_best_match,
|
| 341 |
-
&self.cell_seg_count,
|
| 342 |
-
cfg_val,
|
| 343 |
-
),
|
| 344 |
-
)?;
|
| 345 |
-
}
|
| 346 |
-
|
| 347 |
-
// 2. Activate.
|
| 348 |
-
let activate_cfg = LaunchConfig {
|
| 349 |
-
grid_dim: (((n_cols + 255) / 256) as u32, 1, 1),
|
| 350 |
-
block_dim: (256, 1, 1),
|
| 351 |
-
shared_mem_bytes: 0,
|
| 352 |
-
};
|
| 353 |
-
unsafe {
|
| 354 |
-
activate_fn.clone().launch(
|
| 355 |
-
activate_cfg,
|
| 356 |
-
(
|
| 357 |
-
sp_active_mask,
|
| 358 |
-
&self.col_predicted,
|
| 359 |
-
&self.cell_predictive_bits,
|
| 360 |
-
&mut self.cell_active_bits,
|
| 361 |
-
&mut self.cell_winner_bits,
|
| 362 |
-
&mut self.unpredicted_count,
|
| 363 |
-
&mut self.burst_cols_flat,
|
| 364 |
-
&mut self.burst_cols_count,
|
| 365 |
-
cfg_val,
|
| 366 |
-
),
|
| 367 |
-
)?;
|
| 368 |
-
}
|
| 369 |
-
|
| 370 |
-
// 3. Anomaly.
|
| 371 |
-
let anom_cfg = LaunchConfig {
|
| 372 |
-
grid_dim: (1, 1, 1),
|
| 373 |
-
block_dim: (256, 1, 1),
|
| 374 |
-
shared_mem_bytes: 0,
|
| 375 |
-
};
|
| 376 |
-
unsafe {
|
| 377 |
-
anom_fn.clone().launch(
|
| 378 |
-
anom_cfg,
|
| 379 |
-
(
|
| 380 |
-
sp_active_mask,
|
| 381 |
-
&self.unpredicted_count,
|
| 382 |
-
anomaly_out,
|
| 383 |
-
t_slot,
|
| 384 |
-
n_cols as u32,
|
| 385 |
-
),
|
| 386 |
-
)?;
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
-
if learn {
|
| 390 |
-
// 4. Reinforce (grid = n_cells).
|
| 391 |
-
let learn_cfg = LaunchConfig {
|
| 392 |
-
grid_dim: (n_cells as u32, 1, 1),
|
| 393 |
-
block_dim: (32, 1, 1),
|
| 394 |
-
shared_mem_bytes: 0,
|
| 395 |
-
};
|
| 396 |
-
unsafe {
|
| 397 |
-
learn_fn.clone().launch(
|
| 398 |
-
learn_cfg,
|
| 399 |
-
(
|
| 400 |
-
&self.seg_cell_id,
|
| 401 |
-
&self.seg_syn_count,
|
| 402 |
-
&self.syn_presyn,
|
| 403 |
-
&mut self.syn_perm,
|
| 404 |
-
&self.seg_num_active_conn,
|
| 405 |
-
&self.prev_active_bits,
|
| 406 |
-
sp_active_mask,
|
| 407 |
-
&self.col_predicted,
|
| 408 |
-
&self.cell_seg_count,
|
| 409 |
-
cfg_val,
|
| 410 |
-
),
|
| 411 |
-
)?;
|
| 412 |
-
}
|
| 413 |
-
|
| 414 |
-
// 5. Punish.
|
| 415 |
-
unsafe {
|
| 416 |
-
punish_fn.clone().launch(
|
| 417 |
-
learn_cfg,
|
| 418 |
-
(
|
| 419 |
-
&self.seg_cell_id,
|
| 420 |
-
&self.seg_syn_count,
|
| 421 |
-
&self.syn_presyn,
|
| 422 |
-
&mut self.syn_perm,
|
| 423 |
-
&self.seg_num_active_pot,
|
| 424 |
-
&self.prev_active_bits,
|
| 425 |
-
sp_active_mask,
|
| 426 |
-
&self.cell_seg_count,
|
| 427 |
-
cfg_val,
|
| 428 |
-
),
|
| 429 |
-
)?;
|
| 430 |
-
}
|
| 431 |
-
|
| 432 |
-
// 6. Grow.
|
| 433 |
-
let grow_cfg = LaunchConfig {
|
| 434 |
-
grid_dim: (n_cols as u32, 1, 1),
|
| 435 |
-
block_dim: (32, 1, 1),
|
| 436 |
-
shared_mem_bytes: 0,
|
| 437 |
-
};
|
| 438 |
-
unsafe {
|
| 439 |
-
grow_fn.clone().launch(
|
| 440 |
-
grow_cfg,
|
| 441 |
-
(
|
| 442 |
-
&mut self.seg_cell_id,
|
| 443 |
-
&mut self.seg_syn_count,
|
| 444 |
-
&mut self.syn_presyn,
|
| 445 |
-
&mut self.syn_perm,
|
| 446 |
-
&mut self.cell_seg_count,
|
| 447 |
-
&self.burst_cols_flat,
|
| 448 |
-
&self.burst_cols_count,
|
| 449 |
-
&self.prev_winner_bits,
|
| 450 |
-
&self.prev_active_bits,
|
| 451 |
-
&self.col_best_match,
|
| 452 |
-
cfg_val,
|
| 453 |
-
),
|
| 454 |
-
)?;
|
| 455 |
-
}
|
| 456 |
-
}
|
| 457 |
-
|
| 458 |
-
Ok(())
|
| 459 |
-
}
|
| 460 |
-
}
|
|
|
|
| 1 |
+
//! GPU Temporal Memory.
|
| 2 |
+
//!
|
| 3 |
+
//! Flat device storage. Pre-allocated segment slab:
|
| 4 |
+
//! n_cells = n_columns * cells_per_column
|
| 5 |
+
//! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL
|
| 6 |
+
//! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT
|
| 7 |
+
//!
|
| 8 |
+
//! Defaults (CPU parity targets relaxed on GPU to keep memory tractable):
|
| 9 |
+
//! MAX_SEGMENTS_PER_CELL = 16
|
| 10 |
+
//! MAX_SYN_PER_SEGMENT = 32
|
| 11 |
+
//!
|
| 12 |
+
//! At n_cells = 65536:
|
| 13 |
+
//! n_segments_max = 1_048_576 (~1M)
|
| 14 |
+
//! n_synapses_max = 33_554_432 (~33M)
|
| 15 |
+
//! Storage:
|
| 16 |
+
//! syn_presyn : u32 Γ 33M = 128 MB
|
| 17 |
+
//! syn_perm : i16 Γ 33M = 64 MB
|
| 18 |
+
//! seg_cell : u32 Γ 1M = 4 MB
|
| 19 |
+
//! seg_syn_n : u32 Γ 1M = 4 MB
|
| 20 |
+
//! misc bitsets etc ~ <1 MB
|
| 21 |
+
//! -------------------------------
|
| 22 |
+
//! Total per region ~200 MB
|
| 23 |
+
//!
|
| 24 |
+
//! Permanences are stored as i16 scaled by 32767 (β [0, 32767] represents
|
| 25 |
+
//! [0.0, 1.0]). inc/dec are provided pre-scaled.
|
| 26 |
+
|
| 27 |
+
use std::sync::Arc;
|
| 28 |
+
|
| 29 |
+
use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig};
|
| 30 |
+
use cudarc::nvrtc::Ptx;
|
| 31 |
+
|
| 32 |
+
/// Packed config struct passed by value to TM kernels to stay under
|
| 33 |
+
/// cudarc's 12-tuple launch limit. Layout must match the C-side
|
| 34 |
+
/// `TmConfig` struct declared in each kernel.
|
| 35 |
+
#[repr(C)]
|
| 36 |
+
#[derive(Clone, Copy)]
|
| 37 |
+
pub struct TmConfig {
|
| 38 |
+
pub activation_threshold: u32,
|
| 39 |
+
pub learning_threshold: u32,
|
| 40 |
+
pub cells_per_column: u32,
|
| 41 |
+
pub synapses_per_segment: u32,
|
| 42 |
+
pub n_segments: u32,
|
| 43 |
+
pub n_cells: u32,
|
| 44 |
+
pub max_segments_per_cell: u32,
|
| 45 |
+
pub max_new_synapses: u32,
|
| 46 |
+
pub conn_thr_i16: i32, // i16 widened to i32 for alignment
|
| 47 |
+
pub perm_inc_i16: i32,
|
| 48 |
+
pub perm_dec_i16: i32,
|
| 49 |
+
pub predicted_seg_dec_i16: i32,
|
| 50 |
+
pub initial_perm_i16: i32,
|
| 51 |
+
pub iter_seed: u32,
|
| 52 |
+
pub n_cols: u32,
|
| 53 |
+
pub bits_words: u32,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
unsafe impl DeviceRepr for TmConfig {}
|
| 57 |
+
|
| 58 |
+
// Embedded PTX.
|
| 59 |
+
const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx"));
|
| 60 |
+
const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx"));
|
| 61 |
+
const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx"));
|
| 62 |
+
const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx"));
|
| 63 |
+
const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx"));
|
| 64 |
+
const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx"));
|
| 65 |
+
const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx"));
|
| 66 |
+
|
| 67 |
+
/// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model:
|
| 68 |
+
/// n_cells = 2048 Γ 32 = 65_536
|
| 69 |
+
/// n_segments_max = n_cells Γ MAX_SEGMENTS_PER_CELL
|
| 70 |
+
/// n_synapses_max = n_segments_max Γ MAX_SYN_PER_SEGMENT
|
| 71 |
+
///
|
| 72 |
+
/// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region).
|
| 73 |
+
/// The training loop runs with `reset_each_forward=True`, so segment counts
|
| 74 |
+
/// per window stay well below 32K (typical: ~n_cols new segs per step until
|
| 75 |
+
/// the first matching segment is reused; in a 2048-step window that plateaus
|
| 76 |
+
/// around ~5K total live segments). The 262K ceiling is generous headroom.
|
| 77 |
+
pub const MAX_SEGMENTS_PER_CELL: usize = 4;
|
| 78 |
+
pub const MAX_SYN_PER_SEGMENT: usize = 20;
|
| 79 |
+
|
| 80 |
+
const PERM_SCALE: f32 = 32767.0;
|
| 81 |
+
|
| 82 |
+
fn perm_f32_to_i16(x: f32) -> i16 {
|
| 83 |
+
let clamped = x.clamp(0.0, 1.0);
|
| 84 |
+
(clamped * PERM_SCALE).round() as i16
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
pub struct TemporalMemoryGpu {
|
| 88 |
+
dev: Arc<CudaDevice>,
|
| 89 |
+
|
| 90 |
+
// Config mirror
|
| 91 |
+
pub n_columns: usize,
|
| 92 |
+
pub cells_per_column: usize,
|
| 93 |
+
pub activation_threshold: u32,
|
| 94 |
+
pub learning_threshold: u32,
|
| 95 |
+
pub initial_perm_i16: i16,
|
| 96 |
+
pub conn_thr_i16: i16,
|
| 97 |
+
pub perm_inc_i16: i16,
|
| 98 |
+
pub perm_dec_i16: i16,
|
| 99 |
+
pub predicted_seg_dec_i16: i16,
|
| 100 |
+
pub max_new_synapse_count: u32,
|
| 101 |
+
|
| 102 |
+
// Sizes
|
| 103 |
+
pub n_cells: usize,
|
| 104 |
+
pub n_segments_max: usize,
|
| 105 |
+
pub bits_words: usize, // n_cells / 32
|
| 106 |
+
|
| 107 |
+
// Persistent device buffers
|
| 108 |
+
seg_cell_id: CudaSlice<u32>,
|
| 109 |
+
seg_syn_count: CudaSlice<u32>,
|
| 110 |
+
syn_presyn: CudaSlice<u32>,
|
| 111 |
+
syn_perm: CudaSlice<i16>,
|
| 112 |
+
cell_seg_count: CudaSlice<u32>,
|
| 113 |
+
|
| 114 |
+
cell_active_bits: CudaSlice<u32>,
|
| 115 |
+
cell_winner_bits: CudaSlice<u32>,
|
| 116 |
+
cell_predictive_bits: CudaSlice<u32>,
|
| 117 |
+
prev_active_bits: CudaSlice<u32>,
|
| 118 |
+
prev_winner_bits: CudaSlice<u32>,
|
| 119 |
+
|
| 120 |
+
col_predicted: CudaSlice<u8>,
|
| 121 |
+
seg_num_active_conn: CudaSlice<u32>,
|
| 122 |
+
seg_num_active_pot: CudaSlice<u32>,
|
| 123 |
+
unpredicted_count: CudaSlice<u32>,
|
| 124 |
+
burst_cols_flat: CudaSlice<u32>,
|
| 125 |
+
burst_cols_count: CudaSlice<u32>,
|
| 126 |
+
col_best_match: CudaSlice<u32>,
|
| 127 |
+
|
| 128 |
+
iter_counter: u32,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
impl TemporalMemoryGpu {
|
| 132 |
+
pub fn new(
|
| 133 |
+
dev: Arc<CudaDevice>,
|
| 134 |
+
n_columns: usize,
|
| 135 |
+
cells_per_column: usize,
|
| 136 |
+
) -> Result<Self, DriverError> {
|
| 137 |
+
let n_cells = n_columns * cells_per_column;
|
| 138 |
+
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
|
| 139 |
+
let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL;
|
| 140 |
+
let bits_words = n_cells / 32;
|
| 141 |
+
|
| 142 |
+
// Numenta defaults.
|
| 143 |
+
let activation_threshold = 15u32;
|
| 144 |
+
let learning_threshold = 13u32;
|
| 145 |
+
let initial_perm_i16 = perm_f32_to_i16(0.21);
|
| 146 |
+
let conn_thr_i16 = perm_f32_to_i16(0.50);
|
| 147 |
+
let perm_inc_i16 = perm_f32_to_i16(0.10);
|
| 148 |
+
let perm_dec_i16 = perm_f32_to_i16(0.10);
|
| 149 |
+
let predicted_seg_dec_i16 = perm_f32_to_i16(0.10);
|
| 150 |
+
let max_new_synapse_count = 20u32;
|
| 151 |
+
|
| 152 |
+
// Allocate buffers.
|
| 153 |
+
let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max];
|
| 154 |
+
let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?;
|
| 155 |
+
let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 156 |
+
let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
|
| 157 |
+
let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
|
| 158 |
+
let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?;
|
| 159 |
+
|
| 160 |
+
let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 161 |
+
let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 162 |
+
let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 163 |
+
let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 164 |
+
let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
|
| 165 |
+
|
| 166 |
+
let col_predicted = dev.alloc_zeros::<u8>(n_columns)?;
|
| 167 |
+
let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 168 |
+
let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?;
|
| 169 |
+
let unpredicted_count = dev.alloc_zeros::<u32>(1)?;
|
| 170 |
+
// Bursting columns for one step bounded by n_columns.
|
| 171 |
+
let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?;
|
| 172 |
+
let burst_cols_count = dev.alloc_zeros::<u32>(1)?;
|
| 173 |
+
let col_best_match = dev.alloc_zeros::<u32>(n_columns)?;
|
| 174 |
+
|
| 175 |
+
// Load PTX modules.
|
| 176 |
+
let modules = [
|
| 177 |
+
("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"),
|
| 178 |
+
("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"),
|
| 179 |
+
("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"),
|
| 180 |
+
("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"),
|
| 181 |
+
("htm_tm_grow", PTX_TM_GROW, "tm_grow"),
|
| 182 |
+
("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"),
|
| 183 |
+
("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"),
|
| 184 |
+
];
|
| 185 |
+
for (modname, ptx, fnname) in modules {
|
| 186 |
+
if dev.get_func(modname, fnname).is_none() {
|
| 187 |
+
dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
|
| 188 |
+
}
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
Ok(Self {
|
| 192 |
+
dev,
|
| 193 |
+
n_columns,
|
| 194 |
+
cells_per_column,
|
| 195 |
+
activation_threshold,
|
| 196 |
+
learning_threshold,
|
| 197 |
+
initial_perm_i16,
|
| 198 |
+
conn_thr_i16,
|
| 199 |
+
perm_inc_i16,
|
| 200 |
+
perm_dec_i16,
|
| 201 |
+
predicted_seg_dec_i16,
|
| 202 |
+
max_new_synapse_count,
|
| 203 |
+
n_cells,
|
| 204 |
+
n_segments_max,
|
| 205 |
+
bits_words,
|
| 206 |
+
seg_cell_id,
|
| 207 |
+
seg_syn_count,
|
| 208 |
+
syn_presyn,
|
| 209 |
+
syn_perm,
|
| 210 |
+
cell_seg_count,
|
| 211 |
+
cell_active_bits,
|
| 212 |
+
cell_winner_bits,
|
| 213 |
+
cell_predictive_bits,
|
| 214 |
+
prev_active_bits,
|
| 215 |
+
prev_winner_bits,
|
| 216 |
+
col_predicted,
|
| 217 |
+
seg_num_active_conn,
|
| 218 |
+
seg_num_active_pot,
|
| 219 |
+
unpredicted_count,
|
| 220 |
+
burst_cols_flat,
|
| 221 |
+
burst_cols_count,
|
| 222 |
+
col_best_match,
|
| 223 |
+
iter_counter: 0,
|
| 224 |
+
})
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
// --- Fused-path accessors ---
|
| 228 |
+
pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id }
|
| 229 |
+
pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count }
|
| 230 |
+
pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn }
|
| 231 |
+
pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm }
|
| 232 |
+
pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count }
|
| 233 |
+
|
| 234 |
+
/// Hard reset β clear everything (predictive + active + segments).
|
| 235 |
+
pub fn reset(&mut self) -> Result<(), DriverError> {
|
| 236 |
+
// Restore "unused" sentinel in seg_cell_id.
|
| 237 |
+
let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max];
|
| 238 |
+
self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?;
|
| 239 |
+
self.dev.memset_zeros(&mut self.seg_syn_count)?;
|
| 240 |
+
self.dev.memset_zeros(&mut self.cell_seg_count)?;
|
| 241 |
+
self.dev.memset_zeros(&mut self.cell_active_bits)?;
|
| 242 |
+
self.dev.memset_zeros(&mut self.cell_winner_bits)?;
|
| 243 |
+
self.dev.memset_zeros(&mut self.cell_predictive_bits)?;
|
| 244 |
+
self.dev.memset_zeros(&mut self.prev_active_bits)?;
|
| 245 |
+
self.dev.memset_zeros(&mut self.prev_winner_bits)?;
|
| 246 |
+
self.dev.memset_zeros(&mut self.col_best_match)?;
|
| 247 |
+
self.iter_counter = 0;
|
| 248 |
+
Ok(())
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
fn build_cfg(&self) -> TmConfig {
|
| 252 |
+
TmConfig {
|
| 253 |
+
activation_threshold: self.activation_threshold,
|
| 254 |
+
learning_threshold: self.learning_threshold,
|
| 255 |
+
cells_per_column: self.cells_per_column as u32,
|
| 256 |
+
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
|
| 257 |
+
n_segments: self.n_segments_max as u32,
|
| 258 |
+
n_cells: self.n_cells as u32,
|
| 259 |
+
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
|
| 260 |
+
max_new_synapses: self.max_new_synapse_count,
|
| 261 |
+
conn_thr_i16: self.conn_thr_i16 as i32,
|
| 262 |
+
perm_inc_i16: self.perm_inc_i16 as i32,
|
| 263 |
+
perm_dec_i16: self.perm_dec_i16 as i32,
|
| 264 |
+
predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32,
|
| 265 |
+
initial_perm_i16: self.initial_perm_i16 as i32,
|
| 266 |
+
iter_seed: self.iter_counter,
|
| 267 |
+
n_cols: self.n_columns as u32,
|
| 268 |
+
bits_words: self.bits_words as u32,
|
| 269 |
+
}
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
/// Run one TM step on the GPU. Takes the SP active-column mask (u8, already
|
| 273 |
+
/// on device) and writes `anomaly_out[t_slot]`.
|
| 274 |
+
pub fn step(
|
| 275 |
+
&mut self,
|
| 276 |
+
sp_active_mask: &CudaSlice<u8>,
|
| 277 |
+
anomaly_out: &mut CudaSlice<f32>,
|
| 278 |
+
t_slot: u32,
|
| 279 |
+
learn: bool,
|
| 280 |
+
) -> Result<(), DriverError> {
|
| 281 |
+
let n_cells = self.n_cells;
|
| 282 |
+
let n_cols = self.n_columns;
|
| 283 |
+
|
| 284 |
+
let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap();
|
| 285 |
+
let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap();
|
| 286 |
+
let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap();
|
| 287 |
+
let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap();
|
| 288 |
+
let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap();
|
| 289 |
+
let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap();
|
| 290 |
+
let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap();
|
| 291 |
+
|
| 292 |
+
self.iter_counter = self.iter_counter.wrapping_add(1);
|
| 293 |
+
let cfg_val = self.build_cfg();
|
| 294 |
+
|
| 295 |
+
// 0. Per-step reset.
|
| 296 |
+
let reset_words = self.bits_words.max(n_cols);
|
| 297 |
+
let reset_cfg = LaunchConfig {
|
| 298 |
+
grid_dim: (((reset_words + 255) / 256) as u32, 1, 1),
|
| 299 |
+
block_dim: (256, 1, 1),
|
| 300 |
+
shared_mem_bytes: 0,
|
| 301 |
+
};
|
| 302 |
+
unsafe {
|
| 303 |
+
reset_fn.clone().launch(
|
| 304 |
+
reset_cfg,
|
| 305 |
+
(
|
| 306 |
+
&mut self.cell_active_bits,
|
| 307 |
+
&mut self.cell_winner_bits,
|
| 308 |
+
&mut self.cell_predictive_bits,
|
| 309 |
+
&mut self.prev_active_bits,
|
| 310 |
+
&mut self.prev_winner_bits,
|
| 311 |
+
&mut self.col_predicted,
|
| 312 |
+
&mut self.unpredicted_count,
|
| 313 |
+
&mut self.burst_cols_count,
|
| 314 |
+
&mut self.col_best_match,
|
| 315 |
+
self.bits_words as u32,
|
| 316 |
+
n_cols as u32,
|
| 317 |
+
),
|
| 318 |
+
)?;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
// 1. Predict (grid = n_cells; each block iterates its cell's segments).
|
| 322 |
+
let predict_cfg = LaunchConfig {
|
| 323 |
+
grid_dim: (n_cells as u32, 1, 1),
|
| 324 |
+
block_dim: (32, 1, 1),
|
| 325 |
+
shared_mem_bytes: 0,
|
| 326 |
+
};
|
| 327 |
+
unsafe {
|
| 328 |
+
predict_fn.clone().launch(
|
| 329 |
+
predict_cfg,
|
| 330 |
+
(
|
| 331 |
+
&self.seg_cell_id,
|
| 332 |
+
&self.seg_syn_count,
|
| 333 |
+
&self.syn_presyn,
|
| 334 |
+
&self.syn_perm,
|
| 335 |
+
&self.prev_active_bits,
|
| 336 |
+
&mut self.cell_predictive_bits,
|
| 337 |
+
&mut self.col_predicted,
|
| 338 |
+
&mut self.seg_num_active_conn,
|
| 339 |
+
&mut self.seg_num_active_pot,
|
| 340 |
+
&mut self.col_best_match,
|
| 341 |
+
&self.cell_seg_count,
|
| 342 |
+
cfg_val,
|
| 343 |
+
),
|
| 344 |
+
)?;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// 2. Activate.
|
| 348 |
+
let activate_cfg = LaunchConfig {
|
| 349 |
+
grid_dim: (((n_cols + 255) / 256) as u32, 1, 1),
|
| 350 |
+
block_dim: (256, 1, 1),
|
| 351 |
+
shared_mem_bytes: 0,
|
| 352 |
+
};
|
| 353 |
+
unsafe {
|
| 354 |
+
activate_fn.clone().launch(
|
| 355 |
+
activate_cfg,
|
| 356 |
+
(
|
| 357 |
+
sp_active_mask,
|
| 358 |
+
&self.col_predicted,
|
| 359 |
+
&self.cell_predictive_bits,
|
| 360 |
+
&mut self.cell_active_bits,
|
| 361 |
+
&mut self.cell_winner_bits,
|
| 362 |
+
&mut self.unpredicted_count,
|
| 363 |
+
&mut self.burst_cols_flat,
|
| 364 |
+
&mut self.burst_cols_count,
|
| 365 |
+
cfg_val,
|
| 366 |
+
),
|
| 367 |
+
)?;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
// 3. Anomaly.
|
| 371 |
+
let anom_cfg = LaunchConfig {
|
| 372 |
+
grid_dim: (1, 1, 1),
|
| 373 |
+
block_dim: (256, 1, 1),
|
| 374 |
+
shared_mem_bytes: 0,
|
| 375 |
+
};
|
| 376 |
+
unsafe {
|
| 377 |
+
anom_fn.clone().launch(
|
| 378 |
+
anom_cfg,
|
| 379 |
+
(
|
| 380 |
+
sp_active_mask,
|
| 381 |
+
&self.unpredicted_count,
|
| 382 |
+
anomaly_out,
|
| 383 |
+
t_slot,
|
| 384 |
+
n_cols as u32,
|
| 385 |
+
),
|
| 386 |
+
)?;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
if learn {
|
| 390 |
+
// 4. Reinforce (grid = n_cells).
|
| 391 |
+
let learn_cfg = LaunchConfig {
|
| 392 |
+
grid_dim: (n_cells as u32, 1, 1),
|
| 393 |
+
block_dim: (32, 1, 1),
|
| 394 |
+
shared_mem_bytes: 0,
|
| 395 |
+
};
|
| 396 |
+
unsafe {
|
| 397 |
+
learn_fn.clone().launch(
|
| 398 |
+
learn_cfg,
|
| 399 |
+
(
|
| 400 |
+
&self.seg_cell_id,
|
| 401 |
+
&self.seg_syn_count,
|
| 402 |
+
&self.syn_presyn,
|
| 403 |
+
&mut self.syn_perm,
|
| 404 |
+
&self.seg_num_active_conn,
|
| 405 |
+
&self.prev_active_bits,
|
| 406 |
+
sp_active_mask,
|
| 407 |
+
&self.col_predicted,
|
| 408 |
+
&self.cell_seg_count,
|
| 409 |
+
cfg_val,
|
| 410 |
+
),
|
| 411 |
+
)?;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
// 5. Punish.
|
| 415 |
+
unsafe {
|
| 416 |
+
punish_fn.clone().launch(
|
| 417 |
+
learn_cfg,
|
| 418 |
+
(
|
| 419 |
+
&self.seg_cell_id,
|
| 420 |
+
&self.seg_syn_count,
|
| 421 |
+
&self.syn_presyn,
|
| 422 |
+
&mut self.syn_perm,
|
| 423 |
+
&self.seg_num_active_pot,
|
| 424 |
+
&self.prev_active_bits,
|
| 425 |
+
sp_active_mask,
|
| 426 |
+
&self.cell_seg_count,
|
| 427 |
+
cfg_val,
|
| 428 |
+
),
|
| 429 |
+
)?;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
// 6. Grow.
|
| 433 |
+
let grow_cfg = LaunchConfig {
|
| 434 |
+
grid_dim: (n_cols as u32, 1, 1),
|
| 435 |
+
block_dim: (32, 1, 1),
|
| 436 |
+
shared_mem_bytes: 0,
|
| 437 |
+
};
|
| 438 |
+
unsafe {
|
| 439 |
+
grow_fn.clone().launch(
|
| 440 |
+
grow_cfg,
|
| 441 |
+
(
|
| 442 |
+
&mut self.seg_cell_id,
|
| 443 |
+
&mut self.seg_syn_count,
|
| 444 |
+
&mut self.syn_presyn,
|
| 445 |
+
&mut self.syn_perm,
|
| 446 |
+
&mut self.cell_seg_count,
|
| 447 |
+
&self.burst_cols_flat,
|
| 448 |
+
&self.burst_cols_count,
|
| 449 |
+
&self.prev_winner_bits,
|
| 450 |
+
&self.prev_active_bits,
|
| 451 |
+
&self.col_best_match,
|
| 452 |
+
cfg_val,
|
| 453 |
+
),
|
| 454 |
+
)?;
|
| 455 |
+
}
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
Ok(())
|
| 459 |
+
}
|
| 460 |
+
}
|
overlay/htm_rust/uv.lock
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
version = 1
|
| 2 |
-
revision = 3
|
| 3 |
-
requires-python = ">=3.11"
|
| 4 |
-
|
| 5 |
-
[[package]]
|
| 6 |
-
name = "htm-rust"
|
| 7 |
-
version = "0.1.0"
|
| 8 |
-
source = { editable = "." }
|
|
|
|
| 1 |
+
version = 1
|
| 2 |
+
revision = 3
|
| 3 |
+
requires-python = ">=3.11"
|
| 4 |
+
|
| 5 |
+
[[package]]
|
| 6 |
+
name = "htm-rust"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
source = { editable = "." }
|
overlay/hydra/config.py
CHANGED
|
@@ -110,8 +110,8 @@ class PostSemClawConfig:
|
|
| 110 |
gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
|
| 111 |
|
| 112 |
# Label smoothing + Z-loss
|
| 113 |
-
label_smoothing: float =
|
| 114 |
-
z_loss_weight: float =
|
| 115 |
|
| 116 |
|
| 117 |
# ---------------------------------------------------------------------------
|
|
|
|
| 110 |
gdn_layers: tuple[int, ...] = field(default_factory=_parse_gdn_layers_env)
|
| 111 |
|
| 112 |
# Label smoothing + Z-loss
|
| 113 |
+
label_smoothing: float = 0.0 # disabled: any smoothing hurts in 5-min budget
|
| 114 |
+
z_loss_weight: float = 1e-4
|
| 115 |
|
| 116 |
|
| 117 |
# ---------------------------------------------------------------------------
|
overlay/hydra/engram.py
CHANGED
|
@@ -1,23 +1,93 @@
|
|
| 1 |
-
"""GPU Engram β
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
def __init__(
|
|
@@ -31,15 +101,20 @@ class GPUEngram(nn.Module):
|
|
| 31 |
self.n_columns = n_columns
|
| 32 |
self.max_ngram = max_ngram
|
| 33 |
self.hebbian_boost = hebbian_boost
|
|
|
|
| 34 |
self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
|
| 35 |
self.gate = nn.Linear(d_model, 1, bias=True)
|
| 36 |
-
nn.init.constant_(self.gate.bias, 0.0)
|
| 37 |
-
|
| 38 |
self.primes = [2654435761, 2246822519, 3266489917]
|
| 39 |
self.hebbian_lr = 0.01
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 43 |
B, T = token_ids.shape
|
| 44 |
h = token_ids * self.primes[0]
|
| 45 |
if T > 1:
|
|
@@ -52,103 +127,44 @@ class GPUEngram(nn.Module):
|
|
| 52 |
h = h ^ (shifted2 * self.primes[2])
|
| 53 |
return h % self.n_columns
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
else:
|
| 59 |
-
raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
|
| 60 |
-
if sdr_active_indices.dim() not in (2, 3):
|
| 61 |
-
raise ValueError("compact active indices must have shape (B,T,K) or (B*T,K)")
|
| 62 |
-
# Dense SDR masks arrive with K ~= n_bits; compact buffers are small
|
| 63 |
-
# (retina target_active or RealityBridge l0_k). Refuse obviously dense
|
| 64 |
-
# masks so forced cantor_sdr cannot silently route 0/1 values as offsets.
|
| 65 |
-
if sdr_active_indices.shape[-1] > 1024 or sdr_active_indices.shape[-1] > self.n_columns:
|
| 66 |
-
raise ValueError("Engram Cantor/SDR routing expects compact active indices, not a dense SDR mask")
|
| 67 |
-
|
| 68 |
-
def _cantor_sdr_candidates(
|
| 69 |
-
self,
|
| 70 |
-
sdr_active_indices: torch.Tensor,
|
| 71 |
-
cantor_leaf_ids: torch.Tensor,
|
| 72 |
-
n_leaves: int,
|
| 73 |
-
) -> torch.Tensor:
|
| 74 |
-
"""Map SDR active offsets into each Cantor leaf's Engram column shard."""
|
| 75 |
-
self._validate_active_indices(sdr_active_indices, cantor_leaf_ids)
|
| 76 |
-
if sdr_active_indices.dim() == 2:
|
| 77 |
-
B, T = cantor_leaf_ids.shape
|
| 78 |
-
sdr_active_indices = sdr_active_indices.view(B, T, -1)
|
| 79 |
-
sdr = sdr_active_indices.to(device=cantor_leaf_ids.device, dtype=torch.long)
|
| 80 |
-
leaves = cantor_leaf_ids.to(dtype=torch.long).clamp(min=0, max=max(0, n_leaves - 1))
|
| 81 |
-
cols_per_leaf = max(1, self.n_columns // max(1, n_leaves))
|
| 82 |
-
offsets = sdr.remainder(cols_per_leaf)
|
| 83 |
-
base = leaves.unsqueeze(-1) * cols_per_leaf
|
| 84 |
-
return (base + offsets).clamp(max=self.n_columns - 1)
|
| 85 |
-
|
| 86 |
-
def _flat_retrieve(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
-
scores = x @ self.memory.T
|
| 88 |
-
topk_vals, topk_idx = scores.topk(self.topk_k, dim=-1)
|
| 89 |
-
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 90 |
-
selected_mem = self.memory[topk_idx]
|
| 91 |
-
return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
|
| 92 |
|
| 93 |
-
def
|
| 94 |
-
|
| 95 |
-
x: torch.Tensor,
|
| 96 |
-
sdr_active_indices: torch.Tensor,
|
| 97 |
-
cantor_leaf_ids: torch.Tensor,
|
| 98 |
-
cantor_n_leaves: int,
|
| 99 |
-
) -> torch.Tensor:
|
| 100 |
-
candidates = self._cantor_sdr_candidates(
|
| 101 |
-
sdr_active_indices,
|
| 102 |
-
cantor_leaf_ids,
|
| 103 |
-
n_leaves=cantor_n_leaves,
|
| 104 |
-
)
|
| 105 |
-
cand_mem = self.memory[candidates]
|
| 106 |
-
scores = torch.einsum('btd,btkd->btk', x, cand_mem)
|
| 107 |
-
k = min(self.topk_k, scores.shape[-1])
|
| 108 |
-
topk_vals, local_idx = scores.topk(k, dim=-1)
|
| 109 |
-
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 110 |
-
global_idx = candidates.gather(-1, local_idx)
|
| 111 |
-
selected_mem = self.memory[global_idx]
|
| 112 |
-
return torch.einsum('btk,btkd->btd', topk_w, selected_mem)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
# spent tens of billions of extra MACs per forward.
|
| 138 |
-
use_cantor = k_active < self.n_columns
|
| 139 |
-
|
| 140 |
-
if use_cantor and mode in {"cantor_sdr", "auto"}:
|
| 141 |
-
retrieved = self._cantor_sdr_retrieve(x, sdr_active_indices, cantor_leaf_ids, cantor_n_leaves)
|
| 142 |
-
else:
|
| 143 |
-
retrieved = self._flat_retrieve(x)
|
| 144 |
-
|
| 145 |
-
alpha = torch.sigmoid(self.gate(x))
|
| 146 |
|
|
|
|
| 147 |
if self.training and self.hebbian_boost:
|
| 148 |
with torch.no_grad():
|
|
|
|
| 149 |
indices = self._hash(token_ids)
|
| 150 |
-
flat_idx = indices.reshape(-1)
|
| 151 |
-
flat_x = x.detach().reshape(-1,
|
| 152 |
mem_dtype = self.memory.data.dtype
|
| 153 |
updates = (
|
| 154 |
self.hebbian_lr * flat_x
|
|
@@ -156,5 +172,6 @@ class GPUEngram(nn.Module):
|
|
| 156 |
).to(mem_dtype)
|
| 157 |
self.memory.data.index_add_(0, flat_idx, updates)
|
| 158 |
|
|
|
|
| 159 |
hit_rate = (alpha.detach() > 0.1).float().mean()
|
| 160 |
return x + alpha * retrieved, hit_rate
|
|
|
|
| 1 |
+
"""GPU Engram β Sparse Modern Hopfield retrieval path.
|
| 2 |
|
| 3 |
+
## What changed (scatter-gather β Hopfield matmul)
|
| 4 |
+
|
| 5 |
+
The original forward used `self.memory[indices]` (scatter-gather), which misses
|
| 6 |
+
L2 cache at n_columns > 4096 and creates a hard tps ceiling.
|
| 7 |
+
|
| 8 |
+
The replacement uses:
|
| 9 |
+
scores = x @ self.memory.T # (B, T, n_columns) β coalesced matmul
|
| 10 |
+
weights = entmax15(scores, dim=-1) # sparse attention; 95%+ exact zeros
|
| 11 |
+
retrieved = weights @ self.memory # (B, T, d_model) β coalesced matmul
|
| 12 |
+
|
| 13 |
+
Both matmuls are tile-friendly (cuBLAS GEMM), so L2 reuse is high regardless of
|
| 14 |
+
n_columns. Gradient flows through both matmuls so `self.memory` learns via
|
| 15 |
+
autograd in addition to (or instead of) the Hebbian EMA writes.
|
| 16 |
+
|
| 17 |
+
## Sparsity mechanism
|
| 18 |
+
|
| 19 |
+
alpha-entmax with alpha=1.5 (entmax15) is a sparse attention operator that maps
|
| 20 |
+
logit vectors to distributions where many entries are *exactly* zero (not merely
|
| 21 |
+
small). It generalises softmax (alpha=1) and argmax (alphaββ). At n_columns=1024
|
| 22 |
+
with d_model=64 a random batch typically hits β₯95% zero entries β the key
|
| 23 |
+
property that keeps bandwidth proportional to *attended* columns, not all columns.
|
| 24 |
+
|
| 25 |
+
Fallback: if `entmax` is not pip-installed, top-k softmax (k=32) is used instead.
|
| 26 |
+
This is chosen at module-import time β NO runtime branching per forward call.
|
| 27 |
+
|
| 28 |
+
## token_ids argument
|
| 29 |
|
| 30 |
+
token_ids is accepted for API compatibility with the rest of the hydra stack
|
| 31 |
+
(train.py, lightning_module.py call `engram(x, token_ids)`). It is NOT used in
|
| 32 |
+
the retrieval path β the Hopfield path computes dense similarity over the whole
|
| 33 |
+
memory bank, which subsumes any hash-based column selection. Documented here to
|
| 34 |
+
prevent confusion.
|
| 35 |
+
|
| 36 |
+
## Hebbian writes (hebbian_boost=False by default)
|
| 37 |
+
|
| 38 |
+
With Hopfield retrieval, gradient signals reach self.memory through autograd, so
|
| 39 |
+
Hebbian EMA writes are no longer critical. They are preserved as an *optional*
|
| 40 |
+
boost (hebbian_boost=True) for experiments that want both signals. Default is off.
|
| 41 |
+
|
| 42 |
+
## Checkpoint compatibility
|
| 43 |
+
|
| 44 |
+
`self.memory` shape (n_columns, d_model) is unchanged, so existing .pt / .ckpt
|
| 45 |
+
files load without modification.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
from __future__ import annotations
|
| 49 |
|
| 50 |
import torch
|
| 51 |
import torch.nn as nn
|
| 52 |
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Sparse-attention backend β chosen ONCE at import time, no runtime branching.
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
|
| 57 |
+
try:
|
| 58 |
+
from entmax import entmax15 as _entmax15 # type: ignore[import]
|
| 59 |
|
| 60 |
+
def _sparse_attention(scores: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
"""alpha-entmax (alpha=1.5): truly sparse distribution over last dim."""
|
| 62 |
+
return _entmax15(scores, dim=-1).to(dtype=scores.dtype)
|
| 63 |
|
| 64 |
+
_BACKEND = "entmax15"
|
| 65 |
+
|
| 66 |
+
except ImportError: # pragma: no cover β entmax always installed in CI
|
| 67 |
+
_K = 32 # top-k for fallback
|
| 68 |
|
| 69 |
+
def _sparse_attention(scores: torch.Tensor) -> torch.Tensor: # type: ignore[misc]
|
| 70 |
+
"""Top-k softmax fallback: zero outside the k highest-scoring columns."""
|
| 71 |
+
topk_vals, topk_idx = scores.topk(_K, dim=-1)
|
| 72 |
+
topk_w = torch.softmax(topk_vals, dim=-1)
|
| 73 |
+
weights = torch.zeros_like(scores)
|
| 74 |
+
weights.scatter_(-1, topk_idx, topk_w.to(dtype=weights.dtype))
|
| 75 |
+
return weights
|
| 76 |
+
|
| 77 |
+
_BACKEND = "topk32"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GPUEngram(nn.Module):
|
| 81 |
+
"""GPU Engram: Sparse Modern Hopfield retrieval.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
d_model: Model dimension β must match the surrounding transformer.
|
| 85 |
+
n_columns: Number of memory columns (key-value pairs). Safe at 32 768
|
| 86 |
+
with the matmul path; the old scatter-gather had an L2
|
| 87 |
+
cliff above ~4 096.
|
| 88 |
+
max_ngram: Retained for API compatibility; unused in retrieval path.
|
| 89 |
+
hebbian_boost: If True, also run a Hebbian EMA write on the memory bank
|
| 90 |
+
during training (old behaviour, now optional). Default False.
|
| 91 |
"""
|
| 92 |
|
| 93 |
def __init__(
|
|
|
|
| 101 |
self.n_columns = n_columns
|
| 102 |
self.max_ngram = max_ngram
|
| 103 |
self.hebbian_boost = hebbian_boost
|
| 104 |
+
# Shape unchanged from original β existing checkpoints load cleanly.
|
| 105 |
self.memory = nn.Parameter(torch.randn(n_columns, d_model) * 0.01)
|
| 106 |
self.gate = nn.Linear(d_model, 1, bias=True)
|
| 107 |
+
nn.init.constant_(self.gate.bias, 0.0) # START OPEN
|
| 108 |
+
# Retained for any external code that reads these attrs.
|
| 109 |
self.primes = [2654435761, 2246822519, 3266489917]
|
| 110 |
self.hebbian_lr = 0.01
|
| 111 |
+
|
| 112 |
+
# ------------------------------------------------------------------
|
| 113 |
+
# _hash: retained for API/checkpoint compat; unused in forward below.
|
| 114 |
+
# ------------------------------------------------------------------
|
| 115 |
|
| 116 |
def _hash(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
"""N-gram hash β column index (kept for backward-compat; not used in retrieval)."""
|
| 118 |
B, T = token_ids.shape
|
| 119 |
h = token_ids * self.primes[0]
|
| 120 |
if T > 1:
|
|
|
|
| 127 |
h = h ^ (shifted2 * self.primes[2])
|
| 128 |
return h % self.n_columns
|
| 129 |
|
| 130 |
+
# ------------------------------------------------------------------
|
| 131 |
+
# forward
|
| 132 |
+
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
def forward(self, x: torch.Tensor, token_ids: torch.Tensor):
|
| 135 |
+
"""Hopfield retrieve + soft gate + residual.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
Args:
|
| 138 |
+
x: (B, T, d_model) β input activations.
|
| 139 |
+
token_ids: (B, T) β token indices. Accepted for API compatibility;
|
| 140 |
+
NOT used in the retrieval path (see module docstring).
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
(x + alpha * retrieved, hit_rate)
|
| 144 |
+
- x + alpha * retrieved: (B, T, d_model)
|
| 145 |
+
- hit_rate: scalar tensor β fraction of gate values > 0.1
|
| 146 |
+
"""
|
| 147 |
+
# ---- 1. Similarity scores (coalesced GEMM) ----------------------
|
| 148 |
+
# scores[b, t, c] = dot(x[b,t], memory[c])
|
| 149 |
+
scores = x @ self.memory.T # (B, T, n_columns)
|
| 150 |
+
|
| 151 |
+
# ---- 2. Sparse attention weights --------------------------------
|
| 152 |
+
# _sparse_attention is fixed at import time (entmax15 or top-k).
|
| 153 |
+
weights = _sparse_attention(scores) # (B, T, n_columns), many exact zeros
|
| 154 |
+
|
| 155 |
+
# ---- 3. Retrieved vector (coalesced GEMM) -----------------------
|
| 156 |
+
retrieved = weights @ self.memory # (B, T, d_model)
|
| 157 |
+
|
| 158 |
+
# ---- 4. Soft gate (unchanged) -----------------------------------
|
| 159 |
+
alpha = torch.sigmoid(self.gate(x)) # (B, T, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
# ---- 5. Optional Hebbian EMA write ------------------------------
|
| 162 |
if self.training and self.hebbian_boost:
|
| 163 |
with torch.no_grad():
|
| 164 |
+
# Reuse the hash-based indices for the write target (sparse update).
|
| 165 |
indices = self._hash(token_ids)
|
| 166 |
+
flat_idx = indices.reshape(-1) # (B*T,)
|
| 167 |
+
flat_x = x.detach().reshape(-1, x.shape[-1]) # (B*T, d_model)
|
| 168 |
mem_dtype = self.memory.data.dtype
|
| 169 |
updates = (
|
| 170 |
self.hebbian_lr * flat_x
|
|
|
|
| 172 |
).to(mem_dtype)
|
| 173 |
self.memory.data.index_add_(0, flat_idx, updates)
|
| 174 |
|
| 175 |
+
# ---- 6. Residual + hit_rate -------------------------------------
|
| 176 |
hit_rate = (alpha.detach() > 0.1).float().mean()
|
| 177 |
return x + alpha * retrieved, hit_rate
|
overlay/hydra/model.py
CHANGED
|
@@ -469,6 +469,7 @@ class PostSemClawModel(nn.Module):
|
|
| 469 |
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
|
| 470 |
# dtypes in the same shape group would break lerp_ dtype checks.
|
| 471 |
self.wte.to(dtype=torch.bfloat16)
|
|
|
|
| 472 |
self.htm_proj.to(dtype=torch.bfloat16)
|
| 473 |
self.sdr_proj.to(dtype=torch.bfloat16)
|
| 474 |
self.engram.to(dtype=torch.bfloat16)
|
|
|
|
| 469 |
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
|
| 470 |
# dtypes in the same shape group would break lerp_ dtype checks.
|
| 471 |
self.wte.to(dtype=torch.bfloat16)
|
| 472 |
+
self.blocks.to(dtype=torch.bfloat16)
|
| 473 |
self.htm_proj.to(dtype=torch.bfloat16)
|
| 474 |
self.sdr_proj.to(dtype=torch.bfloat16)
|
| 475 |
self.engram.to(dtype=torch.bfloat16)
|
overlay/scripts/autoresearch.py
CHANGED
|
@@ -1,517 +1,517 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""HYDRA Autoresearch Mutation Loop.
|
| 3 |
-
|
| 4 |
-
Runs baseline training -> evaluates -> picks ONE mutation at a time ->
|
| 5 |
-
trains -> evaluates -> keeps if quality improves AND tps >= floor.
|
| 6 |
-
Repeats until all mutations exhausted or Ctrl+C.
|
| 7 |
-
|
| 8 |
-
State persisted in .omc/autoresearch_config.json for resume support.
|
| 9 |
-
|
| 10 |
-
Usage:
|
| 11 |
-
python scripts/autoresearch.py # run full loop
|
| 12 |
-
python scripts/autoresearch.py --dry-run # show plan, don't train
|
| 13 |
-
python scripts/autoresearch.py --baseline # only run baseline eval
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import argparse
|
| 19 |
-
import json
|
| 20 |
-
import math
|
| 21 |
-
import os
|
| 22 |
-
import re
|
| 23 |
-
import signal
|
| 24 |
-
import subprocess
|
| 25 |
-
import sys
|
| 26 |
-
import time
|
| 27 |
-
from pathlib import Path
|
| 28 |
-
|
| 29 |
-
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 30 |
-
if _PROJECT_ROOT not in sys.path:
|
| 31 |
-
sys.path.insert(0, _PROJECT_ROOT)
|
| 32 |
-
|
| 33 |
-
# ---------------------------------------------------------------------------
|
| 34 |
-
# Mutation catalog (ordered by expected impact)
|
| 35 |
-
# ---------------------------------------------------------------------------
|
| 36 |
-
|
| 37 |
-
MUTATIONS = [
|
| 38 |
-
# Learning dynamics β env vars verified in hydra/config.py
|
| 39 |
-
{"name": "lr_matrix_0.012", "env": "HYDRA_MATRIX_LR=0.012"}, # default 0.12
|
| 40 |
-
{"name": "lr_matrix_0.06", "env": "HYDRA_MATRIX_LR=0.06"}, # half default
|
| 41 |
-
{"name": "lr_matrix_0.24", "env": "HYDRA_MATRIX_LR=0.24"}, # double default
|
| 42 |
-
{"name": "lr_floor_50pct", "env": "HYDRA_LR_MIN_MULT=0.5"}, # default 0.0
|
| 43 |
-
{"name": "lr_floor_20pct", "env": "HYDRA_LR_MIN_MULT=0.2"}, # default 0.0
|
| 44 |
-
{"name": "embed_lr_0.5", "env": "HYDRA_EMBED_LR=0.5"}, # default 1.0
|
| 45 |
-
{"name": "embed_lr_2.0", "env": "HYDRA_EMBED_LR=2.0"}, # default 1.0
|
| 46 |
-
{"name": "unembed_lr_0.01", "env": "HYDRA_UNEMBED_LR=0.01"}, # default 0.005
|
| 47 |
-
# Architecture β env vars verified in hydra/config.py
|
| 48 |
-
{"name": "d_model_384", "env": "HYDRA_D_MODEL=384"}, # default 256
|
| 49 |
-
{"name": "d_model_192", "env": "HYDRA_D_MODEL=192"}, # smaller
|
| 50 |
-
{"name": "d_state_128", "env": "HYDRA_D_STATE=128"}, # default 64
|
| 51 |
-
{"name": "d_state_32", "env": "HYDRA_D_STATE=32"}, # smaller
|
| 52 |
-
{"name": "n_layer_6", "env": "HYDRA_N_LAYER=6"}, # default 4
|
| 53 |
-
{"name": "n_layer_3", "env": "HYDRA_N_LAYER=3"}, # fewer
|
| 54 |
-
{"name": "headdim_16", "env": "HYDRA_HEADDIM=16"}, # default 32 -> more heads
|
| 55 |
-
{"name": "headdim_64", "env": "HYDRA_HEADDIM=64"}, # default 32 -> fewer heads
|
| 56 |
-
{"name": "expand_3", "env": "HYDRA_EXPAND=3"}, # default 2
|
| 57 |
-
{"name": "engram_2048", "env": "HYDRA_ENGRAM_N_COLUMNS=2048"}, # default 1024
|
| 58 |
-
{"name": "engram_4096", "env": "HYDRA_ENGRAM_N_COLUMNS=4096"}, # default 1024
|
| 59 |
-
{"name": "engram_512", "env": "HYDRA_ENGRAM_N_COLUMNS=512"}, # smaller
|
| 60 |
-
# Batch size
|
| 61 |
-
{"name": "batch_32k", "env": "HYDRA_TOTAL_BATCH=32768"}, # default 32768 (verify)
|
| 62 |
-
{"name": "batch_16k", "env": "HYDRA_TOTAL_BATCH=16384"}, # smaller batch
|
| 63 |
-
{"name": "batch_65k", "env": "HYDRA_TOTAL_BATCH=65536"}, # larger batch
|
| 64 |
-
# Regularization β env vars verified in hydra/model.py + hydra/config.py
|
| 65 |
-
{"name": "dropout_0.05", "env": "HYDRA_DROPOUT=0.05"}, # default 0.2
|
| 66 |
-
{"name": "dropout_0.1", "env": "HYDRA_DROPOUT=0.1"}, # default 0.2
|
| 67 |
-
{"name": "dropout_0.3", "env": "HYDRA_DROPOUT=0.3"}, # higher
|
| 68 |
-
]
|
| 69 |
-
|
| 70 |
-
# ---------------------------------------------------------------------------
|
| 71 |
-
# State management
|
| 72 |
-
# ---------------------------------------------------------------------------
|
| 73 |
-
|
| 74 |
-
STATE_DIR = os.path.join(_PROJECT_ROOT, ".omc")
|
| 75 |
-
STATE_FILE = os.path.join(STATE_DIR, "autoresearch_config.json")
|
| 76 |
-
|
| 77 |
-
DEFAULT_STATE = {
|
| 78 |
-
"baseline_quality": None,
|
| 79 |
-
"baseline_tps": None,
|
| 80 |
-
"current_gen": 0,
|
| 81 |
-
"mutations_tested": [],
|
| 82 |
-
"mutations_kept": [],
|
| 83 |
-
"tps_floor": 62000,
|
| 84 |
-
"time_budget": 600,
|
| 85 |
-
"history": [],
|
| 86 |
-
}
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def load_state() -> dict:
|
| 90 |
-
"""Load state from disk or return default."""
|
| 91 |
-
if os.path.exists(STATE_FILE):
|
| 92 |
-
with open(STATE_FILE, "r") as f:
|
| 93 |
-
state = json.load(f)
|
| 94 |
-
# Backfill missing keys from defaults
|
| 95 |
-
for k, v in DEFAULT_STATE.items():
|
| 96 |
-
if k not in state:
|
| 97 |
-
state[k] = v
|
| 98 |
-
return state
|
| 99 |
-
return dict(DEFAULT_STATE)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def save_state(state: dict) -> None:
|
| 103 |
-
"""Persist state to disk."""
|
| 104 |
-
os.makedirs(STATE_DIR, exist_ok=True)
|
| 105 |
-
with open(STATE_FILE, "w") as f:
|
| 106 |
-
json.dump(state, f, indent=2)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
# ---------------------------------------------------------------------------
|
| 110 |
-
# Training subprocess
|
| 111 |
-
# ---------------------------------------------------------------------------
|
| 112 |
-
|
| 113 |
-
def build_env(extra_env: str | None = None) -> dict[str, str]:
|
| 114 |
-
"""Build environment for training subprocess."""
|
| 115 |
-
env = os.environ.copy()
|
| 116 |
-
# Ensure CUDA paths
|
| 117 |
-
ld_paths = ["/usr/lib/wsl/lib", "/usr/local/cuda/lib64"]
|
| 118 |
-
existing = env.get("LD_LIBRARY_PATH", "")
|
| 119 |
-
for p in ld_paths:
|
| 120 |
-
if p not in existing:
|
| 121 |
-
existing = p + ":" + existing
|
| 122 |
-
env["LD_LIBRARY_PATH"] = existing
|
| 123 |
-
|
| 124 |
-
# Apply mutation env var
|
| 125 |
-
if extra_env:
|
| 126 |
-
key, val = extra_env.split("=", 1)
|
| 127 |
-
env[key] = val
|
| 128 |
-
|
| 129 |
-
return env
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def run_training(time_budget: int, extra_env: str | None = None) -> dict | None:
|
| 133 |
-
"""Run train.py with given time budget and optional env override.
|
| 134 |
-
|
| 135 |
-
Returns dict with parsed metrics, or None on failure.
|
| 136 |
-
"""
|
| 137 |
-
env = build_env(extra_env)
|
| 138 |
-
env["HYDRA_TIME_BUDGET"] = str(time_budget)
|
| 139 |
-
|
| 140 |
-
cmd = [os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"), "-u", "train.py"]
|
| 141 |
-
|
| 142 |
-
try:
|
| 143 |
-
proc = subprocess.Popen(
|
| 144 |
-
cmd,
|
| 145 |
-
cwd=_PROJECT_ROOT,
|
| 146 |
-
env=env,
|
| 147 |
-
stdout=subprocess.PIPE,
|
| 148 |
-
stderr=subprocess.STDOUT,
|
| 149 |
-
text=True,
|
| 150 |
-
bufsize=1,
|
| 151 |
-
)
|
| 152 |
-
except Exception as e:
|
| 153 |
-
print(f" [ERROR] Failed to start training: {e}")
|
| 154 |
-
return None
|
| 155 |
-
|
| 156 |
-
output_lines: list[str] = []
|
| 157 |
-
last_step_line = ""
|
| 158 |
-
|
| 159 |
-
try:
|
| 160 |
-
for line in proc.stdout:
|
| 161 |
-
line = line.rstrip()
|
| 162 |
-
output_lines.append(line)
|
| 163 |
-
if line.startswith("step="):
|
| 164 |
-
last_step_line = line
|
| 165 |
-
# Print progress every 50 steps
|
| 166 |
-
m = re.search(r"step=(\d+)", line)
|
| 167 |
-
if m and int(m.group(1)) % 50 == 0:
|
| 168 |
-
tps_m = re.search(r"tps=(\d+)", line)
|
| 169 |
-
bpb_m = re.search(r"bpb=([\d.]+)", line)
|
| 170 |
-
tps = tps_m.group(1) if tps_m else "?"
|
| 171 |
-
bpb = bpb_m.group(1) if bpb_m else "?"
|
| 172 |
-
print(f" step={m.group(1)} tps={tps} bpb={bpb}", flush=True)
|
| 173 |
-
elif "val_bpb" in line or "factual_english_score" in line:
|
| 174 |
-
print(f" {line}", flush=True)
|
| 175 |
-
except KeyboardInterrupt:
|
| 176 |
-
proc.terminate()
|
| 177 |
-
proc.wait()
|
| 178 |
-
raise
|
| 179 |
-
|
| 180 |
-
proc.wait()
|
| 181 |
-
if proc.returncode != 0:
|
| 182 |
-
print(f" [ERROR] Training exited with code {proc.returncode}")
|
| 183 |
-
# Print last 10 lines for debugging
|
| 184 |
-
for line in output_lines[-10:]:
|
| 185 |
-
print(f" {line}")
|
| 186 |
-
return None
|
| 187 |
-
|
| 188 |
-
return _parse_training_output(output_lines)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def _parse_training_output(lines: list[str]) -> dict:
|
| 192 |
-
"""Extract metrics from training output lines."""
|
| 193 |
-
metrics: dict[str, float] = {}
|
| 194 |
-
|
| 195 |
-
for line in lines:
|
| 196 |
-
# Key=value pairs from summary block
|
| 197 |
-
for key in ["val_bpb", "training_seconds", "peak_vram_mb", "mfu_percent",
|
| 198 |
-
"total_tokens_M", "num_steps", "factual_english_score",
|
| 199 |
-
"factual_english_hits"]:
|
| 200 |
-
m = re.match(rf"^{key}:\s+([\d.]+)", line.strip())
|
| 201 |
-
if m:
|
| 202 |
-
metrics[key] = float(m.group(1))
|
| 203 |
-
|
| 204 |
-
# TPS from last step line
|
| 205 |
-
if line.startswith("step="):
|
| 206 |
-
tps_m = re.search(r"tps=(\d+)", line)
|
| 207 |
-
if tps_m:
|
| 208 |
-
metrics["tps"] = float(tps_m.group(1))
|
| 209 |
-
|
| 210 |
-
return metrics
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
# ---------------------------------------------------------------------------
|
| 214 |
-
# Eval integration
|
| 215 |
-
# ---------------------------------------------------------------------------
|
| 216 |
-
|
| 217 |
-
def run_eval_after_training(extra_env: str | None = None) -> dict | None:
|
| 218 |
-
"""Run eval_quality.py after training. Returns metrics dict or None."""
|
| 219 |
-
env = build_env(extra_env)
|
| 220 |
-
cmd = [
|
| 221 |
-
os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"),
|
| 222 |
-
os.path.join(_PROJECT_ROOT, "scripts", "eval_quality.py"),
|
| 223 |
-
]
|
| 224 |
-
|
| 225 |
-
try:
|
| 226 |
-
result = subprocess.run(
|
| 227 |
-
cmd,
|
| 228 |
-
cwd=_PROJECT_ROOT,
|
| 229 |
-
env=env,
|
| 230 |
-
capture_output=True,
|
| 231 |
-
text=True,
|
| 232 |
-
timeout=120, # 2 min max for eval
|
| 233 |
-
)
|
| 234 |
-
except subprocess.TimeoutExpired:
|
| 235 |
-
print(" [ERROR] Eval timed out (120s)")
|
| 236 |
-
return None
|
| 237 |
-
except Exception as e:
|
| 238 |
-
print(f" [ERROR] Eval failed: {e}")
|
| 239 |
-
return None
|
| 240 |
-
|
| 241 |
-
if result.returncode != 0:
|
| 242 |
-
print(f" [ERROR] Eval exited with code {result.returncode}")
|
| 243 |
-
for line in result.stdout.split("\n")[-10:]:
|
| 244 |
-
print(f" {line}")
|
| 245 |
-
for line in result.stderr.split("\n")[-5:]:
|
| 246 |
-
print(f" {line}")
|
| 247 |
-
return None
|
| 248 |
-
|
| 249 |
-
# Parse key=value output
|
| 250 |
-
metrics = {}
|
| 251 |
-
for line in result.stdout.split("\n"):
|
| 252 |
-
line = line.strip()
|
| 253 |
-
m = re.match(r"^([\w]+)=([\d.eE+-]+)$", line)
|
| 254 |
-
if m:
|
| 255 |
-
try:
|
| 256 |
-
metrics[m.group(1)] = float(m.group(2))
|
| 257 |
-
except ValueError:
|
| 258 |
-
pass
|
| 259 |
-
|
| 260 |
-
return metrics if metrics else None
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
# ---------------------------------------------------------------------------
|
| 264 |
-
# Git operations
|
| 265 |
-
# ---------------------------------------------------------------------------
|
| 266 |
-
|
| 267 |
-
def git_commit(message: str) -> bool:
|
| 268 |
-
"""Stage all changes and commit."""
|
| 269 |
-
try:
|
| 270 |
-
subprocess.run(["git", "add", "-A"], cwd=_PROJECT_ROOT, check=True,
|
| 271 |
-
capture_output=True, timeout=30)
|
| 272 |
-
subprocess.run(
|
| 273 |
-
["git", "commit", "-m", message],
|
| 274 |
-
cwd=_PROJECT_ROOT, check=True, capture_output=True, timeout=30,
|
| 275 |
-
)
|
| 276 |
-
return True
|
| 277 |
-
except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
|
| 278 |
-
print(f" [WARN] Git commit failed: {e}")
|
| 279 |
-
return False
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
# ---------------------------------------------------------------------------
|
| 283 |
-
# Main loop
|
| 284 |
-
# ---------------------------------------------------------------------------
|
| 285 |
-
|
| 286 |
-
_SHUTDOWN = False
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
def _handle_sigint(signum, frame):
|
| 290 |
-
global _SHUTDOWN
|
| 291 |
-
if _SHUTDOWN:
|
| 292 |
-
print("\n[AUTORESEARCH] Double Ctrl+C β force exit")
|
| 293 |
-
sys.exit(1)
|
| 294 |
-
_SHUTDOWN = True
|
| 295 |
-
print("\n[AUTORESEARCH] Ctrl+C received β finishing current gen then saving state...")
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
def main():
|
| 299 |
-
global _SHUTDOWN
|
| 300 |
-
signal.signal(signal.SIGINT, _handle_sigint)
|
| 301 |
-
|
| 302 |
-
parser = argparse.ArgumentParser(description="HYDRA autoresearch mutation loop")
|
| 303 |
-
parser.add_argument("--dry-run", action="store_true", help="Show plan, don't train")
|
| 304 |
-
parser.add_argument("--baseline", action="store_true", help="Only run baseline")
|
| 305 |
-
parser.add_argument("--time-budget", type=int, default=600, help="Time budget per run (s)")
|
| 306 |
-
parser.add_argument("--tps-floor", type=int, default=62000, help="Minimum acceptable TPS")
|
| 307 |
-
args = parser.parse_args()
|
| 308 |
-
|
| 309 |
-
state = load_state()
|
| 310 |
-
state["time_budget"] = args.time_budget
|
| 311 |
-
state["tps_floor"] = args.tps_floor
|
| 312 |
-
|
| 313 |
-
tested = set(state["mutations_tested"])
|
| 314 |
-
remaining = [m for m in MUTATIONS if m["name"] not in tested]
|
| 315 |
-
|
| 316 |
-
print("=" * 70)
|
| 317 |
-
print("HYDRA AUTORESEARCH MUTATION LOOP")
|
| 318 |
-
print("=" * 70)
|
| 319 |
-
print(f"Time budget per run: {state['time_budget']}s")
|
| 320 |
-
print(f"TPS floor: {state['tps_floor']}")
|
| 321 |
-
print(f"Current gen: {state['current_gen']}")
|
| 322 |
-
print(f"Mutations tested: {len(tested)}/{len(MUTATIONS)}")
|
| 323 |
-
print(f"Mutations kept: {state['mutations_kept']}")
|
| 324 |
-
print(f"Remaining: {[m['name'] for m in remaining]}")
|
| 325 |
-
print()
|
| 326 |
-
|
| 327 |
-
if args.dry_run:
|
| 328 |
-
print("[DRY RUN] Would test these mutations in order:")
|
| 329 |
-
for i, m in enumerate(remaining):
|
| 330 |
-
print(f" {i + 1}. {m['name']} ({m['env']})")
|
| 331 |
-
return
|
| 332 |
-
|
| 333 |
-
# -----------------------------------------------------------------------
|
| 334 |
-
# Baseline (Gen 0)
|
| 335 |
-
# -----------------------------------------------------------------------
|
| 336 |
-
if state["baseline_quality"] is None:
|
| 337 |
-
print("[GEN 0] Running baseline training + evaluation...")
|
| 338 |
-
train_metrics = run_training(state["time_budget"])
|
| 339 |
-
if train_metrics is None:
|
| 340 |
-
print("[FAIL] Baseline training failed")
|
| 341 |
-
save_state(state)
|
| 342 |
-
return
|
| 343 |
-
|
| 344 |
-
print("[GEN 0] Running quality evaluation...")
|
| 345 |
-
eval_metrics = run_eval_after_training()
|
| 346 |
-
if eval_metrics is None:
|
| 347 |
-
print("[FAIL] Baseline eval failed")
|
| 348 |
-
save_state(state)
|
| 349 |
-
return
|
| 350 |
-
|
| 351 |
-
baseline_tps = train_metrics.get("tps", 0)
|
| 352 |
-
baseline_quality = eval_metrics.get("quality_score", 0)
|
| 353 |
-
|
| 354 |
-
state["baseline_quality"] = baseline_quality
|
| 355 |
-
state["baseline_tps"] = baseline_tps
|
| 356 |
-
state["current_gen"] = 0
|
| 357 |
-
state["history"].append({
|
| 358 |
-
"gen": 0,
|
| 359 |
-
"mutation": "baseline",
|
| 360 |
-
"quality_score": baseline_quality,
|
| 361 |
-
"baseline_score": baseline_quality,
|
| 362 |
-
"delta": "0.0%",
|
| 363 |
-
"tps": baseline_tps,
|
| 364 |
-
"ppl": eval_metrics.get("ppl", 0),
|
| 365 |
-
"bleu4": eval_metrics.get("bleu4", 0),
|
| 366 |
-
"rouge_l": eval_metrics.get("rouge_l", 0),
|
| 367 |
-
"factual": eval_metrics.get("factual", 0),
|
| 368 |
-
"bpb": eval_metrics.get("bpb", 0),
|
| 369 |
-
"repetition_rate": eval_metrics.get("repetition_rate", 0),
|
| 370 |
-
"kept": True,
|
| 371 |
-
})
|
| 372 |
-
save_state(state)
|
| 373 |
-
print(f"[GEN 0] BASELINE: quality={baseline_quality:.4f} tps={baseline_tps:.0f}")
|
| 374 |
-
|
| 375 |
-
if args.baseline:
|
| 376 |
-
return
|
| 377 |
-
else:
|
| 378 |
-
print(f"[RESUME] Baseline quality={state['baseline_quality']:.4f} tps={state['baseline_tps']:.0f}")
|
| 379 |
-
if args.baseline:
|
| 380 |
-
return
|
| 381 |
-
|
| 382 |
-
# -----------------------------------------------------------------------
|
| 383 |
-
# Mutation loop
|
| 384 |
-
# -----------------------------------------------------------------------
|
| 385 |
-
current_quality = state["baseline_quality"]
|
| 386 |
-
# Track best quality so far (from last kept mutation, not just baseline)
|
| 387 |
-
if state["history"]:
|
| 388 |
-
kept_entries = [h for h in state["history"] if h.get("kept")]
|
| 389 |
-
if kept_entries:
|
| 390 |
-
current_quality = kept_entries[-1]["quality_score"]
|
| 391 |
-
|
| 392 |
-
for mutation in remaining:
|
| 393 |
-
if _SHUTDOWN:
|
| 394 |
-
print("[AUTORESEARCH] Shutdown requested β saving state")
|
| 395 |
-
save_state(state)
|
| 396 |
-
return
|
| 397 |
-
|
| 398 |
-
gen = state["current_gen"] + 1
|
| 399 |
-
name = mutation["name"]
|
| 400 |
-
env_str = mutation["env"]
|
| 401 |
-
|
| 402 |
-
print(f"\n[GEN {gen}] Testing {name} ({env_str})...")
|
| 403 |
-
print(f" Current best quality: {current_quality:.4f}")
|
| 404 |
-
|
| 405 |
-
# Train with mutation
|
| 406 |
-
print(f" Training ({state['time_budget']}s)...", flush=True)
|
| 407 |
-
train_metrics = run_training(state["time_budget"], extra_env=env_str)
|
| 408 |
-
if train_metrics is None:
|
| 409 |
-
print(f" [SKIP] Training failed for {name}")
|
| 410 |
-
state["mutations_tested"].append(name)
|
| 411 |
-
state["current_gen"] = gen
|
| 412 |
-
state["history"].append({
|
| 413 |
-
"gen": gen, "mutation": name,
|
| 414 |
-
"quality_score": 0, "baseline_score": current_quality,
|
| 415 |
-
"delta": "FAIL", "tps": 0, "ppl": 0, "bleu4": 0,
|
| 416 |
-
"rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
|
| 417 |
-
"kept": False,
|
| 418 |
-
})
|
| 419 |
-
save_state(state)
|
| 420 |
-
continue
|
| 421 |
-
|
| 422 |
-
tps = train_metrics.get("tps", 0)
|
| 423 |
-
|
| 424 |
-
# TPS floor check
|
| 425 |
-
if tps < state["tps_floor"]:
|
| 426 |
-
print(f" [REJECT] TPS={tps:.0f} < floor={state['tps_floor']} β skipping eval")
|
| 427 |
-
state["mutations_tested"].append(name)
|
| 428 |
-
state["current_gen"] = gen
|
| 429 |
-
state["history"].append({
|
| 430 |
-
"gen": gen, "mutation": name,
|
| 431 |
-
"quality_score": 0, "baseline_score": current_quality,
|
| 432 |
-
"delta": f"TPS_FAIL({tps:.0f})", "tps": tps,
|
| 433 |
-
"ppl": 0, "bleu4": 0, "rouge_l": 0, "factual": 0,
|
| 434 |
-
"bpb": train_metrics.get("val_bpb", 0), "repetition_rate": 0,
|
| 435 |
-
"kept": False,
|
| 436 |
-
})
|
| 437 |
-
save_state(state)
|
| 438 |
-
continue
|
| 439 |
-
|
| 440 |
-
# Evaluate
|
| 441 |
-
print(f" Evaluating...", flush=True)
|
| 442 |
-
eval_metrics = run_eval_after_training(extra_env=env_str)
|
| 443 |
-
if eval_metrics is None:
|
| 444 |
-
print(f" [SKIP] Eval failed for {name}")
|
| 445 |
-
state["mutations_tested"].append(name)
|
| 446 |
-
state["current_gen"] = gen
|
| 447 |
-
state["history"].append({
|
| 448 |
-
"gen": gen, "mutation": name,
|
| 449 |
-
"quality_score": 0, "baseline_score": current_quality,
|
| 450 |
-
"delta": "EVAL_FAIL", "tps": tps, "ppl": 0, "bleu4": 0,
|
| 451 |
-
"rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
|
| 452 |
-
"kept": False,
|
| 453 |
-
})
|
| 454 |
-
save_state(state)
|
| 455 |
-
continue
|
| 456 |
-
|
| 457 |
-
quality = eval_metrics.get("quality_score", 0)
|
| 458 |
-
delta_pct = ((quality - current_quality) / max(abs(current_quality), 1e-6)) * 100
|
| 459 |
-
delta_str = f"{delta_pct:+.1f}%"
|
| 460 |
-
|
| 461 |
-
kept = quality > current_quality and tps >= state["tps_floor"]
|
| 462 |
-
status = "KEEP" if kept else "DISCARD"
|
| 463 |
-
|
| 464 |
-
entry = {
|
| 465 |
-
"gen": gen,
|
| 466 |
-
"mutation": name,
|
| 467 |
-
"quality_score": quality,
|
| 468 |
-
"baseline_score": current_quality,
|
| 469 |
-
"delta": delta_str,
|
| 470 |
-
"tps": tps,
|
| 471 |
-
"ppl": eval_metrics.get("ppl", 0),
|
| 472 |
-
"bleu4": eval_metrics.get("bleu4", 0),
|
| 473 |
-
"rouge_l": eval_metrics.get("rouge_l", 0),
|
| 474 |
-
"factual": eval_metrics.get("factual", 0),
|
| 475 |
-
"bpb": eval_metrics.get("bpb", 0),
|
| 476 |
-
"repetition_rate": eval_metrics.get("repetition_rate", 0),
|
| 477 |
-
"kept": kept,
|
| 478 |
-
}
|
| 479 |
-
|
| 480 |
-
print(f"\n[GEN {gen}] {name}: quality={quality:.4f} ({delta_str}) tps={tps:.0f} -> {status}")
|
| 481 |
-
|
| 482 |
-
if kept:
|
| 483 |
-
current_quality = quality
|
| 484 |
-
state["mutations_kept"].append(name)
|
| 485 |
-
git_commit(f"autoresearch: gen {gen} β {name} quality {delta_str}")
|
| 486 |
-
|
| 487 |
-
state["mutations_tested"].append(name)
|
| 488 |
-
state["current_gen"] = gen
|
| 489 |
-
state["history"].append(entry)
|
| 490 |
-
save_state(state)
|
| 491 |
-
|
| 492 |
-
# -----------------------------------------------------------------------
|
| 493 |
-
# Summary
|
| 494 |
-
# -----------------------------------------------------------------------
|
| 495 |
-
print("\n" + "=" * 70)
|
| 496 |
-
print("AUTORESEARCH COMPLETE")
|
| 497 |
-
print("=" * 70)
|
| 498 |
-
print(f"Total generations: {state['current_gen']}")
|
| 499 |
-
print(f"Mutations kept: {state['mutations_kept']}")
|
| 500 |
-
print(f"Final quality: {current_quality:.4f}")
|
| 501 |
-
if state["baseline_quality"]:
|
| 502 |
-
total_delta = ((current_quality - state["baseline_quality"]) /
|
| 503 |
-
max(abs(state["baseline_quality"]), 1e-6)) * 100
|
| 504 |
-
print(f"Total improvement: {total_delta:+.1f}%")
|
| 505 |
-
print()
|
| 506 |
-
|
| 507 |
-
# Print history table
|
| 508 |
-
print(f"{'Gen':>4} {'Mutation':>20} {'Quality':>8} {'Delta':>8} {'TPS':>7} {'PPL':>8} {'BPB':>7} {'Kept':>5}")
|
| 509 |
-
print("-" * 75)
|
| 510 |
-
for h in state["history"]:
|
| 511 |
-
print(f"{h['gen']:4d} {h['mutation']:>20s} {h['quality_score']:8.4f} "
|
| 512 |
-
f"{h['delta']:>8s} {h['tps']:7.0f} {h['ppl']:8.2f} "
|
| 513 |
-
f"{h.get('bpb', 0):7.4f} {' YES' if h['kept'] else ' NO'}")
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
if __name__ == "__main__":
|
| 517 |
-
main()
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""HYDRA Autoresearch Mutation Loop.
|
| 3 |
+
|
| 4 |
+
Runs baseline training -> evaluates -> picks ONE mutation at a time ->
|
| 5 |
+
trains -> evaluates -> keeps if quality improves AND tps >= floor.
|
| 6 |
+
Repeats until all mutations exhausted or Ctrl+C.
|
| 7 |
+
|
| 8 |
+
State persisted in .omc/autoresearch_config.json for resume support.
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python scripts/autoresearch.py # run full loop
|
| 12 |
+
python scripts/autoresearch.py --dry-run # show plan, don't train
|
| 13 |
+
python scripts/autoresearch.py --baseline # only run baseline eval
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import json
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import re
|
| 23 |
+
import signal
|
| 24 |
+
import subprocess
|
| 25 |
+
import sys
|
| 26 |
+
import time
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 30 |
+
if _PROJECT_ROOT not in sys.path:
|
| 31 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Mutation catalog (ordered by expected impact)
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
|
| 37 |
+
MUTATIONS = [
|
| 38 |
+
# Learning dynamics β env vars verified in hydra/config.py
|
| 39 |
+
{"name": "lr_matrix_0.012", "env": "HYDRA_MATRIX_LR=0.012"}, # default 0.12
|
| 40 |
+
{"name": "lr_matrix_0.06", "env": "HYDRA_MATRIX_LR=0.06"}, # half default
|
| 41 |
+
{"name": "lr_matrix_0.24", "env": "HYDRA_MATRIX_LR=0.24"}, # double default
|
| 42 |
+
{"name": "lr_floor_50pct", "env": "HYDRA_LR_MIN_MULT=0.5"}, # default 0.0
|
| 43 |
+
{"name": "lr_floor_20pct", "env": "HYDRA_LR_MIN_MULT=0.2"}, # default 0.0
|
| 44 |
+
{"name": "embed_lr_0.5", "env": "HYDRA_EMBED_LR=0.5"}, # default 1.0
|
| 45 |
+
{"name": "embed_lr_2.0", "env": "HYDRA_EMBED_LR=2.0"}, # default 1.0
|
| 46 |
+
{"name": "unembed_lr_0.01", "env": "HYDRA_UNEMBED_LR=0.01"}, # default 0.005
|
| 47 |
+
# Architecture β env vars verified in hydra/config.py
|
| 48 |
+
{"name": "d_model_384", "env": "HYDRA_D_MODEL=384"}, # default 256
|
| 49 |
+
{"name": "d_model_192", "env": "HYDRA_D_MODEL=192"}, # smaller
|
| 50 |
+
{"name": "d_state_128", "env": "HYDRA_D_STATE=128"}, # default 64
|
| 51 |
+
{"name": "d_state_32", "env": "HYDRA_D_STATE=32"}, # smaller
|
| 52 |
+
{"name": "n_layer_6", "env": "HYDRA_N_LAYER=6"}, # default 4
|
| 53 |
+
{"name": "n_layer_3", "env": "HYDRA_N_LAYER=3"}, # fewer
|
| 54 |
+
{"name": "headdim_16", "env": "HYDRA_HEADDIM=16"}, # default 32 -> more heads
|
| 55 |
+
{"name": "headdim_64", "env": "HYDRA_HEADDIM=64"}, # default 32 -> fewer heads
|
| 56 |
+
{"name": "expand_3", "env": "HYDRA_EXPAND=3"}, # default 2
|
| 57 |
+
{"name": "engram_2048", "env": "HYDRA_ENGRAM_N_COLUMNS=2048"}, # default 1024
|
| 58 |
+
{"name": "engram_4096", "env": "HYDRA_ENGRAM_N_COLUMNS=4096"}, # default 1024
|
| 59 |
+
{"name": "engram_512", "env": "HYDRA_ENGRAM_N_COLUMNS=512"}, # smaller
|
| 60 |
+
# Batch size
|
| 61 |
+
{"name": "batch_32k", "env": "HYDRA_TOTAL_BATCH=32768"}, # default 32768 (verify)
|
| 62 |
+
{"name": "batch_16k", "env": "HYDRA_TOTAL_BATCH=16384"}, # smaller batch
|
| 63 |
+
{"name": "batch_65k", "env": "HYDRA_TOTAL_BATCH=65536"}, # larger batch
|
| 64 |
+
# Regularization β env vars verified in hydra/model.py + hydra/config.py
|
| 65 |
+
{"name": "dropout_0.05", "env": "HYDRA_DROPOUT=0.05"}, # default 0.2
|
| 66 |
+
{"name": "dropout_0.1", "env": "HYDRA_DROPOUT=0.1"}, # default 0.2
|
| 67 |
+
{"name": "dropout_0.3", "env": "HYDRA_DROPOUT=0.3"}, # higher
|
| 68 |
+
]
|
| 69 |
+
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
# State management
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
|
| 74 |
+
STATE_DIR = os.path.join(_PROJECT_ROOT, ".omc")
|
| 75 |
+
STATE_FILE = os.path.join(STATE_DIR, "autoresearch_config.json")
|
| 76 |
+
|
| 77 |
+
DEFAULT_STATE = {
|
| 78 |
+
"baseline_quality": None,
|
| 79 |
+
"baseline_tps": None,
|
| 80 |
+
"current_gen": 0,
|
| 81 |
+
"mutations_tested": [],
|
| 82 |
+
"mutations_kept": [],
|
| 83 |
+
"tps_floor": 62000,
|
| 84 |
+
"time_budget": 600,
|
| 85 |
+
"history": [],
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def load_state() -> dict:
|
| 90 |
+
"""Load state from disk or return default."""
|
| 91 |
+
if os.path.exists(STATE_FILE):
|
| 92 |
+
with open(STATE_FILE, "r") as f:
|
| 93 |
+
state = json.load(f)
|
| 94 |
+
# Backfill missing keys from defaults
|
| 95 |
+
for k, v in DEFAULT_STATE.items():
|
| 96 |
+
if k not in state:
|
| 97 |
+
state[k] = v
|
| 98 |
+
return state
|
| 99 |
+
return dict(DEFAULT_STATE)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def save_state(state: dict) -> None:
|
| 103 |
+
"""Persist state to disk."""
|
| 104 |
+
os.makedirs(STATE_DIR, exist_ok=True)
|
| 105 |
+
with open(STATE_FILE, "w") as f:
|
| 106 |
+
json.dump(state, f, indent=2)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ---------------------------------------------------------------------------
|
| 110 |
+
# Training subprocess
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
|
| 113 |
+
def build_env(extra_env: str | None = None) -> dict[str, str]:
|
| 114 |
+
"""Build environment for training subprocess."""
|
| 115 |
+
env = os.environ.copy()
|
| 116 |
+
# Ensure CUDA paths
|
| 117 |
+
ld_paths = ["/usr/lib/wsl/lib", "/usr/local/cuda/lib64"]
|
| 118 |
+
existing = env.get("LD_LIBRARY_PATH", "")
|
| 119 |
+
for p in ld_paths:
|
| 120 |
+
if p not in existing:
|
| 121 |
+
existing = p + ":" + existing
|
| 122 |
+
env["LD_LIBRARY_PATH"] = existing
|
| 123 |
+
|
| 124 |
+
# Apply mutation env var
|
| 125 |
+
if extra_env:
|
| 126 |
+
key, val = extra_env.split("=", 1)
|
| 127 |
+
env[key] = val
|
| 128 |
+
|
| 129 |
+
return env
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def run_training(time_budget: int, extra_env: str | None = None) -> dict | None:
|
| 133 |
+
"""Run train.py with given time budget and optional env override.
|
| 134 |
+
|
| 135 |
+
Returns dict with parsed metrics, or None on failure.
|
| 136 |
+
"""
|
| 137 |
+
env = build_env(extra_env)
|
| 138 |
+
env["HYDRA_TIME_BUDGET"] = str(time_budget)
|
| 139 |
+
|
| 140 |
+
cmd = [os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"), "-u", "train.py"]
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
proc = subprocess.Popen(
|
| 144 |
+
cmd,
|
| 145 |
+
cwd=_PROJECT_ROOT,
|
| 146 |
+
env=env,
|
| 147 |
+
stdout=subprocess.PIPE,
|
| 148 |
+
stderr=subprocess.STDOUT,
|
| 149 |
+
text=True,
|
| 150 |
+
bufsize=1,
|
| 151 |
+
)
|
| 152 |
+
except Exception as e:
|
| 153 |
+
print(f" [ERROR] Failed to start training: {e}")
|
| 154 |
+
return None
|
| 155 |
+
|
| 156 |
+
output_lines: list[str] = []
|
| 157 |
+
last_step_line = ""
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
for line in proc.stdout:
|
| 161 |
+
line = line.rstrip()
|
| 162 |
+
output_lines.append(line)
|
| 163 |
+
if line.startswith("step="):
|
| 164 |
+
last_step_line = line
|
| 165 |
+
# Print progress every 50 steps
|
| 166 |
+
m = re.search(r"step=(\d+)", line)
|
| 167 |
+
if m and int(m.group(1)) % 50 == 0:
|
| 168 |
+
tps_m = re.search(r"tps=(\d+)", line)
|
| 169 |
+
bpb_m = re.search(r"bpb=([\d.]+)", line)
|
| 170 |
+
tps = tps_m.group(1) if tps_m else "?"
|
| 171 |
+
bpb = bpb_m.group(1) if bpb_m else "?"
|
| 172 |
+
print(f" step={m.group(1)} tps={tps} bpb={bpb}", flush=True)
|
| 173 |
+
elif "val_bpb" in line or "factual_english_score" in line:
|
| 174 |
+
print(f" {line}", flush=True)
|
| 175 |
+
except KeyboardInterrupt:
|
| 176 |
+
proc.terminate()
|
| 177 |
+
proc.wait()
|
| 178 |
+
raise
|
| 179 |
+
|
| 180 |
+
proc.wait()
|
| 181 |
+
if proc.returncode != 0:
|
| 182 |
+
print(f" [ERROR] Training exited with code {proc.returncode}")
|
| 183 |
+
# Print last 10 lines for debugging
|
| 184 |
+
for line in output_lines[-10:]:
|
| 185 |
+
print(f" {line}")
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
return _parse_training_output(output_lines)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _parse_training_output(lines: list[str]) -> dict:
|
| 192 |
+
"""Extract metrics from training output lines."""
|
| 193 |
+
metrics: dict[str, float] = {}
|
| 194 |
+
|
| 195 |
+
for line in lines:
|
| 196 |
+
# Key=value pairs from summary block
|
| 197 |
+
for key in ["val_bpb", "training_seconds", "peak_vram_mb", "mfu_percent",
|
| 198 |
+
"total_tokens_M", "num_steps", "factual_english_score",
|
| 199 |
+
"factual_english_hits"]:
|
| 200 |
+
m = re.match(rf"^{key}:\s+([\d.]+)", line.strip())
|
| 201 |
+
if m:
|
| 202 |
+
metrics[key] = float(m.group(1))
|
| 203 |
+
|
| 204 |
+
# TPS from last step line
|
| 205 |
+
if line.startswith("step="):
|
| 206 |
+
tps_m = re.search(r"tps=(\d+)", line)
|
| 207 |
+
if tps_m:
|
| 208 |
+
metrics["tps"] = float(tps_m.group(1))
|
| 209 |
+
|
| 210 |
+
return metrics
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
# Eval integration
|
| 215 |
+
# ---------------------------------------------------------------------------
|
| 216 |
+
|
| 217 |
+
def run_eval_after_training(extra_env: str | None = None) -> dict | None:
|
| 218 |
+
"""Run eval_quality.py after training. Returns metrics dict or None."""
|
| 219 |
+
env = build_env(extra_env)
|
| 220 |
+
cmd = [
|
| 221 |
+
os.path.join(_PROJECT_ROOT, ".venv", "bin", "python"),
|
| 222 |
+
os.path.join(_PROJECT_ROOT, "scripts", "eval_quality.py"),
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
result = subprocess.run(
|
| 227 |
+
cmd,
|
| 228 |
+
cwd=_PROJECT_ROOT,
|
| 229 |
+
env=env,
|
| 230 |
+
capture_output=True,
|
| 231 |
+
text=True,
|
| 232 |
+
timeout=120, # 2 min max for eval
|
| 233 |
+
)
|
| 234 |
+
except subprocess.TimeoutExpired:
|
| 235 |
+
print(" [ERROR] Eval timed out (120s)")
|
| 236 |
+
return None
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(f" [ERROR] Eval failed: {e}")
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
if result.returncode != 0:
|
| 242 |
+
print(f" [ERROR] Eval exited with code {result.returncode}")
|
| 243 |
+
for line in result.stdout.split("\n")[-10:]:
|
| 244 |
+
print(f" {line}")
|
| 245 |
+
for line in result.stderr.split("\n")[-5:]:
|
| 246 |
+
print(f" {line}")
|
| 247 |
+
return None
|
| 248 |
+
|
| 249 |
+
# Parse key=value output
|
| 250 |
+
metrics = {}
|
| 251 |
+
for line in result.stdout.split("\n"):
|
| 252 |
+
line = line.strip()
|
| 253 |
+
m = re.match(r"^([\w]+)=([\d.eE+-]+)$", line)
|
| 254 |
+
if m:
|
| 255 |
+
try:
|
| 256 |
+
metrics[m.group(1)] = float(m.group(2))
|
| 257 |
+
except ValueError:
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
return metrics if metrics else None
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# ---------------------------------------------------------------------------
|
| 264 |
+
# Git operations
|
| 265 |
+
# ---------------------------------------------------------------------------
|
| 266 |
+
|
| 267 |
+
def git_commit(message: str) -> bool:
|
| 268 |
+
"""Stage all changes and commit."""
|
| 269 |
+
try:
|
| 270 |
+
subprocess.run(["git", "add", "-A"], cwd=_PROJECT_ROOT, check=True,
|
| 271 |
+
capture_output=True, timeout=30)
|
| 272 |
+
subprocess.run(
|
| 273 |
+
["git", "commit", "-m", message],
|
| 274 |
+
cwd=_PROJECT_ROOT, check=True, capture_output=True, timeout=30,
|
| 275 |
+
)
|
| 276 |
+
return True
|
| 277 |
+
except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
|
| 278 |
+
print(f" [WARN] Git commit failed: {e}")
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ---------------------------------------------------------------------------
|
| 283 |
+
# Main loop
|
| 284 |
+
# ---------------------------------------------------------------------------
|
| 285 |
+
|
| 286 |
+
_SHUTDOWN = False
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def _handle_sigint(signum, frame):
|
| 290 |
+
global _SHUTDOWN
|
| 291 |
+
if _SHUTDOWN:
|
| 292 |
+
print("\n[AUTORESEARCH] Double Ctrl+C β force exit")
|
| 293 |
+
sys.exit(1)
|
| 294 |
+
_SHUTDOWN = True
|
| 295 |
+
print("\n[AUTORESEARCH] Ctrl+C received β finishing current gen then saving state...")
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def main():
|
| 299 |
+
global _SHUTDOWN
|
| 300 |
+
signal.signal(signal.SIGINT, _handle_sigint)
|
| 301 |
+
|
| 302 |
+
parser = argparse.ArgumentParser(description="HYDRA autoresearch mutation loop")
|
| 303 |
+
parser.add_argument("--dry-run", action="store_true", help="Show plan, don't train")
|
| 304 |
+
parser.add_argument("--baseline", action="store_true", help="Only run baseline")
|
| 305 |
+
parser.add_argument("--time-budget", type=int, default=600, help="Time budget per run (s)")
|
| 306 |
+
parser.add_argument("--tps-floor", type=int, default=62000, help="Minimum acceptable TPS")
|
| 307 |
+
args = parser.parse_args()
|
| 308 |
+
|
| 309 |
+
state = load_state()
|
| 310 |
+
state["time_budget"] = args.time_budget
|
| 311 |
+
state["tps_floor"] = args.tps_floor
|
| 312 |
+
|
| 313 |
+
tested = set(state["mutations_tested"])
|
| 314 |
+
remaining = [m for m in MUTATIONS if m["name"] not in tested]
|
| 315 |
+
|
| 316 |
+
print("=" * 70)
|
| 317 |
+
print("HYDRA AUTORESEARCH MUTATION LOOP")
|
| 318 |
+
print("=" * 70)
|
| 319 |
+
print(f"Time budget per run: {state['time_budget']}s")
|
| 320 |
+
print(f"TPS floor: {state['tps_floor']}")
|
| 321 |
+
print(f"Current gen: {state['current_gen']}")
|
| 322 |
+
print(f"Mutations tested: {len(tested)}/{len(MUTATIONS)}")
|
| 323 |
+
print(f"Mutations kept: {state['mutations_kept']}")
|
| 324 |
+
print(f"Remaining: {[m['name'] for m in remaining]}")
|
| 325 |
+
print()
|
| 326 |
+
|
| 327 |
+
if args.dry_run:
|
| 328 |
+
print("[DRY RUN] Would test these mutations in order:")
|
| 329 |
+
for i, m in enumerate(remaining):
|
| 330 |
+
print(f" {i + 1}. {m['name']} ({m['env']})")
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
# -----------------------------------------------------------------------
|
| 334 |
+
# Baseline (Gen 0)
|
| 335 |
+
# -----------------------------------------------------------------------
|
| 336 |
+
if state["baseline_quality"] is None:
|
| 337 |
+
print("[GEN 0] Running baseline training + evaluation...")
|
| 338 |
+
train_metrics = run_training(state["time_budget"])
|
| 339 |
+
if train_metrics is None:
|
| 340 |
+
print("[FAIL] Baseline training failed")
|
| 341 |
+
save_state(state)
|
| 342 |
+
return
|
| 343 |
+
|
| 344 |
+
print("[GEN 0] Running quality evaluation...")
|
| 345 |
+
eval_metrics = run_eval_after_training()
|
| 346 |
+
if eval_metrics is None:
|
| 347 |
+
print("[FAIL] Baseline eval failed")
|
| 348 |
+
save_state(state)
|
| 349 |
+
return
|
| 350 |
+
|
| 351 |
+
baseline_tps = train_metrics.get("tps", 0)
|
| 352 |
+
baseline_quality = eval_metrics.get("quality_score", 0)
|
| 353 |
+
|
| 354 |
+
state["baseline_quality"] = baseline_quality
|
| 355 |
+
state["baseline_tps"] = baseline_tps
|
| 356 |
+
state["current_gen"] = 0
|
| 357 |
+
state["history"].append({
|
| 358 |
+
"gen": 0,
|
| 359 |
+
"mutation": "baseline",
|
| 360 |
+
"quality_score": baseline_quality,
|
| 361 |
+
"baseline_score": baseline_quality,
|
| 362 |
+
"delta": "0.0%",
|
| 363 |
+
"tps": baseline_tps,
|
| 364 |
+
"ppl": eval_metrics.get("ppl", 0),
|
| 365 |
+
"bleu4": eval_metrics.get("bleu4", 0),
|
| 366 |
+
"rouge_l": eval_metrics.get("rouge_l", 0),
|
| 367 |
+
"factual": eval_metrics.get("factual", 0),
|
| 368 |
+
"bpb": eval_metrics.get("bpb", 0),
|
| 369 |
+
"repetition_rate": eval_metrics.get("repetition_rate", 0),
|
| 370 |
+
"kept": True,
|
| 371 |
+
})
|
| 372 |
+
save_state(state)
|
| 373 |
+
print(f"[GEN 0] BASELINE: quality={baseline_quality:.4f} tps={baseline_tps:.0f}")
|
| 374 |
+
|
| 375 |
+
if args.baseline:
|
| 376 |
+
return
|
| 377 |
+
else:
|
| 378 |
+
print(f"[RESUME] Baseline quality={state['baseline_quality']:.4f} tps={state['baseline_tps']:.0f}")
|
| 379 |
+
if args.baseline:
|
| 380 |
+
return
|
| 381 |
+
|
| 382 |
+
# -----------------------------------------------------------------------
|
| 383 |
+
# Mutation loop
|
| 384 |
+
# -----------------------------------------------------------------------
|
| 385 |
+
current_quality = state["baseline_quality"]
|
| 386 |
+
# Track best quality so far (from last kept mutation, not just baseline)
|
| 387 |
+
if state["history"]:
|
| 388 |
+
kept_entries = [h for h in state["history"] if h.get("kept")]
|
| 389 |
+
if kept_entries:
|
| 390 |
+
current_quality = kept_entries[-1]["quality_score"]
|
| 391 |
+
|
| 392 |
+
for mutation in remaining:
|
| 393 |
+
if _SHUTDOWN:
|
| 394 |
+
print("[AUTORESEARCH] Shutdown requested β saving state")
|
| 395 |
+
save_state(state)
|
| 396 |
+
return
|
| 397 |
+
|
| 398 |
+
gen = state["current_gen"] + 1
|
| 399 |
+
name = mutation["name"]
|
| 400 |
+
env_str = mutation["env"]
|
| 401 |
+
|
| 402 |
+
print(f"\n[GEN {gen}] Testing {name} ({env_str})...")
|
| 403 |
+
print(f" Current best quality: {current_quality:.4f}")
|
| 404 |
+
|
| 405 |
+
# Train with mutation
|
| 406 |
+
print(f" Training ({state['time_budget']}s)...", flush=True)
|
| 407 |
+
train_metrics = run_training(state["time_budget"], extra_env=env_str)
|
| 408 |
+
if train_metrics is None:
|
| 409 |
+
print(f" [SKIP] Training failed for {name}")
|
| 410 |
+
state["mutations_tested"].append(name)
|
| 411 |
+
state["current_gen"] = gen
|
| 412 |
+
state["history"].append({
|
| 413 |
+
"gen": gen, "mutation": name,
|
| 414 |
+
"quality_score": 0, "baseline_score": current_quality,
|
| 415 |
+
"delta": "FAIL", "tps": 0, "ppl": 0, "bleu4": 0,
|
| 416 |
+
"rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
|
| 417 |
+
"kept": False,
|
| 418 |
+
})
|
| 419 |
+
save_state(state)
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
tps = train_metrics.get("tps", 0)
|
| 423 |
+
|
| 424 |
+
# TPS floor check
|
| 425 |
+
if tps < state["tps_floor"]:
|
| 426 |
+
print(f" [REJECT] TPS={tps:.0f} < floor={state['tps_floor']} β skipping eval")
|
| 427 |
+
state["mutations_tested"].append(name)
|
| 428 |
+
state["current_gen"] = gen
|
| 429 |
+
state["history"].append({
|
| 430 |
+
"gen": gen, "mutation": name,
|
| 431 |
+
"quality_score": 0, "baseline_score": current_quality,
|
| 432 |
+
"delta": f"TPS_FAIL({tps:.0f})", "tps": tps,
|
| 433 |
+
"ppl": 0, "bleu4": 0, "rouge_l": 0, "factual": 0,
|
| 434 |
+
"bpb": train_metrics.get("val_bpb", 0), "repetition_rate": 0,
|
| 435 |
+
"kept": False,
|
| 436 |
+
})
|
| 437 |
+
save_state(state)
|
| 438 |
+
continue
|
| 439 |
+
|
| 440 |
+
# Evaluate
|
| 441 |
+
print(f" Evaluating...", flush=True)
|
| 442 |
+
eval_metrics = run_eval_after_training(extra_env=env_str)
|
| 443 |
+
if eval_metrics is None:
|
| 444 |
+
print(f" [SKIP] Eval failed for {name}")
|
| 445 |
+
state["mutations_tested"].append(name)
|
| 446 |
+
state["current_gen"] = gen
|
| 447 |
+
state["history"].append({
|
| 448 |
+
"gen": gen, "mutation": name,
|
| 449 |
+
"quality_score": 0, "baseline_score": current_quality,
|
| 450 |
+
"delta": "EVAL_FAIL", "tps": tps, "ppl": 0, "bleu4": 0,
|
| 451 |
+
"rouge_l": 0, "factual": 0, "bpb": 0, "repetition_rate": 0,
|
| 452 |
+
"kept": False,
|
| 453 |
+
})
|
| 454 |
+
save_state(state)
|
| 455 |
+
continue
|
| 456 |
+
|
| 457 |
+
quality = eval_metrics.get("quality_score", 0)
|
| 458 |
+
delta_pct = ((quality - current_quality) / max(abs(current_quality), 1e-6)) * 100
|
| 459 |
+
delta_str = f"{delta_pct:+.1f}%"
|
| 460 |
+
|
| 461 |
+
kept = quality > current_quality and tps >= state["tps_floor"]
|
| 462 |
+
status = "KEEP" if kept else "DISCARD"
|
| 463 |
+
|
| 464 |
+
entry = {
|
| 465 |
+
"gen": gen,
|
| 466 |
+
"mutation": name,
|
| 467 |
+
"quality_score": quality,
|
| 468 |
+
"baseline_score": current_quality,
|
| 469 |
+
"delta": delta_str,
|
| 470 |
+
"tps": tps,
|
| 471 |
+
"ppl": eval_metrics.get("ppl", 0),
|
| 472 |
+
"bleu4": eval_metrics.get("bleu4", 0),
|
| 473 |
+
"rouge_l": eval_metrics.get("rouge_l", 0),
|
| 474 |
+
"factual": eval_metrics.get("factual", 0),
|
| 475 |
+
"bpb": eval_metrics.get("bpb", 0),
|
| 476 |
+
"repetition_rate": eval_metrics.get("repetition_rate", 0),
|
| 477 |
+
"kept": kept,
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
print(f"\n[GEN {gen}] {name}: quality={quality:.4f} ({delta_str}) tps={tps:.0f} -> {status}")
|
| 481 |
+
|
| 482 |
+
if kept:
|
| 483 |
+
current_quality = quality
|
| 484 |
+
state["mutations_kept"].append(name)
|
| 485 |
+
git_commit(f"autoresearch: gen {gen} β {name} quality {delta_str}")
|
| 486 |
+
|
| 487 |
+
state["mutations_tested"].append(name)
|
| 488 |
+
state["current_gen"] = gen
|
| 489 |
+
state["history"].append(entry)
|
| 490 |
+
save_state(state)
|
| 491 |
+
|
| 492 |
+
# -----------------------------------------------------------------------
|
| 493 |
+
# Summary
|
| 494 |
+
# -----------------------------------------------------------------------
|
| 495 |
+
print("\n" + "=" * 70)
|
| 496 |
+
print("AUTORESEARCH COMPLETE")
|
| 497 |
+
print("=" * 70)
|
| 498 |
+
print(f"Total generations: {state['current_gen']}")
|
| 499 |
+
print(f"Mutations kept: {state['mutations_kept']}")
|
| 500 |
+
print(f"Final quality: {current_quality:.4f}")
|
| 501 |
+
if state["baseline_quality"]:
|
| 502 |
+
total_delta = ((current_quality - state["baseline_quality"]) /
|
| 503 |
+
max(abs(state["baseline_quality"]), 1e-6)) * 100
|
| 504 |
+
print(f"Total improvement: {total_delta:+.1f}%")
|
| 505 |
+
print()
|
| 506 |
+
|
| 507 |
+
# Print history table
|
| 508 |
+
print(f"{'Gen':>4} {'Mutation':>20} {'Quality':>8} {'Delta':>8} {'TPS':>7} {'PPL':>8} {'BPB':>7} {'Kept':>5}")
|
| 509 |
+
print("-" * 75)
|
| 510 |
+
for h in state["history"]:
|
| 511 |
+
print(f"{h['gen']:4d} {h['mutation']:>20s} {h['quality_score']:8.4f} "
|
| 512 |
+
f"{h['delta']:>8s} {h['tps']:7.0f} {h['ppl']:8.2f} "
|
| 513 |
+
f"{h.get('bpb', 0):7.4f} {' YES' if h['kept'] else ' NO'}")
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
if __name__ == "__main__":
|
| 517 |
+
main()
|
overlay/scripts/chat.py
CHANGED
|
@@ -1,458 +1,458 @@
|
|
| 1 |
-
"""Interactive chat REPL for HYDRA.
|
| 2 |
-
|
| 3 |
-
Usage:
|
| 4 |
-
python scripts/chat.py # auto-select best checkpoint
|
| 5 |
-
python scripts/chat.py --ckpt PATH # explicit checkpoint
|
| 6 |
-
python scripts/chat.py --sft # prefer sft_final.pt
|
| 7 |
-
python scripts/chat.py --random # skip ckpt, use random weights
|
| 8 |
-
|
| 9 |
-
HONESTY: model is ~7.5M params at d_model=256/n_layer=4. Expect incoherent
|
| 10 |
-
output. This REPL validates the *interface* β tokenizer roundtrip, generation
|
| 11 |
-
loop, stop-token handling, conversation history truncation. Coherent dialogue
|
| 12 |
-
is not a goal at this scale.
|
| 13 |
-
|
| 14 |
-
Slash commands:
|
| 15 |
-
/reset clear conversation history
|
| 16 |
-
/quit exit
|
| 17 |
-
/temp X set temperature (default 0.8)
|
| 18 |
-
/topk K set top-k (default 40)
|
| 19 |
-
/topp P set top-p (default 0.9)
|
| 20 |
-
/max N set max new tokens per turn (default 200)
|
| 21 |
-
/rep R set repetition penalty (default 1.1)
|
| 22 |
-
/sys S set a system prefix prepended to every turn
|
| 23 |
-
/info print current settings + checkpoint path
|
| 24 |
-
"""
|
| 25 |
-
|
| 26 |
-
from __future__ import annotations
|
| 27 |
-
|
| 28 |
-
import argparse
|
| 29 |
-
import os
|
| 30 |
-
import sys
|
| 31 |
-
import time
|
| 32 |
-
from dataclasses import asdict
|
| 33 |
-
from pathlib import Path
|
| 34 |
-
|
| 35 |
-
# Make repo root importable when invoked as `python scripts/chat.py`.
|
| 36 |
-
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 37 |
-
if str(_REPO_ROOT) not in sys.path:
|
| 38 |
-
sys.path.insert(0, str(_REPO_ROOT))
|
| 39 |
-
|
| 40 |
-
import torch # noqa: E402
|
| 41 |
-
|
| 42 |
-
# Chat template β plain-text fallback (see .omc/chat_plan.md).
|
| 43 |
-
# If the SFT agent later reserves special tokens, redefine USER_TAG /
|
| 44 |
-
# ASSISTANT_TAG / END_TAG and the stop-string accordingly.
|
| 45 |
-
USER_TAG = "User:"
|
| 46 |
-
ASSISTANT_TAG = "Assistant:"
|
| 47 |
-
END_TAG = "\nUser:" # stop-string matched on decoded output
|
| 48 |
-
|
| 49 |
-
CKPT_DIR = Path(os.path.expanduser("~/.cache/autoresearch/ckpts"))
|
| 50 |
-
CKPT_CANDIDATES_PRETRAIN = ["pretrain_final.pt", "latest.pt"]
|
| 51 |
-
CKPT_CANDIDATES_SFT = ["sft_final.pt"]
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
# ---------------------------------------------------------------------------
|
| 55 |
-
# Checkpoint resolution
|
| 56 |
-
# ---------------------------------------------------------------------------
|
| 57 |
-
|
| 58 |
-
def resolve_checkpoint(explicit: str | None, prefer_sft: bool) -> Path | None:
|
| 59 |
-
"""Return Path to checkpoint file, or None if nothing found.
|
| 60 |
-
|
| 61 |
-
Order:
|
| 62 |
-
1. `explicit` if provided and exists.
|
| 63 |
-
2. If prefer_sft: sft_final.pt -> pretrain_final.pt -> latest.pt.
|
| 64 |
-
3. Else: sft_final.pt (if exists) -> pretrain_final.pt -> latest.pt.
|
| 65 |
-
"""
|
| 66 |
-
if explicit:
|
| 67 |
-
p = Path(os.path.expanduser(explicit))
|
| 68 |
-
if p.exists():
|
| 69 |
-
return p
|
| 70 |
-
print(f"[WARN] --ckpt {p} does not exist; falling through to auto-select.", file=sys.stderr)
|
| 71 |
-
|
| 72 |
-
# Task spec: prefer sft_final.pt if it exists; otherwise pretrain_final.pt
|
| 73 |
-
# then latest.pt. --sft just makes the preference explicit; it's already
|
| 74 |
-
# the default behavior. We list SFT first in both orderings to honor the
|
| 75 |
-
# spec, since the task description said "prefer sft if exists" by default.
|
| 76 |
-
_ = prefer_sft # reserved for future "pretrain-only" vs "sft-only" modes
|
| 77 |
-
order = CKPT_CANDIDATES_SFT + CKPT_CANDIDATES_PRETRAIN
|
| 78 |
-
for name in order:
|
| 79 |
-
cand = CKPT_DIR / name
|
| 80 |
-
if cand.exists():
|
| 81 |
-
return cand
|
| 82 |
-
return None
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# ---------------------------------------------------------------------------
|
| 86 |
-
# Model + tokenizer loading
|
| 87 |
-
# ---------------------------------------------------------------------------
|
| 88 |
-
|
| 89 |
-
def load_model_and_tokenizer(ckpt_path: Path | None, device: torch.device):
|
| 90 |
-
"""Build model + tokenizer. If ckpt_path is None, random weights are used.
|
| 91 |
-
|
| 92 |
-
Returns (model, tokenizer, meta) where meta is a dict with 'ckpt',
|
| 93 |
-
'step', 'val_bpb' etc. for /info display.
|
| 94 |
-
"""
|
| 95 |
-
from hydra.config import PostSemClawConfig
|
| 96 |
-
from hydra.model import PostSemClawModel
|
| 97 |
-
from prepare import Tokenizer
|
| 98 |
-
|
| 99 |
-
tokenizer = Tokenizer.from_directory()
|
| 100 |
-
vocab_size = tokenizer.get_vocab_size()
|
| 101 |
-
print(f"[chat] Tokenizer loaded (vocab={vocab_size:,})")
|
| 102 |
-
|
| 103 |
-
meta: dict = {"ckpt": str(ckpt_path) if ckpt_path else "<random>", "step": None, "val_bpb": None}
|
| 104 |
-
|
| 105 |
-
# Build config. If checkpoint provides one, use it; else use env-var defaults.
|
| 106 |
-
ckpt_state = None
|
| 107 |
-
config_kwargs: dict = {}
|
| 108 |
-
if ckpt_path is not None:
|
| 109 |
-
print(f"[chat] Loading checkpoint: {ckpt_path}")
|
| 110 |
-
ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 111 |
-
cfg_dict = ckpt_state.get("config")
|
| 112 |
-
if isinstance(cfg_dict, dict):
|
| 113 |
-
# Filter to kwargs PostSemClawConfig actually accepts.
|
| 114 |
-
allowed = set(PostSemClawConfig.__dataclass_fields__.keys())
|
| 115 |
-
config_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
|
| 116 |
-
meta["step"] = ckpt_state.get("step")
|
| 117 |
-
meta["val_bpb"] = ckpt_state.get("val_bpb") or ckpt_state.get("bpb")
|
| 118 |
-
|
| 119 |
-
# Env-var defaults are applied by PostSemClawConfig field defaults; but the
|
| 120 |
-
# training run builds the config explicitly from hydra.config module-level
|
| 121 |
-
# constants. We mirror that here so the random-weights path aligns with
|
| 122 |
-
# what train.py would instantiate for the same env.
|
| 123 |
-
if not config_kwargs:
|
| 124 |
-
from hydra.config import ( # noqa: E402
|
| 125 |
-
D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX,
|
| 126 |
-
ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER,
|
| 127 |
-
)
|
| 128 |
-
from prepare import MAX_SEQ_LEN # noqa: E402
|
| 129 |
-
config_kwargs = dict(
|
| 130 |
-
sequence_len=MAX_SEQ_LEN,
|
| 131 |
-
vocab_size=vocab_size,
|
| 132 |
-
n_layer=N_LAYER,
|
| 133 |
-
d_model=D_MODEL,
|
| 134 |
-
d_state=D_STATE,
|
| 135 |
-
headdim=HEADDIM,
|
| 136 |
-
n_heads=N_HEADS,
|
| 137 |
-
expand=EXPAND,
|
| 138 |
-
engram_n_columns=ENGRAM_N_COLUMNS,
|
| 139 |
-
engram_key_dim=ENGRAM_KEY_DIM,
|
| 140 |
-
engram_layer_idx=ENGRAM_LAYER_IDX,
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
# Build model on meta device then materialize β matches training.py path.
|
| 144 |
-
with torch.device("meta"):
|
| 145 |
-
model = PostSemClawModel(PostSemClawConfig(**config_kwargs))
|
| 146 |
-
model.to_empty(device=device)
|
| 147 |
-
model.init_weights()
|
| 148 |
-
|
| 149 |
-
if ckpt_state is not None and "model_state_dict" in ckpt_state:
|
| 150 |
-
# strict=False: the model has non-parameter buffers (SDR retina loaded
|
| 151 |
-
# from npz, HTM Rust-side state, engram EMA stats) that may not be in
|
| 152 |
-
# the state_dict. missing/unexpected-key warnings are expected and OK.
|
| 153 |
-
missing, unexpected = model.load_state_dict(
|
| 154 |
-
ckpt_state["model_state_dict"], strict=False
|
| 155 |
-
)
|
| 156 |
-
if missing:
|
| 157 |
-
print(f"[chat] Note: {len(missing)} missing key(s) in state_dict (expected for HTM/SDR buffers).")
|
| 158 |
-
if unexpected:
|
| 159 |
-
print(f"[chat] Note: {len(unexpected)} unexpected key(s) in state_dict.")
|
| 160 |
-
elif ckpt_path is None:
|
| 161 |
-
print("[chat] [WARN] NO CHECKPOINT β using random weights. Output will be gibberish.", file=sys.stderr)
|
| 162 |
-
|
| 163 |
-
model.eval()
|
| 164 |
-
return model, tokenizer, meta
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
# ---------------------------------------------------------------------------
|
| 168 |
-
# Generation
|
| 169 |
-
# ---------------------------------------------------------------------------
|
| 170 |
-
|
| 171 |
-
def generate_stream(
|
| 172 |
-
model,
|
| 173 |
-
tokenizer,
|
| 174 |
-
prompt_ids: list[int],
|
| 175 |
-
*,
|
| 176 |
-
max_new_tokens: int,
|
| 177 |
-
temperature: float,
|
| 178 |
-
top_k: int,
|
| 179 |
-
top_p: float,
|
| 180 |
-
repetition_penalty: float,
|
| 181 |
-
stop_strings: tuple[str, ...],
|
| 182 |
-
max_seq_len: int,
|
| 183 |
-
device: torch.device,
|
| 184 |
-
rep_window: int = 64,
|
| 185 |
-
):
|
| 186 |
-
"""Yield decoded-text chunks as tokens are generated.
|
| 187 |
-
|
| 188 |
-
Truncates `prompt_ids` to the last `max_seq_len` tokens if needed. Stops
|
| 189 |
-
early when any `stop_strings` substring appears in the newly-decoded
|
| 190 |
-
continuation.
|
| 191 |
-
"""
|
| 192 |
-
from scripts.sample_utils import sample_token
|
| 193 |
-
|
| 194 |
-
# Truncate prompt to window.
|
| 195 |
-
if len(prompt_ids) > max_seq_len:
|
| 196 |
-
prompt_ids = prompt_ids[-max_seq_len:]
|
| 197 |
-
|
| 198 |
-
ctx = torch.tensor([prompt_ids], device=device, dtype=torch.long)
|
| 199 |
-
generated: list[int] = []
|
| 200 |
-
# Track already-streamed byte length so we can detect when the decoded
|
| 201 |
-
# string has grown (BPE tokens may decode to multi-char strings mid-merge).
|
| 202 |
-
streamed_chars = 0
|
| 203 |
-
accumulated_text = ""
|
| 204 |
-
|
| 205 |
-
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 206 |
-
|
| 207 |
-
for _ in range(max_new_tokens):
|
| 208 |
-
with torch.no_grad(), autocast_ctx:
|
| 209 |
-
out = model(ctx, targets=None)
|
| 210 |
-
# out shape: (1, T, vocab) or (1, vocab) depending on path.
|
| 211 |
-
if out.dim() == 3:
|
| 212 |
-
last_logits = out[0, -1, :]
|
| 213 |
-
else:
|
| 214 |
-
last_logits = out[0]
|
| 215 |
-
|
| 216 |
-
recent = generated[-rep_window:] if generated else None
|
| 217 |
-
next_id = sample_token(
|
| 218 |
-
last_logits,
|
| 219 |
-
temperature=temperature,
|
| 220 |
-
top_k=top_k,
|
| 221 |
-
top_p=top_p,
|
| 222 |
-
repetition_penalty=repetition_penalty,
|
| 223 |
-
recent_tokens=recent,
|
| 224 |
-
)
|
| 225 |
-
generated.append(next_id)
|
| 226 |
-
|
| 227 |
-
# Decode everything so-far then diff β BPE decoding is not token-local,
|
| 228 |
-
# so a per-token decode can drop bytes.
|
| 229 |
-
new_text = tokenizer.decode(generated)
|
| 230 |
-
delta = new_text[streamed_chars:]
|
| 231 |
-
if delta:
|
| 232 |
-
streamed_chars = len(new_text)
|
| 233 |
-
accumulated_text = new_text
|
| 234 |
-
yield delta
|
| 235 |
-
|
| 236 |
-
# Stop-string check.
|
| 237 |
-
hit_stop = any(s and s in accumulated_text for s in stop_strings)
|
| 238 |
-
if hit_stop:
|
| 239 |
-
break
|
| 240 |
-
|
| 241 |
-
# Advance context. If we've filled the window, drop oldest token.
|
| 242 |
-
ctx = torch.cat([ctx, torch.tensor([[next_id]], device=device, dtype=torch.long)], dim=1)
|
| 243 |
-
if ctx.size(1) > max_seq_len:
|
| 244 |
-
ctx = ctx[:, -max_seq_len:]
|
| 245 |
-
|
| 246 |
-
# Final accumulated text is also returned for history tracking.
|
| 247 |
-
return accumulated_text # noqa: B901 (generator return for history)
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
def _consume_stream_with_print(stream_gen):
|
| 251 |
-
"""Iterate a generator, print each chunk, return the full text.
|
| 252 |
-
|
| 253 |
-
Replacement for a naΓ―ve list(stream) since `generate_stream` is a generator
|
| 254 |
-
that yields then returns the final text.
|
| 255 |
-
"""
|
| 256 |
-
collected = []
|
| 257 |
-
try:
|
| 258 |
-
while True:
|
| 259 |
-
chunk = next(stream_gen)
|
| 260 |
-
collected.append(chunk)
|
| 261 |
-
sys.stdout.write(chunk)
|
| 262 |
-
sys.stdout.flush()
|
| 263 |
-
except StopIteration as stop:
|
| 264 |
-
# stop.value holds the return value of the generator.
|
| 265 |
-
final = stop.value
|
| 266 |
-
if final is not None:
|
| 267 |
-
return final
|
| 268 |
-
return "".join(collected)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
# ---------------------------------------------------------------------------
|
| 272 |
-
# REPL
|
| 273 |
-
# ---------------------------------------------------------------------------
|
| 274 |
-
|
| 275 |
-
def build_prompt(system: str, history: list[tuple[str, str]], user_msg: str) -> str:
|
| 276 |
-
"""Assemble the text prompt fed to the tokenizer."""
|
| 277 |
-
parts: list[str] = []
|
| 278 |
-
if system:
|
| 279 |
-
parts.append(system.rstrip() + "\n")
|
| 280 |
-
for u, a in history:
|
| 281 |
-
parts.append(f"{USER_TAG} {u}\n{ASSISTANT_TAG} {a}\n")
|
| 282 |
-
parts.append(f"{USER_TAG} {user_msg}\n{ASSISTANT_TAG}")
|
| 283 |
-
return "".join(parts)
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
def run_repl(
|
| 287 |
-
model,
|
| 288 |
-
tokenizer,
|
| 289 |
-
meta: dict,
|
| 290 |
-
*,
|
| 291 |
-
device: torch.device,
|
| 292 |
-
max_seq_len: int,
|
| 293 |
-
) -> None:
|
| 294 |
-
settings = {
|
| 295 |
-
"temperature": float(os.environ.get("HYDRA_CHAT_TEMP", "0.8")),
|
| 296 |
-
"top_k": int(os.environ.get("HYDRA_CHAT_TOPK", "40")),
|
| 297 |
-
"top_p": float(os.environ.get("HYDRA_CHAT_TOPP", "0.9")),
|
| 298 |
-
"max_new_tokens": int(os.environ.get("HYDRA_CHAT_MAX", "200")),
|
| 299 |
-
"repetition_penalty": float(os.environ.get("HYDRA_CHAT_REP", "1.1")),
|
| 300 |
-
"system": os.environ.get("HYDRA_CHAT_SYSTEM", ""),
|
| 301 |
-
}
|
| 302 |
-
history: list[tuple[str, str]] = []
|
| 303 |
-
|
| 304 |
-
print()
|
| 305 |
-
print("=" * 60)
|
| 306 |
-
print("HYDRA chat REPL")
|
| 307 |
-
print(f" checkpoint: {meta['ckpt']}")
|
| 308 |
-
if meta.get("step") is not None:
|
| 309 |
-
print(f" step: {meta['step']}")
|
| 310 |
-
if meta.get("val_bpb") is not None:
|
| 311 |
-
print(f" val_bpb: {meta['val_bpb']}")
|
| 312 |
-
print(" type /info for settings, /quit to exit")
|
| 313 |
-
print("=" * 60)
|
| 314 |
-
print()
|
| 315 |
-
|
| 316 |
-
while True:
|
| 317 |
-
try:
|
| 318 |
-
line = input(f"{USER_TAG} ")
|
| 319 |
-
except (EOFError, KeyboardInterrupt):
|
| 320 |
-
print()
|
| 321 |
-
return
|
| 322 |
-
|
| 323 |
-
line = line.rstrip()
|
| 324 |
-
if not line:
|
| 325 |
-
continue
|
| 326 |
-
|
| 327 |
-
if line.startswith("/"):
|
| 328 |
-
cmd, *rest = line.split(maxsplit=1)
|
| 329 |
-
arg = rest[0] if rest else ""
|
| 330 |
-
if cmd == "/quit" or cmd == "/exit":
|
| 331 |
-
return
|
| 332 |
-
elif cmd == "/reset":
|
| 333 |
-
history = []
|
| 334 |
-
print("[reset]")
|
| 335 |
-
continue
|
| 336 |
-
elif cmd == "/info":
|
| 337 |
-
print(f"[info] ckpt={meta['ckpt']} settings={settings} history_turns={len(history)}")
|
| 338 |
-
continue
|
| 339 |
-
elif cmd == "/temp":
|
| 340 |
-
try:
|
| 341 |
-
settings["temperature"] = float(arg)
|
| 342 |
-
print(f"[temp={settings['temperature']}]")
|
| 343 |
-
except ValueError:
|
| 344 |
-
print(f"[err] /temp needs a float, got {arg!r}")
|
| 345 |
-
continue
|
| 346 |
-
elif cmd == "/topk":
|
| 347 |
-
try:
|
| 348 |
-
settings["top_k"] = int(arg)
|
| 349 |
-
print(f"[topk={settings['top_k']}]")
|
| 350 |
-
except ValueError:
|
| 351 |
-
print(f"[err] /topk needs an int, got {arg!r}")
|
| 352 |
-
continue
|
| 353 |
-
elif cmd == "/topp":
|
| 354 |
-
try:
|
| 355 |
-
settings["top_p"] = float(arg)
|
| 356 |
-
print(f"[topp={settings['top_p']}]")
|
| 357 |
-
except ValueError:
|
| 358 |
-
print(f"[err] /topp needs a float, got {arg!r}")
|
| 359 |
-
continue
|
| 360 |
-
elif cmd == "/max":
|
| 361 |
-
try:
|
| 362 |
-
settings["max_new_tokens"] = int(arg)
|
| 363 |
-
print(f"[max={settings['max_new_tokens']}]")
|
| 364 |
-
except ValueError:
|
| 365 |
-
print(f"[err] /max needs an int, got {arg!r}")
|
| 366 |
-
continue
|
| 367 |
-
elif cmd == "/rep":
|
| 368 |
-
try:
|
| 369 |
-
settings["repetition_penalty"] = float(arg)
|
| 370 |
-
print(f"[rep={settings['repetition_penalty']}]")
|
| 371 |
-
except ValueError:
|
| 372 |
-
print(f"[err] /rep needs a float, got {arg!r}")
|
| 373 |
-
continue
|
| 374 |
-
elif cmd == "/sys":
|
| 375 |
-
settings["system"] = arg
|
| 376 |
-
print(f"[sys set, {len(arg)} chars]")
|
| 377 |
-
continue
|
| 378 |
-
else:
|
| 379 |
-
print(f"[err] unknown command {cmd!r}. Try /info /reset /quit.")
|
| 380 |
-
continue
|
| 381 |
-
|
| 382 |
-
# Normal chat turn.
|
| 383 |
-
prompt_text = build_prompt(settings["system"], history, line)
|
| 384 |
-
prompt_ids = tokenizer.encode(prompt_text)
|
| 385 |
-
|
| 386 |
-
sys.stdout.write(f"{ASSISTANT_TAG} ")
|
| 387 |
-
sys.stdout.flush()
|
| 388 |
-
|
| 389 |
-
stream = generate_stream(
|
| 390 |
-
model, tokenizer, prompt_ids,
|
| 391 |
-
max_new_tokens=settings["max_new_tokens"],
|
| 392 |
-
temperature=settings["temperature"],
|
| 393 |
-
top_k=settings["top_k"],
|
| 394 |
-
top_p=settings["top_p"],
|
| 395 |
-
repetition_penalty=settings["repetition_penalty"],
|
| 396 |
-
stop_strings=(END_TAG,),
|
| 397 |
-
max_seq_len=max_seq_len,
|
| 398 |
-
device=device,
|
| 399 |
-
)
|
| 400 |
-
response_text = _consume_stream_with_print(stream)
|
| 401 |
-
if not response_text.endswith("\n"):
|
| 402 |
-
sys.stdout.write("\n")
|
| 403 |
-
sys.stdout.flush()
|
| 404 |
-
|
| 405 |
-
# Strip trailing stop marker from the remembered history.
|
| 406 |
-
clean = response_text
|
| 407 |
-
if END_TAG in clean:
|
| 408 |
-
clean = clean.split(END_TAG, 1)[0]
|
| 409 |
-
clean = clean.strip()
|
| 410 |
-
history.append((line, clean))
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
# ---------------------------------------------------------------------------
|
| 414 |
-
# CLI
|
| 415 |
-
# ---------------------------------------------------------------------------
|
| 416 |
-
|
| 417 |
-
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 418 |
-
p = argparse.ArgumentParser(description="HYDRA chat REPL")
|
| 419 |
-
p.add_argument("--ckpt", type=str, default=None,
|
| 420 |
-
help="Path to checkpoint (.pt). If omitted, auto-select.")
|
| 421 |
-
p.add_argument("--sft", action="store_true",
|
| 422 |
-
help="Prefer an SFT checkpoint if available.")
|
| 423 |
-
p.add_argument("--random", action="store_true",
|
| 424 |
-
help="Skip checkpoint load; use random weights.")
|
| 425 |
-
p.add_argument("--device", type=str, default=None,
|
| 426 |
-
help="Torch device (default: cuda if available else cpu).")
|
| 427 |
-
return p.parse_args(argv)
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
def main(argv: list[str] | None = None) -> int:
|
| 431 |
-
args = _parse_args(argv)
|
| 432 |
-
|
| 433 |
-
if args.device:
|
| 434 |
-
device = torch.device(args.device)
|
| 435 |
-
elif torch.cuda.is_available():
|
| 436 |
-
device = torch.device("cuda")
|
| 437 |
-
else:
|
| 438 |
-
device = torch.device("cpu")
|
| 439 |
-
print("[chat] [WARN] CUDA not available; HYDRA's HTM/Mamba kernels may fail on CPU.", file=sys.stderr)
|
| 440 |
-
|
| 441 |
-
ckpt_path: Path | None
|
| 442 |
-
if args.random:
|
| 443 |
-
ckpt_path = None
|
| 444 |
-
else:
|
| 445 |
-
ckpt_path = resolve_checkpoint(args.ckpt, args.sft)
|
| 446 |
-
|
| 447 |
-
t0 = time.time()
|
| 448 |
-
model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
|
| 449 |
-
dt = time.time() - t0
|
| 450 |
-
print(f"[chat] Model ready in {dt:.1f}s on {device}")
|
| 451 |
-
|
| 452 |
-
from prepare import MAX_SEQ_LEN
|
| 453 |
-
run_repl(model, tokenizer, meta, device=device, max_seq_len=MAX_SEQ_LEN)
|
| 454 |
-
return 0
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
if __name__ == "__main__":
|
| 458 |
-
sys.exit(main())
|
|
|
|
| 1 |
+
"""Interactive chat REPL for HYDRA.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
python scripts/chat.py # auto-select best checkpoint
|
| 5 |
+
python scripts/chat.py --ckpt PATH # explicit checkpoint
|
| 6 |
+
python scripts/chat.py --sft # prefer sft_final.pt
|
| 7 |
+
python scripts/chat.py --random # skip ckpt, use random weights
|
| 8 |
+
|
| 9 |
+
HONESTY: model is ~7.5M params at d_model=256/n_layer=4. Expect incoherent
|
| 10 |
+
output. This REPL validates the *interface* β tokenizer roundtrip, generation
|
| 11 |
+
loop, stop-token handling, conversation history truncation. Coherent dialogue
|
| 12 |
+
is not a goal at this scale.
|
| 13 |
+
|
| 14 |
+
Slash commands:
|
| 15 |
+
/reset clear conversation history
|
| 16 |
+
/quit exit
|
| 17 |
+
/temp X set temperature (default 0.8)
|
| 18 |
+
/topk K set top-k (default 40)
|
| 19 |
+
/topp P set top-p (default 0.9)
|
| 20 |
+
/max N set max new tokens per turn (default 200)
|
| 21 |
+
/rep R set repetition penalty (default 1.1)
|
| 22 |
+
/sys S set a system prefix prepended to every turn
|
| 23 |
+
/info print current settings + checkpoint path
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import os
|
| 30 |
+
import sys
|
| 31 |
+
import time
|
| 32 |
+
from dataclasses import asdict
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
# Make repo root importable when invoked as `python scripts/chat.py`.
|
| 36 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 37 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 38 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 39 |
+
|
| 40 |
+
import torch # noqa: E402
|
| 41 |
+
|
| 42 |
+
# Chat template β plain-text fallback (see .omc/chat_plan.md).
|
| 43 |
+
# If the SFT agent later reserves special tokens, redefine USER_TAG /
|
| 44 |
+
# ASSISTANT_TAG / END_TAG and the stop-string accordingly.
|
| 45 |
+
USER_TAG = "User:"
|
| 46 |
+
ASSISTANT_TAG = "Assistant:"
|
| 47 |
+
END_TAG = "\nUser:" # stop-string matched on decoded output
|
| 48 |
+
|
| 49 |
+
CKPT_DIR = Path(os.path.expanduser("~/.cache/autoresearch/ckpts"))
|
| 50 |
+
CKPT_CANDIDATES_PRETRAIN = ["pretrain_final.pt", "latest.pt"]
|
| 51 |
+
CKPT_CANDIDATES_SFT = ["sft_final.pt"]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Checkpoint resolution
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def resolve_checkpoint(explicit: str | None, prefer_sft: bool) -> Path | None:
|
| 59 |
+
"""Return Path to checkpoint file, or None if nothing found.
|
| 60 |
+
|
| 61 |
+
Order:
|
| 62 |
+
1. `explicit` if provided and exists.
|
| 63 |
+
2. If prefer_sft: sft_final.pt -> pretrain_final.pt -> latest.pt.
|
| 64 |
+
3. Else: sft_final.pt (if exists) -> pretrain_final.pt -> latest.pt.
|
| 65 |
+
"""
|
| 66 |
+
if explicit:
|
| 67 |
+
p = Path(os.path.expanduser(explicit))
|
| 68 |
+
if p.exists():
|
| 69 |
+
return p
|
| 70 |
+
print(f"[WARN] --ckpt {p} does not exist; falling through to auto-select.", file=sys.stderr)
|
| 71 |
+
|
| 72 |
+
# Task spec: prefer sft_final.pt if it exists; otherwise pretrain_final.pt
|
| 73 |
+
# then latest.pt. --sft just makes the preference explicit; it's already
|
| 74 |
+
# the default behavior. We list SFT first in both orderings to honor the
|
| 75 |
+
# spec, since the task description said "prefer sft if exists" by default.
|
| 76 |
+
_ = prefer_sft # reserved for future "pretrain-only" vs "sft-only" modes
|
| 77 |
+
order = CKPT_CANDIDATES_SFT + CKPT_CANDIDATES_PRETRAIN
|
| 78 |
+
for name in order:
|
| 79 |
+
cand = CKPT_DIR / name
|
| 80 |
+
if cand.exists():
|
| 81 |
+
return cand
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
# Model + tokenizer loading
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
def load_model_and_tokenizer(ckpt_path: Path | None, device: torch.device):
|
| 90 |
+
"""Build model + tokenizer. If ckpt_path is None, random weights are used.
|
| 91 |
+
|
| 92 |
+
Returns (model, tokenizer, meta) where meta is a dict with 'ckpt',
|
| 93 |
+
'step', 'val_bpb' etc. for /info display.
|
| 94 |
+
"""
|
| 95 |
+
from hydra.config import PostSemClawConfig
|
| 96 |
+
from hydra.model import PostSemClawModel
|
| 97 |
+
from prepare import Tokenizer
|
| 98 |
+
|
| 99 |
+
tokenizer = Tokenizer.from_directory()
|
| 100 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 101 |
+
print(f"[chat] Tokenizer loaded (vocab={vocab_size:,})")
|
| 102 |
+
|
| 103 |
+
meta: dict = {"ckpt": str(ckpt_path) if ckpt_path else "<random>", "step": None, "val_bpb": None}
|
| 104 |
+
|
| 105 |
+
# Build config. If checkpoint provides one, use it; else use env-var defaults.
|
| 106 |
+
ckpt_state = None
|
| 107 |
+
config_kwargs: dict = {}
|
| 108 |
+
if ckpt_path is not None:
|
| 109 |
+
print(f"[chat] Loading checkpoint: {ckpt_path}")
|
| 110 |
+
ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=False)
|
| 111 |
+
cfg_dict = ckpt_state.get("config")
|
| 112 |
+
if isinstance(cfg_dict, dict):
|
| 113 |
+
# Filter to kwargs PostSemClawConfig actually accepts.
|
| 114 |
+
allowed = set(PostSemClawConfig.__dataclass_fields__.keys())
|
| 115 |
+
config_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
|
| 116 |
+
meta["step"] = ckpt_state.get("step")
|
| 117 |
+
meta["val_bpb"] = ckpt_state.get("val_bpb") or ckpt_state.get("bpb")
|
| 118 |
+
|
| 119 |
+
# Env-var defaults are applied by PostSemClawConfig field defaults; but the
|
| 120 |
+
# training run builds the config explicitly from hydra.config module-level
|
| 121 |
+
# constants. We mirror that here so the random-weights path aligns with
|
| 122 |
+
# what train.py would instantiate for the same env.
|
| 123 |
+
if not config_kwargs:
|
| 124 |
+
from hydra.config import ( # noqa: E402
|
| 125 |
+
D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX,
|
| 126 |
+
ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER,
|
| 127 |
+
)
|
| 128 |
+
from prepare import MAX_SEQ_LEN # noqa: E402
|
| 129 |
+
config_kwargs = dict(
|
| 130 |
+
sequence_len=MAX_SEQ_LEN,
|
| 131 |
+
vocab_size=vocab_size,
|
| 132 |
+
n_layer=N_LAYER,
|
| 133 |
+
d_model=D_MODEL,
|
| 134 |
+
d_state=D_STATE,
|
| 135 |
+
headdim=HEADDIM,
|
| 136 |
+
n_heads=N_HEADS,
|
| 137 |
+
expand=EXPAND,
|
| 138 |
+
engram_n_columns=ENGRAM_N_COLUMNS,
|
| 139 |
+
engram_key_dim=ENGRAM_KEY_DIM,
|
| 140 |
+
engram_layer_idx=ENGRAM_LAYER_IDX,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Build model on meta device then materialize β matches training.py path.
|
| 144 |
+
with torch.device("meta"):
|
| 145 |
+
model = PostSemClawModel(PostSemClawConfig(**config_kwargs))
|
| 146 |
+
model.to_empty(device=device)
|
| 147 |
+
model.init_weights()
|
| 148 |
+
|
| 149 |
+
if ckpt_state is not None and "model_state_dict" in ckpt_state:
|
| 150 |
+
# strict=False: the model has non-parameter buffers (SDR retina loaded
|
| 151 |
+
# from npz, HTM Rust-side state, engram EMA stats) that may not be in
|
| 152 |
+
# the state_dict. missing/unexpected-key warnings are expected and OK.
|
| 153 |
+
missing, unexpected = model.load_state_dict(
|
| 154 |
+
ckpt_state["model_state_dict"], strict=False
|
| 155 |
+
)
|
| 156 |
+
if missing:
|
| 157 |
+
print(f"[chat] Note: {len(missing)} missing key(s) in state_dict (expected for HTM/SDR buffers).")
|
| 158 |
+
if unexpected:
|
| 159 |
+
print(f"[chat] Note: {len(unexpected)} unexpected key(s) in state_dict.")
|
| 160 |
+
elif ckpt_path is None:
|
| 161 |
+
print("[chat] [WARN] NO CHECKPOINT β using random weights. Output will be gibberish.", file=sys.stderr)
|
| 162 |
+
|
| 163 |
+
model.eval()
|
| 164 |
+
return model, tokenizer, meta
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
# Generation
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
|
| 171 |
+
def generate_stream(
|
| 172 |
+
model,
|
| 173 |
+
tokenizer,
|
| 174 |
+
prompt_ids: list[int],
|
| 175 |
+
*,
|
| 176 |
+
max_new_tokens: int,
|
| 177 |
+
temperature: float,
|
| 178 |
+
top_k: int,
|
| 179 |
+
top_p: float,
|
| 180 |
+
repetition_penalty: float,
|
| 181 |
+
stop_strings: tuple[str, ...],
|
| 182 |
+
max_seq_len: int,
|
| 183 |
+
device: torch.device,
|
| 184 |
+
rep_window: int = 64,
|
| 185 |
+
):
|
| 186 |
+
"""Yield decoded-text chunks as tokens are generated.
|
| 187 |
+
|
| 188 |
+
Truncates `prompt_ids` to the last `max_seq_len` tokens if needed. Stops
|
| 189 |
+
early when any `stop_strings` substring appears in the newly-decoded
|
| 190 |
+
continuation.
|
| 191 |
+
"""
|
| 192 |
+
from scripts.sample_utils import sample_token
|
| 193 |
+
|
| 194 |
+
# Truncate prompt to window.
|
| 195 |
+
if len(prompt_ids) > max_seq_len:
|
| 196 |
+
prompt_ids = prompt_ids[-max_seq_len:]
|
| 197 |
+
|
| 198 |
+
ctx = torch.tensor([prompt_ids], device=device, dtype=torch.long)
|
| 199 |
+
generated: list[int] = []
|
| 200 |
+
# Track already-streamed byte length so we can detect when the decoded
|
| 201 |
+
# string has grown (BPE tokens may decode to multi-char strings mid-merge).
|
| 202 |
+
streamed_chars = 0
|
| 203 |
+
accumulated_text = ""
|
| 204 |
+
|
| 205 |
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 206 |
+
|
| 207 |
+
for _ in range(max_new_tokens):
|
| 208 |
+
with torch.no_grad(), autocast_ctx:
|
| 209 |
+
out = model(ctx, targets=None)
|
| 210 |
+
# out shape: (1, T, vocab) or (1, vocab) depending on path.
|
| 211 |
+
if out.dim() == 3:
|
| 212 |
+
last_logits = out[0, -1, :]
|
| 213 |
+
else:
|
| 214 |
+
last_logits = out[0]
|
| 215 |
+
|
| 216 |
+
recent = generated[-rep_window:] if generated else None
|
| 217 |
+
next_id = sample_token(
|
| 218 |
+
last_logits,
|
| 219 |
+
temperature=temperature,
|
| 220 |
+
top_k=top_k,
|
| 221 |
+
top_p=top_p,
|
| 222 |
+
repetition_penalty=repetition_penalty,
|
| 223 |
+
recent_tokens=recent,
|
| 224 |
+
)
|
| 225 |
+
generated.append(next_id)
|
| 226 |
+
|
| 227 |
+
# Decode everything so-far then diff β BPE decoding is not token-local,
|
| 228 |
+
# so a per-token decode can drop bytes.
|
| 229 |
+
new_text = tokenizer.decode(generated)
|
| 230 |
+
delta = new_text[streamed_chars:]
|
| 231 |
+
if delta:
|
| 232 |
+
streamed_chars = len(new_text)
|
| 233 |
+
accumulated_text = new_text
|
| 234 |
+
yield delta
|
| 235 |
+
|
| 236 |
+
# Stop-string check.
|
| 237 |
+
hit_stop = any(s and s in accumulated_text for s in stop_strings)
|
| 238 |
+
if hit_stop:
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
# Advance context. If we've filled the window, drop oldest token.
|
| 242 |
+
ctx = torch.cat([ctx, torch.tensor([[next_id]], device=device, dtype=torch.long)], dim=1)
|
| 243 |
+
if ctx.size(1) > max_seq_len:
|
| 244 |
+
ctx = ctx[:, -max_seq_len:]
|
| 245 |
+
|
| 246 |
+
# Final accumulated text is also returned for history tracking.
|
| 247 |
+
return accumulated_text # noqa: B901 (generator return for history)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _consume_stream_with_print(stream_gen):
|
| 251 |
+
"""Iterate a generator, print each chunk, return the full text.
|
| 252 |
+
|
| 253 |
+
Replacement for a naΓ―ve list(stream) since `generate_stream` is a generator
|
| 254 |
+
that yields then returns the final text.
|
| 255 |
+
"""
|
| 256 |
+
collected = []
|
| 257 |
+
try:
|
| 258 |
+
while True:
|
| 259 |
+
chunk = next(stream_gen)
|
| 260 |
+
collected.append(chunk)
|
| 261 |
+
sys.stdout.write(chunk)
|
| 262 |
+
sys.stdout.flush()
|
| 263 |
+
except StopIteration as stop:
|
| 264 |
+
# stop.value holds the return value of the generator.
|
| 265 |
+
final = stop.value
|
| 266 |
+
if final is not None:
|
| 267 |
+
return final
|
| 268 |
+
return "".join(collected)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ---------------------------------------------------------------------------
|
| 272 |
+
# REPL
|
| 273 |
+
# ---------------------------------------------------------------------------
|
| 274 |
+
|
| 275 |
+
def build_prompt(system: str, history: list[tuple[str, str]], user_msg: str) -> str:
|
| 276 |
+
"""Assemble the text prompt fed to the tokenizer."""
|
| 277 |
+
parts: list[str] = []
|
| 278 |
+
if system:
|
| 279 |
+
parts.append(system.rstrip() + "\n")
|
| 280 |
+
for u, a in history:
|
| 281 |
+
parts.append(f"{USER_TAG} {u}\n{ASSISTANT_TAG} {a}\n")
|
| 282 |
+
parts.append(f"{USER_TAG} {user_msg}\n{ASSISTANT_TAG}")
|
| 283 |
+
return "".join(parts)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def run_repl(
|
| 287 |
+
model,
|
| 288 |
+
tokenizer,
|
| 289 |
+
meta: dict,
|
| 290 |
+
*,
|
| 291 |
+
device: torch.device,
|
| 292 |
+
max_seq_len: int,
|
| 293 |
+
) -> None:
|
| 294 |
+
settings = {
|
| 295 |
+
"temperature": float(os.environ.get("HYDRA_CHAT_TEMP", "0.8")),
|
| 296 |
+
"top_k": int(os.environ.get("HYDRA_CHAT_TOPK", "40")),
|
| 297 |
+
"top_p": float(os.environ.get("HYDRA_CHAT_TOPP", "0.9")),
|
| 298 |
+
"max_new_tokens": int(os.environ.get("HYDRA_CHAT_MAX", "200")),
|
| 299 |
+
"repetition_penalty": float(os.environ.get("HYDRA_CHAT_REP", "1.1")),
|
| 300 |
+
"system": os.environ.get("HYDRA_CHAT_SYSTEM", ""),
|
| 301 |
+
}
|
| 302 |
+
history: list[tuple[str, str]] = []
|
| 303 |
+
|
| 304 |
+
print()
|
| 305 |
+
print("=" * 60)
|
| 306 |
+
print("HYDRA chat REPL")
|
| 307 |
+
print(f" checkpoint: {meta['ckpt']}")
|
| 308 |
+
if meta.get("step") is not None:
|
| 309 |
+
print(f" step: {meta['step']}")
|
| 310 |
+
if meta.get("val_bpb") is not None:
|
| 311 |
+
print(f" val_bpb: {meta['val_bpb']}")
|
| 312 |
+
print(" type /info for settings, /quit to exit")
|
| 313 |
+
print("=" * 60)
|
| 314 |
+
print()
|
| 315 |
+
|
| 316 |
+
while True:
|
| 317 |
+
try:
|
| 318 |
+
line = input(f"{USER_TAG} ")
|
| 319 |
+
except (EOFError, KeyboardInterrupt):
|
| 320 |
+
print()
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
line = line.rstrip()
|
| 324 |
+
if not line:
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
if line.startswith("/"):
|
| 328 |
+
cmd, *rest = line.split(maxsplit=1)
|
| 329 |
+
arg = rest[0] if rest else ""
|
| 330 |
+
if cmd == "/quit" or cmd == "/exit":
|
| 331 |
+
return
|
| 332 |
+
elif cmd == "/reset":
|
| 333 |
+
history = []
|
| 334 |
+
print("[reset]")
|
| 335 |
+
continue
|
| 336 |
+
elif cmd == "/info":
|
| 337 |
+
print(f"[info] ckpt={meta['ckpt']} settings={settings} history_turns={len(history)}")
|
| 338 |
+
continue
|
| 339 |
+
elif cmd == "/temp":
|
| 340 |
+
try:
|
| 341 |
+
settings["temperature"] = float(arg)
|
| 342 |
+
print(f"[temp={settings['temperature']}]")
|
| 343 |
+
except ValueError:
|
| 344 |
+
print(f"[err] /temp needs a float, got {arg!r}")
|
| 345 |
+
continue
|
| 346 |
+
elif cmd == "/topk":
|
| 347 |
+
try:
|
| 348 |
+
settings["top_k"] = int(arg)
|
| 349 |
+
print(f"[topk={settings['top_k']}]")
|
| 350 |
+
except ValueError:
|
| 351 |
+
print(f"[err] /topk needs an int, got {arg!r}")
|
| 352 |
+
continue
|
| 353 |
+
elif cmd == "/topp":
|
| 354 |
+
try:
|
| 355 |
+
settings["top_p"] = float(arg)
|
| 356 |
+
print(f"[topp={settings['top_p']}]")
|
| 357 |
+
except ValueError:
|
| 358 |
+
print(f"[err] /topp needs a float, got {arg!r}")
|
| 359 |
+
continue
|
| 360 |
+
elif cmd == "/max":
|
| 361 |
+
try:
|
| 362 |
+
settings["max_new_tokens"] = int(arg)
|
| 363 |
+
print(f"[max={settings['max_new_tokens']}]")
|
| 364 |
+
except ValueError:
|
| 365 |
+
print(f"[err] /max needs an int, got {arg!r}")
|
| 366 |
+
continue
|
| 367 |
+
elif cmd == "/rep":
|
| 368 |
+
try:
|
| 369 |
+
settings["repetition_penalty"] = float(arg)
|
| 370 |
+
print(f"[rep={settings['repetition_penalty']}]")
|
| 371 |
+
except ValueError:
|
| 372 |
+
print(f"[err] /rep needs a float, got {arg!r}")
|
| 373 |
+
continue
|
| 374 |
+
elif cmd == "/sys":
|
| 375 |
+
settings["system"] = arg
|
| 376 |
+
print(f"[sys set, {len(arg)} chars]")
|
| 377 |
+
continue
|
| 378 |
+
else:
|
| 379 |
+
print(f"[err] unknown command {cmd!r}. Try /info /reset /quit.")
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
# Normal chat turn.
|
| 383 |
+
prompt_text = build_prompt(settings["system"], history, line)
|
| 384 |
+
prompt_ids = tokenizer.encode(prompt_text)
|
| 385 |
+
|
| 386 |
+
sys.stdout.write(f"{ASSISTANT_TAG} ")
|
| 387 |
+
sys.stdout.flush()
|
| 388 |
+
|
| 389 |
+
stream = generate_stream(
|
| 390 |
+
model, tokenizer, prompt_ids,
|
| 391 |
+
max_new_tokens=settings["max_new_tokens"],
|
| 392 |
+
temperature=settings["temperature"],
|
| 393 |
+
top_k=settings["top_k"],
|
| 394 |
+
top_p=settings["top_p"],
|
| 395 |
+
repetition_penalty=settings["repetition_penalty"],
|
| 396 |
+
stop_strings=(END_TAG,),
|
| 397 |
+
max_seq_len=max_seq_len,
|
| 398 |
+
device=device,
|
| 399 |
+
)
|
| 400 |
+
response_text = _consume_stream_with_print(stream)
|
| 401 |
+
if not response_text.endswith("\n"):
|
| 402 |
+
sys.stdout.write("\n")
|
| 403 |
+
sys.stdout.flush()
|
| 404 |
+
|
| 405 |
+
# Strip trailing stop marker from the remembered history.
|
| 406 |
+
clean = response_text
|
| 407 |
+
if END_TAG in clean:
|
| 408 |
+
clean = clean.split(END_TAG, 1)[0]
|
| 409 |
+
clean = clean.strip()
|
| 410 |
+
history.append((line, clean))
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# ---------------------------------------------------------------------------
|
| 414 |
+
# CLI
|
| 415 |
+
# ---------------------------------------------------------------------------
|
| 416 |
+
|
| 417 |
+
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 418 |
+
p = argparse.ArgumentParser(description="HYDRA chat REPL")
|
| 419 |
+
p.add_argument("--ckpt", type=str, default=None,
|
| 420 |
+
help="Path to checkpoint (.pt). If omitted, auto-select.")
|
| 421 |
+
p.add_argument("--sft", action="store_true",
|
| 422 |
+
help="Prefer an SFT checkpoint if available.")
|
| 423 |
+
p.add_argument("--random", action="store_true",
|
| 424 |
+
help="Skip checkpoint load; use random weights.")
|
| 425 |
+
p.add_argument("--device", type=str, default=None,
|
| 426 |
+
help="Torch device (default: cuda if available else cpu).")
|
| 427 |
+
return p.parse_args(argv)
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def main(argv: list[str] | None = None) -> int:
|
| 431 |
+
args = _parse_args(argv)
|
| 432 |
+
|
| 433 |
+
if args.device:
|
| 434 |
+
device = torch.device(args.device)
|
| 435 |
+
elif torch.cuda.is_available():
|
| 436 |
+
device = torch.device("cuda")
|
| 437 |
+
else:
|
| 438 |
+
device = torch.device("cpu")
|
| 439 |
+
print("[chat] [WARN] CUDA not available; HYDRA's HTM/Mamba kernels may fail on CPU.", file=sys.stderr)
|
| 440 |
+
|
| 441 |
+
ckpt_path: Path | None
|
| 442 |
+
if args.random:
|
| 443 |
+
ckpt_path = None
|
| 444 |
+
else:
|
| 445 |
+
ckpt_path = resolve_checkpoint(args.ckpt, args.sft)
|
| 446 |
+
|
| 447 |
+
t0 = time.time()
|
| 448 |
+
model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
|
| 449 |
+
dt = time.time() - t0
|
| 450 |
+
print(f"[chat] Model ready in {dt:.1f}s on {device}")
|
| 451 |
+
|
| 452 |
+
from prepare import MAX_SEQ_LEN
|
| 453 |
+
run_repl(model, tokenizer, meta, device=device, max_seq_len=MAX_SEQ_LEN)
|
| 454 |
+
return 0
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == "__main__":
|
| 458 |
+
sys.exit(main())
|
overlay/scripts/chat_eval.py
CHANGED
|
@@ -1,300 +1,300 @@
|
|
| 1 |
-
"""Non-interactive chat eval for HYDRA.
|
| 2 |
-
|
| 3 |
-
Runs a fixed set of prompts through the same chat template that `chat.py`
|
| 4 |
-
uses, prints a markdown table with the response and coherence heuristics.
|
| 5 |
-
|
| 6 |
-
Usage:
|
| 7 |
-
python scripts/chat_eval.py # auto-select checkpoint
|
| 8 |
-
python scripts/chat_eval.py --ckpt PATH
|
| 9 |
-
python scripts/chat_eval.py --random
|
| 10 |
-
python scripts/chat_eval.py --json out.json # also dump raw results
|
| 11 |
-
python scripts/chat_eval.py --max 80 # cap new tokens per prompt
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import argparse
|
| 17 |
-
import json
|
| 18 |
-
import os
|
| 19 |
-
import re
|
| 20 |
-
import sys
|
| 21 |
-
import time
|
| 22 |
-
from pathlib import Path
|
| 23 |
-
|
| 24 |
-
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 25 |
-
if str(_REPO_ROOT) not in sys.path:
|
| 26 |
-
sys.path.insert(0, str(_REPO_ROOT))
|
| 27 |
-
|
| 28 |
-
import torch # noqa: E402
|
| 29 |
-
|
| 30 |
-
from scripts.chat import ( # noqa: E402
|
| 31 |
-
ASSISTANT_TAG, END_TAG, USER_TAG, build_prompt,
|
| 32 |
-
generate_stream, load_model_and_tokenizer, resolve_checkpoint,
|
| 33 |
-
)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
PROMPTS: list[str] = [
|
| 37 |
-
# Factual
|
| 38 |
-
"What is the capital of France?",
|
| 39 |
-
"Who wrote Romeo and Juliet?",
|
| 40 |
-
"What is 2 plus 2?",
|
| 41 |
-
"What color is the sky on a clear day?",
|
| 42 |
-
# Completion
|
| 43 |
-
"Once upon a time",
|
| 44 |
-
"The cat sat on the",
|
| 45 |
-
"In a hole in the ground there lived",
|
| 46 |
-
# Instruction
|
| 47 |
-
"Write one short sentence about rain.",
|
| 48 |
-
"List three animals.",
|
| 49 |
-
"Define the word 'library'.",
|
| 50 |
-
# Conversational
|
| 51 |
-
"Hello, how are you?",
|
| 52 |
-
"Tell me a joke.",
|
| 53 |
-
# Creative
|
| 54 |
-
"Describe a sunset in one line.",
|
| 55 |
-
"Give me a name for a pet robot.",
|
| 56 |
-
"What is the meaning of friendship?",
|
| 57 |
-
]
|
| 58 |
-
|
| 59 |
-
# Heuristic thresholds (printed, not enforced as pass/fail).
|
| 60 |
-
THRESH_DISTINCT_2 = 0.30
|
| 61 |
-
THRESH_SENT_MIN = 5
|
| 62 |
-
THRESH_SENT_MAX = 30
|
| 63 |
-
THRESH_EN_RATIO = 0.95
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
# ---------------------------------------------------------------------------
|
| 67 |
-
# Coherence heuristics
|
| 68 |
-
# ---------------------------------------------------------------------------
|
| 69 |
-
|
| 70 |
-
def _tokens(text: str) -> list[str]:
|
| 71 |
-
return re.findall(r"[A-Za-z0-9']+", text)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def distinct_2(text: str) -> float:
|
| 75 |
-
toks = _tokens(text)
|
| 76 |
-
if len(toks) < 2:
|
| 77 |
-
return 0.0
|
| 78 |
-
bigrams = [(toks[i], toks[i + 1]) for i in range(len(toks) - 1)]
|
| 79 |
-
return len(set(bigrams)) / max(1, len(bigrams))
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def avg_sentence_len(text: str) -> float:
|
| 83 |
-
sents = re.split(r"[.!?]+", text)
|
| 84 |
-
lens = [len(_tokens(s)) for s in sents if _tokens(s)]
|
| 85 |
-
if not lens:
|
| 86 |
-
return 0.0
|
| 87 |
-
return sum(lens) / len(lens)
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def english_char_ratio(text: str) -> float:
|
| 91 |
-
if not text:
|
| 92 |
-
return 0.0
|
| 93 |
-
allowed = 0
|
| 94 |
-
for c in text:
|
| 95 |
-
if c.isalnum() or c.isspace() or c in ".,!?;:'\"-()[]{}/\\*#@&%+=_<>|$":
|
| 96 |
-
allowed += 1
|
| 97 |
-
return allowed / len(text)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
# ---------------------------------------------------------------------------
|
| 101 |
-
# Runner
|
| 102 |
-
# ---------------------------------------------------------------------------
|
| 103 |
-
|
| 104 |
-
def _run_one(model, tokenizer, prompt: str, *, max_new_tokens: int, device: torch.device,
|
| 105 |
-
max_seq_len: int, temperature: float, top_k: int, top_p: float,
|
| 106 |
-
repetition_penalty: float) -> str:
|
| 107 |
-
prompt_text = build_prompt(system="", history=[], user_msg=prompt)
|
| 108 |
-
prompt_ids = tokenizer.encode(prompt_text)
|
| 109 |
-
|
| 110 |
-
stream = generate_stream(
|
| 111 |
-
model, tokenizer, prompt_ids,
|
| 112 |
-
max_new_tokens=max_new_tokens,
|
| 113 |
-
temperature=temperature,
|
| 114 |
-
top_k=top_k,
|
| 115 |
-
top_p=top_p,
|
| 116 |
-
repetition_penalty=repetition_penalty,
|
| 117 |
-
stop_strings=(END_TAG,),
|
| 118 |
-
max_seq_len=max_seq_len,
|
| 119 |
-
device=device,
|
| 120 |
-
)
|
| 121 |
-
collected: list[str] = []
|
| 122 |
-
try:
|
| 123 |
-
while True:
|
| 124 |
-
collected.append(next(stream))
|
| 125 |
-
except StopIteration as stop:
|
| 126 |
-
if stop.value is not None:
|
| 127 |
-
text = stop.value
|
| 128 |
-
else:
|
| 129 |
-
text = "".join(collected)
|
| 130 |
-
|
| 131 |
-
if END_TAG in text:
|
| 132 |
-
text = text.split(END_TAG, 1)[0]
|
| 133 |
-
return text.strip()
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
def _render_markdown(rows: list[dict]) -> str:
|
| 137 |
-
lines = [
|
| 138 |
-
"| # | Prompt | Response | dist-2 | sent_len | en_ratio | flags |",
|
| 139 |
-
"|---|--------|----------|--------|----------|----------|-------|",
|
| 140 |
-
]
|
| 141 |
-
|
| 142 |
-
def _cell(s: str, n: int = 60) -> str:
|
| 143 |
-
s = s.replace("|", "\\|").replace("\n", " ")
|
| 144 |
-
if len(s) > n:
|
| 145 |
-
s = s[: n - 1] + "β¦"
|
| 146 |
-
return s
|
| 147 |
-
|
| 148 |
-
for i, r in enumerate(rows, 1):
|
| 149 |
-
flags = []
|
| 150 |
-
if r["distinct_2"] < THRESH_DISTINCT_2:
|
| 151 |
-
flags.append("repetitive")
|
| 152 |
-
if not (THRESH_SENT_MIN <= r["avg_sentence_len"] <= THRESH_SENT_MAX):
|
| 153 |
-
flags.append("sent_len")
|
| 154 |
-
if r["en_ratio"] < THRESH_EN_RATIO:
|
| 155 |
-
flags.append("non_en")
|
| 156 |
-
flag_str = ",".join(flags) or "ok"
|
| 157 |
-
lines.append(
|
| 158 |
-
f"| {i} | {_cell(r['prompt'], 40)} | {_cell(r['response'], 60)} | "
|
| 159 |
-
f"{r['distinct_2']:.2f} | {r['avg_sentence_len']:.1f} | "
|
| 160 |
-
f"{r['en_ratio']:.2f} | {flag_str} |"
|
| 161 |
-
)
|
| 162 |
-
return "\n".join(lines)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
# ---------------------------------------------------------------------------
|
| 166 |
-
# CLI
|
| 167 |
-
# ---------------------------------------------------------------------------
|
| 168 |
-
|
| 169 |
-
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 170 |
-
p = argparse.ArgumentParser(description="HYDRA chat eval")
|
| 171 |
-
p.add_argument("--ckpt", type=str, default=None, help="Checkpoint path.")
|
| 172 |
-
p.add_argument("--sft", action="store_true", help="Prefer SFT checkpoint.")
|
| 173 |
-
p.add_argument("--random", action="store_true", help="Use random weights.")
|
| 174 |
-
p.add_argument("--max", dest="max_new_tokens", type=int, default=80)
|
| 175 |
-
p.add_argument("--temp", dest="temperature", type=float, default=0.8)
|
| 176 |
-
p.add_argument("--topk", dest="top_k", type=int, default=40)
|
| 177 |
-
p.add_argument("--topp", dest="top_p", type=float, default=0.9)
|
| 178 |
-
p.add_argument("--rep", dest="repetition_penalty", type=float, default=1.1)
|
| 179 |
-
p.add_argument("--json", dest="json_out", type=str, default=None,
|
| 180 |
-
help="Optional: dump raw results to this JSON path.")
|
| 181 |
-
p.add_argument("--device", type=str, default=None)
|
| 182 |
-
return p.parse_args(argv)
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
def main(argv: list[str] | None = None) -> int:
|
| 186 |
-
args = _parse_args(argv)
|
| 187 |
-
|
| 188 |
-
if args.device:
|
| 189 |
-
device = torch.device(args.device)
|
| 190 |
-
elif torch.cuda.is_available():
|
| 191 |
-
device = torch.device("cuda")
|
| 192 |
-
else:
|
| 193 |
-
device = torch.device("cpu")
|
| 194 |
-
|
| 195 |
-
ckpt_path = None if args.random else resolve_checkpoint(args.ckpt, args.sft)
|
| 196 |
-
|
| 197 |
-
t0 = time.time()
|
| 198 |
-
model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
|
| 199 |
-
dt_load = time.time() - t0
|
| 200 |
-
print(f"[chat_eval] Loaded in {dt_load:.1f}s ckpt={meta['ckpt']}")
|
| 201 |
-
|
| 202 |
-
from prepare import MAX_SEQ_LEN
|
| 203 |
-
|
| 204 |
-
rows: list[dict] = []
|
| 205 |
-
t_gen = time.time()
|
| 206 |
-
for i, prompt in enumerate(PROMPTS, 1):
|
| 207 |
-
t_start = time.time()
|
| 208 |
-
try:
|
| 209 |
-
resp = _run_one(
|
| 210 |
-
model, tokenizer, prompt,
|
| 211 |
-
max_new_tokens=args.max_new_tokens,
|
| 212 |
-
device=device,
|
| 213 |
-
max_seq_len=MAX_SEQ_LEN,
|
| 214 |
-
temperature=args.temperature,
|
| 215 |
-
top_k=args.top_k,
|
| 216 |
-
top_p=args.top_p,
|
| 217 |
-
repetition_penalty=args.repetition_penalty,
|
| 218 |
-
)
|
| 219 |
-
err = None
|
| 220 |
-
except Exception as e: # noqa: BLE001 β eval must not abort mid-prompt.
|
| 221 |
-
resp = ""
|
| 222 |
-
err = repr(e)
|
| 223 |
-
print(f"[chat_eval] prompt {i} failed: {err}", file=sys.stderr)
|
| 224 |
-
|
| 225 |
-
rows.append({
|
| 226 |
-
"prompt": prompt,
|
| 227 |
-
"response": resp,
|
| 228 |
-
"distinct_2": distinct_2(resp),
|
| 229 |
-
"avg_sentence_len": avg_sentence_len(resp),
|
| 230 |
-
"en_ratio": english_char_ratio(resp),
|
| 231 |
-
"latency_s": round(time.time() - t_start, 2),
|
| 232 |
-
"error": err,
|
| 233 |
-
})
|
| 234 |
-
print(f"[chat_eval] {i:2d}/{len(PROMPTS)} {rows[-1]['latency_s']:.1f}s {resp!r}")
|
| 235 |
-
|
| 236 |
-
dt_gen = time.time() - t_gen
|
| 237 |
-
|
| 238 |
-
print()
|
| 239 |
-
print("## HYDRA chat_eval results")
|
| 240 |
-
print(f"- checkpoint: `{meta['ckpt']}`")
|
| 241 |
-
if meta.get("step") is not None:
|
| 242 |
-
print(f"- step: {meta['step']}")
|
| 243 |
-
if meta.get("val_bpb") is not None:
|
| 244 |
-
print(f"- val_bpb: {meta['val_bpb']}")
|
| 245 |
-
print(f"- prompts: {len(PROMPTS)}")
|
| 246 |
-
print(f"- load: {dt_load:.1f}s generation: {dt_gen:.1f}s")
|
| 247 |
-
print()
|
| 248 |
-
print(_render_markdown(rows))
|
| 249 |
-
print()
|
| 250 |
-
|
| 251 |
-
# Summary heuristics
|
| 252 |
-
any_empty = sum(1 for r in rows if not r["response"])
|
| 253 |
-
any_error = sum(1 for r in rows if r["error"])
|
| 254 |
-
mean_d2 = sum(r["distinct_2"] for r in rows) / max(1, len(rows))
|
| 255 |
-
mean_en = sum(r["en_ratio"] for r in rows) / max(1, len(rows))
|
| 256 |
-
|
| 257 |
-
print("### Aggregates")
|
| 258 |
-
print(f"- empty responses: {any_empty}/{len(rows)}")
|
| 259 |
-
print(f"- generation errors: {any_error}/{len(rows)}")
|
| 260 |
-
print(f"- mean distinct-2: {mean_d2:.3f} (target > {THRESH_DISTINCT_2})")
|
| 261 |
-
print(f"- mean en_ratio: {mean_en:.3f} (target > {THRESH_EN_RATIO})")
|
| 262 |
-
print()
|
| 263 |
-
print("_Quality at this model scale (~7.5M params) is NOT expected to meet thresholds; "
|
| 264 |
-
"this eval verifies the chat interface, not dialogue coherence._")
|
| 265 |
-
|
| 266 |
-
if args.json_out:
|
| 267 |
-
out = {
|
| 268 |
-
"meta": meta,
|
| 269 |
-
"settings": {
|
| 270 |
-
"max_new_tokens": args.max_new_tokens,
|
| 271 |
-
"temperature": args.temperature,
|
| 272 |
-
"top_k": args.top_k,
|
| 273 |
-
"top_p": args.top_p,
|
| 274 |
-
"repetition_penalty": args.repetition_penalty,
|
| 275 |
-
},
|
| 276 |
-
"rows": rows,
|
| 277 |
-
"aggregates": {
|
| 278 |
-
"empty": any_empty,
|
| 279 |
-
"errors": any_error,
|
| 280 |
-
"mean_distinct_2": mean_d2,
|
| 281 |
-
"mean_en_ratio": mean_en,
|
| 282 |
-
"load_s": dt_load,
|
| 283 |
-
"gen_s": dt_gen,
|
| 284 |
-
},
|
| 285 |
-
}
|
| 286 |
-
Path(args.json_out).write_text(json.dumps(out, indent=2))
|
| 287 |
-
print(f"[chat_eval] JSON written to {args.json_out}")
|
| 288 |
-
|
| 289 |
-
# Exit 0 if we loaded and generated *something* for each prompt (even if
|
| 290 |
-
# quality was poor). Exit 1 only on load failure (caught by main's exception
|
| 291 |
-
# propagation) or if ALL prompts returned empty strings β that signals a
|
| 292 |
-
# broken generation loop, not poor quality.
|
| 293 |
-
if any_empty == len(rows):
|
| 294 |
-
print("[chat_eval] ALL prompts returned empty β generation loop is broken.", file=sys.stderr)
|
| 295 |
-
return 1
|
| 296 |
-
return 0
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
if __name__ == "__main__":
|
| 300 |
-
sys.exit(main())
|
|
|
|
| 1 |
+
"""Non-interactive chat eval for HYDRA.
|
| 2 |
+
|
| 3 |
+
Runs a fixed set of prompts through the same chat template that `chat.py`
|
| 4 |
+
uses, prints a markdown table with the response and coherence heuristics.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
python scripts/chat_eval.py # auto-select checkpoint
|
| 8 |
+
python scripts/chat_eval.py --ckpt PATH
|
| 9 |
+
python scripts/chat_eval.py --random
|
| 10 |
+
python scripts/chat_eval.py --json out.json # also dump raw results
|
| 11 |
+
python scripts/chat_eval.py --max 80 # cap new tokens per prompt
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import re
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 25 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 26 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 27 |
+
|
| 28 |
+
import torch # noqa: E402
|
| 29 |
+
|
| 30 |
+
from scripts.chat import ( # noqa: E402
|
| 31 |
+
ASSISTANT_TAG, END_TAG, USER_TAG, build_prompt,
|
| 32 |
+
generate_stream, load_model_and_tokenizer, resolve_checkpoint,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
PROMPTS: list[str] = [
|
| 37 |
+
# Factual
|
| 38 |
+
"What is the capital of France?",
|
| 39 |
+
"Who wrote Romeo and Juliet?",
|
| 40 |
+
"What is 2 plus 2?",
|
| 41 |
+
"What color is the sky on a clear day?",
|
| 42 |
+
# Completion
|
| 43 |
+
"Once upon a time",
|
| 44 |
+
"The cat sat on the",
|
| 45 |
+
"In a hole in the ground there lived",
|
| 46 |
+
# Instruction
|
| 47 |
+
"Write one short sentence about rain.",
|
| 48 |
+
"List three animals.",
|
| 49 |
+
"Define the word 'library'.",
|
| 50 |
+
# Conversational
|
| 51 |
+
"Hello, how are you?",
|
| 52 |
+
"Tell me a joke.",
|
| 53 |
+
# Creative
|
| 54 |
+
"Describe a sunset in one line.",
|
| 55 |
+
"Give me a name for a pet robot.",
|
| 56 |
+
"What is the meaning of friendship?",
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
# Heuristic thresholds (printed, not enforced as pass/fail).
|
| 60 |
+
THRESH_DISTINCT_2 = 0.30
|
| 61 |
+
THRESH_SENT_MIN = 5
|
| 62 |
+
THRESH_SENT_MAX = 30
|
| 63 |
+
THRESH_EN_RATIO = 0.95
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Coherence heuristics
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
def _tokens(text: str) -> list[str]:
|
| 71 |
+
return re.findall(r"[A-Za-z0-9']+", text)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def distinct_2(text: str) -> float:
|
| 75 |
+
toks = _tokens(text)
|
| 76 |
+
if len(toks) < 2:
|
| 77 |
+
return 0.0
|
| 78 |
+
bigrams = [(toks[i], toks[i + 1]) for i in range(len(toks) - 1)]
|
| 79 |
+
return len(set(bigrams)) / max(1, len(bigrams))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def avg_sentence_len(text: str) -> float:
|
| 83 |
+
sents = re.split(r"[.!?]+", text)
|
| 84 |
+
lens = [len(_tokens(s)) for s in sents if _tokens(s)]
|
| 85 |
+
if not lens:
|
| 86 |
+
return 0.0
|
| 87 |
+
return sum(lens) / len(lens)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def english_char_ratio(text: str) -> float:
|
| 91 |
+
if not text:
|
| 92 |
+
return 0.0
|
| 93 |
+
allowed = 0
|
| 94 |
+
for c in text:
|
| 95 |
+
if c.isalnum() or c.isspace() or c in ".,!?;:'\"-()[]{}/\\*#@&%+=_<>|$":
|
| 96 |
+
allowed += 1
|
| 97 |
+
return allowed / len(text)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ---------------------------------------------------------------------------
|
| 101 |
+
# Runner
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
|
| 104 |
+
def _run_one(model, tokenizer, prompt: str, *, max_new_tokens: int, device: torch.device,
|
| 105 |
+
max_seq_len: int, temperature: float, top_k: int, top_p: float,
|
| 106 |
+
repetition_penalty: float) -> str:
|
| 107 |
+
prompt_text = build_prompt(system="", history=[], user_msg=prompt)
|
| 108 |
+
prompt_ids = tokenizer.encode(prompt_text)
|
| 109 |
+
|
| 110 |
+
stream = generate_stream(
|
| 111 |
+
model, tokenizer, prompt_ids,
|
| 112 |
+
max_new_tokens=max_new_tokens,
|
| 113 |
+
temperature=temperature,
|
| 114 |
+
top_k=top_k,
|
| 115 |
+
top_p=top_p,
|
| 116 |
+
repetition_penalty=repetition_penalty,
|
| 117 |
+
stop_strings=(END_TAG,),
|
| 118 |
+
max_seq_len=max_seq_len,
|
| 119 |
+
device=device,
|
| 120 |
+
)
|
| 121 |
+
collected: list[str] = []
|
| 122 |
+
try:
|
| 123 |
+
while True:
|
| 124 |
+
collected.append(next(stream))
|
| 125 |
+
except StopIteration as stop:
|
| 126 |
+
if stop.value is not None:
|
| 127 |
+
text = stop.value
|
| 128 |
+
else:
|
| 129 |
+
text = "".join(collected)
|
| 130 |
+
|
| 131 |
+
if END_TAG in text:
|
| 132 |
+
text = text.split(END_TAG, 1)[0]
|
| 133 |
+
return text.strip()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _render_markdown(rows: list[dict]) -> str:
|
| 137 |
+
lines = [
|
| 138 |
+
"| # | Prompt | Response | dist-2 | sent_len | en_ratio | flags |",
|
| 139 |
+
"|---|--------|----------|--------|----------|----------|-------|",
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
def _cell(s: str, n: int = 60) -> str:
|
| 143 |
+
s = s.replace("|", "\\|").replace("\n", " ")
|
| 144 |
+
if len(s) > n:
|
| 145 |
+
s = s[: n - 1] + "β¦"
|
| 146 |
+
return s
|
| 147 |
+
|
| 148 |
+
for i, r in enumerate(rows, 1):
|
| 149 |
+
flags = []
|
| 150 |
+
if r["distinct_2"] < THRESH_DISTINCT_2:
|
| 151 |
+
flags.append("repetitive")
|
| 152 |
+
if not (THRESH_SENT_MIN <= r["avg_sentence_len"] <= THRESH_SENT_MAX):
|
| 153 |
+
flags.append("sent_len")
|
| 154 |
+
if r["en_ratio"] < THRESH_EN_RATIO:
|
| 155 |
+
flags.append("non_en")
|
| 156 |
+
flag_str = ",".join(flags) or "ok"
|
| 157 |
+
lines.append(
|
| 158 |
+
f"| {i} | {_cell(r['prompt'], 40)} | {_cell(r['response'], 60)} | "
|
| 159 |
+
f"{r['distinct_2']:.2f} | {r['avg_sentence_len']:.1f} | "
|
| 160 |
+
f"{r['en_ratio']:.2f} | {flag_str} |"
|
| 161 |
+
)
|
| 162 |
+
return "\n".join(lines)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
# CLI
|
| 167 |
+
# ---------------------------------------------------------------------------
|
| 168 |
+
|
| 169 |
+
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
| 170 |
+
p = argparse.ArgumentParser(description="HYDRA chat eval")
|
| 171 |
+
p.add_argument("--ckpt", type=str, default=None, help="Checkpoint path.")
|
| 172 |
+
p.add_argument("--sft", action="store_true", help="Prefer SFT checkpoint.")
|
| 173 |
+
p.add_argument("--random", action="store_true", help="Use random weights.")
|
| 174 |
+
p.add_argument("--max", dest="max_new_tokens", type=int, default=80)
|
| 175 |
+
p.add_argument("--temp", dest="temperature", type=float, default=0.8)
|
| 176 |
+
p.add_argument("--topk", dest="top_k", type=int, default=40)
|
| 177 |
+
p.add_argument("--topp", dest="top_p", type=float, default=0.9)
|
| 178 |
+
p.add_argument("--rep", dest="repetition_penalty", type=float, default=1.1)
|
| 179 |
+
p.add_argument("--json", dest="json_out", type=str, default=None,
|
| 180 |
+
help="Optional: dump raw results to this JSON path.")
|
| 181 |
+
p.add_argument("--device", type=str, default=None)
|
| 182 |
+
return p.parse_args(argv)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main(argv: list[str] | None = None) -> int:
|
| 186 |
+
args = _parse_args(argv)
|
| 187 |
+
|
| 188 |
+
if args.device:
|
| 189 |
+
device = torch.device(args.device)
|
| 190 |
+
elif torch.cuda.is_available():
|
| 191 |
+
device = torch.device("cuda")
|
| 192 |
+
else:
|
| 193 |
+
device = torch.device("cpu")
|
| 194 |
+
|
| 195 |
+
ckpt_path = None if args.random else resolve_checkpoint(args.ckpt, args.sft)
|
| 196 |
+
|
| 197 |
+
t0 = time.time()
|
| 198 |
+
model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
|
| 199 |
+
dt_load = time.time() - t0
|
| 200 |
+
print(f"[chat_eval] Loaded in {dt_load:.1f}s ckpt={meta['ckpt']}")
|
| 201 |
+
|
| 202 |
+
from prepare import MAX_SEQ_LEN
|
| 203 |
+
|
| 204 |
+
rows: list[dict] = []
|
| 205 |
+
t_gen = time.time()
|
| 206 |
+
for i, prompt in enumerate(PROMPTS, 1):
|
| 207 |
+
t_start = time.time()
|
| 208 |
+
try:
|
| 209 |
+
resp = _run_one(
|
| 210 |
+
model, tokenizer, prompt,
|
| 211 |
+
max_new_tokens=args.max_new_tokens,
|
| 212 |
+
device=device,
|
| 213 |
+
max_seq_len=MAX_SEQ_LEN,
|
| 214 |
+
temperature=args.temperature,
|
| 215 |
+
top_k=args.top_k,
|
| 216 |
+
top_p=args.top_p,
|
| 217 |
+
repetition_penalty=args.repetition_penalty,
|
| 218 |
+
)
|
| 219 |
+
err = None
|
| 220 |
+
except Exception as e: # noqa: BLE001 β eval must not abort mid-prompt.
|
| 221 |
+
resp = ""
|
| 222 |
+
err = repr(e)
|
| 223 |
+
print(f"[chat_eval] prompt {i} failed: {err}", file=sys.stderr)
|
| 224 |
+
|
| 225 |
+
rows.append({
|
| 226 |
+
"prompt": prompt,
|
| 227 |
+
"response": resp,
|
| 228 |
+
"distinct_2": distinct_2(resp),
|
| 229 |
+
"avg_sentence_len": avg_sentence_len(resp),
|
| 230 |
+
"en_ratio": english_char_ratio(resp),
|
| 231 |
+
"latency_s": round(time.time() - t_start, 2),
|
| 232 |
+
"error": err,
|
| 233 |
+
})
|
| 234 |
+
print(f"[chat_eval] {i:2d}/{len(PROMPTS)} {rows[-1]['latency_s']:.1f}s {resp!r}")
|
| 235 |
+
|
| 236 |
+
dt_gen = time.time() - t_gen
|
| 237 |
+
|
| 238 |
+
print()
|
| 239 |
+
print("## HYDRA chat_eval results")
|
| 240 |
+
print(f"- checkpoint: `{meta['ckpt']}`")
|
| 241 |
+
if meta.get("step") is not None:
|
| 242 |
+
print(f"- step: {meta['step']}")
|
| 243 |
+
if meta.get("val_bpb") is not None:
|
| 244 |
+
print(f"- val_bpb: {meta['val_bpb']}")
|
| 245 |
+
print(f"- prompts: {len(PROMPTS)}")
|
| 246 |
+
print(f"- load: {dt_load:.1f}s generation: {dt_gen:.1f}s")
|
| 247 |
+
print()
|
| 248 |
+
print(_render_markdown(rows))
|
| 249 |
+
print()
|
| 250 |
+
|
| 251 |
+
# Summary heuristics
|
| 252 |
+
any_empty = sum(1 for r in rows if not r["response"])
|
| 253 |
+
any_error = sum(1 for r in rows if r["error"])
|
| 254 |
+
mean_d2 = sum(r["distinct_2"] for r in rows) / max(1, len(rows))
|
| 255 |
+
mean_en = sum(r["en_ratio"] for r in rows) / max(1, len(rows))
|
| 256 |
+
|
| 257 |
+
print("### Aggregates")
|
| 258 |
+
print(f"- empty responses: {any_empty}/{len(rows)}")
|
| 259 |
+
print(f"- generation errors: {any_error}/{len(rows)}")
|
| 260 |
+
print(f"- mean distinct-2: {mean_d2:.3f} (target > {THRESH_DISTINCT_2})")
|
| 261 |
+
print(f"- mean en_ratio: {mean_en:.3f} (target > {THRESH_EN_RATIO})")
|
| 262 |
+
print()
|
| 263 |
+
print("_Quality at this model scale (~7.5M params) is NOT expected to meet thresholds; "
|
| 264 |
+
"this eval verifies the chat interface, not dialogue coherence._")
|
| 265 |
+
|
| 266 |
+
if args.json_out:
|
| 267 |
+
out = {
|
| 268 |
+
"meta": meta,
|
| 269 |
+
"settings": {
|
| 270 |
+
"max_new_tokens": args.max_new_tokens,
|
| 271 |
+
"temperature": args.temperature,
|
| 272 |
+
"top_k": args.top_k,
|
| 273 |
+
"top_p": args.top_p,
|
| 274 |
+
"repetition_penalty": args.repetition_penalty,
|
| 275 |
+
},
|
| 276 |
+
"rows": rows,
|
| 277 |
+
"aggregates": {
|
| 278 |
+
"empty": any_empty,
|
| 279 |
+
"errors": any_error,
|
| 280 |
+
"mean_distinct_2": mean_d2,
|
| 281 |
+
"mean_en_ratio": mean_en,
|
| 282 |
+
"load_s": dt_load,
|
| 283 |
+
"gen_s": dt_gen,
|
| 284 |
+
},
|
| 285 |
+
}
|
| 286 |
+
Path(args.json_out).write_text(json.dumps(out, indent=2))
|
| 287 |
+
print(f"[chat_eval] JSON written to {args.json_out}")
|
| 288 |
+
|
| 289 |
+
# Exit 0 if we loaded and generated *something* for each prompt (even if
|
| 290 |
+
# quality was poor). Exit 1 only on load failure (caught by main's exception
|
| 291 |
+
# propagation) or if ALL prompts returned empty strings β that signals a
|
| 292 |
+
# broken generation loop, not poor quality.
|
| 293 |
+
if any_empty == len(rows):
|
| 294 |
+
print("[chat_eval] ALL prompts returned empty β generation loop is broken.", file=sys.stderr)
|
| 295 |
+
return 1
|
| 296 |
+
return 0
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
sys.exit(main())
|
overlay/scripts/compile_debug.py
CHANGED
|
@@ -1,213 +1,213 @@
|
|
| 1 |
-
"""Diagnostic script for torch.compile deadlock after ~500 steps.
|
| 2 |
-
|
| 3 |
-
F17 investigation: validates that the _compiled_core / forward split
|
| 4 |
-
fixes the deadlock by running forward+backward loops with compile on.
|
| 5 |
-
|
| 6 |
-
Usage:
|
| 7 |
-
LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
|
| 8 |
-
HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \
|
| 9 |
-
HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \
|
| 10 |
-
.venv/bin/python -u scripts/compile_debug.py [mode]
|
| 11 |
-
|
| 12 |
-
Modes:
|
| 13 |
-
eager - no compile (baseline)
|
| 14 |
-
model_only - compile model _compiled_core only
|
| 15 |
-
muon_only - compile muon step only
|
| 16 |
-
both - compile both (default)
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
from __future__ import annotations
|
| 20 |
-
|
| 21 |
-
import gc
|
| 22 |
-
import os
|
| 23 |
-
import signal
|
| 24 |
-
import sys
|
| 25 |
-
import threading
|
| 26 |
-
import time
|
| 27 |
-
|
| 28 |
-
# Set CUDA env before torch import
|
| 29 |
-
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
|
| 30 |
-
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
|
| 31 |
-
|
| 32 |
-
import torch
|
| 33 |
-
import torch.nn as nn
|
| 34 |
-
import torch.nn.functional as F
|
| 35 |
-
|
| 36 |
-
# -------------------------------------------------------------------------
|
| 37 |
-
# Config
|
| 38 |
-
# -------------------------------------------------------------------------
|
| 39 |
-
MAX_STEPS = 800
|
| 40 |
-
WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds
|
| 41 |
-
BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8"))
|
| 42 |
-
SEQ_LEN = 2048
|
| 43 |
-
VOCAB_SIZE = 8192
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
# -------------------------------------------------------------------------
|
| 47 |
-
# Watchdog thread: kills process if no progress
|
| 48 |
-
# -------------------------------------------------------------------------
|
| 49 |
-
_last_progress = time.time()
|
| 50 |
-
_watchdog_armed = True
|
| 51 |
-
|
| 52 |
-
def _watchdog_fn():
|
| 53 |
-
global _last_progress, _watchdog_armed
|
| 54 |
-
while _watchdog_armed:
|
| 55 |
-
time.sleep(1.0)
|
| 56 |
-
elapsed = time.time() - _last_progress
|
| 57 |
-
if elapsed > WATCHDOG_TIMEOUT_S:
|
| 58 |
-
print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s β DEADLOCK DETECTED ***",
|
| 59 |
-
flush=True)
|
| 60 |
-
_dump_diagnostics()
|
| 61 |
-
os.kill(os.getpid(), signal.SIGTERM)
|
| 62 |
-
return
|
| 63 |
-
|
| 64 |
-
def _dump_diagnostics():
|
| 65 |
-
"""Dump CUDA/dynamo state at deadlock time."""
|
| 66 |
-
try:
|
| 67 |
-
stats = torch.cuda.memory_stats()
|
| 68 |
-
print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}")
|
| 69 |
-
print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB")
|
| 70 |
-
print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB")
|
| 71 |
-
print(f" num_ooms: {stats.get('num_ooms', 0)}")
|
| 72 |
-
except Exception as e:
|
| 73 |
-
print(f" (memory_stats failed: {e})")
|
| 74 |
-
|
| 75 |
-
try:
|
| 76 |
-
import torch._dynamo.utils as du
|
| 77 |
-
print(f" dynamo counters: {dict(du.counters)}")
|
| 78 |
-
except Exception as e:
|
| 79 |
-
print(f" (dynamo counters failed: {e})")
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
def tick():
|
| 83 |
-
global _last_progress
|
| 84 |
-
_last_progress = time.time()
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
# -------------------------------------------------------------------------
|
| 88 |
-
# Test
|
| 89 |
-
# -------------------------------------------------------------------------
|
| 90 |
-
def run_test(mode: str) -> dict:
|
| 91 |
-
"""Run forward+backward loop with specified compile config."""
|
| 92 |
-
print(f"\n{'='*70}")
|
| 93 |
-
print(f"TEST MODE: {mode}")
|
| 94 |
-
print(f"{'='*70}", flush=True)
|
| 95 |
-
|
| 96 |
-
compile_model = mode in ("model_only", "both")
|
| 97 |
-
compile_muon = mode in ("muon_only", "both")
|
| 98 |
-
|
| 99 |
-
os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0"
|
| 100 |
-
os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0"
|
| 101 |
-
os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0"
|
| 102 |
-
os.environ["HYDRA_HESTIA_INTERVAL"] = "9999"
|
| 103 |
-
os.environ["HYDRA_HTM_LEARN_EVERY"] = "4"
|
| 104 |
-
|
| 105 |
-
# Clear cached modules for fresh env var reads
|
| 106 |
-
for mod_name in list(sys.modules.keys()):
|
| 107 |
-
if mod_name.startswith("hydra."):
|
| 108 |
-
del sys.modules[mod_name]
|
| 109 |
-
|
| 110 |
-
torch._dynamo.reset()
|
| 111 |
-
torch.cuda.empty_cache()
|
| 112 |
-
torch.cuda.reset_peak_memory_stats()
|
| 113 |
-
gc.collect()
|
| 114 |
-
|
| 115 |
-
from hydra.model import PostSemClawModel
|
| 116 |
-
from hydra.config import PostSemClawConfig
|
| 117 |
-
|
| 118 |
-
device = torch.device("cuda")
|
| 119 |
-
config = PostSemClawConfig(
|
| 120 |
-
d_model=256, n_layer=4, d_state=64, headdim=32, expand=2,
|
| 121 |
-
vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN,
|
| 122 |
-
)
|
| 123 |
-
|
| 124 |
-
with torch.device("meta"):
|
| 125 |
-
model = PostSemClawModel(config)
|
| 126 |
-
model.to_empty(device=device)
|
| 127 |
-
model.init_weights()
|
| 128 |
-
|
| 129 |
-
optimizer = model.setup_optimizer()
|
| 130 |
-
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 131 |
-
|
| 132 |
-
result = {"mode": mode, "max_step": 0, "tps_samples": []}
|
| 133 |
-
alloc_retries_prev = 0
|
| 134 |
-
|
| 135 |
-
tick()
|
| 136 |
-
|
| 137 |
-
for step in range(MAX_STEPS):
|
| 138 |
-
t0 = time.time()
|
| 139 |
-
|
| 140 |
-
x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
|
| 141 |
-
y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
|
| 142 |
-
|
| 143 |
-
with autocast_ctx:
|
| 144 |
-
loss = model(x, y)
|
| 145 |
-
loss.backward()
|
| 146 |
-
|
| 147 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 148 |
-
optimizer.step()
|
| 149 |
-
model.zero_grad(set_to_none=True)
|
| 150 |
-
|
| 151 |
-
torch.cuda.synchronize()
|
| 152 |
-
dt = time.time() - t0
|
| 153 |
-
tps = int(BATCH_SIZE * SEQ_LEN / dt)
|
| 154 |
-
|
| 155 |
-
tick()
|
| 156 |
-
|
| 157 |
-
stats = torch.cuda.memory_stats()
|
| 158 |
-
retries = stats.get("num_alloc_retries", 0)
|
| 159 |
-
retry_delta = retries - alloc_retries_prev
|
| 160 |
-
alloc_retries_prev = retries
|
| 161 |
-
|
| 162 |
-
result["max_step"] = step
|
| 163 |
-
|
| 164 |
-
if step % 50 == 0 or retry_delta > 0 or step < 3:
|
| 165 |
-
alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6
|
| 166 |
-
print(
|
| 167 |
-
f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms "
|
| 168 |
-
f"alloc={alloc_mb:.0f}MB retries={retries}",
|
| 169 |
-
flush=True,
|
| 170 |
-
)
|
| 171 |
-
result["tps_samples"].append((step, tps))
|
| 172 |
-
|
| 173 |
-
result["completed"] = True
|
| 174 |
-
print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True)
|
| 175 |
-
return result
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def main():
|
| 179 |
-
print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}")
|
| 180 |
-
print(f"GPU: {torch.cuda.get_device_name()}")
|
| 181 |
-
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 182 |
-
print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s")
|
| 183 |
-
|
| 184 |
-
wd = threading.Thread(target=_watchdog_fn, daemon=True)
|
| 185 |
-
wd.start()
|
| 186 |
-
|
| 187 |
-
modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"]
|
| 188 |
-
results = []
|
| 189 |
-
|
| 190 |
-
for mode in modes:
|
| 191 |
-
try:
|
| 192 |
-
r = run_test(mode)
|
| 193 |
-
except SystemExit:
|
| 194 |
-
print(f"\n DEADLOCK/KILLED mode={mode}", flush=True)
|
| 195 |
-
r = {"mode": mode, "completed": False, "max_step": "?"}
|
| 196 |
-
except Exception as e:
|
| 197 |
-
print(f"\n ERROR mode={mode}: {e}", flush=True)
|
| 198 |
-
r = {"mode": mode, "completed": False, "error": str(e)}
|
| 199 |
-
results.append(r)
|
| 200 |
-
|
| 201 |
-
print(f"\n{'='*70}")
|
| 202 |
-
print("SUMMARY")
|
| 203 |
-
print(f"{'='*70}")
|
| 204 |
-
for r in results:
|
| 205 |
-
status = "PASS" if r.get("completed") else "FAIL"
|
| 206 |
-
print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})")
|
| 207 |
-
|
| 208 |
-
global _watchdog_armed
|
| 209 |
-
_watchdog_armed = False
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
if __name__ == "__main__":
|
| 213 |
-
main()
|
|
|
|
| 1 |
+
"""Diagnostic script for torch.compile deadlock after ~500 steps.
|
| 2 |
+
|
| 3 |
+
F17 investigation: validates that the _compiled_core / forward split
|
| 4 |
+
fixes the deadlock by running forward+backward loops with compile on.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
|
| 8 |
+
HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \
|
| 9 |
+
HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \
|
| 10 |
+
.venv/bin/python -u scripts/compile_debug.py [mode]
|
| 11 |
+
|
| 12 |
+
Modes:
|
| 13 |
+
eager - no compile (baseline)
|
| 14 |
+
model_only - compile model _compiled_core only
|
| 15 |
+
muon_only - compile muon step only
|
| 16 |
+
both - compile both (default)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import gc
|
| 22 |
+
import os
|
| 23 |
+
import signal
|
| 24 |
+
import sys
|
| 25 |
+
import threading
|
| 26 |
+
import time
|
| 27 |
+
|
| 28 |
+
# Set CUDA env before torch import
|
| 29 |
+
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
|
| 30 |
+
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
import torch.nn.functional as F
|
| 35 |
+
|
| 36 |
+
# -------------------------------------------------------------------------
|
| 37 |
+
# Config
|
| 38 |
+
# -------------------------------------------------------------------------
|
| 39 |
+
MAX_STEPS = 800
|
| 40 |
+
WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds
|
| 41 |
+
BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8"))
|
| 42 |
+
SEQ_LEN = 2048
|
| 43 |
+
VOCAB_SIZE = 8192
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# -------------------------------------------------------------------------
|
| 47 |
+
# Watchdog thread: kills process if no progress
|
| 48 |
+
# -------------------------------------------------------------------------
|
| 49 |
+
_last_progress = time.time()
|
| 50 |
+
_watchdog_armed = True
|
| 51 |
+
|
| 52 |
+
def _watchdog_fn():
|
| 53 |
+
global _last_progress, _watchdog_armed
|
| 54 |
+
while _watchdog_armed:
|
| 55 |
+
time.sleep(1.0)
|
| 56 |
+
elapsed = time.time() - _last_progress
|
| 57 |
+
if elapsed > WATCHDOG_TIMEOUT_S:
|
| 58 |
+
print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s β DEADLOCK DETECTED ***",
|
| 59 |
+
flush=True)
|
| 60 |
+
_dump_diagnostics()
|
| 61 |
+
os.kill(os.getpid(), signal.SIGTERM)
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
def _dump_diagnostics():
|
| 65 |
+
"""Dump CUDA/dynamo state at deadlock time."""
|
| 66 |
+
try:
|
| 67 |
+
stats = torch.cuda.memory_stats()
|
| 68 |
+
print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}")
|
| 69 |
+
print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB")
|
| 70 |
+
print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB")
|
| 71 |
+
print(f" num_ooms: {stats.get('num_ooms', 0)}")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f" (memory_stats failed: {e})")
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
import torch._dynamo.utils as du
|
| 77 |
+
print(f" dynamo counters: {dict(du.counters)}")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
print(f" (dynamo counters failed: {e})")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def tick():
|
| 83 |
+
global _last_progress
|
| 84 |
+
_last_progress = time.time()
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# -------------------------------------------------------------------------
|
| 88 |
+
# Test
|
| 89 |
+
# -------------------------------------------------------------------------
|
| 90 |
+
def run_test(mode: str) -> dict:
|
| 91 |
+
"""Run forward+backward loop with specified compile config."""
|
| 92 |
+
print(f"\n{'='*70}")
|
| 93 |
+
print(f"TEST MODE: {mode}")
|
| 94 |
+
print(f"{'='*70}", flush=True)
|
| 95 |
+
|
| 96 |
+
compile_model = mode in ("model_only", "both")
|
| 97 |
+
compile_muon = mode in ("muon_only", "both")
|
| 98 |
+
|
| 99 |
+
os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0"
|
| 100 |
+
os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0"
|
| 101 |
+
os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0"
|
| 102 |
+
os.environ["HYDRA_HESTIA_INTERVAL"] = "9999"
|
| 103 |
+
os.environ["HYDRA_HTM_LEARN_EVERY"] = "4"
|
| 104 |
+
|
| 105 |
+
# Clear cached modules for fresh env var reads
|
| 106 |
+
for mod_name in list(sys.modules.keys()):
|
| 107 |
+
if mod_name.startswith("hydra."):
|
| 108 |
+
del sys.modules[mod_name]
|
| 109 |
+
|
| 110 |
+
torch._dynamo.reset()
|
| 111 |
+
torch.cuda.empty_cache()
|
| 112 |
+
torch.cuda.reset_peak_memory_stats()
|
| 113 |
+
gc.collect()
|
| 114 |
+
|
| 115 |
+
from hydra.model import PostSemClawModel
|
| 116 |
+
from hydra.config import PostSemClawConfig
|
| 117 |
+
|
| 118 |
+
device = torch.device("cuda")
|
| 119 |
+
config = PostSemClawConfig(
|
| 120 |
+
d_model=256, n_layer=4, d_state=64, headdim=32, expand=2,
|
| 121 |
+
vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
with torch.device("meta"):
|
| 125 |
+
model = PostSemClawModel(config)
|
| 126 |
+
model.to_empty(device=device)
|
| 127 |
+
model.init_weights()
|
| 128 |
+
|
| 129 |
+
optimizer = model.setup_optimizer()
|
| 130 |
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 131 |
+
|
| 132 |
+
result = {"mode": mode, "max_step": 0, "tps_samples": []}
|
| 133 |
+
alloc_retries_prev = 0
|
| 134 |
+
|
| 135 |
+
tick()
|
| 136 |
+
|
| 137 |
+
for step in range(MAX_STEPS):
|
| 138 |
+
t0 = time.time()
|
| 139 |
+
|
| 140 |
+
x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
|
| 141 |
+
y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
|
| 142 |
+
|
| 143 |
+
with autocast_ctx:
|
| 144 |
+
loss = model(x, y)
|
| 145 |
+
loss.backward()
|
| 146 |
+
|
| 147 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 148 |
+
optimizer.step()
|
| 149 |
+
model.zero_grad(set_to_none=True)
|
| 150 |
+
|
| 151 |
+
torch.cuda.synchronize()
|
| 152 |
+
dt = time.time() - t0
|
| 153 |
+
tps = int(BATCH_SIZE * SEQ_LEN / dt)
|
| 154 |
+
|
| 155 |
+
tick()
|
| 156 |
+
|
| 157 |
+
stats = torch.cuda.memory_stats()
|
| 158 |
+
retries = stats.get("num_alloc_retries", 0)
|
| 159 |
+
retry_delta = retries - alloc_retries_prev
|
| 160 |
+
alloc_retries_prev = retries
|
| 161 |
+
|
| 162 |
+
result["max_step"] = step
|
| 163 |
+
|
| 164 |
+
if step % 50 == 0 or retry_delta > 0 or step < 3:
|
| 165 |
+
alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6
|
| 166 |
+
print(
|
| 167 |
+
f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms "
|
| 168 |
+
f"alloc={alloc_mb:.0f}MB retries={retries}",
|
| 169 |
+
flush=True,
|
| 170 |
+
)
|
| 171 |
+
result["tps_samples"].append((step, tps))
|
| 172 |
+
|
| 173 |
+
result["completed"] = True
|
| 174 |
+
print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True)
|
| 175 |
+
return result
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def main():
|
| 179 |
+
print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}")
|
| 180 |
+
print(f"GPU: {torch.cuda.get_device_name()}")
|
| 181 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
|
| 182 |
+
print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s")
|
| 183 |
+
|
| 184 |
+
wd = threading.Thread(target=_watchdog_fn, daemon=True)
|
| 185 |
+
wd.start()
|
| 186 |
+
|
| 187 |
+
modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"]
|
| 188 |
+
results = []
|
| 189 |
+
|
| 190 |
+
for mode in modes:
|
| 191 |
+
try:
|
| 192 |
+
r = run_test(mode)
|
| 193 |
+
except SystemExit:
|
| 194 |
+
print(f"\n DEADLOCK/KILLED mode={mode}", flush=True)
|
| 195 |
+
r = {"mode": mode, "completed": False, "max_step": "?"}
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"\n ERROR mode={mode}: {e}", flush=True)
|
| 198 |
+
r = {"mode": mode, "completed": False, "error": str(e)}
|
| 199 |
+
results.append(r)
|
| 200 |
+
|
| 201 |
+
print(f"\n{'='*70}")
|
| 202 |
+
print("SUMMARY")
|
| 203 |
+
print(f"{'='*70}")
|
| 204 |
+
for r in results:
|
| 205 |
+
status = "PASS" if r.get("completed") else "FAIL"
|
| 206 |
+
print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})")
|
| 207 |
+
|
| 208 |
+
global _watchdog_armed
|
| 209 |
+
_watchdog_armed = False
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
main()
|
overlay/scripts/dataset_audit.py
CHANGED
|
@@ -1,241 +1,241 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Dataset audit β diagnostic tool for HYDRA's pretraining corpus.
|
| 3 |
-
|
| 4 |
-
Usage:
|
| 5 |
-
python scripts/dataset_audit.py # Quick audit
|
| 6 |
-
python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts
|
| 7 |
-
python scripts/dataset_audit.py --full # Full tokenize of every shard (slow)
|
| 8 |
-
|
| 9 |
-
Reports:
|
| 10 |
-
- Shard count, total disk usage
|
| 11 |
-
- Estimated total tokens (character-based + tokenized sample)
|
| 12 |
-
- Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target
|
| 13 |
-
- Document diversity sample
|
| 14 |
-
- Warnings about shard ordering, shuffle, and streaming behavior
|
| 15 |
-
"""
|
| 16 |
-
from __future__ import annotations
|
| 17 |
-
|
| 18 |
-
import argparse
|
| 19 |
-
import os
|
| 20 |
-
import sys
|
| 21 |
-
import time
|
| 22 |
-
from pathlib import Path
|
| 23 |
-
|
| 24 |
-
import pyarrow.parquet as pq
|
| 25 |
-
|
| 26 |
-
# Resolve repo root so the script works regardless of CWD.
|
| 27 |
-
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 28 |
-
sys.path.insert(0, str(REPO_ROOT))
|
| 29 |
-
|
| 30 |
-
from prepare import ( # noqa: E402
|
| 31 |
-
DATA_DIR,
|
| 32 |
-
MAX_SHARD,
|
| 33 |
-
TOKENIZER_DIR,
|
| 34 |
-
VAL_FILENAME,
|
| 35 |
-
VAL_SHARD,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s
|
| 39 |
-
CHARS_PER_TOKEN_HEURISTIC = 4.0
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def human_bytes(n: int) -> str:
|
| 43 |
-
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 44 |
-
if n < 1024:
|
| 45 |
-
return f"{n:.1f}{unit}"
|
| 46 |
-
n /= 1024
|
| 47 |
-
return f"{n:.1f}PB"
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def human_tokens(n: int | float) -> str:
|
| 51 |
-
if n >= 1e9:
|
| 52 |
-
return f"{n / 1e9:.2f}B"
|
| 53 |
-
if n >= 1e6:
|
| 54 |
-
return f"{n / 1e6:.1f}M"
|
| 55 |
-
if n >= 1e3:
|
| 56 |
-
return f"{n / 1e3:.1f}K"
|
| 57 |
-
return f"{n:.0f}"
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def list_shards() -> tuple[list[Path], Path | None]:
|
| 61 |
-
"""Return (train_shards_sorted, val_shard_or_none)."""
|
| 62 |
-
if not os.path.isdir(DATA_DIR):
|
| 63 |
-
return [], None
|
| 64 |
-
all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet"))
|
| 65 |
-
val_path = Path(DATA_DIR) / VAL_FILENAME
|
| 66 |
-
train = [p for p in all_paths if p.name != VAL_FILENAME]
|
| 67 |
-
val = val_path if val_path.exists() else None
|
| 68 |
-
return train, val
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]:
|
| 72 |
-
"""Tokenize first N row groups of a shard. Returns (tokens, docs)."""
|
| 73 |
-
pf = pq.ParquetFile(shard_path)
|
| 74 |
-
tokens = 0
|
| 75 |
-
docs = 0
|
| 76 |
-
n = min(row_groups, pf.num_row_groups)
|
| 77 |
-
for i in range(n):
|
| 78 |
-
rg = pf.read_row_group(i)
|
| 79 |
-
texts = rg.column("text").to_pylist()
|
| 80 |
-
ids = enc.encode_ordinary_batch(texts, num_threads=8)
|
| 81 |
-
tokens += sum(len(x) for x in ids)
|
| 82 |
-
docs += len(texts)
|
| 83 |
-
return tokens, docs, pf.num_row_groups
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def main() -> int:
|
| 87 |
-
parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus")
|
| 88 |
-
parser.add_argument(
|
| 89 |
-
"--sample",
|
| 90 |
-
type=int,
|
| 91 |
-
default=3,
|
| 92 |
-
help="Number of shards to tokenize for token-count estimate",
|
| 93 |
-
)
|
| 94 |
-
parser.add_argument(
|
| 95 |
-
"--full",
|
| 96 |
-
action="store_true",
|
| 97 |
-
help="Tokenize every shard (slow; gives exact total)",
|
| 98 |
-
)
|
| 99 |
-
args = parser.parse_args()
|
| 100 |
-
|
| 101 |
-
print("=" * 72)
|
| 102 |
-
print("HYDRA corpus audit")
|
| 103 |
-
print("=" * 72)
|
| 104 |
-
print(f"DATA_DIR: {DATA_DIR}")
|
| 105 |
-
print(f"TOKENIZER_DIR: {TOKENIZER_DIR}")
|
| 106 |
-
print(f"Source dataset: karpathy/climbmix-400b-shuffle")
|
| 107 |
-
print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})")
|
| 108 |
-
print()
|
| 109 |
-
|
| 110 |
-
train_shards, val_shard = list_shards()
|
| 111 |
-
if not train_shards:
|
| 112 |
-
print("ERROR: no parquet shards found. Run `python prepare.py` first.")
|
| 113 |
-
return 1
|
| 114 |
-
|
| 115 |
-
total_disk = sum(p.stat().st_size for p in train_shards)
|
| 116 |
-
val_disk = val_shard.stat().st_size if val_shard else 0
|
| 117 |
-
|
| 118 |
-
print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})")
|
| 119 |
-
print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})")
|
| 120 |
-
print(f"Disk (train): {human_bytes(total_disk)}")
|
| 121 |
-
print(f"Disk (val): {human_bytes(val_disk)}")
|
| 122 |
-
print()
|
| 123 |
-
|
| 124 |
-
# Character-based pass (fast): count total chars in all shards.
|
| 125 |
-
t0 = time.time()
|
| 126 |
-
total_chars = 0
|
| 127 |
-
total_docs = 0
|
| 128 |
-
total_row_groups = 0
|
| 129 |
-
for p in train_shards:
|
| 130 |
-
pf = pq.ParquetFile(p)
|
| 131 |
-
total_row_groups += pf.num_row_groups
|
| 132 |
-
total_docs += pf.metadata.num_rows
|
| 133 |
-
dt_meta = time.time() - t0
|
| 134 |
-
print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s")
|
| 135 |
-
print(f"Train documents: {total_docs:,}")
|
| 136 |
-
print(f"Row groups: {total_row_groups:,}")
|
| 137 |
-
print()
|
| 138 |
-
|
| 139 |
-
# Tokenizer-based sampling.
|
| 140 |
-
try:
|
| 141 |
-
import pickle
|
| 142 |
-
|
| 143 |
-
with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f:
|
| 144 |
-
enc = pickle.load(f)
|
| 145 |
-
print(f"Tokenizer vocab: {enc.n_vocab}")
|
| 146 |
-
except FileNotFoundError:
|
| 147 |
-
print("WARNING: tokenizer.pkl not found β skipping tokenized sample.")
|
| 148 |
-
enc = None
|
| 149 |
-
|
| 150 |
-
est_total_tokens = 0
|
| 151 |
-
if enc is not None:
|
| 152 |
-
if args.full:
|
| 153 |
-
sample_shards = train_shards
|
| 154 |
-
else:
|
| 155 |
-
# Pick shards evenly across the range for a representative sample.
|
| 156 |
-
n_sample = min(args.sample, len(train_shards))
|
| 157 |
-
if n_sample == 1:
|
| 158 |
-
sample_shards = [train_shards[0]]
|
| 159 |
-
else:
|
| 160 |
-
stride = max(1, len(train_shards) // n_sample)
|
| 161 |
-
sample_shards = train_shards[::stride][:n_sample]
|
| 162 |
-
|
| 163 |
-
t0 = time.time()
|
| 164 |
-
sample_tokens = 0
|
| 165 |
-
sample_docs = 0
|
| 166 |
-
sample_row_groups = 0
|
| 167 |
-
sample_shard_row_groups = 0
|
| 168 |
-
print(f"Tokenizing sample: {len(sample_shards)} shards ...")
|
| 169 |
-
for p in sample_shards:
|
| 170 |
-
tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5)
|
| 171 |
-
sample_tokens += tok
|
| 172 |
-
sample_docs += docs
|
| 173 |
-
sample_row_groups += min(5, n_rg)
|
| 174 |
-
sample_shard_row_groups += n_rg
|
| 175 |
-
dt_tok = time.time() - t0
|
| 176 |
-
|
| 177 |
-
tokens_per_rg = sample_tokens / max(sample_row_groups, 1)
|
| 178 |
-
per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards))
|
| 179 |
-
est_total_tokens = per_shard * len(train_shards)
|
| 180 |
-
|
| 181 |
-
print(
|
| 182 |
-
f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, "
|
| 183 |
-
f"{sample_tokens:,} tokens) in {dt_tok:.1f}s"
|
| 184 |
-
)
|
| 185 |
-
print(f" tokens/row_group: {tokens_per_rg:,.0f}")
|
| 186 |
-
print(f" tokens/shard: {per_shard:,.0f}")
|
| 187 |
-
print(f" tokens/shard: {human_tokens(per_shard)}")
|
| 188 |
-
else:
|
| 189 |
-
# Fall back to character heuristic.
|
| 190 |
-
per_shard_chars = total_disk / max(len(train_shards), 1)
|
| 191 |
-
# Parquet compression ratio ~3x for text; decompressed ~3 * file size.
|
| 192 |
-
# Chars per token heuristic β 4.
|
| 193 |
-
est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC
|
| 194 |
-
|
| 195 |
-
print()
|
| 196 |
-
print("-" * 72)
|
| 197 |
-
print("Token budget analysis")
|
| 198 |
-
print("-" * 72)
|
| 199 |
-
print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} "
|
| 200 |
-
f"({est_total_tokens:,.0f})")
|
| 201 |
-
print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}")
|
| 202 |
-
ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0
|
| 203 |
-
if ratio >= 1.0:
|
| 204 |
-
print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})")
|
| 205 |
-
else:
|
| 206 |
-
print(f" Ratio: {ratio:.2f}x INSUFFICIENT β need {1 - ratio:.0%} more")
|
| 207 |
-
print()
|
| 208 |
-
|
| 209 |
-
# Warnings about the dataloader behavior.
|
| 210 |
-
print("-" * 72)
|
| 211 |
-
print("Dataloader behavior (prepare.py::_document_batches)")
|
| 212 |
-
print("-" * 72)
|
| 213 |
-
print("+ Infinite streaming: while True around shard list (no StopIteration)")
|
| 214 |
-
print("+ Streams per shard, never loads full corpus into RAM")
|
| 215 |
-
print("+ BOS-aligned best-fit packing gives document-level buffer shuffling")
|
| 216 |
-
print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch")
|
| 217 |
-
print("- Row groups / rows WITHIN a shard are read in fixed order")
|
| 218 |
-
print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)")
|
| 219 |
-
print()
|
| 220 |
-
|
| 221 |
-
# Quick content diversity peek.
|
| 222 |
-
if train_shards:
|
| 223 |
-
print("-" * 72)
|
| 224 |
-
print("Content sample (shard 0, first 3 docs)")
|
| 225 |
-
print("-" * 72)
|
| 226 |
-
pf = pq.ParquetFile(train_shards[0])
|
| 227 |
-
rg = pf.read_row_group(0)
|
| 228 |
-
texts = rg.column("text").to_pylist()
|
| 229 |
-
for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]):
|
| 230 |
-
if idx < len(texts):
|
| 231 |
-
snippet = texts[idx][:160].replace("\n", " ")
|
| 232 |
-
print(f" [{i}] len={len(texts[idx])}: {snippet!r}")
|
| 233 |
-
print()
|
| 234 |
-
|
| 235 |
-
print("=" * 72)
|
| 236 |
-
print("Done.")
|
| 237 |
-
return 0
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
if __name__ == "__main__":
|
| 241 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset audit β diagnostic tool for HYDRA's pretraining corpus.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python scripts/dataset_audit.py # Quick audit
|
| 6 |
+
python scripts/dataset_audit.py --sample 10 # Sample 10 shards for token counts
|
| 7 |
+
python scripts/dataset_audit.py --full # Full tokenize of every shard (slow)
|
| 8 |
+
|
| 9 |
+
Reports:
|
| 10 |
+
- Shard count, total disk usage
|
| 11 |
+
- Estimated total tokens (character-based + tokenized sample)
|
| 12 |
+
- Training budget sufficiency vs 12h @ 65k tok/s = 2.8B token target
|
| 13 |
+
- Document diversity sample
|
| 14 |
+
- Warnings about shard ordering, shuffle, and streaming behavior
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import pyarrow.parquet as pq
|
| 25 |
+
|
| 26 |
+
# Resolve repo root so the script works regardless of CWD.
|
| 27 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 28 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 29 |
+
|
| 30 |
+
from prepare import ( # noqa: E402
|
| 31 |
+
DATA_DIR,
|
| 32 |
+
MAX_SHARD,
|
| 33 |
+
TOKENIZER_DIR,
|
| 34 |
+
VAL_FILENAME,
|
| 35 |
+
VAL_SHARD,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
TARGET_TOKENS_12H = 2_800_000_000 # 65k tok/s * 12h * 3600s
|
| 39 |
+
CHARS_PER_TOKEN_HEURISTIC = 4.0
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def human_bytes(n: int) -> str:
|
| 43 |
+
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 44 |
+
if n < 1024:
|
| 45 |
+
return f"{n:.1f}{unit}"
|
| 46 |
+
n /= 1024
|
| 47 |
+
return f"{n:.1f}PB"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def human_tokens(n: int | float) -> str:
|
| 51 |
+
if n >= 1e9:
|
| 52 |
+
return f"{n / 1e9:.2f}B"
|
| 53 |
+
if n >= 1e6:
|
| 54 |
+
return f"{n / 1e6:.1f}M"
|
| 55 |
+
if n >= 1e3:
|
| 56 |
+
return f"{n / 1e3:.1f}K"
|
| 57 |
+
return f"{n:.0f}"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def list_shards() -> tuple[list[Path], Path | None]:
|
| 61 |
+
"""Return (train_shards_sorted, val_shard_or_none)."""
|
| 62 |
+
if not os.path.isdir(DATA_DIR):
|
| 63 |
+
return [], None
|
| 64 |
+
all_paths = sorted(Path(DATA_DIR).glob("shard_*.parquet"))
|
| 65 |
+
val_path = Path(DATA_DIR) / VAL_FILENAME
|
| 66 |
+
train = [p for p in all_paths if p.name != VAL_FILENAME]
|
| 67 |
+
val = val_path if val_path.exists() else None
|
| 68 |
+
return train, val
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def tokenized_sample(shard_path: Path, enc, row_groups: int = 5) -> tuple[int, int]:
|
| 72 |
+
"""Tokenize first N row groups of a shard. Returns (tokens, docs)."""
|
| 73 |
+
pf = pq.ParquetFile(shard_path)
|
| 74 |
+
tokens = 0
|
| 75 |
+
docs = 0
|
| 76 |
+
n = min(row_groups, pf.num_row_groups)
|
| 77 |
+
for i in range(n):
|
| 78 |
+
rg = pf.read_row_group(i)
|
| 79 |
+
texts = rg.column("text").to_pylist()
|
| 80 |
+
ids = enc.encode_ordinary_batch(texts, num_threads=8)
|
| 81 |
+
tokens += sum(len(x) for x in ids)
|
| 82 |
+
docs += len(texts)
|
| 83 |
+
return tokens, docs, pf.num_row_groups
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def main() -> int:
|
| 87 |
+
parser = argparse.ArgumentParser(description="Audit the HYDRA training corpus")
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--sample",
|
| 90 |
+
type=int,
|
| 91 |
+
default=3,
|
| 92 |
+
help="Number of shards to tokenize for token-count estimate",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--full",
|
| 96 |
+
action="store_true",
|
| 97 |
+
help="Tokenize every shard (slow; gives exact total)",
|
| 98 |
+
)
|
| 99 |
+
args = parser.parse_args()
|
| 100 |
+
|
| 101 |
+
print("=" * 72)
|
| 102 |
+
print("HYDRA corpus audit")
|
| 103 |
+
print("=" * 72)
|
| 104 |
+
print(f"DATA_DIR: {DATA_DIR}")
|
| 105 |
+
print(f"TOKENIZER_DIR: {TOKENIZER_DIR}")
|
| 106 |
+
print(f"Source dataset: karpathy/climbmix-400b-shuffle")
|
| 107 |
+
print(f"Max remote shard: {MAX_SHARD} (pinned val = shard_{VAL_SHARD:05d})")
|
| 108 |
+
print()
|
| 109 |
+
|
| 110 |
+
train_shards, val_shard = list_shards()
|
| 111 |
+
if not train_shards:
|
| 112 |
+
print("ERROR: no parquet shards found. Run `python prepare.py` first.")
|
| 113 |
+
return 1
|
| 114 |
+
|
| 115 |
+
total_disk = sum(p.stat().st_size for p in train_shards)
|
| 116 |
+
val_disk = val_shard.stat().st_size if val_shard else 0
|
| 117 |
+
|
| 118 |
+
print(f"Train shards: {len(train_shards)} ({train_shards[0].name} ... {train_shards[-1].name})")
|
| 119 |
+
print(f"Val shard: {'present' if val_shard else 'MISSING'} ({VAL_FILENAME})")
|
| 120 |
+
print(f"Disk (train): {human_bytes(total_disk)}")
|
| 121 |
+
print(f"Disk (val): {human_bytes(val_disk)}")
|
| 122 |
+
print()
|
| 123 |
+
|
| 124 |
+
# Character-based pass (fast): count total chars in all shards.
|
| 125 |
+
t0 = time.time()
|
| 126 |
+
total_chars = 0
|
| 127 |
+
total_docs = 0
|
| 128 |
+
total_row_groups = 0
|
| 129 |
+
for p in train_shards:
|
| 130 |
+
pf = pq.ParquetFile(p)
|
| 131 |
+
total_row_groups += pf.num_row_groups
|
| 132 |
+
total_docs += pf.metadata.num_rows
|
| 133 |
+
dt_meta = time.time() - t0
|
| 134 |
+
print(f"Metadata scan: {len(train_shards)} shards in {dt_meta:.1f}s")
|
| 135 |
+
print(f"Train documents: {total_docs:,}")
|
| 136 |
+
print(f"Row groups: {total_row_groups:,}")
|
| 137 |
+
print()
|
| 138 |
+
|
| 139 |
+
# Tokenizer-based sampling.
|
| 140 |
+
try:
|
| 141 |
+
import pickle
|
| 142 |
+
|
| 143 |
+
with open(os.path.join(TOKENIZER_DIR, "tokenizer.pkl"), "rb") as f:
|
| 144 |
+
enc = pickle.load(f)
|
| 145 |
+
print(f"Tokenizer vocab: {enc.n_vocab}")
|
| 146 |
+
except FileNotFoundError:
|
| 147 |
+
print("WARNING: tokenizer.pkl not found β skipping tokenized sample.")
|
| 148 |
+
enc = None
|
| 149 |
+
|
| 150 |
+
est_total_tokens = 0
|
| 151 |
+
if enc is not None:
|
| 152 |
+
if args.full:
|
| 153 |
+
sample_shards = train_shards
|
| 154 |
+
else:
|
| 155 |
+
# Pick shards evenly across the range for a representative sample.
|
| 156 |
+
n_sample = min(args.sample, len(train_shards))
|
| 157 |
+
if n_sample == 1:
|
| 158 |
+
sample_shards = [train_shards[0]]
|
| 159 |
+
else:
|
| 160 |
+
stride = max(1, len(train_shards) // n_sample)
|
| 161 |
+
sample_shards = train_shards[::stride][:n_sample]
|
| 162 |
+
|
| 163 |
+
t0 = time.time()
|
| 164 |
+
sample_tokens = 0
|
| 165 |
+
sample_docs = 0
|
| 166 |
+
sample_row_groups = 0
|
| 167 |
+
sample_shard_row_groups = 0
|
| 168 |
+
print(f"Tokenizing sample: {len(sample_shards)} shards ...")
|
| 169 |
+
for p in sample_shards:
|
| 170 |
+
tok, docs, n_rg = tokenized_sample(p, enc, row_groups=5)
|
| 171 |
+
sample_tokens += tok
|
| 172 |
+
sample_docs += docs
|
| 173 |
+
sample_row_groups += min(5, n_rg)
|
| 174 |
+
sample_shard_row_groups += n_rg
|
| 175 |
+
dt_tok = time.time() - t0
|
| 176 |
+
|
| 177 |
+
tokens_per_rg = sample_tokens / max(sample_row_groups, 1)
|
| 178 |
+
per_shard = tokens_per_rg * (sample_shard_row_groups / len(sample_shards))
|
| 179 |
+
est_total_tokens = per_shard * len(train_shards)
|
| 180 |
+
|
| 181 |
+
print(
|
| 182 |
+
f"Sampled {sample_row_groups} row groups ({sample_docs:,} docs, "
|
| 183 |
+
f"{sample_tokens:,} tokens) in {dt_tok:.1f}s"
|
| 184 |
+
)
|
| 185 |
+
print(f" tokens/row_group: {tokens_per_rg:,.0f}")
|
| 186 |
+
print(f" tokens/shard: {per_shard:,.0f}")
|
| 187 |
+
print(f" tokens/shard: {human_tokens(per_shard)}")
|
| 188 |
+
else:
|
| 189 |
+
# Fall back to character heuristic.
|
| 190 |
+
per_shard_chars = total_disk / max(len(train_shards), 1)
|
| 191 |
+
# Parquet compression ratio ~3x for text; decompressed ~3 * file size.
|
| 192 |
+
# Chars per token heuristic β 4.
|
| 193 |
+
est_total_tokens = (total_disk * 3.0) / CHARS_PER_TOKEN_HEURISTIC
|
| 194 |
+
|
| 195 |
+
print()
|
| 196 |
+
print("-" * 72)
|
| 197 |
+
print("Token budget analysis")
|
| 198 |
+
print("-" * 72)
|
| 199 |
+
print(f"Estimated total train tokens: {human_tokens(est_total_tokens)} "
|
| 200 |
+
f"({est_total_tokens:,.0f})")
|
| 201 |
+
print(f"12h @ 65k tok/s target: {human_tokens(TARGET_TOKENS_12H)}")
|
| 202 |
+
ratio = est_total_tokens / TARGET_TOKENS_12H if TARGET_TOKENS_12H else 0
|
| 203 |
+
if ratio >= 1.0:
|
| 204 |
+
print(f" Ratio: {ratio:.1f}x ({'SUFFICIENT' if ratio >= 1.2 else 'TIGHT'})")
|
| 205 |
+
else:
|
| 206 |
+
print(f" Ratio: {ratio:.2f}x INSUFFICIENT β need {1 - ratio:.0%} more")
|
| 207 |
+
print()
|
| 208 |
+
|
| 209 |
+
# Warnings about the dataloader behavior.
|
| 210 |
+
print("-" * 72)
|
| 211 |
+
print("Dataloader behavior (prepare.py::_document_batches)")
|
| 212 |
+
print("-" * 72)
|
| 213 |
+
print("+ Infinite streaming: while True around shard list (no StopIteration)")
|
| 214 |
+
print("+ Streams per shard, never loads full corpus into RAM")
|
| 215 |
+
print("+ BOS-aligned best-fit packing gives document-level buffer shuffling")
|
| 216 |
+
print("- Cross-shard order is LEXICOGRAPHIC and FIXED on every epoch")
|
| 217 |
+
print("- Row groups / rows WITHIN a shard are read in fixed order")
|
| 218 |
+
print(" (climbmix-400b-shuffle is pre-shuffled at source, mitigating this)")
|
| 219 |
+
print()
|
| 220 |
+
|
| 221 |
+
# Quick content diversity peek.
|
| 222 |
+
if train_shards:
|
| 223 |
+
print("-" * 72)
|
| 224 |
+
print("Content sample (shard 0, first 3 docs)")
|
| 225 |
+
print("-" * 72)
|
| 226 |
+
pf = pq.ParquetFile(train_shards[0])
|
| 227 |
+
rg = pf.read_row_group(0)
|
| 228 |
+
texts = rg.column("text").to_pylist()
|
| 229 |
+
for i, idx in enumerate([0, len(texts) // 2, len(texts) - 1]):
|
| 230 |
+
if idx < len(texts):
|
| 231 |
+
snippet = texts[idx][:160].replace("\n", " ")
|
| 232 |
+
print(f" [{i}] len={len(texts[idx])}: {snippet!r}")
|
| 233 |
+
print()
|
| 234 |
+
|
| 235 |
+
print("=" * 72)
|
| 236 |
+
print("Done.")
|
| 237 |
+
return 0
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
raise SystemExit(main())
|
overlay/scripts/download_sft_data.py
CHANGED
|
@@ -1,457 +1,457 @@
|
|
| 1 |
-
"""Download + tokenize instruction data for HYDRA SFT.
|
| 2 |
-
|
| 3 |
-
Writes int16 token shards to `data/sft/shard_XXX.bin` plus a
|
| 4 |
-
`data/sft/meta.json` with counts + special-token mapping.
|
| 5 |
-
|
| 6 |
-
Chat format (vocab's 4 reserved special tokens are repurposed):
|
| 7 |
-
<BOS=8188> <|user|=8189>\n{instruction}\n{input?}\n <|assistant|=8190>\n
|
| 8 |
-
{output}<|end|=8191>\n
|
| 9 |
-
|
| 10 |
-
Special-token IDs are constants derived from the tokenizer (they are the
|
| 11 |
-
last 4 IDs in an 8192-vocab). They are stored in meta.json for the SFT
|
| 12 |
-
script to read.
|
| 13 |
-
|
| 14 |
-
Sources (tried in order):
|
| 15 |
-
1. yahma/alpaca-cleaned (~52K pairs via HF parquet auto-convert)
|
| 16 |
-
2. databricks/databricks-dolly-15k (~15K pairs)
|
| 17 |
-
3. Hard-coded 200 simple Q&A pairs (offline backup)
|
| 18 |
-
|
| 19 |
-
Usage:
|
| 20 |
-
python scripts/download_sft_data.py # full download
|
| 21 |
-
python scripts/download_sft_data.py --test # small smoke run
|
| 22 |
-
python scripts/download_sft_data.py --offline # skip network; use backup
|
| 23 |
-
"""
|
| 24 |
-
|
| 25 |
-
from __future__ import annotations
|
| 26 |
-
|
| 27 |
-
import argparse
|
| 28 |
-
import json
|
| 29 |
-
import os
|
| 30 |
-
import pickle
|
| 31 |
-
import sys
|
| 32 |
-
import time
|
| 33 |
-
from pathlib import Path
|
| 34 |
-
|
| 35 |
-
import numpy as np
|
| 36 |
-
import requests
|
| 37 |
-
|
| 38 |
-
# Make `prepare` and `hydra.*` importable when run as a script
|
| 39 |
-
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 40 |
-
if str(_REPO_ROOT) not in sys.path:
|
| 41 |
-
sys.path.insert(0, str(_REPO_ROOT))
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# ---------------------------------------------------------------------------
|
| 45 |
-
# Constants
|
| 46 |
-
# ---------------------------------------------------------------------------
|
| 47 |
-
|
| 48 |
-
CACHE_DIR = Path.home() / ".cache" / "autoresearch"
|
| 49 |
-
TOKENIZER_PKL = CACHE_DIR / "tokenizer" / "tokenizer.pkl"
|
| 50 |
-
|
| 51 |
-
SFT_DIR = _REPO_ROOT / "data" / "sft"
|
| 52 |
-
SFT_DIR.mkdir(parents=True, exist_ok=True)
|
| 53 |
-
|
| 54 |
-
# Reserved token repurposing β must match prepare.py SPECIAL_TOKENS list
|
| 55 |
-
# (indices 8188-8191 in the 8192-vocab BPE).
|
| 56 |
-
BOS_ID = 8188 # <|reserved_0|>
|
| 57 |
-
USER_ID = 8189 # <|reserved_1|>
|
| 58 |
-
ASSISTANT_ID = 8190 # <|reserved_2|>
|
| 59 |
-
END_ID = 8191 # <|reserved_3|>
|
| 60 |
-
|
| 61 |
-
# Shards are int16 arrays of packed token IDs.
|
| 62 |
-
TOKENS_PER_SHARD = 1_048_576 # ~2 MB per shard
|
| 63 |
-
DTYPE = np.int16 # vocab_size=8192 fits in int16
|
| 64 |
-
|
| 65 |
-
TARGET_TOKENS_DEFAULT = 15_000_000 # ~15M instruction tokens
|
| 66 |
-
TARGET_TOKENS_TEST = 1_500_000 # smoke run
|
| 67 |
-
|
| 68 |
-
# HuggingFace auto-parquet endpoint β one file for alpaca-cleaned
|
| 69 |
-
ALPACA_URL = (
|
| 70 |
-
"https://huggingface.co/api/datasets/yahma/alpaca-cleaned/parquet/"
|
| 71 |
-
"default/train/0.parquet"
|
| 72 |
-
)
|
| 73 |
-
DOLLY_URL = (
|
| 74 |
-
"https://huggingface.co/api/datasets/databricks/databricks-dolly-15k/"
|
| 75 |
-
"parquet/default/train/0.parquet"
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
# ---------------------------------------------------------------------------
|
| 80 |
-
# Offline backup Q&A pairs (used only if network unavailable)
|
| 81 |
-
# ---------------------------------------------------------------------------
|
| 82 |
-
|
| 83 |
-
_BACKUP_QA = [
|
| 84 |
-
("What is the capital of France?", "The capital of France is Paris."),
|
| 85 |
-
("What is the capital of Germany?", "The capital of Germany is Berlin."),
|
| 86 |
-
("What is the capital of Japan?", "The capital of Japan is Tokyo."),
|
| 87 |
-
("What is the capital of Italy?", "The capital of Italy is Rome."),
|
| 88 |
-
("What is the capital of Spain?", "The capital of Spain is Madrid."),
|
| 89 |
-
("What is the capital of England?", "The capital of England is London."),
|
| 90 |
-
("What is the capital of Canada?", "The capital of Canada is Ottawa."),
|
| 91 |
-
("What is the capital of Australia?", "The capital of Australia is Canberra."),
|
| 92 |
-
("What is 2 plus 2?", "Two plus two equals four."),
|
| 93 |
-
("What is 5 plus 5?", "Five plus five equals ten."),
|
| 94 |
-
("What is 3 times 4?", "Three times four equals twelve."),
|
| 95 |
-
("What is 10 minus 3?", "Ten minus three equals seven."),
|
| 96 |
-
("Name a color.", "Red is a color. Other colors include blue, green, and yellow."),
|
| 97 |
-
("Name a fruit.", "Apple is a fruit. Other fruits include banana, orange, and grape."),
|
| 98 |
-
("Name an animal.", "A dog is an animal. Other animals include cat, horse, and cow."),
|
| 99 |
-
("What do birds do?", "Birds fly through the air using their wings."),
|
| 100 |
-
("What do fish do?", "Fish swim in water using their fins and tails."),
|
| 101 |
-
("What is water made of?", "Water is made of hydrogen and oxygen."),
|
| 102 |
-
("What is the sun?", "The sun is a star at the center of our solar system."),
|
| 103 |
-
("What is the moon?", "The moon is Earth's natural satellite."),
|
| 104 |
-
("What is a cloud?", "A cloud is a collection of tiny water droplets in the sky."),
|
| 105 |
-
("What is rain?", "Rain is water that falls from clouds to the ground."),
|
| 106 |
-
("What is snow?", "Snow is frozen water that falls from clouds in cold weather."),
|
| 107 |
-
("What is wind?", "Wind is moving air."),
|
| 108 |
-
("What is fire?", "Fire is a chemical reaction that produces heat and light."),
|
| 109 |
-
("What is ice?", "Ice is frozen water."),
|
| 110 |
-
("Say hello.", "Hello! How can I help you today?"),
|
| 111 |
-
("Say goodbye.", "Goodbye! Have a great day."),
|
| 112 |
-
("Tell me a short story.", "Once upon a time, a brave rabbit lived in the forest. The end."),
|
| 113 |
-
("Tell me a joke.", "Why did the chicken cross the road? To get to the other side."),
|
| 114 |
-
("Who wrote Hamlet?", "William Shakespeare wrote the play Hamlet."),
|
| 115 |
-
("Who wrote Romeo and Juliet?", "William Shakespeare wrote Romeo and Juliet."),
|
| 116 |
-
("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa."),
|
| 117 |
-
("When did World War 2 end?", "World War 2 ended in 1945."),
|
| 118 |
-
("What is gravity?", "Gravity is the force that pulls objects toward the Earth."),
|
| 119 |
-
("What is the speed of light?", "The speed of light is approximately 300,000 kilometers per second."),
|
| 120 |
-
("What is the largest planet?", "Jupiter is the largest planet in our solar system."),
|
| 121 |
-
("What is the smallest planet?", "Mercury is the smallest planet in our solar system."),
|
| 122 |
-
("At what temperature does water boil?", "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit."),
|
| 123 |
-
("At what temperature does water freeze?", "Water freezes at 0 degrees Celsius or 32 degrees Fahrenheit."),
|
| 124 |
-
("How many legs does a spider have?", "A spider has eight legs."),
|
| 125 |
-
("How many legs does an insect have?", "An insect has six legs."),
|
| 126 |
-
("What do plants need to grow?", "Plants need sunlight, water, soil, and air to grow."),
|
| 127 |
-
("What do humans eat?", "Humans eat a variety of foods including fruits, vegetables, meat, and grains."),
|
| 128 |
-
("What is a book?", "A book is a collection of written or printed pages bound together."),
|
| 129 |
-
("What is a computer?", "A computer is an electronic device that processes information."),
|
| 130 |
-
("What is a phone?", "A phone is a device used to communicate with people at a distance."),
|
| 131 |
-
("What is music?", "Music is an arrangement of sounds that is pleasing to hear."),
|
| 132 |
-
("What is art?", "Art is the expression of human creativity and imagination."),
|
| 133 |
-
("What is a language?", "A language is a system of communication used by a group of people."),
|
| 134 |
-
]
|
| 135 |
-
|
| 136 |
-
# Duplicate to reach ~200 samples (each pair appears ~4x)
|
| 137 |
-
BACKUP_QA = (_BACKUP_QA * 4)[:200]
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
# ---------------------------------------------------------------------------
|
| 141 |
-
# Tokenizer loader
|
| 142 |
-
# ---------------------------------------------------------------------------
|
| 143 |
-
|
| 144 |
-
class _TokenizerWrapper:
|
| 145 |
-
"""Minimal wrapper around the pickled tiktoken.Encoding. We avoid
|
| 146 |
-
importing `prepare.Tokenizer` to sidestep its side effects (which
|
| 147 |
-
touch the running pretrain's cache files)."""
|
| 148 |
-
|
| 149 |
-
def __init__(self, enc):
|
| 150 |
-
self.enc = enc
|
| 151 |
-
|
| 152 |
-
def encode(self, text: str) -> list[int]:
|
| 153 |
-
return self.enc.encode_ordinary(text)
|
| 154 |
-
|
| 155 |
-
@property
|
| 156 |
-
def vocab_size(self) -> int:
|
| 157 |
-
return self.enc.n_vocab
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def load_tokenizer() -> _TokenizerWrapper:
|
| 161 |
-
if not TOKENIZER_PKL.exists():
|
| 162 |
-
raise FileNotFoundError(
|
| 163 |
-
f"Tokenizer not found at {TOKENIZER_PKL}. Run `python prepare.py` "
|
| 164 |
-
f"first."
|
| 165 |
-
)
|
| 166 |
-
with open(TOKENIZER_PKL, "rb") as f:
|
| 167 |
-
enc = pickle.load(f)
|
| 168 |
-
tok = _TokenizerWrapper(enc)
|
| 169 |
-
assert tok.vocab_size == 8192, f"Expected vocab=8192, got {tok.vocab_size}"
|
| 170 |
-
return tok
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
# ---------------------------------------------------------------------------
|
| 174 |
-
# Source downloaders
|
| 175 |
-
# ---------------------------------------------------------------------------
|
| 176 |
-
|
| 177 |
-
def _download_parquet(url: str, local_path: Path, timeout: int = 60) -> bool:
|
| 178 |
-
"""Stream-download a parquet file with retry. Returns True on success."""
|
| 179 |
-
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 180 |
-
tmp = local_path.with_suffix(local_path.suffix + ".tmp")
|
| 181 |
-
for attempt in range(1, 4):
|
| 182 |
-
try:
|
| 183 |
-
with requests.get(url, stream=True, timeout=timeout,
|
| 184 |
-
allow_redirects=True) as r:
|
| 185 |
-
r.raise_for_status()
|
| 186 |
-
with open(tmp, "wb") as f:
|
| 187 |
-
for chunk in r.iter_content(chunk_size=1 << 20):
|
| 188 |
-
if chunk:
|
| 189 |
-
f.write(chunk)
|
| 190 |
-
tmp.replace(local_path)
|
| 191 |
-
return True
|
| 192 |
-
except Exception as e:
|
| 193 |
-
print(f" [net] attempt {attempt} failed: {e}", flush=True)
|
| 194 |
-
for p in (tmp, local_path):
|
| 195 |
-
try:
|
| 196 |
-
p.unlink()
|
| 197 |
-
except FileNotFoundError:
|
| 198 |
-
pass
|
| 199 |
-
time.sleep(2 ** attempt)
|
| 200 |
-
return False
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
def _iter_alpaca(local_path: Path):
|
| 204 |
-
"""Yield (instruction, input, output) from alpaca-cleaned parquet."""
|
| 205 |
-
import pyarrow.parquet as pq
|
| 206 |
-
pf = pq.ParquetFile(str(local_path))
|
| 207 |
-
for rg_idx in range(pf.num_row_groups):
|
| 208 |
-
rg = pf.read_row_group(rg_idx)
|
| 209 |
-
instr_col = rg.column("instruction").to_pylist()
|
| 210 |
-
input_col = rg.column("input").to_pylist()
|
| 211 |
-
output_col = rg.column("output").to_pylist()
|
| 212 |
-
for instruction, input_text, output in zip(instr_col, input_col, output_col):
|
| 213 |
-
if instruction and output:
|
| 214 |
-
yield instruction, (input_text or ""), output
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def _iter_dolly(local_path: Path):
|
| 218 |
-
"""Yield (instruction, input, output) from dolly-15k parquet."""
|
| 219 |
-
import pyarrow.parquet as pq
|
| 220 |
-
pf = pq.ParquetFile(str(local_path))
|
| 221 |
-
# Schema: instruction, context, response, category
|
| 222 |
-
for rg_idx in range(pf.num_row_groups):
|
| 223 |
-
rg = pf.read_row_group(rg_idx)
|
| 224 |
-
cols = {n: rg.column(n).to_pylist() for n in rg.schema.names}
|
| 225 |
-
instr_col = cols.get("instruction") or cols.get("Instruction")
|
| 226 |
-
ctx_col = cols.get("context") or cols.get("Context") or [""] * len(instr_col)
|
| 227 |
-
resp_col = cols.get("response") or cols.get("Response")
|
| 228 |
-
for instruction, context, response in zip(instr_col, ctx_col, resp_col):
|
| 229 |
-
if instruction and response:
|
| 230 |
-
yield instruction, (context or ""), response
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
def _iter_backup():
|
| 234 |
-
for q, a in BACKUP_QA:
|
| 235 |
-
yield q, "", a
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
# ---------------------------------------------------------------------------
|
| 239 |
-
# Encoding
|
| 240 |
-
# ---------------------------------------------------------------------------
|
| 241 |
-
|
| 242 |
-
def encode_example(tok: _TokenizerWrapper, instruction: str,
|
| 243 |
-
input_text: str, output: str) -> list[int]:
|
| 244 |
-
"""Serialize one instruction/response pair into a flat token list.
|
| 245 |
-
|
| 246 |
-
Format:
|
| 247 |
-
<BOS> <|user|> \\n {instr}\\n[{input}\\n] <|assistant|> \\n {output} <|end|> \\n
|
| 248 |
-
"""
|
| 249 |
-
ids: list[int] = [BOS_ID, USER_ID]
|
| 250 |
-
ids += tok.encode("\n" + instruction.strip())
|
| 251 |
-
if input_text and input_text.strip():
|
| 252 |
-
ids += tok.encode("\n" + input_text.strip())
|
| 253 |
-
ids += tok.encode("\n")
|
| 254 |
-
ids.append(ASSISTANT_ID)
|
| 255 |
-
ids += tok.encode("\n" + output.strip())
|
| 256 |
-
ids.append(END_ID)
|
| 257 |
-
ids += tok.encode("\n")
|
| 258 |
-
return ids
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
def encode_example_with_mask(tok: _TokenizerWrapper, instruction: str,
|
| 262 |
-
input_text: str, output: str
|
| 263 |
-
) -> tuple[list[int], list[int]]:
|
| 264 |
-
"""Return (tokens, mask) where mask[i]=1 means 'compute loss on token i'
|
| 265 |
-
and mask[i]=0 means 'prompt, ignore'. The boundary is the <|assistant|>
|
| 266 |
-
token: the assistant response (and <|end|>) contribute to loss; the
|
| 267 |
-
user prompt does not."""
|
| 268 |
-
prompt_ids = [BOS_ID, USER_ID] + tok.encode("\n" + instruction.strip())
|
| 269 |
-
if input_text and input_text.strip():
|
| 270 |
-
prompt_ids += tok.encode("\n" + input_text.strip())
|
| 271 |
-
prompt_ids += tok.encode("\n")
|
| 272 |
-
prompt_ids.append(ASSISTANT_ID)
|
| 273 |
-
|
| 274 |
-
response_ids = tok.encode("\n" + output.strip())
|
| 275 |
-
response_ids.append(END_ID)
|
| 276 |
-
response_ids += tok.encode("\n")
|
| 277 |
-
|
| 278 |
-
ids = prompt_ids + response_ids
|
| 279 |
-
mask = [0] * len(prompt_ids) + [1] * len(response_ids)
|
| 280 |
-
return ids, mask
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
# ---------------------------------------------------------------------------
|
| 284 |
-
# Shard writer
|
| 285 |
-
# ---------------------------------------------------------------------------
|
| 286 |
-
|
| 287 |
-
class ShardWriter:
|
| 288 |
-
"""Writes two parallel int16 files per shard:
|
| 289 |
-
data/sft/shard_XXX.bin β token IDs
|
| 290 |
-
data/sft/mask_XXX.bin β 0/1 loss mask
|
| 291 |
-
|
| 292 |
-
Packs one example after another with no padding. At runtime, SFT builds
|
| 293 |
-
sequences of length MAX_SEQ_LEN by slicing across these flat arrays.
|
| 294 |
-
"""
|
| 295 |
-
|
| 296 |
-
def __init__(self, out_dir: Path, tokens_per_shard: int = TOKENS_PER_SHARD):
|
| 297 |
-
self.out_dir = out_dir
|
| 298 |
-
self.tokens_per_shard = tokens_per_shard
|
| 299 |
-
self.shard_idx = 0
|
| 300 |
-
self._buf_tok: list[int] = []
|
| 301 |
-
self._buf_mask: list[int] = []
|
| 302 |
-
self.total_tokens = 0
|
| 303 |
-
|
| 304 |
-
def add(self, tokens: list[int], mask: list[int]):
|
| 305 |
-
assert len(tokens) == len(mask)
|
| 306 |
-
self._buf_tok.extend(tokens)
|
| 307 |
-
self._buf_mask.extend(mask)
|
| 308 |
-
self.total_tokens += len(tokens)
|
| 309 |
-
while len(self._buf_tok) >= self.tokens_per_shard:
|
| 310 |
-
self._flush_one(self.tokens_per_shard)
|
| 311 |
-
|
| 312 |
-
def _flush_one(self, n: int):
|
| 313 |
-
tok_path = self.out_dir / f"shard_{self.shard_idx:04d}.bin"
|
| 314 |
-
mask_path = self.out_dir / f"mask_{self.shard_idx:04d}.bin"
|
| 315 |
-
arr_tok = np.array(self._buf_tok[:n], dtype=DTYPE)
|
| 316 |
-
arr_mask = np.array(self._buf_mask[:n], dtype=np.uint8)
|
| 317 |
-
arr_tok.tofile(tok_path)
|
| 318 |
-
arr_mask.tofile(mask_path)
|
| 319 |
-
self._buf_tok = self._buf_tok[n:]
|
| 320 |
-
self._buf_mask = self._buf_mask[n:]
|
| 321 |
-
print(f" wrote {tok_path.name} ({n:,} tokens)", flush=True)
|
| 322 |
-
self.shard_idx += 1
|
| 323 |
-
|
| 324 |
-
def finalize(self):
|
| 325 |
-
if self._buf_tok:
|
| 326 |
-
self._flush_one(len(self._buf_tok))
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
# ---------------------------------------------------------------------------
|
| 330 |
-
# Main
|
| 331 |
-
# ---------------------------------------------------------------------------
|
| 332 |
-
|
| 333 |
-
def main():
|
| 334 |
-
ap = argparse.ArgumentParser()
|
| 335 |
-
ap.add_argument("--test", action="store_true",
|
| 336 |
-
help="Small smoke run: write ~1.5M tokens and exit.")
|
| 337 |
-
ap.add_argument("--offline", action="store_true",
|
| 338 |
-
help="Skip network, use hard-coded backup only.")
|
| 339 |
-
ap.add_argument("--target-tokens", type=int, default=None,
|
| 340 |
-
help="Override target token count.")
|
| 341 |
-
args = ap.parse_args()
|
| 342 |
-
|
| 343 |
-
target = args.target_tokens or (
|
| 344 |
-
TARGET_TOKENS_TEST if args.test else TARGET_TOKENS_DEFAULT
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
-
print(f"SFT_DIR: {SFT_DIR}")
|
| 348 |
-
print(f"Target tokens: {target:,}")
|
| 349 |
-
print(f"Offline mode: {args.offline}")
|
| 350 |
-
|
| 351 |
-
# Clear any prior shards
|
| 352 |
-
for p in SFT_DIR.glob("shard_*.bin"):
|
| 353 |
-
p.unlink()
|
| 354 |
-
for p in SFT_DIR.glob("mask_*.bin"):
|
| 355 |
-
p.unlink()
|
| 356 |
-
|
| 357 |
-
tok = load_tokenizer()
|
| 358 |
-
print(f"Tokenizer vocab: {tok.vocab_size}")
|
| 359 |
-
print(f"Special tokens: BOS={BOS_ID} USER={USER_ID} "
|
| 360 |
-
f"ASSISTANT={ASSISTANT_ID} END={END_ID}")
|
| 361 |
-
|
| 362 |
-
sources = [] # list of (name, iterator_fn)
|
| 363 |
-
if not args.offline:
|
| 364 |
-
alpaca_path = SFT_DIR / "alpaca_raw.parquet"
|
| 365 |
-
print(f"\n[src] downloading alpaca-cleaned -> {alpaca_path.name} ...")
|
| 366 |
-
if _download_parquet(ALPACA_URL, alpaca_path):
|
| 367 |
-
print(f" ok ({alpaca_path.stat().st_size // (1 << 20)} MiB)")
|
| 368 |
-
sources.append(("alpaca-cleaned", lambda: _iter_alpaca(alpaca_path)))
|
| 369 |
-
else:
|
| 370 |
-
print(" alpaca download FAILED, trying dolly...")
|
| 371 |
-
dolly_path = SFT_DIR / "dolly_raw.parquet"
|
| 372 |
-
if _download_parquet(DOLLY_URL, dolly_path):
|
| 373 |
-
print(f" ok ({dolly_path.stat().st_size // (1 << 20)} MiB)")
|
| 374 |
-
sources.append(("dolly-15k", lambda: _iter_dolly(dolly_path)))
|
| 375 |
-
|
| 376 |
-
# Always include backup β cheap, catches tail
|
| 377 |
-
sources.append(("backup-200", _iter_backup))
|
| 378 |
-
|
| 379 |
-
if not sources:
|
| 380 |
-
print("FATAL: no data sources available.", file=sys.stderr)
|
| 381 |
-
sys.exit(1)
|
| 382 |
-
|
| 383 |
-
# Stream-encode
|
| 384 |
-
writer = ShardWriter(SFT_DIR)
|
| 385 |
-
n_examples = 0
|
| 386 |
-
n_assistant_tokens = 0
|
| 387 |
-
source_counts = {}
|
| 388 |
-
|
| 389 |
-
for src_name, src_fn in sources:
|
| 390 |
-
print(f"\n[src] encoding {src_name} ...")
|
| 391 |
-
src_examples = 0
|
| 392 |
-
src_tokens = 0
|
| 393 |
-
for (instruction, input_text, output) in src_fn():
|
| 394 |
-
# Skip overly long outputs β 7.5M model can't use them
|
| 395 |
-
if len(output) > 2000:
|
| 396 |
-
output = output[:2000]
|
| 397 |
-
ids, mask = encode_example_with_mask(tok, instruction,
|
| 398 |
-
input_text, output)
|
| 399 |
-
if len(ids) < 4 or len(ids) > 512:
|
| 400 |
-
# Skip degenerate / too-long examples
|
| 401 |
-
continue
|
| 402 |
-
writer.add(ids, mask)
|
| 403 |
-
n_examples += 1
|
| 404 |
-
src_examples += 1
|
| 405 |
-
src_tokens += len(ids)
|
| 406 |
-
n_assistant_tokens += sum(mask)
|
| 407 |
-
if writer.total_tokens >= target:
|
| 408 |
-
break
|
| 409 |
-
source_counts[src_name] = {
|
| 410 |
-
"examples": src_examples,
|
| 411 |
-
"tokens": src_tokens,
|
| 412 |
-
}
|
| 413 |
-
print(f" {src_name}: {src_examples:,} examples, {src_tokens:,} tokens")
|
| 414 |
-
if writer.total_tokens >= target:
|
| 415 |
-
break
|
| 416 |
-
|
| 417 |
-
writer.finalize()
|
| 418 |
-
|
| 419 |
-
meta = {
|
| 420 |
-
"total_tokens": writer.total_tokens,
|
| 421 |
-
"total_examples": n_examples,
|
| 422 |
-
"assistant_tokens_in_loss": n_assistant_tokens,
|
| 423 |
-
"num_shards": writer.shard_idx,
|
| 424 |
-
"tokens_per_shard": TOKENS_PER_SHARD,
|
| 425 |
-
"dtype": "int16",
|
| 426 |
-
"vocab_size": tok.vocab_size,
|
| 427 |
-
"special_tokens": {
|
| 428 |
-
"bos": BOS_ID,
|
| 429 |
-
"user": USER_ID,
|
| 430 |
-
"assistant": ASSISTANT_ID,
|
| 431 |
-
"end": END_ID,
|
| 432 |
-
},
|
| 433 |
-
"sources": source_counts,
|
| 434 |
-
"format_hint": (
|
| 435 |
-
"<BOS><|user|>\\n{instr}\\n[{input}\\n]<|assistant|>\\n"
|
| 436 |
-
"{output}<|end|>\\n"
|
| 437 |
-
),
|
| 438 |
-
}
|
| 439 |
-
meta_path = SFT_DIR / "meta.json"
|
| 440 |
-
with open(meta_path, "w") as f:
|
| 441 |
-
json.dump(meta, f, indent=2)
|
| 442 |
-
|
| 443 |
-
print(f"\n===== SFT data ready =====")
|
| 444 |
-
print(f" examples: {n_examples:,}")
|
| 445 |
-
print(f" total tokens: {writer.total_tokens:,}")
|
| 446 |
-
print(f" loss tokens: {n_assistant_tokens:,}")
|
| 447 |
-
print(f" shards: {writer.shard_idx}")
|
| 448 |
-
print(f" meta: {meta_path}")
|
| 449 |
-
|
| 450 |
-
if args.test and writer.total_tokens < 1_000_000:
|
| 451 |
-
print(f"\nWARN: test mode produced only {writer.total_tokens:,} "
|
| 452 |
-
f"tokens β below 1M threshold.")
|
| 453 |
-
sys.exit(2)
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
if __name__ == "__main__":
|
| 457 |
-
main()
|
|
|
|
| 1 |
+
"""Download + tokenize instruction data for HYDRA SFT.
|
| 2 |
+
|
| 3 |
+
Writes int16 token shards to `data/sft/shard_XXX.bin` plus a
|
| 4 |
+
`data/sft/meta.json` with counts + special-token mapping.
|
| 5 |
+
|
| 6 |
+
Chat format (vocab's 4 reserved special tokens are repurposed):
|
| 7 |
+
<BOS=8188> <|user|=8189>\n{instruction}\n{input?}\n <|assistant|=8190>\n
|
| 8 |
+
{output}<|end|=8191>\n
|
| 9 |
+
|
| 10 |
+
Special-token IDs are constants derived from the tokenizer (they are the
|
| 11 |
+
last 4 IDs in an 8192-vocab). They are stored in meta.json for the SFT
|
| 12 |
+
script to read.
|
| 13 |
+
|
| 14 |
+
Sources (tried in order):
|
| 15 |
+
1. yahma/alpaca-cleaned (~52K pairs via HF parquet auto-convert)
|
| 16 |
+
2. databricks/databricks-dolly-15k (~15K pairs)
|
| 17 |
+
3. Hard-coded 200 simple Q&A pairs (offline backup)
|
| 18 |
+
|
| 19 |
+
Usage:
|
| 20 |
+
python scripts/download_sft_data.py # full download
|
| 21 |
+
python scripts/download_sft_data.py --test # small smoke run
|
| 22 |
+
python scripts/download_sft_data.py --offline # skip network; use backup
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import argparse
|
| 28 |
+
import json
|
| 29 |
+
import os
|
| 30 |
+
import pickle
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
|
| 35 |
+
import numpy as np
|
| 36 |
+
import requests
|
| 37 |
+
|
| 38 |
+
# Make `prepare` and `hydra.*` importable when run as a script
|
| 39 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 40 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 41 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# Constants
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
CACHE_DIR = Path.home() / ".cache" / "autoresearch"
|
| 49 |
+
TOKENIZER_PKL = CACHE_DIR / "tokenizer" / "tokenizer.pkl"
|
| 50 |
+
|
| 51 |
+
SFT_DIR = _REPO_ROOT / "data" / "sft"
|
| 52 |
+
SFT_DIR.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
|
| 54 |
+
# Reserved token repurposing β must match prepare.py SPECIAL_TOKENS list
|
| 55 |
+
# (indices 8188-8191 in the 8192-vocab BPE).
|
| 56 |
+
BOS_ID = 8188 # <|reserved_0|>
|
| 57 |
+
USER_ID = 8189 # <|reserved_1|>
|
| 58 |
+
ASSISTANT_ID = 8190 # <|reserved_2|>
|
| 59 |
+
END_ID = 8191 # <|reserved_3|>
|
| 60 |
+
|
| 61 |
+
# Shards are int16 arrays of packed token IDs.
|
| 62 |
+
TOKENS_PER_SHARD = 1_048_576 # ~2 MB per shard
|
| 63 |
+
DTYPE = np.int16 # vocab_size=8192 fits in int16
|
| 64 |
+
|
| 65 |
+
TARGET_TOKENS_DEFAULT = 15_000_000 # ~15M instruction tokens
|
| 66 |
+
TARGET_TOKENS_TEST = 1_500_000 # smoke run
|
| 67 |
+
|
| 68 |
+
# HuggingFace auto-parquet endpoint β one file for alpaca-cleaned
|
| 69 |
+
ALPACA_URL = (
|
| 70 |
+
"https://huggingface.co/api/datasets/yahma/alpaca-cleaned/parquet/"
|
| 71 |
+
"default/train/0.parquet"
|
| 72 |
+
)
|
| 73 |
+
DOLLY_URL = (
|
| 74 |
+
"https://huggingface.co/api/datasets/databricks/databricks-dolly-15k/"
|
| 75 |
+
"parquet/default/train/0.parquet"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# Offline backup Q&A pairs (used only if network unavailable)
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
_BACKUP_QA = [
|
| 84 |
+
("What is the capital of France?", "The capital of France is Paris."),
|
| 85 |
+
("What is the capital of Germany?", "The capital of Germany is Berlin."),
|
| 86 |
+
("What is the capital of Japan?", "The capital of Japan is Tokyo."),
|
| 87 |
+
("What is the capital of Italy?", "The capital of Italy is Rome."),
|
| 88 |
+
("What is the capital of Spain?", "The capital of Spain is Madrid."),
|
| 89 |
+
("What is the capital of England?", "The capital of England is London."),
|
| 90 |
+
("What is the capital of Canada?", "The capital of Canada is Ottawa."),
|
| 91 |
+
("What is the capital of Australia?", "The capital of Australia is Canberra."),
|
| 92 |
+
("What is 2 plus 2?", "Two plus two equals four."),
|
| 93 |
+
("What is 5 plus 5?", "Five plus five equals ten."),
|
| 94 |
+
("What is 3 times 4?", "Three times four equals twelve."),
|
| 95 |
+
("What is 10 minus 3?", "Ten minus three equals seven."),
|
| 96 |
+
("Name a color.", "Red is a color. Other colors include blue, green, and yellow."),
|
| 97 |
+
("Name a fruit.", "Apple is a fruit. Other fruits include banana, orange, and grape."),
|
| 98 |
+
("Name an animal.", "A dog is an animal. Other animals include cat, horse, and cow."),
|
| 99 |
+
("What do birds do?", "Birds fly through the air using their wings."),
|
| 100 |
+
("What do fish do?", "Fish swim in water using their fins and tails."),
|
| 101 |
+
("What is water made of?", "Water is made of hydrogen and oxygen."),
|
| 102 |
+
("What is the sun?", "The sun is a star at the center of our solar system."),
|
| 103 |
+
("What is the moon?", "The moon is Earth's natural satellite."),
|
| 104 |
+
("What is a cloud?", "A cloud is a collection of tiny water droplets in the sky."),
|
| 105 |
+
("What is rain?", "Rain is water that falls from clouds to the ground."),
|
| 106 |
+
("What is snow?", "Snow is frozen water that falls from clouds in cold weather."),
|
| 107 |
+
("What is wind?", "Wind is moving air."),
|
| 108 |
+
("What is fire?", "Fire is a chemical reaction that produces heat and light."),
|
| 109 |
+
("What is ice?", "Ice is frozen water."),
|
| 110 |
+
("Say hello.", "Hello! How can I help you today?"),
|
| 111 |
+
("Say goodbye.", "Goodbye! Have a great day."),
|
| 112 |
+
("Tell me a short story.", "Once upon a time, a brave rabbit lived in the forest. The end."),
|
| 113 |
+
("Tell me a joke.", "Why did the chicken cross the road? To get to the other side."),
|
| 114 |
+
("Who wrote Hamlet?", "William Shakespeare wrote the play Hamlet."),
|
| 115 |
+
("Who wrote Romeo and Juliet?", "William Shakespeare wrote Romeo and Juliet."),
|
| 116 |
+
("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa."),
|
| 117 |
+
("When did World War 2 end?", "World War 2 ended in 1945."),
|
| 118 |
+
("What is gravity?", "Gravity is the force that pulls objects toward the Earth."),
|
| 119 |
+
("What is the speed of light?", "The speed of light is approximately 300,000 kilometers per second."),
|
| 120 |
+
("What is the largest planet?", "Jupiter is the largest planet in our solar system."),
|
| 121 |
+
("What is the smallest planet?", "Mercury is the smallest planet in our solar system."),
|
| 122 |
+
("At what temperature does water boil?", "Water boils at 100 degrees Celsius or 212 degrees Fahrenheit."),
|
| 123 |
+
("At what temperature does water freeze?", "Water freezes at 0 degrees Celsius or 32 degrees Fahrenheit."),
|
| 124 |
+
("How many legs does a spider have?", "A spider has eight legs."),
|
| 125 |
+
("How many legs does an insect have?", "An insect has six legs."),
|
| 126 |
+
("What do plants need to grow?", "Plants need sunlight, water, soil, and air to grow."),
|
| 127 |
+
("What do humans eat?", "Humans eat a variety of foods including fruits, vegetables, meat, and grains."),
|
| 128 |
+
("What is a book?", "A book is a collection of written or printed pages bound together."),
|
| 129 |
+
("What is a computer?", "A computer is an electronic device that processes information."),
|
| 130 |
+
("What is a phone?", "A phone is a device used to communicate with people at a distance."),
|
| 131 |
+
("What is music?", "Music is an arrangement of sounds that is pleasing to hear."),
|
| 132 |
+
("What is art?", "Art is the expression of human creativity and imagination."),
|
| 133 |
+
("What is a language?", "A language is a system of communication used by a group of people."),
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
# Duplicate to reach ~200 samples (each pair appears ~4x)
|
| 137 |
+
BACKUP_QA = (_BACKUP_QA * 4)[:200]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ---------------------------------------------------------------------------
|
| 141 |
+
# Tokenizer loader
|
| 142 |
+
# ---------------------------------------------------------------------------
|
| 143 |
+
|
| 144 |
+
class _TokenizerWrapper:
|
| 145 |
+
"""Minimal wrapper around the pickled tiktoken.Encoding. We avoid
|
| 146 |
+
importing `prepare.Tokenizer` to sidestep its side effects (which
|
| 147 |
+
touch the running pretrain's cache files)."""
|
| 148 |
+
|
| 149 |
+
def __init__(self, enc):
|
| 150 |
+
self.enc = enc
|
| 151 |
+
|
| 152 |
+
def encode(self, text: str) -> list[int]:
|
| 153 |
+
return self.enc.encode_ordinary(text)
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def vocab_size(self) -> int:
|
| 157 |
+
return self.enc.n_vocab
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def load_tokenizer() -> _TokenizerWrapper:
|
| 161 |
+
if not TOKENIZER_PKL.exists():
|
| 162 |
+
raise FileNotFoundError(
|
| 163 |
+
f"Tokenizer not found at {TOKENIZER_PKL}. Run `python prepare.py` "
|
| 164 |
+
f"first."
|
| 165 |
+
)
|
| 166 |
+
with open(TOKENIZER_PKL, "rb") as f:
|
| 167 |
+
enc = pickle.load(f)
|
| 168 |
+
tok = _TokenizerWrapper(enc)
|
| 169 |
+
assert tok.vocab_size == 8192, f"Expected vocab=8192, got {tok.vocab_size}"
|
| 170 |
+
return tok
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
# Source downloaders
|
| 175 |
+
# ---------------------------------------------------------------------------
|
| 176 |
+
|
| 177 |
+
def _download_parquet(url: str, local_path: Path, timeout: int = 60) -> bool:
|
| 178 |
+
"""Stream-download a parquet file with retry. Returns True on success."""
|
| 179 |
+
local_path.parent.mkdir(parents=True, exist_ok=True)
|
| 180 |
+
tmp = local_path.with_suffix(local_path.suffix + ".tmp")
|
| 181 |
+
for attempt in range(1, 4):
|
| 182 |
+
try:
|
| 183 |
+
with requests.get(url, stream=True, timeout=timeout,
|
| 184 |
+
allow_redirects=True) as r:
|
| 185 |
+
r.raise_for_status()
|
| 186 |
+
with open(tmp, "wb") as f:
|
| 187 |
+
for chunk in r.iter_content(chunk_size=1 << 20):
|
| 188 |
+
if chunk:
|
| 189 |
+
f.write(chunk)
|
| 190 |
+
tmp.replace(local_path)
|
| 191 |
+
return True
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f" [net] attempt {attempt} failed: {e}", flush=True)
|
| 194 |
+
for p in (tmp, local_path):
|
| 195 |
+
try:
|
| 196 |
+
p.unlink()
|
| 197 |
+
except FileNotFoundError:
|
| 198 |
+
pass
|
| 199 |
+
time.sleep(2 ** attempt)
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def _iter_alpaca(local_path: Path):
|
| 204 |
+
"""Yield (instruction, input, output) from alpaca-cleaned parquet."""
|
| 205 |
+
import pyarrow.parquet as pq
|
| 206 |
+
pf = pq.ParquetFile(str(local_path))
|
| 207 |
+
for rg_idx in range(pf.num_row_groups):
|
| 208 |
+
rg = pf.read_row_group(rg_idx)
|
| 209 |
+
instr_col = rg.column("instruction").to_pylist()
|
| 210 |
+
input_col = rg.column("input").to_pylist()
|
| 211 |
+
output_col = rg.column("output").to_pylist()
|
| 212 |
+
for instruction, input_text, output in zip(instr_col, input_col, output_col):
|
| 213 |
+
if instruction and output:
|
| 214 |
+
yield instruction, (input_text or ""), output
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _iter_dolly(local_path: Path):
|
| 218 |
+
"""Yield (instruction, input, output) from dolly-15k parquet."""
|
| 219 |
+
import pyarrow.parquet as pq
|
| 220 |
+
pf = pq.ParquetFile(str(local_path))
|
| 221 |
+
# Schema: instruction, context, response, category
|
| 222 |
+
for rg_idx in range(pf.num_row_groups):
|
| 223 |
+
rg = pf.read_row_group(rg_idx)
|
| 224 |
+
cols = {n: rg.column(n).to_pylist() for n in rg.schema.names}
|
| 225 |
+
instr_col = cols.get("instruction") or cols.get("Instruction")
|
| 226 |
+
ctx_col = cols.get("context") or cols.get("Context") or [""] * len(instr_col)
|
| 227 |
+
resp_col = cols.get("response") or cols.get("Response")
|
| 228 |
+
for instruction, context, response in zip(instr_col, ctx_col, resp_col):
|
| 229 |
+
if instruction and response:
|
| 230 |
+
yield instruction, (context or ""), response
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _iter_backup():
|
| 234 |
+
for q, a in BACKUP_QA:
|
| 235 |
+
yield q, "", a
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
# Encoding
|
| 240 |
+
# ---------------------------------------------------------------------------
|
| 241 |
+
|
| 242 |
+
def encode_example(tok: _TokenizerWrapper, instruction: str,
|
| 243 |
+
input_text: str, output: str) -> list[int]:
|
| 244 |
+
"""Serialize one instruction/response pair into a flat token list.
|
| 245 |
+
|
| 246 |
+
Format:
|
| 247 |
+
<BOS> <|user|> \\n {instr}\\n[{input}\\n] <|assistant|> \\n {output} <|end|> \\n
|
| 248 |
+
"""
|
| 249 |
+
ids: list[int] = [BOS_ID, USER_ID]
|
| 250 |
+
ids += tok.encode("\n" + instruction.strip())
|
| 251 |
+
if input_text and input_text.strip():
|
| 252 |
+
ids += tok.encode("\n" + input_text.strip())
|
| 253 |
+
ids += tok.encode("\n")
|
| 254 |
+
ids.append(ASSISTANT_ID)
|
| 255 |
+
ids += tok.encode("\n" + output.strip())
|
| 256 |
+
ids.append(END_ID)
|
| 257 |
+
ids += tok.encode("\n")
|
| 258 |
+
return ids
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def encode_example_with_mask(tok: _TokenizerWrapper, instruction: str,
|
| 262 |
+
input_text: str, output: str
|
| 263 |
+
) -> tuple[list[int], list[int]]:
|
| 264 |
+
"""Return (tokens, mask) where mask[i]=1 means 'compute loss on token i'
|
| 265 |
+
and mask[i]=0 means 'prompt, ignore'. The boundary is the <|assistant|>
|
| 266 |
+
token: the assistant response (and <|end|>) contribute to loss; the
|
| 267 |
+
user prompt does not."""
|
| 268 |
+
prompt_ids = [BOS_ID, USER_ID] + tok.encode("\n" + instruction.strip())
|
| 269 |
+
if input_text and input_text.strip():
|
| 270 |
+
prompt_ids += tok.encode("\n" + input_text.strip())
|
| 271 |
+
prompt_ids += tok.encode("\n")
|
| 272 |
+
prompt_ids.append(ASSISTANT_ID)
|
| 273 |
+
|
| 274 |
+
response_ids = tok.encode("\n" + output.strip())
|
| 275 |
+
response_ids.append(END_ID)
|
| 276 |
+
response_ids += tok.encode("\n")
|
| 277 |
+
|
| 278 |
+
ids = prompt_ids + response_ids
|
| 279 |
+
mask = [0] * len(prompt_ids) + [1] * len(response_ids)
|
| 280 |
+
return ids, mask
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# ---------------------------------------------------------------------------
|
| 284 |
+
# Shard writer
|
| 285 |
+
# ---------------------------------------------------------------------------
|
| 286 |
+
|
| 287 |
+
class ShardWriter:
|
| 288 |
+
"""Writes two parallel int16 files per shard:
|
| 289 |
+
data/sft/shard_XXX.bin β token IDs
|
| 290 |
+
data/sft/mask_XXX.bin β 0/1 loss mask
|
| 291 |
+
|
| 292 |
+
Packs one example after another with no padding. At runtime, SFT builds
|
| 293 |
+
sequences of length MAX_SEQ_LEN by slicing across these flat arrays.
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self, out_dir: Path, tokens_per_shard: int = TOKENS_PER_SHARD):
|
| 297 |
+
self.out_dir = out_dir
|
| 298 |
+
self.tokens_per_shard = tokens_per_shard
|
| 299 |
+
self.shard_idx = 0
|
| 300 |
+
self._buf_tok: list[int] = []
|
| 301 |
+
self._buf_mask: list[int] = []
|
| 302 |
+
self.total_tokens = 0
|
| 303 |
+
|
| 304 |
+
def add(self, tokens: list[int], mask: list[int]):
|
| 305 |
+
assert len(tokens) == len(mask)
|
| 306 |
+
self._buf_tok.extend(tokens)
|
| 307 |
+
self._buf_mask.extend(mask)
|
| 308 |
+
self.total_tokens += len(tokens)
|
| 309 |
+
while len(self._buf_tok) >= self.tokens_per_shard:
|
| 310 |
+
self._flush_one(self.tokens_per_shard)
|
| 311 |
+
|
| 312 |
+
def _flush_one(self, n: int):
|
| 313 |
+
tok_path = self.out_dir / f"shard_{self.shard_idx:04d}.bin"
|
| 314 |
+
mask_path = self.out_dir / f"mask_{self.shard_idx:04d}.bin"
|
| 315 |
+
arr_tok = np.array(self._buf_tok[:n], dtype=DTYPE)
|
| 316 |
+
arr_mask = np.array(self._buf_mask[:n], dtype=np.uint8)
|
| 317 |
+
arr_tok.tofile(tok_path)
|
| 318 |
+
arr_mask.tofile(mask_path)
|
| 319 |
+
self._buf_tok = self._buf_tok[n:]
|
| 320 |
+
self._buf_mask = self._buf_mask[n:]
|
| 321 |
+
print(f" wrote {tok_path.name} ({n:,} tokens)", flush=True)
|
| 322 |
+
self.shard_idx += 1
|
| 323 |
+
|
| 324 |
+
def finalize(self):
|
| 325 |
+
if self._buf_tok:
|
| 326 |
+
self._flush_one(len(self._buf_tok))
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ---------------------------------------------------------------------------
|
| 330 |
+
# Main
|
| 331 |
+
# ---------------------------------------------------------------------------
|
| 332 |
+
|
| 333 |
+
def main():
|
| 334 |
+
ap = argparse.ArgumentParser()
|
| 335 |
+
ap.add_argument("--test", action="store_true",
|
| 336 |
+
help="Small smoke run: write ~1.5M tokens and exit.")
|
| 337 |
+
ap.add_argument("--offline", action="store_true",
|
| 338 |
+
help="Skip network, use hard-coded backup only.")
|
| 339 |
+
ap.add_argument("--target-tokens", type=int, default=None,
|
| 340 |
+
help="Override target token count.")
|
| 341 |
+
args = ap.parse_args()
|
| 342 |
+
|
| 343 |
+
target = args.target_tokens or (
|
| 344 |
+
TARGET_TOKENS_TEST if args.test else TARGET_TOKENS_DEFAULT
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
print(f"SFT_DIR: {SFT_DIR}")
|
| 348 |
+
print(f"Target tokens: {target:,}")
|
| 349 |
+
print(f"Offline mode: {args.offline}")
|
| 350 |
+
|
| 351 |
+
# Clear any prior shards
|
| 352 |
+
for p in SFT_DIR.glob("shard_*.bin"):
|
| 353 |
+
p.unlink()
|
| 354 |
+
for p in SFT_DIR.glob("mask_*.bin"):
|
| 355 |
+
p.unlink()
|
| 356 |
+
|
| 357 |
+
tok = load_tokenizer()
|
| 358 |
+
print(f"Tokenizer vocab: {tok.vocab_size}")
|
| 359 |
+
print(f"Special tokens: BOS={BOS_ID} USER={USER_ID} "
|
| 360 |
+
f"ASSISTANT={ASSISTANT_ID} END={END_ID}")
|
| 361 |
+
|
| 362 |
+
sources = [] # list of (name, iterator_fn)
|
| 363 |
+
if not args.offline:
|
| 364 |
+
alpaca_path = SFT_DIR / "alpaca_raw.parquet"
|
| 365 |
+
print(f"\n[src] downloading alpaca-cleaned -> {alpaca_path.name} ...")
|
| 366 |
+
if _download_parquet(ALPACA_URL, alpaca_path):
|
| 367 |
+
print(f" ok ({alpaca_path.stat().st_size // (1 << 20)} MiB)")
|
| 368 |
+
sources.append(("alpaca-cleaned", lambda: _iter_alpaca(alpaca_path)))
|
| 369 |
+
else:
|
| 370 |
+
print(" alpaca download FAILED, trying dolly...")
|
| 371 |
+
dolly_path = SFT_DIR / "dolly_raw.parquet"
|
| 372 |
+
if _download_parquet(DOLLY_URL, dolly_path):
|
| 373 |
+
print(f" ok ({dolly_path.stat().st_size // (1 << 20)} MiB)")
|
| 374 |
+
sources.append(("dolly-15k", lambda: _iter_dolly(dolly_path)))
|
| 375 |
+
|
| 376 |
+
# Always include backup β cheap, catches tail
|
| 377 |
+
sources.append(("backup-200", _iter_backup))
|
| 378 |
+
|
| 379 |
+
if not sources:
|
| 380 |
+
print("FATAL: no data sources available.", file=sys.stderr)
|
| 381 |
+
sys.exit(1)
|
| 382 |
+
|
| 383 |
+
# Stream-encode
|
| 384 |
+
writer = ShardWriter(SFT_DIR)
|
| 385 |
+
n_examples = 0
|
| 386 |
+
n_assistant_tokens = 0
|
| 387 |
+
source_counts = {}
|
| 388 |
+
|
| 389 |
+
for src_name, src_fn in sources:
|
| 390 |
+
print(f"\n[src] encoding {src_name} ...")
|
| 391 |
+
src_examples = 0
|
| 392 |
+
src_tokens = 0
|
| 393 |
+
for (instruction, input_text, output) in src_fn():
|
| 394 |
+
# Skip overly long outputs β 7.5M model can't use them
|
| 395 |
+
if len(output) > 2000:
|
| 396 |
+
output = output[:2000]
|
| 397 |
+
ids, mask = encode_example_with_mask(tok, instruction,
|
| 398 |
+
input_text, output)
|
| 399 |
+
if len(ids) < 4 or len(ids) > 512:
|
| 400 |
+
# Skip degenerate / too-long examples
|
| 401 |
+
continue
|
| 402 |
+
writer.add(ids, mask)
|
| 403 |
+
n_examples += 1
|
| 404 |
+
src_examples += 1
|
| 405 |
+
src_tokens += len(ids)
|
| 406 |
+
n_assistant_tokens += sum(mask)
|
| 407 |
+
if writer.total_tokens >= target:
|
| 408 |
+
break
|
| 409 |
+
source_counts[src_name] = {
|
| 410 |
+
"examples": src_examples,
|
| 411 |
+
"tokens": src_tokens,
|
| 412 |
+
}
|
| 413 |
+
print(f" {src_name}: {src_examples:,} examples, {src_tokens:,} tokens")
|
| 414 |
+
if writer.total_tokens >= target:
|
| 415 |
+
break
|
| 416 |
+
|
| 417 |
+
writer.finalize()
|
| 418 |
+
|
| 419 |
+
meta = {
|
| 420 |
+
"total_tokens": writer.total_tokens,
|
| 421 |
+
"total_examples": n_examples,
|
| 422 |
+
"assistant_tokens_in_loss": n_assistant_tokens,
|
| 423 |
+
"num_shards": writer.shard_idx,
|
| 424 |
+
"tokens_per_shard": TOKENS_PER_SHARD,
|
| 425 |
+
"dtype": "int16",
|
| 426 |
+
"vocab_size": tok.vocab_size,
|
| 427 |
+
"special_tokens": {
|
| 428 |
+
"bos": BOS_ID,
|
| 429 |
+
"user": USER_ID,
|
| 430 |
+
"assistant": ASSISTANT_ID,
|
| 431 |
+
"end": END_ID,
|
| 432 |
+
},
|
| 433 |
+
"sources": source_counts,
|
| 434 |
+
"format_hint": (
|
| 435 |
+
"<BOS><|user|>\\n{instr}\\n[{input}\\n]<|assistant|>\\n"
|
| 436 |
+
"{output}<|end|>\\n"
|
| 437 |
+
),
|
| 438 |
+
}
|
| 439 |
+
meta_path = SFT_DIR / "meta.json"
|
| 440 |
+
with open(meta_path, "w") as f:
|
| 441 |
+
json.dump(meta, f, indent=2)
|
| 442 |
+
|
| 443 |
+
print(f"\n===== SFT data ready =====")
|
| 444 |
+
print(f" examples: {n_examples:,}")
|
| 445 |
+
print(f" total tokens: {writer.total_tokens:,}")
|
| 446 |
+
print(f" loss tokens: {n_assistant_tokens:,}")
|
| 447 |
+
print(f" shards: {writer.shard_idx}")
|
| 448 |
+
print(f" meta: {meta_path}")
|
| 449 |
+
|
| 450 |
+
if args.test and writer.total_tokens < 1_000_000:
|
| 451 |
+
print(f"\nWARN: test mode produced only {writer.total_tokens:,} "
|
| 452 |
+
f"tokens β below 1M threshold.")
|
| 453 |
+
sys.exit(2)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
if __name__ == "__main__":
|
| 457 |
+
main()
|
overlay/scripts/eval_quality.py
CHANGED
|
@@ -1,525 +1,525 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""Comprehensive quality evaluation harness for HYDRA.
|
| 3 |
-
|
| 4 |
-
Computes: PPL, BLEU-1, BLEU-4, ROUGE-1, ROUGE-L, factual accuracy,
|
| 5 |
-
coherence metrics (distinct-2, repetition-rate, self-BLEU), and a
|
| 6 |
-
composite quality_score.
|
| 7 |
-
|
| 8 |
-
Usage:
|
| 9 |
-
python scripts/eval_quality.py # eval latest model
|
| 10 |
-
python scripts/eval_quality.py --checkpoint ckpt.pt # eval from checkpoint
|
| 11 |
-
|
| 12 |
-
All metrics printed as key=value (grep-friendly). Runs in <30s on RTX 3060.
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
from __future__ import annotations
|
| 16 |
-
|
| 17 |
-
import math
|
| 18 |
-
import os
|
| 19 |
-
import sys
|
| 20 |
-
import time
|
| 21 |
-
from collections import Counter
|
| 22 |
-
from typing import Optional
|
| 23 |
-
|
| 24 |
-
# Ensure project root is on path
|
| 25 |
-
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 26 |
-
if _PROJECT_ROOT not in sys.path:
|
| 27 |
-
sys.path.insert(0, _PROJECT_ROOT)
|
| 28 |
-
|
| 29 |
-
import torch
|
| 30 |
-
import torch.nn.functional as F
|
| 31 |
-
|
| 32 |
-
from hydra.config import (
|
| 33 |
-
D_MODEL, D_STATE, DEVICE_BATCH_SIZE, ENGRAM_KEY_DIM,
|
| 34 |
-
ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, HEADDIM,
|
| 35 |
-
N_HEADS, N_LAYER, PostSemClawConfig,
|
| 36 |
-
)
|
| 37 |
-
from hydra.eval import FACTUAL_EVAL
|
| 38 |
-
from prepare import MAX_SEQ_LEN, Tokenizer, evaluate_bpb
|
| 39 |
-
|
| 40 |
-
# ---------------------------------------------------------------------------
|
| 41 |
-
# Eval prompts (hardcoded for reproducibility)
|
| 42 |
-
# ---------------------------------------------------------------------------
|
| 43 |
-
|
| 44 |
-
EVAL_PROMPTS = [
|
| 45 |
-
"The capital of France is",
|
| 46 |
-
"In 1969, humans first",
|
| 47 |
-
"Water boils at a temperature of",
|
| 48 |
-
"The theory of relativity was developed by",
|
| 49 |
-
"The largest planet in our solar system is",
|
| 50 |
-
"Photosynthesis is the process by which",
|
| 51 |
-
"The stock market crashed in",
|
| 52 |
-
"DNA stands for",
|
| 53 |
-
"The speed of light is approximately",
|
| 54 |
-
"Shakespeare wrote the play",
|
| 55 |
-
"The mitochondria is often called the",
|
| 56 |
-
"In computer science, an algorithm is",
|
| 57 |
-
"The chemical symbol for gold is",
|
| 58 |
-
"The Great Wall of China was built to",
|
| 59 |
-
"Gravity is a force that",
|
| 60 |
-
"The human heart pumps blood through",
|
| 61 |
-
"The Amazon rainforest is located in",
|
| 62 |
-
"Pi is approximately equal to",
|
| 63 |
-
"The first President of the United States was",
|
| 64 |
-
"Oxygen makes up approximately",
|
| 65 |
-
]
|
| 66 |
-
|
| 67 |
-
# Reference continuations (approximate, for BLEU/ROUGE)
|
| 68 |
-
EVAL_REFERENCES = [
|
| 69 |
-
"Paris, which is also the largest city in France.",
|
| 70 |
-
"landed on the Moon during the Apollo 11 mission.",
|
| 71 |
-
"100 degrees Celsius or 212 degrees Fahrenheit at standard atmospheric pressure.",
|
| 72 |
-
"Albert Einstein in the early twentieth century.",
|
| 73 |
-
"Jupiter, which is a gas giant.",
|
| 74 |
-
"plants convert sunlight into chemical energy and produce oxygen.",
|
| 75 |
-
"1929, leading to the Great Depression.",
|
| 76 |
-
"deoxyribonucleic acid, which carries genetic information.",
|
| 77 |
-
"299,792 kilometers per second in a vacuum.",
|
| 78 |
-
"Romeo and Juliet, one of the most famous tragedies.",
|
| 79 |
-
"powerhouse of the cell because it produces energy.",
|
| 80 |
-
"a step by step procedure for solving a problem.",
|
| 81 |
-
"Au, from the Latin word aurum.",
|
| 82 |
-
"protect against invasions from the north.",
|
| 83 |
-
"attracts objects with mass toward each other.",
|
| 84 |
-
"the circulatory system to deliver oxygen and nutrients.",
|
| 85 |
-
"South America, primarily within Brazil.",
|
| 86 |
-
"3.14159, and it represents the ratio of circumference to diameter.",
|
| 87 |
-
"George Washington, who served from 1789 to 1797.",
|
| 88 |
-
"21 percent of the Earth's atmosphere.",
|
| 89 |
-
]
|
| 90 |
-
|
| 91 |
-
COHERENCE_PROMPTS = [
|
| 92 |
-
"The history of science shows that",
|
| 93 |
-
"In modern society, technology has",
|
| 94 |
-
"The relationship between education and",
|
| 95 |
-
"Climate change is affecting the world because",
|
| 96 |
-
"The development of artificial intelligence has led to",
|
| 97 |
-
"Throughout human history, art has been",
|
| 98 |
-
"The economy of a nation depends on",
|
| 99 |
-
"Medical research has shown that",
|
| 100 |
-
"The role of government in society is",
|
| 101 |
-
"The ocean covers more than",
|
| 102 |
-
]
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
# ---------------------------------------------------------------------------
|
| 106 |
-
# Manual BLEU implementation (no nltk dependency)
|
| 107 |
-
# ---------------------------------------------------------------------------
|
| 108 |
-
|
| 109 |
-
def _get_ngrams(tokens: list[str], n: int) -> Counter:
|
| 110 |
-
"""Extract n-gram counts from token list."""
|
| 111 |
-
return Counter(tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1))
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def _modified_precision(reference_tokens: list[str], hypothesis_tokens: list[str], n: int) -> tuple[int, int]:
|
| 115 |
-
"""Compute modified precision for n-grams."""
|
| 116 |
-
ref_ngrams = _get_ngrams(reference_tokens, n)
|
| 117 |
-
hyp_ngrams = _get_ngrams(hypothesis_tokens, n)
|
| 118 |
-
clipped_count = 0
|
| 119 |
-
total_count = 0
|
| 120 |
-
for ngram, count in hyp_ngrams.items():
|
| 121 |
-
clipped_count += min(count, ref_ngrams.get(ngram, 0))
|
| 122 |
-
total_count += count
|
| 123 |
-
return clipped_count, max(total_count, 1)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def compute_bleu(references: list[list[str]], hypotheses: list[list[str]], max_n: int = 4) -> dict[str, float]:
|
| 127 |
-
"""Corpus-level BLEU-1 through BLEU-max_n.
|
| 128 |
-
|
| 129 |
-
Uses brevity penalty and geometric mean of modified precisions.
|
| 130 |
-
"""
|
| 131 |
-
precisions = []
|
| 132 |
-
for n in range(1, max_n + 1):
|
| 133 |
-
total_clip = 0
|
| 134 |
-
total_count = 0
|
| 135 |
-
for ref, hyp in zip(references, hypotheses):
|
| 136 |
-
clip, count = _modified_precision(ref, hyp, n)
|
| 137 |
-
total_clip += clip
|
| 138 |
-
total_count += count
|
| 139 |
-
precisions.append(total_clip / max(total_count, 1))
|
| 140 |
-
|
| 141 |
-
# Brevity penalty
|
| 142 |
-
ref_len = sum(len(r) for r in references)
|
| 143 |
-
hyp_len = sum(len(h) for h in hypotheses)
|
| 144 |
-
if hyp_len == 0:
|
| 145 |
-
return {f"bleu{n}": 0.0 for n in range(1, max_n + 1)}
|
| 146 |
-
bp = math.exp(min(0, 1 - ref_len / hyp_len))
|
| 147 |
-
|
| 148 |
-
result = {}
|
| 149 |
-
for n in range(1, max_n + 1):
|
| 150 |
-
# Geometric mean of precisions 1..n
|
| 151 |
-
log_avg = sum(math.log(max(p, 1e-10)) for p in precisions[:n]) / n
|
| 152 |
-
result[f"bleu{n}"] = bp * math.exp(log_avg)
|
| 153 |
-
return result
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
# ---------------------------------------------------------------------------
|
| 157 |
-
# Manual ROUGE implementation (no rouge_score dependency)
|
| 158 |
-
# ---------------------------------------------------------------------------
|
| 159 |
-
|
| 160 |
-
def _lcs_length(x: list[str], y: list[str]) -> int:
|
| 161 |
-
"""Longest common subsequence length via DP."""
|
| 162 |
-
m, n = len(x), len(y)
|
| 163 |
-
if m == 0 or n == 0:
|
| 164 |
-
return 0
|
| 165 |
-
# Space-optimized: only keep current and previous row
|
| 166 |
-
prev = [0] * (n + 1)
|
| 167 |
-
curr = [0] * (n + 1)
|
| 168 |
-
for i in range(1, m + 1):
|
| 169 |
-
for j in range(1, n + 1):
|
| 170 |
-
if x[i - 1] == y[j - 1]:
|
| 171 |
-
curr[j] = prev[j - 1] + 1
|
| 172 |
-
else:
|
| 173 |
-
curr[j] = max(prev[j], curr[j - 1])
|
| 174 |
-
prev, curr = curr, [0] * (n + 1)
|
| 175 |
-
return prev[n]
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def compute_rouge(references: list[list[str]], hypotheses: list[list[str]]) -> dict[str, float]:
|
| 179 |
-
"""Compute ROUGE-1 (unigram F1) and ROUGE-L (LCS-based F1)."""
|
| 180 |
-
rouge1_scores = []
|
| 181 |
-
rougel_scores = []
|
| 182 |
-
|
| 183 |
-
for ref, hyp in zip(references, hypotheses):
|
| 184 |
-
if not ref or not hyp:
|
| 185 |
-
rouge1_scores.append(0.0)
|
| 186 |
-
rougel_scores.append(0.0)
|
| 187 |
-
continue
|
| 188 |
-
|
| 189 |
-
# ROUGE-1: unigram overlap
|
| 190 |
-
ref_unigrams = Counter(ref)
|
| 191 |
-
hyp_unigrams = Counter(hyp)
|
| 192 |
-
overlap = sum((ref_unigrams & hyp_unigrams).values())
|
| 193 |
-
r1_precision = overlap / max(len(hyp), 1)
|
| 194 |
-
r1_recall = overlap / max(len(ref), 1)
|
| 195 |
-
r1_f1 = 2 * r1_precision * r1_recall / max(r1_precision + r1_recall, 1e-10)
|
| 196 |
-
rouge1_scores.append(r1_f1)
|
| 197 |
-
|
| 198 |
-
# ROUGE-L: LCS-based
|
| 199 |
-
lcs = _lcs_length(ref, hyp)
|
| 200 |
-
rl_precision = lcs / max(len(hyp), 1)
|
| 201 |
-
rl_recall = lcs / max(len(ref), 1)
|
| 202 |
-
rl_f1 = 2 * rl_precision * rl_recall / max(rl_precision + rl_recall, 1e-10)
|
| 203 |
-
rougel_scores.append(rl_f1)
|
| 204 |
-
|
| 205 |
-
return {
|
| 206 |
-
"rouge1": sum(rouge1_scores) / max(len(rouge1_scores), 1),
|
| 207 |
-
"rouge_l": sum(rougel_scores) / max(len(rougel_scores), 1),
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
# ---------------------------------------------------------------------------
|
| 212 |
-
# Greedy generation
|
| 213 |
-
# ---------------------------------------------------------------------------
|
| 214 |
-
|
| 215 |
-
@torch.no_grad()
|
| 216 |
-
def greedy_generate(model, tokenizer, prompt: str, max_new_tokens: int = 32, device: str = "cuda") -> str:
|
| 217 |
-
"""Greedy (argmax) autoregressive generation. Deterministic."""
|
| 218 |
-
ids = tokenizer.encode(prompt)
|
| 219 |
-
x = torch.tensor([ids], device=device, dtype=torch.long)
|
| 220 |
-
|
| 221 |
-
for _ in range(max_new_tokens):
|
| 222 |
-
logits = model(x, targets=None)
|
| 223 |
-
if logits.dim() == 3:
|
| 224 |
-
next_logits = logits[0, -1, :]
|
| 225 |
-
else:
|
| 226 |
-
next_logits = logits[0]
|
| 227 |
-
next_id = next_logits.argmax().unsqueeze(0).unsqueeze(0)
|
| 228 |
-
x = torch.cat([x, next_id], dim=1)
|
| 229 |
-
if x.size(1) >= MAX_SEQ_LEN:
|
| 230 |
-
break
|
| 231 |
-
|
| 232 |
-
all_ids = x[0].tolist()
|
| 233 |
-
return tokenizer.decode(all_ids[len(ids):])
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
# ---------------------------------------------------------------------------
|
| 237 |
-
# Coherence metrics
|
| 238 |
-
# ---------------------------------------------------------------------------
|
| 239 |
-
|
| 240 |
-
def compute_coherence(generations: list[str]) -> dict[str, float]:
|
| 241 |
-
"""Compute distinct-2, repetition rate, and self-BLEU across generations."""
|
| 242 |
-
all_bigrams = []
|
| 243 |
-
all_fourgrams = []
|
| 244 |
-
tokenized_gens = []
|
| 245 |
-
|
| 246 |
-
for gen in generations:
|
| 247 |
-
tokens = gen.lower().split()
|
| 248 |
-
tokenized_gens.append(tokens)
|
| 249 |
-
bigrams = [tuple(tokens[i:i + 2]) for i in range(len(tokens) - 1)]
|
| 250 |
-
fourgrams = [tuple(tokens[i:i + 4]) for i in range(len(tokens) - 3)]
|
| 251 |
-
all_bigrams.extend(bigrams)
|
| 252 |
-
all_fourgrams.extend(fourgrams)
|
| 253 |
-
|
| 254 |
-
# Distinct-2: fraction of unique bigrams
|
| 255 |
-
distinct2 = len(set(all_bigrams)) / max(len(all_bigrams), 1)
|
| 256 |
-
|
| 257 |
-
# Repetition rate: fraction of 4-grams that appear more than once
|
| 258 |
-
fourgram_counts = Counter(all_fourgrams)
|
| 259 |
-
repeated = sum(1 for c in fourgram_counts.values() if c > 1)
|
| 260 |
-
repetition_rate = repeated / max(len(fourgram_counts), 1)
|
| 261 |
-
|
| 262 |
-
# Self-BLEU: average BLEU of each generation against all others
|
| 263 |
-
# Lower = more diverse
|
| 264 |
-
self_bleu_scores = []
|
| 265 |
-
for i, hyp in enumerate(tokenized_gens):
|
| 266 |
-
if not hyp:
|
| 267 |
-
continue
|
| 268 |
-
others = [g for j, g in enumerate(tokenized_gens) if j != i and g]
|
| 269 |
-
if not others:
|
| 270 |
-
continue
|
| 271 |
-
# Average BLEU against each other generation
|
| 272 |
-
pair_scores = []
|
| 273 |
-
for ref in others:
|
| 274 |
-
result = compute_bleu([ref], [hyp], max_n=4)
|
| 275 |
-
pair_scores.append(result.get("bleu4", 0.0))
|
| 276 |
-
self_bleu_scores.append(sum(pair_scores) / len(pair_scores))
|
| 277 |
-
|
| 278 |
-
self_bleu = sum(self_bleu_scores) / max(len(self_bleu_scores), 1)
|
| 279 |
-
|
| 280 |
-
return {
|
| 281 |
-
"distinct2": distinct2,
|
| 282 |
-
"repetition_rate": repetition_rate,
|
| 283 |
-
"self_bleu": self_bleu,
|
| 284 |
-
}
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
# ---------------------------------------------------------------------------
|
| 288 |
-
# Factual accuracy (reuse existing probes)
|
| 289 |
-
# ---------------------------------------------------------------------------
|
| 290 |
-
|
| 291 |
-
def compute_factual(model, tokenizer, device: str = "cuda") -> float:
|
| 292 |
-
"""Run factual eval probes, return accuracy [0,1]."""
|
| 293 |
-
model.eval()
|
| 294 |
-
hits = 0
|
| 295 |
-
|
| 296 |
-
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 297 |
-
for prompt, answers in FACTUAL_EVAL:
|
| 298 |
-
ids = tokenizer.encode(prompt)
|
| 299 |
-
x = torch.tensor([ids], device=device, dtype=torch.long)
|
| 300 |
-
logits = model(x, targets=None)
|
| 301 |
-
if logits.dim() == 3:
|
| 302 |
-
last_logits = logits[0, -1, :]
|
| 303 |
-
else:
|
| 304 |
-
last_logits = logits[0]
|
| 305 |
-
|
| 306 |
-
probs = torch.softmax(last_logits.float(), dim=-1)
|
| 307 |
-
top_k = min(20, probs.shape[-1])
|
| 308 |
-
top_ids = torch.topk(probs, top_k).indices.tolist()
|
| 309 |
-
top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
|
| 310 |
-
answers_lower = [a.lower() for a in answers]
|
| 311 |
-
if any(any(a in tok for a in answers_lower) for tok in top_tokens):
|
| 312 |
-
hits += 1
|
| 313 |
-
|
| 314 |
-
return hits / max(len(FACTUAL_EVAL), 1)
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
# ---------------------------------------------------------------------------
|
| 318 |
-
# PPL (perplexity) via existing evaluate_bpb
|
| 319 |
-
# ---------------------------------------------------------------------------
|
| 320 |
-
|
| 321 |
-
def compute_ppl(model, tokenizer, batch_size: int = 8) -> tuple[float, float]:
|
| 322 |
-
"""Compute BPB and PPL. Returns (bpb, ppl)."""
|
| 323 |
-
import prepare as _prepare_mod
|
| 324 |
-
# Use smaller eval set for speed (<30s budget)
|
| 325 |
-
orig_eval = _prepare_mod.EVAL_TOKENS
|
| 326 |
-
_prepare_mod.EVAL_TOKENS = 2 * 524288 # ~1M tokens, fast
|
| 327 |
-
try:
|
| 328 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 329 |
-
bpb = evaluate_bpb(model, tokenizer, batch_size)
|
| 330 |
-
finally:
|
| 331 |
-
_prepare_mod.EVAL_TOKENS = orig_eval
|
| 332 |
-
ppl = 2 ** bpb
|
| 333 |
-
return bpb, ppl
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
# ---------------------------------------------------------------------------
|
| 337 |
-
# Composite quality score
|
| 338 |
-
# ---------------------------------------------------------------------------
|
| 339 |
-
|
| 340 |
-
def compute_quality_score(ppl: float, bleu4: float, rouge_l: float,
|
| 341 |
-
factual: float, repetition_rate: float) -> float:
|
| 342 |
-
"""Single composite metric for autoresearch optimization.
|
| 343 |
-
|
| 344 |
-
Formula rationale:
|
| 345 |
-
- PPL (30%): Primary language modeling metric, capped at 100
|
| 346 |
-
- BLEU-4 (20%): Generation quality vs references
|
| 347 |
-
- ROUGE-L (20%): Recall of reference content
|
| 348 |
-
- Factual (15%): Knowledge memorization
|
| 349 |
-
- 1-repetition (15%): Diversity/coherence
|
| 350 |
-
"""
|
| 351 |
-
return (
|
| 352 |
-
0.3 * (1 - min(ppl, 100) / 100) +
|
| 353 |
-
0.2 * bleu4 +
|
| 354 |
-
0.2 * rouge_l +
|
| 355 |
-
0.15 * factual +
|
| 356 |
-
0.15 * (1 - repetition_rate)
|
| 357 |
-
)
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
# ---------------------------------------------------------------------------
|
| 361 |
-
# Main evaluation entry point
|
| 362 |
-
# ---------------------------------------------------------------------------
|
| 363 |
-
|
| 364 |
-
def run_quality_eval(
|
| 365 |
-
model: torch.nn.Module,
|
| 366 |
-
tokenizer,
|
| 367 |
-
device: str = "cuda",
|
| 368 |
-
batch_size: int = 8,
|
| 369 |
-
verbose: bool = True,
|
| 370 |
-
) -> dict[str, float]:
|
| 371 |
-
"""Run full quality evaluation suite. Returns dict of all metrics."""
|
| 372 |
-
model.eval()
|
| 373 |
-
results: dict[str, float] = {}
|
| 374 |
-
|
| 375 |
-
t0 = time.time()
|
| 376 |
-
|
| 377 |
-
# 1. PPL / BPB
|
| 378 |
-
if verbose:
|
| 379 |
-
print("[eval] Computing PPL/BPB...", flush=True)
|
| 380 |
-
bpb, ppl = compute_ppl(model, tokenizer, batch_size)
|
| 381 |
-
results["bpb"] = bpb
|
| 382 |
-
results["ppl"] = ppl
|
| 383 |
-
|
| 384 |
-
# 2. Generate continuations for BLEU/ROUGE
|
| 385 |
-
if verbose:
|
| 386 |
-
print("[eval] Generating continuations (20 prompts, greedy)...", flush=True)
|
| 387 |
-
hypotheses_text = []
|
| 388 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 389 |
-
for prompt in EVAL_PROMPTS:
|
| 390 |
-
gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=32, device=device)
|
| 391 |
-
hypotheses_text.append(gen)
|
| 392 |
-
|
| 393 |
-
# Tokenize for BLEU/ROUGE (simple whitespace split)
|
| 394 |
-
ref_tokens = [ref.lower().split() for ref in EVAL_REFERENCES]
|
| 395 |
-
hyp_tokens = [hyp.lower().split() for hyp in hypotheses_text]
|
| 396 |
-
|
| 397 |
-
# 3. BLEU
|
| 398 |
-
if verbose:
|
| 399 |
-
print("[eval] Computing BLEU...", flush=True)
|
| 400 |
-
bleu = compute_bleu(ref_tokens, hyp_tokens, max_n=4)
|
| 401 |
-
results["bleu1"] = bleu["bleu1"]
|
| 402 |
-
results["bleu4"] = bleu["bleu4"]
|
| 403 |
-
|
| 404 |
-
# 4. ROUGE
|
| 405 |
-
if verbose:
|
| 406 |
-
print("[eval] Computing ROUGE...", flush=True)
|
| 407 |
-
rouge = compute_rouge(ref_tokens, hyp_tokens)
|
| 408 |
-
results["rouge1"] = rouge["rouge1"]
|
| 409 |
-
results["rouge_l"] = rouge["rouge_l"]
|
| 410 |
-
|
| 411 |
-
# 5. Factual accuracy
|
| 412 |
-
if verbose:
|
| 413 |
-
print("[eval] Computing factual accuracy...", flush=True)
|
| 414 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 415 |
-
factual = compute_factual(model, tokenizer, device)
|
| 416 |
-
results["factual"] = factual
|
| 417 |
-
|
| 418 |
-
# 6. Coherence
|
| 419 |
-
if verbose:
|
| 420 |
-
print("[eval] Generating coherence passages (10 prompts, 64 tokens)...", flush=True)
|
| 421 |
-
coherence_gens = []
|
| 422 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 423 |
-
for prompt in COHERENCE_PROMPTS:
|
| 424 |
-
gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=64, device=device)
|
| 425 |
-
coherence_gens.append(gen)
|
| 426 |
-
|
| 427 |
-
coherence = compute_coherence(coherence_gens)
|
| 428 |
-
results["distinct2"] = coherence["distinct2"]
|
| 429 |
-
results["repetition_rate"] = coherence["repetition_rate"]
|
| 430 |
-
results["self_bleu"] = coherence["self_bleu"]
|
| 431 |
-
|
| 432 |
-
# 7. Composite score
|
| 433 |
-
results["quality_score"] = compute_quality_score(
|
| 434 |
-
ppl=results["ppl"],
|
| 435 |
-
bleu4=results["bleu4"],
|
| 436 |
-
rouge_l=results["rouge_l"],
|
| 437 |
-
factual=results["factual"],
|
| 438 |
-
repetition_rate=results["repetition_rate"],
|
| 439 |
-
)
|
| 440 |
-
|
| 441 |
-
elapsed = time.time() - t0
|
| 442 |
-
results["eval_time_s"] = elapsed
|
| 443 |
-
|
| 444 |
-
# Print all metrics
|
| 445 |
-
if verbose:
|
| 446 |
-
print("\n--- Quality Evaluation Results ---")
|
| 447 |
-
for k, v in sorted(results.items()):
|
| 448 |
-
print(f"{k}={v:.6f}")
|
| 449 |
-
print("--- End Quality Evaluation ---\n")
|
| 450 |
-
|
| 451 |
-
# Print sample generations
|
| 452 |
-
print("--- Sample Generations ---")
|
| 453 |
-
for i, (prompt, gen) in enumerate(zip(EVAL_PROMPTS[:5], hypotheses_text[:5])):
|
| 454 |
-
print(f' [{i}] "{prompt}" -> "{gen.strip()[:80]}"')
|
| 455 |
-
print("--- End Sample Generations ---\n")
|
| 456 |
-
|
| 457 |
-
print("--- Coherence Samples ---")
|
| 458 |
-
for i, (prompt, gen) in enumerate(zip(COHERENCE_PROMPTS[:3], coherence_gens[:3])):
|
| 459 |
-
print(f' [{i}] "{prompt}" -> "{gen.strip()[:100]}"')
|
| 460 |
-
print("--- End Coherence Samples ---\n")
|
| 461 |
-
|
| 462 |
-
return results
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
# ---------------------------------------------------------------------------
|
| 466 |
-
# Standalone CLI
|
| 467 |
-
# ---------------------------------------------------------------------------
|
| 468 |
-
|
| 469 |
-
def _build_model_and_tokenizer(checkpoint: Optional[str] = None):
|
| 470 |
-
"""Build model + tokenizer, optionally loading from checkpoint."""
|
| 471 |
-
from hydra.model import PostSemClawModel
|
| 472 |
-
|
| 473 |
-
device = torch.device("cuda")
|
| 474 |
-
tokenizer = Tokenizer.from_directory()
|
| 475 |
-
vocab_size = tokenizer.get_vocab_size()
|
| 476 |
-
|
| 477 |
-
config = PostSemClawConfig(
|
| 478 |
-
sequence_len=MAX_SEQ_LEN,
|
| 479 |
-
vocab_size=vocab_size,
|
| 480 |
-
n_layer=N_LAYER,
|
| 481 |
-
d_model=D_MODEL,
|
| 482 |
-
d_state=D_STATE,
|
| 483 |
-
headdim=HEADDIM,
|
| 484 |
-
n_heads=N_HEADS,
|
| 485 |
-
expand=EXPAND,
|
| 486 |
-
engram_n_columns=ENGRAM_N_COLUMNS,
|
| 487 |
-
engram_key_dim=ENGRAM_KEY_DIM,
|
| 488 |
-
engram_layer_idx=ENGRAM_LAYER_IDX,
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
with torch.device("meta"):
|
| 492 |
-
model = PostSemClawModel(config)
|
| 493 |
-
model.to_empty(device=device)
|
| 494 |
-
|
| 495 |
-
if checkpoint and os.path.exists(checkpoint):
|
| 496 |
-
print(f"[eval] Loading checkpoint: {checkpoint}")
|
| 497 |
-
state = torch.load(checkpoint, map_location=device, weights_only=True)
|
| 498 |
-
model.load_state_dict(state, strict=False)
|
| 499 |
-
else:
|
| 500 |
-
print("[eval] No checkpoint β using freshly initialized weights")
|
| 501 |
-
model.init_weights()
|
| 502 |
-
|
| 503 |
-
model.eval()
|
| 504 |
-
return model, tokenizer, device
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
def main():
|
| 508 |
-
import argparse
|
| 509 |
-
parser = argparse.ArgumentParser(description="HYDRA quality evaluation")
|
| 510 |
-
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint")
|
| 511 |
-
parser.add_argument("--batch-size", type=int, default=DEVICE_BATCH_SIZE, help="Batch size for PPL eval")
|
| 512 |
-
args = parser.parse_args()
|
| 513 |
-
|
| 514 |
-
model, tokenizer, device = _build_model_and_tokenizer(args.checkpoint)
|
| 515 |
-
results = run_quality_eval(model, tokenizer, str(device), args.batch_size, verbose=True)
|
| 516 |
-
|
| 517 |
-
# Final summary line (grep-friendly)
|
| 518 |
-
print(f"QUALITY_SCORE={results['quality_score']:.6f} PPL={results['ppl']:.3f} "
|
| 519 |
-
f"BPB={results['bpb']:.4f} BLEU4={results['bleu4']:.4f} "
|
| 520 |
-
f"ROUGE_L={results['rouge_l']:.4f} FACTUAL={results['factual']:.4f} "
|
| 521 |
-
f"REP_RATE={results['repetition_rate']:.4f}")
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
if __name__ == "__main__":
|
| 525 |
-
main()
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Comprehensive quality evaluation harness for HYDRA.
|
| 3 |
+
|
| 4 |
+
Computes: PPL, BLEU-1, BLEU-4, ROUGE-1, ROUGE-L, factual accuracy,
|
| 5 |
+
coherence metrics (distinct-2, repetition-rate, self-BLEU), and a
|
| 6 |
+
composite quality_score.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/eval_quality.py # eval latest model
|
| 10 |
+
python scripts/eval_quality.py --checkpoint ckpt.pt # eval from checkpoint
|
| 11 |
+
|
| 12 |
+
All metrics printed as key=value (grep-friendly). Runs in <30s on RTX 3060.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
import sys
|
| 20 |
+
import time
|
| 21 |
+
from collections import Counter
|
| 22 |
+
from typing import Optional
|
| 23 |
+
|
| 24 |
+
# Ensure project root is on path
|
| 25 |
+
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 26 |
+
if _PROJECT_ROOT not in sys.path:
|
| 27 |
+
sys.path.insert(0, _PROJECT_ROOT)
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
|
| 32 |
+
from hydra.config import (
|
| 33 |
+
D_MODEL, D_STATE, DEVICE_BATCH_SIZE, ENGRAM_KEY_DIM,
|
| 34 |
+
ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, HEADDIM,
|
| 35 |
+
N_HEADS, N_LAYER, PostSemClawConfig,
|
| 36 |
+
)
|
| 37 |
+
from hydra.eval import FACTUAL_EVAL
|
| 38 |
+
from prepare import MAX_SEQ_LEN, Tokenizer, evaluate_bpb
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Eval prompts (hardcoded for reproducibility)
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
EVAL_PROMPTS = [
|
| 45 |
+
"The capital of France is",
|
| 46 |
+
"In 1969, humans first",
|
| 47 |
+
"Water boils at a temperature of",
|
| 48 |
+
"The theory of relativity was developed by",
|
| 49 |
+
"The largest planet in our solar system is",
|
| 50 |
+
"Photosynthesis is the process by which",
|
| 51 |
+
"The stock market crashed in",
|
| 52 |
+
"DNA stands for",
|
| 53 |
+
"The speed of light is approximately",
|
| 54 |
+
"Shakespeare wrote the play",
|
| 55 |
+
"The mitochondria is often called the",
|
| 56 |
+
"In computer science, an algorithm is",
|
| 57 |
+
"The chemical symbol for gold is",
|
| 58 |
+
"The Great Wall of China was built to",
|
| 59 |
+
"Gravity is a force that",
|
| 60 |
+
"The human heart pumps blood through",
|
| 61 |
+
"The Amazon rainforest is located in",
|
| 62 |
+
"Pi is approximately equal to",
|
| 63 |
+
"The first President of the United States was",
|
| 64 |
+
"Oxygen makes up approximately",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
# Reference continuations (approximate, for BLEU/ROUGE)
|
| 68 |
+
EVAL_REFERENCES = [
|
| 69 |
+
"Paris, which is also the largest city in France.",
|
| 70 |
+
"landed on the Moon during the Apollo 11 mission.",
|
| 71 |
+
"100 degrees Celsius or 212 degrees Fahrenheit at standard atmospheric pressure.",
|
| 72 |
+
"Albert Einstein in the early twentieth century.",
|
| 73 |
+
"Jupiter, which is a gas giant.",
|
| 74 |
+
"plants convert sunlight into chemical energy and produce oxygen.",
|
| 75 |
+
"1929, leading to the Great Depression.",
|
| 76 |
+
"deoxyribonucleic acid, which carries genetic information.",
|
| 77 |
+
"299,792 kilometers per second in a vacuum.",
|
| 78 |
+
"Romeo and Juliet, one of the most famous tragedies.",
|
| 79 |
+
"powerhouse of the cell because it produces energy.",
|
| 80 |
+
"a step by step procedure for solving a problem.",
|
| 81 |
+
"Au, from the Latin word aurum.",
|
| 82 |
+
"protect against invasions from the north.",
|
| 83 |
+
"attracts objects with mass toward each other.",
|
| 84 |
+
"the circulatory system to deliver oxygen and nutrients.",
|
| 85 |
+
"South America, primarily within Brazil.",
|
| 86 |
+
"3.14159, and it represents the ratio of circumference to diameter.",
|
| 87 |
+
"George Washington, who served from 1789 to 1797.",
|
| 88 |
+
"21 percent of the Earth's atmosphere.",
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
COHERENCE_PROMPTS = [
|
| 92 |
+
"The history of science shows that",
|
| 93 |
+
"In modern society, technology has",
|
| 94 |
+
"The relationship between education and",
|
| 95 |
+
"Climate change is affecting the world because",
|
| 96 |
+
"The development of artificial intelligence has led to",
|
| 97 |
+
"Throughout human history, art has been",
|
| 98 |
+
"The economy of a nation depends on",
|
| 99 |
+
"Medical research has shown that",
|
| 100 |
+
"The role of government in society is",
|
| 101 |
+
"The ocean covers more than",
|
| 102 |
+
]
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ---------------------------------------------------------------------------
|
| 106 |
+
# Manual BLEU implementation (no nltk dependency)
|
| 107 |
+
# ---------------------------------------------------------------------------
|
| 108 |
+
|
| 109 |
+
def _get_ngrams(tokens: list[str], n: int) -> Counter:
|
| 110 |
+
"""Extract n-gram counts from token list."""
|
| 111 |
+
return Counter(tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1))
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _modified_precision(reference_tokens: list[str], hypothesis_tokens: list[str], n: int) -> tuple[int, int]:
|
| 115 |
+
"""Compute modified precision for n-grams."""
|
| 116 |
+
ref_ngrams = _get_ngrams(reference_tokens, n)
|
| 117 |
+
hyp_ngrams = _get_ngrams(hypothesis_tokens, n)
|
| 118 |
+
clipped_count = 0
|
| 119 |
+
total_count = 0
|
| 120 |
+
for ngram, count in hyp_ngrams.items():
|
| 121 |
+
clipped_count += min(count, ref_ngrams.get(ngram, 0))
|
| 122 |
+
total_count += count
|
| 123 |
+
return clipped_count, max(total_count, 1)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def compute_bleu(references: list[list[str]], hypotheses: list[list[str]], max_n: int = 4) -> dict[str, float]:
|
| 127 |
+
"""Corpus-level BLEU-1 through BLEU-max_n.
|
| 128 |
+
|
| 129 |
+
Uses brevity penalty and geometric mean of modified precisions.
|
| 130 |
+
"""
|
| 131 |
+
precisions = []
|
| 132 |
+
for n in range(1, max_n + 1):
|
| 133 |
+
total_clip = 0
|
| 134 |
+
total_count = 0
|
| 135 |
+
for ref, hyp in zip(references, hypotheses):
|
| 136 |
+
clip, count = _modified_precision(ref, hyp, n)
|
| 137 |
+
total_clip += clip
|
| 138 |
+
total_count += count
|
| 139 |
+
precisions.append(total_clip / max(total_count, 1))
|
| 140 |
+
|
| 141 |
+
# Brevity penalty
|
| 142 |
+
ref_len = sum(len(r) for r in references)
|
| 143 |
+
hyp_len = sum(len(h) for h in hypotheses)
|
| 144 |
+
if hyp_len == 0:
|
| 145 |
+
return {f"bleu{n}": 0.0 for n in range(1, max_n + 1)}
|
| 146 |
+
bp = math.exp(min(0, 1 - ref_len / hyp_len))
|
| 147 |
+
|
| 148 |
+
result = {}
|
| 149 |
+
for n in range(1, max_n + 1):
|
| 150 |
+
# Geometric mean of precisions 1..n
|
| 151 |
+
log_avg = sum(math.log(max(p, 1e-10)) for p in precisions[:n]) / n
|
| 152 |
+
result[f"bleu{n}"] = bp * math.exp(log_avg)
|
| 153 |
+
return result
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ---------------------------------------------------------------------------
|
| 157 |
+
# Manual ROUGE implementation (no rouge_score dependency)
|
| 158 |
+
# ---------------------------------------------------------------------------
|
| 159 |
+
|
| 160 |
+
def _lcs_length(x: list[str], y: list[str]) -> int:
|
| 161 |
+
"""Longest common subsequence length via DP."""
|
| 162 |
+
m, n = len(x), len(y)
|
| 163 |
+
if m == 0 or n == 0:
|
| 164 |
+
return 0
|
| 165 |
+
# Space-optimized: only keep current and previous row
|
| 166 |
+
prev = [0] * (n + 1)
|
| 167 |
+
curr = [0] * (n + 1)
|
| 168 |
+
for i in range(1, m + 1):
|
| 169 |
+
for j in range(1, n + 1):
|
| 170 |
+
if x[i - 1] == y[j - 1]:
|
| 171 |
+
curr[j] = prev[j - 1] + 1
|
| 172 |
+
else:
|
| 173 |
+
curr[j] = max(prev[j], curr[j - 1])
|
| 174 |
+
prev, curr = curr, [0] * (n + 1)
|
| 175 |
+
return prev[n]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def compute_rouge(references: list[list[str]], hypotheses: list[list[str]]) -> dict[str, float]:
|
| 179 |
+
"""Compute ROUGE-1 (unigram F1) and ROUGE-L (LCS-based F1)."""
|
| 180 |
+
rouge1_scores = []
|
| 181 |
+
rougel_scores = []
|
| 182 |
+
|
| 183 |
+
for ref, hyp in zip(references, hypotheses):
|
| 184 |
+
if not ref or not hyp:
|
| 185 |
+
rouge1_scores.append(0.0)
|
| 186 |
+
rougel_scores.append(0.0)
|
| 187 |
+
continue
|
| 188 |
+
|
| 189 |
+
# ROUGE-1: unigram overlap
|
| 190 |
+
ref_unigrams = Counter(ref)
|
| 191 |
+
hyp_unigrams = Counter(hyp)
|
| 192 |
+
overlap = sum((ref_unigrams & hyp_unigrams).values())
|
| 193 |
+
r1_precision = overlap / max(len(hyp), 1)
|
| 194 |
+
r1_recall = overlap / max(len(ref), 1)
|
| 195 |
+
r1_f1 = 2 * r1_precision * r1_recall / max(r1_precision + r1_recall, 1e-10)
|
| 196 |
+
rouge1_scores.append(r1_f1)
|
| 197 |
+
|
| 198 |
+
# ROUGE-L: LCS-based
|
| 199 |
+
lcs = _lcs_length(ref, hyp)
|
| 200 |
+
rl_precision = lcs / max(len(hyp), 1)
|
| 201 |
+
rl_recall = lcs / max(len(ref), 1)
|
| 202 |
+
rl_f1 = 2 * rl_precision * rl_recall / max(rl_precision + rl_recall, 1e-10)
|
| 203 |
+
rougel_scores.append(rl_f1)
|
| 204 |
+
|
| 205 |
+
return {
|
| 206 |
+
"rouge1": sum(rouge1_scores) / max(len(rouge1_scores), 1),
|
| 207 |
+
"rouge_l": sum(rougel_scores) / max(len(rougel_scores), 1),
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
# Greedy generation
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
def greedy_generate(model, tokenizer, prompt: str, max_new_tokens: int = 32, device: str = "cuda") -> str:
|
| 217 |
+
"""Greedy (argmax) autoregressive generation. Deterministic."""
|
| 218 |
+
ids = tokenizer.encode(prompt)
|
| 219 |
+
x = torch.tensor([ids], device=device, dtype=torch.long)
|
| 220 |
+
|
| 221 |
+
for _ in range(max_new_tokens):
|
| 222 |
+
logits = model(x, targets=None)
|
| 223 |
+
if logits.dim() == 3:
|
| 224 |
+
next_logits = logits[0, -1, :]
|
| 225 |
+
else:
|
| 226 |
+
next_logits = logits[0]
|
| 227 |
+
next_id = next_logits.argmax().unsqueeze(0).unsqueeze(0)
|
| 228 |
+
x = torch.cat([x, next_id], dim=1)
|
| 229 |
+
if x.size(1) >= MAX_SEQ_LEN:
|
| 230 |
+
break
|
| 231 |
+
|
| 232 |
+
all_ids = x[0].tolist()
|
| 233 |
+
return tokenizer.decode(all_ids[len(ids):])
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ---------------------------------------------------------------------------
|
| 237 |
+
# Coherence metrics
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
|
| 240 |
+
def compute_coherence(generations: list[str]) -> dict[str, float]:
|
| 241 |
+
"""Compute distinct-2, repetition rate, and self-BLEU across generations."""
|
| 242 |
+
all_bigrams = []
|
| 243 |
+
all_fourgrams = []
|
| 244 |
+
tokenized_gens = []
|
| 245 |
+
|
| 246 |
+
for gen in generations:
|
| 247 |
+
tokens = gen.lower().split()
|
| 248 |
+
tokenized_gens.append(tokens)
|
| 249 |
+
bigrams = [tuple(tokens[i:i + 2]) for i in range(len(tokens) - 1)]
|
| 250 |
+
fourgrams = [tuple(tokens[i:i + 4]) for i in range(len(tokens) - 3)]
|
| 251 |
+
all_bigrams.extend(bigrams)
|
| 252 |
+
all_fourgrams.extend(fourgrams)
|
| 253 |
+
|
| 254 |
+
# Distinct-2: fraction of unique bigrams
|
| 255 |
+
distinct2 = len(set(all_bigrams)) / max(len(all_bigrams), 1)
|
| 256 |
+
|
| 257 |
+
# Repetition rate: fraction of 4-grams that appear more than once
|
| 258 |
+
fourgram_counts = Counter(all_fourgrams)
|
| 259 |
+
repeated = sum(1 for c in fourgram_counts.values() if c > 1)
|
| 260 |
+
repetition_rate = repeated / max(len(fourgram_counts), 1)
|
| 261 |
+
|
| 262 |
+
# Self-BLEU: average BLEU of each generation against all others
|
| 263 |
+
# Lower = more diverse
|
| 264 |
+
self_bleu_scores = []
|
| 265 |
+
for i, hyp in enumerate(tokenized_gens):
|
| 266 |
+
if not hyp:
|
| 267 |
+
continue
|
| 268 |
+
others = [g for j, g in enumerate(tokenized_gens) if j != i and g]
|
| 269 |
+
if not others:
|
| 270 |
+
continue
|
| 271 |
+
# Average BLEU against each other generation
|
| 272 |
+
pair_scores = []
|
| 273 |
+
for ref in others:
|
| 274 |
+
result = compute_bleu([ref], [hyp], max_n=4)
|
| 275 |
+
pair_scores.append(result.get("bleu4", 0.0))
|
| 276 |
+
self_bleu_scores.append(sum(pair_scores) / len(pair_scores))
|
| 277 |
+
|
| 278 |
+
self_bleu = sum(self_bleu_scores) / max(len(self_bleu_scores), 1)
|
| 279 |
+
|
| 280 |
+
return {
|
| 281 |
+
"distinct2": distinct2,
|
| 282 |
+
"repetition_rate": repetition_rate,
|
| 283 |
+
"self_bleu": self_bleu,
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ---------------------------------------------------------------------------
|
| 288 |
+
# Factual accuracy (reuse existing probes)
|
| 289 |
+
# ---------------------------------------------------------------------------
|
| 290 |
+
|
| 291 |
+
def compute_factual(model, tokenizer, device: str = "cuda") -> float:
|
| 292 |
+
"""Run factual eval probes, return accuracy [0,1]."""
|
| 293 |
+
model.eval()
|
| 294 |
+
hits = 0
|
| 295 |
+
|
| 296 |
+
with torch.no_grad(), torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 297 |
+
for prompt, answers in FACTUAL_EVAL:
|
| 298 |
+
ids = tokenizer.encode(prompt)
|
| 299 |
+
x = torch.tensor([ids], device=device, dtype=torch.long)
|
| 300 |
+
logits = model(x, targets=None)
|
| 301 |
+
if logits.dim() == 3:
|
| 302 |
+
last_logits = logits[0, -1, :]
|
| 303 |
+
else:
|
| 304 |
+
last_logits = logits[0]
|
| 305 |
+
|
| 306 |
+
probs = torch.softmax(last_logits.float(), dim=-1)
|
| 307 |
+
top_k = min(20, probs.shape[-1])
|
| 308 |
+
top_ids = torch.topk(probs, top_k).indices.tolist()
|
| 309 |
+
top_tokens = [tokenizer.decode([tid]).strip().lower() for tid in top_ids]
|
| 310 |
+
answers_lower = [a.lower() for a in answers]
|
| 311 |
+
if any(any(a in tok for a in answers_lower) for tok in top_tokens):
|
| 312 |
+
hits += 1
|
| 313 |
+
|
| 314 |
+
return hits / max(len(FACTUAL_EVAL), 1)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ---------------------------------------------------------------------------
|
| 318 |
+
# PPL (perplexity) via existing evaluate_bpb
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
+
|
| 321 |
+
def compute_ppl(model, tokenizer, batch_size: int = 8) -> tuple[float, float]:
|
| 322 |
+
"""Compute BPB and PPL. Returns (bpb, ppl)."""
|
| 323 |
+
import prepare as _prepare_mod
|
| 324 |
+
# Use smaller eval set for speed (<30s budget)
|
| 325 |
+
orig_eval = _prepare_mod.EVAL_TOKENS
|
| 326 |
+
_prepare_mod.EVAL_TOKENS = 2 * 524288 # ~1M tokens, fast
|
| 327 |
+
try:
|
| 328 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 329 |
+
bpb = evaluate_bpb(model, tokenizer, batch_size)
|
| 330 |
+
finally:
|
| 331 |
+
_prepare_mod.EVAL_TOKENS = orig_eval
|
| 332 |
+
ppl = 2 ** bpb
|
| 333 |
+
return bpb, ppl
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# ---------------------------------------------------------------------------
|
| 337 |
+
# Composite quality score
|
| 338 |
+
# ---------------------------------------------------------------------------
|
| 339 |
+
|
| 340 |
+
def compute_quality_score(ppl: float, bleu4: float, rouge_l: float,
|
| 341 |
+
factual: float, repetition_rate: float) -> float:
|
| 342 |
+
"""Single composite metric for autoresearch optimization.
|
| 343 |
+
|
| 344 |
+
Formula rationale:
|
| 345 |
+
- PPL (30%): Primary language modeling metric, capped at 100
|
| 346 |
+
- BLEU-4 (20%): Generation quality vs references
|
| 347 |
+
- ROUGE-L (20%): Recall of reference content
|
| 348 |
+
- Factual (15%): Knowledge memorization
|
| 349 |
+
- 1-repetition (15%): Diversity/coherence
|
| 350 |
+
"""
|
| 351 |
+
return (
|
| 352 |
+
0.3 * (1 - min(ppl, 100) / 100) +
|
| 353 |
+
0.2 * bleu4 +
|
| 354 |
+
0.2 * rouge_l +
|
| 355 |
+
0.15 * factual +
|
| 356 |
+
0.15 * (1 - repetition_rate)
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# ---------------------------------------------------------------------------
|
| 361 |
+
# Main evaluation entry point
|
| 362 |
+
# ---------------------------------------------------------------------------
|
| 363 |
+
|
| 364 |
+
def run_quality_eval(
|
| 365 |
+
model: torch.nn.Module,
|
| 366 |
+
tokenizer,
|
| 367 |
+
device: str = "cuda",
|
| 368 |
+
batch_size: int = 8,
|
| 369 |
+
verbose: bool = True,
|
| 370 |
+
) -> dict[str, float]:
|
| 371 |
+
"""Run full quality evaluation suite. Returns dict of all metrics."""
|
| 372 |
+
model.eval()
|
| 373 |
+
results: dict[str, float] = {}
|
| 374 |
+
|
| 375 |
+
t0 = time.time()
|
| 376 |
+
|
| 377 |
+
# 1. PPL / BPB
|
| 378 |
+
if verbose:
|
| 379 |
+
print("[eval] Computing PPL/BPB...", flush=True)
|
| 380 |
+
bpb, ppl = compute_ppl(model, tokenizer, batch_size)
|
| 381 |
+
results["bpb"] = bpb
|
| 382 |
+
results["ppl"] = ppl
|
| 383 |
+
|
| 384 |
+
# 2. Generate continuations for BLEU/ROUGE
|
| 385 |
+
if verbose:
|
| 386 |
+
print("[eval] Generating continuations (20 prompts, greedy)...", flush=True)
|
| 387 |
+
hypotheses_text = []
|
| 388 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 389 |
+
for prompt in EVAL_PROMPTS:
|
| 390 |
+
gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=32, device=device)
|
| 391 |
+
hypotheses_text.append(gen)
|
| 392 |
+
|
| 393 |
+
# Tokenize for BLEU/ROUGE (simple whitespace split)
|
| 394 |
+
ref_tokens = [ref.lower().split() for ref in EVAL_REFERENCES]
|
| 395 |
+
hyp_tokens = [hyp.lower().split() for hyp in hypotheses_text]
|
| 396 |
+
|
| 397 |
+
# 3. BLEU
|
| 398 |
+
if verbose:
|
| 399 |
+
print("[eval] Computing BLEU...", flush=True)
|
| 400 |
+
bleu = compute_bleu(ref_tokens, hyp_tokens, max_n=4)
|
| 401 |
+
results["bleu1"] = bleu["bleu1"]
|
| 402 |
+
results["bleu4"] = bleu["bleu4"]
|
| 403 |
+
|
| 404 |
+
# 4. ROUGE
|
| 405 |
+
if verbose:
|
| 406 |
+
print("[eval] Computing ROUGE...", flush=True)
|
| 407 |
+
rouge = compute_rouge(ref_tokens, hyp_tokens)
|
| 408 |
+
results["rouge1"] = rouge["rouge1"]
|
| 409 |
+
results["rouge_l"] = rouge["rouge_l"]
|
| 410 |
+
|
| 411 |
+
# 5. Factual accuracy
|
| 412 |
+
if verbose:
|
| 413 |
+
print("[eval] Computing factual accuracy...", flush=True)
|
| 414 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 415 |
+
factual = compute_factual(model, tokenizer, device)
|
| 416 |
+
results["factual"] = factual
|
| 417 |
+
|
| 418 |
+
# 6. Coherence
|
| 419 |
+
if verbose:
|
| 420 |
+
print("[eval] Generating coherence passages (10 prompts, 64 tokens)...", flush=True)
|
| 421 |
+
coherence_gens = []
|
| 422 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 423 |
+
for prompt in COHERENCE_PROMPTS:
|
| 424 |
+
gen = greedy_generate(model, tokenizer, prompt, max_new_tokens=64, device=device)
|
| 425 |
+
coherence_gens.append(gen)
|
| 426 |
+
|
| 427 |
+
coherence = compute_coherence(coherence_gens)
|
| 428 |
+
results["distinct2"] = coherence["distinct2"]
|
| 429 |
+
results["repetition_rate"] = coherence["repetition_rate"]
|
| 430 |
+
results["self_bleu"] = coherence["self_bleu"]
|
| 431 |
+
|
| 432 |
+
# 7. Composite score
|
| 433 |
+
results["quality_score"] = compute_quality_score(
|
| 434 |
+
ppl=results["ppl"],
|
| 435 |
+
bleu4=results["bleu4"],
|
| 436 |
+
rouge_l=results["rouge_l"],
|
| 437 |
+
factual=results["factual"],
|
| 438 |
+
repetition_rate=results["repetition_rate"],
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
elapsed = time.time() - t0
|
| 442 |
+
results["eval_time_s"] = elapsed
|
| 443 |
+
|
| 444 |
+
# Print all metrics
|
| 445 |
+
if verbose:
|
| 446 |
+
print("\n--- Quality Evaluation Results ---")
|
| 447 |
+
for k, v in sorted(results.items()):
|
| 448 |
+
print(f"{k}={v:.6f}")
|
| 449 |
+
print("--- End Quality Evaluation ---\n")
|
| 450 |
+
|
| 451 |
+
# Print sample generations
|
| 452 |
+
print("--- Sample Generations ---")
|
| 453 |
+
for i, (prompt, gen) in enumerate(zip(EVAL_PROMPTS[:5], hypotheses_text[:5])):
|
| 454 |
+
print(f' [{i}] "{prompt}" -> "{gen.strip()[:80]}"')
|
| 455 |
+
print("--- End Sample Generations ---\n")
|
| 456 |
+
|
| 457 |
+
print("--- Coherence Samples ---")
|
| 458 |
+
for i, (prompt, gen) in enumerate(zip(COHERENCE_PROMPTS[:3], coherence_gens[:3])):
|
| 459 |
+
print(f' [{i}] "{prompt}" -> "{gen.strip()[:100]}"')
|
| 460 |
+
print("--- End Coherence Samples ---\n")
|
| 461 |
+
|
| 462 |
+
return results
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# ---------------------------------------------------------------------------
|
| 466 |
+
# Standalone CLI
|
| 467 |
+
# ---------------------------------------------------------------------------
|
| 468 |
+
|
| 469 |
+
def _build_model_and_tokenizer(checkpoint: Optional[str] = None):
|
| 470 |
+
"""Build model + tokenizer, optionally loading from checkpoint."""
|
| 471 |
+
from hydra.model import PostSemClawModel
|
| 472 |
+
|
| 473 |
+
device = torch.device("cuda")
|
| 474 |
+
tokenizer = Tokenizer.from_directory()
|
| 475 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 476 |
+
|
| 477 |
+
config = PostSemClawConfig(
|
| 478 |
+
sequence_len=MAX_SEQ_LEN,
|
| 479 |
+
vocab_size=vocab_size,
|
| 480 |
+
n_layer=N_LAYER,
|
| 481 |
+
d_model=D_MODEL,
|
| 482 |
+
d_state=D_STATE,
|
| 483 |
+
headdim=HEADDIM,
|
| 484 |
+
n_heads=N_HEADS,
|
| 485 |
+
expand=EXPAND,
|
| 486 |
+
engram_n_columns=ENGRAM_N_COLUMNS,
|
| 487 |
+
engram_key_dim=ENGRAM_KEY_DIM,
|
| 488 |
+
engram_layer_idx=ENGRAM_LAYER_IDX,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
with torch.device("meta"):
|
| 492 |
+
model = PostSemClawModel(config)
|
| 493 |
+
model.to_empty(device=device)
|
| 494 |
+
|
| 495 |
+
if checkpoint and os.path.exists(checkpoint):
|
| 496 |
+
print(f"[eval] Loading checkpoint: {checkpoint}")
|
| 497 |
+
state = torch.load(checkpoint, map_location=device, weights_only=True)
|
| 498 |
+
model.load_state_dict(state, strict=False)
|
| 499 |
+
else:
|
| 500 |
+
print("[eval] No checkpoint β using freshly initialized weights")
|
| 501 |
+
model.init_weights()
|
| 502 |
+
|
| 503 |
+
model.eval()
|
| 504 |
+
return model, tokenizer, device
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def main():
|
| 508 |
+
import argparse
|
| 509 |
+
parser = argparse.ArgumentParser(description="HYDRA quality evaluation")
|
| 510 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint")
|
| 511 |
+
parser.add_argument("--batch-size", type=int, default=DEVICE_BATCH_SIZE, help="Batch size for PPL eval")
|
| 512 |
+
args = parser.parse_args()
|
| 513 |
+
|
| 514 |
+
model, tokenizer, device = _build_model_and_tokenizer(args.checkpoint)
|
| 515 |
+
results = run_quality_eval(model, tokenizer, str(device), args.batch_size, verbose=True)
|
| 516 |
+
|
| 517 |
+
# Final summary line (grep-friendly)
|
| 518 |
+
print(f"QUALITY_SCORE={results['quality_score']:.6f} PPL={results['ppl']:.3f} "
|
| 519 |
+
f"BPB={results['bpb']:.4f} BLEU4={results['bleu4']:.4f} "
|
| 520 |
+
f"ROUGE_L={results['rouge_l']:.4f} FACTUAL={results['factual']:.4f} "
|
| 521 |
+
f"REP_RATE={results['repetition_rate']:.4f}")
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
if __name__ == "__main__":
|
| 525 |
+
main()
|
overlay/scripts/fetch_corpus.py
CHANGED
|
@@ -1,211 +1,211 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Fetch additional training shards from karpathy/climbmix-400b-shuffle.
|
| 3 |
-
|
| 4 |
-
The repo already has ~500 shards (~31B tokens). This script is a
|
| 5 |
-
resumable, parallel downloader for cases where more shards are needed
|
| 6 |
-
(e.g., multi-day training, experiments requiring fresh-unseen data,
|
| 7 |
-
or when we want to split the corpus across processes).
|
| 8 |
-
|
| 9 |
-
Usage:
|
| 10 |
-
# Fetch shards up to index 600 (total cap)
|
| 11 |
-
python scripts/fetch_corpus.py --target-shards 600
|
| 12 |
-
|
| 13 |
-
# Fetch a specific range
|
| 14 |
-
python scripts/fetch_corpus.py --start 500 --end 800
|
| 15 |
-
|
| 16 |
-
# Dry-run (list what would be downloaded)
|
| 17 |
-
python scripts/fetch_corpus.py --target-shards 600 --dry-run
|
| 18 |
-
|
| 19 |
-
Notes:
|
| 20 |
-
- Safe to run while training is active; only writes files not touched
|
| 21 |
-
by the training process.
|
| 22 |
-
- Resumable: skips shards already on disk.
|
| 23 |
-
- Downloads to the same DATA_DIR used by prepare.py so they're picked
|
| 24 |
-
up on next training launch.
|
| 25 |
-
"""
|
| 26 |
-
from __future__ import annotations
|
| 27 |
-
|
| 28 |
-
import argparse
|
| 29 |
-
import os
|
| 30 |
-
import shutil
|
| 31 |
-
import sys
|
| 32 |
-
import time
|
| 33 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 34 |
-
from pathlib import Path
|
| 35 |
-
|
| 36 |
-
import requests
|
| 37 |
-
|
| 38 |
-
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 39 |
-
sys.path.insert(0, str(REPO_ROOT))
|
| 40 |
-
|
| 41 |
-
from prepare import BASE_URL, DATA_DIR, MAX_SHARD, VAL_SHARD # noqa: E402
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
def human_bytes(n: int) -> str:
|
| 45 |
-
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 46 |
-
if n < 1024:
|
| 47 |
-
return f"{n:.1f}{unit}"
|
| 48 |
-
n /= 1024
|
| 49 |
-
return f"{n:.1f}PB"
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def download_one(
|
| 53 |
-
index: int, data_dir: str, timeout: int = 30, max_attempts: int = 5
|
| 54 |
-
) -> tuple[int, bool, int, str]:
|
| 55 |
-
"""
|
| 56 |
-
Download a single parquet shard. Resumable + retry with exponential backoff.
|
| 57 |
-
Returns (index, success, bytes_written, message).
|
| 58 |
-
"""
|
| 59 |
-
filename = f"shard_{index:05d}.parquet"
|
| 60 |
-
filepath = os.path.join(data_dir, filename)
|
| 61 |
-
tmp_path = filepath + ".tmp"
|
| 62 |
-
|
| 63 |
-
if os.path.exists(filepath):
|
| 64 |
-
return index, True, 0, "already-present"
|
| 65 |
-
|
| 66 |
-
url = f"{BASE_URL}/{filename}"
|
| 67 |
-
for attempt in range(1, max_attempts + 1):
|
| 68 |
-
try:
|
| 69 |
-
with requests.get(url, stream=True, timeout=timeout) as r:
|
| 70 |
-
r.raise_for_status()
|
| 71 |
-
bytes_written = 0
|
| 72 |
-
with open(tmp_path, "wb") as f:
|
| 73 |
-
for chunk in r.iter_content(chunk_size=1 << 20):
|
| 74 |
-
if chunk:
|
| 75 |
-
f.write(chunk)
|
| 76 |
-
bytes_written += len(chunk)
|
| 77 |
-
os.rename(tmp_path, filepath)
|
| 78 |
-
return index, True, bytes_written, f"ok (attempt {attempt})"
|
| 79 |
-
except (requests.RequestException, OSError) as e:
|
| 80 |
-
# Clean up partial file.
|
| 81 |
-
for p in (tmp_path, filepath):
|
| 82 |
-
if os.path.exists(p):
|
| 83 |
-
try:
|
| 84 |
-
os.remove(p)
|
| 85 |
-
except OSError:
|
| 86 |
-
pass
|
| 87 |
-
if attempt < max_attempts:
|
| 88 |
-
wait = 2 ** attempt
|
| 89 |
-
time.sleep(wait)
|
| 90 |
-
continue
|
| 91 |
-
return index, False, 0, f"failed after {max_attempts} attempts: {e}"
|
| 92 |
-
|
| 93 |
-
return index, False, 0, "unknown failure"
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def check_disk_space(required_bytes: int, data_dir: str) -> tuple[bool, int]:
|
| 97 |
-
"""Ensure we have at least required_bytes + 10% headroom free."""
|
| 98 |
-
os.makedirs(data_dir, exist_ok=True)
|
| 99 |
-
stats = shutil.disk_usage(data_dir)
|
| 100 |
-
headroom = int(required_bytes * 1.1)
|
| 101 |
-
return stats.free >= headroom, stats.free
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def main() -> int:
|
| 105 |
-
parser = argparse.ArgumentParser(
|
| 106 |
-
description="Fetch additional climbmix-400b-shuffle shards"
|
| 107 |
-
)
|
| 108 |
-
parser.add_argument(
|
| 109 |
-
"--target-shards",
|
| 110 |
-
type=int,
|
| 111 |
-
default=None,
|
| 112 |
-
help="Total train-shard count to reach (0..target-1). Mutually exclusive with --start/--end.",
|
| 113 |
-
)
|
| 114 |
-
parser.add_argument("--start", type=int, default=None, help="Starting shard index (inclusive)")
|
| 115 |
-
parser.add_argument("--end", type=int, default=None, help="Ending shard index (exclusive)")
|
| 116 |
-
parser.add_argument("--workers", type=int, default=8, help="Parallel download workers")
|
| 117 |
-
parser.add_argument(
|
| 118 |
-
"--include-val",
|
| 119 |
-
action="store_true",
|
| 120 |
-
help="Also fetch the pinned validation shard (normally present already)",
|
| 121 |
-
)
|
| 122 |
-
parser.add_argument(
|
| 123 |
-
"--dry-run",
|
| 124 |
-
action="store_true",
|
| 125 |
-
help="List what would be downloaded without fetching",
|
| 126 |
-
)
|
| 127 |
-
args = parser.parse_args()
|
| 128 |
-
|
| 129 |
-
# Resolve shard range.
|
| 130 |
-
if args.target_shards is not None:
|
| 131 |
-
if args.start is not None or args.end is not None:
|
| 132 |
-
print("ERROR: --target-shards is exclusive with --start/--end")
|
| 133 |
-
return 1
|
| 134 |
-
ids = list(range(min(args.target_shards, MAX_SHARD)))
|
| 135 |
-
else:
|
| 136 |
-
start = args.start or 0
|
| 137 |
-
end = args.end if args.end is not None else MAX_SHARD
|
| 138 |
-
end = min(end, MAX_SHARD)
|
| 139 |
-
ids = list(range(start, end))
|
| 140 |
-
|
| 141 |
-
if args.include_val and VAL_SHARD not in ids:
|
| 142 |
-
ids.append(VAL_SHARD)
|
| 143 |
-
|
| 144 |
-
os.makedirs(DATA_DIR, exist_ok=True)
|
| 145 |
-
present = set()
|
| 146 |
-
for p in Path(DATA_DIR).glob("shard_*.parquet"):
|
| 147 |
-
try:
|
| 148 |
-
idx = int(p.stem.split("_")[1])
|
| 149 |
-
present.add(idx)
|
| 150 |
-
except (IndexError, ValueError):
|
| 151 |
-
continue
|
| 152 |
-
|
| 153 |
-
to_fetch = [i for i in ids if i not in present]
|
| 154 |
-
if not to_fetch:
|
| 155 |
-
print(f"All {len(ids)} shards already present at {DATA_DIR}")
|
| 156 |
-
return 0
|
| 157 |
-
|
| 158 |
-
# Estimate space: shards are ~88MB; leave 10% headroom.
|
| 159 |
-
avg_shard_bytes = 90 * (1 << 20) # 90MB
|
| 160 |
-
required = avg_shard_bytes * len(to_fetch)
|
| 161 |
-
ok, free = check_disk_space(required, DATA_DIR)
|
| 162 |
-
print(f"Plan: fetch {len(to_fetch)} shards (~{human_bytes(required)}); "
|
| 163 |
-
f"disk free: {human_bytes(free)}")
|
| 164 |
-
if not ok:
|
| 165 |
-
print("ERROR: insufficient disk space (need 1.1x required)")
|
| 166 |
-
return 2
|
| 167 |
-
|
| 168 |
-
if args.dry_run:
|
| 169 |
-
preview = to_fetch[:10]
|
| 170 |
-
print(
|
| 171 |
-
f"Dry-run β would fetch {len(to_fetch)} shards. First {len(preview)}: {preview}"
|
| 172 |
-
)
|
| 173 |
-
return 0
|
| 174 |
-
|
| 175 |
-
print(f"Downloading {len(to_fetch)} shards with {args.workers} workers...")
|
| 176 |
-
t_start = time.time()
|
| 177 |
-
success = 0
|
| 178 |
-
failed = 0
|
| 179 |
-
total_bytes = 0
|
| 180 |
-
|
| 181 |
-
with ThreadPoolExecutor(max_workers=args.workers) as ex:
|
| 182 |
-
futs = {ex.submit(download_one, i, DATA_DIR): i for i in to_fetch}
|
| 183 |
-
for fut in as_completed(futs):
|
| 184 |
-
idx, ok, nbytes, msg = fut.result()
|
| 185 |
-
if ok:
|
| 186 |
-
success += 1
|
| 187 |
-
total_bytes += nbytes
|
| 188 |
-
if success % 10 == 0 or success == len(to_fetch):
|
| 189 |
-
elapsed = time.time() - t_start
|
| 190 |
-
rate = total_bytes / max(elapsed, 1)
|
| 191 |
-
print(
|
| 192 |
-
f" [{success}/{len(to_fetch)}] shard_{idx:05d} ok "
|
| 193 |
-
f"({human_bytes(total_bytes)} @ {human_bytes(int(rate))}/s)"
|
| 194 |
-
)
|
| 195 |
-
else:
|
| 196 |
-
failed += 1
|
| 197 |
-
print(f" [FAIL] shard_{idx:05d}: {msg}")
|
| 198 |
-
|
| 199 |
-
elapsed = time.time() - t_start
|
| 200 |
-
print()
|
| 201 |
-
print("=" * 60)
|
| 202 |
-
print(f"Downloaded {success}/{len(to_fetch)} shards in {elapsed:.1f}s")
|
| 203 |
-
print(f"Failed: {failed}")
|
| 204 |
-
print(f"Total bytes: {human_bytes(total_bytes)}")
|
| 205 |
-
print("=" * 60)
|
| 206 |
-
|
| 207 |
-
return 0 if failed == 0 else 3
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
if __name__ == "__main__":
|
| 211 |
-
raise SystemExit(main())
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fetch additional training shards from karpathy/climbmix-400b-shuffle.
|
| 3 |
+
|
| 4 |
+
The repo already has ~500 shards (~31B tokens). This script is a
|
| 5 |
+
resumable, parallel downloader for cases where more shards are needed
|
| 6 |
+
(e.g., multi-day training, experiments requiring fresh-unseen data,
|
| 7 |
+
or when we want to split the corpus across processes).
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
# Fetch shards up to index 600 (total cap)
|
| 11 |
+
python scripts/fetch_corpus.py --target-shards 600
|
| 12 |
+
|
| 13 |
+
# Fetch a specific range
|
| 14 |
+
python scripts/fetch_corpus.py --start 500 --end 800
|
| 15 |
+
|
| 16 |
+
# Dry-run (list what would be downloaded)
|
| 17 |
+
python scripts/fetch_corpus.py --target-shards 600 --dry-run
|
| 18 |
+
|
| 19 |
+
Notes:
|
| 20 |
+
- Safe to run while training is active; only writes files not touched
|
| 21 |
+
by the training process.
|
| 22 |
+
- Resumable: skips shards already on disk.
|
| 23 |
+
- Downloads to the same DATA_DIR used by prepare.py so they're picked
|
| 24 |
+
up on next training launch.
|
| 25 |
+
"""
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import os
|
| 30 |
+
import shutil
|
| 31 |
+
import sys
|
| 32 |
+
import time
|
| 33 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
import requests
|
| 37 |
+
|
| 38 |
+
REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 39 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 40 |
+
|
| 41 |
+
from prepare import BASE_URL, DATA_DIR, MAX_SHARD, VAL_SHARD # noqa: E402
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def human_bytes(n: int) -> str:
|
| 45 |
+
for unit in ("B", "KB", "MB", "GB", "TB"):
|
| 46 |
+
if n < 1024:
|
| 47 |
+
return f"{n:.1f}{unit}"
|
| 48 |
+
n /= 1024
|
| 49 |
+
return f"{n:.1f}PB"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def download_one(
|
| 53 |
+
index: int, data_dir: str, timeout: int = 30, max_attempts: int = 5
|
| 54 |
+
) -> tuple[int, bool, int, str]:
|
| 55 |
+
"""
|
| 56 |
+
Download a single parquet shard. Resumable + retry with exponential backoff.
|
| 57 |
+
Returns (index, success, bytes_written, message).
|
| 58 |
+
"""
|
| 59 |
+
filename = f"shard_{index:05d}.parquet"
|
| 60 |
+
filepath = os.path.join(data_dir, filename)
|
| 61 |
+
tmp_path = filepath + ".tmp"
|
| 62 |
+
|
| 63 |
+
if os.path.exists(filepath):
|
| 64 |
+
return index, True, 0, "already-present"
|
| 65 |
+
|
| 66 |
+
url = f"{BASE_URL}/{filename}"
|
| 67 |
+
for attempt in range(1, max_attempts + 1):
|
| 68 |
+
try:
|
| 69 |
+
with requests.get(url, stream=True, timeout=timeout) as r:
|
| 70 |
+
r.raise_for_status()
|
| 71 |
+
bytes_written = 0
|
| 72 |
+
with open(tmp_path, "wb") as f:
|
| 73 |
+
for chunk in r.iter_content(chunk_size=1 << 20):
|
| 74 |
+
if chunk:
|
| 75 |
+
f.write(chunk)
|
| 76 |
+
bytes_written += len(chunk)
|
| 77 |
+
os.rename(tmp_path, filepath)
|
| 78 |
+
return index, True, bytes_written, f"ok (attempt {attempt})"
|
| 79 |
+
except (requests.RequestException, OSError) as e:
|
| 80 |
+
# Clean up partial file.
|
| 81 |
+
for p in (tmp_path, filepath):
|
| 82 |
+
if os.path.exists(p):
|
| 83 |
+
try:
|
| 84 |
+
os.remove(p)
|
| 85 |
+
except OSError:
|
| 86 |
+
pass
|
| 87 |
+
if attempt < max_attempts:
|
| 88 |
+
wait = 2 ** attempt
|
| 89 |
+
time.sleep(wait)
|
| 90 |
+
continue
|
| 91 |
+
return index, False, 0, f"failed after {max_attempts} attempts: {e}"
|
| 92 |
+
|
| 93 |
+
return index, False, 0, "unknown failure"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def check_disk_space(required_bytes: int, data_dir: str) -> tuple[bool, int]:
|
| 97 |
+
"""Ensure we have at least required_bytes + 10% headroom free."""
|
| 98 |
+
os.makedirs(data_dir, exist_ok=True)
|
| 99 |
+
stats = shutil.disk_usage(data_dir)
|
| 100 |
+
headroom = int(required_bytes * 1.1)
|
| 101 |
+
return stats.free >= headroom, stats.free
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def main() -> int:
|
| 105 |
+
parser = argparse.ArgumentParser(
|
| 106 |
+
description="Fetch additional climbmix-400b-shuffle shards"
|
| 107 |
+
)
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--target-shards",
|
| 110 |
+
type=int,
|
| 111 |
+
default=None,
|
| 112 |
+
help="Total train-shard count to reach (0..target-1). Mutually exclusive with --start/--end.",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument("--start", type=int, default=None, help="Starting shard index (inclusive)")
|
| 115 |
+
parser.add_argument("--end", type=int, default=None, help="Ending shard index (exclusive)")
|
| 116 |
+
parser.add_argument("--workers", type=int, default=8, help="Parallel download workers")
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--include-val",
|
| 119 |
+
action="store_true",
|
| 120 |
+
help="Also fetch the pinned validation shard (normally present already)",
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--dry-run",
|
| 124 |
+
action="store_true",
|
| 125 |
+
help="List what would be downloaded without fetching",
|
| 126 |
+
)
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
|
| 129 |
+
# Resolve shard range.
|
| 130 |
+
if args.target_shards is not None:
|
| 131 |
+
if args.start is not None or args.end is not None:
|
| 132 |
+
print("ERROR: --target-shards is exclusive with --start/--end")
|
| 133 |
+
return 1
|
| 134 |
+
ids = list(range(min(args.target_shards, MAX_SHARD)))
|
| 135 |
+
else:
|
| 136 |
+
start = args.start or 0
|
| 137 |
+
end = args.end if args.end is not None else MAX_SHARD
|
| 138 |
+
end = min(end, MAX_SHARD)
|
| 139 |
+
ids = list(range(start, end))
|
| 140 |
+
|
| 141 |
+
if args.include_val and VAL_SHARD not in ids:
|
| 142 |
+
ids.append(VAL_SHARD)
|
| 143 |
+
|
| 144 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 145 |
+
present = set()
|
| 146 |
+
for p in Path(DATA_DIR).glob("shard_*.parquet"):
|
| 147 |
+
try:
|
| 148 |
+
idx = int(p.stem.split("_")[1])
|
| 149 |
+
present.add(idx)
|
| 150 |
+
except (IndexError, ValueError):
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
to_fetch = [i for i in ids if i not in present]
|
| 154 |
+
if not to_fetch:
|
| 155 |
+
print(f"All {len(ids)} shards already present at {DATA_DIR}")
|
| 156 |
+
return 0
|
| 157 |
+
|
| 158 |
+
# Estimate space: shards are ~88MB; leave 10% headroom.
|
| 159 |
+
avg_shard_bytes = 90 * (1 << 20) # 90MB
|
| 160 |
+
required = avg_shard_bytes * len(to_fetch)
|
| 161 |
+
ok, free = check_disk_space(required, DATA_DIR)
|
| 162 |
+
print(f"Plan: fetch {len(to_fetch)} shards (~{human_bytes(required)}); "
|
| 163 |
+
f"disk free: {human_bytes(free)}")
|
| 164 |
+
if not ok:
|
| 165 |
+
print("ERROR: insufficient disk space (need 1.1x required)")
|
| 166 |
+
return 2
|
| 167 |
+
|
| 168 |
+
if args.dry_run:
|
| 169 |
+
preview = to_fetch[:10]
|
| 170 |
+
print(
|
| 171 |
+
f"Dry-run β would fetch {len(to_fetch)} shards. First {len(preview)}: {preview}"
|
| 172 |
+
)
|
| 173 |
+
return 0
|
| 174 |
+
|
| 175 |
+
print(f"Downloading {len(to_fetch)} shards with {args.workers} workers...")
|
| 176 |
+
t_start = time.time()
|
| 177 |
+
success = 0
|
| 178 |
+
failed = 0
|
| 179 |
+
total_bytes = 0
|
| 180 |
+
|
| 181 |
+
with ThreadPoolExecutor(max_workers=args.workers) as ex:
|
| 182 |
+
futs = {ex.submit(download_one, i, DATA_DIR): i for i in to_fetch}
|
| 183 |
+
for fut in as_completed(futs):
|
| 184 |
+
idx, ok, nbytes, msg = fut.result()
|
| 185 |
+
if ok:
|
| 186 |
+
success += 1
|
| 187 |
+
total_bytes += nbytes
|
| 188 |
+
if success % 10 == 0 or success == len(to_fetch):
|
| 189 |
+
elapsed = time.time() - t_start
|
| 190 |
+
rate = total_bytes / max(elapsed, 1)
|
| 191 |
+
print(
|
| 192 |
+
f" [{success}/{len(to_fetch)}] shard_{idx:05d} ok "
|
| 193 |
+
f"({human_bytes(total_bytes)} @ {human_bytes(int(rate))}/s)"
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
failed += 1
|
| 197 |
+
print(f" [FAIL] shard_{idx:05d}: {msg}")
|
| 198 |
+
|
| 199 |
+
elapsed = time.time() - t_start
|
| 200 |
+
print()
|
| 201 |
+
print("=" * 60)
|
| 202 |
+
print(f"Downloaded {success}/{len(to_fetch)} shards in {elapsed:.1f}s")
|
| 203 |
+
print(f"Failed: {failed}")
|
| 204 |
+
print(f"Total bytes: {human_bytes(total_bytes)}")
|
| 205 |
+
print("=" * 60)
|
| 206 |
+
|
| 207 |
+
return 0 if failed == 0 else 3
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
raise SystemExit(main())
|
overlay/scripts/grad_probe.py
CHANGED
|
@@ -1,196 +1,196 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Gradient flow probe for PostSemClawModel.
|
| 3 |
-
|
| 4 |
-
READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT
|
| 5 |
-
step an optimizer. Runs one forward + backward and reports, per-parameter:
|
| 6 |
-
|
| 7 |
-
name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm
|
| 8 |
-
|
| 9 |
-
Severity classification at the bottom:
|
| 10 |
-
BLOCKER β requires_grad=True but p.grad is None (disconnected from graph)
|
| 11 |
-
WARNING β grad present but literally zero (ops cancel, wd_init, etc.)
|
| 12 |
-
WARNING β requires_grad=True but param missing from every optimizer group
|
| 13 |
-
OK β everything else
|
| 14 |
-
|
| 15 |
-
Usage:
|
| 16 |
-
.venv/bin/python -u scripts/grad_probe.py
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
from __future__ import annotations
|
| 20 |
-
|
| 21 |
-
import os
|
| 22 |
-
import sys
|
| 23 |
-
from pathlib import Path
|
| 24 |
-
|
| 25 |
-
# Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare`
|
| 26 |
-
# resolve when we run from any cwd). Probe is intentionally a thin wrapper.
|
| 27 |
-
HERE = Path(__file__).resolve().parent
|
| 28 |
-
ROOT = HERE.parent
|
| 29 |
-
sys.path.insert(0, str(ROOT))
|
| 30 |
-
|
| 31 |
-
# Small model config to keep the probe fast (still exercises every component).
|
| 32 |
-
# K=4 MTP (default), d_model=256 (default), n_layer=4 (default).
|
| 33 |
-
os.environ.setdefault("HYDRA_D_MODEL", "256")
|
| 34 |
-
os.environ.setdefault("HYDRA_N_LAYER", "4")
|
| 35 |
-
os.environ.setdefault("HYDRA_MTP_K", "4")
|
| 36 |
-
|
| 37 |
-
import torch # noqa: E402
|
| 38 |
-
|
| 39 |
-
from train import PostSemClawModel, PostSemClawConfig # noqa: E402
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def main() -> int:
|
| 43 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
-
if device != "cuda":
|
| 45 |
-
print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).")
|
| 46 |
-
return 2
|
| 47 |
-
|
| 48 |
-
cfg = PostSemClawConfig(
|
| 49 |
-
sequence_len=64,
|
| 50 |
-
vocab_size=8192,
|
| 51 |
-
n_layer=int(os.environ["HYDRA_N_LAYER"]),
|
| 52 |
-
d_model=int(os.environ["HYDRA_D_MODEL"]),
|
| 53 |
-
d_state=64,
|
| 54 |
-
headdim=32,
|
| 55 |
-
n_heads=8,
|
| 56 |
-
expand=2,
|
| 57 |
-
engram_n_columns=1024,
|
| 58 |
-
engram_key_dim=64,
|
| 59 |
-
engram_layer_idx=1,
|
| 60 |
-
sdr_n_bits=16384,
|
| 61 |
-
sdr_target_active=327,
|
| 62 |
-
sdr_delta_rank=32,
|
| 63 |
-
sdr_som_warmup=500,
|
| 64 |
-
sdr_som_interval=100,
|
| 65 |
-
htm_n_columns=2048,
|
| 66 |
-
htm_cells_per_column=32,
|
| 67 |
-
mtp_k=int(os.environ["HYDRA_MTP_K"]),
|
| 68 |
-
mtp_weight_decay=0.5,
|
| 69 |
-
)
|
| 70 |
-
|
| 71 |
-
print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} "
|
| 72 |
-
f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}")
|
| 73 |
-
|
| 74 |
-
torch.manual_seed(0)
|
| 75 |
-
model = PostSemClawModel(cfg).to(device)
|
| 76 |
-
model.init_weights()
|
| 77 |
-
model.train()
|
| 78 |
-
|
| 79 |
-
# ---- Enumerate params & optimizer group assignment ----
|
| 80 |
-
all_params = list(model.named_parameters())
|
| 81 |
-
print(f"[probe] total named parameters: {len(all_params)}")
|
| 82 |
-
|
| 83 |
-
# Build optimizer to check group coverage (no step, no zero_grad).
|
| 84 |
-
opt = model.setup_optimizer()
|
| 85 |
-
grouped_ids: set[int] = set()
|
| 86 |
-
for group in opt.param_groups:
|
| 87 |
-
for p in group["params"]:
|
| 88 |
-
grouped_ids.add(id(p))
|
| 89 |
-
unique_param_ids = {id(p) for _, p in all_params}
|
| 90 |
-
missing_from_opt = unique_param_ids - grouped_ids
|
| 91 |
-
print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}")
|
| 92 |
-
if missing_from_opt:
|
| 93 |
-
print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups")
|
| 94 |
-
|
| 95 |
-
# Tied weight check.
|
| 96 |
-
tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr()
|
| 97 |
-
print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}")
|
| 98 |
-
|
| 99 |
-
# ---- One forward + backward under bf16 autocast ----
|
| 100 |
-
B, T = 1, 64
|
| 101 |
-
idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device)
|
| 102 |
-
tgt = torch.roll(idx, -1, dims=1)
|
| 103 |
-
|
| 104 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 105 |
-
loss = model(idx, targets=tgt)
|
| 106 |
-
print(f"[probe] fwd loss = {float(loss.detach()):.4f}")
|
| 107 |
-
loss.backward()
|
| 108 |
-
torch.cuda.synchronize()
|
| 109 |
-
|
| 110 |
-
# ---- Report ----
|
| 111 |
-
blockers: list[str] = []
|
| 112 |
-
zero_grads: list[str] = []
|
| 113 |
-
unexpected_frozen: list[str] = []
|
| 114 |
-
not_in_opt: list[str] = []
|
| 115 |
-
rows: list[tuple[str, tuple, str, bool, bool, float, float]] = []
|
| 116 |
-
|
| 117 |
-
for name, p in all_params:
|
| 118 |
-
grad_is_none = p.grad is None
|
| 119 |
-
if p.requires_grad and grad_is_none:
|
| 120 |
-
blockers.append(name)
|
| 121 |
-
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
|
| 122 |
-
p.requires_grad, True, float("nan"), float("nan")))
|
| 123 |
-
continue
|
| 124 |
-
if not p.requires_grad:
|
| 125 |
-
unexpected_frozen.append(name)
|
| 126 |
-
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
|
| 127 |
-
False, True, float("nan"), float("nan")))
|
| 128 |
-
continue
|
| 129 |
-
g = p.grad.detach().float()
|
| 130 |
-
abs_mean = float(g.abs().mean().item())
|
| 131 |
-
norm = float(g.norm().item())
|
| 132 |
-
if abs_mean == 0.0 and norm == 0.0:
|
| 133 |
-
zero_grads.append(name)
|
| 134 |
-
if id(p) not in grouped_ids:
|
| 135 |
-
not_in_opt.append(name)
|
| 136 |
-
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
|
| 137 |
-
p.requires_grad, False, abs_mean, norm))
|
| 138 |
-
|
| 139 |
-
# Pretty table
|
| 140 |
-
print("\n[probe] per-parameter grad table:")
|
| 141 |
-
print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}")
|
| 142 |
-
for name, shape, dtype, rg, none, mean, norm in rows:
|
| 143 |
-
shape_s = "x".join(str(s) for s in shape)
|
| 144 |
-
rg_s = "Y" if rg else "N"
|
| 145 |
-
none_s = "Y" if none else "N"
|
| 146 |
-
if none:
|
| 147 |
-
mean_s, norm_s = " nan ", " nan "
|
| 148 |
-
else:
|
| 149 |
-
mean_s = f"{mean:>10.3e}"
|
| 150 |
-
norm_s = f"{norm:>10.3e}"
|
| 151 |
-
print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}")
|
| 152 |
-
|
| 153 |
-
# Identity checks
|
| 154 |
-
print("\n[probe] identity checks:")
|
| 155 |
-
print(f" id(wte.weight) = {id(model.wte.weight)}")
|
| 156 |
-
print(f" id(lm_head.weight) = {id(model.lm_head.weight)}")
|
| 157 |
-
print(f" same Python object = {model.wte.weight is model.lm_head.weight}")
|
| 158 |
-
print(f" same storage ptr = {tied}")
|
| 159 |
-
|
| 160 |
-
# Engram memory inspection
|
| 161 |
-
print(f"\n[probe] engram.memory is nn.Parameter: "
|
| 162 |
-
f"{isinstance(model.engram.memory, torch.nn.Parameter)}")
|
| 163 |
-
print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}")
|
| 164 |
-
if model.engram.memory.grad is None:
|
| 165 |
-
print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())")
|
| 166 |
-
else:
|
| 167 |
-
g = model.engram.memory.grad.detach().float()
|
| 168 |
-
print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}")
|
| 169 |
-
|
| 170 |
-
# Stash flag sanity: _last_sdr should be uint8, no graph
|
| 171 |
-
last = getattr(model, "_last_sdr", None)
|
| 172 |
-
if last is not None:
|
| 173 |
-
print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}")
|
| 174 |
-
else:
|
| 175 |
-
print("\n[probe] model._last_sdr is None (fwd didn't stash β ok if path changed)")
|
| 176 |
-
|
| 177 |
-
# Summary
|
| 178 |
-
print("\n[probe] ============ SUMMARY ============")
|
| 179 |
-
print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}")
|
| 180 |
-
for n in blockers:
|
| 181 |
-
print(f" - {n}")
|
| 182 |
-
print(f" WARNINGS (grad is literally zero): {len(zero_grads)}")
|
| 183 |
-
for n in zero_grads:
|
| 184 |
-
print(f" - {n}")
|
| 185 |
-
print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}")
|
| 186 |
-
for n in unexpected_frozen:
|
| 187 |
-
print(f" - {n}")
|
| 188 |
-
print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}")
|
| 189 |
-
for n in not_in_opt:
|
| 190 |
-
print(f" - {n}")
|
| 191 |
-
|
| 192 |
-
return 0
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
if __name__ == "__main__":
|
| 196 |
-
sys.exit(main())
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradient flow probe for PostSemClawModel.
|
| 3 |
+
|
| 4 |
+
READ-ONLY diagnostic. Does NOT modify any source, does NOT train, does NOT
|
| 5 |
+
step an optimizer. Runs one forward + backward and reports, per-parameter:
|
| 6 |
+
|
| 7 |
+
name, shape, dtype, requires_grad, grad-is-None?, |grad|.mean, |grad|.norm
|
| 8 |
+
|
| 9 |
+
Severity classification at the bottom:
|
| 10 |
+
BLOCKER β requires_grad=True but p.grad is None (disconnected from graph)
|
| 11 |
+
WARNING β grad present but literally zero (ops cancel, wd_init, etc.)
|
| 12 |
+
WARNING β requires_grad=True but param missing from every optimizer group
|
| 13 |
+
OK β everything else
|
| 14 |
+
|
| 15 |
+
Usage:
|
| 16 |
+
.venv/bin/python -u scripts/grad_probe.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
# Ensure the project root is on sys.path (so `train`, `subsystems`, `prepare`
|
| 26 |
+
# resolve when we run from any cwd). Probe is intentionally a thin wrapper.
|
| 27 |
+
HERE = Path(__file__).resolve().parent
|
| 28 |
+
ROOT = HERE.parent
|
| 29 |
+
sys.path.insert(0, str(ROOT))
|
| 30 |
+
|
| 31 |
+
# Small model config to keep the probe fast (still exercises every component).
|
| 32 |
+
# K=4 MTP (default), d_model=256 (default), n_layer=4 (default).
|
| 33 |
+
os.environ.setdefault("HYDRA_D_MODEL", "256")
|
| 34 |
+
os.environ.setdefault("HYDRA_N_LAYER", "4")
|
| 35 |
+
os.environ.setdefault("HYDRA_MTP_K", "4")
|
| 36 |
+
|
| 37 |
+
import torch # noqa: E402
|
| 38 |
+
|
| 39 |
+
from train import PostSemClawModel, PostSemClawConfig # noqa: E402
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main() -> int:
|
| 43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
+
if device != "cuda":
|
| 45 |
+
print("ERROR: CUDA required (model has mamba-ssm + bf16 autocast path).")
|
| 46 |
+
return 2
|
| 47 |
+
|
| 48 |
+
cfg = PostSemClawConfig(
|
| 49 |
+
sequence_len=64,
|
| 50 |
+
vocab_size=8192,
|
| 51 |
+
n_layer=int(os.environ["HYDRA_N_LAYER"]),
|
| 52 |
+
d_model=int(os.environ["HYDRA_D_MODEL"]),
|
| 53 |
+
d_state=64,
|
| 54 |
+
headdim=32,
|
| 55 |
+
n_heads=8,
|
| 56 |
+
expand=2,
|
| 57 |
+
engram_n_columns=1024,
|
| 58 |
+
engram_key_dim=64,
|
| 59 |
+
engram_layer_idx=1,
|
| 60 |
+
sdr_n_bits=16384,
|
| 61 |
+
sdr_target_active=327,
|
| 62 |
+
sdr_delta_rank=32,
|
| 63 |
+
sdr_som_warmup=500,
|
| 64 |
+
sdr_som_interval=100,
|
| 65 |
+
htm_n_columns=2048,
|
| 66 |
+
htm_cells_per_column=32,
|
| 67 |
+
mtp_k=int(os.environ["HYDRA_MTP_K"]),
|
| 68 |
+
mtp_weight_decay=0.5,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
print(f"[probe] config: d_model={cfg.d_model} n_layer={cfg.n_layer} "
|
| 72 |
+
f"mtp_k={cfg.mtp_k} vocab={cfg.vocab_size}")
|
| 73 |
+
|
| 74 |
+
torch.manual_seed(0)
|
| 75 |
+
model = PostSemClawModel(cfg).to(device)
|
| 76 |
+
model.init_weights()
|
| 77 |
+
model.train()
|
| 78 |
+
|
| 79 |
+
# ---- Enumerate params & optimizer group assignment ----
|
| 80 |
+
all_params = list(model.named_parameters())
|
| 81 |
+
print(f"[probe] total named parameters: {len(all_params)}")
|
| 82 |
+
|
| 83 |
+
# Build optimizer to check group coverage (no step, no zero_grad).
|
| 84 |
+
opt = model.setup_optimizer()
|
| 85 |
+
grouped_ids: set[int] = set()
|
| 86 |
+
for group in opt.param_groups:
|
| 87 |
+
for p in group["params"]:
|
| 88 |
+
grouped_ids.add(id(p))
|
| 89 |
+
unique_param_ids = {id(p) for _, p in all_params}
|
| 90 |
+
missing_from_opt = unique_param_ids - grouped_ids
|
| 91 |
+
print(f"[probe] params in opt groups: {len(grouped_ids)} / unique: {len(unique_param_ids)}")
|
| 92 |
+
if missing_from_opt:
|
| 93 |
+
print(f"[probe] WARNING: {len(missing_from_opt)} unique params missing from opt groups")
|
| 94 |
+
|
| 95 |
+
# Tied weight check.
|
| 96 |
+
tied = model.wte.weight.data_ptr() == model.lm_head.weight.data_ptr()
|
| 97 |
+
print(f"[probe] tied lm_head<->wte (data_ptr match): {tied}")
|
| 98 |
+
|
| 99 |
+
# ---- One forward + backward under bf16 autocast ----
|
| 100 |
+
B, T = 1, 64
|
| 101 |
+
idx = torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long, device=device)
|
| 102 |
+
tgt = torch.roll(idx, -1, dims=1)
|
| 103 |
+
|
| 104 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 105 |
+
loss = model(idx, targets=tgt)
|
| 106 |
+
print(f"[probe] fwd loss = {float(loss.detach()):.4f}")
|
| 107 |
+
loss.backward()
|
| 108 |
+
torch.cuda.synchronize()
|
| 109 |
+
|
| 110 |
+
# ---- Report ----
|
| 111 |
+
blockers: list[str] = []
|
| 112 |
+
zero_grads: list[str] = []
|
| 113 |
+
unexpected_frozen: list[str] = []
|
| 114 |
+
not_in_opt: list[str] = []
|
| 115 |
+
rows: list[tuple[str, tuple, str, bool, bool, float, float]] = []
|
| 116 |
+
|
| 117 |
+
for name, p in all_params:
|
| 118 |
+
grad_is_none = p.grad is None
|
| 119 |
+
if p.requires_grad and grad_is_none:
|
| 120 |
+
blockers.append(name)
|
| 121 |
+
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
|
| 122 |
+
p.requires_grad, True, float("nan"), float("nan")))
|
| 123 |
+
continue
|
| 124 |
+
if not p.requires_grad:
|
| 125 |
+
unexpected_frozen.append(name)
|
| 126 |
+
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
|
| 127 |
+
False, True, float("nan"), float("nan")))
|
| 128 |
+
continue
|
| 129 |
+
g = p.grad.detach().float()
|
| 130 |
+
abs_mean = float(g.abs().mean().item())
|
| 131 |
+
norm = float(g.norm().item())
|
| 132 |
+
if abs_mean == 0.0 and norm == 0.0:
|
| 133 |
+
zero_grads.append(name)
|
| 134 |
+
if id(p) not in grouped_ids:
|
| 135 |
+
not_in_opt.append(name)
|
| 136 |
+
rows.append((name, tuple(p.shape), str(p.dtype).replace("torch.", ""),
|
| 137 |
+
p.requires_grad, False, abs_mean, norm))
|
| 138 |
+
|
| 139 |
+
# Pretty table
|
| 140 |
+
print("\n[probe] per-parameter grad table:")
|
| 141 |
+
print(f" {'name':<56} {'shape':<22} {'dtype':<8} rg none {'|g|.mean':>10} {'|g|.norm':>10}")
|
| 142 |
+
for name, shape, dtype, rg, none, mean, norm in rows:
|
| 143 |
+
shape_s = "x".join(str(s) for s in shape)
|
| 144 |
+
rg_s = "Y" if rg else "N"
|
| 145 |
+
none_s = "Y" if none else "N"
|
| 146 |
+
if none:
|
| 147 |
+
mean_s, norm_s = " nan ", " nan "
|
| 148 |
+
else:
|
| 149 |
+
mean_s = f"{mean:>10.3e}"
|
| 150 |
+
norm_s = f"{norm:>10.3e}"
|
| 151 |
+
print(f" {name:<56} {shape_s:<22} {dtype:<8} {rg_s} {none_s} {mean_s} {norm_s}")
|
| 152 |
+
|
| 153 |
+
# Identity checks
|
| 154 |
+
print("\n[probe] identity checks:")
|
| 155 |
+
print(f" id(wte.weight) = {id(model.wte.weight)}")
|
| 156 |
+
print(f" id(lm_head.weight) = {id(model.lm_head.weight)}")
|
| 157 |
+
print(f" same Python object = {model.wte.weight is model.lm_head.weight}")
|
| 158 |
+
print(f" same storage ptr = {tied}")
|
| 159 |
+
|
| 160 |
+
# Engram memory inspection
|
| 161 |
+
print(f"\n[probe] engram.memory is nn.Parameter: "
|
| 162 |
+
f"{isinstance(model.engram.memory, torch.nn.Parameter)}")
|
| 163 |
+
print(f" engram.memory.requires_grad = {model.engram.memory.requires_grad}")
|
| 164 |
+
if model.engram.memory.grad is None:
|
| 165 |
+
print(f" engram.memory.grad = None (Hebbian-only path; no autograd through detach())")
|
| 166 |
+
else:
|
| 167 |
+
g = model.engram.memory.grad.detach().float()
|
| 168 |
+
print(f" engram.memory.grad |.mean| = {float(g.abs().mean()):.3e}")
|
| 169 |
+
|
| 170 |
+
# Stash flag sanity: _last_sdr should be uint8, no graph
|
| 171 |
+
last = getattr(model, "_last_sdr", None)
|
| 172 |
+
if last is not None:
|
| 173 |
+
print(f"\n[probe] model._last_sdr dtype={last.dtype}, requires_grad={last.requires_grad}")
|
| 174 |
+
else:
|
| 175 |
+
print("\n[probe] model._last_sdr is None (fwd didn't stash β ok if path changed)")
|
| 176 |
+
|
| 177 |
+
# Summary
|
| 178 |
+
print("\n[probe] ============ SUMMARY ============")
|
| 179 |
+
print(f" BLOCKERS (requires_grad but grad is None): {len(blockers)}")
|
| 180 |
+
for n in blockers:
|
| 181 |
+
print(f" - {n}")
|
| 182 |
+
print(f" WARNINGS (grad is literally zero): {len(zero_grads)}")
|
| 183 |
+
for n in zero_grads:
|
| 184 |
+
print(f" - {n}")
|
| 185 |
+
print(f" WARNINGS (requires_grad=False): {len(unexpected_frozen)}")
|
| 186 |
+
for n in unexpected_frozen:
|
| 187 |
+
print(f" - {n}")
|
| 188 |
+
print(f" WARNINGS (missing from every opt group): {len(not_in_opt)}")
|
| 189 |
+
for n in not_in_opt:
|
| 190 |
+
print(f" - {n}")
|
| 191 |
+
|
| 192 |
+
return 0
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
sys.exit(main())
|
overlay/scripts/launch_feather_hf_job.py
CHANGED
|
@@ -211,9 +211,12 @@ def main() -> int:
|
|
| 211 |
if not USE_SPACE_IMAGE:
|
| 212 |
print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
|
| 213 |
|
|
|
|
| 214 |
if DRY_RUN:
|
| 215 |
-
if 'HYDRA_USE_NEMOTRON' not in os.environ and
|
| 216 |
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
|
|
|
|
|
|
| 217 |
print('[launch] dry-run mode; skipping repo creation, upload, and job submission', flush=True)
|
| 218 |
return 0
|
| 219 |
|
|
@@ -277,9 +280,12 @@ def main() -> int:
|
|
| 277 |
'TRITON_CACHE_DIR': f'/workspace/triton_cache/{GPU_PROFILE}',
|
| 278 |
'TRITON_CACHE_REPO': f'{routing.owner}/feather-triton-cache-{GPU_PROFILE}',
|
| 279 |
}
|
| 280 |
-
if 'HYDRA_USE_NEMOTRON' not in os.environ and
|
| 281 |
env['HYDRA_USE_NEMOTRON'] = '1'
|
| 282 |
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
|
|
|
|
|
|
|
|
|
| 283 |
# A10 compatibility profile: avoid known PTX/compile runtime pitfalls and
|
| 284 |
# keep throughput path enabled. Caller can explicitly override each key by
|
| 285 |
# setting it in the parent environment.
|
|
|
|
| 211 |
if not USE_SPACE_IMAGE:
|
| 212 |
print(f'[launch] image={DEFAULT_IMAGE}', flush=True)
|
| 213 |
|
| 214 |
+
fast_start_streaming = should_enable_fast_start_streaming(TARGET_SHARDS, TIME_BUDGET)
|
| 215 |
if DRY_RUN:
|
| 216 |
+
if 'HYDRA_USE_NEMOTRON' not in os.environ and fast_start_streaming:
|
| 217 |
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
| 218 |
+
if 'HYDRA_LOCAL_SHARDS_ONLY' not in os.environ and fast_start_streaming:
|
| 219 |
+
print('[launch] auto-enabled HYDRA_LOCAL_SHARDS_ONLY=0 for Nemotron streaming fast-start profile', flush=True)
|
| 220 |
print('[launch] dry-run mode; skipping repo creation, upload, and job submission', flush=True)
|
| 221 |
return 0
|
| 222 |
|
|
|
|
| 280 |
'TRITON_CACHE_DIR': f'/workspace/triton_cache/{GPU_PROFILE}',
|
| 281 |
'TRITON_CACHE_REPO': f'{routing.owner}/feather-triton-cache-{GPU_PROFILE}',
|
| 282 |
}
|
| 283 |
+
if 'HYDRA_USE_NEMOTRON' not in os.environ and fast_start_streaming:
|
| 284 |
env['HYDRA_USE_NEMOTRON'] = '1'
|
| 285 |
print('[launch] auto-enabled HYDRA_USE_NEMOTRON=1 for short-budget fast-start profile', flush=True)
|
| 286 |
+
if 'HYDRA_LOCAL_SHARDS_ONLY' not in os.environ and fast_start_streaming:
|
| 287 |
+
env['HYDRA_LOCAL_SHARDS_ONLY'] = '0'
|
| 288 |
+
print('[launch] auto-enabled HYDRA_LOCAL_SHARDS_ONLY=0 for Nemotron streaming fast-start profile', flush=True)
|
| 289 |
# A10 compatibility profile: avoid known PTX/compile runtime pitfalls and
|
| 290 |
# keep throughput path enabled. Caller can explicitly override each key by
|
| 291 |
# setting it in the parent environment.
|
overlay/scripts/profile_forward.py
CHANGED
|
@@ -1,87 +1,87 @@
|
|
| 1 |
-
"""Per-subsystem timing to find the tok/s bottleneck.
|
| 2 |
-
|
| 3 |
-
Runs a single forward+backward at (B=8, T=2048) and times each stage via
|
| 4 |
-
torch.cuda.Event. Reports ms/stage and derived tok/s budget.
|
| 5 |
-
"""
|
| 6 |
-
import os, sys, time
|
| 7 |
-
os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64")
|
| 8 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
-
import torch
|
| 10 |
-
from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN
|
| 11 |
-
|
| 12 |
-
B, T = 8, MAX_SEQ_LEN
|
| 13 |
-
|
| 14 |
-
def timeit(name, fn, warmup=1, n=3):
|
| 15 |
-
for _ in range(warmup):
|
| 16 |
-
fn(); torch.cuda.synchronize()
|
| 17 |
-
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
|
| 18 |
-
times = []
|
| 19 |
-
for _ in range(n):
|
| 20 |
-
torch.cuda.synchronize()
|
| 21 |
-
s.record(); fn(); e.record(); torch.cuda.synchronize()
|
| 22 |
-
times.append(s.elapsed_time(e))
|
| 23 |
-
avg = sum(times)/len(times)
|
| 24 |
-
print(f" {name:30s} {avg:8.2f} ms (min {min(times):.2f} max {max(times):.2f})")
|
| 25 |
-
return avg
|
| 26 |
-
|
| 27 |
-
cfg = PostSemClawConfig()
|
| 28 |
-
model = PostSemClawModel(cfg).cuda()
|
| 29 |
-
model.init_weights()
|
| 30 |
-
model.train()
|
| 31 |
-
idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long)
|
| 32 |
-
y = idx.clone()
|
| 33 |
-
|
| 34 |
-
print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n")
|
| 35 |
-
|
| 36 |
-
# Warmup full forward
|
| 37 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 38 |
-
_ = model(idx, y)
|
| 39 |
-
torch.cuda.synchronize()
|
| 40 |
-
|
| 41 |
-
print("Stage times (3 iter avg):\n")
|
| 42 |
-
|
| 43 |
-
# 1) wte
|
| 44 |
-
timeit("wte embedding", lambda: model.wte(idx).sum().item())
|
| 45 |
-
|
| 46 |
-
# 2) sdr_semantic (STE forward)
|
| 47 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 48 |
-
timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item())
|
| 49 |
-
|
| 50 |
-
# 3) sdr binary_only
|
| 51 |
-
timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item())
|
| 52 |
-
|
| 53 |
-
# 4) HTM full forward (with reset/learn)
|
| 54 |
-
with torch.no_grad():
|
| 55 |
-
timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item())
|
| 56 |
-
|
| 57 |
-
# 5) Mamba block stack only
|
| 58 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 59 |
-
def _blocks():
|
| 60 |
-
x = model.wte(idx)
|
| 61 |
-
from train import norm
|
| 62 |
-
x = norm(x)
|
| 63 |
-
streams = model.mhc[0].init_streams(x)
|
| 64 |
-
for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)):
|
| 65 |
-
def _bfn(h, _b=block): return _b(norm(h))
|
| 66 |
-
streams = mhc_layer(streams, _bfn)
|
| 67 |
-
x = model.mhc[-1].merge_streams(streams)
|
| 68 |
-
return x.sum().item()
|
| 69 |
-
timeit("Mamba+mHC blocks (n_layer=4)", _blocks)
|
| 70 |
-
|
| 71 |
-
# 6) Full forward+loss
|
| 72 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 73 |
-
timeit("FULL forward+loss", lambda: model(idx, y).item())
|
| 74 |
-
|
| 75 |
-
# 7) Full forward+loss+backward
|
| 76 |
-
def full_fwd_bwd():
|
| 77 |
-
model.zero_grad(set_to_none=True)
|
| 78 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 79 |
-
loss = model(idx, y)
|
| 80 |
-
loss.backward()
|
| 81 |
-
return loss.item()
|
| 82 |
-
t_full = timeit("FULL forward+backward", full_fwd_bwd)
|
| 83 |
-
|
| 84 |
-
print()
|
| 85 |
-
print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens")
|
| 86 |
-
print(f"tok/s per forward: {B*T / (t_full/1000):.0f}")
|
| 87 |
-
print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s")
|
|
|
|
| 1 |
+
"""Per-subsystem timing to find the tok/s bottleneck.
|
| 2 |
+
|
| 3 |
+
Runs a single forward+backward at (B=8, T=2048) and times each stage via
|
| 4 |
+
torch.cuda.Event. Reports ms/stage and derived tok/s budget.
|
| 5 |
+
"""
|
| 6 |
+
import os, sys, time
|
| 7 |
+
os.environ.setdefault("LD_LIBRARY_PATH", "/usr/lib/wsl/lib:/usr/local/cuda/lib64")
|
| 8 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
import torch
|
| 10 |
+
from train import PostSemClawModel, PostSemClawConfig, MAX_SEQ_LEN
|
| 11 |
+
|
| 12 |
+
B, T = 8, MAX_SEQ_LEN
|
| 13 |
+
|
| 14 |
+
def timeit(name, fn, warmup=1, n=3):
|
| 15 |
+
for _ in range(warmup):
|
| 16 |
+
fn(); torch.cuda.synchronize()
|
| 17 |
+
s = torch.cuda.Event(enable_timing=True); e = torch.cuda.Event(enable_timing=True)
|
| 18 |
+
times = []
|
| 19 |
+
for _ in range(n):
|
| 20 |
+
torch.cuda.synchronize()
|
| 21 |
+
s.record(); fn(); e.record(); torch.cuda.synchronize()
|
| 22 |
+
times.append(s.elapsed_time(e))
|
| 23 |
+
avg = sum(times)/len(times)
|
| 24 |
+
print(f" {name:30s} {avg:8.2f} ms (min {min(times):.2f} max {max(times):.2f})")
|
| 25 |
+
return avg
|
| 26 |
+
|
| 27 |
+
cfg = PostSemClawConfig()
|
| 28 |
+
model = PostSemClawModel(cfg).cuda()
|
| 29 |
+
model.init_weights()
|
| 30 |
+
model.train()
|
| 31 |
+
idx = torch.randint(0, cfg.vocab_size, (B, T), device="cuda", dtype=torch.long)
|
| 32 |
+
y = idx.clone()
|
| 33 |
+
|
| 34 |
+
print(f"== Profile at B={B} T={T} n_params={sum(p.numel() for p in model.parameters())/1e6:.1f}M ==\n")
|
| 35 |
+
|
| 36 |
+
# Warmup full forward
|
| 37 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 38 |
+
_ = model(idx, y)
|
| 39 |
+
torch.cuda.synchronize()
|
| 40 |
+
|
| 41 |
+
print("Stage times (3 iter avg):\n")
|
| 42 |
+
|
| 43 |
+
# 1) wte
|
| 44 |
+
timeit("wte embedding", lambda: model.wte(idx).sum().item())
|
| 45 |
+
|
| 46 |
+
# 2) sdr_semantic (STE forward)
|
| 47 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 48 |
+
timeit("sdr_semantic forward STE", lambda: model.sdr_semantic(idx).sum().item())
|
| 49 |
+
|
| 50 |
+
# 3) sdr binary_only
|
| 51 |
+
timeit("sdr binary_only", lambda: model.sdr_semantic.binary_only(idx).sum().item())
|
| 52 |
+
|
| 53 |
+
# 4) HTM full forward (with reset/learn)
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
timeit("HTM forward (B=8, T=2048)", lambda: model.htm(model.sdr_semantic.binary_only(idx)).sum().item())
|
| 56 |
+
|
| 57 |
+
# 5) Mamba block stack only
|
| 58 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 59 |
+
def _blocks():
|
| 60 |
+
x = model.wte(idx)
|
| 61 |
+
from train import norm
|
| 62 |
+
x = norm(x)
|
| 63 |
+
streams = model.mhc[0].init_streams(x)
|
| 64 |
+
for i, (block, mhc_layer) in enumerate(zip(model.blocks, model.mhc)):
|
| 65 |
+
def _bfn(h, _b=block): return _b(norm(h))
|
| 66 |
+
streams = mhc_layer(streams, _bfn)
|
| 67 |
+
x = model.mhc[-1].merge_streams(streams)
|
| 68 |
+
return x.sum().item()
|
| 69 |
+
timeit("Mamba+mHC blocks (n_layer=4)", _blocks)
|
| 70 |
+
|
| 71 |
+
# 6) Full forward+loss
|
| 72 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 73 |
+
timeit("FULL forward+loss", lambda: model(idx, y).item())
|
| 74 |
+
|
| 75 |
+
# 7) Full forward+loss+backward
|
| 76 |
+
def full_fwd_bwd():
|
| 77 |
+
model.zero_grad(set_to_none=True)
|
| 78 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 79 |
+
loss = model(idx, y)
|
| 80 |
+
loss.backward()
|
| 81 |
+
return loss.item()
|
| 82 |
+
t_full = timeit("FULL forward+backward", full_fwd_bwd)
|
| 83 |
+
|
| 84 |
+
print()
|
| 85 |
+
print(f"FULL step (fwd+bwd): {t_full:.0f} ms for B*T = {B*T} tokens")
|
| 86 |
+
print(f"tok/s per forward: {B*T / (t_full/1000):.0f}")
|
| 87 |
+
print(f"Expected @MFU=20% on RTX3060 (~25 TFLOPS bf16): ~{25e12*0.2 / (6*7.5e6) / 1000:.0f}k tok/s")
|
overlay/scripts/run_domain_expanded_pretrain.sh
CHANGED
|
@@ -188,11 +188,7 @@ fi
|
|
| 188 |
|
| 189 |
RESUME_PATH="$(resolve_resume_path || true)"
|
| 190 |
|
| 191 |
-
|
| 192 |
-
# (H200/A10G HF Jobs) already have their driver paths set by entrypoint.py.
|
| 193 |
-
if [[ -d /usr/lib/wsl/lib ]]; then
|
| 194 |
-
export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
| 195 |
-
fi
|
| 196 |
export HYDRA_TIME_BUDGET="${HYDRA_TIME_BUDGET:-28800}"
|
| 197 |
export HYDRA_TARGET_SHARDS="$TARGET_SHARDS"
|
| 198 |
export HYDRA_DOWNLOAD_WORKERS="$DOWNLOAD_WORKERS"
|
|
|
|
| 188 |
|
| 189 |
RESUME_PATH="$(resolve_resume_path || true)"
|
| 190 |
|
| 191 |
+
export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
export HYDRA_TIME_BUDGET="${HYDRA_TIME_BUDGET:-28800}"
|
| 193 |
export HYDRA_TARGET_SHARDS="$TARGET_SHARDS"
|
| 194 |
export HYDRA_DOWNLOAD_WORKERS="$DOWNLOAD_WORKERS"
|
overlay/scripts/sample_utils.py
CHANGED
|
@@ -1,107 +1,107 @@
|
|
| 1 |
-
"""Shared sampling utilities for chat.py / chat_eval.py.
|
| 2 |
-
|
| 3 |
-
Pure functions: given a 1-D logits tensor (vocab_size,), return a single
|
| 4 |
-
sampled token id. No model/tokenizer knowledge here.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from __future__ import annotations
|
| 8 |
-
|
| 9 |
-
from typing import Iterable, Optional
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def apply_repetition_penalty(
|
| 15 |
-
logits: torch.Tensor,
|
| 16 |
-
recent_tokens: Optional[Iterable[int]],
|
| 17 |
-
penalty: float,
|
| 18 |
-
) -> torch.Tensor:
|
| 19 |
-
"""Divide logits of recent positive tokens by `penalty`, multiply negatives.
|
| 20 |
-
|
| 21 |
-
Operates in-place on a *copy* (logits is cloned first by caller if needed).
|
| 22 |
-
`recent_tokens` may be any iterable of ints; duplicates are deduped internally.
|
| 23 |
-
"""
|
| 24 |
-
if penalty == 1.0 or not recent_tokens:
|
| 25 |
-
return logits
|
| 26 |
-
seen = set(int(t) for t in recent_tokens)
|
| 27 |
-
if not seen:
|
| 28 |
-
return logits
|
| 29 |
-
idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
|
| 30 |
-
vals = logits.index_select(0, idx)
|
| 31 |
-
vals = torch.where(vals > 0, vals / penalty, vals * penalty)
|
| 32 |
-
logits.index_copy_(0, idx, vals)
|
| 33 |
-
return logits
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
|
| 37 |
-
"""Keep only the top-k logits; set the rest to -inf.
|
| 38 |
-
|
| 39 |
-
top_k<=0 or top_k>=vocab disables the filter."""
|
| 40 |
-
if top_k <= 0 or top_k >= logits.size(-1):
|
| 41 |
-
return logits
|
| 42 |
-
topk_vals, topk_idx = logits.topk(top_k)
|
| 43 |
-
mask = torch.full_like(logits, float("-inf"))
|
| 44 |
-
mask.scatter_(0, topk_idx, topk_vals)
|
| 45 |
-
return mask
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
|
| 49 |
-
"""Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p."""
|
| 50 |
-
if top_p >= 1.0 or top_p <= 0.0:
|
| 51 |
-
return logits
|
| 52 |
-
sorted_logits, sorted_idx = logits.sort(descending=True)
|
| 53 |
-
cumulative_probs = sorted_logits.softmax(-1).cumsum(-1)
|
| 54 |
-
mask = cumulative_probs > top_p
|
| 55 |
-
# shift right so we always keep at least one token
|
| 56 |
-
mask[1:] = mask[:-1].clone()
|
| 57 |
-
mask[0] = False
|
| 58 |
-
sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
|
| 59 |
-
out = torch.full_like(logits, float("-inf"))
|
| 60 |
-
out.scatter_(0, sorted_idx, sorted_logits)
|
| 61 |
-
return out
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def sample_token(
|
| 65 |
-
logits: torch.Tensor,
|
| 66 |
-
temperature: float = 1.0,
|
| 67 |
-
top_k: int = 0,
|
| 68 |
-
top_p: float = 1.0,
|
| 69 |
-
repetition_penalty: float = 1.0,
|
| 70 |
-
recent_tokens: Optional[Iterable[int]] = None,
|
| 71 |
-
) -> int:
|
| 72 |
-
"""Return a single sampled token id (Python int).
|
| 73 |
-
|
| 74 |
-
logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe.
|
| 75 |
-
"""
|
| 76 |
-
if logits.dim() != 1:
|
| 77 |
-
raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}")
|
| 78 |
-
|
| 79 |
-
# Work in fp32 on a clone so the caller's tensor is unchanged.
|
| 80 |
-
work = logits.detach().to(torch.float32).clone()
|
| 81 |
-
|
| 82 |
-
if repetition_penalty != 1.0 and recent_tokens is not None:
|
| 83 |
-
work = apply_repetition_penalty(work, recent_tokens, repetition_penalty)
|
| 84 |
-
|
| 85 |
-
# Temperature. Greedy when temperature <= 0.
|
| 86 |
-
if temperature <= 0.0:
|
| 87 |
-
return int(work.argmax().item())
|
| 88 |
-
work = work / max(temperature, 1e-6)
|
| 89 |
-
|
| 90 |
-
work = apply_top_k(work, top_k)
|
| 91 |
-
work = apply_top_p(work, top_p)
|
| 92 |
-
|
| 93 |
-
# Guard against all-(-inf) (can happen if top_k/top_p filter everything out).
|
| 94 |
-
if torch.isinf(work).all():
|
| 95 |
-
return int(logits.argmax().item())
|
| 96 |
-
|
| 97 |
-
probs = torch.softmax(work, dim=-1)
|
| 98 |
-
# Numerical safety β replace any NaN with 0 and renormalize.
|
| 99 |
-
if torch.isnan(probs).any():
|
| 100 |
-
probs = torch.nan_to_num(probs, nan=0.0)
|
| 101 |
-
s = probs.sum()
|
| 102 |
-
if s <= 0:
|
| 103 |
-
return int(logits.argmax().item())
|
| 104 |
-
probs = probs / s
|
| 105 |
-
|
| 106 |
-
tok = torch.multinomial(probs, num_samples=1)
|
| 107 |
-
return int(tok.item())
|
|
|
|
| 1 |
+
"""Shared sampling utilities for chat.py / chat_eval.py.
|
| 2 |
+
|
| 3 |
+
Pure functions: given a 1-D logits tensor (vocab_size,), return a single
|
| 4 |
+
sampled token id. No model/tokenizer knowledge here.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Iterable, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def apply_repetition_penalty(
|
| 15 |
+
logits: torch.Tensor,
|
| 16 |
+
recent_tokens: Optional[Iterable[int]],
|
| 17 |
+
penalty: float,
|
| 18 |
+
) -> torch.Tensor:
|
| 19 |
+
"""Divide logits of recent positive tokens by `penalty`, multiply negatives.
|
| 20 |
+
|
| 21 |
+
Operates in-place on a *copy* (logits is cloned first by caller if needed).
|
| 22 |
+
`recent_tokens` may be any iterable of ints; duplicates are deduped internally.
|
| 23 |
+
"""
|
| 24 |
+
if penalty == 1.0 or not recent_tokens:
|
| 25 |
+
return logits
|
| 26 |
+
seen = set(int(t) for t in recent_tokens)
|
| 27 |
+
if not seen:
|
| 28 |
+
return logits
|
| 29 |
+
idx = torch.tensor(list(seen), device=logits.device, dtype=torch.long)
|
| 30 |
+
vals = logits.index_select(0, idx)
|
| 31 |
+
vals = torch.where(vals > 0, vals / penalty, vals * penalty)
|
| 32 |
+
logits.index_copy_(0, idx, vals)
|
| 33 |
+
return logits
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
|
| 37 |
+
"""Keep only the top-k logits; set the rest to -inf.
|
| 38 |
+
|
| 39 |
+
top_k<=0 or top_k>=vocab disables the filter."""
|
| 40 |
+
if top_k <= 0 or top_k >= logits.size(-1):
|
| 41 |
+
return logits
|
| 42 |
+
topk_vals, topk_idx = logits.topk(top_k)
|
| 43 |
+
mask = torch.full_like(logits, float("-inf"))
|
| 44 |
+
mask.scatter_(0, topk_idx, topk_vals)
|
| 45 |
+
return mask
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
|
| 49 |
+
"""Nucleus filter: keep smallest set of tokens whose cumulative prob >= top_p."""
|
| 50 |
+
if top_p >= 1.0 or top_p <= 0.0:
|
| 51 |
+
return logits
|
| 52 |
+
sorted_logits, sorted_idx = logits.sort(descending=True)
|
| 53 |
+
cumulative_probs = sorted_logits.softmax(-1).cumsum(-1)
|
| 54 |
+
mask = cumulative_probs > top_p
|
| 55 |
+
# shift right so we always keep at least one token
|
| 56 |
+
mask[1:] = mask[:-1].clone()
|
| 57 |
+
mask[0] = False
|
| 58 |
+
sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))
|
| 59 |
+
out = torch.full_like(logits, float("-inf"))
|
| 60 |
+
out.scatter_(0, sorted_idx, sorted_logits)
|
| 61 |
+
return out
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def sample_token(
|
| 65 |
+
logits: torch.Tensor,
|
| 66 |
+
temperature: float = 1.0,
|
| 67 |
+
top_k: int = 0,
|
| 68 |
+
top_p: float = 1.0,
|
| 69 |
+
repetition_penalty: float = 1.0,
|
| 70 |
+
recent_tokens: Optional[Iterable[int]] = None,
|
| 71 |
+
) -> int:
|
| 72 |
+
"""Return a single sampled token id (Python int).
|
| 73 |
+
|
| 74 |
+
logits: 1-D float tensor of shape (vocab_size,). fp32 or upcast-safe.
|
| 75 |
+
"""
|
| 76 |
+
if logits.dim() != 1:
|
| 77 |
+
raise ValueError(f"sample_token expects 1-D logits, got shape {tuple(logits.shape)}")
|
| 78 |
+
|
| 79 |
+
# Work in fp32 on a clone so the caller's tensor is unchanged.
|
| 80 |
+
work = logits.detach().to(torch.float32).clone()
|
| 81 |
+
|
| 82 |
+
if repetition_penalty != 1.0 and recent_tokens is not None:
|
| 83 |
+
work = apply_repetition_penalty(work, recent_tokens, repetition_penalty)
|
| 84 |
+
|
| 85 |
+
# Temperature. Greedy when temperature <= 0.
|
| 86 |
+
if temperature <= 0.0:
|
| 87 |
+
return int(work.argmax().item())
|
| 88 |
+
work = work / max(temperature, 1e-6)
|
| 89 |
+
|
| 90 |
+
work = apply_top_k(work, top_k)
|
| 91 |
+
work = apply_top_p(work, top_p)
|
| 92 |
+
|
| 93 |
+
# Guard against all-(-inf) (can happen if top_k/top_p filter everything out).
|
| 94 |
+
if torch.isinf(work).all():
|
| 95 |
+
return int(logits.argmax().item())
|
| 96 |
+
|
| 97 |
+
probs = torch.softmax(work, dim=-1)
|
| 98 |
+
# Numerical safety β replace any NaN with 0 and renormalize.
|
| 99 |
+
if torch.isnan(probs).any():
|
| 100 |
+
probs = torch.nan_to_num(probs, nan=0.0)
|
| 101 |
+
s = probs.sum()
|
| 102 |
+
if s <= 0:
|
| 103 |
+
return int(logits.argmax().item())
|
| 104 |
+
probs = probs / s
|
| 105 |
+
|
| 106 |
+
tok = torch.multinomial(probs, num_samples=1)
|
| 107 |
+
return int(tok.item())
|
overlay/scripts/sft.py
CHANGED
|
@@ -1,559 +1,559 @@
|
|
| 1 |
-
"""HYDRA SFT β instruction fine-tune the pretrained 7.5M-param base.
|
| 2 |
-
|
| 3 |
-
Mode selection:
|
| 4 |
-
MODE=resume_from_pretrain iff ~/.cache/autoresearch/pretrain_final.pt
|
| 5 |
-
exists AND loads cleanly into a fresh model.
|
| 6 |
-
MODE=from_scratch otherwise (degraded fallback).
|
| 7 |
-
|
| 8 |
-
Data: int16 shards written by `scripts/download_sft_data.py`, paired with
|
| 9 |
-
uint8 loss-mask shards (1 on assistant tokens, 0 on user-prompt tokens).
|
| 10 |
-
At runtime we pack consecutive examples into fixed-length rows; prompt
|
| 11 |
-
positions get target=-1 so CE's `ignore_index=-1` drops them.
|
| 12 |
-
|
| 13 |
-
Env vars (with defaults tuned for RTX 3060 6GB, 7.5M params):
|
| 14 |
-
HYDRA_SFT_TIME_BUDGET 10800 SFT wall-clock budget (3h)
|
| 15 |
-
HYDRA_SFT_SEQ_LEN 512 sequence length during SFT
|
| 16 |
-
HYDRA_BATCH_SIZE 4 per-step device batch
|
| 17 |
-
HYDRA_TOTAL_BATCH 8192 effective batch (grad-accum derived)
|
| 18 |
-
HYDRA_SFT_LR_MULT 0.10 multiply pretrain LRs by this
|
| 19 |
-
HYDRA_SFT_EVAL_INTERVAL 500 steps between sample generations
|
| 20 |
-
HYDRA_SFT_CKPT_INTERVAL 2000 steps between interim checkpoints
|
| 21 |
-
|
| 22 |
-
CLI:
|
| 23 |
-
--dry-run load model+data, run 1 step, exit (validation path)
|
| 24 |
-
--eval-only load `sft_final.pt`, run sample gen, exit
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
from __future__ import annotations
|
| 28 |
-
|
| 29 |
-
import argparse
|
| 30 |
-
import json
|
| 31 |
-
import math
|
| 32 |
-
import os
|
| 33 |
-
import sys
|
| 34 |
-
import time
|
| 35 |
-
from dataclasses import asdict
|
| 36 |
-
from pathlib import Path
|
| 37 |
-
|
| 38 |
-
import numpy as np
|
| 39 |
-
import torch
|
| 40 |
-
|
| 41 |
-
# Repo root on path
|
| 42 |
-
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 43 |
-
if str(_REPO_ROOT) not in sys.path:
|
| 44 |
-
sys.path.insert(0, str(_REPO_ROOT))
|
| 45 |
-
|
| 46 |
-
# Must import hydra.config BEFORE touching torch.cuda for CUDA env setup
|
| 47 |
-
from hydra.config import (
|
| 48 |
-
ADAM_BETAS, D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMBEDDING_LR,
|
| 49 |
-
ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
|
| 50 |
-
FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
|
| 51 |
-
N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
|
| 52 |
-
UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY,
|
| 53 |
-
)
|
| 54 |
-
from hydra.model import PostSemClawModel
|
| 55 |
-
from prepare import Tokenizer
|
| 56 |
-
|
| 57 |
-
# Use line-buffered stdout
|
| 58 |
-
try:
|
| 59 |
-
sys.stdout.reconfigure(line_buffering=True)
|
| 60 |
-
except Exception:
|
| 61 |
-
pass
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# ---------------------------------------------------------------------------
|
| 65 |
-
# Paths
|
| 66 |
-
# ---------------------------------------------------------------------------
|
| 67 |
-
|
| 68 |
-
CACHE_DIR = Path.home() / ".cache" / "autoresearch"
|
| 69 |
-
PRETRAIN_CKPT = CACHE_DIR / "pretrain_final.pt"
|
| 70 |
-
SFT_FINAL_CKPT = CACHE_DIR / "sft_final.pt"
|
| 71 |
-
SFT_INTERIM_CKPT = CACHE_DIR / "sft_interim.pt"
|
| 72 |
-
SFT_DATA_DIR = _REPO_ROOT / "data" / "sft"
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
# ---------------------------------------------------------------------------
|
| 76 |
-
# Env vars for SFT
|
| 77 |
-
# ---------------------------------------------------------------------------
|
| 78 |
-
|
| 79 |
-
SFT_TIME_BUDGET = int(os.environ.get("HYDRA_SFT_TIME_BUDGET", "10800"))
|
| 80 |
-
SFT_SEQ_LEN = int(os.environ.get("HYDRA_SFT_SEQ_LEN", "512"))
|
| 81 |
-
SFT_LR_MULT = float(os.environ.get("HYDRA_SFT_LR_MULT", "0.10"))
|
| 82 |
-
SFT_EVAL_INTERVAL = int(os.environ.get("HYDRA_SFT_EVAL_INTERVAL", "500"))
|
| 83 |
-
SFT_CKPT_INTERVAL = int(os.environ.get("HYDRA_SFT_CKPT_INTERVAL", "2000"))
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ---------------------------------------------------------------------------
|
| 87 |
-
# Data loading
|
| 88 |
-
# ---------------------------------------------------------------------------
|
| 89 |
-
|
| 90 |
-
def _load_meta() -> dict:
|
| 91 |
-
meta_path = SFT_DATA_DIR / "meta.json"
|
| 92 |
-
if not meta_path.exists():
|
| 93 |
-
raise FileNotFoundError(
|
| 94 |
-
f"SFT meta not found at {meta_path}. Run "
|
| 95 |
-
f"`python scripts/download_sft_data.py` first."
|
| 96 |
-
)
|
| 97 |
-
with open(meta_path) as f:
|
| 98 |
-
return json.load(f)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def _load_shards():
|
| 102 |
-
"""Load all shard_XXX.bin + mask_XXX.bin as big flat arrays.
|
| 103 |
-
|
| 104 |
-
Returns: (tokens: np.int64, mask: np.uint8)
|
| 105 |
-
Both arrays are 1-D and the same length. Total len ~= target_tokens.
|
| 106 |
-
"""
|
| 107 |
-
tok_shards = sorted(SFT_DATA_DIR.glob("shard_*.bin"))
|
| 108 |
-
mask_shards = sorted(SFT_DATA_DIR.glob("mask_*.bin"))
|
| 109 |
-
if not tok_shards:
|
| 110 |
-
raise FileNotFoundError(f"No SFT shards in {SFT_DATA_DIR}")
|
| 111 |
-
assert len(tok_shards) == len(mask_shards), (
|
| 112 |
-
f"shard/mask count mismatch: {len(tok_shards)} vs {len(mask_shards)}"
|
| 113 |
-
)
|
| 114 |
-
tok_parts = []
|
| 115 |
-
mask_parts = []
|
| 116 |
-
for t, m in zip(tok_shards, mask_shards):
|
| 117 |
-
tok_parts.append(np.fromfile(str(t), dtype=np.int16).astype(np.int64))
|
| 118 |
-
mask_parts.append(np.fromfile(str(m), dtype=np.uint8))
|
| 119 |
-
tokens = np.concatenate(tok_parts)
|
| 120 |
-
mask = np.concatenate(mask_parts)
|
| 121 |
-
assert tokens.shape == mask.shape
|
| 122 |
-
# Guard against negative int16 values (unlikely with vocab=8192 but defensive)
|
| 123 |
-
if tokens.min() < 0 or tokens.max() >= 8192:
|
| 124 |
-
raise ValueError(
|
| 125 |
-
f"Token IDs out of range: min={tokens.min()} max={tokens.max()}"
|
| 126 |
-
)
|
| 127 |
-
return tokens, mask
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def make_sft_dataloader(tokens: np.ndarray, mask: np.ndarray, B: int, T: int,
|
| 131 |
-
device: torch.device, seed: int = 0):
|
| 132 |
-
"""Yield (x, y, epoch) forever.
|
| 133 |
-
|
| 134 |
-
Each row is a slice of length T+1 sampled at a random start. We produce:
|
| 135 |
-
x = slice[:-1] (B, T) int64 on device
|
| 136 |
-
y = slice[1:] with mask=0 -> -1 (B, T) int64 on device
|
| 137 |
-
|
| 138 |
-
The mask applies to target positions (y), not inputs. This way the
|
| 139 |
-
chunked CE loss in model.forward sees ignore_index=-1 for prompt tokens.
|
| 140 |
-
"""
|
| 141 |
-
N = tokens.shape[0]
|
| 142 |
-
rng = np.random.default_rng(seed)
|
| 143 |
-
# Pin CPU tensors; copy to GPU non-blocking.
|
| 144 |
-
cpu_x = torch.empty(B, T, dtype=torch.long, pin_memory=True)
|
| 145 |
-
cpu_y = torch.empty(B, T, dtype=torch.long, pin_memory=True)
|
| 146 |
-
epoch = 1
|
| 147 |
-
samples_drawn = 0
|
| 148 |
-
samples_per_epoch = max(1, N // (T + 1))
|
| 149 |
-
|
| 150 |
-
# Minimum loss-positions per window. If a sampled window has fewer than
|
| 151 |
-
# this many assistant tokens, resample. Guards against all-prompt windows
|
| 152 |
-
# producing NaN from 0/0 in the chunked CE loss.
|
| 153 |
-
min_loss_positions = max(1, T // 32)
|
| 154 |
-
max_resample = 8
|
| 155 |
-
|
| 156 |
-
while True:
|
| 157 |
-
for b in range(B):
|
| 158 |
-
# Sample a starting index with a light rejection filter to ensure
|
| 159 |
-
# the window contains enough assistant (mask=1) positions.
|
| 160 |
-
if N <= T + 1:
|
| 161 |
-
start = 0
|
| 162 |
-
else:
|
| 163 |
-
start = int(rng.integers(0, N - T - 1))
|
| 164 |
-
for _ in range(max_resample):
|
| 165 |
-
loss_in_window = int(mask[start + 1:start + T + 1].sum())
|
| 166 |
-
if loss_in_window >= min_loss_positions:
|
| 167 |
-
break
|
| 168 |
-
start = int(rng.integers(0, N - T - 1))
|
| 169 |
-
window_tok = tokens[start:start + T + 1]
|
| 170 |
-
window_mask = mask[start:start + T + 1]
|
| 171 |
-
# x = window[:-1], y = window[1:]
|
| 172 |
-
cpu_x[b].copy_(torch.from_numpy(window_tok[:-1].astype(np.int64)))
|
| 173 |
-
y_slice = window_tok[1:].astype(np.int64).copy()
|
| 174 |
-
# Apply mask to targets: mask=0 (prompt) -> target=-1 (ignore)
|
| 175 |
-
y_slice[window_mask[1:] == 0] = -1
|
| 176 |
-
# Final guard: if no loss positions survived, force at least 1
|
| 177 |
-
# valid target so the batch doesn't produce NaN (it's rare with
|
| 178 |
-
# the rejection filter but defensive is cheap).
|
| 179 |
-
if (y_slice != -1).sum() == 0:
|
| 180 |
-
y_slice[-1] = int(window_tok[-1])
|
| 181 |
-
cpu_y[b].copy_(torch.from_numpy(y_slice))
|
| 182 |
-
x = cpu_x.to(device, non_blocking=True)
|
| 183 |
-
y = cpu_y.to(device, non_blocking=True)
|
| 184 |
-
samples_drawn += B
|
| 185 |
-
if samples_drawn >= samples_per_epoch:
|
| 186 |
-
epoch += 1
|
| 187 |
-
samples_drawn = 0
|
| 188 |
-
yield x, y, epoch
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
# ---------------------------------------------------------------------------
|
| 192 |
-
# Model init + checkpoint load
|
| 193 |
-
# ---------------------------------------------------------------------------
|
| 194 |
-
|
| 195 |
-
def _peek_pretrain_config(vocab_size: int) -> PostSemClawConfig | None:
|
| 196 |
-
"""If pretrain checkpoint exists, return its saved config so we build
|
| 197 |
-
the SFT model with matching architecture. Returns None if unavailable."""
|
| 198 |
-
if not PRETRAIN_CKPT.exists():
|
| 199 |
-
return None
|
| 200 |
-
try:
|
| 201 |
-
ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cpu",
|
| 202 |
-
weights_only=False)
|
| 203 |
-
cfg_dict = ckpt.get("config")
|
| 204 |
-
if cfg_dict is None:
|
| 205 |
-
return None
|
| 206 |
-
# Override sequence_len to SFT's (shorter context) β architecture
|
| 207 |
-
# is independent of sequence_len since Mamba3 is recurrent.
|
| 208 |
-
cfg_dict = dict(cfg_dict)
|
| 209 |
-
cfg_dict["sequence_len"] = SFT_SEQ_LEN
|
| 210 |
-
cfg_dict["vocab_size"] = vocab_size
|
| 211 |
-
cfg = PostSemClawConfig(**cfg_dict)
|
| 212 |
-
return cfg
|
| 213 |
-
except Exception as e:
|
| 214 |
-
print(f"[model] could not peek pretrain config: {type(e).__name__}: {e}",
|
| 215 |
-
flush=True)
|
| 216 |
-
return None
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def build_model(vocab_size: int, device: torch.device) -> PostSemClawModel:
|
| 220 |
-
# Prefer checkpoint-derived config if available (guards against env-var drift)
|
| 221 |
-
config = _peek_pretrain_config(vocab_size)
|
| 222 |
-
if config is None:
|
| 223 |
-
config = PostSemClawConfig(
|
| 224 |
-
sequence_len=SFT_SEQ_LEN,
|
| 225 |
-
vocab_size=vocab_size,
|
| 226 |
-
n_layer=N_LAYER,
|
| 227 |
-
d_model=D_MODEL,
|
| 228 |
-
d_state=D_STATE,
|
| 229 |
-
headdim=HEADDIM,
|
| 230 |
-
n_heads=N_HEADS,
|
| 231 |
-
expand=EXPAND,
|
| 232 |
-
engram_n_columns=ENGRAM_N_COLUMNS,
|
| 233 |
-
engram_key_dim=ENGRAM_KEY_DIM,
|
| 234 |
-
engram_layer_idx=ENGRAM_LAYER_IDX,
|
| 235 |
-
)
|
| 236 |
-
print(f"[model] config (from env, no ckpt): {asdict(config)}", flush=True)
|
| 237 |
-
else:
|
| 238 |
-
print(f"[model] config (from pretrain ckpt): {asdict(config)}", flush=True)
|
| 239 |
-
with torch.device("meta"):
|
| 240 |
-
model = PostSemClawModel(config)
|
| 241 |
-
model.to_empty(device=device)
|
| 242 |
-
model.init_weights()
|
| 243 |
-
return model
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
def try_load_pretrain(model: PostSemClawModel) -> tuple[bool, str]:
|
| 247 |
-
"""Attempt to load pretrain checkpoint into model. Returns (loaded, msg)."""
|
| 248 |
-
if not PRETRAIN_CKPT.exists():
|
| 249 |
-
return False, f"no checkpoint at {PRETRAIN_CKPT}"
|
| 250 |
-
try:
|
| 251 |
-
ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cuda",
|
| 252 |
-
weights_only=False)
|
| 253 |
-
state = ckpt.get("model_state_dict", ckpt)
|
| 254 |
-
# Use strict=False in case SDR/HTM params are excluded from state_dict
|
| 255 |
-
# by torch.compile wrappers or similar.
|
| 256 |
-
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 257 |
-
msg = (f"loaded {PRETRAIN_CKPT} β missing={len(missing)} "
|
| 258 |
-
f"unexpected={len(unexpected)}")
|
| 259 |
-
if missing:
|
| 260 |
-
# Log first few missing keys to help diagnose architecture skew
|
| 261 |
-
msg += f" first_missing={missing[:3]}"
|
| 262 |
-
return True, msg
|
| 263 |
-
except Exception as e:
|
| 264 |
-
return False, f"load failed: {type(e).__name__}: {e}"
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
# ---------------------------------------------------------------------------
|
| 268 |
-
# Sample generation (for in-training eval prints)
|
| 269 |
-
# ---------------------------------------------------------------------------
|
| 270 |
-
|
| 271 |
-
_SAMPLE_PROMPTS = [
|
| 272 |
-
"What is the capital of France?",
|
| 273 |
-
"Write a haiku about winter.",
|
| 274 |
-
"List three colors.",
|
| 275 |
-
"How are you?",
|
| 276 |
-
"Explain why the sky is blue in one sentence.",
|
| 277 |
-
]
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
@torch.no_grad()
|
| 281 |
-
def sample_once(model, tokenizer, meta: dict, prompt: str,
|
| 282 |
-
max_new: int = 64, temperature: float = 0.8,
|
| 283 |
-
top_k: int = 40) -> str:
|
| 284 |
-
"""Generate a chat-formatted reply. Stops on <|end|> or max_new tokens."""
|
| 285 |
-
bos = meta["special_tokens"]["bos"]
|
| 286 |
-
user = meta["special_tokens"]["user"]
|
| 287 |
-
assistant = meta["special_tokens"]["assistant"]
|
| 288 |
-
end = meta["special_tokens"]["end"]
|
| 289 |
-
|
| 290 |
-
prompt_ids = [bos, user] + tokenizer.encode("\n" + prompt.strip())
|
| 291 |
-
prompt_ids += tokenizer.encode("\n")
|
| 292 |
-
prompt_ids.append(assistant)
|
| 293 |
-
prompt_ids += tokenizer.encode("\n")
|
| 294 |
-
|
| 295 |
-
ctx = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
|
| 296 |
-
generated: list[int] = []
|
| 297 |
-
for _ in range(max_new):
|
| 298 |
-
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 299 |
-
logits = model(ctx, targets=None)
|
| 300 |
-
last = logits[0, -1].float()
|
| 301 |
-
if top_k and top_k < last.shape[-1]:
|
| 302 |
-
kth = torch.topk(last, top_k).values[-1]
|
| 303 |
-
last = torch.where(last < kth, torch.full_like(last, -1e9), last)
|
| 304 |
-
probs = torch.softmax(last / max(temperature, 1e-6), dim=-1)
|
| 305 |
-
next_id = int(torch.multinomial(probs, num_samples=1).item())
|
| 306 |
-
generated.append(next_id)
|
| 307 |
-
if next_id == end:
|
| 308 |
-
break
|
| 309 |
-
ctx = torch.cat(
|
| 310 |
-
[ctx, torch.tensor([[next_id]], device="cuda", dtype=torch.long)],
|
| 311 |
-
dim=1,
|
| 312 |
-
)
|
| 313 |
-
# Hard cap on ctx length (model was trained at 2048, SFT at 512,
|
| 314 |
-
# but inference could theoretically go longer)
|
| 315 |
-
if ctx.size(1) >= 2048:
|
| 316 |
-
break
|
| 317 |
-
try:
|
| 318 |
-
text = tokenizer.decode(generated)
|
| 319 |
-
except Exception:
|
| 320 |
-
text = "<decode error>"
|
| 321 |
-
return text
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
def run_samples(model, tokenizer, meta: dict, step: int):
|
| 325 |
-
model.eval()
|
| 326 |
-
print(f"\n=== SFT samples @ step {step} ===", flush=True)
|
| 327 |
-
for p in _SAMPLE_PROMPTS:
|
| 328 |
-
try:
|
| 329 |
-
resp = sample_once(model, tokenizer, meta, p)
|
| 330 |
-
except Exception as e:
|
| 331 |
-
resp = f"<sample failed: {type(e).__name__}: {e}>"
|
| 332 |
-
# Sanitize newlines for log readability
|
| 333 |
-
resp_clean = resp.replace("\n", " β ").replace("\r", " ")
|
| 334 |
-
print(f" prompt: {p!r}")
|
| 335 |
-
print(f" reply: {resp_clean!r}")
|
| 336 |
-
print("=== end samples ===\n", flush=True)
|
| 337 |
-
model.train()
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
# ---------------------------------------------------------------------------
|
| 341 |
-
# Checkpoint save
|
| 342 |
-
# ---------------------------------------------------------------------------
|
| 343 |
-
|
| 344 |
-
def save_ckpt(model, step: int, smoothed_loss: float, path: Path,
|
| 345 |
-
mode: str, meta: dict):
|
| 346 |
-
try:
|
| 347 |
-
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 348 |
-
payload = {
|
| 349 |
-
"model_state_dict": model.state_dict(),
|
| 350 |
-
"step": step,
|
| 351 |
-
"smoothed_loss": smoothed_loss,
|
| 352 |
-
"mode": mode,
|
| 353 |
-
"sft_meta": meta,
|
| 354 |
-
}
|
| 355 |
-
torch.save(payload, str(path))
|
| 356 |
-
print(f"[ckpt] saved {path} (step={step})", flush=True)
|
| 357 |
-
except Exception as e:
|
| 358 |
-
print(f"[ckpt] SAVE FAILED {path}: {type(e).__name__}: {e}", flush=True)
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
# ---------------------------------------------------------------------------
|
| 362 |
-
# Main
|
| 363 |
-
# ---------------------------------------------------------------------------
|
| 364 |
-
|
| 365 |
-
def main():
|
| 366 |
-
ap = argparse.ArgumentParser()
|
| 367 |
-
ap.add_argument("--dry-run", action="store_true",
|
| 368 |
-
help="Load model+data, run 1 step, exit.")
|
| 369 |
-
ap.add_argument("--eval-only", action="store_true",
|
| 370 |
-
help="Load sft_final.pt and run sample gen.")
|
| 371 |
-
args = ap.parse_args()
|
| 372 |
-
|
| 373 |
-
t_start = time.time()
|
| 374 |
-
torch.manual_seed(SEED + 1) # +1 so SFT draws different RNG than pretrain
|
| 375 |
-
torch.cuda.manual_seed(SEED + 1)
|
| 376 |
-
torch.set_float32_matmul_precision("high")
|
| 377 |
-
device = torch.device("cuda")
|
| 378 |
-
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 379 |
-
|
| 380 |
-
# --- Tokenizer ---
|
| 381 |
-
tokenizer = Tokenizer.from_directory()
|
| 382 |
-
vocab_size = tokenizer.get_vocab_size()
|
| 383 |
-
print(f"[init] vocab: {vocab_size}", flush=True)
|
| 384 |
-
|
| 385 |
-
# --- Data meta ---
|
| 386 |
-
meta = _load_meta()
|
| 387 |
-
print(f"[data] meta: {meta}", flush=True)
|
| 388 |
-
|
| 389 |
-
# --- Model ---
|
| 390 |
-
model = build_model(vocab_size, device)
|
| 391 |
-
n_params = sum(p.numel() for p in model.parameters())
|
| 392 |
-
print(f"[model] params: {n_params:,}", flush=True)
|
| 393 |
-
|
| 394 |
-
loaded, msg = try_load_pretrain(model)
|
| 395 |
-
mode = "resume_from_pretrain" if loaded else "from_scratch"
|
| 396 |
-
print(f"[init] MODE={mode} :: {msg}", flush=True)
|
| 397 |
-
|
| 398 |
-
# --- Eval-only path ---
|
| 399 |
-
if args.eval_only:
|
| 400 |
-
if SFT_FINAL_CKPT.exists():
|
| 401 |
-
ckpt = torch.load(str(SFT_FINAL_CKPT), map_location=device,
|
| 402 |
-
weights_only=False)
|
| 403 |
-
state = ckpt.get("model_state_dict", ckpt)
|
| 404 |
-
model.load_state_dict(state, strict=False)
|
| 405 |
-
print(f"[eval-only] loaded {SFT_FINAL_CKPT}", flush=True)
|
| 406 |
-
else:
|
| 407 |
-
print(f"[eval-only] no SFT checkpoint β running on current weights",
|
| 408 |
-
flush=True)
|
| 409 |
-
run_samples(model, tokenizer, meta, step=-1)
|
| 410 |
-
return
|
| 411 |
-
|
| 412 |
-
# --- Dataloader ---
|
| 413 |
-
print(f"[data] loading shards ...", flush=True)
|
| 414 |
-
tokens, mask = _load_shards()
|
| 415 |
-
print(f"[data] tokens: {len(tokens):,} loss-positions: {int(mask.sum()):,}",
|
| 416 |
-
flush=True)
|
| 417 |
-
B = DEVICE_BATCH_SIZE
|
| 418 |
-
T = SFT_SEQ_LEN
|
| 419 |
-
tokens_per_fwdbwd = B * T
|
| 420 |
-
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, (
|
| 421 |
-
f"TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} not divisible by B*T={tokens_per_fwdbwd}"
|
| 422 |
-
)
|
| 423 |
-
grad_accum = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
|
| 424 |
-
print(f"[train] B={B} T={T} accum={grad_accum} effective_batch={TOTAL_BATCH_SIZE}",
|
| 425 |
-
flush=True)
|
| 426 |
-
loader = make_sft_dataloader(tokens, mask, B, T, device, seed=SEED + 7)
|
| 427 |
-
x, y, epoch = next(loader)
|
| 428 |
-
|
| 429 |
-
# --- Optimizer (scaled LRs) ---
|
| 430 |
-
matrix_lr = MATRIX_LR * SFT_LR_MULT
|
| 431 |
-
embed_lr = EMBEDDING_LR * SFT_LR_MULT
|
| 432 |
-
unembed_lr = UNEMBEDDING_LR * SFT_LR_MULT
|
| 433 |
-
scalar_lr = SCALAR_LR * SFT_LR_MULT
|
| 434 |
-
print(f"[opt] LRs scaled by {SFT_LR_MULT}: matrix={matrix_lr:.5f} "
|
| 435 |
-
f"embed={embed_lr:.5f} unembed={unembed_lr:.6f}", flush=True)
|
| 436 |
-
optimizer = model.setup_optimizer(
|
| 437 |
-
unembedding_lr=unembed_lr,
|
| 438 |
-
embedding_lr=embed_lr,
|
| 439 |
-
scalar_lr=scalar_lr,
|
| 440 |
-
adam_betas=ADAM_BETAS,
|
| 441 |
-
matrix_lr=matrix_lr,
|
| 442 |
-
weight_decay=WEIGHT_DECAY,
|
| 443 |
-
)
|
| 444 |
-
|
| 445 |
-
# --- Dry-run path (validation) ---
|
| 446 |
-
if args.dry_run:
|
| 447 |
-
print("[dry-run] running 1 step ...", flush=True)
|
| 448 |
-
with autocast_ctx:
|
| 449 |
-
loss = model(x, y)
|
| 450 |
-
loss_f = float(loss.item())
|
| 451 |
-
print(f"[dry-run] step0 loss={loss_f:.4f}", flush=True)
|
| 452 |
-
loss.backward()
|
| 453 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 454 |
-
optimizer.step()
|
| 455 |
-
model.zero_grad(set_to_none=True)
|
| 456 |
-
if math.isnan(loss_f) or loss_f > 100:
|
| 457 |
-
print("[dry-run] FAILED (NaN / huge loss)", flush=True)
|
| 458 |
-
sys.exit(1)
|
| 459 |
-
print("[dry-run] OK", flush=True)
|
| 460 |
-
return
|
| 461 |
-
|
| 462 |
-
# --- Training loop ---
|
| 463 |
-
print(f"[train] budget={SFT_TIME_BUDGET}s eval_every={SFT_EVAL_INTERVAL} "
|
| 464 |
-
f"ckpt_every={SFT_CKPT_INTERVAL}", flush=True)
|
| 465 |
-
t_loop_start = time.time()
|
| 466 |
-
smooth_loss = 0.0
|
| 467 |
-
step = 0
|
| 468 |
-
total_train_secs = 0.0
|
| 469 |
-
|
| 470 |
-
# Warmup schedule for SFT: linear 0->1 over first 5% of budget, then cosine.
|
| 471 |
-
sft_warmup_frac = 0.05
|
| 472 |
-
|
| 473 |
-
def lr_mult(progress: float) -> float:
|
| 474 |
-
if progress < sft_warmup_frac:
|
| 475 |
-
return progress / sft_warmup_frac if sft_warmup_frac > 0 else 1.0
|
| 476 |
-
decay = (progress - sft_warmup_frac) / (1.0 - sft_warmup_frac)
|
| 477 |
-
return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * \
|
| 478 |
-
(1 + math.cos(math.pi * decay))
|
| 479 |
-
|
| 480 |
-
while True:
|
| 481 |
-
torch.cuda.synchronize()
|
| 482 |
-
t0 = time.time()
|
| 483 |
-
for _ in range(grad_accum):
|
| 484 |
-
with autocast_ctx:
|
| 485 |
-
loss = model(x, y)
|
| 486 |
-
train_loss_val = loss.detach()
|
| 487 |
-
(loss / grad_accum).backward()
|
| 488 |
-
x, y, epoch = next(loader)
|
| 489 |
-
|
| 490 |
-
progress = min(total_train_secs / SFT_TIME_BUDGET, 1.0)
|
| 491 |
-
mult = lr_mult(progress)
|
| 492 |
-
for group in optimizer.param_groups:
|
| 493 |
-
group["lr"] = group["initial_lr"] * mult
|
| 494 |
-
|
| 495 |
-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 496 |
-
optimizer.step()
|
| 497 |
-
model.zero_grad(set_to_none=True)
|
| 498 |
-
|
| 499 |
-
loss_f = float(train_loss_val.item())
|
| 500 |
-
if math.isnan(loss_f) or loss_f > 100:
|
| 501 |
-
print(f"[FAIL] step={step} loss={loss_f} β aborting", flush=True)
|
| 502 |
-
save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
|
| 503 |
-
sys.exit(1)
|
| 504 |
-
|
| 505 |
-
torch.cuda.synchronize()
|
| 506 |
-
dt = time.time() - t0
|
| 507 |
-
if step > 3:
|
| 508 |
-
total_train_secs += dt
|
| 509 |
-
|
| 510 |
-
# EMA loss (debiased)
|
| 511 |
-
beta = 0.9
|
| 512 |
-
smooth_loss = beta * smooth_loss + (1 - beta) * loss_f
|
| 513 |
-
debiased = smooth_loss / (1 - beta ** (step + 1))
|
| 514 |
-
bpt = debiased / math.log(2)
|
| 515 |
-
tps = int(TOTAL_BATCH_SIZE / dt) if dt > 0 else 0
|
| 516 |
-
vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
|
| 517 |
-
lr_now = optimizer.param_groups[0]["lr"]
|
| 518 |
-
remaining = max(0, SFT_TIME_BUDGET - total_train_secs)
|
| 519 |
-
|
| 520 |
-
print(
|
| 521 |
-
f"sft_step={step:05d} loss={debiased:.4f} bpt={bpt:.3f} "
|
| 522 |
-
f"tps={tps} dt_ms={dt*1000:.0f} lr={lr_now:.2e} "
|
| 523 |
-
f"vram={vram_mib:.0f}MiB pct={100*progress:.1f} "
|
| 524 |
-
f"epoch={epoch} remaining={remaining:.0f}s",
|
| 525 |
-
flush=True,
|
| 526 |
-
)
|
| 527 |
-
|
| 528 |
-
if step > 0 and step % SFT_EVAL_INTERVAL == 0:
|
| 529 |
-
run_samples(model, tokenizer, meta, step)
|
| 530 |
-
|
| 531 |
-
if step > 0 and step % SFT_CKPT_INTERVAL == 0:
|
| 532 |
-
save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
|
| 533 |
-
|
| 534 |
-
step += 1
|
| 535 |
-
|
| 536 |
-
if step > 5 and total_train_secs >= SFT_TIME_BUDGET:
|
| 537 |
-
break
|
| 538 |
-
|
| 539 |
-
# Final samples + save
|
| 540 |
-
run_samples(model, tokenizer, meta, step)
|
| 541 |
-
save_ckpt(model, step, smooth_loss, SFT_FINAL_CKPT, mode, meta)
|
| 542 |
-
|
| 543 |
-
total_secs = time.time() - t_start
|
| 544 |
-
print("---", flush=True)
|
| 545 |
-
print(f"SFT_COMPLETE mode={mode} step={step} "
|
| 546 |
-
f"smoothed_loss={smooth_loss:.4f} total_seconds={total_secs:.0f} "
|
| 547 |
-
f"train_seconds={total_train_secs:.0f}", flush=True)
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
if __name__ == "__main__":
|
| 551 |
-
try:
|
| 552 |
-
main()
|
| 553 |
-
except SystemExit:
|
| 554 |
-
raise
|
| 555 |
-
except Exception as e:
|
| 556 |
-
import traceback
|
| 557 |
-
print(f"SFT_FAILED {type(e).__name__}: {e}", flush=True)
|
| 558 |
-
traceback.print_exc()
|
| 559 |
-
sys.exit(1)
|
|
|
|
| 1 |
+
"""HYDRA SFT β instruction fine-tune the pretrained 7.5M-param base.
|
| 2 |
+
|
| 3 |
+
Mode selection:
|
| 4 |
+
MODE=resume_from_pretrain iff ~/.cache/autoresearch/pretrain_final.pt
|
| 5 |
+
exists AND loads cleanly into a fresh model.
|
| 6 |
+
MODE=from_scratch otherwise (degraded fallback).
|
| 7 |
+
|
| 8 |
+
Data: int16 shards written by `scripts/download_sft_data.py`, paired with
|
| 9 |
+
uint8 loss-mask shards (1 on assistant tokens, 0 on user-prompt tokens).
|
| 10 |
+
At runtime we pack consecutive examples into fixed-length rows; prompt
|
| 11 |
+
positions get target=-1 so CE's `ignore_index=-1` drops them.
|
| 12 |
+
|
| 13 |
+
Env vars (with defaults tuned for RTX 3060 6GB, 7.5M params):
|
| 14 |
+
HYDRA_SFT_TIME_BUDGET 10800 SFT wall-clock budget (3h)
|
| 15 |
+
HYDRA_SFT_SEQ_LEN 512 sequence length during SFT
|
| 16 |
+
HYDRA_BATCH_SIZE 4 per-step device batch
|
| 17 |
+
HYDRA_TOTAL_BATCH 8192 effective batch (grad-accum derived)
|
| 18 |
+
HYDRA_SFT_LR_MULT 0.10 multiply pretrain LRs by this
|
| 19 |
+
HYDRA_SFT_EVAL_INTERVAL 500 steps between sample generations
|
| 20 |
+
HYDRA_SFT_CKPT_INTERVAL 2000 steps between interim checkpoints
|
| 21 |
+
|
| 22 |
+
CLI:
|
| 23 |
+
--dry-run load model+data, run 1 step, exit (validation path)
|
| 24 |
+
--eval-only load `sft_final.pt`, run sample gen, exit
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import json
|
| 31 |
+
import math
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
import time
|
| 35 |
+
from dataclasses import asdict
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
import torch
|
| 40 |
+
|
| 41 |
+
# Repo root on path
|
| 42 |
+
_REPO_ROOT = Path(__file__).resolve().parent.parent
|
| 43 |
+
if str(_REPO_ROOT) not in sys.path:
|
| 44 |
+
sys.path.insert(0, str(_REPO_ROOT))
|
| 45 |
+
|
| 46 |
+
# Must import hydra.config BEFORE touching torch.cuda for CUDA env setup
|
| 47 |
+
from hydra.config import (
|
| 48 |
+
ADAM_BETAS, D_MODEL, D_STATE, DEVICE_BATCH_SIZE, EMBEDDING_LR,
|
| 49 |
+
ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND,
|
| 50 |
+
FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, HEADDIM, MATRIX_LR, N_HEADS,
|
| 51 |
+
N_LAYER, PostSemClawConfig, SCALAR_LR, SEED, TOTAL_BATCH_SIZE,
|
| 52 |
+
UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY,
|
| 53 |
+
)
|
| 54 |
+
from hydra.model import PostSemClawModel
|
| 55 |
+
from prepare import Tokenizer
|
| 56 |
+
|
| 57 |
+
# Use line-buffered stdout
|
| 58 |
+
try:
|
| 59 |
+
sys.stdout.reconfigure(line_buffering=True)
|
| 60 |
+
except Exception:
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
# Paths
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
|
| 68 |
+
CACHE_DIR = Path.home() / ".cache" / "autoresearch"
|
| 69 |
+
PRETRAIN_CKPT = CACHE_DIR / "pretrain_final.pt"
|
| 70 |
+
SFT_FINAL_CKPT = CACHE_DIR / "sft_final.pt"
|
| 71 |
+
SFT_INTERIM_CKPT = CACHE_DIR / "sft_interim.pt"
|
| 72 |
+
SFT_DATA_DIR = _REPO_ROOT / "data" / "sft"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
# Env vars for SFT
|
| 77 |
+
# ---------------------------------------------------------------------------
|
| 78 |
+
|
| 79 |
+
SFT_TIME_BUDGET = int(os.environ.get("HYDRA_SFT_TIME_BUDGET", "10800"))
|
| 80 |
+
SFT_SEQ_LEN = int(os.environ.get("HYDRA_SFT_SEQ_LEN", "512"))
|
| 81 |
+
SFT_LR_MULT = float(os.environ.get("HYDRA_SFT_LR_MULT", "0.10"))
|
| 82 |
+
SFT_EVAL_INTERVAL = int(os.environ.get("HYDRA_SFT_EVAL_INTERVAL", "500"))
|
| 83 |
+
SFT_CKPT_INTERVAL = int(os.environ.get("HYDRA_SFT_CKPT_INTERVAL", "2000"))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# Data loading
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
def _load_meta() -> dict:
|
| 91 |
+
meta_path = SFT_DATA_DIR / "meta.json"
|
| 92 |
+
if not meta_path.exists():
|
| 93 |
+
raise FileNotFoundError(
|
| 94 |
+
f"SFT meta not found at {meta_path}. Run "
|
| 95 |
+
f"`python scripts/download_sft_data.py` first."
|
| 96 |
+
)
|
| 97 |
+
with open(meta_path) as f:
|
| 98 |
+
return json.load(f)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _load_shards():
|
| 102 |
+
"""Load all shard_XXX.bin + mask_XXX.bin as big flat arrays.
|
| 103 |
+
|
| 104 |
+
Returns: (tokens: np.int64, mask: np.uint8)
|
| 105 |
+
Both arrays are 1-D and the same length. Total len ~= target_tokens.
|
| 106 |
+
"""
|
| 107 |
+
tok_shards = sorted(SFT_DATA_DIR.glob("shard_*.bin"))
|
| 108 |
+
mask_shards = sorted(SFT_DATA_DIR.glob("mask_*.bin"))
|
| 109 |
+
if not tok_shards:
|
| 110 |
+
raise FileNotFoundError(f"No SFT shards in {SFT_DATA_DIR}")
|
| 111 |
+
assert len(tok_shards) == len(mask_shards), (
|
| 112 |
+
f"shard/mask count mismatch: {len(tok_shards)} vs {len(mask_shards)}"
|
| 113 |
+
)
|
| 114 |
+
tok_parts = []
|
| 115 |
+
mask_parts = []
|
| 116 |
+
for t, m in zip(tok_shards, mask_shards):
|
| 117 |
+
tok_parts.append(np.fromfile(str(t), dtype=np.int16).astype(np.int64))
|
| 118 |
+
mask_parts.append(np.fromfile(str(m), dtype=np.uint8))
|
| 119 |
+
tokens = np.concatenate(tok_parts)
|
| 120 |
+
mask = np.concatenate(mask_parts)
|
| 121 |
+
assert tokens.shape == mask.shape
|
| 122 |
+
# Guard against negative int16 values (unlikely with vocab=8192 but defensive)
|
| 123 |
+
if tokens.min() < 0 or tokens.max() >= 8192:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
f"Token IDs out of range: min={tokens.min()} max={tokens.max()}"
|
| 126 |
+
)
|
| 127 |
+
return tokens, mask
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def make_sft_dataloader(tokens: np.ndarray, mask: np.ndarray, B: int, T: int,
|
| 131 |
+
device: torch.device, seed: int = 0):
|
| 132 |
+
"""Yield (x, y, epoch) forever.
|
| 133 |
+
|
| 134 |
+
Each row is a slice of length T+1 sampled at a random start. We produce:
|
| 135 |
+
x = slice[:-1] (B, T) int64 on device
|
| 136 |
+
y = slice[1:] with mask=0 -> -1 (B, T) int64 on device
|
| 137 |
+
|
| 138 |
+
The mask applies to target positions (y), not inputs. This way the
|
| 139 |
+
chunked CE loss in model.forward sees ignore_index=-1 for prompt tokens.
|
| 140 |
+
"""
|
| 141 |
+
N = tokens.shape[0]
|
| 142 |
+
rng = np.random.default_rng(seed)
|
| 143 |
+
# Pin CPU tensors; copy to GPU non-blocking.
|
| 144 |
+
cpu_x = torch.empty(B, T, dtype=torch.long, pin_memory=True)
|
| 145 |
+
cpu_y = torch.empty(B, T, dtype=torch.long, pin_memory=True)
|
| 146 |
+
epoch = 1
|
| 147 |
+
samples_drawn = 0
|
| 148 |
+
samples_per_epoch = max(1, N // (T + 1))
|
| 149 |
+
|
| 150 |
+
# Minimum loss-positions per window. If a sampled window has fewer than
|
| 151 |
+
# this many assistant tokens, resample. Guards against all-prompt windows
|
| 152 |
+
# producing NaN from 0/0 in the chunked CE loss.
|
| 153 |
+
min_loss_positions = max(1, T // 32)
|
| 154 |
+
max_resample = 8
|
| 155 |
+
|
| 156 |
+
while True:
|
| 157 |
+
for b in range(B):
|
| 158 |
+
# Sample a starting index with a light rejection filter to ensure
|
| 159 |
+
# the window contains enough assistant (mask=1) positions.
|
| 160 |
+
if N <= T + 1:
|
| 161 |
+
start = 0
|
| 162 |
+
else:
|
| 163 |
+
start = int(rng.integers(0, N - T - 1))
|
| 164 |
+
for _ in range(max_resample):
|
| 165 |
+
loss_in_window = int(mask[start + 1:start + T + 1].sum())
|
| 166 |
+
if loss_in_window >= min_loss_positions:
|
| 167 |
+
break
|
| 168 |
+
start = int(rng.integers(0, N - T - 1))
|
| 169 |
+
window_tok = tokens[start:start + T + 1]
|
| 170 |
+
window_mask = mask[start:start + T + 1]
|
| 171 |
+
# x = window[:-1], y = window[1:]
|
| 172 |
+
cpu_x[b].copy_(torch.from_numpy(window_tok[:-1].astype(np.int64)))
|
| 173 |
+
y_slice = window_tok[1:].astype(np.int64).copy()
|
| 174 |
+
# Apply mask to targets: mask=0 (prompt) -> target=-1 (ignore)
|
| 175 |
+
y_slice[window_mask[1:] == 0] = -1
|
| 176 |
+
# Final guard: if no loss positions survived, force at least 1
|
| 177 |
+
# valid target so the batch doesn't produce NaN (it's rare with
|
| 178 |
+
# the rejection filter but defensive is cheap).
|
| 179 |
+
if (y_slice != -1).sum() == 0:
|
| 180 |
+
y_slice[-1] = int(window_tok[-1])
|
| 181 |
+
cpu_y[b].copy_(torch.from_numpy(y_slice))
|
| 182 |
+
x = cpu_x.to(device, non_blocking=True)
|
| 183 |
+
y = cpu_y.to(device, non_blocking=True)
|
| 184 |
+
samples_drawn += B
|
| 185 |
+
if samples_drawn >= samples_per_epoch:
|
| 186 |
+
epoch += 1
|
| 187 |
+
samples_drawn = 0
|
| 188 |
+
yield x, y, epoch
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
# Model init + checkpoint load
|
| 193 |
+
# ---------------------------------------------------------------------------
|
| 194 |
+
|
| 195 |
+
def _peek_pretrain_config(vocab_size: int) -> PostSemClawConfig | None:
|
| 196 |
+
"""If pretrain checkpoint exists, return its saved config so we build
|
| 197 |
+
the SFT model with matching architecture. Returns None if unavailable."""
|
| 198 |
+
if not PRETRAIN_CKPT.exists():
|
| 199 |
+
return None
|
| 200 |
+
try:
|
| 201 |
+
ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cpu",
|
| 202 |
+
weights_only=False)
|
| 203 |
+
cfg_dict = ckpt.get("config")
|
| 204 |
+
if cfg_dict is None:
|
| 205 |
+
return None
|
| 206 |
+
# Override sequence_len to SFT's (shorter context) β architecture
|
| 207 |
+
# is independent of sequence_len since Mamba3 is recurrent.
|
| 208 |
+
cfg_dict = dict(cfg_dict)
|
| 209 |
+
cfg_dict["sequence_len"] = SFT_SEQ_LEN
|
| 210 |
+
cfg_dict["vocab_size"] = vocab_size
|
| 211 |
+
cfg = PostSemClawConfig(**cfg_dict)
|
| 212 |
+
return cfg
|
| 213 |
+
except Exception as e:
|
| 214 |
+
print(f"[model] could not peek pretrain config: {type(e).__name__}: {e}",
|
| 215 |
+
flush=True)
|
| 216 |
+
return None
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def build_model(vocab_size: int, device: torch.device) -> PostSemClawModel:
|
| 220 |
+
# Prefer checkpoint-derived config if available (guards against env-var drift)
|
| 221 |
+
config = _peek_pretrain_config(vocab_size)
|
| 222 |
+
if config is None:
|
| 223 |
+
config = PostSemClawConfig(
|
| 224 |
+
sequence_len=SFT_SEQ_LEN,
|
| 225 |
+
vocab_size=vocab_size,
|
| 226 |
+
n_layer=N_LAYER,
|
| 227 |
+
d_model=D_MODEL,
|
| 228 |
+
d_state=D_STATE,
|
| 229 |
+
headdim=HEADDIM,
|
| 230 |
+
n_heads=N_HEADS,
|
| 231 |
+
expand=EXPAND,
|
| 232 |
+
engram_n_columns=ENGRAM_N_COLUMNS,
|
| 233 |
+
engram_key_dim=ENGRAM_KEY_DIM,
|
| 234 |
+
engram_layer_idx=ENGRAM_LAYER_IDX,
|
| 235 |
+
)
|
| 236 |
+
print(f"[model] config (from env, no ckpt): {asdict(config)}", flush=True)
|
| 237 |
+
else:
|
| 238 |
+
print(f"[model] config (from pretrain ckpt): {asdict(config)}", flush=True)
|
| 239 |
+
with torch.device("meta"):
|
| 240 |
+
model = PostSemClawModel(config)
|
| 241 |
+
model.to_empty(device=device)
|
| 242 |
+
model.init_weights()
|
| 243 |
+
return model
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def try_load_pretrain(model: PostSemClawModel) -> tuple[bool, str]:
|
| 247 |
+
"""Attempt to load pretrain checkpoint into model. Returns (loaded, msg)."""
|
| 248 |
+
if not PRETRAIN_CKPT.exists():
|
| 249 |
+
return False, f"no checkpoint at {PRETRAIN_CKPT}"
|
| 250 |
+
try:
|
| 251 |
+
ckpt = torch.load(str(PRETRAIN_CKPT), map_location="cuda",
|
| 252 |
+
weights_only=False)
|
| 253 |
+
state = ckpt.get("model_state_dict", ckpt)
|
| 254 |
+
# Use strict=False in case SDR/HTM params are excluded from state_dict
|
| 255 |
+
# by torch.compile wrappers or similar.
|
| 256 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 257 |
+
msg = (f"loaded {PRETRAIN_CKPT} β missing={len(missing)} "
|
| 258 |
+
f"unexpected={len(unexpected)}")
|
| 259 |
+
if missing:
|
| 260 |
+
# Log first few missing keys to help diagnose architecture skew
|
| 261 |
+
msg += f" first_missing={missing[:3]}"
|
| 262 |
+
return True, msg
|
| 263 |
+
except Exception as e:
|
| 264 |
+
return False, f"load failed: {type(e).__name__}: {e}"
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ---------------------------------------------------------------------------
|
| 268 |
+
# Sample generation (for in-training eval prints)
|
| 269 |
+
# ---------------------------------------------------------------------------
|
| 270 |
+
|
| 271 |
+
_SAMPLE_PROMPTS = [
|
| 272 |
+
"What is the capital of France?",
|
| 273 |
+
"Write a haiku about winter.",
|
| 274 |
+
"List three colors.",
|
| 275 |
+
"How are you?",
|
| 276 |
+
"Explain why the sky is blue in one sentence.",
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@torch.no_grad()
|
| 281 |
+
def sample_once(model, tokenizer, meta: dict, prompt: str,
|
| 282 |
+
max_new: int = 64, temperature: float = 0.8,
|
| 283 |
+
top_k: int = 40) -> str:
|
| 284 |
+
"""Generate a chat-formatted reply. Stops on <|end|> or max_new tokens."""
|
| 285 |
+
bos = meta["special_tokens"]["bos"]
|
| 286 |
+
user = meta["special_tokens"]["user"]
|
| 287 |
+
assistant = meta["special_tokens"]["assistant"]
|
| 288 |
+
end = meta["special_tokens"]["end"]
|
| 289 |
+
|
| 290 |
+
prompt_ids = [bos, user] + tokenizer.encode("\n" + prompt.strip())
|
| 291 |
+
prompt_ids += tokenizer.encode("\n")
|
| 292 |
+
prompt_ids.append(assistant)
|
| 293 |
+
prompt_ids += tokenizer.encode("\n")
|
| 294 |
+
|
| 295 |
+
ctx = torch.tensor([prompt_ids], device="cuda", dtype=torch.long)
|
| 296 |
+
generated: list[int] = []
|
| 297 |
+
for _ in range(max_new):
|
| 298 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 299 |
+
logits = model(ctx, targets=None)
|
| 300 |
+
last = logits[0, -1].float()
|
| 301 |
+
if top_k and top_k < last.shape[-1]:
|
| 302 |
+
kth = torch.topk(last, top_k).values[-1]
|
| 303 |
+
last = torch.where(last < kth, torch.full_like(last, -1e9), last)
|
| 304 |
+
probs = torch.softmax(last / max(temperature, 1e-6), dim=-1)
|
| 305 |
+
next_id = int(torch.multinomial(probs, num_samples=1).item())
|
| 306 |
+
generated.append(next_id)
|
| 307 |
+
if next_id == end:
|
| 308 |
+
break
|
| 309 |
+
ctx = torch.cat(
|
| 310 |
+
[ctx, torch.tensor([[next_id]], device="cuda", dtype=torch.long)],
|
| 311 |
+
dim=1,
|
| 312 |
+
)
|
| 313 |
+
# Hard cap on ctx length (model was trained at 2048, SFT at 512,
|
| 314 |
+
# but inference could theoretically go longer)
|
| 315 |
+
if ctx.size(1) >= 2048:
|
| 316 |
+
break
|
| 317 |
+
try:
|
| 318 |
+
text = tokenizer.decode(generated)
|
| 319 |
+
except Exception:
|
| 320 |
+
text = "<decode error>"
|
| 321 |
+
return text
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def run_samples(model, tokenizer, meta: dict, step: int):
|
| 325 |
+
model.eval()
|
| 326 |
+
print(f"\n=== SFT samples @ step {step} ===", flush=True)
|
| 327 |
+
for p in _SAMPLE_PROMPTS:
|
| 328 |
+
try:
|
| 329 |
+
resp = sample_once(model, tokenizer, meta, p)
|
| 330 |
+
except Exception as e:
|
| 331 |
+
resp = f"<sample failed: {type(e).__name__}: {e}>"
|
| 332 |
+
# Sanitize newlines for log readability
|
| 333 |
+
resp_clean = resp.replace("\n", " β ").replace("\r", " ")
|
| 334 |
+
print(f" prompt: {p!r}")
|
| 335 |
+
print(f" reply: {resp_clean!r}")
|
| 336 |
+
print("=== end samples ===\n", flush=True)
|
| 337 |
+
model.train()
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ---------------------------------------------------------------------------
|
| 341 |
+
# Checkpoint save
|
| 342 |
+
# ---------------------------------------------------------------------------
|
| 343 |
+
|
| 344 |
+
def save_ckpt(model, step: int, smoothed_loss: float, path: Path,
|
| 345 |
+
mode: str, meta: dict):
|
| 346 |
+
try:
|
| 347 |
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
| 348 |
+
payload = {
|
| 349 |
+
"model_state_dict": model.state_dict(),
|
| 350 |
+
"step": step,
|
| 351 |
+
"smoothed_loss": smoothed_loss,
|
| 352 |
+
"mode": mode,
|
| 353 |
+
"sft_meta": meta,
|
| 354 |
+
}
|
| 355 |
+
torch.save(payload, str(path))
|
| 356 |
+
print(f"[ckpt] saved {path} (step={step})", flush=True)
|
| 357 |
+
except Exception as e:
|
| 358 |
+
print(f"[ckpt] SAVE FAILED {path}: {type(e).__name__}: {e}", flush=True)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
# ---------------------------------------------------------------------------
|
| 362 |
+
# Main
|
| 363 |
+
# ---------------------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
def main():
|
| 366 |
+
ap = argparse.ArgumentParser()
|
| 367 |
+
ap.add_argument("--dry-run", action="store_true",
|
| 368 |
+
help="Load model+data, run 1 step, exit.")
|
| 369 |
+
ap.add_argument("--eval-only", action="store_true",
|
| 370 |
+
help="Load sft_final.pt and run sample gen.")
|
| 371 |
+
args = ap.parse_args()
|
| 372 |
+
|
| 373 |
+
t_start = time.time()
|
| 374 |
+
torch.manual_seed(SEED + 1) # +1 so SFT draws different RNG than pretrain
|
| 375 |
+
torch.cuda.manual_seed(SEED + 1)
|
| 376 |
+
torch.set_float32_matmul_precision("high")
|
| 377 |
+
device = torch.device("cuda")
|
| 378 |
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 379 |
+
|
| 380 |
+
# --- Tokenizer ---
|
| 381 |
+
tokenizer = Tokenizer.from_directory()
|
| 382 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 383 |
+
print(f"[init] vocab: {vocab_size}", flush=True)
|
| 384 |
+
|
| 385 |
+
# --- Data meta ---
|
| 386 |
+
meta = _load_meta()
|
| 387 |
+
print(f"[data] meta: {meta}", flush=True)
|
| 388 |
+
|
| 389 |
+
# --- Model ---
|
| 390 |
+
model = build_model(vocab_size, device)
|
| 391 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 392 |
+
print(f"[model] params: {n_params:,}", flush=True)
|
| 393 |
+
|
| 394 |
+
loaded, msg = try_load_pretrain(model)
|
| 395 |
+
mode = "resume_from_pretrain" if loaded else "from_scratch"
|
| 396 |
+
print(f"[init] MODE={mode} :: {msg}", flush=True)
|
| 397 |
+
|
| 398 |
+
# --- Eval-only path ---
|
| 399 |
+
if args.eval_only:
|
| 400 |
+
if SFT_FINAL_CKPT.exists():
|
| 401 |
+
ckpt = torch.load(str(SFT_FINAL_CKPT), map_location=device,
|
| 402 |
+
weights_only=False)
|
| 403 |
+
state = ckpt.get("model_state_dict", ckpt)
|
| 404 |
+
model.load_state_dict(state, strict=False)
|
| 405 |
+
print(f"[eval-only] loaded {SFT_FINAL_CKPT}", flush=True)
|
| 406 |
+
else:
|
| 407 |
+
print(f"[eval-only] no SFT checkpoint β running on current weights",
|
| 408 |
+
flush=True)
|
| 409 |
+
run_samples(model, tokenizer, meta, step=-1)
|
| 410 |
+
return
|
| 411 |
+
|
| 412 |
+
# --- Dataloader ---
|
| 413 |
+
print(f"[data] loading shards ...", flush=True)
|
| 414 |
+
tokens, mask = _load_shards()
|
| 415 |
+
print(f"[data] tokens: {len(tokens):,} loss-positions: {int(mask.sum()):,}",
|
| 416 |
+
flush=True)
|
| 417 |
+
B = DEVICE_BATCH_SIZE
|
| 418 |
+
T = SFT_SEQ_LEN
|
| 419 |
+
tokens_per_fwdbwd = B * T
|
| 420 |
+
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0, (
|
| 421 |
+
f"TOTAL_BATCH_SIZE={TOTAL_BATCH_SIZE} not divisible by B*T={tokens_per_fwdbwd}"
|
| 422 |
+
)
|
| 423 |
+
grad_accum = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
|
| 424 |
+
print(f"[train] B={B} T={T} accum={grad_accum} effective_batch={TOTAL_BATCH_SIZE}",
|
| 425 |
+
flush=True)
|
| 426 |
+
loader = make_sft_dataloader(tokens, mask, B, T, device, seed=SEED + 7)
|
| 427 |
+
x, y, epoch = next(loader)
|
| 428 |
+
|
| 429 |
+
# --- Optimizer (scaled LRs) ---
|
| 430 |
+
matrix_lr = MATRIX_LR * SFT_LR_MULT
|
| 431 |
+
embed_lr = EMBEDDING_LR * SFT_LR_MULT
|
| 432 |
+
unembed_lr = UNEMBEDDING_LR * SFT_LR_MULT
|
| 433 |
+
scalar_lr = SCALAR_LR * SFT_LR_MULT
|
| 434 |
+
print(f"[opt] LRs scaled by {SFT_LR_MULT}: matrix={matrix_lr:.5f} "
|
| 435 |
+
f"embed={embed_lr:.5f} unembed={unembed_lr:.6f}", flush=True)
|
| 436 |
+
optimizer = model.setup_optimizer(
|
| 437 |
+
unembedding_lr=unembed_lr,
|
| 438 |
+
embedding_lr=embed_lr,
|
| 439 |
+
scalar_lr=scalar_lr,
|
| 440 |
+
adam_betas=ADAM_BETAS,
|
| 441 |
+
matrix_lr=matrix_lr,
|
| 442 |
+
weight_decay=WEIGHT_DECAY,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# --- Dry-run path (validation) ---
|
| 446 |
+
if args.dry_run:
|
| 447 |
+
print("[dry-run] running 1 step ...", flush=True)
|
| 448 |
+
with autocast_ctx:
|
| 449 |
+
loss = model(x, y)
|
| 450 |
+
loss_f = float(loss.item())
|
| 451 |
+
print(f"[dry-run] step0 loss={loss_f:.4f}", flush=True)
|
| 452 |
+
loss.backward()
|
| 453 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 454 |
+
optimizer.step()
|
| 455 |
+
model.zero_grad(set_to_none=True)
|
| 456 |
+
if math.isnan(loss_f) or loss_f > 100:
|
| 457 |
+
print("[dry-run] FAILED (NaN / huge loss)", flush=True)
|
| 458 |
+
sys.exit(1)
|
| 459 |
+
print("[dry-run] OK", flush=True)
|
| 460 |
+
return
|
| 461 |
+
|
| 462 |
+
# --- Training loop ---
|
| 463 |
+
print(f"[train] budget={SFT_TIME_BUDGET}s eval_every={SFT_EVAL_INTERVAL} "
|
| 464 |
+
f"ckpt_every={SFT_CKPT_INTERVAL}", flush=True)
|
| 465 |
+
t_loop_start = time.time()
|
| 466 |
+
smooth_loss = 0.0
|
| 467 |
+
step = 0
|
| 468 |
+
total_train_secs = 0.0
|
| 469 |
+
|
| 470 |
+
# Warmup schedule for SFT: linear 0->1 over first 5% of budget, then cosine.
|
| 471 |
+
sft_warmup_frac = 0.05
|
| 472 |
+
|
| 473 |
+
def lr_mult(progress: float) -> float:
|
| 474 |
+
if progress < sft_warmup_frac:
|
| 475 |
+
return progress / sft_warmup_frac if sft_warmup_frac > 0 else 1.0
|
| 476 |
+
decay = (progress - sft_warmup_frac) / (1.0 - sft_warmup_frac)
|
| 477 |
+
return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * \
|
| 478 |
+
(1 + math.cos(math.pi * decay))
|
| 479 |
+
|
| 480 |
+
while True:
|
| 481 |
+
torch.cuda.synchronize()
|
| 482 |
+
t0 = time.time()
|
| 483 |
+
for _ in range(grad_accum):
|
| 484 |
+
with autocast_ctx:
|
| 485 |
+
loss = model(x, y)
|
| 486 |
+
train_loss_val = loss.detach()
|
| 487 |
+
(loss / grad_accum).backward()
|
| 488 |
+
x, y, epoch = next(loader)
|
| 489 |
+
|
| 490 |
+
progress = min(total_train_secs / SFT_TIME_BUDGET, 1.0)
|
| 491 |
+
mult = lr_mult(progress)
|
| 492 |
+
for group in optimizer.param_groups:
|
| 493 |
+
group["lr"] = group["initial_lr"] * mult
|
| 494 |
+
|
| 495 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 496 |
+
optimizer.step()
|
| 497 |
+
model.zero_grad(set_to_none=True)
|
| 498 |
+
|
| 499 |
+
loss_f = float(train_loss_val.item())
|
| 500 |
+
if math.isnan(loss_f) or loss_f > 100:
|
| 501 |
+
print(f"[FAIL] step={step} loss={loss_f} β aborting", flush=True)
|
| 502 |
+
save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
|
| 503 |
+
sys.exit(1)
|
| 504 |
+
|
| 505 |
+
torch.cuda.synchronize()
|
| 506 |
+
dt = time.time() - t0
|
| 507 |
+
if step > 3:
|
| 508 |
+
total_train_secs += dt
|
| 509 |
+
|
| 510 |
+
# EMA loss (debiased)
|
| 511 |
+
beta = 0.9
|
| 512 |
+
smooth_loss = beta * smooth_loss + (1 - beta) * loss_f
|
| 513 |
+
debiased = smooth_loss / (1 - beta ** (step + 1))
|
| 514 |
+
bpt = debiased / math.log(2)
|
| 515 |
+
tps = int(TOTAL_BATCH_SIZE / dt) if dt > 0 else 0
|
| 516 |
+
vram_mib = torch.cuda.memory_allocated() / 1024 / 1024
|
| 517 |
+
lr_now = optimizer.param_groups[0]["lr"]
|
| 518 |
+
remaining = max(0, SFT_TIME_BUDGET - total_train_secs)
|
| 519 |
+
|
| 520 |
+
print(
|
| 521 |
+
f"sft_step={step:05d} loss={debiased:.4f} bpt={bpt:.3f} "
|
| 522 |
+
f"tps={tps} dt_ms={dt*1000:.0f} lr={lr_now:.2e} "
|
| 523 |
+
f"vram={vram_mib:.0f}MiB pct={100*progress:.1f} "
|
| 524 |
+
f"epoch={epoch} remaining={remaining:.0f}s",
|
| 525 |
+
flush=True,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
if step > 0 and step % SFT_EVAL_INTERVAL == 0:
|
| 529 |
+
run_samples(model, tokenizer, meta, step)
|
| 530 |
+
|
| 531 |
+
if step > 0 and step % SFT_CKPT_INTERVAL == 0:
|
| 532 |
+
save_ckpt(model, step, smooth_loss, SFT_INTERIM_CKPT, mode, meta)
|
| 533 |
+
|
| 534 |
+
step += 1
|
| 535 |
+
|
| 536 |
+
if step > 5 and total_train_secs >= SFT_TIME_BUDGET:
|
| 537 |
+
break
|
| 538 |
+
|
| 539 |
+
# Final samples + save
|
| 540 |
+
run_samples(model, tokenizer, meta, step)
|
| 541 |
+
save_ckpt(model, step, smooth_loss, SFT_FINAL_CKPT, mode, meta)
|
| 542 |
+
|
| 543 |
+
total_secs = time.time() - t_start
|
| 544 |
+
print("---", flush=True)
|
| 545 |
+
print(f"SFT_COMPLETE mode={mode} step={step} "
|
| 546 |
+
f"smoothed_loss={smooth_loss:.4f} total_seconds={total_secs:.0f} "
|
| 547 |
+
f"train_seconds={total_train_secs:.0f}", flush=True)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
if __name__ == "__main__":
|
| 551 |
+
try:
|
| 552 |
+
main()
|
| 553 |
+
except SystemExit:
|
| 554 |
+
raise
|
| 555 |
+
except Exception as e:
|
| 556 |
+
import traceback
|
| 557 |
+
print(f"SFT_FAILED {type(e).__name__}: {e}", flush=True)
|
| 558 |
+
traceback.print_exc()
|
| 559 |
+
sys.exit(1)
|
overlay/scripts/sft_orchestrator.sh
CHANGED
|
@@ -1,165 +1,165 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
#
|
| 3 |
-
# SFT orchestrator: waits for pretrain (train.py) to either complete or
|
| 4 |
-
# reach the 8h budget, then kicks off SFT.
|
| 5 |
-
#
|
| 6 |
-
# Behavior:
|
| 7 |
-
# - Polls for `train.py` process every 60 s
|
| 8 |
-
# - Exits the wait loop on either:
|
| 9 |
-
# (a) no train.py process found (pretrain completed naturally), or
|
| 10 |
-
# (b) 8h elapsed since this script started
|
| 11 |
-
# - Sends SIGTERM first (graceful β triggers checkpoint-save patch if
|
| 12 |
-
# applied), waits 30 s, then SIGKILL as fallback
|
| 13 |
-
# - Invokes `scripts/download_sft_data.py` if shards don't exist
|
| 14 |
-
# - Launches `scripts/sft.py` in the background with tuned env vars
|
| 15 |
-
# - Redirects all output to `run_sft.log`
|
| 16 |
-
#
|
| 17 |
-
# Re-entrant: safe to invoke even if pretrain has already exited.
|
| 18 |
-
# Does NOT re-launch if SFT is already running.
|
| 19 |
-
#
|
| 20 |
-
# Usage (typical):
|
| 21 |
-
# cd /home/mikeb/work/feather
|
| 22 |
-
# nohup bash scripts/sft_orchestrator.sh > orchestrator.log 2>&1 &
|
| 23 |
-
# disown
|
| 24 |
-
|
| 25 |
-
set -u # error on unset vars, but don't -e (we handle failures explicitly)
|
| 26 |
-
|
| 27 |
-
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 28 |
-
cd "$REPO_ROOT" || { echo "cannot cd to $REPO_ROOT" >&2; exit 1; }
|
| 29 |
-
|
| 30 |
-
PY="$REPO_ROOT/.venv/bin/python"
|
| 31 |
-
if [ ! -x "$PY" ]; then
|
| 32 |
-
echo "[orchestrator] ERROR: python not found at $PY" >&2
|
| 33 |
-
exit 1
|
| 34 |
-
fi
|
| 35 |
-
|
| 36 |
-
LOG_FILE="$REPO_ROOT/run_sft.log"
|
| 37 |
-
DATA_LOG="$REPO_ROOT/run_sft_download.log"
|
| 38 |
-
MAX_WAIT_SECONDS=28800 # 8 hours
|
| 39 |
-
POLL_INTERVAL=60
|
| 40 |
-
GRACEFUL_SHUTDOWN_WAIT=30
|
| 41 |
-
|
| 42 |
-
log() {
|
| 43 |
-
echo "[orchestrator $(date -u '+%Y-%m-%dT%H:%M:%SZ')] $*"
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
# ---------------------------------------------------------------------------
|
| 47 |
-
# Stage 1: wait for pretrain
|
| 48 |
-
# ---------------------------------------------------------------------------
|
| 49 |
-
|
| 50 |
-
log "starting; max wait = ${MAX_WAIT_SECONDS}s"
|
| 51 |
-
|
| 52 |
-
# Guard against double-launch
|
| 53 |
-
if pgrep -f "scripts/sft.py" > /dev/null; then
|
| 54 |
-
log "SFT is already running β exiting orchestrator to avoid conflict"
|
| 55 |
-
exit 0
|
| 56 |
-
fi
|
| 57 |
-
|
| 58 |
-
T_START=$(date +%s)
|
| 59 |
-
while true; do
|
| 60 |
-
NOW=$(date +%s)
|
| 61 |
-
ELAPSED=$((NOW - T_START))
|
| 62 |
-
|
| 63 |
-
if [ $ELAPSED -ge $MAX_WAIT_SECONDS ]; then
|
| 64 |
-
log "reached 8h wait cap (${ELAPSED}s) β will kill pretrain"
|
| 65 |
-
break
|
| 66 |
-
fi
|
| 67 |
-
|
| 68 |
-
# Count train.py processes owned by current user (not orchestrator/sft.py)
|
| 69 |
-
PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 70 |
-
# Strip pid of this script if pgrep matched something spurious
|
| 71 |
-
PRETRAIN_PIDS=$(echo "$PRETRAIN_PIDS" | sed "s/\b$$\b//g" | xargs)
|
| 72 |
-
|
| 73 |
-
if [ -z "$PRETRAIN_PIDS" ]; then
|
| 74 |
-
log "no train.py process found β pretrain already exited"
|
| 75 |
-
break
|
| 76 |
-
fi
|
| 77 |
-
|
| 78 |
-
# Log a status every 10 polls (~10 min)
|
| 79 |
-
if [ $((ELAPSED / POLL_INTERVAL % 10)) -eq 0 ]; then
|
| 80 |
-
log "waiting... elapsed=${ELAPSED}s pretrain PIDs: $PRETRAIN_PIDS"
|
| 81 |
-
fi
|
| 82 |
-
|
| 83 |
-
sleep $POLL_INTERVAL
|
| 84 |
-
done
|
| 85 |
-
|
| 86 |
-
# ---------------------------------------------------------------------------
|
| 87 |
-
# Stage 2: kill any remaining pretrain processes
|
| 88 |
-
# ---------------------------------------------------------------------------
|
| 89 |
-
|
| 90 |
-
PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 91 |
-
if [ -n "$PRETRAIN_PIDS" ]; then
|
| 92 |
-
log "sending SIGTERM to pretrain PIDs: $PRETRAIN_PIDS"
|
| 93 |
-
for pid in $PRETRAIN_PIDS; do
|
| 94 |
-
kill -TERM "$pid" 2>/dev/null || true
|
| 95 |
-
done
|
| 96 |
-
|
| 97 |
-
# Wait for graceful shutdown (gives the checkpoint-save patch time to run)
|
| 98 |
-
for _ in $(seq 1 $GRACEFUL_SHUTDOWN_WAIT); do
|
| 99 |
-
REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 100 |
-
if [ -z "$REMAINING" ]; then break; fi
|
| 101 |
-
sleep 1
|
| 102 |
-
done
|
| 103 |
-
|
| 104 |
-
# Force-kill any stragglers
|
| 105 |
-
REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 106 |
-
if [ -n "$REMAINING" ]; then
|
| 107 |
-
log "force-killing stragglers: $REMAINING"
|
| 108 |
-
for pid in $REMAINING; do
|
| 109 |
-
kill -9 "$pid" 2>/dev/null || true
|
| 110 |
-
done
|
| 111 |
-
sleep 5
|
| 112 |
-
fi
|
| 113 |
-
fi
|
| 114 |
-
|
| 115 |
-
# ---------------------------------------------------------------------------
|
| 116 |
-
# Stage 3: ensure SFT data exists
|
| 117 |
-
# ---------------------------------------------------------------------------
|
| 118 |
-
|
| 119 |
-
META_JSON="$REPO_ROOT/data/sft/meta.json"
|
| 120 |
-
if [ ! -f "$META_JSON" ]; then
|
| 121 |
-
log "no SFT data found β running download_sft_data.py"
|
| 122 |
-
LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
|
| 123 |
-
"$PY" -u "$REPO_ROOT/scripts/download_sft_data.py" \
|
| 124 |
-
> "$DATA_LOG" 2>&1
|
| 125 |
-
DL_RC=$?
|
| 126 |
-
if [ $DL_RC -ne 0 ] || [ ! -f "$META_JSON" ]; then
|
| 127 |
-
log "ERROR: SFT data download failed (rc=$DL_RC)"
|
| 128 |
-
log " last 20 lines of $DATA_LOG:"
|
| 129 |
-
tail -20 "$DATA_LOG" 2>/dev/null | sed 's/^/ /'
|
| 130 |
-
exit 2
|
| 131 |
-
fi
|
| 132 |
-
log "SFT data ready"
|
| 133 |
-
else
|
| 134 |
-
log "SFT data already present at $META_JSON"
|
| 135 |
-
fi
|
| 136 |
-
|
| 137 |
-
# ---------------------------------------------------------------------------
|
| 138 |
-
# Stage 4: launch SFT
|
| 139 |
-
# ---------------------------------------------------------------------------
|
| 140 |
-
|
| 141 |
-
# Guard: if we somehow got here and SFT is now running, don't double-launch.
|
| 142 |
-
if pgrep -f "scripts/sft.py" > /dev/null; then
|
| 143 |
-
log "SFT is already running β skipping launch"
|
| 144 |
-
exit 0
|
| 145 |
-
fi
|
| 146 |
-
|
| 147 |
-
log "launching SFT (log -> $LOG_FILE)"
|
| 148 |
-
|
| 149 |
-
export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
| 150 |
-
export HYDRA_SFT_TIME_BUDGET="${HYDRA_SFT_TIME_BUDGET:-10800}"
|
| 151 |
-
export HYDRA_BATCH_SIZE="${HYDRA_BATCH_SIZE:-4}"
|
| 152 |
-
export HYDRA_TOTAL_BATCH="${HYDRA_TOTAL_BATCH:-8192}"
|
| 153 |
-
export HYDRA_SFT_SEQ_LEN="${HYDRA_SFT_SEQ_LEN:-512}"
|
| 154 |
-
export HYDRA_SFT_LR_MULT="${HYDRA_SFT_LR_MULT:-0.10}"
|
| 155 |
-
export HYDRA_SFT_EVAL_INTERVAL="${HYDRA_SFT_EVAL_INTERVAL:-500}"
|
| 156 |
-
export HYDRA_SFT_CKPT_INTERVAL="${HYDRA_SFT_CKPT_INTERVAL:-2000}"
|
| 157 |
-
export HYDRA_DROPOUT="${HYDRA_DROPOUT:-0.1}"
|
| 158 |
-
|
| 159 |
-
nohup "$PY" -u "$REPO_ROOT/scripts/sft.py" \
|
| 160 |
-
> "$LOG_FILE" 2>&1 &
|
| 161 |
-
SFT_PID=$!
|
| 162 |
-
disown $SFT_PID 2>/dev/null || true
|
| 163 |
-
|
| 164 |
-
log "SFT launched as PID $SFT_PID (budget=${HYDRA_SFT_TIME_BUDGET}s)"
|
| 165 |
-
log "monitor with: tail -f $LOG_FILE"
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# SFT orchestrator: waits for pretrain (train.py) to either complete or
|
| 4 |
+
# reach the 8h budget, then kicks off SFT.
|
| 5 |
+
#
|
| 6 |
+
# Behavior:
|
| 7 |
+
# - Polls for `train.py` process every 60 s
|
| 8 |
+
# - Exits the wait loop on either:
|
| 9 |
+
# (a) no train.py process found (pretrain completed naturally), or
|
| 10 |
+
# (b) 8h elapsed since this script started
|
| 11 |
+
# - Sends SIGTERM first (graceful β triggers checkpoint-save patch if
|
| 12 |
+
# applied), waits 30 s, then SIGKILL as fallback
|
| 13 |
+
# - Invokes `scripts/download_sft_data.py` if shards don't exist
|
| 14 |
+
# - Launches `scripts/sft.py` in the background with tuned env vars
|
| 15 |
+
# - Redirects all output to `run_sft.log`
|
| 16 |
+
#
|
| 17 |
+
# Re-entrant: safe to invoke even if pretrain has already exited.
|
| 18 |
+
# Does NOT re-launch if SFT is already running.
|
| 19 |
+
#
|
| 20 |
+
# Usage (typical):
|
| 21 |
+
# cd /home/mikeb/work/feather
|
| 22 |
+
# nohup bash scripts/sft_orchestrator.sh > orchestrator.log 2>&1 &
|
| 23 |
+
# disown
|
| 24 |
+
|
| 25 |
+
set -u # error on unset vars, but don't -e (we handle failures explicitly)
|
| 26 |
+
|
| 27 |
+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
| 28 |
+
cd "$REPO_ROOT" || { echo "cannot cd to $REPO_ROOT" >&2; exit 1; }
|
| 29 |
+
|
| 30 |
+
PY="$REPO_ROOT/.venv/bin/python"
|
| 31 |
+
if [ ! -x "$PY" ]; then
|
| 32 |
+
echo "[orchestrator] ERROR: python not found at $PY" >&2
|
| 33 |
+
exit 1
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
LOG_FILE="$REPO_ROOT/run_sft.log"
|
| 37 |
+
DATA_LOG="$REPO_ROOT/run_sft_download.log"
|
| 38 |
+
MAX_WAIT_SECONDS=28800 # 8 hours
|
| 39 |
+
POLL_INTERVAL=60
|
| 40 |
+
GRACEFUL_SHUTDOWN_WAIT=30
|
| 41 |
+
|
| 42 |
+
log() {
|
| 43 |
+
echo "[orchestrator $(date -u '+%Y-%m-%dT%H:%M:%SZ')] $*"
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Stage 1: wait for pretrain
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
log "starting; max wait = ${MAX_WAIT_SECONDS}s"
|
| 51 |
+
|
| 52 |
+
# Guard against double-launch
|
| 53 |
+
if pgrep -f "scripts/sft.py" > /dev/null; then
|
| 54 |
+
log "SFT is already running β exiting orchestrator to avoid conflict"
|
| 55 |
+
exit 0
|
| 56 |
+
fi
|
| 57 |
+
|
| 58 |
+
T_START=$(date +%s)
|
| 59 |
+
while true; do
|
| 60 |
+
NOW=$(date +%s)
|
| 61 |
+
ELAPSED=$((NOW - T_START))
|
| 62 |
+
|
| 63 |
+
if [ $ELAPSED -ge $MAX_WAIT_SECONDS ]; then
|
| 64 |
+
log "reached 8h wait cap (${ELAPSED}s) β will kill pretrain"
|
| 65 |
+
break
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
# Count train.py processes owned by current user (not orchestrator/sft.py)
|
| 69 |
+
PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 70 |
+
# Strip pid of this script if pgrep matched something spurious
|
| 71 |
+
PRETRAIN_PIDS=$(echo "$PRETRAIN_PIDS" | sed "s/\b$$\b//g" | xargs)
|
| 72 |
+
|
| 73 |
+
if [ -z "$PRETRAIN_PIDS" ]; then
|
| 74 |
+
log "no train.py process found β pretrain already exited"
|
| 75 |
+
break
|
| 76 |
+
fi
|
| 77 |
+
|
| 78 |
+
# Log a status every 10 polls (~10 min)
|
| 79 |
+
if [ $((ELAPSED / POLL_INTERVAL % 10)) -eq 0 ]; then
|
| 80 |
+
log "waiting... elapsed=${ELAPSED}s pretrain PIDs: $PRETRAIN_PIDS"
|
| 81 |
+
fi
|
| 82 |
+
|
| 83 |
+
sleep $POLL_INTERVAL
|
| 84 |
+
done
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# Stage 2: kill any remaining pretrain processes
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
PRETRAIN_PIDS=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 91 |
+
if [ -n "$PRETRAIN_PIDS" ]; then
|
| 92 |
+
log "sending SIGTERM to pretrain PIDs: $PRETRAIN_PIDS"
|
| 93 |
+
for pid in $PRETRAIN_PIDS; do
|
| 94 |
+
kill -TERM "$pid" 2>/dev/null || true
|
| 95 |
+
done
|
| 96 |
+
|
| 97 |
+
# Wait for graceful shutdown (gives the checkpoint-save patch time to run)
|
| 98 |
+
for _ in $(seq 1 $GRACEFUL_SHUTDOWN_WAIT); do
|
| 99 |
+
REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 100 |
+
if [ -z "$REMAINING" ]; then break; fi
|
| 101 |
+
sleep 1
|
| 102 |
+
done
|
| 103 |
+
|
| 104 |
+
# Force-kill any stragglers
|
| 105 |
+
REMAINING=$(pgrep -u "$USER" -f "train\.py" 2>/dev/null | tr '\n' ' ')
|
| 106 |
+
if [ -n "$REMAINING" ]; then
|
| 107 |
+
log "force-killing stragglers: $REMAINING"
|
| 108 |
+
for pid in $REMAINING; do
|
| 109 |
+
kill -9 "$pid" 2>/dev/null || true
|
| 110 |
+
done
|
| 111 |
+
sleep 5
|
| 112 |
+
fi
|
| 113 |
+
fi
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Stage 3: ensure SFT data exists
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
META_JSON="$REPO_ROOT/data/sft/meta.json"
|
| 120 |
+
if [ ! -f "$META_JSON" ]; then
|
| 121 |
+
log "no SFT data found β running download_sft_data.py"
|
| 122 |
+
LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
|
| 123 |
+
"$PY" -u "$REPO_ROOT/scripts/download_sft_data.py" \
|
| 124 |
+
> "$DATA_LOG" 2>&1
|
| 125 |
+
DL_RC=$?
|
| 126 |
+
if [ $DL_RC -ne 0 ] || [ ! -f "$META_JSON" ]; then
|
| 127 |
+
log "ERROR: SFT data download failed (rc=$DL_RC)"
|
| 128 |
+
log " last 20 lines of $DATA_LOG:"
|
| 129 |
+
tail -20 "$DATA_LOG" 2>/dev/null | sed 's/^/ /'
|
| 130 |
+
exit 2
|
| 131 |
+
fi
|
| 132 |
+
log "SFT data ready"
|
| 133 |
+
else
|
| 134 |
+
log "SFT data already present at $META_JSON"
|
| 135 |
+
fi
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
# Stage 4: launch SFT
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
# Guard: if we somehow got here and SFT is now running, don't double-launch.
|
| 142 |
+
if pgrep -f "scripts/sft.py" > /dev/null; then
|
| 143 |
+
log "SFT is already running β skipping launch"
|
| 144 |
+
exit 0
|
| 145 |
+
fi
|
| 146 |
+
|
| 147 |
+
log "launching SFT (log -> $LOG_FILE)"
|
| 148 |
+
|
| 149 |
+
export LD_LIBRARY_PATH="/usr/lib/wsl/lib:/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}"
|
| 150 |
+
export HYDRA_SFT_TIME_BUDGET="${HYDRA_SFT_TIME_BUDGET:-10800}"
|
| 151 |
+
export HYDRA_BATCH_SIZE="${HYDRA_BATCH_SIZE:-4}"
|
| 152 |
+
export HYDRA_TOTAL_BATCH="${HYDRA_TOTAL_BATCH:-8192}"
|
| 153 |
+
export HYDRA_SFT_SEQ_LEN="${HYDRA_SFT_SEQ_LEN:-512}"
|
| 154 |
+
export HYDRA_SFT_LR_MULT="${HYDRA_SFT_LR_MULT:-0.10}"
|
| 155 |
+
export HYDRA_SFT_EVAL_INTERVAL="${HYDRA_SFT_EVAL_INTERVAL:-500}"
|
| 156 |
+
export HYDRA_SFT_CKPT_INTERVAL="${HYDRA_SFT_CKPT_INTERVAL:-2000}"
|
| 157 |
+
export HYDRA_DROPOUT="${HYDRA_DROPOUT:-0.1}"
|
| 158 |
+
|
| 159 |
+
nohup "$PY" -u "$REPO_ROOT/scripts/sft.py" \
|
| 160 |
+
> "$LOG_FILE" 2>&1 &
|
| 161 |
+
SFT_PID=$!
|
| 162 |
+
disown $SFT_PID 2>/dev/null || true
|
| 163 |
+
|
| 164 |
+
log "SFT launched as PID $SFT_PID (budget=${HYDRA_SFT_TIME_BUDGET}s)"
|
| 165 |
+
log "monitor with: tail -f $LOG_FILE"
|
overlay/subsystems/fused_sdr_project.py
CHANGED
|
@@ -114,6 +114,13 @@ class FusedSDRProject(torch.autograd.Function):
|
|
| 114 |
|
| 115 |
out = torch.empty(P, D, device=active.device, dtype=sdr_proj_weight.dtype)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
BLOCK_D = min(256, triton.next_power_of_2(D))
|
| 118 |
grid = (P * triton.cdiv(D, BLOCK_D),)
|
| 119 |
|
|
|
|
| 114 |
|
| 115 |
out = torch.empty(P, D, device=active.device, dtype=sdr_proj_weight.dtype)
|
| 116 |
|
| 117 |
+
if not active.is_cuda:
|
| 118 |
+
# Local CPU validation has no Triton driver. Keep the same custom
|
| 119 |
+
# autograd contract but use a deterministic gather+sum fallback.
|
| 120 |
+
out = wt[active].sum(dim=1).to(dtype=sdr_proj_weight.dtype)
|
| 121 |
+
ctx.save_for_backward(active, token_ids, sdr_proj_weight, delta_u, delta_v)
|
| 122 |
+
return out.view(B, T, D)
|
| 123 |
+
|
| 124 |
BLOCK_D = min(256, triton.next_power_of_2(D))
|
| 125 |
grid = (P * triton.cdiv(D, BLOCK_D),)
|
| 126 |
|
overlay/subsystems/htm.py
CHANGED
|
@@ -99,13 +99,19 @@ class HTMLayer(nn.Module):
|
|
| 99 |
self._forward_counter = 0
|
| 100 |
# GPU backend gate. Default: auto-detect β use GPU when the pyo3
|
| 101 |
# module was built with --features gpu AND CUDA is actually usable.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
if use_gpu is None:
|
| 103 |
-
use_gpu = _HTM_HAS_GPU and torch.cuda.is_available()
|
| 104 |
elif use_gpu and not _HTM_HAS_GPU:
|
| 105 |
raise RuntimeError(
|
| 106 |
"HTMLayer(use_gpu=True) but htm_rust was not built with "
|
| 107 |
"--features gpu. Re-run `maturin develop --features gpu`."
|
| 108 |
)
|
|
|
|
|
|
|
| 109 |
self._use_gpu = bool(use_gpu)
|
| 110 |
cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion
|
| 111 |
self._region_cls = cls
|
|
|
|
| 99 |
self._forward_counter = 0
|
| 100 |
# GPU backend gate. Default: auto-detect β use GPU when the pyo3
|
| 101 |
# module was built with --features gpu AND CUDA is actually usable.
|
| 102 |
+
# HYDRA_FORCE_HTM_CPU=1 is an operational safety valve for paid remote
|
| 103 |
+
# canaries when the compiled CUDA HTM backend is present but unstable on
|
| 104 |
+
# a specific hardware/runtime combination.
|
| 105 |
+
force_cpu = _os.environ.get("HYDRA_FORCE_HTM_CPU", "0") == "1"
|
| 106 |
if use_gpu is None:
|
| 107 |
+
use_gpu = (not force_cpu) and _HTM_HAS_GPU and torch.cuda.is_available()
|
| 108 |
elif use_gpu and not _HTM_HAS_GPU:
|
| 109 |
raise RuntimeError(
|
| 110 |
"HTMLayer(use_gpu=True) but htm_rust was not built with "
|
| 111 |
"--features gpu. Re-run `maturin develop --features gpu`."
|
| 112 |
)
|
| 113 |
+
elif use_gpu and force_cpu:
|
| 114 |
+
use_gpu = False
|
| 115 |
self._use_gpu = bool(use_gpu)
|
| 116 |
cls = htm_rust.HTMRegionGpu if self._use_gpu else htm_rust.HTMRegion
|
| 117 |
self._region_cls = cls
|
overlay/subsystems/sdr_semantic.py
CHANGED
|
@@ -46,19 +46,10 @@ class _SDRSTE(torch.autograd.Function):
|
|
| 46 |
flat_grad = grad_out.reshape(B * T, n_bits).to(delta_v.dtype)
|
| 47 |
flat_ids = token_ids.reshape(B * T)
|
| 48 |
V = delta_u.shape[0]
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
# = index_add(V, R, flat_ids, flat_grad @ delta_v.T)
|
| 54 |
-
projected = flat_grad @ delta_v.t() # (B*T, R) β ~1MB at B=8,T=1024,R=32
|
| 55 |
-
per_tok_u = torch.zeros(V, R, device=flat_grad.device, dtype=delta_v.dtype)
|
| 56 |
-
per_tok_u.index_add_(0, flat_ids, projected)
|
| 57 |
-
grad_delta_u = per_tok_u # (V, R) β ~8MB at V=65536
|
| 58 |
-
# grad_delta_v = sum_{pos} delta_u[flat_ids[pos]]^T @ flat_grad[pos]
|
| 59 |
-
# = delta_u[flat_ids].T @ flat_grad β no intermediate buffer
|
| 60 |
-
gathered_u = delta_u[flat_ids] # (B*T, R) β ~1MB
|
| 61 |
-
grad_delta_v = gathered_u.t() @ flat_grad # (R, n_bits) β ~2MB
|
| 62 |
return None, grad_delta_u, grad_delta_v, None
|
| 63 |
|
| 64 |
|
|
@@ -249,25 +240,12 @@ class SemanticFoldingSDR(nn.Module):
|
|
| 249 |
sdr_binary = sdr_binary.view(B, T, self.n_bits)
|
| 250 |
return _SDRSTE.apply(sdr_binary, self.delta_u, self.delta_v, token_ids)
|
| 251 |
|
| 252 |
-
@torch.no_grad()
|
| 253 |
-
def active_indices(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 254 |
-
"""Compact int16 Reality Buffer view: (B,T,K) active retina offsets.
|
| 255 |
-
|
| 256 |
-
This is the production discrete bridge for Cantor/Engram routing. It
|
| 257 |
-
avoids reconstructing dense (B,T,n_bits) masks when consumers only need
|
| 258 |
-
the L0 support set.
|
| 259 |
-
"""
|
| 260 |
-
if token_ids.dim() != 2:
|
| 261 |
-
raise ValueError(f"expected (B, T) token_ids, got shape {tuple(token_ids.shape)}")
|
| 262 |
-
B, T = token_ids.shape
|
| 263 |
-
return self._retina_indices[token_ids.reshape(-1)].view(B, T, self.target_active)
|
| 264 |
-
|
| 265 |
@torch.no_grad()
|
| 266 |
def binary_only(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 267 |
"""uint8 retina view β no STE, no autocast cost. For HTM/consumers that
|
| 268 |
only need the binary pattern. Reconstructs dense from CSR indices."""
|
| 269 |
B, T = token_ids.shape
|
| 270 |
-
idx = self.
|
| 271 |
sdr = torch.zeros(
|
| 272 |
B * T, self.n_bits, dtype=torch.uint8, device=token_ids.device,
|
| 273 |
)
|
|
|
|
| 46 |
flat_grad = grad_out.reshape(B * T, n_bits).to(delta_v.dtype)
|
| 47 |
flat_ids = token_ids.reshape(B * T)
|
| 48 |
V = delta_u.shape[0]
|
| 49 |
+
per_tok = torch.zeros(V, n_bits, device=flat_grad.device, dtype=delta_v.dtype)
|
| 50 |
+
per_tok.index_add_(0, flat_ids, flat_grad)
|
| 51 |
+
grad_delta_u = per_tok @ delta_v.t()
|
| 52 |
+
grad_delta_v = delta_u.t() @ per_tok
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
return None, grad_delta_u, grad_delta_v, None
|
| 54 |
|
| 55 |
|
|
|
|
| 240 |
sdr_binary = sdr_binary.view(B, T, self.n_bits)
|
| 241 |
return _SDRSTE.apply(sdr_binary, self.delta_u, self.delta_v, token_ids)
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
@torch.no_grad()
|
| 244 |
def binary_only(self, token_ids: torch.Tensor) -> torch.Tensor:
|
| 245 |
"""uint8 retina view β no STE, no autocast cost. For HTM/consumers that
|
| 246 |
only need the binary pattern. Reconstructs dense from CSR indices."""
|
| 247 |
B, T = token_ids.shape
|
| 248 |
+
idx = self._retina_indices[token_ids.reshape(-1)] # (B*T, K) int16
|
| 249 |
sdr = torch.zeros(
|
| 250 |
B * T, self.n_bits, dtype=torch.uint8, device=token_ids.device,
|
| 251 |
)
|