Spaces:
Sleeping
Sleeping
Upload landmarkdiff/inference.py with huggingface_hub
Browse files- landmarkdiff/inference.py +120 -34
landmarkdiff/inference.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
"""Inference pipeline for surgical outcome prediction.
|
| 2 |
|
| 3 |
-
|
| 4 |
-
1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (HF auth + GPU)
|
| 5 |
-
2. ControlNet + IP-Adapter: ControlNet
|
| 6 |
-
3. Img2Img: SD1.5 img2img with mask compositing (
|
| 7 |
-
4. TPS-only: geometric warp
|
| 8 |
|
| 9 |
-
|
| 10 |
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
|
@@ -38,7 +38,7 @@ def get_device() -> torch.device:
|
|
| 38 |
def numpy_to_pil(arr: np.ndarray) -> Image.Image:
|
| 39 |
if len(arr.shape) == 2:
|
| 40 |
return Image.fromarray(arr, mode="L")
|
| 41 |
-
return Image.fromarray(arr[:, :, ::-1])
|
| 42 |
|
| 43 |
|
| 44 |
def pil_to_numpy(img: Image.Image) -> np.ndarray:
|
|
@@ -86,7 +86,12 @@ def mask_composite(
|
|
| 86 |
mask: np.ndarray,
|
| 87 |
use_laplacian: bool = True,
|
| 88 |
) -> np.ndarray:
|
| 89 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
mask_f = mask.astype(np.float32)
|
| 91 |
if mask_f.max() > 1.0:
|
| 92 |
mask_f = mask_f / 255.0
|
|
@@ -112,7 +117,11 @@ def mask_composite(
|
|
| 112 |
|
| 113 |
|
| 114 |
def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 115 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
mask_bool = mask > 0.3
|
| 117 |
if not np.any(mask_bool):
|
| 118 |
return source
|
|
@@ -120,7 +129,7 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
|
|
| 120 |
src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 121 |
tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 122 |
|
| 123 |
-
#
|
| 124 |
for ch in range(3):
|
| 125 |
src_vals = src_lab[:, :, ch][mask_bool]
|
| 126 |
tgt_vals = tgt_lab[:, :, ch][mask_bool]
|
|
@@ -128,7 +137,7 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
|
|
| 128 |
src_mean, src_std = np.mean(src_vals), np.std(src_vals) + 1e-6
|
| 129 |
tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + 1e-6
|
| 130 |
|
| 131 |
-
#
|
| 132 |
src_lab[:, :, ch] = np.where(
|
| 133 |
mask_bool,
|
| 134 |
(src_lab[:, :, ch] - src_mean) * (tgt_std / src_std) + tgt_mean,
|
|
@@ -139,13 +148,15 @@ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -
|
|
| 139 |
return cv2.cvtColor(src_lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
|
| 140 |
|
| 141 |
|
| 142 |
-
def poisson_blend(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 143 |
-
"""Poisson blend - just delegates to mask_composite (more reliable)."""
|
| 144 |
-
return mask_composite(source, target, mask)
|
| 145 |
-
|
| 146 |
-
|
| 147 |
class LandmarkDiffPipeline:
|
| 148 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
# Default IP-Adapter model for SD1.5 face identity
|
| 151 |
IP_ADAPTER_REPO = "h94/IP-Adapter"
|
|
@@ -157,16 +168,29 @@ class LandmarkDiffPipeline:
|
|
| 157 |
self,
|
| 158 |
mode: str = "img2img",
|
| 159 |
controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
|
|
|
|
| 160 |
base_model_id: str | None = None,
|
| 161 |
device: Optional[torch.device] = None,
|
| 162 |
dtype: Optional[torch.dtype] = None,
|
| 163 |
ip_adapter_scale: float = 0.6,
|
| 164 |
clinical_flags: Optional["ClinicalFlags"] = None,
|
|
|
|
| 165 |
):
|
| 166 |
self.mode = mode
|
| 167 |
self.device = device or get_device()
|
| 168 |
self.ip_adapter_scale = ip_adapter_scale
|
| 169 |
self.clinical_flags = clinical_flags
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
if self.device.type == "mps":
|
| 172 |
self.dtype = torch.float32
|
|
@@ -188,7 +212,7 @@ class LandmarkDiffPipeline:
|
|
| 188 |
|
| 189 |
def load(self) -> None:
|
| 190 |
if self.mode == "tps":
|
| 191 |
-
print("TPS mode
|
| 192 |
return
|
| 193 |
if self.mode in ("controlnet", "controlnet_ip"):
|
| 194 |
self._load_controlnet()
|
|
@@ -204,10 +228,21 @@ class LandmarkDiffPipeline:
|
|
| 204 |
DPMSolverMultistepScheduler,
|
| 205 |
)
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
print(f"Loading base model from {self.base_model_id}...")
|
| 212 |
self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 213 |
self.base_model_id,
|
|
@@ -216,19 +251,23 @@ class LandmarkDiffPipeline:
|
|
| 216 |
safety_checker=None,
|
| 217 |
requires_safety_checker=False,
|
| 218 |
)
|
| 219 |
-
# DPM++ 2M Karras
|
| 220 |
self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 221 |
self._pipe.scheduler.config,
|
| 222 |
algorithm_type="dpmsolver++",
|
| 223 |
use_karras_sigmas=True,
|
| 224 |
)
|
| 225 |
-
# FP32 VAE decode
|
| 226 |
if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
|
| 227 |
self._pipe.vae.config.force_upcast = True
|
| 228 |
self._apply_device_optimizations()
|
| 229 |
|
| 230 |
def _load_ip_adapter(self) -> None:
|
| 231 |
-
"""Load IP-Adapter for identity
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
if self._pipe is None:
|
| 233 |
raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
|
| 234 |
try:
|
|
@@ -305,12 +344,41 @@ class LandmarkDiffPipeline:
|
|
| 305 |
if face is None:
|
| 306 |
raise ValueError("No face detected in image.")
|
| 307 |
|
| 308 |
-
# face view angle for multi-view awareness
|
| 309 |
view_info = estimate_face_view(face)
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
landmark_img = render_landmark_image(manipulated, 512, 512)
|
| 315 |
mask = generate_surgical_mask(
|
| 316 |
face, procedure, 512, 512, clinical_flags=flags,
|
|
@@ -322,7 +390,7 @@ class LandmarkDiffPipeline:
|
|
| 322 |
|
| 323 |
prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
|
| 324 |
|
| 325 |
-
# TPS warp
|
| 326 |
tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
|
| 327 |
|
| 328 |
if self.mode == "tps":
|
|
@@ -340,7 +408,7 @@ class LandmarkDiffPipeline:
|
|
| 340 |
guidance_scale, strength, generator,
|
| 341 |
)
|
| 342 |
|
| 343 |
-
#
|
| 344 |
identity_check = None
|
| 345 |
restore_used = "none"
|
| 346 |
if postprocess and self.mode != "tps":
|
|
@@ -378,6 +446,7 @@ class LandmarkDiffPipeline:
|
|
| 378 |
"ip_adapter_active": self._ip_adapter_loaded,
|
| 379 |
"identity_check": identity_check,
|
| 380 |
"restore_used": restore_used,
|
|
|
|
| 381 |
}
|
| 382 |
|
| 383 |
def _generate_controlnet(
|
|
@@ -418,7 +487,13 @@ class LandmarkDiffPipeline:
|
|
| 418 |
|
| 419 |
|
| 420 |
def estimate_face_view(face: FaceLandmarks) -> dict:
|
| 421 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
coords = face.pixel_coords
|
| 423 |
nose_tip = coords[1]
|
| 424 |
left_ear = coords[234]
|
|
@@ -471,6 +546,8 @@ def run_inference(
|
|
| 471 |
seed: int = 42,
|
| 472 |
mode: str = "img2img",
|
| 473 |
ip_adapter_scale: float = 0.6,
|
|
|
|
|
|
|
| 474 |
) -> None:
|
| 475 |
out = Path(output_dir)
|
| 476 |
out.mkdir(parents=True, exist_ok=True)
|
|
@@ -480,7 +557,11 @@ def run_inference(
|
|
| 480 |
print(f"ERROR: Could not load {image_path}")
|
| 481 |
sys.exit(1)
|
| 482 |
|
| 483 |
-
pipe = LandmarkDiffPipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
pipe.load()
|
| 485 |
|
| 486 |
print(f"\nGenerating {procedure} prediction (intensity={intensity}, mode={mode})...")
|
|
@@ -517,9 +598,14 @@ if __name__ == "__main__":
|
|
| 517 |
choices=["img2img", "controlnet", "controlnet_ip", "tps"],
|
| 518 |
)
|
| 519 |
parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
args = parser.parse_args()
|
| 521 |
|
| 522 |
run_inference(
|
| 523 |
args.image, args.procedure, args.intensity, args.output,
|
| 524 |
-
args.seed, args.mode, args.ip_adapter_scale,
|
|
|
|
| 525 |
)
|
|
|
|
| 1 |
"""Inference pipeline for surgical outcome prediction.
|
| 2 |
|
| 3 |
+
Four modes:
|
| 4 |
+
1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (requires HF auth + GPU)
|
| 5 |
+
2. ControlNet + IP-Adapter: ControlNet with identity preservation via face embeddings
|
| 6 |
+
3. Img2Img: SD1.5 img2img with mask compositing (runs on MPS, no auth needed)
|
| 7 |
+
4. TPS-only: Pure geometric warp — no diffusion model, instant results
|
| 8 |
|
| 9 |
+
Supports MPS (Apple Silicon), CUDA, and CPU backends.
|
| 10 |
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
|
|
|
| 38 |
def numpy_to_pil(arr: np.ndarray) -> Image.Image:
|
| 39 |
if len(arr.shape) == 2:
|
| 40 |
return Image.fromarray(arr, mode="L")
|
| 41 |
+
return Image.fromarray(arr[:, :, ::-1].copy())
|
| 42 |
|
| 43 |
|
| 44 |
def pil_to_numpy(img: Image.Image) -> np.ndarray:
|
|
|
|
| 86 |
mask: np.ndarray,
|
| 87 |
use_laplacian: bool = True,
|
| 88 |
) -> np.ndarray:
|
| 89 |
+
"""Composite warped image into original using ONLY the mask region.
|
| 90 |
+
|
| 91 |
+
Uses Laplacian pyramid blending by default for seamless transitions.
|
| 92 |
+
Falls back to simple alpha blend if Laplacian unavailable.
|
| 93 |
+
Matches skin tone in LAB space to prevent any color shift.
|
| 94 |
+
"""
|
| 95 |
mask_f = mask.astype(np.float32)
|
| 96 |
if mask_f.max() > 1.0:
|
| 97 |
mask_f = mask_f / 255.0
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 120 |
+
"""Match source skin tone to target within mask, preserving structure.
|
| 121 |
+
|
| 122 |
+
Works in LAB space: transfers L (luminance) and AB (color) statistics
|
| 123 |
+
from the original to the warped image so skin tone is preserved exactly.
|
| 124 |
+
"""
|
| 125 |
mask_bool = mask > 0.3
|
| 126 |
if not np.any(mask_bool):
|
| 127 |
return source
|
|
|
|
| 129 |
src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 130 |
tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
|
| 131 |
|
| 132 |
+
# Match each LAB channel's statistics in the mask region
|
| 133 |
for ch in range(3):
|
| 134 |
src_vals = src_lab[:, :, ch][mask_bool]
|
| 135 |
tgt_vals = tgt_lab[:, :, ch][mask_bool]
|
|
|
|
| 137 |
src_mean, src_std = np.mean(src_vals), np.std(src_vals) + 1e-6
|
| 138 |
tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + 1e-6
|
| 139 |
|
| 140 |
+
# Normalize source to match target's distribution
|
| 141 |
src_lab[:, :, ch] = np.where(
|
| 142 |
mask_bool,
|
| 143 |
(src_lab[:, :, ch] - src_mean) * (tgt_std / src_std) + tgt_mean,
|
|
|
|
| 148 |
return cv2.cvtColor(src_lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
|
| 149 |
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
class LandmarkDiffPipeline:
|
| 152 |
+
"""End-to-end pipeline: image -> landmarks -> manipulate -> generate.
|
| 153 |
+
|
| 154 |
+
Modes:
|
| 155 |
+
- 'controlnet': CrucibleAI/ControlNetMediaPipeFace + SD1.5
|
| 156 |
+
- 'controlnet_ip': ControlNet + IP-Adapter for identity preservation
|
| 157 |
+
- 'img2img': SD1.5 img2img with mask compositing
|
| 158 |
+
- 'tps': Pure geometric TPS warp (no diffusion, instant)
|
| 159 |
+
"""
|
| 160 |
|
| 161 |
# Default IP-Adapter model for SD1.5 face identity
|
| 162 |
IP_ADAPTER_REPO = "h94/IP-Adapter"
|
|
|
|
| 168 |
self,
|
| 169 |
mode: str = "img2img",
|
| 170 |
controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
|
| 171 |
+
controlnet_checkpoint: str | None = None,
|
| 172 |
base_model_id: str | None = None,
|
| 173 |
device: Optional[torch.device] = None,
|
| 174 |
dtype: Optional[torch.dtype] = None,
|
| 175 |
ip_adapter_scale: float = 0.6,
|
| 176 |
clinical_flags: Optional["ClinicalFlags"] = None,
|
| 177 |
+
displacement_model_path: str | None = None,
|
| 178 |
):
|
| 179 |
self.mode = mode
|
| 180 |
self.device = device or get_device()
|
| 181 |
self.ip_adapter_scale = ip_adapter_scale
|
| 182 |
self.clinical_flags = clinical_flags
|
| 183 |
+
self.controlnet_checkpoint = controlnet_checkpoint
|
| 184 |
+
|
| 185 |
+
# Load displacement model for data-driven manipulation
|
| 186 |
+
self._displacement_model = None
|
| 187 |
+
if displacement_model_path:
|
| 188 |
+
try:
|
| 189 |
+
from landmarkdiff.displacement_model import DisplacementModel
|
| 190 |
+
self._displacement_model = DisplacementModel.load(displacement_model_path)
|
| 191 |
+
print(f"Displacement model loaded: {self._displacement_model.procedures}")
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"WARNING: Failed to load displacement model: {e}")
|
| 194 |
|
| 195 |
if self.device.type == "mps":
|
| 196 |
self.dtype = torch.float32
|
|
|
|
| 212 |
|
| 213 |
def load(self) -> None:
|
| 214 |
if self.mode == "tps":
|
| 215 |
+
print("TPS mode — no model to load")
|
| 216 |
return
|
| 217 |
if self.mode in ("controlnet", "controlnet_ip"):
|
| 218 |
self._load_controlnet()
|
|
|
|
| 228 |
DPMSolverMultistepScheduler,
|
| 229 |
)
|
| 230 |
|
| 231 |
+
if self.controlnet_checkpoint:
|
| 232 |
+
# Load fine-tuned ControlNet from local checkpoint
|
| 233 |
+
ckpt_path = Path(self.controlnet_checkpoint)
|
| 234 |
+
# Support both direct path and training checkpoint structure
|
| 235 |
+
if (ckpt_path / "controlnet_ema").exists():
|
| 236 |
+
ckpt_path = ckpt_path / "controlnet_ema"
|
| 237 |
+
print(f"Loading fine-tuned ControlNet from {ckpt_path}...")
|
| 238 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 239 |
+
str(ckpt_path), torch_dtype=self.dtype,
|
| 240 |
+
)
|
| 241 |
+
else:
|
| 242 |
+
print(f"Loading ControlNet from {self.controlnet_id}...")
|
| 243 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 244 |
+
self.controlnet_id, subfolder="diffusion_sd15", torch_dtype=self.dtype,
|
| 245 |
+
)
|
| 246 |
print(f"Loading base model from {self.base_model_id}...")
|
| 247 |
self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 248 |
self.base_model_id,
|
|
|
|
| 251 |
safety_checker=None,
|
| 252 |
requires_safety_checker=False,
|
| 253 |
)
|
| 254 |
+
# DPM++ 2M Karras — produces more photorealistic output than UniPC
|
| 255 |
self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
|
| 256 |
self._pipe.scheduler.config,
|
| 257 |
algorithm_type="dpmsolver++",
|
| 258 |
use_karras_sigmas=True,
|
| 259 |
)
|
| 260 |
+
# FP32 VAE decode — prevents color banding artifacts on skin tones
|
| 261 |
if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
|
| 262 |
self._pipe.vae.config.force_upcast = True
|
| 263 |
self._apply_device_optimizations()
|
| 264 |
|
| 265 |
def _load_ip_adapter(self) -> None:
|
| 266 |
+
"""Load IP-Adapter for identity-preserving generation.
|
| 267 |
+
|
| 268 |
+
Uses h94/IP-Adapter-FaceID with CLIP image encoder to condition
|
| 269 |
+
generation on the input face identity.
|
| 270 |
+
"""
|
| 271 |
if self._pipe is None:
|
| 272 |
raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
|
| 273 |
try:
|
|
|
|
| 344 |
if face is None:
|
| 345 |
raise ValueError("No face detected in image.")
|
| 346 |
|
| 347 |
+
# Estimate face view angle for multi-view awareness
|
| 348 |
view_info = estimate_face_view(face)
|
| 349 |
|
| 350 |
+
# Use displacement model for data-driven manipulation if available
|
| 351 |
+
manipulation_mode = "preset"
|
| 352 |
+
if self._displacement_model and procedure in self._displacement_model.procedures:
|
| 353 |
+
try:
|
| 354 |
+
from landmarkdiff.displacement_model import DisplacementModel
|
| 355 |
+
rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
|
| 356 |
+
# Map UI intensity (0-100) to displacement model intensity (0-2)
|
| 357 |
+
dm_intensity = intensity / 50.0 # 50 -> 1.0x mean displacement
|
| 358 |
+
displacement = self._displacement_model.get_displacement_field(
|
| 359 |
+
procedure, intensity=dm_intensity, noise_scale=0.3, rng=rng,
|
| 360 |
+
)
|
| 361 |
+
# Apply displacement to landmarks
|
| 362 |
+
new_lm = face.landmarks.copy()
|
| 363 |
+
n = min(len(new_lm), len(displacement))
|
| 364 |
+
new_lm[:n, 0] += displacement[:n, 0]
|
| 365 |
+
new_lm[:n, 1] += displacement[:n, 1]
|
| 366 |
+
new_lm[:, 0] = np.clip(new_lm[:, 0], 0.01, 0.99)
|
| 367 |
+
new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
|
| 368 |
+
manipulated = FaceLandmarks(
|
| 369 |
+
landmarks=new_lm,
|
| 370 |
+
image_width=512, image_height=512,
|
| 371 |
+
confidence=face.confidence,
|
| 372 |
+
)
|
| 373 |
+
manipulation_mode = "displacement_model"
|
| 374 |
+
except Exception:
|
| 375 |
+
manipulated = apply_procedure_preset(
|
| 376 |
+
face, procedure, intensity, image_size=512, clinical_flags=flags,
|
| 377 |
+
)
|
| 378 |
+
else:
|
| 379 |
+
manipulated = apply_procedure_preset(
|
| 380 |
+
face, procedure, intensity, image_size=512, clinical_flags=flags,
|
| 381 |
+
)
|
| 382 |
landmark_img = render_landmark_image(manipulated, 512, 512)
|
| 383 |
mask = generate_surgical_mask(
|
| 384 |
face, procedure, 512, 512, clinical_flags=flags,
|
|
|
|
| 390 |
|
| 391 |
prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
|
| 392 |
|
| 393 |
+
# Step 1: TPS geometric warp (always computed — the geometric baseline)
|
| 394 |
tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
|
| 395 |
|
| 396 |
if self.mode == "tps":
|
|
|
|
| 408 |
guidance_scale, strength, generator,
|
| 409 |
)
|
| 410 |
|
| 411 |
+
# Step 2: Post-processing for photorealism (neural + classical pipeline)
|
| 412 |
identity_check = None
|
| 413 |
restore_used = "none"
|
| 414 |
if postprocess and self.mode != "tps":
|
|
|
|
| 446 |
"ip_adapter_active": self._ip_adapter_loaded,
|
| 447 |
"identity_check": identity_check,
|
| 448 |
"restore_used": restore_used,
|
| 449 |
+
"manipulation_mode": manipulation_mode,
|
| 450 |
}
|
| 451 |
|
| 452 |
def _generate_controlnet(
|
|
|
|
| 487 |
|
| 488 |
|
| 489 |
def estimate_face_view(face: FaceLandmarks) -> dict:
|
| 490 |
+
"""Estimate face orientation from landmarks for multi-view awareness.
|
| 491 |
+
|
| 492 |
+
Uses the nose tip (idx 1), left ear (idx 234), and right ear (idx 454) to
|
| 493 |
+
estimate yaw angle. Pitch from forehead (idx 10) and chin (idx 152).
|
| 494 |
+
|
| 495 |
+
Returns dict with yaw, pitch (degrees), and view classification.
|
| 496 |
+
"""
|
| 497 |
coords = face.pixel_coords
|
| 498 |
nose_tip = coords[1]
|
| 499 |
left_ear = coords[234]
|
|
|
|
| 546 |
seed: int = 42,
|
| 547 |
mode: str = "img2img",
|
| 548 |
ip_adapter_scale: float = 0.6,
|
| 549 |
+
controlnet_checkpoint: str | None = None,
|
| 550 |
+
displacement_model_path: str | None = None,
|
| 551 |
) -> None:
|
| 552 |
out = Path(output_dir)
|
| 553 |
out.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 557 |
print(f"ERROR: Could not load {image_path}")
|
| 558 |
sys.exit(1)
|
| 559 |
|
| 560 |
+
pipe = LandmarkDiffPipeline(
|
| 561 |
+
mode=mode, ip_adapter_scale=ip_adapter_scale,
|
| 562 |
+
controlnet_checkpoint=controlnet_checkpoint,
|
| 563 |
+
displacement_model_path=displacement_model_path,
|
| 564 |
+
)
|
| 565 |
pipe.load()
|
| 566 |
|
| 567 |
print(f"\nGenerating {procedure} prediction (intensity={intensity}, mode={mode})...")
|
|
|
|
| 598 |
choices=["img2img", "controlnet", "controlnet_ip", "tps"],
|
| 599 |
)
|
| 600 |
parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
|
| 601 |
+
parser.add_argument("--checkpoint", default=None,
|
| 602 |
+
help="Path to fine-tuned ControlNet checkpoint")
|
| 603 |
+
parser.add_argument("--displacement-model", default=None,
|
| 604 |
+
help="Path to displacement_model.npz for data-driven manipulation")
|
| 605 |
args = parser.parse_args()
|
| 606 |
|
| 607 |
run_inference(
|
| 608 |
args.image, args.procedure, args.intensity, args.output,
|
| 609 |
+
args.seed, args.mode, args.ip_adapter_scale, args.checkpoint,
|
| 610 |
+
args.displacement_model,
|
| 611 |
)
|