MogensR commited on
Commit
f4b2697
·
1 Parent(s): 1af72fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -26
app.py CHANGED
@@ -1,12 +1,9 @@
1
- hERE IS MY APP
2
-
3
-
4
  #!/usr/bin/env python3
5
  """
6
  Final Fixed Video Background Replacement
7
  Uses proper functions from utilities.py to avoid transparency issues
8
- NEW: Added GPU detection, model caching, batch processing support,
9
- and improved error handling
10
  """
11
  import sys
12
  import cv2
@@ -21,6 +18,7 @@
21
  from typing import Optional, Tuple, Dict, Any
22
  import logging
23
  from huggingface_hub import hf_hub_download
 
24
  # Import utilities - CRITICAL: Use these functions, don't duplicate!
25
  from utilities import (
26
  segment_person_hq,
@@ -31,25 +29,30 @@
31
  PROFESSIONAL_BACKGROUNDS,
32
  validate_video_file
33
  )
 
34
  # Import two-stage processor if available
35
  try:
36
  from two_stage_processor import TwoStageProcessor, CHROMA_PRESETS
37
  TWO_STAGE_AVAILABLE = True
38
  except ImportError:
39
  TWO_STAGE_AVAILABLE = False
 
40
  logging.basicConfig(level=logging.INFO)
41
  logger = logging.getLogger(__name__)
 
42
  # ============================================================================ #
43
  # OPTIMIZATION SETTINGS
44
  # ============================================================================ #
45
  KEYFRAME_INTERVAL = 5 # Process MatAnyone every 5th frame
46
  FRAME_SKIP = 1 # Process every frame (set to 2 for every other frame)
47
  MEMORY_CLEANUP_INTERVAL = 30 # Clean memory every 30 frames
 
48
  # ============================================================================ #
49
  # MODEL CACHING SYSTEM
50
  # ============================================================================ #
51
  CACHE_DIR = Path("/tmp/model_cache")
52
  CACHE_DIR.mkdir(exist_ok=True, parents=True)
 
53
  # ============================================================================ #
54
  # GLOBAL MODEL STATE
55
  # ============================================================================ #
@@ -59,14 +62,33 @@
59
  loading_lock = threading.Lock()
60
  two_stage_processor = None
61
  PROCESS_CANCELLED = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # ============================================================================ #
63
  # SAM2 LOADER WITH VALIDATION
64
  # ============================================================================ #
65
- def load_sam2_predictor_fixed(device: str = "cuda", progress_callback: Optional[callable] = None) -> Any:
66
  """Load SAM2 with proper error handling and validation"""
67
  def _prog(pct: float, desc: str):
68
  if progress_callback:
69
  progress_callback(pct, desc)
 
70
  # Format progress info for display in the UI
71
  if "Frame" in desc and "|" in desc:
72
  parts = desc.split("|")
@@ -87,8 +109,10 @@ def _prog(pct: float, desc: str):
87
  f.write(display_text)
88
  except Exception as e:
89
  logger.warning(f"Error writing processing info: {e}")
 
90
  try:
91
  _prog(0.1, "Initializing SAM2...")
 
92
  # Download checkpoint with caching
93
  checkpoint_path = hf_hub_download(
94
  repo_id="facebook/sam2-hiera-large",
@@ -96,14 +120,18 @@ def _prog(pct: float, desc: str):
96
  cache_dir=str(CACHE_DIR / "sam2_checkpoint"),
97
  force_download=False
98
  )
 
99
  _prog(0.5, "SAM2 checkpoint downloaded, building model...")
 
100
  # Import and build
101
  from sam2.build_sam import build_sam2
102
  from sam2.sam2_image_predictor import SAM2ImagePredictor
 
103
  # Build model with explicit config
104
  sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path)
105
  sam2_model.to(device)
106
  predictor = SAM2ImagePredictor(sam2_model)
 
107
  # Test the predictor with dummy data
108
  _prog(0.8, "Testing SAM2 functionality...")
109
  test_image = np.zeros((256, 256, 3), dtype=np.uint8)
@@ -115,15 +143,19 @@ def _prog(pct: float, desc: str):
115
  point_labels=test_labels,
116
  multimask_output=False
117
  )
 
118
  if masks is None or len(masks) == 0:
119
  raise Exception("SAM2 predictor test failed - no masks generated")
 
120
  _prog(1.0, "SAM2 loaded and validated successfully!")
121
- logger.info("SAM2 predictor loaded and tested successfully")
122
  return predictor
 
123
  except Exception as e:
124
  logger.error(f"SAM2 loading failed: {str(e)}")
125
  logger.error(f"Full traceback: {traceback.format_exc()}")
126
  raise Exception(f"SAM2 loading failed: {str(e)}")
 
127
  # ============================================================================ #
128
  # MATANYONE LOADER WITH VALIDATION
129
  # ============================================================================ #
@@ -132,15 +164,19 @@ def load_matanyone_fixed(progress_callback: Optional[callable] = None) -> Any:
132
  def _prog(pct: float, desc: str):
133
  if progress_callback:
134
  progress_callback(pct, desc)
 
135
  try:
136
  _prog(0.2, "Loading MatAnyone...")
 
137
  from matanyone import InferenceCore
138
  processor = InferenceCore("PeiqingYang/MatAnyone")
 
139
  # Test MatAnyone with dummy data
140
  _prog(0.8, "Testing MatAnyone functionality...")
141
  test_image = np.zeros((256, 256, 3), dtype=np.uint8)
142
  test_mask = np.zeros((256, 256), dtype=np.uint8)
143
  test_mask[64:192, 64:192] = 255
 
144
  # Test the processor
145
  try:
146
  if hasattr(processor, 'process') or hasattr(processor, '__call__'):
