Spaces:
Running
Running
| import cv2 | |
| import torch | |
| import numpy as np | |
| from PIL import Image, UnidentifiedImageError | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| try: | |
| logger.info("Loading SegFormer face-parsing model...") | |
| processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing") | |
| model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing") | |
| model.to(device) | |
| model.eval() | |
| logger.info("Model loaded successfully!") | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}", exc_info=True) | |
| raise RuntimeError("SegFormer model load failed!") | |
| hair_class_id = 13 | |
| ear_class_ids = [8, 9] # l_ear=8, r_ear=9 (confirmed from model card) | |
| skin_class_id = 1 # skin | |
| def make_realistic_bald(input_image: Image.Image) -> Image.Image: | |
| if input_image is None: | |
| raise ValueError("No input image provided!") | |
| try: | |
| orig_w, orig_h = input_image.size | |
| original_np = np.array(input_image) | |
| original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR) | |
| MAX_DIM = 2048 | |
| scale_factor = 1.0 | |
| working_np = original_np.copy() | |
| working_bgr = original_bgr.copy() | |
| working_h, working_w = orig_h, orig_w | |
| if max(orig_w, orig_h) > MAX_DIM: | |
| scale_factor = MAX_DIM / max(orig_w, orig_h) | |
| working_w, working_h = int(orig_w * scale_factor), int(orig_h * scale_factor) | |
| working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA) | |
| working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR) | |
| # Segmentation | |
| pil_working = Image.fromarray(working_np) | |
| inputs = processor(images=pil_working, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| upsampled = torch.nn.functional.interpolate( | |
| logits, size=(working_h, working_w), mode="bilinear", align_corners=False | |
| ) | |
| parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy() | |
| # Masks | |
| hair_mask = (parsing == hair_class_id).astype(np.uint8) | |
| ears_mask = np.zeros_like(hair_mask) | |
| for cls in ear_class_ids: | |
| ears_mask[parsing == cls] = 1 | |
| hair_mask[ears_mask == 1] = 0 | |
| # Clean mask | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13)) | |
| hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| hair_mask = cv2.dilate(hair_mask, kernel, iterations=1) | |
| hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5, 5), 0) > 0.28).astype(np.uint8) | |
| hair_pixels = np.sum(hair_mask) | |
| if hair_pixels < 50: | |
| raise ValueError("NO_HAIR_DETECTED") | |
| # Inpainting | |
| radius = 20 if hair_pixels > 220000 else 12 # Slightly larger for big areas | |
| flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS | |
| inpainted_bgr = cv2.inpaint(working_bgr, hair_mask * 255, inpaintRadius=radius, flags=flag) | |
| # Large hair: Color correction + Seamless blending | |
| if hair_pixels > 220000: | |
| skin_mask = (parsing == skin_class_id).astype(np.uint8) | |
| ref_mask = skin_mask.copy() | |
| ref_mask[hair_mask == 1] = 0 | |
| ref_mask = cv2.erode(ref_mask, np.ones((3,3), np.uint8), iterations=1) # Avoid border artifacts | |
| ref_mean = cv2.mean(working_bgr, mask=ref_mask * 255)[:3] | |
| inp_mean = cv2.mean(inpainted_bgr, mask=hair_mask * 255)[:3] | |
| if ref_mean != (0,0,0): # Valid reference | |
| color_diff = np.array(ref_mean) - np.array(inp_mean) | |
| color_diff = np.clip(color_diff * 0.7, -40, 40) # Softer adjustment to avoid over-correction | |
| hair_3ch = np.repeat(hair_mask[..., None], 3, axis=2) | |
| inpainted_bgr[hair_3ch == 1] = np.clip( | |
| inpainted_bgr[hair_3ch == 1].astype(np.float32) + color_diff, 0, 255 | |
| ).astype(np.uint8) | |
| # Seamless clone for better merge (MIXED_CLONE good for skin) | |
| moments = cv2.moments(hair_mask * 255) | |
| if moments["m00"] != 0: | |
| cx = int(moments["m10"] / moments["m00"]) | |
| cy = int(moments["m01"] / moments["m00"]) | |
| else: | |
| cx, cy = working_w // 2, working_h // 2 | |
| inpainted_bgr = cv2.seamlessClone( | |
| inpainted_bgr, working_bgr, hair_mask * 255, (cx, cy), cv2.MIXED_CLONE | |
| ) | |
| inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB) | |
| # Upscale if downscaled | |
| if scale_factor < 1.0: | |
| bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4) | |
| mask_up = cv2.resize(hair_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST) | |
| else: | |
| bald_up = inpainted_rgb | |
| mask_up = hair_mask | |
| # Soft alpha blend (always, for smooth edges) | |
| mask_float = cv2.GaussianBlur(mask_up.astype(np.float32) * 255, (25, 25), 0) / 255.0 | |
| result = (1 - mask_float[..., None]) * original_np.astype(np.float32) + \ | |
| mask_float[..., None] * bald_up.astype(np.float32) | |
| result = np.clip(result, 0, 255).astype(np.uint8) | |
| return Image.fromarray(result) | |
| except UnidentifiedImageError: | |
| raise ValueError("Invalid image format or corrupt image!") | |
| except Exception as e: | |
| logger.error(f"Bald processing failed: {str(e)}", exc_info=True) | |
| raise RuntimeError(f"Bald processing failed: {str(e)}") |