Spaces:
Paused
Paused
from typing import Any, List, Callable | |
import cv2 | |
import threading | |
import numpy as np | |
import os | |
# Environment fixes | |
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' | |
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' | |
import SwitcherAI.globals | |
import SwitcherAI.processors.frame.core as frame_processors | |
from SwitcherAI import wording | |
from SwitcherAI.core import update_status | |
from SwitcherAI.face_analyser import get_many_faces, get_one_face | |
from SwitcherAI.typing import Frame, Face | |
from SwitcherAI.utilities import conditional_download, resolve_relative_path, is_image, is_video | |
# Global variables matching the pattern | |
FRAME_PROCESSOR = None | |
THREAD_SEMAPHORE = threading.Semaphore() | |
THREAD_LOCK = threading.Lock() | |
NAME = 'FACEFUSION.FRAME_PROCESSOR.LIP_SYNCER' | |
def get_frame_processor() -> Any: | |
"""Get the lip sync processor - using ONNX Runtime like FaceFusion""" | |
global FRAME_PROCESSOR | |
with THREAD_LOCK: | |
if FRAME_PROCESSOR is None: | |
try: | |
# Get the model name from globals | |
model_name = getattr(SwitcherAI.globals, 'lip_syncer_model', 'wav2lip_gan_96') | |
model_path = resolve_relative_path(f'../.assets/models/{model_name}.onnx') | |
print(f"[{NAME}] Loading model: {model_path}") | |
if os.path.exists(model_path): | |
# Load ONNX model like FaceFusion does | |
import onnxruntime | |
providers = getattr(SwitcherAI.globals, 'execution_providers', ['CPUExecutionProvider']) | |
FRAME_PROCESSOR = onnxruntime.InferenceSession(model_path, providers=providers) | |
print(f"[{NAME}] ONNX model loaded successfully") | |
else: | |
print(f"[{NAME}] Model file not found: {model_path}") | |
FRAME_PROCESSOR = None | |
except ImportError: | |
print(f"[{NAME}] onnxruntime not available, using passthrough mode") | |
FRAME_PROCESSOR = None | |
except Exception as e: | |
print(f"[{NAME}] Error loading ONNX model: {e}") | |
FRAME_PROCESSOR = None | |
return FRAME_PROCESSOR | |
def clear_frame_processor() -> None: | |
"""Clear the frame processor""" | |
global FRAME_PROCESSOR | |
FRAME_PROCESSOR = None | |
def pre_check() -> bool: | |
"""Pre-check for lip syncer requirements""" | |
print(f"[{NAME}] Pre-check starting...") | |
try: | |
# Check if we need to download models | |
download_directory_path = resolve_relative_path('../.assets/models') | |
# Get model name from globals | |
model_name = getattr(SwitcherAI.globals, 'lip_syncer_model', 'wav2lip_gan_96') | |
model_path = os.path.join(download_directory_path, f'{model_name}.onnx') | |
if not os.path.exists(model_path): | |
print(f"[{NAME}] Model not found: {model_path}") | |
# Model download URLs | |
model_urls = { | |
'wav2lip_96': ['Awwfuck.com'], | |
'wav2lip_gan_96': ['Awwfuck.com'] | |
} | |
if model_name in model_urls: | |
print(f"[{NAME}] Attempting to download {model_name}") | |
conditional_download(download_directory_path, model_urls[model_name]) | |
print(f"[{NAME}] Pre-check passed") | |
return True | |
except Exception as e: | |
print(f"[{NAME}] Pre-check error: {e}") | |
return True | |
def pre_process() -> bool: | |
"""Pre-process initialization""" | |
print(f"[{NAME}] Pre-processing...") | |
# Check target type like FaceFusion does | |
if not is_image(SwitcherAI.globals.target_path) and not is_video(SwitcherAI.globals.target_path): | |
update_status(wording.get('select_image_or_video_target') + wording.get('exclamation_mark'), NAME) | |
return False | |
print(f"[{NAME}] Pre-processing completed") | |
return True | |
def post_process() -> None: | |
"""Post-process cleanup""" | |
clear_frame_processor() | |
print(f"[{NAME}] Post-processing completed") | |
def prepare_audio_frame(audio_frame: np.ndarray) -> np.ndarray: | |
"""Prepare audio frame like FaceFusion - convert mel spectrogram properly""" | |
# FaceFusion audio preprocessing | |
audio_frame = np.maximum(np.exp(-5 * np.log(10)), audio_frame) | |
audio_frame = np.log10(audio_frame) * 1.6 + 3.2 | |
audio_frame = audio_frame.clip(-4, 4).astype(np.float32) | |
audio_frame = np.expand_dims(audio_frame, axis=(0, 1)) | |
return audio_frame | |
def prepare_crop_frame(crop_vision_frame: np.ndarray) -> np.ndarray: | |
"""Prepare crop frame like FaceFusion""" | |
crop_vision_frame = np.expand_dims(crop_vision_frame, axis=0) | |
prepare_vision_frame = crop_vision_frame.copy() | |
prepare_vision_frame[:, 48:] = 0 # Mask bottom half | |
crop_vision_frame = np.concatenate((prepare_vision_frame, crop_vision_frame), axis=3) | |
crop_vision_frame = crop_vision_frame.transpose(0, 3, 1, 2).astype('float32') / 255.0 | |
return crop_vision_frame | |
def normalize_close_frame(crop_vision_frame: np.ndarray) -> np.ndarray: | |
"""Normalize frame like FaceFusion""" | |
crop_vision_frame = crop_vision_frame[0].transpose(1, 2, 0) | |
crop_vision_frame = crop_vision_frame.clip(0, 1) * 255 | |
crop_vision_frame = crop_vision_frame.astype(np.uint8) | |
return crop_vision_frame | |
def forward(temp_audio_frame: np.ndarray, close_vision_frame: np.ndarray) -> np.ndarray: | |
"""Forward pass through model like FaceFusion""" | |
lip_syncer = get_frame_processor() | |
if lip_syncer is None: | |
return close_vision_frame | |
try: | |
with THREAD_SEMAPHORE: | |
# Get input names from the model | |
input_names = [inp.name for inp in lip_syncer.get_inputs()] | |
# Create input dictionary - FaceFusion uses 'source' and 'target' | |
inputs = {} | |
for name in input_names: | |
if 'source' in name.lower() or 'audio' in name.lower() or 'mel' in name.lower(): | |
inputs[name] = temp_audio_frame | |
elif 'target' in name.lower() or 'video' in name.lower() or 'frame' in name.lower(): | |
inputs[name] = close_vision_frame | |
# Run inference | |
close_vision_frame = lip_syncer.run(None, inputs)[0] | |
return close_vision_frame | |
except Exception as e: | |
print(f"[{NAME}] Forward pass error: {e}") | |
return close_vision_frame | |
def sync_lip(target_face: Face, temp_audio_frame: np.ndarray, temp_vision_frame: Frame) -> Frame: | |
"""Main lip sync function following FaceFusion's approach""" | |
try: | |
# For now, create dummy audio frame if none provided | |
if temp_audio_frame is None: | |
# Create empty mel spectrogram (80 features x 16 frames) | |
temp_audio_frame = np.zeros((80, 16), dtype=np.float32) | |
# Prepare audio frame | |
temp_audio_frame = prepare_audio_frame(temp_audio_frame) | |
# Extract face region using face landmarks | |
if hasattr(target_face, 'bbox'): | |
bbox = target_face.bbox | |
x1, y1, x2, y2 = map(int, bbox) | |
# Ensure coordinates are within frame bounds | |
h, w = temp_vision_frame.shape[:2] | |
x1 = max(0, min(x1, w-1)) | |
y1 = max(0, min(y1, h-1)) | |
x2 = max(0, min(x2, w-1)) | |
y2 = max(0, min(y2, h-1)) | |
if x2 <= x1 or y2 <= y1: | |
return temp_vision_frame | |
# Extract and resize face region to 96x96 | |
face_region = temp_vision_frame[y1:y2, x1:x2] | |
close_vision_frame = cv2.resize(face_region, (96, 96)) | |
# Prepare crop frame | |
close_vision_frame = prepare_crop_frame(close_vision_frame) | |
# Forward pass | |
close_vision_frame = forward(temp_audio_frame, close_vision_frame) | |
# Normalize output | |
close_vision_frame = normalize_close_frame(close_vision_frame) | |
# Resize back and paste | |
close_vision_frame = cv2.resize(close_vision_frame, (x2-x1, y2-y1)) | |
# Simple paste back | |
result_frame = temp_vision_frame.copy() | |
result_frame[y1:y2, x1:x2] = close_vision_frame | |
return result_frame | |
return temp_vision_frame | |
except Exception as e: | |
print(f"[{NAME}] Lip sync error: {e}") | |
return temp_vision_frame | |
def process_frame(source_face: Face, reference_face: Face, temp_frame: Frame) -> Frame: | |
"""Process a single frame""" | |
try: | |
# Get all faces in the frame | |
many_faces = get_many_faces(temp_frame) | |
if not many_faces: | |
return temp_frame | |
# Process each face with lip sync | |
result_frame = temp_frame | |
for target_face in many_faces: | |
# Create dummy audio frame for now | |
temp_audio_frame = np.zeros((80, 16), dtype=np.float32) | |
result_frame = sync_lip(target_face, temp_audio_frame, result_frame) | |
return result_frame | |
except Exception as e: | |
print(f"[{NAME}] Error processing frame: {e}") | |
return temp_frame | |
def process_frames(source_path: str, temp_frame_paths: List[str], update: Callable[[], None]) -> None: | |
"""Process multiple frames""" | |
total_frames = len(temp_frame_paths) | |
print(f"[{NAME}] Processing {total_frames} frames") | |
for i, temp_frame_path in enumerate(temp_frame_paths): | |
try: | |
# Read frame | |
temp_frame = cv2.imread(temp_frame_path) | |
if temp_frame is None: | |
continue | |
# Process frame | |
result_frame = process_frame(None, None, temp_frame) | |
# Save processed frame | |
cv2.imwrite(temp_frame_path, result_frame) | |
# Update progress | |
if update: | |
update() | |
# Progress feedback | |
if i % 100 == 0: | |
print(f"[{NAME}] Progress: {i}/{total_frames} frames") | |
except Exception as e: | |
print(f"[{NAME}] Error processing frame {i}: {e}") | |
continue | |
print(f"[{NAME}] Frame processing completed") | |
def process_image(source_path: str, target_path: str, output_path: str) -> None: | |
"""Process a single image""" | |
try: | |
print(f"[{NAME}] Processing image: {os.path.basename(target_path)}") | |
# Read image | |
target_frame = cv2.imread(target_path) | |
if target_frame is None: | |
import shutil | |
shutil.copy2(target_path, output_path) | |
return | |
# Process frame | |
result_frame = process_frame(None, None, target_frame) | |
# Save result | |
cv2.imwrite(output_path, result_frame) | |
print(f"[{NAME}] Image processing completed") | |
except Exception as e: | |
print(f"[{NAME}] Error processing image: {e}") | |
# Fallback: copy original | |
import shutil | |
shutil.copy2(target_path, output_path) | |
def process_video(source_path: str, temp_frame_paths: List[str]) -> None: | |
"""Process video using the frame processor core""" | |
frame_processors.process_video(source_path, temp_frame_paths, process_frames) |