MogensR commited on
Commit
d2502a6
·
1 Parent(s): b6786fa

Update utils/cv_processing.py

Browse files
Files changed (1) hide show
  1. utils/cv_processing.py +109 -82
utils/cv_processing.py CHANGED
@@ -1,6 +1,9 @@
1
  #!/usr/bin/env python3
2
  """
3
  cv_processing.py · FIXED VERSION with proper SAM2 handling + MatAnyone stateful integration
 
 
 
4
  """
5
 
6
  from __future__ import annotations
@@ -28,41 +31,48 @@
28
  PROFESSIONAL_BACKGROUNDS = PROFESSIONAL_BACKGROUNDS_LOCAL
29
 
30
  # ----------------------------------------------------------------------------
31
- # Helpers
32
  # ----------------------------------------------------------------------------
33
  def _ensure_rgb(img: np.ndarray) -> np.ndarray:
 
 
 
 
34
  if img is None:
35
  return img
36
- if img.ndim == 3 and img.shape[2] == 3:
37
- return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
38
- return img
 
 
 
39
 
40
- def _ensure_rgb01(frame_bgr: np.ndarray) -> np.ndarray:
41
  """
42
- Convert BGR uint8 [H,W,3] to RGB float32 in [0,1].
43
- Accepts a variety of layouts and coerces safely to HWC.
44
  """
45
- if frame_bgr is None:
46
- raise ValueError("frame_bgr is None")
47
- x = frame_bgr
48
- if x.ndim == 2:
49
- x = np.stack([x, x, x], axis=-1) # gray -> 3ch
50
- # channels-first -> HWC
51
- if x.ndim == 3 and x.shape[0] in (1, 3, 4) and x.shape[-1] not in (1, 3, 4):
52
- x = np.transpose(x, (1, 2, 0))
53
- if x.dtype != np.uint8:
54
- x = np.clip(x, 0, 255).astype(np.uint8)
55
- rgb = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
56
- return (rgb.astype(np.float32) / 255.0).copy()
57
 
58
  def _to_mask01(m: np.ndarray) -> np.ndarray:
59
  if m is None:
60
  return None
61
- if m.ndim == 3 and m.shape[2] in (1, 3):
62
  m = m[..., 0]
63
- m = m.astype(np.float32)
64
- if m.max() > 1.0:
65
- m = m / 255.0
 
 
66
  return np.clip(m, 0.0, 1.0)
67
 
68
  def _mask_to_2d(mask: np.ndarray) -> np.ndarray:
@@ -71,29 +81,31 @@ def _mask_to_2d(mask: np.ndarray) -> np.ndarray:
71
  Handles HWC/CHW/B1HW/1HW/HW, etc.
72
  """
73
  m = np.asarray(mask)
74
- # channels-first 1xHxW
 
75
  if m.ndim == 3 and m.shape[0] == 1 and (m.shape[1] > 1 and m.shape[2] > 1):
76
  m = m[0]
77
- # channels-last HxWx1
78
  if m.ndim == 3 and m.shape[-1] == 1:
79
  m = m[..., 0]
80
- # multi-channel -> take first channel
81
  if m.ndim == 3:
82
  m = m[..., 0] if m.shape[-1] in (1, 3, 4) else m[0]
83
- # squeeze anything left
84
  m = np.squeeze(m)
85
  if m.ndim != 2:
 
86
  h = int(m.shape[-2]) if m.ndim >= 2 else 512
87
  w = int(m.shape[-1]) if m.ndim >= 2 else 512
88
  logger.warning(f"_mask_to_2d: unexpected shape {mask.shape}, creating neutral mask.")
89
  m = np.full((h, w), 0.5, dtype=np.float32)
90
- # dtype/range
91
  if m.dtype == np.uint8:
92
  m = m.astype(np.float32) / 255.0
93
  elif m.dtype != np.float32:
94
  m = m.astype(np.float32)
95
- m = np.clip(m, 0.0, 1.0)
96
- return np.ascontiguousarray(m)
97
 
98
  def _feather(mask01: np.ndarray, k: int = 2) -> np.ndarray:
99
  if mask01.ndim == 3:
@@ -133,17 +145,18 @@ def create_professional_background(key_or_cfg: Any, width: int, height: int) ->
133
  return _vertical_gradient(dark, color, width, height)
134
 
135
  # ----------------------------------------------------------------------------
136
- # Improved Segmentation
137
  # ----------------------------------------------------------------------------
138
- def _simple_person_segmentation(frame_bgr: np.ndarray) -> np.ndarray:
139
- """Basic fallback segmentation using color detection"""
140
- h, w = frame_bgr.shape[:2]
141
- hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
142
 
143
  lower_skin = np.array([0, 20, 70], dtype=np.uint8)
144
  upper_skin = np.array([20, 255, 255], dtype=np.uint8)
145
  skin_mask = cv2.inRange(hsv, lower_skin, upper_skin)
146
 
 
147
  lower_green = np.array([40, 40, 40], dtype=np.uint8)
148
  upper_green = np.array([80, 255, 255], dtype=np.uint8)
149
  green_mask = cv2.inRange(hsv, lower_green, upper_green)
@@ -171,65 +184,77 @@ def segment_person_hq(
171
  **_compat_kwargs,
172
  ) -> np.ndarray:
173
  """
174
- High-quality person segmentation with proper SAM2 handling
 
175
  """
176
- h, w = frame.shape[:2]
 
177
 
178
  if use_sam2 is False:
179
- return _simple_person_segmentation(frame)
180
 
181
  if predictor is not None:
182
  try:
183
  if hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
184
- rgb = _ensure_rgb(frame)
185
- predictor.set_image(rgb)
186
-
187
- points = []
188
- labels = []
 
189
 
190
- points.append([w // 2, h // 2]); labels.append(1)
191
- points.append([w // 2, h // 4]); labels.append(1)
192
- points.append([w // 2, h // 2 + h // 8]); labels.append(1)
193
 
194
- point_coords = np.array(points, dtype=np.float32)
195
- point_labels = np.array(labels, dtype=np.int32)
 
 
 
 
 
196
 
197
  result = predictor.predict(
198
- point_coords=point_coords,
199
- point_labels=point_labels,
200
  multimask_output=True
201
  )
202
 
 
203
  if isinstance(result, dict):
204
  masks = result.get("masks", None)
205
  scores = result.get("scores", None)
206
- elif isinstance(result, tuple) and len(result) >= 2:
207
  masks, scores = result[0], result[1]
208
  else:
209
- masks = result
210
- scores = None
211
 
212
  if masks is not None:
213
- masks = np.array(masks)
214
- if masks.size > 0:
215
- if masks.ndim == 3 and masks.shape[0] > 0:
216
- if scores is not None and len(scores) > 0:
217
- best_idx = np.argmax(scores)
218
- mask = masks[best_idx]
219
- else:
220
- mask = masks[0]
221
- elif masks.ndim == 2:
222
- mask = masks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  else:
224
- logger.warning(f"Unexpected mask shape from SAM2: {masks.shape}")
225
- mask = None
226
-
227
- if mask is not None:
228
- mask = _to_mask01(mask)
229
- if mask.max() > 0.1:
230
- return mask
231
- else:
232
- logger.warning("SAM2 mask too weak, using fallback")
233
  else:
234
  logger.warning("SAM2 returned no masks")
235
 
@@ -238,7 +263,7 @@ def segment_person_hq(
238
 
239
  if fallback_enabled:
240
  logger.debug("Using fallback segmentation")
241
- return _simple_person_segmentation(frame)
242
  else:
243
  return np.ones((h, w), dtype=np.float32)
244
 
@@ -276,7 +301,7 @@ def refine_mask_hq(
276
 
277
  if matanyone is not None and callable(matanyone):
278
  try:
279
- rgb01 = _ensure_rgb01(frame)
280
 
281
  # Stateful path (preferred)
282
  if frame_idx is not None:
@@ -285,7 +310,7 @@ def refine_mask_hq(
285
  else:
286
  refined = matanyone(rgb01) # propagate without mask
287
  refined = _mask_to_2d(refined)
288
- if refined.max() > 0.1:
289
  return _postprocess_mask(refined)
290
  logger.warning("MatAnyone stateful refinement produced empty/weak mask; falling back.")
291
 
@@ -315,7 +340,7 @@ def refine_mask_hq(
315
  except Exception as e:
316
  logger.debug(f"MatAnyone process failed: {e}")
317
 
318
- if refined is not None and refined.max() > 0.1:
319
  return _postprocess_mask(refined)
320
  else:
321
  logger.warning("MatAnyone refinement failed or produced empty mask")
@@ -331,7 +356,7 @@ def refine_mask_hq(
331
 
332
  def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
333
  """Post-process mask to clean edges and remove artifacts"""
334
- mask_uint8 = (mask01 * 255).astype(np.uint8)
335
 
336
  kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
337
  mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_close)
@@ -342,11 +367,12 @@ def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
342
 
343
  mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)
344
 
345
- return mask_uint8.astype(np.float32) / 255.0
 
346
 
347
  def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
348
  """Simple fallback refinement"""
349
- mask_uint8 = (mask01 * 255).astype(np.uint8)
350
 
351
  mask_uint8 = cv2.bilateralFilter(mask_uint8, 9, 75, 75)
352
 
@@ -356,10 +382,11 @@ def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
356
 
357
  mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)
358
 
359
- return mask_uint8.astype(np.float32) / 255.0
 
360
 
361
  # ----------------------------------------------------------------------------
362
- # Compositing
363
  # ----------------------------------------------------------------------------
364
  def replace_background_hq(
365
  frame: np.ndarray,
@@ -368,7 +395,7 @@ def replace_background_hq(
368
  fallback_enabled: bool = True,
369
  **_compat,
370
  ) -> np.ndarray:
371
- """High-quality background replacement with alpha blending"""
372
  try:
373
  H, W = frame.shape[:2]
374
 
 
1
  #!/usr/bin/env python3
2
  """
