dreamlessx commited on
Commit
0810860
·
verified ·
1 Parent(s): 7931476

Upload landmarkdiff/inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/inference.py +120 -34
landmarkdiff/inference.py CHANGED
@@ -1,12 +1,12 @@
1
  """Inference pipeline for surgical outcome prediction.
2
 
3
- Modes:
4
- 1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (HF auth + GPU)
5
- 2. ControlNet + IP-Adapter: ControlNet w/ identity preservation
6
- 3. Img2Img: SD1.5 img2img with mask compositing (MPS ok, no auth)
7
- 4. TPS-only: geometric warp, no diffusion, instant
8
 
9
- Works on MPS (Apple Silicon), CUDA, and CPU.
10
  """
11
 
12
  from __future__ import annotations
@@ -38,7 +38,7 @@ def get_device() -> torch.device:
38
  def numpy_to_pil(arr: np.ndarray) -> Image.Image:
39
  if len(arr.shape) == 2:
40
  return Image.fromarray(arr, mode="L")
41
- return Image.fromarray(arr[:, :, ::-1])
42
 
43
 
44
  def pil_to_numpy(img: Image.Image) -> np.ndarray:
@@ -86,7 +86,12 @@ def mask_composite(
86
  mask: np.ndarray,
87
  use_laplacian: bool = True,
88
  ) -> np.ndarray:
89
- """Blend warped region into original via Laplacian pyramid + LAB skin-tone match."""
 
 
 
 
 
90
  mask_f = mask.astype(np.float32)
91
  if mask_f.max() > 1.0:
92
  mask_f = mask_f / 255.0
@@ -112,7 +117,11 @@ def mask_composite(
112
 
113
 
114
  def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
115
- """LAB-space color transfer so warped region matches original skin tone."""
 
 
 
 
116
  mask_bool = mask > 0.3
117
  if not np.any(mask_bool):
118
  return source
@@ -120,7 +129,7 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
120
  src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
121
  tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
122
 
123
- # match per-channel stats in masked region
124
  for ch in range(3):
125
  src_vals = src_lab[:, :, ch][mask_bool]
126
  tgt_vals = tgt_lab[:, :, ch][mask_bool]
@@ -128,7 +137,7 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
128
  src_mean, src_std = np.mean(src_vals), np.std(src_vals) + 1e-6
129
  tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + 1e-6
130
 
131
- # shift+scale to match target distribution
132
  src_lab[:, :, ch] = np.where(
133
  mask_bool,
134
  (src_lab[:, :, ch] - src_mean) * (tgt_std / src_std) + tgt_mean,
@@ -139,13 +148,15 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
139
  return cv2.cvtColor(src_lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
140
 
141
 
142
- def poisson_blend(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
143
- """Poisson blend - just delegates to mask_composite (more reliable)."""
144
- return mask_composite(source, target, mask)
145
-
146
-
147
  class LandmarkDiffPipeline:
148
- """Image -> landmarks -> manipulate -> generate."""
 
 
 
 
 
 
 
149
 
150
  # Default IP-Adapter model for SD1.5 face identity
151
  IP_ADAPTER_REPO = "h94/IP-Adapter"
@@ -157,16 +168,29 @@ class LandmarkDiffPipeline:
157
  self,
158
  mode: str = "img2img",
159
  controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
 
160
  base_model_id: str | None = None,
161
  device: Optional[torch.device] = None,
162
  dtype: Optional[torch.dtype] = None,
163
  ip_adapter_scale: float = 0.6,
164
  clinical_flags: Optional["ClinicalFlags"] = None,
 
165
  ):
166
  self.mode = mode
167
  self.device = device or get_device()
168
  self.ip_adapter_scale = ip_adapter_scale
169
  self.clinical_flags = clinical_flags
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if self.device.type == "mps":
172
  self.dtype = torch.float32
@@ -188,7 +212,7 @@ class LandmarkDiffPipeline:
188
 
189
  def load(self) -> None:
190
  if self.mode == "tps":
191
- print("TPS mode - no model to load")
192
  return
193
  if self.mode in ("controlnet", "controlnet_ip"):
194
  self._load_controlnet()
@@ -204,10 +228,21 @@ class LandmarkDiffPipeline:
204
  DPMSolverMultistepScheduler,
205
  )
206
 
207
- print(f"Loading ControlNet from {self.controlnet_id}...")
208
- controlnet = ControlNetModel.from_pretrained(
209
- self.controlnet_id, subfolder="diffusion_sd15", torch_dtype=self.dtype,
210
- )
 
 
 
 
 
 
 
 
 
 
 
211
  print(f"Loading base model from {self.base_model_id}...")
212
  self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
213
  self.base_model_id,
@@ -216,19 +251,23 @@ class LandmarkDiffPipeline:
216
  safety_checker=None,
217
  requires_safety_checker=False,
218
  )
219
- # DPM++ 2M Karras - better skin than UniPC
220
  self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
221
  self._pipe.scheduler.config,
222
  algorithm_type="dpmsolver++",
223
  use_karras_sigmas=True,
224
  )
