dreamlessx commited on
Commit
1dc5f2f
·
verified ·
1 Parent(s): 21efdcf

Update landmarkdiff/arcface_torch.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/arcface_torch.py +58 -193
landmarkdiff/arcface_torch.py CHANGED
@@ -6,17 +6,17 @@ means the identity loss term contributes zero gradients during Phase B training.
6
  This module provides a fully differentiable path so that gradients flow back
7
  through the predicted image into the ControlNet.
8
 
9
- Architecture: IResNet-50 (the standard ArcFace backbone from InsightFace).
10
- conv1(3->64, 3x3) -> BN -> PReLU ->
11
- 4 IResNet blocks [3, 4, 14, 3] with channels [64, 128, 256, 512] ->
12
- BN -> Dropout -> Flatten -> FC(512*7*7 -> 512) -> BN (no bias)
13
- -> L2-normalize
14
 
15
- Each IBasicBlock: conv3x3-BN-PReLU-conv3x3-BN + SE attention + residual.
 
16
 
17
- Pretrained weights: InsightFace distributes IResNet-50 as a PyTorch .pth
18
- (backbone.pth inside the buffalo_l model pack). This module can load those
19
- weights directly, or fall back to random initialization with a warning.
20
 
21
  Usage in losses.py:
22
  from landmarkdiff.arcface_torch import ArcFaceLoss
@@ -41,35 +41,12 @@ logger = logging.getLogger(__name__)
41
  # Building blocks
42
  # ---------------------------------------------------------------------------
43
 
44
-
45
- class SEModule(nn.Module):
46
- """Squeeze-and-Excitation channel attention (Hu et al., 2018).
47
-
48
- Reduces channels by ``reduction``, applies ReLU, expands back, and uses
49
- sigmoid gating on the original feature map.
50
- """
51
-
52
- def __init__(self, channels: int, reduction: int = 4):
53
- super().__init__()
54
- mid = channels // reduction
55
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
56
- self.fc1 = nn.Conv2d(channels, mid, kernel_size=1, bias=True)
57
- self.relu = nn.ReLU(inplace=True)
58
- self.fc2 = nn.Conv2d(mid, channels, kernel_size=1, bias=True)
59
- self.sigmoid = nn.Sigmoid()
60
-
61
- def forward(self, x: torch.Tensor) -> torch.Tensor:
62
- w = self.avg_pool(x)
63
- w = self.relu(self.fc1(w))
64
- w = self.sigmoid(self.fc2(w))
65
- return x * w
66
-
67
-
68
  class IBasicBlock(nn.Module):
69
  """Improved basic residual block for IResNet.
70
 
71
- Structure: BN -> conv3x3 -> BN -> PReLU -> conv3x3 -> BN -> SE -> + residual
72
- Uses pre-activation style BatchNorm and includes SE attention.
 
73
  """
74
 
75
  expansion: int = 1
@@ -80,129 +57,85 @@ class IBasicBlock(nn.Module):
80
  planes: int,
81
  stride: int = 1,
82
  downsample: nn.Module | None = None,
83
- use_se: bool = True,
84
  ):
85
  super().__init__()
86
- self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-5)
87
  self.conv1 = nn.Conv2d(
88
- inplanes,
89
- planes,
90
- kernel_size=3,
91
- stride=1,
92
- padding=1,
93
- bias=False,
94
  )
95
- self.bn2 = nn.BatchNorm2d(planes, eps=1e-5)
96
  self.prelu = nn.PReLU(planes)
97
  self.conv2 = nn.Conv2d(
98
- planes,
99
- planes,
100
- kernel_size=3,
101
- stride=stride,
102
- padding=1,
103
- bias=False,
104
  )
105
- self.bn3 = nn.BatchNorm2d(planes, eps=1e-5)
106
-
107
- self.se_module = SEModule(planes) if use_se else nn.Identity()
108
  self.downsample = downsample
109
- self.stride = stride
110
 
111
  def forward(self, x: torch.Tensor) -> torch.Tensor:
112
  identity = x
113
-
114
  out = self.bn1(x)
115
  out = self.conv1(out)
116
- out = self.bn2(out)
117
  out = self.prelu(out)
