krystv commited on
Commit
3214be6
·
verified ·
1 Parent(s): be1bcbb

Upload liquid_flow/liquid_flow_block.py

Browse files
Files changed (1) hide show
  1. liquid_flow/liquid_flow_block.py +104 -194
liquid_flow/liquid_flow_block.py CHANGED
@@ -1,58 +1,34 @@
1
  """
2
  LiquidFlow Block — Hybrid CfC + Mamba-2 SSD architecture.
3
-
4
- The core innovation: combine Liquid Neural Network dynamics (CfC)
5
- with Mamba-2's efficient linear-time state space model.
6
 
7
  Architecture per block:
8
- Input → [CfC Gate → Mamba2 SSD → CfC Gate] → Output
9
- ↑ ↑
10
- Adaptive gating Gated output
11
-
12
- The CfC provides:
13
- - Time-continuous adaptive gating (what to process/ignore)
14
- - State initialization for the SSM (the "liquid" memory)
15
 
16
- The Mamba-2 SSD provides:
17
- - Efficient O(N) sequence processing
18
- - Content-aware selection mechanism
19
- - Parallelizable computation (no sequential bottleneck)
20
-
21
- Together they create a "Liquid State Space Model" (LSSM):
22
- h_t = σ(-f(x_t;θ_f)·t) ⊙ SSM(x_t, h_{t-1}) + (1-σ(...)) ⊙ h(x_t;θ_h)
23
-
24
- Where SSM is the Mamba-2 selective state space model and the
25
- CfC time-gates control how much the SSM output influences state.
26
-
27
- This is inspired by:
28
- - LNNs: adaptive time constants for state evolution
29
- - Mamba-2: efficient selective state space models
30
- - DiMSUM: multi-scan architecture for 2D images
31
- - Gated SSM: gating mechanism from CfC applied to SSM
32
  """
33
 
34
  import torch
35
  import torch.nn as nn
36
  import torch.nn.functional as F
 
37
 
38
- from .cfc_cell import CfCCell
39
- from .mamba2_ssd import Mamba2SSD
40
 
41
 
42
  class LiquidMambaBlock(nn.Module):
43
  """
44
  LiquidMamba: CfC-gated Mamba-2 SSD block.
45
 
46
- The CfC cell acts as a learned gate on the Mamba-2 output,
47
- creating a liquid time-constant mechanism for the SSM:
48
-
49
- 1. Input goes through Mamba-2 SSD (multi-directional scan)
50
- 2. CfC cell receives the SSM output + original input
51
- 3. CfC produces a time-gated output: σ(f)·SSM_out + (1-σ(f))·input
52
- 4. The CfC's liquid dynamics adaptively mix SSM features with raw input
53
 
54
- This creates a "content-aware gating" that the CfC learns to
55
- control based on both the input and the SSM's processed features.
56
  """
57
 
58
  def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
@@ -60,15 +36,20 @@ class LiquidMambaBlock(nn.Module):
60
  self.dim = dim
61
 
62
  # LayerNorms
63
- self.norm_in = nn.LayerNorm(dim)
64
- self.norm_mamba = nn.LayerNorm(dim)
65
- self.norm_out = nn.LayerNorm(dim)
66
 
67
- # Mamba-2 SSD for efficient sequence processing
68
- self.mamba = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand)
 
 
69
 
70
- # CfC gate: controls the flow between Mamba output and residual
71
- self.cfc_gate = CfCCell(dim=dim, backbone_dropout=dropout, use_conv=True)
 
 
 
72
 
73
  # Feed-forward
74
  ff_dim = dim * expand
@@ -79,104 +60,52 @@ class LiquidMambaBlock(nn.Module):
79
  nn.Linear(ff_dim, dim),
80
  nn.Dropout(dropout),
81
  )
82
-
83
- # Learnable mixing ratio init
84
- self.gate_scale = nn.Parameter(torch.ones(1) * 0.5)
85
 
86
  def forward(self, x):
87
  """
88
  Args:
89
- x: [B, C, H, W] (2D) or [B, L, C] (1D seq)
90
  Returns:
91
  Same shape as input
92
  """
93
  is_2d = x.dim() == 4
94
-
95
  if is_2d:
96
  B, C, H, W = x.shape
97
- L = H * W
98
- x_flat = x.flatten(2).transpose(1, 2) # [B, HW, C]
99
- else:
100
- B, L, C = x.shape
101
- x_flat = x
102
 
103
- residual = x_flat
104
- x_norm = self.norm_in(x_flat)
 
105
 
