MogensR commited on
Commit
26841b5
·
1 Parent(s): 19a46c2

Update utils/cv_processing.py

Browse files
Files changed (1) hide show
  1. utils/cv_processing.py +927 -925
utils/cv_processing.py CHANGED
@@ -1,1134 +1,1136 @@
1
- def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float:
2
- """Assess mask quality automatically"""
3
- try:
4
- h, w = image.shape[:2]
5
- scores = []
6
-
7
- mask_area = np.sum(mask > 127)
8
- total_area = h * w
9
- area_ratio = mask_area / total_area
10
-
11
- if 0.05 <= area_ratio <= 0.8:
12
- area_score = 1.0
13
- elif area_ratio < 0.05:
14
- area_score = area_ratio / 0.05
15
- else:
16
- area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2)
17
- scores.append(area_score)
18
-
19
- mask_binary = mask > 127
20
- if np.any(mask_binary):
21
- mask_center_y, mask_center_x = np.where(mask_binary)
22
- center_y = np.mean(mask_center_y) / h
23
- center_x = np.mean(mask_center_x) / w
24
-
25
- center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5))
26
- scores.append(center_score)
27
- else:
28
- scores.append(0.0)
29
-
30
- edges = cv2.Canny(mask, 50, 150)
31
- edge_density = np.sum(edges > 0) / total_area
32
- smoothness_score = max(0, 1.0 - edge_density * 10)
33
- scores.append(smoothness_score)
34
-
35
- num_labels, _ = cv2.connectedComponents(mask)
36
- connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2)
37
- scores.append(connectivity_score)
38
-
39
- weights = [0.3, 0.2, 0.3, 0.2]
40
- overall_score = np.average(scores, weights=weights)
41
-
42
- return overall_score
43
-
44
- except Exception as e:
45
- logger.warning(f"Quality assessment failed: {e}")
46
- return 0.5
47
 
48
- def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
49
- """Identify problematic areas in mask"""
50
- try:
51
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
52
- edges = cv2.Canny(gray, 50, 150)
53
- mask_edges = cv2.Canny(mask, 50, 150)
54
- edge_discrepancy = cv2.bitwise_xor(edges, mask_edges)
55
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
56
- error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1)
57
- return error_regions > 0
58
- except Exception as e:
59
- logger.warning(f"Error detection failed: {e}")
60
- return np.zeros_like(mask, dtype=bool)
61
 
62
- def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray,
63
- problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
64
- """Generate corrective prompts based on problem areas"""
65
- try:
66
- contours, _ = cv2.findContours(problem_areas.astype(np.uint8),
67
- cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
-
69
- corrective_points = []
70
- corrective_labels = []
71
-
72
- for contour in contours:
73
- if cv2.contourArea(contour) > 100:
74
- M = cv2.moments(contour)
75
- if M["m00"] != 0:
76
- cx = int(M["m10"] / M["m00"])
77
- cy = int(M["m01"] / M["m00"])
78
-
79
- current_mask_value = mask[cy, cx]
80
-
81
- if current_mask_value < 127:
82
- corrective_points.append([cx, cy])
83
- corrective_labels.append(1)
84
- else:
85
- corrective_points.append([cx, cy])
86
- corrective_labels.append(0)
87
-
88
- return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2),
89
- np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32))
90
-
91
- except Exception as e:
92
- logger.warning(f"Corrective prompt generation failed: {e}")
93
- return np.array([]).reshape(0, 2), np.array([], dtype=np.int32)
94
 
95
  # ============================================================================
96
- # HELPER FUNCTIONS - PROCESSING
97
  # ============================================================================
98
 
