krystv commited on
Commit
b48de91
Β·
verified Β·
1 Parent(s): 992d967

V2: Add VAE, fix datasets, streaming, precaching

Browse files
Files changed (1) hide show
  1. LiquidDiffusion_Training.ipynb +31 -676
LiquidDiffusion_Training.ipynb CHANGED
@@ -1,678 +1,33 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# 🌊 LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n",
8
- "\n",
9
- "**A novel image generation architecture** that replaces attention with Parallel CfC (Closed-form Continuous-depth) blocks from Liquid Neural Networks.\n",
10
- "\n",
11
- "## Key Innovations\n",
12
- "- **No attention mechanism** β€” all spatial mixing via multi-scale depthwise convolutions\n",
13
- "- **Fully parallelizable** β€” no sequential ODE solving loops (unlike original LTC/Neural ODE)\n",
14
- "- **Diffusion timestep IS the liquid time constant** β€” natural CfC-diffusion bridge\n",
15
- "- **Liquid relaxation residuals** β€” time-aware skip connections that adapt to noise level\n",
16
- "- **Fits in 16GB VRAM** β€” designed for Colab free tier (T4 GPU)\n",
17
- "\n",
18
- "## Architecture Based On\n",
19
- "- [CfC Networks](https://arxiv.org/abs/2106.13898) (Hasani et al., Nature Machine Intelligence 2022)\n",
20
- "- [LiquidTAD](https://arxiv.org/abs/2604.18274) β€” parallel liquid relaxation\n",
21
- "- [USM](https://arxiv.org/abs/2504.13499) β€” U-Shape architecture for diffusion\n",
22
- "- [Rectified Flow](https://arxiv.org/abs/2209.03003) β€” simplest flow matching objective\n",
23
- "\n",
24
- "## Training: Rectified Flow\n",
25
- "```\n",
26
- "x_t = (1-t)*x0 + t*noise, t ~ U[0,1]\n",
27
- "Loss = MSE(model(x_t, t), noise - x0) # velocity prediction\n",
28
- "```\n",
29
- "That's it β€” no noise schedule, no variance, just MSE on a straight-line velocity."
30
- ]
31
- },
32
- {
33
- "cell_type": "markdown",
34
- "metadata": {},
35
- "source": [
36
- "## πŸ”§ Setup"
37
- ]
38
- },
39
- {
40
- "cell_type": "code",
41
- "execution_count": null,
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "# Install dependencies\n",
46
- "!pip install -q torch torchvision datasets Pillow matplotlib tqdm accelerate"
47
- ]
48
- },
49
- {
50
- "cell_type": "code",
51
- "execution_count": null,
52
- "metadata": {},
53
- "outputs": [],
54
- "source": [
55
- "# Clone the repo\n",
56
- "!git clone https://huggingface.co/krystv/liquid-diffusion\n",
57
- "%cd liquid-diffusion"
58
- ]
59
- },
60
- {
61
- "cell_type": "code",
62
- "execution_count": null,
63
- "metadata": {},
64
- "outputs": [],
65
- "source": [
66
- "import torch\n",
67
- "print(f'PyTorch: {torch.__version__}')\n",
68
- "print(f'CUDA available: {torch.cuda.is_available()}')\n",
69
- "if torch.cuda.is_available():\n",
70
- " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
71
- " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
72
- ]
73
- },
74
- {
75
- "cell_type": "markdown",
76
- "metadata": {},
77
- "source": [
78
- "## πŸ“ Architecture Overview\n",
79
- "\n",
80
- "The core innovation is the **ParallelCfCBlock** β€” a parallelized version of CfC (Closed-form Continuous-depth) networks adapted for 2D image features:\n",
81
- "\n",
82
- "```\n",
83
- "CfC Equation (Hasani et al. 2022, Eq. 10):\n",
84
- " x(t) = Οƒ(-fΒ·t) βŠ™ g + (1 - Οƒ(-fΒ·t)) βŠ™ h\n",
85
- "\n",
86
- "Our adaptation for image generation:\n",
87
- " backbone = SiLU(PointwiseConv(DepthwiseConv(features))) # shared spatial context\n",
88
- " f = Conv1x1(backbone) # time-constant gate\n",
89
- " g = DWConv→SiLU→Conv1x1(backbone) # \"from\" state\n",
90
- " h = DWConv→SiLU→Conv1x1(backbone) # \"to\" state (attractor)\n",
91
- " gate = Οƒ(time_a(t_emb) Β· f - time_b(t_emb)) # liquid time gate\n",
92
- " cfc_out = gate Β· g + (1 - gate) Β· h # CfC interpolation\n",
93
- " \n",
94
- " # Liquid relaxation (from LiquidTAD):\n",
95
- " α = exp(-softplus(ρ) · |t|) # time-aware residual weight\n",
96
- " output = Ξ± Β· input + (1 - Ξ±) Β· cfc_out # adapts to noise level\n",
97
- "```\n",
98
- "\n",
99
- "The **diffusion timestep t** serves double duty:\n",
100
- "1. Standard: conditions the denoiser via AdaLN scale/shift\n",
101
- "2. Novel: acts as the CfC time parameter β€” controls interpolation between g and h\n",
102
- "\n",
103
- "This means: at low noise (tβ‰ˆ0), the gate is balanced β†’ flexible processing.\n",
104
- "At high noise (tβ‰ˆ1), the gate saturates β†’ specialized denoising."
105
- ]
106
- },
107
- {
108
- "cell_type": "markdown",
109
- "metadata": {},
110
- "source": [
111
- "## πŸ§ͺ Quick Test (verify model works)"
112
- ]
113
- },
114
- {
115
- "cell_type": "code",
116
- "execution_count": null,
117
- "metadata": {},
118
- "outputs": [],
119
- "source": [
120
- "# Run the test suite\n",
121
- "!python test_model.py"
122
- ]
123
- },
124
- {
125
- "cell_type": "markdown",
126
- "metadata": {},
127
- "source": [
128
- "## βš™οΈ Training Configuration\n",
129
- "\n",
130
- "Choose your config based on GPU and target resolution:\n",
131
- "\n",
132
- "| Config | Params | Resolution | Batch Size | VRAM | Training Time |\n",
133
- "|--------|--------|-----------|------------|------|---------------|\n",
134
- "| tiny | ~8M | 256Γ—256 | 8 | ~6GB | ~3h (100K steps) |\n",
135
- "| small | ~25M | 256Γ—256 | 4 | ~10GB | ~6h (100K steps) |\n",
136
- "| base | ~65M | 512Γ—512 | 2 | ~14GB | ~12h (100K steps) |\n",
137
- "\n",
138
- "Recommended datasets:\n",
139
- "- `huggan/CelebA-HQ` β€” 30K high-quality face images (256px)\n",
140
- "- `huggan/flowers-102-categories` β€” flowers (various)\n",
141
- "- `lambdalabs/naruto-blip-captions` β€” anime style (~1K)\n",
142
- "- `Norod78/simpsons-blip-captions` β€” cartoon style\n",
143
- "- Any folder of images"
144
- ]
145
- },
146
- {
147
- "cell_type": "code",
148
- "execution_count": null,
149
- "metadata": {},
150
- "outputs": [],
151
- "source": [
152
- "#@title Training Configuration {display-mode: \"form\"}\n",
153
- "\n",
154
- "#@markdown ### Model\n",
155
- "model_size = \"tiny\" #@param [\"tiny\", \"small\", \"base\"]\n",
156
- "\n",
157
- "#@markdown ### Data\n",
158
- "dataset_name = \"huggan/CelebA-HQ\" #@param {type:\"string\"}\n",
159
- "image_column = \"image\" #@param {type:\"string\"}\n",
160
- "image_size = 256 #@param [64, 128, 256, 512] {type:\"integer\"}\n",
161
- "max_samples = 0 #@param {type:\"integer\"}\n",
162
- "\n",
163
- "#@markdown ### Training\n",
164
- "batch_size = 8 #@param {type:\"integer\"}\n",
165
- "learning_rate = 1e-4 #@param {type:\"number\"}\n",
166
- "weight_decay = 0.01 #@param {type:\"number\"}\n",
167
- "total_steps = 100000 #@param {type:\"integer\"}\n",
168
- "warmup_steps = 1000 #@param {type:\"integer\"}\n",
169
- "grad_clip = 1.0 #@param {type:\"number\"}\n",
170
- "ema_decay = 0.9999 #@param {type:\"number\"}\n",
171
- "time_sampling = \"logit_normal\" #@param [\"uniform\", \"logit_normal\"]\n",
172
- "\n",
173
- "#@markdown ### Sampling & Logging\n",
174
- "sample_every = 2000 #@param {type:\"integer\"}\n",
175
- "save_every = 5000 #@param {type:\"integer\"}\n",
176
- "num_sample_steps = 50 #@param {type:\"integer\"}\n",
177
- "num_sample_images = 4 #@param {type:\"integer\"}\n",
178
- "\n",
179
- "#@markdown ### Hardware\n",
180
- "use_amp = True #@param {type:\"boolean\"}\n",
181
- "amp_dtype = \"float16\" #@param [\"float16\", \"bfloat16\"]\n",
182
- "num_workers = 2 #@param {type:\"integer\"}\n",
183
- "\n",
184
- "# Auto-adjust batch size for resolution\n",
185
- "if image_size >= 512 and batch_size > 4:\n",
186
- " batch_size = min(batch_size, 2)\n",
187
- " print(f\"Auto-reduced batch_size to {batch_size} for {image_size}px\")\n",
188
- "\n",
189
- "if max_samples == 0:\n",
190
- " max_samples = None\n",
191
- "\n",
192
- "print(f\"\\nConfig: {model_size} model, {image_size}px, batch={batch_size}, lr={learning_rate}\")\n",
193
- "print(f\"Dataset: {dataset_name}, time_sampling={time_sampling}\")\n",
194
- "print(f\"Total steps: {total_steps:,}, AMP: {use_amp} ({amp_dtype})\")"
195
- ]
196
- },
197
- {
198
- "cell_type": "markdown",
199
- "metadata": {},
200
- "source": [
201
- "## πŸ“¦ Load Dataset"
202
- ]
203
- },
204
- {
205
- "cell_type": "code",
206
- "execution_count": null,
207
- "metadata": {},
208
- "outputs": [],
209
- "source": [
210
- "from datasets import load_dataset\n",
211
- "from liquid_diffusion.trainer import ImageDataset\n",
212
- "from torch.utils.data import DataLoader\n",
213
- "import matplotlib.pyplot as plt\n",
214
- "import numpy as np\n",
215
- "\n",
216
- "# Load dataset\n",
217
- "print(f\"Loading {dataset_name}...\")\n",
218
- "dataset = ImageDataset(\n",
219
- " source=dataset_name,\n",
220
- " image_size=image_size,\n",
221
- " image_column=image_column,\n",
222
- " max_samples=max_samples,\n",
223
- ")\n",
224
- "print(f\"Dataset size: {len(dataset)} images\")\n",
225
- "\n",
226
- "dataloader = DataLoader(\n",
227
- " dataset, batch_size=batch_size, shuffle=True,\n",
228
- " num_workers=num_workers, pin_memory=True, drop_last=True,\n",
229
- ")\n",
230
- "\n",
231
- "# Show some samples\n",
232
- "sample_batch = next(iter(dataloader))\n",
233
- "fig, axes = plt.subplots(1, min(4, batch_size), figsize=(16, 4))\n",
234
- "for i, ax in enumerate(axes):\n",
235
- " img = sample_batch[i].permute(1, 2, 0).numpy() * 0.5 + 0.5 # [-1,1] -> [0,1]\n",
236
- " ax.imshow(np.clip(img, 0, 1))\n",
237
- " ax.axis('off')\n",
238
- "plt.suptitle(f'Training samples ({image_size}Γ—{image_size})')\n",
239
- "plt.tight_layout()\n",
240
- "plt.show()"
241
- ]
242
- },
243
- {
244
- "cell_type": "markdown",
245
- "metadata": {},
246
- "source": [
247
- "## πŸ—οΈ Build Model"
248
- ]
249
- },
250
- {
251
- "cell_type": "code",
252
- "execution_count": null,
253
- "metadata": {},
254
- "outputs": [],
255
- "source": [
256
- "from liquid_diffusion.model import (\n",
257
- " liquid_diffusion_tiny, liquid_diffusion_small, liquid_diffusion_base\n",
258
- ")\n",
259
- "\n",
260
- "# Build model\n",
261
- "model_factories = {\n",
262
- " 'tiny': liquid_diffusion_tiny,\n",
263
- " 'small': liquid_diffusion_small,\n",
264
- " 'base': liquid_diffusion_base,\n",
265
- "}\n",
266
- "\n",
267
- "model = model_factories[model_size]()\n",
268
- "total_params, trainable_params = model.count_params()\n",
269
- "print(f\"Model: liquid_diffusion_{model_size}\")\n",
270
- "print(f\"Parameters: {total_params:,} ({total_params/1e6:.1f}M)\")\n",
271
- "print(f\"Trainable: {trainable_params:,}\")\n",
272
- "\n",
273
- "# Quick forward pass test\n",
274
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
275
- "model = model.to(device)\n",
276
- "test_x = torch.randn(1, 3, image_size, image_size, device=device)\n",
277
- "test_t = torch.tensor([0.5], device=device)\n",
278
- "with torch.no_grad():\n",
279
- " test_out = model(test_x, test_t)\n",
280
- "print(f\"Forward pass OK: {test_x.shape} β†’ {test_out.shape}\")\n",
281
- "del test_x, test_out\n",
282
- "if device == 'cuda':\n",
283
- " torch.cuda.empty_cache()"
284
- ]
285
- },
286
- {
287
- "cell_type": "markdown",
288
- "metadata": {},
289
- "source": [
290
- "## πŸš€ Train!"
291
- ]
292
- },
293
- {
294
- "cell_type": "code",
295
- "execution_count": null,
296
- "metadata": {},
297
- "outputs": [],
298
- "source": [
299
- "import os\n",
300
- "import time\n",
301
- "import math\n",
302
- "from tqdm.auto import tqdm\n",
303
- "from torchvision.utils import save_image, make_grid\n",
304
- "from liquid_diffusion.trainer import RectifiedFlowTrainer, get_cosine_schedule_with_warmup\n",
305
- "\n",
306
- "# Create output directories\n",
307
- "os.makedirs('checkpoints', exist_ok=True)\n",
308
- "os.makedirs('samples', exist_ok=True)\n",
309
- "\n",
310
- "# Build trainer\n",
311
- "trainer = RectifiedFlowTrainer(\n",
312
- " model=model,\n",
313
- " lr=learning_rate,\n",
314
- " weight_decay=weight_decay,\n",
315
- " ema_decay=ema_decay,\n",
316
- " grad_clip=grad_clip,\n",
317
- " time_sampling=time_sampling,\n",
318
- " device=device,\n",
319
- " use_amp=use_amp,\n",
320
- " amp_dtype=amp_dtype,\n",
321
- ")\n",
322
- "\n",
323
- "# Learning rate scheduler\n",
324
- "scheduler = get_cosine_schedule_with_warmup(\n",
325
- " trainer.optimizer, warmup_steps, total_steps\n",
326
- ")\n",
327
- "\n",
328
- "# Optional: resume from checkpoint\n",
329
- "resume_path = 'checkpoints/latest.pt'\n",
330
- "if os.path.exists(resume_path):\n",
331
- " trainer.load_checkpoint(resume_path)\n",
332
- " print(f\"Resumed from step {trainer.step}\")\n",
333
- "\n",
334
- "print(f\"\\n{'='*60}\")\n",
335
- "print(f\"Starting training: {total_steps:,} steps\")\n",
336
- "print(f\"Model: liquid_diffusion_{model_size} ({total_params/1e6:.1f}M params)\")\n",
337
- "print(f\"Resolution: {image_size}Γ—{image_size}, Batch: {batch_size}\")\n",
338
- "print(f\"LR: {learning_rate}, Warmup: {warmup_steps}, AMP: {use_amp}\")\n",
339
- "print(f\"{'='*60}\\n\")\n",
340
- "\n",
341
- "# Training loop\n",
342
- "start_time = time.time()\n",
343
- "data_iter = iter(dataloader)\n",
344
- "pbar = tqdm(range(trainer.step, total_steps), desc='Training', dynamic_ncols=True)\n",
345
- "loss_history = []\n",
346
- "\n",
347
- "for step in pbar:\n",
348
- " # Get batch (cycle through dataset)\n",
349
- " try:\n",
350
- " batch = next(data_iter)\n",
351
- " except StopIteration:\n",
352
- " data_iter = iter(dataloader)\n",
353
- " batch = next(data_iter)\n",
354
- " \n",
355
- " x0 = batch.to(device)\n",
356
- " \n",
357
- " # Train step\n",
358
- " metrics = trainer.train_step(x0)\n",
359
- " scheduler.step()\n",
360
- " \n",
361
- " # Logging\n",
362
- " loss_history.append(metrics['loss'])\n",
363
- " avg_loss = sum(loss_history[-100:]) / len(loss_history[-100:])\n",
364
- " lr_current = scheduler.get_last_lr()[0]\n",
365
- " \n",
366
- " pbar.set_postfix({\n",
367
- " 'loss': f\"{metrics['loss']:.4f}\",\n",
368
- " 'avg': f\"{avg_loss:.4f}\",\n",
369
- " 'lr': f\"{lr_current:.6f}\",\n",
370
- " 'gn': f\"{metrics['grad_norm']:.2f}\",\n",
371
- " })\n",
372
- " \n",
373
- " # Generate samples\n",
374
- " if (step + 1) % sample_every == 0 or step == 0:\n",
375
- " print(f\"\\nGenerating samples at step {step+1}...\")\n",
376
- " samples = trainer.sample(\n",
377
- " batch_size=num_sample_images, image_size=image_size,\n",
378
- " num_steps=num_sample_steps, use_ema=True\n",
379
- " )\n",
380
- " # Save grid\n",
381
- " grid = make_grid(samples * 0.5 + 0.5, nrow=int(math.sqrt(num_sample_images)), padding=2)\n",
382
- " save_image(grid, f'samples/step_{step+1:06d}.png')\n",
383
- " \n",
384
- " # Display\n",
385
- " fig, axes = plt.subplots(1, num_sample_images, figsize=(4*num_sample_images, 4))\n",
386
- " if num_sample_images == 1:\n",
387
- " axes = [axes]\n",
388
- " for i, ax in enumerate(axes):\n",
389
- " img = samples[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
390
- " ax.imshow(np.clip(img, 0, 1))\n",
391
- " ax.axis('off')\n",
392
- " plt.suptitle(f'Step {step+1} (EMA samples, {num_sample_steps} Euler steps)')\n",
393
- " plt.tight_layout()\n",
394
- " plt.show()\n",
395
- " \n",
396
- " # Save checkpoint\n",
397
- " if (step + 1) % save_every == 0:\n",
398
- " trainer.save_checkpoint(f'checkpoints/step_{step+1:06d}.pt', extra={'config': {\n",
399
- " 'model_size': model_size, 'image_size': image_size,\n",
400
- " 'batch_size': batch_size, 'learning_rate': learning_rate,\n",
401
- " }})\n",
402
- " trainer.save_checkpoint('checkpoints/latest.pt')\n",
403
- " print(f\"Saved checkpoint at step {step+1}\")\n",
404
- " \n",
405
- " # Safety: check for NaN\n",
406
- " if math.isnan(metrics['loss']):\n",
407
- " print(\"\\n⚠️ NaN loss detected! Stopping training.\")\n",
408
- " print(\"Try: reduce learning_rate, increase grad_clip, or use smaller model\")\n",
409
- " break\n",
410
- "\n",
411
- "elapsed = time.time() - start_time\n",
412
- "print(f\"\\nTraining complete! {trainer.step:,} steps in {elapsed/3600:.1f}h\")\n",
413
- "print(f\"Final avg loss: {sum(loss_history[-100:])/len(loss_history[-100:]):.4f}\")\n",
414
- "\n",
415
- "# Final save\n",
416
- "trainer.save_checkpoint('checkpoints/final.pt')\n",
417
- "print(\"Saved final checkpoint.\")"
418
- ]
419
- },
420
- {
421
- "cell_type": "markdown",
422
- "metadata": {},
423
- "source": [
424
- "## πŸ“Š Training Loss Curve"
425
- ]
426
- },
427
- {
428
- "cell_type": "code",
429
- "execution_count": null,
430
- "metadata": {},
431
- "outputs": [],
432
- "source": [
433
- "import matplotlib.pyplot as plt\n",
434
- "import numpy as np\n",
435
- "\n",
436
- "if loss_history:\n",
437
- " # Smooth the loss\n",
438
- " window = min(100, len(loss_history) // 5 + 1)\n",
439
- " smoothed = np.convolve(loss_history, np.ones(window)/window, mode='valid')\n",
440
- " \n",
441
- " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
442
- " \n",
443
- " ax1.plot(loss_history, alpha=0.3, label='Raw')\n",
444
- " ax1.plot(range(window-1, len(loss_history)), smoothed, label=f'Smoothed (w={window})')\n",
445
- " ax1.set_xlabel('Step')\n",
446
- " ax1.set_ylabel('Loss')\n",
447
- " ax1.set_title('Training Loss')\n",
448
- " ax1.legend()\n",
449
- " ax1.grid(True, alpha=0.3)\n",
450
- " \n",
451
- " ax2.plot(loss_history[-min(1000, len(loss_history)):], alpha=0.5)\n",
452
- " ax2.set_xlabel('Recent Steps')\n",
453
- " ax2.set_ylabel('Loss')\n",
454
- " ax2.set_title('Recent Loss (last 1000 steps)')\n",
455
- " ax2.grid(True, alpha=0.3)\n",
456
- " \n",
457
- " plt.tight_layout()\n",
458
- " plt.show()\n",
459
- "else:\n",
460
- " print(\"No training history yet.\")"
461
- ]
462
- },
463
- {
464
- "cell_type": "markdown",
465
- "metadata": {},
466
- "source": [
467
- "## 🎨 Generate Images"
468
- ]
469
- },
470
- {
471
- "cell_type": "code",
472
- "execution_count": null,
473
- "metadata": {},
474
- "outputs": [],
475
- "source": [
476
- "#@title Generation Settings {display-mode: \"form\"}\n",
477
- "num_images = 8 #@param {type:\"integer\"}\n",
478
- "sampling_steps = 50 #@param [25, 50, 100, 200] {type:\"integer\"}\n",
479
- "use_ema_model = True #@param {type:\"boolean\"}\n",
480
- "\n",
481
- "print(f\"Generating {num_images} images with {sampling_steps} Euler steps...\")\n",
482
- "samples = trainer.sample(\n",
483
- " batch_size=num_images, image_size=image_size,\n",
484
- " num_steps=sampling_steps, use_ema=use_ema_model,\n",
485
- ")\n",
486
- "\n",
487
- "# Display\n",
488
- "ncols = min(4, num_images)\n",
489
- "nrows = (num_images + ncols - 1) // ncols\n",
490
- "fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows))\n",
491
- "if nrows == 1 and ncols == 1:\n",
492
- " axes = [[axes]]\n",
493
- "elif nrows == 1:\n",
494
- " axes = [axes]\n",
495
- "for i in range(num_images):\n",
496
- " r, c = i // ncols, i % ncols\n",
497
- " img = samples[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
498
- " axes[r][c].imshow(np.clip(img, 0, 1))\n",
499
- " axes[r][c].axis('off')\n",
500
- "# Hide unused axes\n",
501
- "for i in range(num_images, nrows * ncols):\n",
502
- " r, c = i // ncols, i % ncols\n",
503
- " axes[r][c].axis('off')\n",
504
- "plt.suptitle(f'LiquidDiffusion Samples ({sampling_steps} steps, {\"EMA\" if use_ema_model else \"online\"})')\n",
505
- "plt.tight_layout()\n",
506
- "plt.show()\n",
507
- "\n",
508
- "# Save\n",
509
- "grid = make_grid(samples * 0.5 + 0.5, nrow=ncols, padding=2)\n",
510
- "save_image(grid, 'samples/generated.png')\n",
511
- "print(\"Saved to samples/generated.png\")"
512
- ]
513
- },
514
- {
515
- "cell_type": "markdown",
516
- "metadata": {},
517
- "source": [
518
- "## πŸ”¬ Visualize the Denoising Process"
519
- ]
520
- },
521
- {
522
- "cell_type": "code",
523
- "execution_count": null,
524
- "metadata": {},
525
- "outputs": [],
526
- "source": [
527
- "# Show step-by-step denoising\n",
528
- "num_vis_steps = 10\n",
529
- "total_euler_steps = 50\n",
530
- "vis_interval = total_euler_steps // num_vis_steps\n",
531
- "\n",
532
- "model_vis = trainer.ema_model\n",
533
- "model_vis.eval()\n",
534
- "\n",
535
- "z = torch.randn(1, 3, image_size, image_size, device=device)\n",
536
- "dt = 1.0 / total_euler_steps\n",
537
- "intermediates = [z.clone()]\n",
538
- "\n",
539
- "with torch.no_grad():\n",
540
- " for i in range(total_euler_steps, 0, -1):\n",
541
- " t = torch.full((1,), i / total_euler_steps, device=device)\n",
542
- " v = model_vis(z, t)\n",
543
- " z = z - v * dt\n",
544
- " if (total_euler_steps - i + 1) % vis_interval == 0:\n",
545
- " intermediates.append(z.clone())\n",
546
- "\n",
547
- "intermediates.append(z.clamp(-1, 1))\n",
548
- "\n",
549
- "fig, axes = plt.subplots(1, len(intermediates), figsize=(3*len(intermediates), 3))\n",
550
- "for idx, (ax, img_t) in enumerate(zip(axes, intermediates)):\n",
551
- " img = img_t[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
552
- " ax.imshow(np.clip(img, 0, 1))\n",
553
- " ax.axis('off')\n",
554
- " if idx == 0:\n",
555
- " ax.set_title('Noise (t=1)')\n",
556
- " elif idx == len(intermediates) - 1:\n",
557
- " ax.set_title('Output (t=0)')\n",
558
- " else:\n",
559
- " ax.set_title(f't={1-idx*vis_interval/total_euler_steps:.1f}')\n",
560
- "plt.suptitle('LiquidDiffusion Denoising Process')\n",
561
- "plt.tight_layout()\n",
562
- "plt.show()"
563
- ]
564
- },
565
- {
566
- "cell_type": "markdown",
567
- "metadata": {},
568
- "source": [
569
- "## πŸ’Ύ Save & Export Model"
570
- ]
571
- },
572
- {
573
- "cell_type": "code",
574
- "execution_count": null,
575
- "metadata": {},
576
- "outputs": [],
577
- "source": [
578
- "# Save final checkpoint\n",
579
- "trainer.save_checkpoint('checkpoints/final.pt', extra={\n",
580
- " 'config': {\n",
581
- " 'model_size': model_size,\n",
582
- " 'image_size': image_size,\n",
583
- " 'total_params': total_params,\n",
584
- " 'training_steps': trainer.step,\n",
585
- " 'dataset': dataset_name,\n",
586
- " }\n",
587
- "})\n",
588
- "print(f\"Saved checkpoint: checkpoints/final.pt\")\n",
589
- "print(f\"Model: liquid_diffusion_{model_size} ({total_params/1e6:.1f}M params)\")\n",
590
- "print(f\"Trained for {trainer.step:,} steps on {dataset_name}\")"
591
- ]
592
- },
593
- {
594
- "cell_type": "code",
595
- "execution_count": null,
596
- "metadata": {},
597
- "outputs": [],
598
- "source": [
599
- "# Optional: Push to Hugging Face Hub\n",
600
- "# Uncomment and fill in your details:\n",
601
- "\n",
602
- "# from huggingface_hub import HfApi, login\n",
603
- "# login() # or use token\n",
604
- "# api = HfApi()\n",
605
- "# repo_id = \"your-username/liquid-diffusion-celebahq-256\" # change this\n",
606
- "# api.create_repo(repo_id, exist_ok=True)\n",
607
- "# api.upload_file('checkpoints/final.pt', 'model.pt', repo_id)\n",
608
- "# api.upload_folder('liquid_diffusion/', 'liquid_diffusion/', repo_id)\n",
609
- "# print(f\"Uploaded to https://huggingface.co/{repo_id}\")"
610
- ]
611
- },
612
- {
613
- "cell_type": "markdown",
614
- "metadata": {},
615
- "source": [
616
- "## πŸ“š Architecture Details & Theory\n",
617
- "\n",
618
- "### Why Liquid Neural Networks for Image Generation?\n",
619
- "\n",
620
- "**Liquid Time-Constant (LTC) Networks** (Hasani et al., 2020) define neurons with input-dependent time constants:\n",
621
- "\n",
622
- "```\n",
623
- "dx/dt = -[1/Ο„ + f(x,I,ΞΈ)] Β· x + f(x,I,ΞΈ) Β· A\n",
624
- "```\n",
625
- "\n",
626
- "The system time constant `Ο„_sys = Ο„/(1 + τ·f)` adapts dynamically based on input β€” the neuron speeds up or slows down its response depending on what it sees. This is the \"liquid\" property.\n",
627
- "\n",
628
- "**CfC (Closed-form Continuous-depth)** networks (Hasani et al., 2022) solve this ODE in closed form:\n",
629
- "\n",
630
- "```\n",
631
- "x(t) = Οƒ(-fΒ·t) βŠ™ g + (1 - Οƒ(-fΒ·t)) βŠ™ h\n",
632
- "```\n",
633
- "\n",
634
- "This eliminates the ODE solver β€” making CfC **fully parallelizable** while preserving the adaptive time constant behavior.\n",
635
- "\n",
636
- "### Our Innovation: CfC Γ— Diffusion Timestep\n",
637
- "\n",
638
- "In diffusion models, the network must process images at different noise levels `t ∈ [0,1]`. We observe that:\n",
639
- "\n",
640
- "1. CfC's time parameter `t` controls interpolation between two learned states\n",
641
- "2. Diffusion's noise level `t` controls how the denoiser should behave\n",
642
- "3. **These are the same concept** β€” the CfC time parameter IS the diffusion timestep\n",
643
- "\n",
644
- "This gives us:\n",
645
- "- At `tβ‰ˆ0` (clean images): Οƒ(-fΒ·t)β‰ˆ0.5, balanced processing for detail refinement\n",
646
- "- At `tβ‰ˆ1` (noisy images): Οƒ(-fΒ·t) saturates, specialized denoising\n",
647
- "- The gate `f` is **input-dependent** β€” different image content gets different time responses\n",
648
- "\n",
649
- "### References\n",
650
- "\n",
651
- "1. Hasani et al., \"Liquid Time-constant Networks\" (AAAI 2021) β€” arxiv:2006.04439\n",
652
- "2. Hasani et al., \"Closed-form Continuous-time Neural Networks\" (Nature MI 2022) β€” arxiv:2106.13898\n",
653
- "3. LiquidTAD: Parallel liquid relaxation β€” arxiv:2604.18274\n",
654
- "4. USM: U-Shape Mamba for diffusion β€” arxiv:2504.13499\n",
655
- "5. DiffuSSM: Diffusion without attention β€” arxiv:2311.18257\n",
656
- "6. Liu et al., \"Flow Straight and Fast: Rectified Flow\" (ICLR 2023) β€” arxiv:2209.03003"
657
- ]
658
- }
659
- ],
660
- "metadata": {
661
- "accelerator": "GPU",
662
- "colab": {
663
- "gpuType": "T4",
664
- "provenance": [],
665
- "toc_visible": true
666
- },
667
- "kernelspec": {
668
- "display_name": "Python 3",
669
- "name": "python3"
670
- },
671
- "language_info": {
672
- "name": "python",
673
- "version": "3.10.0"
674
- }
675
- },
676
- "nbformat": 4,
677
- "nbformat_minor": 0
678
  }
 
