krystv commited on
Commit
4f47596
·
verified ·
1 Parent(s): f798aa0

Upload lrf/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. lrf/model.py +950 -0
lrf/model.py ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LatentRecurrentFlow (LRF) - Core Architecture Modules
3
+
4
+ Architecture Overview:
5
+ =====================
6
+ The LRF architecture consists of 4 main components:
7
+
8
+ 1. CompactEncoder/Decoder (VAE): f=32 spatial compression with tiny decoder
9
+ 2. TextConditioner: Lightweight text encoding (TinyCLIP or small LM)
10
+ 3. RecursiveLatentCore: The novel HRM-inspired denoising backbone
11
+ 4. FlowScheduler: Rectified flow for training and sampling
12
+
13
+ The RecursiveLatentCore is the key innovation:
14
+ - It contains N_blocks GLD (Gated Linear Diffusion) blocks
15
+ - These blocks are applied recursively T_outer * T_inner times
16
+ - The same parameters are reused across recursions (weight sharing)
17
+ - Training uses IFT (Implicit Function Theorem) for O(1) memory backprop
18
+ - This gives effective depth of T_outer * T_inner * N_blocks layers
19
+ from only N_blocks parameter sets
20
+
21
+ Memory budget at inference (1024x1024, INT8):
22
+ - Text encoder: ~150MB (TinyCLIP-ViT-B/16)
23
+ - VAE encoder: ~100MB (f32 encoder, only needed for editing)
24
+ - VAE decoder: ~6MB (SnapGen-style tiny decoder)
25
+ - LRF core: ~200-400MB (depending on config)
26
+ - Activations: ~500MB peak
27
+ - Total: ~1-1.5GB model + ~500MB activations = 1.5-2GB
28
+ """
29
+
30
+ import math
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from einops import rearrange, repeat
35
+ from typing import Optional, Tuple, Dict, Any
36
+
37
+
38
+ # ============================================================================
39
+ # Utility Modules
40
+ # ============================================================================
41
+
42
+ class RMSNorm(nn.Module):
43
+ """RMSNorm - more stable than LayerNorm for small models."""
44
+ def __init__(self, dim: int, eps: float = 1e-6):
45
+ super().__init__()
46
+ self.eps = eps
47
+ self.weight = nn.Parameter(torch.ones(dim))
48
+
49
+ def forward(self, x):
50
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
51
+ return (x.float() * norm).type_as(x) * self.weight
52
+
53
+
54
+ class SwiGLU(nn.Module):
55
+ """SwiGLU FFN - better than GELU for small models, mobile-friendly (SiLU not GELU)."""
56
+ def __init__(self, dim: int, hidden_dim: Optional[int] = None, dropout: float = 0.0):
57
+ super().__init__()
58
+ hidden_dim = hidden_dim or int(dim * 8 / 3)
59
+ # Round to nearest multiple of 8 for efficiency
60
+ hidden_dim = ((hidden_dim + 7) // 8) * 8
61
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
62
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
63
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
64
+ self.dropout = nn.Dropout(dropout)
65
+
66
+ def forward(self, x):
67
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
68
+
69
+
70
+ class DepthwiseSeparableConv2d(nn.Module):
71
+ """Mobile-optimized convolution."""
72
+ def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
73
+ super().__init__()
74
+ padding = kernel_size // 2
75
+ self.dw = nn.Conv2d(in_channels, in_channels, kernel_size, padding=padding, groups=in_channels, bias=False)
76
+ self.pw = nn.Conv2d(in_channels, out_channels, 1, bias=False)
77
+
78
+ def forward(self, x):
79
+ return self.pw(self.dw(x))
80
+
81
+
82
+ # ============================================================================
83
+ # 2D Positional Encoding
84
+ # ============================================================================
85
+
86
+ class RotaryPositionEncoding2D(nn.Module):
87
+ """2D RoPE for spatial tokens - resolution-independent."""
88
+ def __init__(self, dim: int, max_res: int = 64):
89
+ super().__init__()
90
+ self.dim = dim
91
+ half_dim = dim // 4 # Split into 4 parts: sin_h, cos_h, sin_w, cos_w
92
+ freqs = torch.exp(torch.arange(half_dim) * -(math.log(10000.0) / half_dim))
93
+ self.register_buffer('freqs', freqs)
94
+
95
+ def forward(self, h: int, w: int, device=None):
96
+ device = device or self.freqs.device
97
+ pos_h = torch.arange(h, device=device).float()
98
+ pos_w = torch.arange(w, device=device).float()
99
+
100
+ freqs_h = torch.outer(pos_h, self.freqs.to(device)) # [H, D/4]
101
+ freqs_w = torch.outer(pos_w, self.freqs.to(device)) # [W, D/4]
102
+
103
+ # Expand to [H, W, D/4] each
104
+ freqs_h = freqs_h.unsqueeze(1).expand(-1, w, -1)
105
+ freqs_w = freqs_w.unsqueeze(0).expand(h, -1, -1)
106
+
107
+ # Concatenate: [H, W, D/2] for sin, [H, W, D/2] for cos
108
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1) # [H, W, D/2]
109
+
110
+ sin_enc = freqs.sin()
111
+ cos_enc = freqs.cos()
112
+
113
+ return sin_enc.reshape(h * w, -1), cos_enc.reshape(h * w, -1)
114
+
115
+
116
+ def apply_rope_2d(x, sin_enc, cos_enc):
117
+ """Apply 2D RoPE to queries/keys."""
118
+ d = x.shape[-1]
119
+ half_d = d // 2
120
+ x1, x2 = x[..., :half_d], x[..., half_d:]
121
+ # Expand sin/cos to match batch dims
122
+ while sin_enc.dim() < x1.dim():
123
+ sin_enc = sin_enc.unsqueeze(0)
124
+ cos_enc = cos_enc.unsqueeze(0)
125
+ return torch.cat([x1 * cos_enc - x2 * sin_enc, x2 * cos_enc + x1 * sin_enc], dim=-1)
126
+
127
+
128
+ # ============================================================================
129
+ # Gated Linear Diffusion (GLD) Block - The Core Spatial Mixer
130
+ # ============================================================================
131
+
132
+ class GatedLinearAttention(nn.Module):
133
+ """
134
+ Gated Linear Attention for 2D spatial mixing.
135
+ O(N) complexity instead of O(N²) softmax attention.
136
+
137
+ Based on ViG/GLA research but adapted for diffusion:
138
+ - Bidirectional scan (forward + backward)
139
+ - 2D locality injection via depthwise conv gating
140
+ - Token-differential operator to prevent oversmoothing (from DyDiLA)
141
+
142
+ Math:
143
+ Q, K, V = linear(x), linear(x), linear(x)
144
+ Q = phi(Q), K = phi(K) where phi = 1 + elu (non-negative feature map)
145
+
146
+ Forward scan: S_i = decay * S_{i-1} + K_i^T V_i; O_i = Q_i S_i
147
+ Backward scan: same in reverse
148
+
149
+ Output = gate * (O_fwd + O_bwd) * local_gate
150
+
151
+ Complexity: O(N * d²) where d is head dimension, N is sequence length
152
+ """
153
+ def __init__(self, dim: int, num_heads: int = 8, head_dim: int = 32, dropout: float = 0.0):
154
+ super().__init__()
155
+ self.num_heads = num_heads
156
+ self.head_dim = head_dim
157
+ inner_dim = num_heads * head_dim
158
+
159
+ self.qkv = nn.Linear(dim, 3 * inner_dim, bias=False)
160
+ self.out_proj = nn.Linear(inner_dim, dim, bias=False)
161
+
162
+ # Learnable decay for recurrence (per-head)
163
+ self.log_decay = nn.Parameter(torch.zeros(num_heads))
164
+
165
+ # Gate for output
166
+ self.gate = nn.Linear(dim, inner_dim, bias=False)
167
+
168
+ # 2D locality injection (depthwise conv) - critical for spatial structure
169
+ self.local_conv = nn.Conv2d(inner_dim, inner_dim, 3, padding=1, groups=inner_dim, bias=False)
170
+ self.local_gate = nn.Linear(dim, inner_dim, bias=False)
171
+
172
+ # Token differential parameter (from DyDiLA - prevents oversmoothing)
173
+ self.diff_lambda = nn.Parameter(torch.tensor(0.1))
174
+
175
+ self.dropout = nn.Dropout(dropout)
176
+ self.norm = RMSNorm(inner_dim)
177
+
178
+ def _feature_map(self, x):
179
+ """Non-negative feature map: 1 + elu(x)"""
180
+ return 1.0 + F.elu(x)
181
+
182
+ def _scan(self, Q, K, V, reverse=False):
183
+ """Linear recurrent scan - O(N * d²) per direction."""
184
+ B, H, N, D = Q.shape
185
+
186
+ decay = torch.sigmoid(self.log_decay).view(1, H, 1, 1) # [1, H, 1, 1]
187
+
188
+ if reverse:
189
+ Q = Q.flip(2)
190
+ K = K.flip(2)
191
+ V = V.flip(2)
192
+
193
+ # Chunk-wise computation for memory efficiency
194
+ chunk_size = min(64, N)
195
+ outputs = []
196
+ S = torch.zeros(B, H, D, D, device=Q.device, dtype=Q.dtype)
197
+
198
+ for i in range(0, N, chunk_size):
199
+ q_chunk = Q[:, :, i:i+chunk_size] # [B, H, C, D]
200
+ k_chunk = K[:, :, i:i+chunk_size]
201
+ v_chunk = V[:, :, i:i+chunk_size]
202
+
203
+ chunk_len = q_chunk.shape[2]
204
+
205
+ # Update state: S = decay * S + K^T V
206
+ kv = torch.einsum('bhcd,bhce->bhde', k_chunk, v_chunk)
207
+ S = decay * S + kv
208
+
209
+ # Query state: O = Q S
210
+ o_chunk = torch.einsum('bhcd,bhde->bhce', q_chunk, S)
211
+ outputs.append(o_chunk)
212
+
213
+ output = torch.cat(outputs, dim=2)
214
+
215
+ if reverse:
216
+ output = output.flip(2)
217
+
218
+ return output
219
+
220
+ def forward(self, x, h: int, w: int):
221
+ """
222
+ Args:
223
+ x: [B, N, D] where N = H*W
224
+ h, w: spatial dimensions
225
+ Returns:
226
+ [B, N, D]
227
+ """
228
+ B, N, D = x.shape
229
+
230
+ # Project to Q, K, V
231
+ qkv = self.qkv(x)
232
+ q, k, v = qkv.chunk(3, dim=-1)
233
+
234
+ # Reshape to heads
235
+ q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads)
236
+ k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads)
237
+ v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads)
238
+
239
+ # Token differential (prevents oversmoothing)
240
+ # Q_diff = Q_i - lambda * Q_{i-1}, K_diff = K_i - lambda * K_{i-1}
241
+ lam = torch.sigmoid(self.diff_lambda)
242
+ q_shifted = F.pad(q[:, :, :-1], (0, 0, 1, 0))
243
+ k_shifted = F.pad(k[:, :, :-1], (0, 0, 1, 0))
244
+ q = q - lam * q_shifted
245
+ k = k - lam * k_shifted
246
+
247
+ # Apply feature map (non-negative)
248
+ q = self._feature_map(q)
249
+ k = self._feature_map(k)
250
+
251
+ # Bidirectional scan
252
+ o_fwd = self._scan(q, k, v, reverse=False)
253
+ o_bwd = self._scan(q, k, v, reverse=True)
254
+ output = o_fwd + o_bwd
255
+
256
+ # Normalize
257
+ output = rearrange(output, 'b h n d -> b n (h d)')
258
+ output = self.norm(output)
259
+
260
+ # 2D locality injection (GaLI from ViG)
261
+ x_2d = rearrange(x, 'b (h w) d -> b d h w', h=h, w=w)
262
+ gate_input = rearrange(x, 'b n d -> b n d')
263
+ local_feat = self.local_conv(rearrange(self.local_gate(gate_input), 'b (h w) d -> b d h w', h=h, w=w))
264
+ local_feat = rearrange(local_feat, 'b d h w -> b (h w) d')
265
+
266
+ # Gated output
267
+ g = torch.sigmoid(self.gate(x))
268
+ output = g * output * torch.sigmoid(local_feat)
269
+
270
+ return self.dropout(self.out_proj(output))
271
+
272
+
273
+ class GLDBlock(nn.Module):
274
+ """
275
+ Gated Linear Diffusion Block.
276
+
277
+ Components:
278
+ 1. GatedLinearAttention for spatial mixing (O(N) complexity)
279
+ 2. SwiGLU FFN for channel mixing
280
+ 3. Timestep + condition modulation (adaptive layer norm)
281
+ 4. 2D RoPE for position encoding
282
+
283
+ This replaces the standard transformer block in diffusion models.
284
+ """
285
+ def __init__(
286
+ self,
287
+ dim: int,
288
+ num_heads: int = 8,
289
+ head_dim: int = 32,
290
+ ffn_mult: float = 2.67,
291
+ dropout: float = 0.0,
292
+ cond_dim: int = 256,
293
+ ):
294
+ super().__init__()
295
+ self.norm1 = RMSNorm(dim)
296
+ self.norm2 = RMSNorm(dim)
297
+
298
+ self.attn = GatedLinearAttention(dim, num_heads, head_dim, dropout)
299
+ self.ffn = SwiGLU(dim, int(dim * ffn_mult), dropout)
300
+
301
+ # Adaptive modulation (scale, shift, gate for each sub-layer)
302
+ # Conditioned on timestep + text embedding
303
+ self.adaLN_modulation = nn.Sequential(
304
+ nn.SiLU(),
305
+ nn.Linear(cond_dim, 6 * dim, bias=False),
306
+ )
307
+
308
+ # Cross-attention to text (lightweight - only when text is available)
309
+ self.cross_norm = RMSNorm(dim)
310
+ self.cross_q = nn.Linear(dim, dim, bias=False)
311
+ self.cross_kv = nn.Linear(cond_dim, 2 * dim, bias=False)
312
+ self.cross_out = nn.Linear(dim, dim, bias=False)
313
+ self.cross_gate = nn.Parameter(torch.zeros(1)) # Zero-init for residual
314
+
315
+ def forward(
316
+ self,
317
+ x: torch.Tensor, # [B, N, D]
318
+ cond: torch.Tensor, # [B, cond_dim] - timestep + global condition
319
+ text_ctx: Optional[torch.Tensor] = None, # [B, T, cond_dim] - text tokens
320
+ h: int = 32,
321
+ w: int = 32,
322
+ ) -> torch.Tensor:
323
+ B, N, D = x.shape
324
+
325
+ # Compute modulation parameters
326
+ mod = self.adaLN_modulation(cond) # [B, 6*D]
327
+ shift1, scale1, gate1, shift2, scale2, gate2 = mod.chunk(6, dim=-1)
328
+
329
+ # Pre-norm + modulate + GLA
330
+ x_norm = self.norm1(x)
331
+ x_norm = x_norm * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1)
332
+ x = x + gate1.unsqueeze(1) * self.attn(x_norm, h, w)
333
+
334
+ # Cross-attention to text (if available)
335
+ if text_ctx is not None:
336
+ x_cross = self.cross_norm(x)
337
+ q = self.cross_q(x_cross)
338
+ kv = self.cross_kv(text_ctx)
339
+ k, v = kv.chunk(2, dim=-1)
340
+
341
+ # Simple dot-product attention (text sequence is short, so O(N*T) is fine)
342
+ scale = q.shape[-1] ** -0.5
343
+ attn_weights = torch.bmm(q, k.transpose(-2, -1)) * scale
344
+ attn_weights = F.softmax(attn_weights, dim=-1)
345
+ cross_out = torch.bmm(attn_weights, v)
346
+ x = x + torch.tanh(self.cross_gate) * self.cross_out(cross_out)
347
+
348
+ # Pre-norm + modulate + FFN
349
+ x_norm = self.norm2(x)
350
+ x_norm = x_norm * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)
351
+ x = x + gate2.unsqueeze(1) * self.ffn(x_norm)
352
+
353
+ return x
354
+
355
+
356
+ # ============================================================================
357
+ # Recursive Latent Refinement (RLR) Core - THE KEY INNOVATION
358
+ # ============================================================================
359
+
360
+ class RecursiveLatentCore(nn.Module):
361
+ """
362
+ The Recursive Latent Refinement (RLR) Core.
363
+
364
+ This is the key architectural innovation of LRF. Instead of stacking
365
+ many unique transformer layers (like DiT with 28 layers), we use a
366
+ small set of GLD blocks applied RECURSIVELY through an HRM-inspired
367
+ iterative refinement loop.
368
+
369
+ Architecture:
370
+ - N_blocks GLD blocks (typically 4-6, shared across recursions)
371
+ - T_inner recursive applications per outer step (typically 4-6)
372
+ - T_outer outer steps with slow abstract state update (typically 2-3)
373
+
374
+ Effective depth: T_outer * T_inner * N_blocks = 2*4*4 = 32 effective layers
375
+ Actual parameters: only N_blocks sets = 4 unique block parameter sets
376
+
377
+ Training uses IFT (Implicit Function Theorem):
378
+ - Forward: run full recursion with torch.no_grad() for warmup
379
+ - Backward: only backprop through the LAST recursion step
380
+ - This gives O(1) memory cost regardless of recursion depth!
381
+
382
+ Mathematical formulation:
383
+
384
+ Let z be the noisy latent, c be the condition embedding.
385
+
386
+ Outer loop (j = 1..T_outer):
387
+ z_abstract = f_slow(z, c) # Abstract planning update
388
+ Inner loop (i = 1..T_inner):
389
+ z = f_blocks(z, z_abstract, c) # Apply N shared GLD blocks
390
+
391
+ Where f_blocks applies the same N GLD blocks in sequence.
392
+
393
+ The model learns a FIXED POINT: z* = f(z*, c)
394
+ At convergence, the output is the denoised prediction v(z_t, t, c).
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ dim: int = 384,
400
+ cond_dim: int = 256,
401
+ num_blocks: int = 4,
402
+ num_heads: int = 6,
403
+ head_dim: int = 64,
404
+ T_inner: int = 4,
405
+ T_outer: int = 2,
406
+ ffn_mult: float = 2.67,
407
+ dropout: float = 0.0,
408
+ use_ift_training: bool = True,
409
+ ):
410
+ super().__init__()
411
+ self.dim = dim
412
+ self.cond_dim = cond_dim
413
+ self.num_blocks = num_blocks
414
+ self.T_inner = T_inner
415
+ self.T_outer = T_outer
416
+ self.use_ift_training = use_ift_training
417
+
418
+ # The shared GLD blocks (applied recursively)
419
+ self.blocks = nn.ModuleList([
420
+ GLDBlock(
421
+ dim=dim,
422
+ num_heads=num_heads,
423
+ head_dim=head_dim,
424
+ ffn_mult=ffn_mult,
425
+ dropout=dropout,
426
+ cond_dim=cond_dim,
427
+ )
428
+ for _ in range(num_blocks)
429
+ ])
430
+
431
+ # Abstract state updater (the "slow" H-module from HRM)
432
+ # This updates a global abstract representation every T_inner steps
433
+ self.abstract_norm = RMSNorm(dim)
434
+ self.abstract_update = nn.Sequential(
435
+ nn.Linear(dim * 2, dim, bias=False),
436
+ nn.SiLU(),
437
+ nn.Linear(dim, dim, bias=False),
438
+ )
439
+ self.abstract_gate = nn.Parameter(torch.zeros(1)) # Zero-init
440
+
441
+ # Input projection
442
+ self.input_proj = nn.Linear(dim, dim, bias=False)
443
+
444
+ # Timestep embedding
445
+ self.time_embed = nn.Sequential(
446
+ nn.Linear(256, cond_dim),
447
+ nn.SiLU(),
448
+ nn.Linear(cond_dim, cond_dim),
449
+ )
450
+
451
+ # Output projection (predicts velocity v for rectified flow)
452
+ self.out_norm = RMSNorm(dim)
453
+ self.out_proj = nn.Sequential(
454
+ nn.Linear(dim, dim, bias=False),
455
+ nn.SiLU(),
456
+ nn.Linear(dim, dim, bias=False),
457
+ )
458
+
459
+ # Recursion depth embedding (tells the model which recursion step it's on)
460
+ self.recursion_embed = nn.Embedding(T_outer * T_inner + 1, cond_dim)
461
+
462
+ # 2D positional encoding
463
+ self.rope = RotaryPositionEncoding2D(head_dim)
464
+
465
+ def _sinusoidal_embedding(self, t: torch.Tensor, dim: int = 256) -> torch.Tensor:
466
+ """Sinusoidal timestep embedding."""
467
+ half_dim = dim // 2
468
+ emb = math.log(10000) / (half_dim - 1)
469
+ emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
470
+ emb = t.unsqueeze(-1) * emb.unsqueeze(0)
471
+ return torch.cat([emb.sin(), emb.cos()], dim=-1)
472
+
473
+ def _apply_blocks(self, z, cond, text_ctx, h, w):
474
+ """Apply all GLD blocks once."""
475
+ for block in self.blocks:
476
+ z = block(z, cond, text_ctx, h, w)
477
+ return z
478
+
479
+ def _recursive_refinement(self, z, cond_base, text_ctx, h, w):
480
+ """
481
+ Full recursive refinement loop.
482
+
483
+ Returns the refined latent z after T_outer * T_inner applications.
484
+ """
485
+ z_abstract = z.mean(dim=1, keepdim=True).expand_as(z) # Initial abstract state
486
+
487
+ step_idx = 0
488
+ for j in range(self.T_outer):
489
+ # Abstract state update (slow H-module)
490
+ z_pooled = z.mean(dim=1, keepdim=True).expand_as(z)
491
+ abstract_input = torch.cat([self.abstract_norm(z), z_pooled], dim=-1)
492
+ z_abstract = z_abstract + torch.tanh(self.abstract_gate) * self.abstract_update(abstract_input)
493
+
494
+ for i in range(self.T_inner):
495
+ # Add recursion depth information to conditioning
496
+ rec_emb = self.recursion_embed(
497
+ torch.tensor([step_idx], device=z.device)
498
+ ).expand(z.shape[0], -1)
499
+ cond = cond_base + rec_emb
500
+
501
+ # Apply shared blocks with abstract state modulation
502
+ z_input = z + z_abstract # Combine detail + abstract
503
+ z = z + (self._apply_blocks(z_input, cond, text_ctx, h, w) - z) * 0.5 # Damped update
504
+
505
+ step_idx += 1
506
+
507
+ return z
508
+
509
+ def forward(
510
+ self,
511
+ z_t: torch.Tensor, # [B, C, H, W] - noisy latent
512
+ t: torch.Tensor, # [B] - timestep (0 to 1)
513
+ text_emb: Optional[torch.Tensor] = None, # [B, T, cond_dim] - text tokens
514
+ text_global: Optional[torch.Tensor] = None, # [B, cond_dim] - global text embedding
515
+ image_cond: Optional[torch.Tensor] = None, # [B, C, H, W] - for editing tasks
516
+ ) -> torch.Tensor:
517
+ """
518
+ Forward pass predicting velocity v_theta(z_t, t, c).
519
+
520
+ For rectified flow: z_t = (1-t) * z_0 + t * epsilon
521
+ Target: v = epsilon - z_0
522
+ """
523
+ B, C, H, W = z_t.shape
524
+
525
+ # Flatten spatial dims
526
+ z = rearrange(z_t, 'b c h w -> b (h w) c')
527
+
528
+ # If editing: concatenate condition image (channel-wise before projection)
529
+ if image_cond is not None:
530
+ img_cond_flat = rearrange(image_cond, 'b c h w -> b (h w) c')
531
+ z = z + img_cond_flat # Additive conditioning preserves spatial correspondence
532
+
533
+ # Project
534
+ z = self.input_proj(z)
535
+
536
+ # Build conditioning
537
+ t_emb = self._sinusoidal_embedding(t)
538
+ t_emb = self.time_embed(t_emb) # [B, cond_dim]
539
+
540
+ if text_global is not None:
541
+ cond = t_emb + text_global
542
+ else:
543
+ cond = t_emb
544
+
545
+ # Apply recursive refinement
546
+ if self.training and self.use_ift_training:
547
+ # IFT training: no_grad warmup + 1-step grad
548
+ with torch.no_grad():
549
+ for _ in range(self.T_outer - 1):
550
+ z = self._recursive_refinement(z, cond, text_emb, H, W)
551
+ # Last step with gradients
552
+ z = self._recursive_refinement(z, cond, text_emb, H, W)
553
+ else:
554
+ # Full recursion (inference or non-IFT training)
555
+ z = self._recursive_refinement(z, cond, text_emb, H, W)
556
+
557
+ # Output projection
558
+ z = self.out_norm(z)
559
+ v = self.out_proj(z)
560
+
561
+ # Reshape back to spatial
562
+ v = rearrange(v, 'b (h w) c -> b c h w', h=H, w=W)
563
+
564
+ return v
565
+
566
+
567
+ # ============================================================================
568
+ # Compact VAE (Tiny Decoder inspired by SnapGen)
569
+ # ============================================================================
570
+
571
+ class TinyResBlock(nn.Module):
572
+ """Ultra-compact residual block for tiny decoder."""
573
+ def __init__(self, in_channels: int, out_channels: int = None):
574
+ super().__init__()
575
+ out_channels = out_channels or in_channels
576
+ self.norm1 = nn.GroupNorm(min(8, in_channels), in_channels)
577
+ self.conv1 = DepthwiseSeparableConv2d(in_channels, out_channels, 3)
578
+ self.norm2 = nn.GroupNorm(min(8, out_channels), out_channels)
579
+ self.conv2 = DepthwiseSeparableConv2d(out_channels, out_channels, 3)
580
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) if in_channels != out_channels else nn.Identity()
581
+
582
+ def forward(self, x):
583
+ h = self.conv1(F.silu(self.norm1(x)))
584
+ h = self.conv2(F.silu(self.norm2(h)))
585
+ return self.skip(x) + h
586
+
587
+
588
+ class CompactEncoder(nn.Module):
589
+ """
590
+ Compact image encoder: image -> latent space.
591
+ f=16 spatial compression, C_latent channels.
592
+
593
+ Uses strided depthwise-separable convolutions for efficiency.
594
+ 4 downsampling stages: 256->128->64->32->16 (for 256x256 input)
595
+ """
596
+ def __init__(
597
+ self,
598
+ in_channels: int = 3,
599
+ latent_channels: int = 32,
600
+ base_channels: int = 64,
601
+ num_res_blocks: int = 2,
602
+ ):
603
+ super().__init__()
604
+ channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 4]
605
+
606
+ self.stem = nn.Conv2d(in_channels, channels[0], 3, padding=1, bias=False)
607
+
608
+ self.downs = nn.ModuleList()
609
+ ch_in = channels[0]
610
+ for ch_out in channels:
611
+ blocks = nn.ModuleList()
612
+ # First block handles channel transition
613
+ blocks.append(TinyResBlock(ch_in, ch_out))
614
+ for _ in range(num_res_blocks - 1):
615
+ blocks.append(TinyResBlock(ch_out, ch_out))
616
+ # Downsample with strided conv
617
+ down = nn.Conv2d(ch_out, ch_out, 4, stride=2, padding=1, bias=False)
618
+ self.downs.append(nn.ModuleDict({
619
+ 'blocks': blocks,
620
+ 'down': down,
621
+ }))
622
+ ch_in = ch_out
623
+
624
+ # To latent
625
+ self.to_latent = nn.Sequential(
626
+ nn.GroupNorm(8, ch_in),
627
+ nn.SiLU(),
628
+ nn.Conv2d(ch_in, latent_channels * 2, 1, bias=False), # *2 for mean+logvar
629
+ )
630
+
631
+ def forward(self, x):
632
+ h = self.stem(x)
633
+ for down_module in self.downs:
634
+ for block in down_module['blocks']:
635
+ h = block(h)
636
+ h = down_module['down'](h)
637
+
638
+ params = self.to_latent(h)
639
+ mean, logvar = params.chunk(2, dim=1)
640
+ logvar = torch.clamp(logvar, -30.0, 20.0)
641
+
642
+ return mean, logvar
643
+
644
+
645
+ class TinyDecoder(nn.Module):
646
+ """
647
+ SnapGen-inspired tiny decoder: latent -> image.
648
+ ~1-2M parameters. No attention layers.
649
+ Uses depthwise-separable convolutions + minimal GroupNorm.
650
+
651
+ 4 upsampling stages matching the encoder.
652
+ """
653
+ def __init__(
654
+ self,
655
+ latent_channels: int = 32,
656
+ out_channels: int = 3,
657
+ base_channels: int = 128,
658
+ num_res_blocks: int = 2,
659
+ ):
660
+ super().__init__()
661
+ channels = [base_channels * 2, base_channels * 2, base_channels, base_channels // 2]
662
+
663
+ self.from_latent = nn.Conv2d(latent_channels, channels[0], 1, bias=False)
664
+
665
+ self.ups = nn.ModuleList()
666
+ ch_in = channels[0]
667
+ for ch_out in channels:
668
+ blocks = nn.ModuleList()
669
+ for _ in range(num_res_blocks):
670
+ blocks.append(TinyResBlock(ch_in, ch_in))
671
+ # Upsample with channel transition
672
+ up = nn.Sequential(
673
+ nn.Upsample(scale_factor=2, mode='nearest'),
674
+ DepthwiseSeparableConv2d(ch_in, ch_out, 3),
675
+ )
676
+ self.ups.append(nn.ModuleDict({
677
+ 'blocks': blocks,
678
+ 'up': up,
679
+ }))
680
+ ch_in = ch_out
681
+
682
+ self.to_image = nn.Sequential(
683
+ nn.GroupNorm(min(8, ch_in), ch_in),
684
+ nn.SiLU(),
685
+ nn.Conv2d(ch_in, out_channels, 3, padding=1),
686
+ nn.Tanh(), # Output in [-1, 1]
687
+ )
688
+
689
+ def forward(self, z):
690
+ h = self.from_latent(z)
691
+ for up_module in self.ups:
692
+ for block in up_module['blocks']:
693
+ h = block(h)
694
+ h = up_module['up'](h)
695
+ return self.to_image(h)
696
+
697
+
698
+ class CompactVAE(nn.Module):
699
+ """
700
+ Complete VAE with compact encoder + tiny decoder.
701
+ f=16 compression, configurable latent channels.
702
+ """
703
+ def __init__(
704
+ self,
705
+ in_channels: int = 3,
706
+ latent_channels: int = 32,
707
+ encoder_base_ch: int = 64,
708
+ decoder_base_ch: int = 128,
709
+ ):
710
+ super().__init__()
711
+ self.encoder = CompactEncoder(in_channels, latent_channels, encoder_base_ch)
712
+ self.decoder = TinyDecoder(latent_channels, in_channels, decoder_base_ch)
713
+ self.latent_channels = latent_channels
714
+
715
+ def encode(self, x):
716
+ mean, logvar = self.encoder(x)
717
+ if self.training:
718
+ std = torch.exp(0.5 * logvar)
719
+ eps = torch.randn_like(std)
720
+ z = mean + eps * std
721
+ else:
722
+ z = mean
723
+ return z, mean, logvar
724
+
725
+ def decode(self, z):
726
+ return self.decoder(z)
727
+
728
+ def forward(self, x):
729
+ z, mean, logvar = self.encode(x)
730
+ recon = self.decode(z)
731
+ return recon, mean, logvar
732
+
733
+
734
+ # ============================================================================
735
+ # Text Conditioner (Lightweight)
736
+ # ============================================================================
737
+
738
+ class SimpleTextEncoder(nn.Module):
739
+ """
740
+ Lightweight text encoder for the standalone prototype.
741
+ In production, this would be replaced by TinyCLIP or a small LM.
742
+
743
+ For the prototype: simple learned embeddings + small transformer.
744
+ This lets us test the full pipeline without a heavy text encoder.
745
+ """
746
+ def __init__(
747
+ self,
748
+ vocab_size: int = 32000,
749
+ max_length: int = 77,
750
+ dim: int = 256,
751
+ num_layers: int = 4,
752
+ num_heads: int = 4,
753
+ ):
754
+ super().__init__()
755
+ self.dim = dim
756
+ self.token_embed = nn.Embedding(vocab_size, dim)
757
+ self.pos_embed = nn.Embedding(max_length, dim)
758
+
759
+ encoder_layer = nn.TransformerEncoderLayer(
760
+ d_model=dim, nhead=num_heads, dim_feedforward=dim*4,
761
+ dropout=0.1, activation='gelu', batch_first=True, norm_first=True
762
+ )
763
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
764
+ self.norm = RMSNorm(dim)
765
+
766
+ # Global pooling projection
767
+ self.global_proj = nn.Sequential(
768
+ nn.Linear(dim, dim),
769
+ nn.SiLU(),
770
+ nn.Linear(dim, dim),
771
+ )
772
+
773
+ def forward(self, token_ids, attention_mask=None):
774
+ B, T = token_ids.shape
775
+ pos_ids = torch.arange(T, device=token_ids.device).unsqueeze(0).expand(B, -1)
776
+
777
+ x = self.token_embed(token_ids) + self.pos_embed(pos_ids)
778
+
779
+ if attention_mask is not None:
780
+ # Convert to transformer mask (True = ignore)
781
+ src_key_padding_mask = ~attention_mask.bool()
782
+ else:
783
+ src_key_padding_mask = None
784
+
785
+ x = self.transformer(x, src_key_padding_mask=src_key_padding_mask)
786
+ x = self.norm(x)
787
+
788
+ # Global embedding (mean pool over non-padded tokens)
789
+ if attention_mask is not None:
790
+ mask = attention_mask.unsqueeze(-1).float()
791
+ global_emb = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
792
+ else:
793
+ global_emb = x.mean(dim=1)
794
+
795
+ global_emb = self.global_proj(global_emb)
796
+
797
+ return x, global_emb # [B, T, D], [B, D]
798
+
799
+
800
+ # ============================================================================
801
+ # Full LRF Model
802
+ # ============================================================================
803
+
804
+ class LatentRecurrentFlow(nn.Module):
805
+ """
806
+ LatentRecurrentFlow (LRF) - Complete model.
807
+
808
+ Combines:
809
+ 1. CompactVAE for image encoding/decoding
810
+ 2. SimpleTextEncoder for text conditioning
811
+ 3. RecursiveLatentCore for denoising
812
+
813
+ Training modes:
814
+ - 'vae': Train only the VAE
815
+ - 'denoise': Train only the denoising core (freeze VAE)
816
+ - 'e2e': End-to-end fine-tuning
817
+ - 'distill': Consistency distillation from teacher
818
+ """
819
+
820
+ def __init__(self, config: Optional[Dict[str, Any]] = None):
821
+ super().__init__()
822
+
823
+ config = config or self.default_config()
824
+ self.config = config
825
+
826
+ # VAE
827
+ self.vae = CompactVAE(
828
+ in_channels=3,
829
+ latent_channels=config['latent_channels'],
830
+ encoder_base_ch=config.get('encoder_base_ch', 64),
831
+ decoder_base_ch=config.get('decoder_base_ch', 128),
832
+ )
833
+
834
+ # Text encoder
835
+ self.text_encoder = SimpleTextEncoder(
836
+ vocab_size=config.get('vocab_size', 32000),
837
+ max_length=config.get('max_text_length', 77),
838
+ dim=config['cond_dim'],
839
+ num_layers=config.get('text_layers', 4),
840
+ num_heads=config.get('text_heads', 4),
841
+ )
842
+
843
+ # Denoising core
844
+ self.core = RecursiveLatentCore(
845
+ dim=config['latent_channels'],
846
+ cond_dim=config['cond_dim'],
847
+ num_blocks=config['num_blocks'],
848
+ num_heads=config.get('num_heads', 6),
849
+ head_dim=config.get('head_dim', 64),
850
+ T_inner=config.get('T_inner', 4),
851
+ T_outer=config.get('T_outer', 2),
852
+ ffn_mult=config.get('ffn_mult', 2.67),
853
+ dropout=config.get('dropout', 0.0),
854
+ use_ift_training=config.get('use_ift', True),
855
+ )
856
+
857
+ # Latent scaling (learnable, stabilizes training)
858
+ self.latent_scale = nn.Parameter(torch.tensor(1.0))
859
+
860
+ @staticmethod
861
+ def default_config():
862
+ """Default config targeting ~50M params, trainable on 16GB."""
863
+ return {
864
+ 'latent_channels': 32,
865
+ 'cond_dim': 256,
866
+ 'num_blocks': 4,
867
+ 'num_heads': 4,
868
+ 'head_dim': 64,
869
+ 'T_inner': 4,
870
+ 'T_outer': 2,
871
+ 'ffn_mult': 2.67,
872
+ 'dropout': 0.0,
873
+ 'use_ift': True,
874
+ 'encoder_base_ch': 64,
875
+ 'decoder_base_ch': 128,
876
+ 'vocab_size': 32000,
877
+ 'max_text_length': 77,
878
+ 'text_layers': 4,
879
+ 'text_heads': 4,
880
+ }
881
+
882
+ @staticmethod
883
+ def tiny_config():
884
+ """Tiny config for quick testing."""
885
+ return {
886
+ 'latent_channels': 16,
887
+ 'cond_dim': 128,
888
+ 'num_blocks': 2,
889
+ 'num_heads': 2,
890
+ 'head_dim': 32,
891
+ 'T_inner': 2,
892
+ 'T_outer': 1,
893
+ 'ffn_mult': 2.0,
894
+ 'dropout': 0.0,
895
+ 'use_ift': False,
896
+ 'encoder_base_ch': 32,
897
+ 'decoder_base_ch': 64,
898
+ 'vocab_size': 32000,
899
+ 'max_text_length': 77,
900
+ 'text_layers': 2,
901
+ 'text_heads': 2,
902
+ }
903
+
904
+ def encode_image(self, x):
905
+ """Encode image to latent space."""
906
+ z, mean, logvar = self.vae.encode(x)
907
+ return z * self.latent_scale, mean, logvar
908
+
909
+ def decode_latent(self, z):
910
+ """Decode latent to image."""
911
+ return self.vae.decode(z / self.latent_scale)
912
+
913
+ def encode_text(self, token_ids, attention_mask=None):
914
+ """Encode text to conditioning vectors."""
915
+ return self.text_encoder(token_ids, attention_mask)
916
+
917
+ def predict_velocity(self, z_t, t, text_emb=None, text_global=None, image_cond=None):
918
+ """Predict velocity for rectified flow."""
919
+ return self.core(z_t, t, text_emb, text_global, image_cond)
920
+
921
+ def get_param_groups(self):
922
+ """Return parameter groups for staged training."""
923
+ return {
924
+ 'vae_encoder': list(self.vae.encoder.parameters()),
925
+ 'vae_decoder': list(self.vae.decoder.parameters()),
926
+ 'text_encoder': list(self.text_encoder.parameters()),
927
+ 'core': list(self.core.parameters()),
928
+ 'latent_scale': [self.latent_scale],
929
+ }
930
+
931
+ def count_parameters(self):
932
+ """Count parameters per module."""
933
+ counts = {}
934
+ for name, module in [
935
+ ('vae_encoder', self.vae.encoder),
936
+ ('vae_decoder', self.vae.decoder),
937
+ ('text_encoder', self.text_encoder),
938
+ ('core', self.core),
939
+ ]:
940
+ counts[name] = sum(p.numel() for p in module.parameters())
941
+ counts['latent_scale'] = 1
942
+ counts['total'] = sum(counts.values())
943
+ return counts
944
+
945
+ def forward(self, x=None, token_ids=None, attention_mask=None, **kwargs):
946
+ """Full forward pass for training. See training script for usage."""
947
+ raise NotImplementedError(
948
+ "Use the training pipeline functions instead of calling forward() directly. "
949
+ "See LRFTrainer for VAE training, denoiser training, and distillation."
950
+ )