3
  cv_processing.py · FIXED VERSION with proper SAM2 handling + MatAnyone stateful integration
4
+
5
+ All public functions in this module expect RGB images (H,W,3) unless stated otherwise.
6
+ CoreVideoProcessor already converts BGR→RGB before calling into this module.
7
  """
8
 
9
  from __future__ import annotations
 
31
  PROFESSIONAL_BACKGROUNDS = PROFESSIONAL_BACKGROUNDS_LOCAL
32
 
33
  # ----------------------------------------------------------------------------
34
+ # Helpers (RGB-safe)
35
  # ----------------------------------------------------------------------------
36
  def _ensure_rgb(img: np.ndarray) -> np.ndarray:
37
+ """
38
+ Identity for RGB HWC images. If channels-first, convert to HWC.
39
+ DOES NOT perform BGR↔RGB swaps (the caller is responsible for color space).
40
+ """
41
  if img is None:
42
  return img
43
+ x = np.asarray(img)
44
+ if x.ndim == 3 and x.shape[-1] in (3, 4):
45
+ return x[..., :3]
46
+ if x.ndim == 3 and x.shape[0] in (1, 3, 4) and x.shape[-1] not in (1, 3, 4):
47
+ return np.transpose(x, (1, 2, 0))[..., :3]
48
+ return x
49
 
50
+ def _ensure_rgb01(frame_rgb: np.ndarray) -> np.ndarray:
51
  """
52
+ Convert RGB uint8/float to RGB float32 in [0,1], HWC.
53
+ No channel swaps are performed.
54
  """
55
+ if frame_rgb is None:
56
+ raise ValueError("frame_rgb is None")
57
+ x = _ensure_rgb(frame_rgb)
58
+ if x.dtype == np.uint8:
59
+ return (x.astype(np.float32) / 255.0).copy()
60
+ if np.issubdtype(x.dtype, np.floating):
61
+ return np.clip(x.astype(np.float32), 0.0, 1.0).copy()
62
+ # other integer types
63
+ x = np.clip(x, 0, 255).astype(np.uint8)
64
+ return (x.astype(np.float32) / 255.0).copy()
 
 
65
 
66
  def _to_mask01(m: np.ndarray) -> np.ndarray:
67
  if m is None:
68
  return None
69
+ if m.ndim == 3 and m.shape[2] in (1, 3, 4):
70
  m = m[..., 0]
71
+ m = np.asarray(m)
72
+ if m.dtype == np.uint8:
73
+ m = m.astype(np.float32) / 255.0
74
+ elif m.dtype != np.float32:
75
+ m = m.astype(np.float32)
76
  return np.clip(m, 0.0, 1.0)
77
 
78
  def _mask_to_2d(mask: np.ndarray) -> np.ndarray:
 
81
  Handles HWC/CHW/B1HW/1HW/HW, etc.
82
  """