@@ -149,13 +185,16 @@ def _prog(pct: float, desc: str):
149
  logger.warning("MatAnyone interface unclear, will use fallback refinement")
150
  except Exception as test_e:
151
  logger.warning(f"MatAnyone test failed: {test_e}, will use enhanced OpenCV")
 
152
  _prog(1.0, "MatAnyone loaded successfully!")
153
- logger.info("MatAnyone processor loaded successfully")
154
  return processor
 
155
  except Exception as e:
156
  logger.error(f"MatAnyone loading failed: {str(e)}")
157
  logger.error(f"Full traceback: {traceback.format_exc()}")
158
  raise Exception(f"MatAnyone loading failed: {str(e)}")
 
159
  # ============================================================================ #
160
  # MODEL MANAGEMENT FUNCTIONS
161
  # ============================================================================ #
@@ -165,53 +204,68 @@ def get_model_status() -> Dict[str, str]:
165
  return {
166
  'sam2': 'Ready' if sam2_predictor is not None else 'Not loaded',
167
  'matanyone': 'Ready' if matanyone_model is not None else 'Not loaded',
168
- 'validated': models_loaded
 
169
  }
 
170
  def get_cache_status() -> Dict[str, Any]:
171
  """Get current cache status"""
172
  return {
173
  "sam2_loaded": sam2_predictor is not None,
174
  "matanyone_loaded": matanyone_model is not None,
175
  "models_validated": models_loaded,
176
- "two_stage_available": TWO_STAGE_AVAILABLE
 
177
  }
 
178
  def load_models_with_validation(progress_callback: Optional[callable] = None) -> str:
179
  """Load models with comprehensive validation"""
180
  global sam2_predictor, matanyone_model, models_loaded, two_stage_processor, PROCESS_CANCELLED
 
181
  with loading_lock:
182
  if models_loaded and not PROCESS_CANCELLED:
183
  return "Models already loaded and validated"
 
184
  try:
185
  PROCESS_CANCELLED = False
186
  start_time = time.time()
187
- device = "cuda" if torch.cuda.is_available() else "cpu"
188
- logger.info(f"Starting model loading on {device}")
189
  if progress_callback:
190
- progress_callback(0.0, "Starting model loading...")
 
191
  # Load SAM2 with validation
192
- sam2_predictor = load_sam2_predictor_fixed(device=device, progress_callback=progress_callback)
 
193
  if PROCESS_CANCELLED:
194
  return "Model loading cancelled by user"
 
195
  # Load MatAnyone with validation
196
  matanyone_model = load_matanyone_fixed(progress_callback=progress_callback)
 
197
  if PROCESS_CANCELLED:
198
  return "Model loading cancelled by user"
 
199
  models_loaded = True
 
200
  # Initialize two-stage processor if available
201
  if TWO_STAGE_AVAILABLE:
202
  two_stage_processor = TwoStageProcessor(sam2_predictor, matanyone_model)
203
  logger.info("Two-stage processor initialized")
 
204
  load_time = time.time() - start_time
205
- message = f"SUCCESS: SAM2 + MatAnyone loaded and validated in {load_time:.1f}s"
206
  if TWO_STAGE_AVAILABLE:
207
  message += " (Two-stage mode available)"
208
  logger.info(message)
209
  return message
 
210
  except Exception as e:
211
  models_loaded = False
212
  error_msg = f"Model loading failed: {str(e)}"
213
  logger.error(error_msg)
214
  return error_msg
 
215
  # ============================================================================ #
216
  # MAIN VIDEO PROCESSING - USING UTILITIES FUNCTIONS
217
  # ============================================================================ #
@@ -227,21 +281,28 @@ def process_video_fixed(
227
  ) -> Tuple[Optional[str], str]:
228
  """Optimized video processing using proper functions from utilities"""
229
  global PROCESS_CANCELLED
 
230
  if PROCESS_CANCELLED:
231
  return None, "Processing cancelled by user"
 
232
  if not models_loaded:
233
  return None, "Models not loaded. Call load_models_with_validation() first."
 
234
  if not video_path or not os.path.exists(video_path):
235
  return None, f"Video file not found: {video_path}"
 
236
  # Validate video file
237
  is_valid, validation_msg = validate_video_file(video_path)
238
  if not is_valid:
239
  return None, f"Invalid video: {validation_msg}"
 
240
  def _prog(pct: float, desc: str):
241
  if PROCESS_CANCELLED:
242
  raise Exception("Processing cancelled by user")
 
243
  if progress_callback:
244
  progress_callback(pct, desc)
 
245
  # Update processing info file
246
  if "Frame" in desc and "|" in desc:
247
  parts = desc.split("|")
@@ -249,6 +310,7 @@ def _prog(pct: float, desc: str):
249
  time_info = parts[1].strip() if len(parts) > 1 else ""
250
  fps_info = parts[2].strip() if len(parts) > 2 else ""
251
  eta_info = parts[3].strip() if len(parts) > 3 else ""
 
