from typing import Any, List, Callable, Dict, Optional import cv2 import threading import numpy from functools import lru_cache from pathlib import Path import SwitcherAI.processors.frame.core as frame_processors from SwitcherAI.typing import Frame, Face from SwitcherAI.utilities import conditional_download, resolve_relative_path # Global variables (maintaining your original structure) FRAME_PROCESSOR = None THREAD_SEMAPHORE = threading.Semaphore(1) THREAD_LOCK = threading.Lock() NAME = 'FACEFUSION.FRAME_PROCESSOR.FRAME_ENHANCER' # Enhanced model configuration inspired by FaceFusion @lru_cache(maxsize=None) def get_model_config() -> Dict[str, Any]: """Get model configuration with enhanced options""" base_path = resolve_relative_path('../.assets/models') if isinstance(base_path, str): base_path = Path(base_path) return { 'real_esrgan_x4': { 'model_path': base_path / 'RealESRGAN_x4plus.pth', 'scale': 4, 'tile_size': 256, 'tile_pad': 16, 'num_feat': 64, 'num_block': 23, 'num_grow_ch': 32 } } def get_frame_processor() -> Any: global FRAME_PROCESSOR with THREAD_LOCK: if FRAME_PROCESSOR is None: try: # Import Real-ESRGAN components from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer import torch config = get_model_config()['real_esrgan_x4'] model_path = config['model_path'] # Check if model exists if not model_path.exists(): print(f"⚠️ Real-ESRGAN model not found at: {model_path}") print("🔄 Attempting to download model...") if not pre_check(): print("❌ Failed to download Real-ESRGAN model") return None FRAME_PROCESSOR = RealESRGANer( model_path=str(model_path), model=RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=config['num_feat'], num_block=config['num_block'], num_grow_ch=config['num_grow_ch'], scale=config['scale'] ), device=frame_processors.get_device(), tile=config['tile_size'], tile_pad=config['tile_pad'], pre_pad=0, scale=config['scale'] ) # Ensure CUDA device is set if available if torch.cuda.is_available(): torch.cuda.set_device(0) print("✅ Real-ESRGAN frame processor initialized") except ImportError as e: print(f"⚠️ Real-ESRGAN not available: {e}") print("💡 Install with: pip install realesrgan basicsr") FRAME_PROCESSOR = None except Exception as e: print(f"⚠️ Failed to initialize Real-ESRGAN: {e}") FRAME_PROCESSOR = None return FRAME_PROCESSOR def clear_frame_processor() -> None: global FRAME_PROCESSOR FRAME_PROCESSOR = None def pre_check() -> bool: """Download required models for frame enhancement""" try: download_directory_path = resolve_relative_path('../.assets/models') # Ensure download directory exists if isinstance(download_directory_path, str): download_directory_path = Path(download_directory_path) download_directory_path.mkdir(parents=True, exist_ok=True) # Download Real-ESRGAN model model_urls = [ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' ] conditional_download(str(download_directory_path), model_urls) # Verify the model was downloaded model_path = download_directory_path / 'RealESRGAN_x4plus.pth' if model_path.exists() and model_path.stat().st_size > 0: print(f"✅ Real-ESRGAN model verified: {model_path.stat().st_size / (1024*1024):.1f}MB") return True else: print("❌ Real-ESRGAN model download failed or file is empty") return False except Exception as e: print(f"❌ Real-ESRGAN pre-check failed: {e}") return False def pre_process() -> bool: """Pre-process check with model validation""" try: # Check if processor is available processor = get_frame_processor() if processor is None: print("⚠️ Real-ESRGAN not available, frame enhancement will be skipped") return False return True except Exception as e: print(f"⚠️ Frame enhancement pre-process failed: {e}") return False def post_process() -> None: clear_frame_processor() # Clear cache as in FaceFusion version get_model_config.cache_clear() def create_tile_frames(temp_vision_frame: Frame, tile_size: tuple = (256, 256)) -> tuple: """ Enhanced tiling function inspired by FaceFusion for better memory management """ height, width = temp_vision_frame.shape[:2] tile_height, tile_width = tile_size[0], tile_size[1] # Calculate padding pad_height = (tile_height - height % tile_height) % tile_height pad_width = (tile_width - width % tile_width) % tile_width # Pad the frame if pad_height > 0 or pad_width > 0: temp_vision_frame = numpy.pad( temp_vision_frame, ((0, pad_height), (0, pad_width), (0, 0)), mode='reflect' ) # Create tiles tiles = [] padded_height, padded_width = temp_vision_frame.shape[:2] for y in range(0, padded_height, tile_height): for x in range(0, padded_width, tile_width): tile = temp_vision_frame[y:y+tile_height, x:x+tile_width] tiles.append(tile) return tiles, pad_width, pad_height def merge_tile_frames(tiles: List[Frame], original_width: int, original_height: int, pad_width: int, pad_height: int, tile_size: tuple) -> Frame: """ Enhanced tile merging function inspired by FaceFusion """ tile_height, tile_width = tile_size[0], tile_size[1] padded_height = original_height + pad_height padded_width = original_width + pad_width # Reconstruct the image from tiles result = numpy.zeros((padded_height, padded_width, 3), dtype=numpy.uint8) tile_idx = 0 for y in range(0, padded_height, tile_height): for x in range(0, padded_width, tile_width): if tile_idx < len(tiles): tile = tiles[tile_idx] result[y:y+tile_height, x:x+tile_width] = tile tile_idx += 1 # Remove padding and return to original size if pad_height > 0 or pad_width > 0: result = result[:original_height, :original_width] return result def enhance_frame_with_tiling(temp_frame: Frame) -> Frame: """ Enhanced frame enhancement with improved tiling (inspired by FaceFusion) """ try: processor = get_frame_processor() if processor is None: print("⚠️ Real-ESRGAN processor not available, returning original frame") return temp_frame config = get_model_config()['real_esrgan_x4'] tile_size = (config['tile_size'], config['tile_size']) scale = config['scale'] # Create tiles for processing tiles, pad_width, pad_height = create_tile_frames(temp_frame, tile_size) enhanced_tiles = [] with THREAD_SEMAPHORE: for tile in tiles: try: # Process each tile individually to manage memory enhanced_tile, _ = processor.enhance(tile, outscale=scale) enhanced_tiles.append(enhanced_tile) except Exception as e: print(f"⚠️ Tile enhancement failed: {e}") # Use original tile if enhancement fails enhanced_tiles.append(tile) # Merge tiles back together original_height, original_width = temp_frame.shape[:2] enhanced_frame = merge_tile_frames( enhanced_tiles, original_width * scale, original_height * scale, pad_width * scale, pad_height * scale, (tile_size[0] * scale, tile_size[1] * scale) ) return enhanced_frame except Exception as e: print(f"⚠️ Enhanced tiling failed: {e}") return temp_frame def enhance_frame(temp_frame: Frame) -> Frame: """ Main enhancement function with fallback to original method """ try: processor = get_frame_processor() if processor is None: print("⚠️ Frame enhancer not available, returning original frame") return temp_frame # Try enhanced tiling method first try: return enhance_frame_with_tiling(temp_frame) except Exception as e: print(f"⚠️ Tiling method failed: {e}, trying simple enhancement") # Fallback to original method with THREAD_SEMAPHORE: enhanced_frame, _ = processor.enhance(temp_frame, outscale=1) return enhanced_frame except Exception as e: print(f"⚠️ Frame enhancement failed completely: {e}") return temp_frame def blend_frame(original_frame: Frame, enhanced_frame: Frame, blend_ratio: float = 0.8) -> Frame: """ Blend original and enhanced frames (inspired by FaceFusion) """ try: if original_frame.shape != enhanced_frame.shape: original_frame = cv2.resize(original_frame, (enhanced_frame.shape[1], enhanced_frame.shape[0])) # Convert blend ratio (0-1 where 1 = full enhancement) return cv2.addWeighted(original_frame, 1 - blend_ratio, enhanced_frame, blend_ratio, 0) except Exception as e: print(f"⚠️ Frame blending failed: {e}") return enhanced_frame def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame: """ Main processing function (maintains your original interface) """ try: return enhance_frame(temp_frame) except Exception as e: print(f"⚠️ Error in process_frame: {e}") return temp_frame def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None: """ Process multiple frames (maintains your original interface) """ try: processor = get_frame_processor() if processor is None: print("⚠️ Frame enhancer not available, skipping frame enhancement") if update: update() return for temp_frame_path in temp_frame_paths: try: temp_frame = cv2.imread(temp_frame_path) if temp_frame is not None: result_frame = process_frame(None, None, temp_frame) cv2.imwrite(temp_frame_path, result_frame) else: print(f"⚠️ Failed to read frame: {temp_frame_path}") except Exception as e: print(f"⚠️ Error processing frame {temp_frame_path}: {e}") if update: update() except Exception as e: print(f"⚠️ Error in process_frames: {e}") def process_image(source_path: str, target_path: str, output_path: str) -> None: """ Process single image (maintains your original interface) """ try: processor = get_frame_processor() if processor is None: print("⚠️ Frame enhancer not available, copying original image") import shutil shutil.copy2(target_path, output_path) return target_frame = cv2.imread(target_path) if target_frame is not None: result = process_frame(None, None, target_frame) cv2.imwrite(output_path, result) else: print(f"⚠️ Failed to read image: {target_path}") except Exception as e: print(f"⚠️ Error in process_image: {e}") def process_video(source_path: str, temp_frame_paths: List[str]) -> None: """ Process video frames (maintains your original interface) """ try: frame_processors.process_video(None, temp_frame_paths, process_frames) except Exception as e: print(f"⚠️ Error in process_video: {e}") # Additional utility functions inspired by FaceFusion def get_model_scale() -> int: """Get the current model's scale factor""" try: return get_model_config()['real_esrgan_x4']['scale'] except: return 1 def prepare_frame(frame: Frame) -> Frame: """Prepare frame for processing""" try: if frame.dtype != numpy.uint8: frame = frame.astype(numpy.uint8) return frame except: return frame def normalize_frame(frame: Frame) -> Frame: """Normalize frame after processing""" try: return numpy.clip(frame, 0, 255).astype(numpy.uint8) except: return frame