Upload LiquidFlow_Colab.ipynb
Browse files- LiquidFlow_Colab.ipynb +74 -243
LiquidFlow_Colab.ipynb
CHANGED
|
@@ -12,13 +12,7 @@
|
|
| 12 |
"- **CfC (Closed-form Continuous-time)** Liquid Neural Networks — adaptive time gates\n",
|
| 13 |
"- **Mamba-2 SSD** — linear-time attention replacement, fully parallelizable\n",
|
| 14 |
"- **Physics-Informed Regularization** — TV loss, spectral constraints\n",
|
| 15 |
-
"- **TAESD VAE** — Tiny AutoEncoder (<
|
| 16 |
-
"\n",
|
| 17 |
-
"Based on:\n",
|
| 18 |
-
"- CfC: Hasani et al., Nature MI 2022\n",
|
| 19 |
-
"- Mamba-2: Dao & Gu, 2024 \n",
|
| 20 |
-
"- PINN Diffusion: Bastek & Sun, ICLR 2025\n",
|
| 21 |
-
"- DiMSUM: NeurIPS 2024\n",
|
| 22 |
"\n",
|
| 23 |
"---\n",
|
| 24 |
"## Quick Start\n",
|
|
@@ -33,9 +27,8 @@
|
|
| 33 |
"execution_count": null,
|
| 34 |
"metadata": {},
|
| 35 |
"source": [
|
| 36 |
-
"# @title 1. Install Dependencies
|
| 37 |
-
"!pip install -q torch torchvision diffusers tqdm pillow numpy\n",
|
| 38 |
-
"!pip install -q git+https://github.com/huggingface/diffusers.git\n",
|
| 39 |
"\n",
|
| 40 |
"import torch\n",
|
| 41 |
"print(f\"PyTorch: {torch.__version__}\")\n",
|
|
@@ -52,7 +45,11 @@
|
|
| 52 |
"metadata": {},
|
| 53 |
"source": [
|
| 54 |
"# @title 2. Clone LiquidFlow Repository\n",
|
| 55 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
"%cd /content/LiquidFlow\n",
|
| 57 |
"\n",
|
| 58 |
"import sys\n",
|
|
@@ -60,8 +57,7 @@
|
|
| 60 |
"\n",
|
| 61 |
"from liquid_flow.generator import create_liquidflow\n",
|
| 62 |
"from liquid_flow.vae_wrapper import TAESDWrapper\n",
|
| 63 |
-
"
|
| 64 |
-
"import torch.nn.functional as F"
|
| 65 |
],
|
| 66 |
"outputs": []
|
| 67 |
},
|
|
@@ -70,33 +66,35 @@
|
|
| 70 |
"execution_count": null,
|
| 71 |
"metadata": {},
|
| 72 |
"source": [
|
| 73 |
-
"# @title 3. Configuration
|
| 74 |
"\n",
|
| 75 |
-
"# Model
|
| 76 |
"MODEL_VARIANT = 'small' # @param ['tiny', 'small', 'base']\n",
|
| 77 |
"\n",
|
| 78 |
-
"# Image size
|
| 79 |
-
"IMAGE_SIZE = 128 # @param [64, 128, 256, 512]\n",
|
| 80 |
"\n",
|
| 81 |
-
"# Training\n",
|
| 82 |
-
"BATCH_SIZE = 32 # @param [8, 16, 32, 64]\n",
|
| 83 |
-
"EPOCHS = 50 # @param [10, 25, 50, 100]\n",
|
| 84 |
-
"LEARNING_RATE = 2e-4 # @param
|
| 85 |
"\n",
|
| 86 |
"# Dataset\n",
|
| 87 |
"DATASET = 'cifar10' # @param ['cifar10', 'cifar100', 'stl10']\n",
|
| 88 |
"\n",
|
| 89 |
-
"# Sampling
|
| 90 |
-
"SAMPLE_EVERY = 5 # @param
|
| 91 |
-
"SAMPLE_STEPS = 50 # @param
|
| 92 |
"\n",
|
| 93 |
-
"#
|
| 94 |
-
"
|
| 95 |
-
"
|
| 96 |
-
"
|
|
|
|
|
|
|
|
|
|
| 97 |
"\n",
|
| 98 |
-
"print(f\"Config: {MODEL_VARIANT} model, {IMAGE_SIZE}px, batch={BATCH_SIZE}, epochs={EPOCHS}, lr={LEARNING_RATE}\")
|
| 99 |
-
"print(f\"Physics loss: TV={PHYSICS_TV_WEIGHT}, Spec={PHYSICS_SPEC_WEIGHT}, Grad={PHYSICS_GRAD_WEIGHT}\")"
|
| 100 |
],
|
| 101 |
"outputs": []
|
| 102 |
},
|
|
@@ -117,31 +115,17 @@
|
|
| 117 |
"# Load TAESD (Tiny AutoEncoder)\n",
|
| 118 |
"print(\"Loading TAESD VAE...\")\n",
|
| 119 |
"vae = TAESDWrapper.load(device)\n",
|
| 120 |
-
"
|
|
|
|
| 121 |
"\n",
|
| 122 |
"# Create LiquidFlow model\n",
|
| 123 |
-
"print(f\"
|
| 124 |
-
"model = create_liquidflow(\n",
|
| 125 |
-
" variant=MODEL_VARIANT,\n",
|
| 126 |
-
" image_size=IMAGE_SIZE,\n",
|
| 127 |
-
" physics_weights={\n",
|
| 128 |
-
" 'tv': PHYSICS_TV_WEIGHT,\n",
|
| 129 |
-
" 'cons': 0.001,\n",
|
| 130 |
-
" 'spec': PHYSICS_SPEC_WEIGHT,\n",
|
| 131 |
-
" 'grad': PHYSICS_GRAD_WEIGHT,\n",
|
| 132 |
-
" },\n",
|
| 133 |
-
")\n",
|
| 134 |
"model = model.to(device)\n",
|
| 135 |
"\n",
|
| 136 |
-
"n_params =
|
| 137 |
-
"print(f\"Model: {n_params:,}
|
| 138 |
-
"\
|
| 139 |
-
"# Memory estimate\n",
|
| 140 |
-
"latent_size = IMAGE_SIZE // 8\n",
|
| 141 |
-
"mem_per_sample = latent_size * latent_size * 4 * 4 / 1e6 # MB\n",
|
| 142 |
-
"print(f\"Memory per sample: {mem_per_sample:.1f} MB\")\n",
|
| 143 |
-
"print(f\"Estimated batch memory: {mem_per_sample * BATCH_SIZE:.1f} MB\")\n",
|
| 144 |
-
"print(f\"T4 VRAM: 15 GB — should fit!\" if mem_per_sample * BATCH_SIZE < 10 else \"Watch memory!\")"
|
| 145 |
],
|
| 146 |
"outputs": []
|
| 147 |
},
|
|
@@ -154,8 +138,9 @@
|
|
| 154 |
"\n",
|
| 155 |
"transform = transforms.Compose([\n",
|
| 156 |
" transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n",
|
|
|
|
| 157 |
" transforms.ToTensor(),\n",
|
| 158 |
-
" transforms.Normalize([0.5], [0.5]),\n",
|
| 159 |
"])\n",
|
| 160 |
"\n",
|
| 161 |
"if DATASET == 'cifar10':\n",
|
|
@@ -164,17 +149,14 @@
|
|
| 164 |
" dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
|
| 165 |
"elif DATASET == 'stl10':\n",
|
| 166 |
" dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)\n",
|
| 167 |
-
"else:\n",
|
| 168 |
-
" raise ValueError(f\"Unknown dataset: {DATASET}\")\n",
|
| 169 |
"\n",
|
| 170 |
"dataloader = DataLoader(\n",
|
| 171 |
" dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
|
| 172 |
-
" num_workers=
|
| 173 |
-
" pin_memory=True, drop_last=True,\n",
|
| 174 |
")\n",
|
| 175 |
"\n",
|
| 176 |
-
"print(f\"Dataset: {DATASET}\")\n",
|
| 177 |
-
"print(f\"
|
| 178 |
],
|
| 179 |
"outputs": []
|
| 180 |
},
|
|
@@ -183,109 +165,64 @@
|
|
| 183 |
"execution_count": null,
|
| 184 |
"metadata": {},
|
| 185 |
"source": [
|
| 186 |
-
"# @title 6.
|
| 187 |
-
"\n",
|
| 188 |
"from torchvision.utils import save_image\n",
|
| 189 |
-
"import math\n",
|
| 190 |
"\n",
|
| 191 |
"os.makedirs('./outputs/samples', exist_ok=True)\n",
|
| 192 |
"os.makedirs('./outputs/checkpoints', exist_ok=True)\n",
|
| 193 |
"\n",
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
-
" model.parameters(),\n",
|
| 197 |
-
" lr=LEARNING_RATE,\n",
|
| 198 |
-
" betas=(0.9, 0.999),\n",
|
| 199 |
-
" weight_decay=1e-4,\n",
|
| 200 |
-
")\n",
|
| 201 |
-
"\n",
|
| 202 |
-
"# Cosine LR scheduler\n",
|
| 203 |
-
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
| 204 |
-
" optimizer, T_max=EPOCHS * len(dataloader)\n",
|
| 205 |
-
")\n",
|
| 206 |
"\n",
|
| 207 |
-
"# AMP\n",
|
| 208 |
"use_amp = device.type == 'cuda'\n",
|
| 209 |
"scaler = torch.cuda.amp.GradScaler() if use_amp else None\n",
|
| 210 |
"\n",
|
| 211 |
-
"print(f\"Training: {EPOCHS} epochs,
|
| 212 |
-
"print(
|
| 213 |
"\n",
|
| 214 |
-
"global_step = 0\n",
|
| 215 |
"best_loss = float('inf')\n",
|
| 216 |
"\n",
|
| 217 |
"for epoch in range(EPOCHS):\n",
|
| 218 |
" model.train()\n",
|
| 219 |
-
"
|
| 220 |
-
" \n",
|
| 221 |
-
" pbar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{EPOCHS}\")\n",
|
| 222 |
" \n",
|
| 223 |
" for images, _ in pbar:\n",
|
| 224 |
" images = images.to(device)\n",
|
| 225 |
" \n",
|
| 226 |
-
" # Encode to latent\n",
|
| 227 |
" with torch.no_grad():\n",
|
| 228 |
" latents = TAESDWrapper.encode(vae, images)\n",
|
| 229 |
" \n",
|
| 230 |
-
" # Training step
|
| 231 |
" loss_dict = model.training_step(latents, optimizer, scaler, use_amp)\n",
|
| 232 |
-
" \n",
|
| 233 |
-
" # Track\n",
|
| 234 |
-
" total_loss = loss_dict['total']\n",
|
| 235 |
-
" epoch_total += total_loss\n",
|
| 236 |
-
" \n",
|
| 237 |
-
" # Update scheduler\n",
|
| 238 |
" scheduler.step()\n",
|
| 239 |
" \n",
|
| 240 |
-
"
|
| 241 |
-
" pbar.set_postfix({\n",
|
| 242 |
-
" 'loss': f\"{total_loss:.4f}\",\n",
|
| 243 |
-
" 'diff': f\"{loss_dict.get('diffusion', 0):.4f}\",\n",
|
| 244 |
-
" 'phys': f\"{loss_dict.get('physics', 0):.4f}\",\n",
|
| 245 |
-
" 'lr': f\"{optimizer.param_groups[0]['lr']:.2e}\",\n",
|
| 246 |
-
" })\n",
|
| 247 |
-
" \n",
|
| 248 |
-
" global_step += 1\n",
|
| 249 |
" \n",
|
| 250 |
-
"
|
| 251 |
-
" print(f
|
| 252 |
" \n",
|
| 253 |
" # Generate samples\n",
|
| 254 |
" if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == EPOCHS - 1:\n",
|
| 255 |
-
" print(\" Generating samples...\")\n",
|
| 256 |
" model.eval()\n",
|
| 257 |
" with torch.no_grad():\n",
|
| 258 |
-
"
|
| 259 |
-
"
|
| 260 |
-
"
|
| 261 |
-
"
|
| 262 |
-
" progress=False,\n",
|
| 263 |
-
" )\n",
|
| 264 |
-
" images_gen = TAESDWrapper.decode(vae, latents_gen)\n",
|
| 265 |
-
" \n",
|
| 266 |
-
" save_image(\n",
|
| 267 |
-
" images_gen, f'./outputs/samples/epoch_{epoch+1:03d}.png',\n",
|
| 268 |
-
" nrow=4, normalize=True, value_range=(-1, 1)\n",
|
| 269 |
-
" )\n",
|
| 270 |
-
" print(f\" Saved to ./outputs/samples/epoch_{epoch+1:03d}.png\")\n",
|
| 271 |
-
" \n",
|
| 272 |
-
" # Save checkpoint\n",
|
| 273 |
-
" if (epoch + 1) % 10 == 0 or epoch == EPOCHS - 1:\n",
|
| 274 |
-
" torch.save({\n",
|
| 275 |
-
" 'epoch': epoch + 1,\n",
|
| 276 |
-
" 'model_state_dict': model.state_dict(),\n",
|
| 277 |
-
" 'optimizer_state_dict': optimizer.state_dict(),\n",
|
| 278 |
-
" 'loss': avg_loss,\n",
|
| 279 |
-
" }, f'./outputs/checkpoints/epoch_{epoch+1:03d}.pt')\n",
|
| 280 |
" \n",
|
| 281 |
-
"
|
| 282 |
-
"
|
| 283 |
-
"
|
| 284 |
-
"\n",
|
| 285 |
-
"
|
| 286 |
-
"
|
| 287 |
-
"
|
| 288 |
-
"
|
|
|
|
|
|
|
| 289 |
],
|
| 290 |
"outputs": []
|
| 291 |
},
|
|
@@ -294,134 +231,28 @@
|
|
| 294 |
"execution_count": null,
|
| 295 |
"metadata": {},
|
| 296 |
"source": [
|
| 297 |
-
"# @title 7.
|
| 298 |
-
"\n",
|
| 299 |
"import matplotlib.pyplot as plt\n",
|
| 300 |
"from PIL import Image\n",
|
| 301 |
"import glob\n",
|
| 302 |
"\n",
|
| 303 |
-
"# Load latest sample\n",
|
| 304 |
"sample_files = sorted(glob.glob('./outputs/samples/epoch_*.png'))\n",
|
| 305 |
"if sample_files:\n",
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
-
" plt.figure(figsize=(12, 12))\n",
|
| 309 |
" plt.imshow(img)\n",
|
| 310 |
-
" plt.title(f'LiquidFlow
|
| 311 |
" plt.axis('off')\n",
|
| 312 |
" plt.show()\n",
|
| 313 |
"else:\n",
|
| 314 |
-
" print(
|
| 315 |
],
|
| 316 |
"outputs": []
|
| 317 |
-
},
|
| 318 |
-
{
|
| 319 |
-
"cell_type": "code",
|
| 320 |
-
"execution_count": null,
|
| 321 |
-
"metadata": {},
|
| 322 |
-
"source": [
|
| 323 |
-
"# @title 8. Export Model for Mobile (ONNX)\n",
|
| 324 |
-
"\n",
|
| 325 |
-
"# LiquidFlow can be exported to ONNX for mobile deployment\n",
|
| 326 |
-
"# since it uses pure PyTorch (no custom CUDA kernels)\n",
|
| 327 |
-
"\n",
|
| 328 |
-
"def export_to_onnx(model, output_path='liquidflow_model.onnx', image_size=128):\n",
|
| 329 |
-
" \"\"\"Export LiquidFlow to ONNX for mobile deployment.\"\"\"\n",
|
| 330 |
-
" model = model.cpu()\n",
|
| 331 |
-
" model.eval()\n",
|
| 332 |
-
" \n",
|
| 333 |
-
" latent_size = image_size // 8\n",
|
| 334 |
-
" \n",
|
| 335 |
-
" # Dummy inputs\n",
|
| 336 |
-
" x = torch.randn(1, 4, latent_size, latent_size)\n",
|
| 337 |
-
" t = torch.tensor([500], dtype=torch.long)\n",
|
| 338 |
-
" \n",
|
| 339 |
-
" # Export\n",
|
| 340 |
-
" torch.onnx.export(\n",
|
| 341 |
-
" model,\n",
|
| 342 |
-
" (x, t),\n",
|
| 343 |
-
" output_path,\n",
|
| 344 |
-
" input_names=['noisy_latent', 'timestep'],\n",
|
| 345 |
-
" output_names=['predicted_noise'],\n",
|
| 346 |
-
" dynamic_axes={\n",
|
| 347 |
-
" 'noisy_latent': {0: 'batch'},\n",
|
| 348 |
-
" 'predicted_noise': {0: 'batch'},\n",
|
| 349 |
-
" },\n",
|
| 350 |
-
" opset_version=14,\n",
|
| 351 |
-
" )\n",
|
| 352 |
-
" \n",
|
| 353 |
-
" import os\n",
|
| 354 |
-
" size_mb = os.path.getsize(output_path) / 1e6\n",
|
| 355 |
-
" print(f\"ONNX model exported to {output_path} ({size_mb:.1f} MB)\")\n",
|
| 356 |
-
" return output_path\n",
|
| 357 |
-
"\n",
|
| 358 |
-
"# Load best model and export\n",
|
| 359 |
-
"best_model_path = './outputs/checkpoints/best_model.pt'\n",
|
| 360 |
-
"if os.path.exists(best_model_path):\n",
|
| 361 |
-
" model.load_state_dict(torch.load(best_model_path, map_location='cpu'))\n",
|
| 362 |
-
" export_to_onnx(model, 'liquidflow_128.onnx', IMAGE_SIZE)\n",
|
| 363 |
-
" print(\"Ready for mobile deployment!\")\n",
|
| 364 |
-
"else:\n",
|
| 365 |
-
" print(\"Train model first before exporting.\")"
|
| 366 |
-
],
|
| 367 |
-
"outputs": []
|
| 368 |
-
},
|
| 369 |
-
{
|
| 370 |
-
"cell_type": "markdown",
|
| 371 |
-
"metadata": {},
|
| 372 |
-
"source": [
|
| 373 |
-
"## Architecture Details\n",
|
| 374 |
-
"\n",
|
| 375 |
-
"### LiquidFlow Block Architecture\n",
|
| 376 |
-
"```\n",
|
| 377 |
-
"Input → [CfC Gate → Mamba-2 SSD → CfC Gate] → Output\n",
|
| 378 |
-
" ↑ ↑\n",
|
| 379 |
-
" Adaptive time gate Gated output\n",
|
| 380 |
-
"```\n",
|
| 381 |
-
"\n",
|
| 382 |
-
"### CfC (Closed-form Continuous-time) Cell\n",
|
| 383 |
-
"```\n",
|
| 384 |
-
"h(t) = σ(-f(x,I;θ_f)·t) ⊙ g(x,I;θ_g) + (1-σ(-f(x,I;θ_f)·t)) ⊙ h(x,I;θ_h)\n",
|
| 385 |
-
"```\n",
|
| 386 |
-
"- **No ODE solver needed** — 100x faster than Neural ODEs\n",
|
| 387 |
-
"- Time-continuous gating adaptively controls information flow\n",
|
| 388 |
-
"- Closed-form solution → stable gradients\n",
|
| 389 |
-
"\n",
|
| 390 |
-
"### Mamba-2 SSD (State Space Duality)\n",
|
| 391 |
-
"```\n",
|
| 392 |
-
"h_t = A_t * h_{t-1} + B_t * x_t\n",
|
| 393 |
-
"y_t = C_t^T * h_t\n",
|
| 394 |
-
"```\n",
|
| 395 |
-
"- **O(N) linear complexity** vs Transformers O(N²)\n",
|
| 396 |
-
"- **Parallelizable** via associative scan (Blelloch)\n",
|
| 397 |
-
"- **Scalar-A** formulation enables chunk-scan optimization\n",
|
| 398 |
-
"- Pure PyTorch — no CUDA kernels needed\n",
|
| 399 |
-
"\n",
|
| 400 |
-
"### Physics-Informed Regularization\n",
|
| 401 |
-
"- **Total Variation**: `L_TV = ||∇_x x̂||₁ + ||∇_y x̂||₁`\n",
|
| 402 |
-
"- **Spectral**: Penalize high-frequency artifacts\n",
|
| 403 |
-
"- **Gradient**: Sobolev norm for stable training\n",
|
| 404 |
-
"- Pattern from Bastek & Sun (ICLR 2025): physics loss as training-only regularizer\n",
|
| 405 |
-
"\n",
|
| 406 |
-
"### Model Variants\n",
|
| 407 |
-
"| Variant | Params | Hidden Dim | Stages | Blocks | T4 VRAM |\n",
|
| 408 |
-
"|---------|--------|------------|--------|--------|---------|\n",
|
| 409 |
-
"| Tiny | ~2M | 128 | 2 | 2 | < 2 GB |\n",
|
| 410 |
-
"| Small | ~8M | 256 | 4 | 4 | ~4 GB |\n",
|
| 411 |
-
"| Base | ~30M | 384 | 6 | 6 | ~8 GB |"
|
| 412 |
-
]
|
| 413 |
}
|
| 414 |
],
|
| 415 |
"metadata": {
|
| 416 |
-
"colab": {
|
| 417 |
-
|
| 418 |
-
"provenance": []
|
| 419 |
-
},
|
| 420 |
-
"kernelspec": {
|
| 421 |
-
"display_name": "Python 3",
|
| 422 |
-
"language": "python",
|
| 423 |
-
"name": "python3"
|
| 424 |
-
}
|
| 425 |
},
|
| 426 |
"nbformat": 4,
|
| 427 |
"nbformat_minor": 0
|
|
|
|
| 12 |
"- **CfC (Closed-form Continuous-time)** Liquid Neural Networks — adaptive time gates\n",
|
| 13 |
"- **Mamba-2 SSD** — linear-time attention replacement, fully parallelizable\n",
|
| 14 |
"- **Physics-Informed Regularization** — TV loss, spectral constraints\n",
|
| 15 |
+
"- **TAESD VAE** — Tiny AutoEncoder (< 3M params) for fast encoding\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"\n",
|
| 17 |
"---\n",
|
| 18 |
"## Quick Start\n",
|
|
|
|
| 27 |
"execution_count": null,
|
| 28 |
"metadata": {},
|
| 29 |
"source": [
|
| 30 |
+
"# @title 1. Install Dependencies\n",
|
| 31 |
+
"!pip install -q torch torchvision diffusers tqdm pillow numpy accelerate\n",
|
|
|
|
| 32 |
"\n",
|
| 33 |
"import torch\n",
|
| 34 |
"print(f\"PyTorch: {torch.__version__}\")\n",
|
|
|
|
| 45 |
"metadata": {},
|
| 46 |
"source": [
|
| 47 |
"# @title 2. Clone LiquidFlow Repository\n",
|
| 48 |
+
"import os\n",
|
| 49 |
+
"if not os.path.exists('/content/LiquidFlow'):\n",
|
| 50 |
+
" !git clone https://huggingface.co/krystv/LiquidFlow-Gen /content/LiquidFlow\n",
|
| 51 |
+
"else:\n",
|
| 52 |
+
" !cd /content/LiquidFlow && git pull\n",
|
| 53 |
"%cd /content/LiquidFlow\n",
|
| 54 |
"\n",
|
| 55 |
"import sys\n",
|
|
|
|
| 57 |
"\n",
|
| 58 |
"from liquid_flow.generator import create_liquidflow\n",
|
| 59 |
"from liquid_flow.vae_wrapper import TAESDWrapper\n",
|
| 60 |
+
"print('LiquidFlow imported successfully!')"
|
|
|
|
| 61 |
],
|
| 62 |
"outputs": []
|
| 63 |
},
|
|
|
|
| 66 |
"execution_count": null,
|
| 67 |
"metadata": {},
|
| 68 |
"source": [
|
| 69 |
+
"# @title 3. Configuration\n",
|
| 70 |
"\n",
|
| 71 |
+
"# Model variant: 'tiny' (~3.6M), 'small' (~11M), 'base' (~36M)\n",
|
| 72 |
"MODEL_VARIANT = 'small' # @param ['tiny', 'small', 'base']\n",
|
| 73 |
"\n",
|
| 74 |
+
"# Image size (128 recommended for T4 free tier)\n",
|
| 75 |
+
"IMAGE_SIZE = 128 # @param [64, 128, 256, 512] {type:\"integer\"}\n",
|
| 76 |
"\n",
|
| 77 |
+
"# Training hyperparameters\n",
|
| 78 |
+
"BATCH_SIZE = 32 # @param [8, 16, 32, 64] {type:\"integer\"}\n",
|
| 79 |
+
"EPOCHS = 50 # @param [10, 25, 50, 100] {type:\"integer\"}\n",
|
| 80 |
+
"LEARNING_RATE = 2e-4 # @param {type:\"number\"}\n",
|
| 81 |
"\n",
|
| 82 |
"# Dataset\n",
|
| 83 |
"DATASET = 'cifar10' # @param ['cifar10', 'cifar100', 'stl10']\n",
|
| 84 |
"\n",
|
| 85 |
+
"# Sampling\n",
|
| 86 |
+
"SAMPLE_EVERY = 5 # @param {type:\"integer\"}\n",
|
| 87 |
+
"SAMPLE_STEPS = 50 # @param {type:\"integer\"}\n",
|
| 88 |
"\n",
|
| 89 |
+
"# Ensure integer types (Colab forms can return strings)\n",
|
| 90 |
+
"IMAGE_SIZE = int(IMAGE_SIZE)\n",
|
| 91 |
+
"BATCH_SIZE = int(BATCH_SIZE)\n",
|
| 92 |
+
"EPOCHS = int(EPOCHS)\n",
|
| 93 |
+
"SAMPLE_EVERY = int(SAMPLE_EVERY)\n",
|
| 94 |
+
"SAMPLE_STEPS = int(SAMPLE_STEPS)\n",
|
| 95 |
+
"LEARNING_RATE = float(LEARNING_RATE)\n",
|
| 96 |
"\n",
|
| 97 |
+
"print(f\"Config: {MODEL_VARIANT} model, {IMAGE_SIZE}px, batch={BATCH_SIZE}, epochs={EPOCHS}, lr={LEARNING_RATE}\")"
|
|
|
|
| 98 |
],
|
| 99 |
"outputs": []
|
| 100 |
},
|
|
|
|
| 115 |
"# Load TAESD (Tiny AutoEncoder)\n",
|
| 116 |
"print(\"Loading TAESD VAE...\")\n",
|
| 117 |
"vae = TAESDWrapper.load(device)\n",
|
| 118 |
+
"latent_size = IMAGE_SIZE // 8\n",
|
| 119 |
+
"print(f\"VAE loaded! Latent: {IMAGE_SIZE}x{IMAGE_SIZE} -> {latent_size}x{latent_size}x4\")\n",
|
| 120 |
"\n",
|
| 121 |
"# Create LiquidFlow model\n",
|
| 122 |
+
"print(f\"\\nCreating '{MODEL_VARIANT}' LiquidFlow model...\")\n",
|
| 123 |
+
"model = create_liquidflow(variant=MODEL_VARIANT, image_size=IMAGE_SIZE)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
"model = model.to(device)\n",
|
| 125 |
"\n",
|
| 126 |
+
"n_params = model.count_parameters()\n",
|
| 127 |
+
"print(f\"Model parameters: {n_params:,} ({n_params/1e6:.1f}M)\")\n",
|
| 128 |
+
"print(f\"\\nReady to train!\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
],
|
| 130 |
"outputs": []
|
| 131 |
},
|
|
|
|
| 138 |
"\n",
|
| 139 |
"transform = transforms.Compose([\n",
|
| 140 |
" transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n",
|
| 141 |
+
" transforms.RandomHorizontalFlip(),\n",
|
| 142 |
" transforms.ToTensor(),\n",
|
| 143 |
+
" transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),\n",
|
| 144 |
"])\n",
|
| 145 |
"\n",
|
| 146 |
"if DATASET == 'cifar10':\n",
|
|
|
|
| 149 |
" dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
|
| 150 |
"elif DATASET == 'stl10':\n",
|
| 151 |
" dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)\n",
|
|
|
|
|
|
|
| 152 |
"\n",
|
| 153 |
"dataloader = DataLoader(\n",
|
| 154 |
" dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
|
| 155 |
+
" num_workers=2, pin_memory=True, drop_last=True,\n",
|
|
|
|
| 156 |
")\n",
|
| 157 |
"\n",
|
| 158 |
+
"print(f\"Dataset: {DATASET} ({len(dataset):,} images)\")\n",
|
| 159 |
+
"print(f\"Batches per epoch: {len(dataloader)}\")"
|
| 160 |
],
|
| 161 |
"outputs": []
|
| 162 |
},
|
|
|
|
| 165 |
"execution_count": null,
|
| 166 |
"metadata": {},
|
| 167 |
"source": [
|
| 168 |
+
"# @title 6. Train!\n",
|
|
|
|
| 169 |
"from torchvision.utils import save_image\n",
|
|
|
|
| 170 |
"\n",
|
| 171 |
"os.makedirs('./outputs/samples', exist_ok=True)\n",
|
| 172 |
"os.makedirs('./outputs/checkpoints', exist_ok=True)\n",
|
| 173 |
"\n",
|
| 174 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)\n",
|
| 175 |
+
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS * len(dataloader))\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
"\n",
|
|
|
|
| 177 |
"use_amp = device.type == 'cuda'\n",
|
| 178 |
"scaler = torch.cuda.amp.GradScaler() if use_amp else None\n",
|
| 179 |
"\n",
|
| 180 |
+
"print(f\"Training: {EPOCHS} epochs, AMP={use_amp}\")\n",
|
| 181 |
+
"print('=' * 60)\n",
|
| 182 |
"\n",
|
|
|
|
| 183 |
"best_loss = float('inf')\n",
|
| 184 |
"\n",
|
| 185 |
"for epoch in range(EPOCHS):\n",
|
| 186 |
" model.train()\n",
|
| 187 |
+
" epoch_loss = 0\n",
|
| 188 |
+
" pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{EPOCHS}')\n",
|
|
|
|
| 189 |
" \n",
|
| 190 |
" for images, _ in pbar:\n",
|
| 191 |
" images = images.to(device)\n",
|
| 192 |
" \n",
|
| 193 |
+
" # Encode to latent space\n",
|
| 194 |
" with torch.no_grad():\n",
|
| 195 |
" latents = TAESDWrapper.encode(vae, images)\n",
|
| 196 |
" \n",
|
| 197 |
+
" # Training step\n",
|
| 198 |
" loss_dict = model.training_step(latents, optimizer, scaler, use_amp)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
" scheduler.step()\n",
|
| 200 |
" \n",
|
| 201 |
+
" epoch_loss += loss_dict['total']\n",
|
| 202 |
+
" pbar.set_postfix(loss=f\"{loss_dict['total']:.4f}\", diff=f\"{loss_dict['diffusion']:.4f}\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
" \n",
|
| 204 |
+
" avg = epoch_loss / len(dataloader)\n",
|
| 205 |
+
" print(f'Epoch {epoch+1}: loss={avg:.4f}')\n",
|
| 206 |
" \n",
|
| 207 |
" # Generate samples\n",
|
| 208 |
" if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == EPOCHS - 1:\n",
|
|
|
|
| 209 |
" model.eval()\n",
|
| 210 |
" with torch.no_grad():\n",
|
| 211 |
+
" z = model.sample(batch_size=16, steps=SAMPLE_STEPS, ddim=True, progress=False)\n",
|
| 212 |
+
" imgs = TAESDWrapper.decode(vae, z)\n",
|
| 213 |
+
" save_image(imgs, f'./outputs/samples/epoch_{epoch+1:03d}.png', nrow=4, normalize=True, value_range=(-1,1))\n",
|
| 214 |
+
" print(f' Samples saved: ./outputs/samples/epoch_{epoch+1:03d}.png')\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
" \n",
|
| 216 |
+
" # Checkpoint\n",
|
| 217 |
+
" if avg < best_loss:\n",
|
| 218 |
+
" best_loss = avg\n",
|
| 219 |
+
" torch.save(model.state_dict(), './outputs/checkpoints/best.pt')\n",
|
| 220 |
+
" if (epoch+1) % 10 == 0:\n",
|
| 221 |
+
" torch.save({'epoch': epoch+1, 'model': model.state_dict(), 'opt': optimizer.state_dict()},\n",
|
| 222 |
+
" f'./outputs/checkpoints/epoch_{epoch+1:03d}.pt')\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"print('=' * 60)\n",
|
| 225 |
+
"print(f'Done! Best loss: {best_loss:.4f}')"
|
| 226 |
],
|
| 227 |
"outputs": []
|
| 228 |
},
|
|
|
|
| 231 |
"execution_count": null,
|
| 232 |
"metadata": {},
|
| 233 |
"source": [
|
| 234 |
+
"# @title 7. Display Generated Samples\n",
|
|
|
|
| 235 |
"import matplotlib.pyplot as plt\n",
|
| 236 |
"from PIL import Image\n",
|
| 237 |
"import glob\n",
|
| 238 |
"\n",
|
|
|
|
| 239 |
"sample_files = sorted(glob.glob('./outputs/samples/epoch_*.png'))\n",
|
| 240 |
"if sample_files:\n",
|
| 241 |
+
" img = Image.open(sample_files[-1])\n",
|
| 242 |
+
" plt.figure(figsize=(10, 10))\n",
|
|
|
|
| 243 |
" plt.imshow(img)\n",
|
| 244 |
+
" plt.title(f'LiquidFlow — {MODEL_VARIANT}, {IMAGE_SIZE}px (latest)')\n",
|
| 245 |
" plt.axis('off')\n",
|
| 246 |
" plt.show()\n",
|
| 247 |
"else:\n",
|
| 248 |
+
" print('No samples yet — run training first!')"
|
| 249 |
],
|
| 250 |
"outputs": []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
}
|
| 252 |
],
|
| 253 |
"metadata": {
|
| 254 |
+
"colab": {"name": "LiquidFlow_Train", "provenance": []},
|
| 255 |
+
"kernelspec": {"display_name": "Python 3", "name": "python3"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
},
|
| 257 |
"nbformat": 4,
|
| 258 |
"nbformat_minor": 0
|