File size: 14,051 Bytes
6fdc616
 
 
 
 
 
 
 
 
 
bd9b4d2
6fdc616
bd9b4d2
6fdc616
 
 
 
 
 
 
 
 
 
 
 
bd9b4d2
 
 
 
 
 
 
6fdc616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd9b4d2
6fdc616
 
19a2b07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fdc616
 
 
 
 
 
 
 
 
 
 
 
 
 
bd9b4d2
6fdc616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd9b4d2
6fdc616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a2b07
 
6fdc616
19a2b07
6fdc616
 
19a2b07
6fdc616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a2b07
 
 
6fdc616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19a2b07
6fdc616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd9b4d2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
#!/usr/bin/env python3
"""
utils.segmentation
─────────────────────────────────────────────────────────────────────────────
All high-quality person-segmentation code for BackgroundFX Pro.

Exports
-------
segment_person_hq(image, predictor, fallback_enabled=True) β†’ np.ndarray
segment_person_hq_original(image, predictor, fallback_enabled=True) β†’ np.ndarray
SegmentationError - Custom exception for segmentation errors

Everything else is prefixed "_" and considered private.
"""

from __future__ import annotations
from typing import Any, Tuple, Optional, Dict
import logging, os, math

import cv2
import numpy as np
import torch

log = logging.getLogger(__name__)

# ============================================================================
# CUSTOM EXCEPTION
# ============================================================================
class SegmentationError(Exception):
    """Custom exception for segmentation-related errors"""
    pass

# ============================================================================
# TUNABLE CONSTANTS
# ============================================================================
USE_ENHANCED_SEGMENTATION   = True
USE_INTELLIGENT_PROMPTING   = True
USE_ITERATIVE_REFINEMENT    = True

MIN_AREA_RATIO = 0.015
MAX_AREA_RATIO = 0.97
SALIENCY_THRESH = 0.65
GRABCUT_ITERS   = 3

# ----------------------------------------------------------------------------
# Public -- main entry-points
# ----------------------------------------------------------------------------
__all__ = [
    "segment_person_hq",
    "segment_person_hq_original",
    "SegmentationError",
]

# ============================================================================
# SAM2 TO MATANYONE MASK BRIDGE
# ============================================================================
def _sam2_to_matanyone_mask(masks: Any, scores: Any = None) -> np.ndarray:
    """
    Convert SAM2 multi-mask output to single best mask for MatAnyone.
    SAM2 returns (N, H, W) where N is typically 3 masks.
    We need to return a single (H, W) mask.
    """
    if masks is None or len(masks) == 0:
        raise SegmentationError("No masks returned from SAM2")
    
    # Handle torch tensors
    if isinstance(masks, torch.Tensor):
        masks = masks.cpu().numpy()
    if scores is not None and isinstance(scores, torch.Tensor):
        scores = scores.cpu().numpy()
    
    # Ensure we have the right shape
    if masks.ndim == 4:  # (B, N, H, W)
        masks = masks[0]  # Take first batch
    if masks.ndim != 3:  # Should be (N, H, W)
        raise SegmentationError(f"Unexpected mask shape: {masks.shape}")
    
    # Select best mask
    if scores is not None and len(scores) > 0:
        best_idx = int(np.argmax(scores))
    else:
        # Fallback: pick mask with largest area
        areas = [np.sum(m > 0.5) for m in masks]
        best_idx = int(np.argmax(areas))
    
    mask = masks[best_idx]
    
    # Convert to uint8 binary mask
    if mask.dtype in (np.float32, np.float64):
        mask = (mask > 0.5).astype(np.uint8) * 255
    elif mask.dtype != np.uint8:
        mask = mask.astype(np.uint8)
    
    # Ensure single channel
    if mask.ndim == 3:
        mask = mask[:, :, 0] if mask.shape[2] > 1 else mask.squeeze()
    
    # Binary threshold
    _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    
    # Verify output shape
    assert mask.ndim == 2, f"Output mask must be 2D, got shape {mask.shape}"
    
    return mask

# ============================================================================
# MAIN API
# ============================================================================

def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
    """
    High-quality person segmentation.  Tries SAM-2 with smart prompts first,
    then a classical CV cascade, then a geometric fallback.
    Returns uint8 mask (0/255).  Never raises if fallback_enabled=True.
    """
    if not USE_ENHANCED_SEGMENTATION:
        return segment_person_hq_original(image, predictor, fallback_enabled)

    if image is None or image.size == 0:
        raise SegmentationError("Invalid input image")

    # 1) β€” SAM-2 path -------------------------------------------------------
    if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
        try:
            predictor.set_image(image)
            mask = (
                _segment_with_intelligent_prompts(image, predictor)
                if USE_INTELLIGENT_PROMPTING
                else _segment_with_basic_prompts(image, predictor)
            )
            if USE_ITERATIVE_REFINEMENT:
                mask = _auto_refine_mask_iteratively(image, mask, predictor)
            if _validate_mask_quality(mask, image.shape[:2]):
                return mask
            log.warning("SAM2 mask failed validation β†’ fallback")
        except Exception as e:
            log.warning(f"SAM2 path failed: {e}")

    # 2) β€” Classical cascade ----------------------------------------------
    try:
        mask = _classical_segmentation_cascade(image)
        if _validate_mask_quality(mask, image.shape[:2]):
            return mask
        log.warning("Classical cascade weak β†’ geometric fallback")
    except Exception as e:
        log.debug(f"Classical cascade error: {e}")

    # 3) β€” Last-chance geometric ellipse ----------------------------------
    return _geometric_person_mask(image)


