MogensR commited on
Commit
47df7fb
Β·
verified Β·
1 Parent(s): 6d8cd04

Update early_env.py

Browse files
Files changed (1) hide show
  1. early_env.py +1031 -25
early_env.py CHANGED
@@ -1,30 +1,1036 @@
1
- # early_env.py β€” import this BEFORE numpy/torch/cv2/etc.
2
- import os, re
3
-
4
- def _clean_int_env(name: str, default: str | None = None):
5
- val = os.environ.get(name)
6
- if val is None:
7
- if default is not None:
8
- os.environ[name] = default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  return
10
- if not re.fullmatch(r"\d+", str(val).strip()):
11
- if default is None:
12
- os.environ.pop(name, None)
13
- else:
14
- os.environ[name] = default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # Keep thread env sane for Spaces (prevents libgomp errors)
17
- _clean_int_env("OMP_NUM_THREADS", "2")
18
- _clean_int_env("MKL_NUM_THREADS", "2")
19
- _clean_int_env("OPENBLAS_NUM_THREADS", "2")
20
- _clean_int_env("NUMEXPR_NUM_THREADS", "2")
21
 
22
- # Configure PyTorch threading EARLY (before any parallel work)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  try:
24
  import torch
25
- if hasattr(torch, "set_num_interop_threads"):
26
- torch.set_num_interop_threads(2)
27
- if hasattr(torch, "set_num_threads"):
28
- torch.set_num_threads(2)
29
- except Exception:
30
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Video Background Replacer (GPU-Optimized)
4
+
5
+ - MatAnyone (primary), SAM2 (mask seeding), rembg (fallback)
6
+ - K-Governor guards torch.topk/kthvalue (no __wrapped__ assumption)
7
+ - Adaptive MatAnyone loader (from_pretrained | constructor network/model | repo-id)
8
+ - Optional repo pinning via MATANYONE_COMMIT / SAM2_COMMIT
9
+ - First-run warmup β†’ READY βœ… before first request
10
+ - Robust Gradio input coercion (path | dict | file-like | PIL | NumPy)
11
+ - Alpha probing & (optional) stitching alpha_*.png sequences to a video
12
+ - Short-clip stabilizer (pre-roll) with correct trim
13
+ - Concurrency lock for MatAnyone core
14
+ """
15
+
16
+ # =========================
17
+ # EARLY env & imports
18
+ # =========================
19
+ import os, sys, re, time, gc, shutil, subprocess, tempfile, threading, traceback, inspect, glob
20
+ from pathlib import Path
21
+
22
+ # ---- Thread/env sanitization (must run BEFORE numpy/torch/cv2) ----
23
+ def _safe_int_env(var: str, default: int = 2, cap: int = 8) -> int:
24
+ v = os.environ.get(var, "").strip()
25
+ if not v or not re.fullmatch(r"\d+", v):
26
+ os.environ[var] = str(default); return default
27
+ iv = max(1, min(int(v), cap))
28
+ os.environ[var] = str(iv); return iv
29
+
30
+ _safe_int_env("OMP_NUM_THREADS", 2, 8)
31
+ _safe_int_env("MKL_NUM_THREADS", 2, 8)
32
+
33
+ # General runtime defaults
34
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:512")
35
+ os.environ.setdefault("CUDA_MODULE_LOADING", "LAZY")
36
+ os.environ.setdefault("PYTHONUNBUFFERED", "1")
37
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
38
+
39
+ # MatAnyone prefs
40
+ os.environ.setdefault("MATANYONE_MAX_EDGE", "1024")
41
+ os.environ.setdefault("MATANYONE_TARGET_PIXELS", "1000000")
42
+ os.environ.setdefault("MATANYONE_WINDOWED", "1")
43
+ os.environ.setdefault("MATANYONE_WINDOW", "16")
44
+ os.environ.setdefault("MAX_MODEL_SIZE", "1920")
45
+
46
+ # CUDA + cuDNN
47
+ os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0")
48
+ os.environ.setdefault("TORCH_CUDNN_V8_API_ENABLED", "1")
49
+ os.environ.setdefault("CUDNN_BENCHMARK", "1")
50
+
51
+ # HF cache
52
+ os.environ.setdefault("HF_HOME", "./checkpoints/hf")
53
+ os.environ.setdefault("TRANSFORMERS_CACHE", "./checkpoints/hf")
54
+ os.environ.setdefault("HF_DATASETS_CACHE", "./checkpoints/hf")
55
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS", "1")
56
+ os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
57
+ os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
58
+
59
+ # Gradio
60
+ os.environ.setdefault("GRADIO_SERVER_NAME", "0.0.0.0")
61
+ os.environ.setdefault("GRADIO_SERVER_PORT", "7860")
62
+
63
+ # Features
64
+ os.environ.setdefault("USE_MATANYONE", "true")
65
+ os.environ.setdefault("USE_SAM2", "true")
66
+ os.environ.setdefault("SELF_CHECK_MODE", "false")
67
+
68
+ # Stabilizer defaults
69
+ os.environ.setdefault("MATANYONE_STABILIZE", "true")
70
+ os.environ.setdefault("MATANYONE_PREROLL_FRAMES", "12")
71
+
72
+ # Optional strict re-sanitization later
73
+ os.environ.setdefault("STRICT_ENV_GUARD", "1")
74
+
75
+ # =========================
76
+ # Std imports (safe now)
77
+ # =========================
78
+ import cv2
79
+ import numpy as np
80
+ from PIL import Image
81
+ import gradio as gr
82
+ from moviepy.editor import VideoFileClip, ImageSequenceClip, concatenate_videoclips
83
+
84
+ print("=" * 50)
85
+ print("Application Startup at", os.popen('date').read().strip())
86
+ print("=" * 50)
87
+ print("Environment Configuration:")
88
+ print(f"Python: {sys.version}")
89
+ print(f"Working directory: {os.getcwd()}")
90
+ print(f"CUDA_MODULE_LOADING: {os.getenv('CUDA_MODULE_LOADING')}")
91
+ print(f"OMP_NUM_THREADS: {os.getenv('OMP_NUM_THREADS')}")
92
+ print("=" * 50)
93
+
94
+ # =========================
95
+ # Third-party repos & optional pinning
96
+ # =========================
97
+ BASE_DIR = Path(__file__).resolve().parent
98
+ TP_DIR = BASE_DIR / "third_party"
99
+ CHECKPOINTS_DIR = BASE_DIR / "checkpoints"
100
+ TP_DIR.mkdir(exist_ok=True); CHECKPOINTS_DIR.mkdir(exist_ok=True)
101
+
102
+ def _git_clone_if_missing(url: str, path: Path, name: str):
103
+ if path.exists():
104
  return
105
+ print(f"Cloning {name}…")
106
+ try:
107
+ subprocess.run(["git", "clone", "--depth", "1", url, str(path)], check=True, timeout=300)
108
+ print(f"{name} cloned successfully")
109
+ except Exception as e:
110
+ print(f"Failed to clone {name}: {e}")
111
+
112
+ _git_clone_if_missing("https://github.com/facebookresearch/segment-anything-2.git", TP_DIR/"sam2", "SAM2")
113
+ _git_clone_if_missing("https://github.com/pq-yang/MatAnyone.git", TP_DIR/"matanyone", "MatAnyone")
114
+
115
+ def _checkout(repo_dir: Path, commit: str):
116
+ if not commit:
117
+ print(f"{repo_dir.name} not pinned (env is empty) β€” using current HEAD.")
118
+ return
119
+ try:
120
+ subprocess.run(["git", "-C", str(repo_dir), "fetch", "--depth", "1", "origin", commit], check=True)
121
+ subprocess.run(["git", "-C", str(repo_dir), "checkout", "--detach", commit], check=True)
122
+ print(f"Locked {repo_dir.name} to {commit}")
123
+ except Exception as e:
124
+ print(f"Warning: failed to lock {repo_dir.name} to {commit}: {e}")
125
+
126
+ MATANYONE_COMMIT = os.getenv("MATANYONE_COMMIT", "").strip()
127
+ SAM2_COMMIT = os.getenv("SAM2_COMMIT", "").strip()
128
+ _checkout(TP_DIR / "matanyone", MATANYONE_COMMIT)
129
+ _checkout(TP_DIR / "sam2", SAM2_COMMIT)
130
+
131
+ # Ensure vendored paths are importable
132
+ for p in [TP_DIR / "sam2", TP_DIR / "matanyone"]:
133
+ if p.exists() and str(p) not in sys.path:
134
+ sys.path.insert(0, str(p)); print(f"Added to path: {p}")
135
+
136
+ # =========================
137
+ # K-Governor (with bypass; robust for PyTorch 2.2)
138
+ # =========================
139
+ if os.getenv("SAFE_TOPK_BYPASS", "0") not in ("1","true","TRUE"):
140
+ import re as _re
141
+ def _write_safe_ops_file(pkg_root: Path):
142
+ utils_dir = pkg_root / "matanyone" / "utils"
143
+ if not utils_dir.exists(): utils_dir = pkg_root / "utils"
144
+ utils_dir.mkdir(parents=True, exist_ok=True)
145
+ (utils_dir / "safe_ops.py").write_text(
146
+ """
147
+ import os
148
+ import torch
149
 
