Sandesh Bharadwaj commited on
Commit
b81e951
2 Parent(s): 47c18e5 e957316

Merge pull request #3 from animikhaich/web-app-dev

Browse files
.streamlit/config.toml CHANGED
@@ -1,2 +1,5 @@
1
  [browser]
2
- gatherUsageStats = false
 
 
 
 
1
  [browser]
2
+ gatherUsageStats = false
3
+
4
+ [theme]
5
+ base = "light"
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use micromamba as the base image
2
+ FROM python:3.9.19
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /src
6
+
7
+ # Copy Requirements file
8
+ COPY requirements.txt /src
9
+
10
+ # Install the required packages
11
+ RUN pip install -r requirements.txt
12
+
13
+ # Expose port 8003 for Streamlit
14
+ EXPOSE 8003
15
+
16
+ # Copy the current directory contents into the container at /src
17
+ COPY . /src
18
+
19
+ # Run id_cleaner.py as a background process and then start Streamlit
20
+ CMD ["sh", "-c", "python id_cleaner.py & streamlit run main.py --server.port 8003"]
assets/VidTune-Logo-With-BG.png ADDED
assets/VidTune-Logo-Without-BG.png ADDED
assets/favicon.png ADDED
assets/homepage.png ADDED
engine/video_descriptor.py CHANGED
@@ -37,9 +37,9 @@ You must return your response using this JSON schema: {json_schema}
37
 
38
 
39
  class DescribeVideo:
40
- def __init__(self, model="flash"):
41
  self.model = self.get_model_name(model)
42
- __api_key = self.load_api_key()
43
  self.is_safety_set = False
44
  self.safety_settings = self.get_safety_settings()
45
 
 
37
 
38
 
39
  class DescribeVideo:
40
+ def __init__(self, model="flash", google_api_key=None):
41
  self.model = self.get_model_name(model)
42
+ __api_key = google_api_key # self.load_api_key()
43
  self.is_safety_set = False
44
  self.safety_settings = self.get_safety_settings()
45
 
