vta-ldm / app.py
fffiloni's picture
Update app.py
6a33c24 verified
raw
history blame contribute delete
No virus
8.2 kB
import gradio as gr
import huggingface_hub
import os
import subprocess
import threading
import shutil
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from moviepy.editor import VideoFileClip, AudioFileClip
# download model
huggingface_hub.snapshot_download(
repo_id='ariesssxu/vta-ldm-clip4clip-v-large',
local_dir='./ckpt/vta-ldm-clip4clip-v-large'
)
def stream_output(pipe):
for line in iter(pipe.readline, ''):
print(line, end='')
def print_directory_contents(path):
for root, dirs, files in os.walk(path):
level = root.replace(path, '').count(os.sep)
indent = ' ' * 4 * (level)
print(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
print(f"{subindent}{f}")
# Print the ckpt directory contents
print_directory_contents('./ckpt')
def get_wav_files(path):
wav_files = [] # Initialize an empty list to store the paths of .wav files
for root, dirs, files in os.walk(path):
level = root.replace(path, '').count(os.sep)
indent = ' ' * 4 * (level)
print(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
file_path = os.path.join(root, f)
if f.lower().endswith('.wav'):
wav_files.append(file_path) # Add .wav file paths to the list
print(f"{subindent}{file_path}")
else:
print(f"{subindent}{f}")
return wav_files # Return the list of .wav file paths
def check_outputs_folder(folder_path):
# Check if the folder exists
if os.path.exists(folder_path) and os.path.isdir(folder_path):
# Delete all contents inside the folder
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove directory
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
else:
print(f'The folder {folder_path} does not exist.')
def plot_spectrogram(wav_file, output_image):
# Read the WAV file
sample_rate, audio_data = wavfile.read(wav_file)
# Check if audio_data is stereo (2 channels) and convert it to mono (1 channel) if needed
if len(audio_data.shape) == 2:
audio_data = audio_data.mean(axis=1)
# Create a plot for the spectrogram
plt.figure(figsize=(10, 2))
plt.specgram(audio_data, Fs=sample_rate, NFFT=1024, noverlap=512, cmap='gray', aspect='auto')
# Remove gridlines and ticks for a cleaner look
plt.grid(False)
plt.xticks([])
plt.yticks([])
# Save the plot as an image file
plt.savefig(output_image, bbox_inches='tight', pad_inches=0, dpi=300)
plt.close
def merge_audio_to_video(input_vid, input_aud):
# Load the video file
video = VideoFileClip(input_vid)
# Load the new audio file
new_audio = AudioFileClip(input_aud)
# Set the new audio to the video
video_with_new_audio = video.set_audio(new_audio)
# Save the result to a new file
video_with_new_audio.write_videofile("output_video.mp4", codec='libx264', audio_codec='aac')
return "output_video.mp4"
def infer(video_in):
# check if 'outputs' dir exists and empty it if necessary
check_outputs_folder('./outputs/tmp')
# Need to find path to gradio temp vid from video input
print(f"VIDEO IN PATH: {video_in}")
# Get the directory name
folder_path = os.path.dirname(video_in)
# Path to the input video file
input_video_path = video_in
# Load the video file
video = VideoFileClip(input_video_path)
# Get the length of the video in seconds
video_duration = int(video.duration)
print(f"Video duration: {video_duration} seconds")
# Check if the video duration is more than 10 seconds
if video_duration > 10:
# Cut the video to the first 10 seconds
cut_video = video.subclip(0, 10)
video_duration = 10
# Extract the directory and filename
dir_name = os.path.dirname(input_video_path)
base_name = os.path.basename(input_video_path)
# Generate the new filename
new_base_name = base_name.replace(".mp4", "_10sec_cut.mp4")
output_video_path = os.path.join(dir_name, new_base_name)
# Save the cut video
cut_video.write_videofile(output_video_path, codec='libx264', audio_codec='aac')
print(f"Cut video saved as: {output_video_path}")
video_in = output_video_path
# Delete the original video file
os.remove(input_video_path)
print(f"Original video file {input_video_path} deleted.")
else:
print("Video is 10 seconds or shorter; no cutting needed.")
# Execute the inference command
command = ['python', 'inference_from_video.py',
'--original_args', 'ckpt/vta-ldm-clip4clip-v-large/summary.jsonl',
'--model', 'ckpt/vta-ldm-clip4clip-v-large/pytorch_model_2.bin',
'--data_path', folder_path,
'--max_duration', f"{video_duration}"
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1)
# Create threads to handle stdout and stderr
stdout_thread = threading.Thread(target=stream_output, args=(process.stdout,))
stderr_thread = threading.Thread(target=stream_output, args=(process.stderr,))
# Start the threads
stdout_thread.start()
stderr_thread.start()
# Wait for the process to complete and the threads to finish
process.wait()
stdout_thread.join()
stderr_thread.join()
print("Inference script finished with return code:", process.returncode)
# Need to find where are the results stored, default should be "./outputs/tmp"
# Print the outputs directory contents
print_directory_contents('./outputs/tmp')
wave_files = get_wav_files('./outputs/tmp')
print(wave_files)
plot_spectrogram(wave_files[0], 'spectrogram.png')
final_merged_out = merge_audio_to_video(video_in, wave_files[0])
return wave_files[0], 'spectrogram.png', final_merged_out
css="""
#col-container {
max-width: 920px;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Video-to-Audio Generation with Hidden Alignment")
gr.HTML("""
<div style="display:flex;column-gap:4px;">
<a href='https://sites.google.com/view/vta-ldm'>
<img src='https://img.shields.io/badge/Project-Page-Green'>
</a>
<a href='https://huggingface.co/papers/2407.07464'>
<img src='https://img.shields.io/badge/HF-Paper-red'>
</a>
</div>
""")
with gr.Row():
with gr.Column():
video_in = gr.Video(label='Video IN', format="mp4", include_audio=False)
submit_btn = gr.Button("Submit")
gr.Examples(
examples = [
["./examples/lion_gt.mp4"],
["./examples/ice_gt.mp4"],
["./examples/seashore.mp4"],
["./examples/typewriter.mp4"],
["./examples/tennis_gt.mp4"],
["./examples/chew.mp4"],
],
inputs = [video_in]
)
with gr.Column():
output_sound = gr.Audio(label="Audio OUT")
output_spectrogram = gr.Image(label='Spectrogram')
merged_out = gr.Video(label="Merged video + generated audio")
submit_btn.click(
fn = infer,
inputs = [video_in],
outputs = [output_sound, output_spectrogram, merged_out],
show_api = False
)
demo.launch(show_api=False, show_error=True)