Jackoatmon's picture
Update Feather H200 runtime: Nemotron streaming and HTM force-CPU canary fixes
c2bf4b6 verified
//! GPU Temporal Memory.
//!
//! Flat device storage. Pre-allocated segment slab:
//! n_cells = n_columns * cells_per_column
//! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL
//! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT
//!
//! Defaults (CPU parity targets relaxed on GPU to keep memory tractable):
//! MAX_SEGMENTS_PER_CELL = 16
//! MAX_SYN_PER_SEGMENT = 32
//!
//! At n_cells = 65536:
//! n_segments_max = 1_048_576 (~1M)
//! n_synapses_max = 33_554_432 (~33M)
//! Storage:
//! syn_presyn : u32 × 33M = 128 MB
//! syn_perm : i16 × 33M = 64 MB
//! seg_cell : u32 × 1M = 4 MB
//! seg_syn_n : u32 × 1M = 4 MB
//! misc bitsets etc ~ <1 MB
//! -------------------------------
//! Total per region ~200 MB
//!
//! Permanences are stored as i16 scaled by 32767 (→ [0, 32767] represents
//! [0.0, 1.0]). inc/dec are provided pre-scaled.
use std::sync::Arc;
use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig};
use cudarc::nvrtc::Ptx;
/// Packed config struct passed by value to TM kernels to stay under
/// cudarc's 12-tuple launch limit. Layout must match the C-side
/// `TmConfig` struct declared in each kernel.
#[repr(C)]
#[derive(Clone, Copy)]
pub struct TmConfig {
pub activation_threshold: u32,
pub learning_threshold: u32,
pub cells_per_column: u32,
pub synapses_per_segment: u32,
pub n_segments: u32,
pub n_cells: u32,
pub max_segments_per_cell: u32,
pub max_new_synapses: u32,
pub conn_thr_i16: i32, // i16 widened to i32 for alignment
pub perm_inc_i16: i32,
pub perm_dec_i16: i32,
pub predicted_seg_dec_i16: i32,
pub initial_perm_i16: i32,
pub iter_seed: u32,
pub n_cols: u32,
pub bits_words: u32,
}
unsafe impl DeviceRepr for TmConfig {}
// Embedded PTX.
const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx"));
const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx"));
const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx"));
const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx"));
const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx"));
const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx"));
const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx"));
/// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model:
/// n_cells = 2048 × 32 = 65_536
/// n_segments_max = n_cells × MAX_SEGMENTS_PER_CELL
/// n_synapses_max = n_segments_max × MAX_SYN_PER_SEGMENT
///
/// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region).
/// The training loop runs with `reset_each_forward=True`, so segment counts
/// per window stay well below 32K (typical: ~n_cols new segs per step until
/// the first matching segment is reused; in a 2048-step window that plateaus
/// around ~5K total live segments). The 262K ceiling is generous headroom.
pub const MAX_SEGMENTS_PER_CELL: usize = 4;
pub const MAX_SYN_PER_SEGMENT: usize = 20;
const PERM_SCALE: f32 = 32767.0;
fn perm_f32_to_i16(x: f32) -> i16 {
let clamped = x.clamp(0.0, 1.0);
(clamped * PERM_SCALE).round() as i16
}
pub struct TemporalMemoryGpu {
dev: Arc<CudaDevice>,
// Config mirror
pub n_columns: usize,
pub cells_per_column: usize,
pub activation_threshold: u32,
pub learning_threshold: u32,
pub initial_perm_i16: i16,
pub conn_thr_i16: i16,
pub perm_inc_i16: i16,
pub perm_dec_i16: i16,
pub predicted_seg_dec_i16: i16,
pub max_new_synapse_count: u32,
// Sizes
pub n_cells: usize,
pub n_segments_max: usize,
pub bits_words: usize, // n_cells / 32
// Persistent device buffers
seg_cell_id: CudaSlice<u32>,
seg_syn_count: CudaSlice<u32>,
syn_presyn: CudaSlice<u32>,
syn_perm: CudaSlice<i16>,
cell_seg_count: CudaSlice<u32>,
cell_active_bits: CudaSlice<u32>,
cell_winner_bits: CudaSlice<u32>,
cell_predictive_bits: CudaSlice<u32>,
prev_active_bits: CudaSlice<u32>,
prev_winner_bits: CudaSlice<u32>,
col_predicted: CudaSlice<u8>,
seg_num_active_conn: CudaSlice<u32>,
seg_num_active_pot: CudaSlice<u32>,
unpredicted_count: CudaSlice<u32>,
burst_cols_flat: CudaSlice<u32>,
burst_cols_count: CudaSlice<u32>,
col_best_match: CudaSlice<u32>,
iter_counter: u32,
}
impl TemporalMemoryGpu {
pub fn new(
dev: Arc<CudaDevice>,
n_columns: usize,
cells_per_column: usize,
) -> Result<Self, DriverError> {
let n_cells = n_columns * cells_per_column;
assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets");
let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL;
let bits_words = n_cells / 32;
// Numenta defaults.
let activation_threshold = 15u32;
let learning_threshold = 13u32;
let initial_perm_i16 = perm_f32_to_i16(0.21);
let conn_thr_i16 = perm_f32_to_i16(0.50);
let perm_inc_i16 = perm_f32_to_i16(0.10);
let perm_dec_i16 = perm_f32_to_i16(0.10);
let predicted_seg_dec_i16 = perm_f32_to_i16(0.10);
let max_new_synapse_count = 20u32;
// Allocate buffers.
let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max];
let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?;
let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?;
let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?;
let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?;
let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?;
let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?;
let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?;
let col_predicted = dev.alloc_zeros::<u8>(n_columns)?;
let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?;
let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?;
let unpredicted_count = dev.alloc_zeros::<u32>(1)?;
// Bursting columns for one step bounded by n_columns.
let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?;
let burst_cols_count = dev.alloc_zeros::<u32>(1)?;
let col_best_match = dev.alloc_zeros::<u32>(n_columns)?;
// Load PTX modules.
let modules = [
("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"),
("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"),
("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"),
("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"),
("htm_tm_grow", PTX_TM_GROW, "tm_grow"),
("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"),
("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"),
];
for (modname, ptx, fnname) in modules {
if dev.get_func(modname, fnname).is_none() {
dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?;
}
}
Ok(Self {
dev,
n_columns,
cells_per_column,
activation_threshold,
learning_threshold,
initial_perm_i16,
conn_thr_i16,
perm_inc_i16,
perm_dec_i16,
predicted_seg_dec_i16,
max_new_synapse_count,
n_cells,
n_segments_max,
bits_words,
seg_cell_id,
seg_syn_count,
syn_presyn,
syn_perm,
cell_seg_count,
cell_active_bits,
cell_winner_bits,
cell_predictive_bits,
prev_active_bits,
prev_winner_bits,
col_predicted,
seg_num_active_conn,
seg_num_active_pot,
unpredicted_count,
burst_cols_flat,
burst_cols_count,
col_best_match,
iter_counter: 0,
})
}
// --- Fused-path accessors ---
pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id }
pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count }
pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn }
pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm }
pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count }
/// Hard reset — clear everything (predictive + active + segments).
pub fn reset(&mut self) -> Result<(), DriverError> {
// Restore "unused" sentinel in seg_cell_id.
let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max];
self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?;
self.dev.memset_zeros(&mut self.seg_syn_count)?;
self.dev.memset_zeros(&mut self.cell_seg_count)?;
self.dev.memset_zeros(&mut self.cell_active_bits)?;
self.dev.memset_zeros(&mut self.cell_winner_bits)?;
self.dev.memset_zeros(&mut self.cell_predictive_bits)?;
self.dev.memset_zeros(&mut self.prev_active_bits)?;
self.dev.memset_zeros(&mut self.prev_winner_bits)?;
self.dev.memset_zeros(&mut self.col_best_match)?;
self.iter_counter = 0;
Ok(())
}
fn build_cfg(&self) -> TmConfig {
TmConfig {
activation_threshold: self.activation_threshold,
learning_threshold: self.learning_threshold,
cells_per_column: self.cells_per_column as u32,
synapses_per_segment: MAX_SYN_PER_SEGMENT as u32,
n_segments: self.n_segments_max as u32,
n_cells: self.n_cells as u32,
max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32,
max_new_synapses: self.max_new_synapse_count,
conn_thr_i16: self.conn_thr_i16 as i32,
perm_inc_i16: self.perm_inc_i16 as i32,
perm_dec_i16: self.perm_dec_i16 as i32,
predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32,
initial_perm_i16: self.initial_perm_i16 as i32,
iter_seed: self.iter_counter,
n_cols: self.n_columns as u32,
bits_words: self.bits_words as u32,
}
}
/// Run one TM step on the GPU. Takes the SP active-column mask (u8, already
/// on device) and writes `anomaly_out[t_slot]`.
pub fn step(
&mut self,
sp_active_mask: &CudaSlice<u8>,
anomaly_out: &mut CudaSlice<f32>,
t_slot: u32,
learn: bool,
) -> Result<(), DriverError> {
let n_cells = self.n_cells;
let n_cols = self.n_columns;
let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap();
let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap();
let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap();
let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap();
let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap();
let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap();
let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap();
self.iter_counter = self.iter_counter.wrapping_add(1);
let cfg_val = self.build_cfg();
// 0. Per-step reset.
let reset_words = self.bits_words.max(n_cols);
let reset_cfg = LaunchConfig {
grid_dim: (((reset_words + 255) / 256) as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
reset_fn.clone().launch(
reset_cfg,
(
&mut self.cell_active_bits,
&mut self.cell_winner_bits,
&mut self.cell_predictive_bits,
&mut self.prev_active_bits,
&mut self.prev_winner_bits,
&mut self.col_predicted,
&mut self.unpredicted_count,
&mut self.burst_cols_count,
&mut self.col_best_match,
self.bits_words as u32,
n_cols as u32,
),
)?;
}
// 1. Predict (grid = n_cells; each block iterates its cell's segments).
let predict_cfg = LaunchConfig {
grid_dim: (n_cells as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
predict_fn.clone().launch(
predict_cfg,
(
&self.seg_cell_id,
&self.seg_syn_count,
&self.syn_presyn,
&self.syn_perm,
&self.prev_active_bits,
&mut self.cell_predictive_bits,
&mut self.col_predicted,
&mut self.seg_num_active_conn,
&mut self.seg_num_active_pot,
&mut self.col_best_match,
&self.cell_seg_count,
cfg_val,
),
)?;
}
// 2. Activate.
let activate_cfg = LaunchConfig {
grid_dim: (((n_cols + 255) / 256) as u32, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
activate_fn.clone().launch(
activate_cfg,
(
sp_active_mask,
&self.col_predicted,
&self.cell_predictive_bits,
&mut self.cell_active_bits,
&mut self.cell_winner_bits,
&mut self.unpredicted_count,
&mut self.burst_cols_flat,
&mut self.burst_cols_count,
cfg_val,
),
)?;
}
// 3. Anomaly.
let anom_cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
anom_fn.clone().launch(
anom_cfg,
(
sp_active_mask,
&self.unpredicted_count,
anomaly_out,
t_slot,
n_cols as u32,
),
)?;
}
if learn {
// 4. Reinforce (grid = n_cells).
let learn_cfg = LaunchConfig {
grid_dim: (n_cells as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
learn_fn.clone().launch(
learn_cfg,
(
&self.seg_cell_id,
&self.seg_syn_count,
&self.syn_presyn,
&mut self.syn_perm,
&self.seg_num_active_conn,
&self.prev_active_bits,
sp_active_mask,
&self.col_predicted,
&self.cell_seg_count,
cfg_val,
),
)?;
}
// 5. Punish.
unsafe {
punish_fn.clone().launch(
learn_cfg,
(
&self.seg_cell_id,
&self.seg_syn_count,
&self.syn_presyn,
&mut self.syn_perm,
&self.seg_num_active_pot,
&self.prev_active_bits,
sp_active_mask,
&self.cell_seg_count,
cfg_val,
),
)?;
}
// 6. Grow.
let grow_cfg = LaunchConfig {
grid_dim: (n_cols as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
grow_fn.clone().launch(
grow_cfg,
(
&mut self.seg_cell_id,
&mut self.seg_syn_count,
&mut self.syn_presyn,
&mut self.syn_perm,
&mut self.cell_seg_count,
&self.burst_cols_flat,
&self.burst_cols_count,
&self.prev_winner_bits,
&self.prev_active_bits,
&self.col_best_match,
cfg_val,
),
)?;
}
}
Ok(())
}
}