| """ |
| SNNBlock: 完整的 SNN 隐状态空间 Block(并行化版本) |
| |
| 结构(每个 SNN 时间步): |
| spike_in {0,1}^D |
| ├─→ W_in → I[t] ∈ R^{D*N} |
| ├─→ W_β^(x) + b_β → σ → β(t) |
| ├─→ W_α^(x) + b_α → softplus → α(t) |
| ├─→ W_th^(x) + b_th → |·|+V_min → V_th(t) |
| ├─→ W_gate → sigmoid → gate ∈ (0,1)^D |
| └─→ W_skip → I_skip ∈ R^D |
| |
| SelectivePLIF(I, β, α, V_th) → s[t] ∈ {0,1}^{D*N} |
| |
| W_out · V_post[t] ⊙ gate + I_skip → 连续输出 ∈ R^D |
| |
| 数学原理见 SNN_SELECTIVE_STATE_SPACE.md。 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from spikingjelly.activation_based import base, layer, surrogate |
|
|
| from .selective_plif import SelectivePLIFNode |
| from .parallel_scan import plif_parallel_forward |
|
|
|
|
| |
| |
| |
| |
|
|
| @torch.compile(backend='inductor', fullgraph=True) |
| def _fused_modulation(raw_beta, b_beta, raw_alpha, b_alpha, raw_th, b_th, v_th_min, I_all): |
| beta = torch.sigmoid(raw_beta + b_beta) |
| alpha = F.softplus(raw_alpha + b_alpha) |
| v_th = v_th_min + torch.abs(raw_th + b_th) |
| u = alpha * I_all |
| return beta, u, v_th |
|
|
|
|
| class SNNBlock(base.MemoryModule): |
| """ |
| 单个 SNN Block(并行化)。 |
| |
| Args: |
| D: 可见维度(Block 间通信的维度) |
| N: 状态扩展因子(每个通道的隐神经元数) |
| v_th_min: 动态阈值下限 |
| surrogate_function: surrogate gradient 函数 |
| """ |
|
|
| def __init__( |
| self, |
| D: int, |
| N: int = 8, |
| v_th_min: float = 0.1, |
| surrogate_function=surrogate.Sigmoid(alpha=4.0), |
| ): |
| super().__init__() |
| self.D = D |
| self.N = N |
| self.v_th_min = v_th_min |
| DN = D * N |
|
|
| |
| self.W_in = layer.Linear(D, DN, bias=False, step_mode='s') |
| self.W_beta_x = layer.Linear(D, DN, bias=False, step_mode='s') |
| self.W_alpha_x = layer.Linear(D, DN, bias=False, step_mode='s') |
| self.W_th_x = layer.Linear(D, DN, bias=False, step_mode='s') |
| self.W_gate = layer.Linear(D, D, bias=False, step_mode='s') |
| self.W_skip = layer.Linear(D, D, bias=False, step_mode='s') |
|
|
| |
|
|
| |
| self.b_beta = nn.Parameter(torch.empty(DN)) |
| self.b_alpha = nn.Parameter(torch.empty(DN)) |
| self.b_th = nn.Parameter(torch.empty(DN)) |
|
|
| |
| self.W_out = layer.Linear(DN, D, bias=False, step_mode='s') |
|
|
| |
| self.hidden_neuron = SelectivePLIFNode( |
| surrogate_function=surrogate_function, |
| detach_reset=False, |
| ) |
|
|
| |
| self._initialize_parameters() |
|
|
| def _initialize_parameters(self): |
| """功能引导初始化。""" |
| D, N = self.D, self.N |
| K_ref = 16 |
|
|
| |
| beta_values = torch.linspace(0.80, 0.99, N) |
|
|
| |
| b_beta_per_n = torch.log(beta_values / (1.0 - beta_values)) |
| |
| self.b_beta.data.copy_(b_beta_per_n.repeat(D)) |
| self.b_beta.data.add_(torch.empty_like(self.b_beta).normal_(0, 0.1)) |
|
|
| |
| |
| self.b_alpha.data.normal_(0.5413, 0.1) |
|
|
| |
| for lin in [self.W_in, self.W_gate, self.W_skip, self.W_out]: |
| nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5)) |
| for lin in [self.W_beta_x, self.W_alpha_x, self.W_th_x]: |
| nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5)) |
| lin.weight.data.mul_(0.1) |
|
|
| |
| scale_per_n = torch.sqrt(1.0 - beta_values ** 2) |
| scale_DN = scale_per_n.repeat(D) |
| with torch.no_grad(): |
| self.W_in.weight.mul_(scale_DN.unsqueeze(1)) |
|
|
| |
| |
| |
| |
| |
| p_assumed = 0.15 |
| sigma_I_base = math.sqrt(p_assumed / 3.0) |
| sigma_V_per_n = sigma_I_base * torch.sqrt( |
| 1.0 - beta_values ** (2 * K_ref) |
| ) |
| target_p_fire = torch.linspace(0.25, 0.08, N) |
| z_scores = math.sqrt(2.0) * torch.erfinv( |
| 2.0 * (1.0 - target_p_fire) - 1.0 |
| ) |
| target_V_th = sigma_V_per_n * z_scores |
| b_th_per_n = torch.clamp(target_V_th - self.v_th_min, min=0.05) |
| |
| self.b_th.data.copy_(b_th_per_n.repeat(D)) |
| self.b_th.data.add_(torch.empty_like(self.b_th).normal_(0, 0.02)) |
|
|
| |
| out_scale_per_n = 1.0 / torch.sqrt(target_p_fire) |
| out_scale_per_n = out_scale_per_n / out_scale_per_n.mean() |
| out_scale_DN = out_scale_per_n.repeat(D) |
| with torch.no_grad(): |
| self.W_out.weight.mul_(out_scale_DN.unsqueeze(0)) |
|
|
| def forward_parallel(self, spike_in_seq: torch.Tensor) -> torch.Tensor: |
| """ |
| 并行前向传播:使用 parallel scan 处理全序列。 |
| |
| Args: |
| spike_in_seq: (TK, batch, D) — 全部 T×K 帧的输入 spike |
| |
| Returns: |
| continuous_out: (TK, batch, D) — 全部 T×K 帧的连续输出(V_post 经 W_out 投影) |
| """ |
| TK, batch, D = spike_in_seq.shape |
| DN = self.D * self.N |
|
|
| |
| flat = spike_in_seq.reshape(TK * batch, D) |
|
|
| I_all = F.linear(flat, self.W_in.weight).reshape(TK, batch, DN) |
| raw_beta = F.linear(flat, self.W_beta_x.weight).reshape(TK, batch, DN) |
| raw_alpha = F.linear(flat, self.W_alpha_x.weight).reshape(TK, batch, DN) |
| raw_th = F.linear(flat, self.W_th_x.weight).reshape(TK, batch, DN) |
| gate_all = torch.sigmoid( |
| F.linear(flat, self.W_gate.weight).reshape(TK, batch, D) |
| ) |
| I_skip_all = F.linear(flat, self.W_skip.weight).reshape(TK, batch, D) |
|
|
| |
| beta_all, u_hidden, v_th_all = _fused_modulation( |
| raw_beta, self.b_beta, raw_alpha, self.b_alpha, |
| raw_th, self.b_th, self.v_th_min, I_all, |
| ) |
|
|
| |
| v_init_hidden = self.hidden_neuron.v |
| if isinstance(v_init_hidden, float): |
| v_init_hidden = torch.zeros(batch, DN, device=flat.device, dtype=flat.dtype) |
|
|
| s_hidden, V_post_hidden, _ = plif_parallel_forward( |
| beta_all, u_hidden, v_th_all, v_init_hidden, max_iter=3, |
| surrogate_function=self.hidden_neuron.surrogate_function, |
| ) |
|
|
| |
| self.hidden_neuron.v = V_post_hidden[-1].detach() |
|
|
| |
| |
| |
| |
| v_flat = V_post_hidden.reshape(TK * batch, DN) |
| I_out_all = F.linear(v_flat, self.W_out.weight).reshape(TK, batch, D) |
| I_total_all = I_out_all * gate_all + I_skip_all |
|
|
| |
| return I_total_all |
|
|
| def single_step_forward(self, spike_in: torch.Tensor) -> torch.Tensor: |
| """ |
| 单步前向传播(用于调试/兼容)。 |
| |
| Args: |
| spike_in: 二值脉冲输入, shape (batch, D), 值域 {0, 1} |
| |
| Returns: |
| continuous_out: 连续输出, shape (batch, D) |
| """ |
| V_prev = self.hidden_neuron.v |
| if isinstance(V_prev, float): |
| V_prev = torch.zeros( |
| spike_in.shape[0], self.D * self.N, |
| device=spike_in.device, dtype=spike_in.dtype, |
| ) |
|
|
| I_t = self.W_in(spike_in) |
|
|
| |
| beta = torch.sigmoid(self.W_beta_x(spike_in) + self.b_beta) |
| alpha = F.softplus(self.W_alpha_x(spike_in) + self.b_alpha) |
| v_th = self.v_th_min + torch.abs(self.W_th_x(spike_in) + self.b_th) |
|
|
| gate = torch.sigmoid(self.W_gate(spike_in)) |
| I_skip = self.W_skip(spike_in) |
|
|
| s_hidden = self.hidden_neuron(I_t, beta, alpha, v_th) |
|
|
| |
| V_post = self.hidden_neuron.v |
| I_out = self.W_out(V_post) |
| I_total = I_out * gate + I_skip |
|
|
| return I_total |
|
|