225
- # FP32 VAE decode - prevents color banding on skin
226
  if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
227
  self._pipe.vae.config.force_upcast = True
228
  self._apply_device_optimizations()
229
 
230
  def _load_ip_adapter(self) -> None:
231
- """Load IP-Adapter for identity preservation."""
 
 
 
 
232
  if self._pipe is None:
233
  raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
234
  try:
@@ -305,12 +344,41 @@ class LandmarkDiffPipeline:
305
  if face is None:
306
  raise ValueError("No face detected in image.")
307
 
308
- # face view angle for multi-view awareness
309
  view_info = estimate_face_view(face)
310
 
311
- manipulated = apply_procedure_preset(
312
- face, procedure, intensity, image_size=512, clinical_flags=flags,
313
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  landmark_img = render_landmark_image(manipulated, 512, 512)
315
  mask = generate_surgical_mask(
316
  face, procedure, 512, 512, clinical_flags=flags,
@@ -322,7 +390,7 @@ class LandmarkDiffPipeline:
322
 
323
  prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
324
 
325
- # TPS warp is always the geometric baseline
326
  tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
327
 
328
  if self.mode == "tps":
@@ -340,7 +408,7 @@ class LandmarkDiffPipeline:
340
  guidance_scale, strength, generator,
341
  )
342
 
343
- # postprocess for photorealism
344
  identity_check = None
345
  restore_used = "none"
346
  if postprocess and self.mode != "tps":
@@ -378,6 +446,7 @@ class LandmarkDiffPipeline:
378
  "ip_adapter_active": self._ip_adapter_loaded,
379
  "identity_check": identity_check,
380
  "restore_used": restore_used,
 
381
  }
382
 
