krystv commited on
Commit
508d27c
·
verified ·
1 Parent(s): 3798d56

Upload LiquidFlow_Colab.ipynb

Browse files
Files changed (1) hide show
  1. LiquidFlow_Colab.ipynb +74 -243
LiquidFlow_Colab.ipynb CHANGED
@@ -12,13 +12,7 @@
12
  "- **CfC (Closed-form Continuous-time)** Liquid Neural Networks — adaptive time gates\n",
13
  "- **Mamba-2 SSD** — linear-time attention replacement, fully parallelizable\n",
14
  "- **Physics-Informed Regularization** — TV loss, spectral constraints\n",
15
- "- **TAESD VAE** — Tiny AutoEncoder (< 1M params) for fast encoding\n",
16
- "\n",
17
- "Based on:\n",
18
- "- CfC: Hasani et al., Nature MI 2022\n",
19
- "- Mamba-2: Dao & Gu, 2024 \n",
20
- "- PINN Diffusion: Bastek & Sun, ICLR 2025\n",
21
- "- DiMSUM: NeurIPS 2024\n",
22
  "\n",
23
  "---\n",
24
  "## Quick Start\n",
@@ -33,9 +27,8 @@
33
  "execution_count": null,
34
  "metadata": {},
35
  "source": [
36
- "# @title 1. Install Dependencies (~2 min)\n",
37
- "!pip install -q torch torchvision diffusers tqdm pillow numpy\n",
38
- "!pip install -q git+https://github.com/huggingface/diffusers.git\n",
39
  "\n",
40
  "import torch\n",
41
  "print(f\"PyTorch: {torch.__version__}\")\n",
@@ -52,7 +45,11 @@
52
  "metadata": {},
53
  "source": [
54
  "# @title 2. Clone LiquidFlow Repository\n",
55
- "!git clone https://huggingface.co/krystv/LiquidFlow-Gen /content/LiquidFlow\n",
 
 
 
 
56
  "%cd /content/LiquidFlow\n",
57
  "\n",
58
  "import sys\n",
@@ -60,8 +57,7 @@
60
  "\n",
61
  "from liquid_flow.generator import create_liquidflow\n",
62
  "from liquid_flow.vae_wrapper import TAESDWrapper\n",
63
- "import torch.nn as nn\n",
64
- "import torch.nn.functional as F"
65
  ],
66
  "outputs": []
67
  },
@@ -70,33 +66,35 @@
70
  "execution_count": null,
71
  "metadata": {},
72
  "source": [
73
- "# @title 3. Configuration — Adjust these settings!\n",
74
  "\n",
75
- "# Model size: 'tiny' (~2M), 'small' (~8M), 'base' (~30M)\n",
76
  "MODEL_VARIANT = 'small' # @param ['tiny', 'small', 'base']\n",
77
  "\n",
78
- "# Image size: 128 recommended for T4, 512 needs more VRAM\n",
79
- "IMAGE_SIZE = 128 # @param [64, 128, 256, 512]\n",
80
  "\n",
81
- "# Training\n",
82
- "BATCH_SIZE = 32 # @param [8, 16, 32, 64]\n",
83
- "EPOCHS = 50 # @param [10, 25, 50, 100]\n",
84
- "LEARNING_RATE = 2e-4 # @param [1e-4, 2e-4, 5e-4, 1e-3]\n",
85
  "\n",
86
  "# Dataset\n",
87
  "DATASET = 'cifar10' # @param ['cifar10', 'cifar100', 'stl10']\n",
88
  "\n",
89
- "# Sampling (DDIM steps)\n",
90
- "SAMPLE_EVERY = 5 # @param [1, 5, 10]\n",
91
- "SAMPLE_STEPS = 50 # @param [20, 50, 100]\n",
92
  "\n",
93
- "# Physics regularization weights\n",
94
- "PHYSICS_TV_WEIGHT = 0.01\n",
95
- "PHYSICS_SPEC_WEIGHT = 0.01\n",
96
- "PHYSICS_GRAD_WEIGHT = 0.001\n",
 
 
 
97
  "\n",
98
- "print(f\"Config: {MODEL_VARIANT} model, {IMAGE_SIZE}px, batch={BATCH_SIZE}, epochs={EPOCHS}, lr={LEARNING_RATE}\")\n",
99
- "print(f\"Physics loss: TV={PHYSICS_TV_WEIGHT}, Spec={PHYSICS_SPEC_WEIGHT}, Grad={PHYSICS_GRAD_WEIGHT}\")"
100
  ],
