Spaces:
Sleeping
Sleeping
animikhaich
commited on
Commit
·
032d7c2
1
Parent(s):
ac90b4d
Added Google API key input + text field persistence
Browse files- engine/video_descriptor.py +2 -2
- main.py +55 -22
engine/video_descriptor.py
CHANGED
@@ -37,9 +37,9 @@ You must return your response using this JSON schema: {json_schema}
|
|
37 |
|
38 |
|
39 |
class DescribeVideo:
|
40 |
-
def __init__(self, model="flash"):
|
41 |
self.model = self.get_model_name(model)
|
42 |
-
__api_key = self.load_api_key()
|
43 |
self.is_safety_set = False
|
44 |
self.safety_settings = self.get_safety_settings()
|
45 |
|
|
|
37 |
|
38 |
|
39 |
class DescribeVideo:
|
40 |
+
def __init__(self, model="flash", google_api_key=None):
|
41 |
self.model = self.get_model_name(model)
|
42 |
+
__api_key = google_api_key # self.load_api_key()
|
43 |
self.is_safety_set = False
|
44 |
self.safety_settings = self.get_safety_settings()
|
45 |
|
main.py
CHANGED
@@ -83,9 +83,30 @@ if "orig_audio_vol" not in st.session_state:
|
|
83 |
st.session_state.orig_audio_vol = 100
|
84 |
if "generated_audio_vol" not in st.session_state:
|
85 |
st.session_state.generated_audio_vol = 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Sidebar
|
88 |
-
st.sidebar.title("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Basic Settings
|
91 |
st.session_state.video_model = st.sidebar.selectbox(
|
@@ -141,15 +162,19 @@ generate_button = st.sidebar.button("Generate Music")
|
|
141 |
|
142 |
# Cache the model loading
|
143 |
@st.cache_resource
|
144 |
-
def load_models(video_model_key, music_model_key):
|
145 |
-
video_descriptor = DescribeVideo(
|
|
|
|
|
146 |
audio_generator = GenerateAudio(model=music_model_map[music_model_key])
|
147 |
return video_descriptor, audio_generator
|
148 |
|
149 |
|
150 |
# Load models
|
151 |
video_descriptor, audio_generator = load_models(
|
152 |
-
st.session_state.video_model,
|
|
|
|
|
153 |
)
|
154 |
|
155 |
# Video Uploader
|
@@ -177,31 +202,37 @@ if generate_button:
|
|
177 |
user_keywords=st.session_state.user_keywords,
|
178 |
)
|
179 |
video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration
|
180 |
-
|
|
|
|
|
|
|
181 |
|
182 |
st.success("Video description generated successfully.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
-
|
185 |
-
st.text_area(
|
186 |
-
"Video Description",
|
187 |
-
video_description["Content Description"],
|
188 |
-
disabled=True,
|
189 |
-
height=120,
|
190 |
-
)
|
191 |
-
music_prompt = st.text_area(
|
192 |
-
"Music Prompt",
|
193 |
-
music_prompt,
|
194 |
-
disabled=True,
|
195 |
-
height=120,
|
196 |
-
)
|
197 |
-
|
198 |
# Generate Music
|
199 |
with st.spinner("Generating music..."):
|
200 |
if video_duration > 30:
|
201 |
st.warning(
|
202 |
"Due to hardware limitations, the maximum music length is capped at 30 seconds."
|
203 |
)
|
204 |
-
music_prompt = [music_prompt] * st.session_state.num_samples
|
205 |
audio_generator.generate_audio(music_prompt, duration=video_duration)
|
206 |
st.session_state.audio_paths = audio_generator.save_audio()
|
207 |
st.success("Music generated successfully.")
|
@@ -210,6 +241,7 @@ if generate_button:
|
|
210 |
|
211 |
# Callback function for radio button selection change
|
212 |
def on_audio_selection_change():
|
|
|
213 |
selected_audio_index = st.session_state.selected_audio
|
214 |
if selected_audio_index > 0:
|
215 |
st.session_state.selected_audio_path = st.session_state.audio_paths[
|
@@ -235,14 +267,15 @@ if st.session_state.audio_paths:
|
|
235 |
format_func=lambda x: audio_options[x],
|
236 |
index=0,
|
237 |
key="selected_audio",
|
|
|
238 |
)
|
239 |
|
240 |
# Button to confirm the selection
|
241 |
if st.button("Add Generated Music to Video"):
|
242 |
-
|
243 |
|
244 |
# Handle Audio Mixing and Export
|
245 |
-
if st.session_state.selected_audio_path is not None:
|
246 |
with st.spinner("Mixing Audio..."):
|
247 |
orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4")
|
248 |
orig_clip_audio = orig_clip.audio
|
|
|
83 |
st.session_state.orig_audio_vol = 100
|
84 |
if "generated_audio_vol" not in st.session_state:
|
85 |
st.session_state.generated_audio_vol = 100
|
86 |
+
if "generate_button_flag" not in st.session_state:
|
87 |
+
st.session_state.generate_button_flag = False
|
88 |
+
if "video_description_content" not in st.session_state:
|
89 |
+
st.session_state.video_description_content = ""
|
90 |
+
if "music_prompt" not in st.session_state:
|
91 |
+
st.session_state.music_prompt = ""
|
92 |
+
if "audio_mix_flag" not in st.session_state:
|
93 |
+
st.session_state.audio_mix_flag = False
|
94 |
+
if "google_api_key" not in st.session_state:
|
95 |
+
st.session_state.google_api_key = ""
|
96 |
|
97 |
# Sidebar
|
98 |
+
st.sidebar.title("Configuration")
|
99 |
+
|
100 |
+
# Google API Key
|
101 |
+
st.session_state.google_api_key = st.sidebar.text_input(
|
102 |
+
"Enter your Google API Key to get started:",
|
103 |
+
st.session_state.google_api_key,
|
104 |
+
type="password",
|
105 |
+
)
|
106 |
+
|
107 |
+
if not st.session_state.google_api_key:
|
108 |
+
st.warning("Please enter your Google API Key to proceed.")
|
109 |
+
st.stop()
|
110 |
|
111 |
# Basic Settings
|
112 |
st.session_state.video_model = st.sidebar.selectbox(
|
|
|
162 |
|
163 |
# Cache the model loading
|
164 |
@st.cache_resource
|
165 |
+
def load_models(video_model_key, music_model_key, google_api_key):
|
166 |
+
video_descriptor = DescribeVideo(
|
167 |
+
model=video_model_map[video_model_key], google_api_key=google_api_key
|
168 |
+
)
|
169 |
audio_generator = GenerateAudio(model=music_model_map[music_model_key])
|
170 |
return video_descriptor, audio_generator
|
171 |
|
172 |
|
173 |
# Load models
|
174 |
video_descriptor, audio_generator = load_models(
|
175 |
+
st.session_state.video_model,
|
176 |
+
st.session_state.music_model,
|
177 |
+
st.session_state.google_api_key,
|
178 |
)
|
179 |
|
180 |
# Video Uploader
|
|
|
202 |
user_keywords=st.session_state.user_keywords,
|
203 |
)
|
204 |
video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration
|
205 |
+
st.session_state.video_description_content = video_description[
|
206 |
+
"Content Description"
|
207 |
+
]
|
208 |
+
st.session_state.music_prompt = video_description["Music Prompt"]
|
209 |
|
210 |
st.success("Video description generated successfully.")
|
211 |
+
st.session_state.generate_button_flag = True
|
212 |
+
|
213 |
+
# Display Video Description and Music Prompt
|
214 |
+
if st.session_state.generate_button_flag:
|
215 |
+
st.text_area(
|
216 |
+
"Video Description",
|
217 |
+
st.session_state.video_description_content,
|
218 |
+
disabled=True,
|
219 |
+
height=120,
|
220 |
+
)
|
221 |
+
music_prompt = st.text_area(
|
222 |
+
"Music Prompt",
|
223 |
+
st.session_state.music_prompt,
|
224 |
+
disabled=True,
|
225 |
+
height=120,
|
226 |
+
)
|
227 |
|
228 |
+
if generate_button:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
# Generate Music
|
230 |
with st.spinner("Generating music..."):
|
231 |
if video_duration > 30:
|
232 |
st.warning(
|
233 |
"Due to hardware limitations, the maximum music length is capped at 30 seconds."
|
234 |
)
|
235 |
+
music_prompt = [st.session_state.music_prompt] * st.session_state.num_samples
|
236 |
audio_generator.generate_audio(music_prompt, duration=video_duration)
|
237 |
st.session_state.audio_paths = audio_generator.save_audio()
|
238 |
st.success("Music generated successfully.")
|
|
|
241 |
|
242 |
# Callback function for radio button selection change
|
243 |
def on_audio_selection_change():
|
244 |
+
st.session_state.audio_mix_flag = False
|
245 |
selected_audio_index = st.session_state.selected_audio
|
246 |
if selected_audio_index > 0:
|
247 |
st.session_state.selected_audio_path = st.session_state.audio_paths[
|
|
|
267 |
format_func=lambda x: audio_options[x],
|
268 |
index=0,
|
269 |
key="selected_audio",
|
270 |
+
on_change=on_audio_selection_change,
|
271 |
)
|
272 |
|
273 |
# Button to confirm the selection
|
274 |
if st.button("Add Generated Music to Video"):
|
275 |
+
st.session_state.audio_mix_flag = True
|
276 |
|
277 |
# Handle Audio Mixing and Export
|
278 |
+
if st.session_state.selected_audio_path is not None and st.session_state.audio_mix_flag:
|
279 |
with st.spinner("Mixing Audio..."):
|
280 |
orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4")
|
281 |
orig_clip_audio = orig_clip.audio
|