garyuzair commited on
Commit
e03938b
Β·
verified Β·
1 Parent(s): 3c36747

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +354 -416
app.py CHANGED
@@ -1,526 +1,464 @@
1
  import streamlit as st
2
- import imageio
3
- import numpy as np
4
  from PIL import Image
5
- from transformers import pipeline, AutoProcessor, MusicgenForConditionalGeneration
6
- import soundfile as sf
7
  import torch
 
8
  import os
9
  import tempfile
10
  import math
11
- import gc # Garbage collector
12
- import traceback # For detailed error logging
13
 
14
- # Try importing moviepy, with fallback
15
  try:
16
  import moviepy.editor as mpy
17
- except ModuleNotFoundError:
18
- st.error("The 'moviepy' library is not installed. Please install it (`pip install moviepy`) 🚨")
19
- st.stop()
20
- except OSError as e:
21
- st.error(f"Error initializing moviepy: {e}. This might be due to a missing ffmpeg. 🚨")
22
- st.warning("Ensure ffmpeg is installed and accessible in your system's PATH (e.g., `sudo apt-get install ffmpeg` or `conda install ffmpeg`).")
23
- st.stop()
24
-
25
- # --- Constants & Defaults ---
26
- MODEL_MUSICGEN = "facebook/musicgen-small"
27
- MODEL_MOONDREAM = "vikhyatk/moondream2" # Using official Moondream2
28
- DEFAULT_AUDIO_DURATION_S = 10 # Reduced default for faster testing
29
- DEFAULT_FRAMES_TO_ANALYZE = 3 # Reduced for faster default processing
30
- DEFAULT_GUIDANCE = 3.0 # MusicGen default
31
- DEFAULT_TEMPERATURE = 1.0 # MusicGen default
32
- MAX_FRAMES_TO_SHOW_UI = 3
33
- MAX_MOONDREAM_DESCRIPTION_TOKENS = 70 # For concise sound descriptions
34
- MUSICGEN_PROMPT_MAX_TOKENS = 1024 # Safety limit for MusicGen prompt
35
 
36
- IS_CUDA_AVAILABLE = torch.cuda.is_available()
37
- DEVICE = torch.device("cuda:0" if IS_CUDA_AVAILABLE else "cpu")
 
38
 