252
  display_text = f"""📊 PROCESSING STATUS
253
  ━━━━━━━━━━━━━━━━━━━━━━━━━━
254
  🎬 {frame_info}
@@ -262,24 +324,31 @@ def _prog(pct: float, desc: str):
262
  f.write(display_text)
263
  except Exception as e:
264
  logger.warning(f"Error writing processing info: {e}")
 
265
  try:
266
- _prog(0.0, f"Starting {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'} processing...")
 
267
  # Check if two-stage mode is requested
268
  if use_two_stage:
269
  if not TWO_STAGE_AVAILABLE:
270
  return None, "Two-stage mode not available. Please add two_stage_processor.py file."
 
271
  if two_stage_processor is None:
272
  return None, "Two-stage processor not initialized. Please reload models."
 
273
  _prog(0.05, "Starting TWO-STAGE green screen processing...")
 
274
  # Get video dimensions
275
  cap = cv2.VideoCapture(video_path)
276
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
277
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
278
  cap.release()
 
279
  # Prepare background
280
  if background_choice == "custom" and custom_background_path:
281
  if not os.path.exists(custom_background_path):
282
  return None, f"Custom background not found: {custom_background_path}"
 
283
  background = cv2.imread(custom_background_path)
284
  if background is None:
285
  return None, "Could not read custom background image."
@@ -291,11 +360,14 @@ def _prog(pct: float, desc: str):
291
  background_name = bg_config["name"]
292
  else:
293
  return None, f"Invalid background selection: {background_choice}"
 
294
  # Get chroma settings
295
  chroma_settings = CHROMA_PRESETS.get(chroma_preset, CHROMA_PRESETS['standard'])
 
296
  # Run two-stage pipeline
297
  timestamp = int(time.time())
298
  final_output = f"/tmp/twostage_final_{timestamp}.mp4"
 
299
  result, message = two_stage_processor.process_full_pipeline(
300
  video_path,
301
  background,
@@ -303,13 +375,17 @@ def _prog(pct: float, desc: str):
303
  chroma_settings=chroma_settings,
304
  progress_callback=_prog
305
  )
 
306
  if PROCESS_CANCELLED:
307
  return None, "Processing cancelled by user"
 
308
  if result is None:
309
  return None, message
 
310
  # Add audio back
311
  _prog(0.9, "Adding audio...")
312
  final_with_audio = f"/tmp/twostage_audio_{timestamp}.mp4"
 
313
  try:
314
  audio_cmd = (
315
  f'ffmpeg -y -i "{final_output}" -i "{video_path}" '
@@ -324,34 +400,46 @@ def _prog(pct: float, desc: str):
324
  except Exception as e:
325
  logger.warning(f"Audio processing error: {e}")
326
  final_with_audio = final_output # Fallback to video without audio
 
327
  _prog(1.0, "TWO-STAGE processing complete!")
 
328
  success_message = (
329
  f"TWO-STAGE Success!\n"
330
  f"Background: {background_name}\n"
331
  f"Method: Green Screen Chroma Key\n"
332
  f"Preset: {chroma_preset}\n"
333
- f"Quality: Professional cinema-grade"
 
334
  )
 
335
  return final_output, success_message
 
336
  # Single-stage processing
337
- _prog(0.05, "Starting SINGLE-STAGE processing...")
 
338
  cap = cv2.VideoCapture(video_path)
339
  if not cap.isOpened():
340
  return None, "Could not open video file."
 
341
  fps = cap.get(cv2.CAP_PROP_FPS)
342
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
343
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
344
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
345
  if total_frames == 0:
346
  return None, "Video appears to be empty."
 
347
  # Log video info
348
- logger.info(f"Video info: {frame_width}x{frame_height}, {fps}fps, {total_frames} frames")
 
349
  # Prepare background
350
  background = None
351
  background_name = ""
 
352
  if background_choice == "custom" and custom_background_path:
353
  if not os.path.exists(custom_background_path):
354
  return None, f"Custom background not found: {custom_background_path}"
 
355
  background = cv2.imread(custom_background_path)
356
  if background is None:
357
  return None, "Could not read custom background image."
@@ -363,11 +451,15 @@ def _prog(pct: float, desc: str):
363
  background_name = bg_config["name"]
364
  else:
365
  return None, f"Invalid background selection: {background_choice}"
 
366
  if background is None:
367
  return None, "Failed to create background."
 
368
  timestamp = int(time.time())
369
  fourcc = cv2.VideoWriter_fourcc(*'avc1') # H.264 for better compatibility
370
- _prog(0.1, f"Processing {total_frames} frames with {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'} processing...")
 
 
371
  # Create temporary output for preview if needed
372
  if preview_mask or preview_greenscreen:
373
  temp_output = f"/tmp/preview_{timestamp}.mp4"
@@ -375,13 +467,17 @@ def _prog(pct: float, desc: str):
375
  else:
376
  final_path = f"/tmp/output_{timestamp}.mp4"
377
  final_writer = cv2.VideoWriter(final_path, fourcc, fps, (frame_width, frame_height))
 
378
  if not final_writer.isOpened():
379
  return None, "Could not create output video file."
 
380
  frame_count = 0
381
  successful_frames = 0
382
  last_refined_mask = None
 
383
  # Processing stats
384
  start_time = time.time()
 
385
  while True:
386
  if PROCESS_CANCELLED:
387
  cap.release()
@@ -389,13 +485,16 @@ def _prog(pct: float, desc: str):
389
  if os.path.exists(final_path):
390
  os.remove(final_path)
391
  return None, "Processing cancelled by user"
 
392
  ret, frame = cap.read()
393
  if not ret:
394
  break
 
395
  # Skip frames if FRAME_SKIP > 1
396
  if frame_count % FRAME_SKIP != 0:
397
  frame_count += 1
398
  continue
 
399
  try:
400
  # Update progress with detailed timing info and ETA
401
  elapsed_time = time.time() - start_time
@@ -403,13 +502,17 @@ def _prog(pct: float, desc: str):
403
  remaining_frames = total_frames - frame_count
404
  eta_seconds = remaining_frames / current_fps if current_fps > 0 else 0
405
  eta_display = f"{int(eta_seconds//60)}m {int(eta_seconds%60)}s" if eta_seconds > 60 else f"{int(eta_seconds)}s"
406
- progress_msg = f"Frame {frame_count + 1}/{total_frames} | {elapsed_time:.1f}s | {current_fps:.1f} fps | ETA: {eta_display}"
 
 
407
  # Log and display progress
408
  logger.info(progress_msg)
409
  _prog(0.1 + (frame_count / max(1, total_frames)) * 0.8, progress_msg)
 
410
  # CRITICAL: Use functions from utilities.py, not local implementations!
411
  # SAM2 segmentation using utilities function
412
  mask = segment_person_hq(frame, sam2_predictor)
 
413
  if preview_mask:
414
  # Save mask visualization
415
  mask_vis = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
@@ -417,15 +520,17 @@ def _prog(pct: float, desc: str):
417
  final_writer.write(mask_vis)
418
  frame_count += 1
419
  continue
 
420
  # MatAnyone refinement on keyframes using utilities function
421
  if (frame_count % KEYFRAME_INTERVAL == 0) or (last_refined_mask is None):
422
  refined_mask = refine_mask_hq(frame, mask, matanyone_model)
423
  last_refined_mask = refined_mask.copy()
424
- logger.info(f"Keyframe refinement at frame {frame_count}")
425
  else:
426
  # Blend SAM2 mask with last refined mask for temporal smoothness
427
  alpha = 0.7
428
  refined_mask = cv2.addWeighted(mask, alpha, last_refined_mask, 1-alpha, 0)
 
429
  if preview_greenscreen:
430
  # Create green screen preview
431
  green_bg = np.zeros_like(frame)
@@ -437,15 +542,19 @@ def _prog(pct: float, desc: str):
437
  final_writer.write(preview_frame.astype(np.uint8))
438
  frame_count += 1
439
  continue
 
440
  # CRITICAL: Use replace_background_hq from utilities which has the transparency fix!
441
  result_frame = replace_background_hq(frame, refined_mask, background)
442
  final_writer.write(result_frame)
443
  successful_frames += 1
 
444
  except Exception as frame_error:
445
  logger.warning(f"Error processing frame {frame_count}: {frame_error}")
446
  # Write original frame if processing fails
447
  final_writer.write(frame)
 
448
  frame_count += 1
 
449
  # Memory management
450
  if frame_count % MEMORY_CLEANUP_INTERVAL == 0:
451
  gc.collect()
@@ -454,25 +563,32 @@ def _prog(pct: float, desc: str):
454
  elapsed = time.time() - start_time
455
  fps_actual = frame_count / elapsed
456
  eta = (total_frames - frame_count) / fps_actual if fps_actual > 0 else 0
457
- logger.info(f"Progress: {frame_count}/{total_frames}, FPS: {fps_actual:.1f}, ETA: {eta:.0f}s")
 
458
  cap.release()
459
  final_writer.release()
 
460
  if PROCESS_CANCELLED:
461
  if os.path.exists(final_path):
462
  os.remove(final_path)
463
  return None, "Processing cancelled by user"
 
464
  if successful_frames == 0:
465
  return None, "No frames were processed successfully with AI."
 
466
  # Calculate processing stats
467
  total_time = time.time() - start_time
468
  avg_fps = frame_count / total_time if total_time > 0 else 0
 
469
  _prog(0.9, "Finalizing output...")
 
470
  if preview_mask or preview_greenscreen:
471
  final_output = temp_output
472
  else:
473
  # Add audio back for final output
474
  _prog(0.9, "Adding audio...")
475
  final_output = f"/tmp/final_{timestamp}.mp4"
 
476
  try:
477
  audio_cmd = (
478
  f'ffmpeg -y -i "{final_path}" -i "{video_path}" '
@@ -486,13 +602,16 @@ def _prog(pct: float, desc: str):
486
  except Exception as e:
487
  logger.warning(f"Audio processing error: {e}")
488
  shutil.copy2(final_path, final_output)
 
489
  # Cleanup
490
  try:
491
  if os.path.exists(final_path):
492
  os.remove(final_path)
493
  except Exception as e:
494
  logger.warning(f"Cleanup error: {e}")
 
495
  _prog(1.0, "Processing complete!")
 
496
  success_message = (
497
  f"Success!\n"
498
  f"Background: {background_name}\n"
@@ -502,12 +621,16 @@ def _prog(pct: float, desc: str):
502
  f"Processing time: {total_time:.1f}s\n"
503
  f"Average FPS: {avg_fps:.1f}\n"
504
  f"Keyframe interval: {KEYFRAME_INTERVAL}\n"
505
- f"Mode: {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'}"
 
506
  )
 
507
  return final_output, success_message
 
508
  except Exception as e:
509
  logger.error(f"Processing error: {traceback.format_exc()}")
510
  return None, f"Processing Error: {str(e)}"
 
511
  # ============================================================================ #
512
  # MAIN - IMPORT UI COMPONENTS
513
  # ============================================================================ #
@@ -517,13 +640,18 @@ def main():
517
  print(f"Keyframe interval: {KEYFRAME_INTERVAL} frames")
518
  print(f"Frame skip: {FRAME_SKIP} (1=all frames, 2=every other)")
519
  print(f"Two-stage mode: {'AVAILABLE' if TWO_STAGE_AVAILABLE else 'NOT AVAILABLE'}")
 
520
  print("Loading UI components...")
 
521
  # Import UI components
522
  from ui_components import create_interface
 
523
  os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True)
524
  CACHE_DIR.mkdir(exist_ok=True, parents=True)
 
525
  print("Creating interface...")
526
  demo = create_interface()
 
527
  print("Launching...")
528
  demo.launch(
529
  server_name="0.0.0.0",
@@ -533,10 +661,10 @@ def main():
533
  debug=True,
534
  enable_queue=True
535
  )
 
536
  except Exception as e:
537
  logger.error(f"Startup failed: {e}")
538
  print(f"Startup failed: {e}")
 
539
  if __name__ == "__main__":
540
  main()
541
-
542
- pLEASE UPDATE
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
  Final Fixed Video Background Replacement
4
  Uses proper functions from utilities.py to avoid transparency issues
5
+ NEW: Added automatic device detection for Hugging Face Spaces compatibility,
6
+ improved error handling, and better resource management
7
  """
