#!/usr/bin/env python3 """ utils.segmentation ───────────────────────────────────────────────────────────────────────────── All high-quality person-segmentation code for BackgroundFX Pro. Exports ------- segment_person_hq(image, predictor, fallback_enabled=True) → np.ndarray segment_person_hq_original(image, predictor, fallback_enabled=True) → np.ndarray SegmentationError - Custom exception for segmentation errors Everything else is prefixed "_" and considered private. """ from __future__ import annotations from typing import Any, Tuple, Optional, Dict import logging, os, math import cv2 import numpy as np import torch log = logging.getLogger(__name__) # ============================================================================ # CUSTOM EXCEPTION # ============================================================================ class SegmentationError(Exception): """Custom exception for segmentation-related errors""" pass # ============================================================================ # TUNABLE CONSTANTS # ============================================================================ USE_ENHANCED_SEGMENTATION = True USE_INTELLIGENT_PROMPTING = True USE_ITERATIVE_REFINEMENT = True MIN_AREA_RATIO = 0.015 MAX_AREA_RATIO = 0.97 SALIENCY_THRESH = 0.65 GRABCUT_ITERS = 3 # ---------------------------------------------------------------------------- # Public -- main entry-points # ---------------------------------------------------------------------------- __all__ = [ "segment_person_hq", "segment_person_hq_original", "SegmentationError", ] # ============================================================================ # SAM2 TO MATANYONE MASK BRIDGE # ============================================================================ def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray: """ Convert SAM2 multi-mask output to single best mask for MatAnyone. SAM2 returns (N, H, W) where N is typically 3 masks. We need to return a single (H, W) mask. """ if masks is None or len(masks) == 0: raise SegmentationError("No masks returned from SAM2") # Handle torch tensors if isinstance(masks, torch.Tensor): masks = masks.cpu().numpy() if scores is not None and isinstance(scores, torch.Tensor): scores = scores.cpu().numpy() # Ensure we have the right shape if masks.ndim == 4: # (B, N, H, W) masks = masks[0] # Take first batch if masks.ndim != 3: # Should be (N, H, W) raise SegmentationError(f"Unexpected mask shape: {masks.shape}") # Select best mask if scores is not None and len(scores) > 0: best_idx = int(np.argmax(scores)) else: # Fallback: pick mask with largest area areas = [np.sum(m > 0.5) for m in masks] best_idx = int(np.argmax(areas)) mask = masks[best_idx] # Convert to uint8 binary mask if mask.dtype in (np.float32, np.float64): mask = (mask > 0.5).astype(np.uint8) * 255 elif mask.dtype != np.uint8: mask = mask.astype(np.uint8) # Ensure single channel if mask.ndim == 3: mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze() # Binary threshold _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) # Verify output shape assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}" return mask # ============================================================================ # MAIN API # ============================================================================ def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: """ High-quality person segmentation. Tries SAM-2 with smart prompts first, then a classical CV cascade, then a geometric fallback. Returns uint8 mask (0/255). Never raises if fallback_enabled=True. """ if not USE_ENHANCED_SEGMENTATION: return segment_person_hq_original(image, predictor, fallback_enabled) if image is None or image.size == 0: raise SegmentationError("Invalid input image") # 1) — SAM-2 path ------------------------------------------------------- if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): try: predictor.set_image(image) mask = ( _segment_with_intelligent_prompts(image, predictor) if USE_INTELLIGENT_PROMPTING else _segment_with_basic_prompts(image, predictor) ) if USE_ITERATIVE_REFINEMENT: mask = _auto_refine_mask_iteratively(image, mask, predictor) if _validate_mask_quality(mask, image.shape[:2]): return mask log.warning("SAM2 mask failed validation → fallback") except Exception as e: log.warning(f"SAM2 path failed: {e}") # 2) — Classical cascade ---------------------------------------------- try: mask = _classical_segmentation_cascade(image) if _validate_mask_quality(mask, image.shape[:2]): return mask log.warning("Classical cascade weak → geometric fallback") except Exception as e: log.debug(f"Classical cascade error: {e}") # 3) — Last-chance geometric ellipse ---------------------------------- return _geometric_person_mask(image) def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray: """ Very first implementation kept for rollback. Fewer smarts, still robust. """ if image is None or image.size == 0: raise SegmentationError("Invalid input image") try: if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"): h, w = image.shape[:2] predictor.set_image(image) points = np.array([ [w//2, h//4], [w//2, h//2], [w//2, 3*h//4], [w//3, h//2], [2*w//3, h//2], ], dtype=np.float32) labels = np.ones(len(points), np.int32) with torch.no_grad(): masks, scores, _ = predictor.predict( point_coords=points, point_labels=labels, multimask_output=True, ) # Use the bridge function to get single best mask if masks is not None and len(masks): mask = _sam2_to_matanyone_mask(masks, scores) if _validate_mask_quality(mask, image.shape[:2]): return mask if fallback_enabled: return _classical_segmentation_cascade(image) raise RuntimeError("SAM2 failed and fallback disabled") except Exception as e: log.warning(f"segment_person_hq_original error: {e}") return _classical_segmentation_cascade(image) # ============================================================================ # INTELLIGENT + BASIC PROMPTING # ============================================================================ def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: pos, neg = _generate_smart_prompts(image) return _sam2_predict(image, predictor, pos, neg) def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray: h, w = image.shape[:2] pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32) neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32) return _sam2_predict(image, predictor, pos, neg) def _sam2_predict(image: np.ndarray, predictor: Any, pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray: if pos_points.size == 0: pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32) points = np.vstack([pos_points, neg_points]) labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32) with torch.no_grad(): masks, scores, _ = predictor.predict( point_coords=points, point_labels=labels, multimask_output=True, ) # Use the bridge function to convert multi-mask to single mask return _sam2_to_matanyone_mask(masks, scores) def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Simple saliency-based heuristic to auto-place positive / negative points. """ h, w = image.shape[:2] sal = _compute_saliency(image) pos, neg = [], [] if sal is not None: high = sal > (SALIENCY_THRESH - .1) contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]: M = cv2.moments(c) if M["m00"]: pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])]) if not pos: pos = [[w//2, h//2]] neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]] return np.asarray(pos, np.float32), np.asarray(neg, np.float32) # ============================================================================ # CLASSICAL SEGMENTATION CASCADE # ============================================================================ def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray: """ Edge-median background subtraction → saliency flood-fill → GrabCut. """ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]]) diff = np.abs(gray.astype(float) - np.median(edge_px)) mask = (diff > 30).astype(np.uint8) * 255 mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))) if _validate_mask_quality(mask, image.shape[:2]): return mask # Saliency + flood-fill mask = _refine_with_saliency(image, mask) if _validate_mask_quality(mask, image.shape[:2]): return mask # GrabCut mask = _refine_with_grabcut(image, mask) if _validate_mask_quality(mask, image.shape[:2]): return mask # Geometric fallback return _geometric_person_mask(image) # Saliency, GrabCut helpers -------------------------------------------------- def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]: try: if hasattr(cv2, "saliency"): s = cv2.saliency.StaticSaliencySpectralResidual_create() ok, smap = s.computeSaliency(image) if ok: smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min()) return smap except Exception: pass return None def _auto_person_rect(image): sal = _compute_saliency(image) if sal is None: return None m = (sal > SALIENCY_THRESH).astype(np.uint8) cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not cnts: return None x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea)) H,W = image.shape[:2] pad = 0.05 x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad)) w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad))) return x,y,w,h def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray: h,w = image.shape[:2] gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8) gc[seed>200] = cv2.GC_FGD rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7)) bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64) cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK) return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8) def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray: sal = _compute_saliency(image) if sal is None: return seed high = (sal > SALIENCY_THRESH).astype(np.uint8)*255 ys,xs = np.where(seed>127) cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2 ff = high.copy() cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5) return ff # ============================================================================ # QUALITY / HELPER FUNCTIONS # ============================================================================ def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool: h,w = shape ratio = np.sum(mask>127)/(h*w) return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO def _process_mask(mask: np.ndarray) -> np.ndarray: """Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask""" if mask.dtype in (np.float32, np.float64): if mask.max() <= 1.0: mask = (mask*255).astype(np.uint8) if mask.dtype != np.uint8: mask = mask.astype(np.uint8) if mask.ndim == 3: mask = mask.squeeze() if mask.ndim == 3: # multi-channel mask → collapse mask = mask[:,:,0] _,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY) return mask def _geometric_person_mask(image: np.ndarray) -> np.ndarray: h,w = image.shape[:2] mask = np.zeros((h,w), np.uint8) cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1) return mask # ============================================================================ # OPTIONAL: Iterative auto-refinement (lightweight) # ============================================================================ def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1): # Simple one-pass hook (full version lives in refinement.py) return mask