83
  m = np.asarray(mask)
84
+
85
+ # CHW with single channel
86
  if m.ndim == 3 and m.shape[0] == 1 and (m.shape[1] > 1 and m.shape[2] > 1):
87
  m = m[0]
88
+ # HWC with single channel
89
  if m.ndim == 3 and m.shape[-1] == 1:
90
  m = m[..., 0]
91
+ # generic 3D -> take first channel
92
  if m.ndim == 3:
93
  m = m[..., 0] if m.shape[-1] in (1, 3, 4) else m[0]
94
+
95
  m = np.squeeze(m)
96
  if m.ndim != 2:
97
+ # fall back to neutral 0.5 mask
98
  h = int(m.shape[-2]) if m.ndim >= 2 else 512
99
  w = int(m.shape[-1]) if m.ndim >= 2 else 512
100
  logger.warning(f"_mask_to_2d: unexpected shape {mask.shape}, creating neutral mask.")
101
  m = np.full((h, w), 0.5, dtype=np.float32)
102
+
103
  if m.dtype == np.uint8:
104
  m = m.astype(np.float32) / 255.0
105
  elif m.dtype != np.float32:
106
  m = m.astype(np.float32)
107
+
108
+ return np.ascontiguousarray(np.clip(m, 0.0, 1.0))
109
 
110
  def _feather(mask01: np.ndarray, k: int = 2) -> np.ndarray:
111
  if mask01.ndim == 3:
 
145
  return _vertical_gradient(dark, color, width, height)
146
 
147
  # ----------------------------------------------------------------------------
148
+ # Improved Segmentation (expects RGB input)
149
  # ----------------------------------------------------------------------------
150
+ def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
151
+ """Basic fallback segmentation using color detection on RGB frames."""
152
+ h, w = frame_rgb.shape[:2]
153
+ hsv = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2HSV)
154
 
155
  lower_skin = np.array([0, 20, 70], dtype=np.uint8)
156
  upper_skin = np.array([20, 255, 255], dtype=np.uint8)
157
  skin_mask = cv2.inRange(hsv, lower_skin, upper_skin)
158
 
159
+ # detect greenscreen-ish
160
  lower_green = np.array([40, 40, 40], dtype=np.uint8)
161
  upper_green = np.array([80, 255, 255], dtype=np.uint8)
162
  green_mask = cv2.inRange(hsv, lower_green, upper_green)
 
184
  **_compat_kwargs,
185
  ) -> np.ndarray:
186
  """
187
+ High-quality person segmentation with proper SAM2 handling.
188
+ Expects RGB frame (H,W,3), uint8 or float in [0,1].
189
  """
190
+ frame_rgb = _ensure_rgb(frame)
191
+ h, w = frame_rgb.shape[:2]
192
 
193
  if use_sam2 is False:
194
+ return _simple_person_segmentation(frame_rgb)
195
 
196
  if predictor is not None:
197
  try:
198
  if hasattr(predictor, "set_image") and hasattr(predictor, "predict"):
199
+ # Predictor adapter expects RGB uint8; convert if needed
200
+ if frame_rgb.dtype != np.uint8:
201
+ rgb_u8 = np.clip(frame_rgb * (255.0 if frame_rgb.dtype != np.uint8 else 1.0), 0, 255).astype(np.uint8) \
202
+ if np.issubdtype(frame_rgb.dtype, np.floating) else frame_rgb.astype(np.uint8)
203
+ else:
204
+ rgb_u8 = frame_rgb
205
 
206
+ predictor.set_image(rgb_u8)
 
 
207
 