environment.yml ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: vidtune
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - _libgcc_mutex=0.1
6
+ - _openmp_mutex=5.1
7
+ - ca-certificates=2024.3.11
8
+ - ld_impl_linux-64=2.38
9
+ - libffi=3.4.4
10
+ - libgcc-ng=11.2.0
11
+ - libgomp=11.2.0
12
+ - libstdcxx-ng=11.2.0
13
+ - ncurses=6.4
14
+ - openssl=3.0.14
15
+ - pip=24.0
16
+ - python=3.9.19
17
+ - readline=8.2
18
+ - setuptools=69.5.1
19
+ - sqlite=3.45.3
20
+ - tk=8.6.14
21
+ - wheel=0.43.0
22
+ - xz=5.4.6
23
+ - zlib=1.2.13
24
+ - pip:
25
+ - aiofiles==23.2.1
26
+ - altair==5.3.0
27
+ - annotated-types==0.7.0
28
+ - antlr4-python3-runtime==4.9.3
29
+ - anyio==4.4.0
30
+ - asttokens==2.4.1
31
+ - attrs==23.2.0
32
+ - audiocraft==1.3.0
33
+ - audioread==3.0.1
34
+ - av==11.0.0
35
+ - backcall==0.2.0
36
+ - beautifulsoup4==4.12.3
37
+ - bleach==6.1.0
38
+ - blinker==1.8.2
39
+ - blis==0.7.11
40
+ - cachetools==5.3.3
41
+ - catalogue==2.0.10
42
+ - certifi==2024.7.4
43
+ - cffi==1.16.0
44
+ - charset-normalizer==3.3.2
45
+ - click==8.1.7
46
+ - cloudpathlib==0.18.1
47
+ - cloudpickle==3.0.0
48
+ - colorlog==6.8.2
49
+ - confection==0.1.5
50
+ - contourpy==1.2.1
51
+ - cycler==0.12.1
52
+ - cymem==2.0.8
53
+ - decorator==4.4.2
54
+ - defusedxml==0.7.1
55
+ - demucs==4.0.1
56
+ - dnspython==2.6.1
57
+ - docopt==0.6.2
58
+ - dora-search==0.1.12
59
+ - einops==0.8.0
60
+ - email-validator==2.2.0
61
+ - encodec==0.1.1
62
+ - exceptiongroup==1.2.2
63
+ - executing==2.0.1
64
+ - fastapi==0.111.0
65
+ - fastapi-cli==0.0.4
66
+ - fastjsonschema==2.20.0
67
+ - ffmpy==0.3.2
68
+ - filelock==3.15.4
69
+ - flashy==0.0.2
70
+ - fonttools==4.53.1
71
+ - fsspec==2024.6.1
72
+ - gitdb==4.0.11
73
+ - gitpython==3.1.43
74
+ - google-ai-generativelanguage==0.6.6
75
+ - google-api-core==2.19.1
76
+ - google-api-python-client==2.137.0
77
+ - google-auth==2.32.0
78
+ - google-auth-httplib2==0.2.0
79
+ - google-generativeai==0.7.2
80
+ - googleapis-common-protos==1.63.2
81
+ - gradio==4.38.1
82
+ - gradio-client==1.1.0
83
+ - grpcio==1.64.1
84
+ - grpcio-status==1.62.2
85
+ - h11==0.14.0
86
+ - httpcore==1.0.5
87
+ - httplib2==0.22.0
88
+ - httptools==0.6.1
89
+ - httpx==0.27.0
90
+ - huggingface-hub==0.23.4
91
+ - hydra-colorlog==1.2.0
92
+ - hydra-core==1.3.2
93
+ - idna==3.7
94
+ - imageio==2.34.2
95
+ - imageio-ffmpeg==0.5.1
96
+ - importlib-metadata==8.2.0
97
+ - importlib-resources==6.4.0
98
+ - ipython==8.12.3
99
+ - jedi==0.19.1
100
+ - jinja2==3.1.4
101
+ - joblib==1.4.2
102
+ - jsonschema==4.23.0
103
+ - jsonschema-specifications==2023.12.1
104
+ - julius==0.2.7
105
+ - jupyter-client==8.6.2
106
+ - jupyter-core==5.7.2
107
+ - jupyterlab-pygments==0.3.0
108
+ - kiwisolver==1.4.5
109
+ - lameenc==1.7.0
110
+ - langcodes==3.4.0
111
+ - language-data==1.2.0
112
+ - lazy-loader==0.4
113
+ - librosa==0.10.2.post1
114
+ - lightning-utilities==0.11.5
115
+ - llvmlite==0.43.0
116
+ - marisa-trie==1.2.0
117
+ - markdown-it-py==3.0.0
118
+ - markupsafe==2.1.5
119
+ - matplotlib==3.9.1
120
+ - matplotlib-inline==0.1.7
121
+ - mdurl==0.1.2
122
+ - mistune==3.0.2
123
+ - moviepy==1.0.3
124
+ - mpmath==1.3.0
125
+ - msgpack==1.0.8
126
+ - murmurhash==1.0.10
127
+ - nbclient==0.10.0
128
+ - nbconvert==7.16.4
129
+ - nbformat==5.10.4
130
+ - networkx==3.2.1
131
+ - num2words==0.5.13
132
+ - numba==0.60.0
133
+ - numpy==1.26.4
134
+ - nvidia-cublas-cu12==12.1.3.1
135
+ - nvidia-cuda-cupti-cu12==12.1.105
136
+ - nvidia-cuda-nvrtc-cu12==12.1.105
137
+ - nvidia-cuda-runtime-cu12==12.1.105
138
+ - nvidia-cudnn-cu12==8.9.2.26
139
+ - nvidia-cufft-cu12==11.0.2.54
140
+ - nvidia-curand-cu12==10.3.2.106
141
+ - nvidia-cusolver-cu12==11.4.5.107
142
+ - nvidia-cusparse-cu12==12.1.0.106
143
+ - nvidia-nccl-cu12==2.18.1
144
+ - nvidia-nvjitlink-cu12==12.5.82
145
+ - nvidia-nvtx-cu12==12.1.105
146
+ - omegaconf==2.3.0
147
+ - openunmix==1.3.0
148
+ - orjson==3.10.6
149
+ - packaging==24.1
150
+ - pandas==2.2.2
151
+ - pandocfilters==1.5.1
152
+ - parso==0.8.4
153
+ - pexpect==4.9.0
154
+ - pickleshare==0.7.5
155
+ - pillow==10.4.0
156
+ - pipreqs==0.5.0
157
+ - platformdirs==4.2.2
158
+ - pooch==1.8.2
159
+ - preshed==3.0.9
160
+ - proglog==0.1.10
161
+ - prompt-toolkit==3.0.47
162
+ - proto-plus==1.24.0
163
+ - protobuf==4.25.3
164
+ - psutil==6.0.0
165
+ - ptyprocess==0.7.0
166
+ - pure-eval==0.2.3
167
+ - pyarrow==16.1.0
168
+ - pyasn1==0.6.0
169
+ - pyasn1-modules==0.4.0
170
+ - pycparser==2.22
171
+ - pydantic==2.7.3
172
+ - pydantic-core==2.18.4
173
+ - pydeck==0.9.1
174
+ - pydub==0.25.1
175
+ - pygments==2.18.0
176
+ - pyparsing==3.1.2
177
+ - python-dateutil==2.9.0.post0
178
+ - python-dotenv==1.0.1
179
+ - python-multipart==0.0.9
180
+ - pytz==2024.1
181
+ - pyyaml==6.0.1
182
+ - pyzmq==26.1.0
183
+ - referencing==0.35.1
184
+ - regex==2024.5.15
185
+ - requests==2.32.3
186
+ - retrying==1.3.4
187
+ - rich==13.7.1
188
+ - rpds-py==0.19.0
189
+ - rsa==4.9
190
+ - ruff==0.5.2
191
+ - safetensors==0.4.3
192
+ - scikit-learn==1.5.1
193
+ - scipy==1.13.1
194
+ - semantic-version==2.10.0
195
+ - sentencepiece==0.2.0
196
+ - shellingham==1.5.4
197
+ - six==1.16.0
198
+ - smart-open==7.0.4
199
+ - smmap==5.0.1
200
+ - sniffio==1.3.1
201
+ - soundfile==0.12.1
202
+ - soupsieve==2.5
203
+ - soxr==0.3.7
204
+ - spacy==3.7.5
205
+ - spacy-legacy==3.0.12
206
+ - spacy-loggers==1.0.5
207
+ - srsly==2.4.8
208
+ - stack-data==0.6.3
209
+ - starlette==0.37.2
210
+ - streamlit==1.36.0
211
+ - submitit==1.5.1
212
+ - sympy==1.13.0
213
+ - tenacity==8.5.0
214
+ - thinc==8.2.5
215
+ - threadpoolctl==3.5.0
216
+ - tinycss2==1.3.0
217
+ - tokenizers==0.19.1
218
+ - toml==0.10.2
219
+ - tomlkit==0.12.0
220
+ - toolz==0.12.1
221
+ - --extra-index-url https://download.pytorch.org/whl/cu121
222
+ - torch==2.1.0
223
+ - torchaudio==2.1.0
224
+ - torchdata==0.7.0
225
+ - torchmetrics==1.4.0.post0
226
+ - torchtext==0.16.0
227
+ - torchvision==0.16.0
228
+ - tornado==6.4.1
229
+ - tqdm==4.66.4
230
+ - traitlets==5.14.3
231
+ - transformers==4.42.4
232
+ - treetable==0.2.5
233
+ - triton==2.1.0
234
+ - typer==0.12.3
235
+ - typing-extensions==4.12.2
236
+ - tzdata==2024.1
237
+ - ujson==5.10.0
238
+ - uritemplate==4.1.1
239
+ - urllib3==2.2.2
240
+ - uvicorn==0.30.1
241
+ - uvloop==0.19.0
242
+ - wasabi==1.1.3
243
+ - watchdog==4.0.1
244
+ - watchfiles==0.22.0
245
+ - wcwidth==0.2.13
246
+ - weasel==0.4.1
247
+ - webencodings==0.5.1
248
+ - websockets==11.0.3
249
+ - wrapt==1.16.0
250
+ - xformers==0.0.22.post7
251
+ - yarg==0.1.9
252
+ - zipp==3.19.2
id_cleaner.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import shutil
4
+ import logging
5
+ from datetime import datetime, timedelta
6
+ from watchdog.observers import Observer
7
+ from watchdog.events import FileSystemEventHandler
8
+
9
+ # Configure logging
10
+ FORMAT = "%(asctime)s: %(levelname)s: %(message)s"
11
+ logging.basicConfig(filename="logs.log", level=logging.INFO, format=FORMAT)
12
+ STDERRLOGGER = logging.StreamHandler()
13
+ STDERRLOGGER.setFormatter(logging.Formatter(FORMAT))
14
+ logging.getLogger().addHandler(STDERRLOGGER)
15
+
16
+ class DirectoryCleanupHandler(FileSystemEventHandler):
17
+ def __init__(self, threshold_minutes=60, check_interval_minutes=5):
18
+ self.threshold = timedelta(minutes=threshold_minutes)
19
+ self.check_interval = check_interval_minutes * 60
20
+ self.last_check_time = datetime.now()
21
+
22
+ def on_modified(self, event):
23
+ if event.is_directory and event.src_path.startswith("_id_"):
24
+ logging.info(f"Detected modification in directory: {event.src_path}")
25
+ self.cleanup_directories()
26
+
27
+ def cleanup_directories(self):
28
+ now = datetime.now()
29
+ for dirpath, _, _ in os.walk("."):
30
+ if os.path.basename(dirpath).startswith("_id_"):
31
+ try:
32
+ mtime = os.path.getmtime(dirpath)
33
+ mtime_dt = datetime.fromtimestamp(mtime)
34
+ if now - mtime_dt > self.threshold:
35
+ logging.info(f"Deleting directory: {dirpath}")
36
+ shutil.rmtree(dirpath)
37
+ except Exception as e:
38
+ logging.error(f"Error deleting {dirpath}: {e}")
39
+
40
+ def start_cleanup_loop(self):
41
+ while True:
42
+ current_time = datetime.now()
43
+ if (
44
+ current_time - self.last_check_time
45
+ ).total_seconds() >= self.check_interval:
46
+ logging.info("Woke up to check directories")
47
+ self.cleanup_directories()
48
+ self.last_check_time = current_time
49
+ time.sleep(self.check_interval)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ logging.info("Starting directory cleanup script")
54
+ event_handler = DirectoryCleanupHandler(
55
+ threshold_minutes=60, check_interval_minutes=30
56
+ )
57
+ observer = Observer()
58
+ observer.schedule(event_handler, path=".", recursive=True)
59
+ observer.start()
60
+
61
+ try:
62
+ event_handler.start_cleanup_loop()
63
+ except KeyboardInterrupt:
64
+ logging.info("Stopping directory cleanup script due to keyboard interrupt")
65
+ observer.stop()
66
+ observer.join()
main.py CHANGED
@@ -5,13 +5,16 @@ from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip
5
  from moviepy.audio.fx.volumex import volumex
