Add Colab/Kaggle training notebook with all dataset options
Browse files- LiquidFlow_Training.ipynb +747 -0
LiquidFlow_Training.ipynb
ADDED
|
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# π LiquidFlow β Liquid-SSM Flow Matching Image Generator\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"A **novel architecture** combining:\n",
|
| 10 |
+
"- **Liquid Time-Constant Networks** (CfC closed-form) β adaptive ODE dynamics, bounded by construction\n",
|
| 11 |
+
"- **Selective State Space Models** (Mamba-style) β linear-time long-range context, parallelizable\n",
|
| 12 |
+
"- **Zigzag Scanning** β 2D spatial awareness for image patches\n",
|
| 13 |
+
"- **Physics-Informed Regularization** β smoothness + total variation constraints\n",
|
| 14 |
+
"- **Rectified Flow Matching** β ODE-based generation (no noise schedule tuning)\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"### π What this notebook does\n",
|
| 17 |
+
"1. **Install & clone** the LiquidFlow codebase\n",
|
| 18 |
+
"2. **Choose a dataset** (CIFAR-10, Flowers-102, CelebA, or custom folder)\n",
|
| 19 |
+
"3. **Choose a model size** (tiny ~6M, small ~14M, base ~38M)\n",
|
| 20 |
+
"4. **Train** with one click β all Colab/Kaggle optimized\n",
|
| 21 |
+
"5. **Generate images** and visualize progress\n",
|
| 22 |
+
"6. **Export** trained model for mobile deployment\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"### π» Hardware Requirements\n",
|
| 25 |
+
"| Config | GPU VRAM | Best For |\n",
|
| 26 |
+
"|--------|----------|----------|\n",
|
| 27 |
+
"| tiny-128 (bs=32) | ~4 GB | Colab free T4, Kaggle |\n",
|
| 28 |
+
"| small-128 (bs=16) | ~8 GB | Colab free T4, Kaggle |\n",
|
| 29 |
+
"| base-256 (bs=8) | ~12 GB | Colab Pro, Kaggle |\n",
|
| 30 |
+
"| 512 (bs=4) | ~14 GB | Colab Pro, A100 |"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "markdown",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"source": [
|
| 37 |
+
"---\n",
|
| 38 |
+
"## 0. Setup & Install"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "code",
|
| 43 |
+
"execution_count": null,
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [],
|
| 46 |
+
"source": [
|
| 47 |
+
"# Check GPU\n",
|
| 48 |
+
"!nvidia-smi || echo 'No GPU β CPU training only (very slow)'\n",
|
| 49 |
+
"import torch\n",
|
| 50 |
+
"print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')\n",
|
| 51 |
+
"if torch.cuda.is_available():\n",
|
| 52 |
+
" print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"cell_type": "code",
|
| 57 |
+
"execution_count": null,
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"outputs": [],
|
| 60 |
+
"source": [
|
| 61 |
+
"# Install dependencies\n",
|
| 62 |
+
"!pip install -q torch torchvision einops pillow matplotlib tqdm"
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"execution_count": null,
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [],
|
| 70 |
+
"source": [
|
| 71 |
+
"# Clone the repo (or just copy the files if already have them)\n",
|
| 72 |
+
"import os\n",
|
| 73 |
+
"if not os.path.exists('liquidflow'):\n",
|
| 74 |
+
" !git clone https://huggingface.co/krystv/LiquidFlow liquidflow_repo\n",
|
| 75 |
+
" !cp -r liquidflow_repo/liquidflow .\n",
|
| 76 |
+
"else:\n",
|
| 77 |
+
" print('liquidflow/ already exists')\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# Verify\n",
|
| 80 |
+
"from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
|
| 81 |
+
"from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
|
| 82 |
+
"from liquidflow.sampling import euler_sample, heun_sample, generate_grid, make_grid_image\n",
|
| 83 |
+
"print('β
LiquidFlow imported successfully!')"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "markdown",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"---\n",
|
| 91 |
+
"## 1. βοΈ Configuration β EDIT THIS CELL\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"Choose your dataset, model size, and training hyperparameters."
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "code",
|
| 98 |
+
"execution_count": null,
|
| 99 |
+
"metadata": {},
|
| 100 |
+
"outputs": [],
|
| 101 |
+
"source": [
|
| 102 |
+
"#@title ποΈ Training Configuration { display-mode: \"form\" }\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"# ============== DATASET ==============\n",
|
| 105 |
+
"#@markdown ### Dataset\n",
|
| 106 |
+
"DATASET = 'cifar10' #@param ['cifar10', 'flowers', 'celeba', 'folder', 'fashion_mnist', 'afhq', 'lsun_churches']\n",
|
| 107 |
+
"CUSTOM_DATA_DIR = '/content/my_images' #@param {type:\"string\"}\n",
|
| 108 |
+
"#@markdown > For 'folder': put images in CUSTOM_DATA_DIR. Supports .png/.jpg/.webp\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"# ============== MODEL ==============\n",
|
| 111 |
+
"#@markdown ### Model\n",
|
| 112 |
+
"MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', '512']\n",
|
| 113 |
+
"IMG_SIZE = 128 #@param [32, 64, 128, 256, 512] {type:\"integer\"}\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"# ============== TRAINING ==============\n",
|
| 116 |
+
"#@markdown ### Training\n",
|
| 117 |
+
"EPOCHS = 100 #@param {type:\"integer\"}\n",
|
| 118 |
+
"BATCH_SIZE = 32 #@param [4, 8, 16, 32, 64, 128] {type:\"integer\"}\n",
|
| 119 |
+
"LEARNING_RATE = 3e-4 #@param {type:\"number\"}\n",
|
| 120 |
+
"GRAD_ACCUM = 1 #@param [1, 2, 4, 8] {type:\"integer\"}\n",
|
| 121 |
+
"USE_AMP = True #@param {type:\"boolean\"}\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"# ============== PHYSICS LOSS ==============\n",
|
| 124 |
+
"#@markdown ### Physics-Informed Regularization\n",
|
| 125 |
+
"LAMBDA_SMOOTH = 0.01 #@param {type:\"number\"}\n",
|
| 126 |
+
"LAMBDA_TV = 0.001 #@param {type:\"number\"}\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"# ============== SAMPLING ==============\n",
|
| 129 |
+
"#@markdown ### Sampling & Logging\n",
|
| 130 |
+
"SAMPLE_EVERY = 5 #@param {type:\"integer\"}\n",
|
| 131 |
+
"SAMPLE_STEPS = 50 #@param [10, 25, 50, 100] {type:\"integer\"}\n",
|
| 132 |
+
"LOG_EVERY = 50 #@param {type:\"integer\"}\n",
|
| 133 |
+
"SAVE_EVERY = 10 #@param {type:\"integer\"}\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"# ============== PATHS ==============\n",
|
| 136 |
+
"OUTPUT_DIR = './outputs'\n",
|
| 137 |
+
"DATA_DIR = './data'\n",
|
| 138 |
+
"\n",
|
| 139 |
+
"# ============== AUTO-CONFIG ==============\n",
|
| 140 |
+
"# Smart batch size based on GPU memory\n",
|
| 141 |
+
"import torch\n",
|
| 142 |
+
"if torch.cuda.is_available():\n",
|
| 143 |
+
" vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9\n",
|
| 144 |
+
" print(f'GPU VRAM: {vram_gb:.1f} GB')\n",
|
| 145 |
+
" \n",
|
| 146 |
+
" # Auto-adjust batch size if needed\n",
|
| 147 |
+
" recommended = {\n",
|
| 148 |
+
" (32, 'tiny'): 128, (64, 'tiny'): 64, (128, 'tiny'): 32,\n",
|
| 149 |
+
" (32, 'small'): 64, (64, 'small'): 32, (128, 'small'): 16,\n",
|
| 150 |
+
" (256, 'base'): 8, (512, '512'): 4,\n",
|
| 151 |
+
" }\n",
|
| 152 |
+
" key = (IMG_SIZE, MODEL_SIZE)\n",
|
| 153 |
+
" if key in recommended and vram_gb < 16:\n",
|
| 154 |
+
" rec_bs = recommended[key]\n",
|
| 155 |
+
" if BATCH_SIZE > rec_bs:\n",
|
| 156 |
+
" print(f'β οΈ Reducing batch size {BATCH_SIZE} β {rec_bs} for {vram_gb:.0f}GB VRAM')\n",
|
| 157 |
+
" BATCH_SIZE = rec_bs\n",
|
| 158 |
+
"else:\n",
|
| 159 |
+
" print('β οΈ No GPU detected β training will be very slow!')\n",
|
| 160 |
+
" USE_AMP = False\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"print(f'\\nπ Config: {MODEL_SIZE}-{IMG_SIZE}, {DATASET}, bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={EPOCHS}')\n",
|
| 163 |
+
"print(f' Physics: Ξ»_smooth={LAMBDA_SMOOTH}, Ξ»_tv={LAMBDA_TV}')\n",
|
| 164 |
+
"print(f' AMP: {USE_AMP}, GradAccum: {GRAD_ACCUM}')"
|
| 165 |
+
]
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"cell_type": "markdown",
|
| 169 |
+
"metadata": {},
|
| 170 |
+
"source": [
|
| 171 |
+
"---\n",
|
| 172 |
+
"## 2. π¦ Load Dataset"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "code",
|
| 177 |
+
"execution_count": null,
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"outputs": [],
|
| 180 |
+
"source": [
|
| 181 |
+
"import torchvision\n",
|
| 182 |
+
"import torchvision.transforms as transforms\n",
|
| 183 |
+
"from torch.utils.data import DataLoader, Dataset, ConcatDataset\n",
|
| 184 |
+
"from pathlib import Path\n",
|
| 185 |
+
"from PIL import Image\n",
|
| 186 |
+
"import os\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"# Standard transform\n",
|
| 189 |
+
"def get_transform(img_size):\n",
|
| 190 |
+
" return transforms.Compose([\n",
|
| 191 |
+
" transforms.Resize(img_size + img_size // 8),\n",
|
| 192 |
+
" transforms.CenterCrop(img_size),\n",
|
| 193 |
+
" transforms.RandomHorizontalFlip(),\n",
|
| 194 |
+
" transforms.ToTensor(),\n",
|
| 195 |
+
" transforms.Normalize([0.5]*3, [0.5]*3),\n",
|
| 196 |
+
" ])\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"class ImageFolderFlat(Dataset):\n",
|
| 199 |
+
" \"\"\"Load all images from a folder (recursively).\"\"\"\n",
|
| 200 |
+
" def __init__(self, root, transform):\n",
|
| 201 |
+
" self.transform = transform\n",
|
| 202 |
+
" self.files = []\n",
|
| 203 |
+
" for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']:\n",
|
| 204 |
+
" self.files.extend(Path(root).rglob(ext))\n",
|
| 205 |
+
" self.files = sorted(self.files)\n",
|
| 206 |
+
" print(f'Found {len(self.files)} images in {root}')\n",
|
| 207 |
+
" def __len__(self): return len(self.files)\n",
|
| 208 |
+
" def __getitem__(self, idx):\n",
|
| 209 |
+
" return self.transform(Image.open(self.files[idx]).convert('RGB'))\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"class GrayscaleToRGB:\n",
|
| 212 |
+
" \"\"\"Convert 1-channel grayscale to 3-channel RGB.\"\"\"\n",
|
| 213 |
+
" def __call__(self, x):\n",
|
| 214 |
+
" if x.shape[0] == 1:\n",
|
| 215 |
+
" x = x.repeat(3, 1, 1)\n",
|
| 216 |
+
" return x\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"tfm = get_transform(IMG_SIZE)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"if DATASET == 'cifar10':\n",
|
| 221 |
+
" dataset = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True, transform=tfm)\n",
|
| 222 |
+
" print(f'β
CIFAR-10: {len(dataset)} images')\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"elif DATASET == 'flowers':\n",
|
| 225 |
+
" ds_train = torchvision.datasets.Flowers102(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
|
| 226 |
+
" ds_val = torchvision.datasets.Flowers102(root=DATA_DIR, split='val', download=True, transform=tfm)\n",
|
| 227 |
+
" ds_test = torchvision.datasets.Flowers102(root=DATA_DIR, split='test', download=True, transform=tfm)\n",
|
| 228 |
+
" dataset = ConcatDataset([ds_train, ds_val, ds_test]) # Use all splits for generation\n",
|
| 229 |
+
" print(f'β
Flowers-102: {len(dataset)} images (all splits)')\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"elif DATASET == 'celeba':\n",
|
| 232 |
+
" dataset = torchvision.datasets.CelebA(root=DATA_DIR, split='train', download=True, transform=tfm)\n",
|
| 233 |
+
" print(f'β
CelebA: {len(dataset)} images')\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"elif DATASET == 'fashion_mnist':\n",
|
| 236 |
+
" fm_tfm = transforms.Compose([\n",
|
| 237 |
+
" transforms.Resize(IMG_SIZE),\n",
|
| 238 |
+
" transforms.ToTensor(),\n",
|
| 239 |
+
" transforms.Normalize([0.5], [0.5]),\n",
|
| 240 |
+
" GrayscaleToRGB(),\n",
|
| 241 |
+
" ])\n",
|
| 242 |
+
" dataset = torchvision.datasets.FashionMNIST(root=DATA_DIR, train=True, download=True, transform=fm_tfm)\n",
|
| 243 |
+
" print(f'β
Fashion-MNIST: {len(dataset)} images (converted to RGB)')\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"elif DATASET == 'afhq':\n",
|
| 246 |
+
" # Download AFHQ from Kaggle or manual\n",
|
| 247 |
+
" afhq_dir = os.path.join(DATA_DIR, 'afhq', 'train')\n",
|
| 248 |
+
" if not os.path.exists(afhq_dir):\n",
|
| 249 |
+
" print('β¬οΈ Downloading AFHQ...')\n",
|
| 250 |
+
" !pip install -q gdown\n",
|
| 251 |
+
" !gdown 1Gof5BaELXlmSJIlvKMYCe9ONYPebkNsf -O {DATA_DIR}/afhq.zip\n",
|
| 252 |
+
" !unzip -q {DATA_DIR}/afhq.zip -d {DATA_DIR}/afhq\n",
|
| 253 |
+
" dataset = ImageFolderFlat(afhq_dir, tfm)\n",
|
| 254 |
+
" print(f'β
AFHQ: {len(dataset)} images')\n",
|
| 255 |
+
"\n",
|
| 256 |
+
"elif DATASET == 'lsun_churches':\n",
|
| 257 |
+
" # LSUN requires manual download β point to extracted folder\n",
|
| 258 |
+
" lsun_dir = os.path.join(DATA_DIR, 'lsun_churches')\n",
|
| 259 |
+
" if not os.path.exists(lsun_dir):\n",
|
| 260 |
+
" print('β LSUN churches not found. Please download and extract to', lsun_dir)\n",
|
| 261 |
+
" print(' See: https://github.com/fyu/lsun')\n",
|
| 262 |
+
" raise FileNotFoundError(lsun_dir)\n",
|
| 263 |
+
" dataset = ImageFolderFlat(lsun_dir, tfm)\n",
|
| 264 |
+
" print(f'β
LSUN Churches: {len(dataset)} images')\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"elif DATASET == 'folder':\n",
|
| 267 |
+
" dataset = ImageFolderFlat(CUSTOM_DATA_DIR, tfm)\n",
|
| 268 |
+
" print(f'β
Custom folder: {len(dataset)} images from {CUSTOM_DATA_DIR}')\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"else:\n",
|
| 271 |
+
" raise ValueError(f'Unknown dataset: {DATASET}')\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"# Show a few samples\n",
|
| 274 |
+
"import matplotlib.pyplot as plt\n",
|
| 275 |
+
"import numpy as np\n",
|
| 276 |
+
"\n",
|
| 277 |
+
"fig, axes = plt.subplots(1, 8, figsize=(16, 2))\n",
|
| 278 |
+
"for i, ax in enumerate(axes):\n",
|
| 279 |
+
" sample = dataset[i]\n",
|
| 280 |
+
" if isinstance(sample, (list, tuple)):\n",
|
| 281 |
+
" sample = sample[0]\n",
|
| 282 |
+
" img = sample * 0.5 + 0.5 # denormalize\n",
|
| 283 |
+
" ax.imshow(img.permute(1, 2, 0).clamp(0, 1).numpy())\n",
|
| 284 |
+
" ax.axis('off')\n",
|
| 285 |
+
"plt.suptitle(f'{DATASET} samples ({IMG_SIZE}Γ{IMG_SIZE})', fontsize=14)\n",
|
| 286 |
+
"plt.tight_layout()\n",
|
| 287 |
+
"plt.show()"
|
| 288 |
+
]
|
| 289 |
+
},
|
| 290 |
+
{
|
| 291 |
+
"cell_type": "markdown",
|
| 292 |
+
"metadata": {},
|
| 293 |
+
"source": [
|
| 294 |
+
"---\n",
|
| 295 |
+
"## 3. ποΈ Build Model"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"cell_type": "code",
|
| 300 |
+
"execution_count": null,
|
| 301 |
+
"metadata": {},
|
| 302 |
+
"outputs": [],
|
| 303 |
+
"source": [
|
| 304 |
+
"import torch\n",
|
| 305 |
+
"from liquidflow.model import liquidflow_tiny, liquidflow_small, liquidflow_base, liquidflow_512\n",
|
| 306 |
+
"\n",
|
| 307 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
| 308 |
+
"\n",
|
| 309 |
+
"model_factories = {\n",
|
| 310 |
+
" 'tiny': liquidflow_tiny,\n",
|
| 311 |
+
" 'small': liquidflow_small,\n",
|
| 312 |
+
" 'base': liquidflow_base,\n",
|
| 313 |
+
" '512': liquidflow_512,\n",
|
| 314 |
+
"}\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"model = model_factories[MODEL_SIZE](img_size=IMG_SIZE).to(device)\n",
|
| 317 |
+
"\n",
|
| 318 |
+
"num_params = model.count_params()\n",
|
| 319 |
+
"print(f'ποΈ LiquidFlow-{MODEL_SIZE}')\n",
|
| 320 |
+
"print(f' Parameters: {num_params:,} ({num_params/1e6:.1f}M)')\n",
|
| 321 |
+
"print(f' Image size: {IMG_SIZE}Γ{IMG_SIZE}')\n",
|
| 322 |
+
"print(f' Patch size: {model.patch_size}')\n",
|
| 323 |
+
"print(f' Num patches: {model.num_patches}')\n",
|
| 324 |
+
"print(f' Model dim: {model.d_model}')\n",
|
| 325 |
+
"print(f' Depth: {model.depth}')\n",
|
| 326 |
+
"print(f' Device: {device}')\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"# Quick forward pass test\n",
|
| 329 |
+
"with torch.no_grad():\n",
|
| 330 |
+
" test_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
|
| 331 |
+
" test_t = torch.tensor([0.5], device=device)\n",
|
| 332 |
+
" test_v = model(test_x, test_t)\n",
|
| 333 |
+
" assert test_v.shape == test_x.shape\n",
|
| 334 |
+
" print(f' β
Forward pass OK: {test_x.shape} β {test_v.shape}')"
|
| 335 |
+
]
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"cell_type": "markdown",
|
| 339 |
+
"metadata": {},
|
| 340 |
+
"source": [
|
| 341 |
+
"---\n",
|
| 342 |
+
"## 4. π Train"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"cell_type": "code",
|
| 347 |
+
"execution_count": null,
|
| 348 |
+
"metadata": {},
|
| 349 |
+
"outputs": [],
|
| 350 |
+
"source": [
|
| 351 |
+
"import math\n",
|
| 352 |
+
"import time\n",
|
| 353 |
+
"import json\n",
|
| 354 |
+
"import torch.nn as nn\n",
|
| 355 |
+
"from torch.cuda.amp import autocast, GradScaler\n",
|
| 356 |
+
"from liquidflow.losses import PhysicsInformedFlowLoss, EMAModel\n",
|
| 357 |
+
"from liquidflow.sampling import euler_sample, make_grid_image\n",
|
| 358 |
+
"from IPython.display import display, clear_output\n",
|
| 359 |
+
"import matplotlib.pyplot as plt\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"# Prepare\n",
|
| 362 |
+
"os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\n",
|
| 363 |
+
"os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"dataloader = DataLoader(\n",
|
| 366 |
+
" dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
|
| 367 |
+
" num_workers=2, pin_memory=True, drop_last=True\n",
|
| 368 |
+
")\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,\n",
|
| 371 |
+
" betas=(0.9, 0.999), weight_decay=0.01)\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"total_steps = EPOCHS * len(dataloader) // GRAD_ACCUM\n",
|
| 374 |
+
"warmup_steps = min(500, total_steps // 10)\n",
|
| 375 |
+
"\n",
|
| 376 |
+
"def cosine_lr(step):\n",
|
| 377 |
+
" if step < warmup_steps:\n",
|
| 378 |
+
" return step / max(1, warmup_steps)\n",
|
| 379 |
+
" progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)\n",
|
| 380 |
+
" return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_lr)\n",
|
| 383 |
+
"criterion = PhysicsInformedFlowLoss(\n",
|
| 384 |
+
" lambda_smooth=LAMBDA_SMOOTH, lambda_tv=LAMBDA_TV\n",
|
| 385 |
+
").to(device)\n",
|
| 386 |
+
"ema = EMAModel(model, decay=0.9999)\n",
|
| 387 |
+
"scaler = GradScaler(enabled=USE_AMP)\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"# Training log\n",
|
| 390 |
+
"all_losses = []\n",
|
| 391 |
+
"global_step = 0\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"print(f'π Training {EPOCHS} epochs, {total_steps} steps')\n",
|
| 394 |
+
"print(f' Effective batch: {BATCH_SIZE} Γ {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}')\n",
|
| 395 |
+
"print(f' LR: {LEARNING_RATE} β warmup {warmup_steps} steps β cosine decay')\n",
|
| 396 |
+
"print()\n",
|
| 397 |
+
"\n",
|
| 398 |
+
"t_start = time.time()\n",
|
| 399 |
+
"\n",
|
| 400 |
+
"for epoch in range(EPOCHS):\n",
|
| 401 |
+
" model.train()\n",
|
| 402 |
+
" epoch_loss = 0.0\n",
|
| 403 |
+
" epoch_flow = 0.0\n",
|
| 404 |
+
" n_batches = 0\n",
|
| 405 |
+
"\n",
|
| 406 |
+
" for batch_idx, batch_data in enumerate(dataloader):\n",
|
| 407 |
+
" if isinstance(batch_data, (list, tuple)):\n",
|
| 408 |
+
" x1 = batch_data[0].to(device)\n",
|
| 409 |
+
" else:\n",
|
| 410 |
+
" x1 = batch_data.to(device)\n",
|
| 411 |
+
"\n",
|
| 412 |
+
" B = x1.shape[0]\n",
|
| 413 |
+
" x0 = torch.randn_like(x1)\n",
|
| 414 |
+
" t = torch.rand(B, device=device)\n",
|
| 415 |
+
" t_e = t.view(B, 1, 1, 1)\n",
|
| 416 |
+
" x_t = t_e * x1 + (1 - t_e) * x0\n",
|
| 417 |
+
"\n",
|
| 418 |
+
" with autocast(enabled=USE_AMP):\n",
|
| 419 |
+
" v_pred = model(x_t, t)\n",
|
| 420 |
+
" loss, ld = criterion(v_pred, x0, x1, t, step=global_step)\n",
|
| 421 |
+
" loss = loss / GRAD_ACCUM\n",
|
| 422 |
+
"\n",
|
| 423 |
+
" scaler.scale(loss).backward()\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" if (batch_idx + 1) % GRAD_ACCUM == 0:\n",
|
| 426 |
+
" scaler.unscale_(optimizer)\n",
|
| 427 |
+
" gn = nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
| 428 |
+
" scaler.step(optimizer)\n",
|
| 429 |
+
" scaler.update()\n",
|
| 430 |
+
" optimizer.zero_grad()\n",
|
| 431 |
+
" scheduler.step()\n",
|
| 432 |
+
" ema.update(model)\n",
|
| 433 |
+
" global_step += 1\n",
|
| 434 |
+
"\n",
|
| 435 |
+
" epoch_loss += ld['total'].item()\n",
|
| 436 |
+
" epoch_flow += ld['flow'].item()\n",
|
| 437 |
+
" n_batches += 1\n",
|
| 438 |
+
"\n",
|
| 439 |
+
" if global_step % LOG_EVERY == 0:\n",
|
| 440 |
+
" avg = epoch_loss / n_batches\n",
|
| 441 |
+
" avg_f = epoch_flow / n_batches\n",
|
| 442 |
+
" lr_now = scheduler.get_last_lr()[0]\n",
|
| 443 |
+
" elapsed = time.time() - t_start\n",
|
| 444 |
+
" it_s = global_step / elapsed\n",
|
| 445 |
+
" all_losses.append({'step': global_step, 'loss': avg, 'flow': avg_f,\n",
|
| 446 |
+
" 'lr': lr_now, 'epoch': epoch})\n",
|
| 447 |
+
" print(f' E{epoch+1} step {global_step}/{total_steps} | '\n",
|
| 448 |
+
" f'loss={avg:.4f} flow={avg_f:.4f} lr={lr_now:.2e} '\n",
|
| 449 |
+
" f'gn={gn:.2f} [{it_s:.1f} it/s]')\n",
|
| 450 |
+
"\n",
|
| 451 |
+
" # End of epoch\n",
|
| 452 |
+
" avg_epoch = epoch_loss / max(1, n_batches)\n",
|
| 453 |
+
" print(f'\\nπ Epoch {epoch+1}/{EPOCHS} β avg loss: {avg_epoch:.4f}\\n')\n",
|
| 454 |
+
"\n",
|
| 455 |
+
" # Sample\n",
|
| 456 |
+
" if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0:\n",
|
| 457 |
+
" model.eval()\n",
|
| 458 |
+
" ema.apply_shadow(model)\n",
|
| 459 |
+
" with torch.no_grad():\n",
|
| 460 |
+
" n_samples = min(16, BATCH_SIZE)\n",
|
| 461 |
+
" imgs = euler_sample(model, (n_samples, 3, IMG_SIZE, IMG_SIZE),\n",
|
| 462 |
+
" num_steps=SAMPLE_STEPS, device=device)\n",
|
| 463 |
+
" imgs = imgs.clamp(-1, 1) * 0.5 + 0.5\n",
|
| 464 |
+
" grid = make_grid_image(imgs, nrow=4)\n",
|
| 465 |
+
" grid.save(f'{OUTPUT_DIR}/samples/epoch_{epoch+1:04d}.png')\n",
|
| 466 |
+
"\n",
|
| 467 |
+
" # Display inline\n",
|
| 468 |
+
" fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
|
| 469 |
+
" ax.imshow(grid)\n",
|
| 470 |
+
" ax.set_title(f'Epoch {epoch+1} β {MODEL_SIZE}-{IMG_SIZE} on {DATASET}')\n",
|
| 471 |
+
" ax.axis('off')\n",
|
| 472 |
+
" plt.tight_layout()\n",
|
| 473 |
+
" plt.show()\n",
|
| 474 |
+
"\n",
|
| 475 |
+
" ema.restore(model)\n",
|
| 476 |
+
" model.train()\n",
|
| 477 |
+
"\n",
|
| 478 |
+
" # Checkpoint\n",
|
| 479 |
+
" if (epoch + 1) % SAVE_EVERY == 0:\n",
|
| 480 |
+
" ckpt = {\n",
|
| 481 |
+
" 'model': model.state_dict(),\n",
|
| 482 |
+
" 'optimizer': optimizer.state_dict(),\n",
|
| 483 |
+
" 'scheduler': scheduler.state_dict(),\n",
|
| 484 |
+
" 'ema': ema.state_dict(),\n",
|
| 485 |
+
" 'epoch': epoch,\n",
|
| 486 |
+
" 'global_step': global_step,\n",
|
| 487 |
+
" }\n",
|
| 488 |
+
" torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/epoch_{epoch+1:04d}.pt')\n",
|
| 489 |
+
" torch.save(ckpt, f'{OUTPUT_DIR}/checkpoints/latest.pt')\n",
|
| 490 |
+
" print(f'πΎ Checkpoint saved: epoch {epoch+1}')\n",
|
| 491 |
+
"\n",
|
| 492 |
+
"# Save final\n",
|
| 493 |
+
"ema.apply_shadow(model)\n",
|
| 494 |
+
"torch.save({'model': model.state_dict(), 'config': {\n",
|
| 495 |
+
" 'model_size': MODEL_SIZE, 'img_size': IMG_SIZE, 'dataset': DATASET,\n",
|
| 496 |
+
" 'num_params': num_params, 'epochs': EPOCHS,\n",
|
| 497 |
+
"}}, f'{OUTPUT_DIR}/liquidflow_final.pt')\n",
|
| 498 |
+
"ema.restore(model)\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"elapsed = time.time() - t_start\n",
|
| 501 |
+
"print(f'\\nβ
Training complete! {elapsed/60:.1f} min total')\n",
|
| 502 |
+
"print(f' Final model: {OUTPUT_DIR}/liquidflow_final.pt')"
|
| 503 |
+
]
|
| 504 |
+
},
|
| 505 |
+
{
|
| 506 |
+
"cell_type": "markdown",
|
| 507 |
+
"metadata": {},
|
| 508 |
+
"source": [
|
| 509 |
+
"---\n",
|
| 510 |
+
"## 5. π Training Curves"
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"cell_type": "code",
|
| 515 |
+
"execution_count": null,
|
| 516 |
+
"metadata": {},
|
| 517 |
+
"outputs": [],
|
| 518 |
+
"source": [
|
| 519 |
+
"import matplotlib.pyplot as plt\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"if all_losses:\n",
|
| 522 |
+
" steps = [d['step'] for d in all_losses]\n",
|
| 523 |
+
" losses = [d['loss'] for d in all_losses]\n",
|
| 524 |
+
" flows = [d['flow'] for d in all_losses]\n",
|
| 525 |
+
" lrs = [d['lr'] for d in all_losses]\n",
|
| 526 |
+
"\n",
|
| 527 |
+
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" ax1.plot(steps, losses, label='Total Loss', alpha=0.8)\n",
|
| 530 |
+
" ax1.plot(steps, flows, label='Flow Loss', alpha=0.8)\n",
|
| 531 |
+
" ax1.set_xlabel('Step'); ax1.set_ylabel('Loss')\n",
|
| 532 |
+
" ax1.set_title('Training Loss'); ax1.legend(); ax1.grid(True, alpha=0.3)\n",
|
| 533 |
+
"\n",
|
| 534 |
+
" ax2.plot(steps, lrs, color='orange')\n",
|
| 535 |
+
" ax2.set_xlabel('Step'); ax2.set_ylabel('LR')\n",
|
| 536 |
+
" ax2.set_title('Learning Rate Schedule'); ax2.grid(True, alpha=0.3)\n",
|
| 537 |
+
"\n",
|
| 538 |
+
" plt.tight_layout()\n",
|
| 539 |
+
" plt.savefig(f'{OUTPUT_DIR}/training_curves.png', dpi=150)\n",
|
| 540 |
+
" plt.show()\n",
|
| 541 |
+
"else:\n",
|
| 542 |
+
" print('No training logs yet.')"
|
| 543 |
+
]
|
| 544 |
+
},
|
| 545 |
+
{
|
| 546 |
+
"cell_type": "markdown",
|
| 547 |
+
"metadata": {},
|
| 548 |
+
"source": [
|
| 549 |
+
"---\n",
|
| 550 |
+
"## 6. π¨ Generate Images"
|
| 551 |
+
]
|
| 552 |
+
},
|
| 553 |
+
{
|
| 554 |
+
"cell_type": "code",
|
| 555 |
+
"execution_count": null,
|
| 556 |
+
"metadata": {},
|
| 557 |
+
"outputs": [],
|
| 558 |
+
"source": [
|
| 559 |
+
"#@title π¨ Generation Settings { display-mode: \"form\" }\n",
|
| 560 |
+
"NUM_IMAGES = 16 #@param {type:\"integer\"}\n",
|
| 561 |
+
"GEN_STEPS = 50 #@param [10, 25, 50, 100, 200] {type:\"integer\"}\n",
|
| 562 |
+
"SAMPLER = 'euler' #@param ['euler', 'heun']\n",
|
| 563 |
+
"SEED = 42 #@param {type:\"integer\"}\n",
|
| 564 |
+
"\n",
|
| 565 |
+
"import torch\n",
|
| 566 |
+
"from liquidflow.sampling import euler_sample, heun_sample, make_grid_image\n",
|
| 567 |
+
"import matplotlib.pyplot as plt\n",
|
| 568 |
+
"\n",
|
| 569 |
+
"# Load best model\n",
|
| 570 |
+
"ckpt_path = f'{OUTPUT_DIR}/liquidflow_final.pt'\n",
|
| 571 |
+
"if os.path.exists(ckpt_path):\n",
|
| 572 |
+
" ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)\n",
|
| 573 |
+
" model.load_state_dict(ckpt['model'])\n",
|
| 574 |
+
" print(f'Loaded: {ckpt_path}')\n",
|
| 575 |
+
"else:\n",
|
| 576 |
+
" print(f'No checkpoint found, using current model weights')\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"model.eval()\n",
|
| 579 |
+
"torch.manual_seed(SEED)\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"shape = (NUM_IMAGES, 3, IMG_SIZE, IMG_SIZE)\n",
|
| 582 |
+
"\n",
|
| 583 |
+
"with torch.no_grad():\n",
|
| 584 |
+
" if SAMPLER == 'euler':\n",
|
| 585 |
+
" images = euler_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
|
| 586 |
+
" else:\n",
|
| 587 |
+
" images = heun_sample(model, shape, num_steps=GEN_STEPS, device=device)\n",
|
| 588 |
+
"\n",
|
| 589 |
+
"images = images.clamp(-1, 1) * 0.5 + 0.5\n",
|
| 590 |
+
"grid = make_grid_image(images, nrow=int(NUM_IMAGES**0.5))\n",
|
| 591 |
+
"grid.save(f'{OUTPUT_DIR}/generated_final.png')\n",
|
| 592 |
+
"\n",
|
| 593 |
+
"plt.figure(figsize=(10, 10))\n",
|
| 594 |
+
"plt.imshow(grid)\n",
|
| 595 |
+
"plt.title(f'LiquidFlow-{MODEL_SIZE} | {DATASET} {IMG_SIZE}Γ{IMG_SIZE} | {GEN_STEPS} steps ({SAMPLER})')\n",
|
| 596 |
+
"plt.axis('off')\n",
|
| 597 |
+
"plt.show()"
|
| 598 |
+
]
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"cell_type": "markdown",
|
| 602 |
+
"metadata": {},
|
| 603 |
+
"source": [
|
| 604 |
+
"---\n",
|
| 605 |
+
"## 7. π± Export for Mobile (ONNX + TorchScript)"
|
| 606 |
+
]
|
| 607 |
+
},
|
| 608 |
+
{
|
| 609 |
+
"cell_type": "code",
|
| 610 |
+
"execution_count": null,
|
| 611 |
+
"metadata": {},
|
| 612 |
+
"outputs": [],
|
| 613 |
+
"source": [
|
| 614 |
+
"# Export to TorchScript for mobile deployment\n",
|
| 615 |
+
"model.eval()\n",
|
| 616 |
+
"\n",
|
| 617 |
+
"# TorchScript (for PyTorch Mobile / ExecuTorch)\n",
|
| 618 |
+
"example_x = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)\n",
|
| 619 |
+
"example_t = torch.tensor([0.5], device=device)\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"try:\n",
|
| 622 |
+
" traced = torch.jit.trace(model, (example_x, example_t))\n",
|
| 623 |
+
" ts_path = f'{OUTPUT_DIR}/liquidflow_mobile.pt'\n",
|
| 624 |
+
" traced.save(ts_path)\n",
|
| 625 |
+
" ts_size_mb = os.path.getsize(ts_path) / 1e6\n",
|
| 626 |
+
" print(f'β
TorchScript saved: {ts_path} ({ts_size_mb:.1f} MB)')\n",
|
| 627 |
+
"except Exception as e:\n",
|
| 628 |
+
" print(f'β οΈ TorchScript export failed: {e}')\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"# ONNX\n",
|
| 631 |
+
"try:\n",
|
| 632 |
+
" onnx_path = f'{OUTPUT_DIR}/liquidflow.onnx'\n",
|
| 633 |
+
" torch.onnx.export(\n",
|
| 634 |
+
" model.cpu(), (example_x.cpu(), example_t.cpu()),\n",
|
| 635 |
+
" onnx_path, opset_version=14,\n",
|
| 636 |
+
" input_names=['image', 'timestep'],\n",
|
| 637 |
+
" output_names=['velocity'],\n",
|
| 638 |
+
" dynamic_axes={'image': {0: 'batch'}, 'timestep': {0: 'batch'}, 'velocity': {0: 'batch'}}\n",
|
| 639 |
+
" )\n",
|
| 640 |
+
" onnx_size_mb = os.path.getsize(onnx_path) / 1e6\n",
|
| 641 |
+
" print(f'β
ONNX saved: {onnx_path} ({onnx_size_mb:.1f} MB)')\n",
|
| 642 |
+
" model.to(device)\n",
|
| 643 |
+
"except Exception as e:\n",
|
| 644 |
+
" print(f'β οΈ ONNX export failed: {e}')\n",
|
| 645 |
+
" model.to(device)"
|
| 646 |
+
]
|
| 647 |
+
},
|
| 648 |
+
{
|
| 649 |
+
"cell_type": "markdown",
|
| 650 |
+
"metadata": {},
|
| 651 |
+
"source": [
|
| 652 |
+
"---\n",
|
| 653 |
+
"## 8. π¬ Architecture Deep Dive\n",
|
| 654 |
+
"\n",
|
| 655 |
+
"### How LiquidFlow works\n",
|
| 656 |
+
"\n",
|
| 657 |
+
"```\n",
|
| 658 |
+
"Noise xβ ~ N(0,I) βββ LiquidFlow v_ΞΈ(xβ, t) βββ Image xβ\n",
|
| 659 |
+
" β\n",
|
| 660 |
+
" ββββββββ΄βββββββ\n",
|
| 661 |
+
" β Patchify β (img β non-overlapping patches)\n",
|
| 662 |
+
" β + PosEmb β (2D learnable positions)\n",
|
| 663 |
+
" β + DepthConvβ (local structure)\n",
|
| 664 |
+
" ββββββββ¬βββββββ\n",
|
| 665 |
+
" β\n",
|
| 666 |
+
" ββββββββββββββΌβββββββββββββ\n",
|
| 667 |
+
" β L Γ LiquidSSM Block β\n",
|
| 668 |
+
" β ββββββββββββββββββββ β\n",
|
| 669 |
+
" β β AdaLN (t-cond) β β\n",
|
| 670 |
+
" β β Zigzag Scan β β β rotates scan pattern per layer\n",
|
| 671 |
+
" β β SelectiveSSM β β β Mamba-style, input-dependent\n",
|
| 672 |
+
" β β + LiquidCfC β β β CfC gating, bounded dynamics\n",
|
| 673 |
+
" β β + FFN β β\n",
|
| 674 |
+
" β β + Skip Connect β β β U-Net style long skips\n",
|
| 675 |
+
" β ββββββββββββββββββββ β\n",
|
| 676 |
+
" ββββββββββββββΌβββββββββββββ\n",
|
| 677 |
+
" β\n",
|
| 678 |
+
" ββββββββ΄βββββββ\n",
|
| 679 |
+
" β DepthConv β\n",
|
| 680 |
+
" β Unpatchify β (patches β img)\n",
|
| 681 |
+
" ββββββββ¬βββββββ\n",
|
| 682 |
+
" β\n",
|
| 683 |
+
" velocity v_ΞΈ\n",
|
| 684 |
+
"```\n",
|
| 685 |
+
"\n",
|
| 686 |
+
"### Key Innovations\n",
|
| 687 |
+
"\n",
|
| 688 |
+
"1. **Liquid CfC Cell**: Instead of solving the ODE `dx/dt = f(x,t)` numerically, we use the\n",
|
| 689 |
+
" closed-form solution `x(t+Ξt) = Ο(-f_Ο) β x(t) + (1 - Ο(-f_Ο)) β f_x`.\n",
|
| 690 |
+
" The sigmoid gating **guarantees bounded dynamics** β no training explosion possible.\n",
|
| 691 |
+
"\n",
|
| 692 |
+
"2. **SSM + Liquid dual path**: The SSM branch captures long-range spatial dependencies\n",
|
| 693 |
+
" via selective scanning; the Liquid branch adds continuous-time adaptive dynamics.\n",
|
| 694 |
+
" A learnable mixing coefficient balances them.\n",
|
| 695 |
+
"\n",
|
| 696 |
+
"3. **Physics-informed loss**: Smoothness (Laplacian) and Total Variation regularizers\n",
|
| 697 |
+
" act as soft PDE constraints on generated images, improving training stability\n",
|
| 698 |
+
" and reducing artifacts without domain-specific physics knowledge.\n",
|
| 699 |
+
"\n",
|
| 700 |
+
"4. **Flow Matching = Liquid ODE**: Rectified flow trains `v_ΞΈ` to follow straight paths\n",
|
| 701 |
+
" from noise to data. This is structurally identical to the LTC ODE, making Liquid\n",
|
| 702 |
+
" networks a natural fit as the velocity field parameterization."
|
| 703 |
+
]
|
| 704 |
+
},
|
| 705 |
+
{
|
| 706 |
+
"cell_type": "markdown",
|
| 707 |
+
"metadata": {},
|
| 708 |
+
"source": [
|
| 709 |
+
"---\n",
|
| 710 |
+
"## 9. π§ͺ Recommended Experiments\n",
|
| 711 |
+
"\n",
|
| 712 |
+
"| Experiment | Dataset | Model | IMG_SIZE | Epochs | Notes |\n",
|
| 713 |
+
"|------------|---------|-------|----------|--------|-------|\n",
|
| 714 |
+
"| Quick sanity check | CIFAR-10 | tiny | 32 | 20 | ~5 min on T4 |\n",
|
| 715 |
+
"| Baseline 128Γ128 | CIFAR-10 | tiny | 128 | 100 | ~2 hrs on T4 |\n",
|
| 716 |
+
"| Quality 128Γ128 | Flowers-102 | small | 128 | 200 | ~4 hrs on T4 |\n",
|
| 717 |
+
"| Faces 128Γ128 | CelebA | small | 128 | 50 | ~6 hrs on T4 |\n",
|
| 718 |
+
"| High-res 512Γ512 | CelebA | 512 | 512 | 100 | needs β₯16GB |\n",
|
| 719 |
+
"| Production | Your data | small | 128 | 300+ | best quality |\n",
|
| 720 |
+
"\n",
|
| 721 |
+
"### Tips for best results:\n",
|
| 722 |
+
"- Start with `tiny` + low epochs to verify everything works\n",
|
| 723 |
+
"- Use `small` for 128Γ128 production quality\n",
|
| 724 |
+
"- Increase `SAMPLE_STEPS` to 100+ for final generation\n",
|
| 725 |
+
"- `heun` sampler gives better quality at half the steps vs `euler`\n",
|
| 726 |
+
"- Physics loss warmup is automatic β don't increase Ξ» too much"
|
| 727 |
+
]
|
| 728 |
+
}
|
| 729 |
+
],
|
| 730 |
+
"metadata": {
|
| 731 |
+
"accelerator": "GPU",
|
| 732 |
+
"colab": {
|
| 733 |
+
"gpuType": "T4",
|
| 734 |
+
"provenance": []
|
| 735 |
+
},
|
| 736 |
+
"kernelspec": {
|
| 737 |
+
"display_name": "Python 3",
|
| 738 |
+
"name": "python3"
|
| 739 |
+
},
|
| 740 |
+
"language_info": {
|
| 741 |
+
"name": "python",
|
| 742 |
+
"version": "3.10.12"
|
| 743 |
+
}
|
| 744 |
+
},
|
| 745 |
+
"nbformat": 4,
|
| 746 |
+
"nbformat_minor": 4
|
| 747 |
+
}
|