LiquidFlow-Gen / liquid_flow /physics_loss.py
krystv's picture
Upload liquid_flow/physics_loss.py
0cf988f verified
"""
Physics-Informed Regularization for LiquidFlow.
CORRECTED VERSION: fixed intensity tracking, proper buffer handling.
Pattern from: Bastek & Sun (ICLR 2025)
- Physics losses computed on estimated x̂₀ during training
- Zero cost at inference
- Acts as implicit regularizer against artifacts
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PhysicsRegularizer(nn.Module):
"""
Physics-informed regularizer for diffusion training.
Computed on estimated clean sample x̂₀ (DDIM one-step estimate).
All losses are differentiable through the noise predictor.
"""
def __init__(self, tv_weight=0.01, cons_weight=0.001, spec_weight=0.01, grad_weight=0.001):
super().__init__()
self.tv_weight = tv_weight
self.cons_weight = cons_weight
self.spec_weight = spec_weight
self.grad_weight = grad_weight
# EMA intensity tracking
self.register_buffer('intensity_ema', torch.tensor(0.0))
self.register_buffer('step_count', torch.tensor(0, dtype=torch.long))
def total_variation(self, x):
"""L1 total variation: encourages spatial smoothness."""
diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
return diff_h.mean() + diff_w.mean()
def conservation_intensity(self, x):
"""Penalize deviation from running mean intensity."""
batch_mean = x.mean()
if self.training:
with torch.no_grad():
self.step_count += 1
alpha = min(0.99, 1.0 - 1.0 / (self.step_count.float() + 1))
self.intensity_ema = alpha * self.intensity_ema + (1 - alpha) * batch_mean
# Only activate after warmup (100 steps)
if self.step_count > 100:
return (batch_mean - self.intensity_ema.detach()) ** 2
return torch.zeros(1, device=x.device, requires_grad=True).squeeze()
def spectral_regularizer(self, x):
"""Penalize high-frequency energy (anti-checkerboard)."""
B, C, H, W = x.shape
# 2D FFT
x_fft = torch.fft.rfft2(x, norm='ortho')
mag = torch.abs(x_fft)
# High-frequency mask: upper-right quadrant of frequency space
# For rfft2, output shape is [B, C, H, W//2+1]
freq_h = torch.arange(H, device=x.device).float()
freq_w = torch.arange(W // 2 + 1, device=x.device).float()
# Normalize frequencies to [0, 1]
freq_h = torch.min(freq_h, H - freq_h) / (H / 2)
freq_w = freq_w / (W / 2)
# Distance from DC (center)
dist = torch.sqrt(freq_h.unsqueeze(1) ** 2 + freq_w.unsqueeze(0) ** 2)
# High frequency: distance > 0.5 (half Nyquist)
high_mask = (dist > 0.5).float()
high_energy = (mag * high_mask.unsqueeze(0).unsqueeze(0)).mean()
return high_energy
def gradient_penalty(self, x):
"""Sobolev L2 gradient penalty."""
grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
return (grad_h ** 2).mean() + (grad_w ** 2).mean()
def forward(self, x0_hat, x_ref=None):
"""
Args:
x0_hat: Estimated clean image [B, C, H, W]
x_ref: Ground truth (unused, kept for API compat)
Returns:
total_loss, loss_dict
"""
losses = {}
total = torch.zeros(1, device=x0_hat.device, requires_grad=True).squeeze()
if self.tv_weight > 0:
tv = self.total_variation(x0_hat)
losses['tv'] = tv
total = total + self.tv_weight * tv
if self.cons_weight > 0:
cons = self.conservation_intensity(x0_hat)
losses['cons'] = cons
total = total + self.cons_weight * cons
if self.spec_weight > 0:
spec = self.spectral_regularizer(x0_hat)
losses['spec'] = spec
total = total + self.spec_weight * spec
if self.grad_weight > 0:
grad = self.gradient_penalty(x0_hat)
losses['grad'] = grad
total = total + self.grad_weight * grad
return total, losses
class DDIMEstimator:
"""DDIM one-step clean sample estimation."""
@staticmethod
def estimate_x0(x_t, eps_pred, alpha_bar_t):
"""
x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
Args:
x_t: [B, C, H, W]
eps_pred: [B, C, H, W]
alpha_bar_t: [B] — cumulative alpha at timestep t
"""
a = alpha_bar_t.reshape(-1, 1, 1, 1)
x0_hat = (x_t - torch.sqrt(1 - a) * eps_pred) / (torch.sqrt(a) + 1e-8)
# Clamp to prevent extreme values early in training
return x0_hat.clamp(-5, 5)