8
  import sys
9
  import cv2
 
18
  from typing import Optional, Tuple, Dict, Any
19
  import logging
20
  from huggingface_hub import hf_hub_download
21
+
22
  # Import utilities - CRITICAL: Use these functions, don't duplicate!
23
  from utilities import (
24
  segment_person_hq,
 
29
  PROFESSIONAL_BACKGROUNDS,
30
  validate_video_file
31
  )
32
+
33
  # Import two-stage processor if available
34
  try:
35
  from two_stage_processor import TwoStageProcessor, CHROMA_PRESETS
36
  TWO_STAGE_AVAILABLE = True
37
  except ImportError:
38
  TWO_STAGE_AVAILABLE = False
39
+
40
  logging.basicConfig(level=logging.INFO)
41
  logger = logging.getLogger(__name__)
42
+
43
  # ============================================================================ #
44
  # OPTIMIZATION SETTINGS
45
  # ============================================================================ #
46
  KEYFRAME_INTERVAL = 5 # Process MatAnyone every 5th frame
47
  FRAME_SKIP = 1 # Process every frame (set to 2 for every other frame)
48
  MEMORY_CLEANUP_INTERVAL = 30 # Clean memory every 30 frames
49
+
50
  # ============================================================================ #
51
  # MODEL CACHING SYSTEM
