crash10155's picture
Update SwitcherAI/processors/frame/modules/frame_enhancer.py
d454ef7 verified
from typing import Any, List, Callable, Dict, Optional
import cv2
import threading
import numpy
from functools import lru_cache
from pathlib import Path
import SwitcherAI.processors.frame.core as frame_processors
from SwitcherAI.typing import Frame, Face
from SwitcherAI.utilities import conditional_download, resolve_relative_path
# Global variables (maintaining your original structure)
FRAME_PROCESSOR = None
THREAD_SEMAPHORE = threading.Semaphore(1)
THREAD_LOCK = threading.Lock()
NAME = 'FACEFUSION.FRAME_PROCESSOR.FRAME_ENHANCER'
# Enhanced model configuration inspired by FaceFusion
@lru_cache(maxsize=None)
def get_model_config() -> Dict[str, Any]:
"""Get model configuration with enhanced options"""
base_path = resolve_relative_path('../.assets/models')
if isinstance(base_path, str):
base_path = Path(base_path)
return {
'real_esrgan_x4': {
'model_path': base_path / 'RealESRGAN_x4plus.pth',
'scale': 4,
'tile_size': 256,
'tile_pad': 16,
'num_feat': 64,
'num_block': 23,
'num_grow_ch': 32
}
}
def get_frame_processor() -> Any:
global FRAME_PROCESSOR
with THREAD_LOCK:
if FRAME_PROCESSOR is None:
try:
# Import Real-ESRGAN components
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
import torch
config = get_model_config()['real_esrgan_x4']
model_path = config['model_path']
# Check if model exists
if not model_path.exists():
print(f"⚠️ Real-ESRGAN model not found at: {model_path}")
print("🔄 Attempting to download model...")
if not pre_check():
print("❌ Failed to download Real-ESRGAN model")
return None
FRAME_PROCESSOR = RealESRGANer(
model_path=str(model_path),
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=config['num_feat'],
num_block=config['num_block'],
num_grow_ch=config['num_grow_ch'],
scale=config['scale']
),
device=frame_processors.get_device(),
tile=config['tile_size'],
tile_pad=config['tile_pad'],
pre_pad=0,
scale=config['scale']
)
# Ensure CUDA device is set if available
if torch.cuda.is_available():
torch.cuda.set_device(0)
print("✅ Real-ESRGAN frame processor initialized")
except ImportError as e:
print(f"⚠️ Real-ESRGAN not available: {e}")
print("💡 Install with: pip install realesrgan basicsr")
FRAME_PROCESSOR = None
except Exception as e:
print(f"⚠️ Failed to initialize Real-ESRGAN: {e}")
FRAME_PROCESSOR = None
return FRAME_PROCESSOR
def clear_frame_processor() -> None:
global FRAME_PROCESSOR
FRAME_PROCESSOR = None
def pre_check() -> bool:
"""Download required models for frame enhancement"""
try:
download_directory_path = resolve_relative_path('../.assets/models')
# Ensure download directory exists
if isinstance(download_directory_path, str):
download_directory_path = Path(download_directory_path)
download_directory_path.mkdir(parents=True, exist_ok=True)
# Download Real-ESRGAN model
model_urls = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
]
conditional_download(str(download_directory_path), model_urls)
# Verify the model was downloaded
model_path = download_directory_path / 'RealESRGAN_x4plus.pth'
if model_path.exists() and model_path.stat().st_size > 0:
print(f"✅ Real-ESRGAN model verified: {model_path.stat().st_size / (1024*1024):.1f}MB")
return True
else:
print("❌ Real-ESRGAN model download failed or file is empty")
return False
except Exception as e:
print(f"❌ Real-ESRGAN pre-check failed: {e}")
return False
def pre_process() -> bool:
"""Pre-process check with model validation"""
try:
# Check if processor is available
processor = get_frame_processor()
if processor is None:
print("⚠️ Real-ESRGAN not available, frame enhancement will be skipped")
return False
return True
except Exception as e:
print(f"⚠️ Frame enhancement pre-process failed: {e}")
return False
def post_process() -> None:
clear_frame_processor()
# Clear cache as in FaceFusion version
get_model_config.cache_clear()
def create_tile_frames(temp_vision_frame: Frame, tile_size: tuple = (256, 256)) -> tuple:
"""
Enhanced tiling function inspired by FaceFusion for better memory management
"""
height, width = temp_vision_frame.shape[:2]
tile_height, tile_width = tile_size[0], tile_size[1]
# Calculate padding
pad_height = (tile_height - height % tile_height) % tile_height
pad_width = (tile_width - width % tile_width) % tile_width
# Pad the frame
if pad_height > 0 or pad_width > 0:
temp_vision_frame = numpy.pad(
temp_vision_frame,
((0, pad_height), (0, pad_width), (0, 0)),
mode='reflect'
)
# Create tiles
tiles = []
padded_height, padded_width = temp_vision_frame.shape[:2]
for y in range(0, padded_height, tile_height):
for x in range(0, padded_width, tile_width):
tile = temp_vision_frame[y:y+tile_height, x:x+tile_width]
tiles.append(tile)
return tiles, pad_width, pad_height
def merge_tile_frames(tiles: List[Frame], original_width: int, original_height: int,
pad_width: int, pad_height: int, tile_size: tuple) -> Frame:
"""
Enhanced tile merging function inspired by FaceFusion
"""
tile_height, tile_width = tile_size[0], tile_size[1]
padded_height = original_height + pad_height
padded_width = original_width + pad_width
# Reconstruct the image from tiles
result = numpy.zeros((padded_height, padded_width, 3), dtype=numpy.uint8)
tile_idx = 0
for y in range(0, padded_height, tile_height):
for x in range(0, padded_width, tile_width):
if tile_idx < len(tiles):
tile = tiles[tile_idx]
result[y:y+tile_height, x:x+tile_width] = tile
tile_idx += 1
# Remove padding and return to original size
if pad_height > 0 or pad_width > 0:
result = result[:original_height, :original_width]
return result
def enhance_frame_with_tiling(temp_frame: Frame) -> Frame:
"""
Enhanced frame enhancement with improved tiling (inspired by FaceFusion)
"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Real-ESRGAN processor not available, returning original frame")
return temp_frame
config = get_model_config()['real_esrgan_x4']
tile_size = (config['tile_size'], config['tile_size'])
scale = config['scale']
# Create tiles for processing
tiles, pad_width, pad_height = create_tile_frames(temp_frame, tile_size)
enhanced_tiles = []
with THREAD_SEMAPHORE:
for tile in tiles:
try:
# Process each tile individually to manage memory
enhanced_tile, _ = processor.enhance(tile, outscale=scale)
enhanced_tiles.append(enhanced_tile)
except Exception as e:
print(f"⚠️ Tile enhancement failed: {e}")
# Use original tile if enhancement fails
enhanced_tiles.append(tile)
# Merge tiles back together
original_height, original_width = temp_frame.shape[:2]
enhanced_frame = merge_tile_frames(
enhanced_tiles,
original_width * scale,
original_height * scale,
pad_width * scale,
pad_height * scale,
(tile_size[0] * scale, tile_size[1] * scale)
)
return enhanced_frame
except Exception as e:
print(f"⚠️ Enhanced tiling failed: {e}")
return temp_frame
def enhance_frame(temp_frame: Frame) -> Frame:
"""
Main enhancement function with fallback to original method
"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Frame enhancer not available, returning original frame")
return temp_frame
# Try enhanced tiling method first
try:
return enhance_frame_with_tiling(temp_frame)
except Exception as e:
print(f"⚠️ Tiling method failed: {e}, trying simple enhancement")
# Fallback to original method
with THREAD_SEMAPHORE:
enhanced_frame, _ = processor.enhance(temp_frame, outscale=1)
return enhanced_frame
except Exception as e:
print(f"⚠️ Frame enhancement failed completely: {e}")
return temp_frame
def blend_frame(original_frame: Frame, enhanced_frame: Frame, blend_ratio: float = 0.8) -> Frame:
"""
Blend original and enhanced frames (inspired by FaceFusion)
"""
try:
if original_frame.shape != enhanced_frame.shape:
original_frame = cv2.resize(original_frame, (enhanced_frame.shape[1], enhanced_frame.shape[0]))
# Convert blend ratio (0-1 where 1 = full enhancement)
return cv2.addWeighted(original_frame, 1 - blend_ratio, enhanced_frame, blend_ratio, 0)
except Exception as e:
print(f"⚠️ Frame blending failed: {e}")
return enhanced_frame
def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame:
"""
Main processing function (maintains your original interface)
"""
try:
return enhance_frame(temp_frame)
except Exception as e:
print(f"⚠️ Error in process_frame: {e}")
return temp_frame
def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
"""
Process multiple frames (maintains your original interface)
"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Frame enhancer not available, skipping frame enhancement")
if update:
update()
return
for temp_frame_path in temp_frame_paths:
try:
temp_frame = cv2.imread(temp_frame_path)
if temp_frame is not None:
result_frame = process_frame(None, None, temp_frame)
cv2.imwrite(temp_frame_path, result_frame)
else:
print(f"⚠️ Failed to read frame: {temp_frame_path}")
except Exception as e:
print(f"⚠️ Error processing frame {temp_frame_path}: {e}")
if update:
update()
except Exception as e:
print(f"⚠️ Error in process_frames: {e}")
def process_image(source_path: str, target_path: str, output_path: str) -> None:
"""
Process single image (maintains your original interface)
"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Frame enhancer not available, copying original image")
import shutil
shutil.copy2(target_path, output_path)
return
target_frame = cv2.imread(target_path)
if target_frame is not None:
result = process_frame(None, None, target_frame)
cv2.imwrite(output_path, result)
else:
print(f"⚠️ Failed to read image: {target_path}")
except Exception as e:
print(f"⚠️ Error in process_image: {e}")
def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
"""
Process video frames (maintains your original interface)
"""
try:
frame_processors.process_video(None, temp_frame_paths, process_frames)
except Exception as e:
print(f"⚠️ Error in process_video: {e}")
# Additional utility functions inspired by FaceFusion
def get_model_scale() -> int:
"""Get the current model's scale factor"""
try:
return get_model_config()['real_esrgan_x4']['scale']
except:
return 1
def prepare_frame(frame: Frame) -> Frame:
"""Prepare frame for processing"""
try:
if frame.dtype != numpy.uint8:
frame = frame.astype(numpy.uint8)
return frame
except:
return frame
def normalize_frame(frame: Frame) -> Frame:
"""Normalize frame after processing"""
try:
return numpy.clip(frame, 0, 255).astype(numpy.uint8)
except:
return frame