LiquidFlow-Gen / liquid_flow /mamba2_ssd.py
krystv's picture
Upload liquid_flow/mamba2_ssd.py
d7d1235 verified
"""
Mamba-2 SSD — OPTIMIZED: intra-chunk parallelism via matrix multiply.
The key Mamba-2 insight (State Space Duality):
Within each chunk of size T, the SSM can be computed as a MATRIX MULTIPLY:
Y_chunk = (L ⊙ (C B^T)) @ (Δ ⊙ X)
Where L is a lower-triangular mask with cumulative A products.
This replaces the T sequential steps with a single matmul of size T×T.
For L=256, T=16, num_chunks=16:
- Within chunk: parallel matmul (T×T = 16×16)
- Across chunks: 16 sequential state carries (unavoidable, but trivial)
Total: 16 sequential state carries + 16 parallel matmuls = FAST.
NO in-place ops. Fully autograd safe. Works on CPU and GPU.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Mamba2SSD(nn.Module):
"""
Mamba-2 SSD with intra-chunk matrix-multiply parallelism.
Args:
dim: Input/output dimension
d_state: SSM state dimension (default 16)
d_conv: Conv1d kernel size (default 4)
expand: Inner dimension expansion (default 2)
chunk_size: Chunk size for scan (default 64 — larger = more parallel)
"""
def __init__(self, dim, d_state=16, d_conv=4, expand=2, chunk_size=64):
super().__init__()
self.dim = dim
self.d_state = d_state
self.chunk_size = chunk_size
self.inner_dim = dim * expand
# Input projection: x and gate
self.in_proj = nn.Linear(dim, self.inner_dim * 2, bias=False)
# Short causal conv for local context
self.conv1d = nn.Conv1d(
self.inner_dim, self.inner_dim,
kernel_size=d_conv, padding=d_conv - 1,
groups=self.inner_dim, bias=True
)
# SSM parameter projections
self.dt_proj = nn.Linear(self.inner_dim, self.inner_dim, bias=True)
self.B_proj = nn.Linear(self.inner_dim, d_state, bias=False)
self.C_proj = nn.Linear(self.inner_dim, d_state, bias=False)
# A: fixed decay rates (log-space, negative for stability)
A = torch.arange(1, d_state + 1, dtype=torch.float32)
self.A_log = nn.Parameter(torch.log(A))
# D: residual skip
self.D = nn.Parameter(torch.ones(self.inner_dim))
# Output
self.norm = nn.LayerNorm(self.inner_dim)
self.out_proj = nn.Linear(self.inner_dim, dim, bias=False)
self._init_weights()
def _init_weights(self):
nn.init.constant_(self.dt_proj.bias, -4.0) # softplus(-4) ≈ 0.018
nn.init.xavier_uniform_(self.in_proj.weight, gain=0.1)
nn.init.xavier_uniform_(self.out_proj.weight, gain=0.1)
def forward(self, x):
"""x: [B, L, dim] → [B, L, dim]"""
return self._process(x)
def _process(self, x):
B, L, D = x.shape
# Input projection
xz = self.in_proj(x)
x_inner, z = xz.chunk(2, dim=-1)
# Causal conv
x_conv = self.conv1d(x_inner.transpose(1, 2))[:, :, :L].transpose(1, 2)
x_conv = F.silu(x_conv)
# SSM params
dt = F.softplus(self.dt_proj(x_conv)) # [B, L, inner_dim], positive
B_mat = self.B_proj(x_conv) # [B, L, d_state]
C_mat = self.C_proj(x_conv) # [B, L, d_state]
A = -torch.exp(self.A_log) # [d_state], negative
# Chunk-parallel scan
y = self._chunk_ssm(x_conv, dt, A, B_mat, C_mat)
# Skip + norm + gate
y = y + x_conv * self.D.unsqueeze(0).unsqueeze(0)
y = self.norm(y) * F.silu(z)
return self.out_proj(y)
def _chunk_ssm(self, u, dt, A, B, C):
"""
Chunk-parallel SSM computation.
Within each chunk: compute via cumulative decay matrix (parallel).
Across chunks: propagate final state (sequential, only num_chunks steps).
The intra-chunk computation uses the identity:
h_t = sum_{s=0}^{t} (prod_{k=s+1}^{t} dA_k) * dB_s * u_s
This is a lower-triangular matrix-vector product, computable in parallel.
"""
batch, L, d_inner = u.shape
d_state = A.shape[0]
T = min(self.chunk_size, L)
# Pad to multiple of T
pad = (T - L % T) % T
if pad > 0:
u = F.pad(u, (0, 0, 0, pad))
dt = F.pad(dt, (0, 0, 0, pad))
B = F.pad(B, (0, 0, 0, pad))
C = F.pad(C, (0, 0, 0, pad))
L_pad = u.shape[1]
n_chunks = L_pad // T
# Reshape: [B, n_chunks, T, ...]
u_c = u.reshape(batch, n_chunks, T, d_inner)
dt_c = dt.reshape(batch, n_chunks, T, d_inner)
B_c = B.reshape(batch, n_chunks, T, d_state)
C_c = C.reshape(batch, n_chunks, T, d_state)
# Mean dt per position for state decay (simplification for scalar-A)
dt_mean = dt_c.mean(dim=-1) # [B, n_chunks, T]
# Compute log(dA) per position: log_dA = dt_mean * A
# A is [d_state], dt_mean is [B, nc, T]
log_dA = dt_mean.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0).unsqueeze(0)
# log_dA: [B, nc, T, d_state]
# Cumulative sum for decay within chunk: cumsum along T dimension
# For position t, decay from position s is: exp(sum_{k=s+1}^{t} log_dA_k)
log_dA_cumsum = torch.cumsum(log_dA, dim=2) # [B, nc, T, d_state]
# Lower-triangular decay matrix: L[t,s] = exp(cumsum[t] - cumsum[s])
# L[t,s,n] = exp(sum_{k=s+1}^{t} log_dA_k_n) for t >= s, else 0
# Shape: [B, nc, T, T, d_state]
decay_matrix = log_dA_cumsum.unsqueeze(3) - log_dA_cumsum.unsqueeze(2)
# decay_matrix[..., t, s, :] = cumsum[t] - cumsum[s]
# Apply causal mask (t >= s only)
causal_mask = torch.tril(torch.ones(T, T, device=u.device)) # [T, T]
decay_matrix = decay_matrix * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
decay_matrix = torch.exp(decay_matrix) * causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
# [B, nc, T, T, d_state]
# Compute dBu: dt * B * u → state input at each position
# dt_c: [B, nc, T, d_inner], B_c: [B, nc, T, d_state], u_c: [B, nc, T, d_inner]
# We need [B, nc, T, d_state, d_inner]
dBu = dt_c.unsqueeze(-2) * B_c.unsqueeze(-1) * u_c.unsqueeze(-2)
# dBu: [B, nc, T, d_state, d_inner]
# Intra-chunk SSM via matrix multiply:
# h[t] = sum_s decay[t,s] * dBu[s]
# h: [B, nc, T, d_state, d_inner]
# decay_matrix: [B, nc, T, T, d_state]
# dBu: [B, nc, T, d_state, d_inner]
# Einsum: h[b,c,t,n,d] = sum_s decay[b,c,t,s,n] * dBu[b,c,s,n,d]
h_intra = torch.einsum('bctsn,bcsnd->bctnd', decay_matrix, dBu)
# h_intra: [B, nc, T, d_state, d_inner]
# Inter-chunk state propagation
# Decay of previous chunk's final state into current chunk
# Total decay for a full chunk: exp(sum of all T log_dA values)
chunk_decay = torch.exp(log_dA_cumsum[:, :, -1, :]) # [B, nc, d_state]
# Decay from chunk start to each position within chunk:
# position_decay[t] = exp(cumsum[t]) (from position 0)
position_decay = torch.exp(log_dA_cumsum) # [B, nc, T, d_state]
# Propagate states across chunks
h_carry = torch.zeros(batch, d_state, d_inner, device=u.device)
h_chunks = []
for c_idx in range(n_chunks):
# Decay carry state to each position in this chunk
# h_from_prev[t] = position_decay[t] * h_carry
h_from_prev = position_decay[:, c_idx, :, :].unsqueeze(-1) * h_carry.unsqueeze(1)
# h_from_prev: [B, T, d_state, d_inner]
# Total hidden state
h_total = h_intra[:, c_idx] + h_from_prev # [B, T, d_state, d_inner]
h_chunks.append(h_total)
# Update carry: final state of this chunk
h_carry = h_total[:, -1, :, :] # [B, d_state, d_inner]
# Stack chunks: [B, nc, T, d_state, d_inner]
h_all = torch.stack(h_chunks, dim=1)
# Output: y[t] = C[t]^T @ h[t]
# C_c: [B, nc, T, d_state], h_all: [B, nc, T, d_state, d_inner]
y = torch.einsum('bctn,bctnd->bctd', C_c, h_all)
# y: [B, nc, T, d_inner]
# Reshape back
y = y.reshape(batch, L_pad, d_inner)
return y[:, :L, :]
class Mamba2Block(nn.Module):
"""
Mamba-2 block with bidirectional scanning for 2D images.
Forward + backward raster scan, merged via learned projection.
"""
def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.ssd_fwd = Mamba2SSD(dim, d_state, d_conv, expand)
self.ssd_bwd = Mamba2SSD(dim, d_state, d_conv, expand)
self.merge = nn.Linear(dim * 2, dim, bias=False)
ff_dim = dim * expand
self.ff = nn.Sequential(
nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, dim), nn.Dropout(dropout),
)
def forward(self, x):
"""x: [B, C, H, W] or [B, L, C]"""
is_2d = x.dim() == 4
if is_2d:
B, C, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
residual = x
x_norm = self.norm1(x)
fwd = self.ssd_fwd(x_norm)
bwd = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1])
merged = self.merge(torch.cat([fwd, bwd], dim=-1))
x = residual + merged
x = x + self.ff(self.norm2(x))
if is_2d:
x = x.transpose(1, 2).reshape(B, C, H, W)
return x