52
  # ============================================================================ #
53
  CACHE_DIR = Path("/tmp/model_cache")
54
  CACHE_DIR.mkdir(exist_ok=True, parents=True)
55
+
56
  # ============================================================================ #
57
  # GLOBAL MODEL STATE
58
  # ============================================================================ #
 
62
  loading_lock = threading.Lock()
63
  two_stage_processor = None
64
  PROCESS_CANCELLED = False
65
+
66
+ # ============================================================================ #
67
+ # DEVICE DETECTION FOR HUGGING FACE SPACES
68
+ # ============================================================================ #
69
+ def get_device():
70
+ """Automatically detect the best available device (CPU or GPU)"""
71
+ if torch.cuda.is_available():
72
+ # Get the current CUDA device name
73
+ device_name = torch.cuda.get_device_name(0)
74
+ logger.info(f"Using GPU: {device_name}")
75
+ return "cuda"
76
+ else:
77
+ logger.info("Using CPU (no GPU available)")
78
+ return "cpu"
79
+
80
+ # Set the device globally
81
+ DEVICE = get_device()
82
+
83
  # ============================================================================ #
84
  # SAM2 LOADER WITH VALIDATION
85
  # ============================================================================ #
86
+ def load_sam2_predictor_fixed(device: str = DEVICE, progress_callback: Optional[callable] = None) -> Any:
87
  """Load SAM2 with proper error handling and validation"""
88
  def _prog(pct: float, desc: str):
89
  if progress_callback:
90
  progress_callback(pct, desc)
91
+
92
  # Format progress info for display in the UI
93
  if "Frame" in desc and "|" in desc:
94
  parts = desc.split("|")
 
109
  f.write(display_text)
110
  except Exception as e:
111
  logger.warning(f"Error writing processing info: {e}")
112
+
113
  try:
114
  _prog(0.1, "Initializing SAM2...")
115
+
116
  # Download checkpoint with caching
117
  checkpoint_path = hf_hub_download(
118
  repo_id="facebook/sam2-hiera-large",
 
120
  cache_dir=str(CACHE_DIR / "sam2_checkpoint"),
121
  force_download=False
122
  )
123
+
124
  _prog(0.5, "SAM2 checkpoint downloaded, building model...")
125
+
126
  # Import and build
127
  from sam2.build_sam import build_sam2
128
  from sam2.sam2_image_predictor import SAM2ImagePredictor
129
+
130
  # Build model with explicit config
131
  sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path)
132
  sam2_model.to(device)
133
  predictor = SAM2ImagePredictor(sam2_model)
134
+
135
  # Test the predictor with dummy data
136
  _prog(0.8, "Testing SAM2 functionality...")
137
  test_image = np.zeros((256, 256, 3), dtype=np.uint8)
 
143
  point_labels=test_labels,
144
  multimask_output=False
145
  )
146
+
147
  if masks is None or len(masks) == 0:
148
  raise Exception("SAM2 predictor test failed - no masks generated")
149
+
150
  _prog(1.0, "SAM2 loaded and validated successfully!")
151
+ logger.info(f"SAM2 predictor loaded and tested successfully on {device}")
152
  return predictor
153
+
154
  except Exception as e:
155
  logger.error(f"SAM2 loading failed: {str(e)}")
156
  logger.error(f"Full traceback: {traceback.format_exc()}")
157
  raise Exception(f"SAM2 loading failed: {str(e)}")
158
+
159
  # ============================================================================ #
160
  # MATANYONE LOADER WITH VALIDATION
161
  # ============================================================================ #
 
164
  def _prog(pct: float, desc: str):
165
  if progress_callback:
166
  progress_callback(pct, desc)
167
+
168
  try:
169
  _prog(0.2, "Loading MatAnyone...")
170
+
171
  from matanyone import InferenceCore
172
  processor = InferenceCore("PeiqingYang/MatAnyone")
173
+
174
  # Test MatAnyone with dummy data
175
  _prog(0.8, "Testing MatAnyone functionality...")