383
  def _generate_controlnet(
@@ -418,7 +487,13 @@ class LandmarkDiffPipeline:
418
 
419
 
420
  def estimate_face_view(face: FaceLandmarks) -> dict:
421
- """Yaw/pitch from nose-ear and forehead-chin distances. Returns view dict."""
 
 
 
 
 
 
422
  coords = face.pixel_coords
423
  nose_tip = coords[1]
424
  left_ear = coords[234]
@@ -471,6 +546,8 @@ def run_inference(
471
  seed: int = 42,
472
  mode: str = "img2img",
473
  ip_adapter_scale: float = 0.6,
 
 
474
  ) -> None:
475
  out = Path(output_dir)
476
  out.mkdir(parents=True, exist_ok=True)
@@ -480,7 +557,11 @@ def run_inference(
480
  print(f"ERROR: Could not load {image_path}")
481
  sys.exit(1)
482
 
483
- pipe = LandmarkDiffPipeline(mode=mode, ip_adapter_scale=ip_adapter_scale)
 
 
 
 
484
  pipe.load()
485
 
486
  print(f"\nGenerating {procedure} prediction (intensity={intensity}, mode={mode})...")
@@ -517,9 +598,14 @@ if __name__ == "__main__":
517
  choices=["img2img", "controlnet", "controlnet_ip", "tps"],
518
  )
519
  parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
 
 
 
 
520
  args = parser.parse_args()
521
 
522
  run_inference(
523
  args.image, args.procedure, args.intensity, args.output,
524
- args.seed, args.mode, args.ip_adapter_scale,
 
525
  )
 
1
  """Inference pipeline for surgical outcome prediction.
2
 
3
+ Four modes:
4
+ 1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (requires HF auth + GPU)
5
+ 2. ControlNet + IP-Adapter: ControlNet with identity preservation via face embeddings
6
+ 3. Img2Img: SD1.5 img2img with mask compositing (runs on MPS, no auth needed)
7
+ 4. TPS-only: Pure geometric warp no diffusion model, instant results
8
 
9
+ Supports MPS (Apple Silicon), CUDA, and CPU backends.
10
  """
11
 
12
  from __future__ import annotations
 
38
  def numpy_to_pil(arr: np.ndarray) -> Image.Image:
39
  if len(arr.shape) == 2:
40
  return Image.fromarray(arr, mode="L")
41
+ return Image.fromarray(arr[:, :, ::-1].copy())
42
 
43
 
44
  def pil_to_numpy(img: Image.Image) -> np.ndarray:
 
86
  mask: np.ndarray,
87
  use_laplacian: bool = True,
88
  ) -> np.ndarray:
89
+ """Composite warped image into original using ONLY the mask region.
90
+
91
+ Uses Laplacian pyramid blending by default for seamless transitions.
92
+ Falls back to simple alpha blend if Laplacian unavailable.
93
+ Matches skin tone in LAB space to prevent any color shift.
94
+ """
95
  mask_f = mask.astype(np.float32)
96
  if mask_f.max() > 1.0:
97
  mask_f = mask_f / 255.0
 
117
 
118
 
119
  def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
120
+ """Match source skin tone to target within mask, preserving structure.
121
+
122
+ Works in LAB space: transfers L (luminance) and AB (color) statistics
123
+ from the original to the warped image so skin tone is preserved exactly.
124
+ """
125
  mask_bool = mask > 0.3
126
  if not np.any(mask_bool):
127
  return source
 
129
  src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
130
  tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
131
 
132
+ # Match each LAB channel's statistics in the mask region
133
  for ch in range(3):
134
  src_vals = src_lab[:, :, ch][mask_bool]
135
  tgt_vals = tgt_lab[:, :, ch][mask_bool]
 
137
  src_mean, src_std = np.mean(src_vals), np.std(src_vals) + 1e-6
138
  tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + 1e-6
139
 