208
+ # Center + a couple of body-biased prompts
209
+ points = np.array([
210
+ [w // 2, h // 2],
211
+ [w // 2, h // 4],
212
+ [w // 2, h // 2 + h // 8],
213
+ ], dtype=np.float32)
214
+ labels = np.array([1, 1, 1], dtype=np.int32)
215
 
216
  result = predictor.predict(
217
+ point_coords=points,
218
+ point_labels=labels,
219
  multimask_output=True
220
  )
221
 
222
+ # normalize outputs
223
  if isinstance(result, dict):
224
  masks = result.get("masks", None)
225
  scores = result.get("scores", None)
226
+ elif isinstance(result, (tuple, list)) and len(result) >= 2:
227
  masks, scores = result[0], result[1]
228
  else:
229
+ masks, scores = result, None
 
230
 
231
  if masks is not None:
232
+ masks = np.asarray(masks)
233
+ if masks.ndim == 2:
234
+ mask = masks
235
+ elif masks.ndim == 3 and masks.shape[0] > 0:
236
+ if scores is not None:
237
+ best_idx = int(np.argmax(np.asarray(scores)))
238
+ mask = masks[best_idx]
239
+ else:
240
+ mask = masks[0]
241
+ elif masks.ndim == 4 and masks.shape[1] == 1:
242
+ # (N,1,H,W)
243
+ if scores is not None:
244
+ best_idx = int(np.argmax(np.asarray(scores)))
245
+ mask = masks[best_idx, 0]
246
+ else:
247
+ mask = masks[0, 0]
248
+ else:
249
+ logger.warning(f"Unexpected mask shape from SAM2: {masks.shape}")
250
+ mask = None
251
+
252
+ if mask is not None:
253
+ mask = _to_mask01(mask)
254
+ if float(mask.max()) > 0.1:
255
+ return np.ascontiguousarray(mask)
256
  else:
257
+ logger.warning("SAM2 mask too weak, using fallback")
 
 
 
 
 
 
 
 
258
  else:
259
  logger.warning("SAM2 returned no masks")
260
 
 
263
 
264
  if fallback_enabled:
265
  logger.debug("Using fallback segmentation")
266
+ return _simple_person_segmentation(frame_rgb)
267
  else:
268
  return np.ones((h, w), dtype=np.float32)
269
 
 
301
 
302
  if matanyone is not None and callable(matanyone):
303
  try:
304
+ rgb01 = _ensure_rgb01(frame) # RGB float32 in [0,1]
305
 
306
  # Stateful path (preferred)
307
  if frame_idx is not None:
 
310
  else:
311
  refined = matanyone(rgb01) # propagate without mask
312
  refined = _mask_to_2d(refined)
313
+ if float(refined.max()) > 0.1:
314
  return _postprocess_mask(refined)
315
  logger.warning("MatAnyone stateful refinement produced empty/weak mask; falling back.")
316
 
 
340
  except Exception as e:
341
  logger.debug(f"MatAnyone process failed: {e}")
342
 
343
+ if refined is not None and float(refined.max()) > 0.1:
344
  return _postprocess_mask(refined)
345
  else:
346
  logger.warning("MatAnyone refinement failed or produced empty mask")
 
356
 
357
  def _postprocess_mask(mask01: np.ndarray) -> np.ndarray:
358
  """Post-process mask to clean edges and remove artifacts"""
359
+ mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
360
 
361
  kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
362
  mask_uint8 = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel_close)
 
367
 
368
  mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)
369
 
370
+ out = mask_uint8.astype(np.float32) / 255.0
371
+ return np.ascontiguousarray(out)
372
 
373
  def _fallback_refine(mask01: np.ndarray) -> np.ndarray:
374
  """Simple fallback refinement"""
375
+ mask_uint8 = (np.clip(mask01, 0, 1) * 255).astype(np.uint8)
376
 
377
  mask_uint8 = cv2.bilateralFilter(mask_uint8, 9, 75, 75)
378
 
 
382
 
383
  mask_uint8 = cv2.GaussianBlur(mask_uint8, (5, 5), 1)
384
 
385
+ out = mask_uint8.astype(np.float32) / 255.0
386
+ return np.ascontiguousarray(out)
387
 
388
  # ----------------------------------------------------------------------------
389
+ # Compositing (expects RGB inputs)
390
  # ----------------------------------------------------------------------------
391
  def replace_background_hq(
392
  frame: np.ndarray,
 
395
  fallback_enabled: bool = True,
396
  **_compat,
397
  ) -> np.ndarray:
398
+ """High-quality background replacement with alpha blending (RGB in/out)."""
399
  try:
400
  H, W = frame.shape[:2]
401