176
  test_image = np.zeros((256, 256, 3), dtype=np.uint8)
177
  test_mask = np.zeros((256, 256), dtype=np.uint8)
178
  test_mask[64:192, 64:192] = 255
179
+
180
  # Test the processor
181
  try:
182
  if hasattr(processor, 'process') or hasattr(processor, '__call__'):
 
185
  logger.warning("MatAnyone interface unclear, will use fallback refinement")
186
  except Exception as test_e:
187
  logger.warning(f"MatAnyone test failed: {test_e}, will use enhanced OpenCV")
188
+
189
  _prog(1.0, "MatAnyone loaded successfully!")
190
+ logger.info(f"MatAnyone processor loaded successfully on {DEVICE}")
191
  return processor
192
+
193
  except Exception as e:
194
  logger.error(f"MatAnyone loading failed: {str(e)}")
195
  logger.error(f"Full traceback: {traceback.format_exc()}")
196
  raise Exception(f"MatAnyone loading failed: {str(e)}")
197
+
198
  # ============================================================================ #
199
  # MODEL MANAGEMENT FUNCTIONS
200
  # ============================================================================ #
 
204
  return {
205
  'sam2': 'Ready' if sam2_predictor is not None else 'Not loaded',
206
  'matanyone': 'Ready' if matanyone_model is not None else 'Not loaded',
207
+ 'validated': models_loaded,
208
+ 'device': DEVICE
209
  }
210
+
211
  def get_cache_status() -> Dict[str, Any]:
212
  """Get current cache status"""
213
  return {
214
  "sam2_loaded": sam2_predictor is not None,
215
  "matanyone_loaded": matanyone_model is not None,
216
  "models_validated": models_loaded,
217
+ "two_stage_available": TWO_STAGE_AVAILABLE,
218
+ "device": DEVICE
219
  }
220
+
221
  def load_models_with_validation(progress_callback: Optional[callable] = None) -> str:
222
  """Load models with comprehensive validation"""
223
  global sam2_predictor, matanyone_model, models_loaded, two_stage_processor, PROCESS_CANCELLED
224
+
225
  with loading_lock:
226
  if models_loaded and not PROCESS_CANCELLED:
227
  return "Models already loaded and validated"
228
+
229
  try:
230
  PROCESS_CANCELLED = False
231
  start_time = time.time()
232
+ logger.info(f"Starting model loading on {DEVICE}")
233
+
234
  if progress_callback:
235
+ progress_callback(0.0, f"Starting model loading on {DEVICE}...")
236
+
237
  # Load SAM2 with validation
238
+ sam2_predictor = load_sam2_predictor_fixed(device=DEVICE, progress_callback=progress_callback)
239
+
240
  if PROCESS_CANCELLED:
241
  return "Model loading cancelled by user"
242
+
243
  # Load MatAnyone with validation
244
  matanyone_model = load_matanyone_fixed(progress_callback=progress_callback)
245
+
246
  if PROCESS_CANCELLED:
247
  return "Model loading cancelled by user"
248
+
249
  models_loaded = True
250
+
251
  # Initialize two-stage processor if available
252
  if TWO_STAGE_AVAILABLE:
253
  two_stage_processor = TwoStageProcessor(sam2_predictor, matanyone_model)
254
  logger.info("Two-stage processor initialized")
255
+
256
  load_time = time.time() - start_time
257
+ message = f"SUCCESS: SAM2 + MatAnyone loaded and validated in {load_time:.1f}s on {DEVICE}"
258
  if TWO_STAGE_AVAILABLE:
259
  message += " (Two-stage mode available)"
260
  logger.info(message)
261
  return message
262
+
263
  except Exception as e:
264
  models_loaded = False
265
  error_msg = f"Model loading failed: {str(e)}"
266
  logger.error(error_msg)
267
  return error_msg
268
+
269
  # ============================================================================ #
270
  # MAIN VIDEO PROCESSING - USING UTILITIES FUNCTIONS
271
  # ============================================================================ #
 
281
  ) -> Tuple[Optional[str], str]:
282
  """Optimized video processing using proper functions from utilities"""
283
  global PROCESS_CANCELLED
284
+
285
  if PROCESS_CANCELLED:
286
  return None, "Processing cancelled by user"
287
+
288
  if not models_loaded:
289
  return None, "Models not loaded. Call load_models_with_validation() first."
290
+
291
  if not video_path or not os.path.exists(video_path):
292
  return None, f"Video file not found: {video_path}"
293
+
294
  # Validate video file
295
  is_valid, validation_msg = validate_video_file(video_path)
296
  if not is_valid:
297
  return None, f"Invalid video: {validation_msg}"
298
+
299
  def _prog(pct: float, desc: str):
300
  if PROCESS_CANCELLED:
301
  raise Exception("Processing cancelled by user")
302
+
303
  if progress_callback:
304
  progress_callback(pct, desc)
305
+
306
  # Update processing info file
307
  if "Frame" in desc and "|" in desc:
308
  parts = desc.split("|")
 
310
  time_info = parts[1].strip() if len(parts) > 1 else ""
311
  fps_info = parts[2].strip() if len(parts) > 2 else ""
312
  eta_info = parts[3].strip() if len(parts) > 3 else ""
