Spaces:
Sleeping
Sleeping
Upload landmarkdiff/losses.py with huggingface_hub
Browse files- landmarkdiff/losses.py +68 -27
landmarkdiff/losses.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
"""4-term loss for ControlNet fine-tuning.
|
| 2 |
|
| 3 |
-
L_total =
|
| 4 |
|
| 5 |
-
Phase A (synthetic TPS data):
|
| 6 |
-
rubbery TPS warps
|
| 7 |
-
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
|
@@ -26,7 +27,7 @@ class LossWeights:
|
|
| 26 |
|
| 27 |
|
| 28 |
class DiffusionLoss:
|
| 29 |
-
"""
|
| 30 |
|
| 31 |
def __call__(
|
| 32 |
self,
|
|
@@ -37,9 +38,10 @@ class DiffusionLoss:
|
|
| 37 |
|
| 38 |
|
| 39 |
class LandmarkLoss:
|
| 40 |
-
"""L2 landmark distance
|
| 41 |
|
| 42 |
-
|
|
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
def __call__(
|
|
@@ -66,10 +68,17 @@ class LandmarkLoss:
|
|
| 66 |
|
| 67 |
|
| 68 |
class IdentityLoss:
|
| 69 |
-
"""ArcFace cosine
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
"""
|
| 74 |
|
| 75 |
def __init__(self, device: torch.device | None = None):
|
|
@@ -78,7 +87,7 @@ class IdentityLoss:
|
|
| 78 |
self._has_arcface = None # None = not checked yet
|
| 79 |
|
| 80 |
def _ensure_loaded(self, device: torch.device) -> None:
|
| 81 |
-
"""Lazy-load ArcFace on first
|
| 82 |
if self._has_arcface is not None:
|
| 83 |
return
|
| 84 |
try:
|
|
@@ -95,7 +104,14 @@ class IdentityLoss:
|
|
| 95 |
|
| 96 |
@torch.no_grad()
|
| 97 |
def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
| 98 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
if self._has_arcface:
|
| 100 |
import numpy as np
|
| 101 |
embeddings = []
|
|
@@ -151,23 +167,26 @@ class IdentityLoss:
|
|
| 151 |
if not any(valid):
|
| 152 |
return torch.tensor(0.0, device=pred_image.device)
|
| 153 |
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
-
# L2 normalize (safe
|
| 157 |
-
|
| 158 |
-
|
| 159 |
|
| 160 |
-
cosine_sim = (
|
| 161 |
-
|
| 162 |
-
cosine_sim = cosine_sim * valid_t.float()
|
| 163 |
-
return (1 - cosine_sim).sum() / valid_t.float().sum()
|
| 164 |
|
| 165 |
def _procedure_crop(
|
| 166 |
self,
|
| 167 |
image: torch.Tensor,
|
| 168 |
procedure: str,
|
| 169 |
) -> torch.Tensor:
|
| 170 |
-
"""
|
| 171 |
_, _, h, w = image.shape
|
| 172 |
|
| 173 |
if procedure == "rhinoplasty":
|
|
@@ -184,7 +203,11 @@ class IdentityLoss:
|
|
| 184 |
|
| 185 |
|
| 186 |
class PerceptualLoss:
|
| 187 |
-
"""LPIPS
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
def __init__(self):
|
| 190 |
self._lpips = None
|
|
@@ -211,8 +234,8 @@ class PerceptualLoss:
|
|
| 211 |
# Invert mask: we want loss OUTSIDE surgical region
|
| 212 |
outside_mask = 1 - mask
|
| 213 |
|
| 214 |
-
# Erode outside_mask
|
| 215 |
-
#
|
| 216 |
erode_kernel = 5
|
| 217 |
if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
|
| 218 |
outside_mask = -F.max_pool2d(
|
|
@@ -238,20 +261,38 @@ class PerceptualLoss:
|
|
| 238 |
|
| 239 |
|
| 240 |
class CombinedLoss:
|
| 241 |
-
"""4-term
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
def __init__(
|
| 244 |
self,
|
| 245 |
weights: LossWeights | None = None,
|
| 246 |
phase: str = "A",
|
|
|
|
|
|
|
| 247 |
):
|
| 248 |
self.weights = weights or LossWeights()
|
| 249 |
self.phase = phase
|
| 250 |
self.diffusion_loss = DiffusionLoss()
|
| 251 |
self.landmark_loss = LandmarkLoss()
|
| 252 |
-
self.identity_loss = IdentityLoss()
|
| 253 |
self.perceptual_loss = PerceptualLoss()
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
def __call__(
|
| 256 |
self,
|
| 257 |
noise_pred: torch.Tensor,
|
|
|
|
| 1 |
+
"""4-term loss function module for ControlNet fine-tuning.
|
| 2 |
|
| 3 |
+
L_total = L_diffusion + w_landmark * L_landmark + w_identity * L_identity + w_perceptual * L_perceptual
|
| 4 |
|
| 5 |
+
Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
|
| 6 |
+
rubbery TPS warps — it would penalize realism.
|
| 7 |
+
|
| 8 |
+
Phase B (FEM/clinical data): All 4 terms enabled.
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
class DiffusionLoss:
|
| 30 |
+
"""Standard epsilon-prediction MSE loss (primary training signal)."""
|
| 31 |
|
| 32 |
def __call__(
|
| 33 |
self,
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class LandmarkLoss:
|
| 41 |
+
"""L2 landmark distance normalized by inter-ocular distance.
|
| 42 |
|
| 43 |
+
Computed INSIDE surgical mask only. Requires MediaPipe re-extraction
|
| 44 |
+
from generated image (done at eval, not every training step for speed).
|
| 45 |
"""
|
| 46 |
|
| 47 |
def __call__(
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
class IdentityLoss:
|
| 71 |
+
"""ArcFace cosine similarity loss with procedure-dependent masking.
|
| 72 |
+
|
| 73 |
+
Uses InsightFace ArcFace model (buffalo_l) for 512-dim identity embeddings.
|
| 74 |
+
Falls back to pixel-level cosine similarity if InsightFace is unavailable.
|
| 75 |
|
| 76 |
+
- Full face for blepharoplasty
|
| 77 |
+
- Upper-face crop for rhinoplasty
|
| 78 |
+
- Disabled for orthognathic
|
| 79 |
+
|
| 80 |
+
Input images MUST be normalized to [-1, 1] and cropped to 112x112
|
| 81 |
+
before passing to ArcFace (AdaFace outputs garbage for 1024x1024).
|
| 82 |
"""
|
| 83 |
|
| 84 |
def __init__(self, device: torch.device | None = None):
|
|
|
|
| 87 |
self._has_arcface = None # None = not checked yet
|
| 88 |
|
| 89 |
def _ensure_loaded(self, device: torch.device) -> None:
|
| 90 |
+
"""Lazy-load ArcFace model on first use."""
|
| 91 |
if self._has_arcface is not None:
|
| 92 |
return
|
| 93 |
try:
|
|
|
|
| 104 |
|
| 105 |
@torch.no_grad()
|
| 106 |
def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
"""Extract ArcFace embedding from a batch of images.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
image_tensor: (B, 3, 112, 112) in [-1, 1]
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
(B, 512) identity embeddings, or (B, D) pixel-level if fallback.
|
| 114 |
+
"""
|
| 115 |
if self._has_arcface:
|
| 116 |
import numpy as np
|
| 117 |
embeddings = []
|
|
|
|
| 167 |
if not any(valid):
|
| 168 |
return torch.tensor(0.0, device=pred_image.device)
|
| 169 |
|
| 170 |
+
valid_indices = [i for i, v in enumerate(valid) if v]
|
| 171 |
+
valid_idx_t = torch.tensor(valid_indices, device=pred_image.device, dtype=torch.long)
|
| 172 |
+
|
| 173 |
+
# Select ONLY valid embeddings before normalization to avoid 0/0 NaN
|
| 174 |
+
pred_valid_emb = pred_emb[valid_idx_t].float()
|
| 175 |
+
target_valid_emb = target_emb[valid_idx_t].float()
|
| 176 |
|
| 177 |
+
# L2 normalize (safe — zero vectors excluded above)
|
| 178 |
+
pred_valid_emb = F.normalize(pred_valid_emb, dim=1)
|
| 179 |
+
target_valid_emb = F.normalize(target_valid_emb, dim=1)
|
| 180 |
|
| 181 |
+
cosine_sim = (pred_valid_emb * target_valid_emb).sum(dim=1)
|
| 182 |
+
return (1 - cosine_sim).mean()
|
|
|
|
|
|
|
| 183 |
|
| 184 |
def _procedure_crop(
|
| 185 |
self,
|
| 186 |
image: torch.Tensor,
|
| 187 |
procedure: str,
|
| 188 |
) -> torch.Tensor:
|
| 189 |
+
"""Crop image based on procedure for identity comparison."""
|
| 190 |
_, _, h, w = image.shape
|
| 191 |
|
| 192 |
if procedure == "rhinoplasty":
|
|
|
|
| 203 |
|
| 204 |
|
| 205 |
class PerceptualLoss:
|
| 206 |
+
"""LPIPS perceptual loss on regions OUTSIDE surgical mask only.
|
| 207 |
+
|
| 208 |
+
LPIPS expects [-1, 1] input. VAE outputs [0, 1].
|
| 209 |
+
Must apply (x * 2) - 1 before every call.
|
| 210 |
+
"""
|
| 211 |
|
| 212 |
def __init__(self):
|
| 213 |
self._lpips = None
|
|
|
|
| 234 |
# Invert mask: we want loss OUTSIDE surgical region
|
| 235 |
outside_mask = 1 - mask
|
| 236 |
|
| 237 |
+
# Erode outside_mask to exclude boundary pixels — avoids artificial
|
| 238 |
+
# edge features where masked (0) meets unmasked (non-zero) values
|
| 239 |
erode_kernel = 5
|
| 240 |
if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
|
| 241 |
outside_mask = -F.max_pool2d(
|
|
|
|
| 261 |
|
| 262 |
|
| 263 |
class CombinedLoss:
|
| 264 |
+
"""Combined 4-term loss with configurable weights.
|
| 265 |
+
|
| 266 |
+
Use phase='A' for Phase A training (diffusion only).
|
| 267 |
+
Use phase='B' for Phase B training (all terms).
|
| 268 |
+
|
| 269 |
+
For Phase B, set ``use_differentiable_arcface=True`` to use the
|
| 270 |
+
PyTorch-native ArcFace backbone (``arcface_torch.py``) that provides
|
| 271 |
+
actual gradient signal. The default ONNX-based IdentityLoss produces
|
| 272 |
+
zero gradients (DA2-03).
|
| 273 |
+
"""
|
| 274 |
|
| 275 |
def __init__(
|
| 276 |
self,
|
| 277 |
weights: LossWeights | None = None,
|
| 278 |
phase: str = "A",
|
| 279 |
+
use_differentiable_arcface: bool = False,
|
| 280 |
+
arcface_weights_path: str | None = None,
|
| 281 |
):
|
| 282 |
self.weights = weights or LossWeights()
|
| 283 |
self.phase = phase
|
| 284 |
self.diffusion_loss = DiffusionLoss()
|
| 285 |
self.landmark_loss = LandmarkLoss()
|
|
|
|
| 286 |
self.perceptual_loss = PerceptualLoss()
|
| 287 |
|
| 288 |
+
# Identity loss: differentiable PyTorch ArcFace for Phase B,
|
| 289 |
+
# or ONNX-based fallback
|
| 290 |
+
if use_differentiable_arcface:
|
| 291 |
+
from landmarkdiff.arcface_torch import ArcFaceLoss
|
| 292 |
+
self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
|
| 293 |
+
else:
|
| 294 |
+
self.identity_loss = IdentityLoss()
|
| 295 |
+
|
| 296 |
def __call__(
|
| 297 |
self,
|
| 298 |
noise_pred: torch.Tensor,
|