File size: 14,051 Bytes
6fdc616 bd9b4d2 6fdc616 bd9b4d2 6fdc616 bd9b4d2 6fdc616 bd9b4d2 6fdc616 19a2b07 6fdc616 bd9b4d2 6fdc616 bd9b4d2 6fdc616 19a2b07 6fdc616 19a2b07 6fdc616 19a2b07 6fdc616 19a2b07 6fdc616 19a2b07 6fdc616 bd9b4d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
#!/usr/bin/env python3
"""
utils.segmentation
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
All high-quality person-segmentation code for BackgroundFX Pro.
Exports
-------
segment_person_hq(image, predictor, fallback_enabled=True) β np.ndarray
segment_person_hq_original(image, predictor, fallback_enabled=True) β np.ndarray
SegmentationError - Custom exception for segmentation errors
Everything else is prefixed "_" and considered private.
"""
from __future__ import annotations
from typing import Any, Tuple, Optional, Dict
import logging, os, math
import cv2
import numpy as np
import torch
log = logging.getLogger(__name__)
# ============================================================================
# CUSTOM EXCEPTION
# ============================================================================
class SegmentationError(Exception):
"""Custom exception for segmentation-related errors"""
pass
# ============================================================================
# TUNABLE CONSTANTS
# ============================================================================
USE_ENHANCED_SEGMENTATION = True
USE_INTELLIGENT_PROMPTING = True
USE_ITERATIVE_REFINEMENT = True
MIN_AREA_RATIO = 0.015
MAX_AREA_RATIO = 0.97
SALIENCY_THRESH = 0.65
GRABCUT_ITERS = 3
# ----------------------------------------------------------------------------
# Public -- main entry-points
# ----------------------------------------------------------------------------
__all__ = [
"segment_person_hq",
"segment_person_hq_original",
"SegmentationError",
]
# ============================================================================
# SAM2 TO MATANYONE MASK BRIDGE
# ============================================================================
def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray:
"""
Convert SAM2 multi-mask output to single best mask for MatAnyone.
SAM2 returns (N, H, W) where N is typically 3 masks.
We need to return a single (H, W) mask.
"""
if masks is None or len(masks) == 0:
raise SegmentationError("No masks returned from SAM2")
# Handle torch tensors
if isinstance(masks, torch.Tensor):
masks = masks.cpu().numpy()
if scores is not None and isinstance(scores, torch.Tensor):
scores = scores.cpu().numpy()
# Ensure we have the right shape
if masks.ndim == 4: # (B, N, H, W)
masks = masks[0] # Take first batch
if masks.ndim != 3: # Should be (N, H, W)
raise SegmentationError(f"Unexpected mask shape: {masks.shape}")
# Select best mask
if scores is not None and len(scores) > 0:
best_idx = int(np.argmax(scores))
else:
# Fallback: pick mask with largest area
areas = [np.sum(m > 0.5) for m in masks]
best_idx = int(np.argmax(areas))
mask = masks[best_idx]
# Convert to uint8 binary mask
if mask.dtype in (np.float32, np.float64):
mask = (mask > 0.5).astype(np.uint8) * 255
elif mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
# Ensure single channel
if mask.ndim == 3:
mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze()
# Binary threshold
_, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# Verify output shape
assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}"
return mask
# ============================================================================
# MAIN API
# ============================================================================
def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
"""
High-quality person segmentation. Tries SAM-2 with smart prompts first,
then a classical CV cascade, then a geometric fallback.
Returns uint8 mask (0/255). Never raises if fallback_enabled=True.
"""
if not USE_ENHANCED_SEGMENTATION:
return segment_person_hq_original(image, predictor, fallback_enabled)
if image is None or image.size == 0:
raise SegmentationError("Invalid input image")
# 1) β SAM-2 path -------------------------------------------------------
if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
try:
predictor.set_image(image)
mask = (
_segment_with_intelligent_prompts(image, predictor)
if USE_INTELLIGENT_PROMPTING
else _segment_with_basic_prompts(image, predictor)
)
if USE_ITERATIVE_REFINEMENT:
mask = _auto_refine_mask_iteratively(image, mask, predictor)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
log.warning("SAM2 mask failed validation β fallback")
except Exception as e:
log.warning(f"SAM2 path failed: {e}")
# 2) β Classical cascade ----------------------------------------------
try:
mask = _classical_segmentation_cascade(image)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
log.warning("Classical cascade weak β geometric fallback")
except Exception as e:
log.debug(f"Classical cascade error: {e}")
# 3) β Last-chance geometric ellipse ----------------------------------
return _geometric_person_mask(image)
def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
"""
Very first implementation kept for rollback. Fewer smarts, still robust.
"""
if image is None or image.size == 0:
raise SegmentationError("Invalid input image")
try:
if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
h, w = image.shape[:2]
predictor.set_image(image)
points = np.array([
[w//2, h//4],
[w//2, h//2],
[w//2, 3*h//4],
[w//3, h//2],
[2*w//3, h//2],
], dtype=np.float32)
labels = np.ones(len(points), np.int32)
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True,
)
# Use the bridge function to get single best mask
if masks is not None and len(masks):
mask = _sam2_to_matanyone_mask(masks, scores)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
if fallback_enabled:
return _classical_segmentation_cascade(image)
raise RuntimeError("SAM2 failed and fallback disabled")
except Exception as e:
log.warning(f"segment_person_hq_original error: {e}")
return _classical_segmentation_cascade(image)
# ============================================================================
# INTELLIGENT + BASIC PROMPTING
# ============================================================================
def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
pos, neg = _generate_smart_prompts(image)
return _sam2_predict(image, predictor, pos, neg)
def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
h, w = image.shape[:2]
pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32)
neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32)
return _sam2_predict(image, predictor, pos, neg)
def _sam2_predict(image: np.ndarray, predictor: Any,
pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray:
if pos_points.size == 0:
pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32)
points = np.vstack([pos_points, neg_points])
labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32)
with torch.no_grad():
masks, scores, _ = predictor.predict(
point_coords=points,
point_labels=labels,
multimask_output=True,
)
# Use the bridge function to convert multi-mask to single mask
return _sam2_to_matanyone_mask(masks, scores)
def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Simple saliency-based heuristic to auto-place positive / negative points.
"""
h, w = image.shape[:2]
sal = _compute_saliency(image)
pos, neg = [], []
if sal is not None:
high = sal > (SALIENCY_THRESH - .1)
contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
M = cv2.moments(c)
if M["m00"]:
pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])])
if not pos:
pos = [[w//2, h//2]]
neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]]
return np.asarray(pos, np.float32), np.asarray(neg, np.float32)
# ============================================================================
# CLASSICAL SEGMENTATION CASCADE
# ============================================================================
def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray:
"""
Edge-median background subtraction β saliency flood-fill β GrabCut.
"""
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]])
diff = np.abs(gray.astype(float) - np.median(edge_px))
mask = (diff > 30).astype(np.uint8) * 255
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE,
cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
if _validate_mask_quality(mask, image.shape[:2]):
return mask
# Saliency + flood-fill
mask = _refine_with_saliency(image, mask)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
# GrabCut
mask = _refine_with_grabcut(image, mask)
if _validate_mask_quality(mask, image.shape[:2]):
return mask
# Geometric fallback
return _geometric_person_mask(image)
# Saliency, GrabCut helpers --------------------------------------------------
def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]:
try:
if hasattr(cv2, "saliency"):
s = cv2.saliency.StaticSaliencySpectralResidual_create()
ok, smap = s.computeSaliency(image)
if ok:
smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min())
return smap
except Exception:
pass
return None
def _auto_person_rect(image):
sal = _compute_saliency(image)
if sal is None:
return None
m = (sal > SALIENCY_THRESH).astype(np.uint8)
cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not cnts:
return None
x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
H,W = image.shape[:2]
pad = 0.05
x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad))
w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad)))
return x,y,w,h
def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
h,w = image.shape[:2]
gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
gc[seed>200] = cv2.GC_FGD
rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7))
bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64)
cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK)
return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8)
def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
sal = _compute_saliency(image)
if sal is None:
return seed
high = (sal > SALIENCY_THRESH).astype(np.uint8)*255
ys,xs = np.where(seed>127)
cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2
ff = high.copy()
cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5)
return ff
# ============================================================================
# QUALITY / HELPER FUNCTIONS
# ============================================================================
def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool:
h,w = shape
ratio = np.sum(mask>127)/(h*w)
return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO
def _process_mask(mask: np.ndarray) -> np.ndarray:
"""Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask"""
if mask.dtype in (np.float32, np.float64):
if mask.max() <= 1.0:
mask = (mask*255).astype(np.uint8)
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
if mask.ndim == 3:
mask = mask.squeeze()
if mask.ndim == 3: # multi-channel mask β collapse
mask = mask[:,:,0]
_,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
return mask
def _geometric_person_mask(image: np.ndarray) -> np.ndarray:
h,w = image.shape[:2]
mask = np.zeros((h,w), np.uint8)
cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1)
return mask
# ============================================================================
# OPTIONAL: Iterative auto-refinement (lightweight)
# ============================================================================
def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1):
# Simple one-pass hook (full version lives in refinement.py)
return mask |