101
  "outputs": []
102
  },
@@ -117,31 +115,17 @@
117
  "# Load TAESD (Tiny AutoEncoder)\n",
118
  "print(\"Loading TAESD VAE...\")\n",
119
  "vae = TAESDWrapper.load(device)\n",
120
- "print(f\"VAE loaded! Latent compression: {IMAGE_SIZE}x{IMAGE_SIZE} → {IMAGE_SIZE//8}x{IMAGE_SIZE//8}\")\n",
 
121
  "\n",
122
  "# Create LiquidFlow model\n",
123
- "print(f\"Creating {MODEL_VARIANT} LiquidFlow model...\")\n",
124
- "model = create_liquidflow(\n",
125
- " variant=MODEL_VARIANT,\n",
126
- " image_size=IMAGE_SIZE,\n",
127
- " physics_weights={\n",
128
- " 'tv': PHYSICS_TV_WEIGHT,\n",
129
- " 'cons': 0.001,\n",
130
- " 'spec': PHYSICS_SPEC_WEIGHT,\n",
131
- " 'grad': PHYSICS_GRAD_WEIGHT,\n",
132
- " },\n",
133
- ")\n",
134
  "model = model.to(device)\n",
135
  "\n",
136
- "n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
137
- "print(f\"Model: {n_params:,} parameters ({n_params/1e6:.1f}M)\")\n",
138
- "\n",
139
- "# Memory estimate\n",
140
- "latent_size = IMAGE_SIZE // 8\n",
141
- "mem_per_sample = latent_size * latent_size * 4 * 4 / 1e6 # MB\n",
142
- "print(f\"Memory per sample: {mem_per_sample:.1f} MB\")\n",
143
- "print(f\"Estimated batch memory: {mem_per_sample * BATCH_SIZE:.1f} MB\")\n",
144
- "print(f\"T4 VRAM: 15 GB — should fit!\" if mem_per_sample * BATCH_SIZE < 10 else \"Watch memory!\")"
145
  ],
146
  "outputs": []
147
  },
@@ -154,8 +138,9 @@
154
  "\n",
155
  "transform = transforms.Compose([\n",
156
  " transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n",
 
157
  " transforms.ToTensor(),\n",
158
- " transforms.Normalize([0.5], [0.5]),\n",
159
  "])\n",
160
  "\n",
161
  "if DATASET == 'cifar10':\n",
@@ -164,17 +149,14 @@
164
  " dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
165
  "elif DATASET == 'stl10':\n",
166
  " dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)\n",
167
- "else:\n",
168
- " raise ValueError(f\"Unknown dataset: {DATASET}\")\n",
169
  "\n",
170
  "dataloader = DataLoader(\n",
171
  " dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
172
- " num_workers=min(4, os.cpu_count() or 1),\n",
173
- " pin_memory=True, drop_last=True,\n",
174
  ")\n",
175
  "\n",
176
- "print(f\"Dataset: {DATASET}\")\n",
177
- "print(f\"Images: {len(dataset):,}, Batches per epoch: {len(dataloader)}\")"
178
  ],
179
  "outputs": []
180
  },
@@ -183,109 +165,64 @@
183
  "execution_count": null,
184
  "metadata": {},