99
- def _process_mask(mask: np.ndarray) -> np.ndarray:
100
- """Process raw mask to ensure correct format and range"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  try:
102
- if len(mask.shape) > 2:
103
- mask = mask.squeeze()
104
-
105
- if len(mask.shape) > 2:
106
- mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2)
 
107
 
108
- if mask.dtype == bool:
109
- mask = mask.astype(np.uint8) * 255
110
- elif mask.dtype == np.float32 or mask.dtype == np.float64:
111
- if mask.max() <= 1.0:
112
- mask = (mask * 255).astype(np.uint8)
 
113
  else:
114
- mask = np.clip(mask, 0, 255).astype(np.uint8)
 
 
 
115
  else:
116
- mask = mask.astype(np.uint8)
117
 
118
- kernel = np.ones((3, 3), np.uint8)
119
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
120
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
121
 
122
- _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
 
 
 
 
 
123
 
 
124
  return mask
125
 
 
 
126
  except Exception as e:
127
- logger.error(f"Mask processing failed: {e}")
128
- h, w = mask.shape[:2] if len(mask.shape) >= 2 else (256, 256)
129
- fallback = np.zeros((h, w), dtype=np.uint8)
130
- fallback[h//4:3*h//4, w//4:3*w//4] = 255
131
- return fallback
132
-
133
- def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool:
134
- """Validate that the mask meets quality criteria"""
 
 
 
135
  try:
136
- h, w = image_shape
137
- mask_area = np.sum(mask > 127)
138
- total_area = h * w
139
-
140
- area_ratio = mask_area / total_area
141
- if area_ratio < 0.05 or area_ratio > 0.8:
142
- logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}")
143
- return False
144
-
145
- mask_binary = mask > 127
146
- mask_center_y, mask_center_x = np.where(mask_binary)
147
-
148
- if len(mask_center_y) == 0:
149
- logger.warning("Empty mask")
150
- return False
151
 
152
- center_y = np.mean(mask_center_y)
153
- center_x = np.mean(mask_center_x)
 
 
 
 
 
 
154
 
155
- if center_y < h * 0.2 or center_y > h * 0.9:
156
- logger.warning(f"Mask center too far from expected person location: y={center_y/h:.2f}")
157
- return False
158
 
159
- return True
 
 
 
 
 
 
 
 
 
160
 
161
- except Exception as e:
162
- logger.warning(f"Mask validation error: {e}")
163
- return True
164
-
165
- def _fallback_segmentation(image: np.ndarray) -> np.ndarray:
166
- """Fallback segmentation when AI models fail"""
167
- try:
168
- logger.info("Using fallback segmentation strategy")
169
- h, w = image.shape[:2]
170
 
171
  try:
172
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
173
-
174
- edge_pixels = np.concatenate([
175
- gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1]
176
- ])
177
- bg_color = np.median(edge_pixels)
178
-
179
- diff = np.abs(gray.astype(float) - bg_color)
180
- mask = (diff > 30).astype(np.uint8) * 255
181
-
182
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
183
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
184
- mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
185
-
186
- if _validate_mask_quality(mask, image.shape[:2]):
187
- logger.info("Background subtraction fallback successful")
188
- return mask
189
-
190
  except Exception as e:
191
- logger.warning(f"Background subtraction fallback failed: {e}")
 
 
 
 
192
 
193
- mask = np.zeros((h, w), dtype=np.uint8)
 
 
 
 
 
194
 
195
- center_x, center_y = w // 2, h // 2
196
- radius_x, radius_y = w // 3, h // 2.5
 
 
 
 
 
197
 
198
- y, x = np.ogrid[:h, :w]
199
- mask_ellipse = ((x - center_x) / radius_x) ** 2 + ((y - center_y) / radius_y) ** 2 <= 1
200
- mask[mask_ellipse] = 255
201
 
202
- logger.info("Using geometric fallback mask")
 
 
 
 
 
 
 
203
  return mask
204
 
 
 
205
  except Exception as e:
206
- logger.error(f"All fallback strategies failed: {e}")
207
- h, w = image.shape[:2]
208
- mask = np.zeros((h, w), dtype=np.uint8)
209
- mask[h//6:5*h//6, w//4:3*w//4] = 255
210
- return mask
211
 
212
- def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
213
- """Attempt MatAnyone mask refinement"""
 
 
 
 
 
 
 
 
214
  try:
215
- if hasattr(processor, 'infer'):
216
- refined_mask = processor.infer(image, mask)
217
- elif hasattr(processor, 'process'):
218
- refined_mask = processor.process(image, mask)
219
- elif callable(processor):
220
- refined_mask = processor(image, mask)
221
- else:
222
- logger.warning("Unknown MatAnyone interface")
223
- return None
224
-
225
- if refined_mask is None:
226
- return None
227
 
228
- refined_mask = _process_mask(refined_mask)
229
- logger.debug("MatAnyone refinement successful")
230
- return refined_mask
 
 
 
 
 
 
 
 
 
 
231
 
 
 
 
 
 
 
 
 
232
  except Exception as e:
233
- logger.warning(f"MatAnyone processing error: {e}")
234
- return None
 
 
 
235
 
236
- def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray:
237
- """Approximation of guided filter for edge-aware smoothing"""
238
  try:
239
- guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide
240
- guide_gray = guide_gray.astype(np.float32) / 255.0
241
- mask_float = mask.astype(np.float32) / 255.0
242
 
243
- kernel_size = 2 * radius + 1
 
244
 
245
- mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size))
246
- mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size))
247
- corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size))
248
 
249
- cov_guide_mask = corr_guide_mask - mean_guide * mean_mask
250
- mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size))
251
- var_guide = mean_guide_sq - mean_guide * mean_guide
252
 
253
- a = cov_guide_mask / (var_guide + eps)
254
- b = mean_mask - a * mean_guide
255
 
256
- mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size))
257
- mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size))
258
 
259
- output = mean_a * guide_gray + mean_b
260
- output = np.clip(output * 255, 0, 255).astype(np.uint8)
261
 
262
- return output
263
 
264
  except Exception as e:
265
- logger.warning(f"Guided filter approximation failed: {e}")
266
- return mask
267
 
268
  # ============================================================================
269
- # HELPER FUNCTIONS - COMPOSITING
270
  # ============================================================================
271
 
272
- def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
273
- """Advanced compositing with edge feathering and color correction"""
 
 
 
 
274
  try:
275
- threshold = 100
276
- _, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
277
 
278
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
279
- mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel)
280
- mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel)
281
 
282
- mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0)
283
- mask_smooth = mask_smooth / 255.0
284
 
285
- mask_smooth = np.power(mask_smooth, 0.8)
286
-
287
- mask_smooth = np.where(mask_smooth > 0.5,
288
- np.minimum(mask_smooth * 1.1, 1.0),
289
- mask_smooth * 0.9)
290
-
291
- frame_adjusted = _color_match_edges(frame, background, mask_smooth)
292
 
293
- alpha_3ch = np.stack([mask_smooth] * 3, axis=2)
 
 
 
 
 
 
 
 
 
 
294
 
295
- frame_float = frame_adjusted.astype(np.float32)
296
- background_float = background.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
- result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch)
299
- result = np.clip(result, 0, 255).astype(np.uint8)
300
 
301
- return result
302
 
303
  except Exception as e:
304
- logger.error(f"Advanced compositing error: {e}")
305
- raise
306
 
307
- def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray:
308
- """Subtle color matching at edges to reduce halos"""
 
 
 
 
 
 
 
309
  try:
310
- edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3)
311
- edge_mask = np.abs(edge_mask)
312
- edge_mask = (edge_mask > 0.1).astype(np.float32)
313
 
314
- edge_areas = edge_mask > 0
315
- if not np.any(edge_areas):
316
- return frame
317
 
318
- frame_adjusted = frame.copy().astype(np.float32)
319
- background_float = background.astype(np.float32)
 
320
 
321
- adjustment_strength = 0.1
322
- for c in range(3):
323
- frame_adjusted[:, :, c] = np.where(
324
- edge_areas,
325
- frame_adjusted[:, :, c] * (1 - adjustment_strength) +
326
- background_float[:, :, c] * adjustment_strength,
327
- frame_adjusted[:, :, c]
328
- )
329
 
330
- return np.clip(frame_adjusted, 0, 255).astype(np.uint8)
331
 
332
- except Exception as e:
333
- logger.warning(f"Color matching failed: {e}")
334
- return frame
335
-
336
- def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
337
- """Simple fallback compositing method"""
338
- try:
339
- logger.info("Using simple compositing fallback")
340
 
341
- background = cv2.resize(background, (frame.shape[1], frame.shape[0]))
 
342
 
343
- if len(mask.shape) == 3:
344
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
345
- if mask.max() <= 1.0:
346
- mask = (mask * 255).astype(np.uint8)
347
 
348
- _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
 
349
 
350
- mask_norm = mask_binary.astype(np.float32) / 255.0
351
- mask_3ch = np.stack([mask_norm] * 3, axis=2)
 
352
 
353
- result = frame * mask_3ch + background * (1 - mask_3ch)
354
- return result.astype(np.uint8)
355
 
356
  except Exception as e:
357
- logger.error(f"Simple compositing failed: {e}")
358
- return frame
359
 
360
  # ============================================================================
361
- # HELPER FUNCTIONS - BACKGROUND CREATION
362
  # ============================================================================
363
 
364
- def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
365
- """Create solid color background"""
366
- color_hex = bg_config["colors"][0].lstrip('#')
367
- color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
368
- color_bgr = color_rgb[::-1]
369
- return np.full((height, width, 3), color_bgr, dtype=np.uint8)
370
-
371
- def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
372
- """Create enhanced gradient background with better quality"""
373
  try:
374
- colors = bg_config["colors"]
375
- direction = bg_config.get("direction", "vertical")
376
 
377
- rgb_colors = []
378
- for color_hex in colors:
379
- color_hex = color_hex.lstrip('#')
380
- rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
381
- rgb_colors.append(rgb)
382
 
383
- if not rgb_colors:
384
- rgb_colors = [(128, 128, 128)]
 
 
 
385
 
386
- if direction == "vertical":
387
- background = _create_vertical_gradient(rgb_colors, width, height)
388
- elif direction == "horizontal":
389
- background = _create_horizontal_gradient(rgb_colors, width, height)
390
- elif direction == "diagonal":
391
- background = _create_diagonal_gradient(rgb_colors, width, height)
392
- elif direction in ["radial", "soft_radial"]:
393
- background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial")
 
 
 
 
 
 
 
 
394
  else:
395
- background = _create_vertical_gradient(rgb_colors, width, height)
396
 
397
- return cv2.cvtColor(background, cv2.COLOR_RGB2BGR)
398
 
399
  except Exception as e:
400
- logger.error(f"Gradient creation error: {e}")
401
- return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
402
-
403
- def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray:
404
- """Create vertical gradient using NumPy for performance"""
405
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
406
-
407
- for y in range(height):
408
- progress = y / height if height > 0 else 0
409
- color = _interpolate_color(colors, progress)
410
- gradient[y, :] = color
411
-
412
- return gradient
413
 
414
- def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray:
415
- """Create horizontal gradient using NumPy for performance"""
416
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
417
 
418
- for x in range(width):
419
- progress = x / width if width > 0 else 0
420
- color = _interpolate_color(colors, progress)
421
- gradient[:, x] = color
 
422
 
423
- return gradient
424
-
425
- def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray:
426
- """Create diagonal gradient using vectorized operations"""
427
- y_coords, x_coords = np.mgrid[0:height, 0:width]
428
- max_distance = width + height
429
- progress = (x_coords + y_coords) / max_distance
430
- progress = np.clip(progress, 0, 1)
431
 
432
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
433
- for c in range(3):
434
- gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
435
 
436
- return gradient
437
-
438
- def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray:
439
- """Create radial gradient using vectorized operations"""
440
- center_x, center_y = width // 2, height // 2
441
- max_distance = np.sqrt(center_x**2 + center_y**2)
442
 
443
- y_coords, x_coords = np.mgrid[0:height, 0:width]
444
- distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2)
445
- progress = distances / max_distance
446
- progress = np.clip(progress, 0, 1)
447
 
448
- if soft:
449
- progress = np.power(progress, 0.7)
450
 
451
- gradient = np.zeros((height, width, 3), dtype=np.uint8)
452
- for c in range(3):
453
- gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
454
-
455
- return gradient
456
-
457
- def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray:
458
- """Vectorized color interpolation for performance"""
459
- if len(colors) == 1:
460
- return np.full_like(progress, colors[0][channel], dtype=np.uint8)
461
-
462
- num_segments = len(colors) - 1
463
- segment_progress = progress * num_segments
464
- segment_indices = np.floor(segment_progress).astype(int)
465
- segment_indices = np.clip(segment_indices, 0, num_segments - 1)
466
- local_progress = segment_progress - segment_indices
467
-
468
- start_colors = np.array([colors[i][channel] for i in range(len(colors))])
469
- end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))])
470
-
471
- start_vals = start_colors[segment_indices]
472
- end_vals = end_colors[segment_indices]
473
-
474
- result = start_vals + (end_vals - start_vals) * local_progress
475
- return np.clip(result, 0, 255).astype(np.uint8)
476
-
477
- def _interpolate_color(colors: list, progress: float) -> tuple:
478
- """Interpolate between multiple colors"""
479
- if len(colors) == 1:
480
- return colors[0]
481
- elif len(colors) == 2:
482
- r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress)
483
- g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress)
484
- b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress)
485
- return (r, g, b)
486
- else:
487
- segment = progress * (len(colors) - 1)
488
- idx = int(segment)
489
- local_progress = segment - idx
490
- if idx >= len(colors) - 1:
491
- return colors[-1]
492
- c1, c2 = colors[idx], colors[idx + 1]
493
- r = int(c1[0] + (c2[0] - c1[0]) * local_progress)
494
- g = int(c1[1] + (c2[1] - c1[1]) * local_progress)
495
- b = int(c1[2] + (c2[2] - c1[2]) * local_progress)
496
- return (r, g, b)
497
 
498
- def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray:
499
- """Apply brightness and contrast adjustments to background"""
500
  try:
501
- brightness = bg_config.get("brightness", 1.0)
502
- contrast = bg_config.get("contrast", 1.0)
503
 
504
- if brightness != 1.0 or contrast != 1.0:
505
- background = background.astype(np.float32)
506
- background = background * contrast * brightness
507
- background = np.clip(background, 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
- return background
 
 
 
 
 
 
 
 
 
510
 
511
  except Exception as e:
512
- logger.warning(f"Background adjustment failed: {e}")
513
- return background"""
514
- Computer Vision Processing Module for BackgroundFX Pro
515
- Contains segmentation, mask refinement, background replacement, and helper functions
516
- """
517
-
518
- # Set OMP_NUM_THREADS at the very beginning to prevent libgomp errors
519
- import os
520
- if 'OMP_NUM_THREADS' not in os.environ:
521
- os.environ['OMP_NUM_THREADS'] = '4'
522
- os.environ['MKL_NUM_THREADS'] = '4'
523
-
524
- import logging
525
- from typing import Optional, Tuple, Dict, Any
526
- import numpy as np
527
- import cv2
528
- import torch
529
-
530
- logger = logging.getLogger(__name__)
531
-
532
- # ============================================================================
533
- # CONFIGURATION AND CONSTANTS
534
- # ============================================================================
535
-
536
- # Version control flags for CV functions
537
- USE_ENHANCED_SEGMENTATION = True
538
- USE_AUTO_TEMPORAL_CONSISTENCY = True
539
- USE_INTELLIGENT_PROMPTING = True
540
- USE_ITERATIVE_REFINEMENT = True
541
-
542
- # Professional background templates
543
- PROFESSIONAL_BACKGROUNDS = {
544
- "office_modern": {
545
- "name": "Modern Office",
546
- "type": "gradient",
547
- "colors": ["#f8f9fa", "#e9ecef", "#dee2e6"],
548
- "direction": "diagonal",
549
- "description": "Clean, contemporary office environment",
550
- "brightness": 0.95,
551
- "contrast": 1.1
552
- },
553
- "studio_blue": {
554
- "name": "Professional Blue",
555
- "type": "gradient",
556
- "colors": ["#1e3c72", "#2a5298", "#3498db"],
557
- "direction": "radial",
558
- "description": "Broadcast-quality blue studio",
559
- "brightness": 0.9,
560
- "contrast": 1.2
561
- },
562
- "studio_green": {
563
- "name": "Broadcast Green",
564
- "type": "color",
565
- "colors": ["#00b894"],
566
- "chroma_key": True,
567
- "description": "Professional green screen replacement",
568
- "brightness": 1.0,
569
- "contrast": 1.0
570
- },
571
- "minimalist": {
572
- "name": "Minimalist White",
573
- "type": "gradient",
574
- "colors": ["#ffffff", "#f1f2f6", "#ddd"],
575
- "direction": "soft_radial",
576
- "description": "Clean, minimal background",
577
- "brightness": 0.98,
578
- "contrast": 0.9
579
- },
580
- "warm_gradient": {
581
- "name": "Warm Sunset",
582
- "type": "gradient",
583
- "colors": ["#ff7675", "#fd79a8", "#fdcb6e"],
584
- "direction": "diagonal",
585
- "description": "Warm, inviting atmosphere",
586
- "brightness": 0.85,
587
- "contrast": 1.15
588
- },
589
- "tech_dark": {
590
- "name": "Tech Dark",
591
- "type": "gradient",
592
- "colors": ["#0c0c0c", "#2d3748", "#4a5568"],
593
- "direction": "vertical",
594
- "description": "Modern tech/gaming setup",
595
- "brightness": 0.7,
596
- "contrast": 1.3
597
- }
598
- }
599
 
600
  # ============================================================================
601
- # CUSTOM EXCEPTIONS
602
  # ============================================================================
603
 
604
- class SegmentationError(Exception):
605
- """Custom exception for segmentation failures"""
606
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
 
608
- class MaskRefinementError(Exception):
609
- """Custom exception for mask refinement failures"""
610
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
 
612
- class BackgroundReplacementError(Exception):
613
- """Custom exception for background replacement failures"""
614
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
 
616
  # ============================================================================
617
- # MAIN SEGMENTATION FUNCTIONS
618
  # ============================================================================
619
 
620
- def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
621
- """High-quality person segmentation with intelligent automation"""
622
- if not USE_ENHANCED_SEGMENTATION:
623
- return segment_person_hq_original(image, predictor, fallback_enabled)
624
-
625
- logger.debug("Using ENHANCED segmentation with intelligent automation")
626
-
627
- if image is None or image.size == 0:
628
- raise SegmentationError("Invalid input image")
629
-
630
  try:
631
- if predictor is None:
632
- if fallback_enabled:
633
- logger.warning("SAM2 predictor not available, using fallback")
634
- return _fallback_segmentation(image)
635
- else:
636
- raise SegmentationError("SAM2 predictor not available")
637
 
638
- try:
639
- predictor.set_image(image)
640
- except Exception as e:
641
- logger.error(f"Failed to set image in predictor: {e}")
642
- if fallback_enabled:
643
- return _fallback_segmentation(image)
644
- else:
645
- raise SegmentationError(f"Predictor setup failed: {e}")
646
 
647
- if USE_INTELLIGENT_PROMPTING:
648
- mask = _segment_with_intelligent_prompts(image, predictor)
 
 
 
 
 
649
  else:
650
- mask = _segment_with_basic_prompts(image, predictor)
651
 
652
- if USE_ITERATIVE_REFINEMENT and mask is not None:
653
- mask = _auto_refine_mask_iteratively(image, mask, predictor)
 
654
 
655
- if not _validate_mask_quality(mask, image.shape[:2]):
656
- logger.warning("Mask quality validation failed")
657
- if fallback_enabled:
658
- return _fallback_segmentation(image)
659
- else:
660
- raise SegmentationError("Poor mask quality")
661
 
662
- logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}")
663
  return mask
664
 
665
- except SegmentationError:
666
- raise
667
  except Exception as e:
668
- logger.error(f"Unexpected segmentation error: {e}")
669
- if fallback_enabled:
670
- return _fallback_segmentation(image)
671
- else:
672
- raise SegmentationError(f"Unexpected error: {e}")
673
 
674
- def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
675
- """Original version of person segmentation for rollback"""
676
- if image is None or image.size == 0:
677
- raise SegmentationError("Invalid input image")
678
-
679
  try:
680
- if predictor is None:
681
- if fallback_enabled:
682
- logger.warning("SAM2 predictor not available, using fallback")
683
- return _fallback_segmentation(image)
684
- else:
685
- raise SegmentationError("SAM2 predictor not available")
686
 
687
- try:
688
- predictor.set_image(image)
689
- except Exception as e:
690
- logger.error(f"Failed to set image in predictor: {e}")
691
- if fallback_enabled:
692
- return _fallback_segmentation(image)
693
- else:
694
- raise SegmentationError(f"Predictor setup failed: {e}")
695
 
696
- h, w = image.shape[:2]
 
697
 
698
- points = np.array([
699
- [w//2, h//4],
700
- [w//2, h//2],
701
- [w//2, 3*h//4],
702
- [w//3, h//2],
703
- [2*w//3, h//2],
704
- [w//2, h//6],
705
- [w//4, 2*h//3],
706
- [3*w//4, 2*h//3],
707
- ], dtype=np.float32)
708
 
709
- labels = np.ones(len(points), dtype=np.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
710
 
711
  try:
712
- with torch.no_grad():
713
- masks, scores, _ = predictor.predict(
714
- point_coords=points,
715
- point_labels=labels,
716
- multimask_output=True
717
- )
 
 
 
 
 
 
 
 
 
 
 
 
718
  except Exception as e:
719
- logger.error(f"SAM2 prediction failed: {e}")
720
- if fallback_enabled:
721
- return _fallback_segmentation(image)
722
- else:
723
- raise SegmentationError(f"Prediction failed: {e}")
724
-
725
- if masks is None or len(masks) == 0:
726
- logger.warning("SAM2 returned no masks")
727
- if fallback_enabled:
728
- return _fallback_segmentation(image)
729
- else:
730
- raise SegmentationError("No masks generated")
731
 
732
- if scores is None or len(scores) == 0:
733
- logger.warning("SAM2 returned no scores")
734
- best_mask = masks[0]
735
- else:
736
- best_idx = np.argmax(scores)
737
- best_mask = masks[best_idx]
738
- logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
739
 
740
- mask = _process_mask(best_mask)
 
741
 
742
- if not _validate_mask_quality(mask, image.shape[:2]):
743
- logger.warning("Mask quality validation failed")
744
- if fallback_enabled:
745
- return _fallback_segmentation(image)
746
- else:
747
- raise SegmentationError("Poor mask quality")
748
 
749
- logger.debug(f"Segmentation successful - mask range: {mask.min()}-{mask.max()}")
750
  return mask
751
 
752
- except SegmentationError:
753
- raise
754
  except Exception as e:
755
- logger.error(f"Unexpected segmentation error: {e}")
756
- if fallback_enabled:
757
- return _fallback_segmentation(image)
758
- else:
759
- raise SegmentationError(f"Unexpected error: {e}")
760
-
761
- # ============================================================================
762
- # MASK REFINEMENT FUNCTIONS
763
- # ============================================================================
764
 
765
- def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any,
766
- fallback_enabled: bool = True) -> np.ndarray:
767
- """Enhanced mask refinement with MatAnyone and robust fallbacks"""
768
- if image is None or mask is None:
769
- raise MaskRefinementError("Invalid input image or mask")
770
-
771
  try:
772
- mask = _process_mask(mask)
 
 
 
 
 
 
 
 
773
 
774
- if matanyone_processor is not None:
775
- try:
776
- logger.debug("Attempting MatAnyone refinement")
777
- refined_mask = _matanyone_refine(image, mask, matanyone_processor)
778
-
779
- if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]):
780
- logger.debug("MatAnyone refinement successful")
781
- return refined_mask
782
- else:
783
- logger.warning("MatAnyone produced poor quality mask")
784
-
785
- except Exception as e:
786
- logger.warning(f"MatAnyone refinement failed: {e}")
787
 
788
- if fallback_enabled:
789
- logger.debug("Using enhanced OpenCV refinement")
790
- return enhance_mask_opencv_advanced(image, mask)
791
- else:
792
- raise MaskRefinementError("MatAnyone failed and fallback disabled")
793
-
794
- except MaskRefinementError:
795
- raise
796
  except Exception as e:
797
- logger.error(f"Unexpected mask refinement error: {e}")
798
- if fallback_enabled:
799
- return enhance_mask_opencv_advanced(image, mask)
800
- else:
801
- raise MaskRefinementError(f"Unexpected error: {e}")
802
 
803
- def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
804
- """Advanced OpenCV-based mask enhancement with multiple techniques"""
805
  try:
806
- if len(mask.shape) == 3:
807
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
808
 
809
- if mask.max() <= 1.0:
810
- mask = (mask * 255).astype(np.uint8)
811
 
812
- refined_mask = cv2.bilateralFilter(mask, 9, 75, 75)
813
- refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2)
 
814
 
815
- kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
816
- refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close)
 
817
 
818
- kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
819
- refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open)
820
 
821
- refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8)
 
822
 
823
- _, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY)
 
824
 
825
- return refined_mask
826
 
827
  except Exception as e:
828
- logger.warning(f"Enhanced OpenCV refinement failed: {e}")
829
- return cv2.GaussianBlur(mask, (5, 5), 1.0)
830
 
831
  # ============================================================================
832
- # BACKGROUND REPLACEMENT FUNCTIONS
833
  # ============================================================================
834
 
835
- def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray,
836
- fallback_enabled: bool = True) -> np.ndarray:
837
- """Enhanced background replacement with comprehensive error handling"""
838
- if frame is None or mask is None or background is None:
839
- raise BackgroundReplacementError("Invalid input frame, mask, or background")
840
-
841
  try:
842
- background = cv2.resize(background, (frame.shape[1], frame.shape[0]),
843
- interpolation=cv2.INTER_LANCZOS4)
844
 
845
- if len(mask.shape) == 3:
846
- mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
847
 
848
- if mask.dtype != np.uint8:
849
- mask = mask.astype(np.uint8)
850
 
851
- if mask.max() <= 1.0:
852
- logger.debug("Converting normalized mask to 0-255 range")
853
- mask = (mask * 255).astype(np.uint8)
854
 
855
- try:
856
- result = _advanced_compositing(frame, mask, background)
857
- logger.debug("Advanced compositing successful")
858
- return result
859
-
860
- except Exception as e:
861
- logger.warning(f"Advanced compositing failed: {e}")
862
- if fallback_enabled:
863
- return _simple_compositing(frame, mask, background)
864
- else:
865
- raise BackgroundReplacementError(f"Advanced compositing failed: {e}")
866
 
867
- except BackgroundReplacementError:
868
- raise
869
- except Exception as e:
870
- logger.error(f"Unexpected background replacement error: {e}")
871
- if fallback_enabled:
872
- return _simple_compositing(frame, mask, background)
873
- else:
874
- raise BackgroundReplacementError(f"Unexpected error: {e}")
875
-
876
- def create_professional_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
877
- """Enhanced professional background creation with quality improvements"""
878
- try:
879
- if bg_config["type"] == "color":
880
- background = _create_solid_background(bg_config, width, height)
881
- elif bg_config["type"] == "gradient":
882
- background = _create_gradient_background_enhanced(bg_config, width, height)
883
- else:
884
- background = np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
885
 
886
- background = _apply_background_adjustments(background, bg_config)
887
 
888
- return background
 
 
 
 
 
 
889
 
890
  except Exception as e:
891
- logger.error(f"Background creation error: {e}")
892
- return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
893
-
894
- # ============================================================================
895
- # VALIDATION FUNCTION
896
- # ============================================================================
897
 
898
- def validate_video_file(video_path: str) -> Tuple[bool, str]:
899
- """Enhanced video file validation with detailed checks"""
900
- if not video_path or not os.path.exists(video_path):
901
- return False, "Video file not found"
902
-
903
  try:
904
- file_size = os.path.getsize(video_path)
905
- if file_size == 0:
906
- return False, "Video file is empty"
907
 
908
- if file_size > 2 * 1024 * 1024 * 1024:
909
- return False, "Video file too large (>2GB)"
 
910
 
911
- cap = cv2.VideoCapture(video_path)
912
- if not cap.isOpened():
913
- return False, "Cannot open video file"
914
 
915
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
916
- fps = cap.get(cv2.CAP_PROP_FPS)
917
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
918
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
919
 
920
- cap.release()
921
 
922
- if frame_count == 0:
923
- return False, "Video appears to be empty (0 frames)"
 
 
 
 
 
 
924
 
925
- if fps <= 0 or fps > 120:
926
- return False, f"Invalid frame rate: {fps}"
927
 
928
- if width <= 0 or height <= 0:
929
- return False, f"Invalid resolution: {width}x{height}"
 
 
930
 
931
- if width > 4096 or height > 4096:
932
- return False, f"Resolution too high: {width}x{height} (max 4096x4096)"
933
 
934
- duration = frame_count / fps
935
- if duration > 300:
936
- return False, f"Video too long: {duration:.1f}s (max 300s)"
937
 
938
- return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s"
 
939
 
940
  except Exception as e:
941
- return False, f"Error validating video: {str(e)}"
 
942
 
943
  # ============================================================================
944
- # HELPER FUNCTIONS - SEGMENTATION
945
  # ============================================================================
946
 
947
- def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
948
- """Intelligent automatic prompt generation for segmentation"""
 
 
 
 
 
 
 
949
  try:
950
- h, w = image.shape[:2]
951
- pos_points, neg_points = _generate_smart_prompts(image)
952
-
953
- if len(pos_points) == 0:
954
- pos_points = np.array([[w//2, h//2]], dtype=np.float32)
955
-
956
- points = np.vstack([pos_points, neg_points])
957
- labels = np.hstack([
958
- np.ones(len(pos_points), dtype=np.int32),
959
- np.zeros(len(neg_points), dtype=np.int32)
960
- ])
961
-
962
- logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points")
963
 
964
- with torch.no_grad():
965
- masks, scores, _ = predictor.predict(
966
- point_coords=points,
967
- point_labels=labels,
968
- multimask_output=True
969
- )
970
 
971
- if masks is None or len(masks) == 0:
972
- raise SegmentationError("No masks generated")
973
 
974
- if scores is not None and len(scores) > 0:
975
- best_idx = np.argmax(scores)
976
- best_mask = masks[best_idx]
977
- logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
 
 
 
 
978
  else:
979
- best_mask = masks[0]
980
 
981
- return _process_mask(best_mask)
982
 
983
  except Exception as e:
984
- logger.error(f"Intelligent prompting failed: {e}")
985
- raise
986
 
987
- def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
988
- """Basic prompting method for segmentation"""
989
- h, w = image.shape[:2]
990
 
991
- positive_points = np.array([
992
- [w//2, h//3],
993
- [w//2, h//2],
994
- [w//2, 2*h//3],
995
- ], dtype=np.float32)
996
 
997
- negative_points = np.array([
998
- [w//10, h//10],
999
- [9*w//10, h//10],
1000
- [w//10, 9*h//10],
1001
- [9*w//10, 9*h//10],
1002
- ], dtype=np.float32)
1003
 
1004
- points = np.vstack([positive_points, negative_points])
1005
- labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32)
 
 
1006
 
1007
- with torch.no_grad():
1008
- masks, scores, _ = predictor.predict(
1009
- point_coords=points,
1010
- point_labels=labels,
1011
- multimask_output=True
1012
- )
 
 
1013
 
1014
- if masks is None or len(masks) == 0:
1015
- raise SegmentationError("No masks generated")
 
1016
 
1017
- best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0
1018
- best_mask = masks[best_idx]
 
 
 
 
1019
 
1020
- return _process_mask(best_mask)
 
 
 
 
 
 
 
 
 
 
 
 
1021
 
1022
- def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
1023
- """Generate optimal positive/negative points automatically"""
1024
- try:
1025
- h, w = image.shape[:2]
1026
-
1027
- try:
1028
- saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
1029
- success, saliency_map = saliency.computeSaliency(image)
1030
-
1031
- if success:
1032
- saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1]
1033
- contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8),
1034
- cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
1035
-
1036
- positive_points = []
1037
- if contours:
1038
- for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
1039
- M = cv2.moments(contour)
1040
- if M["m00"] != 0:
1041
- cx = int(M["m10"] / M["m00"])
1042
- cy = int(M["m01"] / M["m00"])
1043
- if 0 < cx < w and 0 < cy < h:
1044
- positive_points.append([cx, cy])
1045
-
1046
- if positive_points:
1047
- logger.debug(f"Generated {len(positive_points)} saliency-based points")
1048
- positive_points = np.array(positive_points, dtype=np.float32)
1049
- else:
1050
- raise Exception("No valid saliency points found")
1051
-
1052
- except Exception as e:
1053
- logger.debug(f"Saliency method failed: {e}, using fallback")
1054
- positive_points = np.array([
1055
- [w//2, h//3],
1056
- [w//2, h//2],
1057
- [w//2, 2*h//3],
1058
- ], dtype=np.float32)
1059
-
1060
- negative_points = np.array([
1061
- [10, 10],
1062
- [w-10, 10],
1063
- [10, h-10],
1064
- [w-10, h-10],
1065
- [w//2, 5],
1066
- [w//2, h-5],
1067
- ], dtype=np.float32)
1068
-
1069
- return positive_points, negative_points
1070
-
1071
- except Exception as e:
1072
- logger.warning(f"Smart prompt generation failed: {e}")
1073
- h, w = image.shape[:2]
1074
- positive_points = np.array([[w//2, h//2]], dtype=np.float32)
1075
- negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32)
1076
- return positive_points, negative_points
1077
 
1078
- # ============================================================================
1079
- # HELPER FUNCTIONS - REFINEMENT
1080
- # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1081
 
1082
- def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray,
1083
- predictor: Any, max_iterations: int = 2) -> np.ndarray:
1084
- """Automatically refine mask based on quality assessment"""
1085
  try:
1086
- current_mask = initial_mask.copy()
 
1087
 
1088
- for iteration in range(max_iterations):
1089
- quality_score = _assess_mask_quality(current_mask, image)
1090
- logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}")
1091
-
1092
- if quality_score > 0.85:
1093
- logger.debug(f"Quality sufficient after {iteration} iterations")
1094
- break
1095
-
1096
- problem_areas = _find_mask_errors(current_mask, image)
1097
-
1098
- if np.any(problem_areas):
1099
- corrective_points, corrective_labels = _generate_corrective_prompts(
1100
- image, current_mask, problem_areas
1101
- )
1102
-
1103
- if len(corrective_points) > 0:
1104
- try:
1105
- with torch.no_grad():
1106
- masks, scores, _ = predictor.predict(
1107
- point_coords=corrective_points,
1108
- point_labels=corrective_labels,
1109
- mask_input=current_mask[None, :, :],
1110
- multimask_output=False
1111
- )
1112
-
1113
- if masks is not None and len(masks) > 0:
1114
- refined_mask = _process_mask(masks[0])
1115
-
1116
- if _assess_mask_quality(refined_mask, image) > quality_score:
1117
- current_mask = refined_mask
1118
- logger.debug(f"Improved mask in iteration {iteration}")
1119
- else:
1120
- logger.debug(f"Refinement didn't improve quality in iteration {iteration}")
1121
- break
1122
-
1123
- except Exception as e:
1124
- logger.debug(f"Refinement iteration {iteration} failed: {e}")
1125
- break
1126
- else:
1127
- logger.debug("No problem areas detected")
1128
- break
1129
 
1130
- return current_mask
1131
 
1132
  except Exception as e:
1133
- logger.warning(f"Iterative refinement failed: {e}")
1134
- return initial_mask
 
1
+ """
2
+ Computer Vision Processing Module for BackgroundFX Pro
3
+ Contains segmentation, mask refinement, background replacement, and helper functions
4
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Set OMP_NUM_THREADS at the very beginning to prevent libgomp errors
7
+ import os
8
+ if 'OMP_NUM_THREADS' not in os.environ:
9
+ os.environ['OMP_NUM_THREADS'] = '4'
10
+ os.environ['MKL_NUM_THREADS'] = '4'
 
 
 
 
 
 
 
 
11
 
12
+ import logging
13
+ from typing import Optional, Tuple, Dict, Any
14
+ import numpy as np
15
+ import cv2
16
+ import torch
17
+
18
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # ============================================================================
21
+ # CONFIGURATION AND CONSTANTS
22
  # ============================================================================
23
 
24
+ # Version control flags for CV functions
25
+ USE_ENHANCED_SEGMENTATION = True
26
+ USE_AUTO_TEMPORAL_CONSISTENCY = True
27
+ USE_INTELLIGENT_PROMPTING = True
28
+ USE_ITERATIVE_REFINEMENT = True
29
+
30
+ # Professional background templates
31
+ PROFESSIONAL_BACKGROUNDS = {
32
+ "office_modern": {
33
+ "name": "Modern Office",
34
+ "type": "gradient",
35
+ "colors": ["#f8f9fa", "#e9ecef", "#dee2e6"],
36
+ "direction": "diagonal",
37
+ "description": "Clean, contemporary office environment",
38
+ "brightness": 0.95,
39
+ "contrast": 1.1
40
+ },
41
+ "studio_blue": {
42
+ "name": "Professional Blue",
43
+ "type": "gradient",
44
+ "colors": ["#1e3c72", "#2a5298", "#3498db"],
45
+ "direction": "radial",
46
+ "description": "Broadcast-quality blue studio",
47
+ "brightness": 0.9,
48
+ "contrast": 1.2
49
+ },
50
+ "studio_green": {
51
+ "name": "Broadcast Green",
52
+ "type": "color",
53
+ "colors": ["#00b894"],
54
+ "chroma_key": True,
55
+ "description": "Professional green screen replacement",
56
+ "brightness": 1.0,
57
+ "contrast": 1.0
58
+ },
59
+ "minimalist": {
60
+ "name": "Minimalist White",
61
+ "type": "gradient",
62
+ "colors": ["#ffffff", "#f1f2f6", "#ddd"],
63
+ "direction": "soft_radial",
64
+ "description": "Clean, minimal background",
65
+ "brightness": 0.98,
66
+ "contrast": 0.9
67
+ },
68
+ "warm_gradient": {
69
+ "name": "Warm Sunset",
70
+ "type": "gradient",
71
+ "colors": ["#ff7675", "#fd79a8", "#fdcb6e"],
72
+ "direction": "diagonal",
73
+ "description": "Warm, inviting atmosphere",
74
+ "brightness": 0.85,
75
+ "contrast": 1.15
76
+ },
77
+ "tech_dark": {
78
+ "name": "Tech Dark",
79
+ "type": "gradient",
80
+ "colors": ["#0c0c0c", "#2d3748", "#4a5568"],
81
+ "direction": "vertical",
82
+ "description": "Modern tech/gaming setup",
83
+ "brightness": 0.7,
84
+ "contrast": 1.3
85
+ }
86
+ }
87
+
88
+ # ============================================================================
89
+ # CUSTOM EXCEPTIONS
90
+ # ============================================================================
91
+
92
+ class SegmentationError(Exception):
93
+ """Custom exception for segmentation failures"""
94
+ pass
95
+
96
+ class MaskRefinementError(Exception):
97
+ """Custom exception for mask refinement failures"""
98
+ pass
99
+
100
+ class BackgroundReplacementError(Exception):
101
+ """Custom exception for background replacement failures"""
102
+ pass
103
+
104
+ # ============================================================================
105
+ # MAIN SEGMENTATION FUNCTIONS
106
+ # ============================================================================
107
+
108
+ def segment_person_hq(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
109
+ """High-quality person segmentation with intelligent automation"""
110
+ if not USE_ENHANCED_SEGMENTATION:
111
+ return segment_person_hq_original(image, predictor, fallback_enabled)
112
+
113
+ logger.debug("Using ENHANCED segmentation with intelligent automation")
114
+
115
+ if image is None or image.size == 0:
116
+ raise SegmentationError("Invalid input image")
117
+
118
  try:
119
+ if predictor is None:
120
+ if fallback_enabled:
121
+ logger.warning("SAM2 predictor not available, using fallback")
122
+ return _fallback_segmentation(image)
123
+ else:
124
+ raise SegmentationError("SAM2 predictor not available")
125
 
126
+ try:
127
+ predictor.set_image(image)
128
+ except Exception as e:
129
+ logger.error(f"Failed to set image in predictor: {e}")
130
+ if fallback_enabled:
131
+ return _fallback_segmentation(image)
132
  else:
133
+ raise SegmentationError(f"Predictor setup failed: {e}")
134
+
135
+ if USE_INTELLIGENT_PROMPTING:
136
+ mask = _segment_with_intelligent_prompts(image, predictor)
137
  else:
138
+ mask = _segment_with_basic_prompts(image, predictor)
139
 
140
+ if USE_ITERATIVE_REFINEMENT and mask is not None:
141
+ mask = _auto_refine_mask_iteratively(image, mask, predictor)
 
142
 
143
+ if not _validate_mask_quality(mask, image.shape[:2]):
144
+ logger.warning("Mask quality validation failed")
145
+ if fallback_enabled:
146
+ return _fallback_segmentation(image)
147
+ else:
148
+ raise SegmentationError("Poor mask quality")
149
 
150
+ logger.debug(f"Enhanced segmentation successful - mask range: {mask.min()}-{mask.max()}")
151
  return mask
152
 
153
+ except SegmentationError:
154
+ raise
155
  except Exception as e:
156
+ logger.error(f"Unexpected segmentation error: {e}")
157
+ if fallback_enabled:
158
+ return _fallback_segmentation(image)
159
+ else:
160
+ raise SegmentationError(f"Unexpected error: {e}")
161
+
162
+ def segment_person_hq_original(image: np.ndarray, predictor: Any, fallback_enabled: bool = True) -> np.ndarray:
163
+ """Original version of person segmentation for rollback"""
164
+ if image is None or image.size == 0:
165
+ raise SegmentationError("Invalid input image")
166
+
167
  try:
168
+ if predictor is None:
169
+ if fallback_enabled:
170
+ logger.warning("SAM2 predictor not available, using fallback")
171
+ return _fallback_segmentation(image)
172
+ else:
173
+ raise SegmentationError("SAM2 predictor not available")
 
 
 
 
 
 
 
 
 
174
 
175
+ try:
176
+ predictor.set_image(image)
177
+ except Exception as e:
178
+ logger.error(f"Failed to set image in predictor: {e}")
179
+ if fallback_enabled:
180
+ return _fallback_segmentation(image)
181
+ else:
182
+ raise SegmentationError(f"Predictor setup failed: {e}")
183
 
184
+ h, w = image.shape[:2]
 
 
185
 
186
+ points = np.array([
187
+ [w//2, h//4],
188
+ [w//2, h//2],
189
+ [w//2, 3*h//4],
190
+ [w//3, h//2],
191
+ [2*w//3, h//2],
192
+ [w//2, h//6],
193
+ [w//4, 2*h//3],
194
+ [3*w//4, 2*h//3],
195
+ ], dtype=np.float32)
196
 
197
+ labels = np.ones(len(points), dtype=np.int32)
 
 
 
 
 
 
 
 
198
 
199
  try:
200
+ with torch.no_grad():
201
+ masks, scores, _ = predictor.predict(
202
+ point_coords=points,
203
+ point_labels=labels,
204
+ multimask_output=True
205
+ )
 
 
 
 
 
 
 
 
 
 
 
 
206
  except Exception as e:
207
+ logger.error(f"SAM2 prediction failed: {e}")
208
+ if fallback_enabled:
209
+ return _fallback_segmentation(image)
210
+ else:
211
+ raise SegmentationError(f"Prediction failed: {e}")
212
 
213
+ if masks is None or len(masks) == 0:
214
+ logger.warning("SAM2 returned no masks")
215
+ if fallback_enabled:
216
+ return _fallback_segmentation(image)
217
+ else:
218
+ raise SegmentationError("No masks generated")
219
 
220
+ if scores is None or len(scores) == 0:
221
+ logger.warning("SAM2 returned no scores")
222
+ best_mask = masks[0]
223
+ else:
224
+ best_idx = np.argmax(scores)
225
+ best_mask = masks[best_idx]
226
+ logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
227
 
228
+ mask = _process_mask(best_mask)
 
 
229
 
230
+ if not _validate_mask_quality(mask, image.shape[:2]):
231
+ logger.warning("Mask quality validation failed")
232
+ if fallback_enabled:
233
+ return _fallback_segmentation(image)
234
+ else:
235
+ raise SegmentationError("Poor mask quality")
236
+
237
+ logger.debug(f"Segmentation successful - mask range: {mask.min()}-{mask.max()}")
238
  return mask
239
 
240
+ except SegmentationError:
241
+ raise
242
  except Exception as e:
243
+ logger.error(f"Unexpected segmentation error: {e}")
244
+ if fallback_enabled:
245
+ return _fallback_segmentation(image)
246
+ else:
247
+ raise SegmentationError(f"Unexpected error: {e}")
248
 
249
+ # ============================================================================
250
+ # MASK REFINEMENT FUNCTIONS
251
+ # ============================================================================
252
+
253
+ def refine_mask_hq(image: np.ndarray, mask: np.ndarray, matanyone_processor: Any,
254
+ fallback_enabled: bool = True) -> np.ndarray:
255
+ """Enhanced mask refinement with MatAnyone and robust fallbacks"""
256
+ if image is None or mask is None:
257
+ raise MaskRefinementError("Invalid input image or mask")
258
+
259
  try:
260
+ mask = _process_mask(mask)
 
 
 
 
 
 
 
 
 
 
 
261
 
262
+ if matanyone_processor is not None:
263
+ try:
264
+ logger.debug("Attempting MatAnyone refinement")
265
+ refined_mask = _matanyone_refine(image, mask, matanyone_processor)
266
+
267
+ if refined_mask is not None and _validate_mask_quality(refined_mask, image.shape[:2]):
268
+ logger.debug("MatAnyone refinement successful")
269
+ return refined_mask
270
+ else:
271
+ logger.warning("MatAnyone produced poor quality mask")
272
+
273
+ except Exception as e:
274
+ logger.warning(f"MatAnyone refinement failed: {e}")
275
 
276
+ if fallback_enabled:
277
+ logger.debug("Using enhanced OpenCV refinement")
278
+ return enhance_mask_opencv_advanced(image, mask)
279
+ else:
280
+ raise MaskRefinementError("MatAnyone failed and fallback disabled")
281
+
282
+ except MaskRefinementError:
283
+ raise
284
  except Exception as e:
285
+ logger.error(f"Unexpected mask refinement error: {e}")
286
+ if fallback_enabled:
287
+ return enhance_mask_opencv_advanced(image, mask)
288
+ else:
289
+ raise MaskRefinementError(f"Unexpected error: {e}")
290
 
291
+ def enhance_mask_opencv_advanced(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
292
+ """Advanced OpenCV-based mask enhancement with multiple techniques"""
293
  try:
294
+ if len(mask.shape) == 3:
295
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
296
 
297
+ if mask.max() <= 1.0:
298
+ mask = (mask * 255).astype(np.uint8)
299
 
300
+ refined_mask = cv2.bilateralFilter(mask, 9, 75, 75)
301
+ refined_mask = _guided_filter_approx(image, refined_mask, radius=8, eps=0.2)
 
302
 
303
+ kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
304
+ refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_CLOSE, kernel_close)
 
305
 
306
+ kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
307
+ refined_mask = cv2.morphologyEx(refined_mask, cv2.MORPH_OPEN, kernel_open)
308
 
309
+ refined_mask = cv2.GaussianBlur(refined_mask, (3, 3), 0.8)
 
310
 
311
+ _, refined_mask = cv2.threshold(refined_mask, 127, 255, cv2.THRESH_BINARY)
 
312
 
313
+ return refined_mask
314
 
315
  except Exception as e:
316
+ logger.warning(f"Enhanced OpenCV refinement failed: {e}")
317
+ return cv2.GaussianBlur(mask, (5, 5), 1.0)
318
 
319
  # ============================================================================
320
+ # BACKGROUND REPLACEMENT FUNCTIONS
321
  # ============================================================================
322
 
323
+ def replace_background_hq(frame: np.ndarray, mask: np.ndarray, background: np.ndarray,
324
+ fallback_enabled: bool = True) -> np.ndarray:
325
+ """Enhanced background replacement with comprehensive error handling"""
326
+ if frame is None or mask is None or background is None:
327
+ raise BackgroundReplacementError("Invalid input frame, mask, or background")
328
+
329
  try:
330
+ background = cv2.resize(background, (frame.shape[1], frame.shape[0]),
331
+ interpolation=cv2.INTER_LANCZOS4)
332
 
333
+ if len(mask.shape) == 3:
334
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
 
335
 
336
+ if mask.dtype != np.uint8:
337
+ mask = mask.astype(np.uint8)
338
 
339
+ if mask.max() <= 1.0:
340
+ logger.debug("Converting normalized mask to 0-255 range")
341
+ mask = (mask * 255).astype(np.uint8)
 
 
 
 
342
 
343
+ try:
344
+ result = _advanced_compositing(frame, mask, background)
345
+ logger.debug("Advanced compositing successful")
346
+ return result
347
+
348
+ except Exception as e:
349
+ logger.warning(f"Advanced compositing failed: {e}")
350
+ if fallback_enabled:
351
+ return _simple_compositing(frame, mask, background)
352
+ else:
353
+ raise BackgroundReplacementError(f"Advanced compositing failed: {e}")
354
 
355
+ except BackgroundReplacementError:
356
+ raise
357
+ except Exception as e:
358
+ logger.error(f"Unexpected background replacement error: {e}")
359
+ if fallback_enabled:
360
+ return _simple_compositing(frame, mask, background)
361
+ else:
362
+ raise BackgroundReplacementError(f"Unexpected error: {e}")
363
+
364
+ def create_professional_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
365
+ """Enhanced professional background creation with quality improvements"""
366
+ try:
367
+ if bg_config["type"] == "color":
368
+ background = _create_solid_background(bg_config, width, height)
369
+ elif bg_config["type"] == "gradient":
370
+ background = _create_gradient_background_enhanced(bg_config, width, height)
371
+ else:
372
+ background = np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
373
 
374
+ background = _apply_background_adjustments(background, bg_config)
 
375
 
376
+ return background
377
 
378
  except Exception as e:
379
+ logger.error(f"Background creation error: {e}")
380
+ return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
381
 
382
+ # ============================================================================
383
+ # VALIDATION FUNCTION
384
+ # ============================================================================
385
+
386
+ def validate_video_file(video_path: str) -> Tuple[bool, str]:
387
+ """Enhanced video file validation with detailed checks"""
388
+ if not video_path or not os.path.exists(video_path):
389
+ return False, "Video file not found"
390
+
391
  try:
392
+ file_size = os.path.getsize(video_path)
393
+ if file_size == 0:
394
+ return False, "Video file is empty"
395
 
396
+ if file_size > 2 * 1024 * 1024 * 1024:
397
+ return False, "Video file too large (>2GB)"
 
398
 
399
+ cap = cv2.VideoCapture(video_path)
400
+ if not cap.isOpened():
401
+ return False, "Cannot open video file"
402
 
403
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
404
+ fps = cap.get(cv2.CAP_PROP_FPS)
405
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
406
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
 
 
 
407
 
408
+ cap.release()
409
 
410
+ if frame_count == 0:
411
+ return False, "Video appears to be empty (0 frames)"
 
 
 
 
 
 
412
 
413
+ if fps <= 0 or fps > 120:
414
+ return False, f"Invalid frame rate: {fps}"
415
 
416
+ if width <= 0 or height <= 0:
417
+ return False, f"Invalid resolution: {width}x{height}"
 
 
418
 
419
+ if width > 4096 or height > 4096:
420
+ return False, f"Resolution too high: {width}x{height} (max 4096x4096)"
421
 
422
+ duration = frame_count / fps
423
+ if duration > 300:
424
+ return False, f"Video too long: {duration:.1f}s (max 300s)"
425
 
426
+ return True, f"Valid video: {width}x{height}, {fps:.1f}fps, {duration:.1f}s"
 
427
 
428
  except Exception as e:
429
+ return False, f"Error validating video: {str(e)}"
 
430
 
431
  # ============================================================================
432
+ # HELPER FUNCTIONS - SEGMENTATION
433
  # ============================================================================
434
 
435
+ def _segment_with_intelligent_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
436
+ """Intelligent automatic prompt generation for segmentation"""
 
 
 
 
 
 
 
437
  try:
438
+ h, w = image.shape[:2]
439
+ pos_points, neg_points = _generate_smart_prompts(image)
440
 
441
+ if len(pos_points) == 0:
442
+ pos_points = np.array([[w//2, h//2]], dtype=np.float32)
 
 
 
443
 
444
+ points = np.vstack([pos_points, neg_points])
445
+ labels = np.hstack([
446
+ np.ones(len(pos_points), dtype=np.int32),
447
+ np.zeros(len(neg_points), dtype=np.int32)
448
+ ])
449
 
450
+ logger.debug(f"Using {len(pos_points)} positive, {len(neg_points)} negative points")
451
+
452
+ with torch.no_grad():
453
+ masks, scores, _ = predictor.predict(
454
+ point_coords=points,
455
+ point_labels=labels,
456
+ multimask_output=True
457
+ )
458
+
459
+ if masks is None or len(masks) == 0:
460
+ raise SegmentationError("No masks generated")
461
+
462
+ if scores is not None and len(scores) > 0:
463
+ best_idx = np.argmax(scores)
464
+ best_mask = masks[best_idx]
465
+ logger.debug(f"Selected mask {best_idx} with score {scores[best_idx]:.3f}")
466
  else:
467
+ best_mask = masks[0]
468
 
469
+ return _process_mask(best_mask)
470
 
471
  except Exception as e:
472
+ logger.error(f"Intelligent prompting failed: {e}")
473
+ raise
 
 
 
 
 
 
 
 
 
 
 
474
 
475
+ def _segment_with_basic_prompts(image: np.ndarray, predictor: Any) -> np.ndarray:
476
+ """Basic prompting method for segmentation"""
477
+ h, w = image.shape[:2]
478
 
479
+ positive_points = np.array([
480
+ [w//2, h//3],
481
+ [w//2, h//2],
482
+ [w//2, 2*h//3],
483
+ ], dtype=np.float32)
484
 
485
+ negative_points = np.array([
486
+ [w//10, h//10],
487
+ [9*w//10, h//10],
488
+ [w//10, 9*h//10],
489
+ [9*w//10, 9*h//10],
490
+ ], dtype=np.float32)
 
 
491
 
492
+ points = np.vstack([positive_points, negative_points])
493
+ labels = np.array([1, 1, 1, 0, 0, 0, 0], dtype=np.int32)
 
494
 
495
+ with torch.no_grad():
496
+ masks, scores, _ = predictor.predict(
497
+ point_coords=points,
498
+ point_labels=labels,
499
+ multimask_output=True
500
+ )
501
 
502
+ if masks is None or len(masks) == 0:
503
+ raise SegmentationError("No masks generated")
 
 
504
 
505
+ best_idx = np.argmax(scores) if scores is not None and len(scores) > 0 else 0
506
+ best_mask = masks[best_idx]
507
 
508
+ return _process_mask(best_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
 
510
+ def _generate_smart_prompts(image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
511
+ """Generate optimal positive/negative points automatically"""
512
  try:
513
+ h, w = image.shape[:2]
 
514
 
515
+ try:
516
+ saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
517
+ success, saliency_map = saliency.computeSaliency(image)
518
+
519
+ if success:
520
+ saliency_thresh = cv2.threshold(saliency_map, 0.7, 1, cv2.THRESH_BINARY)[1]
521
+ contours, _ = cv2.findContours((saliency_thresh * 255).astype(np.uint8),
522
+ cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
523
+
524
+ positive_points = []
525
+ if contours:
526
+ for contour in sorted(contours, key=cv2.contourArea, reverse=True)[:3]:
527
+ M = cv2.moments(contour)
528
+ if M["m00"] != 0:
529
+ cx = int(M["m10"] / M["m00"])
530
+ cy = int(M["m01"] / M["m00"])
531
+ if 0 < cx < w and 0 < cy < h:
532
+ positive_points.append([cx, cy])
533
+
534
+ if positive_points:
535
+ logger.debug(f"Generated {len(positive_points)} saliency-based points")
536
+ positive_points = np.array(positive_points, dtype=np.float32)
537
+ else:
538
+ raise Exception("No valid saliency points found")
539
+
540
+ except Exception as e:
541
+ logger.debug(f"Saliency method failed: {e}, using fallback")
542
+ positive_points = np.array([
543
+ [w//2, h//3],
544
+ [w//2, h//2],
545
+ [w//2, 2*h//3],
546
+ ], dtype=np.float32)
547
 
548
+ negative_points = np.array([
549
+ [10, 10],
550
+ [w-10, 10],
551
+ [10, h-10],
552
+ [w-10, h-10],
553
+ [w//2, 5],
554
+ [w//2, h-5],
555
+ ], dtype=np.float32)
556
+
557
+ return positive_points, negative_points
558
 
559
  except Exception as e:
560
+ logger.warning(f"Smart prompt generation failed: {e}")
561
+ h, w = image.shape[:2]
562
+ positive_points = np.array([[w//2, h//2]], dtype=np.float32)
563
+ negative_points = np.array([[10, 10], [w-10, 10]], dtype=np.float32)
564
+ return positive_points, negative_points
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
 
566
  # ============================================================================
567
+ # HELPER FUNCTIONS - REFINEMENT
568
  # ============================================================================
569
 
570
+ def _auto_refine_mask_iteratively(image: np.ndarray, initial_mask: np.ndarray,
571
+ predictor: Any, max_iterations: int = 2) -> np.ndarray:
572
+ """Automatically refine mask based on quality assessment"""
573
+ try:
574
+ current_mask = initial_mask.copy()
575
+
576
+ for iteration in range(max_iterations):
577
+ quality_score = _assess_mask_quality(current_mask, image)
578
+ logger.debug(f"Iteration {iteration}: quality score = {quality_score:.3f}")
579
+
580
+ if quality_score > 0.85:
581
+ logger.debug(f"Quality sufficient after {iteration} iterations")
582
+ break
583
+
584
+ problem_areas = _find_mask_errors(current_mask, image)
585
+
586
+ if np.any(problem_areas):
587
+ corrective_points, corrective_labels = _generate_corrective_prompts(
588
+ image, current_mask, problem_areas
589
+ )
590
+
591
+ if len(corrective_points) > 0:
592
+ try:
593
+ with torch.no_grad():
594
+ masks, scores, _ = predictor.predict(
595
+ point_coords=corrective_points,
596
+ point_labels=corrective_labels,
597
+ mask_input=current_mask[None, :, :],
598
+ multimask_output=False
599
+ )
600
+
601
+ if masks is not None and len(masks) > 0:
602
+ refined_mask = _process_mask(masks[0])
603
+
604
+ if _assess_mask_quality(refined_mask, image) > quality_score:
605
+ current_mask = refined_mask
606
+ logger.debug(f"Improved mask in iteration {iteration}")
607
+ else:
608
+ logger.debug(f"Refinement didn't improve quality in iteration {iteration}")
609
+ break
610
+
611
+ except Exception as e:
612
+ logger.debug(f"Refinement iteration {iteration} failed: {e}")
613
+ break
614
+ else:
615
+ logger.debug("No problem areas detected")
616
+ break
617
+
618
+ return current_mask
619
+
620
+ except Exception as e:
621
+ logger.warning(f"Iterative refinement failed: {e}")
622
+ return initial_mask
623
 
624
+ def _assess_mask_quality(mask: np.ndarray, image: np.ndarray) -> float:
625
+ """Assess mask quality automatically"""
626
+ try:
627
+ h, w = image.shape[:2]
628
+ scores = []
629
+
630
+ mask_area = np.sum(mask > 127)
631
+ total_area = h * w
632
+ area_ratio = mask_area / total_area
633
+
634
+ if 0.05 <= area_ratio <= 0.8:
635
+ area_score = 1.0
636
+ elif area_ratio < 0.05:
637
+ area_score = area_ratio / 0.05
638
+ else:
639
+ area_score = max(0, 1.0 - (area_ratio - 0.8) / 0.2)
640
+ scores.append(area_score)
641
+
642
+ mask_binary = mask > 127
643
+ if np.any(mask_binary):
644
+ mask_center_y, mask_center_x = np.where(mask_binary)
645
+ center_y = np.mean(mask_center_y) / h
646
+ center_x = np.mean(mask_center_x) / w
647
+
648
+ center_score = 1.0 - min(abs(center_x - 0.5), abs(center_y - 0.5))
649
+ scores.append(center_score)
650
+ else:
651
+ scores.append(0.0)
652
+
653
+ edges = cv2.Canny(mask, 50, 150)
654
+ edge_density = np.sum(edges > 0) / total_area
655
+ smoothness_score = max(0, 1.0 - edge_density * 10)
656
+ scores.append(smoothness_score)
657
+
658
+ num_labels, _ = cv2.connectedComponents(mask)
659
+ connectivity_score = max(0, 1.0 - (num_labels - 2) * 0.2)
660
+ scores.append(connectivity_score)
661
+
662
+ weights = [0.3, 0.2, 0.3, 0.2]
663
+ overall_score = np.average(scores, weights=weights)
664
+
665
+ return overall_score
666
+
667
+ except Exception as e:
668
+ logger.warning(f"Quality assessment failed: {e}")
669
+ return 0.5
670
 
671
+ def _find_mask_errors(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
672
+ """Identify problematic areas in mask"""
673
+ try:
674
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
675
+ edges = cv2.Canny(gray, 50, 150)
676
+ mask_edges = cv2.Canny(mask, 50, 150)
677
+ edge_discrepancy = cv2.bitwise_xor(edges, mask_edges)
678
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
679
+ error_regions = cv2.dilate(edge_discrepancy, kernel, iterations=1)
680
+ return error_regions > 0
681
+ except Exception as e:
682
+ logger.warning(f"Error detection failed: {e}")
683
+ return np.zeros_like(mask, dtype=bool)
684
+
685
+ def _generate_corrective_prompts(image: np.ndarray, mask: np.ndarray,
686
+ problem_areas: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
687
+ """Generate corrective prompts based on problem areas"""
688
+ try:
689
+ contours, _ = cv2.findContours(problem_areas.astype(np.uint8),
690
+ cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
691
+
692
+ corrective_points = []
693
+ corrective_labels = []
694
+
695
+ for contour in contours:
696
+ if cv2.contourArea(contour) > 100:
697
+ M = cv2.moments(contour)
698
+ if M["m00"] != 0:
699
+ cx = int(M["m10"] / M["m00"])
700
+ cy = int(M["m01"] / M["m00"])
701
+
702
+ current_mask_value = mask[cy, cx]
703
+
704
+ if current_mask_value < 127:
705
+ corrective_points.append([cx, cy])
706
+ corrective_labels.append(1)
707
+ else:
708
+ corrective_points.append([cx, cy])
709
+ corrective_labels.append(0)
710
+
711
+ return (np.array(corrective_points, dtype=np.float32) if corrective_points else np.array([]).reshape(0, 2),
712
+ np.array(corrective_labels, dtype=np.int32) if corrective_labels else np.array([], dtype=np.int32))
713
+
714
+ except Exception as e:
715
+ logger.warning(f"Corrective prompt generation failed: {e}")
716
+ return np.array([]).reshape(0, 2), np.array([], dtype=np.int32)
717
 
718
  # ============================================================================
719
+ # HELPER FUNCTIONS - PROCESSING
720
  # ============================================================================
721
 
722
+ def _process_mask(mask: np.ndarray) -> np.ndarray:
723
+ """Process raw mask to ensure correct format and range"""
 
 
 
 
 
 
 
 
724
  try:
725
+ if len(mask.shape) > 2:
726
+ mask = mask.squeeze()
 
 
 
 
727
 
728
+ if len(mask.shape) > 2:
729
+ mask = mask[:, :, 0] if mask.shape[2] > 0 else mask.sum(axis=2)
 
 
 
 
 
 
730
 
731
+ if mask.dtype == bool:
732
+ mask = mask.astype(np.uint8) * 255
733
+ elif mask.dtype == np.float32 or mask.dtype == np.float64:
734
+ if mask.max() <= 1.0:
735
+ mask = (mask * 255).astype(np.uint8)
736
+ else:
737
+ mask = np.clip(mask, 0, 255).astype(np.uint8)
738
  else:
739
+ mask = mask.astype(np.uint8)
740
 
741
+ kernel = np.ones((3, 3), np.uint8)
742
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
743
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
744
 
745
+ _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
 
 
 
 
 
746
 
 
747
  return mask
748
 
 
 
749
  except Exception as e:
750
+ logger.error(f"Mask processing failed: {e}")
751
+ h, w = mask.shape[:2] if len(mask.shape) >= 2 else (256, 256)
752
+ fallback = np.zeros((h, w), dtype=np.uint8)
753
+ fallback[h//4:3*h//4, w//4:3*w//4] = 255
754
+ return fallback
755
 
756
+ def _validate_mask_quality(mask: np.ndarray, image_shape: Tuple[int, int]) -> bool:
757
+ """Validate that the mask meets quality criteria"""
 
 
 
758
  try:
759
+ h, w = image_shape
760
+ mask_area = np.sum(mask > 127)
761
+ total_area = h * w
 
 
 
762
 
763
+ area_ratio = mask_area / total_area
764
+ if area_ratio < 0.05 or area_ratio > 0.8:
765
+ logger.warning(f"Suspicious mask area ratio: {area_ratio:.3f}")
766
+ return False
 
 
 
 
767
 
768
+ mask_binary = mask > 127
769
+ mask_center_y, mask_center_x = np.where(mask_binary)
770
 
771
+ if len(mask_center_y) == 0:
772
+ logger.warning("Empty mask")
773
+ return False
 
 
 
 
 
 
 
774
 
775
+ center_y = np.mean(mask_center_y)
776
+ center_x = np.mean(mask_center_x)
777
+
778
+ if center_y < h * 0.2 or center_y > h * 0.9:
779
+ logger.warning(f"Mask center too far from expected person location: y={center_y/h:.2f}")
780
+ return False
781
+
782
+ return True
783
+
784
+ except Exception as e:
785
+ logger.warning(f"Mask validation error: {e}")
786
+ return True
787
+
788
+ def _fallback_segmentation(image: np.ndarray) -> np.ndarray:
789
+ """Fallback segmentation when AI models fail"""
790
+ try:
791
+ logger.info("Using fallback segmentation strategy")
792
+ h, w = image.shape[:2]
793
 
794
  try:
795
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
796
+
797
+ edge_pixels = np.concatenate([
798
+ gray[0, :], gray[-1, :], gray[:, 0], gray[:, -1]
799
+ ])
800
+ bg_color = np.median(edge_pixels)
801
+
802
+ diff = np.abs(gray.astype(float) - bg_color)
803
+ mask = (diff > 30).astype(np.uint8) * 255
804
+
805
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
806
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
807
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
808
+
809
+ if _validate_mask_quality(mask, image.shape[:2]):
810
+ logger.info("Background subtraction fallback successful")
811
+ return mask
812
+
813
  except Exception as e:
814
+ logger.warning(f"Background subtraction fallback failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
815
 
816
+ mask = np.zeros((h, w), dtype=np.uint8)
 
 
 
 
 
 
817
 
818
+ center_x, center_y = w // 2, h // 2
819
+ radius_x, radius_y = w // 3, h // 2.5
820
 
821
+ y, x = np.ogrid[:h, :w]
822
+ mask_ellipse = ((x - center_x) / radius_x) ** 2 + ((y - center_y) / radius_y) ** 2 <= 1
823
+ mask[mask_ellipse] = 255
 
 
 
824
 
825
+ logger.info("Using geometric fallback mask")
826
  return mask
827
 
 
 
828
  except Exception as e:
829
+ logger.error(f"All fallback strategies failed: {e}")
830
+ h, w = image.shape[:2]
831
+ mask = np.zeros((h, w), dtype=np.uint8)
832
+ mask[h//6:5*h//6, w//4:3*w//4] = 255
833
+ return mask
 
 
 
 
834
 
835
+ def _matanyone_refine(image: np.ndarray, mask: np.ndarray, processor: Any) -> Optional[np.ndarray]:
836
+ """Attempt MatAnyone mask refinement"""
 
 
 
 
837
  try:
838
+ if hasattr(processor, 'infer'):
839
+ refined_mask = processor.infer(image, mask)
840
+ elif hasattr(processor, 'process'):
841
+ refined_mask = processor.process(image, mask)
842
+ elif callable(processor):
843
+ refined_mask = processor(image, mask)
844
+ else:
845
+ logger.warning("Unknown MatAnyone interface")
846
+ return None
847
 
848
+ if refined_mask is None:
849
+ return None
850
+
851
+ refined_mask = _process_mask(refined_mask)
852
+ logger.debug("MatAnyone refinement successful")
853
+ return refined_mask
 
 
 
 
 
 
 
854
 
 
 
 
 
 
 
 
 
855
  except Exception as e:
856
+ logger.warning(f"MatAnyone processing error: {e}")
857
+ return None
 
 
 
858
 
859
+ def _guided_filter_approx(guide: np.ndarray, mask: np.ndarray, radius: int = 8, eps: float = 0.2) -> np.ndarray:
860
+ """Approximation of guided filter for edge-aware smoothing"""
861
  try:
862
+ guide_gray = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY) if len(guide.shape) == 3 else guide
863
+ guide_gray = guide_gray.astype(np.float32) / 255.0
864
+ mask_float = mask.astype(np.float32) / 255.0
865
 
866
+ kernel_size = 2 * radius + 1
 
867
 
868
+ mean_guide = cv2.boxFilter(guide_gray, -1, (kernel_size, kernel_size))
869
+ mean_mask = cv2.boxFilter(mask_float, -1, (kernel_size, kernel_size))
870
+ corr_guide_mask = cv2.boxFilter(guide_gray * mask_float, -1, (kernel_size, kernel_size))
871
 
872
+ cov_guide_mask = corr_guide_mask - mean_guide * mean_mask
873
+ mean_guide_sq = cv2.boxFilter(guide_gray * guide_gray, -1, (kernel_size, kernel_size))
874
+ var_guide = mean_guide_sq - mean_guide * mean_guide
875
 
876
+ a = cov_guide_mask / (var_guide + eps)
877
+ b = mean_mask - a * mean_guide
878
 
879
+ mean_a = cv2.boxFilter(a, -1, (kernel_size, kernel_size))
880
+ mean_b = cv2.boxFilter(b, -1, (kernel_size, kernel_size))
881
 
882
+ output = mean_a * guide_gray + mean_b
883
+ output = np.clip(output * 255, 0, 255).astype(np.uint8)
884
 
885
+ return output
886
 
887
  except Exception as e:
888
+ logger.warning(f"Guided filter approximation failed: {e}")
889
+ return mask
890
 
891
  # ============================================================================
892
+ # HELPER FUNCTIONS - COMPOSITING
893
  # ============================================================================
894
 
895
+ def _advanced_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
896
+ """Advanced compositing with edge feathering and color correction"""
 
 
 
 
897
  try:
898
+ threshold = 100
899
+ _, mask_binary = cv2.threshold(mask, threshold, 255, cv2.THRESH_BINARY)
900
 
901
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
902
+ mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_CLOSE, kernel)
903
+ mask_binary = cv2.morphologyEx(mask_binary, cv2.MORPH_OPEN, kernel)
904
 
905
+ mask_smooth = cv2.GaussianBlur(mask_binary.astype(np.float32), (5, 5), 1.0)
906
+ mask_smooth = mask_smooth / 255.0
907
 
908
+ mask_smooth = np.power(mask_smooth, 0.8)
 
 
909
 
910
+ mask_smooth = np.where(mask_smooth > 0.5,
911
+ np.minimum(mask_smooth * 1.1, 1.0),
912
+ mask_smooth * 0.9)
 
 
 
 
 
 
 
 
913
 
914
+ frame_adjusted = _color_match_edges(frame, background, mask_smooth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
915
 
916
+ alpha_3ch = np.stack([mask_smooth] * 3, axis=2)
917
 
918
+ frame_float = frame_adjusted.astype(np.float32)
919
+ background_float = background.astype(np.float32)
920
+
921
+ result = frame_float * alpha_3ch + background_float * (1 - alpha_3ch)
922
+ result = np.clip(result, 0, 255).astype(np.uint8)
923
+
924
+ return result
925
 
926
  except Exception as e:
927
+ logger.error(f"Advanced compositing error: {e}")
928
+ raise
 
 
 
 
929
 
930
+ def _color_match_edges(frame: np.ndarray, background: np.ndarray, alpha: np.ndarray) -> np.ndarray:
931
+ """Subtle color matching at edges to reduce halos"""
 
 
 
932
  try:
933
+ edge_mask = cv2.Sobel(alpha, cv2.CV_64F, 1, 1, ksize=3)
934
+ edge_mask = np.abs(edge_mask)
935
+ edge_mask = (edge_mask > 0.1).astype(np.float32)
936
 
937
+ edge_areas = edge_mask > 0
938
+ if not np.any(edge_areas):
939
+ return frame
940
 
941
+ frame_adjusted = frame.copy().astype(np.float32)
942
+ background_float = background.astype(np.float32)
 
943
 
944
+ adjustment_strength = 0.1
945
+ for c in range(3):
946
+ frame_adjusted[:, :, c] = np.where(
947
+ edge_areas,
948
+ frame_adjusted[:, :, c] * (1 - adjustment_strength) +
949
+ background_float[:, :, c] * adjustment_strength,
950
+ frame_adjusted[:, :, c]
951
+ )
952
 
953
+ return np.clip(frame_adjusted, 0, 255).astype(np.uint8)
954
 
955
+ except Exception as e:
956
+ logger.warning(f"Color matching failed: {e}")
957
+ return frame
958
+
959
+ def _simple_compositing(frame: np.ndarray, mask: np.ndarray, background: np.ndarray) -> np.ndarray:
960
+ """Simple fallback compositing method"""
961
+ try:
962
+ logger.info("Using simple compositing fallback")
963
 
964
+ background = cv2.resize(background, (frame.shape[1], frame.shape[0]))
 
965
 
966
+ if len(mask.shape) == 3:
967
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
968
+ if mask.max() <= 1.0:
969
+ mask = (mask * 255).astype(np.uint8)
970
 
971
+ _, mask_binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
 
972
 
973
+ mask_norm = mask_binary.astype(np.float32) / 255.0
974
+ mask_3ch = np.stack([mask_norm] * 3, axis=2)
 
975
 
976
+ result = frame * mask_3ch + background * (1 - mask_3ch)
977
+ return result.astype(np.uint8)
978
 
979
  except Exception as e:
980
+ logger.error(f"Simple compositing failed: {e}")
981
+ return frame
982
 
983
  # ============================================================================
984
+ # HELPER FUNCTIONS - BACKGROUND CREATION
985
  # ============================================================================
986
 
987
+ def _create_solid_background(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
988
+ """Create solid color background"""
989
+ color_hex = bg_config["colors"][0].lstrip('#')
990
+ color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
991
+ color_bgr = color_rgb[::-1]
992
+ return np.full((height, width, 3), color_bgr, dtype=np.uint8)
993
+
994
+ def _create_gradient_background_enhanced(bg_config: Dict[str, Any], width: int, height: int) -> np.ndarray:
995
+ """Create enhanced gradient background with better quality"""
996
  try:
997
+ colors = bg_config["colors"]
998
+ direction = bg_config.get("direction", "vertical")
 
 
 
 
 
 
 
 
 
 
 
999
 
1000
+ rgb_colors = []
1001
+ for color_hex in colors:
1002
+ color_hex = color_hex.lstrip('#')
1003
+ rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
1004
+ rgb_colors.append(rgb)
 
1005
 
1006
+ if not rgb_colors:
1007
+ rgb_colors = [(128, 128, 128)]
1008
 
1009
+ if direction == "vertical":
1010
+ background = _create_vertical_gradient(rgb_colors, width, height)
1011
+ elif direction == "horizontal":
1012
+ background = _create_horizontal_gradient(rgb_colors, width, height)
1013
+ elif direction == "diagonal":
1014
+ background = _create_diagonal_gradient(rgb_colors, width, height)
1015
+ elif direction in ["radial", "soft_radial"]:
1016
+ background = _create_radial_gradient(rgb_colors, width, height, direction == "soft_radial")
1017
  else:
1018
+ background = _create_vertical_gradient(rgb_colors, width, height)
1019
 
1020
+ return cv2.cvtColor(background, cv2.COLOR_RGB2BGR)
1021
 
1022
  except Exception as e:
1023
+ logger.error(f"Gradient creation error: {e}")
1024
+ return np.full((height, width, 3), (128, 128, 128), dtype=np.uint8)
1025
 
1026
+ def _create_vertical_gradient(colors: list, width: int, height: int) -> np.ndarray:
1027
+ """Create vertical gradient using NumPy for performance"""
1028
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
1029
 
1030
+ for y in range(height):
1031
+ progress = y / height if height > 0 else 0
1032
+ color = _interpolate_color(colors, progress)
1033
+ gradient[y, :] = color
 
1034
 
1035
+ return gradient
1036
+
1037
+ def _create_horizontal_gradient(colors: list, width: int, height: int) -> np.ndarray:
1038
+ """Create horizontal gradient using NumPy for performance"""
1039
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
 
1040
 
1041
+ for x in range(width):
1042
+ progress = x / width if width > 0 else 0
1043
+ color = _interpolate_color(colors, progress)
1044
+ gradient[:, x] = color
1045
 
1046
+ return gradient
1047
+
1048
+ def _create_diagonal_gradient(colors: list, width: int, height: int) -> np.ndarray:
1049
+ """Create diagonal gradient using vectorized operations"""
1050
+ y_coords, x_coords = np.mgrid[0:height, 0:width]
1051
+ max_distance = width + height
1052
+ progress = (x_coords + y_coords) / max_distance
1053
+ progress = np.clip(progress, 0, 1)
1054
 
1055
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
1056
+ for c in range(3):
1057
+ gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
1058
 
1059
+ return gradient
1060
+
1061
+ def _create_radial_gradient(colors: list, width: int, height: int, soft: bool = False) -> np.ndarray:
1062
+ """Create radial gradient using vectorized operations"""
1063
+ center_x, center_y = width // 2, height // 2
1064
+ max_distance = np.sqrt(center_x**2 + center_y**2)
1065
 
1066
+ y_coords, x_coords = np.mgrid[0:height, 0:width]
1067
+ distances = np.sqrt((x_coords - center_x)**2 + (y_coords - center_y)**2)
1068
+ progress = distances / max_distance
1069
+ progress = np.clip(progress, 0, 1)
1070
+
1071
+ if soft:
1072
+ progress = np.power(progress, 0.7)
1073
+
1074
+ gradient = np.zeros((height, width, 3), dtype=np.uint8)
1075
+ for c in range(3):
1076
+ gradient[:, :, c] = _vectorized_color_interpolation(colors, progress, c)
1077
+
1078
+ return gradient
1079
 
1080
+ def _vectorized_color_interpolation(colors: list, progress: np.ndarray, channel: int) -> np.ndarray:
1081
+ """Vectorized color interpolation for performance"""
1082
+ if len(colors) == 1:
1083
+ return np.full_like(progress, colors[0][channel], dtype=np.uint8)
1084
+
1085
+ num_segments = len(colors) - 1
1086
+ segment_progress = progress * num_segments
1087
+ segment_indices = np.floor(segment_progress).astype(int)
1088
+ segment_indices = np.clip(segment_indices, 0, num_segments - 1)
1089
+ local_progress = segment_progress - segment_indices
1090
+
1091
+ start_colors = np.array([colors[i][channel] for i in range(len(colors))])
1092
+ end_colors = np.array([colors[min(i + 1, len(colors) - 1)][channel] for i in range(len(colors))])
1093
+
1094
+ start_vals = start_colors[segment_indices]
1095
+ end_vals = end_colors[segment_indices]
1096
+
1097
+ result = start_vals + (end_vals - start_vals) * local_progress
1098
+ return np.clip(result, 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1099
 
1100
+ def _interpolate_color(colors: list, progress: float) -> tuple:
1101
+ """Interpolate between multiple colors"""
1102
+ if len(colors) == 1:
1103
+ return colors[0]
1104
+ elif len(colors) == 2:
1105
+ r = int(colors[0][0] + (colors[1][0] - colors[0][0]) * progress)
1106
+ g = int(colors[0][1] + (colors[1][1] - colors[0][1]) * progress)
1107
+ b = int(colors[0][2] + (colors[1][2] - colors[0][2]) * progress)
1108
+ return (r, g, b)
1109
+ else:
1110
+ segment = progress * (len(colors) - 1)
1111
+ idx = int(segment)
1112
+ local_progress = segment - idx
1113
+ if idx >= len(colors) - 1:
1114
+ return colors[-1]
1115
+ c1, c2 = colors[idx], colors[idx + 1]
1116
+ r = int(c1[0] + (c2[0] - c1[0]) * local_progress)
1117
+ g = int(c1[1] + (c2[1] - c1[1]) * local_progress)
1118
+ b = int(c1[2] + (c2[2] - c1[2]) * local_progress)
1119
+ return (r, g, b)
1120
 
1121
+ def _apply_background_adjustments(background: np.ndarray, bg_config: Dict[str, Any]) -> np.ndarray:
1122
+ """Apply brightness and contrast adjustments to background"""
 
1123
  try:
1124
+ brightness = bg_config.get("brightness", 1.0)
1125
+ contrast = bg_config.get("contrast", 1.0)
1126
 
1127
+ if brightness != 1.0 or contrast != 1.0:
1128
+ background = background.astype(np.float32)
1129
+ background = background * contrast * brightness
1130
+ background = np.clip(background, 0, 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1131
 
1132
+ return background
1133
 
1134
  except Exception as e:
1135
+ logger.warning(f"Background adjustment failed: {e}")
1136
+ return background