140
+ # Normalize source to match target's distribution
141
  src_lab[:, :, ch] = np.where(
142
  mask_bool,
143
  (src_lab[:, :, ch] - src_mean) * (tgt_std / src_std) + tgt_mean,
 
148
  return cv2.cvtColor(src_lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
149
 
150
 
 
 
 
 
 
151
  class LandmarkDiffPipeline:
152
+ """End-to-end pipeline: image -> landmarks -> manipulate -> generate.
153
+
154
+ Modes:
155
+ - 'controlnet': CrucibleAI/ControlNetMediaPipeFace + SD1.5
156
+ - 'controlnet_ip': ControlNet + IP-Adapter for identity preservation
157
+ - 'img2img': SD1.5 img2img with mask compositing
158
+ - 'tps': Pure geometric TPS warp (no diffusion, instant)
159
+ """
160
 
161
  # Default IP-Adapter model for SD1.5 face identity
162
  IP_ADAPTER_REPO = "h94/IP-Adapter"
 
168
  self,
169
  mode: str = "img2img",
170
  controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
171
+ controlnet_checkpoint: str | None = None,
172
  base_model_id: str | None = None,
173
  device: Optional[torch.device] = None,
174
  dtype: Optional[torch.dtype] = None,
175
  ip_adapter_scale: float = 0.6,
176
  clinical_flags: Optional["ClinicalFlags"] = None,
177
+ displacement_model_path: str | None = None,
178
  ):
179
  self.mode = mode
180
  self.device = device or get_device()
181
  self.ip_adapter_scale = ip_adapter_scale
182
  self.clinical_flags = clinical_flags
183
+ self.controlnet_checkpoint = controlnet_checkpoint
184
+
185
+ # Load displacement model for data-driven manipulation
186
+ self._displacement_model = None
187
+ if displacement_model_path:
188
+ try:
189
+ from landmarkdiff.displacement_model import DisplacementModel
190
+ self._displacement_model = DisplacementModel.load(displacement_model_path)
191
+ print(f"Displacement model loaded: {self._displacement_model.procedures}")
192
+ except Exception as e:
193
+ print(f"WARNING: Failed to load displacement model: {e}")
194
 
195
  if self.device.type == "mps":
196
  self.dtype = torch.float32
 
212
 
213
  def load(self) -> None:
214
  if self.mode == "tps":
215
+ print("TPS mode no model to load")
216
  return
217
  if self.mode in ("controlnet", "controlnet_ip"):
218
  self._load_controlnet()
 
228
  DPMSolverMultistepScheduler,
229
  )
230
 
231
+ if self.controlnet_checkpoint:
232
+ # Load fine-tuned ControlNet from local checkpoint
233
+ ckpt_path = Path(self.controlnet_checkpoint)
234
+ # Support both direct path and training checkpoint structure
235
+ if (ckpt_path / "controlnet_ema").exists():
236
+ ckpt_path = ckpt_path / "controlnet_ema"
237
+ print(f"Loading fine-tuned ControlNet from {ckpt_path}...")
238
+ controlnet = ControlNetModel.from_pretrained(
239
+ str(ckpt_path), torch_dtype=self.dtype,
240
+ )
241
+ else:
242
+ print(f"Loading ControlNet from {self.controlnet_id}...")
243
+ controlnet = ControlNetModel.from_pretrained(
244
+ self.controlnet_id, subfolder="diffusion_sd15", torch_dtype=self.dtype,
245
+ )
246
  print(f"Loading base model from {self.base_model_id}...")
247
  self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
248
  self.base_model_id,
 
251
  safety_checker=None,
252
  requires_safety_checker=False,
253
  )
254
+ # DPM++ 2M Karras produces more photorealistic output than UniPC
255
  self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
256
  self._pipe.scheduler.config,
257
  algorithm_type="dpmsolver++",
258
  use_karras_sigmas=True,
259
  )
260
+ # FP32 VAE decode prevents color banding artifacts on skin tones
261
  if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
262
  self._pipe.vae.config.force_upcast = True
263
  self._apply_device_optimizations()
264
 
265
  def _load_ip_adapter(self) -> None:
266
+ """Load IP-Adapter for identity-preserving generation.
267
+
268
+ Uses h94/IP-Adapter-FaceID with CLIP image encoder to condition
269
+ generation on the input face identity.
270
+ """
271
  if self._pipe is None:
272
  raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
273
  try:
 
344
  if face is None:
345
  raise ValueError("No face detected in image.")
346
 
347
+ # Estimate face view angle for multi-view awareness
348
  view_info = estimate_face_view(face)
349
 