313
+
314
  display_text = f"""📊 PROCESSING STATUS
315
  ━━━━━━━━━━━━━━━━━━━━━━━━━━
316
  🎬 {frame_info}
 
324
  f.write(display_text)
325
  except Exception as e:
326
  logger.warning(f"Error writing processing info: {e}")
327
+
328
  try:
329
+ _prog(0.0, f"Starting {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'} processing on {DEVICE}...")
330
+
331
  # Check if two-stage mode is requested
332
  if use_two_stage:
333
  if not TWO_STAGE_AVAILABLE:
334
  return None, "Two-stage mode not available. Please add two_stage_processor.py file."
335
+
336
  if two_stage_processor is None:
337
  return None, "Two-stage processor not initialized. Please reload models."
338
+
339
  _prog(0.05, "Starting TWO-STAGE green screen processing...")
340
+
341
  # Get video dimensions
342
  cap = cv2.VideoCapture(video_path)
343
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
344
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
345
  cap.release()
346
+
347
  # Prepare background
348
  if background_choice == "custom" and custom_background_path:
349
  if not os.path.exists(custom_background_path):
350
  return None, f"Custom background not found: {custom_background_path}"
351
+
352
  background = cv2.imread(custom_background_path)
353
  if background is None:
354
  return None, "Could not read custom background image."
 
360
  background_name = bg_config["name"]
361
  else:
362
  return None, f"Invalid background selection: {background_choice}"
363
+
364
  # Get chroma settings
365
  chroma_settings = CHROMA_PRESETS.get(chroma_preset, CHROMA_PRESETS['standard'])
366
+
367
  # Run two-stage pipeline
368
  timestamp = int(time.time())
369
  final_output = f"/tmp/twostage_final_{timestamp}.mp4"
370
+
371
  result, message = two_stage_processor.process_full_pipeline(
372
  video_path,
373
  background,
 
375
  chroma_settings=chroma_settings,
376
  progress_callback=_prog
377
  )
378
+
379
  if PROCESS_CANCELLED:
380
  return None, "Processing cancelled by user"
381
+
382
  if result is None:
383
  return None, message
384
+
385
  # Add audio back
386
  _prog(0.9, "Adding audio...")
387
  final_with_audio = f"/tmp/twostage_audio_{timestamp}.mp4"
388
+
389
  try:
390
  audio_cmd = (
391
  f'ffmpeg -y -i "{final_output}" -i "{video_path}" '
 
400
  except Exception as e:
401
  logger.warning(f"Audio processing error: {e}")
402
  final_with_audio = final_output # Fallback to video without audio
403
+
404
  _prog(1.0, "TWO-STAGE processing complete!")
405
+
406
  success_message = (
407
  f"TWO-STAGE Success!\n"
408
  f"Background: {background_name}\n"
409
  f"Method: Green Screen Chroma Key\n"
410
  f"Preset: {chroma_preset}\n"
411
+ f"Quality: Professional cinema-grade\n"
412
+ f"Device: {DEVICE}"
413
  )
414
+
415
  return final_output, success_message
416
+
417
  # Single-stage processing
418
+ _prog(0.05, f"Starting SINGLE-STAGE processing on {DEVICE}...")
419
+
420
  cap = cv2.VideoCapture(video_path)
421
  if not cap.isOpened():
422
  return None, "Could not open video file."
423
+
424
  fps = cap.get(cv2.CAP_PROP_FPS)
425
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
426
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
427
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
428
+
429
  if total_frames == 0:
430
  return None, "Video appears to be empty."
431
+
432
  # Log video info
433
+ logger.info(f"Video info: {frame_width}x{frame_height}, {fps}fps, {total_frames} frames, processing on {DEVICE}")
434
+
435
  # Prepare background
436
  background = None
437
  background_name = ""
438
+
439
  if background_choice == "custom" and custom_background_path:
440
  if not os.path.exists(custom_background_path):
441
  return None, f"Custom background not found: {custom_background_path}"
442
+
443
  background = cv2.imread(custom_background_path)
444
  if background is None:
445
  return None, "Could not read custom background image."
 
451
  background_name = bg_config["name"]
452
  else:
453
  return None, f"Invalid background selection: {background_choice}"
454
+
455
  if background is None:
456
  return None, "Failed to create background."
457
+
458
  timestamp = int(time.time())
459
  fourcc = cv2.VideoWriter_fourcc(*'avc1') # H.264 for better compatibility
460
+
461
+ _prog(0.1, f"Processing {total_frames} frames with {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'} processing on {DEVICE}...")
462
+
463
  # Create temporary output for preview if needed
464
  if preview_mask or preview_greenscreen:
465
  temp_output = f"/tmp/preview_{timestamp}.mp4"
 
467
  else:
468
  final_path = f"/tmp/output_{timestamp}.mp4"
469
  final_writer = cv2.VideoWriter(final_path, fourcc, fps, (frame_width, frame_height))
470
+
471
  if not final_writer.isOpened():
472
  return None, "Could not create output video file."
473
+
474
  frame_count = 0
475
  successful_frames = 0
476
  last_refined_mask = None
477
+
478
  # Processing stats
479
  start_time = time.time()
480
+
481
  while True:
482
  if PROCESS_CANCELLED:
483
  cap.release()
 
485
  if os.path.exists(final_path):
486
  os.remove(final_path)
487
  return None, "Processing cancelled by user"
488
+
489
  ret, frame = cap.read()
490
  if not ret:
491
  break
492
+
493
  # Skip frames if FRAME_SKIP > 1
494
  if frame_count % FRAME_SKIP != 0:
495
  frame_count += 1
496
  continue
497
+
498
  try:
499
  # Update progress with detailed timing info and ETA
500
  elapsed_time = time.time() - start_time
 
502
  remaining_frames = total_frames - frame_count
503
  eta_seconds = remaining_frames / current_fps if current_fps > 0 else 0
504
  eta_display = f"{int(eta_seconds//60)}m {int(eta_seconds%60)}s" if eta_seconds > 60 else f"{int(eta_seconds)}s"
505
+
506
+ progress_msg = f"Frame {frame_count + 1}/{total_frames} | {elapsed_time:.1f}s | {current_fps:.1f} fps | ETA: {eta_display} | Device: {DEVICE}"
507
+
508
  # Log and display progress
509
  logger.info(progress_msg)
510
  _prog(0.1 + (frame_count / max(1, total_frames)) * 0.8, progress_msg)
511
+
512
  # CRITICAL: Use functions from utilities.py, not local implementations!
513
  # SAM2 segmentation using utilities function
514
  mask = segment_person_hq(frame, sam2_predictor)
515
+
516
  if preview_mask:
517
  # Save mask visualization
518
  mask_vis = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
 
520
  final_writer.write(mask_vis)
521
  frame_count += 1
522
  continue
523
+
524
  # MatAnyone refinement on keyframes using utilities function
525
  if (frame_count % KEYFRAME_INTERVAL == 0) or (last_refined_mask is None):
526
  refined_mask = refine_mask_hq(frame, mask, matanyone_model)
527
  last_refined_mask = refined_mask.copy()
528
+ logger.info(f"Keyframe refinement at frame {frame_count} on {DEVICE}")
529
  else:
530
  # Blend SAM2 mask with last refined mask for temporal smoothness
531
  alpha = 0.7
532
  refined_mask = cv2.addWeighted(mask, alpha, last_refined_mask, 1-alpha, 0)
533
+
534
  if preview_greenscreen:
535
  # Create green screen preview
536
  green_bg = np.zeros_like(frame)
 
542
  final_writer.write(preview_frame.astype(np.uint8))
543
  frame_count += 1
544
  continue
545
+
546
  # CRITICAL: Use replace_background_hq from utilities which has the transparency fix!
547
  result_frame = replace_background_hq(frame, refined_mask, background)
548
  final_writer.write(result_frame)
549
  successful_frames += 1
550
+
551
  except Exception as frame_error:
552
  logger.warning(f"Error processing frame {frame_count}: {frame_error}")
553
  # Write original frame if processing fails
554
  final_writer.write(frame)
555
+
556
  frame_count += 1
557
+
558
  # Memory management
559
  if frame_count % MEMORY_CLEANUP_INTERVAL == 0:
560
  gc.collect()
 
563
  elapsed = time.time() - start_time
564
  fps_actual = frame_count / elapsed
565
  eta = (total_frames - frame_count) / fps_actual if fps_actual > 0 else 0
566
+ logger.info(f"Progress: {frame_count}/{total_frames}, FPS: {fps_actual:.1f}, ETA: {eta:.0f}s, Device: {DEVICE}")
567
+
568
  cap.release()
569
  final_writer.release()
570
+
571
  if PROCESS_CANCELLED:
572
  if os.path.exists(final_path):
573
  os.remove(final_path)
574
  return None, "Processing cancelled by user"
575
+
576
  if successful_frames == 0:
577
  return None, "No frames were processed successfully with AI."
578
+
579
  # Calculate processing stats
580
  total_time = time.time() - start_time
581
  avg_fps = frame_count / total_time if total_time > 0 else 0
582
+
583
  _prog(0.9, "Finalizing output...")
584
+
585
  if preview_mask or preview_greenscreen:
586
  final_output = temp_output
587
  else:
588
  # Add audio back for final output
589
  _prog(0.9, "Adding audio...")
590
  final_output = f"/tmp/final_{timestamp}.mp4"
591
+
592
  try:
593
  audio_cmd = (
594
  f'ffmpeg -y -i "{final_path}" -i "{video_path}" '
 
602
  except Exception as e:
603
  logger.warning(f"Audio processing error: {e}")
604
  shutil.copy2(final_path, final_output)
605
+
606
  # Cleanup
607
  try:
608
  if os.path.exists(final_path):
609
  os.remove(final_path)
610
  except Exception as e:
611
  logger.warning(f"Cleanup error: {e}")
612
+
613
  _prog(1.0, "Processing complete!")
614
+
615
  success_message = (
616
  f"Success!\n"
617
  f"Background: {background_name}\n"
 
621
  f"Processing time: {total_time:.1f}s\n"
622
  f"Average FPS: {avg_fps:.1f}\n"
623
  f"Keyframe interval: {KEYFRAME_INTERVAL}\n"
624
+ f"Mode: {'TWO-STAGE' if use_two_stage else 'SINGLE-STAGE'}\n"
625
+ f"Device: {DEVICE}"
626
  )
627
+
628
  return final_output, success_message
629
+
630
  except Exception as e:
631
  logger.error(f"Processing error: {traceback.format_exc()}")
632
  return None, f"Processing Error: {str(e)}"
633
+
634
  # ============================================================================ #
635
  # MAIN - IMPORT UI COMPONENTS
636
  # ============================================================================ #
 
640
  print(f"Keyframe interval: {KEYFRAME_INTERVAL} frames")
641
  print(f"Frame skip: {FRAME_SKIP} (1=all frames, 2=every other)")
642
  print(f"Two-stage mode: {'AVAILABLE' if TWO_STAGE_AVAILABLE else 'NOT AVAILABLE'}")
643
+ print(f"Device: {DEVICE}")
644
  print("Loading UI components...")
645
+
646
  # Import UI components
647
  from ui_components import create_interface
648
+
649
  os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True)
650
  CACHE_DIR.mkdir(exist_ok=True, parents=True)
651
+
652
  print("Creating interface...")
653
  demo = create_interface()
654
+
655
  print("Launching...")
656
  demo.launch(
657
  server_name="0.0.0.0",
 
661
  debug=True,
662
  enable_queue=True
663
  )
664
+
665
  except Exception as e:
666
  logger.error(f"Startup failed: {e}")
667
  print(f"Startup failed: {e}")
668
+
669
  if __name__ == "__main__":
670
  main()