118
  out = self.conv2(out)
119
- out = self.bn3(out)
120
- out = self.se_module(out)
121
-
122
  if self.downsample is not None:
123
  identity = self.downsample(x)
124
-
125
- out = out + identity
126
- return out
127
 
128
 
129
  # ---------------------------------------------------------------------------
130
  # Backbone
131
  # ---------------------------------------------------------------------------
132
 
133
-
134
  class ArcFaceBackbone(nn.Module):
135
  """IResNet-50 backbone for ArcFace identity embeddings.
136
 
137
  Input: (B, 3, 112, 112) face crops normalized to [-1, 1].
138
  Output: (B, 512) L2-normalized embeddings.
139
 
140
- Architecture follows the InsightFace IResNet-50 exactly so that
141
- pretrained weights can be loaded without key remapping.
142
  """
143
 
144
  def __init__(
145
  self,
146
  layers: tuple[int, ...] = (3, 4, 14, 3),
147
- dropout_rate: float = 0.0,
148
  embedding_dim: int = 512,
149
- use_se: bool = True,
150
  ):
151
  super().__init__()
152
  self.inplanes = 64
153
- self.use_se = use_se
154
 
155
- # Stem: conv1 -> BN -> PReLU
156
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
157
- self.bn1 = nn.BatchNorm2d(64, eps=1e-5)
158
  self.prelu = nn.PReLU(64)
159
 
160
  # 4 residual stages
161
- self.layer1 = self._make_layer(IBasicBlock, 64, layers[0], stride=2)
162
- self.layer2 = self._make_layer(IBasicBlock, 128, layers[1], stride=2)
163
- self.layer3 = self._make_layer(IBasicBlock, 256, layers[2], stride=2)
164
- self.layer4 = self._make_layer(IBasicBlock, 512, layers[3], stride=2)
165
-
166
- # Head: BN -> Dropout -> Flatten -> FC -> BN
167
- self.bn2 = nn.BatchNorm2d(512 * IBasicBlock.expansion, eps=1e-5)
168
- self.dropout = nn.Dropout(p=dropout_rate, inplace=True)
169
- self.fc = nn.Linear(512 * IBasicBlock.expansion * 7 * 7, embedding_dim)
170
- self.features = nn.BatchNorm1d(embedding_dim, eps=1e-5)
171
- # InsightFace convention: final BN has no bias
172
- nn.init.constant_(self.features.weight, 1.0)
173
- self.features.bias.requires_grad_(False)
174
 
175
  # Weight initialization
176
  self._initialize_weights()
177
 
178
  def _make_layer(
179
  self,
180
- block: type[IBasicBlock],
181
  planes: int,
182
  num_blocks: int,
183
  stride: int = 1,
184
  ) -> nn.Sequential:
185
  downsample = None
186
- if stride != 1 or self.inplanes != planes * block.expansion:
187
- downsample = nn.Sequential(
188
- nn.Conv2d(
189
- self.inplanes,
190
- planes * block.expansion,
191
- kernel_size=1,
192
- stride=stride,
193
- bias=False,
194
- ),
195
- nn.BatchNorm2d(planes * block.expansion, eps=1e-5),
196
  )
197
 
198
- layers = [
199
- block(self.inplanes, planes, stride, downsample, use_se=self.use_se),
200
- ]
201
- self.inplanes = planes * block.expansion
202
  for _ in range(1, num_blocks):
203
- layers.append(
204
- block(self.inplanes, planes, stride=1, use_se=self.use_se),
205
- )
206
 
207
  return nn.Sequential(*layers)
208
 
@@ -230,7 +163,6 @@ class ArcFaceBackbone(nn.Module):
230
  (B, 512) L2-normalized embeddings.
