import streamlit as st import os import sys import tempfile import time import shutil import gc from pathlib import Path import cv2 import numpy as np from PIL import Image import logging import base64 from io import BytesIO import torch from contextlib import contextmanager # Add project root to path sys.path.append(str(Path(__file__).parent.absolute())) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Persistent temp dir (survives beyond TemporaryDirectory scope) TMP_DIR = Path("tmp") TMP_DIR.mkdir(parents=True, exist_ok=True) # Page config st.set_page_config( page_title="MyAvatar - Video Background Replacer", page_icon="๐ŸŽฅ", layout="wide", initial_sidebar_state="expanded" ) # Memory management utilities @contextmanager def torch_memory_manager(): """Context manager for CUDA memory cleanup.""" try: yield finally: if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() def clear_model_cache(): """Clear all cached models and free memory.""" if hasattr(st, 'cache_resource'): st.cache_resource.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logger.info("Model cache cleared") def get_memory_usage(): """Get current memory usage statistics.""" memory_info = {} if torch.cuda.is_available(): memory_info['gpu_allocated'] = torch.cuda.memory_allocated() / 1e9 memory_info['gpu_reserved'] = torch.cuda.memory_reserved() / 1e9 memory_info['gpu_free'] = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9 import psutil memory_info['ram_used'] = psutil.virtual_memory().used / 1e9 memory_info['ram_available'] = psutil.virtual_memory().available / 1e9 return memory_info # Lazy model loading @st.cache_resource(show_spinner=False) def load_sam2_predictor(): """Lazy load SAM2 image predictor only when needed.""" try: logger.info("Loading SAM2 image predictor...") from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor checkpoint_path = "/home/user/app/checkpoints/sam2.1_hiera_large.pt" model_cfg = "/home/user/app/configs/sam2.1/sam2.1_hiera_l.yaml" if not os.path.exists(checkpoint_path) or not os.path.exists(model_cfg): logger.warning("Local checkpoints not found, using Hugging Face...") predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") else: memory_info = get_memory_usage() if memory_info.get('gpu_free', 0) < 4.0: logger.warning("Limited GPU memory, using smaller SAM2 model...") try: predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny") except: predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small") else: predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint_path)) logger.info("โœ… SAM2 image predictor loaded successfully!") return predictor except Exception as e: logger.error(f"Failed to load SAM2 predictor: {e}") st.error(f"โŒ Failed to load SAM2: {e}") return None @st.cache_resource(show_spinner=False) def load_matanyone_processor(): """Lazy load MatAnyone processor only when needed.""" try: logger.info("Loading MatAnyone processor...") from matanyone import InferenceCore processor = InferenceCore("PeiqingYang/MatAnyone") logger.info("โœ… MatAnyone processor loaded successfully!") return processor except Exception as e: logger.error(f"Failed to load MatAnyone: {e}") st.error(f"โŒ Failed to load MatAnyone: {e}") return None def generate_mask_from_video_first_frame(video_path, sam2_predictor): """Generate mask for the first frame of video using SAM2.""" try: with torch_memory_manager(): cap = cv2.VideoCapture(video_path) ret, frame = cap.read() cap.release() if not ret: st.error("Failed to read video frame") return None # Resize frame if too large to save memory h, w = frame.shape[:2] max_size = 1080 if max(h, w) > max_size: scale = max_size / max(h, w) new_w, new_h = int(w * scale), int(h * scale) frame = cv2.resize(frame, (new_w, new_h)) logger.info(f"Resized frame from {w}x{h} to {new_w}x{new_h}") frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): sam2_predictor.set_image(frame_rgb) # Get center point as default prompt h, w = frame_rgb.shape[:2] center_point = np.array([[w//2, h//2]], dtype=np.float32) center_label = np.array([1], dtype=np.int32) masks, scores, logits = sam2_predictor.predict( point_coords=center_point, point_labels=center_label, multimask_output=True ) best_mask = masks[np.argmax(scores)] return best_mask.astype(np.uint8) * 255 except Exception as e: st.error(f"Failed to generate mask: {e}") return None def stage1_create_transparent_video(input_file): """STAGE 1: Create transparent video using SAM2 + MatAnyone.""" logger.info("Starting Stage 1: Create transparent video") memory_info = get_memory_usage() if memory_info.get('gpu_free', 0) < 2.0: st.warning("โš ๏ธ Low GPU memory detected. Processing may be slower.") clear_model_cache() try: progress_bar = st.progress(0) status_text = st.empty() def update_progress(progress, message): progress = max(0, min(1, progress)) progress_bar.progress(progress) status_text.text(f"Stage 1: {message} | GPU: {get_memory_usage().get('gpu_allocated', 0):.1f}GB") logger.info(f"Stage 1 Progress: {progress:.2f} - {message}") # Load models update_progress(0.05, "Loading SAM2 model...") logger.info("Attempting to load SAM2 predictor...") sam2_predictor = load_sam2_predictor() if sam2_predictor is None: logger.error("SAM2 predictor failed to load") st.error("โŒ Failed to load SAM2 model") return None logger.info("SAM2 predictor loaded successfully") update_progress(0.1, "Loading MatAnyone model...") logger.info("Attempting to load MatAnyone processor...") matanyone_processor = load_matanyone_processor() if matanyone_processor is None: logger.error("MatAnyone processor failed to load") st.error("โŒ Failed to load MatAnyone model") return None logger.info("MatAnyone processor loaded successfully") # Process video to create transparent version with tempfile.TemporaryDirectory() as temp_dir: temp_dir = Path(temp_dir) input_path = str(temp_dir / "input.mp4") # Save input video with open(input_path, "wb") as f: f.write(input_file.getvalue()) update_progress(0.2, "Generating segmentation mask...") # Generate mask using SAM2 with torch_memory_manager(): mask = generate_mask_from_video_first_frame(input_path, sam2_predictor) if mask is None: return None mask_path = str(temp_dir / "mask.png") cv2.imwrite(mask_path, mask) update_progress(0.4, "Creating transparent video with MatAnyone...") # Process with MatAnyone to get foreground and alpha try: with torch_memory_manager(): foreground_path, alpha_path = matanyone_processor.process_video( input_path=input_path, mask_path=mask_path, output_path=str(temp_dir), max_size=720 # Limit resolution for memory efficiency ) update_progress(0.8, "Creating transparent .mov file...") # Create transparent video (.mov with alpha channel) transparent_path = create_transparent_mov(foreground_path, alpha_path, temp_dir) if transparent_path and os.path.exists(transparent_path): # Copy to persistent location persist_path = TMP_DIR / "transparent_video.mov" shutil.copyfile(transparent_path, persist_path) update_progress(1.0, "Transparent video created!") time.sleep(0.5) return str(persist_path) else: st.error("Failed to create transparent video") return None except Exception as e: st.error(f"MatAnyone processing failed: {e}") return None except Exception as e: logger.error(f"Error in Stage 1 processing: {str(e)}", exc_info=True) st.error(f"โŒ Stage 1 failed: {str(e)}") # Show additional debug info try: memory_info = get_memory_usage() st.info(f"Memory at failure - GPU: {memory_info.get('gpu_allocated', 0):.1f}GB, RAM: {memory_info.get('ram_used', 0):.1f}GB") except: pass return None finally: logger.info("Stage 1 cleanup starting...") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() logger.info("Stage 1 cleanup completed") def create_transparent_mov(foreground_path, alpha_path, temp_dir): """Create a .mov file with alpha channel from foreground and alpha videos.""" try: output_path = str(temp_dir / "transparent.mov") # Read videos fg_cap = cv2.VideoCapture(foreground_path) alpha_cap = cv2.VideoCapture(alpha_path) # Get video properties fps = int(fg_cap.get(cv2.CAP_PROP_FPS)) or 30 width = int(fg_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(fg_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Use PNG codec for alpha channel support fourcc = cv2.VideoWriter_fourcc(*'png ') # PNG codec supports alpha out = cv2.VideoWriter(output_path, fourcc, fps, (width, height), True) frame_count = 0 while True: ret_fg, fg_frame = fg_cap.read() ret_alpha, alpha_frame = alpha_cap.read() if not ret_fg or not ret_alpha: break # Convert alpha to single channel if needed if len(alpha_frame.shape) == 3: alpha_frame = cv2.cvtColor(alpha_frame, cv2.COLOR_BGR2GRAY) # Create RGBA frame rgba_frame = np.zeros((height, width, 4), dtype=np.uint8) rgba_frame[:, :, :3] = fg_frame # RGB channels rgba_frame[:, :, 3] = alpha_frame # Alpha channel # Convert RGBA to BGRA for OpenCV bgra_frame = cv2.cvtColor(rgba_frame, cv2.COLOR_RGBA2BGRA) out.write(bgra_frame) frame_count += 1 if frame_count % 10 == 0: gc.collect() fg_cap.release() alpha_cap.release() out.release() return output_path if os.path.exists(output_path) else None except Exception as e: logger.error(f"Failed to create transparent MOV: {e}") return None def stage2_composite_background(transparent_video_path, background, bg_type): """STAGE 2: Composite transparent video with new background.""" try: progress_bar = st.progress(0) status_text = st.empty() def update_progress(progress, message): progress = max(0, min(1, progress)) progress_bar.progress(progress) status_text.text(f"Stage 2: {message}") with tempfile.TemporaryDirectory() as temp_dir: temp_dir = Path(temp_dir) update_progress(0.2, "Loading transparent video...") # Read transparent video cap = cv2.VideoCapture(transparent_video_path) fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Prepare background update_progress(0.4, "Preparing background...") if bg_type == "image" and background is not None: bg_array = np.array(background) if len(bg_array.shape) == 3 and bg_array.shape[2] == 3: bg_array = cv2.cvtColor(bg_array, cv2.COLOR_RGB2BGR) elif len(bg_array.shape) == 3 and bg_array.shape[2] == 4: bg_array = cv2.cvtColor(bg_array, cv2.COLOR_RGBA2BGR) bg_resized = cv2.resize(bg_array, (width, height)) elif bg_type == "color": color_hex = st.session_state.bg_color.lstrip('#') r = int(color_hex[0:2], 16) g = int(color_hex[2:4], 16) b = int(color_hex[4:6], 16) bg_resized = np.full((height, width, 3), (b, g, r), dtype=np.uint8) else: bg_resized = np.full((height, width, 3), (0, 255, 0), dtype=np.uint8) # Create output video output_path = str(temp_dir / "final_output.mp4") fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) update_progress(0.6, "Compositing frames...") frame_count = 0 total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) while True: ret, frame = cap.read() if not ret: break # Extract alpha channel if present (BGRA format) if frame.shape[2] == 4: bgr_frame = frame[:, :, :3] alpha_channel = frame[:, :, 3] else: # Fallback: assume full opacity bgr_frame = frame alpha_channel = np.full((height, width), 255, dtype=np.uint8) # Normalize alpha to 0-1 alpha_norm = alpha_channel.astype(np.float32) / 255.0 alpha_norm = np.expand_dims(alpha_norm, axis=2) # Composite: result = fg * alpha + bg * (1 - alpha) fg_float = bgr_frame.astype(np.float32) bg_float = bg_resized.astype(np.float32) result = fg_float * alpha_norm + bg_float * (1 - alpha_norm) result = result.astype(np.uint8) out.write(result) frame_count += 1 # Update progress if total_frames > 0 and frame_count % 5 == 0: progress = 0.6 + 0.3 * (frame_count / total_frames) update_progress(progress, f"Compositing frame {frame_count}/{total_frames}") if frame_count % 10 == 0: gc.collect() cap.release() out.release() if os.path.exists(output_path): # Copy to persistent location persist_path = TMP_DIR / "final_video.mp4" shutil.copyfile(output_path, persist_path) update_progress(1.0, "Compositing complete!") time.sleep(0.5) return str(persist_path) else: return None except Exception as e: logger.error(f"Error in Stage 2 compositing: {str(e)}", exc_info=True) st.error(f"Stage 2 failed: {str(e)}") return None # UI Functions (simplified for two-stage approach) def add_logo(): st.markdown( """ """, unsafe_allow_html=True ) def show_memory_info(): memory_info = get_memory_usage() with st.sidebar: st.markdown("### ๐Ÿง  Memory Usage") if 'gpu_allocated' in memory_info: st.metric("GPU Memory", f"{memory_info['gpu_allocated']:.1f}GB", f"Free: {memory_info['gpu_free']:.1f}GB") st.metric("RAM Usage", f"{memory_info['ram_used']:.1f}GB", f"Available: {memory_info['ram_available']:.1f}GB") # Test model loading if st.button("๐Ÿงช Test Models", help="Test if SAM2 and MatAnyone can load"): with st.spinner("Testing model loading..."): try: sam2_test = load_sam2_predictor() if sam2_test: st.success("โœ… SAM2 loads successfully") else: st.error("โŒ SAM2 failed to load") matanyone_test = load_matanyone_processor() if matanyone_test: st.success("โœ… MatAnyone loads successfully") else: st.error("โŒ MatAnyone failed to load") except Exception as e: st.error(f"Model test failed: {e}") if st.button("๐Ÿงน Clear Cache", help="Free up memory by clearing model cache"): clear_model_cache() st.success("Cache cleared!") st.experimental_rerun() def initialize_session_state(): if 'uploaded_video' not in st.session_state: st.session_state.uploaded_video = None if 'bg_image' not in st.session_state: st.session_state.bg_image = None if 'bg_image_info' not in st.session_state: st.session_state.bg_image_info = None if 'bg_color' not in st.session_state: st.session_state.bg_color = "#00FF00" if 'bg_type' not in st.session_state: st.session_state.bg_type = "image" if 'transparent_video_path' not in st.session_state: st.session_state.transparent_video_path = None if 'final_video_path' not in st.session_state: st.session_state.final_video_path = None if 'processing_stage1' not in st.session_state: st.session_state.processing_stage1 = False if 'processing_stage2' not in st.session_state: st.session_state.processing_stage2 = False def handle_video_upload(): uploaded = st.file_uploader( "๐Ÿ“น Upload Video", type=["mp4", "mov", "avi", "mkv"], key="video_uploader", help="Recommended: Videos under 30 seconds for faster processing" ) if uploaded is not None: file_size_mb = uploaded.size / (1024 * 1024) if file_size_mb > 100: st.warning(f"โš ๏ธ Large file detected ({file_size_mb:.1f}MB). Processing may take longer.") st.session_state.uploaded_video = uploaded # Reset processed videos when new video is uploaded st.session_state.transparent_video_path = None st.session_state.final_video_path = None def show_video_preview(): st.markdown("### Video Preview") if st.session_state.uploaded_video is not None: video_bytes = st.session_state.uploaded_video.getvalue() st.video(video_bytes) st.session_state.uploaded_video.seek(0) def handle_background_selection(): st.markdown("### Background Options") bg_type = st.radio( "Select Background Type:", ["Image", "Color"], horizontal=True, key="bg_type_radio" ) st.session_state.bg_type = bg_type.lower() if bg_type == "Image": handle_image_background() elif bg_type == "Color": handle_color_background() def handle_image_background(): bg_image = st.file_uploader( "๐Ÿ–ผ๏ธ Upload Background Image", type=["jpg", "png", "jpeg"], key="bg_image_uploader", help="Recommended: Images under 5MB for better performance" ) if bg_image is not None: image_size_mb = bg_image.size / (1024 * 1024) if image_size_mb > 10: st.warning(f"โš ๏ธ Large image ({image_size_mb:.1f}MB). Consider resizing for better performance.") current_file_info = f"{bg_image.name}_{bg_image.size}" if st.session_state.bg_image_info != current_file_info: st.session_state.bg_image = Image.open(bg_image) st.session_state.bg_image_info = current_file_info # Reset final video when background changes st.session_state.final_video_path = None if st.session_state.bg_image is not None: st.image(st.session_state.bg_image, caption="Selected Background", use_container_width=True) else: if 'bg_image' in st.session_state: st.session_state.bg_image = None if 'bg_image_info' in st.session_state: st.session_state.bg_image_info = None def handle_color_background(): st.markdown("#### Select a Color") old_color = st.session_state.get('bg_color', "#00FF00") color_presets = { "Pure White": "#FFFFFF", "Pure Black": "#000000", "Light Gray": "#F5F5F5", "Professional Blue": "#0078D4", "Corporate Green": "#107C10", "Custom": old_color } cols = st.columns(3) for i, (name, color) in enumerate(color_presets.items()): with cols[i % 3]: if name == "Custom": new_color = st.color_picker("Custom Color", old_color, key="custom_color_picker") if new_color != old_color: st.session_state.bg_color = new_color st.session_state.final_video_path = None # Reset final video else: if st.button(name, key=f"color_{name}", use_container_width=True): st.session_state.bg_color = color st.session_state.final_video_path = None # Reset final video st.markdown( f'
', unsafe_allow_html=True ) def main(): add_logo() st.markdown( """

