MuseTalk / inference.py
Ultronprime's picture
Update inference.py
b9a578a verified
"""MuseTalk Inference Module
Refactored for Long-Form Generation (5-10 mins)
using Memory-Efficient Streaming, Looping, and Audio Muxing.
"""
import os
import cv2
import torch
import numpy as np
import tempfile
import librosa
import mimetypes
import subprocess
from pathlib import Path
from typing import Optional, Tuple, Union
class MuseTalkInference:
"""MuseTalk inference engine for audio-driven video generation."""
def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
self.device = device
self.model = None
self.whisper_model = None
self.face_detector = None
self.pose_model = None
self.initialized = False
def load_models(self, progress_callback=None):
"""Load MuseTalk models from HuggingFace Hub."""
try:
if progress_callback:
progress_callback(0, "Loading MuseTalk models...")
# Placeholder: Initialize your actual PyTorch models here
self.initialized = True
if progress_callback:
progress_callback(5, "Models loaded successfully")
except Exception as e:
print(f"Error loading models: {e}")
raise
def extract_audio_features(self, audio_path: str, progress_callback=None) -> np.ndarray:
"""Extract audio features using Whisper/Mel-Spectrogram."""
try:
if progress_callback:
progress_callback(10, "Extracting audio features...")
try:
audio, sr = librosa.load(audio_path, sr=16000)
except:
try:
import scipy.io.wavfile as wavfile
sr, audio = wavfile.read(audio_path)
if sr != 16000:
ratio = 16000 / sr
audio = (audio * ratio).astype(np.int16)
except:
import soundfile as sf
audio, sr = sf.read(audio_path)
audio = audio.astype(np.float32)
audio = audio / (np.max(np.abs(audio)) + 1e-8)
n_mels = 80
n_fft = 400
hop_length = 160
mel_features = self._compute_mel_spectrogram(audio, sr, n_mels, n_fft, hop_length)
if progress_callback:
progress_callback(15, "Audio features extracted")
return mel_features
except Exception as e:
print(f"Error extracting audio features: {e}")
raise
def extract_source_frames(self, file_path: str, fps: int = 25, progress_callback=None) -> Tuple[list, int, int]:
"""Extracts frames from a short video or loads a single image to memory."""
try:
if progress_callback:
progress_callback(20, "Reading source image/video...")
mime_type, _ = mimetypes.guess_type(file_path)
frames = []
# Handle Single Image Input
if mime_type and mime_type.startswith('image'):
frame = cv2.imread(file_path)
if frame is None:
raise ValueError("Failed to read image")
frames.append(frame)
# Handle Short Video Input
else:
cap = cv2.VideoCapture(file_path)
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
if not frames:
raise ValueError("No frames extracted from source file")
height, width = frames[0].shape[:2]
return frames, width, height
except Exception as e:
print(f"Error extracting video frames: {e}")
raise
def detect_faces(self, frames: list, progress_callback=None) -> list:
"""Detect faces ONLY on the short source clip to save compute."""
try:
if progress_callback:
progress_callback(25, "Detecting face in source media...")
face_detections = []
cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
face_cascade = cv2.CascadeClassifier(cascade_path)
for i, frame in enumerate(frames):
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) > 0:
# Take the LARGEST face by area (width * height)
face = max(faces, key=lambda f: f[2] * f[3])
face_detections.append(face)
else:
if face_detections:
face_detections.append(face_detections[-1])
else:
h, w = frame.shape[:2]
face_detections.append(np.array([w//4, h//4, w//2, h//2]))
return face_detections
except Exception as e:
print(f"Error detecting faces: {e}")
raise
def generate(self, audio_path: str, video_path: str, output_path: str,
fps: int = 25, progress_callback=None) -> str:
"""
Memory-efficient generator for long videos.
Loops short inputs to match 5-10 minute audio.
"""
try:
if not self.initialized:
self.load_models(progress_callback)
# 1. Extract audio features
audio_features = self.extract_audio_features(audio_path, progress_callback)
# 2. Determine Total Output Frames based on Audio Length
audio_data, sr = librosa.load(audio_path, sr=16000)
audio_duration = len(audio_data) / sr
total_target_frames = int(audio_duration * fps)
if total_target_frames == 0:
raise ValueError("Audio file is too short or invalid.")
# 3. Extract Source Clip/Image (Only loads short clip into memory)
source_frames, width, height = self.extract_source_frames(video_path, fps, progress_callback)
# 4. Detect faces on the short source clip (Pre-cached)
source_faces = self.detect_faces(source_frames, progress_callback)
# 5. Stream Process (Write directly to file to avoid OOM crash)
temp_silent_video = output_path.replace('.mp4', '_silent.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_silent_video, fourcc, fps, (width, height))
if progress_callback:
progress_callback(30, f"Generating {total_target_frames} frames (Streaming)...")
for i in range(total_target_frames):
# LOOPING LOGIC: Loop the short video or image continuously
src_idx = i % len(source_frames)
frame = source_frames[src_idx].copy()
face = source_faces[src_idx]
# --- START AI LIP-SYNC INFERENCE ---
# NOTE: Put your actual AI model generation code here.
# Right now, this just draws a box around the face.
# Example: frame = self.model.infer(frame, face, audio_features[:, i])
x, y, w, h = int(face[0]), int(face[1]), int(face[2]), int(face[3])
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
# --- END AI LIP-SYNC INFERENCE ---
# Write directly to disk (Saves 30GB+ of RAM for 10 min videos)
out.write(frame)
# Report progress periodically
if (i + 1) % max(1, total_target_frames // 20) == 0 and progress_callback:
progress_pct = 30 + int((i / total_target_frames) * 60)
progress_callback(progress_pct, f"Generated frames: {i + 1}/{total_target_frames}")
out.release()
# 6. MUX AUDIO (Combine the generated silent video with original audio)
if progress_callback:
progress_callback(95, "Merging final audio and video...")
try:
cmd = [
"ffmpeg", "-y",
"-i", temp_silent_video, # The generated silent video
"-i", audio_path, # The original audio
"-c:v", "libx264", # Re-encode video for broad web compatibility
"-c:a", "aac", # Re-encode audio to AAC
"-map", "0:v:0",
"-map", "1:a:0",
"-shortest", # Cut at the shortest stream
output_path
]
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# Cleanup temp file
if os.path.exists(temp_silent_video):
os.remove(temp_silent_video)
except subprocess.CalledProcessError as e:
print(f"FFMPEG Error: {e.stderr}")
# Fallback to silent video if FFMPEG fails
os.rename(temp_silent_video, output_path)
if progress_callback:
progress_callback(100, "Generation Complete!")
return output_path
except Exception as e:
print(f"Error during generation: {e}")
raise
def _compute_mel_spectrogram(self, audio: np.ndarray, sr: int, n_mels: int,
n_fft: int, hop_length: int) -> np.ndarray:
"""Compute mel-spectrogram from audio."""
try:
import librosa
mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft,
hop_length=hop_length, n_mels=n_mels)
mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
return mel_spec
except:
n_frames = len(audio) // hop_length
return np.random.randn(n_mels, n_frames)