MogensR commited on
Commit
f0dc2a6
·
1 Parent(s): a8c6577

Update utils/__init__.py

Browse files
Files changed (1) hide show
  1. utils/__init__.py +437 -0
utils/__init__.py CHANGED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complete utils/__init__.py with all required functions
3
+ Device-safe, SAM2↔MatAnyOne interop, and compositing helpers.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ import logging
10
+ import tempfile
11
+ from typing import Optional, Tuple, Dict, Any, List, Iterable, Callable
12
+
13
+ import cv2
14
+ import numpy as np
15
+ from PIL import Image
16
+ import torch
17
+
18
+ # NEW: interop + bridge imports (add these files from the previous steps)
19
+ from utils.interop import ensure_image_nchw, ensure_mask_for_matanyone, log_shape
20
+ from utils.mask_bridge import sam2_to_matanyone_mask
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Professional backgrounds configuration
25
+ PROFESSIONAL_BACKGROUNDS = {
26
+ "office": {"color": (240, 248, 255), "gradient": True},
27
+ "studio": {"color": (32, 32, 32), "gradient": False},
28
+ "nature": {"color": (34, 139, 34), "gradient": True},
29
+ "abstract": {"color": (75, 0, 130), "gradient": True},
30
+ "white": {"color": (255, 255, 255), "gradient": False},
31
+ "black": {"color": (0, 0, 0), "gradient": False},
32
+ }
33
+
34
+ # -------------------------------
35
+ # Utility: device
36
+ # -------------------------------
37
+ def _default_device() -> str:
38
+ return "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+
41
+ # -------------------------------
42
+ # Video validation
43
+ # -------------------------------
44
+ def validate_video_file(video_path: str) -> bool:
45
+ """Validate if video file is readable"""
46
+ try:
47
+ if not os.path.exists(video_path):
48
+ return False
49
+ cap = cv2.VideoCapture(video_path)
50
+ if not cap.isOpened():
51
+ return False
52
+ ret, frame = cap.read()
53
+ cap.release()
54
+ return ret and frame is not None
55
+ except Exception as e:
56
+ logger.error(f"Video validation failed: {e}")
57
+ return False
58
+
59
+
60
+ # -------------------------------
61
+ # SAM2 person segmentation (first-frame bootstrapping)
62
+ # -------------------------------
63
+ def segment_person_hq(
64
+ frame_rgb: np.ndarray,
65
+ *,
66
+ use_sam2: bool = True,
67
+ sam2_predictor: Any = None, # prefer injecting a ready predictor (from your ModelLoader)
68
+ ) -> Optional[np.ndarray]:
69
+ """
70
+ High-quality person segmentation for a single RGB frame.
71
+ Returns a float mask HxW in [0,1], or None on failure.
72
+
73
+ Preferred path: pass a ready-made SAM2 predictor (e.g., SAM2ImagePredictor).
74
+ Fallback path: simple color-based segmentation.
75
+ """
76
+ try:
77
+ if use_sam2 and sam2_predictor is not None:
78
+ try:
79
+ # SAM2 official predictors accept RGB np.uint8; set + predict.
80
+ # We use a simple center-point prompt; adapt to your UX if needed.
81
+ if hasattr(sam2_predictor, "set_image"):
82
+ sam2_predictor.set_image(frame_rgb)
83
+
84
+ h, w = frame_rgb.shape[:2]
85
+ center_point = np.array([[w // 2, h // 2]])
86
+ center_label = np.array([1])
87
+
88
+ # Try the SAM2 "predict" API (Meta’s predictor style)
89
+ if hasattr(sam2_predictor, "predict"):
90
+ out = sam2_predictor.predict(
91
+ point_coords=center_point,
92
+ point_labels=center_label,
93
+ multimask_output=True,
94
+ )
95
+ # Known Meta API returns (masks, scores, logits) as numpy
96
+ if isinstance(out, (list, tuple)) and len(out) >= 1:
97
+ masks = out[0]
98
+ if masks is None or len(masks) == 0:
99
+ return None
100
+ # masks: (M,H,W); pick best by area
101
+ areas = masks.reshape(masks.shape[0], -1).sum(axis=1)
102
+ best = int(np.argmax(areas))
103
+ m = masks[best].astype(np.float32)
104
+ m = (m >= 0.5).astype(np.float32)
105
+ return m
106
+
107
+ # Some wrappers expose processor/post_process; if you use that, call separately
108
+ logger.warning("SAM2 predictor provided but unknown API; falling back to simple segmentation")
109
+ except Exception as e:
110
+ logger.warning(f"SAM2 segmentation failed: {e}; falling back to simple method")
111
+
112
+ # Fallback: color-based person segmentation
113
+ return _simple_person_segmentation(frame_rgb)
114
+ except Exception as e:
115
+ logger.error(f"Person segmentation failed: {e}")
116
+ return None
117
+
118
+
119
+ def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
120
+ """Simple person segmentation using color-based methods"""
121
+ hsv = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2HSV)
122
+ # Green screen detection
123
+ lower_green = np.array([40, 40, 40])
124
+ upper_green = np.array([80, 255, 255])
125
+ green_mask = cv2.inRange(hsv, lower_green, upper_green)
126
+ # White background detection
127
+ lower_white = np.array([0, 0, 200])
128
+ upper_white = np.array([180, 30, 255])
129
+ white_mask = cv2.inRange(hsv, lower_white, upper_white)
130
+ # Combine + invert to person
131
+ bg_mask = cv2.bitwise_or(green_mask, white_mask)
132
+ person_mask = cv2.bitwise_not(bg_mask)
133
+ # Morph clean
134
+ kernel = np.ones((5, 5), np.uint8)
135
+ person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_CLOSE, kernel)
136
+ person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_OPEN, kernel)
137
+ return (person_mask.astype(np.float32) / 255.0)
138
+
139
+
140
+ # -------------------------------
141
+ # MatAnyOne integration (first-frame + per-frame)
142
+ # -------------------------------
143
+ def refine_mask_hq(
144
+ mask_hw_float01: np.ndarray,
145
+ frame_rgb: np.ndarray,
146
+ *,
147
+ use_matanyone: bool = True,
148
+ mat_core: Any = None, # prefer injecting a ready InferenceCore from ModelLoader
149
+ first_frame: bool = True,
150
+ device: str | None = None,
151
+ ) -> np.ndarray:
152
+ """
153
+ High-quality mask refinement for a single frame + mask pair using MatAnyOne.
154
+ Returns refined mask HxW float in [0,1]. If use_matanyone=False or mat_core is None,
155
+ falls back to simple refinement.
156
+
157
+ NOTE: For videos, prefer using seed/refine helpers below that keep temporal memory.
158
+ """
159
+ try:
160
+ if not use_matanyone or mat_core is None:
161
+ return _simple_mask_refinement(mask_hw_float01, frame_rgb)
162
+
163
+ device = device or _default_device()
164
+
165
+ # Image → (1,3,H,W)
166
+ img_nchw = ensure_image_nchw(torch.from_numpy(frame_rgb).to(device), device=device, want_batched=True)
167
+ log_shape("refine.image_nchw", img_nchw)
168
+
169
+ # Mask → (1,H,W)
170
+ mask_t = torch.from_numpy(mask_hw_float01).to(device)
171
+ mask_c_hw = ensure_mask_for_matanyone(mask_t, idx_mask=False, threshold=0.5, keep_soft=False, device=device)
172
+ log_shape("refine.mask_c_hw", mask_c_hw)
173
+
174
+ # MatAnyOne step (we let the global guard in ModelLoader do additional checks)
175
+ pred = mat_core.step(
176
+ image=img_nchw[0], # CHW
177
+ mask=mask_c_hw if first_frame else None,
178
+ idx_mask=False,
179
+ matting=True,
180
+ first_frame_pred=bool(first_frame),
181
+ )
182
+
183
+ # Try to decode output into an alpha HxW float mask
184
+ refined = _coerce_pred_to_mask(pred, device=device)
185
+ if refined is None:
186
+ # If the core doesn’t return alpha directly, fall back
187
+ return _simple_mask_refinement(mask_hw_float01, frame_rgb)
188
+
189
+ return refined
190
+ except Exception as e:
191
+ logger.warning(f"MatAnyOne refinement failed: {e}; using simple refinement")
192
+ return _simple_mask_refinement(mask_hw_float01, frame_rgb)
193
+
194
+
195
+ def _coerce_pred_to_mask(pred: Any, device: str = "cuda") -> Optional[np.ndarray]:
196
+ """
197
+ Best-effort: extract HxW float mask from MatAnyOne output variants.
198
+ Supports torch.Tensor, numpy, PIL, or dict with common keys.
199
+ """
200
+ try:
201
+ # Dict-like: look for common keys
202
+ if isinstance(pred, dict):
203
+ for k in ("alpha", "mask", "matte", "mattes"):
204
+ if k in pred:
205
+ v = pred[k]
206
+ return _coerce_pred_to_mask(v, device=device)
207
+
208
+ # Torch tensor
209
+ if torch.is_tensor(pred):
210
+ t = pred.detach()
211
+ # possible shapes: (H,W), (1,H,W), (N,1,H,W)
212
+ if t.ndim == 4 and t.shape[1] == 1:
213
+ t = t[0, 0]
214
+ elif t.ndim == 3 and t.shape[0] == 1:
215
+ t = t[0]
216
+ t = t.float().clamp(0, 1).to("cpu").numpy()
217
+ if t.ndim == 2:
218
+ return t.astype(np.float32)
219
+
220
+ # Numpy
221
+ if isinstance(pred, np.ndarray):
222
+ a = pred
223
+ if a.ndim == 3 and a.shape[0] == 1:
224
+ a = a[0]
225
+ if a.ndim == 2:
226
+ a = a.astype(np.float32)
227
+ if a.max() > 1.0:
228
+ a = a / 255.0
229
+ return np.clip(a, 0.0, 1.0)
230
+
231
+ # PIL Image
232
+ if isinstance(pred, Image.Image):
233
+ a = np.array(pred).astype(np.float32)
234
+ if a.ndim == 3 and a.shape[2] == 1:
235
+ a = a[:, :, 0]
236
+ if a.ndim == 2:
237
+ if a.max() > 1.0:
238
+ a = a / 255.0
239
+ return np.clip(a, 0.0, 1.0)
240
+
241
+ except Exception as e:
242
+ logger.debug(f"_coerce_pred_to_mask fallback due to: {e}")
243
+ return None
244
+
245
+
246
+ def _simple_mask_refinement(mask: np.ndarray, frame_rgb: np.ndarray) -> np.ndarray:
247
+ """Simple mask refinement using OpenCV operations"""
248
+ mask_uint8 = (np.clip(mask, 0.0, 1.0) * 255).astype(np.uint8)
249
+ mask_blurred = cv2.GaussianBlur(mask_uint8, (5, 5), 0)
250
+ mask_refined = cv2.bilateralFilter(mask_blurred, 9, 75, 75)
251
+ return (mask_refined.astype(np.float32) / 255.0)
252
+
253
+
254
+ # -------------------------------
255
+ # Two-stage video helpers (seed + propagate)
256
+ # -------------------------------
257
+ @torch.inference_mode()
258
+ def seed_with_sam2_post_masks(
259
+ core: Any,
260
+ frame0_rgb: np.ndarray, # HxWx3 uint8 RGB
261
+ sam2_post_masks: torch.Tensor, # (1,M,H,W)
262
+ iou_scores: Optional[torch.Tensor] = None,
263
+ *,
264
+ device: str | None = None,
265
+ idx_mask: bool = False,
266
+ threshold: float = 0.5,
267
+ keep_soft: bool = False,
268
+ ) -> Any:
269
+ """
270
+ Seed MatAnyOne on the first frame using SAM2 post-processed masks (preferred).
271
+ """
272
+ device = device or _default_device()
273
+ img0 = ensure_image_nchw(torch.from_numpy(frame0_rgb).to(device), device=device, want_batched=True)
274
+ log_shape("seed.image_nchw", img0)
275
+
276
+ if idx_mask:
277
+ m_c_hw = sam2_to_matanyone_mask(sam2_post_masks.to(device), iou_scores, threshold, "single", keep_soft=False)
278
+ idx_hw = ensure_mask_for_matanyone(m_c_hw, idx_mask=True, device=device, threshold=threshold)
279
+ log_shape("seed.idx_hw", idx_hw)
280
+ return core.step(
281
+ image=img0[0],
282
+ mask=idx_hw,
283
+ idx_mask=True,
284
+ matting=True,
285
+ first_frame_pred=True,
286
+ )
287
+ else:
288
+ m_c_hw = sam2_to_matanyone_mask(sam2_post_masks.to(device), iou_scores, threshold, "single", keep_soft=keep_soft)
289
+ log_shape("seed.mask_c_hw", m_c_hw)
290
+ return core.step(
291
+ image=img0[0],
292
+ mask=m_c_hw,
293
+ idx_mask=False,
294
+ matting=True,
295
+ first_frame_pred=True,
296
+ )
297
+
298
+
299
+ @torch.inference_mode()
300
+ def refine_next_frame(core: Any, frame_rgb: np.ndarray, *, device: str | None = None) -> Any:
301
+ """Step MatAnyOne forward on a subsequent frame (no mask; uses memory)."""
302
+ device = device or _default_device()
303
+ img = ensure_image_nchw(torch.from_numpy(frame_rgb).to(device), device=device, want_batched=True)
304
+ log_shape("refine.image_nchw", img)
305
+ return core.step(
306
+ image=img[0],
307
+ mask=None,
308
+ idx_mask=False,
309
+ matting=True,
310
+ first_frame_pred=False,
311
+ )
312
+
313
+
314
+ @torch.inference_mode()
315
+ def run_two_stage_matting(
316
+ core: Any,
317
+ frames_rgb_iter: Iterable[np.ndarray], # iterable of HxWx3 uint8 RGB
318
+ sam2_post_masks: torch.Tensor, # (1,M,H,W) for the first frame
319
+ iou_scores: Optional[torch.Tensor] = None,
320
+ *,
321
+ device: str | None = None,
322
+ on_pred: Optional[Callable[[int, Any], None]] = None,
323
+ progress: Optional[Callable[[int, Optional[int]], None]] = None,
324
+ total_frames: Optional[int] = None,
325
+ idx_mask: bool = False,
326
+ threshold: float = 0.5,
327
+ keep_soft: bool = False,
328
+ ) -> None:
329
+ """
330
+ Convenience runner for videos:
331
+ - Seeds on the first frame using SAM2 post-process outputs
332
+ - Propagates across the rest (one frame per step)
333
+ """
334
+ device = device or _default_device()
335
+ it = iter(frames_rgb_iter)
336
+ try:
337
+ f0 = next(it)
338
+ except StopIteration:
339
+ return
340
+
341
+ pred0 = seed_with_sam2_post_masks(
342
+ core, f0, sam2_post_masks, iou_scores,
343
+ device=device, idx_mask=idx_mask, threshold=threshold, keep_soft=keep_soft
344
+ )
345
+ if on_pred: on_pred(0, pred0)
346
+ if progress: progress(1, total_frames)
347
+
348
+ t = 1
349
+ for frgb in it:
350
+ pred = refine_next_frame(core, frgb, device=device)
351
+ if on_pred: on_pred(t, pred)
352
+ t += 1
353
+ if progress: progress(t, total_frames)
354
+
355
+
356
+ # -------------------------------
357
+ # Background replacement
358
+ # -------------------------------
359
+ def replace_background_hq(frame_rgb: np.ndarray, mask_hw_float01: np.ndarray, background_rgb: np.ndarray) -> np.ndarray:
360
+ """High-quality background replacement with proper compositing"""
361
+ try:
362
+ h, w = frame_rgb.shape[:2]
363
+ background_resized = cv2.resize(background_rgb, (w, h))
364
+
365
+ # Ensure mask is HxW float in [0,1]
366
+ if mask_hw_float01.ndim == 3:
367
+ mask_hw_float01 = mask_hw_float01[..., 0]
368
+ m = np.clip(mask_hw_float01.astype(np.float32), 0.0, 1.0)
369
+
370
+ # Feather edges lightly
371
+ m_uint8 = (m * 255).astype(np.uint8)
372
+ m_feather = cv2.GaussianBlur(m_uint8, (7, 7), 0).astype(np.float32) / 255.0
373
+ m3 = np.stack([m_feather] * 3, axis=-1)
374
+
375
+ result = frame_rgb.astype(np.float32) * m3 + background_resized.astype(np.float32) * (1.0 - m3)
376
+ return np.clip(result, 0, 255).astype(np.uint8)
377
+ except Exception as e:
378
+ logger.error(f"Background replacement failed: {e}")
379
+ return frame_rgb
380
+
381
+
382
+ # -------------------------------
383
+ # Background generators
384
+ # -------------------------------
385
+ def create_professional_background(bg_type: str, width: int, height: int) -> np.ndarray:
386
+ """Create professional background of specified type and size"""
387
+ try:
388
+ if bg_type not in PROFESSIONAL_BACKGROUNDS:
389
+ bg_type = "office" # Default fallback
390
+
391
+ config = PROFESSIONAL_BACKGROUNDS[bg_type]
392
+ color = config["color"]
393
+ use_gradient = config["gradient"]
394
+
395
+ if use_gradient:
396
+ background = _create_gradient_background(color, width, height)
397
+ else:
398
+ background = np.full((height, width, 3), color, dtype=np.uint8)
399
+
400
+ return background
401
+ except Exception as e:
402
+ logger.error(f"Background creation failed: {e}")
403
+ return np.full((height, width, 3), (255, 255, 255), dtype=np.uint8)
404
+
405
+
406
+ def _create_gradient_background(base_color: Tuple[int, int, int], width: int, height: int) -> np.ndarray:
407
+ """Create a vertical gradient background from base color"""
408
+ r, g, b = base_color
409
+ dark = (int(r * 0.7), int(g * 0.7), int(b * 0.7))
410
+ bg = np.zeros((height, width, 3), dtype=np.uint8)
411
+ for y in range(height):
412
+ t = y / max(height, 1)
413
+ bg[y, :] = [
414
+ int(dark[0] * (1 - t) + r * t),
415
+ int(dark[1] * (1 - t) + g * t),
416
+ int(dark[2] * (1 - t) + b * t),
417
+ ]
418
+ return bg
419
+
420
+
421
+ # -------------------------------
422
+ # Exports
423
+ # -------------------------------
424
+ __all__ = [
425
+ # segment / refine (single-frame)
426
+ "segment_person_hq",
427
+ "refine_mask_hq",
428
+ # video runner + steps
429
+ "seed_with_sam2_post_masks",
430
+ "refine_next_frame",
431
+ "run_two_stage_matting",
432
+ # backgrounds & utils
433
+ "replace_background_hq",
434
+ "create_professional_background",
435
+ "PROFESSIONAL_BACKGROUNDs" if False else "PROFESSIONAL_BACKGROUNDS",
436
+ "validate_video_file",
437
+ ]