231
  """
232
  x = self.conv1(x)
233
- x = self.bn1(x)
234
  x = self.prelu(x)
235
 
236
  x = self.layer1(x)
@@ -239,7 +171,6 @@ class ArcFaceBackbone(nn.Module):
239
  x = self.layer4(x)
240
 
241
  x = self.bn2(x)
242
- x = self.dropout(x)
243
  x = torch.flatten(x, 1)
244
  x = self.fc(x)
245
  x = self.features(x)
@@ -253,60 +184,34 @@ class ArcFaceBackbone(nn.Module):
253
  # Pretrained weight loading
254
  # ---------------------------------------------------------------------------
255
 
256
- # Known locations where InsightFace buffalo_l backbone.pth may live
257
  _KNOWN_WEIGHT_PATHS = [
258
- Path.home() / ".insightface" / "models" / "buffalo_l" / "w600k_r50.onnx",
259
- Path.home() / ".insightface" / "models" / "buffalo_l" / "backbone.pth",
260
- # Common manual download location
261
  Path.home() / ".cache" / "arcface" / "backbone.pth",
 
262
  ]
263
 
264
- # Glint360K R50 weights URL (InsightFace official release)
265
- _WEIGHT_URL = (
266
- "https://github.com/deepinsight/insightface/releases/download/"
267
- "v0.7/glint360k_cosface_r50_fp16_0.1-backbone.pth"
268
- )
269
-
270
 
271
  def _find_pretrained_weights() -> Path | None:
272
  """Search known locations for pretrained IResNet-50 weights."""
273
  for p in _KNOWN_WEIGHT_PATHS:
274
- if p.exists() and p.suffix == ".pth":
275
  return p
276
  return None
277
 
278
 
279
- def _try_download_weights(dest: Path) -> bool:
280
- """Attempt to download pretrained weights from the InsightFace release."""
281
- try:
282
- import urllib.request
283
-
284
- dest.parent.mkdir(parents=True, exist_ok=True)
285
- logger.info("Downloading ArcFace IResNet-50 weights from %s ...", _WEIGHT_URL)
286
- urllib.request.urlretrieve(_WEIGHT_URL, str(dest))
287
- logger.info("Downloaded to %s", dest)
288
- return True
289
- except Exception as e:
290
- logger.warning("Failed to download ArcFace weights: %s", e)
291
- return False
292
-
293
-
294
  def load_pretrained_weights(
295
  model: ArcFaceBackbone,
296
  weights_path: str | None = None,
297
- download: bool = True,
298
  ) -> bool:
299
  """Load pretrained InsightFace IResNet-50 weights into the model.
300
 
301
- InsightFace distributes backbone weights as PyTorch state dicts. The key
302
- names match our module structure exactly (both follow the IResNet
303
- convention), so no key remapping is needed in most cases.
304
 
305
  Args:
306
  model: An ``ArcFaceBackbone`` instance.
307
  weights_path: Explicit path to a ``.pth`` file. If ``None``, searches
308
- known locations and optionally downloads.
309
- download: Whether to attempt downloading if no local weights found.
310
 
311
  Returns:
312
  ``True`` if weights were loaded successfully, ``False`` otherwise
