Ultronprime commited on
Commit
b9a578a
·
verified ·
1 Parent(s): 356d0ea

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +111 -194
inference.py CHANGED
@@ -1,7 +1,7 @@
1
  """MuseTalk Inference Module
2
 
3
- This module provides the core inference functionality for MuseTalk,
4
- enabling audio-driven lip-sync video generation.
5
  """
6
 
7
  import os
@@ -9,20 +9,17 @@ import cv2
9
  import torch
10
  import numpy as np
11
  import tempfile
 
 
 
12
  from pathlib import Path
13
  from typing import Optional, Tuple, Union
14
- import subprocess
15
 
16
 
17
  class MuseTalkInference:
18
  """MuseTalk inference engine for audio-driven video generation."""
19
 
20
  def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
21
- """Initialize MuseTalk inference engine.
22
-
23
- Args:
24
- device: torch device to use ('cuda' or 'cpu')
25
- """
26
  self.device = device
27
  self.model = None
28
  self.whisper_model = None
@@ -31,45 +28,30 @@ class MuseTalkInference:
31
  self.initialized = False
32
 
33
  def load_models(self, progress_callback=None):
34
- """Load MuseTalk models from HuggingFace Hub.
35
-
36
- Args:
37
- progress_callback: Optional callback to report loading progress
38
- """
39
  try:
40
  if progress_callback:
41
  progress_callback(0, "Loading MuseTalk models...")
42
 
43
- # For now, return success - models will be loaded lazily during inference
44
  self.initialized = True
45
 
46
  if progress_callback:
47
- progress_callback(100, "Models loaded successfully")
48
 
49
  except Exception as e:
50
  print(f"Error loading models: {e}")
51
  raise
52
 
53
  def extract_audio_features(self, audio_path: str, progress_callback=None) -> np.ndarray:
54
- """Extract audio features using Whisper.
55
-
56
- Args:
57
- audio_path: Path to audio file
58
- progress_callback: Optional progress callback
59
-
60
- Returns:
61
- Audio features array
62
- """
63
  try:
64
  if progress_callback:
65
  progress_callback(10, "Extracting audio features...")
66
 
67
- # Load audio file
68
  try:
69
- import librosa
70
  audio, sr = librosa.load(audio_path, sr=16000)
71
  except:
72
- # Fallback using scipy
73
  try:
74
  import scipy.io.wavfile as wavfile
75
  sr, audio = wavfile.read(audio_path)
@@ -77,24 +59,20 @@ class MuseTalkInference:
77
  ratio = 16000 / sr
78
  audio = (audio * ratio).astype(np.int16)
79
  except:
80
- # Additional fallback
81
  import soundfile as sf
82
  audio, sr = sf.read(audio_path)
83
 
84
- # Normalize audio
85
  audio = audio.astype(np.float32)
86
  audio = audio / (np.max(np.abs(audio)) + 1e-8)
87
 
88
- # Create feature representation (mel-spectrogram)
89
  n_mels = 80
90
  n_fft = 400
91
  hop_length = 160
92
 
93
- # Simple mel-spectrogram computation
94
  mel_features = self._compute_mel_spectrogram(audio, sr, n_mels, n_fft, hop_length)
95
 
96
  if progress_callback:
97
- progress_callback(30, "Audio features extracted")
98
 
99
  return mel_features
100
 
@@ -102,42 +80,36 @@ class MuseTalkInference:
102
  print(f"Error extracting audio features: {e}")
103
  raise
104
 
105
- def extract_video_frames(self, video_path: str, fps: int = 25, progress_callback=None) -> Tuple[list, int, int]:
106
- """Extract frames from video file.
107
-
108
- Args:
109
- video_path: Path to video file
110
- fps: Target fps for extraction
111
- progress_callback: Optional progress callback
112
-
113
- Returns:
114
- Tuple of (frames list, width, height)
115
- """
116
  try:
117
  if progress_callback:
118
- progress_callback(10, "Extracting video frames...")
119
-
120
- cap = cv2.VideoCapture(video_path)
121
  frames = []
122
- frame_count = 0
123
 
124
- while True:
125
- ret, frame = cap.read()
126
- if not ret:
127
- break
 
128
  frames.append(frame)
