Spaces:
Sleeping
Sleeping
| """MuseTalk Inference Module | |
| Refactored for Long-Form Generation (5-10 mins) | |
| using Memory-Efficient Streaming, Looping, and Audio Muxing. | |
| """ | |
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| import librosa | |
| import mimetypes | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Optional, Tuple, Union | |
| class MuseTalkInference: | |
| """MuseTalk inference engine for audio-driven video generation.""" | |
| def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"): | |
| self.device = device | |
| self.model = None | |
| self.whisper_model = None | |
| self.face_detector = None | |
| self.pose_model = None | |
| self.initialized = False | |
| def load_models(self, progress_callback=None): | |
| """Load MuseTalk models from HuggingFace Hub.""" | |
| try: | |
| if progress_callback: | |
| progress_callback(0, "Loading MuseTalk models...") | |
| # Placeholder: Initialize your actual PyTorch models here | |
| self.initialized = True | |
| if progress_callback: | |
| progress_callback(5, "Models loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading models: {e}") | |
| raise | |
| def extract_audio_features(self, audio_path: str, progress_callback=None) -> np.ndarray: | |
| """Extract audio features using Whisper/Mel-Spectrogram.""" | |
| try: | |
| if progress_callback: | |
| progress_callback(10, "Extracting audio features...") | |
| try: | |
| audio, sr = librosa.load(audio_path, sr=16000) | |
| except: | |
| try: | |
| import scipy.io.wavfile as wavfile | |
| sr, audio = wavfile.read(audio_path) | |
| if sr != 16000: | |
| ratio = 16000 / sr | |
| audio = (audio * ratio).astype(np.int16) | |
| except: | |
| import soundfile as sf | |
| audio, sr = sf.read(audio_path) | |
| audio = audio.astype(np.float32) | |
| audio = audio / (np.max(np.abs(audio)) + 1e-8) | |
| n_mels = 80 | |
| n_fft = 400 | |
| hop_length = 160 | |
| mel_features = self._compute_mel_spectrogram(audio, sr, n_mels, n_fft, hop_length) | |
| if progress_callback: | |
| progress_callback(15, "Audio features extracted") | |
| return mel_features | |
| except Exception as e: | |
| print(f"Error extracting audio features: {e}") | |
| raise | |
| def extract_source_frames(self, file_path: str, fps: int = 25, progress_callback=None) -> Tuple[list, int, int]: | |
| """Extracts frames from a short video or loads a single image to memory.""" | |
| try: | |
| if progress_callback: | |
| progress_callback(20, "Reading source image/video...") | |
| mime_type, _ = mimetypes.guess_type(file_path) | |
| frames = [] | |
| # Handle Single Image Input | |
| if mime_type and mime_type.startswith('image'): | |
| frame = cv2.imread(file_path) | |
| if frame is None: | |
| raise ValueError("Failed to read image") | |
| frames.append(frame) | |
| # Handle Short Video Input | |
| else: | |
| cap = cv2.VideoCapture(file_path) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frames.append(frame) | |
| cap.release() | |
| if not frames: | |
| raise ValueError("No frames extracted from source file") | |
| height, width = frames[0].shape[:2] | |
| return frames, width, height | |
| except Exception as e: | |
| print(f"Error extracting video frames: {e}") | |
| raise | |
| def detect_faces(self, frames: list, progress_callback=None) -> list: | |
| """Detect faces ONLY on the short source clip to save compute.""" | |
| try: | |
| if progress_callback: | |
| progress_callback(25, "Detecting face in source media...") | |
| face_detections = [] | |
| cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' | |
| face_cascade = cv2.CascadeClassifier(cascade_path) | |
| for i, frame in enumerate(frames): | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) | |
| if len(faces) > 0: | |
| # Take the LARGEST face by area (width * height) | |
| face = max(faces, key=lambda f: f[2] * f[3]) | |
| face_detections.append(face) | |
| else: | |
| if face_detections: | |
| face_detections.append(face_detections[-1]) | |
| else: | |
| h, w = frame.shape[:2] | |
| face_detections.append(np.array([w//4, h//4, w//2, h//2])) | |
| return face_detections | |
| except Exception as e: | |
| print(f"Error detecting faces: {e}") | |
| raise | |
| def generate(self, audio_path: str, video_path: str, output_path: str, | |
| fps: int = 25, progress_callback=None) -> str: | |
| """ | |
| Memory-efficient generator for long videos. | |
| Loops short inputs to match 5-10 minute audio. | |
| """ | |
| try: | |
| if not self.initialized: | |
| self.load_models(progress_callback) | |
| # 1. Extract audio features | |
| audio_features = self.extract_audio_features(audio_path, progress_callback) | |
| # 2. Determine Total Output Frames based on Audio Length | |
| audio_data, sr = librosa.load(audio_path, sr=16000) | |
| audio_duration = len(audio_data) / sr | |
| total_target_frames = int(audio_duration * fps) | |
| if total_target_frames == 0: | |
| raise ValueError("Audio file is too short or invalid.") | |
| # 3. Extract Source Clip/Image (Only loads short clip into memory) | |
| source_frames, width, height = self.extract_source_frames(video_path, fps, progress_callback) | |
| # 4. Detect faces on the short source clip (Pre-cached) | |
| source_faces = self.detect_faces(source_frames, progress_callback) | |
| # 5. Stream Process (Write directly to file to avoid OOM crash) | |
| temp_silent_video = output_path.replace('.mp4', '_silent.mp4') | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_silent_video, fourcc, fps, (width, height)) | |
| if progress_callback: | |
| progress_callback(30, f"Generating {total_target_frames} frames (Streaming)...") | |
| for i in range(total_target_frames): | |
| # LOOPING LOGIC: Loop the short video or image continuously | |
| src_idx = i % len(source_frames) | |
| frame = source_frames[src_idx].copy() | |
| face = source_faces[src_idx] | |
| # --- START AI LIP-SYNC INFERENCE --- | |
| # NOTE: Put your actual AI model generation code here. | |
| # Right now, this just draws a box around the face. | |
| # Example: frame = self.model.infer(frame, face, audio_features[:, i]) | |
| x, y, w, h = int(face[0]), int(face[1]), int(face[2]), int(face[3]) | |
| cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) | |
| # --- END AI LIP-SYNC INFERENCE --- | |
| # Write directly to disk (Saves 30GB+ of RAM for 10 min videos) | |
| out.write(frame) | |
| # Report progress periodically | |
| if (i + 1) % max(1, total_target_frames // 20) == 0 and progress_callback: | |
| progress_pct = 30 + int((i / total_target_frames) * 60) | |
| progress_callback(progress_pct, f"Generated frames: {i + 1}/{total_target_frames}") | |
| out.release() | |
| # 6. MUX AUDIO (Combine the generated silent video with original audio) | |
| if progress_callback: | |
| progress_callback(95, "Merging final audio and video...") | |
| try: | |
| cmd = [ | |
| "ffmpeg", "-y", | |
| "-i", temp_silent_video, # The generated silent video | |
| "-i", audio_path, # The original audio | |
| "-c:v", "libx264", # Re-encode video for broad web compatibility | |
| "-c:a", "aac", # Re-encode audio to AAC | |
| "-map", "0:v:0", | |
| "-map", "1:a:0", | |
| "-shortest", # Cut at the shortest stream | |
| output_path | |
| ] | |
| subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| # Cleanup temp file | |
| if os.path.exists(temp_silent_video): | |
| os.remove(temp_silent_video) | |
| except subprocess.CalledProcessError as e: | |
| print(f"FFMPEG Error: {e.stderr}") | |
| # Fallback to silent video if FFMPEG fails | |
| os.rename(temp_silent_video, output_path) | |
| if progress_callback: | |
| progress_callback(100, "Generation Complete!") | |
| return output_path | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| raise | |
| def _compute_mel_spectrogram(self, audio: np.ndarray, sr: int, n_mels: int, | |
| n_fft: int, hop_length: int) -> np.ndarray: | |
| """Compute mel-spectrogram from audio.""" | |
| try: | |
| import librosa | |
| mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft, | |
| hop_length=hop_length, n_mels=n_mels) | |
| mel_spec = librosa.power_to_db(mel_spec, ref=np.max) | |
| return mel_spec | |
| except: | |
| n_frames = len(audio) // hop_length | |
| return np.random.randn(n_mels, n_frames) |