| """ |
| LiquidGen Training Pipeline v2 |
| |
| Optimized for Colab free tier: |
| - Fast VAE encoding: batch=64 for 256px, batch=32 for 512px (~5x faster) |
| - Auto-limits large datasets (WikiArt capped at 10K by default) |
| - Latent pre-caching: train on pure tensors, no VAE during training |
| - Gradient checkpointing + auto batch size = no OOM |
| - ETA shown on every log line |
| - All datasets pure parquet, open SDXL VAE (no login) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.amp import autocast, GradScaler |
| import math |
| import os |
| import json |
| import time |
| from dataclasses import dataclass, asdict |
|
|
|
|
| DATASET_PRESETS = { |
| "cartoon": { |
| "name": "Norod78/cartoon-blip-captions", |
| "config": "", |
| "image_column": "image", |
| "label_column": "", |
| "num_classes": 0, |
| "max_default": 0, |
| "description": "~2.5K cartoon/anime, unconditional, 181MB — fast", |
| }, |
| "flowers": { |
| "name": "huggan/flowers-102-categories", |
| "config": "", |
| "image_column": "image", |
| "label_column": "", |
| "num_classes": 0, |
| "max_default": 0, |
| "description": "~8K flower photos, unconditional, 331MB", |
| }, |
| "wikiart": { |
| "name": "Artificio/WikiArt", |
| "config": "", |
| "image_column": "image", |
| "label_column": "style", |
| "num_classes": 0, |
| "max_default": 10000, |
| "description": "~105K paintings with styles (auto-capped to 10K for speed)", |
| }, |
| "art_painting": { |
| "name": "huggan/few-shot-art-painting", |
| "config": "", |
| "image_column": "image", |
| "label_column": "", |
| "num_classes": 0, |
| "max_default": 0, |
| "description": "~6K art paintings, unconditional, 511MB", |
| }, |
| } |
|
|
|
|
| def auto_batch_size(model_size, image_size, gpu_mem_gb): |
| param_mem = {"small": 0.66, "base": 1.68, "large": 3.35} |
| base = param_mem.get(model_size, 1.0) |
| act_per_sample = {"small": {256: 0.02, 512: 0.07}, |
| "base": {256: 0.03, 512: 0.13}, |
| "large": {256: 0.05, 512: 0.21}} |
| per_sample = act_per_sample.get(model_size, {}).get(image_size, 0.1) |
| available = gpu_mem_gb - base - 1.5 |
| bs = max(1, int(available / per_sample)) |
| if bs >= 32: return 32 |
| if bs >= 16: return 16 |
| if bs >= 8: return 8 |
| if bs >= 4: return 4 |
| return max(1, bs) |
|
|
|
|
| def _fmt_time(seconds): |
| """Format seconds into human readable string.""" |
| if seconds < 60: return f"{seconds:.0f}s" |
| if seconds < 3600: return f"{seconds/60:.1f}m" |
| return f"{seconds/3600:.1f}h" |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| model_size: str = "small" |
| num_classes: int = 0 |
| class_drop_prob: float = 0.1 |
| dataset_preset: str = "cartoon" |
| image_size: int = 256 |
| max_images: int = 0 |
| vae_id: str = "madebyollin/sdxl-vae-fp16-fix" |
| vae_scaling_factor: float = 0.13025 |
| latent_channels: int = 4 |
| batch_size: int = 0 |
| gradient_accumulation_steps: int = 1 |
| learning_rate: float = 1e-4 |
| weight_decay: float = 0.01 |
| max_grad_norm: float = 2.0 |
| num_epochs: int = 100 |
| warmup_steps: int = 500 |
| ema_decay: float = 0.9999 |
| mixed_precision: bool = True |
| gradient_checkpointing: bool = True |
| min_timestep: float = 0.001 |
| max_timestep: float = 0.999 |
| output_dir: str = "./outputs" |
| save_every_n_steps: int = 2000 |
| sample_every_n_steps: int = 500 |
| log_every_n_steps: int = 25 |
| num_sample_steps: int = 50 |
| cfg_scale: float = 2.0 |
| num_samples: int = 4 |
| seed: int = 42 |
| num_workers: int = 2 |
| compile_model: bool = False |
| push_to_hub: bool = False |
| hub_model_id: str = "" |
|
|
|
|
| def get_model_config(size, num_classes=0, class_drop_prob=0.1): |
| configs = { |
| "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=3.0), |
| "base": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=4.0), |
| "large": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.5, mlp_ratio=4.0), |
| } |
| cfg = configs[size] |
| cfg["num_classes"] = num_classes |
| cfg["class_drop_prob"] = class_drop_prob |
| cfg["use_zigzag"] = True |
| return cfg |
|
|
|
|
| class CachedLatentDataset(Dataset): |
| def __init__(self, cache_path): |
| data = torch.load(cache_path, map_location="cpu", weights_only=True) |
| self.latents = data["latents"] |
| self.labels = data.get("labels", None) |
| print(f"Loaded {len(self.latents)} cached latents: {self.latents.shape}") |
| if self.labels is not None and (self.labels >= 0).any(): |
| print(f" {self.labels[self.labels >= 0].unique().shape[0]} classes") |
| def __len__(self): return len(self.latents) |
| def __getitem__(self, idx): |
| return self.latents[idx], (self.labels[idx] if self.labels is not None else -1) |
|
|
|
|
| def precache_latents(config, cache_path=None): |
| if cache_path is None: |
| cache_path = os.path.join(config.output_dir, "cached_latents.pt") |
| if os.path.exists(cache_path): |
| print(f"Cache exists: {cache_path}") |
| d = torch.load(cache_path, map_location="cpu", weights_only=True) |
| print(f" {d['latents'].shape[0]} latents {d['latents'].shape[1:]}") |
| return cache_path |
|
|
| os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| print(f"Loading VAE: {config.vae_id}...") |
| from diffusers import AutoencoderKL |
| vae = AutoencoderKL.from_pretrained(config.vae_id, torch_dtype=torch.float16).to(device).eval() |
| for p in vae.parameters(): p.requires_grad_(False) |
|
|
| preset = DATASET_PRESETS[config.dataset_preset] |
| print(f"Dataset: {preset['name']}") |
| from datasets import load_dataset |
| from torchvision import transforms |
|
|
| ds_kwargs = {"split": "train"} |
| if preset["config"]: ds_kwargs["name"] = preset["config"] |
| dataset = load_dataset(preset["name"], **ds_kwargs) |
|
|
| transform = transforms.Compose([ |
| transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS), |
| transforms.CenterCrop(config.image_size), transforms.ToTensor(), |
| ]) |
|
|
| if config.max_images > 0: |
| max_imgs = config.max_images |
| elif preset.get("max_default", 0) > 0: |
| max_imgs = preset["max_default"] |
| print(f" Auto-capping to {max_imgs} images (set max_images to override)") |
| else: |
| max_imgs = len(dataset) |
| max_imgs = min(max_imgs, len(dataset)) |
|
|
| encode_bs = 64 if config.image_size <= 256 else 32 |
| print(f" Encoding {max_imgs} images (batch={encode_bs})...") |
|
|
| img_col, lbl_col = preset["image_column"], preset["label_column"] |
| style_to_id = {} |
| all_latents, all_labels = [], [] |
| batch_px, batch_lb = [], [] |
| count = 0 |
| t0 = time.time() |
|
|
| for item in dataset: |
| if count >= max_imgs: break |
| img = item[img_col] |
| if img.mode != "RGB": img = img.convert("RGB") |
| batch_px.append(transform(img)) |
| if lbl_col and lbl_col in item: |
| raw = item[lbl_col] |
| if isinstance(raw, str): |
| if raw not in style_to_id: style_to_id[raw] = len(style_to_id) |
| batch_lb.append(style_to_id[raw]) |
| elif isinstance(raw, int): batch_lb.append(raw) |
| else: batch_lb.append(-1) |
| else: batch_lb.append(-1) |
| count += 1 |
| if len(batch_px) >= encode_bs: |
| with torch.no_grad(): |
| px = torch.stack(batch_px).to(device, dtype=torch.float16) * 2 - 1 |
| lat = vae.encode(px).latent_dist.sample() * config.vae_scaling_factor |
| all_latents.append(lat.cpu().float()) |
| all_labels.extend(batch_lb); batch_px, batch_lb = [], [] |
| elapsed = time.time() - t0 |
| speed = count / elapsed |
| eta = (max_imgs - count) / speed if speed > 0 else 0 |
| if count % (encode_bs * 4) == 0: |
| print(f" {count}/{max_imgs} | {speed:.0f} img/s | ETA {_fmt_time(eta)}") |
|
|
| if batch_px: |
| with torch.no_grad(): |
| px = torch.stack(batch_px).to(device, dtype=torch.float16) * 2 - 1 |
| lat = vae.encode(px).latent_dist.sample() * config.vae_scaling_factor |
| all_latents.append(lat.cpu().float()) |
| all_labels.extend(batch_lb) |
|
|
| all_latents = torch.cat(all_latents, dim=0) |
| all_labels = torch.tensor(all_labels, dtype=torch.long) |
| save_data = {"latents": all_latents, "labels": all_labels} |
| if style_to_id: |
| save_data["style_to_id"] = style_to_id |
| print(f" {len(style_to_id)} style classes") |
| torch.save(save_data, cache_path) |
| mb = os.path.getsize(cache_path) / 1024**2 |
| print(f"Cached {count} latents -> {cache_path} ({mb:.0f}MB, {_fmt_time(time.time()-t0)})") |
| del vae |
| if torch.cuda.is_available(): torch.cuda.empty_cache() |
| return cache_path |
|
|
|
|
| class EMAModel: |
| def __init__(self, model, decay=0.9999): |
| self.decay = decay |
| self.shadow = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad} |
| @torch.no_grad() |
| def update(self, model): |
| for n, p in model.named_parameters(): |
| if p.requires_grad and n in self.shadow: |
| self.shadow[n].mul_(self.decay).add_(p.data, alpha=1 - self.decay) |
| def apply(self, model): |
| self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad} |
| for n, p in model.named_parameters(): |
| if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n]) |
| def restore(self, model): |
| for n, p in model.named_parameters(): |
| if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n]) |
| self.backup = {} |
|
|
|
|
| class FlowMatchingScheduler: |
| def __init__(self, min_t=0.001, max_t=0.999): |
| self.min_t, self.max_t = min_t, max_t |
| def sample_timesteps(self, bs, dev): |
| return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t |
| def add_noise(self, x0, noise, t): |
| t = t.view(-1, 1, 1, 1); return (1 - t) * x0 + t * noise |
| def get_velocity_target(self, x0, noise): |
| return noise - x0 |
| @torch.no_grad() |
| def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0): |
| model.eval(); x = torch.randn(shape, device=dev); dt = 1.0 / num_steps |
| for tv in torch.linspace(1.0, dt, num_steps, device=dev): |
| t = torch.full((shape[0],), tv.item(), device=dev) |
| with torch.amp.autocast("cuda"): |
| if cfg > 1.0 and labels is not None: |
| vc = model(x, t, labels); vu = model(x, t, torch.zeros_like(labels)) |
| v = vu + cfg * (vc - vu) |
| else: v = model(x, t, labels) |
| x = x - dt * v.float() |
| return x |
|
|
|
|
| def cosine_schedule(opt, warmup, total): |
| def lr(s): |
| if s < warmup: return s / max(1, warmup) |
| return max(0, 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(1, total - warmup)))) |
| return torch.optim.lr_scheduler.LambdaLR(opt, lr) |
|
|
|
|
| def train(config): |
| from model import LiquidGen |
| torch.manual_seed(config.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| gpu_mem = 0 |
| if torch.cuda.is_available(): |
| gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1024**3 |
| print(f"GPU: {torch.cuda.get_device_name(0)} ({gpu_mem:.1f} GB)") |
|
|
| if config.batch_size <= 0: |
| config.batch_size = auto_batch_size(config.model_size, config.image_size, gpu_mem) if gpu_mem > 0 else 4 |
| print(f"Auto batch: {config.batch_size}") |
|
|
| os.makedirs(config.output_dir, exist_ok=True) |
| os.makedirs(f"{config.output_dir}/samples", exist_ok=True) |
| os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True) |
|
|
| cache_path = precache_latents(config) |
| train_ds = CachedLatentDataset(cache_path) |
| train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, |
| num_workers=config.num_workers, pin_memory=True, drop_last=True) |
|
|
| mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob) |
| mcfg["in_channels"] = config.latent_channels |
| model = LiquidGen(**mcfg).to(device) |
| if config.gradient_checkpointing: |
| model.enable_gradient_checkpointing() |
| print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M (ckpt={'ON' if config.gradient_checkpointing else 'OFF'})") |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, |
| weight_decay=config.weight_decay, betas=(0.9, 0.999)) |
| total_steps = len(train_dl) * config.num_epochs // config.gradient_accumulation_steps |
| sched = cosine_schedule(opt, config.warmup_steps, total_steps) |
| ema = EMAModel(model, config.ema_decay) |
| scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available()) |
| fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep) |
| lat_size = config.image_size // 8 |
| print(f"Steps: {total_steps} | Batch: {config.batch_size} | Epochs: {config.num_epochs}") |
|
|
| gs = 0; la = 0.0; vae = None; vae_loaded = False |
| print(f"\nTraining!\n") |
| t_start = time.time() |
|
|
| for epoch in range(config.num_epochs): |
| model.train(); et = time.time() |
| for bi, (lats, lbls) in enumerate(train_dl): |
| lats = lats.to(device) |
| lbls = lbls.to(device) if config.num_classes > 0 else None |
| t = fm.sample_timesteps(lats.shape[0], device) |
| noise = torch.randn_like(lats) |
| xt = fm.add_noise(lats, noise, t) |
| vtgt = fm.get_velocity_target(lats, noise) |
| with autocast("cuda", enabled=config.mixed_precision and torch.cuda.is_available()): |
| loss = F.mse_loss(model(xt, t, lbls), vtgt) / config.gradient_accumulation_steps |
| scaler.scale(loss).backward(); la += loss.item() |
| if (bi + 1) % config.gradient_accumulation_steps == 0: |
| scaler.unscale_(opt) |
| gn = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step() |
| ema.update(model); gs += 1 |
| if gs % config.log_every_n_steps == 0: |
| al = la / config.log_every_n_steps |
| elapsed = time.time() - t_start |
| sps = gs / max(elapsed, 1) |
| remaining = (total_steps - gs) / sps if sps > 0 else 0 |
| vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0 |
| pct = gs / total_steps * 100 |
| print(f"step={gs:>6d}/{total_steps} ({pct:.0f}%) | ep={epoch} | " |
| f"loss={al:.4f} | gn={gn:.2f} | lr={opt.param_groups[0]['lr']:.2e} | " |
| f"vram={vram:.1f}G | {sps:.1f} st/s | ETA {_fmt_time(remaining)}") |
| la = 0.0 |
| if math.isnan(al) or al > 50: print("Diverged!"); return |
| if gs % config.sample_every_n_steps == 0: |
| if not vae_loaded: |
| from diffusers import AutoencoderKL |
| vae = AutoencoderKL.from_pretrained(config.vae_id, torch_dtype=torch.float16).to(device).eval() |
| for p in vae.parameters(): p.requires_grad_(False) |
| vae_loaded = True |
| ema.apply(model); model.eval() |
| sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,), device=device) if config.num_classes > 0 else None |
| samp = fm.sample(model, (config.num_samples, config.latent_channels, lat_size, lat_size), |
| device, config.num_sample_steps, sl, config.cfg_scale) |
| with torch.no_grad(): |
| imgs = ((vae.decode(samp.half() / config.vae_scaling_factor).sample + 1) / 2).clamp(0, 1).float() |
| from torchvision.utils import save_image |
| save_image(imgs, f"{config.output_dir}/samples/step_{gs:07d}.png", nrow=2) |
| print(f" Saved samples"); ema.restore(model); model.train() |
| if gs % config.save_every_n_steps == 0: |
| torch.save({"model": model.state_dict(), "ema": ema.shadow, |
| "optimizer": opt.state_dict(), "step": gs, "model_config": mcfg}, |
| f"{config.output_dir}/checkpoints/step_{gs:07d}.pt") |
| ep_time = time.time() - et |
| ep_eta = ep_time * (config.num_epochs - epoch - 1) |
| print(f"Epoch {epoch}/{config.num_epochs} done | {_fmt_time(ep_time)} | ETA {_fmt_time(ep_eta)}\n") |
|
|
| final = f"{config.output_dir}/checkpoints/final.pt" |
| torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final) |
| total_time = time.time() - t_start |
| print(f"\nDone! {gs} steps in {_fmt_time(total_time)} -> {final}") |
|
|