| """ |
| Rectified Flow Training for LiquidDiffusion |
| |
| Training Objective (Rectified Flow): |
| x_t = (1-t)*x0 + t*x1, t ~ U[0,1], x1 ~ N(0,I) |
| v_target = x1 - x0 (constant velocity) |
| L = E[||v_θ(x_t, t) - v_target||²] (simple MSE) |
| |
| Sampling (Euler ODE): |
| Start from x_1 ~ N(0,I), integrate backward: |
| x_{t-dt} = x_t - v_θ(x_t, t) * dt |
| |
| References: |
| [1] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023 |
| [2] Lee et al., "Improving the Training of Rectified Flows", 2024 |
| """ |
|
|
| import math |
| import copy |
| import os |
| import time |
| import json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
| from torchvision.utils import save_image, make_grid |
|
|
|
|
| class RectifiedFlowTrainer: |
| """Trainer for LiquidDiffusion using Rectified Flow objective. |
| |
| Features: |
| - Simple MSE velocity loss (no noise schedule to tune) |
| - Optional logit-normal time sampling (from SD3) |
| - EMA model for stable sampling |
| - Gradient clipping, mixed precision |
| - Checkpoint save/load with resume support |
| """ |
| |
| def __init__(self, model, optimizer=None, lr=1e-4, weight_decay=0.01, |
| ema_decay=0.9999, grad_clip=1.0, time_sampling="logit_normal", |
| logit_normal_mean=0.0, logit_normal_std=1.0, device="cuda", |
| use_amp=True, amp_dtype="float16"): |
| self.device = device |
| self.model = model.to(device) |
| self.ema_decay = ema_decay |
| self.grad_clip = grad_clip |
| self.time_sampling = time_sampling |
| self.logit_normal_mean = logit_normal_mean |
| self.logit_normal_std = logit_normal_std |
| |
| |
| self.use_amp = use_amp and (device == "cuda") |
| self.amp_dtype = getattr(torch, amp_dtype) if self.use_amp else torch.float32 |
| |
| if optimizer is None: |
| self.optimizer = torch.optim.AdamW( |
| model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999)) |
| else: |
| self.optimizer = optimizer |
| |
| |
| if self.use_amp and amp_dtype == "float16": |
| self.scaler = torch.amp.GradScaler("cuda", enabled=True) |
| else: |
| self.scaler = torch.amp.GradScaler("cuda", enabled=False) |
| |
| self.ema_model = self._build_ema() |
| self.step = 0 |
| self.losses = [] |
| |
| def _build_ema(self): |
| """Create EMA copy of model.""" |
| ema = copy.deepcopy(self.model) |
| ema.eval() |
| for p in ema.parameters(): |
| p.requires_grad_(False) |
| return ema |
| |
| @torch.no_grad() |
| def _update_ema(self): |
| """Update EMA weights: ema = decay * ema + (1-decay) * model""" |
| for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()): |
| ema_p.data.mul_(self.ema_decay).add_(model_p.data, alpha=1 - self.ema_decay) |
| |
| def _sample_time(self, batch_size): |
| """Sample timesteps. logit_normal puts more weight near t=0.5.""" |
| eps = 1e-5 |
| if self.time_sampling == "uniform": |
| return torch.rand(batch_size, device=self.device) * (1 - 2*eps) + eps |
| elif self.time_sampling == "logit_normal": |
| u = torch.randn(batch_size, device=self.device) * self.logit_normal_std + self.logit_normal_mean |
| return torch.sigmoid(u).clamp(eps, 1 - eps) |
| raise ValueError(f"Unknown time_sampling: {self.time_sampling}") |
| |
| def train_step(self, x0): |
| """Single training step. x0: [B,C,H,W] images in [-1,1].""" |
| self.model.train() |
| x0 = x0.to(self.device) |
| x1 = torch.randn_like(x0) |
| t = self._sample_time(x0.shape[0]) |
| t_expand = t[:, None, None, None] |
| x_t = (1 - t_expand) * x0 + t_expand * x1 |
| v_target = x1 - x0 |
| |
| |
| if self.use_amp: |
| with torch.amp.autocast("cuda", dtype=self.amp_dtype): |
| v_pred = self.model(x_t, t) |
| loss = F.mse_loss(v_pred, v_target) |
| else: |
| v_pred = self.model(x_t, t) |
| loss = F.mse_loss(v_pred, v_target) |
| |
| |
| self.optimizer.zero_grad(set_to_none=True) |
| self.scaler.scale(loss).backward() |
| |
| if self.grad_clip > 0: |
| self.scaler.unscale_(self.optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) |
| else: |
| grad_norm = torch.tensor(0.0) |
| |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self._update_ema() |
| |
| self.step += 1 |
| loss_val = loss.item() |
| self.losses.append(loss_val) |
| return { |
| 'loss': loss_val, |
| 'grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, |
| 'step': self.step, |
| } |
| |
| @torch.no_grad() |
| def sample(self, batch_size=4, image_size=256, channels=3, num_steps=50, use_ema=True): |
| """Generate images via Euler ODE integration from noise → data.""" |
| model = self.ema_model if use_ema else self.model |
| model.eval() |
| z = torch.randn(batch_size, channels, image_size, image_size, device=self.device) |
| dt = 1.0 / num_steps |
| |
| for i in range(num_steps, 0, -1): |
| t = torch.full((batch_size,), i / num_steps, device=self.device) |
| if self.use_amp: |
| with torch.amp.autocast("cuda", dtype=self.amp_dtype): |
| v = model(z, t) |
| v = v.float() |
| else: |
| v = model(z, t) |
| z = z - v * dt |
| |
| return z.clamp(-1, 1) |
| |
| def save_checkpoint(self, path, extra=None): |
| """Save model, EMA, optimizer, scaler, and training state.""" |
| ckpt = { |
| 'model': self.model.state_dict(), |
| 'ema_model': self.ema_model.state_dict(), |
| 'optimizer': self.optimizer.state_dict(), |
| 'scaler': self.scaler.state_dict(), |
| 'step': self.step, |
| 'losses': self.losses[-1000:], |
| } |
| if extra: |
| ckpt.update(extra) |
| dir_path = os.path.dirname(path) |
| if dir_path: |
| os.makedirs(dir_path, exist_ok=True) |
| torch.save(ckpt, path) |
| |
| def load_checkpoint(self, path): |
| """Load checkpoint and resume training.""" |
| ckpt = torch.load(path, map_location=self.device, weights_only=False) |
| self.model.load_state_dict(ckpt['model']) |
| self.ema_model.load_state_dict(ckpt['ema_model']) |
| self.optimizer.load_state_dict(ckpt['optimizer']) |
| if 'scaler' in ckpt: |
| self.scaler.load_state_dict(ckpt['scaler']) |
| self.step = ckpt.get('step', 0) |
| self.losses = ckpt.get('losses', []) |
| print(f"Resumed from step {self.step}") |
|
|
|
|
| class ImageDataset(Dataset): |
| """Image dataset from local folder or HuggingFace Hub. |
| |
| Usage: |
| # From HuggingFace |
| ds = ImageDataset("huggan/CelebA-HQ", image_size=256) |
| |
| # From local folder |
| ds = ImageDataset("/path/to/images", image_size=256) |
| |
| # With pre-loaded HF dataset |
| from datasets import load_dataset |
| hf_ds = load_dataset("huggan/CelebA-HQ", split="train") |
| ds = ImageDataset(None, image_size=256, hf_dataset=hf_ds) |
| """ |
| |
| def __init__(self, source, image_size=256, split="train", |
| image_column="image", max_samples=None, hf_dataset=None): |
| self.image_size = image_size |
| self.image_column = image_column |
| self.transform = transforms.Compose([ |
| transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS), |
| transforms.CenterCrop(image_size), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ]) |
| |
| if hf_dataset is not None: |
| self.data = hf_dataset |
| self.mode = "hf" |
| elif source and os.path.isdir(source): |
| from glob import glob |
| self.files = [] |
| for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']: |
| self.files.extend(glob(os.path.join(source, '**', ext), recursive=True)) |
| self.files.sort() |
| if max_samples: |
| self.files = self.files[:max_samples] |
| self.mode = "folder" |
| else: |
| from datasets import load_dataset |
| self.data = load_dataset(source, split=split) |
| if max_samples: |
| self.data = self.data.select(range(min(max_samples, len(self.data)))) |
| self.mode = "hf" |
| |
| def __len__(self): |
| return len(self.files) if self.mode == "folder" else len(self.data) |
| |
| def __getitem__(self, idx): |
| if self.mode == "folder": |
| from PIL import Image |
| img = Image.open(self.files[idx]).convert("RGB") |
| else: |
| img = self.data[idx][self.image_column] |
| if not hasattr(img, 'convert'): |
| from PIL import Image as PILImage |
| img = PILImage.fromarray(img) |
| img = img.convert("RGB") |
| return self.transform(img) |
|
|
|
|
| def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): |
| """Cosine annealing with linear warmup — standard for diffusion training.""" |
| def lr_lambda(step): |
| if step < num_warmup_steps: |
| return float(step) / float(max(1, num_warmup_steps)) |
| progress = float(step - num_warmup_steps) / float( |
| max(1, num_training_steps - num_warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|