| """ |
| Physics-Informed Regularization for LiquidFlow. |
| CORRECTED VERSION: fixed intensity tracking, proper buffer handling. |
| |
| Pattern from: Bastek & Sun (ICLR 2025) |
| - Physics losses computed on estimated x̂₀ during training |
| - Zero cost at inference |
| - Acts as implicit regularizer against artifacts |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class PhysicsRegularizer(nn.Module): |
| """ |
| Physics-informed regularizer for diffusion training. |
| |
| Computed on estimated clean sample x̂₀ (DDIM one-step estimate). |
| All losses are differentiable through the noise predictor. |
| """ |
| |
| def __init__(self, tv_weight=0.01, cons_weight=0.001, spec_weight=0.01, grad_weight=0.001): |
| super().__init__() |
| self.tv_weight = tv_weight |
| self.cons_weight = cons_weight |
| self.spec_weight = spec_weight |
| self.grad_weight = grad_weight |
| |
| |
| self.register_buffer('intensity_ema', torch.tensor(0.0)) |
| self.register_buffer('step_count', torch.tensor(0, dtype=torch.long)) |
| |
| def total_variation(self, x): |
| """L1 total variation: encourages spatial smoothness.""" |
| diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) |
| diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) |
| return diff_h.mean() + diff_w.mean() |
| |
| def conservation_intensity(self, x): |
| """Penalize deviation from running mean intensity.""" |
| batch_mean = x.mean() |
| |
| if self.training: |
| with torch.no_grad(): |
| self.step_count += 1 |
| alpha = min(0.99, 1.0 - 1.0 / (self.step_count.float() + 1)) |
| self.intensity_ema = alpha * self.intensity_ema + (1 - alpha) * batch_mean |
| |
| |
| if self.step_count > 100: |
| return (batch_mean - self.intensity_ema.detach()) ** 2 |
| return torch.zeros(1, device=x.device, requires_grad=True).squeeze() |
| |
| def spectral_regularizer(self, x): |
| """Penalize high-frequency energy (anti-checkerboard).""" |
| B, C, H, W = x.shape |
| |
| |
| x_fft = torch.fft.rfft2(x, norm='ortho') |
| mag = torch.abs(x_fft) |
| |
| |
| |
| freq_h = torch.arange(H, device=x.device).float() |
| freq_w = torch.arange(W // 2 + 1, device=x.device).float() |
| |
| |
| freq_h = torch.min(freq_h, H - freq_h) / (H / 2) |
| freq_w = freq_w / (W / 2) |
| |
| |
| dist = torch.sqrt(freq_h.unsqueeze(1) ** 2 + freq_w.unsqueeze(0) ** 2) |
| |
| |
| high_mask = (dist > 0.5).float() |
| |
| high_energy = (mag * high_mask.unsqueeze(0).unsqueeze(0)).mean() |
| return high_energy |
| |
| def gradient_penalty(self, x): |
| """Sobolev L2 gradient penalty.""" |
| grad_h = x[:, :, 1:, :] - x[:, :, :-1, :] |
| grad_w = x[:, :, :, 1:] - x[:, :, :, :-1] |
| return (grad_h ** 2).mean() + (grad_w ** 2).mean() |
| |
| def forward(self, x0_hat, x_ref=None): |
| """ |
| Args: |
| x0_hat: Estimated clean image [B, C, H, W] |
| x_ref: Ground truth (unused, kept for API compat) |
| Returns: |
| total_loss, loss_dict |
| """ |
| losses = {} |
| total = torch.zeros(1, device=x0_hat.device, requires_grad=True).squeeze() |
| |
| if self.tv_weight > 0: |
| tv = self.total_variation(x0_hat) |
| losses['tv'] = tv |
| total = total + self.tv_weight * tv |
| |
| if self.cons_weight > 0: |
| cons = self.conservation_intensity(x0_hat) |
| losses['cons'] = cons |
| total = total + self.cons_weight * cons |
| |
| if self.spec_weight > 0: |
| spec = self.spectral_regularizer(x0_hat) |
| losses['spec'] = spec |
| total = total + self.spec_weight * spec |
| |
| if self.grad_weight > 0: |
| grad = self.gradient_penalty(x0_hat) |
| losses['grad'] = grad |
| total = total + self.grad_weight * grad |
| |
| return total, losses |
|
|
|
|
| class DDIMEstimator: |
| """DDIM one-step clean sample estimation.""" |
| |
| @staticmethod |
| def estimate_x0(x_t, eps_pred, alpha_bar_t): |
| """ |
| x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t) |
| |
| Args: |
| x_t: [B, C, H, W] |
| eps_pred: [B, C, H, W] |
| alpha_bar_t: [B] — cumulative alpha at timestep t |
| """ |
| a = alpha_bar_t.reshape(-1, 1, 1, 1) |
| x0_hat = (x_t - torch.sqrt(1 - a) * eps_pred) / (torch.sqrt(a) + 1e-8) |
| |
| return x0_hat.clamp(-5, 5) |
|
|