185
  "source": [
186
- "# @title 6. Training Loop\n",
187
- "\n",
188
  "from torchvision.utils import save_image\n",
189
- "import math\n",
190
  "\n",
191
  "os.makedirs('./outputs/samples', exist_ok=True)\n",
192
  "os.makedirs('./outputs/checkpoints', exist_ok=True)\n",
193
  "\n",
194
- "# Optimizer\n",
195
- "optimizer = torch.optim.AdamW(\n",
196
- " model.parameters(),\n",
197
- " lr=LEARNING_RATE,\n",
198
- " betas=(0.9, 0.999),\n",
199
- " weight_decay=1e-4,\n",
200
- ")\n",
201
- "\n",
202
- "# Cosine LR scheduler\n",
203
- "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
204
- " optimizer, T_max=EPOCHS * len(dataloader)\n",
205
- ")\n",
206
  "\n",
207
- "# AMP\n",
208
  "use_amp = device.type == 'cuda'\n",
209
  "scaler = torch.cuda.amp.GradScaler() if use_amp else None\n",
210
  "\n",
211
- "print(f\"Training: {EPOCHS} epochs, LR={LEARNING_RATE}, AMP={use_amp}\")\n",
212
- "print(\"=\"*60)\n",
213
  "\n",
214
- "global_step = 0\n",
215
  "best_loss = float('inf')\n",
216
  "\n",
217
  "for epoch in range(EPOCHS):\n",
218
  " model.train()\n",
219
- " epoch_total = 0\n",
220
- " \n",
221
- " pbar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{EPOCHS}\")\n",
222
  " \n",
223
  " for images, _ in pbar:\n",
224
  " images = images.to(device)\n",
225
  " \n",
226
- " # Encode to latent\n",
227
  " with torch.no_grad():\n",
228
  " latents = TAESDWrapper.encode(vae, images)\n",
229
  " \n",
230
- " # Training step with physics regularization\n",
231
  " loss_dict = model.training_step(latents, optimizer, scaler, use_amp)\n",
232
- " \n",
233
- " # Track\n",
234
- " total_loss = loss_dict['total']\n",
235
- " epoch_total += total_loss\n",
236
- " \n",
237
- " # Update scheduler\n",
238
  " scheduler.step()\n",
239
  " \n",
240
- " # Progress bar\n",
241
- " pbar.set_postfix({\n",
242
- " 'loss': f\"{total_loss:.4f}\",\n",
243
- " 'diff': f\"{loss_dict.get('diffusion', 0):.4f}\",\n",
244
- " 'phys': f\"{loss_dict.get('physics', 0):.4f}\",\n",
245
- " 'lr': f\"{optimizer.param_groups[0]['lr']:.2e}\",\n",
246
- " })\n",
247
- " \n",
248
- " global_step += 1\n",
249
  " \n",
250
- " avg_loss = epoch_total / len(dataloader)\n",
251
- " print(f\"Epoch {epoch+1}: avg_loss={avg_loss:.4f}\")\n",
252
  " \n",
253
  " # Generate samples\n",
254
  " if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == EPOCHS - 1:\n",
255
- " print(\" Generating samples...\")\n",
256
  " model.eval()\n",
257
  " with torch.no_grad():\n",
258
- " latents_gen = model.sample(\n",
259
- " batch_size=16,\n",
260
- " steps=SAMPLE_STEPS,\n",
261
- " ddim=True,\n",
262
- " progress=False,\n",
263
- " )\n",
264
- " images_gen = TAESDWrapper.decode(vae, latents_gen)\n",
265
- " \n",
266
- " save_image(\n",
267
- " images_gen, f'./outputs/samples/epoch_{epoch+1:03d}.png',\n",
268
- " nrow=4, normalize=True, value_range=(-1, 1)\n",
269
- " )\n",
270
- " print(f\" Saved to ./outputs/samples/epoch_{epoch+1:03d}.png\")\n",
271
- " \n",
272
- " # Save checkpoint\n",
273
- " if (epoch + 1) % 10 == 0 or epoch == EPOCHS - 1:\n",
274
- " torch.save({\n",
275
- " 'epoch': epoch + 1,\n",
276
- " 'model_state_dict': model.state_dict(),\n",
277
- " 'optimizer_state_dict': optimizer.state_dict(),\n",
278
- " 'loss': avg_loss,\n",
279
- " }, f'./outputs/checkpoints/epoch_{epoch+1:03d}.pt')\n",
280
  " \n",
281
- " if avg_loss < best_loss:\n",
282
- " best_loss = avg_loss\n",
283
- " torch.save(model.state_dict(), './outputs/checkpoints/best_model.pt')\n",
284
- "\n",
285
- "print(\"=\"*60)\n",
286
- "print(f\"Training complete! Best loss: {best_loss:.4f}\")\n",
287
- "print(f\"Checkpoints saved to ./outputs/checkpoints/\")\n",
288
- "print(f\"Samples saved to ./outputs/samples/\")"
 
 
289
  ],
290
  "outputs": []
291
  },
