krystv commited on
Commit
ce37111
·
verified ·
1 Parent(s): 2ddc44d

Revert to last working version (verbose logs + VAE + simple MSE loss — before v3/v4 broke training)

Browse files
Files changed (1) hide show
  1. LiquidDiffusion_Training.ipynb +1 -1
LiquidDiffusion_Training.ipynb CHANGED
@@ -1 +1 @@
1
- {"nbformat": 4, "nbformat_minor": 0, "metadata": {"colab": {"provenance": [], "gpuType": "T4"}, "kernelspec": {"name": "python3", "display_name": "Python 3"}, "accelerator": "GPU"}, "cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# \ud83c\udf0a LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n", "\n", "A **novel image generation model** combining:\n", "- **Liquid Neural Networks** (CfC) for adaptive, time-aware processing\n", "- **Rectified Flow** for simple, stable training\n", "- **Pretrained SD-VAE** for efficient latent-space training\n", "- **Zero attention** \u2014 fully convolutional\n", "- **Fully parallelizable** \u2014 no sequential ODE loops\n", "\n", "**Repo**: [krystv/liquid-diffusion](https://huggingface.co/krystv/liquid-diffusion)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \u2699\ufe0f Configuration"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "#@title \u2699\ufe0f Training Configuration\n\n# === MODEL ===\nMODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', 'custom']\n# tiny = ~20M params \u2192 best for <50K images, fast on T4\n# small = ~62M params \u2192 best for 50K-200K images\n# base = ~140M params \u2192 best for 200K+ images, needs good GPU\nCUSTOM_CHANNELS = [48, 96, 192]\nCUSTOM_BLOCKS = [1, 2, 3]\nCUSTOM_T_DIM = 192\n\n# === TRAINING MODE ===\nTRAINING_MODE = 'latent' #@param ['latent', 'pixel']\n\n# === IMAGE RESOLUTION ===\nIMAGE_SIZE = 512 #@param [128, 256, 512] {type:\"integer\"}\n\n# === DATASET ===\nDATASET = 'huggan/wikiart' #@param ['huggan/wikiart', 'Dhiraj45/Animes', 'huggan/AFHQv2', 'nielsr/CelebA-faces', 'huggan/anime-faces', 'huggan/flowers-102-categories', 'reach-vb/pokemon-blip-captions', 'Norod78/cartoon-blip-captions']\n# huggan/wikiart \u2192 81K art/paintings/illustrations (RECOMMENDED)\n# Dhiraj45/Animes \u2192 83K anime scenes\n# huggan/AFHQv2 \u2192 16K animal faces (512px native)\n# nielsr/CelebA-faces \u2192 202K celebrity faces\n# huggan/anime-faces \u2192 63K anime faces (64px native - low res!)\n# huggan/flowers-102-categories \u2192 8K flowers\n# reach-vb/pokemon-blip-captions \u2192 833 pokemon\n# Norod78/cartoon-blip-captions \u2192 3K cartoons\nIMAGE_COLUMN = 'image'\nUSE_STREAMING = True #@param {type:\"boolean\"}\n# \u26a0\ufe0f USE STREAMING=True for large datasets (>10K) to avoid RAM issues\nMAX_SAMPLES = None # e.g. 5000 for quick test, None = full\nSTREAMING_STEPS_PER_EPOCH = 2000 # only used when streaming (no len())\n\n# === TRAINING ===\nBATCH_SIZE = 10 #@param {type:\"integer\"}\nLEARNING_RATE = 2e-4 #@param {type:\"number\"}\nWEIGHT_DECAY = 0.01\nNUM_EPOCHS = 50 #@param {type:\"integer\"}\nGRAD_CLIP = 1.0\nEMA_DECAY = 0.9999\nNUM_WORKERS = 2\n\nTIME_SAMPLING = 'logit_normal' #@param ['logit_normal', 'uniform']\nUSE_AMP = True #@param {type:\"boolean\"}\nAMP_DTYPE = 'float16'\n\n# === LR SCHEDULE ===\nLR_SCHEDULE = 'cosine_restarts' #@param ['cosine_restarts', 'cosine', 'constant']\n# cosine_restarts = cosine decay with periodic warm restarts (BEST for breaking plateaus)\n# cosine = standard cosine decay to 0 (can plateau late)\n# constant = flat LR (simple but works)\nWARMUP_FRACTION = 0.02 # fraction of total steps (2% = fast warmup)\nNUM_RESTARTS = 3 # for cosine_restarts: how many times to restart LR\n\n# === RESUME FROM CHECKPOINT ===\nRESUME_FROM = None #@param {type:\"string\"}\n# Set to checkpoint path like './outputs/checkpoints/final.pt' to continue training\n\n# === SAMPLING & CHECKPOINTS ===\nSAMPLE_EVERY = 500 #@param {type:\"integer\"}\nNUM_SAMPLE_IMAGES = 8\nNUM_EULER_STEPS = 50\nSAVE_EVERY = 2000 #@param {type:\"integer\"}\nOUTPUT_DIR = './outputs'\nLOG_EVERY = 50\n\nprint(f'\u2705 Config: {MODEL_SIZE} model, {IMAGE_SIZE}px, mode={TRAINING_MODE}')\nprint(f' Dataset: {DATASET}')\nprint(f' bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}')\nprint(f' Schedule: {LR_SCHEDULE}, warmup={WARMUP_FRACTION*100:.0f}%, restarts={NUM_RESTARTS}')\nprint(f' Streaming: {USE_STREAMING}')\nif RESUME_FROM:\n print(f' \ud83d\udcc2 Resuming from: {RESUME_FROM}')"}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udce6 Install Dependencies"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["!pip install -q datasets diffusers accelerate huggingface_hub Pillow matplotlib\n", "import torch\n", "print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n", "if torch.cuda.is_available():\n", " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfd7\ufe0f Model Architecture"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import math, copy, os, time\nimport torch, torch.nn as nn, torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset, IterableDataset\nfrom torchvision import transforms\nfrom torchvision.utils import save_image, make_grid\n\nclass SinusoidalTimeEmbedding(nn.Module):\n def __init__(self, dim, max_period=10000):\n super().__init__()\n self.dim, self.max_period = dim, max_period\n self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))\n def forward(self, t):\n half = self.dim // 2\n freqs = torch.exp(-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t[:, None] * freqs[None, :]\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n if self.dim % 2: emb = F.pad(emb, (0, 1))\n return self.mlp(emb)\n\nclass AdaLN(nn.Module):\n def __init__(self, dim, cond_dim):\n super().__init__()\n ng = min(32, dim)\n while dim % ng != 0: ng -= 1\n self.norm = nn.GroupNorm(ng, dim, affine=False)\n self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))\n def forward(self, x, t_emb):\n s, sh = self.proj(t_emb).chunk(2, dim=1)\n return self.norm(x) * (1 + s[:,:,None,None]) + sh[:,:,None,None]\n\nclass CCA(nn.Module):\n \"\"\"Compact Channel Attention (from DiCo, May 2025).\n Fixes dead channels in depthwise conv blocks by learning channel gates.\n GAP \u2192 1\u00d71 conv \u2192 sigmoid \u2192 channel-wise multiply. Zero extra spatial cost.\"\"\"\n def __init__(self, dim):\n super().__init__()\n self.fc = nn.Conv2d(dim, dim, 1)\n def forward(self, x):\n return x * torch.sigmoid(self.fc(x.mean(dim=[2,3], keepdim=True)))\n\nclass ParallelCfCBlock(nn.Module):\n \"\"\"CfC Eq.10 + CCA + gate bias tracking (DeepSeek-V3 inspired).\n Fully parallel, no ODE solver. Diffusion timestep = liquid time constant.\"\"\"\n def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=5, dropout=0.0):\n super().__init__()\n hidden = int(dim * expand_ratio)\n self.backbone = nn.Sequential(\n nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim),\n nn.Conv2d(dim, hidden, 1), nn.SiLU())\n self.cca = CCA(hidden) # DiCo: reactivate dormant channels\n self.f_head = nn.Conv2d(hidden, dim, 1)\n self.g_head = nn.Conv2d(hidden, dim, 1)\n self.h_head = nn.Conv2d(hidden, dim, 1)\n self.time_a, self.time_b = nn.Linear(t_dim, dim), nn.Linear(t_dim, dim)\n self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1))\n self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim))\n self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n # DeepSeek-V3 aux-free gate bias (prevents gate collapse)\n self.register_buffer('gate_bias', torch.zeros(1, dim, 1, 1))\n def forward(self, x, t_emb):\n residual = x\n bb = self.cca(self.backbone(x)) # CCA on expanded features\n f, g, h = self.f_head(bb), self.g_head(bb), self.h_head(bb)\n ta, tb = self.time_a(t_emb)[:,:,None,None], self.time_b(t_emb)[:,:,None,None]\n gate = torch.sigmoid(ta * f - tb + self.gate_bias)\n # Track gate stats for bias update (only during training)\n if self.training:\n with torch.no_grad():\n mean_gate = gate.mean(dim=[0,2,3], keepdim=True)\n self.gate_bias += 0.001 * (0.5 - mean_gate) # push toward 0.5\n cfc_out = self.dropout(gate * g + (1.0 - gate) * h)\n t_sc = t_emb.mean(dim=1, keepdim=True)[:,:,None,None]\n alpha = torch.exp(-(F.softplus(self.rho) + 1e-6) * t_sc.abs().clamp(min=0.01))\n out = alpha * residual + (1.0 - alpha) * cfc_out\n return out * torch.sigmoid(self.output_gate(t_emb))[:,:,None,None]\n\nclass MultiScaleSpatialMix(nn.Module):\n def __init__(self, dim, t_dim, kernel_size=5):\n super().__init__()\n self.local_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)\n self.global_pool, self.global_proj = nn.AdaptiveAvgPool2d(1), nn.Conv2d(dim, dim, 1)\n self.merge, self.act, self.adaln = nn.Conv2d(dim*2, dim, 1), nn.SiLU(), AdaLN(dim, t_dim)\n def forward(self, x, t_emb):\n xn = self.adaln(x, t_emb)\n return x + self.act(self.merge(torch.cat([self.local_dw(xn), self.global_proj(self.global_pool(xn)).expand_as(xn)], dim=1)))\n\nclass LiquidDiffusionBlock(nn.Module):\n def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=5, dropout=0.0):\n super().__init__()\n self.adaln1, self.cfc = AdaLN(dim, t_dim), ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout)\n self.spatial_mix, self.adaln2 = MultiScaleSpatialMix(dim, t_dim, kernel_size), AdaLN(dim, t_dim)\n ff_dim = int(dim * expand_ratio)\n self.ff = nn.Sequential(nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1))\n self.res_scale = nn.Parameter(torch.ones(1) * 0.1)\n def forward(self, x, t_emb):\n x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb)\n x = self.spatial_mix(x, t_emb)\n return x + self.res_scale * self.ff(self.adaln2(x, t_emb))\n\nclass DownSample(nn.Module):\n def __init__(self, i, o): super().__init__(); self.conv = nn.Conv2d(i, o, 3, stride=2, padding=1)\n def forward(self, x): return self.conv(x)\nclass UpSample(nn.Module):\n def __init__(self, i, o): super().__init__(); self.conv = nn.Conv2d(i, o, 3, padding=1)\n def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))\nclass SkipFusion(nn.Module):\n def __init__(self, dim, t_dim):\n super().__init__()\n self.proj = nn.Conv2d(dim*2, dim, 1)\n self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid())\n def forward(self, x, skip, t_emb):\n m = self.proj(torch.cat([x, skip], dim=1)); g = self.gate(t_emb)[:,:,None,None]\n return m * g + x * (1 - g)\n\nclass LiquidDiffusionUNet(nn.Module):\n \"\"\"LiquidDiffusion v4: CfC + CCA + multi-scale output heads.\"\"\"\n def __init__(self, in_channels=3, channels=None, blocks_per_stage=None, t_dim=256, expand_ratio=2.0, kernel_size=5, dropout=0.0):\n super().__init__()\n channels = channels or [64,128,256]; blocks_per_stage = blocks_per_stage or [2,2,4]\n assert len(channels) == len(blocks_per_stage)\n self.channels, self.num_stages, self.in_channels = channels, len(channels), in_channels\n self.time_embed = SinusoidalTimeEmbedding(t_dim)\n self.stem = nn.Sequential(nn.Conv2d(in_channels, channels[0], 3, padding=1), nn.SiLU(), nn.Conv2d(channels[0], channels[0], 3, padding=1))\n self.encoder_blocks, self.downsamplers = nn.ModuleList(), nn.ModuleList()\n for i in range(self.num_stages):\n self.encoder_blocks.append(nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout) for _ in range(blocks_per_stage[i])]))\n if i < self.num_stages - 1: self.downsamplers.append(DownSample(channels[i], channels[i+1]))\n self.bottleneck = nn.ModuleList([LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout) for _ in range(2)])\n self.decoder_blocks, self.upsamplers, self.skip_fusions = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()\n # Multi-scale output heads (DiMR-inspired: predict velocity at each decoder scale)\n self.aux_heads = nn.ModuleList()\n for i in range(self.num_stages-1, -1, -1):\n if i < self.num_stages - 1:\n self.upsamplers.append(UpSample(channels[i+1], channels[i])); self.skip_fusions.append(SkipFusion(channels[i], t_dim))\n self.decoder_blocks.append(nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout) for _ in range(blocks_per_stage[i])]))\n self.aux_heads.append(nn.Conv2d(channels[i], in_channels, 1)) # aux velocity pred\n hg = min(32, channels[0])\n while channels[0] % hg != 0: hg -= 1\n self.head = nn.Sequential(nn.GroupNorm(hg, channels[0]), nn.SiLU(), nn.Conv2d(channels[0], in_channels, 3, padding=1))\n nn.init.zeros_(self.head[-1].weight); nn.init.zeros_(self.head[-1].bias)\n for ah in self.aux_heads: nn.init.zeros_(ah.weight); nn.init.zeros_(ah.bias)\n\n def forward(self, x, t, return_multiscale=False):\n t_emb, h = self.time_embed(t), self.stem(x)\n skips = []\n for i in range(self.num_stages):\n for blk in self.encoder_blocks[i]: h = blk(h, t_emb)\n skips.append(h)\n if i < self.num_stages - 1: h = self.downsamplers[i](h)\n for blk in self.bottleneck: h = blk(h, t_emb)\n aux_preds = []\n up_idx = 0\n for di in range(self.num_stages):\n si = self.num_stages - 1 - di\n if di > 0: h = self.upsamplers[up_idx](h); h = self.skip_fusions[up_idx](h, skips[si], t_emb); up_idx += 1\n for blk in self.decoder_blocks[di]: h = blk(h, t_emb)\n if return_multiscale:\n aux_preds.append(self.aux_heads[di](h))\n main_out = self.head(h)\n if return_multiscale:\n return main_out, aux_preds\n return main_out\n\n def count_params(self): return sum(p.numel() for p in self.parameters()), sum(p.numel() for p in self.parameters() if p.requires_grad)\n\nprint('\u2705 LiquidDiffusion v4 loaded (CCA + gate bias + multi-scale heads)')"}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udd27 Build Model + Load VAE"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["device = 'cuda' if torch.cuda.is_available() else 'cpu'\nvae, vae_scale, model_in_channels = None, 1.0, 3\n\nif TRAINING_MODE == 'latent':\n from diffusers import AutoencoderKL\n print('Loading pretrained SD-VAE (stabilityai/sd-vae-ft-mse)...')\n vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse',\n torch_dtype=torch.float16 if (USE_AMP and device=='cuda') else torch.float32\n ).to(device).eval()\n vae.requires_grad_(False)\n vae_scale = vae.config.scaling_factor # 0.18215\n model_in_channels = vae.config.latent_channels # 4\n latent_size = IMAGE_SIZE // 8\n print(f' VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)')\n print(f' Latent: {IMAGE_SIZE}px \\u2192 {latent_size}x{latent_size}x{model_in_channels}')\n if device == 'cuda': print(f' VAE VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB')\nelse:\n latent_size = IMAGE_SIZE\n print('Pixel mode: no VAE')\n\nMODEL_CONFIGS = {\n 'tiny': dict(channels=[64,128,256], blocks_per_stage=[2,2,4], t_dim=256),\n 'small': dict(channels=[96,192,384], blocks_per_stage=[2,3,6], t_dim=384),\n 'base': dict(channels=[128,256,512], blocks_per_stage=[2,4,8], t_dim=512),\n}\ncfg = MODEL_CONFIGS.get(MODEL_SIZE, dict(channels=CUSTOM_CHANNELS, blocks_per_stage=CUSTOM_BLOCKS, t_dim=CUSTOM_T_DIM))\ncfg['in_channels'] = model_in_channels\n\nmodel = LiquidDiffusionUNet(**cfg).to(device)\ntotal_p, _ = model.count_params()\nprint(f'\\nLiquidDiffusion [{MODEL_SIZE}]: {total_p:,} ({total_p/1e6:.1f}M) params')\nprint(f' in_ch={model_in_channels}, channels={cfg[\"channels\"]}, blocks={cfg[\"blocks_per_stage\"]}')\nwith torch.no_grad():\n tx = torch.randn(1, model_in_channels, latent_size, latent_size, device=device)\n to = model(tx, torch.tensor([0.5], device=device))\n print(f' Forward: {tx.shape} \\u2192 {to.shape} \\u2713'); del tx, to\nif device == 'cuda': torch.cuda.empty_cache(); print(f' Total VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcca Load Dataset"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "from PIL import Image\nfrom datasets import load_dataset\nimport matplotlib.pyplot as plt\n\nclass HFImageDataset(Dataset):\n def __init__(self, hf_data, image_size, image_column='image'):\n self.data, self.col = hf_data, image_column\n self.transform = transforms.Compose([\n transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),\n transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __len__(self): return len(self.data)\n def __getitem__(self, idx):\n img = self.data[idx][self.col]\n if not hasattr(img, 'convert'): img = Image.fromarray(img)\n return self.transform(img.convert('RGB'))\n\nclass StreamingImageDataset(IterableDataset):\n \"\"\"Streams from HF Hub with auto-repeat so it never exhausts.\"\"\"\n def __init__(self, name, image_size, image_column='image'):\n self.name, self.col = name, image_column\n self.transform = transforms.Compose([\n transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),\n transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __iter__(self):\n while True: # infinite repeat \u2014 cycles through dataset forever\n ds = load_dataset(self.name, split='train', streaming=True)\n for s in ds:\n img = s[self.col]\n if not hasattr(img, 'convert'): img = Image.fromarray(img)\n yield self.transform(img.convert('RGB'))\n\nprint(f'Loading: {DATASET}')\nif USE_STREAMING:\n dataset = StreamingImageDataset(DATASET, IMAGE_SIZE, IMAGE_COLUMN)\n dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True) # 0 workers for streaming\n print(' Streaming mode')\nelse:\n hf_data = load_dataset(DATASET, split='train')\n if MAX_SAMPLES: hf_data = hf_data.select(range(min(MAX_SAMPLES, len(hf_data))))\n dataset = HFImageDataset(hf_data, IMAGE_SIZE, IMAGE_COLUMN)\n dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)\n print(f' {len(dataset):,} images, {len(dataloader):,} steps/epoch')\n\n# Preview\nsb = next(iter(dataloader))\nfig, axes = plt.subplots(1, min(8, sb.shape[0]), figsize=(16, 2.5))\nif not hasattr(axes, '__len__'): axes = [axes]\nfor i, ax in enumerate(axes): ax.imshow((sb[i].permute(1,2,0)*0.5+0.5).clamp(0,1)); ax.axis('off')\nplt.suptitle(f'{DATASET} ({IMAGE_SIZE}px)'); plt.tight_layout(); plt.show()\n\nif vae is not None:\n with torch.no_grad():\n ti = sb[:4].to(device, dtype=vae.dtype)\n lat = vae.encode(ti).latent_dist.sample() * vae_scale\n dec = vae.decode(lat / vae_scale).sample\n print(f'\\n VAE: {ti.shape} \\u2192 {lat.shape} \\u2192 {dec.shape}')\n print(f' Latent: mean={lat.mean():.4f}, std={lat.std():.4f}')\n fig, axes = plt.subplots(2, 4, figsize=(12, 6))\n for i in range(4):\n axes[0,i].imshow((ti[i].cpu().float().permute(1,2,0)*0.5+0.5).clamp(0,1)); axes[0,i].set_title('Original'); axes[0,i].axis('off')\n axes[1,i].imshow((dec[i].cpu().float().permute(1,2,0)*0.5+0.5).clamp(0,1)); axes[1,i].set_title('VAE Recon'); axes[1,i].axis('off')\n plt.suptitle('VAE Quality Check'); plt.tight_layout(); plt.show()"}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\ude80 Training"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\nos.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))\n\nif USE_STREAMING: steps_per_epoch = STREAMING_STEPS_PER_EPOCH\nelse: steps_per_epoch = len(dataloader)\ntotal_steps = steps_per_epoch * NUM_EPOCHS\nwarmup_steps = max(50, int(total_steps * WARMUP_FRACTION))\n\nif LR_SCHEDULE == 'cosine_restarts':\n restart_period = max(1, (total_steps - warmup_steps) // (NUM_RESTARTS + 1))\n def lr_lambda(step):\n if step < warmup_steps: return float(step) / max(1, warmup_steps)\n cycle_pos = (step - warmup_steps) % restart_period\n return max(0.05, 0.5 * (1 + math.cos(math.pi * cycle_pos / restart_period)))\nelif LR_SCHEDULE == 'cosine':\n def lr_lambda(step):\n if step < warmup_steps: return float(step) / max(1, warmup_steps)\n return max(0.0, 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / max(1, total_steps - warmup_steps))))\nelse:\n def lr_lambda(step):\n if step < warmup_steps: return float(step) / max(1, warmup_steps)\n return 1.0\nscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n\nema_model = copy.deepcopy(model).eval()\nfor p in ema_model.parameters(): p.requires_grad_(False)\nscaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device=='cuda'))\namp_dtype = getattr(torch, AMP_DTYPE) if (USE_AMP and device=='cuda') else torch.float32\n\n# === TRAINING TRICKS (from DeepSeek-V3, FasterDiT, Min-SNR, DiMR research) ===\n\n# 1. Logit-Normal timestep sampling (FLUX/SD3/FasterDiT \u2014 focuses on mid-noise)\ndef sample_time(bs):\n eps = 1e-5\n if TIME_SAMPLING == 'uniform': return torch.rand(bs, device=device)*(1-2*eps)+eps\n return torch.sigmoid(torch.randn(bs, device=device)).clamp(eps, 1-eps)\n\n# 2. Min-SNR-\u03b3 loss weighting (3.4\u00d7 faster convergence \u2014 Min-SNR paper)\ndef min_snr_weight(t, gamma=5.0):\n snr = ((1 - t) / (t + 1e-8)).pow(2)\n return torch.clamp(snr, max=gamma) / (snr + 1)\n\n# 3. Multi-scale velocity loss (DiMR \u2014 breaks gradient starvation at deep layers)\ndef multi_scale_loss(main_pred, aux_preds, v_target, t, gamma=5.0):\n w = min_snr_weight(t, gamma).view(-1, 1, 1, 1)\n # Main loss\n loss = (w * (main_pred - v_target).pow(2)).mean()\n # Auxiliary losses at each decoder scale\n for i, aux in enumerate(aux_preds):\n scale = v_target.shape[-1] // aux.shape[-1]\n if scale > 1:\n target_down = F.avg_pool2d(v_target, scale)\n else:\n target_down = v_target\n aux_weight = 0.25 / (2 ** i) # decreasing weight for coarser scales\n loss += aux_weight * (w * (aux - target_down).pow(2)).mean()\n return loss\n\n# 4. Velocity direction loss (FasterDiT \u2014 cosine similarity on velocity direction)\ndef velocity_direction_loss(pred, target):\n return 1.0 - F.cosine_similarity(\n pred.flatten(2), target.flatten(2), dim=2\n ).mean()\n\nglobal_step, start_epoch, all_losses = 0, 0, []\nif RESUME_FROM and os.path.exists(RESUME_FROM):\n ckpt = torch.load(RESUME_FROM, map_location=device, weights_only=False)\n try: model.load_state_dict(ckpt['model'], strict=False)\n except: print(' \u26a0\ufe0f Partial weight load (architecture changed)')\n try: ema_model.load_state_dict(ckpt['ema_model'], strict=False)\n except: ema_model = copy.deepcopy(model).eval(); [p.requires_grad_(False) for p in ema_model.parameters()]\n global_step = ckpt.get('step', 0); start_epoch = ckpt.get('epoch', 0)\n all_losses = ckpt.get('losses', [])\n for _ in range(global_step): scheduler.step()\n print(f' \ud83d\udcc2 Resumed from step {global_step}')\n\n@torch.no_grad()\ndef generate_samples(step):\n ema_model.eval()\n z = torch.randn(NUM_SAMPLE_IMAGES, model_in_channels, latent_size, latent_size, device=device)\n dt = 1.0 / NUM_EULER_STEPS\n for i in range(NUM_EULER_STEPS, 0, -1):\n t = torch.full((NUM_SAMPLE_IMAGES,), i/NUM_EULER_STEPS, device=device)\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n v = ema_model(z, t, return_multiscale=False)\n if USE_AMP and amp_dtype == torch.float16: v = v.float()\n z = z - v * dt\n if vae is not None: pixels = vae.decode((z / vae_scale).to(vae.dtype)).sample.float()\n else: pixels = z\n pixels = pixels.clamp(-1, 1)\n save_image(make_grid(pixels*0.5+0.5, nrow=int(math.sqrt(NUM_SAMPLE_IMAGES)), padding=2), f'{OUTPUT_DIR}/samples/step_{step:06d}.png')\n return pixels\n\ndef fmt_time(s):\n if s < 60: return f'{s:.0f}s'\n if s < 3600: return f'{s/60:.1f}m'\n return f'{int(s//3600)}h{int((s%3600)//60):02d}m'\n\nbest_loss = float('inf')\nloss_window = []\n\nprint(f'\\n{\"=\"*70}')\nprint(f' \ud83c\udf0a LiquidDiffusion v4 Training')\nprint(f'{\"=\"*70}')\nprint(f' Mode: {TRAINING_MODE} ({latent_size}x{latent_size}x{model_in_channels})')\nprint(f' Model: {MODEL_SIZE} ({total_p/1e6:.1f}M params)')\nprint(f' Dataset: {DATASET}')\nprint(f' Batch size: {BATCH_SIZE}')\nprint(f' Steps: ~{total_steps:,} ({steps_per_epoch}/epoch \u00d7 {NUM_EPOCHS})')\nprint(f' Schedule: {LR_SCHEDULE} (warmup={warmup_steps})')\nprint(f' LR: {LEARNING_RATE}')\nprint(f' Tricks: Min-SNR-\u03b3=5 + velocity direction loss + CCA + multi-scale + gate bias')\nprint(f' Device: {device}')\nif device == 'cuda':\n print(f' GPU: {torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB)')\nprint(f'{\"=\"*70}\\n')\n\ntrain_start = time.time()\nepoch_losses, step_times = [], []\n\nfor epoch in range(start_epoch, NUM_EPOCHS):\n model.train(); epoch_loss, nb_ = 0, 0\n epoch_start = time.time()\n\n for batch_idx, pixel_batch in enumerate(dataloader):\n if USE_STREAMING and batch_idx >= STREAMING_STEPS_PER_EPOCH: break\n step_start = time.time()\n pixel_batch = pixel_batch.to(device, non_blocking=True)\n\n # VAE encode\n if vae is not None:\n with torch.no_grad():\n x0 = vae.encode(pixel_batch.to(vae.dtype)).latent_dist.sample().float() * vae_scale\n else: x0 = pixel_batch\n\n\n\n # Rectified flow\n x1 = torch.randn_like(x0); t = sample_time(x0.shape[0]); te = t[:,None,None,None]\n x_t = (1-te)*x0 + te*x1; v_target = x1 - x0\n\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n main_pred, aux_preds = model(x_t, t, return_multiscale=True)\n # Combined loss: Min-SNR weighted MSE + multi-scale + velocity direction\n loss_mse = multi_scale_loss(main_pred, aux_preds, v_target, t, gamma=5.0)\n loss_dir = 0.1 * velocity_direction_loss(main_pred, v_target)\n loss = loss_mse + loss_dir\n\n optimizer.zero_grad(set_to_none=True); scaler.scale(loss).backward()\n if GRAD_CLIP > 0: scaler.unscale_(optimizer); gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n else: gn = torch.tensor(0.0)\n _s = scaler.get_scale()\n scaler.step(optimizer); scaler.update()\n if scaler.get_scale() >= _s: scheduler.step()\n with torch.no_grad():\n for ep, mp in zip(ema_model.parameters(), model.parameters()): ep.data.mul_(EMA_DECAY).add_(mp.data, alpha=1-EMA_DECAY)\n\n global_step += 1; nb_ += 1\n lv = loss.item(); all_losses.append(lv); epoch_loss += lv\n step_times.append(time.time() - step_start)\n if len(step_times) > 200: step_times = step_times[-200:]\n loss_window.append(lv)\n if len(loss_window) > 500: loss_window.pop(0)\n if lv < best_loss: best_loss = lv\n\n if global_step % LOG_EVERY == 0:\n elapsed = time.time() - train_start\n avg_loss = sum(all_losses[-LOG_EVERY:]) / LOG_EVERY\n avg_step = sum(step_times) / len(step_times)\n sps = 1.0 / avg_step if avg_step > 0 else 0\n lr = scheduler.get_last_lr()[0]\n remaining = (total_steps - global_step) * avg_step\n pct = global_step / total_steps * 100\n if len(loss_window) >= 100:\n d = sum(loss_window[-50:])/50 - sum(loss_window[-100:-50])/50\n trend = f'\u2193{abs(d):.4f}' if d < -0.01 else f'\u2191{d:.4f}' if d > 0.01 else '\u2192stable'\n else: trend = '...'\n mem = f' | VRAM:{torch.cuda.memory_allocated()/1e9:.1f}/{torch.cuda.max_memory_allocated()/1e9:.1f}GB' if device=='cuda' else ''\n print(f'\\n Step {global_step:>6d}/{total_steps} [{pct:5.1f}%] | Epoch {epoch+1}/{NUM_EPOCHS}')\n print(f' Loss: {avg_loss:.4f} (best:{best_loss:.4f} trend:{trend})')\n print(f' LR: {lr:.2e} | Grad:{gn.item() if torch.is_tensor(gn) else gn:.3f} | {sps:.2f}it/s {avg_step*1000:.0f}ms/step')\n print(f' Time: {fmt_time(elapsed)} elapsed | ETA:{fmt_time(remaining)} | {global_step*BATCH_SIZE:,}imgs{mem}')\n\n if global_step % SAMPLE_EVERY == 0:\n t0 = time.time(); samples = generate_samples(global_step)\n print(f' \ud83d\udcf8 Sampled in {time.time()-t0:.1f}s')\n fig, axes = plt.subplots(1, min(8, NUM_SAMPLE_IMAGES), figsize=(16, 2.5))\n if not hasattr(axes, '__len__'): axes = [axes]\n for i, ax in enumerate(axes):\n if i < samples.shape[0]: ax.imshow((samples[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\n plt.suptitle(f'Step {global_step} | Loss:{lv:.4f}'); plt.tight_layout(); plt.show()\n\n if global_step % SAVE_EVERY == 0:\n p = f'{OUTPUT_DIR}/checkpoints/step_{global_step:06d}.pt'\n torch.save({'model':model.state_dict(),'ema_model':ema_model.state_dict(),'optimizer':optimizer.state_dict(),'step':global_step,'epoch':epoch,'losses':all_losses[-2000:],'config':cfg}, p)\n print(f' \ud83d\udcbe Saved {p} ({os.path.getsize(p)/1e6:.0f}MB)')\n\n if nb_ > 0:\n avg_ep = epoch_loss / nb_; epoch_losses.append(avg_ep)\n ed = time.time() - epoch_start; re = (NUM_EPOCHS - epoch - 1) * ed\n delta_str = ''\n if len(epoch_losses) >= 2:\n d = epoch_losses[-1] - epoch_losses[-2]\n delta_str = f' | vs prev:{d:+.4f} {\"\u2705\" if d < 0 else \"\u26a0\ufe0f\" if d > 0.01 else \"\u2192\"}'\n print(f'\\n \u2550\u2550 Epoch {epoch+1}/{NUM_EPOCHS}: loss={avg_ep:.4f} | {fmt_time(ed)} | ETA:{fmt_time(re)}{delta_str} \u2550\u2550')\n\nfinal_p = f'{OUTPUT_DIR}/checkpoints/final.pt'\ntorch.save({'model':model.state_dict(),'ema_model':ema_model.state_dict(),'step':global_step,'config':cfg,'losses':all_losses[-2000:]}, final_p)\nprint(f'\\n{\"=\"*70}')\nprint(f' \u2705 Training complete! {fmt_time(time.time()-train_start)}')\nprint(f' Steps: {global_step:,} | Final loss: {all_losses[-1]:.4f} | Best: {best_loss:.4f}')\nprint(f'{\"=\"*70}')"}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcc8 Training Curves"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import numpy as np\nfig, (a1, a2) = plt.subplots(1, 2, figsize=(14, 5))\na1.plot(all_losses, alpha=0.3, color='blue', linewidth=0.5)\nw = min(200, max(1, len(all_losses)//5))\nif w > 1 and len(all_losses) > w:\n sm = np.convolve(all_losses, np.ones(w)/w, mode='valid')\n a1.plot(range(w-1, len(all_losses)), sm, color='red', linewidth=2, label=f'Smooth(w={w})')\na1.set_xlabel('Step'); a1.set_ylabel('Loss'); a1.set_title('Training Loss'); a1.legend(); a1.grid(True, alpha=0.3)\nif epoch_losses:\n a2.plot(range(1, len(epoch_losses)+1), epoch_losses, 'o-', color='green')\n a2.set_xlabel('Epoch'); a2.set_ylabel('Loss'); a2.set_title('Per Epoch'); a2.grid(True, alpha=0.3)\nplt.tight_layout(); plt.show()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfa8 Generate Images"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["NUM_GENERATE = 16 #@param {type:\"integer\"}\nEULER_STEPS = 50 #@param {type:\"integer\"}\n\nprint(f'Generating {NUM_GENERATE} images ({EULER_STEPS} steps)...')\nema_model.eval()\nwith torch.no_grad():\n z = torch.randn(NUM_GENERATE, model_in_channels, latent_size, latent_size, device=device)\n dt = 1.0 / EULER_STEPS\n for i in range(EULER_STEPS, 0, -1):\n t = torch.full((NUM_GENERATE,), i/EULER_STEPS, device=device)\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'): v = ema_model(z, t)\n if USE_AMP and amp_dtype == torch.float16: v = v.float()\n z = z - v * dt\n if vae is not None: generated = vae.decode((z/vae_scale).to(vae.dtype)).sample.float().clamp(-1,1)\n else: generated = z.clamp(-1,1)\n\nnr = int(math.ceil(math.sqrt(NUM_GENERATE)))\nfig, axes = plt.subplots(nr, nr, figsize=(2.5*nr, 2.5*nr))\naxes = axes.flatten() if hasattr(axes, 'flatten') else [axes]\nfor i, ax in enumerate(axes):\n if i < NUM_GENERATE: ax.imshow((generated[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\nplt.suptitle(f'LiquidDiffusion ({IMAGE_SIZE}px)', fontsize=14); plt.tight_layout(); plt.show()\nsave_image(make_grid(generated*0.5+0.5, nrow=nr, padding=2), f'{OUTPUT_DIR}/final_samples.png')\nprint(f'Saved to {OUTPUT_DIR}/final_samples.png')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcbe Save to Hub"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["PUSH_TO_HUB = False #@param {type:\"boolean\"}\nHUB_MODEL_ID = 'your-username/liquid-diffusion-model' #@param {type:\"string\"}\nif PUSH_TO_HUB:\n from huggingface_hub import HfApi\n api = HfApi(); api.create_repo(HUB_MODEL_ID, exist_ok=True)\n api.upload_file(path_or_fileobj=f'{OUTPUT_DIR}/checkpoints/final.pt', path_in_repo='model.pt', repo_id=HUB_MODEL_ID)\n print(f'Pushed to https://huggingface.co/{HUB_MODEL_ID}')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["---\n", "## \ud83d\udcd6 Architecture\n", "\n", "### CfC Time-Gating\n", "```\n", "gate = \u03c3(time_a(t) \u00b7 f(features) - time_b(t))\n", "out = gate \u00b7 g + (1-gate) \u00b7 h\n", "\u03b1 = exp(-\u03bb\u00b7|t|) \u2192 time-aware residual\n", "```\n", "\n", "### Latent Training Pipeline\n", "```\n", "pixel (3\u00d7256\u00d7256) \u2192 [SD-VAE encode] \u2192 latent (4\u00d732\u00d732) \u2192 [LiquidDiffusion] \u2192 [SD-VAE decode] \u2192 pixel\n", "```\n", "\n", "### References\n", "- [CfC (Nature MI 2022)](https://arxiv.org/abs/2106.13898)\n", "- [LiquidTAD](https://arxiv.org/abs/2604.18274)\n", "- [Rectified Flow (ICLR 2023)](https://arxiv.org/abs/2209.03003)\n", "- [SD-VAE ft-MSE](https://huggingface.co/stabilityai/sd-vae-ft-mse)"]}]}
 
1
+ {"nbformat": 4, "nbformat_minor": 0, "metadata": {"colab": {"provenance": [], "gpuType": "T4"}, "kernelspec": {"name": "python3", "display_name": "Python 3"}, "accelerator": "GPU"}, "cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# \ud83c\udf0a LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n", "\n", "A **novel image generation model** combining:\n", "- **Liquid Neural Networks** (CfC) for adaptive, time-aware processing\n", "- **Rectified Flow** for simple, stable training\n", "- **Pretrained SD-VAE** for efficient latent-space training\n", "- **Zero attention** \u2014 fully convolutional\n", "- **Fully parallelizable** \u2014 no sequential ODE loops\n", "\n", "**Repo**: [krystv/liquid-diffusion](https://huggingface.co/krystv/liquid-diffusion)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \u2699\ufe0f Configuration"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["#@title \u2699\ufe0f Training Configuration\n", "\n", "# === MODEL ===\n", "MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', 'custom']\n", "CUSTOM_CHANNELS = [48, 96, 192]\n", "CUSTOM_BLOCKS = [1, 2, 3]\n", "CUSTOM_T_DIM = 192\n", "\n", "# === TRAINING MODE ===\n", "TRAINING_MODE = 'latent' #@param ['latent', 'pixel']\n", "# latent = train in VAE latent space (4ch, 8x smaller) - RECOMMENDED\n", "# pixel = train directly on RGB pixels (3ch, full res)\n", "\n", "# === IMAGE RESOLUTION ===\n", "IMAGE_SIZE = 256 #@param [128, 256, 512] {type:\"integer\"}\n", "\n", "# === DATASET ===\n", "DATASET = 'huggan/AFHQv2' #@param ['huggan/AFHQv2', 'nielsr/CelebA-faces', 'huggan/flowers-102-categories', 'reach-vb/pokemon-blip-captions', 'huggan/anime-faces', 'Norod78/cartoon-blip-captions']\n", "# huggan/AFHQv2 \u2192 16K animal faces (512px native)\n", "# nielsr/CelebA-faces \u2192 202K celebrity faces\n", "# huggan/flowers-102-categories \u2192 8K flower photos\n", "# reach-vb/pokemon-blip-captions \u2192 833 pokemon illustrations\n", "# huggan/anime-faces \u2192 63K anime faces (64px native)\n", "# Norod78/cartoon-blip-captions \u2192 ~3K cartoon characters\n", "IMAGE_COLUMN = 'image'\n", "USE_STREAMING = False #@param {type:\"boolean\"}\n", "MAX_SAMPLES = None # Set to e.g. 1000 for quick test\n", "\n", "# === TRAINING ===\n", "BATCH_SIZE = 8 #@param {type:\"integer\"}\n", "LEARNING_RATE = 1e-4 #@param {type:\"number\"}\n", "WEIGHT_DECAY = 0.01\n", "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n", "GRAD_CLIP = 1.0\n", "EMA_DECAY = 0.9999\n", "NUM_WORKERS = 2\n", "TIME_SAMPLING = 'logit_normal' #@param ['logit_normal', 'uniform']\n", "USE_AMP = True #@param {type:\"boolean\"}\n", "AMP_DTYPE = 'float16'\n", "\n", "# === SAMPLING & CHECKPOINTS ===\n", "SAMPLE_EVERY = 500 #@param {type:\"integer\"}\n", "NUM_SAMPLE_IMAGES = 8\n", "NUM_EULER_STEPS = 50\n", "SAVE_EVERY = 2000 #@param {type:\"integer\"}\n", "OUTPUT_DIR = './outputs'\n", "RESUME_FROM = None\n", "LOG_EVERY = 50\n", "\n", "print(f'\u2705 Config: {MODEL_SIZE} model, {IMAGE_SIZE}px, mode={TRAINING_MODE}')\n", "print(f' Dataset: {DATASET}')\n", "print(f' bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}, AMP={USE_AMP}')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udce6 Install Dependencies"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["!pip install -q datasets diffusers accelerate huggingface_hub Pillow matplotlib\n", "import torch\n", "print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n", "if torch.cuda.is_available():\n", " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfd7\ufe0f Model Architecture"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "import math, copy, os, time\nimport torch, torch.nn as nn, torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset, IterableDataset\nfrom torchvision import transforms\nfrom torchvision.utils import save_image, make_grid\n\nclass SinusoidalTimeEmbedding(nn.Module):\n def __init__(self, dim, max_period=10000):\n super().__init__()\n self.dim, self.max_period = dim, max_period\n self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))\n def forward(self, t):\n half = self.dim // 2\n freqs = torch.exp(-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t[:, None] * freqs[None, :]\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n if self.dim % 2: emb = F.pad(emb, (0, 1))\n return self.mlp(emb)\n\nclass AdaLN(nn.Module):\n def __init__(self, dim, cond_dim):\n super().__init__()\n ng = min(32, dim)\n while dim % ng != 0: ng -= 1\n self.norm = nn.GroupNorm(ng, dim, affine=False)\n self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))\n def forward(self, x, t_emb):\n s, sh = self.proj(t_emb).chunk(2, dim=1)\n return self.norm(x) * (1 + s[:,:,None,None]) + sh[:,:,None,None]\n\nclass ParallelCfCBlock(nn.Module):\n \"\"\"CfC Eq.10: x(t) = \\u03c3(-f\\u00b7t)\\u2299g + (1-\\u03c3(-f\\u00b7t))\\u2299h \\u2014 fully parallel.\n Optimized: single depthwise in backbone, 1x1 heads only.\"\"\"\n def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=5, dropout=0.0):\n super().__init__()\n hidden = int(dim * expand_ratio)\n self.backbone = nn.Sequential(\n nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim),\n nn.Conv2d(dim, hidden, 1), nn.SiLU())\n self.f_head = nn.Conv2d(hidden, dim, 1)\n self.g_head = nn.Conv2d(hidden, dim, 1)\n self.h_head = nn.Conv2d(hidden, dim, 1)\n self.time_a, self.time_b = nn.Linear(t_dim, dim), nn.Linear(t_dim, dim)\n self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1))\n self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim))\n self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n def forward(self, x, t_emb):\n residual = x\n bb = self.backbone(x)\n f, g, h = self.f_head(bb), self.g_head(bb), self.h_head(bb)\n ta, tb = self.time_a(t_emb)[:,:,None,None], self.time_b(t_emb)[:,:,None,None]\n gate = torch.sigmoid(ta * f - tb)\n cfc_out = self.dropout(gate * g + (1.0 - gate) * h)\n t_sc = t_emb.mean(dim=1, keepdim=True)[:,:,None,None]\n alpha = torch.exp(-(F.softplus(self.rho) + 1e-6) * t_sc.abs().clamp(min=0.01))\n out = alpha * residual + (1.0 - alpha) * cfc_out\n return out * torch.sigmoid(self.output_gate(t_emb))[:,:,None,None]\n\nclass MultiScaleSpatialMix(nn.Module):\n \"\"\"Single large-kernel depthwise + global pool (replaces 3-conv version).\"\"\"\n def __init__(self, dim, t_dim, kernel_size=5):\n super().__init__()\n self.local_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)\n self.global_pool, self.global_proj = nn.AdaptiveAvgPool2d(1), nn.Conv2d(dim, dim, 1)\n self.merge, self.act, self.adaln = nn.Conv2d(dim*2, dim, 1), nn.SiLU(), AdaLN(dim, t_dim)\n def forward(self, x, t_emb):\n xn = self.adaln(x, t_emb)\n return x + self.act(self.merge(torch.cat([self.local_dw(xn), self.global_proj(self.global_pool(xn)).expand_as(xn)], dim=1)))\n\nclass LiquidDiffusionBlock(nn.Module):\n def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=5, dropout=0.0):\n super().__init__()\n self.adaln1, self.cfc = AdaLN(dim, t_dim), ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout)\n self.spatial_mix, self.adaln2 = MultiScaleSpatialMix(dim, t_dim, kernel_size), AdaLN(dim, t_dim)\n ff_dim = int(dim * expand_ratio)\n self.ff = nn.Sequential(nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1))\n self.res_scale = nn.Parameter(torch.ones(1) * 0.1)\n def forward(self, x, t_emb):\n x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb)\n x = self.spatial_mix(x, t_emb)\n return x + self.res_scale * self.ff(self.adaln2(x, t_emb))\n\nclass DownSample(nn.Module):\n def __init__(self, i, o): super().__init__(); self.conv = nn.Conv2d(i, o, 3, stride=2, padding=1)\n def forward(self, x): return self.conv(x)\nclass UpSample(nn.Module):\n def __init__(self, i, o): super().__init__(); self.conv = nn.Conv2d(i, o, 3, padding=1)\n def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))\nclass SkipFusion(nn.Module):\n def __init__(self, dim, t_dim):\n super().__init__()\n self.proj = nn.Conv2d(dim*2, dim, 1)\n self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid())\n def forward(self, x, skip, t_emb):\n m = self.proj(torch.cat([x, skip], dim=1)); g = self.gate(t_emb)[:,:,None,None]\n return m * g + x * (1 - g)\n\nclass LiquidDiffusionUNet(nn.Module):\n def __init__(self, in_channels=3, channels=None, blocks_per_stage=None, t_dim=256, expand_ratio=2.0, kernel_size=5, dropout=0.0):\n super().__init__()\n channels = channels or [64,128,256]; blocks_per_stage = blocks_per_stage or [2,2,4]\n assert len(channels) == len(blocks_per_stage)\n self.channels, self.num_stages, self.in_channels = channels, len(channels), in_channels\n self.time_embed = SinusoidalTimeEmbedding(t_dim)\n self.stem = nn.Sequential(nn.Conv2d(in_channels, channels[0], 3, padding=1), nn.SiLU(), nn.Conv2d(channels[0], channels[0], 3, padding=1))\n self.encoder_blocks, self.downsamplers = nn.ModuleList(), nn.ModuleList()\n for i in range(self.num_stages):\n self.encoder_blocks.append(nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout) for _ in range(blocks_per_stage[i])]))\n if i < self.num_stages - 1: self.downsamplers.append(DownSample(channels[i], channels[i+1]))\n self.bottleneck = nn.ModuleList([LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout) for _ in range(2)])\n self.decoder_blocks, self.upsamplers, self.skip_fusions = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()\n for i in range(self.num_stages-1, -1, -1):\n if i < self.num_stages - 1:\n self.upsamplers.append(UpSample(channels[i+1], channels[i])); self.skip_fusions.append(SkipFusion(channels[i], t_dim))\n self.decoder_blocks.append(nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout) for _ in range(blocks_per_stage[i])]))\n hg = min(32, channels[0])\n while channels[0] % hg != 0: hg -= 1\n self.head = nn.Sequential(nn.GroupNorm(hg, channels[0]), nn.SiLU(), nn.Conv2d(channels[0], in_channels, 3, padding=1))\n nn.init.zeros_(self.head[-1].weight); nn.init.zeros_(self.head[-1].bias)\n def forward(self, x, t):\n t_emb, h = self.time_embed(t), self.stem(x)\n skips = []\n for i in range(self.num_stages):\n for blk in self.encoder_blocks[i]: h = blk(h, t_emb)\n skips.append(h)\n if i < self.num_stages - 1: h = self.downsamplers[i](h)\n for blk in self.bottleneck: h = blk(h, t_emb)\n up_idx = 0\n for di in range(self.num_stages):\n si = self.num_stages - 1 - di\n if di > 0: h = self.upsamplers[up_idx](h); h = self.skip_fusions[up_idx](h, skips[si], t_emb); up_idx += 1\n for blk in self.decoder_blocks[di]: h = blk(h, t_emb)\n return self.head(h)\n def count_params(self): return sum(p.numel() for p in self.parameters()), sum(p.numel() for p in self.parameters() if p.requires_grad)\n\nprint('\\u2705 LiquidDiffusion v2 (optimized) loaded.')"}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udd27 Build Model + Load VAE"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["device = 'cuda' if torch.cuda.is_available() else 'cpu'\nvae, vae_scale, model_in_channels = None, 1.0, 3\n\nif TRAINING_MODE == 'latent':\n from diffusers import AutoencoderKL\n print('Loading pretrained SD-VAE (stabilityai/sd-vae-ft-mse)...')\n vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse',\n torch_dtype=torch.float16 if (USE_AMP and device=='cuda') else torch.float32\n ).to(device).eval()\n vae.requires_grad_(False)\n vae_scale = vae.config.scaling_factor # 0.18215\n model_in_channels = vae.config.latent_channels # 4\n latent_size = IMAGE_SIZE // 8\n print(f' VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)')\n print(f' Latent: {IMAGE_SIZE}px \\u2192 {latent_size}x{latent_size}x{model_in_channels}')\n if device == 'cuda': print(f' VAE VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB')\nelse:\n latent_size = IMAGE_SIZE\n print('Pixel mode: no VAE')\n\nMODEL_CONFIGS = {\n 'tiny': dict(channels=[64,128,256], blocks_per_stage=[2,2,4], t_dim=256),\n 'small': dict(channels=[96,192,384], blocks_per_stage=[2,3,6], t_dim=384),\n 'base': dict(channels=[128,256,512], blocks_per_stage=[2,4,8], t_dim=512),\n}\ncfg = MODEL_CONFIGS.get(MODEL_SIZE, dict(channels=CUSTOM_CHANNELS, blocks_per_stage=CUSTOM_BLOCKS, t_dim=CUSTOM_T_DIM))\ncfg['in_channels'] = model_in_channels\n\nmodel = LiquidDiffusionUNet(**cfg).to(device)\ntotal_p, _ = model.count_params()\nprint(f'\\nLiquidDiffusion [{MODEL_SIZE}]: {total_p:,} ({total_p/1e6:.1f}M) params')\nprint(f' in_ch={model_in_channels}, channels={cfg[\"channels\"]}, blocks={cfg[\"blocks_per_stage\"]}')\nwith torch.no_grad():\n tx = torch.randn(1, model_in_channels, latent_size, latent_size, device=device)\n to = model(tx, torch.tensor([0.5], device=device))\n print(f' Forward: {tx.shape} \\u2192 {to.shape} \\u2713'); del tx, to\nif device == 'cuda': torch.cuda.empty_cache(); print(f' Total VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcca Load Dataset"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["from PIL import Image\nfrom datasets import load_dataset\nimport matplotlib.pyplot as plt\n\nclass HFImageDataset(Dataset):\n def __init__(self, hf_data, image_size, image_column='image'):\n self.data, self.col = hf_data, image_column\n self.transform = transforms.Compose([\n transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),\n transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __len__(self): return len(self.data)\n def __getitem__(self, idx):\n img = self.data[idx][self.col]\n if not hasattr(img, 'convert'): img = Image.fromarray(img)\n return self.transform(img.convert('RGB'))\n\nclass StreamingImageDataset(IterableDataset):\n def __init__(self, name, image_size, image_column='image'):\n self.ds, self.col = load_dataset(name, split='train', streaming=True), image_column\n self.transform = transforms.Compose([\n transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),\n transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __iter__(self):\n for s in self.ds:\n img = s[self.col]\n if not hasattr(img, 'convert'): img = Image.fromarray(img)\n yield self.transform(img.convert('RGB'))\n\nprint(f'Loading: {DATASET}')\nif USE_STREAMING:\n dataset = StreamingImageDataset(DATASET, IMAGE_SIZE, IMAGE_COLUMN)\n dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)\n print(' Streaming mode')\nelse:\n hf_data = load_dataset(DATASET, split='train')\n if MAX_SAMPLES: hf_data = hf_data.select(range(min(MAX_SAMPLES, len(hf_data))))\n dataset = HFImageDataset(hf_data, IMAGE_SIZE, IMAGE_COLUMN)\n dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)\n print(f' {len(dataset):,} images, {len(dataloader):,} steps/epoch')\n\n# Preview\nsb = next(iter(dataloader))\nfig, axes = plt.subplots(1, min(8, sb.shape[0]), figsize=(16, 2.5))\nif not hasattr(axes, '__len__'): axes = [axes]\nfor i, ax in enumerate(axes): ax.imshow((sb[i].permute(1,2,0)*0.5+0.5).clamp(0,1)); ax.axis('off')\nplt.suptitle(f'{DATASET} ({IMAGE_SIZE}px)'); plt.tight_layout(); plt.show()\n\nif vae is not None:\n with torch.no_grad():\n ti = sb[:4].to(device, dtype=vae.dtype)\n lat = vae.encode(ti).latent_dist.sample() * vae_scale\n dec = vae.decode(lat / vae_scale).sample\n print(f'\\n VAE: {ti.shape} \\u2192 {lat.shape} \\u2192 {dec.shape}')\n print(f' Latent: mean={lat.mean():.4f}, std={lat.std():.4f}')\n fig, axes = plt.subplots(2, 4, figsize=(12, 6))\n for i in range(4):\n axes[0,i].imshow((ti[i].cpu().float().permute(1,2,0)*0.5+0.5).clamp(0,1)); axes[0,i].set_title('Original'); axes[0,i].axis('off')\n axes[1,i].imshow((dec[i].cpu().float().permute(1,2,0)*0.5+0.5).clamp(0,1)); axes[1,i].set_title('VAE Recon'); axes[1,i].axis('off')\n plt.suptitle('VAE Quality Check'); plt.tight_layout(); plt.show()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\ude80 Training"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": "os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\nos.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\n# Optimizer + scheduler\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))\ntotal_steps = len(dataloader) * NUM_EPOCHS if not USE_STREAMING else SAMPLE_EVERY * 200\nwarmup_steps = min(1000, total_steps // 10)\ndef lr_lambda(step):\n if step < warmup_steps: return float(step) / max(1, warmup_steps)\n return max(0.0, 0.5 * (1.0 + math.cos(math.pi * (step - warmup_steps) / max(1, total_steps - warmup_steps))))\nscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n\n# EMA\nema_model = copy.deepcopy(model).eval()\nfor p in ema_model.parameters(): p.requires_grad_(False)\nscaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device=='cuda'))\namp_dtype = getattr(torch, AMP_DTYPE) if (USE_AMP and device=='cuda') else torch.float32\n\ndef sample_time(bs):\n eps = 1e-5\n if TIME_SAMPLING == 'uniform': return torch.rand(bs, device=device)*(1-2*eps)+eps\n return torch.sigmoid(torch.randn(bs, device=device)).clamp(eps, 1-eps)\n\nglobal_step, start_epoch, all_losses = 0, 0, []\nif RESUME_FROM and os.path.exists(RESUME_FROM):\n ckpt = torch.load(RESUME_FROM, map_location=device, weights_only=False)\n model.load_state_dict(ckpt['model']); ema_model.load_state_dict(ckpt['ema_model'])\n optimizer.load_state_dict(ckpt['optimizer'])\n global_step, start_epoch = ckpt.get('step',0), ckpt.get('epoch',0)\n all_losses = ckpt.get('losses',[]); print(f'Resumed from step {global_step}')\n\n@torch.no_grad()\ndef generate_samples(step):\n ema_model.eval()\n z = torch.randn(NUM_SAMPLE_IMAGES, model_in_channels, latent_size, latent_size, device=device)\n dt = 1.0 / NUM_EULER_STEPS\n for i in range(NUM_EULER_STEPS, 0, -1):\n t = torch.full((NUM_SAMPLE_IMAGES,), i/NUM_EULER_STEPS, device=device)\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'): v = ema_model(z, t)\n if USE_AMP and amp_dtype == torch.float16: v = v.float()\n z = z - v * dt\n if vae is not None: pixels = vae.decode((z / vae_scale).to(vae.dtype)).sample.float()\n else: pixels = z\n pixels = pixels.clamp(-1, 1)\n save_image(make_grid(pixels*0.5+0.5, nrow=int(math.sqrt(NUM_SAMPLE_IMAGES)), padding=2), f'{OUTPUT_DIR}/samples/step_{step:06d}.png')\n return pixels\n\n# === Verbose logging helpers ===\ndef fmt_time(seconds):\n \"\"\"Format seconds into human-readable string.\"\"\"\n if seconds < 60: return f'{seconds:.0f}s'\n if seconds < 3600: return f'{seconds/60:.1f}m'\n h = int(seconds // 3600); m = int((seconds % 3600) // 60)\n return f'{h}h{m:02d}m'\n\ndef fmt_num(n):\n \"\"\"Format large numbers with K/M suffix.\"\"\"\n if n >= 1e6: return f'{n/1e6:.1f}M'\n if n >= 1e3: return f'{n/1e3:.1f}K'\n return str(n)\n\nbest_loss = float('inf')\nloss_window_500 = [] # track last 500 for trend\n\nprint(f'\\n{\"=\"*70}')\nprint(f' \\U0001f30a LiquidDiffusion Training')\nprint(f'{\"=\"*70}')\nprint(f' Mode: {TRAINING_MODE} ({latent_size}x{latent_size}x{model_in_channels})')\nprint(f' Model: {MODEL_SIZE} ({fmt_num(total_p)} params)')\nprint(f' Dataset: {DATASET}')\nprint(f' Batch size: {BATCH_SIZE}')\nprint(f' Epochs: {NUM_EPOCHS}')\nprint(f' Total steps:~{total_steps:,}')\nprint(f' Warmup: {warmup_steps} steps')\nprint(f' LR: {LEARNING_RATE} (cosine \u2192 0)')\nprint(f' AMP: {USE_AMP} ({AMP_DTYPE})')\nprint(f' Device: {device}')\nif device == 'cuda':\n print(f' GPU: {torch.cuda.get_device_name(0)}')\n print(f' VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB')\n print(f' VRAM total: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB')\nprint(f'{\"=\"*70}\\n')\n\ntrain_start = time.time()\nepoch_losses = []\nstep_times = []\n\nfor epoch in range(start_epoch, NUM_EPOCHS):\n model.train(); epoch_loss, nb_ = 0, 0\n epoch_start = time.time()\n\n for batch_idx, pixel_batch in enumerate(dataloader):\n step_start = time.time()\n pixel_batch = pixel_batch.to(device, non_blocking=True)\n if vae is not None:\n with torch.no_grad(): x0 = vae.encode(pixel_batch.to(vae.dtype)).latent_dist.sample().float() * vae_scale\n else: x0 = pixel_batch\n x1 = torch.randn_like(x0); t = sample_time(x0.shape[0]); te = t[:,None,None,None]\n x_t = (1-te)*x0 + te*x1; v_target = x1 - x0\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n loss = F.mse_loss(model(x_t, t), v_target)\n optimizer.zero_grad(set_to_none=True); scaler.scale(loss).backward()\n if GRAD_CLIP > 0: scaler.unscale_(optimizer); gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n else: gn = torch.tensor(0.0)\n scaler.step(optimizer); scaler.update(); scheduler.step()\n with torch.no_grad():\n for ep, mp in zip(ema_model.parameters(), model.parameters()): ep.data.mul_(EMA_DECAY).add_(mp.data, alpha=1-EMA_DECAY)\n\n global_step += 1; nb_ += 1\n lv = loss.item(); all_losses.append(lv); epoch_loss += lv\n step_dur = time.time() - step_start\n step_times.append(step_dur)\n loss_window_500.append(lv)\n if len(loss_window_500) > 500: loss_window_500.pop(0)\n if lv < best_loss: best_loss = lv\n\n # === VERBOSE LOGGING ===\n if global_step % LOG_EVERY == 0:\n elapsed = time.time() - train_start\n avg_loss = sum(all_losses[-LOG_EVERY:]) / LOG_EVERY\n avg_step_time = sum(step_times[-LOG_EVERY:]) / len(step_times[-LOG_EVERY:])\n sps = 1.0 / avg_step_time if avg_step_time > 0 else 0\n imgs_per_sec = sps * BATCH_SIZE\n lr = scheduler.get_last_lr()[0]\n\n # ETA\n remaining_steps = total_steps - global_step\n eta_seconds = remaining_steps * avg_step_time\n pct = (global_step / total_steps) * 100 if total_steps > 0 else 0\n\n # Loss trend\n if len(loss_window_500) >= 100:\n recent_50 = sum(loss_window_500[-50:]) / 50\n older_50 = sum(loss_window_500[-100:-50]) / 50\n trend = recent_50 - older_50\n if trend < -0.01: trend_str = f'\\u2193{abs(trend):.4f}'\n elif trend > 0.01: trend_str = f'\\u2191{trend:.4f}'\n else: trend_str = '\\u2192stable'\n else:\n trend_str = '...'\n\n # Memory\n if device == 'cuda':\n vram_used = torch.cuda.memory_allocated() / 1e9\n vram_peak = torch.cuda.max_memory_allocated() / 1e9\n mem_str = f' | VRAM: {vram_used:.1f}/{vram_peak:.1f}GB'\n else:\n mem_str = ''\n\n # Grad norm\n gn_val = gn.item() if torch.is_tensor(gn) else gn\n\n print(f'\\n Step {global_step:>6d}/{total_steps} [{pct:5.1f}%] | Epoch {epoch+1}/{NUM_EPOCHS}')\n print(f' Loss: {avg_loss:.4f} (best: {best_loss:.4f}, trend: {trend_str})')\n print(f' LR: {lr:.2e} | Grad norm: {gn_val:.3f}')\n print(f' Speed: {sps:.2f} steps/s | {imgs_per_sec:.1f} imgs/s | {avg_step_time*1000:.0f}ms/step')\n print(f' Elapsed: {fmt_time(elapsed)} | ETA: {fmt_time(eta_seconds)} | Remaining: {remaining_steps:,} steps')\n print(f' Samples: {global_step * BATCH_SIZE:,} images seen{mem_str}')\n\n if global_step % SAMPLE_EVERY == 0:\n print(f'\\n \\U0001f4f8 Generating {NUM_SAMPLE_IMAGES} samples at step {global_step}...')\n t0 = time.time()\n samples = generate_samples(global_step)\n print(f' Sampling took {time.time()-t0:.1f}s ({NUM_EULER_STEPS} Euler steps)')\n fig, axes = plt.subplots(1, min(8, NUM_SAMPLE_IMAGES), figsize=(16, 2.5))\n if not hasattr(axes, '__len__'): axes = [axes]\n for i, ax in enumerate(axes):\n if i < samples.shape[0]: ax.imshow((samples[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\n plt.suptitle(f'Step {global_step} | Loss: {lv:.4f}'); plt.tight_layout(); plt.show()\n\n if global_step % SAVE_EVERY == 0:\n ckpt_path = f'{OUTPUT_DIR}/checkpoints/step_{global_step:06d}.pt'\n torch.save({'model':model.state_dict(),'ema_model':ema_model.state_dict(),'optimizer':optimizer.state_dict(),'step':global_step,'epoch':epoch,'losses':all_losses[-2000:],'config':cfg}, ckpt_path)\n ckpt_mb = os.path.getsize(ckpt_path) / 1e6\n print(f' \\U0001f4be Checkpoint saved: {ckpt_path} ({ckpt_mb:.0f}MB)')\n\n # === END OF EPOCH ===\n if nb_ > 0:\n avg_epoch = epoch_loss / nb_\n epoch_losses.append(avg_epoch)\n epoch_dur = time.time() - epoch_start\n total_elapsed = time.time() - train_start\n remaining_epochs = NUM_EPOCHS - (epoch + 1)\n epoch_eta = remaining_epochs * epoch_dur\n\n print(f'\\n {\"=\"*60}')\n print(f' Epoch {epoch+1}/{NUM_EPOCHS} complete')\n print(f' Avg loss: {avg_epoch:.4f} (best step loss: {best_loss:.4f})')\n print(f' Duration: {fmt_time(epoch_dur)} ({nb_} steps)')\n print(f' Total: {fmt_time(total_elapsed)} elapsed | ~{fmt_time(epoch_eta)} remaining')\n if len(epoch_losses) >= 2:\n delta = epoch_losses[-1] - epoch_losses[-2]\n print(f' vs prev: {delta:+.4f} ({\"improving \\u2705\" if delta < 0 else \"worse \\u26a0\\ufe0f\" if delta > 0.01 else \"flat\"})')\n print(f' {\"=\"*60}')\n\n# === FINAL ===\nfinal_path = f'{OUTPUT_DIR}/checkpoints/final.pt'\ntorch.save({'model':model.state_dict(),'ema_model':ema_model.state_dict(),'step':global_step,'config':cfg,'losses':all_losses[-2000:]}, final_path)\ntotal_time = time.time() - train_start\nprint(f'\\n{\"=\"*70}')\nprint(f' \\u2705 Training complete!')\nprint(f' Total time: {fmt_time(total_time)}')\nprint(f' Total steps: {global_step:,}')\nprint(f' Final loss: {all_losses[-1]:.4f} (best: {best_loss:.4f})')\nprint(f' Checkpoint: {final_path}')\nprint(f' Samples in: {OUTPUT_DIR}/samples/')\nprint(f'{\"=\"*70}')"}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcc8 Training Curves"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import numpy as np\nfig, (a1, a2) = plt.subplots(1, 2, figsize=(14, 5))\na1.plot(all_losses, alpha=0.3, color='blue', linewidth=0.5)\nw = min(200, max(1, len(all_losses)//5))\nif w > 1 and len(all_losses) > w:\n sm = np.convolve(all_losses, np.ones(w)/w, mode='valid')\n a1.plot(range(w-1, len(all_losses)), sm, color='red', linewidth=2, label=f'Smooth(w={w})')\na1.set_xlabel('Step'); a1.set_ylabel('Loss'); a1.set_title('Training Loss'); a1.legend(); a1.grid(True, alpha=0.3)\nif epoch_losses:\n a2.plot(range(1, len(epoch_losses)+1), epoch_losses, 'o-', color='green')\n a2.set_xlabel('Epoch'); a2.set_ylabel('Loss'); a2.set_title('Per Epoch'); a2.grid(True, alpha=0.3)\nplt.tight_layout(); plt.show()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfa8 Generate Images"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["NUM_GENERATE = 16 #@param {type:\"integer\"}\nEULER_STEPS = 50 #@param {type:\"integer\"}\n\nprint(f'Generating {NUM_GENERATE} images ({EULER_STEPS} steps)...')\nema_model.eval()\nwith torch.no_grad():\n z = torch.randn(NUM_GENERATE, model_in_channels, latent_size, latent_size, device=device)\n dt = 1.0 / EULER_STEPS\n for i in range(EULER_STEPS, 0, -1):\n t = torch.full((NUM_GENERATE,), i/EULER_STEPS, device=device)\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'): v = ema_model(z, t)\n if USE_AMP and amp_dtype == torch.float16: v = v.float()\n z = z - v * dt\n if vae is not None: generated = vae.decode((z/vae_scale).to(vae.dtype)).sample.float().clamp(-1,1)\n else: generated = z.clamp(-1,1)\n\nnr = int(math.ceil(math.sqrt(NUM_GENERATE)))\nfig, axes = plt.subplots(nr, nr, figsize=(2.5*nr, 2.5*nr))\naxes = axes.flatten() if hasattr(axes, 'flatten') else [axes]\nfor i, ax in enumerate(axes):\n if i < NUM_GENERATE: ax.imshow((generated[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\nplt.suptitle(f'LiquidDiffusion ({IMAGE_SIZE}px)', fontsize=14); plt.tight_layout(); plt.show()\nsave_image(make_grid(generated*0.5+0.5, nrow=nr, padding=2), f'{OUTPUT_DIR}/final_samples.png')\nprint(f'Saved to {OUTPUT_DIR}/final_samples.png')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcbe Save to Hub"]}, {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["PUSH_TO_HUB = False #@param {type:\"boolean\"}\nHUB_MODEL_ID = 'your-username/liquid-diffusion-model' #@param {type:\"string\"}\nif PUSH_TO_HUB:\n from huggingface_hub import HfApi\n api = HfApi(); api.create_repo(HUB_MODEL_ID, exist_ok=True)\n api.upload_file(path_or_fileobj=f'{OUTPUT_DIR}/checkpoints/final.pt', path_in_repo='model.pt', repo_id=HUB_MODEL_ID)\n print(f'Pushed to https://huggingface.co/{HUB_MODEL_ID}')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["---\n", "## \ud83d\udcd6 Architecture\n", "\n", "### CfC Time-Gating\n", "```\n", "gate = \u03c3(time_a(t) \u00b7 f(features) - time_b(t))\n", "out = gate \u00b7 g + (1-gate) \u00b7 h\n", "\u03b1 = exp(-\u03bb\u00b7|t|) \u2192 time-aware residual\n", "```\n", "\n", "### Latent Training Pipeline\n", "```\n", "pixel (3\u00d7256\u00d7256) \u2192 [SD-VAE encode] \u2192 latent (4\u00d732\u00d732) \u2192 [LiquidDiffusion] \u2192 [SD-VAE decode] \u2192 pixel\n", "```\n", "\n", "### References\n", "- [CfC (Nature MI 2022)](https://arxiv.org/abs/2106.13898)\n", "- [LiquidTAD](https://arxiv.org/abs/2604.18274)\n", "- [Rectified Flow (ICLR 2023)](https://arxiv.org/abs/2209.03003)\n", "- [SD-VAE ft-MSE](https://huggingface.co/stabilityai/sd-vae-ft-mse)"]}]}