๐ŸŽฅ Video Background Replacer

Two-Stage Processing: SAM2 + MatAnyone โ†’ Transparent โ†’ Composite

""", unsafe_allow_html=True ) st.markdown("---") initialize_session_state() show_memory_info() col1, col2 = st.columns([1, 1], gap="large") with col1: st.header("1. Upload Video") handle_video_upload() show_video_preview() # STAGE 1: Create Transparent Video st.markdown('
STAGE 1: Create Transparent Video
', unsafe_allow_html=True) stage1_disabled = not st.session_state.uploaded_video or st.session_state.processing_stage1 if st.button("๐ŸŽญ Create Transparent Video", type="primary", disabled=stage1_disabled, use_container_width=True, help="Remove background using SAM2 + MatAnyone AI"): with st.spinner("Stage 1: Creating transparent video..."): st.session_state.processing_stage1 = True try: transparent_path = stage1_create_transparent_video(st.session_state.uploaded_video) if transparent_path: st.session_state.transparent_video_path = transparent_path st.success("โœ… Stage 1 Complete: Transparent video created!") st.balloons() else: st.error("โŒ Stage 1 Failed: Could not create transparent video") except Exception as e: st.error(f"โŒ Stage 1 Error: {str(e)}") finally: st.session_state.processing_stage1 = False # Show transparent video result if st.session_state.get('transparent_video_path'): st.markdown("#### Transparent Video Result") try: with open(st.session_state.transparent_video_path, 'rb') as f: transparent_bytes = f.read() st.video(transparent_bytes) st.download_button( label="๐Ÿ’พ Download Transparent Video (.mov)", data=transparent_bytes, file_name="transparent_video.mov", mime="video/quicktime", use_container_width=True, help="Download for use in other video editors" ) file_size_mb = len(transparent_bytes) / (1024 * 1024) st.caption(f"Transparent video size: {file_size_mb:.1f}MB") except Exception as e: st.error(f"Error displaying transparent video: {str(e)}") with col2: st.header("2. Background Settings") handle_background_selection() # STAGE 2: Composite with Background st.markdown('
STAGE 2: Composite with Background
', unsafe_allow_html=True) stage2_disabled = (not st.session_state.get('transparent_video_path') or st.session_state.processing_stage2 or (st.session_state.bg_type == "image" and not st.session_state.get('bg_image'))) if st.button("๐ŸŽฌ Composite Final Video", type="primary", disabled=stage2_disabled, use_container_width=True, help="Combine transparent video with selected background"): if st.session_state.bg_type == "image" and not st.session_state.get('bg_image'): st.error("Please upload a background image first.") else: with st.spinner("Stage 2: Compositing with background..."): st.session_state.processing_stage2 = True try: background = None if st.session_state.bg_type == "image": background = st.session_state.bg_image elif st.session_state.bg_type == "color": background = st.session_state.bg_color final_path = stage2_composite_background( st.session_state.transparent_video_path, background, st.session_state.bg_type ) if final_path: st.session_state.final_video_path = final_path st.success("โœ… Stage 2 Complete: Final video ready!") st.balloons() else: st.error("โŒ Stage 2 Failed: Could not composite video") except Exception as e: st.error(f"โŒ Stage 2 Error: {str(e)}") finally: st.session_state.processing_stage2 = False # Show final video result if st.session_state.get('final_video_path'): st.markdown("#### Final Video Result") try: with open(st.session_state.final_video_path, 'rb') as f: final_bytes = f.read() st.video(final_bytes) st.download_button( label="๐Ÿ’พ Download Final Video (.mp4)", data=final_bytes, file_name="final_video.mp4", mime="video/mp4", use_container_width=True ) file_size_mb = len(final_bytes) / (1024 * 1024) st.caption(f"Final video size: {file_size_mb:.1f}MB") except Exception as e: st.error(f"Error displaying final video: {str(e)}") # Processing tips with st.expander("๐Ÿ’ก Two-Stage Processing Tips"): st.markdown(""" **Stage 1 - Create Transparent Video:** - Uses SAM2 + MatAnyone AI to remove background - Creates a .mov file with alpha channel (transparency) - Only needs to be done once per video - Download transparent video for use in other editors **Stage 2 - Composite Background:** - Fast compositing with your chosen background - Can try multiple backgrounds without re-processing - Change background and re-composite instantly - Much faster than Stage 1 **Benefits:** - **Flexible**: Try different backgrounds easily - **Efficient**: Reuse transparent video multiple times - **Professional**: Industry-standard workflow - **Cacheable**: Save transparent video for future use """) if __name__ == "__main__": main()