def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
    """
    Very first implementation kept for rollback.  Fewer smarts, still robust.
    """
    if image is None or image.size == 0:
        raise SegmentationError("Invalid input image")

    try:
        if predictor and hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
            h, w = image.shape[:2]
            predictor.set_image(image)

            points = np.array([
                [w//2, h//4],
                [w//2, h//2],
                [w//2, 3*h//4],
                [w//3, h//2],
                [2*w//3, h//2],
            ], dtype=np.float32)
            labels = np.ones(len(points), np.int32)

            with torch.no_grad():
                masks, scores, _ = predictor.predict(
                    point_coords=points,
                    point_labels=labels,
                    multimask_output=True,
                )
            
            # Use the bridge function to get single best mask
            if masks is not None and len(masks):
                mask = _sam2_to_matanyone_mask(masks, scores)
                if _validate_mask_quality(mask, image.shape[:2]):
                    return mask
                    
        if fallback_enabled:
            return _classical_segmentation_cascade(image)
        raise RuntimeError("SAM2 failed and fallback disabled")
    except Exception as e:
        log.warning(f"segment_person_hq_original error: {e}")
        return _classical_segmentation_cascade(image)


# ============================================================================
# INTELLIGENT + BASIC PROMPTING
# ============================================================================

def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
    pos, neg = _generate_smart_prompts(image)
    return _sam2_predict(image, predictor, pos, neg)


def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
    h, w = image.shape[:2]
    pos = np.array([[w//2, h//3], [w//2, h//2], [w//2, 2*h//3]], np.float32)
    neg = np.array([[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]], np.float32)
    return _sam2_predict(image, predictor, pos, neg)


def _sam2_predict(image: np.ndarray, predictor: Any,
                  pos_points: np.ndarray, neg_points: np.ndarray) -> np.ndarray:
    if pos_points.size == 0:
        pos_points = np.array([[image.shape[1]//2, image.shape[0]//2]], np.float32)
    points = np.vstack([pos_points, neg_points])
    labels = np.hstack([np.ones(len(pos_points)), np.zeros(len(neg_points))]).astype(np.int32)
    with torch.no_grad():
        masks, scores, _ = predictor.predict(
            point_coords=points,
            point_labels=labels,
            multimask_output=True,
        )
    
    # Use the bridge function to convert multi-mask to single mask
    return _sam2_to_matanyone_mask(masks, scores)


def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Simple saliency-based heuristic to auto-place positive / negative points.
    """
    h, w = image.shape[:2]
    sal = _compute_saliency(image)
    pos, neg = [], []
    if sal is not None:
        high = sal > (SALIENCY_THRESH - .1)
        contours, _ = cv2.findContours((high*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        for c in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
            M = cv2.moments(c)
            if M["m00"]:
                pos.append([int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"])])
    if not pos:
        pos = [[w//2, h//2]]
    neg = [[10, 10], [w-10, 10], [10, h-10], [w-10, h-10]]
    return np.asarray(pos, np.float32), np.asarray(neg, np.float32)

# ============================================================================
# CLASSICAL SEGMENTATION CASCADE
# ============================================================================

def _classical_segmentation_cascade(image: np.ndarray) -> np.ndarray:
    """
    Edge-median background subtraction β†’ saliency flood-fill β†’ GrabCut.
    """
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    edge_px = np.concatenate([gray[0], gray[-1], gray[:, 0], gray[:, -1]])
    diff = np.abs(gray.astype(float) - np.median(edge_px))
    mask = (diff > 30).astype(np.uint8) * 255
    mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE,
                            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))
    if _validate_mask_quality(mask, image.shape[:2]):
        return mask
    # Saliency + flood-fill
    mask = _refine_with_saliency(image, mask)
    if _validate_mask_quality(mask, image.shape[:2]):
        return mask
    # GrabCut
    mask = _refine_with_grabcut(image, mask)
    if _validate_mask_quality(mask, image.shape[:2]):
        return mask
    # Geometric fallback
    return _geometric_person_mask(image)

#  Saliency, GrabCut helpers --------------------------------------------------

def _compute_saliency(image: np.ndarray) -> Optional[np.ndarray]:
    try:
        if hasattr(cv2, "saliency"):
            s = cv2.saliency.StaticSaliencySpectralResidual_create()
            ok, smap = s.computeSaliency(image)
            if ok:
                smap = (smap - smap.min()) / max(1e-6, smap.max()-smap.min())
                return smap
    except Exception:
        pass
    return None

def _auto_person_rect(image):
    sal = _compute_saliency(image)
    if sal is None:
        return None
    m = (sal > SALIENCY_THRESH).astype(np.uint8)
    cnts, _ = cv2.findContours(m*255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not cnts:
        return None
    x,y,w,h = cv2.boundingRect(max(cnts, key=cv2.contourArea))
    H,W = image.shape[:2]
    pad = 0.05
    x = max(0, int(x-W*pad)); y = max(0, int(y-H*pad))
    w = min(W-x, int(w*(1+2*pad))); h = min(H-y, int(h*(1+2*pad)))
    return x,y,w,h

def _refine_with_grabcut(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
    h,w = image.shape[:2]
    gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
    gc[seed>200] = cv2.GC_FGD
    rect = _auto_person_rect(image) or (w//4, h//6, w//2, int(h*0.7))
    bgd, fgd = np.zeros((1,65), np.float64), np.zeros((1,65), np.float64)
    cv2.grabCut(image, gc, rect, bgd, fgd, GRABCUT_ITERS, cv2.GC_INIT_WITH_MASK)
    return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD), 255, 0).astype(np.uint8)

def _refine_with_saliency(image: np.ndarray, seed: np.ndarray) -> np.ndarray:
    sal = _compute_saliency(image)
    if sal is None:
        return seed
    high = (sal > SALIENCY_THRESH).astype(np.uint8)*255
    ys,xs = np.where(seed>127)
    cy,cx = int(np.mean(ys)) if len(ys) else image.shape[0]//2, int(np.mean(xs)) if len(xs) else image.shape[1]//2
    ff = high.copy()
    cv2.floodFill(ff, None, (cx,cy), 255, loDiff=5, upDiff=5)
    return ff

# ============================================================================
# QUALITY / HELPER FUNCTIONS
# ============================================================================

def _validate_mask_quality(mask: np.ndarray, shape: Tuple[int,int]) -> bool:
    h,w = shape
    ratio = np.sum(mask>127)/(h*w)
    return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO

def _process_mask(mask: np.ndarray) -> np.ndarray:
    """Legacy mask processor - kept for compatibility but mostly replaced by _sam2_to_matanyone_mask"""
    if mask.dtype in (np.float32, np.float64):
        if mask.max() <= 1.0:
            mask = (mask*255).astype(np.uint8)
    if mask.dtype != np.uint8:
        mask = mask.astype(np.uint8)
    if mask.ndim == 3:
        mask = mask.squeeze()
        if mask.ndim == 3:            # multi-channel mask β†’ collapse
            mask = mask[:,:,0]
    _,mask = cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
    return mask

def _geometric_person_mask(image: np.ndarray) -> np.ndarray:
    h,w = image.shape[:2]
    mask = np.zeros((h,w), np.uint8)
    cv2.ellipse(mask, (w//2,h//2), (w//3,int(h/2.5)), 0, 0,360, 255,-1)
    return mask

# ============================================================================
# OPTIONAL: Iterative auto-refinement (lightweight)
# ============================================================================

def _auto_refine_mask_iteratively(image, mask, predictor, max_iterations=1):
    # Simple one-pass hook (full version lives in refinement.py)
    return mask