129
- frame_count += 1
130
-
131
- cap.release()
132
 
 
 
 
 
 
 
 
 
 
 
133
  if not frames:
134
- raise ValueError("No frames extracted from video")
135
 
136
  height, width = frames[0].shape[:2]
137
-
138
- if progress_callback:
139
- progress_callback(30, f"Extracted {len(frames)} frames")
140
-
141
  return frames, width, height
142
 
143
  except Exception as e:
@@ -145,22 +117,12 @@ class MuseTalkInference:
145
  raise
146
 
147
  def detect_faces(self, frames: list, progress_callback=None) -> list:
148
- """Detect faces in video frames.
149
-
150
- Args:
151
- frames: List of video frames
152
- progress_callback: Optional progress callback
153
-
154
- Returns:
155
- List of face bounding boxes for each frame
156
- """
157
  try:
158
  if progress_callback:
159
- progress_callback(40, "Detecting faces in frames...")
160
 
161
  face_detections = []
162
-
163
- # Use OpenCV's Haar Cascade for face detection
164
  cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
165
  face_cascade = cv2.CascadeClassifier(cascade_path)
166
 
@@ -169,145 +131,112 @@ class MuseTalkInference:
169
  faces = face_cascade.detectMultiScale(gray, 1.1, 4)
170
 
171
  if len(faces) > 0:
172
- # Take the largest face
173
  face = max(faces, key=lambda f: f[2] * f[3])
174
  face_detections.append(face)
175
  else:
176
- # Use previous face detection or frame dimensions
177
  if face_detections:
178
  face_detections.append(face_detections[-1])
179
  else:
180
  h, w = frame.shape[:2]