39
- # --- Page Config ---
40
- st.set_page_config(
41
- page_title="AI Video Sound Designer",
42
- page_icon="🎬",
43
- layout="wide"
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # --- Cached Model Loaders ---
47
  @st.cache_resource
48
- def load_moondream_model():
 
49
  try:
50
- model_kwargs = {"torch_dtype": torch.float16} if IS_CUDA_AVAILABLE else {}
51
- pipe = pipeline(
52
- "image-to-text", # Correct pipeline type for Moondream-like models
53
- model=MODEL_MOONDREAM,
54
- trust_remote_code=True, # Moondream requires this
55
- model_kwargs=model_kwargs,
56
- device=DEVICE
57
- )
58
- st.toast(f"Moondream2 loaded on: {str(DEVICE).upper()} ({model_kwargs.get('torch_dtype', 'float32')})", icon="πŸ€–")
59
- return pipe
60
  except Exception as e:
61
- st.error(f"Error loading Moondream2 model ({MODEL_MOONDREAM}): {e}")
62
  st.error(traceback.format_exc())
63
- return None
64
 
65
  @st.cache_resource
66
- def load_musicgen_model_and_processor():
 
67
  try:
68
- processor = AutoProcessor.from_pretrained(MODEL_MUSICGEN)
69
- model = MusicgenForConditionalGeneration.from_pretrained(MODEL_MUSICGEN)
70
-
71
- dtype_str = "float32 (full)"
72
- if IS_CUDA_AVAILABLE:
73
- model = model.half().to(DEVICE)
74
- dtype_str = "float16 (half)"
75
- else:
76
- model = model.to(DEVICE) # Ensure it's on CPU if not CUDA
77
-
78
- st.toast(f"MusicGen ({MODEL_MUSICGEN}) loaded on: {str(DEVICE).upper()} with {dtype_str} precision.", icon="🎢")
79
  return processor, model
80
  except Exception as e:
81
- st.error(f"Error loading MusicGen model ({MODEL_MUSICGEN}): {e}")
82
  st.error(traceback.format_exc())
83
  return None, None
84
 
85
- # --- Utilities ---
86
- def clear_gpu_memory():
87
- if IS_CUDA_AVAILABLE:
88
- torch.cuda.empty_cache()
89
- gc.collect()
90
-
91
- # --- Frame Extraction ---
92
- def extract_frames(video_path: str, num_frames_to_extract: int) -> list[Image.Image]:
93
  frames = []
94
  reader = None
95
  try:
96
  reader = imageio.get_reader(video_path, "ffmpeg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- total_video_frames = 0
99
- # Try to get frame count using different methods for robustness
100
- try:
101
- total_video_frames = reader.count_frames()
102
- except Exception: # count_frames might not be implemented or fail
103
- pass
104
-
105
- if not isinstance(total_video_frames, (int, float)) or total_video_frames <= 0:
106
- meta_data = reader.get_meta_data()
107
- total_video_frames = meta_data.get('nframes') # Check 'nframes' in metadata
108
- if not isinstance(total_video_frames, (int, float)) or total_video_frames <= 0:
109
- fps = meta_data.get('fps', 25)
110
- duration = meta_data.get('duration')
111
- if duration and fps:
112
- total_video_frames = int(fps * duration)
113
- else: # Cannot determine length
114
- st.warning("Could not reliably determine video length. Frame selection may be suboptimal.")
115
- # Fallback: try to read up to a certain number of frames to estimate
116
- # This part can be complex, for now, assume if above fails, it's problematic.
117
- total_video_frames = 0 # Indicate failure to determine
118
-
119
- if total_video_frames < 1:
120
- st.error("Video appears to have 0 frames or its length could not be determined accurately.")
121
- if reader: reader.close()
122
- return []
123
-
124
- num_to_sample = max(1, min(num_frames_to_extract, int(total_video_frames)))
125
-
126
- indices = np.linspace(0, int(total_video_frames) - 1, num_to_sample, dtype=int, endpoint=True)
127
- indices = sorted(list(set(indices))) # Ensure unique and sorted indices
128
-
129
  for i in indices:
130
- frame_data = reader.get_data(i)
131
- frames.append(Image.fromarray(frame_data).convert("RGB"))
132
-
 
 
 
 
 
 
 
133
  except (imageio.core.fetching.NeedDownloadError, OSError) as e_ffmpeg:
134
- st.error(f"FFmpeg not found or failed: {e_ffmpeg} 🚨.")
135
- st.warning("Please install ffmpeg and ensure it's in your system's PATH (e.g., `sudo apt-get install ffmpeg` or `conda install ffmpeg`).")
136
  return []
137
  except Exception as e:
138
- st.error(f"Could not extract frames: {e}")
139
  st.error(traceback.format_exc())
140
  return []
141
  finally:
142
  if reader:
143
  reader.close()
144
-
145
- if not frames and num_frames_to_extract > 0:
146
- st.warning("No frames were extracted. The video might be empty, corrupted, or in an unsupported format.")
147
- return frames
148
-
149
- # --- Sound Prompt Generation ---
150
- def generate_sound_prompt(frames: list[Image.Image], moondream_pipe) -> str:
151
- instruction = (
152
- "Describe only the sounds implied by this image. Focus on: ambient noise, distinct sound events, "
153
- "sound textures, actions producing sound, and the overall atmosphere. Be concise and evocative."
154
- )
155
  descriptions = []
156
- with st.spinner(f"Analyzing {len(frames)} frames with Moondream2..."):
 
 
157
  for i, frame in enumerate(frames):
158
- # st.progress((i + 1) / len(frames), text=f"Analyzing frame {i+1}/{len(frames)}") # More granular progress
159
  try:
160
- # Moondream2 pipeline expects 'images' and 'prompt'
161
- # generate_kwargs can control output length for the captioner
162
- out = moondream_pipe(
163
- images=frame,
164
- prompt=instruction,
165
- generate_kwargs={"max_new_tokens": MAX_MOONDREAM_DESCRIPTION_TOKENS}
166
- )
167
- text = out[0].get('generated_text', '').strip()
168
- if text:
169
- descriptions.append(text)
170
  except Exception as e:
171
- st.warning(f"Could not analyze frame {i+1} with Moondream2: {e}")
172
  continue
173
 
174
  if not descriptions:
175
- return "ambient background noise, general atmosphere" # Fallback if all frames fail
176
-
177
- # Combine unique descriptions, trying to maintain some order
178
- combined = "; ".join(list(dict.fromkeys(descriptions)))
179
- return combined
180
-
181
- # --- Audio Generation ---
182
- def generate_audio(prompt: str, duration_s: int, musicgen_processor, musicgen_model, guidance_scale: float, temperature: float):
 
 
 
183
  try:
184
- # Check prompt length against a practical limit for MusicGen (e.g., T5 encoder limit)
185
- # MusicGen processor will truncate, but a warning is good.
186
- # This tokenization is for estimation; actual truncation is by the processor.
187
- prompt_tokens = musicgen_processor.tokenizer.tokenize(prompt)
188
- if len(prompt_tokens) > MUSICGEN_PROMPT_MAX_TOKENS:
189
- st.warning(
190
- f"Generated sound prompt is very long ({len(prompt_tokens)} tokens) and will be truncated by MusicGen. "
191
- f"This might affect audio quality. Consider using fewer analysis frames or if descriptions are too verbose."
192
- )
193
- # Truncate prompt manually if desired, or let processor handle it.
194
- # For simplicity, we let the processor handle it.
195
-
196
- inputs = musicgen_processor(
197
- text=[prompt],
198
- return_tensors="pt",
199
- padding=True
200
- )
201
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
202
-
203
- # Critical for .half() models: ensure non-float tensors are not cast to half
204
- inputs = {
205
- k: (v.to(musicgen_model.dtype) if v.dtype.is_floating_point else v)
206
- for k, v in inputs.items()
207
- }
208
-
209
- # Determine max_new_tokens based on duration
210
- # MusicGen's default is 50 tokens/second.
211
- # Max sequence length for musicgen-small's decoder is typically 2048.
212
- # We need to leave space for the prompt tokens.
213
- # Max generated tokens (1500 for 30s) is a safe upper bound.
214
- tokens_per_second = musicgen_model.config.audio_encoder.token_per_second
215
- max_new_tokens = min(int(duration_s * tokens_per_second), 1500)
216
 
217
- with torch.inference_mode():
218
- audio_tensor = musicgen_model.generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  **inputs,
220
  max_new_tokens=max_new_tokens,
221
  do_sample=True,
222
- guidance_scale=guidance_scale,
223
- temperature=temperature,
224
- pad_token_id=musicgen_model.config.eos_token_id
225
  )
226
 
227
- # Post-processing
228
- arr = audio_tensor[0].cpu().float().numpy()
229
- # Normalize audio
230
- peak = np.max(np.abs(arr))
231
- if peak == 0: # Avoid division by zero for silent audio
232
- arr = np.zeros_like(arr) # return silence
233
- else:
234
- arr = arr / peak * 0.9 # Normalize with some headroom
235
- arr = np.clip(arr, -1.0, 1.0)
236
 
237
- sampling_rate = musicgen_model.config.audio_encoder.sampling_rate
238
- return arr, sampling_rate
239
-
 
240
  except Exception as e:
241
- st.error(f"Error during audio generation: {e}")
242
  st.error(traceback.format_exc())
243
  return None, None
244
- finally:
245
- clear_gpu_memory()
246
 
 
 
 
 
 
247
 
248
- # --- Sync Audio/Video ---
249
- def sync_audio_video(video_path: str, audio_arr: np.ndarray, sampling_rate: int, mix_original_audio: bool) -> str | None:
250
- tmp_wav_path = None
251
  output_video_path = None
 
252
  video_clip = None
253
- audio_clip_generated = None
254
  final_clip = None
255
 
256
  try:
257
- # Write generated audio to a temporary WAV file
258
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_wav_file:
259
- sf.write(tmp_wav_file.name, audio_arr, sampling_rate)
260
- tmp_wav_path = tmp_wav_file.name
261
-
 
262
  video_clip = mpy.VideoFileClip(video_path)
263
- audio_clip_generated = mpy.AudioFileClip(tmp_wav_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
- # Match audio duration to video duration (loop or trim)
266
- if audio_clip_generated.duration < video_clip.duration:
267
- num_loops = math.ceil(video_clip.duration / audio_clip_generated.duration)
268
- audio_clip_generated = mpy.concatenate_audioclips([audio_clip_generated] * num_loops)
269
-
270
- audio_clip_generated = audio_clip_generated.subclip(0, video_clip.duration) # Trim to exact video duration
271
-
272
- final_audio = audio_clip_generated
273
- if mix_original_audio and video_clip.audio:
274
- # Simple mix: lower volume of both and combine
275
- # Adjust volumes as needed for better mixing
276
- original_audio_processed = video_clip.audio.volumex(0.6)
277
- generated_audio_processed = audio_clip_generated.volumex(0.8) # Give generated slightly more presence
278
- final_audio = mpy.CompositeAudioClip([original_audio_processed, generated_audio_processed])
279
-
280
  final_clip = video_clip.set_audio(final_audio)
281
 
282
- # Create a temporary file for the output video
283
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_out_vid_file:
284
- output_video_path = tmp_out_vid_file.name
285
-
286
  final_clip.write_videofile(
287
- output_video_path,
288
- codec='libx264',
289
- audio_codec='aac',
290
- threads=max(1, os.cpu_count() // 2), # Use half CPU cores, or at least 1
291
- logger=None # Suppress verbose moviepy logs
 
292
  )
293
  return output_video_path
294
 
295
  except Exception as e:
296
- st.error(f"Error syncing audio with video: {e}")
297
  st.error(traceback.format_exc())
298
  return None
299
  finally:
300
- # Close moviepy clips to release file locks
301
  if video_clip: video_clip.close()
302
- if audio_clip_generated: audio_clip_generated.close()
303
- if final_clip: final_clip.close()
304
-
305
- if tmp_wav_path and os.path.exists(tmp_wav_path):
306
- os.remove(tmp_wav_path)
307
- # The output_video_path is returned and should be handled (e.g., deleted after download) by the caller
308
- clear_gpu_memory()
309
-
310
- # --- Main UI ---
311
- st.title("🎬 AI Video Sound Designer")
312
- st.markdown(
313
- "Upload an MP4 video, and this tool will analyze its visuals using **Moondream2** "
314
- "to generate relevant sound prompts, then synthesize sound effects or music using **MusicGen**. "
315
- "You can download the audio or the video with new synchronized sound."
316
- )
317
- st.markdown("---")
318
 
319
- # Sidebar Settings
320
  with st.sidebar:
321
- st.header("βš™οΈ Generation Settings")
322
-
323
- num_frames_to_analyze = st.slider(
324
- "Frames to Analyze (for sound ideas)",
325
- min_value=1, max_value=10, value=DEFAULT_FRAMES_TO_ANALYZE,
326
- help="Number of frames sampled from the video to generate sound descriptions. More frames can give diverse ideas but take longer to analyze."
327
- )
328
 
329
- audio_duration_s = st.slider(
330
- "Target Audio Duration (seconds)",
331
- min_value=5, max_value=30, value=DEFAULT_AUDIO_DURATION_S,
332
- help="Duration of the generated audio. If shorter than video, it will be looped."
333
- )
334
-
335
  st.subheader("MusicGen Parameters")
336
- guidance_scale = st.slider(
337
- "Guidance Scale",
338
- min_value=1.0, max_value=10.0, value=DEFAULT_GUIDANCE, step=0.5,
339
- help="Higher values make the audio follow the prompt more closely, but can reduce diversity. (MusicGen default: 3.0)"
340
- )
341
- temperature = st.slider(
342
- "Temperature",
343
- min_value=0.1, max_value=2.0, value=DEFAULT_TEMPERATURE, step=0.1,
344
- help="Controls randomness. Higher values mean more diversity/creativity, lower values more deterministic. (MusicGen default: 1.0)"
345
- )
346
 
347
- st.subheader("Video Output")
348
- mix_with_original_audio = st.checkbox(
349
- "Mix with original video audio", value=False,
350
- help="If checked, the generated sound will be mixed with the video's existing audio. Otherwise, existing audio is replaced."
351
- )
352
 
353
- st.markdown("---")
354
- st.caption(f"Using Moondream2: `{MODEL_MOONDREAM}`")
355
- st.caption(f"Using MusicGen: `{MODEL_MUSICGEN}`")
356
- st.caption(f"Processing on: `{str(DEVICE).upper()}`")
357
- # if st.button("Clear Model Cache & Reload"):
358
- # st.cache_resource.clear()
359
- # st.experimental_rerun()
360
 
 
 
361
 
362
- # File Uploader
363
- uploaded_video_file = st.file_uploader("πŸ“€ Upload MP4 Video", type=['mp4', 'mov', 'avi']) # Added more types
 
 
 
364
 
365
- if 'processed_video_path' not in st.session_state:
366
- st.session_state.processed_video_path = None
367
- if 'generated_audio_path' not in st.session_state:
368
- st.session_state.generated_audio_path = None
369
 
370
-
371
- if uploaded_video_file:
372
- st.info(f"Uploaded: `{uploaded_video_file.name}` ({uploaded_video_file.size / (1024*1024):.2f} MB)")
373
 
374
- # Use a button to trigger processing to avoid re-processing on every widget change
375
  if st.button("✨ Generate Sound Design!", type="primary", use_container_width=True):
376
- # Clear previous results from session state if any
377
- if st.session_state.processed_video_path and os.path.exists(st.session_state.processed_video_path):
378
- os.remove(st.session_state.processed_video_path)
379
- st.session_state.processed_video_path = None
380
- if st.session_state.generated_audio_path and os.path.exists(st.session_state.generated_audio_path):
381
- os.remove(st.session_state.generated_audio_path)
382
- st.session_state.generated_audio_path = None
383
-
384
- # Path for the uploaded video, will be cleaned up in finally
385
- temp_uploaded_video_path = None
386
-
 
387
  try:
388
- # --- 1. Save uploaded video to a temporary file ---
389
- with st.spinner("Preparing video..."):
390
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_video_file.name)[1]) as tfile:
391
- tfile.write(uploaded_video_file.read())
392
- temp_uploaded_video_path = tfile.name
393
 
394
- # --- 2. Extract Frames ---
395
- with st.spinner(f"Extracting {num_frames_to_analyze} frames from video..."):
396
- extracted_frames = extract_frames(temp_uploaded_video_path, num_frames_to_analyze)
397
-
398
- if not extracted_frames:
399
- st.error("Failed to extract frames. Cannot proceed.")
400
- st.stop()
401
-
402
- st.subheader("πŸ–ΌοΈ Sampled Frames for Sound Analysis")
403
- cols_to_show = min(len(extracted_frames), MAX_FRAMES_TO_SHOW_UI)
404
- if cols_to_show > 0:
405
- cols = st.columns(cols_to_show)
406
- for i in range(cols_to_show):
407
- cols[i].image(extracted_frames[i], caption=f"Frame {i+1}", use_column_width=True)
408
- elif extracted_frames:
409
- st.write(f"{len(extracted_frames)} frames extracted (not shown due to display limit).")
410
 
411
- # --- 3. Load Moondream2 and Generate Sound Prompt ---
412
- moondream_pipe = load_moondream_model()
413
- if not moondream_pipe:
414
- st.error("Moondream2 model could not be loaded. Cannot generate sound prompt.")
415
  st.stop()
416
 
417
- sound_prompt = generate_sound_prompt(extracted_frames, moondream_pipe)
418
- del moondream_pipe # Release Moondream model from memory
419
- clear_gpu_memory()
420
-
421
- if not sound_prompt:
422
- st.error("Failed to generate a sound prompt from video frames.")
423
- st.stop()
424
- st.info(f"🧠 **Generated Sound Prompt:** {sound_prompt}")
425
-
426
- # --- 4. Load MusicGen and Generate Audio ---
427
- musicgen_processor, musicgen_model = load_musicgen_model_and_processor()
428
- if not musicgen_processor or not musicgen_model:
429
- st.error("MusicGen model/processor could not be loaded. Cannot generate audio.")
430
- st.stop()
431
-
432
- with st.spinner(f"Synthesizing {audio_duration_s}s audio with MusicGen... (This can take a few minutes)"):
433
- generated_audio_arr, generated_sr = generate_audio(
434
- sound_prompt, audio_duration_s, musicgen_processor, musicgen_model, guidance_scale, temperature
435
- )
436
-
437
- del musicgen_processor, musicgen_model # Release MusicGen model from memory
438
- clear_gpu_memory()
439
-
440
- if generated_audio_arr is None or generated_sr is None:
441
- st.error("Audio generation failed.")
442
- st.stop()
443
-
444
- st.subheader("πŸ”Š Generated Sound Effect")
445
- st.audio(generated_audio_arr, sample_rate=generated_sr)
446
 
447
- # Save generated audio to a temporary file for download
448
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_audio_file:
449
- sf.write(tmp_audio_file.name, generated_audio_arr, generated_sr)
450
- st.session_state.generated_audio_path = tmp_audio_file.name # Store path in session_state
451
-
452
- with open(st.session_state.generated_audio_path, 'rb') as f_audio:
453
- st.download_button(
454
- label="πŸ“₯ Download Generated Audio Only (.wav)",
455
- data=f_audio,
456
- file_name=f"{os.path.splitext(uploaded_video_file.name)[0]}_sound_effect.wav",
457
- mime='audio/wav'
458
- )
459
 
460
- # --- 5. Sync Audio with Video ---
461
- with st.spinner("Synchronizing generated audio with video..."):
462
- output_video_path = sync_audio_video(
463
- temp_uploaded_video_path, generated_audio_arr, generated_sr, mix_with_original_audio
464
- )
465
-
466
- if output_video_path and os.path.exists(output_video_path):
467
- st.subheader("πŸŽ₯ Video with New Sound Design")
468
- st.video(output_video_path)
469
- st.session_state.processed_video_path = output_video_path # Store path
470
-
471
- with open(output_video_path, 'rb') as f_video:
 
 
 
 
 
 
 
 
 
 
 
 
472
  st.download_button(
473
- label="🎬 Download Video with New Sound (.mp4)",
474
- data=f_video,
475
- file_name=f"{os.path.splitext(uploaded_video_file.name)[0]}_sound_designed.mp4",
476
- mime='video/mp4'
477
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  else:
479
- st.error("Failed to create video with new sound.")
480
 
481
  except Exception as e:
482
- st.error(f"An unexpected error occurred during processing: {e}")
483
  st.error(traceback.format_exc())
484
  finally:
485
- # Clean up the initially uploaded temporary video file
486
- if temp_uploaded_video_path and os.path.exists(temp_uploaded_video_path):
487
- os.remove(temp_uploaded_video_path)
488
- clear_gpu_memory()
489
- # Note: Temp files for download (st.session_state.processed_video_path, st.session_state.generated_audio_path)
490
- # are not deleted here. They would ideally be cleaned up when the session ends or a new file is processed.
491
- # Streamlit doesn't have a direct session end hook for cleanup easily.
492
- # For robust cleanup of download files, more complex session management or a background task would be needed.
493
- # Current approach: they are overwritten or deleted on next "Generate" click.
494
-
495
- # Display results from session state if they exist (e.g., after a settings change without re-upload)
496
- # This part is tricky because changing settings should ideally re-trigger processing with new settings.
497
- # The button "Generate Sound Design!" helps control this. If there are previous results and no button click,
498
- # they might still be shown. This part is simplified for now.
499
- elif st.session_state.generated_audio_path and os.path.exists(st.session_state.generated_audio_path):
500
- st.subheader("πŸ”Š Previously Generated Sound Effect")
501
- st.audio(st.session_state.generated_audio_path)
502
- with open(st.session_state.generated_audio_path, 'rb') as f_audio:
503
  st.download_button(
504
- label="πŸ“₯ Download Generated Audio Only (.wav)",
505
- data=f_audio,
506
- file_name=f"{os.path.splitext(uploaded_video_file.name)[0]}_sound_effect.wav",
507
- mime='audio/wav',
508
- key="prev_audio_download" # Unique key
509
  )
510
- if st.session_state.processed_video_path and os.path.exists(st.session_state.processed_video_path):
511
- st.subheader("πŸŽ₯ Previously Generated Video with New Sound Design")
512
- st.video(st.session_state.processed_video_path)
513
- with open(st.session_state.processed_video_path, 'rb') as f_video:
 
514
  st.download_button(
515
- label="🎬 Download Video with New Sound (.mp4)",
516
- data=f_video,
517
- file_name=f"{os.path.splitext(uploaded_video_file.name)[0]}_sound_designed.mp4",
518
- mime='video/mp4',
519
- key="prev_video_download" # Unique key
520
  )
521
 
522
  else:
523
- st.info("πŸ‘‹ Upload a video file to begin!")
524
 
525
  st.markdown("---")
526
- st.caption("Developed with ❀️ using Streamlit, Hugging Face Transformers, MoviePy, and ImageIO.")
 
1
  import streamlit as st
 
 
2
  from PIL import Image
3
+ import numpy as np
 
4
  import torch
5
+ import gc
6
  import os
7
  import tempfile
8
  import math
9
+ import imageio
10
+ import traceback
11
 
12
+ # --- Attempt to import moviepy for video processing ---
13
  try:
14
  import moviepy.editor as mpy
15
+ MOVIEPY_AVAILABLE = True
16
+ except (ImportError, OSError) as e:
17
+ MOVIEPY_AVAILABLE = False
18
+ st.warning(
19
+ "MoviePy library is not available or ffmpeg is missing. "
20
+ "Video syncing features will be disabled. "
21
+ "If running locally, install with: pip install moviepy. Ensure ffmpeg is installed."
22
+ )
23
+ print(f"MoviePy load error: {e}")
24
+
 
 
 
 
 
 
 
 
25
 
26
+ # --- Model Configuration ---
27
+ IMAGE_CAPTION_MODEL = "Salesforce/blip-image-captioning-base"
28
+ AUDIO_GEN_MODEL = "facebook/musicgen-small"
29
 
30
+ # --- Constants ---
31
+ DEFAULT_NUM_FRAMES = 2 # Fewer frames for faster processing on free tier
32
+ DEFAULT_AUDIO_DURATION_S = 5 # Shorter audio for faster generation
33
+ MAX_FRAMES_TO_SHOW_UI = 3
34
+ DEVICE = torch.device("cpu") # Explicitly use CPU for Hugging Face free tier
35
+
36
+ # --- Page Setup ---
37
+ st.set_page_config(page_title="AI Video Sound Designer (HF Space)", layout="wide", page_icon="🎬")
38
+
39
+ st.title("🎬 AI Video Sound Designer (for Hugging Face Spaces)")
40
+ st.markdown("""
41
+ Upload a short MP4 video. The tool will:
42
+ 1. Extract frames from the video.
43
+ 2. Analyze frames using an image captioning model to generate sound ideas.
44
+ 3. Synthesize audio using MusicGen based on these ideas.
45
+ 4. Optionally, combine the new audio with your video.
46
+ ---
47
+ **Note:** Processing on CPU (especially audio generation) can be slow. Please be patient!
48
+ """)
49
+
50
+ # --- Utility Functions ---
51
+ def clear_memory(model_obj=None, processor_obj=None):
52
+ """Clears model objects from memory and runs garbage collection."""
53
+ if model_obj:
54
+ del model_obj
55
+ if processor_obj:
56
+ del processor_obj
57
+ gc.collect()
58
+ if torch.cuda.is_available(): # Though we target CPU, good practice
59
+ torch.cuda.empty_cache()
60
+ print("Memory cleared.")
61
 
 
62
  @st.cache_resource
63
+ def load_image_caption_model_and_processor():
64
+ """Loads the image captioning model and processor."""
65
  try:
66
+ from transformers import BlipProcessor, BlipForConditionalGeneration
67
+ st.write(f"Loading Image Captioning Model: {IMAGE_CAPTION_MODEL} (this might take a moment)...")
68
+ processor = BlipProcessor.from_pretrained(IMAGE_CAPTION_MODEL)
69
+ model = BlipForConditionalGeneration.from_pretrained(IMAGE_CAPTION_MODEL).to(DEVICE)
70
+ st.toast("Image Captioning model loaded!", icon="πŸ–ΌοΈ")
71
+ return processor, model
 
 
 
 
72
  except Exception as e:
73
+ st.error(f"Error loading image captioning model: {e}")
74
  st.error(traceback.format_exc())
75
+ return None, None
76
 
77
  @st.cache_resource
78
+ def load_audio_gen_model_and_processor():
79
+ """Loads the audio generation model and processor."""
80
  try:
81
+ from transformers import AutoProcessor, MusicgenForConditionalGeneration
82
+ st.write(f"Loading Audio Generation Model: {AUDIO_GEN_MODEL} (this might take a while on CPU)...")
83
+ processor = AutoProcessor.from_pretrained(AUDIO_GEN_MODEL)
84
+ model = MusicgenForConditionalGeneration.from_pretrained(AUDIO_GEN_MODEL).to(DEVICE)
85
+ st.toast("Audio Generation model loaded! (CPU generation will be slow)", icon="🎢")
 
 
 
 
 
 
86
  return processor, model
87
  except Exception as e:
88
+ st.error(f"Error loading audio generation model: {e}")
89
  st.error(traceback.format_exc())
90
  return None, None
91
 
92
+ def extract_frames_from_video(video_path, num_frames):
93
+ """Extracts a specified number of frames evenly from a video."""
 
 
 
 
 
 
94
  frames = []
95
  reader = None
96
  try:
97
  reader = imageio.get_reader(video_path, "ffmpeg")
98
+ total_frames = reader.count_frames()
99
+ if total_frames == 0: # If count_frames fails, try metadata
100
+ meta = reader.get_meta_data()
101
+ duration = meta.get('duration')
102
+ fps = meta.get('fps', 25)
103
+ if duration:
104
+ total_frames = int(duration * fps)
105
+ else: # Fallback if duration isn't available
106
+ st.warning("Could not determine video length. Will attempt to read initial frames.")
107
+ # Try to read a few frames anyway if count fails
108
+ for i, frame_data in enumerate(reader):
109
+ if i < num_frames * 5: # Read a bit more than needed to find distinct frames
110
+ frames.append(Image.fromarray(frame_data).convert("RGB"))
111
+ if len(frames) >= num_frames:
112
+ break
113
+ if reader: reader.close()
114
+ return frames[::len(frames)//num_frames] if frames else []
115
+
116
+
117
+ if total_frames < num_frames:
118
+ indices = np.arange(total_frames)
119
+ else:
120
+ indices = np.linspace(0, total_frames - 1, num_frames, dtype=int, endpoint=True)
121
 
122
+ actual_frames_extracted = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  for i in indices:
124
+ if actual_frames_extracted >= num_frames:
125
+ break
126
+ try:
127
+ frame_data = reader.get_data(i)
128
+ frames.append(Image.fromarray(frame_data).convert("RGB"))
129
+ actual_frames_extracted +=1
130
+ except Exception as e:
131
+ st.warning(f"Skipping problematic frame {i}: {e}")
132
+ continue
133
+ return frames
134
  except (imageio.core.fetching.NeedDownloadError, OSError) as e_ffmpeg:
135
+ st.error(f"FFmpeg not found or failed: {e_ffmpeg}. Please ensure ffmpeg is installed and in PATH if running locally.")
 
136
  return []
137
  except Exception as e:
138
+ st.error(f"Error extracting frames: {e}")
139
  st.error(traceback.format_exc())
140
  return []
141
  finally:
142
  if reader:
143
  reader.close()
144
+
145
+ def generate_sound_prompt_from_frames(frames, caption_processor, caption_model):
146
+ """Generates sound descriptions from frames using BLIP."""
147
+ if not frames:
148
+ return "ambient background noise"
149
+
 
 
 
 
 
150
  descriptions = []
151
+ instruction = "A short description of this image, focusing on elements that might produce sound:"
152
+
153
+ with st.spinner(f"Generating sound ideas from {len(frames)} frames..."):
154
  for i, frame in enumerate(frames):
 
155
  try:
156
+ inputs = caption_processor(images=frame, text=instruction, return_tensors="pt").to(DEVICE)
157
+ # For BLIP, generate is typically used like this.
158
+ # You might need to adjust max_length based on desired description length.
159
+ generated_ids = caption_model.generate(**inputs, max_length=50) # Keep descriptions short
160
+ description = caption_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
161
+ if description:
162
+ descriptions.append(description)
163
+ st.progress((i + 1) / len(frames), text=f"Frame {i+1}/{len(frames)} analyzed.")
 
 
164
  except Exception as e:
165
+ st.warning(f"Could not get description for a frame: {e}")
166
  continue
167
 
168
  if not descriptions:
169
+ return "general ambiance, subtle environmental sounds" # Fallback
170
+
171
+ # Simple combination: join unique descriptions
172
+ unique_descriptions = list(dict.fromkeys(descriptions))
173
+ combined_prompt = ". ".join(unique_descriptions)
174
+ # Further processing to make it more like a sound design brief
175
+ final_prompt = f"Sounds for a scene featuring: {combined_prompt}. Focus on atmosphere, key sound events, and textures."
176
+ return final_prompt
177
+
178
+ def generate_audio_from_prompt(prompt, duration_s, audio_processor, audio_model, guidance, temp):
179
+ """Generates audio using MusicGen."""
180
  try:
181
+ inputs = audio_processor(text=[prompt], return_tensors="pt", padding=True).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ # MusicGen has a max sequence length for the prompt, often around 2048 tokens.
184
+ # Forcing it to 512 to be safe on CPU and for typical descriptions.
185
+ # The processor handles truncation.
186
+ if inputs.input_ids.shape[1] > 512:
187
+ st.warning(f"Prompt is long ({inputs.input_ids.shape[1]} tokens), might be truncated by MusicGen.")
188
+ # inputs['input_ids'] = inputs['input_ids'][:, :512]
189
+ # inputs['attention_mask'] = inputs['attention_mask'][:, :512]
190
+
191
+
192
+ # Calculate max_new_tokens based on duration and model's token/sec rate
193
+ # musicgen-small typically 50 tokens/second. Max output length ~2048 tokens.
194
+ tokens_per_second = audio_model.config.audio_encoder.token_per_second # typically 50 for musicgen
195
+ max_new_tokens = min(int(duration_s * tokens_per_second), 1500) # Cap at 1500 (30s) as a practical limit
196
+
197
+ with st.spinner(f"Synthesizing {duration_s}s audio... (CPU: This will take several minutes!)"):
198
+ # For CPU, do_sample=False might be faster but less diverse. Try True first.
199
+ audio_values = audio_model.generate(
200
  **inputs,
201
  max_new_tokens=max_new_tokens,
202
  do_sample=True,
203
+ guidance_scale=guidance,
204
+ temperature=temp,
205
+ # No pad_token_id for MusicGen's generate function, it uses eos_token_id for padding by default if needed
206
  )
207
 
208
+ audio_array = audio_values[0, 0].cpu().numpy()
209
+ sampling_rate = audio_model.config.audio_encoder.sampling_rate
 
 
 
 
 
 
 
210
 
211
+ # Normalize
212
+ if np.abs(audio_array).max() > 0:
213
+ audio_array = audio_array / np.abs(audio_array).max() * 0.9
214
+ return audio_array, sampling_rate
215
  except Exception as e:
216
+ st.error(f"Error generating audio: {e}")
217
  st.error(traceback.format_exc())
218
  return None, None
 
 
219
 
220
+ def combine_audio_video(video_path, audio_array, sampling_rate, mix_original):
221
+ """Combines generated audio with the video using MoviePy."""
222
+ if not MOVIEPY_AVAILABLE:
223
+ st.error("MoviePy is not available. Cannot combine audio and video.")
224
+ return None
225
 
 
 
 
226
  output_video_path = None
227
+ temp_audio_path = None
228
  video_clip = None
229
+ generated_audio_clip = None
230
  final_clip = None
231
 
232
  try:
233
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio:
234
+ # Scipy.io.wavfile can also be used here, or soundfile
235
+ import scipy.io.wavfile
236
+ scipy.io.wavfile.write(tmp_audio.name, sampling_rate, audio_array)
237
+ temp_audio_path = tmp_audio.name
238
+
239
  video_clip = mpy.VideoFileClip(video_path)
240
+ generated_audio_clip = mpy.AudioFileClip(temp_audio_path)
241
+
242
+ # Loop or trim generated audio to match video duration
243
+ if generated_audio_clip.duration < video_clip.duration:
244
+ generated_audio_clip = generated_audio_clip.fx(mpy.afx.audio_loop, duration=video_clip.duration)
245
+ elif generated_audio_clip.duration > video_clip.duration:
246
+ generated_audio_clip = generated_audio_clip.subclip(0, video_clip.duration)
247
+
248
+ final_audio = generated_audio_clip
249
+ if mix_original and video_clip.audio:
250
+ # Adjust volumes for mixing
251
+ original_audio = video_clip.audio.volumex(0.5) # Lower original audio
252
+ generated_audio = generated_audio_clip.volumex(0.8) # Keep generated slightly louder
253
+ final_audio = mpy.CompositeAudioClip([original_audio, generated_audio])
254
+ final_audio = final_audio.set_duration(video_clip.duration) # Ensure composite duration matches
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  final_clip = video_clip.set_audio(final_audio)
257
 
258
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmp_video_out:
259
+ output_video_path = tmp_video_out.name
260
+
 
261
  final_clip.write_videofile(
262
+ output_video_path,
263
+ codec="libx264",
264
+ audio_codec="aac",
265
+ temp_audiofile_path=os.path.dirname(temp_audio_path), # Ensure moviepy can write temp audio here
266
+ threads=2, # Limit threads on free tier
267
+ logger=None # or 'bar' for progress
268
  )
269
  return output_video_path
270
 
271
  except Exception as e:
272
+ st.error(f"Error combining audio and video: {e}")
273
  st.error(traceback.format_exc())
274
  return None
275
  finally:
276
+ # Close clips to release resources
277
  if video_clip: video_clip.close()
278
+ if generated_audio_clip: generated_audio_clip.close()
279
+ # if final_clip: final_clip.close() # final_clip is usually the same as video_clip with modified audio
280
+
281
+ if temp_audio_path and os.path.exists(temp_audio_path):
282
+ os.remove(temp_audio_path)
283
+ # The output_video_path is handled by the caller (downloaded, then potentially cleaned up)
 
 
 
 
 
 
 
 
 
 
284
 
285
+ # --- Sidebar for Settings ---
286
  with st.sidebar:
287
+ st.header("βš™οΈ Settings")
288
+ num_frames_analysis = st.slider("Number of Frames to Analyze", 1, 5, DEFAULT_NUM_FRAMES, 1,
289
+ help="More frames provide more context but increase analysis time.")
290
+ audio_duration = st.slider("Target Audio Duration (seconds)", 3, 15, DEFAULT_AUDIO_DURATION_S, 1,
291
+ help="Shorter durations generate much faster on CPU.")
 
 
292
 
 
 
 
 
 
 
293
  st.subheader("MusicGen Parameters")
294
+ guidance = st.slider("Guidance Scale (MusicGen)", 1.0, 7.0, 3.0, 0.5,
295
+ help="Higher values make audio follow prompt more closely. Default is 3.0.")
296
+ temperature = st.slider("Temperature (MusicGen)", 0.5, 1.5, 1.0, 0.1,
297
+ help="Controls randomness. Higher is more diverse. Default is 1.0.")
 
 
 
 
 
 
298
 
299
+ if MOVIEPY_AVAILABLE:
300
+ st.subheader("Video Output")
301
+ mix_audio = st.checkbox("Mix with original video audio", value=False)
302
+ else:
303
+ mix_audio = False # Disable if moviepy not available
304
 
 
 
 
 
 
 
 
305
 
306
+ # --- Main Application Logic ---
307
+ uploaded_file = st.file_uploader("πŸ“€ Upload your MP4 video file (short clips recommended):", type=["mp4", "mov", "avi"])
308
 
309
+ # Initialize session state for generated file paths
310
+ if 'generated_audio_file' not in st.session_state:
311
+ st.session_state.generated_audio_file = None
312
+ if 'output_video_file' not in st.session_state:
313
+ st.session_state.output_video_file = None
314
 
 
 
 
 
315
 
316
+ if uploaded_file is not None:
317
+ st.video(uploaded_file)
 
318
 
319
+ # Use a button to trigger processing
320
  if st.button("✨ Generate Sound Design!", type="primary", use_container_width=True):
321
+ # --- Clear previous results ---
322
+ if st.session_state.generated_audio_file and os.path.exists(st.session_state.generated_audio_file):
323
+ os.remove(st.session_state.generated_audio_file)
324
+ st.session_state.generated_audio_file = None
325
+ if st.session_state.output_video_file and os.path.exists(st.session_state.output_video_file):
326
+ os.remove(st.session_state.output_video_file)
327
+ st.session_state.output_video_file = None
328
+ clear_memory()
329
+
330
+ video_bytes = uploaded_file.read()
331
+ temp_video_path = None
332
+
333
  try:
334
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_vid:
335
+ tmp_vid.write(video_bytes)
336
+ temp_video_path = tmp_vid.name
 
 
337
 
338
+ # === Stage 1: Frame Extraction ===
339
+ st.subheader("1. Extracting Frames")
340
+ with st.spinner("Extracting frames from video..."):
341
+ frames = extract_frames_from_video(temp_video_path, num_frames_analysis)
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
+ if not frames:
344
+ st.error("No frames extracted. Cannot proceed.")
 
 
345
  st.stop()
346
 
347
+ st.success(f"Extracted {len(frames)} frames.")
348
+ if frames:
349
+ cols_to_show = min(len(frames), MAX_FRAMES_TO_SHOW_UI)
350
+ if cols_to_show > 0:
351
+ st.write("Sampled Frames:")
352
+ cols = st.columns(cols_to_show)
353
+ for i, frame_img in enumerate(frames[:cols_to_show]):
354
+ cols[i].image(frame_img, caption=f"Frame {i+1}", use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
+ # === Stage 2: Image Captioning (Sound Prompt Generation) ===
357
+ st.subheader("2. Generating Sound Ideas (Image Analysis)")
358
+ caption_processor, caption_model = load_image_caption_model_and_processor()
359
+ if caption_processor and caption_model:
360
+ sound_prompt = generate_sound_prompt_from_frames(frames, caption_processor, caption_model)
361
+ st.info(f"✍️ **Generated Sound Prompt:** {sound_prompt}")
362
+
363
+ # Unload captioning model immediately
364
+ clear_memory(caption_model, caption_processor)
365
+ else:
366
+ st.error("Failed to load image captioning model. Using a default prompt.")
367
+ sound_prompt = "ambient nature sounds with a gentle breeze" # Fallback
368
 
369
+ # === Stage 3: Audio Generation ===
370
+ st.subheader("3. Synthesizing Audio (MusicGen)")
371
+ st.warning("🎧 Audio generation on CPU can take several minutes. Please be patient!")
372
+ audio_processor, audio_model = load_audio_gen_model_and_processor()
373
+ generated_audio_array, sr = None, None # Initialize
374
+
375
+ if audio_processor and audio_model:
376
+ generated_audio_array, sr = generate_audio_from_prompt(sound_prompt, audio_duration, audio_processor, audio_model, guidance, temperature)
377
+ # Unload audio model immediately
378
+ clear_memory(audio_model, audio_processor)
379
+ else:
380
+ st.error("Failed to load audio generation model. Cannot generate audio.")
381
+
382
+ if generated_audio_array is not None and sr is not None:
383
+ st.success("Audio generated!")
384
+ st.audio(generated_audio_array, sample_rate=sr)
385
+
386
+ # Save audio for download
387
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_audio_out:
388
+ import scipy.io.wavfile # or soundfile
389
+ scipy.io.wavfile.write(tmp_audio_out.name, sr, generated_audio_array)
390
+ st.session_state.generated_audio_file = tmp_audio_out.name
391
+
392
+ with open(st.session_state.generated_audio_file, "rb") as f:
393
  st.download_button(
394
+ "πŸ“₯ Download Generated Audio (.wav)",
395
+ f,
396
+ file_name="generated_sound.wav",
397
+ mime="audio/wav"
398
  )
399
+
400
+ # === Stage 4: (Optional) Video and Audio Syncing ===
401
+ if MOVIEPY_AVAILABLE:
402
+ st.subheader("4. Combining Audio with Video")
403
+ with st.spinner("Processing video with new audio... (can be slow)"):
404
+ output_video_file_path = combine_audio_video(temp_video_path, generated_audio_array, sr, mix_audio)
405
+
406
+ if output_video_file_path and os.path.exists(output_video_file_path):
407
+ st.success("Video processing complete!")
408
+ st.video(output_video_file_path)
409
+ st.session_state.output_video_file = output_video_file_path
410
+
411
+ with open(output_video_file_path, "rb") as f_vid:
412
+ st.download_button(
413
+ "🎬 Download Video with New Sound (.mp4)",
414
+ f_vid,
415
+ file_name="video_with_new_sound.mp4",
416
+ mime="video/mp4"
417
+ )
418
+ elif MOVIEPY_AVAILABLE: # Only show error if moviepy was expected to work
419
+ st.error("Failed to combine audio and video.")
420
  else:
421
+ st.error("Audio generation failed. Cannot proceed to video syncing.")
422
 
423
  except Exception as e:
424
+ st.error(f"An unexpected error occurred in the main processing pipeline: {e}")
425
  st.error(traceback.format_exc())
426
  finally:
427
+ if temp_video_path and os.path.exists(temp_video_path):
428
+ os.remove(temp_video_path)
429
+ # Models are cleared within their stages using clear_memory()
430
+ # Generated download files (audio/video) are kept in session_state until next run or session ends
431
+ print("Main processing finished or errored. Temp video (if any) cleaned up.")
432
+ clear_memory() # Final catch-all clear
433
+
434
+ # Display download buttons if files were generated in a previous run within the session
435
+ elif st.session_state.generated_audio_file and os.path.exists(st.session_state.generated_audio_file):
436
+ st.markdown("---")
437
+ st.write("Previously generated audio:")
438
+ st.audio(st.session_state.generated_audio_file)
439
+ with open(st.session_state.generated_audio_file, "rb") as f:
 
 
 
 
 
440
  st.download_button(
441
+ "πŸ“₯ Download Previously Generated Audio (.wav)",
442
+ f,
443
+ file_name="generated_sound_previous.wav",
444
+ mime="audio/wav",
445
+ key="prev_audio_dl"
446
  )
447
+ if st.session_state.output_video_file and os.path.exists(st.session_state.output_video_file) and MOVIEPY_AVAILABLE:
448
+ st.markdown("---")
449
+ st.write("Previously generated video with new sound:")
450
+ st.video(st.session_state.output_video_file)
451
+ with open(st.session_state.output_video_file, "rb") as f_vid:
452
  st.download_button(
453
+ "🎬 Download Previously Generated Video (.mp4)",
454
+ f_vid,
455
+ file_name="video_with_new_sound_previous.mp4",
456
+ mime="video/mp4",
457
+ key="prev_video_dl"
458
  )
459
 
460
  else:
461
+ st.info("☝️ Upload a video to get started.")
462
 
463
  st.markdown("---")
464
+ st.markdown("Made for Hugging Face Spaces. Model loading & generation can be slow on CPU.")