6
  from streamlit.runtime.scriptrunner import get_script_run_ctx
7
 
 
8
  def get_session_id():
9
  session_id = get_script_run_ctx().session_id
10
- session_id = session_id.replace('-','_')
11
- session_id = '_id_' + session_id
12
  return session_id
13
 
14
- print(get_session_id())
 
 
15
  # Define model maps
16
  video_model_map = {
17
  "Fast": "flash",
@@ -46,13 +49,27 @@ genre_map = {
46
 
47
  # Streamlit page configuration
48
  st.set_page_config(
49
- page_title="VidTune: Where Videos Find Their Melody", layout="centered"
 
 
50
  )
51
 
 
 
 
 
52
  # Title and Description
53
- st.title("VidTune: Where Videos Find Their Melody")
54
- st.write(
55
- "VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video."
 
 
 
 
 
 
 
 
56
  )
57
 
58
  # Initialize session state for advanced settings and other inputs
@@ -80,9 +97,30 @@ if "orig_audio_vol" not in st.session_state:
80
  st.session_state.orig_audio_vol = 100
81
  if "generated_audio_vol" not in st.session_state:
82
  st.session_state.generated_audio_vol = 100
83
-
 
 
 
 
 
 
 
 
 
 
84
  # Sidebar
85
- st.sidebar.title("Settings")
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # Basic Settings
88
  st.session_state.video_model = st.sidebar.selectbox(
@@ -138,26 +176,34 @@ generate_button = st.sidebar.button("Generate Music")
138
 
139
  # Cache the model loading
140
  @st.cache_resource
141
- def load_models(video_model_key, music_model_key):
142
- video_descriptor = DescribeVideo(model=video_model_map[video_model_key])
 
 
143
  audio_generator = GenerateAudio(model=music_model_map[music_model_key])
 
 
 
 
144
  return video_descriptor, audio_generator
145
 
146
 
147
  # Load models
148
  video_descriptor, audio_generator = load_models(
149
- st.session_state.video_model, st.session_state.music_model
 
 
150
  )
151
 
152
  # Video Uploader
153
  uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
154
  if uploaded_video is not None:
155
  st.session_state.uploaded_video = uploaded_video
156
- with open("temp.mp4", mode="wb") as w:
157
  w.write(uploaded_video.getvalue())
158
 
159
  # Video Player
160
- if os.path.exists("temp.mp4") and uploaded_video is not None:
161
  st.video(uploaded_video)
162
 
163
  # Submit button if video is not uploaded
@@ -168,101 +214,130 @@ if generate_button:
168
 
169
  with st.spinner("Analyzing video..."):
170
  video_description = video_descriptor.describe_video(
171
- "temp.mp4",
172
  genre=st.session_state.music_genre,
173
  bpm=st.session_state.music_bpm,
174
  user_keywords=st.session_state.user_keywords,
175
  )
176
- video_duration = VideoFileClip("temp.mp4").duration
177
- music_prompt = video_description["Music Prompt"]
 
 
 
178
 
179
  st.success("Video description generated successfully.")
 
180
 
181
- # Display Video Description and Music Prompt
182
- st.text_area(
183
- "Video Description",
184
- video_description["Content Description"],
185
- disabled=True,
186
- height=120,
187
- )
188
- music_prompt = st.text_area(
189
- "Music Prompt",
190
- music_prompt,
191
- disabled=False,
192
- height=120,
193
- )
 
194
 
 
195
  # Generate Music
196
  with st.spinner("Generating music..."):
197
  if video_duration > 30:
198
  st.warning(
199
  "Due to hardware limitations, the maximum music length is capped at 30 seconds."
200
  )
201
- music_prompt = [music_prompt] * st.session_state.num_samples
202
  audio_generator.generate_audio(music_prompt, duration=video_duration)
203
  st.session_state.audio_paths = audio_generator.save_audio()
204
  st.success("Music generated successfully.")
205
  st.balloons()
206
 
 
207
  # Callback function for radio button selection change
208
  def on_audio_selection_change():
209
- selected_index = audio_options.index(st.session_state.selected_audio) - 1
210
- if selected_index >= 0:
211
- st.session_state.selected_audio_path = st.session_state.audio_paths[selected_index]
 
 
 
212
  else:
213
  st.session_state.selected_audio_path = None
214
 
215
- # Display radio buttons and handle audio selections
216
  if st.session_state.audio_paths:
 
 
 
 
 
 
217
  for i, audio_path in enumerate(st.session_state.audio_paths):
218
  st.audio(audio_path, format="audio/wav")
219
-
220
- audio_options = ["None"]+[f"Sample {i+1}" for i in range(len(st.session_state.audio_paths))]
221
- st.radio(
222
  "Select one of the generated audio files for further processing:",
223
- audio_options,
 
224
  index=0,
225
  key="selected_audio",
226
- on_change=on_audio_selection_change
227
  )
228
-
229
- if st.session_state.selected_audio_path:
230
- st.write(f"**Selected Audio:** {st.session_state.selected_audio_path}")
 
231
 
232
  # Handle Audio Mixing and Export
233
- if st.session_state.selected_audio_path is not None:
234
- orig_clip = VideoFileClip("temp.mp4")
235
- orig_clip_audio = orig_clip.audio
236
- generated_audio = AudioFileClip(st.session_state.selected_audio_path)
237
-
238
- st.session_state.orig_audio_vol = st.slider(
239
- "Original Audio Volume", 0, 200, st.session_state.orig_audio_vol
240
- )
241
-
242
- st.session_state.generated_audio_vol = st.slider(
243
- "Selected Sample Volume", 0, 200, st.session_state.generated_audio_vol
244
- )
245
-
246
- orig_clip_audio = volumex(orig_clip_audio, float(st.session_state.orig_audio_vol/100))
247
- generated_audio = volumex(generated_audio, float(st.session_state.generated_audio_vol/100))
248
-
249
- orig_clip.audio = CompositeAudioClip([orig_clip_audio, generated_audio])
250
-
251
- final_video_path="out_tmp.mp4"
252
- orig_clip.write_videofile(final_video_path)
253
-
254
- orig_clip.close()
255
- generated_audio.close()
256
-
257
- st.session_state.final_video_path = final_video_path
258
-
259
- st.video(final_video_path)
260
-
261
- if st.session_state.final_video_path:
262
- with open(st.session_state.final_video_path, "rb") as video_file:
263
- st.download_button(
264
- label="Download final video",
265
- data=video_file,
266
- file_name="final_video.mp4",
267
- mime="video/mp4",
268
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from moviepy.audio.fx.volumex import volumex
6
  from streamlit.runtime.scriptrunner import get_script_run_ctx
7
 
8
+
9
  def get_session_id():
10
  session_id = get_script_run_ctx().session_id
11
+ session_id = session_id.replace("-", "_")
12
+ session_id = "_id_" + session_id
13
  return session_id
14
 
15
+
16
+ user_session_id = get_session_id()
17
+ os.makedirs(user_session_id, exist_ok=True)
18
  # Define model maps
19
  video_model_map = {
20
  "Fast": "flash",
 
49
 
50
  # Streamlit page configuration
51
  st.set_page_config(
52
+ page_title="VidTune: Where Videos Find Their Melody",
53
+ layout="centered",
54
+ page_icon="assets/favicon.png",
55
  )
56
 
57
+ left_co, cent_co, last_co = st.columns(3)
58
+ with cent_co:
59
+ st.image("assets/VidTune-Logo-Without-BG.png", use_column_width=False, width=200)
60
+
61
  # Title and Description
62
+ st.markdown(
63
+ """
64
+ <style>
65
+ h2, p, div, img {
66
+ text-align: center;
67
+ }
68
+ </style>
69
+ <div style="font-size: 35px; font-weight: bold;">VidTune: Where Videos Find Their Melody</div>
70
+ <p>VidTune is a web application to effortlessly tailor perfect soundtracks for your videos with AI.</p>
71
+ """,
72
+ unsafe_allow_html=True,
73
  )
74
 
75
  # Initialize session state for advanced settings and other inputs
 
97
  st.session_state.orig_audio_vol = 100
98
  if "generated_audio_vol" not in st.session_state:
99
  st.session_state.generated_audio_vol = 100
100
+ if "generate_button_flag" not in st.session_state:
101
+ st.session_state.generate_button_flag = False
102
+ if "video_description_content" not in st.session_state:
103
+ st.session_state.video_description_content = ""
104
+ if "music_prompt" not in st.session_state:
105
+ st.session_state.music_prompt = ""
106
+ if "audio_mix_flag" not in st.session_state:
107
+ st.session_state.audio_mix_flag = False
108
+ if "google_api_key" not in st.session_state:
109
+ st.session_state.google_api_key = ""
110
+
111
  # Sidebar
112
+ st.sidebar.title("Configuration")
113
+
114
+ # Google API Key
115
+ st.session_state.google_api_key = st.sidebar.text_input(
116
+ "Enter your [Google API Key](https://ai.google.dev/gemini-api/docs/api-key) to get started :",
117
+ st.session_state.google_api_key,
118
+ type="password",
119
+ )
120
+
121
+ if not st.session_state.google_api_key:
122
+ st.warning("Please enter your Google API Key to proceed.")
123
+ st.stop()
124
 
125
  # Basic Settings
126
  st.session_state.video_model = st.sidebar.selectbox(
 
176
 
177
  # Cache the model loading
178
  @st.cache_resource
179
+ def load_models(video_model_key, music_model_key, google_api_key):
180
+ video_descriptor = DescribeVideo(
181
+ model=video_model_map[video_model_key], google_api_key=google_api_key
182
+ )
183
  audio_generator = GenerateAudio(model=music_model_map[music_model_key])
184
+ if audio_generator.device == "cpu":
185
+ st.warning(
186
+ "The music generator model is running on CPU. For faster results, consider using a GPU."
187
+ )
188
  return video_descriptor, audio_generator
189
 
190
 
191
  # Load models
192
  video_descriptor, audio_generator = load_models(
193
+ st.session_state.video_model,
194
+ st.session_state.music_model,
195
+ st.session_state.google_api_key,
196
  )
197
 
198
  # Video Uploader
199
  uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
200
  if uploaded_video is not None:
201
  st.session_state.uploaded_video = uploaded_video
202
+ with open(f"{user_session_id}/temp.mp4", mode="wb") as w:
203
  w.write(uploaded_video.getvalue())
204
 
205
  # Video Player
206
+ if os.path.exists(f"{user_session_id}/temp.mp4") and uploaded_video is not None:
207
  st.video(uploaded_video)
208
 
209
  # Submit button if video is not uploaded
 
214
 
215
  with st.spinner("Analyzing video..."):
216
  video_description = video_descriptor.describe_video(
217
+ f"{user_session_id}/temp.mp4",
218
  genre=st.session_state.music_genre,
219
  bpm=st.session_state.music_bpm,
220
  user_keywords=st.session_state.user_keywords,
221
  )
222
+ video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration
223
+ st.session_state.video_description_content = video_description[
224
+ "Content Description"
225
+ ]
226
+ st.session_state.music_prompt = video_description["Music Prompt"]
227
 
228
  st.success("Video description generated successfully.")
229
+ st.session_state.generate_button_flag = True
230
 
231
+ # Display Video Description and Music Prompt
232
+ if st.session_state.generate_button_flag:
233
+ st.text_area(
234
+ "Video Description",
235
+ st.session_state.video_description_content,
236
+ disabled=True,
237
+ height=120,
238
+ )
239
+ music_prompt = st.text_area(
240
+ "Music Prompt",
241
+ st.session_state.music_prompt,
242
+ disabled=True,
243
+ height=120,
244
+ )
245
 
246
+ if generate_button:
247
  # Generate Music
248
  with st.spinner("Generating music..."):
249
  if video_duration > 30:
250
  st.warning(
251
  "Due to hardware limitations, the maximum music length is capped at 30 seconds."
252
  )
253
+ music_prompt = [st.session_state.music_prompt] * st.session_state.num_samples
254
  audio_generator.generate_audio(music_prompt, duration=video_duration)
255
  st.session_state.audio_paths = audio_generator.save_audio()
256
  st.success("Music generated successfully.")
257
  st.balloons()
258
 
259
+
260
  # Callback function for radio button selection change
261
  def on_audio_selection_change():
262
+ st.session_state.audio_mix_flag = False
263
+ selected_audio_index = st.session_state.selected_audio
264
+ if selected_audio_index > 0:
265
+ st.session_state.selected_audio_path = st.session_state.audio_paths[
266
+ selected_audio_index - 1
267
+ ]
268
  else:
269
  st.session_state.selected_audio_path = None
270
 
271
+
272
  if st.session_state.audio_paths:
273
+ # Dropdown to select one of the generated audio files
274
+ audio_options = ["None"] + [
275
+ f"Generated Music {i+1}" for i in range(len(st.session_state.audio_paths))
276
+ ]
277
+
278
+ # Display the audio files
279
  for i, audio_path in enumerate(st.session_state.audio_paths):
280
  st.audio(audio_path, format="audio/wav")
281
+
282
+ selected_audio_index = st.selectbox(
 
283
  "Select one of the generated audio files for further processing:",
284
+ range(len(audio_options)),
285
+ format_func=lambda x: audio_options[x],
286
  index=0,
287
  key="selected_audio",
288
+ on_change=on_audio_selection_change,
289
  )
290
+
291
+ # Button to confirm the selection
292
+ if st.button("Add Generated Music to Video"):
293
+ st.session_state.audio_mix_flag = True
294
 
295
  # Handle Audio Mixing and Export
296
+ if st.session_state.selected_audio_path is not None and st.session_state.audio_mix_flag:
297
+ with st.spinner("Mixing Audio..."):
298
+ orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4")
299
+ orig_clip_audio = orig_clip.audio
300
+ generated_audio = AudioFileClip(st.session_state.selected_audio_path)
301
+
302
+ st.session_state.orig_audio_vol = st.slider(
303
+ "Original Audio Volume",
304
+ 0,
305
+ 200,
306
+ st.session_state.orig_audio_vol,
307
+ format="%d%%",
308
+ )
309
+
310
+ st.session_state.generated_audio_vol = st.slider(
311
+ "Generated Music Volume",
312
+ 0,
313
+ 200,
314
+ st.session_state.generated_audio_vol,
315
+ format="%d%%",
316
+ )
317
+
318
+ orig_clip_audio = volumex(
319
+ orig_clip_audio, float(st.session_state.orig_audio_vol / 100)
320
+ )
321
+ generated_audio = volumex(
322
+ generated_audio, float(st.session_state.generated_audio_vol / 100)
323
+ )
324
+
325
+ orig_clip.audio = CompositeAudioClip([orig_clip_audio, generated_audio])
326
+
327
+ final_video_path = f"{user_session_id}/out_tmp.mp4"
328
+ orig_clip.write_videofile(final_video_path)
329
+
330
+ orig_clip.close()
331
+ generated_audio.close()
332
+
333
+ st.session_state.final_video_path = final_video_path
334
+
335
+ st.video(final_video_path)
336
+ if st.session_state.final_video_path:
337
+ with open(st.session_state.final_video_path, "rb") as video_file:
338
+ st.download_button(
339
+ label="Download final video",
340
+ data=video_file,
341
+ file_name="final_video.mp4",
342
+ mime="video/mp4",
343
+ )
requirements.txt CHANGED
@@ -1,9 +1,228 @@
 
 
 
 
 
 
 
1
  audiocraft==1.3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  fastapi==0.111.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  numpy==1.26.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  pydantic==2.7.3
5
- Requests==2.32.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  scipy==1.13.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  torch==2.1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  uvicorn==0.30.1
9
- psutil==6.0.0
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.3.0
3
+ annotated-types==0.7.0
4
+ antlr4-python3-runtime==4.9.3
5
+ anyio==4.4.0
6
+ asttokens==2.4.1
7
+ attrs==23.2.0
8
  audiocraft==1.3.0
9
+ audioread==3.0.1
10
+ av==11.0.0
11
+ backcall==0.2.0
12
+ beautifulsoup4==4.12.3
13
+ bleach==6.1.0
14
+ blinker==1.8.2
15
+ blis==0.7.11
16
+ cachetools==5.3.3
17
+ catalogue==2.0.10
18
+ certifi==2024.7.4
19
+ cffi==1.16.0
20
+ charset-normalizer==3.3.2
21
+ click==8.1.7
22
+ cloudpathlib==0.18.1
23
+ cloudpickle==3.0.0
24
+ colorlog==6.8.2
25
+ confection==0.1.5
26
+ contourpy==1.2.1
27
+ cycler==0.12.1
28
+ cymem==2.0.8
29
+ decorator==4.4.2
30
+ defusedxml==0.7.1
31
+ demucs==4.0.1
32
+ dnspython==2.6.1
33
+ docopt==0.6.2
34
+ dora_search==0.1.12
35
+ einops==0.8.0
36
+ email_validator==2.2.0
37
+ encodec==0.1.1
38
+ exceptiongroup==1.2.2
39
+ executing==2.0.1
40
  fastapi==0.111.0
41
+ fastapi-cli==0.0.4
42
+ fastjsonschema==2.20.0
43
+ ffmpy==0.3.2
44
+ filelock==3.15.4
45
+ flashy==0.0.2
46
+ fonttools==4.53.1
47
+ fsspec==2024.6.1
48
+ gitdb==4.0.11
49
+ GitPython==3.1.43
50
+ google-ai-generativelanguage==0.6.6
51
+ google-api-core==2.19.1
52
+ google-api-python-client==2.137.0
53
+ google-auth==2.32.0
54
+ google-auth-httplib2==0.2.0
55
+ google-generativeai==0.7.2
56
+ googleapis-common-protos==1.63.2
57
+ gradio==4.38.1
58
+ gradio_client==1.1.0
59
+ grpcio==1.64.1
60
+ grpcio-status==1.62.2
61
+ h11==0.14.0
62
+ httpcore==1.0.5
63
+ httplib2==0.22.0
64
+ httptools==0.6.1
65
+ httpx==0.27.0
66
+ huggingface-hub==0.23.4
67
+ hydra-colorlog==1.2.0
68
+ hydra-core==1.3.2
69
+ idna==3.7
70
+ imageio==2.34.2
71
+ imageio-ffmpeg==0.5.1
72
+ importlib_metadata==8.2.0
73
+ importlib_resources==6.4.0
74
+ ipython==8.12.3
75
+ jedi==0.19.1
76
+ Jinja2==3.1.4
77
+ joblib==1.4.2
78
+ jsonschema==4.23.0
79
+ jsonschema-specifications==2023.12.1
80
+ julius==0.2.7
81
+ jupyter_client==8.6.2
82
+ jupyter_core==5.7.2
83
+ jupyterlab_pygments==0.3.0
84
+ kiwisolver==1.4.5
85
+ lameenc==1.7.0
86
+ langcodes==3.4.0
87
+ language_data==1.2.0
88
+ lazy_loader==0.4
89
+ librosa==0.10.2.post1
90
+ lightning-utilities==0.11.5
91
+ llvmlite==0.43.0
92
+ marisa-trie==1.2.0
93
+ markdown-it-py==3.0.0
94
+ MarkupSafe==2.1.5
95
+ matplotlib==3.9.1
96
+ matplotlib-inline==0.1.7
97
+ mdurl==0.1.2
98
+ mistune==3.0.2
99
+ moviepy==1.0.3
100
+ mpmath==1.3.0
101
+ msgpack==1.0.8
102
+ murmurhash==1.0.10
103
+ nbclient==0.10.0
104
+ nbconvert==7.16.4
105
+ nbformat==5.10.4
106
+ networkx==3.2.1
107
+ num2words==0.5.13
108
+ numba==0.60.0
109
  numpy==1.26.4
110
+ nvidia-cublas-cu12==12.1.3.1
111
+ nvidia-cuda-cupti-cu12==12.1.105
112
+ nvidia-cuda-nvrtc-cu12==12.1.105
113
+ nvidia-cuda-runtime-cu12==12.1.105
114
+ nvidia-cudnn-cu12==8.9.2.26
115
+ nvidia-cufft-cu12==11.0.2.54
116
+ nvidia-curand-cu12==10.3.2.106
117
+ nvidia-cusolver-cu12==11.4.5.107
118
+ nvidia-cusparse-cu12==12.1.0.106
119
+ nvidia-nccl-cu12==2.18.1
120
+ nvidia-nvjitlink-cu12==12.5.82
121
+ nvidia-nvtx-cu12==12.1.105
122
+ omegaconf==2.3.0
123
+ openunmix==1.3.0
124
+ orjson==3.10.6
125
+ packaging==24.1
126
+ pandas==2.2.2
127
+ pandocfilters==1.5.1
128
+ parso==0.8.4
129
+ pexpect==4.9.0
130
+ pickleshare==0.7.5
131
+ pillow==10.4.0
132
+ pipreqs==0.5.0
133
+ platformdirs==4.2.2
134
+ pooch==1.8.2
135
+ preshed==3.0.9
136
+ proglog==0.1.10
137
+ prompt_toolkit==3.0.47
138
+ proto-plus==1.24.0
139
+ protobuf==4.25.3
140
+ psutil==6.0.0
141
+ ptyprocess==0.7.0
142
+ pure_eval==0.2.3
143
+ pyarrow==16.1.0
144
+ pyasn1==0.6.0
145
+ pyasn1_modules==0.4.0
146
+ pycparser==2.22
147
  pydantic==2.7.3
148
+ pydantic_core==2.18.4
149
+ pydeck==0.9.1
150
+ pydub==0.25.1
151
+ Pygments==2.18.0
152
+ pyparsing==3.1.2
153
+ python-dateutil==2.9.0.post0
154
+ python-dotenv==1.0.1
155
+ python-multipart==0.0.9
156
+ pytz==2024.1
157
+ PyYAML==6.0.1
158
+ pyzmq==26.1.0
159
+ referencing==0.35.1
160
+ regex==2024.5.15
161
+ requests==2.32.3
162
+ retrying==1.3.4
163
+ rich==13.7.1
164
+ rpds-py==0.19.0
165
+ rsa==4.9
166
+ ruff==0.5.2
167
+ safetensors==0.4.3
168
+ scikit-learn==1.5.1
169
  scipy==1.13.1
170
+ semantic-version==2.10.0
171
+ sentencepiece==0.2.0
172
+ shellingham==1.5.4
173
+ six==1.16.0
174
+ smart-open==7.0.4
175
+ smmap==5.0.1
176
+ sniffio==1.3.1
177
+ soundfile==0.12.1
178
+ soupsieve==2.5
179
+ soxr==0.3.7
180
+ spacy==3.7.5
181
+ spacy-legacy==3.0.12
182
+ spacy-loggers==1.0.5
183
+ srsly==2.4.8
184
+ stack-data==0.6.3
185
+ starlette==0.37.2
186
+ streamlit==1.36.0
187
+ submitit==1.5.1
188
+ sympy==1.13.0
189
+ tenacity==8.5.0
190
+ thinc==8.2.5
191
+ threadpoolctl==3.5.0
192
+ tinycss2==1.3.0
193
+ tokenizers==0.19.1
194
+ toml==0.10.2
195
+ tomlkit==0.12.0
196
+ toolz==0.12.1
197
+ --extra-index-url https://download.pytorch.org/whl/cu121
198
  torch==2.1.0
199
+ torchaudio==2.1.0
200
+ torchdata==0.7.0
201
+ torchmetrics==1.4.0.post0
202
+ torchtext==0.16.0
203
+ torchvision==0.16.0
204
+ tornado==6.4.1
205
+ tqdm==4.66.4
206
+ traitlets==5.14.3
207
+ transformers==4.42.4
208
+ treetable==0.2.5
209
+ triton==2.1.0
210
+ typer==0.12.3
211
+ typing_extensions==4.12.2
212
+ tzdata==2024.1
213
+ ujson==5.10.0
214
+ uritemplate==4.1.1
215
+ urllib3==2.2.2
216
  uvicorn==0.30.1
217
+ uvloop==0.19.0
218
+ wasabi==1.1.3
219
+ watchdog==4.0.1
220
+ watchfiles==0.22.0
221
+ wcwidth==0.2.13
222
+ weasel==0.4.1
223
+ webencodings==0.5.1
224
+ websockets==11.0.3
225
+ wrapt==1.16.0
226
+ xformers==0.0.22.post7
227
+ yarg==0.1.9
228
+ zipp==3.19.2