@@ -323,11 +228,6 @@ def load_pretrained_weights(
323
  if path is None:
324
  path = _find_pretrained_weights()
325
 
326
- if path is None and download:
327
- dest = Path.home() / ".cache" / "arcface" / "backbone.pth"
328
- if _try_download_weights(dest):
329
- path = dest
330
-
331
  if path is None:
332
  warnings.warn(
333
  "No pretrained ArcFace weights found. The model will use random "
@@ -346,7 +246,7 @@ def load_pretrained_weights(
346
  if "state_dict" in state_dict:
347
  state_dict = state_dict["state_dict"]
348
 
349
- # Try direct load first (InsightFace uses the same key names)
350
  try:
351
  model.load_state_dict(state_dict, strict=True)
352
  logger.info("Loaded ArcFace weights (strict match)")
@@ -354,15 +254,12 @@ def load_pretrained_weights(
354
  except RuntimeError:
355
  pass
356
 
357
- # Try non-strict load (some checkpoints have extra keys like the
358
- # classification head 'fc_angular.*' or use 'output_layer' instead
359
- # of 'features' for the final BN)
360
  try:
361
  # Remap common differences
362
  remapped = {}
363
  for k, v in state_dict.items():
364
  new_k = k
365
- # Some checkpoints use 'output_layer' for the final BatchNorm1d
366
  if k.startswith("output_layer."):
367
  new_k = k.replace("output_layer.", "features.")
368
  remapped[new_k] = v
@@ -370,8 +267,7 @@ def load_pretrained_weights(
370
  missing, unexpected = model.load_state_dict(remapped, strict=False)
371
  if missing:
372
  logger.warning(
373
- "Missing keys when loading ArcFace weights (may be OK if only "
374
- "classification head keys): %s",
375
  missing[:10],
376
  )
377
  if unexpected:
@@ -380,7 +276,8 @@ def load_pretrained_weights(
380
  return True
381
  except Exception as e:
382
  warnings.warn(
383
- f"Failed to load ArcFace weights from {path}: {e}. Using random initialization.",
 
384
  UserWarning,
385
  stacklevel=2,
386
  )
@@ -391,7 +288,6 @@ def load_pretrained_weights(
391
  # Differentiable face alignment
392
  # ---------------------------------------------------------------------------
393
 
394
-
395
  def align_face(
396
  images: torch.Tensor,
397
  size: int = 112,
@@ -414,29 +310,21 @@ def align_face(
414
  """
415
  B, C, H, W = images.shape
416
 
417
- if size == H and size == W:
418
  return images
419
 
420
  # Crop fraction: keep central 80% to remove background padding
421
  crop_frac = 0.8
422
 
423
  # Build a normalized grid [-1, 1] covering the center crop region
424
- # The grid maps output pixels to input pixel locations
425
  half_crop = crop_frac / 2.0
426
- # grid_sample expects coordinates in [-1, 1] where -1 is top-left, +1 is bottom-right
427
- # Center crop: map [-1, 1] output range to [-crop_frac, +crop_frac] input range
428
  theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype)
429
- theta[:, 0, 0] = half_crop # x scale
430
- theta[:, 1, 1] = half_crop # y scale
431
- # translation stays 0 (centered)
432
 
433
  grid = F.affine_grid(theta, [B, C, size, size], align_corners=False)
434
  aligned = F.grid_sample(
435
- images,
436
- grid,
437
- mode="bilinear",
438
- padding_mode="border",
439
- align_corners=False,
440
  )
441
  return aligned
442
 
@@ -447,7 +335,7 @@ def align_face_no_crop(
447
  ) -> torch.Tensor:
448
  """Resize face images to (size x size) without cropping, differentiably.
449
 
450
- Simple bilinear resize using ``F.grid_sample`` for gradient flow. Use
451
  this when images are already tightly cropped faces.
452
 
453
  Args:
@@ -460,10 +348,7 @@ def align_face_no_crop(
460
  if images.shape[-2] == size and images.shape[-1] == size:
461
  return images
462
  return F.interpolate(
463
- images,
464
- size=(size, size),
465
- mode="bilinear",
466
- align_corners=False,
467
  )
468
 
469
 
@@ -471,7 +356,6 @@ def align_face_no_crop(
471
  # ArcFaceLoss: differentiable identity preservation loss
472
  # ---------------------------------------------------------------------------
473
 
474
-
475
  class ArcFaceLoss(nn.Module):
476
  """Differentiable identity loss using PyTorch-native ArcFace.
477
 
@@ -503,7 +387,7 @@ class ArcFaceLoss(nn.Module):
503
  device: Device to place the backbone on. If ``None``, determined
504
  from the first forward call.
505
  weights_path: Path to pretrained backbone.pth. If ``None``,
506
- searches known locations and attempts download.
507
  crop_face: Whether to center-crop images before embedding.
508
  Set ``False`` if images are already 112x112 face crops.
509
  """
@@ -532,7 +416,6 @@ class ArcFaceLoss(nn.Module):
532
  # Move to device and freeze
533
  self.backbone = self.backbone.to(device)
534
  self.backbone.eval()
535
- # Freeze all parameters -- we do NOT want to train ArcFace
536
  for param in self.backbone.parameters():
537
  param.requires_grad_(False)
538
 
@@ -547,7 +430,10 @@ class ArcFaceLoss(nn.Module):
547
  Returns:
548
  (B, 3, 112, 112) in [-1, 1].
549
  """
550
- x = align_face(images, size=112) if self.crop_face else align_face_no_crop(images, size=112)
 
 
 
551
 
552
  # Normalize from [0, 1] to [-1, 1]
553
  x = x * 2.0 - 1.0
@@ -560,10 +446,6 @@ class ArcFaceLoss(nn.Module):
560
  ) -> torch.Tensor:
561
  """Extract ArcFace embeddings.
562
 
563
- The backbone is in eval mode with frozen parameters, but when
564
- ``enable_grad=True`` we allow gradient computation through the
565
- forward pass (important for the predicted images).
566
-
567
  Args:
568
  images: (B, 3, 112, 112) in [-1, 1].
569
  enable_grad: If ``True``, gradients flow through the backbone's
@@ -573,12 +455,6 @@ class ArcFaceLoss(nn.Module):
573
  (B, 512) L2-normalized embeddings.
574
  """
575
  if enable_grad:
576
- # Gradients flow through the backbone forward pass so that
577
- # the generator receives gradient signal from the identity loss.
578
- # NOTE: backbone parameters are frozen (requires_grad=False), so
579
- # only the input tensor carries gradients, which is exactly what
580
- # we want -- gradients w.r.t. the predicted image, not w.r.t.
581
- # ArcFace weights.
582
  return self.backbone(images)
583
  else:
584
  with torch.no_grad():
@@ -619,16 +495,13 @@ class ArcFaceLoss(nn.Module):
619
  target_prepared = self._prepare_images(target_crop)
620
 
621
  # Extract embeddings
622
- # pred: WITH gradient flow (so generator gets identity signal)
623
  pred_emb = self._extract_embedding(pred_prepared, enable_grad=True)
624
- # target: WITHOUT gradient flow (no need to backprop through target)
625
  target_emb = self._extract_embedding(target_prepared, enable_grad=False)
626
 
627
  # Detach target to be absolutely sure no gradients leak
628
  target_emb = target_emb.detach()
629
 
630
  # Cosine similarity loss: 1 - cos_sim
631
- # Both embeddings are already L2-normalized by the backbone
632
  cosine_sim = (pred_emb * target_emb).sum(dim=1) # (B,)
633
 
634
  # Clamp to valid range (numerical safety for BF16)
@@ -642,21 +515,14 @@ class ArcFaceLoss(nn.Module):
642
  image: torch.Tensor,
643
  procedure: str,
644
  ) -> torch.Tensor:
645
- """Crop image based on surgical procedure for identity comparison.
646
-
647
- Matches the cropping logic from the original ``IdentityLoss`` in
648
- ``losses.py`` for consistency.
649
- """
650
  _, _, h, w = image.shape
651
 
652
  if procedure == "rhinoplasty":
653
- # Upper face crop (forehead to nose tip) -- exclude surgical region
654
  return image[:, :, : h * 2 // 3, :]
655
  elif procedure == "blepharoplasty":
656
- # Full face
657
  return image
658
  elif procedure == "rhytidectomy":
659
- # Upper face (above jawline)
660
  return image[:, :, : h * 3 // 4, :]
661
  else:
662
  return image
@@ -679,7 +545,6 @@ class ArcFaceLoss(nn.Module):
679
  # Convenience: create a pre-configured loss instance
680
  # ---------------------------------------------------------------------------
681
 
682
-
683
  def create_arcface_loss(
684
  device: torch.device | None = None,
685
  weights_path: str | None = None,
 
6
  This module provides a fully differentiable path so that gradients flow back
7
  through the predicted image into the ControlNet.
8
 
9
+ Architecture: IResNet-50 matching the InsightFace w600k_r50 ONNX model.
10
+ conv1(3->64, 3x3, bias) -> PReLU ->
11
+ 4 IResNet stages [3, 4, 14, 3] with channels [64, 128, 256, 512] ->
12
+ BN2d -> Flatten -> FC(512*7*7 -> 512) -> BN1d -> L2-normalize
 
13
 
14
+ Each IBasicBlock: BN -> conv3x3(bias) -> PReLU -> conv3x3(bias) + residual.
15
+ No SE module. Convolutions use bias=True.
16
 
17
+ Pretrained weights: converted from the InsightFace buffalo_l w600k_r50.onnx
18
+ model to a PyTorch state dict (backbone.pth). The conversion extracts weights
19
+ from the ONNX graph and maps them to matching PyTorch module keys.
20
 
21
  Usage in losses.py:
22
  from landmarkdiff.arcface_torch import ArcFaceLoss
 
41
  # Building blocks
42
  # ---------------------------------------------------------------------------
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  class IBasicBlock(nn.Module):
45
  """Improved basic residual block for IResNet.
46
 
47
+ Structure: BN -> conv3x3(bias) -> PReLU -> conv3x3(bias) -> + residual
48
+ Uses pre-activation style BatchNorm. Convolutions have bias=True to match
49
+ the InsightFace w600k_r50 ONNX weights.
50
  """
51
 
52
  expansion: int = 1
 
57
  planes: int,
58
  stride: int = 1,
59
  downsample: nn.Module | None = None,
 
60
  ):
61
  super().__init__()
62
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=2e-5, momentum=0.1)
63
  self.conv1 = nn.Conv2d(
64
+ inplanes, planes, kernel_size=3, stride=1, padding=1, bias=True,
 
 
 
 
 
65
  )
 
66
  self.prelu = nn.PReLU(planes)
67
  self.conv2 = nn.Conv2d(
68
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=True,
 
 
 
 
 
69
  )
 
 
 
70
  self.downsample = downsample
 
71
 
72
  def forward(self, x: torch.Tensor) -> torch.Tensor:
73
  identity = x
 
74
  out = self.bn1(x)
75
  out = self.conv1(out)
 
76
  out = self.prelu(out)
77
  out = self.conv2(out)
 
 
 
78
  if self.downsample is not None:
79
  identity = self.downsample(x)
80
+ return out + identity
 
 
81
 
82
 
83
  # ---------------------------------------------------------------------------
84
  # Backbone
85
  # ---------------------------------------------------------------------------
86
 
 
87
  class ArcFaceBackbone(nn.Module):
88
  """IResNet-50 backbone for ArcFace identity embeddings.
89
 
90
  Input: (B, 3, 112, 112) face crops normalized to [-1, 1].
91
  Output: (B, 512) L2-normalized embeddings.
92
 
93
+ Architecture matches the InsightFace w600k_r50 ONNX model exactly:
94
+ Conv(bias) -> PReLU -> 4 stages -> BN2d -> Flatten -> FC -> BN1d -> L2norm.
95
  """
96
 
97
  def __init__(
98
  self,
99
  layers: tuple[int, ...] = (3, 4, 14, 3),
 
100
  embedding_dim: int = 512,
 
101
  ):
102
  super().__init__()
103
  self.inplanes = 64
 
104
 
105
+ # Stem: conv1(bias) -> PReLU (no BN in stem)
106
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True)
 
107
  self.prelu = nn.PReLU(64)
108
 
109
  # 4 residual stages
110
+ self.layer1 = self._make_layer(64, layers[0], stride=2)
111
+ self.layer2 = self._make_layer(128, layers[1], stride=2)
112
+ self.layer3 = self._make_layer(256, layers[2], stride=2)
113
+ self.layer4 = self._make_layer(512, layers[3], stride=2)
114
+
115
+ # Head: BN2d -> Flatten -> FC -> BN1d
116
+ self.bn2 = nn.BatchNorm2d(512, eps=2e-5, momentum=0.1)
117
+ self.fc = nn.Linear(512 * 7 * 7, embedding_dim)
118
+ self.features = nn.BatchNorm1d(embedding_dim, eps=2e-5, momentum=0.1)
 
 
 
 
119
 
120
  # Weight initialization
121
  self._initialize_weights()
122
 
123
  def _make_layer(
124
  self,
 
125
  planes: int,
126
  num_blocks: int,
127
  stride: int = 1,
128
  ) -> nn.Sequential:
129
  downsample = None
130
+ if stride != 1 or self.inplanes != planes:
131
+ downsample = nn.Conv2d(
132
+ self.inplanes, planes, kernel_size=1, stride=stride, bias=True,
 
 
 
 
 
 
 
133
  )
134
 
135
+ layers = [IBasicBlock(self.inplanes, planes, stride, downsample)]
136
+ self.inplanes = planes
 
 
137
  for _ in range(1, num_blocks):
138
+ layers.append(IBasicBlock(self.inplanes, planes))
 
 
139
 
140
  return nn.Sequential(*layers)
141
 
 
163
  (B, 512) L2-normalized embeddings.
164
  """
165
  x = self.conv1(x)
 
166
  x = self.prelu(x)
167
 
168
  x = self.layer1(x)
 
171
  x = self.layer4(x)
172
 
173
  x = self.bn2(x)
 
174
  x = torch.flatten(x, 1)
175
  x = self.fc(x)
176
  x = self.features(x)
 
184
  # Pretrained weight loading
185
  # ---------------------------------------------------------------------------
186
 
187
+ # Known locations where converted backbone.pth may live
188
  _KNOWN_WEIGHT_PATHS = [
 
 
 
189
  Path.home() / ".cache" / "arcface" / "backbone.pth",
190
+ Path.home() / ".insightface" / "models" / "buffalo_l" / "backbone.pth",
191
  ]
192
 
 
 
 
 
 
 
193
 
194
  def _find_pretrained_weights() -> Path | None:
195
  """Search known locations for pretrained IResNet-50 weights."""
196
  for p in _KNOWN_WEIGHT_PATHS:
197
+ if p.exists() and p.suffix == ".pth" and p.stat().st_size > 0:
198
  return p
199
  return None
200
 
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  def load_pretrained_weights(
203
  model: ArcFaceBackbone,
204
  weights_path: str | None = None,
 
205
  ) -> bool:
206
  """Load pretrained InsightFace IResNet-50 weights into the model.
207
 
208
+ Weights are a PyTorch state dict converted from the InsightFace
209
+ w600k_r50.onnx model. Key names match our module structure exactly.
 
210
 
211
  Args:
212
  model: An ``ArcFaceBackbone`` instance.
213
  weights_path: Explicit path to a ``.pth`` file. If ``None``, searches
214
+ known locations.
 
215
 
216
  Returns:
217
  ``True`` if weights were loaded successfully, ``False`` otherwise
 
228
  if path is None:
229
  path = _find_pretrained_weights()
230
 
 
 
 
 
 
231
  if path is None:
232
  warnings.warn(
233
  "No pretrained ArcFace weights found. The model will use random "
 
246
  if "state_dict" in state_dict:
247
  state_dict = state_dict["state_dict"]
248
 
249
+ # Try direct load first
250
  try:
251
  model.load_state_dict(state_dict, strict=True)
252
  logger.info("Loaded ArcFace weights (strict match)")
 
254
  except RuntimeError:
255
  pass
256
 
257
+ # Try non-strict load (some checkpoints may have extra keys)
 
 
258
  try:
259
  # Remap common differences
260
  remapped = {}
261
  for k, v in state_dict.items():
262
  new_k = k
 
263
  if k.startswith("output_layer."):
264
  new_k = k.replace("output_layer.", "features.")
265
  remapped[new_k] = v
 
267
  missing, unexpected = model.load_state_dict(remapped, strict=False)
268
  if missing:
269
  logger.warning(
270
+ "Missing keys when loading ArcFace weights: %s",
 
271
  missing[:10],
272
  )
273
  if unexpected:
 
276
  return True
277
  except Exception as e:
278
  warnings.warn(
279
+ f"Failed to load ArcFace weights from {path}: {e}. "
280
+ "Using random initialization.",
281
  UserWarning,
282
  stacklevel=2,
283
  )
 
288
  # Differentiable face alignment
289
  # ---------------------------------------------------------------------------
290
 
 
291
  def align_face(
292
  images: torch.Tensor,
293
  size: int = 112,
 
310
  """
311
  B, C, H, W = images.shape
312
 
313
+ if H == size and W == size:
314
  return images
315
 
316
  # Crop fraction: keep central 80% to remove background padding
317
  crop_frac = 0.8
318
 
319
  # Build a normalized grid [-1, 1] covering the center crop region
 
320
  half_crop = crop_frac / 2.0
 
 
321
  theta = torch.zeros(B, 2, 3, device=images.device, dtype=images.dtype)
322
+ theta[:, 0, 0] = half_crop # x scale
323
+ theta[:, 1, 1] = half_crop # y scale
 
324
 
325
  grid = F.affine_grid(theta, [B, C, size, size], align_corners=False)
326
  aligned = F.grid_sample(
327
+ images, grid, mode="bilinear", padding_mode="border", align_corners=False,
 
 
 
 
328
  )
329
  return aligned
330
 
 
335
  ) -> torch.Tensor:
336
  """Resize face images to (size x size) without cropping, differentiably.
337
 
338
+ Simple bilinear resize using ``F.interpolate`` for gradient flow. Use
339
  this when images are already tightly cropped faces.
340
 
341
  Args:
 
348
  if images.shape[-2] == size and images.shape[-1] == size:
349
  return images
350
  return F.interpolate(
351
+ images, size=(size, size), mode="bilinear", align_corners=False,
 
 
 
352
  )
353
 
354
 
 
356
  # ArcFaceLoss: differentiable identity preservation loss
357
  # ---------------------------------------------------------------------------
358
 
 
359
  class ArcFaceLoss(nn.Module):
360
  """Differentiable identity loss using PyTorch-native ArcFace.
361
 
 
387
  device: Device to place the backbone on. If ``None``, determined
388
  from the first forward call.
389
  weights_path: Path to pretrained backbone.pth. If ``None``,
390
+ searches known locations.
391
  crop_face: Whether to center-crop images before embedding.
392
  Set ``False`` if images are already 112x112 face crops.
393
  """
 
416
  # Move to device and freeze
417
  self.backbone = self.backbone.to(device)
418
  self.backbone.eval()
 
419
  for param in self.backbone.parameters():
420
  param.requires_grad_(False)
421
 
 
430
  Returns:
431
  (B, 3, 112, 112) in [-1, 1].
432
  """
433
+ if self.crop_face:
434
+ x = align_face(images, size=112)
435
+ else:
436
+ x = align_face_no_crop(images, size=112)
437
 
438
  # Normalize from [0, 1] to [-1, 1]
439
  x = x * 2.0 - 1.0
 
446
  ) -> torch.Tensor:
447
  """Extract ArcFace embeddings.
448
 
 
 
 
 
449
  Args:
450
  images: (B, 3, 112, 112) in [-1, 1].
451
  enable_grad: If ``True``, gradients flow through the backbone's
 
455
  (B, 512) L2-normalized embeddings.
456
  """
457
  if enable_grad:
 
 
 
 
 
 
458
  return self.backbone(images)
459
  else:
460
  with torch.no_grad():
 
495
  target_prepared = self._prepare_images(target_crop)
496
 
497
  # Extract embeddings
 
498
  pred_emb = self._extract_embedding(pred_prepared, enable_grad=True)
 
499
  target_emb = self._extract_embedding(target_prepared, enable_grad=False)
500
 
501
  # Detach target to be absolutely sure no gradients leak
502
  target_emb = target_emb.detach()
503
 
504
  # Cosine similarity loss: 1 - cos_sim
 
505
  cosine_sim = (pred_emb * target_emb).sum(dim=1) # (B,)
506
 
507
  # Clamp to valid range (numerical safety for BF16)
 
515
  image: torch.Tensor,
516
  procedure: str,
517
  ) -> torch.Tensor:
518
+ """Crop image based on surgical procedure for identity comparison."""
 
 
 
 
519
  _, _, h, w = image.shape
520
 
521
  if procedure == "rhinoplasty":
 
522
  return image[:, :, : h * 2 // 3, :]
523
  elif procedure == "blepharoplasty":
 
524
  return image
525
  elif procedure == "rhytidectomy":
 
526
  return image[:, :, : h * 3 // 4, :]
527
  else:
528
  return image
 
545
  # Convenience: create a pre-configured loss instance
546
  # ---------------------------------------------------------------------------
547
 
 
548
  def create_arcface_loss(
549
  device: torch.device | None = None,
550
  weights_path: str | None = None,