181
  face_detections.append(np.array([w//4, h//4, w//2, h//2]))
182
-
183
- if (i + 1) % max(1, len(frames) // 10) == 0 and progress_callback:
184
- progress_callback(40 + int((i + 1) / len(frames) * 20), f"Detected faces: {i + 1}/{len(frames)}")
185
 
186
  return face_detections
187
-
188
  except Exception as e:
189
  print(f"Error detecting faces: {e}")
190
  raise
191
 
192
- def generate_lipsync(self, frames: list, audio_features: np.ndarray, face_detections: list,
193
- progress_callback=None) -> list:
194
- """Generate lip-sync frames.
195
-
196
- Args:
197
- frames: List of original video frames
198
- audio_features: Audio feature array
199
- face_detections: List of face bounding boxes
200
- progress_callback: Optional progress callback
201
-
202
- Returns:
203
- List of lip-synced frames
204
  """
205
  try:
206
- if progress_callback:
207
- progress_callback(60, "Generating lip-sync...")
208
-
209
- lipsync_frames = []
210
 
211
- # For now, return frames with marked regions (placeholder for actual inference)
212
- for i, frame in enumerate(frames):
213
- output_frame = frame.copy()
214
-
215
- if i < len(face_detections):
216
- face = face_detections[i]
217
- x, y, w, h = int(face[0]), int(face[1]), int(face[2]), int(face[3])
218
- # Draw rectangle around detected face region
219
- cv2.rectangle(output_frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
220
-
221
- lipsync_frames.append(output_frame)
222
-
223
- if (i + 1) % max(1, len(frames) // 10) == 0 and progress_callback:
224
- progress_callback(60 + int((i + 1) / len(frames) * 20), f"Lip-sync frames: {i + 1}/{len(frames)}")
225
 
226
- return lipsync_frames
 
 
 
227
 
228
- except Exception as e:
229
- print(f"Error generating lip-sync: {e}")
230
- raise
231
 
232
- def save_output_video(self, frames: list, output_path: str, fps: int = 25, progress_callback=None) -> str:
233
- """Save generated frames as video file.
234
-
235
- Args:
236
- frames: List of output frames
237
- output_path: Path to save output video
238
- fps: Frames per second for output video
239
- progress_callback: Optional progress callback
240
-
241
- Returns:
242
- Path to saved video file
243
- """
244
- try:
245
- if progress_callback:
246
- progress_callback(80, "Encoding video...")
247
-
248
- if not frames:
249
- raise ValueError("No frames to save")
250
 
251
- height, width = frames[0].shape[:2]
 
252
 
253
- # Use OpenCV VideoWriter
 
254
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
255
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
256
-
257
- for i, frame in enumerate(frames):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  out.write(frame)
259
- if (i + 1) % max(1, len(frames) // 10) == 0 and progress_callback:
260
- progress_callback(80 + int((i + 1) / len(frames) * 15), f"Encoding: {i + 1}/{len(frames)}")
261
-
 
 
 
262
  out.release()
263
-
 
264
  if progress_callback:
265
- progress_callback(95, "Video encoding complete")
266
-
267
- return output_path
268
-
269
- except Exception as e:
270
- print(f"Error saving video: {e}")
271
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
- def generate(self, audio_path: str, video_path: str, output_path: str,
274
- fps: int = 25, progress_callback=None) -> str:
275
- """Generate lip-synced video from audio and video.
276
-
277
- Args:
278
- audio_path: Path to input audio file
279
- video_path: Path to input video file
280
- output_path: Path to save output video
281
- fps: Target fps for output
282
- progress_callback: Optional progress callback
283
-
284
- Returns:
285
- Path to generated video
286
- """
287
- try:
288
- # Initialize models if not already done
289
- if not self.initialized:
290
- self.load_models(progress_callback)
291
-
292
- # Extract audio features
293
- audio_features = self.extract_audio_features(audio_path, progress_callback)
294
-
295
- # Extract video frames
296
- frames, width, height = self.extract_video_frames(video_path, fps, progress_callback)
297
-
298
- # Detect faces
299
- face_detections = self.detect_faces(frames, progress_callback)
300
-
301
- # Generate lip-sync
302
- output_frames = self.generate_lipsync(frames, audio_features, face_detections, progress_callback)
303
-
304
- # Save output video
305
- result_path = self.save_output_video(output_frames, output_path, fps, progress_callback)
306
-
307
  if progress_callback:
308
- progress_callback(100, "Lip-sync generation complete!")
309
 
310
- return result_path
311
 
312
  except Exception as e:
313
  print(f"Error during generation: {e}")
@@ -315,18 +244,7 @@ class MuseTalkInference:
315
 
316
  def _compute_mel_spectrogram(self, audio: np.ndarray, sr: int, n_mels: int,
317
  n_fft: int, hop_length: int) -> np.ndarray:
318
- """Compute mel-spectrogram from audio.
319
-
320
- Args:
321
- audio: Audio signal
322
- sr: Sample rate
323
- n_mels: Number of mel bins
324
- n_fft: FFT window size
325
- hop_length: Hop length
326
-
327
- Returns:
328
- Mel-spectrogram array
329
- """
330
  try:
331
  import librosa
332
  mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft,
@@ -334,6 +252,5 @@ class MuseTalkInference:
334
  mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
335
  return mel_spec
336
  except:
337
- # Fallback: return a dummy feature array
338
  n_frames = len(audio) // hop_length
339
  return np.random.randn(n_mels, n_frames)
 
1
  """MuseTalk Inference Module
2
 
3
+ Refactored for Long-Form Generation (5-10 mins)
4
+ using Memory-Efficient Streaming, Looping, and Audio Muxing.
5
  """
6
 
7
  import os
 
9
  import torch
10
  import numpy as np
11
  import tempfile
12
+ import librosa
13
+ import mimetypes
14
+ import subprocess
15
  from pathlib import Path
16
  from typing import Optional, Tuple, Union
 
17
 
18
 
19
  class MuseTalkInference:
20
  """MuseTalk inference engine for audio-driven video generation."""
21
 
22
  def __init__(self, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
 
 
 
 
 
23
  self.device = device
24
  self.model = None
25
  self.whisper_model = None
 
28
  self.initialized = False
29
 
30
  def load_models(self, progress_callback=None):
31
+ """Load MuseTalk models from HuggingFace Hub."""
 
 
 
 
32
  try:
33
  if progress_callback:
34
  progress_callback(0, "Loading MuseTalk models...")
35
 
36
+ # Placeholder: Initialize your actual PyTorch models here
37
  self.initialized = True
38
 
39
  if progress_callback:
40
+ progress_callback(5, "Models loaded successfully")
41
 
42
  except Exception as e:
43
  print(f"Error loading models: {e}")
44
  raise
45
 
46
  def extract_audio_features(self, audio_path: str, progress_callback=None) -> np.ndarray:
47
+ """Extract audio features using Whisper/Mel-Spectrogram."""
 
 
 
 
 
 
 
 
48
  try:
49
  if progress_callback:
50
  progress_callback(10, "Extracting audio features...")
51
 
 
52
  try:
 
53
  audio, sr = librosa.load(audio_path, sr=16000)
54
  except:
 
55
  try:
56
  import scipy.io.wavfile as wavfile
57
  sr, audio = wavfile.read(audio_path)
 
59
  ratio = 16000 / sr
60
  audio = (audio * ratio).astype(np.int16)
61
  except:
 
62
  import soundfile as sf
63
  audio, sr = sf.read(audio_path)
64
 
 
65
  audio = audio.astype(np.float32)
66
  audio = audio / (np.max(np.abs(audio)) + 1e-8)
67
 
 
68
  n_mels = 80
69
  n_fft = 400
70
  hop_length = 160
71
 
 
72
  mel_features = self._compute_mel_spectrogram(audio, sr, n_mels, n_fft, hop_length)
73
 
74
  if progress_callback:
75
+ progress_callback(15, "Audio features extracted")
76
 
77
  return mel_features
78
 
 
80
  print(f"Error extracting audio features: {e}")
81
  raise
82
 
83
+ def extract_source_frames(self, file_path: str, fps: int = 25, progress_callback=None) -> Tuple[list, int, int]:
84
+ """Extracts frames from a short video or loads a single image to memory."""
 
 
 
 
 
 
 
 
 
85
  try:
86
  if progress_callback:
87
+ progress_callback(20, "Reading source image/video...")
88
+
89
+ mime_type, _ = mimetypes.guess_type(file_path)
90
  frames = []
 
91
 
92
+ # Handle Single Image Input
93
+ if mime_type and mime_type.startswith('image'):
94
+ frame = cv2.imread(file_path)
95
+ if frame is None:
96
+ raise ValueError("Failed to read image")
97
  frames.append(frame)
 
 
 
98
 
99
+ # Handle Short Video Input
100
+ else:
101
+ cap = cv2.VideoCapture(file_path)
102
+ while True:
103
+ ret, frame = cap.read()
104
+ if not ret:
105
+ break
106
+ frames.append(frame)
107
+ cap.release()
108
+
109
  if not frames:
110
+ raise ValueError("No frames extracted from source file")
111
 
112
  height, width = frames[0].shape[:2]
 
 
 
 
113
  return frames, width, height
114
 
115
  except Exception as e:
 
117
  raise
118
 
119
  def detect_faces(self, frames: list, progress_callback=None) -> list:
120
+ """Detect faces ONLY on the short source clip to save compute."""
 
 
 
 
 
 
 
 
121
  try:
122
  if progress_callback:
123
+ progress_callback(25, "Detecting face in source media...")
124
 
125
  face_detections = []
 
 
126
  cascade_path = cv2.data.haarcascades + 'haarcascade_frontalface_default.xml'
127
  face_cascade = cv2.CascadeClassifier(cascade_path)
128
 
 
131
  faces = face_cascade.detectMultiScale(gray, 1.1, 4)
132
 
133
  if len(faces) > 0:
134
+ # Take the LARGEST face by area (width * height)
135
  face = max(faces, key=lambda f: f[2] * f[3])
136
  face_detections.append(face)
137
  else:
 
138
  if face_detections:
139
  face_detections.append(face_detections[-1])
140
  else:
141
  h, w = frame.shape[:2]
142
  face_detections.append(np.array([w//4, h//4, w//2, h//2]))
 
 
 
143
 
144
  return face_detections
 
145
  except Exception as e:
146
  print(f"Error detecting faces: {e}")
147
  raise
148
 
149
+ def generate(self, audio_path: str, video_path: str, output_path: str,
150
+ fps: int = 25, progress_callback=None) -> str:
151
+ """
152
+ Memory-efficient generator for long videos.
153
+ Loops short inputs to match 5-10 minute audio.
 
 
 
 
 
 
 
154
  """
155
  try:
156
+ if not self.initialized:
157
+ self.load_models(progress_callback)
 
 
158
 
159
+ # 1. Extract audio features
160
+ audio_features = self.extract_audio_features(audio_path, progress_callback)
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
+ # 2. Determine Total Output Frames based on Audio Length
163
+ audio_data, sr = librosa.load(audio_path, sr=16000)
164
+ audio_duration = len(audio_data) / sr
165
+ total_target_frames = int(audio_duration * fps)
166
 
167
+ if total_target_frames == 0:
168
+ raise ValueError("Audio file is too short or invalid.")
 
169
 
170
+ # 3. Extract Source Clip/Image (Only loads short clip into memory)
171
+ source_frames, width, height = self.extract_source_frames(video_path, fps, progress_callback)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # 4. Detect faces on the short source clip (Pre-cached)
174
+ source_faces = self.detect_faces(source_frames, progress_callback)
175
 
176
+ # 5. Stream Process (Write directly to file to avoid OOM crash)
177
+ temp_silent_video = output_path.replace('.mp4', '_silent.mp4')
178
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
179
+ out = cv2.VideoWriter(temp_silent_video, fourcc, fps, (width, height))
180
+
181
+ if progress_callback:
182
+ progress_callback(30, f"Generating {total_target_frames} frames (Streaming)...")
183
+
184
+ for i in range(total_target_frames):
185
+ # LOOPING LOGIC: Loop the short video or image continuously
186
+ src_idx = i % len(source_frames)
187
+ frame = source_frames[src_idx].copy()
188
+ face = source_faces[src_idx]
189
+
190
+ # --- START AI LIP-SYNC INFERENCE ---
191
+ # NOTE: Put your actual AI model generation code here.
192
+ # Right now, this just draws a box around the face.
193
+ # Example: frame = self.model.infer(frame, face, audio_features[:, i])
194
+
195
+ x, y, w, h = int(face[0]), int(face[1]), int(face[2]), int(face[3])
196
+ cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
197
+ # --- END AI LIP-SYNC INFERENCE ---
198
+
199
+ # Write directly to disk (Saves 30GB+ of RAM for 10 min videos)
200
  out.write(frame)
201
+
202
+ # Report progress periodically
203
+ if (i + 1) % max(1, total_target_frames // 20) == 0 and progress_callback:
204
+ progress_pct = 30 + int((i / total_target_frames) * 60)
205
+ progress_callback(progress_pct, f"Generated frames: {i + 1}/{total_target_frames}")
206
+
207
  out.release()
208
+
209
+ # 6. MUX AUDIO (Combine the generated silent video with original audio)
210
  if progress_callback:
211
+ progress_callback(95, "Merging final audio and video...")
212
+
213
+ try:
214
+ cmd = [
215
+ "ffmpeg", "-y",
216
+ "-i", temp_silent_video, # The generated silent video
217
+ "-i", audio_path, # The original audio
218
+ "-c:v", "libx264", # Re-encode video for broad web compatibility
219
+ "-c:a", "aac", # Re-encode audio to AAC
220
+ "-map", "0:v:0",
221
+ "-map", "1:a:0",
222
+ "-shortest", # Cut at the shortest stream
223
+ output_path
224
+ ]
225
+ subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
226
+
227
+ # Cleanup temp file
228
+ if os.path.exists(temp_silent_video):
229
+ os.remove(temp_silent_video)
230
+
231
+ except subprocess.CalledProcessError as e:
232
+ print(f"FFMPEG Error: {e.stderr}")
233
+ # Fallback to silent video if FFMPEG fails
234
+ os.rename(temp_silent_video, output_path)
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  if progress_callback:
237
+ progress_callback(100, "Generation Complete!")
238
 
239
+ return output_path
240
 
241
  except Exception as e:
242
  print(f"Error during generation: {e}")
 
244
 
245
  def _compute_mel_spectrogram(self, audio: np.ndarray, sr: int, n_mels: int,
246
  n_fft: int, hop_length: int) -> np.ndarray:
247
+ """Compute mel-spectrogram from audio."""
 
 
 
 
 
 
 
 
 
 
 
248
  try:
249
  import librosa
250
  mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_fft=n_fft,
 
252
  mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
253
  return mel_spec
254
  except:
 
255
  n_frames = len(audio) // hop_length
256
  return np.random.randn(n_mels, n_frames)