@@ -294,134 +231,28 @@
294
  "execution_count": null,
295
  "metadata": {},
296
  "source": [
297
- "# @title 7. Generate & Display Samples\n",
298
- "\n",
299
  "import matplotlib.pyplot as plt\n",
300
  "from PIL import Image\n",
301
  "import glob\n",
302
  "\n",
303
- "# Load latest sample\n",
304
  "sample_files = sorted(glob.glob('./outputs/samples/epoch_*.png'))\n",
305
  "if sample_files:\n",
306
- " latest = sample_files[-1]\n",
307
- " img = Image.open(latest)\n",
308
- " plt.figure(figsize=(12, 12))\n",
309
  " plt.imshow(img)\n",
310
- " plt.title(f'LiquidFlow Samples — {MODEL_VARIANT} model, {IMAGE_SIZE}px')\n",
311
  " plt.axis('off')\n",
312
  " plt.show()\n",
313
  "else:\n",
314
- " print(\"No samples generated yet. Train for more epochs!\")"
315
  ],
316
  "outputs": []
317
- },
318
- {
319
- "cell_type": "code",
320
- "execution_count": null,
321
- "metadata": {},
322
- "source": [
323
- "# @title 8. Export Model for Mobile (ONNX)\n",
324
- "\n",
325
- "# LiquidFlow can be exported to ONNX for mobile deployment\n",
326
- "# since it uses pure PyTorch (no custom CUDA kernels)\n",
327
- "\n",
328
- "def export_to_onnx(model, output_path='liquidflow_model.onnx', image_size=128):\n",
329
- " \"\"\"Export LiquidFlow to ONNX for mobile deployment.\"\"\"\n",
330
- " model = model.cpu()\n",
331
- " model.eval()\n",
332
- " \n",
333
- " latent_size = image_size // 8\n",
334
- " \n",
335
- " # Dummy inputs\n",
336
- " x = torch.randn(1, 4, latent_size, latent_size)\n",
337
- " t = torch.tensor([500], dtype=torch.long)\n",
338
- " \n",
339
- " # Export\n",
340
- " torch.onnx.export(\n",
341
- " model,\n",
342
- " (x, t),\n",
343
- " output_path,\n",
344
- " input_names=['noisy_latent', 'timestep'],\n",
345
- " output_names=['predicted_noise'],\n",
346
- " dynamic_axes={\n",
347
- " 'noisy_latent': {0: 'batch'},\n",
348
- " 'predicted_noise': {0: 'batch'},\n",
349
- " },\n",
350
- " opset_version=14,\n",
351
- " )\n",
352
- " \n",
353
- " import os\n",
354
- " size_mb = os.path.getsize(output_path) / 1e6\n",
355
- " print(f\"ONNX model exported to {output_path} ({size_mb:.1f} MB)\")\n",
356
- " return output_path\n",
357
- "\n",
358
- "# Load best model and export\n",
359
- "best_model_path = './outputs/checkpoints/best_model.pt'\n",
360
- "if os.path.exists(best_model_path):\n",
361
- " model.load_state_dict(torch.load(best_model_path, map_location='cpu'))\n",
362
- " export_to_onnx(model, 'liquidflow_128.onnx', IMAGE_SIZE)\n",
363
- " print(\"Ready for mobile deployment!\")\n",
364
- "else:\n",
365
- " print(\"Train model first before exporting.\")"
366
- ],
367
- "outputs": []
368
- },
369
- {
370
- "cell_type": "markdown",
371
- "metadata": {},
372
- "source": [
373
- "## Architecture Details\n",
374
- "\n",
375
- "### LiquidFlow Block Architecture\n",
376
- "```\n",
377
- "Input → [CfC Gate → Mamba-2 SSD → CfC Gate] → Output\n",
378
- " ↑ ↑\n",
379
- " Adaptive time gate Gated output\n",
380
- "```\n",
381
- "\n",
382
- "### CfC (Closed-form Continuous-time) Cell\n",
383
- "```\n",
384
- "h(t) = σ(-f(x,I;θ_f)·t) ⊙ g(x,I;θ_g) + (1-σ(-f(x,I;θ_f)·t)) ⊙ h(x,I;θ_h)\n",
385
- "```\n",
386
- "- **No ODE solver needed** — 100x faster than Neural ODEs\n",
387
- "- Time-continuous gating adaptively controls information flow\n",
388
- "- Closed-form solution → stable gradients\n",
389
- "\n",
390
- "### Mamba-2 SSD (State Space Duality)\n",
391
- "```\n",
392
- "h_t = A_t * h_{t-1} + B_t * x_t\n",
393
- "y_t = C_t^T * h_t\n",
394
- "```\n",
395
- "- **O(N) linear complexity** vs Transformers O(N²)\n",
396
- "- **Parallelizable** via associative scan (Blelloch)\n",
397
- "- **Scalar-A** formulation enables chunk-scan optimization\n",
398
- "- Pure PyTorch — no CUDA kernels needed\n",
399
- "\n",
400
- "### Physics-Informed Regularization\n",
401
- "- **Total Variation**: `L_TV = ||∇_x x̂||₁ + ||∇_y x̂||₁`\n",
402
- "- **Spectral**: Penalize high-frequency artifacts\n",
403
- "- **Gradient**: Sobolev norm for stable training\n",
404
- "- Pattern from Bastek & Sun (ICLR 2025): physics loss as training-only regularizer\n",
405
- "\n",
406
- "### Model Variants\n",
407
- "| Variant | Params | Hidden Dim | Stages | Blocks | T4 VRAM |\n",
408
- "|---------|--------|------------|--------|--------|---------|\n",
409
- "| Tiny | ~2M | 128 | 2 | 2 | < 2 GB |\n",
410
- "| Small | ~8M | 256 | 4 | 4 | ~4 GB |\n",
411
- "| Base | ~30M | 384 | 6 | 6 | ~8 GB |"
412
- ]
413
  }