106
- # Mamba-2 SSD processing with multi-directional scan
107
- if is_2d:
108
- # Reshape for 2D scanning
109
- x_2d = x_norm.transpose(1, 2).reshape(B, C, H, W)
110
- mamba_out = self._mamba_2d_scan(x_2d)
111
- mamba_out = mamba_out.flatten(2).transpose(1, 2) # [B, HW, C]
112
- else:
113
- mamba_out = self.mamba(x_norm)
114
 
115
- # CfC gating: liquid dynamics control the mix
116
- mamba_norm = self.norm_mamba(mamba_out)
 
 
117
 
118
- # CfC receives both the Mamba output and the residual
119
- # This lets it learn when to trust the SSM vs the original signal
120
- cfc_input = mamba_norm + residual
121
- cfc_out = self.cfc_gate(cfc_input)
122
 
123
- # Gated mix: CfC controls the blend
124
- gate = torch.sigmoid(self.gate_scale * (cfc_out - mamba_out))
125
- mixed = gate * mamba_out + (1 - gate) * residual + cfc_out
126
 
127
- # Feed-forward + residual
128
- out_norm = self.norm_out(mixed)
129
- out = mixed + self.ff(out_norm)
130
 
131
  if is_2d:
132
- out = out.transpose(1, 2).reshape(B, C, H, W)
133
-
134
- return out
135
-
136
- def _mamba_2d_scan(self, x):
137
- """
138
- Multi-directional Mamba-2 scan for 2D images.
139
-
140
- Scans in forward and backward raster directions, then merges.
141
- This preserves 2D spatial structure better than single-direction scan.
142
- """
143
- B, C, H, W = x.shape
144
- device = x.device
145
-
146
- # Forward raster: left→right, top→bottom
147
- fwd = x.flatten(2) # [B, C, HW]
148
- fwd_seq = fwd.transpose(1, 2) # [B, HW, C]
149
- fwd_out = self.mamba(fwd_seq)
150
-
151
- # Backward raster: right→left, bottom→top
152
- bwd = torch.flip(x.flatten(2), dims=[-1]) # [B, C, HW]
153
- bwd_seq = bwd.transpose(1, 2)
154
- bwd_out = self.mamba(bwd_seq)
155
- bwd_out = torch.flip(bwd_out, dims=[1]) # Flip back
156
-
157
- # Merge both directions
158
- merged = (fwd_out + bwd_out) / 2
159
- merged = merged.transpose(1, 2).reshape(B, C, H, W)
160
-
161
- return merged
162
 
163
 
164
  class LiquidFlowStage(nn.Module):
165
- """
166
- A stage in LiquidFlow: multiple LiquidMamba blocks at the same resolution.
167
-
168
- Architecture:
169
- [LiquidMamba Block] × num_blocks
170
- [Optional Downsample/Upsample]
171
-
172
- This mirrors the hierarchical design from DiT/DiMSUM but with
173
- liquid neural network dynamics in every block.
174
- """
175
 
176
  def __init__(self, dim, num_blocks=4, d_state=16, expand=2, dropout=0.0):
177
  super().__init__()
