Spaces:
Runtime error
Runtime error
crypto-code
commited on
Commit
·
686c4ae
1
Parent(s):
d9cb0bd
Update app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,7 @@ import torchvision.transforms as transforms
|
|
20 |
import av
|
21 |
import subprocess
|
22 |
import librosa
|
|
|
23 |
|
24 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
25 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
@@ -33,8 +34,6 @@ class dotdict(dict):
|
|
33 |
|
34 |
args = dotdict(args)
|
35 |
|
36 |
-
generated_audio_files = []
|
37 |
-
|
38 |
llama_type = args.llama_type
|
39 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
40 |
llama_tokenzier_path = args.llama_dir
|
@@ -118,7 +117,6 @@ def parse_text(text, image_path, video_path, audio_path):
|
|
118 |
|
119 |
|
120 |
def save_audio_to_local(audio, sec):
|
121 |
-
global generated_audio_files
|
122 |
if not os.path.exists('temp'):
|
123 |
os.mkdir('temp')
|
124 |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
|
@@ -126,7 +124,6 @@ def save_audio_to_local(audio, sec):
|
|
126 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
127 |
else:
|
128 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
129 |
-
generated_audio_files.append(filename)
|
130 |
return filename
|
131 |
|
132 |
|
@@ -166,8 +163,6 @@ def reset_dialog():
|
|
166 |
|
167 |
|
168 |
def reset_state():
|
169 |
-
global generated_audio_files
|
170 |
-
generated_audio_files = []
|
171 |
return None, None, None, None, [], [], []
|
172 |
|
173 |
|
@@ -214,6 +209,12 @@ def get_video_length(filename):
|
|
214 |
def get_audio_length(filename):
|
215 |
return int(round(librosa.get_duration(path=filename)))
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
def predict(
|
219 |
prompt_input,
|
@@ -226,7 +227,6 @@ def predict(
|
|
226 |
history,
|
227 |
modality_cache,
|
228 |
audio_length_in_s):
|
229 |
-
global generated_audio_files
|
230 |
prompts = [llama.format_prompt(prompt_input)]
|
231 |
prompts = [model.tokenizer(x).input_ids for x in prompts]
|
232 |
print(image_path, audio_path, video_path)
|
@@ -244,11 +244,11 @@ def predict(
|
|
244 |
container = av.open(video_path)
|
245 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
246 |
video = read_video_pyav(container=container, indices=indices)
|
247 |
-
|
248 |
-
if
|
249 |
-
audio_length_in_s = get_audio_length(
|
250 |
sample_rate = 24000
|
251 |
-
waveform, sr = torchaudio.load(
|
252 |
if sample_rate != sr:
|
253 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
254 |
audio = torch.mean(waveform, 0)
|
@@ -259,7 +259,6 @@ def predict(
|
|
259 |
print(f"Video Length: {audio_length_in_s}")
|
260 |
if audio_path is not None:
|
261 |
audio_length_in_s = get_audio_length(audio_path)
|
262 |
-
generated_audio_files.append(audio_path)
|
263 |
print(f"Audio Length: {audio_length_in_s}")
|
264 |
|
265 |
print(image, video, audio)
|
|
|
20 |
import av
|
21 |
import subprocess
|
22 |
import librosa
|
23 |
+
import re
|
24 |
|
25 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
26 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
|
|
34 |
|
35 |
args = dotdict(args)
|
36 |
|
|
|
|
|
37 |
llama_type = args.llama_type
|
38 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
39 |
llama_tokenzier_path = args.llama_dir
|
|
|
117 |
|
118 |
|
119 |
def save_audio_to_local(audio, sec):
|
|
|
120 |
if not os.path.exists('temp'):
|
121 |
os.mkdir('temp')
|
122 |
filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
|
|
|
124 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
125 |
else:
|
126 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
|
|
127 |
return filename
|
128 |
|
129 |
|
|
|
163 |
|
164 |
|
165 |
def reset_state():
|
|
|
|
|
166 |
return None, None, None, None, [], [], []
|
167 |
|
168 |
|
|
|
209 |
def get_audio_length(filename):
|
210 |
return int(round(librosa.get_duration(path=filename)))
|
211 |
|
212 |
+
def get_last_audio():
|
213 |
+
for hist in history[::-1]:
|
214 |
+
print(hist)
|
215 |
+
if "<audio controls playsinline>" in hist[1]:
|
216 |
+
return re.search('<audio controls playsinline><source src=\"\.\/file=(.*)\" type="audio\/wav"><\/audio>', hist[1]).group(1)
|
217 |
+
return None
|
218 |
|
219 |
def predict(
|
220 |
prompt_input,
|
|
|
227 |
history,
|
228 |
modality_cache,
|
229 |
audio_length_in_s):
|
|
|
230 |
prompts = [llama.format_prompt(prompt_input)]
|
231 |
prompts = [model.tokenizer(x).input_ids for x in prompts]
|
232 |
print(image_path, audio_path, video_path)
|
|
|
244 |
container = av.open(video_path)
|
245 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
246 |
video = read_video_pyav(container=container, indices=indices)
|
247 |
+
generated_audio_file = get_last_audio()
|
248 |
+
if generated_audio_file is not None:
|
249 |
+
audio_length_in_s = get_audio_length(generated_audio_file)
|
250 |
sample_rate = 24000
|
251 |
+
waveform, sr = torchaudio.load(generated_audio_file)
|
252 |
if sample_rate != sr:
|
253 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
254 |
audio = torch.mean(waveform, 0)
|
|
|
259 |
print(f"Video Length: {audio_length_in_s}")
|
260 |
if audio_path is not None:
|
261 |
audio_length_in_s = get_audio_length(audio_path)
|
|
|
262 |
print(f"Audio Length: {audio_length_in_s}")
|
263 |
|
264 |
print(image, video, audio)
|