VidTune / main.py
animikhaich's picture
Added: Dockerfile, Conda Env, Python Reqs, Watchdog Cleaner, CPU-only warning
e957316
import streamlit as st
from engine import DescribeVideo, GenerateAudio
import os
from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip
from moviepy.audio.fx.volumex import volumex
from streamlit.runtime.scriptrunner import get_script_run_ctx
def get_session_id():
session_id = get_script_run_ctx().session_id
session_id = session_id.replace("-", "_")
session_id = "_id_" + session_id
return session_id
user_session_id = get_session_id()
os.makedirs(user_session_id, exist_ok=True)
# Define model maps
video_model_map = {
"Fast": "flash",
"Quality": "pro",
}
music_model_map = {
"Fast": "musicgen-stereo-small",
"Balanced": "musicgen-stereo-medium",
"Quality": "musicgen-stereo-large",
}
# music_model_map = {
# "Fast": "facebook/musicgen-melody",
# "Quality": "facebook/musicgen-melody-large",
# }
genre_map = {
"None": None,
"Pop": "Pop",
"Rock": "Rock",
"Hip Hop": "Hip-Hop/Rap",
"Jazz": "Jazz",
"Classical": "Classical",
"Blues": "Blues",
"Country": "Country",
"EDM": "Electronic/Dance",
"Metal": "Metal",
"Disco": "Disco",
"Lo-Fi": "Lo-Fi",
}
# Streamlit page configuration
st.set_page_config(
page_title="VidTune: Where Videos Find Their Melody",
layout="centered",
page_icon="assets/favicon.png",
)
left_co, cent_co, last_co = st.columns(3)
with cent_co:
st.image("assets/VidTune-Logo-Without-BG.png", use_column_width=False, width=200)
# Title and Description
st.markdown(
"""
<style>
h2, p, div, img {
text-align: center;
}
</style>
<div style="font-size: 35px; font-weight: bold;">VidTune: Where Videos Find Their Melody</div>
<p>VidTune is a web application to effortlessly tailor perfect soundtracks for your videos with AI.</p>
""",
unsafe_allow_html=True,
)
# Initialize session state for advanced settings and other inputs
if "show_advanced" not in st.session_state:
st.session_state.show_advanced = False
if "video_model" not in st.session_state:
st.session_state.video_model = "Fast"
if "music_model" not in st.session_state:
st.session_state.music_model = "Fast"
if "num_samples" not in st.session_state:
st.session_state.num_samples = 3
if "music_genre" not in st.session_state:
st.session_state.music_genre = None
if "music_bpm" not in st.session_state:
st.session_state.music_bpm = 100
if "user_keywords" not in st.session_state:
st.session_state.user_keywords = None
if "selected_audio" not in st.session_state:
st.session_state.selected_audio = "None"
if "audio_paths" not in st.session_state:
st.session_state.audio_paths = []
if "selected_audio_path" not in st.session_state:
st.session_state.selected_audio_path = None
if "orig_audio_vol" not in st.session_state:
st.session_state.orig_audio_vol = 100
if "generated_audio_vol" not in st.session_state:
st.session_state.generated_audio_vol = 100
if "generate_button_flag" not in st.session_state:
st.session_state.generate_button_flag = False
if "video_description_content" not in st.session_state:
st.session_state.video_description_content = ""
if "music_prompt" not in st.session_state:
st.session_state.music_prompt = ""
if "audio_mix_flag" not in st.session_state:
st.session_state.audio_mix_flag = False
if "google_api_key" not in st.session_state:
st.session_state.google_api_key = ""
# Sidebar
st.sidebar.title("Configuration")
# Google API Key
st.session_state.google_api_key = st.sidebar.text_input(
"Enter your [Google API Key](https://ai.google.dev/gemini-api/docs/api-key) to get started :",
st.session_state.google_api_key,
type="password",
)
if not st.session_state.google_api_key:
st.warning("Please enter your Google API Key to proceed.")
st.stop()
# Basic Settings
st.session_state.video_model = st.sidebar.selectbox(
"Select Video Descriptor",
["Fast", "Quality"],
index=["Fast", "Quality"].index(st.session_state.video_model),
)
st.session_state.music_model = st.sidebar.selectbox(
"Select Music Generator",
["Fast", "Balanced", "Quality"],
index=["Fast", "Balanced", "Quality"].index(st.session_state.music_model),
)
st.session_state.num_samples = st.sidebar.slider(
"Number of samples", 1, 5, st.session_state.num_samples
)
# Sidebar for advanced settings
with st.sidebar:
# Create a placeholder for the advanced settings button
placeholder = st.empty()
# Button to toggle advanced settings
if placeholder.button("Advanced"):
st.session_state.show_advanced = not st.session_state.show_advanced
st.rerun() # Refresh the layout after button click
# Display advanced settings if enabled
if st.session_state.show_advanced:
# Advanced settings
st.session_state.music_bpm = st.sidebar.slider("Beats Per Minute", 35, 180, 100)
st.session_state.music_genre = st.sidebar.selectbox(
"Select Music Genre",
list(genre_map.keys()),
index=(
list(genre_map.keys()).index(st.session_state.music_genre)
if st.session_state.music_genre in genre_map.keys()
else 0
),
)
st.session_state.user_keywords = st.sidebar.text_input(
"User Keywords",
value=st.session_state.user_keywords,
help="Enter keywords separated by commas.",
)
else:
st.session_state.music_genre = None
st.session_state.music_bpm = None
st.session_state.user_keywords = None
# Generate Button
generate_button = st.sidebar.button("Generate Music")
# Cache the model loading
@st.cache_resource
def load_models(video_model_key, music_model_key, google_api_key):
video_descriptor = DescribeVideo(
model=video_model_map[video_model_key], google_api_key=google_api_key
)
audio_generator = GenerateAudio(model=music_model_map[music_model_key])
if audio_generator.device == "cpu":
st.warning(
"The music generator model is running on CPU. For faster results, consider using a GPU."
)
return video_descriptor, audio_generator
# Load models
video_descriptor, audio_generator = load_models(
st.session_state.video_model,
st.session_state.music_model,
st.session_state.google_api_key,
)
# Video Uploader
uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
if uploaded_video is not None:
st.session_state.uploaded_video = uploaded_video
with open(f"{user_session_id}/temp.mp4", mode="wb") as w:
w.write(uploaded_video.getvalue())
# Video Player
if os.path.exists(f"{user_session_id}/temp.mp4") and uploaded_video is not None:
st.video(uploaded_video)
# Submit button if video is not uploaded
if generate_button:
if uploaded_video is None:
st.error("Please upload a video before generating music.")
st.stop()
with st.spinner("Analyzing video..."):
video_description = video_descriptor.describe_video(
f"{user_session_id}/temp.mp4",
genre=st.session_state.music_genre,
bpm=st.session_state.music_bpm,
user_keywords=st.session_state.user_keywords,
)
video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration
st.session_state.video_description_content = video_description[
"Content Description"
]
st.session_state.music_prompt = video_description["Music Prompt"]
st.success("Video description generated successfully.")
st.session_state.generate_button_flag = True
# Display Video Description and Music Prompt
if st.session_state.generate_button_flag:
st.text_area(
"Video Description",
st.session_state.video_description_content,
disabled=True,
height=120,
)
music_prompt = st.text_area(
"Music Prompt",
st.session_state.music_prompt,
disabled=True,
height=120,
)
if generate_button:
# Generate Music
with st.spinner("Generating music..."):
if video_duration > 30:
st.warning(
"Due to hardware limitations, the maximum music length is capped at 30 seconds."
)
music_prompt = [st.session_state.music_prompt] * st.session_state.num_samples
audio_generator.generate_audio(music_prompt, duration=video_duration)
st.session_state.audio_paths = audio_generator.save_audio()
st.success("Music generated successfully.")
st.balloons()
# Callback function for radio button selection change
def on_audio_selection_change():
st.session_state.audio_mix_flag = False
selected_audio_index = st.session_state.selected_audio
if selected_audio_index > 0:
st.session_state.selected_audio_path = st.session_state.audio_paths[
selected_audio_index - 1
]
else:
st.session_state.selected_audio_path = None
if st.session_state.audio_paths:
# Dropdown to select one of the generated audio files
audio_options = ["None"] + [
f"Generated Music {i+1}" for i in range(len(st.session_state.audio_paths))
]
# Display the audio files
for i, audio_path in enumerate(st.session_state.audio_paths):
st.audio(audio_path, format="audio/wav")
selected_audio_index = st.selectbox(
"Select one of the generated audio files for further processing:",
range(len(audio_options)),
format_func=lambda x: audio_options[x],
index=0,
key="selected_audio",
on_change=on_audio_selection_change,
)
# Button to confirm the selection
if st.button("Add Generated Music to Video"):
st.session_state.audio_mix_flag = True
# Handle Audio Mixing and Export
if st.session_state.selected_audio_path is not None and st.session_state.audio_mix_flag:
with st.spinner("Mixing Audio..."):
orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4")
orig_clip_audio = orig_clip.audio
generated_audio = AudioFileClip(st.session_state.selected_audio_path)
st.session_state.orig_audio_vol = st.slider(
"Original Audio Volume",
0,
200,
st.session_state.orig_audio_vol,
format="%d%%",
)
st.session_state.generated_audio_vol = st.slider(
"Generated Music Volume",
0,
200,
st.session_state.generated_audio_vol,
format="%d%%",
)
orig_clip_audio = volumex(
orig_clip_audio, float(st.session_state.orig_audio_vol / 100)
)
generated_audio = volumex(
generated_audio, float(st.session_state.generated_audio_vol / 100)
)
orig_clip.audio = CompositeAudioClip([orig_clip_audio, generated_audio])
final_video_path = f"{user_session_id}/out_tmp.mp4"
orig_clip.write_videofile(final_video_path)
orig_clip.close()
generated_audio.close()
st.session_state.final_video_path = final_video_path
st.video(final_video_path)
if st.session_state.final_video_path:
with open(st.session_state.final_video_path, "rb") as video_file:
st.download_button(
label="Download final video",
data=video_file,
file_name="final_video.mp4",
mime="video/mp4",
)