krystv commited on
Commit
f349bc4
Β·
verified Β·
1 Parent(s): 02e8800

Add Colab/Kaggle training notebook with all dataset options

Browse files
Files changed (1) hide show
  1. LiquidFlow_Training.ipynb +747 -0
LiquidFlow_Training.ipynb ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🌊 LiquidFlow β€” Liquid-SSM Flow Matching Image Generator\n",
8
+ "\n",
9
+ "A **novel architecture** combining:\n",
10
+ "- **Liquid Time-Constant Networks** (CfC closed-form) β€” adaptive ODE dynamics, bounded by construction\n",
11
+ "- **Selective State Space Models** (Mamba-style) β€” linear-time long-range context, parallelizable\n",
12
+ "- **Zigzag Scanning** β€” 2D spatial awareness for image patches\n",
13
+ "- **Physics-Informed Regularization** β€” smoothness + total variation constraints\n",
14
+ "- **Rectified Flow Matching** β€” ODE-based generation (no noise schedule tuning)\n",
15
+ "\n",
16
+ "### πŸ“‹ What this notebook does\n",
17
+ "1. **Install & clone** the LiquidFlow codebase\n",
18
+ "2. **Choose a dataset** (CIFAR-10, Flowers-102, CelebA, or custom folder)\n",
19
+ "3. **Choose a model size** (tiny ~6M, small ~14M, base ~38M)\n",
20
+ "4. **Train** with one click β€” all Colab/Kaggle optimized\n",
21
+ "5. **Generate images** and visualize progress\n",
22
+ "6. **Export** trained model for mobile deployment\n",
23
+ "\n",
24
+ "### πŸ’» Hardware Requirements\n",
25
+ "| Config | GPU VRAM | Best For |\n",
26
+ "|--------|----------|----------|\n",
27
+ "| tiny-128 (bs=32) | ~4 GB | Colab free T4, Kaggle |\n",
28
+ "| small-128 (bs=16) | ~8 GB | Colab free T4, Kaggle |\n",
29
+ "| base-256 (bs=8) | ~12 GB | Colab Pro, Kaggle |\n",
30
+ "| 512 (bs=4) | ~14 GB | Colab Pro, A100 |"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {},
36
+ "source": [
37
+ "---\n",
38
+ "## 0. Setup & Install"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "# Check GPU\n",
48
+ "!nvidia-smi || echo 'No GPU β€” CPU training only (very slow)'\n",
49
+ "import torch\n",
50
+ "print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')\n",
51
+ "if torch.cuda.is_available():\n",
52
+ " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "# Install dependencies\n",
62
+ "!pip install -q torch torchvision einops pillow matplotlib tqdm"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "# Clone the repo (or just copy the files if already have them)\n",
72
+ "import os\n",
73
+ "if not os.path.exists('liquidflow'):\n",
74
+ " !git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n",
75
+ " !cp -r liquidflow_repo/liquidflow .\n",
76
+ "else:\n",
77
+ " print('liquidflow/ already exists')\n",
78
+ "\n",
79
+ "# Verify\n",
80
+ "from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
81
+ "from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
82
+ "from liquidflow.sampling import euler_sample, heun_sample, generate_grid, make_grid_image\n",
83
+ "print('βœ… LiquidFlow imported successfully!')"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "metadata": {},
89
+ "source": [
90
+ "---\n",
91
+ "## 1. βš™οΈ Configuration β€” EDIT THIS CELL\n",
92
+ "\n",
93
+ "Choose your dataset, model size, and training hyperparameters."
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "#@title πŸŽ›οΈ Training Configuration { display-mode: \"form\" }\n",
103
+ "\n",
104
+ "# ============== DATASET ==============\n",
105
+ "#@markdown ### Dataset\n",
106
+ "DATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist', 'afhq', 'lsun_churches']\n",
107
+ "CUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\n",
108
+ "#@markdown > For 'folder': put images in CUSTOM_DATA_DIR. Supports .png/.jpg/.webp\n",
109
+ "\n",
110
+ "# ============== MODEL ==============\n",
111
+ "#@markdown ### Model\n",
112
+ "MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\n",
113
+ "IMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\n",
114
+ "\n",
115
+ "# ============== TRAINING ==============\n",
116
+ "#@markdown ### Training\n",
117
+ "EPOCHS = 100 #@param {type:\"integer\"}\n",
118
+ "BATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\n",
119
+ "LEARNING_RATE = 3e-4 #@param {type:\"number\"}\n",
120
+ "GRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\n",
121
+ "USE_AMP = True #@param {type:\"boolean\"}\n",
122
+ "\n",
123
+ "# ============== PHYSICS LOSS ==============\n",
124
+ "#@markdown ### Physics-Informed Regularization\n",
125
+ "LAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\n",
126
+ "LAMBDA_TV = 0.001 #@param {type:\"number\"}\n",
127
+ "\n",
128
+ "# ============== SAMPLING ==============\n",
129
+ "#@markdown ### Sampling & Logging\n",
130
+ "SAMPLE_EVERY = 5 #@param {type:\"integer\"}\n",
131
+ "SAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\n",
132
+ "LOG_EVERY = 50 #@param {type:\"integer\"}\n",
133
+ "SAVE_EVERY = 10 #@param {type:\"integer\"}\n",
134
+ "\n",
135
+ "# ============== PATHS ==============\n",
136
+ "OUTPUT_DIR = './outputs'\n",
137
+ "DATA_DIR = './data'\n",
138
+ "\n",
139
+ "# ============== AUTO-CONFIG ==============\n",
140
+ "# Smart batch size based on GPU memory\n",
141
+ "import torch\n",
142
+ "if torch.cuda.is_available():\n",
143
+ " vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
144
+ " print(f'GPU VRAM: {vram_gb:.1f} GB')\n",
145
+ " \n",
146
+ " # Auto-adjust batch size if needed\n",
147
+ " recommended = {\n",
148
+ " (32, 'tiny'): 128, (64, 'tiny'): 64, (128, 'tiny'): 32,\n",
149
+ " (32, 'small'): 64, (64, 'small'): 32, (128, 'small'): 16,\n",
150
+ " (256, 'base'): 8, (512, '512'): 4,\n",
151
+ " }\n",
152
+ " key = (IMG_SIZE, MODEL_SIZE)\n",
153
+ " if key in recommended and vram_gb < 16:\n",
154
+ " rec_bs = recommended[key]\n",
155
+ " if BATCH_SIZE > rec_bs:\n",
156
+ " print(f'⚠️ Reducing batch size {BATCH_SIZE} β†’ {rec_bs} for {vram_gb:.0f}GB VRAM')\n",
157
+ " BATCH_SIZE = rec_bs\n",
158
+ "else:\n",
159
+ " print('⚠️ No GPU detected β€” training will be very slow!')\n",
160
+ " USE_AMP = False\n",
161
+ "\n",
162
+ "print(f'\\nπŸ“‹ Config: {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')\n",
163
+ "print(f' Physics: Ξ»_smooth={LAMBDA_SMOOTH}, Ξ»_tv={LAMBDA_TV}')\n",
164
+ "print(f' AMP: {USE_AMP}, GradAccum: {GRAD_ACCUM}')"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "metadata": {},
170
+ "source": [
171
+ "---\n",
172
+ "## 2. πŸ“¦ Load Dataset"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "import torchvision\n",
182
+ "import torchvision.transforms as transforms\n",
183
+ "from torch.utils.data import DataLoader, Dataset, ConcatDataset\n",
184
+ "from pathlib import Path\n",
185
+ "from PIL import Image\n",
186
+ "import os\n",
187
+ "\n",
188
+ "# Standard transform\n",
189
+ "def get_transform(img_size):\n",
190
+ " return transforms.Compose([\n",
191
+ " transforms.Resize(img_size + img_size // 8),\n",
192
+ " transforms.CenterCrop(img_size),\n",
193
+ " transforms.RandomHorizontalFlip(),\n",
194
+ " transforms.ToTensor(),\n",
195
+ " transforms.Normalize([0.5]*3, [0.5]*3),\n",
196
+ " ])\n",
197
+ "\n",
198
+ "class ImageFolderFlat(Dataset):\n",
199
+ " \"\"\"Load all images from a folder (recursively).\"\"\"\n",
200
+ " def __init__(self, root, transform):\n",
201
+ " self.transform = transform\n",
202
+ " self.files = []\n",
203
+ " for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:\n",
204
+ " self.files.extend(Path(root).rglob(ext))\n",
205
+ " self.files = sorted(self.files)\n",
206
+ " print(f'Found {len(self.files)} images in {root}')\n",
207
+ " def __len__(self): return len(self.files)\n",
208
+ " def __getitem__(self, idx):\n",
209
+ " return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
210
+ "\n",
211
+ "class GrayscaleToRGB:\n",
212
+ " \"\"\"Convert 1-channel grayscale to 3-channel RGB.\"\"\"\n",
213
+ " def __call__(self, x):\n",
214
+ " if x.shape[0] == 1:\n",
215
+ " x = x.repeat(3, 1, 1)\n",
216
+ " return x\n",
217
+ "\n",
218
+ "tfm = get_transform(IMG_SIZE)\n",
219
+ "\n",
220
+ "if DATASET == 'cifar10':\n",
221
+ " dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=tfm)\n",
222
+ " print(f'βœ… CIFAR-10: {len(dataset)} images')\n",
223
+ "\n",
224
+ "elif DATASET == 'flowers':\n",
225
+ " ds_train = torchvision.datasets.Flowers102(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
226
+ " ds_val = torchvision.datasets.Flowers102(root=DATA_DIR, split='val', download=True, transform=tfm)\n",
227
+ " ds_test = torchvision.datasets.Flowers102(root=DATA_DIR, split='test', download=True, transform=tfm)\n",
228
+ " dataset = ConcatDataset([ds_train, ds_val, ds_test]) # Use all splits for generation\n",
229
+ " print(f'βœ… Flowers-102: {len(dataset)} images (all splits)')\n",
230
+ "\n",
231
+ "elif DATASET == 'celeba':\n",
232
+ " dataset = torchvision.datasets.CelebA(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
233
+ " print(f'βœ… CelebA: {len(dataset)} images')\n",
234
+ "\n",
235
+ "elif DATASET == 'fashion_mnist':\n",
236
+ " fm_tfm = transforms.Compose([\n",
237
+ " transforms.Resize(IMG_SIZE),\n",
238
+ " transforms.ToTensor(),\n",
239
+ " transforms.Normalize([0.5], [0.5]),\n",
240
+ " GrayscaleToRGB(),\n",
241
+ " ])\n",
242
+ " dataset = torchvision.datasets.FashionMNIST(root=DATA_DIR, train=True, download=True, transform=fm_tfm)\n",
243
+ " print(f'βœ… Fashion-MNIST: {len(dataset)} images (converted to RGB)')\n",
244
+ "\n",
245
+ "elif DATASET == 'afhq':\n",
246
+ " # Download AFHQ from Kaggle or manual\n",
247
+ " afhq_dir = os.path.join(DATA_DIR, 'afhq', 'train')\n",
248
+ " if not os.path.exists(afhq_dir):\n",
249
+ " print('⬇️ Downloading AFHQ...')\n",
250
+ " !pip install -q gdown\n",
251
+ " !gdown 1Gof5BaELXlmSJIlvKMYCe9ONYPebkNsf -O {DATA_DIR}/afhq.zip\n",
252
+ " !unzip -q {DATA_DIR}/afhq.zip -d {DATA_DIR}/afhq\n",
253
+ " dataset = ImageFolderFlat(afhq_dir, tfm)\n",
254
+ " print(f'βœ… AFHQ: {len(dataset)} images')\n",
255
+ "\n",
256
+ "elif DATASET == 'lsun_churches':\n",
257
+ " # LSUN requires manual download β€” point to extracted folder\n",
258
+ " lsun_dir = os.path.join(DATA_DIR, 'lsun_churches')\n",
259
+ " if not os.path.exists(lsun_dir):\n",
260
+ " print('❌ LSUN churches not found. Please download and extract to', lsun_dir)\n",
261
+ " print(' See: https://github.com/fyu/lsun')\n",
262
+ " raise FileNotFoundError(lsun_dir)\n",
263
+ " dataset = ImageFolderFlat(lsun_dir, tfm)\n",
264
+ " print(f'βœ… LSUN Churches: {len(dataset)} images')\n",
265
+ "\n",
266
+ "elif DATASET == 'folder':\n",
267
+ " dataset = ImageFolderFlat(CUSTOM_DATA_DIR, tfm)\n",
268
+ " print(f'βœ… Custom folder: {len(dataset)} images from {CUSTOM_DATA_DIR}')\n",
269
+ "\n",
270
+ "else:\n",
271
+ " raise ValueError(f'Unknown dataset: {DATASET}')\n",
272
+ "\n",
273
+ "# Show a few samples\n",
274
+ "import matplotlib.pyplot as plt\n",
275
+ "import numpy as np\n",
276
+ "\n",
277
+ "fig, axes = plt.subplots(1, 8, figsize=(16, 2))\n",
278
+ "for i, ax in enumerate(axes):\n",
279
+ " sample = dataset[i]\n",
280
+ " if isinstance(sample, (list, tuple)):\n",
281
+ " sample = sample[0]\n",
282
+ " img = sample * 0.5 + 0.5 # denormalize\n",
283
+ " ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy())\n",
284
+ " ax.axis('off')\n",
285
+ "plt.suptitle(f'{DATASET} samples ({IMG_SIZE}Γ—{IMG_SIZE})', fontsize=14)\n",
286
+ "plt.tight_layout()\n",
287
+ "plt.show()"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "markdown",
292
+ "metadata": {},
293
+ "source": [
294
+ "---\n",
295
+ "## 3. πŸ—οΈ Build Model"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "import torch\n",
305
+ "from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
306
+ "\n",
307
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
308
+ "\n",
309
+ "model_factories = {\n",
310
+ " 'tiny': liquidflow_tiny,\n",
311
+ " 'small': liquidflow_small,\n",
312
+ " 'base': liquidflow_base,\n",
313
+ " '512': liquidflow_512,\n",
314
+ "}\n",
315
+ "\n",
316
+ "model = model_factories[MODEL_SIZE](img_size=IMG_SIZE).to(device)\n",
317
+ "\n",
318
+ "num_params = model.count_params()\n",
319
+ "print(f'πŸ—οΈ LiquidFlow-{MODEL_SIZE}')\n",
320
+ "print(f' Parameters: {num_params:,} ({num_params/1e6:.1f}M)')\n",
321
+ "print(f' Image size: {IMG_SIZE}Γ—{IMG_SIZE}')\n",
322
+ "print(f' Patch size: {model.patch_size}')\n",
323
+ "print(f' Num patches: {model.num_patches}')\n",
324
+ "print(f' Model dim: {model.d_model}')\n",
325
+ "print(f' Depth: {model.depth}')\n",
326
+ "print(f' Device: {device}')\n",
327
+ "\n",
328
+ "# Quick forward pass test\n",
329
+ "with torch.no_grad():\n",
330
+ " test_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
331
+ " test_t = torch.tensor([0.5], device=device)\n",
332
+ " test_v = model(test_x, test_t)\n",
333
+ " assert test_v.shape == test_x.shape\n",
334
+ " print(f' βœ… Forward pass OK: {test_x.shape} β†’ {test_v.shape}')"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "markdown",
339
+ "metadata": {},
340
+ "source": [
341
+ "---\n",
342
+ "## 4. πŸš€ Train"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": null,
348
+ "metadata": {},
349
+ "outputs": [],
350
+ "source": [
351
+ "import math\n",
352
+ "import time\n",
353
+ "import json\n",
354
+ "import torch.nn as nn\n",
355
+ "from torch.cuda.amp import autocast, GradScaler\n",
356
+ "from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
357
+ "from liquidflow.sampling import euler_sample, make_grid_image\n",
358
+ "from IPython.display import display, clear_output\n",
359
+ "import matplotlib.pyplot as plt\n",
360
+ "\n",
361
+ "# Prepare\n",
362
+ "os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\n",
363
+ "os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n",
364
+ "\n",
365
+ "dataloader = DataLoader(\n",
366
+ " dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
367
+ " num_workers=2, pin_memory=True, drop_last=True\n",
368
+ ")\n",
369
+ "\n",
370
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,\n",
371
+ " betas=(0.9, 0.999), weight_decay=0.01)\n",
372
+ "\n",
373
+ "total_steps = EPOCHS * len(dataloader) // GRAD_ACCUM\n",
374
+ "warmup_steps = min(500, total_steps // 10)\n",
375
+ "\n",
376
+ "def cosine_lr(step):\n",
377
+ " if step < warmup_steps:\n",
378
+ " return step / max(1, warmup_steps)\n",
379
+ " progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)\n",
380
+ " return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))\n",
381
+ "\n",
382
+ "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)\n",
383
+ "criterion = PhysicsInformedFlowLoss(\n",
384
+ " lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV\n",
385
+ ").to(device)\n",
386
+ "ema = EMAModel(model, decay=0.9999)\n",
387
+ "scaler = GradScaler(enabled=USE_AMP)\n",
388
+ "\n",
389
+ "# Training log\n",
390
+ "all_losses = []\n",
391
+ "global_step = 0\n",
392
+ "\n",
393
+ "print(f'πŸš€ Training {EPOCHS} epochs, {total_steps} steps')\n",
394
+ "print(f' Effective batch: {BATCH_SIZE} Γ— {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}')\n",
395
+ "print(f' LR: {LEARNING_RATE} β†’ warmup {warmup_steps} steps β†’ cosine decay')\n",
396
+ "print()\n",
397
+ "\n",
398
+ "t_start = time.time()\n",
399
+ "\n",
400
+ "for epoch in range(EPOCHS):\n",
401
+ " model.train()\n",
402
+ " epoch_loss = 0.0\n",
403
+ " epoch_flow = 0.0\n",
404
+ " n_batches = 0\n",
405
+ "\n",
406
+ " for batch_idx, batch_data in enumerate(dataloader):\n",
407
+ " if isinstance(batch_data, (list, tuple)):\n",
408
+ " x1 = batch_data[0].to(device)\n",
409
+ " else:\n",
410
+ " x1 = batch_data.to(device)\n",
411
+ "\n",
412
+ " B = x1.shape[0]\n",
413
+ " x0 = torch.randn_like(x1)\n",
414
+ " t = torch.rand(B, device=device)\n",
415
+ " t_e = t.view(B, 1, 1, 1)\n",
416
+ " x_t = t_e * x1 + (1 - t_e) * x0\n",
417
+ "\n",
418
+ " with autocast(enabled=USE_AMP):\n",
419
+ " v_pred = model(x_t, t)\n",
420
+ " loss, ld = criterion(v_pred, x0, x1, t, step=global_step)\n",
421
+ " loss = loss / GRAD_ACCUM\n",
422
+ "\n",
423
+ " scaler.scale(loss).backward()\n",
424
+ "\n",
425
+ " if (batch_idx + 1) % GRAD_ACCUM == 0:\n",
426
+ " scaler.unscale_(optimizer)\n",
427
+ " gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
428
+ " scaler.step(optimizer)\n",
429
+ " scaler.update()\n",
430
+ " optimizer.zero_grad()\n",
431
+ " scheduler.step()\n",
432
+ " ema.update(model)\n",
433
+ " global_step += 1\n",
434
+ "\n",
435
+ " epoch_loss += ld['total'].item()\n",
436
+ " epoch_flow += ld['flow'].item()\n",
437
+ " n_batches += 1\n",
438
+ "\n",
439
+ " if global_step % LOG_EVERY == 0:\n",
440
+ " avg = epoch_loss / n_batches\n",
441
+ " avg_f = epoch_flow / n_batches\n",
442
+ " lr_now = scheduler.get_last_lr()[0]\n",
443
+ " elapsed = time.time() - t_start\n",
444
+ " it_s = global_step / elapsed\n",
445
+ " all_losses.append({'step': global_step, 'loss': avg, 'flow': avg_f,\n",
446
+ " 'lr': lr_now, 'epoch': epoch})\n",
447
+ " print(f' E{epoch+1} step {global_step}/{total_steps} | '\n",
448
+ " f'loss={avg:.4f} flow={avg_f:.4f} lr={lr_now:.2e} '\n",
449
+ " f'gn={gn:.2f} [{it_s:.1f} it/s]')\n",
450
+ "\n",
451
+ " # End of epoch\n",
452
+ " avg_epoch = epoch_loss / max(1, n_batches)\n",
453
+ " print(f'\\nπŸ“Š Epoch {epoch+1}/{EPOCHS} β€” avg loss: {avg_epoch:.4f}\\n')\n",
454
+ "\n",
455
+ " # Sample\n",
456
+ " if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:\n",
457
+ " model.eval()\n",
458
+ " ema.apply_shadow(model)\n",
459
+ " with torch.no_grad():\n",
460
+ " n_samples = min(16, BATCH_SIZE)\n",
461
+ " imgs = euler_sample(model, (n_samples, 3, IMG_SIZE, IMG_SIZE),\n",
462
+ " num_steps=SAMPLE_STEPS, device=device)\n",
463
+ " imgs = imgs.clamp(-1, 1) * 0.5 + 0.5\n",
464
+ " grid = make_grid_image(imgs, nrow=4)\n",
465
+ " grid.save(f'{OUTPUT_DIR}/samples/epoch_{epoch+1:04d}.png')\n",
466
+ "\n",
467
+ " # Display inline\n",
468
+ " fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
469
+ " ax.imshow(grid)\n",
470
+ " ax.set_title(f'Epoch {epoch+1} β€” {MODEL_SIZE}-{IMG_SIZE} on {DATASET}')\n",
471
+ " ax.axis('off')\n",
472
+ " plt.tight_layout()\n",
473
+ " plt.show()\n",
474
+ "\n",
475
+ " ema.restore(model)\n",
476
+ " model.train()\n",
477
+ "\n",
478
+ " # Checkpoint\n",
479
+ " if (epoch + 1) % SAVE_EVERY == 0:\n",
480
+ " ckpt = {\n",
481
+ " 'model': model.state_dict(),\n",
482
+ " 'optimizer': optimizer.state_dict(),\n",
483
+ " 'scheduler': scheduler.state_dict(),\n",
484
+ " 'ema': ema.state_dict(),\n",
485
+ " 'epoch': epoch,\n",
486
+ " 'global_step': global_step,\n",
487
+ " }\n",
488
+ " torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/epoch_{epoch+1:04d}.pt')\n",
489
+ " torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n",
490
+ " print(f'πŸ’Ύ Checkpoint saved: epoch {epoch+1}')\n",
491
+ "\n",
492
+ "# Save final\n",
493
+ "ema.apply_shadow(model)\n",
494
+ "torch.save({'model': model.state_dict(), 'config': {\n",
495
+ " 'model_size': MODEL_SIZE, 'img_size': IMG_SIZE, 'dataset': DATASET,\n",
496
+ " 'num_params': num_params, 'epochs': EPOCHS,\n",
497
+ "}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\n",
498
+ "ema.restore(model)\n",
499
+ "\n",
500
+ "elapsed = time.time() - t_start\n",
501
+ "print(f'\\nβœ… Training complete! {elapsed/60:.1f} min total')\n",
502
+ "print(f' Final model: {OUTPUT_DIR}/liquidflow_final.pt')"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "markdown",
507
+ "metadata": {},
508
+ "source": [
509
+ "---\n",
510
+ "## 5. πŸ“ˆ Training Curves"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "metadata": {},
517
+ "outputs": [],
518
+ "source": [
519
+ "import matplotlib.pyplot as plt\n",
520
+ "\n",
521
+ "if all_losses:\n",
522
+ " steps = [d['step'] for d in all_losses]\n",
523
+ " losses = [d['loss'] for d in all_losses]\n",
524
+ " flows = [d['flow'] for d in all_losses]\n",
525
+ " lrs = [d['lr'] for d in all_losses]\n",
526
+ "\n",
527
+ " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
528
+ "\n",
529
+ " ax1.plot(steps, losses, label='Total Loss', alpha=0.8)\n",
530
+ " ax1.plot(steps, flows, label='Flow Loss', alpha=0.8)\n",
531
+ " ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n",
532
+ " ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True, alpha=0.3)\n",
533
+ "\n",
534
+ " ax2.plot(steps, lrs, color='orange')\n",
535
+ " ax2.set_xlabel('Step'); ax2.set_ylabel('LR')\n",
536
+ " ax2.set_title('Learning Rate Schedule'); ax2.grid(True, alpha=0.3)\n",
537
+ "\n",
538
+ " plt.tight_layout()\n",
539
+ " plt.savefig(f'{OUTPUT_DIR}/training_curves.png', dpi=150)\n",
540
+ " plt.show()\n",
541
+ "else:\n",
542
+ " print('No training logs yet.')"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "markdown",
547
+ "metadata": {},
548
+ "source": [
549
+ "---\n",
550
+ "## 6. 🎨 Generate Images"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": null,
556
+ "metadata": {},
557
+ "outputs": [],
558
+ "source": [
559
+ "#@title 🎨 Generation Settings { display-mode: \"form\" }\n",
560
+ "NUM_IMAGES = 16 #@param {type:\"integer\"}\n",
561
+ "GEN_STEPS = 50 #@param [10, 25, 50, 100, 200] {type:\"integer\"}\n",
562
+ "SAMPLER = 'euler' #@param ['euler', 'heun']\n",
563
+ "SEED = 42 #@param {type:\"integer\"}\n",
564
+ "\n",
565
+ "import torch\n",
566
+ "from liquidflow.sampling import euler_sample, heun_sample, make_grid_image\n",
567
+ "import matplotlib.pyplot as plt\n",
568
+ "\n",
569
+ "# Load best model\n",
570
+ "ckpt_path = f'{OUTPUT_DIR}/liquidflow_final.pt'\n",
571
+ "if os.path.exists(ckpt_path):\n",
572
+ " ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)\n",
573
+ " model.load_state_dict(ckpt['model'])\n",
574
+ " print(f'Loaded: {ckpt_path}')\n",
575
+ "else:\n",
576
+ " print(f'No checkpoint found, using current model weights')\n",
577
+ "\n",
578
+ "model.eval()\n",
579
+ "torch.manual_seed(SEED)\n",
580
+ "\n",
581
+ "shape = (NUM_IMAGES, 3, IMG_SIZE, IMG_SIZE)\n",
582
+ "\n",
583
+ "with torch.no_grad():\n",
584
+ " if SAMPLER == 'euler':\n",
585
+ " images = euler_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
586
+ " else:\n",
587
+ " images = heun_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
588
+ "\n",
589
+ "images = images.clamp(-1, 1) * 0.5 + 0.5\n",
590
+ "grid = make_grid_image(images, nrow=int(NUM_IMAGES**0.5))\n",
591
+ "grid.save(f'{OUTPUT_DIR}/generated_final.png')\n",
592
+ "\n",
593
+ "plt.figure(figsize=(10, 10))\n",
594
+ "plt.imshow(grid)\n",
595
+ "plt.title(f'LiquidFlow-{MODEL_SIZE} | {DATASET} {IMG_SIZE}Γ—{IMG_SIZE} | {GEN_STEPS} steps ({SAMPLER})')\n",
596
+ "plt.axis('off')\n",
597
+ "plt.show()"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "markdown",
602
+ "metadata": {},
603
+ "source": [
604
+ "---\n",
605
+ "## 7. πŸ“± Export for Mobile (ONNX + TorchScript)"
606
+ ]
607
+ },
608
+ {
609
+ "cell_type": "code",
610
+ "execution_count": null,
611
+ "metadata": {},
612
+ "outputs": [],
613
+ "source": [
614
+ "# Export to TorchScript for mobile deployment\n",
615
+ "model.eval()\n",
616
+ "\n",
617
+ "# TorchScript (for PyTorch Mobile / ExecuTorch)\n",
618
+ "example_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
619
+ "example_t = torch.tensor([0.5], device=device)\n",
620
+ "\n",
621
+ "try:\n",
622
+ " traced = torch.jit.trace(model, (example_x, example_t))\n",
623
+ " ts_path = f'{OUTPUT_DIR}/liquidflow_mobile.pt'\n",
624
+ " traced.save(ts_path)\n",
625
+ " ts_size_mb = os.path.getsize(ts_path) / 1e6\n",
626
+ " print(f'βœ… TorchScript saved: {ts_path} ({ts_size_mb:.1f} MB)')\n",
627
+ "except Exception as e:\n",
628
+ " print(f'⚠️ TorchScript export failed: {e}')\n",
629
+ "\n",
630
+ "# ONNX\n",
631
+ "try:\n",
632
+ " onnx_path = f'{OUTPUT_DIR}/liquidflow.onnx'\n",
633
+ " torch.onnx.export(\n",
634
+ " model.cpu(), (example_x.cpu(), example_t.cpu()),\n",
635
+ " onnx_path, opset_version=14,\n",
636
+ " input_names=['image', 'timestep'],\n",
637
+ " output_names=['velocity'],\n",
638
+ " dynamic_axes={'image': {0: 'batch'}, 'timestep': {0: 'batch'}, 'velocity': {0: 'batch'}}\n",
639
+ " )\n",
640
+ " onnx_size_mb = os.path.getsize(onnx_path) / 1e6\n",
641
+ " print(f'βœ… ONNX saved: {onnx_path} ({onnx_size_mb:.1f} MB)')\n",
642
+ " model.to(device)\n",
643
+ "except Exception as e:\n",
644
+ " print(f'⚠️ ONNX export failed: {e}')\n",
645
+ " model.to(device)"
646
+ ]
647
+ },
648
+ {
649
+ "cell_type": "markdown",
650
+ "metadata": {},
651
+ "source": [
652
+ "---\n",
653
+ "## 8. πŸ”¬ Architecture Deep Dive\n",
654
+ "\n",
655
+ "### How LiquidFlow works\n",
656
+ "\n",
657
+ "```\n",
658
+ "Noise xβ‚€ ~ N(0,I) ──→ LiquidFlow v_ΞΈ(xβ‚œ, t) ──→ Image x₁\n",
659
+ " β”‚\n",
660
+ " β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”\n",
661
+ " β”‚ Patchify β”‚ (img β†’ non-overlapping patches)\n",
662
+ " β”‚ + PosEmb β”‚ (2D learnable positions)\n",
663
+ " β”‚ + DepthConvβ”‚ (local structure)\n",
664
+ " β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜\n",
665
+ " β”‚\n",
666
+ " β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n",
667
+ " β”‚ L Γ— LiquidSSM Block β”‚\n",
668
+ " β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚\n",
669
+ " β”‚ β”‚ AdaLN (t-cond) β”‚ β”‚\n",
670
+ " β”‚ β”‚ Zigzag Scan β”‚ β”‚ ← rotates scan pattern per layer\n",
671
+ " β”‚ β”‚ SelectiveSSM β”‚ β”‚ ← Mamba-style, input-dependent\n",
672
+ " β”‚ β”‚ + LiquidCfC β”‚ β”‚ ← CfC gating, bounded dynamics\n",
673
+ " β”‚ β”‚ + FFN β”‚ β”‚\n",
674
+ " β”‚ β”‚ + Skip Connect β”‚ β”‚ ← U-Net style long skips\n",
675
+ " β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚\n",
676
+ " β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n",
677
+ " β”‚\n",
678
+ " β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”\n",
679
+ " β”‚ DepthConv β”‚\n",
680
+ " β”‚ Unpatchify β”‚ (patches β†’ img)\n",
681
+ " β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”˜\n",
682
+ " β”‚\n",
683
+ " velocity v_ΞΈ\n",
684
+ "```\n",
685
+ "\n",
686
+ "### Key Innovations\n",
687
+ "\n",
688
+ "1. **Liquid CfC Cell**: Instead of solving the ODE `dx/dt = f(x,t)` numerically, we use the\n",
689
+ " closed-form solution `x(t+Ξ”t) = Οƒ(-f_Ο„) βŠ™ x(t) + (1 - Οƒ(-f_Ο„)) βŠ™ f_x`.\n",
690
+ " The sigmoid gating **guarantees bounded dynamics** β€” no training explosion possible.\n",
691
+ "\n",
692
+ "2. **SSM + Liquid dual path**: The SSM branch captures long-range spatial dependencies\n",
693
+ " via selective scanning; the Liquid branch adds continuous-time adaptive dynamics.\n",
694
+ " A learnable mixing coefficient balances them.\n",
695
+ "\n",
696
+ "3. **Physics-informed loss**: Smoothness (Laplacian) and Total Variation regularizers\n",
697
+ " act as soft PDE constraints on generated images, improving training stability\n",
698
+ " and reducing artifacts without domain-specific physics knowledge.\n",
699
+ "\n",
700
+ "4. **Flow Matching = Liquid ODE**: Rectified flow trains `v_ΞΈ` to follow straight paths\n",
701
+ " from noise to data. This is structurally identical to the LTC ODE, making Liquid\n",
702
+ " networks a natural fit as the velocity field parameterization."
703
+ ]
704
+ },
705
+ {
706
+ "cell_type": "markdown",
707
+ "metadata": {},
708
+ "source": [
709
+ "---\n",
710
+ "## 9. πŸ§ͺ Recommended Experiments\n",
711
+ "\n",
712
+ "| Experiment | Dataset | Model | IMG_SIZE | Epochs | Notes |\n",
713
+ "|------------|---------|-------|----------|--------|-------|\n",
714
+ "| Quick sanity check | CIFAR-10 | tiny | 32 | 20 | ~5 min on T4 |\n",
715
+ "| Baseline 128Γ—128 | CIFAR-10 | tiny | 128 | 100 | ~2 hrs on T4 |\n",
716
+ "| Quality 128Γ—128 | Flowers-102 | small | 128 | 200 | ~4 hrs on T4 |\n",
717
+ "| Faces 128Γ—128 | CelebA | small | 128 | 50 | ~6 hrs on T4 |\n",
718
+ "| High-res 512Γ—512 | CelebA | 512 | 512 | 100 | needs β‰₯16GB |\n",
719
+ "| Production | Your data | small | 128 | 300+ | best quality |\n",
720
+ "\n",
721
+ "### Tips for best results:\n",
722
+ "- Start with `tiny` + low epochs to verify everything works\n",
723
+ "- Use `small` for 128Γ—128 production quality\n",
724
+ "- Increase `SAMPLE_STEPS` to 100+ for final generation\n",
725
+ "- `heun` sampler gives better quality at half the steps vs `euler`\n",
726
+ "- Physics loss warmup is automatic β€” don't increase Ξ» too much"
727
+ ]
728
+ }
729
+ ],
730
+ "metadata": {
731
+ "accelerator": "GPU",
732
+ "colab": {
733
+ "gpuType": "T4",
734
+ "provenance": []
735
+ },
736
+ "kernelspec": {
737
+ "display_name": "Python 3",
738
+ "name": "python3"
739
+ },
740
+ "language_info": {
741
+ "name": "python",
742
+ "version": "3.10.12"
743
+ }
744
+ },
745
+ "nbformat": 4,
746
+ "nbformat_minor": 4
747
+ }