Spaces:
Runtime error
Runtime error
| //! GPU Temporal Memory. | |
| //! | |
| //! Flat device storage. Pre-allocated segment slab: | |
| //! n_cells = n_columns * cells_per_column | |
| //! n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL | |
| //! n_synapses_max = n_segments_max * MAX_SYN_PER_SEGMENT | |
| //! | |
| //! Defaults (CPU parity targets relaxed on GPU to keep memory tractable): | |
| //! MAX_SEGMENTS_PER_CELL = 16 | |
| //! MAX_SYN_PER_SEGMENT = 32 | |
| //! | |
| //! At n_cells = 65536: | |
| //! n_segments_max = 1_048_576 (~1M) | |
| //! n_synapses_max = 33_554_432 (~33M) | |
| //! Storage: | |
| //! syn_presyn : u32 × 33M = 128 MB | |
| //! syn_perm : i16 × 33M = 64 MB | |
| //! seg_cell : u32 × 1M = 4 MB | |
| //! seg_syn_n : u32 × 1M = 4 MB | |
| //! misc bitsets etc ~ <1 MB | |
| //! ------------------------------- | |
| //! Total per region ~200 MB | |
| //! | |
| //! Permanences are stored as i16 scaled by 32767 (→ [0, 32767] represents | |
| //! [0.0, 1.0]). inc/dec are provided pre-scaled. | |
| use std::sync::Arc; | |
| use cudarc::driver::{CudaDevice, CudaSlice, DriverError, DeviceRepr, LaunchAsync, LaunchConfig}; | |
| use cudarc::nvrtc::Ptx; | |
| /// Packed config struct passed by value to TM kernels to stay under | |
| /// cudarc's 12-tuple launch limit. Layout must match the C-side | |
| /// `TmConfig` struct declared in each kernel. | |
| pub struct TmConfig { | |
| pub activation_threshold: u32, | |
| pub learning_threshold: u32, | |
| pub cells_per_column: u32, | |
| pub synapses_per_segment: u32, | |
| pub n_segments: u32, | |
| pub n_cells: u32, | |
| pub max_segments_per_cell: u32, | |
| pub max_new_synapses: u32, | |
| pub conn_thr_i16: i32, // i16 widened to i32 for alignment | |
| pub perm_inc_i16: i32, | |
| pub perm_dec_i16: i32, | |
| pub predicted_seg_dec_i16: i32, | |
| pub initial_perm_i16: i32, | |
| pub iter_seed: u32, | |
| pub n_cols: u32, | |
| pub bits_words: u32, | |
| } | |
| unsafe impl DeviceRepr for TmConfig {} | |
| // Embedded PTX. | |
| const PTX_TM_PREDICT: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_predict.ptx")); | |
| const PTX_TM_ACTIVATE: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_activate.ptx")); | |
| const PTX_TM_LEARN: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_learn.ptx")); | |
| const PTX_TM_PUNISH: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_punish.ptx")); | |
| const PTX_TM_GROW: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_grow.ptx")); | |
| const PTX_TM_ANOMALY: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_anomaly.ptx")); | |
| const PTX_TM_RESET: &str = include_str!(concat!(env!("HTM_GPU_PTX_DIR"), "/tm_reset.ptx")); | |
| /// Capacity trade-offs for 6 GB VRAM (RTX 3060) shared with the model: | |
| /// n_cells = 2048 × 32 = 65_536 | |
| /// n_segments_max = n_cells × MAX_SEGMENTS_PER_CELL | |
| /// n_synapses_max = n_segments_max × MAX_SYN_PER_SEGMENT | |
| /// | |
| /// At 4/20 these are 262_144 segments and ~5.2M synapses (~50 MB per region). | |
| /// The training loop runs with `reset_each_forward=True`, so segment counts | |
| /// per window stay well below 32K (typical: ~n_cols new segs per step until | |
| /// the first matching segment is reused; in a 2048-step window that plateaus | |
| /// around ~5K total live segments). The 262K ceiling is generous headroom. | |
| pub const MAX_SEGMENTS_PER_CELL: usize = 4; | |
| pub const MAX_SYN_PER_SEGMENT: usize = 20; | |
| const PERM_SCALE: f32 = 32767.0; | |
| fn perm_f32_to_i16(x: f32) -> i16 { | |
| let clamped = x.clamp(0.0, 1.0); | |
| (clamped * PERM_SCALE).round() as i16 | |
| } | |
| pub struct TemporalMemoryGpu { | |
| dev: Arc<CudaDevice>, | |
| // Config mirror | |
| pub n_columns: usize, | |
| pub cells_per_column: usize, | |
| pub activation_threshold: u32, | |
| pub learning_threshold: u32, | |
| pub initial_perm_i16: i16, | |
| pub conn_thr_i16: i16, | |
| pub perm_inc_i16: i16, | |
| pub perm_dec_i16: i16, | |
| pub predicted_seg_dec_i16: i16, | |
| pub max_new_synapse_count: u32, | |
| // Sizes | |
| pub n_cells: usize, | |
| pub n_segments_max: usize, | |
| pub bits_words: usize, // n_cells / 32 | |
| // Persistent device buffers | |
| seg_cell_id: CudaSlice<u32>, | |
| seg_syn_count: CudaSlice<u32>, | |
| syn_presyn: CudaSlice<u32>, | |
| syn_perm: CudaSlice<i16>, | |
| cell_seg_count: CudaSlice<u32>, | |
| cell_active_bits: CudaSlice<u32>, | |
| cell_winner_bits: CudaSlice<u32>, | |
| cell_predictive_bits: CudaSlice<u32>, | |
| prev_active_bits: CudaSlice<u32>, | |
| prev_winner_bits: CudaSlice<u32>, | |
| col_predicted: CudaSlice<u8>, | |
| seg_num_active_conn: CudaSlice<u32>, | |
| seg_num_active_pot: CudaSlice<u32>, | |
| unpredicted_count: CudaSlice<u32>, | |
| burst_cols_flat: CudaSlice<u32>, | |
| burst_cols_count: CudaSlice<u32>, | |
| col_best_match: CudaSlice<u32>, | |
| iter_counter: u32, | |
| } | |
| impl TemporalMemoryGpu { | |
| pub fn new( | |
| dev: Arc<CudaDevice>, | |
| n_columns: usize, | |
| cells_per_column: usize, | |
| ) -> Result<Self, DriverError> { | |
| let n_cells = n_columns * cells_per_column; | |
| assert!(n_cells % 32 == 0, "n_cells must be divisible by 32 for bitsets"); | |
| let n_segments_max = n_cells * MAX_SEGMENTS_PER_CELL; | |
| let bits_words = n_cells / 32; | |
| // Numenta defaults. | |
| let activation_threshold = 15u32; | |
| let learning_threshold = 13u32; | |
| let initial_perm_i16 = perm_f32_to_i16(0.21); | |
| let conn_thr_i16 = perm_f32_to_i16(0.50); | |
| let perm_inc_i16 = perm_f32_to_i16(0.10); | |
| let perm_dec_i16 = perm_f32_to_i16(0.10); | |
| let predicted_seg_dec_i16 = perm_f32_to_i16(0.10); | |
| let max_new_synapse_count = 20u32; | |
| // Allocate buffers. | |
| let seg_cell_id_host: Vec<u32> = vec![u32::MAX; n_segments_max]; | |
| let seg_cell_id = dev.htod_sync_copy(&seg_cell_id_host)?; | |
| let seg_syn_count = dev.alloc_zeros::<u32>(n_segments_max)?; | |
| let syn_presyn = dev.alloc_zeros::<u32>(n_segments_max * MAX_SYN_PER_SEGMENT)?; | |
| let syn_perm = dev.alloc_zeros::<i16>(n_segments_max * MAX_SYN_PER_SEGMENT)?; | |
| let cell_seg_count = dev.alloc_zeros::<u32>(n_cells)?; | |
| let cell_active_bits = dev.alloc_zeros::<u32>(bits_words)?; | |
| let cell_winner_bits = dev.alloc_zeros::<u32>(bits_words)?; | |
| let cell_predictive_bits = dev.alloc_zeros::<u32>(bits_words)?; | |
| let prev_active_bits = dev.alloc_zeros::<u32>(bits_words)?; | |
| let prev_winner_bits = dev.alloc_zeros::<u32>(bits_words)?; | |
| let col_predicted = dev.alloc_zeros::<u8>(n_columns)?; | |
| let seg_num_active_conn = dev.alloc_zeros::<u32>(n_segments_max)?; | |
| let seg_num_active_pot = dev.alloc_zeros::<u32>(n_segments_max)?; | |
| let unpredicted_count = dev.alloc_zeros::<u32>(1)?; | |
| // Bursting columns for one step bounded by n_columns. | |
| let burst_cols_flat = dev.alloc_zeros::<u32>(n_columns)?; | |
| let burst_cols_count = dev.alloc_zeros::<u32>(1)?; | |
| let col_best_match = dev.alloc_zeros::<u32>(n_columns)?; | |
| // Load PTX modules. | |
| let modules = [ | |
| ("htm_tm_predict", PTX_TM_PREDICT, "tm_predict"), | |
| ("htm_tm_activate", PTX_TM_ACTIVATE, "tm_activate"), | |
| ("htm_tm_learn", PTX_TM_LEARN, "tm_learn_reinforce"), | |
| ("htm_tm_punish", PTX_TM_PUNISH, "tm_punish"), | |
| ("htm_tm_grow", PTX_TM_GROW, "tm_grow"), | |
| ("htm_tm_anomaly", PTX_TM_ANOMALY, "tm_anomaly"), | |
| ("htm_tm_reset", PTX_TM_RESET, "tm_reset_step"), | |
| ]; | |
| for (modname, ptx, fnname) in modules { | |
| if dev.get_func(modname, fnname).is_none() { | |
| dev.load_ptx(Ptx::from_src(ptx), modname, &[fnname])?; | |
| } | |
| } | |
| Ok(Self { | |
| dev, | |
| n_columns, | |
| cells_per_column, | |
| activation_threshold, | |
| learning_threshold, | |
| initial_perm_i16, | |
| conn_thr_i16, | |
| perm_inc_i16, | |
| perm_dec_i16, | |
| predicted_seg_dec_i16, | |
| max_new_synapse_count, | |
| n_cells, | |
| n_segments_max, | |
| bits_words, | |
| seg_cell_id, | |
| seg_syn_count, | |
| syn_presyn, | |
| syn_perm, | |
| cell_seg_count, | |
| cell_active_bits, | |
| cell_winner_bits, | |
| cell_predictive_bits, | |
| prev_active_bits, | |
| prev_winner_bits, | |
| col_predicted, | |
| seg_num_active_conn, | |
| seg_num_active_pot, | |
| unpredicted_count, | |
| burst_cols_flat, | |
| burst_cols_count, | |
| col_best_match, | |
| iter_counter: 0, | |
| }) | |
| } | |
| // --- Fused-path accessors --- | |
| pub fn seg_cell_id_accessor(&self) -> &CudaSlice<u32> { &self.seg_cell_id } | |
| pub fn seg_syn_count_accessor(&self) -> &CudaSlice<u32> { &self.seg_syn_count } | |
| pub fn syn_presyn_accessor(&self) -> &CudaSlice<u32> { &self.syn_presyn } | |
| pub fn syn_perm_accessor(&self) -> &CudaSlice<i16> { &self.syn_perm } | |
| pub fn cell_seg_count_accessor(&self) -> &CudaSlice<u32> { &self.cell_seg_count } | |
| /// Hard reset — clear everything (predictive + active + segments). | |
| pub fn reset(&mut self) -> Result<(), DriverError> { | |
| // Restore "unused" sentinel in seg_cell_id. | |
| let unused_host: Vec<u32> = vec![u32::MAX; self.n_segments_max]; | |
| self.dev.htod_sync_copy_into(&unused_host, &mut self.seg_cell_id)?; | |
| self.dev.memset_zeros(&mut self.seg_syn_count)?; | |
| self.dev.memset_zeros(&mut self.cell_seg_count)?; | |
| self.dev.memset_zeros(&mut self.cell_active_bits)?; | |
| self.dev.memset_zeros(&mut self.cell_winner_bits)?; | |
| self.dev.memset_zeros(&mut self.cell_predictive_bits)?; | |
| self.dev.memset_zeros(&mut self.prev_active_bits)?; | |
| self.dev.memset_zeros(&mut self.prev_winner_bits)?; | |
| self.dev.memset_zeros(&mut self.col_best_match)?; | |
| self.iter_counter = 0; | |
| Ok(()) | |
| } | |
| fn build_cfg(&self) -> TmConfig { | |
| TmConfig { | |
| activation_threshold: self.activation_threshold, | |
| learning_threshold: self.learning_threshold, | |
| cells_per_column: self.cells_per_column as u32, | |
| synapses_per_segment: MAX_SYN_PER_SEGMENT as u32, | |
| n_segments: self.n_segments_max as u32, | |
| n_cells: self.n_cells as u32, | |
| max_segments_per_cell: MAX_SEGMENTS_PER_CELL as u32, | |
| max_new_synapses: self.max_new_synapse_count, | |
| conn_thr_i16: self.conn_thr_i16 as i32, | |
| perm_inc_i16: self.perm_inc_i16 as i32, | |
| perm_dec_i16: self.perm_dec_i16 as i32, | |
| predicted_seg_dec_i16: self.predicted_seg_dec_i16 as i32, | |
| initial_perm_i16: self.initial_perm_i16 as i32, | |
| iter_seed: self.iter_counter, | |
| n_cols: self.n_columns as u32, | |
| bits_words: self.bits_words as u32, | |
| } | |
| } | |
| /// Run one TM step on the GPU. Takes the SP active-column mask (u8, already | |
| /// on device) and writes `anomaly_out[t_slot]`. | |
| pub fn step( | |
| &mut self, | |
| sp_active_mask: &CudaSlice<u8>, | |
| anomaly_out: &mut CudaSlice<f32>, | |
| t_slot: u32, | |
| learn: bool, | |
| ) -> Result<(), DriverError> { | |
| let n_cells = self.n_cells; | |
| let n_cols = self.n_columns; | |
| let predict_fn = self.dev.get_func("htm_tm_predict", "tm_predict").unwrap(); | |
| let activate_fn = self.dev.get_func("htm_tm_activate", "tm_activate").unwrap(); | |
| let learn_fn = self.dev.get_func("htm_tm_learn", "tm_learn_reinforce").unwrap(); | |
| let punish_fn = self.dev.get_func("htm_tm_punish", "tm_punish").unwrap(); | |
| let grow_fn = self.dev.get_func("htm_tm_grow", "tm_grow").unwrap(); | |
| let anom_fn = self.dev.get_func("htm_tm_anomaly", "tm_anomaly").unwrap(); | |
| let reset_fn = self.dev.get_func("htm_tm_reset", "tm_reset_step").unwrap(); | |
| self.iter_counter = self.iter_counter.wrapping_add(1); | |
| let cfg_val = self.build_cfg(); | |
| // 0. Per-step reset. | |
| let reset_words = self.bits_words.max(n_cols); | |
| let reset_cfg = LaunchConfig { | |
| grid_dim: (((reset_words + 255) / 256) as u32, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| reset_fn.clone().launch( | |
| reset_cfg, | |
| ( | |
| &mut self.cell_active_bits, | |
| &mut self.cell_winner_bits, | |
| &mut self.cell_predictive_bits, | |
| &mut self.prev_active_bits, | |
| &mut self.prev_winner_bits, | |
| &mut self.col_predicted, | |
| &mut self.unpredicted_count, | |
| &mut self.burst_cols_count, | |
| &mut self.col_best_match, | |
| self.bits_words as u32, | |
| n_cols as u32, | |
| ), | |
| )?; | |
| } | |
| // 1. Predict (grid = n_cells; each block iterates its cell's segments). | |
| let predict_cfg = LaunchConfig { | |
| grid_dim: (n_cells as u32, 1, 1), | |
| block_dim: (32, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| predict_fn.clone().launch( | |
| predict_cfg, | |
| ( | |
| &self.seg_cell_id, | |
| &self.seg_syn_count, | |
| &self.syn_presyn, | |
| &self.syn_perm, | |
| &self.prev_active_bits, | |
| &mut self.cell_predictive_bits, | |
| &mut self.col_predicted, | |
| &mut self.seg_num_active_conn, | |
| &mut self.seg_num_active_pot, | |
| &mut self.col_best_match, | |
| &self.cell_seg_count, | |
| cfg_val, | |
| ), | |
| )?; | |
| } | |
| // 2. Activate. | |
| let activate_cfg = LaunchConfig { | |
| grid_dim: (((n_cols + 255) / 256) as u32, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| activate_fn.clone().launch( | |
| activate_cfg, | |
| ( | |
| sp_active_mask, | |
| &self.col_predicted, | |
| &self.cell_predictive_bits, | |
| &mut self.cell_active_bits, | |
| &mut self.cell_winner_bits, | |
| &mut self.unpredicted_count, | |
| &mut self.burst_cols_flat, | |
| &mut self.burst_cols_count, | |
| cfg_val, | |
| ), | |
| )?; | |
| } | |
| // 3. Anomaly. | |
| let anom_cfg = LaunchConfig { | |
| grid_dim: (1, 1, 1), | |
| block_dim: (256, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| anom_fn.clone().launch( | |
| anom_cfg, | |
| ( | |
| sp_active_mask, | |
| &self.unpredicted_count, | |
| anomaly_out, | |
| t_slot, | |
| n_cols as u32, | |
| ), | |
| )?; | |
| } | |
| if learn { | |
| // 4. Reinforce (grid = n_cells). | |
| let learn_cfg = LaunchConfig { | |
| grid_dim: (n_cells as u32, 1, 1), | |
| block_dim: (32, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| learn_fn.clone().launch( | |
| learn_cfg, | |
| ( | |
| &self.seg_cell_id, | |
| &self.seg_syn_count, | |
| &self.syn_presyn, | |
| &mut self.syn_perm, | |
| &self.seg_num_active_conn, | |
| &self.prev_active_bits, | |
| sp_active_mask, | |
| &self.col_predicted, | |
| &self.cell_seg_count, | |
| cfg_val, | |
| ), | |
| )?; | |
| } | |
| // 5. Punish. | |
| unsafe { | |
| punish_fn.clone().launch( | |
| learn_cfg, | |
| ( | |
| &self.seg_cell_id, | |
| &self.seg_syn_count, | |
| &self.syn_presyn, | |
| &mut self.syn_perm, | |
| &self.seg_num_active_pot, | |
| &self.prev_active_bits, | |
| sp_active_mask, | |
| &self.cell_seg_count, | |
| cfg_val, | |
| ), | |
| )?; | |
| } | |
| // 6. Grow. | |
| let grow_cfg = LaunchConfig { | |
| grid_dim: (n_cols as u32, 1, 1), | |
| block_dim: (32, 1, 1), | |
| shared_mem_bytes: 0, | |
| }; | |
| unsafe { | |
| grow_fn.clone().launch( | |
| grow_cfg, | |
| ( | |
| &mut self.seg_cell_id, | |
| &mut self.seg_syn_count, | |
| &mut self.syn_presyn, | |
| &mut self.syn_perm, | |
| &mut self.cell_seg_count, | |
| &self.burst_cols_flat, | |
| &self.burst_cols_count, | |
| &self.prev_winner_bits, | |
| &self.prev_active_bits, | |
| &self.col_best_match, | |
| cfg_val, | |
| ), | |
| )?; | |
| } | |
| } | |
| Ok(()) | |
| } | |
| } | |