krystv commited on
Commit
0cf988f
·
verified ·
1 Parent(s): 3214be6

Upload liquid_flow/physics_loss.py

Browse files
Files changed (1) hide show
  1. liquid_flow/physics_loss.py +65 -176
liquid_flow/physics_loss.py CHANGED
@@ -1,37 +1,11 @@
1
  """
2
  Physics-Informed Regularization for LiquidFlow.
 
3
 
4
- From: "Physics-Informed Diffusion Models" (Bastek & Sun, ICLR 2025)
5
- and "PID: Physics-Informed Diffusion for IR Image Generation" (Mao et al., 2024)
6
-
7
- Physics losses act as TRAINING-ONLY regularizers they don't affect
8
- inference speed. The pattern:
9
-
10
- 1. During training: denoise to get x̂₀, compute physics residual, add to loss
11
- 2. During inference: no change at all
12
-
13
- Implemented physics constraints for image generation:
14
-
15
- A. Total Variation (TV) — penalizes non-smooth outputs
16
- L_TV = ||∇_x x̂₀||₁ + ||∇_y x̂₀||₁
17
- → Enforces spatial smoothness, reduces artifacts
18
-
19
- B. Conservation of Intensity — mass conservation across image
20
- L_cons = ||mean(x̂₀) - E[mean(x_ref)]||²
21
- → Prevents intensity drift
22
-
23
- C. Spectral Regularizer — penalizes high-frequency noise
24
- L_spec = ||FFT_high(x̂₀)||²
25
- → Reduces checkerboard artifacts
26
-
27
- D. Gradient Magnitude Balance — prevents exploding gradients in dark regions
28
- L_grad = ||∇x̂₀||² (Sobolev regularization)
29
- → Stabilizes training in low-signal regions
30
-
31
- Pattern: L_total = L_diffusion + λ_TV * L_TV + λ_cons * L_cons + λ_spec * L_spec
32
-
33
- The virtual-observable paradigm (from PAD-Hand, 2026):
34
- Physics constraints are SOFT — they guide without requiring perfect satisfaction.
35
  """
36
 
37
  import torch
@@ -41,209 +15,124 @@ import torch.nn.functional as F
41
 
42
  class PhysicsRegularizer(nn.Module):
43
  """
44
- Physics-informed regularizer for image generation training.
45
-
46
- All losses are computed on the estimated clean sample x̂₀ during training.
47
- They are ADDITIVE regularizers — just add to the diffusion loss.
48
 
49
- Args:
50
- tv_weight: Total Variation weight (default 0.01)
51
- cons_weight: Conservation of intensity weight (default 0.001)
52
- spec_weight: Spectral regularizer weight (default 0.01)
53
- grad_weight: Gradient magnitude penalty weight (default 0.001)
54
  """
55
 
56
- def __init__(
57
- self,
58
- tv_weight=0.01,
59
- cons_weight=0.001,
60
- spec_weight=0.01,
61
- grad_weight=0.001,
62
- ):
63
  super().__init__()
64
  self.tv_weight = tv_weight
65
  self.cons_weight = cons_weight
66
  self.spec_weight = spec_weight
67
  self.grad_weight = grad_weight
68
 
69
- # Running mean for intensity conservation
70
- self.register_buffer('intensity_mean', torch.tensor(0.0))
71
- self.register_buffer('intensity_count', torch.tensor(0))
72
- self.intensity_alpha = 0.99 # EMA decay
73
 
74
  def total_variation(self, x):
75
- """
76
- Total Variation loss on image batch x.
77
-
78
- L_TV = mean(|x_{i+1,j} - x_{i,j}| + |x_{i,j+1} - x_{i,j}|)
79
-
80
- Args:
81
- x: [B, C, H, W] images
82
- Returns:
83
- tv_loss: scalar
84
- """
85
  diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
86
  diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
87
  return diff_h.mean() + diff_w.mean()
88
 
89
  def conservation_intensity(self, x):
