MogensR commited on
Commit
6135232
·
1 Parent(s): 8ca115b

Update utils/__init__.py

Browse files
Files changed (1) hide show
  1. utils/__init__.py +419 -425
utils/__init__.py CHANGED
@@ -1,437 +1,431 @@
 
1
  """
2
- Complete utils/__init__.py with all required functions
3
- Device-safe, SAM2↔MatAnyOne interop, and compositing helpers.
4
  """
5
 
6
- from __future__ import annotations
7
-
8
- import os
9
- import logging
10
- import tempfile
11
- from typing import Optional, Tuple, Dict, Any, List, Iterable, Callable
12
-
13
- import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  import numpy as np
 
15
  from PIL import Image
16
- import torch
17
-
18
- # NEW: interop + bridge imports (add these files from the previous steps)
19
- from utils.interop import ensure_image_nchw, ensure_mask_for_matanyone, log_shape
20
- from utils.mask_bridge import sam2_to_matanyone_mask
21
-
22
- logger = logging.getLogger(__name__)
23
-
24
- # Professional backgrounds configuration
25
- PROFESSIONAL_BACKGROUNDS = {
26
- "office": {"color": (240, 248, 255), "gradient": True},
27
- "studio": {"color": (32, 32, 32), "gradient": False},
28
- "nature": {"color": (34, 139, 34), "gradient": True},
29
- "abstract": {"color": (75, 0, 130), "gradient": True},
30
- "white": {"color": (255, 255, 255), "gradient": False},
31
- "black": {"color": (0, 0, 0), "gradient": False},
32
- }
33
-
34
- # -------------------------------
35
- # Utility: device
36
- # -------------------------------
37
- def _default_device() -> str:
38
- return "cuda" if torch.cuda.is_available() else "cpu"
39
-
40
-
41
- # -------------------------------
42
- # Video validation
43
- # -------------------------------
44
- def validate_video_file(video_path: str) -> bool:
45
- """Validate if video file is readable"""
46
- try:
47
- if not os.path.exists(video_path):
48
- return False
49
- cap = cv2.VideoCapture(video_path)
50
- if not cap.isOpened():
51
- return False
52
- ret, frame = cap.read()
53
- cap.release()
54
- return ret and frame is not None
55
- except Exception as e:
56
- logger.error(f"Video validation failed: {e}")
57
- return False
58
-
59
-
60
- # -------------------------------
61
- # SAM2 person segmentation (first-frame bootstrapping)
62
- # -------------------------------
63
- def segment_person_hq(
64
- frame_rgb: np.ndarray,
65
- *,
66
- use_sam2: bool = True,
67
- sam2_predictor: Any = None, # prefer injecting a ready predictor (from your ModelLoader)
68
- ) -> Optional[np.ndarray]:
69
- """
70
- High-quality person segmentation for a single RGB frame.
71
- Returns a float mask HxW in [0,1], or None on failure.
72
-
73
- Preferred path: pass a ready-made SAM2 predictor (e.g., SAM2ImagePredictor).
74
- Fallback path: simple color-based segmentation.
75
- """
76
- try:
77
- if use_sam2 and sam2_predictor is not None:
78
- try:
79
- # SAM2 official predictors accept RGB np.uint8; set + predict.
80
- # We use a simple center-point prompt; adapt to your UX if needed.
81
- if hasattr(sam2_predictor, "set_image"):
82
- sam2_predictor.set_image(frame_rgb)
83
-
84
- h, w = frame_rgb.shape[:2]
85
- center_point = np.array([[w // 2, h // 2]])
86
- center_label = np.array([1])
87
-
88
- # Try the SAM2 "predict" API (Meta’s predictor style)
89
- if hasattr(sam2_predictor, "predict"):
90
- out = sam2_predictor.predict(
91
- point_coords=center_point,
92
- point_labels=center_label,
93
- multimask_output=True,
94
- )
95
- # Known Meta API returns (masks, scores, logits) as numpy
96
- if isinstance(out, (list, tuple)) and len(out) >= 1:
97
- masks = out[0]
98
- if masks is None or len(masks) == 0:
99
- return None
100
- # masks: (M,H,W); pick best by area
101
- areas = masks.reshape(masks.shape[0], -1).sum(axis=1)
102
- best = int(np.argmax(areas))
103
- m = masks[best].astype(np.float32)
104
- m = (m >= 0.5).astype(np.float32)
105
- return m
106
-
107
- # Some wrappers expose processor/post_process; if you use that, call separately
108
- logger.warning("SAM2 predictor provided but unknown API; falling back to simple segmentation")
109
- except Exception as e:
110
- logger.warning(f"SAM2 segmentation failed: {e}; falling back to simple method")
111
-
112
- # Fallback: color-based person segmentation
113
- return _simple_person_segmentation(frame_rgb)
114
- except Exception as e:
115
- logger.error(f"Person segmentation failed: {e}")
116
- return None
117
-
118
-
119
- def _simple_person_segmentation(frame_rgb: np.ndarray) -> np.ndarray:
120
- """Simple person segmentation using color-based methods"""
121
- hsv = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2HSV)
122
- # Green screen detection
123
- lower_green = np.array([40, 40, 40])
124
- upper_green = np.array([80, 255, 255])
125
- green_mask = cv2.inRange(hsv, lower_green, upper_green)
126
- # White background detection
127
- lower_white = np.array([0, 0, 200])
128
- upper_white = np.array([180, 30, 255])
129
- white_mask = cv2.inRange(hsv, lower_white, upper_white)
130
- # Combine + invert to person
131
- bg_mask = cv2.bitwise_or(green_mask, white_mask)
132
- person_mask = cv2.bitwise_not(bg_mask)
133
- # Morph clean
134
- kernel = np.ones((5, 5), np.uint8)
135
- person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_CLOSE, kernel)
136
- person_mask = cv2.morphologyEx(person_mask, cv2.MORPH_OPEN, kernel)
137
- return (person_mask.astype(np.float32) / 255.0)
138
-
139
-
140
- # -------------------------------
141
- # MatAnyOne integration (first-frame + per-frame)
142
- # -------------------------------
143
- def refine_mask_hq(
144
- mask_hw_float01: np.ndarray,
145
- frame_rgb: np.ndarray,
146
- *,
147
- use_matanyone: bool = True,
148
- mat_core: Any = None, # prefer injecting a ready InferenceCore from ModelLoader
149
- first_frame: bool = True,
150
- device: str | None = None,
151
- ) -> np.ndarray:
152
- """
153
- High-quality mask refinement for a single frame + mask pair using MatAnyOne.
154
- Returns refined mask HxW float in [0,1]. If use_matanyone=False or mat_core is None,
155
- falls back to simple refinement.
156
-
157
- NOTE: For videos, prefer using seed/refine helpers below that keep temporal memory.
158
- """
159
- try:
160
- if not use_matanyone or mat_core is None:
161
- return _simple_mask_refinement(mask_hw_float01, frame_rgb)
162
-
163
- device = device or _default_device()
164
-
165
- # Image → (1,3,H,W)
166
- img_nchw = ensure_image_nchw(torch.from_numpy(frame_rgb).to(device), device=device, want_batched=True)
167
- log_shape("refine.image_nchw", img_nchw)
168
-
169
- # Mask → (1,H,W)
170
- mask_t = torch.from_numpy(mask_hw_float01).to(device)
171
- mask_c_hw = ensure_mask_for_matanyone(mask_t, idx_mask=False, threshold=0.5, keep_soft=False, device=device)
172
- log_shape("refine.mask_c_hw", mask_c_hw)
173
-
174
- # MatAnyOne step (we let the global guard in ModelLoader do additional checks)
175
- pred = mat_core.step(
176
- image=img_nchw[0], # CHW
177
- mask=mask_c_hw if first_frame else None,
178
- idx_mask=False,
179
- matting=True,
180
- first_frame_pred=bool(first_frame),
181
- )
182
 
183
- # Try to decode output into an alpha HxW float mask
184
- refined = _coerce_pred_to_mask(pred, device=device)
185
- if refined is None:
186
- # If the core doesn’t return alpha directly, fall back
187
- return _simple_mask_refinement(mask_hw_float01, frame_rgb)
188
-
189
- return refined
190
- except Exception as e:
191
- logger.warning(f"MatAnyOne refinement failed: {e}; using simple refinement")
192
- return _simple_mask_refinement(mask_hw_float01, frame_rgb)
193
-
194
-
195
- def _coerce_pred_to_mask(pred: Any, device: str = "cuda") -> Optional[np.ndarray]:
196
- """
197
- Best-effort: extract HxW float mask from MatAnyOne output variants.
198
- Supports torch.Tensor, numpy, PIL, or dict with common keys.
199
- """
200
- try:
201
- # Dict-like: look for common keys
202
- if isinstance(pred, dict):
203
- for k in ("alpha", "mask", "matte", "mattes"):
204
- if k in pred:
205
- v = pred[k]
206
- return _coerce_pred_to_mask(v, device=device)
207
-
208
- # Torch tensor
209
- if torch.is_tensor(pred):
210
- t = pred.detach()
211
- # possible shapes: (H,W), (1,H,W), (N,1,H,W)
212
- if t.ndim == 4 and t.shape[1] == 1:
213
- t = t[0, 0]
214
- elif t.ndim == 3 and t.shape[0] == 1:
215
- t = t[0]
216
- t = t.float().clamp(0, 1).to("cpu").numpy()
217
- if t.ndim == 2:
218
- return t.astype(np.float32)
219
-
220
- # Numpy
221
- if isinstance(pred, np.ndarray):
222
- a = pred
223
- if a.ndim == 3 and a.shape[0] == 1:
224
- a = a[0]
225
- if a.ndim == 2:
226
- a = a.astype(np.float32)
227
- if a.max() > 1.0:
228
- a = a / 255.0
229
- return np.clip(a, 0.0, 1.0)
230
-
231
- # PIL Image
232
- if isinstance(pred, Image.Image):
233
- a = np.array(pred).astype(np.float32)
234
- if a.ndim == 3 and a.shape[2] == 1:
235
- a = a[:, :, 0]
236
- if a.ndim == 2:
237
- if a.max() > 1.0:
238
- a = a / 255.0
239
- return np.clip(a, 0.0, 1.0)
240
-
241
- except Exception as e:
242
- logger.debug(f"_coerce_pred_to_mask fallback due to: {e}")
243
- return None
244
-
245
-
246
- def _simple_mask_refinement(mask: np.ndarray, frame_rgb: np.ndarray) -> np.ndarray:
247
- """Simple mask refinement using OpenCV operations"""
248
- mask_uint8 = (np.clip(mask, 0.0, 1.0) * 255).astype(np.uint8)
249
- mask_blurred = cv2.GaussianBlur(mask_uint8, (5, 5), 0)
250
- mask_refined = cv2.bilateralFilter(mask_blurred, 9, 75, 75)
251
- return (mask_refined.astype(np.float32) / 255.0)
252
-
253
-
254
- # -------------------------------
255
- # Two-stage video helpers (seed + propagate)
256
- # -------------------------------
257
- @torch.inference_mode()
258
- def seed_with_sam2_post_masks(
259
- core: Any,
260
- frame0_rgb: np.ndarray, # HxWx3 uint8 RGB
261
- sam2_post_masks: torch.Tensor, # (1,M,H,W)
262
- iou_scores: Optional[torch.Tensor] = None,
263
- *,
264
- device: str | None = None,
265
- idx_mask: bool = False,
266
- threshold: float = 0.5,
267
- keep_soft: bool = False,
268
- ) -> Any:
269
- """
270
- Seed MatAnyOne on the first frame using SAM2 post-processed masks (preferred).
271
- """
272
- device = device or _default_device()
273
- img0 = ensure_image_nchw(torch.from_numpy(frame0_rgb).to(device), device=device, want_batched=True)
274
- log_shape("seed.image_nchw", img0)
275
-
276
- if idx_mask:
277
- m_c_hw = sam2_to_matanyone_mask(sam2_post_masks.to(device), iou_scores, threshold, "single", keep_soft=False)
278
- idx_hw = ensure_mask_for_matanyone(m_c_hw, idx_mask=True, device=device, threshold=threshold)
279
- log_shape("seed.idx_hw", idx_hw)
280
- return core.step(
281
- image=img0[0],
282
- mask=idx_hw,
283
- idx_mask=True,
284
- matting=True,
285
- first_frame_pred=True,
286
  )
287
- else:
288
- m_c_hw = sam2_to_matanyone_mask(sam2_post_masks.to(device), iou_scores, threshold, "single", keep_soft=keep_soft)
289
- log_shape("seed.mask_c_hw", m_c_hw)
290
- return core.step(
291
- image=img0[0],
292
- mask=m_c_hw,
293
- idx_mask=False,
294
- matting=True,
295
- first_frame_pred=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  )
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- @torch.inference_mode()
300
- def refine_next_frame(core: Any, frame_rgb: np.ndarray, *, device: str | None = None) -> Any:
301
- """Step MatAnyOne forward on a subsequent frame (no mask; uses memory)."""
302
- device = device or _default_device()
303
- img = ensure_image_nchw(torch.from_numpy(frame_rgb).to(device), device=device, want_batched=True)
304
- log_shape("refine.image_nchw", img)
305
- return core.step(
306
- image=img[0],
307
- mask=None,
308
- idx_mask=False,
309
- matting=True,
310
- first_frame_pred=False,
311
- )
312
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- @torch.inference_mode()
315
- def run_two_stage_matting(
316
- core: Any,
317
- frames_rgb_iter: Iterable[np.ndarray], # iterable of HxWx3 uint8 RGB
318
- sam2_post_masks: torch.Tensor, # (1,M,H,W) for the first frame
319
- iou_scores: Optional[torch.Tensor] = None,
320
- *,
321
- device: str | None = None,
322
- on_pred: Optional[Callable[[int, Any], None]] = None,
323
- progress: Optional[Callable[[int, Optional[int]], None]] = None,
324
- total_frames: Optional[int] = None,
325
- idx_mask: bool = False,
326
- threshold: float = 0.5,
327
- keep_soft: bool = False,
328
- ) -> None:
329
- """
330
- Convenience runner for videos:
331
- - Seeds on the first frame using SAM2 post-process outputs
332
- - Propagates across the rest (one frame per step)
333
- """
334
- device = device or _default_device()
335
- it = iter(frames_rgb_iter)
336
- try:
337
- f0 = next(it)
338
- except StopIteration:
339
- return
340
-
341
- pred0 = seed_with_sam2_post_masks(
342
- core, f0, sam2_post_masks, iou_scores,
343
- device=device, idx_mask=idx_mask, threshold=threshold, keep_soft=keep_soft
344
  )
345
- if on_pred: on_pred(0, pred0)
346
- if progress: progress(1, total_frames)
347
-
348
- t = 1
349
- for frgb in it:
350
- pred = refine_next_frame(core, frgb, device=device)
351
- if on_pred: on_pred(t, pred)
352
- t += 1
353
- if progress: progress(t, total_frames)
354
-
355
-
356
- # -------------------------------
357
- # Background replacement
358
- # -------------------------------
359
- def replace_background_hq(frame_rgb: np.ndarray, mask_hw_float01: np.ndarray, background_rgb: np.ndarray) -> np.ndarray:
360
- """High-quality background replacement with proper compositing"""
361
- try:
362
- h, w = frame_rgb.shape[:2]
363
- background_resized = cv2.resize(background_rgb, (w, h))
364
-
365
- # Ensure mask is HxW float in [0,1]
366
- if mask_hw_float01.ndim == 3:
367
- mask_hw_float01 = mask_hw_float01[..., 0]
368
- m = np.clip(mask_hw_float01.astype(np.float32), 0.0, 1.0)
369
-
370
- # Feather edges lightly
371
- m_uint8 = (m * 255).astype(np.uint8)
372
- m_feather = cv2.GaussianBlur(m_uint8, (7, 7), 0).astype(np.float32) / 255.0
373
- m3 = np.stack([m_feather] * 3, axis=-1)
374
-
375
- result = frame_rgb.astype(np.float32) * m3 + background_resized.astype(np.float32) * (1.0 - m3)
376
- return np.clip(result, 0, 255).astype(np.uint8)
377
- except Exception as e:
378
- logger.error(f"Background replacement failed: {e}")
379
- return frame_rgb
380
-
381
-
382
- # -------------------------------
383
- # Background generators
384
- # -------------------------------
385
- def create_professional_background(bg_type: str, width: int, height: int) -> np.ndarray:
386
- """Create professional background of specified type and size"""
387
- try:
388
- if bg_type not in PROFESSIONAL_BACKGROUNDS:
389
- bg_type = "office" # Default fallback
390
-
391
- config = PROFESSIONAL_BACKGROUNDS[bg_type]
392
- color = config["color"]
393
- use_gradient = config["gradient"]
394
-
395
- if use_gradient:
396
- background = _create_gradient_background(color, width, height)
397
- else:
398
- background = np.full((height, width, 3), color, dtype=np.uint8)
399
-
400
- return background
401
- except Exception as e:
402
- logger.error(f"Background creation failed: {e}")
403
- return np.full((height, width, 3), (255, 255, 255), dtype=np.uint8)
404
-
405
-
406
- def _create_gradient_background(base_color: Tuple[int, int, int], width: int, height: int) -> np.ndarray:
407
- """Create a vertical gradient background from base color"""
408
- r, g, b = base_color
409
- dark = (int(r * 0.7), int(g * 0.7), int(b * 0.7))
410
- bg = np.zeros((height, width, 3), dtype=np.uint8)
411
- for y in range(height):
412
- t = y / max(height, 1)
413
- bg[y, :] = [
414
- int(dark[0] * (1 - t) + r * t),
415
- int(dark[1] * (1 - t) + g * t),
416
- int(dark[2] * (1 - t) + b * t),
417
- ]
418
- return bg
419
-
420
-
421
- # -------------------------------
422
- # Exports
423
- # -------------------------------
424
- __all__ = [
425
- # segment / refine (single-frame)
426
- "segment_person_hq",
427
- "refine_mask_hq",
428
- # video runner + steps
429
- "seed_with_sam2_post_masks",
430
- "refine_next_frame",
431
- "run_two_stage_matting",
432
- # backgrounds & utils
433
- "replace_background_hq",
434
- "create_professional_background",
435
- "PROFESSIONAL_BACKGROUNDs" if False else "PROFESSIONAL_BACKGROUNDS",
436
- "validate_video_file",
437
- ]
 
1
+ #!/usr/bin/env python3
2
  """
3
+ BackgroundFX Pro - CSP-Safe Application Entry Point
4
+ Now with: live background preview + sources: Preset / Upload / Gradient / AI Generate
5
  """
6
 
7
+ import early_env # <<< must be FIRST
8
+
9
+ import os, time
10
+ from typing import Optional, Dict, Any, Callable, Tuple
11
+
12
+ # 1) CSP-safe Gradio env
13
+ os.environ['GRADIO_ALLOW_FLAGGING'] = 'never'
14
+ os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
15
+ os.environ['GRADIO_SERVER_NAME'] = '0.0.0.0'
16
+ os.environ['GRADIO_SERVER_PORT'] = '7860'
17
+
18
+ # 2) Gradio schema patch
19
+ try:
20
+ import gradio_client.utils as gc_utils
21
+ _orig_get_type = gc_utils.get_type
22
+ def _patched_get_type(schema):
23
+ if not isinstance(schema, dict):
24
+ if isinstance(schema, bool): return "boolean"
25
+ if isinstance(schema, str): return "string"
26
+ if isinstance(schema, (int, float)): return "number"
27
+ return "string"
28
+ return _orig_get_type(schema)
29
+ gc_utils.get_type = _patched_get_type
30
+ except Exception:
31
+ pass
32
+
33
+ # 3) Logging early
34
+ from utils.logging_setup import setup_logging, make_logger
35
+ setup_logging(app_name="backgroundfx")
36
+ logger = make_logger("entrypoint")
37
+ logger.info("Entrypoint starting…")
38
+
39
+ # 4) Imports
40
+ from core.exceptions import ModelLoadingError, VideoProcessingError
41
+ from config.app_config import get_config
42
+ from utils.hardware.device_manager import DeviceManager
43
+ from utils.system.memory_manager import MemoryManager
44
+ from models.loaders.model_loader import ModelLoader
45
+ from processing.video.video_processor import CoreVideoProcessor, ProcessorConfig
46
+ from processing.audio.audio_processor import AudioProcessor
47
+
48
+ # Background helpers
49
+ from utils import PROFESSIONAL_BACKGROUNDS, validate_video_file, create_professional_background
50
+ # Gradient helper (add to utils; fallback here for preview only if missing)
51
+ try:
52
+ from utils import create_gradient_background
53
+ except Exception:
54
+ def create_gradient_background(spec: Dict[str, Any], width: int, height: int):
55
+ # Lightweight fallback preview (linear only)
56
+ import numpy as np
57
+ import cv2
58
+ def _to_rgb(c):
59
+ if isinstance(c, (list, tuple)) and len(c) == 3:
60
+ return tuple(int(x) for x in c)
61
+ if isinstance(c, str) and c.startswith("#") and len(c) == 7:
62
+ return tuple(int(c[i:i+2], 16) for i in (1,3,5))
63
+ return (255, 255, 255)
64
+ start = _to_rgb(spec.get("start", "#222222"))
65
+ end = _to_rgb(spec.get("end", "#888888"))
66
+ angle = float(spec.get("angle_deg", 0))
67
+ bg = np.zeros((height, width, 3), np.uint8)
68
+ for y in range(height):
69
+ t = y / max(1, height - 1)
70
+ r = int(start[0] * (1 - t) + end[0] * t)
71
+ g = int(start[1] * (1 - t) + end[1] * t)
72
+ b = int(start[2] * (1 - t) + end[2] * t)
73
+ bg[y, :] = (r, g, b)
74
+ center = (width / 2, height / 2)
75
+ rot = cv2.getRotationMatrix2D(center, angle, 1.0)
76
+ return cv2.warpAffine(bg, rot, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
77
+
78
+ # 5) CSP-safe fallbacks for models
79
+ class CSPSafeSAM2:
80
+ def set_image(self, image):
81
+ self.shape = getattr(image, 'shape', (512, 512, 3))
82
+ def predict(self, point_coords=None, point_labels=None, box=None, multimask_output=True, **kwargs):
83
+ import numpy as np
84
+ h, w = self.shape[:2] if hasattr(self, 'shape') else (512, 512)
85
+ n = 3 if multimask_output else 1
86
+ return np.ones((n, h, w), dtype=bool), np.array([0.9, 0.8, 0.7][:n]), np.ones((n, h, w), dtype=np.float32)
87
+
88
+ class CSPSafeMatAnyone:
89
+ def step(self, image_tensor, mask_tensor=None, objects=None, first_frame_pred=False, **kwargs):
90
+ import torch
91
+ if hasattr(image_tensor, "shape"):
92
+ if len(image_tensor.shape) == 3:
93
+ _, H, W = image_tensor.shape
94
+ elif len(image_tensor.shape) == 4:
95
+ _, _, H, W = image_tensor.shape
96
+ else:
97
+ H, W = 256, 256
98
+ else:
99
+ H, W = 256, 256
100
+ return torch.ones((1, 1, H, W))
101
+ def output_prob_to_mask(self, output_prob):
102
+ return (output_prob > 0.5).float()
103
+ def process(self, image, mask, **kwargs):
104
+ return mask
105
+
106
+ # ---------- helpers for UI ----------
107
  import numpy as np
108
+ import cv2
109
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ PREVIEW_W, PREVIEW_H = 640, 360 # 16:9
112
+
113
+ from typing import Tuple
114
+ def _hex_to_rgb(x: str) -> Tuple[int, int, int]:
115
+ x = (x or "").strip()
116
+ if x.startswith("#") and len(x) == 7:
117
+ return tuple(int(x[i:i+2], 16) for i in (1, 3, 5))
118
+ return (255, 255, 255)
119
+
120
+ def _np_to_pil(arr: np.ndarray) -> Image.Image:
121
+ if arr.dtype != np.uint8:
122
+ arr = arr.clip(0, 255).astype(np.uint8)
123
+ return Image.fromarray(arr)
124
+
125
+ # ---------- main app ----------
126
+ class VideoBackgroundApp:
127
+ def __init__(self):
128
+ self.config = get_config()
129
+ self.device_mgr = DeviceManager()
130
+ self.memory_mgr = MemoryManager(self.device_mgr.get_optimal_device())
131
+ self.model_loader = ModelLoader(self.device_mgr, self.memory_mgr)
132
+ self.audio_proc = AudioProcessor()
133
+ self.models_loaded = False
134
+ self.core_processor: Optional[CoreVideoProcessor] = None
135
+ logger.info("VideoBackgroundApp initialized (device=%s)", self.device_mgr.get_optimal_device())
136
+
137
+ def load_models(self, progress_callback: Optional[Callable] = None) -> str:
138
+ logger.info("Loading models (CSP-safe)…")
139
+ try:
140
+ sam2, matanyone = self.model_loader.load_all_models(progress_callback=progress_callback)
141
+ except Exception as e:
142
+ logger.warning("Model loading failed (%s) - Using CSP-safe fallbacks", e)
143
+ sam2, matanyone = None, None
144
+
145
+ sam2_model = getattr(sam2, "model", sam2) if sam2 else CSPSafeSAM2()
146
+ matanyone_model = getattr(matanyone, "model", matanyone) if matanyone else CSPSafeMatAnyone()
147
+
148
+ cfg = ProcessorConfig(
149
+ background_preset="office",
150
+ write_fps=None,
151
+ max_model_size=1280,
152
+ use_nvenc=True,
153
+ nvenc_codec="h264",
154
+ nvenc_preset="p5",
155
+ nvenc_cq=18,
156
+ nvenc_tune_hq=True,
157
+ nvenc_pix_fmt="yuv420p",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  )
159
+ self.core_processor = CoreVideoProcessor(config=cfg, models=None)
160
+ self.core_processor.models = type('FakeModelManager', (), {
161
+ 'get_sam2': lambda self_: sam2_model,
162
+ 'get_matanyone': lambda self_: matanyone_model
163
+ })()
164
+
165
+ self.models_loaded = True
166
+ logger.info("Models ready (SAM2=%s, MatAnyOne=%s)",
167
+ type(sam2_model).__name__, type(matanyone_model).__name__)
168
+ return "Models loaded (CSP-safe; fallbacks in use if actual AI models failed)."
169
+
170
+ # ---- PREVIEWS ----
171
+ def preview_preset(self, preset_key: str) -> Image.Image:
172
+ key = preset_key if preset_key in PROFESSIONAL_BACKGROUNDS else "office"
173
+ bg = create_professional_background(key, PREVIEW_W, PREVIEW_H) # RGB
174
+ return _np_to_pil(bg)
175
+
176
+ def preview_upload(self, file) -> Optional[Image.Image]:
177
+ if file is None:
178
+ return None
179
+ try:
180
+ img = Image.open(file.name).convert("RGB")
181
+ img = img.resize((PREVIEW_W, PREVIEW_H), Image.LANCZOS)
182
+ return img
183
+ except Exception as e:
184
+ logger.warning("Upload preview failed: %s", e)
185
+ return None
186
+
187
+ def preview_gradient(self, gtype: str, color1: str, color2: str, angle: int) -> Image.Image:
188
+ spec = {
189
+ "type": (gtype or "linear").lower(), # "linear" or "radial" (linear in fallback)
190
+ "start": _hex_to_rgb(color1 or "#222222"),
191
+ "end": _hex_to_rgb(color2 or "#888888"),
192
+ "angle_deg": float(angle or 0),
193
+ }
194
+ bg = create_gradient_background(spec, PREVIEW_W, PREVIEW_H)
195
+ return _np_to_pil(bg)
196
+
197
+ def ai_generate_background(self, prompt: str, seed: int, width: int, height: int) -> Tuple[Optional[Image.Image], Optional[str], str]:
198
+ """
199
+ Try generating a background with diffusers; save to /tmp and return (img, path, status).
200
+ """
201
+ try:
202
+ from diffusers import StableDiffusionPipeline
203
+ import torch
204
+ model_id = os.environ.get("BGFX_T2I_MODEL", "stabilityai/stable-diffusion-2-1")
205
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
206
+ device = "cuda" if torch.cuda.is_available() else "cpu"
207
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype).to(device)
208
+ g = torch.Generator(device=device).manual_seed(int(seed)) if seed is not None else None
209
+ if device == "cuda":
210
+ with torch.autocast("cuda"):
211
+ img = pipe(prompt, height=height, width=width, guidance_scale=7.0, num_inference_steps=25, generator=g).images[0]
212
+ else:
213
+ img = pipe(prompt, height=height, width=width, guidance_scale=7.0, num_inference_steps=25, generator=g).images[0]
214
+ tmp_path = f"/tmp/ai_bg_{int(time.time())}.png"
215
+ img.save(tmp_path)
216
+ return img.resize((PREVIEW_W, PREVIEW_H), Image.LANCZOS), tmp_path, f"AI background generated ✓ ({os.path.basename(tmp_path)})"
217
+ except Exception as e:
218
+ logger.warning("AI generation unavailable: %s", e)
219
+ return None, None, f"AI generation unavailable: {e}"
220
+
221
+ # ---- PROCESS VIDEO ----
222
+ def process_video(
223
+ self,
224
+ video: str,
225
+ bg_source: str,
226
+ preset_key: str,
227
+ custom_bg_file,
228
+ grad_type: str,
229
+ grad_color1: str,
230
+ grad_color2: str,
231
+ grad_angle: int,
232
+ ai_bg_path: Optional[str],
233
+ ):
234
+ if not self.models_loaded:
235
+ return None, "Models not loaded yet"
236
+
237
+ logger.info("process_video called (video=%s, source=%s, preset=%s, file=%s, grad=%s, ai=%s)",
238
+ video, bg_source, preset_key, getattr(custom_bg_file, "name", None) if custom_bg_file else None,
239
+ {"type": grad_type, "c1": grad_color1, "c2": grad_color2, "angle": grad_angle},
240
+ ai_bg_path)
241
+
242
+ output_path = f"/tmp/output_{int(time.time())}.mp4"
243
+
244
+ # Validate input video
245
+ ok = validate_video_file(video)
246
+ if not ok:
247
+ logger.warning("Invalid/unreadable video: %s", video)
248
+ return None, "Invalid or unreadable video file"
249
+
250
+ # Build bg_config based on source
251
+ src = (bg_source or "Preset").lower()
252
+ if src == "upload" and custom_bg_file is not None:
253
+ bg_cfg: Dict[str, Any] = {"custom_path": custom_bg_file.name}
254
+ elif src == "gradient":
255
+ bg_cfg = {
256
+ "gradient": {
257
+ "type": (grad_type or "linear").lower(),
258
+ "start": _hex_to_rgb(grad_color1 or "#222222"),
259
+ "end": _hex_to_rgb(grad_color2 or "#888888"),
260
+ "angle_deg": float(grad_angle or 0),
261
+ }
262
+ }
263
+ elif src == "ai generate" and ai_bg_path:
264
+ bg_cfg = {"custom_path": ai_bg_path}
265
+ else:
266
+ key = preset_key if preset_key in PROFESSIONAL_BACKGROUNDS else "office"
267
+ bg_cfg = {"background_choice": key}
268
+
269
+ try:
270
+ result = self.core_processor.process_video(
271
+ input_path=video,
272
+ output_path=output_path,
273
+ bg_config=bg_cfg
274
+ )
275
+ logger.info("Core processing done → %s", output_path)
276
+
277
+ output_with_audio = self.audio_proc.add_audio_to_video(video, output_path)
278
+ logger.info("Audio merged → %s", output_with_audio)
279
+
280
+ frames = (result.get('frames') if isinstance(result, dict) else None) or "n/a"
281
+ return output_with_audio, f"Processing complete ({frames} frames, background={bg_source})"
282
+
283
+ except Exception as e:
284
+ logger.exception("Processing failed")
285
+ return None, f"Processing failed: {e}"
286
+
287
+ # 7) Gradio UI
288
+ def create_csp_safe_gradio():
289
+ import gradio as gr
290
+ app = VideoBackgroundApp()
291
+
292
+ with gr.Blocks(
293
+ title="BackgroundFX Pro - CSP Safe",
294
+ analytics_enabled=False,
295
+ css="""
296
+ .gradio-container { max-width: 1100px; margin: auto; }
297
+ """
298
+ ) as demo:
299
+ gr.Markdown("# 🎬 BackgroundFX Pro (CSP-Safe)")
300
+ gr.Markdown("Replace your video background with cinema-quality AI matting. Now with live background preview.")
301
+
302
+ with gr.Row():
303
+ with gr.Column(scale=1):
304
+ video = gr.Video(label="Upload Video")
305
+ bg_source = gr.Radio(
306
+ ["Preset", "Upload", "Gradient", "AI Generate"],
307
+ value="Preset",
308
+ label="Background Source",
309
+ interactive=True,
310
+ )
311
+
312
+ # PRESET
313
+ preset_choices = list(PROFESSIONAL_BACKGROUNDS.keys())
314
+ default_preset = "office" if "office" in preset_choices else (preset_choices[0] if preset_choices else "office")
315
+ preset_key = gr.Dropdown(choices=preset_choices, value=default_preset, label="Preset")
316
+
317
+ # UPLOAD
318
+ custom_bg = gr.File(label="Custom Background (Image)", file_types=["image"], visible=False)
319
+
320
+ # GRADIENT
321
+ grad_type = gr.Dropdown(choices=["Linear", "Radial"], value="Linear", label="Gradient Type", visible=False)
322
+ grad_color1 = gr.ColorPicker(value="#222222", label="Start Color", visible=False)
323
+ grad_color2 = gr.ColorPicker(value="#888888", label="End Color", visible=False)
324
+ grad_angle = gr.Slider(0, 360, value=0, step=1, label="Angle (degrees)", visible=False)
325
+
326
+ # AI
327
+ ai_prompt = gr.Textbox(label="AI Prompt", placeholder="e.g., sunlit modern office, soft bokeh, neutral palette", visible=False)
328
+ ai_seed = gr.Slider(0, 2**31-1, step=1, value=42, label="Seed", visible=False)
329
+ ai_size = gr.Dropdown(choices=["640x360","960x540","1280x720"], value="640x360", label="AI Image Size", visible=False)
330
+ ai_go = gr.Button("✨ Generate Background", visible=False, variant="secondary")
331
+ ai_status = gr.Markdown(visible=False)
332
+ ai_bg_path_state = gr.State(value=None) # store /tmp path
333
+
334
+ btn_load = gr.Button("🔄 Load Models", variant="secondary")
335
+ btn_run = gr.Button("🎬 Process Video", variant="primary")
336
+
337
+ with gr.Column(scale=1):
338
+ status = gr.Textbox(label="Status", lines=4)
339
+ bg_preview = gr.Image(label="Background Preview", width=PREVIEW_W, height=PREVIEW_H, interactive=False)
340
+ out_video = gr.Video(label="Processed Video")
341
+
342
+ # ---------- UI wiring ----------
343
+
344
+ # background source → show/hide controls
345
+ def on_source_toggle(src):
346
+ src = (src or "Preset").lower()
347
+ return (
348
+ gr.update(visible=(src == "preset")),
349
+ gr.update(visible=(src == "upload")),
350
+ gr.update(visible=(src == "gradient")),
351
+ gr.update(visible=(src == "gradient")),
352
+ gr.update(visible=(src == "gradient")),
353
+ gr.update(visible=(src == "gradient")),
354
+ gr.update(visible=(src == "ai generate")),
355
+ gr.update(visible=(src == "ai generate")),
356
+ gr.update(visible=(src == "ai generate")),
357
+ gr.update(visible=(src == "ai generate")),
358
+ gr.update(visible=(src == "ai generate")),
359
+ )
360
+ bg_source.change(
361
+ fn=on_source_toggle,
362
+ inputs=[bg_source],
363
+ outputs=[preset_key, custom_bg, grad_type, grad_color1, grad_color2, grad_angle, ai_prompt, ai_seed, ai_size, ai_go, ai_status],
364
  )
365
 
366
+ # When source changes, also refresh preview based on visible controls
367
+ def on_source_preview(src, pkey, gt, c1, c2, ang):
368
+ src_l = (src or "Preset").lower()
369
+ if src_l == "preset":
370
+ return app.preview_preset(pkey)
371
+ elif src_l == "gradient":
372
+ return app.preview_gradient(gt, c1, c2, ang)
373
+ # For upload/AI we keep whatever the component change handler sets (don’t overwrite)
374
+ return gr.update() # no-op
375
+ bg_source.change(
376
+ fn=on_source_preview,
377
+ inputs=[bg_source, preset_key, grad_type, grad_color1, grad_color2, grad_angle],
378
+ outputs=[bg_preview]
379
+ )
380
 
381
+ # live previews
382
+ preset_key.change(fn=lambda k: app.preview_preset(k), inputs=[preset_key], outputs=[bg_preview])
383
+ custom_bg.change(fn=lambda f: app.preview_upload(f), inputs=[custom_bg], outputs=[bg_preview])
384
+ for comp in (grad_type, grad_color1, grad_color2, grad_angle):
385
+ comp.change(
386
+ fn=lambda gt, c1, c2, ang: app.preview_gradient(gt, c1, c2, ang),
387
+ inputs=[grad_type, grad_color1, grad_color2, grad_angle],
388
+ outputs=[bg_preview],
389
+ )
390
+
391
+ # AI generate
392
+ def ai_generate(prompt, seed, size):
393
+ try:
394
+ w, h = map(int, size.split("x"))
395
+ except Exception:
396
+ w, h = PREVIEW_W, PREVIEW_H
397
+ img, path, msg = app.ai_generate_background(
398
+ prompt or "professional modern office background, neutral colors, depth of field",
399
+ int(seed), w, h
400
+ )
401
+ return img, (path or None), msg
402
+ ai_go.click(fn=ai_generate, inputs=[ai_prompt, ai_seed, ai_size], outputs=[bg_preview, ai_bg_path_state, ai_status])
403
+
404
+ # model load / run
405
+ def safe_load():
406
+ msg = app.load_models()
407
+ logger.info("UI: models loaded")
408
+ return msg, app.preview_preset(preset_key.value if hasattr(preset_key, "value") else "office")
409
+ btn_load.click(fn=safe_load, outputs=[status, bg_preview])
410
+
411
+ def safe_process(vid, src, pkey, file, gtype, c1, c2, ang, ai_path):
412
+ return app.process_video(vid, src, pkey, file, gtype, c1, c2, ang, ai_path)
413
+ btn_run.click(
414
+ fn=safe_process,
415
+ inputs=[video, bg_source, preset_key, custom_bg, grad_type, grad_color1, grad_color2, grad_angle, ai_bg_path_state],
416
+ outputs=[out_video, status]
417
+ )
418
 
419
+ return demo
420
+
421
+ # 8) Launch
422
+ if __name__ == "__main__":
423
+ logger.info("Launching CSP-safe Gradio interface for Hugging Face Spaces")
424
+ demo = create_csp_safe_gradio()
425
+ demo.queue().launch(
426
+ server_name="0.0.0.0",
427
+ server_port=7860,
428
+ show_error=True,
429
+ debug=False,
430
+ inbrowser=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  )