1
  {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {"provenance": [], "gpuType": "T4"},
6
+ "kernelspec": {"name": "python3", "display_name": "Python 3"},
7
+ "accelerator": "GPU"
8
+ },
9
+ "cells": [
10
+ {"cell_type": "markdown", "metadata": {}, "source": ["# \ud83c\udf0a LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n", "\n", "A **novel** image generation model combining:\n", "- **Liquid Neural Networks** (CfC \u2014 Closed-form Continuous-depth) for adaptive processing\n", "- **Rectified Flow** for simple, stable training (MSE velocity prediction)\n", "- **Pretrained SD-VAE** for efficient latent-space training (4ch, 8\u00d7 downscale)\n", "- **Zero attention** \u2014 fully convolutional + multi-scale spatial mixing\n", "- **Fully parallelizable** \u2014 no ODE loops, no recurrence\n", "\n", "### Key Innovation\n", "Diffusion timestep = liquid time constant. CfC gate `\u03c3(-f\u00b7t)` adapts behavior to noise level.\n", "\n", "### References\n", "- [CfC Networks (Nature MI 2022)](https://arxiv.org/abs/2106.13898)\n", "- [LiquidTAD (2024)](https://arxiv.org/abs/2604.18274) | [USM (CVPR 2025)](https://arxiv.org/abs/2504.13499)\n", "- [Rectified Flow (ICLR 2023)](https://arxiv.org/abs/2209.03003)\n", "- **Repo**: [krystv/liquid-diffusion](https://huggingface.co/krystv/liquid-diffusion)"]},
11
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \u2699\ufe0f Configuration"]},
12
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["#@title \u2699\ufe0f Configuration { display-mode: \"form\" }\n", "\n", "#@markdown ### Model\n", "MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'custom']\n", "#@markdown > `tiny`=23M (256px, ~6GB) | `small`=69M (256px, ~10GB)\n", "CUSTOM_CHANNELS = [48, 96, 192]\n", "CUSTOM_BLOCKS = [1, 2, 3]\n", "CUSTOM_T_DIM = 192\n", "\n", "#@markdown ### Resolution\n", "IMAGE_SIZE = 256 #@param [128, 256, 512] {type:\"raw\"}\n", "\n", "#@markdown ### VAE (Latent Space)\n", "USE_VAE = True #@param {type:\"boolean\"}\n", "#@markdown > Pretrained SD-VAE encodes images to 4ch latents (8\u00d7 smaller). **Highly recommended.**\n", "VAE_MODEL = 'stabilityai/sd-vae-ft-mse' #@param ['stabilityai/sd-vae-ft-mse', 'madebyollin/sdxl-vae-fp16-fix']\n", "PRECACHE_LATENTS = True #@param {type:\"boolean\"}\n", "#@markdown > Pre-encode all images once. Frees ~160MB VAE VRAM during training.\n", "\n", "#@markdown ### Dataset\n", "DATASET = 'nielsr/CelebA-faces' #@param ['nielsr/CelebA-faces', 'huggan/flowers-102-categories', 'reach-vb/pokemon-blip-captions', 'huggan/anime-faces', 'huggan/AFHQv2', 'Norod78/cartoon-blip-captions']\n", "#@markdown > All verified \u2713 | CelebA=202K faces | flowers=8K | pokemon=833 | anime=21K | AFHQ=16K animals | cartoon=2K\n", "IMAGE_COLUMN = 'image'\n", "MAX_SAMPLES = None # e.g. 5000 for quick test, None=full\n", "\n", "#@markdown ### Training\n", "BATCH_SIZE = 8 #@param {type:\"integer\"}\n", "LEARNING_RATE = 1e-4 #@param {type:\"number\"}\n", "WEIGHT_DECAY = 0.01 #@param {type:\"number\"}\n", "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n", "GRAD_CLIP = 1.0 #@param {type:\"number\"}\n", "EMA_DECAY = 0.9999 #@param {type:\"number\"}\n", "NUM_WORKERS = 2\n", "TIME_SAMPLING = 'logit_normal' #@param ['uniform', 'logit_normal']\n", "USE_AMP = True #@param {type:\"boolean\"}\n", "AMP_DTYPE = 'float16' #@param ['float16', 'bfloat16']\n", "\n", "#@markdown ### Sampling & Logging\n", "SAMPLE_EVERY = 500 #@param {type:\"integer\"}\n", "NUM_SAMPLE_IMAGES = 8 #@param {type:\"integer\"}\n", "NUM_EULER_STEPS = 50 #@param {type:\"integer\"}\n", "SAVE_EVERY = 2000 #@param {type:\"integer\"}\n", "OUTPUT_DIR = './outputs'\n", "RESUME_FROM = None\n", "LOG_EVERY = 50\n", "\n", "LATENT_SIZE = IMAGE_SIZE // 8 if USE_VAE else IMAGE_SIZE\n", "IN_CHANNELS = 4 if USE_VAE else 3\n", "print(f\"Config: {MODEL_SIZE} | {IMAGE_SIZE}px {'(latent '+str(LATENT_SIZE)+'px)' if USE_VAE else '(pixel)'} | {DATASET}\")\n", "print(f\"Training: bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}, AMP={USE_AMP}\")"]},
13
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udce6 Install & Check GPU"]},
14
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["!pip install -q datasets diffusers accelerate huggingface_hub Pillow matplotlib transformers\n", "import torch\n", "print(f\"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB\")\n", "else:\n", " print(\"\u26a0\ufe0f No GPU! Enable via Runtime \u2192 Change runtime type.\")"]},
15
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfd7\ufe0f Model Architecture"]},
16
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import math, copy, os, time\nfrom glob import glob\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision import transforms\nfrom torchvision.utils import save_image, make_grid\n\nclass SinusoidalTimeEmbedding(nn.Module):\n def __init__(self, dim, max_period=10000):\n super().__init__()\n self.dim, self.mp = dim, max_period\n self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))\n def forward(self, t):\n h = self.dim // 2\n f = torch.exp(-math.log(self.mp)*torch.arange(h, device=t.device, dtype=t.dtype)/h)\n e = torch.cat([torch.cos(t[:,None]*f[None]), torch.sin(t[:,None]*f[None])], -1)\n if self.dim%2: e = F.pad(e,(0,1))\n return self.mlp(e)\n\nclass AdaLN(nn.Module):\n def __init__(self, dim, cd):\n super().__init__()\n ng = min(32, dim)\n while dim%ng!=0: ng-=1\n self.norm = nn.GroupNorm(ng, dim, affine=False)\n self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cd, dim*2))\n def forward(self, x, te):\n s, sh = self.proj(te).chunk(2,1)\n return self.norm(x)*(1+s[:,:,None,None])+sh[:,:,None,None]\n\nclass ParallelCfCBlock(nn.Module):\n def __init__(self, dim, td, er=2.0, ks=7, dr=0.0):\n super().__init__()\n hid = int(dim*er)\n self.bdw = nn.Conv2d(dim, dim, ks, padding=ks//2, groups=dim)\n self.bpw = nn.Conv2d(dim, hid, 1)\n self.ba = nn.SiLU()\n self.fh = nn.Conv2d(hid, dim, 1)\n self.gh = nn.Sequential(nn.Conv2d(hid,hid,ks,padding=ks//2,groups=hid),nn.SiLU(),nn.Conv2d(hid,dim,1))\n self.hh = nn.Sequential(nn.Conv2d(hid,hid,ks,padding=ks//2,groups=hid),nn.SiLU(),nn.Conv2d(hid,dim,1))\n self.ta, self.tb = nn.Linear(td, dim), nn.Linear(td, dim)\n self.rho = nn.Parameter(torch.zeros(1,dim,1,1))\n self.og = nn.Sequential(nn.SiLU(), nn.Linear(td, dim))\n self.do = nn.Dropout(dr) if dr>0 else nn.Identity()\n def forward(self, x, te):\n res = x\n bb = self.ba(self.bpw(self.bdw(x)))\n f,g,h = self.fh(bb), self.gh(bb), self.hh(bb)\n gt = torch.sigmoid(self.ta(te)[:,:,None,None]*f - self.tb(te)[:,:,None,None])\n co = self.do(gt*g + (1-gt)*h)\n lam = F.softplus(self.rho)+1e-6\n al = torch.exp(-lam*te.mean(1,keepdim=True)[:,:,None,None].abs().clamp(min=0.01))\n return (al*res+(1-al)*co)*torch.sigmoid(self.og(te))[:,:,None,None]\n\nclass MultiScaleSpatialMix(nn.Module):\n def __init__(self, dim, td):\n super().__init__()\n self.d3=nn.Conv2d(dim,dim,3,padding=1,groups=dim)\n self.d5=nn.Conv2d(dim,dim,5,padding=2,groups=dim)\n self.d7=nn.Conv2d(dim,dim,7,padding=3,groups=dim)\n self.gp=nn.AdaptiveAvgPool2d(1); self.gpj=nn.Conv2d(dim,dim,1)\n self.mg=nn.Conv2d(dim*4,dim,1); self.ac=nn.SiLU(); self.an=AdaLN(dim,td)\n def forward(self, x, te):\n xn=self.an(x,te)\n return x+self.ac(self.mg(torch.cat([self.d3(xn),self.d5(xn),self.d7(xn),self.gpj(self.gp(xn)).expand_as(xn)],1)))\n\nclass LiquidDiffusionBlock(nn.Module):\n def __init__(self, dim, td, er=2.0, ks=7, dr=0.0):\n super().__init__()\n self.a1=AdaLN(dim,td); self.cfc=ParallelCfCBlock(dim,td,er,ks,dr)\n self.sm=MultiScaleSpatialMix(dim,td); self.a2=AdaLN(dim,td)\n ff=int(dim*er); self.ff=nn.Sequential(nn.Conv2d(dim,ff,1),nn.SiLU(),nn.Conv2d(ff,dim,1))\n self.rs=nn.Parameter(torch.ones(1)*0.1)\n def forward(self, x, te):\n x=x+self.rs*self.cfc(self.a1(x,te),te); x=self.sm(x,te)\n return x+self.rs*self.ff(self.a2(x,te))\n\nclass DS(nn.Module):\n def __init__(self,i,o): super().__init__(); self.c=nn.Conv2d(i,o,3,stride=2,padding=1)\n def forward(self,x): return self.c(x)\nclass US(nn.Module):\n def __init__(self,i,o): super().__init__(); self.c=nn.Conv2d(i,o,3,padding=1)\n def forward(self,x): return self.c(F.interpolate(x,scale_factor=2,mode='nearest'))\nclass SF(nn.Module):\n def __init__(self,d,td): super().__init__(); self.p=nn.Conv2d(d*2,d,1); self.g=nn.Sequential(nn.SiLU(),nn.Linear(td,d),nn.Sigmoid())\n def forward(self,x,sk,te): m=self.p(torch.cat([x,sk],1)); g=self.g(te)[:,:,None,None]; return m*g+x*(1-g)\n\nclass LiquidDiffusionUNet(nn.Module):\n def __init__(self, in_ch=3, chs=None, bps=None, td=256, er=2.0, ks=7, dr=0.0):\n super().__init__()\n chs=chs or [64,128,256]; bps=bps or [2,2,4]\n assert len(chs)==len(bps)\n self.chs,self.ns=chs,len(chs)\n self.te=SinusoidalTimeEmbedding(td)\n self.st=nn.Sequential(nn.Conv2d(in_ch,chs[0],3,padding=1),nn.SiLU(),nn.Conv2d(chs[0],chs[0],3,padding=1))\n self.enc,self.dn=nn.ModuleList(),nn.ModuleList()\n for i in range(self.ns):\n self.enc.append(nn.ModuleList([LiquidDiffusionBlock(chs[i],td,er,ks,dr) for _ in range(bps[i])]))\n if i<self.ns-1: self.dn.append(DS(chs[i],chs[i+1]))\n self.bot=nn.ModuleList([LiquidDiffusionBlock(chs[-1],td,er,ks,dr) for _ in range(2)])\n self.dec,self.up_,self.sf_=nn.ModuleList(),nn.ModuleList(),nn.ModuleList()\n for i in range(self.ns-1,-1,-1):\n if i<self.ns-1: self.up_.append(US(chs[i+1],chs[i])); self.sf_.append(SF(chs[i],td))\n self.dec.append(nn.ModuleList([LiquidDiffusionBlock(chs[i],td,er,ks,dr) for _ in range(bps[i])]))\n hg=min(32,chs[0])\n while chs[0]%hg!=0: hg-=1\n self.hd=nn.Sequential(nn.GroupNorm(hg,chs[0]),nn.SiLU(),nn.Conv2d(chs[0],in_ch,3,padding=1))\n nn.init.zeros_(self.hd[-1].weight); nn.init.zeros_(self.hd[-1].bias)\n def forward(self, x, t):\n te=self.te(t); h=self.st(x); sk=[]\n for i in range(self.ns):\n for b in self.enc[i]: h=b(h,te)\n sk.append(h)\n if i<self.ns-1: h=self.dn[i](h)\n for b in self.bot: h=b(h,te)\n ui=0\n for di in range(self.ns):\n si=self.ns-1-di\n if di>0: h=self.up_[ui](h); h=self.sf_[ui](h,sk[si],te); ui+=1\n for b in self.dec[di]: h=b(h,te)\n return self.hd(h)\n def count_params(self): return sum(p.numel() for p in self.parameters()), sum(p.numel() for p in self.parameters() if p.requires_grad)\n\nprint('\u2705 Model architecture defined.')"]},
17
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udd27 Build Model + Load VAE"]},
18
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["device = 'cuda' if torch.cuda.is_available() else 'cpu'\nCFGS = {'tiny': dict(chs=[64,128,256], bps=[2,2,4], td=256), 'small': dict(chs=[96,192,384], bps=[2,3,6], td=384)}\nif MODEL_SIZE=='custom': cfg=dict(chs=CUSTOM_CHANNELS,bps=CUSTOM_BLOCKS,td=CUSTOM_T_DIM)\nelse: cfg=CFGS[MODEL_SIZE]\nmodel = LiquidDiffusionUNet(in_ch=IN_CHANNELS, **cfg).to(device)\ntp,_=model.count_params()\nprint(f'Model: {MODEL_SIZE} | {tp:,} params ({tp/1e6:.1f}M) | in_ch={IN_CHANNELS}')\n\nvae=None; vae_scale=1.0\nif USE_VAE:\n from diffusers import AutoencoderKL\n print(f'Loading VAE: {VAE_MODEL}...')\n vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=torch.float16 if device=='cuda' else torch.float32)\n vae = vae.to(device).eval(); vae.requires_grad_(False)\n vae_scale = vae.config.scaling_factor\n print(f'VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params, latent_ch={vae.config.latent_channels}, scale={vae_scale}')\n print(f' {IMAGE_SIZE}px \u2192 {LATENT_SIZE}px latent (8\u00d7 downsample)')\n\nwith torch.no_grad():\n tx=torch.randn(1,IN_CHANNELS,LATENT_SIZE,LATENT_SIZE,device=device)\n assert model(tx, torch.tensor([0.5],device=device)).shape==tx.shape\n print(f'Forward OK: {tx.shape}')\n del tx\nif device=='cuda': torch.cuda.empty_cache(); print(f'VRAM: {torch.cuda.memory_allocated()/1e9:.2f}GB')"]},
19
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcca Load Dataset"]},
20
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["from PIL import Image\nfrom datasets import load_dataset\nimport matplotlib.pyplot as plt\n\nclass ImageDS(Dataset):\n def __init__(self, ds, sz, col='image'):\n self.ds, self.col = ds, col\n self.tf = transforms.Compose([transforms.Resize(sz, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(sz), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __len__(self): return len(self.ds)\n def __getitem__(self, i):\n img = self.ds[i][self.col]\n if not hasattr(img,'convert'): img=Image.fromarray(img)\n return self.tf(img.convert('RGB'))\n\nprint(f'Loading: {DATASET}')\nraw = load_dataset(DATASET, split='train')\nif MAX_SAMPLES: raw = raw.select(range(min(MAX_SAMPLES, len(raw))))\nprint(f' {len(raw):,} images')\ndataset = ImageDS(raw, IMAGE_SIZE, IMAGE_COLUMN)\ndata_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=True)\nprint(f' {len(data_loader):,} steps/epoch | ~{len(data_loader)*NUM_EPOCHS:,} total steps')\n\nsb=next(iter(data_loader))\nfig,axes=plt.subplots(1,min(8,BATCH_SIZE),figsize=(16,2.5))\nfor i,ax in enumerate(axes if hasattr(axes,'__len__') else [axes]): ax.imshow((sb[i].permute(1,2,0)*0.5+0.5).clamp(0,1)); ax.axis('off')\nplt.suptitle(f'Training samples ({IMAGE_SIZE}px)'); plt.tight_layout(); plt.show()"]},
21
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\uddc3\ufe0f Pre-cache Latents (if VAE enabled)"]},
22
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["cached_latents = None\ntrain_loader = data_loader\n\nif USE_VAE and PRECACHE_LATENTS:\n print(f'Pre-encoding {len(dataset):,} images...')\n cl = DataLoader(dataset, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n all_z = []\n vd = torch.float16 if device=='cuda' else torch.float32\n t0 = time.time()\n with torch.no_grad():\n for bi, imgs in enumerate(cl):\n z = vae.encode(imgs.to(device, dtype=vd)).latent_dist.sample() * vae_scale\n all_z.append(z.cpu().float())\n if (bi+1)%50==0: print(f' {(bi+1)*BATCH_SIZE*2:,}/{len(dataset):,}')\n cached_latents = torch.cat(all_z)\n print(f' Done in {time.time()-t0:.0f}s | Shape: {cached_latents.shape} | {cached_latents.numel()*4/1e9:.2f}GB')\n vae = vae.cpu()\n if device=='cuda': torch.cuda.empty_cache(); print(f' VAE \u2192 CPU. GPU VRAM: {torch.cuda.memory_allocated()/1e9:.2f}GB')\n class LatDS(Dataset):\n def __init__(self,z): self.z=z\n def __len__(self): return len(self.z)\n def __getitem__(self,i): return self.z[i]\n train_loader = DataLoader(LatDS(cached_latents), batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)\n print(f' Latent loader: {len(train_loader)} steps/epoch')\nelif USE_VAE:\n print('Online VAE encoding (VAE stays on GPU)')\nelse:\n print('Pixel-space training (no VAE)')"]},
23
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\ude80 Training"]},
24
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import matplotlib.pyplot as plt\nos.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True); os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=(0.9,0.999))\ntotal_steps = len(train_loader)*NUM_EPOCHS\nwarmup = min(1000, total_steps//10)\ndef lrl(step):\n if step<warmup: return step/max(1,warmup)\n return max(0.0, 0.5*(1+math.cos(math.pi*(step-warmup)/max(1,total_steps-warmup))))\nsched = torch.optim.lr_scheduler.LambdaLR(optimizer, lrl)\n\nema = copy.deepcopy(model).eval()\nfor p in ema.parameters(): p.requires_grad_(False)\nscaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device=='cuda'))\namp_dt = getattr(torch, AMP_DTYPE) if USE_AMP and device=='cuda' else torch.float32\n\ndef st(bs):\n e=1e-5\n if TIME_SAMPLING=='uniform': return torch.rand(bs,device=device)*(1-2*e)+e\n return torch.sigmoid(torch.randn(bs,device=device)).clamp(e,1-e)\n\ngstep,start_ep,all_losses,ep_losses=0,0,[],[]\nif RESUME_FROM and os.path.exists(RESUME_FROM):\n ck=torch.load(RESUME_FROM,map_location=device,weights_only=False)\n model.load_state_dict(ck['model']); ema.load_state_dict(ck['ema_model']); optimizer.load_state_dict(ck['optimizer'])\n gstep=ck.get('step',0); start_ep=ck.get('epoch',0); all_losses=ck.get('losses',[])\n print(f'Resumed from step {gstep}')\n\n@torch.no_grad()\ndef gen_samples(step):\n ema.eval()\n z=torch.randn(NUM_SAMPLE_IMAGES,IN_CHANNELS,LATENT_SIZE,LATENT_SIZE,device=device)\n dt=1.0/NUM_EULER_STEPS\n for i in range(NUM_EULER_STEPS,0,-1):\n t=torch.full((NUM_SAMPLE_IMAGES,),i/NUM_EULER_STEPS,device=device)\n with torch.amp.autocast(device,dtype=amp_dt,enabled=USE_AMP and device=='cuda'): v=ema(z,t)\n if USE_AMP and amp_dt==torch.float16: v=v.float()\n z=z-v*dt\n z=z.clamp(-3,3)\n if USE_VAE:\n _v=vae.to(device); vd=torch.float16 if device=='cuda' else torch.float32\n imgs=_v.decode(z.to(vd)/vae_scale).sample.float()\n if PRECACHE_LATENTS: vae.cpu()\n else: imgs=z\n imgs=imgs.clamp(-1,1)\n save_image(make_grid(imgs*0.5+0.5,nrow=int(math.ceil(math.sqrt(NUM_SAMPLE_IMAGES))),padding=2),f'{OUTPUT_DIR}/samples/step_{step:06d}.png')\n return imgs\n\nprint(f'\\n{\"=\"*60}\\nTraining: {NUM_EPOCHS} epochs, {total_steps:,} steps\\n{\"=\"*60}\\n')\nt_start=time.time(); online_vae=USE_VAE and not PRECACHE_LATENTS; vd=torch.float16 if device=='cuda' else torch.float32\n\nfor epoch in range(start_ep, NUM_EPOCHS):\n model.train(); el=0\n for batch in train_loader:\n if online_vae:\n with torch.no_grad(): x0=vae.encode(batch.to(device,dtype=vd)).latent_dist.sample()*vae_scale; x0=x0.float()\n else: x0=batch.to(device)\n x1=torch.randn_like(x0); t=st(x0.shape[0]); te=t[:,None,None,None]\n xt=(1-te)*x0+te*x1; vt=x1-x0\n with torch.amp.autocast(device,dtype=amp_dt,enabled=USE_AMP and device=='cuda'):\n vp=model(xt,t); loss=F.mse_loss(vp,vt)\n optimizer.zero_grad(set_to_none=True); scaler.scale(loss).backward()\n if GRAD_CLIP>0: scaler.unscale_(optimizer); torch.nn.utils.clip_grad_norm_(model.parameters(),GRAD_CLIP)\n scaler.step(optimizer); scaler.update(); sched.step()\n with torch.no_grad():\n for ep,mp in zip(ema.parameters(),model.parameters()): ep.data.mul_(EMA_DECAY).add_(mp.data,alpha=1-EMA_DECAY)\n gstep+=1; lv=loss.item(); all_losses.append(lv); el+=lv\n if gstep%LOG_EVERY==0:\n avg=sum(all_losses[-LOG_EVERY:])/LOG_EVERY; lr=sched.get_last_lr()[0]\n sps=gstep/(time.time()-t_start); eta=(total_steps-gstep)/max(sps,1e-8)\n vm=f' | VRAM:{torch.cuda.max_memory_allocated()/1e9:.1f}GB' if device=='cuda' else ''\n print(f'Step {gstep:6d}/{total_steps} | Loss:{avg:.4f} | LR:{lr:.2e} | {sps:.1f}it/s | ETA:{eta/60:.0f}m{vm}')\n if gstep%SAMPLE_EVERY==0:\n print(' \\U0001f4f8 Generating...'); samps=gen_samples(gstep)\n fig,axes=plt.subplots(1,min(8,NUM_SAMPLE_IMAGES),figsize=(16,2.5))\n if not hasattr(axes,'__len__'): axes=[axes]\n for i,ax in enumerate(axes):\n if i<samps.shape[0]: ax.imshow((samps[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\n plt.suptitle(f'Step {gstep} | Loss:{lv:.4f}'); plt.tight_layout(); plt.show()\n if gstep%SAVE_EVERY==0:\n cp=f'{OUTPUT_DIR}/checkpoints/step_{gstep:06d}.pt'\n torch.save({'model':model.state_dict(),'ema_model':ema.state_dict(),'optimizer':optimizer.state_dict(),'step':gstep,'epoch':epoch,'losses':all_losses[-2000:],'config':cfg},cp)\n print(f' \\U0001f4be Saved: {cp}')\n ep_losses.append(el/len(train_loader))\n print(f' Epoch {epoch+1}/{NUM_EPOCHS} | Avg loss:{ep_losses[-1]:.4f}')\n\nfp=f'{OUTPUT_DIR}/checkpoints/final.pt'\ntorch.save({'model':model.state_dict(),'ema_model':ema.state_dict(),'step':gstep,'config':cfg,'losses':all_losses[-2000:]},fp)\nprint(f'\\n\\u2705 Done! {fp} | {(time.time()-t_start)/3600:.1f}h')"]},
25
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcc8 Training Curves"]},
26
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import numpy as np\nfig,(a1,a2)=plt.subplots(1,2,figsize=(14,5))\na1.plot(all_losses,alpha=0.3,color='blue',lw=0.5)\nw=min(200,len(all_losses)//5)\nif w>1:\n sm=np.convolve(all_losses,np.ones(w)/w,mode='valid')\n a1.plot(range(w-1,len(all_losses)),sm,color='red',lw=2,label=f'Smooth(w={w})')\na1.set_xlabel('Step');a1.set_ylabel('Loss');a1.set_title('Training Loss');a1.legend();a1.grid(True,alpha=0.3)\nif ep_losses: a2.plot(range(1,len(ep_losses)+1),ep_losses,'o-',color='green'); a2.set_xlabel('Epoch');a2.set_ylabel('Loss');a2.set_title('Per Epoch');a2.grid(True,alpha=0.3)\nplt.tight_layout();plt.show()"]},
27
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfa8 Generate Images"]},
28
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["N_GEN=16; STEPS=50\nprint(f'Generating {N_GEN} images ({STEPS} steps)...')\nema.eval()\nif USE_VAE: vae=vae.to(device)\nwith torch.no_grad():\n z=torch.randn(N_GEN,IN_CHANNELS,LATENT_SIZE,LATENT_SIZE,device=device)\n dt=1.0/STEPS\n for i in range(STEPS,0,-1):\n t=torch.full((N_GEN,),i/STEPS,device=device)\n with torch.amp.autocast(device,dtype=amp_dt,enabled=USE_AMP and device=='cuda'): v=ema(z,t)\n if USE_AMP and amp_dt==torch.float16: v=v.float()\n z=z-v*dt\n if USE_VAE: vdd=torch.float16 if device=='cuda' else torch.float32; gen=vae.decode(z.clamp(-3,3).to(vdd)/vae_scale).sample.float()\n else: gen=z\n gen=gen.clamp(-1,1)\nnr=int(math.ceil(math.sqrt(N_GEN)))\nfig,axes=plt.subplots(nr,nr,figsize=(2.5*nr,2.5*nr))\naxes=axes.flatten() if hasattr(axes,'flatten') else [axes]\nfor i,ax in enumerate(axes):\n if i<N_GEN: ax.imshow((gen[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\nplt.suptitle(f'LiquidDiffusion ({IMAGE_SIZE}px, {STEPS} steps)',fontsize=14);plt.tight_layout();plt.show()\nsave_image(make_grid(gen*0.5+0.5,nrow=nr,padding=2),f'{OUTPUT_DIR}/final_samples.png')\nprint(f'Saved: {OUTPUT_DIR}/final_samples.png')"]},
29
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcbe Push to Hub"]},
30
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["PUSH=False #@param {type:\"boolean\"}\nHUB_ID='your-username/liquid-diffusion-256' #@param {type:\"string\"}\nif PUSH:\n from huggingface_hub import HfApi\n api=HfApi(); api.create_repo(HUB_ID,exist_ok=True)\n api.upload_file(path_or_fileobj=fp,path_in_repo='model.pt',repo_id=HUB_ID)\n print(f'Pushed: https://huggingface.co/{HUB_ID}')"]},
31
+ {"cell_type": "markdown", "metadata": {}, "source": ["---\n", "## \ud83d\udcd6 Architecture Reference\n", "\n", "### CfC Time-Gating\n", "```\n", "gate = \u03c3(time_a(t) \u00b7 f(features) - time_b(t))\n", "out = gate \u00b7 g + (1-gate) \u00b7 h\n", "```\n", "### Liquid Relaxation\n", "```\n", "\u03b1 = exp(-\u03bb\u00b7|t|), out = \u03b1\u00b7input + (1-\u03b1)\u00b7CfC_out\n", "```\n", "High noise \u2192 \u03b1\u22480 \u2192 heavy processing. Low noise \u2192 \u03b1\u22481 \u2192 preserve.\n", "\n", "### VAE: `stabilityai/sd-vae-ft-mse`\n", "83M params, 4ch latents, 8\u00d7 downscale. 256px\u219232\u00d732\u00d74 latent.\n", "\n", "### Verified Datasets\n", "| Dataset | Size | Content |\n", "|---------|------|---------|\n", "| `nielsr/CelebA-faces` | 202K | Celebrity faces |\n", "| `huggan/flowers-102-categories` | 8K | Flowers |\n", "| `reach-vb/pokemon-blip-captions` | 833 | Pokemon art |\n", "| `huggan/anime-faces` | 21K | Anime faces |\n", "| `huggan/AFHQv2` | 16K | Cat/dog/wild |\n", "| `Norod78/cartoon-blip-captions` | 2K | Cartoon characters |"]}
32
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  }