90
- """
91
- Conservation of image intensity (mass).
92
-
93
- L_cons = (mean(x) - running_mean)^2
94
-
95
- This prevents the generator from drifting into producing
96
- images that are too dark or too bright.
97
-
98
- Args:
99
- x: [B, C, H, W] images
100
- Returns:
101
- cons_loss: scalar
102
- """
103
  batch_mean = x.mean()
104
 
105
- # Update running statistics
106
  if self.training:
107
  with torch.no_grad():
108
- self.intensity_mean = (
109
- self.intensity_alpha * self.intensity_mean +
110
- (1 - self.intensity_alpha) * batch_mean.detach()
111
- )
112
-
113
- # Conservation loss: penalize deviation from running mean
114
- if self.intensity_count > 100: # Only after some warmup
115
- return ((batch_mean - self.intensity_mean) ** 2).mean()
116
- return torch.tensor(0.0, device=x.device)
117
 
118
  def spectral_regularizer(self, x):
119
- """
120
- Spectral regularizer: penalize high-frequency content.
121
-
122
- Uses FFT and penalizes high-frequency components.
123
- This prevents high-frequency artifacts (checkerboard patterns).
124
 
125
- Args:
126
- x: [B, C, H, W] images
127
- Returns:
128
- spec_loss: scalar
129
- """
130
  # 2D FFT
131
- x_fft = torch.fft.fft2(x)
132
- x_fft_shift = torch.fft.fftshift(x_fft)
133
 
134
- # Create high-frequency mask (center is low frequency)
135
- B, C, H, W = x.shape
136
- h_center, w_center = H // 2, W // 2
 
137
 
138
- y, x_coord = torch.meshgrid(
139
- torch.arange(H, device=x.device),
140
- torch.arange(W, device=x.device),
141
- indexing='ij'
142
- )
143
- dist = torch.sqrt((y - h_center) ** 2 + (x_coord - w_center) ** 2)
144
 
145
- # High frequency: distance > quarter of image size
146
- high_freq_mask = (dist > min(H, W) / 4).float()
147
 
148
- # Penalize high-frequency magnitude
149
- spec_mag = torch.abs(x_fft_shift)
150
- high_freq_energy = (spec_mag * high_freq_mask.unsqueeze(0).unsqueeze(0)).mean()
151
 
152
- return high_freq_energy
 
153
 
154
  def gradient_penalty(self, x):
155
- """
156
- Sobolev gradient penalty.
157
-
158
- L_grad = ||∇x||² (mean squared gradient magnitude)
159
-
160
- This prevents the generator from creating regions where
161
- gradients explode (common in GAN-like training).
162
- For diffusion, this helps stabilize the noise prediction.
163
-
164
- Args:
165
- x: [B, C, H, W] images
166
- Returns:
167
- grad_loss: scalar
168
- """
169
  grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
170
  grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
171
-
172
- grad_mag = (grad_h ** 2).mean() + (grad_w ** 2).mean()
173
- return grad_mag
174
 
175
  def forward(self, x0_hat, x_ref=None):
176
  """
177
- Compute total physics loss.
178
-
179
  Args:
180
  x0_hat: Estimated clean image [B, C, H, W]
181
- x_ref: Optional ground truth reference (for intensity tracking)
182
-
183
  Returns:
184
- total_loss: Combined physics regularizer (scalar)
185
- loss_dict: Dict of individual losses
186
  """
187
  losses = {}
 
188
 
189
- # Total Variation
190
  if self.tv_weight > 0:
191
- losses['tv'] = self.total_variation(x0_hat)
 
 
192
 
193
- # Conservation of Intensity
194
  if self.cons_weight > 0:
195
- losses['cons'] = self.conservation_intensity(x0_hat)
 
 
196
 
197
- # Spectral Regularizer
198
  if self.spec_weight > 0:
199
- losses['spec'] = self.spectral_regularizer(x0_hat)
 
 
200
 
201
- # Gradient Penalty
202
  if self.grad_weight > 0:
203
- losses['grad'] = self.gradient_penalty(x0_hat)
204
-
205
- # Weighted sum
206
- total = (
207
- self.tv_weight * losses.get('tv', 0.0) +
208
- self.cons_weight * losses.get('cons', 0.0) +
209
- self.spec_weight * losses.get('spec', 0.0) +
210
- self.grad_weight * losses.get('grad', 0.0)
211
- )
212
 
213
  return total, losses
214
 
215
 
216
  class DDIMEstimator:
217
- """
218
- DDIM clean-sample estimator for physics loss computation.
219
-
220
- From the Bastek & Sun (ICLR 2025) pattern:
221
- x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
222
-
223
- This provides an estimate of the clean sample at training time
224
- without requiring full reverse diffusion.
225
- """
226
 
227
  @staticmethod
228
  def estimate_x0(x_t, eps_pred, alpha_bar_t):
229
  """
230
- Estimate clean sample from noisy sample and predicted noise.
231
 
232
  Args:
233
- x_t: Noisy sample [B, C, H, W]
234
- eps_pred: Predicted noise [B, C, H, W]
235
- alpha_bar_t: Cumulative product of alphas at timestep t [B]
236
-
237
- Returns:
238
- x0_hat: Estimated clean sample [B, C, H, W]
239
  """
240
- alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
241
- x0_hat = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
242
- return x0_hat
243
-
244
- @staticmethod
245
- def estimate_noise(x_t, x0_hat, alpha_bar_t):
246
- """Reverse: estimate noise from clean sample."""
247
- alpha_bar_t = alpha_bar_t.reshape(-1, 1, 1, 1)
248
- eps_pred = (x_t - torch.sqrt(alpha_bar_t) * x0_hat) / torch.sqrt(1 - alpha_bar_t)
249
- return eps_pred
 
1
  """
2
  Physics-Informed Regularization for LiquidFlow.
3
+ CORRECTED VERSION: fixed intensity tracking, proper buffer handling.
4
 
5
+ Pattern from: Bastek & Sun (ICLR 2025)
6
+ - Physics losses computed on estimated x̂₀ during training
7
+ - Zero cost at inference
8
+ - Acts as implicit regularizer against artifacts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  """
10
 
11
  import torch
 
15
 
16
  class PhysicsRegularizer(nn.Module):
17
  """
18
+ Physics-informed regularizer for diffusion training.
 
 
 
19
 
20
+ Computed on estimated clean sample x̂₀ (DDIM one-step estimate).
21
+ All losses are differentiable through the noise predictor.
 
 
 
22
  """
23
 
24
+ def __init__(self, tv_weight=0.01, cons_weight=0.001, spec_weight=0.01, grad_weight=0.001):
 
 
 
 
 
 
25
  super().__init__()
26
  self.tv_weight = tv_weight
27
  self.cons_weight = cons_weight
28
  self.spec_weight = spec_weight
29
  self.grad_weight = grad_weight
30
 
31
+ # EMA intensity tracking
32
+ self.register_buffer('intensity_ema', torch.tensor(0.0))
33
+ self.register_buffer('step_count', torch.tensor(0, dtype=torch.long))
 
34
 
35
  def total_variation(self, x):
36
+ """L1 total variation: encourages spatial smoothness."""
 
 
 
 
 
 
 
 
 
37
  diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
38
  diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
39
  return diff_h.mean() + diff_w.mean()
40
 
41
  def conservation_intensity(self, x):
42
+ """Penalize deviation from running mean intensity."""
 
 
 
 
 
 
 
 
 
 
 
 
43
  batch_mean = x.mean()
44
 
 
45
  if self.training:
46
  with torch.no_grad():
47
+ self.step_count += 1
48
+ alpha = min(0.99, 1.0 - 1.0 / (self.step_count.float() + 1))
49
+ self.intensity_ema = alpha * self.intensity_ema + (1 - alpha) * batch_mean
50
+
51
+ # Only activate after warmup (100 steps)
52
+ if self.step_count > 100:
53
+ return (batch_mean - self.intensity_ema.detach()) ** 2
54
+ return torch.zeros(1, device=x.device, requires_grad=True).squeeze()
 
