dreamlessx commited on
Commit
59c75b7
·
verified ·
1 Parent(s): 6fbb3ae

Upload landmarkdiff/evaluation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/evaluation.py +348 -0
landmarkdiff/evaluation.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation metrics: FID, LPIPS, NME, ArcFace sim, SSIM.
2
+
3
+ Stratified by Fitzpatrick skin type (I-VI) via ITA thresholding.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from dataclasses import dataclass, field
9
+
10
+ import numpy as np
11
+
12
+ try:
13
+ import cv2
14
+ except ImportError:
15
+ cv2 = None # type: ignore[assignment]
16
+
17
+
18
+ @dataclass
19
+ class EvalMetrics:
20
+ """Computed evaluation metrics for a batch of generated images."""
21
+
22
+ fid: float = 0.0
23
+ lpips: float = 0.0
24
+ nme: float = 0.0 # Normalized Mean landmark Error
25
+ identity_sim: float = 0.0 # ArcFace cosine similarity
26
+ ssim: float = 0.0
27
+
28
+ # Per-Fitzpatrick breakdown (all metrics stratified)
29
+ fid_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
30
+ nme_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
31
+ lpips_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
32
+ ssim_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
33
+ identity_sim_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
34
+ count_by_fitzpatrick: dict[str, int] = field(default_factory=dict)
35
+
36
+ # Per-procedure breakdown
37
+ nme_by_procedure: dict[str, float] = field(default_factory=dict)
38
+ lpips_by_procedure: dict[str, float] = field(default_factory=dict)
39
+ ssim_by_procedure: dict[str, float] = field(default_factory=dict)
40
+
41
+ def summary(self) -> str:
42
+ lines = [
43
+ f"FID: {self.fid:.2f}",
44
+ f"LPIPS: {self.lpips:.4f}",
45
+ f"NME: {self.nme:.4f}",
46
+ f"Identity Sim: {self.identity_sim:.4f}",
47
+ f"SSIM: {self.ssim:.4f}",
48
+ ]
49
+ if self.count_by_fitzpatrick:
50
+ lines.append("\nBy Fitzpatrick Type:")
51
+ for ftype in sorted(self.count_by_fitzpatrick):
52
+ n = self.count_by_fitzpatrick[ftype]
53
+ parts = [f" Type {ftype} (n={n}):"]
54
+ if ftype in self.lpips_by_fitzpatrick:
55
+ parts.append(f"LPIPS={self.lpips_by_fitzpatrick[ftype]:.4f}")
56
+ if ftype in self.ssim_by_fitzpatrick:
57
+ parts.append(f"SSIM={self.ssim_by_fitzpatrick[ftype]:.4f}")
58
+ if ftype in self.nme_by_fitzpatrick:
59
+ parts.append(f"NME={self.nme_by_fitzpatrick[ftype]:.4f}")
60
+ if ftype in self.identity_sim_by_fitzpatrick:
61
+ parts.append(f"ID={self.identity_sim_by_fitzpatrick[ftype]:.4f}")
62
+ lines.append(" ".join(parts))
63
+ if self.fid_by_fitzpatrick:
64
+ lines.append("\nFID by Fitzpatrick:")
65
+ for k, v in sorted(self.fid_by_fitzpatrick.items()):
66
+ lines.append(f" Type {k}: {v:.2f}")
67
+ return "\n".join(lines)
68
+
69
+ def to_dict(self) -> dict:
70
+ """Convert to flat dictionary for JSON/CSV export."""
71
+ d = {
72
+ "fid": self.fid,
73
+ "lpips": self.lpips,
74
+ "nme": self.nme,
75
+ "identity_sim": self.identity_sim,
76
+ "ssim": self.ssim,
77
+ }
78
+ for ftype in sorted(self.count_by_fitzpatrick):
79
+ prefix = f"fitz_{ftype}"
80
+ d[f"{prefix}_count"] = self.count_by_fitzpatrick.get(ftype, 0)
81
+ d[f"{prefix}_lpips"] = self.lpips_by_fitzpatrick.get(ftype, 0.0)
82
+ d[f"{prefix}_ssim"] = self.ssim_by_fitzpatrick.get(ftype, 0.0)
83
+ d[f"{prefix}_nme"] = self.nme_by_fitzpatrick.get(ftype, 0.0)
84
+ d[f"{prefix}_identity"] = self.identity_sim_by_fitzpatrick.get(ftype, 0.0)
85
+ for proc in sorted(self.nme_by_procedure):
86
+ d[f"proc_{proc}_nme"] = self.nme_by_procedure.get(proc, 0.0)
87
+ d[f"proc_{proc}_lpips"] = self.lpips_by_procedure.get(proc, 0.0)
88
+ d[f"proc_{proc}_ssim"] = self.ssim_by_procedure.get(proc, 0.0)
89
+ return d
90
+
91
+
92
+ def classify_fitzpatrick_ita(image: np.ndarray) -> str:
93
+ """Fitzpatrick I-VI from ITA angle (Chardon et al. 1991 thresholds)."""
94
+ if cv2 is None:
95
+ raise ImportError("opencv-python is required for Fitzpatrick classification")
96
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
97
+
98
+ # Sample from face center region (avoid background)
99
+ h, w = image.shape[:2]
100
+ center = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
101
+
102
+ L_mean = center[:, :, 0].mean() * 100 / 255 # scale to 0-100
103
+ b_mean = center[:, :, 2].mean() - 128 # center around 0
104
+
105
+ if abs(b_mean) < 1e-6:
106
+ b_mean = 1e-6
107
+
108
+ ita = np.arctan2(L_mean - 50, b_mean) * (180 / np.pi)
109
+
110
+ if ita > 55:
111
+ return "I"
112
+ elif ita > 41:
113
+ return "II"
114
+ elif ita > 28:
115
+ return "III"
116
+ elif ita > 10:
117
+ return "IV"
118
+ elif ita > -30:
119
+ return "V"
120
+ else:
121
+ return "VI"
122
+
123
+
124
+ def compute_nme(
125
+ pred_landmarks: np.ndarray,
126
+ target_landmarks: np.ndarray,
127
+ left_eye_idx: int = 33,
128
+ right_eye_idx: int = 263,
129
+ ) -> float:
130
+ """Compute Normalized Mean Error for landmarks."""
131
+ iod = np.linalg.norm(
132
+ target_landmarks[left_eye_idx] - target_landmarks[right_eye_idx]
133
+ )
134
+ if iod < 1.0:
135
+ iod = 1.0
136
+
137
+ distances = np.linalg.norm(pred_landmarks - target_landmarks, axis=1)
138
+ return float(np.mean(distances) / iod)
139
+
140
+
141
+ def compute_ssim(
142
+ pred: np.ndarray,
143
+ target: np.ndarray,
144
+ ) -> float:
145
+ """SSIM via skimage, falls back to global SSIM if not installed."""
146
+ try:
147
+ from skimage.metrics import structural_similarity
148
+ # Convert to grayscale if color, or compute per-channel
149
+ if pred.ndim == 3 and pred.shape[2] == 3:
150
+ return float(structural_similarity(pred, target, channel_axis=2, data_range=255))
151
+ else:
152
+ return float(structural_similarity(pred, target, data_range=255))
153
+ except ImportError:
154
+ # Fallback: simple global SSIM (not publication-quality)
155
+ pred_f = pred.astype(np.float64)
156
+ target_f = target.astype(np.float64)
157
+
158
+ mu_p = np.mean(pred_f)
159
+ mu_t = np.mean(target_f)
160
+ sigma_p = np.std(pred_f)
161
+ sigma_t = np.std(target_f)
162
+ sigma_pt = np.mean((pred_f - mu_p) * (target_f - mu_t))
163
+
164
+ C1 = (0.01 * 255) ** 2
165
+ C2 = (0.03 * 255) ** 2
166
+
167
+ ssim_val = (
168
+ (2 * mu_p * mu_t + C1) * (2 * sigma_pt + C2)
169
+ ) / (
170
+ (mu_p ** 2 + mu_t ** 2 + C1) * (sigma_p ** 2 + sigma_t ** 2 + C2)
171
+ )
172
+ return float(ssim_val)
173
+
174
+
175
+ _LPIPS_FN = None
176
+
177
+
178
+ def _get_lpips_fn():
179
+ """Get or create singleton LPIPS model."""
180
+ global _LPIPS_FN
181
+ if _LPIPS_FN is None:
182
+ import lpips
183
+ _LPIPS_FN = lpips.LPIPS(net="alex", verbose=False)
184
+ _LPIPS_FN.eval()
185
+ return _LPIPS_FN
186
+
187
+
188
+ def compute_lpips(
189
+ pred: np.ndarray,
190
+ target: np.ndarray,
191
+ ) -> float:
192
+ """LPIPS perceptual distance (lower = more similar)."""
193
+ try:
194
+ import lpips
195
+ import torch
196
+ except ImportError:
197
+ return 0.0
198
+
199
+ _lpips_fn = _get_lpips_fn()
200
+
201
+ def _to_tensor(img: np.ndarray) -> torch.Tensor:
202
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0)
203
+ return t * 2 - 1 # LPIPS expects [-1, 1]
204
+
205
+ with torch.no_grad():
206
+ score = _lpips_fn(_to_tensor(pred), _to_tensor(target))
207
+ return float(score.item())
208
+
209
+
210
+ def compute_fid(
211
+ real_dir: str,
212
+ generated_dir: str,
213
+ ) -> float:
214
+ """Compute FID between directories of real and generated images."""
215
+ try:
216
+ from torch_fidelity import calculate_metrics
217
+ except ImportError:
218
+ raise ImportError(
219
+ "torch-fidelity is required for FID. Install with: pip install torch-fidelity"
220
+ )
221
+
222
+ metrics = calculate_metrics(
223
+ input1=generated_dir,
224
+ input2=real_dir,
225
+ cuda=True,
226
+ fid=True,
227
+ verbose=False,
228
+ )
229
+ return float(metrics["frechet_inception_distance"])
230
+
231
+
232
+ def compute_identity_similarity(
233
+ pred: np.ndarray,
234
+ target: np.ndarray,
235
+ ) -> float:
236
+ """ArcFace cosine sim [0,1]. Falls back to SSIM if no InsightFace."""
237
+ try:
238
+ from insightface.app import FaceAnalysis
239
+ app = FaceAnalysis(
240
+ name="buffalo_l",
241
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
242
+ )
243
+ app.prepare(ctx_id=-1, det_size=(320, 320))
244
+
245
+ pred_bgr = pred if pred.shape[2] == 3 else cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
246
+ target_bgr = target if target.shape[2] == 3 else cv2.cvtColor(target, cv2.COLOR_RGB2BGR)
247
+
248
+ pred_faces = app.get(pred_bgr)
249
+ target_faces = app.get(target_bgr)
250
+
251
+ if pred_faces and target_faces:
252
+ pred_emb = pred_faces[0].embedding
253
+ target_emb = target_faces[0].embedding
254
+ sim = np.dot(pred_emb, target_emb) / (
255
+ np.linalg.norm(pred_emb) * np.linalg.norm(target_emb) + 1e-8
256
+ )
257
+ return float(np.clip(sim, 0, 1))
258
+ except Exception:
259
+ pass
260
+
261
+ # Fallback: SSIM-based proxy
262
+ return compute_ssim(pred, target)
263
+
264
+
265
+ def evaluate_batch(
266
+ predictions: list[np.ndarray],
267
+ targets: list[np.ndarray],
268
+ pred_landmarks: list[np.ndarray] | None = None,
269
+ target_landmarks: list[np.ndarray] | None = None,
270
+ procedures: list[str] | None = None,
271
+ compute_identity: bool = False,
272
+ ) -> EvalMetrics:
273
+ """Evaluate a batch of predicted vs target images."""
274
+ n = len(predictions)
275
+ ssim_scores = []
276
+ lpips_scores = []
277
+ nme_scores = []
278
+ identity_scores = []
279
+ fitz_groups: dict[str, list[int]] = {}
280
+ proc_groups: dict[str, list[int]] = {}
281
+
282
+ for i in range(n):
283
+ ssim_scores.append(compute_ssim(predictions[i], targets[i]))
284
+ lpips_scores.append(compute_lpips(predictions[i], targets[i]))
285
+
286
+ if pred_landmarks is not None and target_landmarks is not None:
287
+ nme_scores.append(compute_nme(pred_landmarks[i], target_landmarks[i]))
288
+
289
+ if compute_identity:
290
+ identity_scores.append(compute_identity_similarity(predictions[i], targets[i]))
291
+
292
+ # Fitzpatrick classification
293
+ if cv2 is not None:
294
+ try:
295
+ fitz = classify_fitzpatrick_ita(targets[i])
296
+ fitz_groups.setdefault(fitz, []).append(i)
297
+ except Exception:
298
+ pass
299
+
300
+ # Procedure grouping
301
+ if procedures is not None and i < len(procedures):
302
+ proc_groups.setdefault(procedures[i], []).append(i)
303
+
304
+ metrics = EvalMetrics(
305
+ ssim=float(np.mean(ssim_scores)) if ssim_scores else 0.0,
306
+ lpips=float(np.mean(lpips_scores)) if lpips_scores else 0.0,
307
+ nme=float(np.mean(nme_scores)) if nme_scores else 0.0,
308
+ identity_sim=float(np.mean(identity_scores)) if identity_scores else 0.0,
309
+ )
310
+
311
+ # Full Fitzpatrick stratification for ALL metrics
312
+ for ftype, indices in fitz_groups.items():
313
+ metrics.count_by_fitzpatrick[ftype] = len(indices)
314
+
315
+ group_lpips = [lpips_scores[i] for i in indices]
316
+ if group_lpips:
317
+ metrics.lpips_by_fitzpatrick[ftype] = float(np.mean(group_lpips))
318
+
319
+ group_ssim = [ssim_scores[i] for i in indices]
320
+ if group_ssim:
321
+ metrics.ssim_by_fitzpatrick[ftype] = float(np.mean(group_ssim))
322
+
323
+ if nme_scores:
324
+ group_nme = [nme_scores[i] for i in indices if i < len(nme_scores)]
325
+ if group_nme:
326
+ metrics.nme_by_fitzpatrick[ftype] = float(np.mean(group_nme))
327
+
328
+ if identity_scores:
329
+ group_id = [identity_scores[i] for i in indices if i < len(identity_scores)]
330
+ if group_id:
331
+ metrics.identity_sim_by_fitzpatrick[ftype] = float(np.mean(group_id))
332
+
333
+ # Per-procedure breakdown
334
+ for proc, indices in proc_groups.items():
335
+ group_lpips = [lpips_scores[i] for i in indices]
336
+ if group_lpips:
337
+ metrics.lpips_by_procedure[proc] = float(np.mean(group_lpips))
338
+
339
+ group_ssim = [ssim_scores[i] for i in indices]
340
+ if group_ssim:
341
+ metrics.ssim_by_procedure[proc] = float(np.mean(group_ssim))
342
+
343
+ if nme_scores:
344
+ group_nme = [nme_scores[i] for i in indices if i < len(nme_scores)]
345
+ if group_nme:
346
+ metrics.nme_by_procedure[proc] = float(np.mean(group_nme))
347
+
348
+ return metrics