Spaces:
Paused
Paused
from typing import Any, List, Callable, Dict, Tuple | |
import cv2 | |
import insightface | |
import threading | |
import numpy as np | |
from functools import lru_cache | |
from pathlib import Path | |
import SwitcherAI.globals | |
import SwitcherAI.processors.frame.core as frame_processors | |
from SwitcherAI import wording | |
from SwitcherAI.core import update_status | |
from SwitcherAI.face_analyser import get_one_face, get_many_faces, find_similar_faces | |
from SwitcherAI.face_reference import get_face_reference, set_face_reference | |
from SwitcherAI.typing import Face, Frame | |
from SwitcherAI.utilities import resolve_relative_path, is_image, is_video | |
FRAME_PROCESSOR = None | |
EMBEDDING_CONVERTER = None | |
THREAD_LOCK = threading.Lock() | |
NAME = 'FACEFUSION.FRAME_PROCESSOR.FACE_SWAPPER' | |
# Model configurations - local paths only | |
MODEL_CONFIGS = { | |
'inswapper_128': { | |
'path': '../.assets/models/inswapper_128.onnx', | |
'type': 'inswapper', | |
'size': (128, 128), | |
'mean': [0.0, 0.0, 0.0], | |
'standard_deviation': [1.0, 1.0, 1.0], | |
'requires_converter': False | |
}, | |
'inswapper_128_fp16': { | |
'path': '../.assets/models/inswapper_128_fp16.onnx', | |
'type': 'inswapper', | |
'size': (128, 128), | |
'mean': [0.0, 0.0, 0.0], | |
'standard_deviation': [1.0, 1.0, 1.0], | |
'requires_converter': False | |
}, | |
'simswap_256': { | |
'path': '../.assets/models/simswap_256.onnx', | |
'converter_path': '../.assets/models/simswap_256_converter.onnx', | |
'type': 'simswap', | |
'size': (256, 256), | |
'mean': [0.485, 0.456, 0.406], | |
'standard_deviation': [0.229, 0.224, 0.225], | |
'requires_converter': True | |
}, | |
} | |
# Default model - can be changed via globals | |
DEFAULT_MODEL = 'inswapper_128' | |
def get_current_model_config() -> Dict: | |
"""Get the current model configuration""" | |
model_name = getattr(SwitcherAI.globals, 'face_swapper_model', DEFAULT_MODEL) | |
return MODEL_CONFIGS.get(model_name, MODEL_CONFIGS[DEFAULT_MODEL]) | |
def get_static_model_initializer(model_path: str) -> np.ndarray: | |
"""Cache model initialization data""" | |
try: | |
# This would need to be implemented based on the specific model requirements | |
# For now, return identity matrix as fallback | |
return np.eye(512, dtype=np.float32) | |
except Exception: | |
return np.eye(512, dtype=np.float32) | |
def get_frame_processor() -> Any: | |
global FRAME_PROCESSOR | |
with THREAD_LOCK: | |
if FRAME_PROCESSOR is None: | |
# Try models in order of preference | |
model_priority = ['inswapper_128_fp16', 'inswapper_128', 'simswap_256'] | |
# If user set a specific model, try it first | |
current_model = getattr(SwitcherAI.globals, 'face_swapper_model', None) | |
if current_model and current_model in MODEL_CONFIGS: | |
model_priority.insert(0, current_model) | |
# Remove duplicate if it exists later in the list | |
if current_model in model_priority[1:]: | |
model_priority.remove(current_model) | |
for model_name in model_priority: | |
if model_name not in MODEL_CONFIGS: | |
continue | |
try: | |
print(f"🔄 Trying to load face swap model: {model_name}") | |
# Get model config | |
temp_config = MODEL_CONFIGS[model_name] | |
model_path = resolve_relative_path(temp_config['path']) | |
# Convert to Path object if it's a string for validation | |
if isinstance(model_path, str): | |
model_path_obj = Path(model_path) | |
else: | |
model_path_obj = model_path | |
# Check if model exists locally | |
if not model_path_obj.exists(): | |
print(f"❌ Model {model_name} not found at: {model_path_obj}") | |
continue | |
# Verify model file size | |
if model_path_obj.stat().st_size < 1024: # Less than 1KB indicates corruption | |
print(f"⚠️ {model_name} appears corrupted (file too small), skipping...") | |
continue | |
# Try to load the model | |
FRAME_PROCESSOR = insightface.model_zoo.get_model( | |
str(model_path_obj), | |
providers=SwitcherAI.globals.execution_providers | |
) | |
# If successful, update the global setting and break | |
SwitcherAI.globals.face_swapper_model = model_name | |
print(f"✅ Successfully loaded face swap model: {model_name}") | |
break | |
except Exception as e: | |
print(f"❌ Failed to load {model_name}: {e}") | |
continue | |
if FRAME_PROCESSOR is None: | |
print("❌ All face swap models failed to load. Please ensure models are present in .assets/models folder.") | |
return FRAME_PROCESSOR | |
def get_embedding_converter() -> Any: | |
global EMBEDDING_CONVERTER | |
config = get_current_model_config() | |
if not config.get('requires_converter', False): | |
return None | |
with THREAD_LOCK: | |
if EMBEDDING_CONVERTER is None: | |
try: | |
converter_path = resolve_relative_path(config['converter_path']) | |
# Convert to Path object if it's a string for validation | |
if isinstance(converter_path, str): | |
converter_path_obj = Path(converter_path) | |
else: | |
converter_path_obj = converter_path | |
# Check if converter exists locally | |
if not converter_path_obj.exists(): | |
print(f"❌ Embedding converter not found at: {converter_path_obj}") | |
print("Please ensure the converter model is present in .assets/models folder.") | |
return None | |
EMBEDDING_CONVERTER = insightface.model_zoo.get_model( | |
str(converter_path_obj), | |
providers=SwitcherAI.globals.execution_providers | |
) | |
print("✅ Embedding converter initialized") | |
except Exception as e: | |
print(f"❌ Failed to initialize embedding converter: {e}") | |
EMBEDDING_CONVERTER = None | |
return EMBEDDING_CONVERTER | |
def clear_frame_processor() -> None: | |
global FRAME_PROCESSOR, EMBEDDING_CONVERTER | |
FRAME_PROCESSOR = None | |
EMBEDDING_CONVERTER = None | |
def pre_check() -> bool: | |
"""Check if required models exist locally""" | |
try: | |
config = get_current_model_config() | |
# Check main model path | |
model_path = resolve_relative_path(config['path']) | |
if isinstance(model_path, str): | |
model_path_obj = Path(model_path) | |
else: | |
model_path_obj = model_path | |
if not model_path_obj.exists(): | |
print(f"❌ Main model not found at: {model_path_obj}") | |
print("Please ensure the model file is present in .assets/models folder.") | |
return False | |
# Check converter if needed | |
if config.get('requires_converter', False): | |
converter_path = resolve_relative_path(config['converter_path']) | |
if isinstance(converter_path, str): | |
converter_path_obj = Path(converter_path) | |
else: | |
converter_path_obj = converter_path | |
if not converter_path_obj.exists(): | |
print(f"❌ Converter model not found at: {converter_path_obj}") | |
print("Please ensure the converter model file is present in .assets/models folder.") | |
return False | |
print("✅ All required models found locally") | |
return True | |
except Exception as e: | |
print(f"❌ Face swap pre-check failed: {e}") | |
return False | |
def pre_process() -> bool: | |
try: | |
if not is_image(SwitcherAI.globals.source_path): | |
update_status(wording.get('select_image_source') + wording.get('exclamation_mark'), NAME) | |
return False | |
elif not get_one_face(cv2.imread(SwitcherAI.globals.source_path)): | |
update_status(wording.get('no_source_face_detected') + wording.get('exclamation_mark'), NAME) | |
return False | |
if not is_image(SwitcherAI.globals.target_path) and not is_video(SwitcherAI.globals.target_path): | |
update_status(wording.get('select_image_or_video_target') + wording.get('exclamation_mark'), NAME) | |
return False | |
# Check if required models exist locally | |
if not pre_check(): | |
update_status("Required models not found in .assets/models folder", NAME) | |
return False | |
# Check if processor is available | |
processor = get_frame_processor() | |
if processor is None: | |
update_status("Face swap processor not available", NAME) | |
return False | |
return True | |
except Exception as e: | |
print(f"⚠️ Face swap pre-process failed: {e}") | |
return False | |
def post_process() -> None: | |
clear_frame_processor() | |
# Clear caches like the newer version | |
get_static_model_initializer.cache_clear() | |
def prepare_source_embedding(source_face: Face) -> np.ndarray: | |
"""Prepare source face embedding based on model type""" | |
try: | |
config = get_current_model_config() | |
model_type = config['type'] | |
if model_type == 'inswapper': | |
# Enhanced embedding preparation for inswapper | |
model_path = resolve_relative_path(config['path']) | |
model_initializer = get_static_model_initializer(str(model_path)) | |
source_embedding = source_face.embedding.reshape((1, -1)) | |
source_embedding = np.dot(source_embedding, model_initializer) / np.linalg.norm(source_embedding) | |
return source_embedding | |
elif model_type == 'simswap': | |
# Use embedding converter for simswap | |
converter = get_embedding_converter() | |
if converter is not None: | |
embedding = source_face.embedding.reshape(-1, 512) | |
try: | |
converted_embedding = converter.run(None, {'input': embedding})[0] | |
converted_embedding = converted_embedding.ravel() | |
normed_embedding = converted_embedding / np.linalg.norm(converted_embedding) | |
return normed_embedding.reshape(1, -1) | |
except Exception: | |
pass | |
# Fallback to original embedding | |
return source_face.embedding.reshape(1, -1) | |
else: | |
# Default behavior | |
return source_face.embedding.reshape(1, -1) | |
except Exception as e: | |
print(f"⚠️ Error preparing source embedding: {e}") | |
return source_face.embedding.reshape(1, -1) | |
def prepare_crop_frame(crop_frame: Frame) -> np.ndarray: | |
"""Prepare cropped frame for model input with normalization""" | |
try: | |
config = get_current_model_config() | |
model_mean = config['mean'] | |
model_std = config['standard_deviation'] | |
# Convert to float and normalize | |
crop_frame = crop_frame[:, :, ::-1] / 255.0 | |
crop_frame = (crop_frame - model_mean) / model_std | |
crop_frame = crop_frame.transpose(2, 0, 1) | |
crop_frame = np.expand_dims(crop_frame, axis=0).astype(np.float32) | |
return crop_frame | |
except Exception as e: | |
print(f"⚠️ Error preparing crop frame: {e}") | |
return crop_frame | |
def normalize_crop_frame(crop_frame: np.ndarray) -> Frame: | |
"""Normalize cropped frame back to image format""" | |
try: | |
config = get_current_model_config() | |
model_type = config['type'] | |
model_mean = config['mean'] | |
model_std = config['standard_deviation'] | |
crop_frame = crop_frame.transpose(1, 2, 0) | |
# Apply reverse normalization for certain model types | |
if model_type in ['simswap']: | |
crop_frame = crop_frame * model_std + model_mean | |
crop_frame = crop_frame.clip(0, 1) | |
crop_frame = crop_frame[:, :, ::-1] * 255 | |
return crop_frame.astype(np.uint8) | |
except Exception as e: | |
print(f"⚠️ Error normalizing crop frame: {e}") | |
return crop_frame.astype(np.uint8) | |
def enhanced_swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: | |
"""Enhanced face swapping with improved preprocessing""" | |
try: | |
processor = get_frame_processor() | |
if processor is None: | |
print("⚠️ Face swap processor not available") | |
return temp_frame | |
config = get_current_model_config() | |
model_type = config['type'] | |
if model_type == 'inswapper': | |
# Use original method for inswapper | |
return processor.get(temp_frame, target_face, source_face, paste_back=True) | |
else: | |
# Enhanced method for other models | |
try: | |
# Prepare source embedding | |
source_embedding = prepare_source_embedding(source_face) | |
# Get crop region (this would need proper implementation) | |
# For now, fall back to original method | |
return processor.get(temp_frame, target_face, source_face, paste_back=True) | |
except Exception as e: | |
print(f"⚠️ Enhanced swap failed: {e}") | |
# Fallback to original method | |
return processor.get(temp_frame, target_face, source_face, paste_back=True) | |
except Exception as e: | |
print(f"⚠️ Face swap failed: {e}") | |
return temp_frame | |
def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame: | |
"""Main face swapping function with model-specific handling""" | |
try: | |
processor = get_frame_processor() | |
if processor is None: | |
print("⚠️ Face swap processor not available, skipping swap") | |
return temp_frame | |
config = get_current_model_config() | |
# Use enhanced swapping for supported models | |
if config['type'] in ['simswap', 'inswapper']: | |
return enhanced_swap_face(source_face, target_face, temp_frame) | |
else: | |
# Original method | |
return processor.get(temp_frame, target_face, source_face, paste_back=True) | |
except Exception as e: | |
print(f"⚠️ Error in swap_face: {e}") | |
return temp_frame | |
def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame: | |
"""Process frame with enhanced face selection logic""" | |
try: | |
processor = get_frame_processor() | |
if processor is None: | |
print("⚠️ Face swap processor not available, skipping frame") | |
return temp_frame | |
if 'reference' in SwitcherAI.globals.face_recognition: | |
similar_faces = find_similar_faces(temp_frame, reference_face, SwitcherAI.globals.reference_face_distance) | |
if similar_faces: | |
for similar_face in similar_faces: | |
temp_frame = swap_face(source_face, similar_face, temp_frame) | |
if 'many' in SwitcherAI.globals.face_recognition: | |
many_faces = get_many_faces(temp_frame) | |
if many_faces: | |
# Sort faces by size (largest first) like the newer version | |
many_faces = sorted(many_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True) | |
for target_face in many_faces: | |
temp_frame = swap_face(source_face, target_face, temp_frame) | |
return temp_frame | |
except Exception as e: | |
print(f"⚠️ Error in process_frame: {e}") | |
return temp_frame | |
def get_average_face(faces: List[Face]) -> Face: | |
"""Get average face from multiple faces (simplified version)""" | |
if not faces: | |
return None | |
if len(faces) == 1: | |
return faces[0] | |
# For now, return the first face | |
# In a full implementation, this would average the embeddings | |
return faces[0] | |
def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None: | |
"""Enhanced frame processing with better source face handling""" | |
try: | |
processor = get_frame_processor() | |
if processor is None: | |
print("⚠️ Face swap processor not available, skipping frame processing") | |
if update: | |
update() | |
return | |
source_frame = cv2.imread(source_path) | |
if source_frame is None: | |
print(f"⚠️ Failed to read source image: {source_path}") | |
if update: | |
update() | |
return | |
source_faces = get_many_faces(source_frame) | |
# Get best source face (largest) | |
if source_faces: | |
source_faces = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True) | |
source_face = source_faces[0] | |
else: | |
source_face = get_one_face(source_frame) | |
if source_face is None: | |
print("⚠️ No source face found") | |
if update: | |
update() | |
return | |
# Handle multiple source faces if available | |
if len(source_faces) > 1: | |
source_face = get_average_face(source_faces) | |
reference_face = get_face_reference() if 'reference' in SwitcherAI.globals.face_recognition else None | |
for temp_frame_path in temp_frame_paths: | |
try: | |
temp_frame = cv2.imread(temp_frame_path) | |
if temp_frame is not None: | |
result_frame = process_frame(source_face, reference_face, temp_frame) | |
cv2.imwrite(temp_frame_path, result_frame) | |
else: | |
print(f"⚠️ Failed to read frame: {temp_frame_path}") | |
except Exception as e: | |
print(f"⚠️ Error processing frame {temp_frame_path}: {e}") | |
if update: | |
update() | |
except Exception as e: | |
print(f"⚠️ Error in process_frames: {e}") | |
def process_image(source_path: str, target_path: str, output_path: str) -> None: | |
"""Enhanced image processing""" | |
try: | |
processor = get_frame_processor() | |
if processor is None: | |
print("⚠️ Face swap processor not available, copying original image") | |
import shutil | |
shutil.copy2(target_path, output_path) | |
return | |
source_frame = cv2.imread(source_path) | |
if source_frame is None: | |
print(f"⚠️ Failed to read source image: {source_path}") | |
return | |
source_faces = get_many_faces(source_frame) | |
# Get best source face | |
if source_faces: | |
source_faces = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True) | |
source_face = source_faces[0] | |
# Handle multiple source faces | |
if len(source_faces) > 1: | |
source_face = get_average_face(source_faces) | |
else: | |
source_face = get_one_face(source_frame) | |
if source_face is None: | |
print("⚠️ No source face found") | |
return | |
target_frame = cv2.imread(target_path) | |
if target_frame is None: | |
print(f"⚠️ Failed to read target image: {target_path}") | |
return | |
reference_face = get_one_face(target_frame, SwitcherAI.globals.reference_face_position) if 'reference' in SwitcherAI.globals.face_recognition else None | |
result_frame = process_frame(source_face, reference_face, target_frame) | |
cv2.imwrite(output_path, result_frame) | |
except Exception as e: | |
print(f"⚠️ Error in process_image: {e}") | |
def process_video(source_path: str, temp_frame_paths: List[str]) -> None: | |
try: | |
conditional_set_face_reference(temp_frame_paths) | |
frame_processors.process_video(source_path, temp_frame_paths, process_frames) | |
except Exception as e: | |
print(f"⚠️ Error in process_video: {e}") | |
def conditional_set_face_reference(temp_frame_paths: List[str]) -> None: | |
try: | |
if 'reference' in SwitcherAI.globals.face_recognition and not get_face_reference(): | |
reference_frame = cv2.imread(temp_frame_paths[SwitcherAI.globals.reference_face_position]) | |
if reference_frame is not None: | |
reference_face = get_one_face(reference_frame, SwitcherAI.globals.reference_face_position) | |
set_face_reference(reference_face) | |
else: | |
print(f"⚠️ Failed to read reference frame") | |
except Exception as e: | |
print(f"⚠️ Error setting face reference: {e}") |