55
 
56
  def spectral_regularizer(self, x):
57
+ """Penalize high-frequency energy (anti-checkerboard)."""
58
+ B, C, H, W = x.shape
 
 
 
59
 
 
 
 
 
 
60
  # 2D FFT
61
+ x_fft = torch.fft.rfft2(x, norm='ortho')
62
+ mag = torch.abs(x_fft)
63
 
64
+ # High-frequency mask: upper-right quadrant of frequency space
65
+ # For rfft2, output shape is [B, C, H, W//2+1]
66
+ freq_h = torch.arange(H, device=x.device).float()
67
+ freq_w = torch.arange(W // 2 + 1, device=x.device).float()
68
 
69
+ # Normalize frequencies to [0, 1]
70
+ freq_h = torch.min(freq_h, H - freq_h) / (H / 2)
71
+ freq_w = freq_w / (W / 2)
 
 
 
72
 
73
+ # Distance from DC (center)
74
+ dist = torch.sqrt(freq_h.unsqueeze(1) ** 2 + freq_w.unsqueeze(0) ** 2)
75
 
76
+ # High frequency: distance > 0.5 (half Nyquist)
77
+ high_mask = (dist > 0.5).float()
 
78
 
79
+ high_energy = (mag * high_mask.unsqueeze(0).unsqueeze(0)).mean()
80
+ return high_energy
81
 
82
  def gradient_penalty(self, x):
83
+ """Sobolev L2 gradient penalty."""
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
85
  grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
86
+ return (grad_h ** 2).mean() + (grad_w ** 2).mean()
 
 
87
 
88
  def forward(self, x0_hat, x_ref=None):
89
  """
 
 
90
  Args:
91
  x0_hat: Estimated clean image [B, C, H, W]
92
+ x_ref: Ground truth (unused, kept for API compat)
 
93
  Returns:
94
+ total_loss, loss_dict
 
95
  """
96
  losses = {}
97
+ total = torch.zeros(1, device=x0_hat.device, requires_grad=True).squeeze()
98
 
 
99
  if self.tv_weight > 0:
100
+ tv = self.total_variation(x0_hat)
101
+ losses['tv'] = tv
102
+ total = total + self.tv_weight * tv
103
 
 
104
  if self.cons_weight > 0:
105
+ cons = self.conservation_intensity(x0_hat)
106
+ losses['cons'] = cons
107
+ total = total + self.cons_weight * cons
108
 
 
109
  if self.spec_weight > 0:
110
+ spec = self.spectral_regularizer(x0_hat)
111
+ losses['spec'] = spec
112
+ total = total + self.spec_weight * spec
113
 
 
114
  if self.grad_weight > 0:
115
+ grad = self.gradient_penalty(x0_hat)
116
+ losses['grad'] = grad
117
+ total = total + self.grad_weight * grad
 
 
 
 
 
 
118
 
119
  return total, losses
120
 
121
 
122
  class DDIMEstimator:
123
+ """DDIM one-step clean sample estimation."""
 
 
 
 
 
 
 
 
124
 
125
  @staticmethod
126
  def estimate_x0(x_t, eps_pred, alpha_bar_t):
127
  """
128
+ x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
129
 
130
  Args:
131
+ x_t: [B, C, H, W]
132
+ eps_pred: [B, C, H, W]
133
+ alpha_bar_t: [B] cumulative alpha at timestep t
 
 
 
134
  """
135
+ a = alpha_bar_t.reshape(-1, 1, 1, 1)
136
+ x0_hat = (x_t - torch.sqrt(1 - a) * eps_pred) / (torch.sqrt(a) + 1e-8)
137
+ # Clamp to prevent extreme values early in training
138
+ return x0_hat.clamp(-5, 5)