414
  ],
415
  "metadata": {
416
- "colab": {
417
- "name": "LiquidFlow: LiquidNN + Mamba-2 SSD Image Generator",
418
- "provenance": []
419
- },
420
- "kernelspec": {
421
- "display_name": "Python 3",
422
- "language": "python",
423
- "name": "python3"
424
- }
425
  },
426
  "nbformat": 4,
427
  "nbformat_minor": 0
 
12
  "- **CfC (Closed-form Continuous-time)** Liquid Neural Networks — adaptive time gates\n",
13
  "- **Mamba-2 SSD** — linear-time attention replacement, fully parallelizable\n",
14
  "- **Physics-Informed Regularization** — TV loss, spectral constraints\n",
15
+ "- **TAESD VAE** — Tiny AutoEncoder (< 3M params) for fast encoding\n",
 
 
 
 
 
 
16
  "\n",
17
  "---\n",
18
  "## Quick Start\n",
 
27
  "execution_count": null,
28
  "metadata": {},
29
  "source": [
30
+ "# @title 1. Install Dependencies\n",
31
+ "!pip install -q torch torchvision diffusers tqdm pillow numpy accelerate\n",
 
32
  "\n",
33
  "import torch\n",
34
  "print(f\"PyTorch: {torch.__version__}\")\n",
 
45
  "metadata": {},
46
  "source": [
47
  "# @title 2. Clone LiquidFlow Repository\n",
48
+ "import os\n",
49
+ "if not os.path.exists('/content/LiquidFlow'):\n",
50
+ " !git clone https://huggingface.co/krystv/LiquidFlow-Gen /content/LiquidFlow\n",
51
+ "else:\n",
52
+ " !cd /content/LiquidFlow && git pull\n",
53
  "%cd /content/LiquidFlow\n",
54
  "\n",
55
  "import sys\n",
 
57
  "\n",
58
  "from liquid_flow.generator import create_liquidflow\n",
59
  "from liquid_flow.vae_wrapper import TAESDWrapper\n",
60
+ "print('LiquidFlow imported successfully!')"
 
61
  ],
62
  "outputs": []
63
  },
 
66
  "execution_count": null,
67
  "metadata": {},
68
  "source": [
69
+ "# @title 3. Configuration\n",
70
  "\n",
71
+ "# Model variant: 'tiny' (~3.6M), 'small' (~11M), 'base' (~36M)\n",
72
  "MODEL_VARIANT = 'small' # @param ['tiny', 'small', 'base']\n",
73
  "\n",
74
+ "# Image size (128 recommended for T4 free tier)\n",
75
+ "IMAGE_SIZE = 128 # @param [64, 128, 256, 512] {type:\"integer\"}\n",
76
  "\n",
77
+ "# Training hyperparameters\n",
78
+ "BATCH_SIZE = 32 # @param [8, 16, 32, 64] {type:\"integer\"}\n",
79
+ "EPOCHS = 50 # @param [10, 25, 50, 100] {type:\"integer\"}\n",
80
+ "LEARNING_RATE = 2e-4 # @param {type:\"number\"}\n",
81
  "\n",
82
  "# Dataset\n",
83
  "DATASET = 'cifar10' # @param ['cifar10', 'cifar100', 'stl10']\n",
84
  "\n",
85
+ "# Sampling\n",
86
+ "SAMPLE_EVERY = 5 # @param {type:\"integer\"}\n",
87
+ "SAMPLE_STEPS = 50 # @param {type:\"integer\"}\n",
88
  "\n",
89
+ "# Ensure integer types (Colab forms can return strings)\n",
90
+ "IMAGE_SIZE = int(IMAGE_SIZE)\n",
91
+ "BATCH_SIZE = int(BATCH_SIZE)\n",
92
+ "EPOCHS = int(EPOCHS)\n",
93
+ "SAMPLE_EVERY = int(SAMPLE_EVERY)\n",
94
+ "SAMPLE_STEPS = int(SAMPLE_STEPS)\n",
95
+ "LEARNING_RATE = float(LEARNING_RATE)\n",
96
  "\n",
97
+ "print(f\"Config: {MODEL_VARIANT} model, {IMAGE_SIZE}px, batch={BATCH_SIZE}, epochs={EPOCHS}, lr={LEARNING_RATE}\")"
 
98
  ],
99
  "outputs": []
100
  },
 
115
  "# Load TAESD (Tiny AutoEncoder)\n",
116
  "print(\"Loading TAESD VAE...\")\n",
117
  "vae = TAESDWrapper.load(device)\n",
118
+ "latent_size = IMAGE_SIZE // 8\n",
119
+ "print(f\"VAE loaded! Latent: {IMAGE_SIZE}x{IMAGE_SIZE} -> {latent_size}x{latent_size}x4\")\n",
120
  "\n",
121
  "# Create LiquidFlow model\n",
122
+ "print(f\"\\nCreating '{MODEL_VARIANT}' LiquidFlow model...\")\n",
123
+ "model = create_liquidflow(variant=MODEL_VARIANT, image_size=IMAGE_SIZE)\n",
 
 
 
 
 
 
 
 
 
124
  "model = model.to(device)\n",
125
  "\n",
126
+ "n_params = model.count_parameters()\n",
127
+ "print(f\"Model parameters: {n_params:,} ({n_params/1e6:.1f}M)\")\n",
128
+ "print(f\"\\nReady to train!\")"
 
 
 
 
 
 
129
  ],
130
  "outputs": []
131
  },
 
138
  "\n",
139
  "transform = transforms.Compose([\n",
140
  " transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n",
141
+ " transforms.RandomHorizontalFlip(),\n",
142
  " transforms.ToTensor(),\n",
143
+ " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n",
144
  "])\n",
145
  "\n",
146
  "if DATASET == 'cifar10':\n",
 
149
  " dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
150
  "elif DATASET == 'stl10':\n",
151
  " dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)\n",
 
 
152
  "\n",
153
  "dataloader = DataLoader(\n",
154
  " dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
155
+ " num_workers=2, pin_memory=True, drop_last=True,\n",
 
156
  ")\n",
157
  "\n",
158
+ "print(f\"Dataset: {DATASET} ({len(dataset):,} images)\")\n",
159
+ "print(f\"Batches per epoch: {len(dataloader)}\")"
160
  ],
161
  "outputs": []
162
  },
 
165
  "execution_count": null,
166
  "metadata": {},
167
  "source": [
168
+ "# @title 6. Train!\n",
 
169
  "from torchvision.utils import save_image\n",
 
170
  "\n",
171
  "os.makedirs('./outputs/samples', exist_ok=True)\n",
172
  "os.makedirs('./outputs/checkpoints', exist_ok=True)\n",
173
  "\n",
174
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)\n",
175
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS * len(dataloader))\n",
 
 
 
 
 
 
 
 
 
 
176
  "\n",
 