150
+ _VERBOSE = bool(int(os.environ.get("SAFE_TOPK_VERBOSE", "1")))
 
 
 
 
151
 
152
+ # Robust for builds where topk/kthvalue are builtins without attributes.
153
+ _ORIG_TOPK = getattr(torch.topk, "__wrapped__", torch.topk)
154
+ _ORIG_KTH = getattr(torch.kthvalue, "__wrapped__", torch.kthvalue)
155
+
156
+ def _log(msg):
157
+ if _VERBOSE:
158
+ print(f"[K-Governor] {msg}")
159
+
160
+ def safe_topk(x, k, dim=None, largest=True, sorted=True):
161
+ if not isinstance(k, int):
162
+ k = int(k)
163
+ if dim is None:
164
+ dim = -1
165
+ n = x.size(dim)
166
+ k_eff = max(1, min(k, int(n)))
167
+ if k_eff != k:
168
+ _log(f"torch.topk: clamp k {k} -> {k_eff} for dim={dim} shape={tuple(x.shape)}")
169
+ values, indices = _ORIG_TOPK(x, k_eff, dim=dim, largest=largest, sorted=sorted)
170
+ if k_eff < k:
171
+ pad = k - k_eff
172
+ pad_shape = list(values.shape); pad_shape[dim] = pad
173
+ pad_vals = values.new_full(pad_shape, float('-inf'))
174
+ pad_idx = indices.new_zeros(pad_shape, dtype=indices.dtype)
175
+ values = torch.cat([values, pad_vals], dim=dim)
176
+ indices = torch.cat([indices, pad_idx], dim=dim)
177
+ return values, indices
178
+
179
+ def safe_kthvalue(x, k, dim=None, keepdim=False):
180
+ if not isinstance(k, int):
181
+ k = int(k)
182
+ if dim is None:
183
+ dim = -1
184
+ n = x.size(dim)
185
+ k_eff = max(1, min(k, int(n)))
186
+ if k_eff != k:
187
+ _log(f"torch.kthvalue: clamp k {k} -> {k_eff} for dim={dim} shape={tuple(x.shape)}")
188
+ return _ORIG_KTH(x, k_eff, dim=dim, keepdim=keepdim)
189
+ """.lstrip(), encoding="utf-8")
190
+
191
+ def _patch_matanyone_sources(repo_dir: Path) -> int:
192
+ root = repo_dir / "matanyone"
193
+ if not root.exists(): root = repo_dir
194
+ changed = 0
195
+ header_import = "from matanyone.utils.safe_ops import safe_topk, safe_kthvalue\n"
196
+ pt = _re.compile(r"\btorch\.topk\s*\(")
197
+ pm = _re.compile(r"(\b[\w\.]+)\.topk\s*\(")
198
+ kt = _re.compile(r"\btorch\.kthvalue\s*\(")
199
+ km = _re.compile(r"(\b[\w\.]+)\.kthvalue\s*\(")
200
+ for py in root.rglob("*.py"):
201
+ try:
202
+ txt = py.read_text(encoding="utf-8"); orig = txt
203
+ if "safe_topk" not in txt and py.name != "__init__.py":
204
+ lines = txt.splitlines(keepends=True)
205
+ insert_at = 0
206
+ for i, L in enumerate(lines[:80]):
207
+ if L.startswith(("import ","from ")): insert_at = i+1
208
+ lines.insert(insert_at, header_import)
209
+ txt = "".join(lines)
210
+ txt = pt.sub("safe_topk(", txt)
211
+ txt = kt.sub("safe_kthvalue(", txt)
212
+ def _mt(m): return f"safe_topk({m.group(1)}, "
213
+ def _mk(m): return f"safe_kthvalue({m.group(1)}, "
214
+ txt = pm.sub(_mt, txt); txt = km.sub(_mk, txt)
215
+ if txt != orig:
216
+ py.write_text(txt, encoding="utf-8"); changed += 1
217
+ except Exception as e:
218
+ print(f"[K-Governor] Patch warning on {py}: {e}")
219
+ return changed
220
+
221
+ try:
222
+ MATANY_REPO_DIR = TP_DIR / "matanyone"
223
+ _write_safe_ops_file(MATANY_REPO_DIR)
224
+ patched_files = _patch_matanyone_sources(MATANY_REPO_DIR)
225
+ print(f"[K-Governor] Patched MatAnyone sources: {patched_files} files updated.")
226
+ except Exception as e:
227
+ print(f"[K-Governor] Patch failed: {e}")
228
+ else:
229
+ print("[K-Governor] BYPASSED via SAFE_TOPK_BYPASS")
230
+
231
+ # =========================
232
+ # Torch & device
233
+ # =========================
234
+ TORCH_AVAILABLE = False; CUDA_AVAILABLE = False; GPU_NAME = "N/A"; DEVICE = "cpu"
235
  try:
236
  import torch
237
+ TORCH_AVAILABLE = True
238
+ CUDA_AVAILABLE = torch.cuda.is_available()
239
+ if CUDA_AVAILABLE:
240
+ torch.backends.cudnn.enabled = True
241
+ torch.backends.cudnn.benchmark = True
242
+ torch.backends.cudnn.deterministic = False
243
+ GPU_NAME = torch.cuda.get_device_name(0); DEVICE = "cuda"
244
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
245
+ print(f"GPU: {GPU_NAME}")
246
+ print(f"VRAM: {gpu_memory:.1f} GB")
247
+ print(f"CUDA Capability: {torch.cuda.get_device_capability(0)}")
248
+ try: torch.cuda.set_per_process_memory_fraction(0.9)
249
+ except Exception: pass
250
+ print(f"Torch version: {torch.__version__}")
251
+ print(f"CUDA available: {CUDA_AVAILABLE}")
252
+ print(f"Device: {DEVICE}")
253
+ except Exception as e:
254
+ print(f"Torch not available: {e}")
255
+
256
+ # =========================
257
+ # Light GPU monitor
258
+ # =========================
259
+ class GPUMonitor:
260
+ def __init__(self):
261
+ self.monitoring = False
262
+ self.stats = {"gpu_util": 0, "memory_used": 0, "memory_total": 0}
263
+ def start_monitoring(self):
264
+ if not CUDA_AVAILABLE: return
265
+ self.monitoring = True
266
+ threading.Thread(target=self._monitor_loop, daemon=True).start()
267
+ def stop_monitoring(self): self.monitoring = False
268
+ def _monitor_loop(self):
269
+ while self.monitoring:
270
+ try:
271
+ if CUDA_AVAILABLE:
272
+ mem_used = torch.cuda.memory_allocated(0) / 1024**3
273
+ mem_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
274
+ self.stats.update({
275
+ "memory_used": mem_used, "memory_total": mem_total,
276
+ "memory_percent": (mem_used/mem_total)*100 if mem_total else 0
277
+ })
278
+ try:
279
+ import pynvml
280
+ pynvml.nvmlInit()
281
+ h = pynvml.nvmlDeviceGetHandleByIndex(0)
282
+ util = pynvml.nvmlDeviceGetUtilizationRates(h)
283
+ self.stats["gpu_util"] = util.gpu
284
+ except Exception:
285
+ pass
286
+ except Exception as e:
287
+ print(f"GPU monitoring error: {e}")
288
+ time.sleep(1)
289
+ def get_stats(self): return self.stats.copy()
290
+
291
+ gpu_monitor = GPUMonitor(); gpu_monitor.start_monitoring()
292
+
293
+ # =========================
294
+ # SAM2 (verified micro-inference)
295
+ # =========================
296
+ SAM2_IMPORTED = False; SAM2_AVAILABLE = False; SAM2_PREDICTOR = None
297
+ if TORCH_AVAILABLE and os.getenv("USE_SAM2","true").lower()=="true":
298
+ try:
299
+ print("Setting up SAM2…")
300
+ from hydra import initialize_config_dir, compose
301
+ from hydra.core.global_hydra import GlobalHydra
302
+ from sam2.build_sam import build_sam2
303
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
304
+ SAM2_IMPORTED = True
305
+ ckpt = Path("./checkpoints/sam2.1_hiera_tiny.pt")
306
+ ckpt.parent.mkdir(parents=True, exist_ok=True)
307
+ if not ckpt.exists():
308
+ print("Downloading SAM2.1 checkpoint…")
309
+ import requests
310
+ url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt"
311
+ r = requests.get(url, stream=True, timeout=60); r.raise_for_status()
312
+ with open(ckpt, "wb") as f:
313
+ for ch in r.iter_content(chunk_size=8192):
314
+ if ch: f.write(ch)
315
+ print(f"SAM2 checkpoint downloaded to {ckpt}")
316
+ if GlobalHydra().is_initialized():
317
+ GlobalHydra.instance().clear()
318
+ config_dir = str(TP_DIR / "sam2" / "sam2" / "configs")
319
+ config_file = "sam2.1/sam2.1_hiera_t.yaml"
320
+ initialize_config_dir(config_dir=config_dir, version_base=None)
321
+ _ = compose(config_name=config_file)
322
+ model = build_sam2(config_file, str(ckpt), device="cuda" if CUDA_AVAILABLE else "cpu")
323
+ if CUDA_AVAILABLE and hasattr(torch, "compile"):
324
+ try: model = torch.compile(model, mode="max-autotune")
325
+ except Exception as _e: print(f"torch.compile not used: {_e}")
326
+ SAM2_PREDICTOR = SAM2ImagePredictor(model)
327
+ try:
328
+ dummy = np.zeros((64,64,3), dtype=np.uint8)
329
+ SAM2_PREDICTOR.set_image(dummy)
330
+ pts = np.array([[32,32]], dtype=np.int32); lbs = np.array([1], dtype=np.int32)
331
+ _m,_s,_l = SAM2_PREDICTOR.predict(point_coords=pts, point_labels=lbs, multimask_output=True)
332
+ SAM2_AVAILABLE = True; print("βœ… SAM2 verified via micro-inference.")
333
+ except Exception as ver_e:
334
+ SAM2_AVAILABLE = False; SAM2_PREDICTOR = None
335
+ print(f"SAM2 verification failed: {ver_e}")
336
+ except Exception as e:
337
+ print(f"SAM2 setup failed: {e}")
338
+
339
+ # =========================
340
+ # MatAnyone import (canonical first, fallback)
341
+ # =========================
342
+ MATANYONE_IMPORTED = False; MatAnyInferenceCore = None
343
+ try:
344
+ from matanyone.inference.inference_core import InferenceCore as MatAnyInferenceCore
345
+ MATANYONE_IMPORTED = True
346
+ print("MatAnyone import OK: matanyone.inference.inference_core.InferenceCore")
347
+ except Exception as e1:
348
+ try:
349
+ from matanyone import InferenceCore as MatAnyInferenceCore
350
+ MATANYONE_IMPORTED = True
351
+ print("MatAnyone import OK: matanyone.InferenceCore")
352
+ except Exception as e2:
353
+ print(f"MatAnyone not importable: {e2 or e1}")
354
+
355
+ # =========================
356
+ # rembg fallback
357
+ # =========================
358
+ REMBG_AVAILABLE = False
359
+ try:
360
+ from rembg import remove
361
+ REMBG_AVAILABLE = True; print("rembg import OK (fallback ready).")
362
+ except Exception as e:
363
+ print(f"rembg not available: {e}")
364
+
365
+ # =========================
366
+ # Background helpers
367
+ # =========================
368
+ def make_solid(w, h, rgb): return np.full((h, w, 3), rgb, dtype=np.uint8)
369
+ def make_vertical_gradient(w, h, top_rgb, bottom_rgb):
370
+ top = np.array(top_rgb, dtype=np.float32); bot = np.array(bottom_rgb, dtype=np.float32)
371
+ t = np.linspace(0,1,h,dtype=np.float32)[:,None]
372
+ grad = (1-t)*top + t*bot; grad = np.clip(grad,0,255).astype(np.uint8)
373
+ return np.repeat(grad[None,...], w, axis=0).transpose(1,0,2)
374
+ def build_professional_bg(w, h, preset: str) -> np.ndarray:
375
+ p = (preset or "").lower()
376
+ if p == "office (soft gray)": return make_vertical_gradient(w,h,(245,246,248),(220,223,228))
377
+ if p == "studio (charcoal)": return make_vertical_gradient(w,h,(32,32,36),(64,64,70))
378
+ if p == "nature (green tint)":return make_vertical_gradient(w,h,(180,220,190),(100,160,120))
379
+ if p == "brand blue": return make_solid(w,h,(18,112,214))
380
+ return make_solid(w,h,(240,240,240))
381
+
382
+ # =========================
383
+ # MatAnyone wrapper (+ lock, adaptive constructor, alpha stitching)
384
+ # =========================
385
+ class OptimizedMatAnyoneProcessor:
386
+ def __init__(self):
387
+ self.processor = None
388
+ self.device = "cuda" if (TORCH_AVAILABLE and CUDA_AVAILABLE) else "cpu"
389
+ self.initialized = False
390
+ self.verified = False
391
+ self.last_error = None
392
+ self.stabilize = os.getenv("MATANYONE_STABILIZE","true").lower()=="true"
393
+ try: self.preroll_frames = max(0, int(os.getenv("MATANYONE_PREROLL_FRAMES","12")))
394
+ except Exception: self.preroll_frames = 12
395
+ self._lock = threading.Lock()
396
+
397
+ # ---- Adaptive core constructor
398
+ def _construct_inference_core(self, network_or_repo):
399
+ # prefer classmethod if available
400
+ try:
401
+ if hasattr(MatAnyInferenceCore, "from_pretrained"):
402
+ return MatAnyInferenceCore.from_pretrained(
403
+ network_or_repo,
404
+ device=("cuda" if CUDA_AVAILABLE else "cpu")
405
+ )
406
+ except Exception:
407
+ pass
408
+ # try constructor with introspection
409
+ try:
410
+ sig = inspect.signature(MatAnyInferenceCore)
411
+ if isinstance(network_or_repo, str):
412
+ return MatAnyInferenceCore(network_or_repo)
413
+ if "network" in sig.parameters:
414
+ return MatAnyInferenceCore(network=network_or_repo)
415
+ if "model" in sig.parameters:
416
+ return MatAnyInferenceCore(model=network_or_repo)
417
+ return MatAnyInferenceCore(network_or_repo)
418
+ except Exception as e:
419
+ raise RuntimeError(f"InferenceCore construction failed: {type(e).__name__}: {e}")
420
+
421
+ # ---- Normalize return + disk probe + png sequence stitch
422
+ def _stitch_alpha_sequence(self, outdir: str, fps: float) -> str | None:
423
+ # common patterns
424
+ patt_list = ["alpha_%04d.png", "alpha_%03d.png", "alpha_%05d.png", "alpha_*.png"]
425
+ frames = []
426
+ for patt in patt_list:
427
+ frames = sorted(glob.glob(os.path.join(outdir, patt.replace("%0", "*").replace("d",""))))
428
+ if frames:
429
+ break
430
+ if not frames:
431
+ return None
432
+ # read as float [0,1]
433
+ ary = []
434
+ for p in frames:
435
+ im = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
436
+ if im is None: continue
437
+ ary.append((im.astype(np.float32) / 255.0))
438
+ if not ary:
439
+ return None
440
+ clip = ImageSequenceClip([f for f in ary], fps=max(1, int(round(fps or 24))))
441
+ alpha_mp4 = tempfile.NamedTemporaryFile(delete=False, suffix="_alpha_seq.mp4").name
442
+ clip.write_videofile(alpha_mp4, audio=False, logger=None)
443
+ clip.close()
444
+ return alpha_mp4
445
+
446
+ def _normalize_ret_and_probe(self, ret, outdir: str, fallback_fps: float = 24.0):
447
+ fg_path = alpha_path = None
448
+ if isinstance(ret, (list, tuple)):
449
+ if len(ret) >= 2: fg_path, alpha_path = ret[0], ret[1]
450
+ elif len(ret) == 1: alpha_path = ret[0]
451
+ elif isinstance(ret, str):
452
+ alpha_path = ret
453
+
454
+ def _valid(p: str) -> bool:
455
+ return p and os.path.exists(p) and os.path.getsize(p) > 0
456
+
457
+ # probe common video names
458
+ if not _valid(alpha_path):
459
+ for cand in ("alpha.mp4","alpha.mkv","alpha.mov","alpha.webm"):
460
+ p = os.path.join(outdir, cand)
461
+ if _valid(p):
462
+ alpha_path = p; break
463
+
464
+ # try stitching sequences if needed
465
+ if not _valid(alpha_path):
466
+ stitched = self._stitch_alpha_sequence(outdir, fallback_fps)
467
+ if stitched and _valid(stitched):
468
+ alpha_path = stitched
469
+
470
+ return fg_path, alpha_path
471
+
472
+ def _warmup(self) -> None:
473
+ import numpy as _np, cv2 as _cv2, os as _os
474
+ from moviepy.editor import ImageSequenceClip as _ISC
475
+ with tempfile.TemporaryDirectory() as td:
476
+ frames = []
477
+ for t in range(8):
478
+ fr = _np.zeros((64,64,3), _np.uint8); x = 8 + t*4
479
+ _cv2.rectangle(fr, (x,20), (x+12,44), 200, -1); frames.append(fr)
480
+ vid = _os.path.join(td,"warmup.mp4"); _ISC(frames, fps=10).write_videofile(vid, audio=False, logger=None)
481
+ m = _np.zeros((64,64), _np.uint8); _cv2.rectangle(m,(24,24),(40,40),255,-1)
482
+ mask = _os.path.join(td,"mask.png"); _cv2.imwrite(mask, m)
483
+ outdir = _os.path.join(td,"out"); os.makedirs(outdir, exist_ok=True)
484
+ # ensure method exists
485
+ if not hasattr(self.processor, "process_video"):
486
+ if hasattr(self.processor, "process"):
487
+ self.processor.process_video = self.processor.process
488
+ else:
489
+ raise RuntimeError("MatAnyone core lacks process_video/process")
490
+
491
+ ret = self.processor.process_video(input_path=vid, mask_path=mask, output_path=outdir, max_size=512)
492
+ _fg, alpha = self._normalize_ret_and_probe(ret, outdir, fallback_fps=10)
493
+ if not alpha or not os.path.exists(alpha) or os.path.getsize(alpha) == 0:
494
+ raise RuntimeError("Warmup: MatAnyone produced no alpha")
495
+
496
+ def initialize(self) -> bool:
497
+ with self._lock:
498
+ if not MATANYONE_IMPORTED:
499
+ print("MatAnyone not importable; skipping init."); return False
500
+ if self.initialized and self.processor is not None:
501
+ return True
502
+ self.last_error = None
503
+
504
+ # HF path first
505
+ try:
506
+ print(f"Initializing MatAnyone (HF repo-id) on {self.device}…")
507
+ self.processor = self._construct_inference_core("PeiqingYang/MatAnyone")
508
+ if self.device == "cuda":
509
+ import torch as _t
510
+ _t.cuda.empty_cache(); _ = _t.rand(1, device="cuda") * 0.0
511
+ # alias method if needed
512
+ if not hasattr(self.processor, "process_video") and hasattr(self.processor, "process"):
513
+ self.processor.process_video = self.processor.process
514
+ self._warmup()
515
+ self.verified = True; self.initialized = True
516
+ print("βœ… MatAnyone initialized & warmed up (HF repo-id).")
517
+ return True
518
+ except Exception as e:
519
+ self.last_error = f"HF init failed: {type(e).__name__}: {e}"
520
+ print(self.last_error)
521
+
522
+ # Local ckpt fallback
523
+ try:
524
+ print("Falling back to local checkpoint init for MatAnyone…")
525
+ from hydra.core.global_hydra import GlobalHydra
526
+ if hasattr(GlobalHydra,"instance") and GlobalHydra().is_initialized():
527
+ GlobalHydra.instance().clear()
528
+ import requests
529
+ from matanyone.utils.get_default_model import get_matanyone_model
530
+ ckpt_dir = Path("./pretrained_models"); ckpt_dir.mkdir(parents=True, exist_ok=True)
531
+ ckpt_path = ckpt_dir / "matanyone.pth"
532
+ if not ckpt_path.exists():
533
+ url = "https://github.com/pq-yang/MatAnyone/releases/download/v1.0.0/matanyone.pth"
534
+ print(f"Downloading MatAnyone checkpoint from: {url}")
535
+ with requests.get(url, stream=True, timeout=180) as r:
536
+ r.raise_for_status()
537
+ with open(ckpt_path, "wb") as f:
538
+ for chunk in r.iter_content(chunk_size=8192):
539
+ if chunk: f.write(chunk)
540
+ print(f"Checkpoint saved to {ckpt_path}")
541
+ network = get_matanyone_model(str(ckpt_path), device=("cuda" if CUDA_AVAILABLE else "cpu"))
542
+ self.processor = self._construct_inference_core(network)
543
+ if self.device == "cuda":
544
+ import torch as _t
545
+ _t.cuda.empty_cache(); _ = _t.rand(1, device="cuda") * 0.0
546
+ if not hasattr(self.processor, "process_video") and hasattr(self.processor, "process"):
547
+ self.processor.process_video = self.processor.process
548
+ self._warmup()
549
+ self.verified = True; self.initialized = True
550
+ print("βœ… MatAnyone initialized & warmed up (local checkpoint).")
551
+ return True
552
+ except Exception as e:
553
+ self.last_error = f"Local init/warmup failed: {type(e).__name__}: {e}"
554
+ print(f"MatAnyone initialization failed: {self.last_error}")
555
+ traceback.print_exc(); return False
556
+
557
+ # ---- Pre-roll & trimming
558
+ @staticmethod
559
+ def _build_preroll_concat(input_path: str, frames: int) -> tuple[str, float, float]:
560
+ clip = VideoFileClip(input_path)
561
+ fps = float(clip.fps or 24.0)
562
+ preroll_frames = max(0, frames)
563
+ if preroll_frames == 0:
564
+ out = input_path; clip.close(); return out, 0.0, fps
565
+ first = clip.get_frame(0)
566
+ pre = ImageSequenceClip([first]*preroll_frames, fps=max(1, int(round(fps))))
567
+ concat = concatenate_videoclips([pre, clip])
568
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix="_concat.mp4")
569
+ concat.write_videofile(tmp.name, audio=False, logger=None)
570
+ pre.close(); concat.close(); clip.close()
571
+ return tmp.name, preroll_frames / fps, fps
572
+
573
+ @staticmethod
574
+ def _trim_head(video_path: str, seconds: float) -> str:
575
+ if seconds <= 0: return video_path
576
+ clip = VideoFileClip(video_path); dur = clip.duration or 0
577
+ start = min(seconds, max(0.0, dur - 0.001))
578
+ trimmed = tempfile.NamedTemporaryFile(delete=False, suffix="_trim.mp4").name
579
+ clip.subclip(start, None).write_videofile(trimmed, audio=False, logger=None)
580
+ clip.close(); return trimmed
581
+
582
+ def create_mask_optimized(self, video_path: str, output_path: str) -> str:
583
+ cap = cv2.VideoCapture(video_path); ret, frame = cap.read(); cap.release()
584
+ if not ret: raise ValueError("Could not read first frame from video.")
585
+ if SAM2_AVAILABLE and SAM2_PREDICTOR is not None:
586
+ try:
587
+ print("Creating mask with SAM2 (first frame)…")
588
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
589
+ SAM2_PREDICTOR.set_image(rgb)
590
+ h, w = rgb.shape[:2]
591
+ pts = np.array([[w//2, h//2],[w//3, h//3],[2*w//3, 2*h//3]], dtype=np.int32)
592
+ lbs = np.array([1,1,1], dtype=np.int32)
593
+ masks, scores, _ = SAM2_PREDICTOR.predict(point_coords=pts, point_labels=lbs, multimask_output=True)
594
+ best = masks[np.argmax(scores)]
595
+ mask = ((best.astype(np.uint8) > 0).astype(np.uint8)) * 255 # 1ch u8 {0,255}
596
+ cv2.imwrite(output_path, mask)
597
+ print(f"Self-test mask uniques: {np.unique(mask//255)}")
598
+ return output_path
599
+ except Exception as e:
600
+ print(f"SAM2 mask creation failed; fallback rectangle. Error: {e}")
601
+ # Fallback: centered box
602
+ h, w = frame.shape[:2]
603
+ mask = np.zeros((h,w), dtype=np.uint8)
604
+ mx, my = int(w*0.15), int(h*0.10)
605
+ mask[my:h-my, mx:w-mx] = 255
606
+ cv2.imwrite(output_path, mask); return output_path
607
+
608
+ def process_video_optimized(self, input_path: str, output_dir: str):
609
+ with self._lock:
610
+ if not self.initialized and not self.initialize():
611
+ return None
612
+ try:
613
+ print("πŸš€ MatAnyone processing…")
614
+ if CUDA_AVAILABLE:
615
+ import torch as _t
616
+ _t.cuda.empty_cache(); gc.collect()
617
+
618
+ concat_path = input_path; preroll_sec = 0.0; fps_used = 24.0
619
+ if self.stabilize and self.preroll_frames > 0:
620
+ concat_path, preroll_sec, fps_used = self._build_preroll_concat(input_path, self.preroll_frames)
621
+ print(f"[Stabilizer] Pre-rolled {self.preroll_frames} frames ({preroll_sec:.3f}s).")
622
+
623
+ mask_path = os.path.join(output_dir, "mask.png")
624
+ self.create_mask_optimized(input_path, mask_path)
625
+
626
+ if not hasattr(self.processor, "process_video") and hasattr(self.processor, "process"):
627
+ self.processor.process_video = self.processor.process
628
+
629
+ ret = self.processor.process_video(
630
+ input_path=concat_path,
631
+ mask_path=mask_path,
632
+ output_path=output_dir,
633
+ max_size=int(os.getenv("MAX_MODEL_SIZE","1920"))
634
+ )
635
+ fg_path, alpha_path = self._normalize_ret_and_probe(ret, output_dir, fallback_fps=fps_used)
636
+
637
+ if not alpha_path or not os.path.exists(alpha_path):
638
+ raise RuntimeError("MatAnyone finished without a valid alpha video on disk.")
639
+
640
+ if preroll_sec > 0.0:
641
+ alpha_path = self._trim_head(alpha_path, preroll_sec)
642
+ print(f"[Stabilizer] Trimmed {preroll_sec:.3f}s from alpha.")
643
+
644
+ if not os.path.exists(alpha_path) or os.path.getsize(alpha_path) == 0:
645
+ raise RuntimeError("Alpha exists but is empty/zero bytes after trim.")
646
+
647
+ return alpha_path
648
+
649
+ except Exception as e:
650
+ print(f"❌ MatAnyone processing failed: {e}")
651
+ traceback.print_exc()
652
+ return None
653
+
654
+ matanyone_processor = OptimizedMatAnyoneProcessor()
655
+
656
+ # =========================
657
+ # rembg helpers
658
+ # =========================
659
+ REMBG_AVAILABLE = REMBG_AVAILABLE
660
+ def process_frame_rembg_optimized(frame_bgr_u8, bg_img_rgb_u8):
661
+ if not REMBG_AVAILABLE:
662
+ return cv2.cvtColor(frame_bgr_u8, cv2.COLOR_BGR2RGB)
663
+ try:
664
+ frame_rgb = cv2.cvtColor(frame_bgr_u8, cv2.COLOR_BGR2RGB)
665
+ pil_im = Image.fromarray(frame_rgb)
666
+ from rembg import remove # lazy import in case plugin is heavy
667
+ result = remove(pil_im).convert("RGBA")
668
+ result_np = np.array(result)
669
+ if result_np.shape[2] == 4:
670
+ alpha = (result_np[:, :, 3:4].astype(np.float32) / 255.0)
671
+ comp = alpha * result_np[:, :, :3].astype(np.float32) + (1 - alpha) * bg_img_rgb_u8.astype(np.float32)
672
+ return comp.astype(np.uint8)
673
+ return result_np.astype(np.uint8)
674
+ except Exception as e:
675
+ print(f"rembg processing error: {e}")
676
+ return cv2.cvtColor(frame_bgr_u8, cv2.COLOR_BGR2RGB)
677
+
678
+ # =========================
679
+ # Compositing
680
+ # =========================
681
+ def composite_with_background(original_path, alpha_path, bg_path=None, bg_preset=None):
682
+ print("🎬 Compositing final video…")
683
+ orig_clip = VideoFileClip(original_path)
684
+ alpha_clip = VideoFileClip(alpha_path)
685
+ fps = orig_clip.fps or 24
686
+ w, h = orig_clip.size
687
+ if bg_path:
688
+ bg_img = cv2.imread(bg_path)
689
+ if bg_img is None: raise ValueError(f"Could not read background image: {bg_path}")
690
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB); bg_img = cv2.resize(bg_img, (w, h))
691
+ else:
692
+ bg_img = build_professional_bg(w, h, bg_preset)
693
+
694
+ def process_func(get_frame, t):
695
+ frame = get_frame(t)
696
+ a = alpha_clip.get_frame(t)
697
+ if a.ndim == 2: a = a[..., None]
698
+ elif a.shape[2] > 1: a = a[..., :1]
699
+ a = np.clip(a, 0.0, 1.0).astype(np.float32)
700
+ bg_f32 = (bg_img.astype(np.float32) / 255.0)
701
+ comp = a * frame.astype(np.float32) + (1.0 - a) * bg_f32
702
+ return comp.astype(np.float32)
703
+
704
+ new_clip = orig_clip.fl(process_func).set_fps(fps)
705
+ output_path = "final_output.mp4"
706
+ new_clip.write_videofile(output_path, audio=False, logger=None)
707
+ alpha_clip.close(); orig_clip.close(); new_clip.close()
708
+ return output_path
709
+
710
+ # =========================
711
+ # rembg whole-video fallback
712
+ # =========================
713
+ def process_video_rembg_fallback(video_path, bg_image_path=None, bg_preset=None):
714
+ print("πŸ”„ Processing with rembg fallback…")
715
+ cap = cv2.VideoCapture(video_path); ret, frame = cap.read()
716
+ if not ret: cap.release(); raise ValueError("Could not read video")
717
+ h, w, _ = frame.shape; cap.release()
718
+ if bg_image_path:
719
+ bg_img = cv2.imread(bg_image_path)
720
+ if bg_img is None: raise ValueError(f"Could not read background image: {bg_image_path}")
721
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB); bg_img = cv2.resize(bg_img, (w, h))
722
+ else:
723
+ bg_img = build_professional_bg(w, h, bg_preset)
724
+ clip = VideoFileClip(video_path)
725
+ fps = clip.fps or 24
726
+ def process_func(get_frame, t):
727
+ fr = get_frame(t)
728
+ fr_u8 = (fr * 255).astype(np.uint8)
729
+ comp = process_frame_rembg_optimized(cv2.cvtColor(fr_u8, cv2.COLOR_RGB2BGR), bg_img)
730
+ return (comp.astype(np.float32) / 255.0)
731
+ new_clip = clip.fl(process_func).set_fps(fps)
732
+ output_path = "rembg_output.mp4"
733
+ new_clip.write_videofile(output_path, audio=False, logger=None)
734
+ clip.close(); new_clip.close()
735
+ return output_path
736
+
737
+ # =========================
738
+ # Self-test harness
739
+ # =========================
740
+ def _ok(flag): return "βœ…" if flag else "❌"
741
+ def self_test_cuda():
742
+ try:
743
+ if not TORCH_AVAILABLE: return False, "Torch not importable"
744
+ if not CUDA_AVAILABLE: return False, "CUDA not available"
745
+ import torch as _t
746
+ a = _t.randn((1024,1024), device="cuda"); b = _t.randn((1024,1024), device="cuda")
747
+ c = (a @ b).mean().item(); return True, f"CUDA matmul ok, mean={c:.6f}"
748
+ except Exception as e: return False, f"CUDA op failed: {e}"
749
+ def self_test_ffmpeg_moviepy():
750
+ try:
751
+ ff = shutil.which("ffmpeg")
752
+ if not ff: return False, "ffmpeg not found on PATH"
753
+ frames = [(np.zeros((64,64,3), np.uint8) + i).clip(0,255) for i in range(0,200,25)]
754
+ clip = ImageSequenceClip(frames, fps=4)
755
+ with tempfile.TemporaryDirectory() as td:
756
+ vp = os.path.join(td, "tiny.mp4")
757
+ clip.write_videofile(vp, audio=False, logger=None); clip.close()
758
+ clip_r = VideoFileClip(vp); _ = clip_r.get_frame(0.1); clip_r.close()
759
+ return True, "FFmpeg/MoviePy encode/decode ok"
760
+ except Exception as e: return False, f"FFmpeg/MoviePy test failed: {e}"
761
+ def self_test_rembg():
762
+ try:
763
+ if not REMBG_AVAILABLE: return False, "rembg not importable"
764
+ from rembg import remove
765
+ img = np.zeros((64,64,3), dtype=np.uint8); img[:,:] = (0,255,0)
766
+ pil = Image.fromarray(img); out = remove(pil)
767
+ ok = isinstance(out, Image.Image) and out.size == (64,64)
768
+ return ok, "rembg ok" if ok else "rembg returned unexpected output"
769
+ except Exception as e: return False, f"rembg failed: {e}"
770
+ def self_test_sam2():
771
+ try:
772
+ if not SAM2_IMPORTED: return False, "SAM2 not importable"
773
+ if not SAM2_PREDICTOR: return False, "SAM2 predictor not initialized"
774
+ dummy = np.zeros((64,64,3), dtype=np.uint8)
775
+ SAM2_PREDICTOR.set_image(dummy)
776
+ pts = np.array([[32,32]], dtype=np.int32); lbs = np.array([1], dtype=np.int32)
777
+ masks, scores, _ = SAM2_PREDICTOR.predict(point_coords=pts, point_labels=lbs, multimask_output=True)
778
+ ok = masks is not None and len(masks) > 0
779
+ return ok, "SAM2 micro-inference ok" if ok else "SAM2 predict returned no masks"
780
+ except Exception as e: return False, f"SAM2 micro-inference failed: {e}"
781
+ def self_test_matanyone():
782
+ try:
783
+ ok_init = matanyone_processor.initialize()
784
+ if not ok_init: return False, f"MatAnyone init failed: {getattr(matanyone_processor,'last_error','no details')}"
785
+ if not matanyone_processor.verified: return False, "MatAnyone missing process_video API"
786
+ with tempfile.TemporaryDirectory() as td:
787
+ frames = []
788
+ for t in range(8):
789
+ frame = np.zeros((64,64,3), dtype=np.uint8)
790
+ x = 8 + t*4; cv2.rectangle(frame, (x,20),(x+12,44), 200, -1); frames.append(frame)
791
+ vid_path = os.path.join(td,"tiny_input.mp4")
792
+ clip = ImageSequenceClip(frames, fps=8); clip.write_videofile(vid_path, audio=False, logger=None); clip.close()
793
+ mask = np.zeros((64,64), dtype=np.uint8); cv2.rectangle(mask,(24,24),(40,40),255,-1)
794
+ mask_path = os.path.join(td,"mask.png"); cv2.imwrite(mask_path, mask)
795
+ alpha = matanyone_processor.process_video_optimized(vid_path, td)
796
+ if alpha is None or not os.path.exists(alpha): return False, "MatAnyone did not produce alpha video"
797
+ _alpha_clip = VideoFileClip(alpha); _ = _alpha_clip.get_frame(0.1); _alpha_clip.close()
798
+ return True, "MatAnyone process_video ok"
799
+ except Exception as e: return False, f"MatAnyone test failed: {e}"
800
+ def run_self_test() -> str:
801
+ lines = []
802
+ lines.append("=== SELF TEST REPORT ===")
803
+ lines.append(f"Python: {sys.version.split()[0]}")
804
+ lines.append(f"Torch: {torch.__version__ if TORCH_AVAILABLE else 'N/A'} | CUDA: {CUDA_AVAILABLE} | Device: {DEVICE} | GPU: {GPU_NAME}")
805
+ lines.append(f"FFmpeg on PATH: {bool(shutil.which('ffmpeg'))}")
806
+ lines.append("")
807
+ tests = [("CUDA", self_test_cuda), ("FFmpeg/MoviePy", self_test_ffmpeg_moviepy),
808
+ ("rembg", self_test_rembg), ("SAM2", self_test_sam2), ("MatAnyone", self_test_matanyone)]
809
+ for name, fn in tests:
810
+ t0 = time.time(); ok, msg = fn(); dt = time.time() - t0
811
+ lines.append(f"{_ok(ok)} {name}: {msg} [{dt:.2f}s]")
812
+ return "\n".join(lines)
813
+
814
+ # =========================
815
+ # Gradio input coercion helpers
816
+ # =========================
817
+ def _coerce_video_to_path(video_file):
818
+ if video_file is None:
819
+ return None
820
+ if isinstance(video_file, str):
821
+ return video_file
822
+ if isinstance(video_file, dict) and "name" in video_file:
823
+ return video_file["name"]
824
+ return getattr(video_file, "name", None)
825
+
826
+ def _coerce_bg_to_path(bg_image, temp_dir):
827
+ """Return filesystem path for background image, writing it to temp_dir if needed."""
828
+ if bg_image is None:
829
+ return None
830
+ if isinstance(bg_image, str):
831
+ return bg_image
832
+ if isinstance(bg_image, dict) and "name" in bg_image:
833
+ return bg_image["name"]
834
+ if hasattr(bg_image, "name") and isinstance(bg_image.name, str):
835
+ return bg_image.name
836
+ if isinstance(bg_image, Image.Image):
837
+ p = os.path.join(temp_dir, "bg_uploaded.png")
838
+ bg_image.save(p); return p
839
+ if isinstance(bg_image, np.ndarray):
840
+ p = os.path.join(temp_dir, "bg_uploaded.png")
841
+ arr = bg_image
842
+ if arr.ndim == 3 and arr.shape[2] == 3:
843
+ cv2.imwrite(p, cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))
844
+ else:
845
+ cv2.imwrite(p, arr)
846
+ return p
847
+ return None
848
+
849
+ # =========================
850
+ # Gradio callback
851
+ # =========================
852
+ def gradio_interface_optimized(video_file, bg_image, use_matanyone=True, bg_preset="Office (Soft Gray)", stabilize=True, preroll_frames=12):
853
+ try:
854
+ if video_file is None:
855
+ return None, None, "Please upload a video."
856
+ print(f"UI types: video={type(video_file)}, bg={type(bg_image)}")
857
+
858
+ with tempfile.TemporaryDirectory() as temp_dir:
859
+ video_path = _coerce_video_to_path(video_file)
860
+ if not video_path or not os.path.exists(video_path):
861
+ return None, None, "Could not read the uploaded video path."
862
+ bg_path = _coerce_bg_to_path(bg_image, temp_dir) # may be None β†’ preset is used
863
+
864
+ # reflect UI choices
865
+ matanyone_processor.stabilize = bool(stabilize)
866
+ try:
867
+ matanyone_processor.preroll_frames = max(0, int(preroll_frames))
868
+ except Exception:
869
+ pass
870
+
871
+ start_time = time.time()
872
+
873
+ if use_matanyone and MATANYONE_IMPORTED:
874
+ if not matanyone_processor.initialized:
875
+ matanyone_processor.initialize()
876
+
877
+ if matanyone_processor.initialized and matanyone_processor.verified:
878
+ alpha_video_path = matanyone_processor.process_video_optimized(video_path, temp_dir)
879
+ if alpha_video_path is None:
880
+ out = process_video_rembg_fallback(video_path, bg_path, bg_preset=bg_preset)
881
+ method = "rembg (fallback after MatAnyone error)"
882
+ else:
883
+ out = composite_with_background(video_path, alpha_video_path, bg_path, bg_preset=bg_preset)
884
+ method = f"MatAnyone (GPU: {CUDA_AVAILABLE})"
885
+ else:
886
+ out = process_video_rembg_fallback(video_path, bg_path, bg_preset=bg_preset)
887
+ method = "rembg (MatAnyone not verified)"
888
+ else:
889
+ out = process_video_rembg_fallback(video_path, bg_path, bg_preset=bg_preset)
890
+ method = "rembg"
891
+
892
+ final_gpu = gpu_monitor.get_stats()
893
+ elapsed = time.time() - start_time
894
+ status = (
895
+ f"βœ… Processing complete\n"
896
+ f"Method: {method}\n"
897
+ f"Time: {elapsed:.2f}s\n"
898
+ f"Output: {out}\n\n"
899
+ f"GPU Stats:\n"
900
+ f"β€’ Mem: {final_gpu.get('memory_used', 0):.2f}GB / {final_gpu.get('memory_total', 0):.2f}GB"
901
+ f" ({final_gpu.get('memory_percent', 0):.1f}%)\n"
902
+ f"β€’ Util: {final_gpu.get('gpu_util', 0)}%\n"
903
+ f"β€’ CUDA: {CUDA_AVAILABLE}"
904
+ )
905
+ return out, out, status
906
+
907
+ except Exception as e:
908
+ traceback.print_exc()
909
+ msg = (
910
+ f"❌ Error: {e}\n"
911
+ f"- MatAnyone imported: {MATANYONE_IMPORTED}\n"
912
+ f"- MatAnyone initialized: {matanyone_processor.initialized}\n"
913
+ f"- MatAnyone verified: {matanyone_processor.verified}\n"
914
+ f"- MatAnyone last_error: {matanyone_processor.last_error}\n"
915
+ f"- SAM2 imported: {SAM2_IMPORTED}\n"
916
+ f"- SAM2 verified: {SAM2_AVAILABLE}\n"
917
+ f"- rembg: {REMBG_AVAILABLE}\n"
918
+ f"- CUDA: {CUDA_AVAILABLE}\n"
919
+ f"(see server logs for traceback)"
920
+ )
921
+ return None, None, msg
922
+
923
+ def gradio_run_self_test(): return run_self_test()
924
+ def show_matanyone_diag():
925
+ try:
926
+ ok = matanyone_processor.initialized and matanyone_processor.verified
927
+ return "READY βœ…" if ok else (matanyone_processor.last_error or "Not initialized yet")
928
+ except Exception as e:
929
+ return f"Diag error: {e}"
930
+
931
+ # =========================
932
+ # UI
933
+ # =========================
934
+ with gr.Blocks(title="Video Background Replacer - GPU Optimized", theme=gr.themes.Soft()) as demo:
935
+ gr.Markdown("# 🎬 Video Background Replacer (GPU Optimized)")
936
+ gr.Markdown("All green checks are earned by real tests. No guesses.")
937
+ gpu_status = f"βœ… {GPU_NAME}" if CUDA_AVAILABLE else "❌ CPU Only"
938
+ matany_status = "βœ… Module Imported" if MATANYONE_IMPORTED else "❌ Not Importable"
939
+ sam2_status = "βœ… Verified" if SAM2_AVAILABLE else ("⚠️ Imported but unverified" if SAM2_IMPORTED else "❌ Not Ready")
940
+ rembg_status = "βœ… Ready" if REMBG_AVAILABLE else "❌ Not Available"
941
+ torch_status = "βœ… GPU" if CUDA_AVAILABLE else "❌ CPU"
942
+ status_html = f"""
943
+ <div style='padding: 15px; background: #f8f9fa; border-radius: 8px; margin-bottom: 20px; border-left: 4px solid #6c757d;'>
944
+ <h4 style='margin-top: 0;'>πŸ–₯️ System Status (verified)</h4>
945
+ <strong>GPU:</strong> {gpu_status}<br>
946
+ <strong>Device:</strong> {DEVICE}<br>
947
+ <strong>MatAnyone module:</strong> {matany_status}<br>
948
+ <strong>MatAnyone ready:</strong> {"βœ… Yes" if getattr(matanyone_processor, "verified", False) else "❌ No"}<br>
949
+ <strong>SAM2:</strong> {sam2_status}<br>
950
+ <strong>rembg:</strong> {rembg_status}<br>
951
+ <strong>PyTorch:</strong> {torch_status}
952
+ </div>
953
+ """
954
+ gr.HTML(status_html)
955
+
956
+ with gr.Row():
957
+ with gr.Column():
958
+ video_input = gr.Video(label="πŸ“Ή Input Video")
959
+ bg_input = gr.Image(label="πŸ–ΌοΈ Background Image (optional)", type="filepath")
960
+ bg_preset = gr.Dropdown(
961
+ label="🎨 Background Preset (if no image)",
962
+ choices=["Office (Soft Gray)","Studio (Charcoal)","Nature (Green Tint)","Brand Blue","Plain Light"],
963
+ value="Office (Soft Gray)",
964
+ )
965
+ use_matanyone = gr.Checkbox(label="πŸš€ Use MatAnyone (GPU accelerated, best quality)",
966
+ value=MATANYONE_IMPORTED, interactive=MATANYONE_IMPORTED)
967
+ stabilize = gr.Checkbox(label="🧱 Stabilize short clips (pre-roll first frame)",
968
+ value=os.getenv("MATANYONE_STABILIZE","true").lower()=="true")
969
+ preroll_frames = gr.Slider(label="Pre-roll frames", minimum=0, maximum=24, step=1,
970
+ value=int(os.getenv("MATANYONE_PREROLL_FRAMES","12")))
971
+ process_btn = gr.Button("πŸš€ Process Video", variant="primary")
972
+ gr.Markdown("### πŸ”Ž Self-Verification"); selftest_btn = gr.Button("Run Self-Test")
973
+ selftest_out = gr.Textbox(label="Self-Test Report", lines=16)
974
+ gr.Markdown("### πŸ›  MatAnyone Diagnostics"); mat_diag_btn = gr.Button("Show MatAnyone Diagnostics")
975
+ mat_diag_out = gr.Textbox(label="MatAnyone Last Error / Status", lines=6)
976
+ with gr.Column():
977
+ output_video = gr.Video(label="✨ Result")
978
+ download_file = gr.File(label="πŸ’Ύ Download")
979
+ status_text = gr.Textbox(label="πŸ“Š Status & Performance", lines=8)
980
+
981
+ process_btn.click(fn=gradio_interface_optimized,
982
+ inputs=[video_input, bg_input, use_matanyone, bg_preset, stabilize, preroll_frames],
983
+ outputs=[output_video, download_file, status_text])
984
+ selftest_btn.click(fn=gradio_run_self_test, inputs=[], outputs=[selftest_out])
985
+ mat_diag_btn.click(fn=show_matanyone_diag, inputs=[], outputs=[mat_diag_out])
986
+
987
+ gr.Markdown("---")
988
+ gr.Markdown("""
989
+ **Notes**
990
+ - K-Governor clamps/pads Top-K inside MatAnyone to prevent 'k out of range' crashes.
991
+ - Short-clip stabilizer pre-roll is trimmed out of alpha automatically.
992
+ - SAM2 shows βœ… only after a real micro-inference passes.
993
+ - FFmpeg/MoviePy, CUDA, and rembg are validated by actually running them.
994
+ """)
995
+
996
+ # =========================
997
+ # Proactive warmup at boot (before UI render)
998
+ # =========================
999
+ try:
1000
+ if MATANYONE_IMPORTED and os.getenv("USE_MATANYONE","true").lower()=="true":
1001
+ print("Warming up MatAnyone…")
1002
+ matanyone_processor.initialize()
1003
+ print("MatAnyone warmup complete.")
1004
+ except Exception as e:
1005
+ print(f"MatAnyone warmup failed (non-fatal): {e}")
1006
+ traceback.print_exc()
1007
+
1008
+ # =========================
1009
+ # Late re-sanitization for external .env overrides
1010
+ # =========================
1011
+ def _re_sanitize_threads():
1012
+ for v in ("OMP_NUM_THREADS", "MKL_NUM_THREADS"):
1013
+ val = os.environ.get(v, "")
1014
+ if not str(val).isdigit():
1015
+ os.environ[v] = "2"
1016
+ print(f"{v} had invalid value; reset to 2")
1017
+
1018
+ if os.getenv("STRICT_ENV_GUARD","1") in ("1","true","TRUE"):
1019
+ _re_sanitize_threads()
1020
+
1021
+ # =========================
1022
+ # Entrypoint / CLI self-test
1023
+ # =========================
1024
+ if __name__ == "__main__":
1025
+ if "--self-test" in sys.argv:
1026
+ report = run_self_test(); print(report)
1027
+ exit_code = 0
1028
+ for line in report.splitlines():
1029
+ if line.startswith("❌"): exit_code = 2; break
1030
+ sys.exit(exit_code)
1031
+ print("\n" + "="*50)
1032
+ print("πŸš€ Starting GPU-optimized Gradio app…")
1033
+ print("URL: http://0.0.0.0:7860")
1034
+ print(f"GPU Monitoring: {'Active' if CUDA_AVAILABLE else 'Disabled'}")
1035
+ print("="*50 + "\n")
1036
+ demo.launch(server_name="0.0.0.0", server_port=7860)