MogensR commited on
Commit
db9de0d
·
1 Parent(s): a54098c

Update utilities.py

Browse files
Files changed (1) hide show
  1. utilities.py +426 -146
utilities.py CHANGED
@@ -1,7 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
- Enhanced utilities.py - Core computer vision functions with improved error handling
4
- Fixed transparency issues, added fallback strategies, and enhanced memory management
 
5
  """
6
 
7
  import os
@@ -11,14 +12,24 @@
11
  from PIL import Image, ImageDraw
12
  import logging
13
  import time
14
- from typing import Optional, Dict, Any, Tuple
15
  from pathlib import Path
16
 
17
- # Setup logging
 
 
 
 
 
 
 
 
 
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
21
- # Professional background templates with enhanced configurations
22
  PROFESSIONAL_BACKGROUNDS = {
23
  "office_modern": {
24
  "name": "Modern Office",
@@ -94,6 +105,7 @@
94
  }
95
  }
96
 
 
97
  class SegmentationError(Exception):
98
  """Custom exception for segmentation failures"""
99
  pass
@@ -106,9 +118,15 @@ class BackgroundReplacementError(Exception):
106
  """Custom exception for background replacement failures"""
107
  pass
108
 
 
 
 
 
109
  def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
110
  """
111
- High-quality person segmentation with enhanced error handling and fallback strategies
 
 
112
 
113
  Args:
114
  image: Input image (H, W, 3)
@@ -117,9 +135,398 @@ def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool
117
 
118
  Returns:
119
  Binary mask (H, W) with values 0-255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- Raises:
122
- SegmentationError: If segmentation fails and fallback is disabled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  """
124
  if image is None or image.size == 0:
125
  raise SegmentationError("Invalid input image")
@@ -214,6 +621,10 @@ def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool
214
  else:
215
  raise SegmentationError(f"Unexpected error: {e}")
216
 
 
 
 
 
217
  def _process_mask(mask: np.ndarray) -> np.ndarray:
218
  """Process raw mask to ensure correct format and range"""
219
  try:
@@ -344,22 +755,15 @@ def _fallback_segmentation(image: np.ndarray) -> np.ndarray:
344
  mask[h//6:5*h//6, w//4:3*w//4] = 255
345
  return mask
346
 
 
 
 
 
347
  def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any,
348
  fallback_enabled: bool = True) -> np.ndarray:
349
  """
350
  Enhanced mask refinement with MatAnyone and robust fallbacks
