|
|
|
|
|
|
|
|
import os |
|
|
os.environ.pop("OMP_NUM_THREADS", None) |
|
|
os.environ.setdefault("MKL_NUM_THREADS", "1") |
|
|
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1") |
|
|
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1") |
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "max_split_size_mb:1024") |
|
|
os.environ.setdefault("CUDA_LAUNCH_BLOCKING", "0") |
|
|
|
|
|
|
|
|
""" |
|
|
FIXED Single-Stage Video Background Replacement with Working SAM2 + MatAnyone |
|
|
Core processing functions with proper AI model integration |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
import torch |
|
|
import traceback |
|
|
import time |
|
|
import shutil |
|
|
import gc |
|
|
import threading |
|
|
from typing import Optional |
|
|
import logging |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
from utilities import * |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = Path("/tmp/model_cache") |
|
|
CACHE_DIR.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
def save_model_weights(model, model_name: str): |
|
|
"""Save only model weights, not the entire object""" |
|
|
try: |
|
|
cache_path = CACHE_DIR / f"{model_name}_weights.pth" |
|
|
if hasattr(model, 'model'): |
|
|
torch.save(model.model.state_dict(), cache_path) |
|
|
elif hasattr(model, 'state_dict'): |
|
|
torch.save(model.state_dict(), cache_path) |
|
|
else: |
|
|
logger.warning(f"Cannot save weights for {model_name} - no state_dict found") |
|
|
return False |
|
|
logger.info(f"Model weights for {model_name} cached successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to cache {model_name} weights: {e}") |
|
|
return False |
|
|
|
|
|
def load_model_weights(model, model_name: str): |
|
|
"""Load weights into existing model""" |
|
|
try: |
|
|
cache_path = CACHE_DIR / f"{model_name}_weights.pth" |
|
|
if not cache_path.exists(): |
|
|
return False |
|
|
|
|
|
weights = torch.load(cache_path, map_location='cpu') |
|
|
if hasattr(model, 'model'): |
|
|
model.model.load_state_dict(weights) |
|
|
elif hasattr(model, 'load_state_dict'): |
|
|
model.load_state_dict(weights) |
|
|
else: |
|
|
return False |
|
|
|
|
|
logger.info(f"Model weights for {model_name} loaded from cache") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load {model_name} weights from cache: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_sam2_predictor_fixed(device: str = "cuda", progress_callback=None): |
|
|
"""Load SAM2 with proper error handling and validation""" |
|
|
|
|
|
def _prog(pct: float, desc: str): |
|
|
if progress_callback: |
|
|
progress_callback(pct, desc) |
|
|
|
|
|
try: |
|
|
_prog(0.1, "Initializing SAM2...") |
|
|
|
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id="facebook/sam2-hiera-large", |
|
|
filename="sam2_hiera_large.pt", |
|
|
cache_dir=str(CACHE_DIR / "sam2_checkpoint") |
|
|
) |
|
|
_prog(0.5, "SAM2 checkpoint downloaded, building model...") |
|
|
|
|
|
|
|
|
from sam2.build_sam import build_sam2 |
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
|
|
|
|
|
sam2_model = build_sam2("sam2_hiera_l.yaml", checkpoint_path) |
|
|
sam2_model.to(device) |
|
|
predictor = SAM2ImagePredictor(sam2_model) |
|
|
|
|
|
|
|
|
_prog(0.8, "Testing SAM2 functionality...") |
|
|
test_image = np.zeros((256, 256, 3), dtype=np.uint8) |
|
|
predictor.set_image(test_image) |
|
|
test_points = np.array([[128, 128]]) |
|
|
test_labels = np.array([1]) |
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=test_points, |
|
|
point_labels=test_labels, |
|
|
multimask_output=False |
|
|
) |
|
|
|
|
|
if masks is None or len(masks) == 0: |
|
|
raise Exception("SAM2 predictor test failed - no masks generated") |
|
|
|
|
|
_prog(1.0, "SAM2 loaded and validated successfully!") |
|
|
logger.info("SAM2 predictor loaded and tested successfully") |
|
|
return predictor |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 loading failed: {str(e)}") |
|
|
logger.error(f"Full traceback: {traceback.format_exc()}") |
|
|
raise Exception(f"SAM2 loading failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_matanyone_fixed(progress_callback=None): |
|
|
"""Load MatAnyone with proper error handling and validation""" |
|
|
|
|
|
def _prog(pct: float, desc: str): |
|
|
if progress_callback: |
|
|
progress_callback(pct, desc) |
|
|
|
|
|
try: |
|
|
_prog(0.2, "Loading MatAnyone...") |
|
|
|
|
|
from matanyone import InferenceCore |
|
|
processor = InferenceCore("PeiqingYang/MatAnyone") |
|
|
|
|
|
|
|
|
_prog(0.8, "Testing MatAnyone functionality...") |
|
|
test_image = np.zeros((256, 256, 3), dtype=np.uint8) |
|
|
test_mask = np.zeros((256, 256), dtype=np.uint8) |
|
|
test_mask[64:192, 64:192] = 255 |
|
|
|
|
|
|
|
|
try: |
|
|
if hasattr(processor, 'process') or hasattr(processor, '__call__'): |
|
|
logger.info("MatAnyone processor interface detected") |
|
|
else: |
|
|
logger.warning("MatAnyone interface unclear, will use fallback refinement") |
|
|
except Exception as test_e: |
|
|
logger.warning(f"MatAnyone test failed: {test_e}, will use enhanced OpenCV") |
|
|
|
|
|
_prog(1.0, "MatAnyone loaded successfully!") |
|
|
logger.info("MatAnyone processor loaded successfully") |
|
|
return processor |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"MatAnyone loading failed: {str(e)}") |
|
|
logger.error(f"Full traceback: {traceback.format_exc()}") |
|
|
raise Exception(f"MatAnyone loading failed: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sam2_predictor = None |
|
|
matanyone_model = None |
|
|
models_loaded = False |
|
|
loading_lock = threading.Lock() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_status(): |
|
|
"""Return current model status for UI""" |
|
|
global sam2_predictor, matanyone_model, models_loaded |
|
|
return { |
|
|
'sam2': 'Ready' if sam2_predictor is not None else 'Not loaded', |
|
|
'matanyone': 'Ready' if matanyone_model is not None else 'Not loaded', |
|
|
'validated': models_loaded |
|
|
} |
|
|
|
|
|
def load_models_with_validation(progress_callback=None): |
|
|
"""Load models with comprehensive validation""" |
|
|
global sam2_predictor, matanyone_model, models_loaded |
|
|
|
|
|
with loading_lock: |
|
|
if models_loaded: |
|
|
return "Models already loaded and validated" |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
logger.info(f"Starting model loading on {device}") |
|
|
|
|
|
|
|
|
sam2_predictor = load_sam2_predictor_fixed(device=device, progress_callback=progress_callback) |
|
|
|
|
|
|
|
|
matanyone_model = load_matanyone_fixed(progress_callback=progress_callback) |
|
|
|
|
|
models_loaded = True |
|
|
load_time = time.time() - start_time |
|
|
|
|
|
message = f"SUCCESS: SAM2 + MatAnyone loaded and validated in {load_time:.1f}s" |
|
|
logger.info(message) |
|
|
return message |
|
|
|
|
|
except Exception as e: |
|
|
models_loaded = False |
|
|
error_msg = f"Model loading failed: {str(e)}" |
|
|
logger.error(error_msg) |
|
|
return error_msg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def segment_person_with_validation(image, predictor): |
|
|
"""Enhanced person segmentation with validation""" |
|
|
try: |
|
|
if predictor is None: |
|
|
raise Exception("SAM2 predictor is None") |
|
|
|
|
|
predictor.set_image(image) |
|
|
h, w = image.shape[:2] |
|
|
|
|
|
|
|
|
points = np.array([ |
|
|
[w//2, h//3], |
|
|
[w//2, h//2], |
|
|
[w//2, 2*h//3], |
|
|
[w//3, h//2], |
|
|
[2*w//3, h//2], |
|
|
]) |
|
|
labels = np.ones(len(points)) |
|
|
|
|
|
masks, scores, _ = predictor.predict( |
|
|
point_coords=points, |
|
|
point_labels=labels, |
|
|
multimask_output=True |
|
|
) |
|
|
|
|
|
if masks is None or len(masks) == 0: |
|
|
raise Exception("SAM2 returned no masks") |
|
|
|
|
|
|
|
|
best_idx = np.argmax(scores) |
|
|
best_mask = masks[best_idx] |
|
|
|
|
|
|
|
|
if len(best_mask.shape) > 2: |
|
|
best_mask = best_mask.squeeze() |
|
|
if best_mask.dtype != np.uint8: |
|
|
best_mask = (best_mask * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
kernel = np.ones((3, 3), np.uint8) |
|
|
best_mask = cv2.morphologyEx(best_mask, cv2.MORPH_CLOSE, kernel) |
|
|
best_mask = cv2.GaussianBlur(best_mask.astype(np.float32), (3, 3), 0.8) |
|
|
|
|
|
final_mask = (best_mask * 255).astype(np.uint8) if best_mask.max() <= 1.0 else best_mask.astype(np.uint8) |
|
|
|
|
|
logger.info(f"SAM2 segmentation successful, mask shape: {final_mask.shape}, range: {final_mask.min()}-{final_mask.max()}") |
|
|
return final_mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 segmentation failed: {e}") |
|
|
|
|
|
return create_fallback_mask(image) |
|
|
|
|
|
def create_fallback_mask(image): |
|
|
"""Enhanced fallback segmentation when SAM2 fails""" |
|
|
try: |
|
|
h, w = image.shape[:2] |
|
|
|
|
|
|
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
|
edges = cv2.Canny(gray, 50, 150) |
|
|
|
|
|
|
|
|
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
|
|
|
if contours: |
|
|
largest_contour = max(contours, key=cv2.contourArea) |
|
|
mask = np.zeros((h, w), dtype=np.uint8) |
|
|
cv2.fillPoly(mask, [largest_contour], 255) |
|
|
else: |
|
|
|
|
|
mask = np.zeros((h, w), dtype=np.uint8) |
|
|
x1, y1 = w//4, h//6 |
|
|
x2, y2 = 3*w//4, 5*h//6 |
|
|
mask[y1:y2, x1:x2] = 255 |
|
|
|
|
|
|
|
|
mask = cv2.GaussianBlur(mask, (15, 15), 5) |
|
|
|
|
|
logger.warning("Using enhanced fallback segmentation") |
|
|
return mask |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Fallback segmentation failed: {e}") |
|
|
|
|
|
h, w = image.shape[:2] |
|
|
mask = np.zeros((h, w), dtype=np.uint8) |
|
|
mask[h//6:5*h//6, w//4:3*w//4] = 255 |
|
|
return mask |
|
|
|
|
|
def refine_mask_with_validation(image, mask, matanyone_processor): |
|
|
"""Enhanced mask refinement with validation""" |
|
|
try: |
|
|
if matanyone_processor is None: |
|
|
logger.warning("MatAnyone processor is None, using enhanced OpenCV refinement") |
|
|
return enhance_mask_opencv_advanced(image, mask) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if hasattr(matanyone_processor, 'process'): |
|
|
refined_mask = matanyone_processor.process(image, mask) |
|
|
elif hasattr(matanyone_processor, '__call__'): |
|
|
refined_mask = matanyone_processor(image, mask) |
|
|
else: |
|
|
|
|
|
refined_mask = refine_mask_hq(image, mask, matanyone_processor) |
|
|
|
|
|
|
|
|
if refined_mask is not None and refined_mask.shape[:2] == mask.shape[:2]: |
|
|
logger.info("MatAnyone refinement successful") |
|
|
return refined_mask |
|
|
else: |
|
|
raise Exception("MatAnyone returned invalid mask") |
|
|
|
|
|
except Exception as ma_error: |
|
|
logger.warning(f"MatAnyone refinement failed: {ma_error}, using enhanced OpenCV") |
|
|
return enhance_mask_opencv_advanced(image, mask) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Mask refinement error: {e}") |
|
|
return enhance_mask_opencv_advanced(image, mask) |
|
|
|
|
|
def enhance_mask_opencv_advanced(image, mask): |
|
|
"""Advanced OpenCV mask enhancement""" |
|
|
try: |
|
|
if len(mask.shape) == 3: |
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
|
|
|
|
|
|
refined = cv2.bilateralFilter(mask, 15, 80, 80) |
|
|
|
|
|
|
|
|
kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
|
|
kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) |
|
|
|
|
|
refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel_close) |
|
|
refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel_open) |
|
|
|
|
|
|
|
|
refined = cv2.medianBlur(refined, 5) |
|
|
refined = cv2.GaussianBlur(refined, (5, 5), 1.5) |
|
|
|
|
|
|
|
|
dist_transform = cv2.distanceTransform(refined, cv2.DIST_L2, 5) |
|
|
dist_transform = cv2.normalize(dist_transform, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
|
|
|
|
|
|
|
|
alpha = 0.6 |
|
|
refined = cv2.addWeighted(refined, alpha, dist_transform, 1-alpha, 0) |
|
|
|
|
|
|
|
|
refined = cv2.GaussianBlur(refined, (3, 3), 0.8) |
|
|
|
|
|
logger.info("Advanced OpenCV mask enhancement completed") |
|
|
return refined |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Advanced mask enhancement failed: {e}") |
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_video_fixed(video_path, background_choice, custom_background_path, progress_callback=None): |
|
|
"""Fixed core video processing with proper SAM2 + MatAnyone integration""" |
|
|
if not models_loaded: |
|
|
return None, "Models not loaded. Call load_models_with_validation() first." |
|
|
if not video_path: |
|
|
return None, "No video file provided." |
|
|
|
|
|
def _prog(pct: float, desc: str): |
|
|
if progress_callback: |
|
|
progress_callback(pct, desc) |
|
|
|
|
|
try: |
|
|
_prog(0.0, "Starting FIXED single-stage processing...") |
|
|
|
|
|
if not os.path.exists(video_path): |
|
|
return None, f"Video file not found: {video_path}" |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return None, "Could not open video file." |
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
|
|
if total_frames == 0: |
|
|
return None, "Video appears to be empty." |
|
|
|
|
|
|
|
|
background = None |
|
|
background_name = "" |
|
|
|
|
|
if background_choice == "custom" and custom_background_path: |
|
|
background = cv2.imread(custom_background_path) |
|
|
if background is None: |
|
|
return None, "Could not read custom background image." |
|
|
background_name = "Custom Image" |
|
|
else: |
|
|
if background_choice in PROFESSIONAL_BACKGROUNDS: |
|
|
bg_config = PROFESSIONAL_BACKGROUNDS[background_choice] |
|
|
background = create_professional_background(bg_config, frame_width, frame_height) |
|
|
background_name = bg_config["name"] |
|
|
else: |
|
|
return None, f"Invalid background selection: {background_choice}" |
|
|
|
|
|
if background is None: |
|
|
return None, "Failed to create background." |
|
|
|
|
|
timestamp = int(time.time()) |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
|
|
|
_prog(0.1, f"Processing with VALIDATED SAM2 + MatAnyone: {background_name}...") |
|
|
final_path = f"/tmp/fixed_output_{timestamp}.mp4" |
|
|
final_writer = cv2.VideoWriter(final_path, fourcc, fps, (frame_width, frame_height)) |
|
|
|
|
|
if not final_writer.isOpened(): |
|
|
return None, "Could not create output video file." |
|
|
|
|
|
frame_count = 0 |
|
|
successful_frames = 0 |
|
|
keyframe_interval = 3 |
|
|
last_refined_mask = None |
|
|
|
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
if not ret: |
|
|
break |
|
|
|
|
|
try: |
|
|
_prog(0.1 + (frame_count / max(1, total_frames)) * 0.8, |
|
|
f"Processing frame {frame_count + 1}/{total_frames} with AI") |
|
|
|
|
|
|
|
|
mask = segment_person_with_validation(frame, sam2_predictor) |
|
|
|
|
|
|
|
|
if (frame_count % keyframe_interval == 0) or (last_refined_mask is None): |
|
|
refined_mask = refine_mask_with_validation(frame, mask, matanyone_model) |
|
|
last_refined_mask = refined_mask.copy() |
|
|
logger.info(f"AI refinement on frame {frame_count}") |
|
|
else: |
|
|
|
|
|
alpha = 0.7 |
|
|
refined_mask = cv2.addWeighted(mask, alpha, last_refined_mask, 1-alpha, 0) |
|
|
|
|
|
|
|
|
result_frame = replace_background_hq(frame, refined_mask, background) |
|
|
final_writer.write(result_frame) |
|
|
successful_frames += 1 |
|
|
|
|
|
except Exception as frame_error: |
|
|
logger.warning(f"Error processing frame {frame_count}: {frame_error}") |
|
|
|
|
|
final_writer.write(frame) |
|
|
|
|
|
frame_count += 1 |
|
|
if frame_count % 50 == 0: |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
final_writer.release() |
|
|
cap.release() |
|
|
|
|
|
if successful_frames == 0: |
|
|
return None, "No frames were processed successfully with AI." |
|
|
|
|
|
_prog(0.9, "Adding audio...") |
|
|
final_output = f"/tmp/final_fixed_{timestamp}.mp4" |
|
|
|
|
|
try: |
|
|
audio_cmd = ( |
|
|
f'ffmpeg -y -i "{final_path}" -i "{video_path}" ' |
|
|
f'-c:v libx264 -crf 18 -preset medium ' |
|
|
f'-c:a aac -b:a 192k -ac 2 -ar 48000 ' |
|
|
f'-map 0:v:0 -map 1:a:0? -shortest "{final_output}"' |
|
|
) |
|
|
result = os.system(audio_cmd) |
|
|
if result != 0 or not os.path.exists(final_output): |
|
|
shutil.copy2(final_path, final_output) |
|
|
except Exception as e: |
|
|
logger.warning(f"Audio processing error: {e}") |
|
|
shutil.copy2(final_path, final_output) |
|
|
|
|
|
|
|
|
try: |
|
|
myavatar_path = "/tmp/MyAvatar/My_Videos/" |
|
|
os.makedirs(myavatar_path, exist_ok=True) |
|
|
saved_filename = f"fixed_sam2_matanyone_{timestamp}.mp4" |
|
|
saved_path = os.path.join(myavatar_path, saved_filename) |
|
|
shutil.copy2(final_output, saved_path) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not save to MyAvatar: {e}") |
|
|
saved_filename = os.path.basename(final_output) |
|
|
|
|
|
|
|
|
try: |
|
|
if os.path.exists(final_path): |
|
|
os.remove(final_path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
_prog(1.0, "FIXED processing complete!") |
|
|
|
|
|
success_message = ( |
|
|
f"FIXED Success!\n" |
|
|
f"Background: {background_name}\n" |
|
|
f"Total frames: {frame_count}\n" |
|
|
f"Successfully processed: {successful_frames}\n" |
|
|
f"AI model usage: SAM2 + MatAnyone validated\n" |
|
|
f"Saved: {saved_filename}" |
|
|
) |
|
|
|
|
|
return final_output, success_message |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Fixed processing error: {traceback.format_exc()}") |
|
|
return None, f"Processing Error: {str(e)}" |
|
|
|
|
|
def get_cache_status(): |
|
|
"""Get current cache status""" |
|
|
return { |
|
|
"sam2_loaded": sam2_predictor is not None, |
|
|
"matanyone_loaded": matanyone_model is not None, |
|
|
"models_validated": models_loaded |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
try: |
|
|
print("===== FIXED SAM2 + MATANYONE CORE =====") |
|
|
print("Loading UI components...") |
|
|
|
|
|
|
|
|
from ui_components import create_interface |
|
|
|
|
|
os.makedirs("/tmp/MyAvatar/My_Videos/", exist_ok=True) |
|
|
CACHE_DIR.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
print("Creating interface...") |
|
|
demo = create_interface() |
|
|
|
|
|
print("Launching...") |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, show_error=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Startup failed: {e}") |
|
|
print(f"Startup failed: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |