asdf98 commited on
Commit
a55971a
·
verified ·
1 Parent(s): 1470d5b

Add microforge_notebook.ipynb

Browse files
Files changed (1) hide show
  1. microforge_notebook.ipynb +837 -0
microforge_notebook.ipynb ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🔨 MicroForge: A Novel Mobile-First Image Generation Architecture\n",
8
+ "\n",
9
+ "**A genuinely new architecture combining Recurrent Latent Planning, SSM-Conv Hybrid Backbone, and Deep Compression VAE**\n",
10
+ "\n",
11
+ "This notebook demonstrates the complete MicroForge architecture:\n",
12
+ "- Module-by-module construction and testing\n",
13
+ "- End-to-end training pipeline (VAE + backbone + planner)\n",
14
+ "- Inference for text-to-image generation\n",
15
+ "- Memory and compute profiling\n",
16
+ "- Staged training curriculum design\n",
17
+ "\n",
18
+ "## Architecture Overview\n",
19
+ "\n",
20
+ "```\n",
21
+ "┌─────────────────────────────────────────────────────┐\n",
22
+ "│ MicroForge Pipeline │\n",
23
+ "├─────────────────────────────────────────────────────┤\n",
24
+ "│ │\n",
25
+ "│ Text ──→ [Text Encoder] ──→ text_emb, text_pooled │\n",
26
+ "│ │ │\n",
27
+ "│ ▼ │\n",
28
+ "│ Noise ──→ [Recurrent Latent Planner] ◄── plan_t-1 │\n",
29
+ "│ │ READ: plan ◄── z_t │\n",
30
+ "│ │ REASON: plan self-attention │\n",
31
+ "│ │ OUTPUT: planner_tokens │\n",
32
+ "│ ▼ │\n",
33
+ "│ z_t ──→ [SSM-Conv Backbone] ◄── planner_tokens │\n",
34
+ "│ │ Per-block: │\n",
35
+ "│ │ AdaLN-Group conditioning │\n",
36
+ "│ │ Bidirectional SSM (zigzag scan) │\n",
37
+ "│ │ Cross-attention to text+plan │\n",
38
+ "│ │ FFN (expansion=3) │\n",
39
+ "│ │ Global: Shared MQA attention │\n",
40
+ "│ ▼ │\n",
41
+ "│ v_pred ──→ [Euler ODE Step] ──→ z_{t-1} │\n",
42
+ "│ │\n",
43
+ "│ z_0 ──→ [DC-VAE Decoder] ──→ Image │\n",
44
+ "│ │\n",
45
+ "└─────────────────────────────────────────────────────┘\n",
46
+ "```\n",
47
+ "\n",
48
+ "## Key Innovations\n",
49
+ "\n",
50
+ "1. **Recurrent Latent Planner (RLP)**: A compact set of 32 latent tokens that iteratively reason about the image before committing to pixel changes. Inspired by RIN but adapted for diffusion.\n",
51
+ "\n",
52
+ "2. **SSM-Conv Hybrid Backbone**: Bidirectional state-space model with zigzag scanning + local DWConv + one globally-shared attention block. O(N) complexity vs O(N²) for transformers.\n",
53
+ "\n",
54
+ "3. **Deep Compression VAE**: 32× spatial compression with residual space-to-channel shortcuts. 512px → 16×16×32 latent (only 256 spatial tokens).\n",
55
+ "\n",
56
+ "4. **Editing-Ready Architecture**: DreamLite-style spatial concatenation for unified generation + editing with zero extra parameters."
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "metadata": {},
62
+ "source": [
63
+ "## 1. Setup & Installation"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "# Install dependencies\n",
73
+ "!pip install -q torch torchvision einops timm matplotlib"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "import torch\n",
83
+ "import torch.nn as nn\n",
84
+ "import torch.nn.functional as F\n",
85
+ "import matplotlib.pyplot as plt\n",
86
+ "import numpy as np\n",
87
+ "import time\n",
88
+ "import os\n",
89
+ "\n",
90
+ "# Auto-detect device\n",
91
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
92
+ "print(f'Using device: {device}')\n",
93
+ "if device == 'cuda':\n",
94
+ " print(f'GPU: {torch.cuda.get_device_name()}')\n",
95
+ " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": [
102
+ "## 2. Architecture Module Tests"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "from microforge.vae import MicroForgeVAE\n",
112
+ "from microforge.backbone import MicroForgeBackbone\n",
113
+ "from microforge.planner import RecurrentLatentPlanner\n",
114
+ "from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder\n",
115
+ "from microforge.training import MicroForgeTrainer, FlowMatchingScheduler, MicroForgeLoss\n",
116
+ "\n",
117
+ "print('All modules imported successfully!')"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "metadata": {},
123
+ "source": [
124
+ "### 2.1 Deep Compression VAE\n",
125
+ "\n",
126
+ "The VAE compresses images by 32× spatially using residual space-to-channel shortcuts (DC-AE technique).\n",
127
+ "\n",
128
+ "- **Input**: `[B, 3, H, W]` images\n",
129
+ "- **Latent**: `[B, C_latent, H/32, W/32]` — for 256px: `[B, 16, 8, 8]` (tiny) or `[B, 32, 8, 8]` (small)\n",
130
+ "- **Key**: Space-to-channel rearrangement as non-parametric skip connection"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": null,
136
+ "metadata": {},
137
+ "outputs": [],
138
+ "source": [
139
+ "# Test each VAE configuration\n",
140
+ "for config in ['tiny', 'small', 'base']:\n",
141
+ " vae = MicroForgeVAE(config=config)\n",
142
+ " params = sum(p.numel() for p in vae.parameters())\n",
143
+ " \n",
144
+ " x = torch.randn(1, 3, 256, 256)\n",
145
+ " x_recon, mu, logvar = vae(x)\n",
146
+ " \n",
147
+ " print(f'{config:>5}: {params:>12,} params | '\n",
148
+ " f'{params*4/1e6:>6.1f} MB fp32 | '\n",
149
+ " f'{params*2/1e6:>6.1f} MB fp16 | '\n",
150
+ " f'latent: {mu.shape}')"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "metadata": {},
156
+ "source": [
157
+ "### 2.2 SSM-Conv Hybrid Backbone\n",
158
+ "\n",
159
+ "The denoising backbone replaces quadratic attention with:\n",
160
+ "- **Bidirectional SSM** with zigzag scanning (O(N) complexity)\n",
161
+ "- **Local DWConv** for spatial feature enhancement\n",
162
+ "- **One globally-shared MQA attention block** (from DiMSUM)\n",
163
+ "- **AdaLN-Group conditioning** (46% fewer params than full adaLN)"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "metadata": {},
170
+ "outputs": [],
171
+ "source": [
172
+ "# Test each backbone configuration\n",
173
+ "for config_name in ['tiny', 'small', 'base']:\n",
174
+ " lc = 16 if config_name == 'tiny' else 32\n",
175
+ " backbone = MicroForgeBackbone(latent_channels=lc, config=config_name)\n",
176
+ " params = sum(p.numel() for p in backbone.parameters())\n",
177
+ " \n",
178
+ " z = torch.randn(1, lc, 8, 8)\n",
179
+ " t = torch.rand(1)\n",
180
+ " text_emb = torch.randn(1, 10, 768)\n",
181
+ " text_pooled = torch.randn(1, 768)\n",
182
+ " \n",
183
+ " start = time.time()\n",
184
+ " v = backbone(z, t, text_emb, text_pooled)\n",
185
+ " elapsed = time.time() - start\n",
186
+ " \n",
187
+ " print(f'{config_name:>5}: {params:>12,} params | '\n",
188
+ " f'{params*4/1e6:>6.1f} MB fp32 | '\n",
189
+ " f'{params*2/1e6:>6.1f} MB fp16 | '\n",
190
+ " f'latency: {elapsed*1000:.0f}ms')"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "metadata": {},
196
+ "source": [
197
+ "### 2.3 Recurrent Latent Planner (Novel Component)\n",
198
+ "\n",
199
+ "The RLP is our key innovation — a \"reasoning core\" that maintains persistent plan tokens:\n",
200
+ "\n",
201
+ "```\n",
202
+ "plan_0 = init(text)\n",
203
+ "for each denoising step:\n",
204
+ " plan = READ(plan, image_tokens) # absorb image info\n",
205
+ " plan = REASON(plan) # self-attention over plan\n",
206
+ " output = PROJECT(plan) # inject into backbone\n",
207
+ " z_{t-1} = backbone(z_t, output) # guided denoising\n",
208
+ "```\n",
209
+ "\n",
210
+ "Only 32 plan tokens × D dims = negligible memory overhead."
211
+ ]
212
+ },
213
+ {
214
+ "cell_type": "code",
215
+ "execution_count": null,
216
+ "metadata": {},
217
+ "outputs": [],
218
+ "source": [
219
+ "planner = RecurrentLatentPlanner(num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32)\n",
220
+ "params = sum(p.numel() for p in planner.parameters())\n",
221
+ "print(f'Planner: {params:,} params = {params*4/1e6:.1f} MB fp32')\n",
222
+ "print(f'Plan state size: {planner.get_plan_size_bytes()} bytes = {planner.get_plan_size_bytes()/1024:.1f} KB')\n",
223
+ "\n",
224
+ "# Test planner with self-conditioning (simulating multi-step)\n",
225
+ "text_pooled = torch.randn(1, 768)\n",
226
+ "plan = planner.initialize_plan(text_pooled, batch_size=1)\n",
227
+ "print(f'\\nInitial plan: {plan.shape}')\n",
228
+ "\n",
229
+ "# Simulate 3 denoising steps with plan carry-forward\n",
230
+ "for step in range(3):\n",
231
+ " z = torch.randn(1, 32, 8, 8)\n",
232
+ " img_tokens = z.reshape(1, 32, -1).permute(0, 2, 1)\n",
233
+ " t_emb = torch.randn(1, 384)\n",
234
+ " \n",
235
+ " plan, output = planner(img_tokens, plan, t_emb)\n",
236
+ " \n",
237
+ " # Self-condition for next step\n",
238
+ " plan = planner.initialize_plan(text_pooled, 1, prev_plan=plan)\n",
239
+ " print(f'Step {step}: plan_norm={plan.norm():.2f}, output_norm={output.norm():.2f}')"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "markdown",
244
+ "metadata": {},
245
+ "source": [
246
+ "## 3. Full Pipeline Assembly"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": [
255
+ "# Assemble full pipeline with tiny config (for fast testing)\n",
256
+ "vae = MicroForgeVAE(config='tiny')\n",
257
+ "backbone = MicroForgeBackbone(latent_channels=16, config='tiny')\n",
258
+ "planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)\n",
259
+ "text_encoder = SimpleTextEncoder(vocab_size=8192, embed_dim=768, num_layers=2)\n",
260
+ "\n",
261
+ "pipeline = MicroForgePipeline(vae, backbone, text_encoder, planner, device='cpu')\n",
262
+ "\n",
263
+ "# Parameter count\n",
264
+ "params = pipeline.count_parameters()\n",
265
+ "print('=== MicroForge Parameter Budget ===')\n",
266
+ "for name, count in params.items():\n",
267
+ " print(f' {name:>15}: {count:>12,} ({count*4/1e6:.1f} MB fp32, {count*2/1e6:.1f} MB fp16)')\n",
268
+ "\n",
269
+ "# Memory estimate\n",
270
+ "print('\\n=== Memory Estimates ===')\n",
271
+ "for res in [128, 256, 512]:\n",
272
+ " mem = pipeline.get_memory_estimate(res, res)\n",
273
+ " print(f' {res}x{res}: ~{mem[\"estimated_inference_mb\"]:.0f} MB inference')"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "markdown",
278
+ "metadata": {},
279
+ "source": [
280
+ "## 4. End-to-End Inference Test"
281
+ ]
282
+ },
283
+ {
284
+ "cell_type": "code",
285
+ "execution_count": null,
286
+ "metadata": {},
287
+ "outputs": [],
288
+ "source": [
289
+ "# Generate a test image (random weights = noise, but validates full pipeline)\n",
290
+ "tokens = torch.randint(0, 8192, (1, 10))\n",
291
+ "\n",
292
+ "start = time.time()\n",
293
+ "with torch.no_grad():\n",
294
+ " images = pipeline.text2img(\n",
295
+ " tokens, \n",
296
+ " height=128, width=128,\n",
297
+ " num_steps=4, # Few steps for speed\n",
298
+ " cfg_scale=1.0, # No CFG for untrained model\n",
299
+ " seed=42\n",
300
+ " )\n",
301
+ "elapsed = time.time() - start\n",
302
+ "\n",
303
+ "print(f'Generated {images.shape} in {elapsed:.2f}s')\n",
304
+ "print(f'Range: [{images.min():.2f}, {images.max():.2f}]')\n",
305
+ "\n",
306
+ "# Visualize\n",
307
+ "img = images[0].permute(1, 2, 0).cpu().numpy()\n",
308
+ "img = (img - img.min()) / (img.max() - img.min() + 1e-8)\n",
309
+ "\n",
310
+ "plt.figure(figsize=(4, 4))\n",
311
+ "plt.imshow(img)\n",
312
+ "plt.title('MicroForge Output (untrained, random weights)')\n",
313
+ "plt.axis('off')\n",
314
+ "plt.tight_layout()\n",
315
+ "plt.savefig('test_generation.png', dpi=100)\n",
316
+ "plt.show()\n",
317
+ "print('Saved to test_generation.png')"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "markdown",
322
+ "metadata": {},
323
+ "source": [
324
+ "## 5. Training Pipeline Demo\n",
325
+ "\n",
326
+ "### 5.1 Stage 1: VAE Training\n",
327
+ "\n",
328
+ "Train the VAE on synthetic data to verify the training loop.\n",
329
+ "In production, use ImageNet or similar with perceptual + adversarial losses."
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "# Stage 1: VAE Training\n",
339
+ "vae_train = MicroForgeVAE(config='tiny').train()\n",
340
+ "vae_opt = torch.optim.AdamW(vae_train.parameters(), lr=1e-4, weight_decay=0.01)\n",
341
+ "loss_fn = MicroForgeLoss(lambda_kl=1e-6)\n",
342
+ "\n",
343
+ "vae_losses = []\n",
344
+ "print('=== Stage 1: VAE Training ===')\n",
345
+ "for step in range(50):\n",
346
+ " # Synthetic data: random colored patches\n",
347
+ " images = torch.randn(4, 3, 128, 128) * 0.5\n",
348
+ " \n",
349
+ " x_recon, mu, logvar = vae_train(images)\n",
350
+ " losses = loss_fn.vae_loss(x_recon, images, mu, logvar)\n",
351
+ " \n",
352
+ " vae_opt.zero_grad()\n",
353
+ " losses['total'].backward()\n",
354
+ " torch.nn.utils.clip_grad_norm_(vae_train.parameters(), 2.0)\n",
355
+ " vae_opt.step()\n",
356
+ " \n",
357
+ " vae_losses.append(losses['recon'].item())\n",
358
+ " if step % 10 == 0:\n",
359
+ " print(f' Step {step:3d}: recon={losses[\"recon\"].item():.4f}, kl={losses[\"kl\"].item():.2f}')\n",
360
+ "\n",
361
+ "plt.figure(figsize=(8, 3))\n",
362
+ "plt.plot(vae_losses)\n",
363
+ "plt.xlabel('Step')\n",
364
+ "plt.ylabel('Reconstruction Loss')\n",
365
+ "plt.title('Stage 1: VAE Training')\n",
366
+ "plt.tight_layout()\n",
367
+ "plt.savefig('vae_training.png', dpi=100)\n",
368
+ "plt.show()"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "markdown",
373
+ "metadata": {},
374
+ "source": [
375
+ "### 5.2 Stage 2: Backbone Flow Matching Training\n",
376
+ "\n",
377
+ "Train the SSM backbone with rectified flow matching.\n",
378
+ "VAE is frozen; backbone learns to predict velocity v(z_t, t)."
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": null,
384
+ "metadata": {},
385
+ "outputs": [],
386
+ "source": [
387
+ "# Stage 2: Backbone Training with Flow Matching\n",
388
+ "vae_train.eval()\n",
389
+ "backbone_train = MicroForgeBackbone(latent_channels=16, config='tiny')\n",
390
+ "planner_train = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)\n",
391
+ "\n",
392
+ "trainer = MicroForgeTrainer(\n",
393
+ " vae_train, backbone_train, planner_train,\n",
394
+ " lr=1e-4, weight_decay=0.01, use_ema=True\n",
395
+ ")\n",
396
+ "\n",
397
+ "flow_losses = []\n",
398
+ "print('=== Stage 2: Backbone Flow Matching Training ===')\n",
399
+ "for step in range(100):\n",
400
+ " images = torch.randn(4, 3, 128, 128) * 0.5\n",
401
+ " text_emb = torch.randn(4, 10, 768)\n",
402
+ " text_pooled = torch.randn(4, 768)\n",
403
+ " \n",
404
+ " losses = trainer.train_step(images, text_emb, text_pooled)\n",
405
+ " flow_losses.append(losses['flow'])\n",
406
+ " \n",
407
+ " if step % 20 == 0:\n",
408
+ " print(f' Step {step:3d}: flow_loss={losses[\"flow\"]:.4f}')\n",
409
+ "\n",
410
+ "plt.figure(figsize=(8, 3))\n",
411
+ "plt.plot(flow_losses)\n",
412
+ "plt.xlabel('Step')\n",
413
+ "plt.ylabel('Flow Matching Loss')\n",
414
+ "plt.title('Stage 2: Backbone Training')\n",
415
+ "plt.tight_layout()\n",
416
+ "plt.savefig('backbone_training.png', dpi=100)\n",
417
+ "plt.show()"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "markdown",
422
+ "metadata": {},
423
+ "source": [
424
+ "## 6. Staged Training Curriculum (Production)\n",
425
+ "\n",
426
+ "The full training curriculum for a production model:\n",
427
+ "\n",
428
+ "```\n",
429
+ "STAGE 1 — VAE (freeze after):\n",
430
+ " Data: ImageNet + SAM (mixed res)\n",
431
+ " Loss: L1 recon + 1e-6*KL + perceptual (LPIPS) + adversarial (PatchGAN)\n",
432
+ " Steps: 100K, batch=256, lr=1e-4\n",
433
+ " Hardware: 4× A100 (or 1× T4 with grad accumulation)\n",
434
+ "\n",
435
+ "STAGE 2 — Backbone Low-Res (128-256px):\n",
436
+ " Data: Teacher-generated synthetic data (FLUX/SD3.5 outputs)\n",
437
+ " Loss: Flow matching ||v_pred - v_target||²\n",
438
+ " Steps: 500K, batch=128, lr=1e-4\n",
439
+ " Freeze: VAE encoder+decoder\n",
440
+ " Train: Backbone + Planner\n",
441
+ "\n",
442
+ "STAGE 3 — Backbone High-Res (256-512px):\n",
443
+ " Data: Same + high-res subset\n",
444
+ " Loss: Flow matching + resolution-adaptive noise schedule\n",
445
+ " Steps: 200K, batch=64, lr=5e-5\n",
446
+ " Init: From Stage 2 weights\n",
447
+ "\n",
448
+ "STAGE 4 — Knowledge Distillation:\n",
449
+ " Teacher: FLUX.1-dev or SD3.5-Large\n",
450
+ " Loss: Flow matching + t-scaled distillation loss\n",
451
+ " Steps: 100K, batch=64, lr=2e-5\n",
452
+ "\n",
453
+ "STAGE 5 — Editing (spatial concat):\n",
454
+ " Data: InstructPix2Pix pairs + FLUX Kontext edits\n",
455
+ " Loss: Flow matching on [target | source] concat\n",
456
+ " Steps: 50K, batch=32, lr=1e-5\n",
457
+ " Trick: Progressive: T2I → Edit → Joint (DreamLite recipe)\n",
458
+ "\n",
459
+ "STAGE 6 — Step Distillation (4-step):\n",
460
+ " Method: Consistency distillation + LADD\n",
461
+ " Steps: 50K, batch=128, lr=1e-5\n",
462
+ " Target: 1-4 step generation\n",
463
+ "```"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": null,
469
+ "metadata": {},
470
+ "outputs": [],
471
+ "source": [
472
+ "# Demonstrate staged freeze/thaw training\n",
473
+ "print('=== Staged Training Configuration ===')\n",
474
+ "print()\n",
475
+ "\n",
476
+ "# Stage 1: Only VAE trainable\n",
477
+ "vae_s = MicroForgeVAE(config='tiny')\n",
478
+ "backbone_s = MicroForgeBackbone(latent_channels=16, config='tiny')\n",
479
+ "planner_s = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)\n",
480
+ "\n",
481
+ "def count_trainable(model):\n",
482
+ " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
483
+ "\n",
484
+ "def freeze(model):\n",
485
+ " for p in model.parameters():\n",
486
+ " p.requires_grad_(False)\n",
487
+ "\n",
488
+ "def unfreeze(model):\n",
489
+ " for p in model.parameters():\n",
490
+ " p.requires_grad_(True)\n",
491
+ "\n",
492
+ "# Stage 1: VAE only\n",
493
+ "freeze(backbone_s)\n",
494
+ "freeze(planner_s)\n",
495
+ "unfreeze(vae_s)\n",
496
+ "print(f'Stage 1 (VAE): {count_trainable(vae_s):,} trainable params')\n",
497
+ "\n",
498
+ "# Stage 2: Backbone + Planner only\n",
499
+ "freeze(vae_s)\n",
500
+ "unfreeze(backbone_s)\n",
501
+ "unfreeze(planner_s)\n",
502
+ "print(f'Stage 2 (Backbone+Planner): {count_trainable(backbone_s) + count_trainable(planner_s):,} trainable params')\n",
503
+ "\n",
504
+ "# Stage 5: Editing - all unfrozen but low LR\n",
505
+ "unfreeze(vae_s)\n",
506
+ "unfreeze(backbone_s)\n",
507
+ "unfreeze(planner_s)\n",
508
+ "total = count_trainable(vae_s) + count_trainable(backbone_s) + count_trainable(planner_s)\n",
509
+ "print(f'Stage 5 (Joint): {total:,} trainable params')"
510
+ ]
511
+ },
512
+ {
513
+ "cell_type": "markdown",
514
+ "metadata": {},
515
+ "source": [
516
+ "## 7. Memory Profiling for Mobile Deployment\n",
517
+ "\n",
518
+ "Target: < 3-4 GB RAM for inference on consumer devices."
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "metadata": {},
525
+ "outputs": [],
526
+ "source": [
527
+ "print('=== MicroForge Memory Budget ===')\n",
528
+ "print()\n",
529
+ "\n",
530
+ "configs = {\n",
531
+ " 'Mobile (tiny)': ('tiny', 16, 16, 256),\n",
532
+ " 'Prototype (small)': ('small', 32, 32, 384),\n",
533
+ " 'Full (base)': ('base', 32, 32, 512),\n",
534
+ "}\n",
535
+ "\n",
536
+ "for name, (cfg, lc, plan_tokens, plan_dim) in configs.items():\n",
537
+ " vae = MicroForgeVAE(config=cfg)\n",
538
+ " bb = MicroForgeBackbone(latent_channels=lc, config=cfg)\n",
539
+ " pl = RecurrentLatentPlanner(num_plan_tokens=plan_tokens, dim=plan_dim, text_dim=768, latent_channels=lc)\n",
540
+ " \n",
541
+ " total_params = sum(p.numel() for p in vae.parameters()) + \\\n",
542
+ " sum(p.numel() for p in bb.parameters()) + \\\n",
543
+ " sum(p.numel() for p in pl.parameters())\n",
544
+ " \n",
545
+ " fp32_mb = total_params * 4 / 1e6\n",
546
+ " fp16_mb = total_params * 2 / 1e6\n",
547
+ " int8_mb = total_params / 1e6\n",
548
+ " \n",
549
+ " print(f'{name}:')\n",
550
+ " print(f' Total params: {total_params:,}')\n",
551
+ " print(f' FP32: {fp32_mb:.0f} MB | FP16: {fp16_mb:.0f} MB | INT8: {int8_mb:.0f} MB')\n",
552
+ " \n",
553
+ " # Activation memory estimate (rough)\n",
554
+ " # For 512px: latent = 16x16xC, backbone processes 256 tokens\n",
555
+ " latent_tokens = 16 * 16 # at 512px\n",
556
+ " act_mb = latent_tokens * plan_dim * 4 / 1e6 * 20 # ~20 intermediate tensors\n",
557
+ " print(f' Activation memory @512px: ~{act_mb:.0f} MB')\n",
558
+ " print(f' Total inference @512px (FP16): ~{fp16_mb + act_mb:.0f} MB')\n",
559
+ " print()"
560
+ ]
561
+ },
562
+ {
563
+ "cell_type": "markdown",
564
+ "metadata": {},
565
+ "source": [
566
+ "## 8. Editing Readiness Demo\n",
567
+ "\n",
568
+ "The architecture supports editing via spatial concatenation:\n",
569
+ "- **Generation**: `z_input = [z_noise | zeros]` (width-concat)\n",
570
+ "- **Editing**: `z_input = [z_noise | z_source]` (width-concat)\n",
571
+ "- **Inpainting**: `z_input = [z_noise | z_masked_source]`\n",
572
+ "- **Super-res**: `z_input = [z_noise | z_lowres_upsampled]`\n",
573
+ "\n",
574
+ "No extra parameters needed — same backbone handles all tasks.\n",
575
+ "Task is indicated by prepending task tokens to the text prompt."
576
+ ]
577
+ },
578
+ {
579
+ "cell_type": "code",
580
+ "execution_count": null,
581
+ "metadata": {},
582
+ "outputs": [],
583
+ "source": [
584
+ "# Demonstrate spatial concatenation for different tasks\n",
585
+ "B, C, H, W = 1, 16, 8, 8 # Latent dimensions for 256px\n",
586
+ "\n",
587
+ "z_noise = torch.randn(B, C, H, W)\n",
588
+ "z_source = torch.randn(B, C, H, W)\n",
589
+ "z_zeros = torch.zeros(B, C, H, W)\n",
590
+ "\n",
591
+ "# Generation mode\n",
592
+ "z_gen = torch.cat([z_noise, z_zeros], dim=-1) # [B, C, H, 2W]\n",
593
+ "print(f'Generation input: {z_gen.shape} (target + blank context)')\n",
594
+ "\n",
595
+ "# Editing mode\n",
596
+ "z_edit = torch.cat([z_noise, z_source], dim=-1)\n",
597
+ "print(f'Editing input: {z_edit.shape} (target + source context)')\n",
598
+ "\n",
599
+ "# Inpainting mode\n",
600
+ "mask = torch.ones(B, 1, H, W)\n",
601
+ "mask[:, :, 2:6, 2:6] = 0 # Unmask center region\n",
602
+ "z_masked = z_source * mask # Zero out inpaint region\n",
603
+ "z_inpaint = torch.cat([z_noise, z_masked], dim=-1)\n",
604
+ "print(f'Inpaint input: {z_inpaint.shape} (target + masked source)')\n",
605
+ "\n",
606
+ "# The backbone processes all of these identically\n",
607
+ "bb = MicroForgeBackbone(latent_channels=C, config='tiny')\n",
608
+ "t = torch.rand(B)\n",
609
+ "text_emb = torch.randn(B, 5, 768)\n",
610
+ "text_pooled = torch.randn(B, 768)\n",
611
+ "\n",
612
+ "v_gen = bb(z_gen, t, text_emb, text_pooled)\n",
613
+ "print(f'\\nBackbone output: {v_gen.shape}')\n",
614
+ "print(f'Target velocity (left half): {v_gen[..., :W].shape}')"
615
+ ]
616
+ },
617
+ {
618
+ "cell_type": "markdown",
619
+ "metadata": {},
620
+ "source": [
621
+ "## 9. Mathematical Formulation Summary\n",
622
+ "\n",
623
+ "### Forward Process (Rectified Flow)\n",
624
+ "$$z_t = (1-t) \\cdot z_0 + t \\cdot \\epsilon, \\quad \\epsilon \\sim \\mathcal{N}(0, I)$$\n",
625
+ "\n",
626
+ "### Training Objective\n",
627
+ "$$\\mathcal{L}_{\\text{flow}} = \\mathbb{E}_{t, z_0, \\epsilon} \\left[ w(t) \\|v_\\theta(z_t, t, c) - (\\epsilon - z_0)\\|^2 \\right]$$\n",
628
+ "\n",
629
+ "where $w(t) = \\frac{1}{1 + |2t - 1|}$ (t-scaling, peaks at $t=0.5$)\n",
630
+ "\n",
631
+ "### Sampling (Euler ODE)\n",
632
+ "$$z_{t-\\Delta t} = z_t + \\Delta t \\cdot v_\\theta(z_t, t, c)$$\n",
633
+ "\n",
634
+ "### Planner Update\n",
635
+ "$$p^{(l+1)} = \\text{SelfAttn}(\\text{CrossAttn}(p^{(l)}, \\text{Proj}(z_t)))$$\n",
636
+ "\n",
637
+ "### Self-Conditioning\n",
638
+ "$$p_t = \\sigma(w) \\cdot p_{t+1} + (1 - \\sigma(w)) \\cdot p_{\\text{init}}(c_{\\text{text}})$$\n",
639
+ "\n",
640
+ "### VAE Loss\n",
641
+ "$$\\mathcal{L}_{\\text{VAE}} = \\|x - \\hat{x}\\|_1 + \\lambda_{\\text{KL}} \\cdot D_{\\text{KL}}(q(z|x) \\| \\mathcal{N}(0, I))$$"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "markdown",
646
+ "metadata": {},
647
+ "source": [
648
+ "## 10. Ablation Plan\n",
649
+ "\n",
650
+ "To validate each component's contribution:"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "execution_count": null,
656
+ "metadata": {},
657
+ "outputs": [],
658
+ "source": [
659
+ "ablations = [\n",
660
+ " ('Full MicroForge', True, True, True),\n",
661
+ " ('No Planner', True, False, True),\n",
662
+ " ('No SSM (attention only)', False, True, False), # Replace SSM with self-attn\n",
663
+ " ('No Shared Attention', True, True, True), # Remove shared attn block\n",
664
+ " ('No DWConv in SSM', True, True, True), # Remove local_conv from SSM\n",
665
+ "]\n",
666
+ "\n",
667
+ "print('=== Ablation Plan ===')\n",
668
+ "print(f'{\"Configuration\":>30} | {\"SSM\":>5} | {\"Planner\":>8} | {\"SharedAttn\":>10}')\n",
669
+ "print('-' * 65)\n",
670
+ "for name, ssm, planner, shared in ablations:\n",
671
+ " print(f'{name:>30} | {\"✓\" if ssm else \"✗\":>5} | {\"✓\" if planner else \"✗\":>8} | {\"✓\" if shared else \"✗\":>10}')\n",
672
+ "\n",
673
+ "print()\n",
674
+ "print('Metrics to track per ablation:')\n",
675
+ "print(' - FID (quality) on COCO-30K')\n",
676
+ "print(' - CLIP-Score (prompt adherence)')\n",
677
+ "print(' - ImageReward (aesthetics)')\n",
678
+ "print(' - Inference latency (ms)')\n",
679
+ "print(' - Peak memory (MB)')\n",
680
+ "print(' - Training convergence speed (steps to target FID)')"
681
+ ]
682
+ },
683
+ {
684
+ "cell_type": "markdown",
685
+ "metadata": {},
686
+ "source": [
687
+ "## 11. Dataset Pipeline for Staged Training"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "execution_count": null,
693
+ "metadata": {},
694
+ "outputs": [],
695
+ "source": [
696
+ "# Dataset recommendations per training stage\n",
697
+ "print('=== Recommended Datasets ===')\n",
698
+ "print()\n",
699
+ "\n",
700
+ "stages = {\n",
701
+ " 'Stage 1 - VAE': {\n",
702
+ " 'datasets': [\n",
703
+ " 'ImageNet-1K (class-cond, 1.28M images)',\n",
704
+ " 'SAM-1M (diverse scenes, SA-1B subset)',\n",
705
+ " 'FFHQ (70K faces for quality tuning)',\n",
706
+ " ],\n",
707
+ " 'hub_ids': ['ILSVRC/imagenet-1k', 'facebook/sam', 'NoCrypt/ffhq-512'],\n",
708
+ " },\n",
709
+ " 'Stage 2 - Low-Res T2I': {\n",
710
+ " 'datasets': [\n",
711
+ " 'JourneyDB-4M (high aesthetic quality)',\n",
712
+ " 'LAION-Aesthetics-6.5+ (filtered subset)',\n",
713
+ " 'Teacher-generated synthetic data (FLUX/SD3.5 outputs)',\n",
714
+ " ],\n",
715
+ " 'hub_ids': ['JourneyDB/JourneyDB', 'laion/laion2B-en-aesthetic'],\n",
716
+ " },\n",
717
+ " 'Stage 3 - High-Res T2I': {\n",
718
+ " 'datasets': [\n",
719
+ " 'Same as Stage 2, filtered for >512px',\n",
720
+ " 'Unsplash-25K (very high quality photos)',\n",
721
+ " ],\n",
722
+ " 'hub_ids': [],\n",
723
+ " },\n",
724
+ " 'Stage 4 - Knowledge Distillation': {\n",
725
+ " 'datasets': [\n",
726
+ " 'Self-generated: 1M prompts → FLUX.1-dev outputs',\n",
727
+ " 'DiffusionDB-2M (real user prompts)',\n",
728
+ " ],\n",
729
+ " 'hub_ids': ['poloclub/diffusiondb'],\n",
730
+ " },\n",
731
+ " 'Stage 5 - Editing': {\n",
732
+ " 'datasets': [\n",
733
+ " 'InstructPix2Pix (454K editing pairs)',\n",
734
+ " 'MagicBrush (10K high-quality edits)',\n",
735
+ " 'GRIT-Entity (subject-driven, 200K)',\n",
736
+ " 'Custom: FLUX.1-Kontext-generated edit pairs',\n",
737
+ " ],\n",
738
+ " 'hub_ids': ['timbrooks/instructpix2pix-clip-filtered', 'osunlp/MagicBrush'],\n",
739
+ " },\n",
740
+ "}\n",
741
+ "\n",
742
+ "for stage, info in stages.items():\n",
743
+ " print(f'\\n{stage}:')\n",
744
+ " for ds in info['datasets']:\n",
745
+ " print(f' • {ds}')\n",
746
+ " if info['hub_ids']:\n",
747
+ " print(f' HF Hub: {info[\"hub_ids\"]}')"
748
+ ]
749
+ },
750
+ {
751
+ "cell_type": "markdown",
752
+ "metadata": {},
753
+ "source": [
754
+ "## 12. Comparison with Existing Architectures"
755
+ ]
756
+ },
757
+ {
758
+ "cell_type": "code",
759
+ "execution_count": null,
760
+ "metadata": {},
761
+ "outputs": [],
762
+ "source": [
763
+ "comparison = [\n",
764
+ " ('SD-v1.5', '860M', '~3.4 GB', 'O(N²)', 'UNet', 'No', '20-50'),\n",
765
+ " ('SDXL', '2.6B', '~6.5 GB', 'O(N²)', 'UNet', 'No', '20-50'),\n",
766
+ " ('FLUX.1-dev', '12B', '~24 GB', 'O(N²)', 'MM-DiT', 'No', '20-50'),\n",
767
+ " ('SD3.5-Medium', '2.5B', '~6 GB', 'O(N²)', 'MM-DiT', 'No', '28'),\n",
768
+ " ('SANA-Sprint', '600M+2B', '~5.5 GB', 'O(N)', 'Linear DiT', 'No', '1-4'),\n",
769
+ " ('SnapGen', '380M+2B', '~4 GB', 'O(N²)', 'Pruned UNet', 'No', '4-28'),\n",
770
+ " ('DreamLite', '389M+2B', '~4 GB', 'O(N²)', 'Pruned UNet', 'Yes', '4'),\n",
771
+ " ('MicroForge-tiny', '28M+text', '~0.2 GB*', 'O(N)', 'SSM-Conv', 'Yes', '4-20'),\n",
772
+ " ('MicroForge-small', '114M+text', '~0.6 GB*', 'O(N)', 'SSM-Conv', 'Yes', '4-20'),\n",
773
+ " ('MicroForge-base', '240M+text', '~1.2 GB*', 'O(N)', 'SSM-Conv', 'Yes', '4-20'),\n",
774
+ "]\n",
775
+ "\n",
776
+ "print(f'{\"Model\":>18} | {\"Params\":>12} | {\"VRAM\":>10} | {\"Complexity\":>10} | {\"Backbone\":>12} | {\"Edit\":>5} | {\"Steps\":>6}')\n",
777
+ "print('-' * 95)\n",
778
+ "for row in comparison:\n",
779
+ " print(f'{row[0]:>18} | {row[1]:>12} | {row[2]:>10} | {row[3]:>10} | {row[4]:>12} | {row[5]:>5} | {row[6]:>6}')\n",
780
+ "print()\n",
781
+ "print('* MicroForge VRAM excludes text encoder (shared/swappable component)')\n",
782
+ "print(' With CLIP-L (428M): add ~0.9 GB. With Gemma-2-2B: add ~4 GB.')\n",
783
+ "print(' For mobile: use TinyCLIP (~60M) adding only ~0.12 GB.')"
784
+ ]
785
+ },
786
+ {
787
+ "cell_type": "markdown",
788
+ "metadata": {},
789
+ "source": [
790
+ "## 13. Export and Save Model"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "code",
795
+ "execution_count": null,
796
+ "metadata": {},
797
+ "outputs": [],
798
+ "source": [
799
+ "# Save model checkpoint\n",
800
+ "os.makedirs('checkpoints', exist_ok=True)\n",
801
+ "\n",
802
+ "checkpoint = {\n",
803
+ " 'vae_state_dict': vae_train.state_dict(),\n",
804
+ " 'backbone_state_dict': backbone_train.state_dict(),\n",
805
+ " 'planner_state_dict': planner_train.state_dict(),\n",
806
+ " 'config': {\n",
807
+ " 'vae_config': 'tiny',\n",
808
+ " 'backbone_config': 'tiny',\n",
809
+ " 'latent_channels': 16,\n",
810
+ " 'plan_tokens': 16,\n",
811
+ " 'plan_dim': 256,\n",
812
+ " 'text_dim': 768,\n",
813
+ " },\n",
814
+ " 'architecture_version': '0.1.0',\n",
815
+ "}\n",
816
+ "\n",
817
+ "torch.save(checkpoint, 'checkpoints/microforge_tiny_demo.pt')\n",
818
+ "size_mb = os.path.getsize('checkpoints/microforge_tiny_demo.pt') / 1e6\n",
819
+ "print(f'Saved checkpoint: {size_mb:.1f} MB')\n",
820
+ "print('Done!')"
821
+ ]
822
+ }
823
+ ],
824
+ "metadata": {
825
+ "kernelspec": {
826
+ "display_name": "Python 3",
827
+ "language": "python",
828
+ "name": "python3"
829
+ },
830
+ "language_info": {
831
+ "name": "python",
832
+ "version": "3.12.0"
833
+ }
834
+ },
835
+ "nbformat": 4,
836
+ "nbformat_minor": 4
837
+ }