351
-
352
- Args:
353
- image: Input image (H, W, 3)
354
- mask: Input mask (H, W) with values 0-255
355
- matanyone_processor: MatAnyone processor instance
356
- fallback_enabled: Whether to use fallback refinement if MatAnyone fails
357
-
358
- Returns:
359
- Refined mask (H, W) with values 0-255
360
-
361
- Raises:
362
- MaskRefinementError: If refinement fails and fallback is disabled
363
  """
364
  if image is None or mask is None:
365
  raise MaskRefinementError("Invalid input image or mask")
@@ -426,118 +830,8 @@ def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Op
426
  logger.warning(f"MatAnyone processing error: {e}")
427
  return None
428
 
429
- def _background_matting_v2_refine(image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
430
- """Use BackgroundMattingV2 for mask refinement"""
431
- try:
432
- # Import BackgroundMattingV2 if available
433
- from inference_images import inference_img
434
- import torch
435
-
436
- # Convert inputs to proper format
437
- image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
438
- mask_tensor = torch.from_numpy(mask).float() / 255.0
439
-
440
- # Create trimap from mask
441
- trimap = _create_trimap_from_mask(mask)
442
- trimap_tensor = torch.from_numpy(trimap).float()
443
-
444
- # Run inference
445
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
446
-
447
- with torch.no_grad():
448
- alpha = inference_img(
449
- image_tensor.unsqueeze(0).to(device),
450
- trimap_tensor.unsqueeze(0).unsqueeze(0).to(device)
451
- )
452
-
453
- # Convert back to numpy
454
- refined_mask = alpha.cpu().squeeze().numpy()
455
- refined_mask = (refined_mask * 255).astype(np.uint8)
456
-
457
- logger.info("BackgroundMattingV2 refinement successful")
458
- return refined_mask
459
-
460
- except ImportError:
461
- logger.warning("BackgroundMattingV2 not available")
462
- return None
463
- except Exception as e:
464
- logger.warning(f"BackgroundMattingV2 error: {e}")
465
- return None
466
-
467
- def _rembg_refine(image: np.ndarray, mask: np.ndarray) -> Optional[np.ndarray]:
468
- """Use rembg for mask refinement"""
469
- try:
470
- from rembg import remove, new_session
471
-
472
- # Use rembg to get a high-quality mask
473
- session = new_session('u2net')
474
-
475
- # Convert image to PIL
476
- from PIL import Image
477
- pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
478
-
479
- # Remove background
480
- output = remove(pil_image, session=session)
481
-
482
- # Extract alpha channel as mask
483
- if output.mode == 'RGBA':
484
- alpha = np.array(output)[:, :, 3]
485
- else:
486
- # Fallback: convert to grayscale
487
- alpha = np.array(output.convert('L'))
488
-
489
- # Combine with original mask using weighted average
490
- original_mask_norm = mask.astype(np.float32) / 255.0
491
- rembg_mask_norm = alpha.astype(np.float32) / 255.0
492
-
493
- # Weighted combination: 70% rembg, 30% original
494
- combined = 0.7 * rembg_mask_norm + 0.3 * original_mask_norm
495
- combined = np.clip(combined * 255, 0, 255).astype(np.uint8)
496
-
497
- logger.info("Rembg refinement successful")
498
- return combined
499
-
500
- except ImportError:
501
- logger.warning("Rembg not available")
502
- return None
503
- except Exception as e:
504
- logger.warning(f"Rembg error: {e}")
505
- return None
506
-
507
- def _create_trimap_from_mask(mask: np.ndarray, erode_size: int = 10, dilate_size: int = 20) -> np.ndarray:
508
- """Create trimap from binary mask for BackgroundMattingV2"""
509
- try:
510
- # Ensure mask is binary
511
- _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
512
-
513
- # Create trimap: 0 = background, 128 = unknown, 255 = foreground
514
- trimap = np.zeros_like(mask, dtype=np.uint8)
515
-
516
- # Erode mask to get sure foreground
517
- kernel_erode = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_size, erode_size))
518
- sure_fg = cv2.erode(binary_mask, kernel_erode, iterations=1)
519
-
520
- # Dilate mask to get unknown region
521
- kernel_dilate = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (dilate_size, dilate_size))
522
- unknown_region = cv2.dilate(binary_mask, kernel_dilate, iterations=1)
523
-
524
- # Set trimap values
525
- trimap[sure_fg == 255] = 255 # Sure foreground
526
- trimap[(unknown_region == 255) & (sure_fg == 0)] = 128 # Unknown
527
- # Background remains 0
528
-
529
- return trimap
530
-
531
- except Exception as e:
532
- logger.warning(f"Trimap creation failed: {e}")
533
- # Return simple trimap based on original mask
534
- trimap = np.where(mask > 127, 255, 0).astype(np.uint8)
535
- return trimap
536
-
537
  def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
538
- """
539
- Advanced OpenCV-based mask enhancement with multiple techniques
540
- """
541
  try:
542
  if len(mask.shape) == 3:
543
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
@@ -613,21 +907,7 @@ def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8,
613
 
614
  def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray,
615
  fallback_enabled: bool = True) -> np.ndarray:
616
- """
617
- Enhanced background replacement with comprehensive error handling and quality improvements
618
-
619
- Args:
620
- frame: Input frame (H, W, 3)
621
- mask: Binary mask (H, W) with values 0-255
622
- background: Background image (H, W, 3)
623
- fallback_enabled: Whether to use fallback if main method fails
624
-
625
- Returns:
626
- Composited frame (H, W, 3)
627
-
628
- Raises:
629
- BackgroundReplacementError: If replacement fails and fallback is disabled
630
- """
631
  if frame is None or mask is None or background is None:
632
  raise BackgroundReplacementError("Invalid input frame, mask, or background")
633
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Enhanced utilities.py - Core computer vision functions with auto-best quality
4
+ VERSION: 2.0-auto-best
5
+ ROLLBACK: Set USE_ENHANCED_SEGMENTATION = False to revert to original behavior
6
  """
7
 
8
  import os
 
12
  from PIL import Image, ImageDraw
13
  import logging
14
  import time
15
+ from typing import Optional, Dict, Any, Tuple, List
16
  from pathlib import Path
17
 