178
- self.dim = dim
179
-
180
  self.blocks = nn.ModuleList([
181
  LiquidMambaBlock(dim=dim, d_state=d_state, expand=expand, dropout=dropout)
182
  for _ in range(num_blocks)
@@ -190,27 +119,18 @@ class LiquidFlowStage(nn.Module):
190
 
191
  class LiquidFlowBackbone(nn.Module):
192
  """
193
- Complete LiquidFlow backbone for image generation.
194
-
195
- Architecture:
196
- Input (noisy latent) [B, C, H, W]
197
-
198
- [Patch Embed + Positional Encoding]
199
-
200
- [LiquidMamba Stages × N] (at uniform resolution)
201
-
202
- [Output Head] → predicted noise
203
 
204
- This is designed as a DiT-style noise predictor for diffusion models.
205
 
206
- Args:
207
- in_channels: Input channels (latent dim from VAE)
208
- hidden_dim: Hidden dimension
209
- num_stages: Number of processing stages
210
- blocks_per_stage: Number of blocks per stage
211
- d_state: SSM state dimension
212
- expand: Expansion factor
213
- dropout: Dropout rate
214
  """
215
 
216
  def __init__(
@@ -226,27 +146,30 @@ class LiquidFlowBackbone(nn.Module):
226
  super().__init__()
227
  self.in_channels = in_channels
228
  self.hidden_dim = hidden_dim
229
- self.num_stages = num_stages
230
 
231
- # Input embedding: patch embedding
232
- self.patch_size = 2 # Fixed patch size
233
  self.in_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
234
 
235
- # Time embedding (for diffusion timestep)
236
  self.time_embed = nn.Sequential(
237
  nn.Linear(hidden_dim, hidden_dim * 4),
238
  nn.SiLU(),
239
  nn.Linear(hidden_dim * 4, hidden_dim),
240
  )
241
 
242
- # Learnable positional encoding
243
- # For 128×128 with patch_size=2: 64×64 = 4096 positions
 
 
 
 
 
244
  self.pos_embed = nn.Parameter(torch.randn(1, 4096, hidden_dim) * 0.02)
245
 
246
  # LiquidFlow stages
247
  self.stages = nn.ModuleList([
248
  LiquidFlowStage(
249
- dim=hidden_dim,
250
  num_blocks=blocks_per_stage,
251
  d_state=d_state,
252
  expand=expand,
@@ -255,80 +178,67 @@ class LiquidFlowBackbone(nn.Module):
255
  for _ in range(num_stages)
256
  ])
257
 
258
- # Output head
259
  self.out_norm = nn.LayerNorm(hidden_dim)
260
- self.out_proj = nn.Sequential(
261
- nn.Linear(hidden_dim, hidden_dim),
262
- nn.GELU(),
263
- nn.Linear(hidden_dim, in_channels * self.patch_size * self.patch_size),
264
- )
265
 
266
- # Timestep conditioner (modulated conv trick)
267
- self.t_conditioner = nn.Sequential(
268
- nn.SiLU(),
269
- nn.Linear(hidden_dim, hidden_dim * 2), # scale, shift
270
- )
271
 
272
- def _get_timestep_embedding(self, timesteps, dim, max_period=10000):
273
- """Sinusoidal timestep embedding (from DiT)."""
 
 
 
 
 
274
  half = dim // 2
275
  freqs = torch.exp(
276
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
277
- ).to(timesteps.device)
278
  args = timesteps.float().unsqueeze(-1) * freqs.unsqueeze(0)
279
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
280
  if dim % 2:
281
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
282
- return embedding
283
 
284
  def forward(self, x, t):
285
  """
286
  Args:
287
- x: Noisy latent [B, C, H, W]
288
- t: Diffusion timesteps [B]
289
-
290
  Returns:
291
- Predicted noise [B, C, H, W]
292
  """
293
  B, C, H, W = x.shape
294
- device = x.device
295
- L = (H // self.patch_size) * (W // self.patch_size)
296
 
297
- # Input projection
298
  x = self.in_proj(x) # [B, hidden_dim, H, W]
 
299
 
300
- # Flatten and add positional encoding
301
- x_flat = x.flatten(2).transpose(1, 2) # [B, H*W, hidden_dim]
302
-
303
- # Time embedding
304
- t_emb = self._get_timestep_embedding(t, self.hidden_dim)
305
  t_emb = self.time_embed(t_emb) # [B, hidden_dim]
 
 
306
 
307
- # Add time conditioning as bias to input
308
- t_cond = self.t_conditioner(t_emb) # [B, hidden_dim * 2]
309
- t_scale, t_shift = t_cond.chunk(2, dim=-1)
310
- x_flat = x_flat * (1 + t_scale.unsqueeze(1)) + t_shift.unsqueeze(1)
311
-
312
- # Add positional encoding
313
- x_flat = x_flat + self.pos_embed[:, :L, :]
314
 
315
- # Reshape back to 2D for processing
316
- x_2d = x_flat.transpose(1, 2).reshape(B, self.hidden_dim, H, W)
317
 
318
  # Process through all stages
319
  for stage in self.stages:
320
- x_2d = stage(x_2d)
321
 
322
  # Output head
323
- x_out = x_2d.flatten(2).transpose(1, 2) # [B, H*W, hidden_dim]
324
- x_out = self.out_norm(x_out)
325
- x_out = self.out_proj(x_out) # [B, H*W, C * patch²]
326
 
327
- # Reshape to image
328
- x_out = x_out.reshape(B, H, W, C, self.patch_size, self.patch_size)
329
- x_out = x_out.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H * self.patch_size, W * self.patch_size)
330
 
331
- return x_out
332
-
333
-
334
- import math
 
1
  """
2
  LiquidFlow Block — Hybrid CfC + Mamba-2 SSD architecture.
3
+ CORRECTED VERSION: proper dimensions, no sequential loops.
 
 
4
 
5
  Architecture per block:
6
+ Input → Mamba-2 SSD (bidirectional) → CfC adaptive gate → Output
 
 
 
 
 
 
7
 
8
+ The CfC provides adaptive gating that modulates the SSM output
9
+ based on input-dependent "liquid" time constants.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
 
12
  import torch
13
  import torch.nn as nn
14
  import torch.nn.functional as F
15
+ import math
16
 
17
+ from .cfc_cell import CfCCell, CfCBlock
18
+ from .mamba2_ssd import Mamba2SSD, Mamba2Block
19
 
20
 
21
  class LiquidMambaBlock(nn.Module):
22
  """
23
  LiquidMamba: CfC-gated Mamba-2 SSD block.
24
 
25
+ Flow:
26
+ 1. Input LayerNorm Mamba-2 SSD (bidirectional scan)
27
+ 2. SSM output → CfC adaptive gate (parallel over all positions)
28
+ 3. Gated output residual + feed-forward
 
 
 
29
 
30
+ The CfC gate learns WHEN to trust the SSM output vs the raw input,
31
+ creating content-aware adaptive processing.
32
  """
33
 
34
  def __init__(self, dim, d_state=16, d_conv=4, expand=2, dropout=0.0):
 
36
  self.dim = dim
37
 
38
  # LayerNorms
39
+ self.norm_ssm = nn.LayerNorm(dim)
40
+ self.norm_gate = nn.LayerNorm(dim)
41
+ self.norm_ff = nn.LayerNorm(dim)
42
 
43
+ # Mamba-2 SSD: bidirectional scan
44
+ self.ssd_fwd = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand)
45
+ self.ssd_bwd = Mamba2SSD(dim=dim, d_state=d_state, d_conv=d_conv, expand=expand)
46
+ self.merge = nn.Linear(dim * 2, dim, bias=False)
47
 
48
+ # CfC gate: parallel adaptive gating
49
+ self.cfc_gate = CfCCell(dim=dim, dropout=dropout)
50
+
51
+ # Gate projection (learnable mixing)
52
+ self.gate_proj = nn.Linear(dim, dim)
53
 
54
  # Feed-forward
55
  ff_dim = dim * expand
 
60
  nn.Linear(ff_dim, dim),
61
  nn.Dropout(dropout),
62
  )
 
 
 
63
 
64
  def forward(self, x):
65
  """
66
  Args:
67
+ x: [B, C, H, W] or [B, L, C]
68
  Returns:
69
  Same shape as input
70
  """
71
  is_2d = x.dim() == 4
 
72
  if is_2d:
73
  B, C, H, W = x.shape
74
+ x = x.flatten(2).transpose(1, 2) # [B, HW, C]
 
 
 
 
75
 
76
+ # === SSM branch ===
77
+ residual = x
78
+ x_norm = self.norm_ssm(x)
79
 
80
+ # Bidirectional Mamba-2 scan
81
+ fwd_out = self.ssd_fwd(x_norm)
82
+ bwd_out = torch.flip(self.ssd_bwd(torch.flip(x_norm, [1])), [1])
83
+ ssm_out = self.merge(torch.cat([fwd_out, bwd_out], dim=-1))
 
 
 
 
84
 
85
+ # === CfC gate ===
86
+ # CfC processes the SSM output and produces adaptive gate
87
+ gate_input = self.norm_gate(ssm_out)
88
+ cfc_out = self.cfc_gate(gate_input) # [B, L, D] — parallel!
89
 
90
+ # Sigmoid gate: how much SSM output to use
91
+ gate = torch.sigmoid(self.gate_proj(cfc_out))
 
 
92
 
93
+ # Gated residual: blend SSM output with residual
94
+ x = residual + gate * ssm_out
 
95
 
96
+ # === Feed-forward ===
97
+ x = x + self.ff(self.norm_ff(x))
 
98
 
99
  if is_2d:
100
+ x = x.transpose(1, 2).reshape(B, C, H, W)
101
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
 
104
  class LiquidFlowStage(nn.Module):
105
+ """Stack of LiquidMamba blocks at the same resolution."""
 
 
 
 
 
 
 
 
 
106
 
107
  def __init__(self, dim, num_blocks=4, d_state=16, expand=2, dropout=0.0):
108
  super().__init__()
 
 
109
  self.blocks = nn.ModuleList([
110
  LiquidMambaBlock(dim=dim, d_state=d_state, expand=expand, dropout=dropout)
111
  for _ in range(num_blocks)
 
119
 
120
  class LiquidFlowBackbone(nn.Module):
121
  """
122
+ Complete LiquidFlow backbone DiT-style noise predictor.
 
 
 
 
 
 
 
 
 
123
 
124
+ FIXED: Output shape == Input shape (no patch_size confusion).
125
 
126
+ Architecture:
127
+ Input [B, in_ch, H, W]
128
+ Conv2d projection to hidden_dim
129
+ + sinusoidal timestep embedding (AdaLN-style)
130
+ + learnable positional encoding
131
+ N × LiquidMamba Stages
132
+ Conv2d projection back to in_ch
133
+ Output [B, in_ch, H, W]
134
  """
135
 
136
  def __init__(
 
146
  super().__init__()
147
  self.in_channels = in_channels
148
  self.hidden_dim = hidden_dim
 
149
 
150
+ # Input projection (pointwise conv)
 
151
  self.in_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
152
 
153
+ # Timestep embedding
154
  self.time_embed = nn.Sequential(
155
  nn.Linear(hidden_dim, hidden_dim * 4),
156
  nn.SiLU(),
157
  nn.Linear(hidden_dim * 4, hidden_dim),
158
  )
159
 
160
+ # AdaLN-style conditioning: scale and shift
161
+ self.t_cond = nn.Sequential(
162
+ nn.SiLU(),
163
+ nn.Linear(hidden_dim, hidden_dim * 2),
164
+ )
165
+
166
+ # Positional encoding (learnable, supports up to 64×64 = 4096 positions)
167
  self.pos_embed = nn.Parameter(torch.randn(1, 4096, hidden_dim) * 0.02)
168
 
169
  # LiquidFlow stages
170
  self.stages = nn.ModuleList([
171
  LiquidFlowStage(
172
+ dim=hidden_dim,
173
  num_blocks=blocks_per_stage,
174
  d_state=d_state,
175
  expand=expand,
 
178
  for _ in range(num_stages)
179
  ])
180
 
181
+ # Output projection
182
  self.out_norm = nn.LayerNorm(hidden_dim)
183
+ self.out_proj = nn.Linear(hidden_dim, in_channels)
 
 
 
 
184
 
185
+ self._init_weights()
 
 
 
 
186
 
187
+ def _init_weights(self):
188
+ # Zero-init output projection for residual learning
189
+ nn.init.zeros_(self.out_proj.weight)
190
+ nn.init.zeros_(self.out_proj.bias)
191
+
192
+ def _sinusoidal_embedding(self, timesteps, dim):
193
+ """Sinusoidal positional embedding for diffusion timesteps."""
194
  half = dim // 2
195
  freqs = torch.exp(
196
+ -math.log(10000.0) * torch.arange(half, device=timesteps.device).float() / half
197
+ )
198
  args = timesteps.float().unsqueeze(-1) * freqs.unsqueeze(0)
199
+ emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
200
  if dim % 2:
201
+ emb = F.pad(emb, (0, 1))
202
+ return emb
203
 
204
  def forward(self, x, t):
205
  """
206
  Args:
207
+ x: [B, in_channels, H, W] — noisy latent
208
+ t: [B] — diffusion timesteps (integers 0..T-1)
 
209
  Returns:
210
+ [B, in_channels, H, W] — predicted noise (same shape as input!)
211
  """
212
  B, C, H, W = x.shape
213
+ L = H * W
 
214
 
215
+ # Project to hidden dim
216
  x = self.in_proj(x) # [B, hidden_dim, H, W]
217
+ x = x.flatten(2).transpose(1, 2) # [B, HW, hidden_dim]
218
 
219
+ # Timestep conditioning (AdaLN)
220
+ t_emb = self._sinusoidal_embedding(t, self.hidden_dim) # [B, hidden_dim]
 
 
 
221
  t_emb = self.time_embed(t_emb) # [B, hidden_dim]
222
+ t_cond = self.t_cond(t_emb) # [B, hidden_dim*2]
223
+ scale, shift = t_cond.chunk(2, dim=-1) # each [B, hidden_dim]
224
 
225
+ # Apply conditioning + positional encoding
226
+ x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
227
+ x = x + self.pos_embed[:, :L, :]
 
 
 
 
228
 
229
+ # Reshape to 2D for processing
230
+ x = x.transpose(1, 2).reshape(B, self.hidden_dim, H, W)
231
 
232
  # Process through all stages
233
  for stage in self.stages:
234
+ x = stage(x)
235
 
236
  # Output head
237
+ x = x.flatten(2).transpose(1, 2) # [B, HW, hidden_dim]
238
+ x = self.out_norm(x)
239
+ x = self.out_proj(x) # [B, HW, in_channels]
240
 
241
+ # Reshape back to image: [B, in_channels, H, W]
242
+ x = x.transpose(1, 2).reshape(B, self.in_channels, H, W)
 
243
 
244
+ return x