| """ |
| Mamba-2 SSD — OPTIMIZED: intra-chunk parallelism via matrix multiply. |
| |
| The key Mamba-2 insight (State Space Duality): |
| Within each chunk of size T, the SSM can be computed as a MATRIX MULTIPLY: |
| |
| Y_chunk = (L ⊙ (C B^T)) @ (Δ ⊙ X) |
| |
| Where L is a lower-triangular mask with cumulative A products. |
| This replaces the T sequential steps with a single matmul of size T×T. |
| |
| For L=256, T=16, num_chunks=16: |
| - Within chunk: parallel matmul (T×T = 16×16) |
| - Across chunks: 16 sequential state carries (unavoidable, but trivial) |
| |
| Total: 16 sequential state carries + 16 parallel matmuls = FAST. |
| |
| NO in-place ops. Fully autograd safe. Works on CPU and GPU. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class Mamba2SSD(nn.Module): |
| """ |
| Mamba-2 SSD with intra-chunk matrix-multiply parallelism. |
| |
| Args: |
| dim: Input/output dimension |
| d_state: SSM state dimension (default 16) |
| d_conv: Conv1d kernel size (default 4) |
| expand: Inner dimension expansion (default 2) |
| chunk_size: Chunk size for scan (default 64 — larger = more parallel) |
| """ |
| |
| def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=64): |
| super().__init__() |
| self.dim = dim |
| self.d_state = d_state |
| self.chunk_size = chunk_size |
| self.inner_dim = dim * expand |
| |
| |
| self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False) |
| |
| |
| self.conv1d = nn.Conv1d( |
| self.inner_dim, self.inner_dim, |
| kernel_size=d_conv, padding=d_conv - 1, |
| groups=self.inner_dim, bias=True |
| ) |
| |
| |
| self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True) |
| self.B_proj = nn.Linear(self.inner_dim, d_state, bias=False) |
| self.C_proj = nn.Linear(self.inner_dim, d_state, bias=False) |
| |
| |
| A = torch.arange(1, d_state + 1, dtype=torch.float32) |
| self.A_log = nn.Parameter(torch.log(A)) |
| |
| |
| self.D = nn.Parameter(torch.ones(self.inner_dim)) |
| |
| |
| self.norm = nn.LayerNorm(self.inner_dim) |
| self.out_proj = nn.Linear(self.inner_dim, dim, bias=False) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| nn.init.constant_(self.dt_proj.bias, -4.0) |
| nn.init.xavier_uniform_(self.in_proj.weight, gain=0.1) |
| nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1) |
| |
| def forward(self, x): |
| """x: [B, L, dim] → [B, L, dim]""" |
| return self._process(x) |
| |
| def _process(self, x): |
| B, L, D = x.shape |
| |
| |
| xz = self.in_proj(x) |
| x_inner, z = xz.chunk(2, dim=-1) |
| |
| |
| x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2) |
| x_conv = F.silu(x_conv) |
| |
| |
| dt = F.softplus(self.dt_proj(x_conv)) |
| B_mat = self.B_proj(x_conv) |
| C_mat = self.C_proj(x_conv) |
| A = -torch.exp(self.A_log) |
| |
| |
| y = self._chunk_ssm(x_conv, dt, A, B_mat, C_mat) |
| |
| |
| y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0) |
| y = self.norm(y) * F.silu(z) |
| |
| return self.out_proj(y) |
| |
| def _chunk_ssm(self, u, dt, A, B, C): |
| """ |
| Chunk-parallel SSM computation. |
| |
| Within each chunk: compute via cumulative decay matrix (parallel). |
| Across chunks: propagate final state (sequential, only num_chunks steps). |
| |
| The intra-chunk computation uses the identity: |
| h_t = sum_{s=0}^{t} (prod_{k=s+1}^{t} dA_k) * dB_s * u_s |
| |
| This is a lower-triangular matrix-vector product, computable in parallel. |
| """ |
| batch, L, d_inner = u.shape |
| d_state = A.shape[0] |
| T = min(self.chunk_size, L) |
| |
| |
| pad = (T - L % T) % T |
| if pad > 0: |
| u = F.pad(u, (0, 0, 0, pad)) |
| dt = F.pad(dt, (0, 0, 0, pad)) |
| B = F.pad(B, (0, 0, 0, pad)) |
| C = F.pad(C, (0, 0, 0, pad)) |
| |
| L_pad = u.shape[1] |
| n_chunks = L_pad // T |
| |
| |
| u_c = u.reshape(batch, n_chunks, T, d_inner) |
| dt_c = dt.reshape(batch, n_chunks, T, d_inner) |
| B_c = B.reshape(batch, n_chunks, T, d_state) |
| C_c = C.reshape(batch, n_chunks, T, d_state) |
| |
| |
| dt_mean = dt_c.mean(dim=-1) |
| |
| |
| |
| log_dA = dt_mean.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0).unsqueeze(0) |
| |
| |
| |
| |
| log_dA_cumsum = torch.cumsum(log_dA, dim=2) |
| |
| |
| |
| |
| decay_matrix = log_dA_cumsum.unsqueeze(3) - log_dA_cumsum.unsqueeze(2) |
| |
| |
| |
| causal_mask = torch.tril(torch.ones(T, T, device=u.device)) |
| decay_matrix = decay_matrix * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1) |
| decay_matrix = torch.exp(decay_matrix) * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1) |
| |
| |
| |
| |
| |
| dBu = dt_c.unsqueeze(-2) * B_c.unsqueeze(-1) * u_c.unsqueeze(-2) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| h_intra = torch.einsum('bctsn,bcsnd->bctnd', decay_matrix, dBu) |
| |
| |
| |
| |
| |
| chunk_decay = torch.exp(log_dA_cumsum[:, :, -1, :]) |
| |
| |
| position_decay = torch.exp(log_dA_cumsum) |
| |
| |
| h_carry = torch.zeros(batch, d_state, d_inner, device=u.device) |
| h_chunks = [] |
| |
| for c_idx in range(n_chunks): |
| |
| |
| h_from_prev = position_decay[:, c_idx, :, :].unsqueeze(-1) * h_carry.unsqueeze(1) |
| |
| |
| |
| h_total = h_intra[:, c_idx] + h_from_prev |
| h_chunks.append(h_total) |
| |
| |
| h_carry = h_total[:, -1, :, :] |
| |
| |
| h_all = torch.stack(h_chunks, dim=1) |
| |
| |
| |
| y = torch.einsum('bctn,bctnd->bctd', C_c, h_all) |
| |
| |
| |
| y = y.reshape(batch, L_pad, d_inner) |
| return y[:, :L, :] |
|
|
|
|
| class Mamba2Block(nn.Module): |
| """ |
| Mamba-2 block with bidirectional scanning for 2D images. |
| Forward + backward raster scan, merged via learned projection. |
| """ |
| |
| def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
| |
| self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand) |
| self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand) |
| self.merge = nn.Linear(dim * 2, dim, bias=False) |
| |
| ff_dim = dim * expand |
| self.ff = nn.Sequential( |
| nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(ff_dim, dim), nn.Dropout(dropout), |
| ) |
| |
| def forward(self, x): |
| """x: [B, C, H, W] or [B, L, C]""" |
| is_2d = x.dim() == 4 |
| if is_2d: |
| B, C, H, W = x.shape |
| x = x.flatten(2).transpose(1, 2) |
| |
| residual = x |
| x_norm = self.norm1(x) |
| |
| fwd = self.ssd_fwd(x_norm) |
| bwd = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1]) |
| merged = self.merge(torch.cat([fwd, bwd], dim=-1)) |
| |
| x = residual + merged |
| x = x + self.ff(self.norm2(x)) |
| |
| if is_2d: |
| x = x.transpose(1, 2).reshape(B, C, H, W) |
| return x |
|
|