dreamlessx commited on
Commit
73fad7a
·
verified ·
1 Parent(s): fd53eef

Upload landmarkdiff/losses.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/losses.py +68 -27
landmarkdiff/losses.py CHANGED
@@ -1,10 +1,11 @@
1
- """4-term loss for ControlNet fine-tuning.
2
 
3
- L_total = L_diff + w_lm * L_landmark + w_id * L_identity + w_perc * L_perceptual
4
 
5
- Phase A (synthetic TPS data): diffusion loss only. No perceptual against
6
- rubbery TPS warps - it would penalize realism.
7
- Phase B (FEM/clinical data): all 4 terms.
 
8
  """
9
 
10
  from __future__ import annotations
@@ -26,7 +27,7 @@ class LossWeights:
26
 
27
 
28
  class DiffusionLoss:
29
- """Epsilon-prediction MSE."""
30
 
31
  def __call__(
32
  self,
@@ -37,9 +38,10 @@ class DiffusionLoss:
37
 
38
 
39
  class LandmarkLoss:
40
- """L2 landmark distance, IOD-normalized, inside surgical mask only.
41
 
42
- Requires re-extraction from generated image (eval only, too slow per step).
 
43
  """
44
 
45
  def __call__(
@@ -66,10 +68,17 @@ class LandmarkLoss:
66
 
67
 
68
  class IdentityLoss:
69
- """ArcFace cosine sim loss, procedure-dependent crop.
 
 
 
70
 
71
- buffalo_l 512-dim embeddings, falls back to pixel cosine if unavailable.
72
- Disabled for orthognathic. Images MUST be [-1,1] at 112x112 for ArcFace.
 
 
 
 
73
  """
74
 
75
  def __init__(self, device: torch.device | None = None):
@@ -78,7 +87,7 @@ class IdentityLoss:
78
  self._has_arcface = None # None = not checked yet
79
 
80
  def _ensure_loaded(self, device: torch.device) -> None:
81
- """Lazy-load ArcFace on first call."""
82
  if self._has_arcface is not None:
83
  return
84
  try:
@@ -95,7 +104,14 @@ class IdentityLoss:
95
 
96
  @torch.no_grad()
97
  def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
98
- """(B,3,112,112) in [-1,1] -> (B,512) embeddings (or pixel fallback)."""
 
 
 
 
 
 
 
99
  if self._has_arcface:
100
  import numpy as np
101
  embeddings = []
@@ -151,23 +167,26 @@ class IdentityLoss:
151
  if not any(valid):
152
  return torch.tensor(0.0, device=pred_image.device)
153
 
154
- valid_t = torch.tensor(valid, device=pred_image.device)
 
 
 
 
 
155
 
156
- # L2 normalize (safe, only valid embeddings have nonzero norm)
157
- pred_emb = F.normalize(pred_emb.float(), dim=1)
158
- target_emb = F.normalize(target_emb.float(), dim=1)
159
 
160
- cosine_sim = (pred_emb * target_emb).sum(dim=1)
161
- # Zero out invalid entries before averaging
162
- cosine_sim = cosine_sim * valid_t.float()
163
- return (1 - cosine_sim).sum() / valid_t.float().sum()
164
 
165
  def _procedure_crop(
166
  self,
167
  image: torch.Tensor,
168
  procedure: str,
169
  ) -> torch.Tensor:
170
- """Procedure-specific crop for identity comparison."""
171
  _, _, h, w = image.shape
172
 
173
  if procedure == "rhinoplasty":
@@ -184,7 +203,11 @@ class IdentityLoss:
184
 
185
 
186
  class PerceptualLoss:
187
- """LPIPS outside surgical mask only. Remember: LPIPS wants [-1,1], VAE gives [0,1]."""
 
 
 
 
188
 
189
  def __init__(self):
190
  self._lpips = None
@@ -211,8 +234,8 @@ class PerceptualLoss:
211
  # Invert mask: we want loss OUTSIDE surgical region
212
  outside_mask = 1 - mask
213
 
214
- # Erode outside_mask by a few pixels to avoid artificial edge features
215
- # at the mask boundary (LPIPS VGG detects the hard 0->value transition)
216
  erode_kernel = 5
217
  if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
218
  outside_mask = -F.max_pool2d(
@@ -238,20 +261,38 @@ class PerceptualLoss:
238
 
239
 
240
  class CombinedLoss:
241
- """4-term combined loss. phase='A' = diffusion only, phase='B' = all terms."""
 
 
 
 
 
 
 
 
 
242
 
243
  def __init__(
244
  self,
245
  weights: LossWeights | None = None,
246
  phase: str = "A",
 
 
247
  ):
248
  self.weights = weights or LossWeights()
249
  self.phase = phase
250
  self.diffusion_loss = DiffusionLoss()
251
  self.landmark_loss = LandmarkLoss()
252
- self.identity_loss = IdentityLoss()
253
  self.perceptual_loss = PerceptualLoss()
254
 
 
 
 
 
 
 
 
 
255
  def __call__(
256
  self,
257
  noise_pred: torch.Tensor,
 
1
+ """4-term loss function module for ControlNet fine-tuning.
2
 
3
+ L_total = L_diffusion + w_landmark * L_landmark + w_identity * L_identity + w_perceptual * L_perceptual
4
 
5
+ Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
6
+ rubbery TPS warps it would penalize realism.
7
+
8
+ Phase B (FEM/clinical data): All 4 terms enabled.
9
  """
10
 
11
  from __future__ import annotations
 
27
 
28
 
29
  class DiffusionLoss:
30
+ """Standard epsilon-prediction MSE loss (primary training signal)."""
31
 
32
  def __call__(
33
  self,
 
38
 
39
 
40
  class LandmarkLoss:
41
+ """L2 landmark distance normalized by inter-ocular distance.
42
 
43
+ Computed INSIDE surgical mask only. Requires MediaPipe re-extraction
44
+ from generated image (done at eval, not every training step for speed).
45
  """
46
 
47
  def __call__(
 
68
 
69
 
70
  class IdentityLoss:
71
+ """ArcFace cosine similarity loss with procedure-dependent masking.
72
+
73
+ Uses InsightFace ArcFace model (buffalo_l) for 512-dim identity embeddings.
74
+ Falls back to pixel-level cosine similarity if InsightFace is unavailable.
75
 
76
+ - Full face for blepharoplasty
77
+ - Upper-face crop for rhinoplasty
78
+ - Disabled for orthognathic
79
+
80
+ Input images MUST be normalized to [-1, 1] and cropped to 112x112
81
+ before passing to ArcFace (AdaFace outputs garbage for 1024x1024).
82
  """
83
 
84
  def __init__(self, device: torch.device | None = None):
 
87
  self._has_arcface = None # None = not checked yet
88
 
89
  def _ensure_loaded(self, device: torch.device) -> None:
90
+ """Lazy-load ArcFace model on first use."""
91
  if self._has_arcface is not None:
92
  return
93
  try:
 
104
 
105
  @torch.no_grad()
106
  def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
107
+ """Extract ArcFace embedding from a batch of images.
108
+
109
+ Args:
110
+ image_tensor: (B, 3, 112, 112) in [-1, 1]
111
+
112
+ Returns:
113
+ (B, 512) identity embeddings, or (B, D) pixel-level if fallback.
114
+ """
115
  if self._has_arcface:
116
  import numpy as np
117
  embeddings = []
 
167
  if not any(valid):
168
  return torch.tensor(0.0, device=pred_image.device)
169
 
170
+ valid_indices = [i for i, v in enumerate(valid) if v]
171
+ valid_idx_t = torch.tensor(valid_indices, device=pred_image.device, dtype=torch.long)
172
+
173
+ # Select ONLY valid embeddings before normalization to avoid 0/0 NaN
174
+ pred_valid_emb = pred_emb[valid_idx_t].float()
175
+ target_valid_emb = target_emb[valid_idx_t].float()
176
 
177
+ # L2 normalize (safe zero vectors excluded above)
178
+ pred_valid_emb = F.normalize(pred_valid_emb, dim=1)
179
+ target_valid_emb = F.normalize(target_valid_emb, dim=1)
180
 
181
+ cosine_sim = (pred_valid_emb * target_valid_emb).sum(dim=1)
182
+ return (1 - cosine_sim).mean()
 
 
183
 
184
  def _procedure_crop(
185
  self,
186
  image: torch.Tensor,
187
  procedure: str,
188
  ) -> torch.Tensor:
189
+ """Crop image based on procedure for identity comparison."""
190
  _, _, h, w = image.shape
191
 
192
  if procedure == "rhinoplasty":
 
203
 
204
 
205
  class PerceptualLoss:
206
+ """LPIPS perceptual loss on regions OUTSIDE surgical mask only.
207
+
208
+ LPIPS expects [-1, 1] input. VAE outputs [0, 1].
209
+ Must apply (x * 2) - 1 before every call.
210
+ """
211
 
212
  def __init__(self):
213
  self._lpips = None
 
234
  # Invert mask: we want loss OUTSIDE surgical region
235
  outside_mask = 1 - mask
236
 
237
+ # Erode outside_mask to exclude boundary pixels avoids artificial
238
+ # edge features where masked (0) meets unmasked (non-zero) values
239
  erode_kernel = 5
240
  if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
241
  outside_mask = -F.max_pool2d(
 
261
 
262
 
263
  class CombinedLoss:
264
+ """Combined 4-term loss with configurable weights.
265
+
266
+ Use phase='A' for Phase A training (diffusion only).
267
+ Use phase='B' for Phase B training (all terms).
268
+
269
+ For Phase B, set ``use_differentiable_arcface=True`` to use the
270
+ PyTorch-native ArcFace backbone (``arcface_torch.py``) that provides
271
+ actual gradient signal. The default ONNX-based IdentityLoss produces
272
+ zero gradients (DA2-03).
273
+ """
274
 
275
  def __init__(
276
  self,
277
  weights: LossWeights | None = None,
278
  phase: str = "A",
279
+ use_differentiable_arcface: bool = False,
280
+ arcface_weights_path: str | None = None,
281
  ):
282
  self.weights = weights or LossWeights()
283
  self.phase = phase
284
  self.diffusion_loss = DiffusionLoss()
285
  self.landmark_loss = LandmarkLoss()
 
286
  self.perceptual_loss = PerceptualLoss()
287
 
288
+ # Identity loss: differentiable PyTorch ArcFace for Phase B,
289
+ # or ONNX-based fallback
290
+ if use_differentiable_arcface:
291
+ from landmarkdiff.arcface_torch import ArcFaceLoss
292
+ self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
293
+ else:
294
+ self.identity_loss = IdentityLoss()
295
+
296
  def __call__(
297
  self,
298
  noise_pred: torch.Tensor,