crash10155's picture
Update SwitcherAI/processors/frame/modules/face_swapper.py
ddf9937 verified
from typing import Any, List, Callable, Dict, Tuple
import cv2
import insightface
import threading
import numpy as np
from functools import lru_cache
from pathlib import Path
import SwitcherAI.globals
import SwitcherAI.processors.frame.core as frame_processors
from SwitcherAI import wording
from SwitcherAI.core import update_status
from SwitcherAI.face_analyser import get_one_face, get_many_faces, find_similar_faces
from SwitcherAI.face_reference import get_face_reference, set_face_reference
from SwitcherAI.typing import Face, Frame
from SwitcherAI.utilities import resolve_relative_path, is_image, is_video
FRAME_PROCESSOR = None
EMBEDDING_CONVERTER = None
THREAD_LOCK = threading.Lock()
NAME = 'FACEFUSION.FRAME_PROCESSOR.FACE_SWAPPER'
# Model configurations - local paths only
MODEL_CONFIGS = {
'inswapper_128': {
'path': '../.assets/models/inswapper_128.onnx',
'type': 'inswapper',
'size': (128, 128),
'mean': [0.0, 0.0, 0.0],
'standard_deviation': [1.0, 1.0, 1.0],
'requires_converter': False
},
'inswapper_128_fp16': {
'path': '../.assets/models/inswapper_128_fp16.onnx',
'type': 'inswapper',
'size': (128, 128),
'mean': [0.0, 0.0, 0.0],
'standard_deviation': [1.0, 1.0, 1.0],
'requires_converter': False
},
'simswap_256': {
'path': '../.assets/models/simswap_256.onnx',
'converter_path': '../.assets/models/simswap_256_converter.onnx',
'type': 'simswap',
'size': (256, 256),
'mean': [0.485, 0.456, 0.406],
'standard_deviation': [0.229, 0.224, 0.225],
'requires_converter': True
},
}
# Default model - can be changed via globals
DEFAULT_MODEL = 'inswapper_128'
def get_current_model_config() -> Dict:
"""Get the current model configuration"""
model_name = getattr(SwitcherAI.globals, 'face_swapper_model', DEFAULT_MODEL)
return MODEL_CONFIGS.get(model_name, MODEL_CONFIGS[DEFAULT_MODEL])
@lru_cache(maxsize=None)
def get_static_model_initializer(model_path: str) -> np.ndarray:
"""Cache model initialization data"""
try:
# This would need to be implemented based on the specific model requirements
# For now, return identity matrix as fallback
return np.eye(512, dtype=np.float32)
except Exception:
return np.eye(512, dtype=np.float32)
def get_frame_processor() -> Any:
global FRAME_PROCESSOR
with THREAD_LOCK:
if FRAME_PROCESSOR is None:
# Try models in order of preference
model_priority = ['inswapper_128_fp16', 'inswapper_128', 'simswap_256']
# If user set a specific model, try it first
current_model = getattr(SwitcherAI.globals, 'face_swapper_model', None)
if current_model and current_model in MODEL_CONFIGS:
model_priority.insert(0, current_model)
# Remove duplicate if it exists later in the list
if current_model in model_priority[1:]:
model_priority.remove(current_model)
for model_name in model_priority:
if model_name not in MODEL_CONFIGS:
continue
try:
print(f"🔄 Trying to load face swap model: {model_name}")
# Get model config
temp_config = MODEL_CONFIGS[model_name]
model_path = resolve_relative_path(temp_config['path'])
# Convert to Path object if it's a string for validation
if isinstance(model_path, str):
model_path_obj = Path(model_path)
else:
model_path_obj = model_path
# Check if model exists locally
if not model_path_obj.exists():
print(f"❌ Model {model_name} not found at: {model_path_obj}")
continue
# Verify model file size
if model_path_obj.stat().st_size < 1024: # Less than 1KB indicates corruption
print(f"⚠️ {model_name} appears corrupted (file too small), skipping...")
continue
# Try to load the model
FRAME_PROCESSOR = insightface.model_zoo.get_model(
str(model_path_obj),
providers=SwitcherAI.globals.execution_providers
)
# If successful, update the global setting and break
SwitcherAI.globals.face_swapper_model = model_name
print(f"✅ Successfully loaded face swap model: {model_name}")
break
except Exception as e:
print(f"❌ Failed to load {model_name}: {e}")
continue
if FRAME_PROCESSOR is None:
print("❌ All face swap models failed to load. Please ensure models are present in .assets/models folder.")
return FRAME_PROCESSOR
def get_embedding_converter() -> Any:
global EMBEDDING_CONVERTER
config = get_current_model_config()
if not config.get('requires_converter', False):
return None
with THREAD_LOCK:
if EMBEDDING_CONVERTER is None:
try:
converter_path = resolve_relative_path(config['converter_path'])
# Convert to Path object if it's a string for validation
if isinstance(converter_path, str):
converter_path_obj = Path(converter_path)
else:
converter_path_obj = converter_path
# Check if converter exists locally
if not converter_path_obj.exists():
print(f"❌ Embedding converter not found at: {converter_path_obj}")
print("Please ensure the converter model is present in .assets/models folder.")
return None
EMBEDDING_CONVERTER = insightface.model_zoo.get_model(
str(converter_path_obj),
providers=SwitcherAI.globals.execution_providers
)
print("✅ Embedding converter initialized")
except Exception as e:
print(f"❌ Failed to initialize embedding converter: {e}")
EMBEDDING_CONVERTER = None
return EMBEDDING_CONVERTER
def clear_frame_processor() -> None:
global FRAME_PROCESSOR, EMBEDDING_CONVERTER
FRAME_PROCESSOR = None
EMBEDDING_CONVERTER = None
def pre_check() -> bool:
"""Check if required models exist locally"""
try:
config = get_current_model_config()
# Check main model path
model_path = resolve_relative_path(config['path'])
if isinstance(model_path, str):
model_path_obj = Path(model_path)
else:
model_path_obj = model_path
if not model_path_obj.exists():
print(f"❌ Main model not found at: {model_path_obj}")
print("Please ensure the model file is present in .assets/models folder.")
return False
# Check converter if needed
if config.get('requires_converter', False):
converter_path = resolve_relative_path(config['converter_path'])
if isinstance(converter_path, str):
converter_path_obj = Path(converter_path)
else:
converter_path_obj = converter_path
if not converter_path_obj.exists():
print(f"❌ Converter model not found at: {converter_path_obj}")
print("Please ensure the converter model file is present in .assets/models folder.")
return False
print("✅ All required models found locally")
return True
except Exception as e:
print(f"❌ Face swap pre-check failed: {e}")
return False
def pre_process() -> bool:
try:
if not is_image(SwitcherAI.globals.source_path):
update_status(wording.get('select_image_source') + wording.get('exclamation_mark'), NAME)
return False
elif not get_one_face(cv2.imread(SwitcherAI.globals.source_path)):
update_status(wording.get('no_source_face_detected') + wording.get('exclamation_mark'), NAME)
return False
if not is_image(SwitcherAI.globals.target_path) and not is_video(SwitcherAI.globals.target_path):
update_status(wording.get('select_image_or_video_target') + wording.get('exclamation_mark'), NAME)
return False
# Check if required models exist locally
if not pre_check():
update_status("Required models not found in .assets/models folder", NAME)
return False
# Check if processor is available
processor = get_frame_processor()
if processor is None:
update_status("Face swap processor not available", NAME)
return False
return True
except Exception as e:
print(f"⚠️ Face swap pre-process failed: {e}")
return False
def post_process() -> None:
clear_frame_processor()
# Clear caches like the newer version
get_static_model_initializer.cache_clear()
def prepare_source_embedding(source_face: Face) -> np.ndarray:
"""Prepare source face embedding based on model type"""
try:
config = get_current_model_config()
model_type = config['type']
if model_type == 'inswapper':
# Enhanced embedding preparation for inswapper
model_path = resolve_relative_path(config['path'])
model_initializer = get_static_model_initializer(str(model_path))
source_embedding = source_face.embedding.reshape((1, -1))
source_embedding = np.dot(source_embedding, model_initializer) / np.linalg.norm(source_embedding)
return source_embedding
elif model_type == 'simswap':
# Use embedding converter for simswap
converter = get_embedding_converter()
if converter is not None:
embedding = source_face.embedding.reshape(-1, 512)
try:
converted_embedding = converter.run(None, {'input': embedding})[0]
converted_embedding = converted_embedding.ravel()
normed_embedding = converted_embedding / np.linalg.norm(converted_embedding)
return normed_embedding.reshape(1, -1)
except Exception:
pass
# Fallback to original embedding
return source_face.embedding.reshape(1, -1)
else:
# Default behavior
return source_face.embedding.reshape(1, -1)
except Exception as e:
print(f"⚠️ Error preparing source embedding: {e}")
return source_face.embedding.reshape(1, -1)
def prepare_crop_frame(crop_frame: Frame) -> np.ndarray:
"""Prepare cropped frame for model input with normalization"""
try:
config = get_current_model_config()
model_mean = config['mean']
model_std = config['standard_deviation']
# Convert to float and normalize
crop_frame = crop_frame[:, :, ::-1] / 255.0
crop_frame = (crop_frame - model_mean) / model_std
crop_frame = crop_frame.transpose(2, 0, 1)
crop_frame = np.expand_dims(crop_frame, axis=0).astype(np.float32)
return crop_frame
except Exception as e:
print(f"⚠️ Error preparing crop frame: {e}")
return crop_frame
def normalize_crop_frame(crop_frame: np.ndarray) -> Frame:
"""Normalize cropped frame back to image format"""
try:
config = get_current_model_config()
model_type = config['type']
model_mean = config['mean']
model_std = config['standard_deviation']
crop_frame = crop_frame.transpose(1, 2, 0)
# Apply reverse normalization for certain model types
if model_type in ['simswap']:
crop_frame = crop_frame * model_std + model_mean
crop_frame = crop_frame.clip(0, 1)
crop_frame = crop_frame[:, :, ::-1] * 255
return crop_frame.astype(np.uint8)
except Exception as e:
print(f"⚠️ Error normalizing crop frame: {e}")
return crop_frame.astype(np.uint8)
def enhanced_swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
"""Enhanced face swapping with improved preprocessing"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Face swap processor not available")
return temp_frame
config = get_current_model_config()
model_type = config['type']
if model_type == 'inswapper':
# Use original method for inswapper
return processor.get(temp_frame, target_face, source_face, paste_back=True)
else:
# Enhanced method for other models
try:
# Prepare source embedding
source_embedding = prepare_source_embedding(source_face)
# Get crop region (this would need proper implementation)
# For now, fall back to original method
return processor.get(temp_frame, target_face, source_face, paste_back=True)
except Exception as e:
print(f"⚠️ Enhanced swap failed: {e}")
# Fallback to original method
return processor.get(temp_frame, target_face, source_face, paste_back=True)
except Exception as e:
print(f"⚠️ Face swap failed: {e}")
return temp_frame
def swap_face(source_face: Face, target_face: Face, temp_frame: Frame) -> Frame:
"""Main face swapping function with model-specific handling"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Face swap processor not available, skipping swap")
return temp_frame
config = get_current_model_config()
# Use enhanced swapping for supported models
if config['type'] in ['simswap', 'inswapper']:
return enhanced_swap_face(source_face, target_face, temp_frame)
else:
# Original method
return processor.get(temp_frame, target_face, source_face, paste_back=True)
except Exception as e:
print(f"⚠️ Error in swap_face: {e}")
return temp_frame
def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame:
"""Process frame with enhanced face selection logic"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Face swap processor not available, skipping frame")
return temp_frame
if 'reference' in SwitcherAI.globals.face_recognition:
similar_faces = find_similar_faces(temp_frame, reference_face, SwitcherAI.globals.reference_face_distance)
if similar_faces:
for similar_face in similar_faces:
temp_frame = swap_face(source_face, similar_face, temp_frame)
if 'many' in SwitcherAI.globals.face_recognition:
many_faces = get_many_faces(temp_frame)
if many_faces:
# Sort faces by size (largest first) like the newer version
many_faces = sorted(many_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True)
for target_face in many_faces:
temp_frame = swap_face(source_face, target_face, temp_frame)
return temp_frame
except Exception as e:
print(f"⚠️ Error in process_frame: {e}")
return temp_frame
def get_average_face(faces: List[Face]) -> Face:
"""Get average face from multiple faces (simplified version)"""
if not faces:
return None
if len(faces) == 1:
return faces[0]
# For now, return the first face
# In a full implementation, this would average the embeddings
return faces[0]
def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None:
"""Enhanced frame processing with better source face handling"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Face swap processor not available, skipping frame processing")
if update:
update()
return
source_frame = cv2.imread(source_path)
if source_frame is None:
print(f"⚠️ Failed to read source image: {source_path}")
if update:
update()
return
source_faces = get_many_faces(source_frame)
# Get best source face (largest)
if source_faces:
source_faces = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True)
source_face = source_faces[0]
else:
source_face = get_one_face(source_frame)
if source_face is None:
print("⚠️ No source face found")
if update:
update()
return
# Handle multiple source faces if available
if len(source_faces) > 1:
source_face = get_average_face(source_faces)
reference_face = get_face_reference() if 'reference' in SwitcherAI.globals.face_recognition else None
for temp_frame_path in temp_frame_paths:
try:
temp_frame = cv2.imread(temp_frame_path)
if temp_frame is not None:
result_frame = process_frame(source_face, reference_face, temp_frame)
cv2.imwrite(temp_frame_path, result_frame)
else:
print(f"⚠️ Failed to read frame: {temp_frame_path}")
except Exception as e:
print(f"⚠️ Error processing frame {temp_frame_path}: {e}")
if update:
update()
except Exception as e:
print(f"⚠️ Error in process_frames: {e}")
def process_image(source_path: str, target_path: str, output_path: str) -> None:
"""Enhanced image processing"""
try:
processor = get_frame_processor()
if processor is None:
print("⚠️ Face swap processor not available, copying original image")
import shutil
shutil.copy2(target_path, output_path)
return
source_frame = cv2.imread(source_path)
if source_frame is None:
print(f"⚠️ Failed to read source image: {source_path}")
return
source_faces = get_many_faces(source_frame)
# Get best source face
if source_faces:
source_faces = sorted(source_faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]), reverse=True)
source_face = source_faces[0]
# Handle multiple source faces
if len(source_faces) > 1:
source_face = get_average_face(source_faces)
else:
source_face = get_one_face(source_frame)
if source_face is None:
print("⚠️ No source face found")
return
target_frame = cv2.imread(target_path)
if target_frame is None:
print(f"⚠️ Failed to read target image: {target_path}")
return
reference_face = get_one_face(target_frame, SwitcherAI.globals.reference_face_position) if 'reference' in SwitcherAI.globals.face_recognition else None
result_frame = process_frame(source_face, reference_face, target_frame)
cv2.imwrite(output_path, result_frame)
except Exception as e:
print(f"⚠️ Error in process_image: {e}")
def process_video(source_path: str, temp_frame_paths: List[str]) -> None:
try:
conditional_set_face_reference(temp_frame_paths)
frame_processors.process_video(source_path, temp_frame_paths, process_frames)
except Exception as e:
print(f"⚠️ Error in process_video: {e}")
def conditional_set_face_reference(temp_frame_paths: List[str]) -> None:
try:
if 'reference' in SwitcherAI.globals.face_recognition and not get_face_reference():
reference_frame = cv2.imread(temp_frame_paths[SwitcherAI.globals.reference_face_position])
if reference_frame is not None:
reference_face = get_one_face(reference_frame, SwitcherAI.globals.reference_face_position)
set_face_reference(reference_face)
else:
print(f"⚠️ Failed to read reference frame")
except Exception as e:
print(f"⚠️ Error setting face reference: {e}")