350
+ # Use displacement model for data-driven manipulation if available
351
+ manipulation_mode = "preset"
352
+ if self._displacement_model and procedure in self._displacement_model.procedures:
353
+ try:
354
+ from landmarkdiff.displacement_model import DisplacementModel
355
+ rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
356
+ # Map UI intensity (0-100) to displacement model intensity (0-2)
357
+ dm_intensity = intensity / 50.0 # 50 -> 1.0x mean displacement
358
+ displacement = self._displacement_model.get_displacement_field(
359
+ procedure, intensity=dm_intensity, noise_scale=0.3, rng=rng,
360
+ )
361
+ # Apply displacement to landmarks
362
+ new_lm = face.landmarks.copy()
363
+ n = min(len(new_lm), len(displacement))
364
+ new_lm[:n, 0] += displacement[:n, 0]
365
+ new_lm[:n, 1] += displacement[:n, 1]
366
+ new_lm[:, 0] = np.clip(new_lm[:, 0], 0.01, 0.99)
367
+ new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
368
+ manipulated = FaceLandmarks(
369
+ landmarks=new_lm,
370
+ image_width=512, image_height=512,
371
+ confidence=face.confidence,
372
+ )
373
+ manipulation_mode = "displacement_model"
374
+ except Exception:
375
+ manipulated = apply_procedure_preset(
376
+ face, procedure, intensity, image_size=512, clinical_flags=flags,
377
+ )
378
+ else:
379
+ manipulated = apply_procedure_preset(
380
+ face, procedure, intensity, image_size=512, clinical_flags=flags,
381
+ )
382
  landmark_img = render_landmark_image(manipulated, 512, 512)
383
  mask = generate_surgical_mask(
384
  face, procedure, 512, 512, clinical_flags=flags,
 
390
 
391
  prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
392
 
393
+ # Step 1: TPS geometric warp (always computed — the geometric baseline)
394
  tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
395
 
396
  if self.mode == "tps":
 
408
  guidance_scale, strength, generator,
409
  )
410
 
411
+ # Step 2: Post-processing for photorealism (neural + classical pipeline)
412
  identity_check = None
413
  restore_used = "none"
414
  if postprocess and self.mode != "tps":
 
446
  "ip_adapter_active": self._ip_adapter_loaded,
447
  "identity_check": identity_check,
448
  "restore_used": restore_used,
449
+ "manipulation_mode": manipulation_mode,
450
  }
451
 
452
  def _generate_controlnet(
 
487
 
488
 
489
  def estimate_face_view(face: FaceLandmarks) -> dict:
490
+ """Estimate face orientation from landmarks for multi-view awareness.
491
+
492
+ Uses the nose tip (idx 1), left ear (idx 234), and right ear (idx 454) to
493
+ estimate yaw angle. Pitch from forehead (idx 10) and chin (idx 152).
494
+
495
+ Returns dict with yaw, pitch (degrees), and view classification.
496
+ """
497
  coords = face.pixel_coords
498
  nose_tip = coords[1]
499
  left_ear = coords[234]
 
546
  seed: int = 42,
547
  mode: str = "img2img",
548
  ip_adapter_scale: float = 0.6,
549
+ controlnet_checkpoint: str | None = None,
550
+ displacement_model_path: str | None = None,
551
  ) -> None:
552
  out = Path(output_dir)
553
  out.mkdir(parents=True, exist_ok=True)
 
557
  print(f"ERROR: Could not load {image_path}")
558
  sys.exit(1)
559
 
560
+ pipe = LandmarkDiffPipeline(
561
+ mode=mode, ip_adapter_scale=ip_adapter_scale,
562
+ controlnet_checkpoint=controlnet_checkpoint,
563
+ displacement_model_path=displacement_model_path,
564
+ )
565
  pipe.load()
566
 
567
  print(f"\nGenerating {procedure} prediction (intensity={intensity}, mode={mode})...")
 
598
  choices=["img2img", "controlnet", "controlnet_ip", "tps"],
599
  )
600
  parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
601
+ parser.add_argument("--checkpoint", default=None,
602
+ help="Path to fine-tuned ControlNet checkpoint")
603
+ parser.add_argument("--displacement-model", default=None,
604
+ help="Path to displacement_model.npz for data-driven manipulation")
605
  args = parser.parse_args()
606
 
607
  run_inference(
608
  args.image, args.procedure, args.intensity, args.output,
609
+ args.seed, args.mode, args.ip_adapter_scale, args.checkpoint,
610
+ args.displacement_model,
611
  )