Spaces:
Runtime error
Runtime error
crypto-code
commited on
Commit
β’
e77fc2d
1
Parent(s):
686c4ae
Update app.py
Browse files
app.py
CHANGED
@@ -20,7 +20,6 @@ import torchvision.transforms as transforms
|
|
20 |
import av
|
21 |
import subprocess
|
22 |
import librosa
|
23 |
-
import re
|
24 |
|
25 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
26 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
@@ -34,6 +33,8 @@ class dotdict(dict):
|
|
34 |
|
35 |
args = dotdict(args)
|
36 |
|
|
|
|
|
37 |
llama_type = args.llama_type
|
38 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
39 |
llama_tokenzier_path = args.llama_dir
|
@@ -117,6 +118,7 @@ def parse_text(text, image_path, video_path, audio_path):
|
|
117 |
|
118 |
|
119 |
def save_audio_to_local(audio, sec):
|
|
|
120 |
if not os.path.exists('temp'):
|
121 |
os.mkdir('temp')
|
122 |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
|
@@ -124,6 +126,7 @@ def save_audio_to_local(audio, sec):
|
|
124 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
125 |
else:
|
126 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
|
|
127 |
return filename
|
128 |
|
129 |
|
@@ -159,10 +162,14 @@ def reset_user_input():
|
|
159 |
|
160 |
|
161 |
def reset_dialog():
|
|
|
|
|
162 |
return [], []
|
163 |
|
164 |
|
165 |
def reset_state():
|
|
|
|
|
166 |
return None, None, None, None, [], [], []
|
167 |
|
168 |
|
@@ -209,12 +216,6 @@ def get_video_length(filename):
|
|
209 |
def get_audio_length(filename):
|
210 |
return int(round(librosa.get_duration(path=filename)))
|
211 |
|
212 |
-
def get_last_audio():
|
213 |
-
for hist in history[::-1]:
|
214 |
-
print(hist)
|
215 |
-
if "<audio controls playsinline>" in hist[1]:
|
216 |
-
return re.search('<audio controls playsinline><source src=\"\.\/file=(.*)\" type="audio\/wav"><\/audio>', hist[1]).group(1)
|
217 |
-
return None
|
218 |
|
219 |
def predict(
|
220 |
prompt_input,
|
@@ -227,6 +228,7 @@ def predict(
|
|
227 |
history,
|
228 |
modality_cache,
|
229 |
audio_length_in_s):
|
|
|
230 |
prompts = [llama.format_prompt(prompt_input)]
|
231 |
prompts = [model.tokenizer(x).input_ids for x in prompts]
|
232 |
print(image_path, audio_path, video_path)
|
@@ -244,11 +246,11 @@ def predict(
|
|
244 |
container = av.open(video_path)
|
245 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
246 |
video = read_video_pyav(container=container, indices=indices)
|
247 |
-
|
248 |
-
if
|
249 |
-
audio_length_in_s = get_audio_length(
|
250 |
sample_rate = 24000
|
251 |
-
waveform, sr = torchaudio.load(
|
252 |
if sample_rate != sr:
|
253 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
254 |
audio = torch.mean(waveform, 0)
|
@@ -259,6 +261,7 @@ def predict(
|
|
259 |
print(f"Video Length: {audio_length_in_s}")
|
260 |
if audio_path is not None:
|
261 |
audio_length_in_s = get_audio_length(audio_path)
|
|
|
262 |
print(f"Audio Length: {audio_length_in_s}")
|
263 |
|
264 |
print(image, video, audio)
|
@@ -350,4 +353,4 @@ with gr.Blocks() as demo:
|
|
350 |
], show_progress=True)
|
351 |
|
352 |
if __name__ == "__main__":
|
353 |
-
demo.launch()
|
|
|
20 |
import av
|
21 |
import subprocess
|
22 |
import librosa
|
|
|
23 |
|
24 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
25 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
|
|
33 |
|
34 |
args = dotdict(args)
|
35 |
|
36 |
+
generated_audio_files = []
|
37 |
+
|
38 |
llama_type = args.llama_type
|
39 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
40 |
llama_tokenzier_path = args.llama_dir
|
|
|
118 |
|
119 |
|
120 |
def save_audio_to_local(audio, sec):
|
121 |
+
global generated_audio_files
|
122 |
if not os.path.exists('temp'):
|
123 |
os.mkdir('temp')
|
124 |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
|
|
|
126 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
127 |
else:
|
128 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
129 |
+
generated_audio_files.append(filename)
|
130 |
return filename
|
131 |
|
132 |
|
|
|
162 |
|
163 |
|
164 |
def reset_dialog():
|
165 |
+
global generated_audio_files
|
166 |
+
generated_audio_files = []
|
167 |
return [], []
|
168 |
|
169 |
|
170 |
def reset_state():
|
171 |
+
global generated_audio_files
|
172 |
+
generated_audio_files = []
|
173 |
return None, None, None, None, [], [], []
|
174 |
|
175 |
|
|
|
216 |
def get_audio_length(filename):
|
217 |
return int(round(librosa.get_duration(path=filename)))
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
def predict(
|
221 |
prompt_input,
|
|
|
228 |
history,
|
229 |
modality_cache,
|
230 |
audio_length_in_s):
|
231 |
+
global generated_audio_files
|
232 |
prompts = [llama.format_prompt(prompt_input)]
|
233 |
prompts = [model.tokenizer(x).input_ids for x in prompts]
|
234 |
print(image_path, audio_path, video_path)
|
|
|
246 |
container = av.open(video_path)
|
247 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
248 |
video = read_video_pyav(container=container, indices=indices)
|
249 |
+
|
250 |
+
if len(generated_audio_files) != 0:
|
251 |
+
audio_length_in_s = get_audio_length(generated_audio_files[-1])
|
252 |
sample_rate = 24000
|
253 |
+
waveform, sr = torchaudio.load(generated_audio_files[-1])
|
254 |
if sample_rate != sr:
|
255 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
256 |
audio = torch.mean(waveform, 0)
|
|
|
261 |
print(f"Video Length: {audio_length_in_s}")
|
262 |
if audio_path is not None:
|
263 |
audio_length_in_s = get_audio_length(audio_path)
|
264 |
+
generated_audio_files.append(audio_path)
|
265 |
print(f"Audio Length: {audio_length_in_s}")
|
266 |
|
267 |
print(image, video, audio)
|
|
|
353 |
], show_progress=True)
|
354 |
|
355 |
if __name__ == "__main__":
|
356 |
+
demo.launch()
|