177
  "use_amp = device.type == 'cuda'\n",
178
  "scaler = torch.cuda.amp.GradScaler() if use_amp else None\n",
179
  "\n",
180
+ "print(f\"Training: {EPOCHS} epochs, AMP={use_amp}\")\n",
181
+ "print('=' * 60)\n",
182
  "\n",
 
183
  "best_loss = float('inf')\n",
184
  "\n",
185
  "for epoch in range(EPOCHS):\n",
186
  " model.train()\n",
187
+ " epoch_loss = 0\n",
188
+ " pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{EPOCHS}')\n",
 
189
  " \n",
190
  " for images, _ in pbar:\n",
191
  " images = images.to(device)\n",
192
  " \n",
193
+ " # Encode to latent space\n",
194
  " with torch.no_grad():\n",
195
  " latents = TAESDWrapper.encode(vae, images)\n",
196
  " \n",
197
+ " # Training step\n",
198
  " loss_dict = model.training_step(latents, optimizer, scaler, use_amp)\n",
 
 
 
 
 
 
199
  " scheduler.step()\n",
200
  " \n",
201
+ " epoch_loss += loss_dict['total']\n",
202
+ " pbar.set_postfix(loss=f\"{loss_dict['total']:.4f}\", diff=f\"{loss_dict['diffusion']:.4f}\")\n",
 
 
 
 
 
 
 
203
  " \n",
204
+ " avg = epoch_loss / len(dataloader)\n",
205
+ " print(f'Epoch {epoch+1}: loss={avg:.4f}')\n",
206
  " \n",
207
  " # Generate samples\n",
208
  " if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == EPOCHS - 1:\n",
 
209
  " model.eval()\n",
210
  " with torch.no_grad():\n",
211
+ " z = model.sample(batch_size=16, steps=SAMPLE_STEPS, ddim=True, progress=False)\n",
212
+ " imgs = TAESDWrapper.decode(vae, z)\n",
213
+ " save_image(imgs, f'./outputs/samples/epoch_{epoch+1:03d}.png', nrow=4, normalize=True, value_range=(-1,1))\n",
214
+ " print(f' Samples saved: ./outputs/samples/epoch_{epoch+1:03d}.png')\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  " \n",
216
+ " # Checkpoint\n",
217
+ " if avg < best_loss:\n",
218
+ " best_loss = avg\n",
219
+ " torch.save(model.state_dict(), './outputs/checkpoints/best.pt')\n",
220
+ " if (epoch+1) % 10 == 0:\n",
221
+ " torch.save({'epoch': epoch+1, 'model': model.state_dict(), 'opt': optimizer.state_dict()},\n",
222
+ " f'./outputs/checkpoints/epoch_{epoch+1:03d}.pt')\n",
223
+ "\n",
224
+ "print('=' * 60)\n",
225
+ "print(f'Done! Best loss: {best_loss:.4f}')"
226
  ],
227
  "outputs": []
228
  },
 
231
  "execution_count": null,
232
  "metadata": {},
233
  "source": [
234
+ "# @title 7. Display Generated Samples\n",
 
235
  "import matplotlib.pyplot as plt\n",
236
  "from PIL import Image\n",
237
  "import glob\n",
238
  "\n",
 
239
  "sample_files = sorted(glob.glob('./outputs/samples/epoch_*.png'))\n",
240
  "if sample_files:\n",
241
+ " img = Image.open(sample_files[-1])\n",
242
+ " plt.figure(figsize=(10, 10))\n",
 
243
  " plt.imshow(img)\n",
244
+ " plt.title(f'LiquidFlow — {MODEL_VARIANT}, {IMAGE_SIZE}px (latest)')\n",
245
  " plt.axis('off')\n",
246
  " plt.show()\n",
247
  "else:\n",
248
+ " print('No samples yet run training first!')"
249
  ],
250
  "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  }
252
  ],
253
  "metadata": {
254
+ "colab": {"name": "LiquidFlow_Train", "provenance": []},
255
+ "kernelspec": {"display_name": "Python 3", "name": "python3"}
 
 
 
 
 
 
 
256
  },
257
  "nbformat": 4,
258
  "nbformat_minor": 0