18
+ # ============================================================================
19
+ # VERSION CONTROL AND FEATURE FLAGS - EASY ROLLBACK
20
+ # ============================================================================
21
+
22
+ # ROLLBACK CONTROL: Set to False to use original functions
23
+ USE_ENHANCED_SEGMENTATION = True
24
+ USE_AUTO_TEMPORAL_CONSISTENCY = True
25
+ USE_INTELLIGENT_PROMPTING = True
26
+ USE_ITERATIVE_REFINEMENT = True
27
+
28
+ # Logging
29
  logging.basicConfig(level=logging.INFO)
30
  logger = logging.getLogger(__name__)
31
 
32
+ # Professional background templates (unchanged)
33
  PROFESSIONAL_BACKGROUNDS = {
34
  "office_modern": {
35
  "name": "Modern Office",
 
105
  }
106
  }
107
 
108
+ # Exceptions (unchanged)
109
  class SegmentationError(Exception):
110
  """Custom exception for segmentation failures"""
111
  pass
 
118
  """Custom exception for background replacement failures"""
119
  pass
120
 
121
+ # ============================================================================
122
+ # ENHANCED SEGMENTATION FUNCTIONS - NEW AUTO-BEST VERSION
123
+ # ============================================================================
124
+
125
  def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
126
  """
127
+ ENHANCED VERSION 2.0: High-quality person segmentation with intelligent automation
128
+
129
+ ROLLBACK: Set USE_ENHANCED_SEGMENTATION = False to revert to original behavior
130
 
131
  Args:
132
  image: Input image (H, W, 3)
 
135
 
136
  Returns:
137
  Binary mask (H, W) with values 0-255
138
+ """
139
+ if not USE_ENHANCED_SEGMENTATION:
140
+ return segment_person_hq_original(image, predictor, fallback_enabled)
141
+
142
+ logger.debug("Using ENHANCED segmentation with intelligent automation")
143
+
144
+ if image is None or image.size == 0:
145
+ raise SegmentationError("Invalid input image")
146
+
147
+ try:
148
+ # Validate predictor
149
+ if predictor is None:
150
+ if fallback_enabled:
151
+ logger.warning("SAM2 predictor not available, using fallback")
152
+ return _fallback_segmentation(image)
153
+ else:
154
+ raise SegmentationError("SAM2 predictor not available")
155
+
156
+ # Set image for prediction
157
+ try:
158
+ predictor.set_image(image)
159
+ except Exception as e:
160
+ logger.error(f"Failed to set image in predictor: {e}")
161
+ if fallback_enabled:
162
+ return _fallback_segmentation(image)
163
+ else:
164
+ raise SegmentationError(f"Predictor setup failed: {e}")
165
+
166
+ # ENHANCED: Intelligent automatic prompt generation
167
+ if USE_INTELLIGENT_PROMPTING:
168
+ mask = _segment_with_intelligent_prompts(image, predictor)
169
+ else:
170
+ mask = _segment_with_basic_prompts(image, predictor)
171
+
172
+ # ENHANCED: Iterative refinement
173
+ if USE_ITERATIVE_REFINEMENT and mask is not None:
174
+ mask = _auto_refine_mask_iteratively(image, mask, predictor)
175
+
176
+ # Validate mask quality
177
+ if not _validate_mask_quality(mask, image.shape[:2]):
178
+ logger.warning("Mask quality validation failed")
179
+ if fallback_enabled:
180
+ return _fallback_segmentation(image)
181
+ else:
182
+ raise SegmentationError("Poor mask quality")
183
+
184
+ logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}")
185
+ return mask
186
+
187
+ except SegmentationError:
188
+ raise
189
+ except Exception as e:
190
+ logger.error(f"Unexpected segmentation error: {e}")
191
+ if fallback_enabled:
192
+ return _fallback_segmentation(image)
193
+ else:
194
+ raise SegmentationError(f"Unexpected error: {e}")
195
+
196
+ def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
197
+ """NEW: Intelligent automatic prompt generation"""
198
+ try:
199
+ h, w = image.shape[:2]
200
+
201
+ # Generate content-aware prompts
202
+ pos_points, neg_points = _generate_smart_prompts(image)
203
+
204
+ if len(pos_points) == 0:
205
+ # Fallback to center point
206
+ pos_points = np.array([[w//2, h//2]], dtype=np.float32)
207
+
208
+ # Combine points and labels
209
+ points = np.vstack([pos_points, neg_points])
210
+ labels = np.hstack([
211
+ np.ones(len(pos_points), dtype=np.int32),
212
+ np.zeros(len(neg_points), dtype=np.int32)
213
+ ])
214
+
215
+ logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points")
216
+
217
+ # Perform prediction
218
+ with torch.no_grad():
219
+ masks, scores, _ = predictor.predict(
220
+ point_coords=points,
221
+ point_labels=labels,
222
+ multimask_output=True
223
+ )
224
+
225
+ if masks is None or len(masks) == 0:
226
+ raise SegmentationError("No masks generated")
227
+
228
+ # Select best mask
229
+ if scores is not None and len(scores) > 0:
230
+ best_idx = np.argmax(scores)
231
+ best_mask = masks[best_idx]
232
+ logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
233
+ else:
234
+ best_mask = masks[0]
235
+
236
+ return _process_mask(best_mask)
237
+
238
+ except Exception as e:
239
+ logger.error(f"Intelligent prompting failed: {e}")
240
+ raise
241
+
242
+ def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
243
+ """NEW: Generate optimal positive/negative points automatically"""
244
+ try:
245
+ h, w = image.shape[:2]
246
+
247
+ # Method 1: Saliency-based point placement
248
+ try:
249
+ saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
250
+ success, saliency_map = saliency.computeSaliency(image)
251
+
252
+ if success:
253
+ # Find high-saliency regions
254
+ saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1]
255
+ contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8),
256
+ cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
257
+
258
+ positive_points = []
259
+ if contours:
260
+ # Get center points of largest salient regions
261
+ for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
262
+ M = cv2.moments(contour)
263
+ if M["m00"] != 0:
264
+ cx = int(M["m10"] / M["m00"])
265
+ cy = int(M["m01"] / M["m00"])
266
+ # Ensure points are within image bounds
267
+ if 0 < cx < w and 0 < cy < h:
268
+ positive_points.append([cx, cy])
269
+
270
+ if positive_points:
271
+ logger.debug(f"Generated {len(positive_points)} saliency-based points")
272
+ positive_points = np.array(positive_points, dtype=np.float32)
273
+ else:
274
+ raise Exception("No valid saliency points found")
275
+
276
+ except Exception as e:
277
+ logger.debug(f"Saliency method failed: {e}, using fallback")
278
+ # Method 2: Fallback to strategic grid points
279
+ positive_points = np.array([
280
+ [w//2, h//3], # Upper body
281
+ [w//2, h//2], # Center torso
282
+ [w//2, 2*h//3], # Lower body
283
+ ], dtype=np.float32)
284
+
285
+ # Always place negative points in corners and edges (likely background)
286
+ negative_points = np.array([
287
+ [10, 10], # Top-left corner
288
+ [w-10, 10], # Top-right corner
289
+ [10, h-10], # Bottom-left corner
290
+ [w-10, h-10], # Bottom-right corner
291
+ [w//2, 5], # Top center edge
292
+ [w//2, h-5], # Bottom center edge
293
+ ], dtype=np.float32)
294
+
295
+ return positive_points, negative_points
296
+
297
+ except Exception as e:
298
+ logger.warning(f"Smart prompt generation failed: {e}")
299
+ # Ultimate fallback
300
+ h, w = image.shape[:2]
301
+ positive_points = np.array([[w//2, h//2]], dtype=np.float32)
302
+ negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32)
303
+ return positive_points, negative_points
304
+
305
+ def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray,
306
+ predictor: Any, max_iterations: int = 2) -> np.ndarray:
307
+ """NEW: Automatically refine mask based on quality assessment"""
308
+ try:
309
+ current_mask = initial_mask.copy()
310
+ h, w = image.shape[:2]
311
+
312
+ for iteration in range(max_iterations):
313
+ # Analyze mask quality
314
+ quality_score = _assess_mask_quality(current_mask, image)
315
+ logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}")
316
+
317
+ if quality_score > 0.85: # Good enough
318
+ logger.debug(f"Quality sufficient after {iteration} iterations")
319
+ break
320
+
321
+ # Identify problem areas
322
+ problem_areas = _find_mask_errors(current_mask, image)
323
+
324
+ if np.any(problem_areas):
325
+ # Generate corrective prompts
326
+ corrective_points, corrective_labels = _generate_corrective_prompts(
327
+ image, current_mask, problem_areas
328
+ )
329
+
330
+ if len(corrective_points) > 0:
331
+ # Re-run SAM2 with additional prompts
332
+ try:
333
+ with torch.no_grad():
334
+ masks, scores, _ = predictor.predict(
335
+ point_coords=corrective_points,
336
+ point_labels=corrective_labels,
337
+ mask_input=current_mask[None, :, :], # Add batch dimension
338
+ multimask_output=False
339
+ )
340
+
341
+ if masks is not None and len(masks) > 0:
342
+ refined_mask = _process_mask(masks[0])
343
+
344
+ # Only use refined mask if it's actually better
345
+ if _assess_mask_quality(refined_mask, image) > quality_score:
346
+ current_mask = refined_mask
347
+ logger.debug(f"Improved mask in iteration {iteration}")
348
+ else:
349
+ logger.debug(f"Refinement didn't improve quality in iteration {iteration}")
350
+ break
351
+
352
+ except Exception as e:
353
+ logger.debug(f"Refinement iteration {iteration} failed: {e}")
354
+ break
355
+ else:
356
+ logger.debug("No problem areas detected")
357
+ break
358
+
359
+ return current_mask
360
+
361
+ except Exception as e:
362
+ logger.warning(f"Iterative refinement failed: {e}")
363
+ return initial_mask
364
+
365
+ def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float:
366
+ """NEW: Assess mask quality automatically"""
367
+ try:
368
+ h, w = image.shape[:2]
369
+
370
+ # Quality factors
371
+ scores = []
372
+
373
+ # 1. Area ratio (person should be 5-80% of image)
374
+ mask_area = np.sum(mask > 127)
375
+ total_area = h * w
376
+ area_ratio = mask_area / total_area
377
+
378
+ if 0.05 <= area_ratio <= 0.8:
379
+ area_score = 1.0
380
+ elif area_ratio < 0.05:
381
+ area_score = area_ratio / 0.05
382
+ else:
383
+ area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2)
384
+ scores.append(area_score)
385
+
386
+ # 2. Centeredness (person should be roughly centered)
387
+ mask_binary = mask > 127
388
+ if np.any(mask_binary):
389
+ mask_center_y, mask_center_x = np.where(mask_binary)
390
+ center_y = np.mean(mask_center_y) / h
391
+ center_x = np.mean(mask_center_x) / w
392
+
393
+ center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5))
394
+ scores.append(center_score)
395
+ else:
396
+ scores.append(0.0)
397
+
398
+ # 3. Edge smoothness
399
+ edges = cv2.Canny(mask, 50, 150)
400
+ edge_density = np.sum(edges > 0) / total_area
401
+ smoothness_score = max(0, 1.0 - edge_density * 10) # Penalize too many edges
402
+ scores.append(smoothness_score)
403
+
404
+ # 4. Connectivity (prefer single connected component)
405
+ num_labels, _ = cv2.connectedComponents(mask)
406
+ connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2) # -2 because background is label 0
407
+ scores.append(connectivity_score)
408
+
409
+ # Weighted average
410
+ weights = [0.3, 0.2, 0.3, 0.2]
411
+ overall_score = np.average(scores, weights=weights)
412
 
413
+ return overall_score
414
+
415
+ except Exception as e:
416
+ logger.warning(f"Quality assessment failed: {e}")
417
+ return 0.5 # Neutral score
418
+
419
+ def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
420
+ """NEW: Identify problematic areas in mask"""
421
+ try:
422
+ # Find areas with high gradient that might need correction
423
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
424
+
425
+ # Edge detection on original image
426
+ edges = cv2.Canny(gray, 50, 150)
427
+
428
+ # Mask edges
429
+ mask_edges = cv2.Canny(mask, 50, 150)
430
+
431
+ # Find discrepancy between image edges and mask edges
432
+ edge_discrepancy = cv2.bitwise_xor(edges, mask_edges)
433
+
434
+ # Dilate to create error regions
435
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
436
+ error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1)
437
+
438
+ return error_regions > 0
439
+
440
+ except Exception as e:
441
+ logger.warning(f"Error detection failed: {e}")
442
+ return np.zeros_like(mask, dtype=bool)
443
+
444
+ def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray,
445
+ problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
446
+ """NEW: Generate corrective prompts based on problem areas"""
447
+ try:
448
+ # Find centers of problem regions
449
+ contours, _ = cv2.findContours(problem_areas.astype(np.uint8),
450
+ cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
451
+
452
+ corrective_points = []
453
+ corrective_labels = []
454
+
455
+ for contour in contours:
456
+ if cv2.contourArea(contour) > 100: # Ignore tiny regions
457
+ M = cv2.moments(contour)
458
+ if M["m00"] != 0:
459
+ cx = int(M["m10"] / M["m00"])
460
+ cy = int(M["m01"] / M["m00"])
461
+
462
+ # Determine if this should be positive or negative
463
+ # Sample the current mask at this point
464
+ current_mask_value = mask[cy, cx]
465
+
466
+ # If mask says background but image has strong edges, add positive point
467
+ # If mask says foreground but area looks like background, add negative point
468
+ if current_mask_value < 127:
469
+ # Currently background, maybe should be foreground
470
+ corrective_points.append([cx, cy])
471
+ corrective_labels.append(1) # Positive
472
+ else:
473
+ # Currently foreground, maybe should be background
474
+ corrective_points.append([cx, cy])
475
+ corrective_labels.append(0) # Negative
476
+
477
+ return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2),
478
+ np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32))
479
+
480
+ except Exception as e:
481
+ logger.warning(f"Corrective prompt generation failed: {e}")
482
+ return np.array([]).reshape(0, 2), np.array([], dtype=np.int32)
483
+
484
+ def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
485
+ """FALLBACK: Original basic prompting method"""
486
+ h, w = image.shape[:2]
487
+
488
+ # Original strategic points with negative prompts added
489
+ positive_points = np.array([
490
+ [w//2, h//3], # Head area
491
+ [w//2, h//2], # Torso center
492
+ [w//2, 2*h//3], # Lower body
493
+ ], dtype=np.float32)
494
+
495
+ negative_points = np.array([
496
+ [w//10, h//10], # Top-left corner (background)
497
+ [9*w//10, h//10], # Top-right corner (background)
498
+ [w//10, 9*h//10], # Bottom-left corner (background)
499
+ [9*w//10, 9*h//10], # Bottom-right corner (background)
500
+ ], dtype=np.float32)
501
+
502
+ # Combine points
503
+ points = np.vstack([positive_points, negative_points])
504
+ labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32)
505
+
506
+ # Perform prediction
507
+ with torch.no_grad():
508
+ masks, scores, _ = predictor.predict(
509
+ point_coords=points,
510
+ point_labels=labels,
511
+ multimask_output=True
512
+ )
513
+
514
+ if masks is None or len(masks) == 0:
515
+ raise SegmentationError("No masks generated")
516
+
517
+ # Select best mask based on score
518
+ best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0
519
+ best_mask = masks[best_idx]
520
+
521
+ return _process_mask(best_mask)
522
+
523
+ # ============================================================================
524
+ # ORIGINAL FUNCTION PRESERVED FOR ROLLBACK
525
+ # ============================================================================
526
+
527
+ def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
528
+ """
529
+ ORIGINAL VERSION: Preserved for rollback capability
530
  """
531
  if image is None or image.size == 0:
532
  raise SegmentationError("Invalid input image")
 
621
  else:
622
  raise SegmentationError(f"Unexpected error: {e}")
623
 
624
+ # ============================================================================
625
+ # EXISTING FUNCTIONS PRESERVED (unchanged for rollback safety)
626
+ # ============================================================================
627
+
628
  def _process_mask(mask: np.ndarray) -> np.ndarray:
629
  """Process raw mask to ensure correct format and range"""
630
  try:
 
755
  mask[h//6:5*h//6, w//4:3*w//4] = 255
756
  return mask
757
 
758
+ # ============================================================================
759
+ # ALL OTHER EXISTING FUNCTIONS REMAIN UNCHANGED FOR ROLLBACK SAFETY
760
+ # ============================================================================
761
+
762
  def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any,
763
  fallback_enabled: bool = True) -> np.ndarray:
764
  """
765
  Enhanced mask refinement with MatAnyone and robust fallbacks
766
+ UNCHANGED for rollback safety
 
 
 
 
 
 
 
 
 
 
 
767
  """
768
  if image is None or mask is None:
769
  raise MaskRefinementError("Invalid input image or mask")
 
830
  logger.warning(f"MatAnyone processing error: {e}")
831
  return None
832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
833
  def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
834
+ """Advanced OpenCV-based mask enhancement with multiple techniques"""
 
 
835
  try:
836
  if len(mask.shape) == 3:
837
  mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
907
 
908
  def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray,
909
  fallback_enabled: bool = True) -> np.ndarray:
910
+ """Enhanced background replacement with comprehensive error handling and quality improvements"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  if frame is None or mask is None or background is None:
912
  raise BackgroundReplacementError("Invalid input frame, mask, or background")
913