krystv commited on
Commit
8a00562
·
verified ·
1 Parent(s): e507fd7

Upload liquid_diffusion/model.py

Browse files
Files changed (1) hide show
  1. liquid_diffusion/model.py +419 -0
liquid_diffusion/model.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiquidDiffusion Model — A Novel Attention-Free Image Generation Architecture
3
+
4
+ Core Innovation: Parallel Liquid Neural Network blocks for image generation.
5
+ The CfC (Closed-form Continuous-depth) time-gating mechanism naturally bridges
6
+ with diffusion timesteps — the diffusion noise level IS the liquid time constant.
7
+
8
+ Mathematical Foundation:
9
+ CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h
10
+
11
+ For image generation, we adapt this as:
12
+ φ'(t) = σ(-f(φ)·t_diff) ⊙ g(φ) + (1 - σ(-f(φ)·t_diff)) ⊙ h(φ)
13
+
14
+ Where t_diff is the diffusion timestep, f/g/h are spatial feature transforms.
15
+ This is FULLY PARALLEL — no ODE solver, no sequential scanning.
16
+
17
+ Additionally, we use learnable exponential relaxation (from LiquidTAD):
18
+ α = exp(-λ·t_diff), out = α·φ + (1-α)·S(φ)
19
+ This gives depth-dependent, time-aware residual connections.
20
+
21
+ Architecture:
22
+ Input (noisy image) → Conv stem → [Encoder: DownBlocks with LiquidCfC]
23
+ → Bottleneck (LiquidCfC) → [Decoder: UpBlocks with LiquidCfC + skip]
24
+ → Conv head → Velocity prediction (for rectified flow)
25
+
26
+ No attention anywhere. All spatial mixing via depthwise convolutions +
27
+ multi-scale parallel processing in liquid blocks.
28
+
29
+ References:
30
+ [1] Hasani et al., "Closed-form Continuous-time Neural Networks", Nature MI 2022 (CfC)
31
+ [2] arxiv 2604.18274 — LiquidTAD (parallel liquid relaxation)
32
+ [3] arxiv 2504.13499 — USM (U-Shape Mamba for diffusion)
33
+ [4] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023
34
+ """
35
+
36
+ import math
37
+ import torch
38
+ import torch.nn as nn
39
+ import torch.nn.functional as F
40
+
41
+
42
+ # =============================================================================
43
+ # 1. TIME EMBEDDING — Sinusoidal + MLP
44
+ # =============================================================================
45
+
46
+ class SinusoidalTimeEmbedding(nn.Module):
47
+ """Maps scalar timestep t to a high-dimensional embedding.
48
+ Uses sinusoidal positional encoding followed by 2-layer MLP.
49
+ """
50
+ def __init__(self, dim: int, max_period: int = 10000):
51
+ super().__init__()
52
+ self.dim = dim
53
+ self.max_period = max_period
54
+ self.mlp = nn.Sequential(
55
+ nn.Linear(dim, dim * 4),
56
+ nn.SiLU(),
57
+ nn.Linear(dim * 4, dim),
58
+ )
59
+
60
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
61
+ """t: [B] timestep values in [0, 1] → [B, dim] embeddings"""
62
+ half = self.dim // 2
63
+ freqs = torch.exp(
64
+ -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half
65
+ )
66
+ args = t[:, None] * freqs[None, :]
67
+ emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
68
+ if self.dim % 2 == 1:
69
+ emb = F.pad(emb, (0, 1))
70
+ return self.mlp(emb)
71
+
72
+
73
+ # =============================================================================
74
+ # 2. ADAPTIVE LAYER NORM (AdaLN) — Timestep conditioning via scale/shift
75
+ # =============================================================================
76
+
77
+ class AdaLN(nn.Module):
78
+ """Adaptive Layer Normalization: out = norm(x) * (1 + scale(t)) + shift(t)"""
79
+ def __init__(self, dim: int, cond_dim: int):
80
+ super().__init__()
81
+ self.norm = nn.GroupNorm(num_groups=min(32, dim), num_channels=dim, affine=False)
82
+ self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))
83
+
84
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
85
+ """x: [B,C,H,W], t_emb: [B, cond_dim] → [B,C,H,W]"""
86
+ scale, shift = self.proj(t_emb).chunk(2, dim=1)
87
+ return self.norm(x) * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
88
+
89
+
90
+ # =============================================================================
91
+ # 3. PARALLEL CfC BLOCK — Core liquid neural network layer
92
+ # =============================================================================
93
+
94
+ class ParallelCfCBlock(nn.Module):
95
+ """Parallel Closed-form Continuous-depth block for spatial features.
96
+
97
+ CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h
98
+
99
+ Adaptations for image generation:
100
+ 1. f/g/h heads operate on 2D feature maps via conv layers
101
+ 2. Diffusion timestep t IS the CfC time parameter
102
+ 3. Multi-directional depthwise convolutions for spatial context
103
+ 4. No recurrence — each spatial position computed independently
104
+ 5. Liquid relaxation residual: α·input + (1-α)·CfC_output
105
+ where α = exp(-λ·t_diff) adapts residual strength to noise level
106
+ """
107
+ def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0,
108
+ kernel_size: int = 7, dropout: float = 0.0):
109
+ super().__init__()
110
+ hidden = int(dim * expand_ratio)
111
+
112
+ # Shared backbone: depthwise + pointwise for local spatial context
113
+ self.backbone_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size // 2, groups=dim)
114
+ self.backbone_pw = nn.Conv2d(dim, hidden, 1)
115
+ self.backbone_act = nn.SiLU()
116
+
117
+ # Three CfC heads
118
+ self.f_head = nn.Conv2d(hidden, dim, 1) # time-constant gate
119
+ self.g_head = nn.Sequential( # "from" state
120
+ nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size // 2, groups=hidden),
121
+ nn.SiLU(),
122
+ nn.Conv2d(hidden, dim, 1),
123
+ )
124
+ self.h_head = nn.Sequential( # "to" state (attractor)
125
+ nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size // 2, groups=hidden),
126
+ nn.SiLU(),
127
+ nn.Conv2d(hidden, dim, 1),
128
+ )
129
+
130
+ # CfC time parameters: maps t_emb to per-channel gate modulation
131
+ self.time_a = nn.Linear(t_dim, dim)
132
+ self.time_b = nn.Linear(t_dim, dim)
133
+
134
+ # Liquid relaxation decay (LiquidTAD-inspired)
135
+ self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1))
136
+
137
+ # Output gate conditioned on timestep
138
+ self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim))
139
+
140
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
141
+
142
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
143
+ """x: [B,C,H,W], t_emb: [B, t_dim] → [B,C,H,W]"""
144
+ residual = x
145
+
146
+ # Shared backbone
147
+ backbone = self.backbone_act(self.backbone_pw(self.backbone_dw(x)))
148
+
149
+ # Three CfC heads
150
+ f = self.f_head(backbone) # time constant logits
151
+ g = self.g_head(backbone) # "from" state
152
+ h = self.h_head(backbone) # "to" state
153
+
154
+ # CfC time-gating: σ(time_a(t) · f - time_b(t))
155
+ ta = self.time_a(t_emb)[:, :, None, None]
156
+ tb = self.time_b(t_emb)[:, :, None, None]
157
+ gate = torch.sigmoid(ta * f - tb)
158
+
159
+ # CfC interpolation: gate*g + (1-gate)*h
160
+ cfc_out = gate * g + (1.0 - gate) * h
161
+ cfc_out = self.dropout(cfc_out)
162
+
163
+ # Liquid relaxation: α = exp(-λ · |t_mean|)
164
+ t_scalar = t_emb.mean(dim=1, keepdim=True)[:, :, None, None]
165
+ lam = F.softplus(self.rho) + 1e-6
166
+ alpha = torch.exp(-lam * t_scalar.abs().clamp(min=0.01))
167
+
168
+ out = alpha * residual + (1.0 - alpha) * cfc_out
169
+
170
+ # Output gate
171
+ out_gate = torch.sigmoid(self.output_gate(t_emb))[:, :, None, None]
172
+ return out * out_gate
173
+
174
+
175
+ # =============================================================================
176
+ # 4. MULTI-SCALE SPATIAL MIXING — Global context without attention
177
+ # =============================================================================
178
+
179
+ class MultiScaleSpatialMix(nn.Module):
180
+ """Multi-scale depthwise conv + global pooling for spatial context.
181
+
182
+ Uses parallel depthwise convolutions at 3x3, 5x5, 7x7 scales
183
+ plus adaptive average pooling for global receptive field.
184
+ This replaces self-attention's global spatial mixing.
185
+ """
186
+ def __init__(self, dim: int, t_dim: int):
187
+ super().__init__()
188
+ self.dw3 = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
189
+ self.dw5 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
190
+ self.dw7 = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
191
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
192
+ self.global_proj = nn.Conv2d(dim, dim, 1)
193
+ self.merge = nn.Conv2d(dim * 4, dim, 1)
194
+ self.act = nn.SiLU()
195
+ self.adaln = AdaLN(dim, t_dim)
196
+
197
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
198
+ x_norm = self.adaln(x, t_emb)
199
+ s3 = self.dw3(x_norm)
200
+ s5 = self.dw5(x_norm)
201
+ s7 = self.dw7(x_norm)
202
+ sg = self.global_proj(self.global_pool(x_norm)).expand_as(x_norm)
203
+ return x + self.act(self.merge(torch.cat([s3, s5, s7, sg], dim=1)))
204
+
205
+
206
+ # =============================================================================
207
+ # 5. LIQUID DIFFUSION BLOCK — Complete processing unit
208
+ # =============================================================================
209
+
210
+ class LiquidDiffusionBlock(nn.Module):
211
+ """One complete LiquidDiffusion block:
212
+ AdaLN → ParallelCfC → MultiScaleSpatialMix → FeedForward
213
+ """
214
+ def __init__(self, dim: int, t_dim: int, expand_ratio: float = 2.0,
215
+ kernel_size: int = 7, dropout: float = 0.0):
216
+ super().__init__()
217
+ self.adaln1 = AdaLN(dim, t_dim)
218
+ self.cfc = ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout)
219
+ self.spatial_mix = MultiScaleSpatialMix(dim, t_dim)
220
+ self.adaln2 = AdaLN(dim, t_dim)
221
+ ff_dim = int(dim * expand_ratio)
222
+ self.ff = nn.Sequential(
223
+ nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1),
224
+ )
225
+ self.res_scale = nn.Parameter(torch.ones(1) * 0.1)
226
+
227
+ def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
228
+ x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb)
229
+ x = self.spatial_mix(x, t_emb)
230
+ x = x + self.res_scale * self.ff(self.adaln2(x, t_emb))
231
+ return x
232
+
233
+
234
+ # =============================================================================
235
+ # 6. DOWN/UP SAMPLE + SKIP FUSION
236
+ # =============================================================================
237
+
238
+ class DownSample(nn.Module):
239
+ """Strided convolution downsampling (2x)."""
240
+ def __init__(self, in_dim: int, out_dim: int):
241
+ super().__init__()
242
+ self.conv = nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1)
243
+ def forward(self, x):
244
+ return self.conv(x)
245
+
246
+
247
+ class UpSample(nn.Module):
248
+ """Nearest-neighbor interpolation + conv upsampling (2x)."""
249
+ def __init__(self, in_dim: int, out_dim: int):
250
+ super().__init__()
251
+ self.conv = nn.Conv2d(in_dim, out_dim, 3, padding=1)
252
+ def forward(self, x):
253
+ return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
254
+
255
+
256
+ class SkipFusion(nn.Module):
257
+ """Timestep-gated skip connection fusion."""
258
+ def __init__(self, dim: int, t_dim: int):
259
+ super().__init__()
260
+ self.proj = nn.Conv2d(dim * 2, dim, 1)
261
+ self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid())
262
+
263
+ def forward(self, x, skip, t_emb):
264
+ merged = self.proj(torch.cat([x, skip], dim=1))
265
+ g = self.gate(t_emb)[:, :, None, None]
266
+ return merged * g + x * (1 - g)
267
+
268
+
269
+ # =============================================================================
270
+ # 7. LIQUID DIFFUSION U-NET — The complete denoiser
271
+ # =============================================================================
272
+
273
+ class LiquidDiffusionUNet(nn.Module):
274
+ """LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks.
275
+
276
+ U-Net where every processing block uses Parallel CfC layers instead of attention.
277
+ The diffusion timestep serves dual purpose:
278
+ 1. Conditions the denoiser via AdaLN scale/shift
279
+ 2. Acts as CfC "time parameter" — controlling liquid neuron interpolation
280
+
281
+ Scales:
282
+ tiny: channels=[64,128,256], blocks=[2,2,4], ~8M (256px, fast)
283
+ small: channels=[96,192,384], blocks=[2,3,6], ~25M (256px, quality)
284
+ base: channels=[128,256,512], blocks=[2,4,8], ~65M (512px)
285
+ large: channels=[128,256,512,768],blocks=[2,4,8,4], ~120M (512px HQ)
286
+ """
287
+ def __init__(self, in_channels=3, channels=None, blocks_per_stage=None,
288
+ t_dim=256, expand_ratio=2.0, kernel_size=7, dropout=0.0):
289
+ super().__init__()
290
+ if channels is None:
291
+ channels = [64, 128, 256]
292
+ if blocks_per_stage is None:
293
+ blocks_per_stage = [2, 2, 4]
294
+
295
+ assert len(channels) == len(blocks_per_stage)
296
+ self.channels = channels
297
+ self.num_stages = len(channels)
298
+
299
+ # Time embedding
300
+ self.time_embed = SinusoidalTimeEmbedding(t_dim)
301
+
302
+ # Input stem
303
+ self.stem = nn.Sequential(
304
+ nn.Conv2d(in_channels, channels[0], 3, padding=1),
305
+ nn.SiLU(),
306
+ nn.Conv2d(channels[0], channels[0], 3, padding=1),
307
+ )
308
+
309
+ # Encoder
310
+ self.encoder_blocks = nn.ModuleList()
311
+ self.downsamplers = nn.ModuleList()
312
+ for i in range(self.num_stages):
313
+ stage = nn.ModuleList()
314
+ for _ in range(blocks_per_stage[i]):
315
+ stage.append(LiquidDiffusionBlock(
316
+ channels[i], t_dim, expand_ratio, kernel_size, dropout))
317
+ self.encoder_blocks.append(stage)
318
+ if i < self.num_stages - 1:
319
+ self.downsamplers.append(DownSample(channels[i], channels[i + 1]))
320
+
321
+ # Bottleneck
322
+ self.bottleneck = nn.ModuleList([
323
+ LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout),
324
+ LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout),
325
+ ])
326
+
327
+ # Decoder
328
+ self.decoder_blocks = nn.ModuleList()
329
+ self.upsamplers = nn.ModuleList()
330
+ self.skip_fusions = nn.ModuleList()
331
+ for i in range(self.num_stages - 1, -1, -1):
332
+ if i < self.num_stages - 1:
333
+ self.upsamplers.append(UpSample(channels[i + 1], channels[i]))
334
+ self.skip_fusions.append(SkipFusion(channels[i], t_dim))
335
+ stage = nn.ModuleList()
336
+ for _ in range(blocks_per_stage[i]):
337
+ stage.append(LiquidDiffusionBlock(
338
+ channels[i], t_dim, expand_ratio, kernel_size, dropout))
339
+ self.decoder_blocks.append(stage)
340
+
341
+ # Output head (initialized to zero for stable start)
342
+ self.head = nn.Sequential(
343
+ nn.GroupNorm(min(32, channels[0]), channels[0]),
344
+ nn.SiLU(),
345
+ nn.Conv2d(channels[0], in_channels, 3, padding=1),
346
+ )
347
+ nn.init.zeros_(self.head[-1].weight)
348
+ nn.init.zeros_(self.head[-1].bias)
349
+
350
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
351
+ """
352
+ Args:
353
+ x: [B, C, H, W] noisy image
354
+ t: [B] timestep values in [0, 1]
355
+ Returns:
356
+ [B, C, H, W] predicted velocity
357
+ """
358
+ t_emb = self.time_embed(t)
359
+ h = self.stem(x)
360
+
361
+ # Encoder
362
+ skips = []
363
+ for i in range(self.num_stages):
364
+ for block in self.encoder_blocks[i]:
365
+ h = block(h, t_emb)
366
+ skips.append(h)
367
+ if i < self.num_stages - 1:
368
+ h = self.downsamplers[i](h)
369
+
370
+ # Bottleneck
371
+ for block in self.bottleneck:
372
+ h = block(h, t_emb)
373
+
374
+ # Decoder
375
+ up_idx = 0
376
+ for dec_i in range(self.num_stages):
377
+ stage_idx = self.num_stages - 1 - dec_i
378
+ if dec_i > 0:
379
+ h = self.upsamplers[up_idx](h)
380
+ h = self.skip_fusions[up_idx](h, skips[stage_idx], t_emb)
381
+ up_idx += 1
382
+ for block in self.decoder_blocks[dec_i]:
383
+ h = block(h, t_emb)
384
+
385
+ return self.head(h)
386
+
387
+ def count_params(self):
388
+ total = sum(p.numel() for p in self.parameters())
389
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
390
+ return total, trainable
391
+
392
+
393
+ # =============================================================================
394
+ # 8. MODEL CONFIGS
395
+ # =============================================================================
396
+
397
+ def liquid_diffusion_tiny(**kwargs):
398
+ """~8M params, 256px, fits ~4GB VRAM."""
399
+ return LiquidDiffusionUNet(
400
+ channels=[64, 128, 256], blocks_per_stage=[2, 2, 4],
401
+ t_dim=256, expand_ratio=2.0, kernel_size=7, **kwargs)
402
+
403
+ def liquid_diffusion_small(**kwargs):
404
+ """~25M params, 256px, fits ~8GB VRAM."""
405
+ return LiquidDiffusionUNet(
406
+ channels=[96, 192, 384], blocks_per_stage=[2, 3, 6],
407
+ t_dim=384, expand_ratio=2.0, kernel_size=7, **kwargs)
408
+
409
+ def liquid_diffusion_base(**kwargs):
410
+ """~65M params, 512px, fits ~14GB VRAM."""
411
+ return LiquidDiffusionUNet(
412
+ channels=[128, 256, 512], blocks_per_stage=[2, 4, 8],
413
+ t_dim=512, expand_ratio=2.0, kernel_size=7, **kwargs)
414
+
415
+ def liquid_diffusion_large(**kwargs):
416
+ """~120M params, 512px, needs ~24GB VRAM."""
417
+ return LiquidDiffusionUNet(
418
+ channels=[128, 256, 512, 768], blocks_per_stage=[2, 4, 8, 4],
419
+ t_dim=512, expand_ratio=2.0, kernel_size=7, **kwargs)