| """ |
| CfC Cell — Closed-form Continuous-time neural network cell. |
| |
| From: "Closed-form Continuous-time Neural Networks" (Hasani et al., 2022) |
| |
| The CfC model provides an approximate closed-form solution to Liquid Time-Constant (LTC) |
| network dynamics without needing ODE solvers. |
| |
| Architecture: |
| x(t) = σ(-f(x,I;θ_f) · t) ⊙ g(x,I;θ_g) + (1 - σ(-f(x,I;θ_f) · t)) ⊙ h(x,I;θ_h) |
| |
| Where: |
| - f, g, h are neural network heads sharing a backbone |
| - σ is the sigmoid (replacing exponential decay for gradient stability) |
| - t is a time parameter |
| - The sigmoidal terms act as time-continuous gates between g and h |
| |
| Key properties: |
| - No ODE solving → 100x+ faster than Neural ODEs |
| - Time-continuous gating mechanism → adaptive computation |
| - Closed-form → stable gradients, easy to train |
| - Naturally causal → good for sequential processing |
| |
| For 2D image inputs: we treat the spatial sequence as "time" steps for the CfC, |
| allowing the liquid dynamics to model spatial dependencies with adaptive gates. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class CfCCell(nn.Module): |
| """ |
| Single CfC cell with backbone + 3 heads (f, g, h). |
| |
| Args: |
| dim: Hidden dimension |
| backbone_dropout: Dropout in backbone layers |
| time_scale: Range [a, b] for time parameter sampling |
| use_conv: Add conv1d for local context |
| """ |
| |
| def __init__(self, dim, backbone_dropout=0.0, time_scale=(0.0, 1.0), use_conv=True): |
| super().__init__() |
| self.dim = dim |
| self.time_scale = time_scale |
| |
| |
| backbone_dim = dim * 3 |
| self.backbone = nn.Sequential( |
| nn.Linear(dim + dim, backbone_dim), |
| nn.LayerNorm(backbone_dim), |
| nn.SiLU(), |
| nn.Dropout(backbone_dropout), |
| nn.Linear(backbone_dim, dim * 4), |
| nn.LayerNorm(dim * 4), |
| ) |
| |
| |
| self.conv = nn.Conv1d(dim, dim, kernel_size=3, padding=1, groups=dim) if use_conv else None |
| |
| |
| self.f_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.Tanh()) |
| self.g_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU()) |
| self.h_head = nn.Sequential(nn.Linear(dim, dim), nn.LayerNorm(dim), nn.GELU()) |
| |
| self.out_proj = nn.Linear(dim, dim) |
| self._init_weights() |
| |
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| |
| def forward(self, x, h_prev=None, t=None): |
| """ |
| Args: |
| x: [B, dim] or [B, L, dim] |
| h_prev: Previous hidden state [B, dim] |
| t: Time parameter |
| Returns: h: [B, dim] or [B, L, dim] |
| """ |
| is_seq = x.dim() == 3 |
| B, device = x.shape[0], x.device |
| |
| if is_seq: |
| return self._forward_seq(x, h_prev, t) |
| |
| if h_prev is None: |
| h_prev = torch.zeros(B, self.dim, device=device) |
| if t is None: |
| t = torch.rand(B, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0] |
| elif t.dim() == 1: |
| t = t.unsqueeze(1) |
| |
| return self._step(x, h_prev, t) |
| |
| def _forward_seq(self, x, h_prev=None, t=None): |
| B, L, D = x.shape |
| device = x.device |
| |
| if t is None: |
| t = torch.rand(B, 1, 1, device=device) * (self.time_scale[1] - self.time_scale[0]) + self.time_scale[0] |
| |
| outputs = [] |
| h = torch.zeros(B, D, device=device) if h_prev is None else h_prev |
| for step in range(L): |
| h = self._step(x[:, step, :], h, t.squeeze(-1) if t.dim() == 3 else t) |
| outputs.append(h) |
| return torch.stack(outputs, dim=1) |
| |
| def _step(self, x, h_prev, t): |
| """Core CfC step.""" |
| combined = torch.cat([x, h_prev], dim=-1) |
| backbone_out = self.backbone(combined) |
| f_base, g_base, h_base, skip = backbone_out.chunk(4, dim=-1) |
| |
| if self.conv is not None: |
| f_base = f_base + self.conv(f_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1) |
| g_base = g_base + self.conv(g_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1) |
| h_base = h_base + self.conv(h_base.unsqueeze(1).transpose(1,2)).transpose(1,2).squeeze(1) |
| |
| f_out = self.f_head(f_base) |
| g_out = self.g_head(g_base) |
| h_out = self.h_head(h_base) |
| |
| gate = torch.sigmoid(-f_out * t) |
| h = gate * g_out + (1 - gate) * h_out + skip |
| return self.out_proj(h) |
|
|
|
|
| class CfCBlock(nn.Module): |
| """CfC block for 2D image processing with residual connection.""" |
| |
| def __init__(self, dim, dropout=0.0, time_scale=(0.0, 1.0), expansion_factor=2): |
| super().__init__() |
| self.dim = dim |
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
| self.cfc = CfCCell(dim=dim, backbone_dropout=dropout, time_scale=time_scale, use_conv=True) |
| |
| ff_dim = dim * expansion_factor |
| self.ff = nn.Sequential( |
| nn.Linear(dim, ff_dim), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(ff_dim, dim), nn.Dropout(dropout), |
| ) |
| |
| self.pos_embed = nn.Parameter(torch.randn(1, 4096, dim) * 0.02) |
| |
| def forward(self, x, return_2d=True): |
| is_2d = x.dim() == 4 |
| if is_2d: |
| B, C, H, W = x.shape |
| L = H * W |
| x = x.flatten(2).transpose(1, 2) |
| else: |
| B, L, C = x.shape |
| |
| x_with_pos = x + self.pos_embed[:, :L, :] |
| residual = x |
| h = self.cfc(self.norm1(x_with_pos)) |
| x_out = h + self.ff(self.norm2(h + residual)) |
| |
| if is_2d and return_2d: |
| x_out = x_out.transpose(1, 2).reshape(B, C, H, W) |
| return x_out |
|
|