Space / bald_processor.py
Seniordev22's picture
Update bald_processor.py
eb4d078 verified
import cv2
import torch
import numpy as np
from PIL import Image, UnidentifiedImageError
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
try:
logger.info("Loading SegFormer face-parsing model...")
processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)
model.eval()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}", exc_info=True)
raise RuntimeError("SegFormer model load failed!")
hair_class_id = 13
ear_class_ids = [8, 9] # l_ear=8, r_ear=9 (confirmed from model card)
skin_class_id = 1 # skin
def make_realistic_bald(input_image: Image.Image) -> Image.Image:
if input_image is None:
raise ValueError("No input image provided!")
try:
orig_w, orig_h = input_image.size
original_np = np.array(input_image)
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
MAX_DIM = 2048
scale_factor = 1.0
working_np = original_np.copy()
working_bgr = original_bgr.copy()
working_h, working_w = orig_h, orig_w
if max(orig_w, orig_h) > MAX_DIM:
scale_factor = MAX_DIM / max(orig_w, orig_h)
working_w, working_h = int(orig_w * scale_factor), int(orig_h * scale_factor)
working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
# Segmentation
pil_working = Image.fromarray(working_np)
inputs = processor(images=pil_working, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled = torch.nn.functional.interpolate(
logits, size=(working_h, working_w), mode="bilinear", align_corners=False
)
parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
# Masks
hair_mask = (parsing == hair_class_id).astype(np.uint8)
ears_mask = np.zeros_like(hair_mask)
for cls in ear_class_ids:
ears_mask[parsing == cls] = 1
hair_mask[ears_mask == 1] = 0
# Clean mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5, 5), 0) > 0.28).astype(np.uint8)
hair_pixels = np.sum(hair_mask)
if hair_pixels < 50:
raise ValueError("NO_HAIR_DETECTED")
# Inpainting
radius = 20 if hair_pixels > 220000 else 12 # Slightly larger for big areas
flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
inpainted_bgr = cv2.inpaint(working_bgr, hair_mask * 255, inpaintRadius=radius, flags=flag)
# Large hair: Color correction + Seamless blending
if hair_pixels > 220000:
skin_mask = (parsing == skin_class_id).astype(np.uint8)
ref_mask = skin_mask.copy()
ref_mask[hair_mask == 1] = 0
ref_mask = cv2.erode(ref_mask, np.ones((3,3), np.uint8), iterations=1) # Avoid border artifacts
ref_mean = cv2.mean(working_bgr, mask=ref_mask * 255)[:3]
inp_mean = cv2.mean(inpainted_bgr, mask=hair_mask * 255)[:3]
if ref_mean != (0,0,0): # Valid reference
color_diff = np.array(ref_mean) - np.array(inp_mean)
color_diff = np.clip(color_diff * 0.7, -40, 40) # Softer adjustment to avoid over-correction
hair_3ch = np.repeat(hair_mask[..., None], 3, axis=2)
inpainted_bgr[hair_3ch == 1] = np.clip(
inpainted_bgr[hair_3ch == 1].astype(np.float32) + color_diff, 0, 255
).astype(np.uint8)
# Seamless clone for better merge (MIXED_CLONE good for skin)
moments = cv2.moments(hair_mask * 255)
if moments["m00"] != 0:
cx = int(moments["m10"] / moments["m00"])
cy = int(moments["m01"] / moments["m00"])
else:
cx, cy = working_w // 2, working_h // 2
inpainted_bgr = cv2.seamlessClone(
inpainted_bgr, working_bgr, hair_mask * 255, (cx, cy), cv2.MIXED_CLONE
)
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
# Upscale if downscaled
if scale_factor < 1.0:
bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
mask_up = cv2.resize(hair_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
else:
bald_up = inpainted_rgb
mask_up = hair_mask
# Soft alpha blend (always, for smooth edges)
mask_float = cv2.GaussianBlur(mask_up.astype(np.float32) * 255, (25, 25), 0) / 255.0
result = (1 - mask_float[..., None]) * original_np.astype(np.float32) + \
mask_float[..., None] * bald_up.astype(np.float32)
result = np.clip(result, 0, 255).astype(np.uint8)
return Image.fromarray(result)
except UnidentifiedImageError:
raise ValueError("Invalid image format or corrupt image!")
except Exception as e:
logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
raise RuntimeError(f"Bald processing failed: {str(e)}")