Spaces:
Sleeping
Sleeping
Siddhant
commited on
Commit
•
f78ed8b
1
Parent(s):
3168a2f
Add eval metrics
Browse files- app.py +101 -42
- eval/ASR_WER.py +55 -0
- eval/LLM_Metrics.py +106 -0
- eval/TTS_intelligibility.py +58 -0
- eval/TTS_speech_quality.py +42 -0
- eval/vert.py +272 -0
- requirements.txt +2 -1
- tts_samples/sample1.wav +0 -0
- utils.py +12 -0
- versa.sh +4 -0
app.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1 |
import os
|
2 |
import shutil
|
3 |
-
import
|
|
|
|
|
|
|
|
|
4 |
from huggingface_hub import HfApi
|
|
|
5 |
api = HfApi()
|
6 |
import nltk
|
7 |
nltk.download('averaged_perceptron_tagger_eng')
|
@@ -29,7 +34,13 @@ text2speech = Text2Speech.from_pretrained(
|
|
29 |
noise_scale=0.333,
|
30 |
noise_scale_dur=0.333,
|
31 |
)
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
import numpy as np
|
34 |
from VAD.vad_iterator import VADIterator
|
35 |
import torch
|
@@ -66,6 +77,8 @@ import soundfile as sf
|
|
66 |
import kaldiio
|
67 |
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch
|
68 |
from espnet2.bin.s2t_inference import Speech2Text
|
|
|
|
|
69 |
|
70 |
s2t = Speech2TextGreedySearch.from_pretrained(
|
71 |
"pyf98/owsm_ctc_v3.1_1B",
|
@@ -74,6 +87,9 @@ s2t = Speech2TextGreedySearch.from_pretrained(
|
|
74 |
lang_sym='<eng>',
|
75 |
task_sym='<asr>',
|
76 |
)
|
|
|
|
|
|
|
77 |
|
78 |
start_event = torch.cuda.Event(enable_timing=True)
|
79 |
end_event = torch.cuda.Event(enable_timing=True)
|
@@ -84,23 +100,13 @@ res = s2t(speech)
|
|
84 |
end_event.record()
|
85 |
torch.cuda.synchronize()
|
86 |
|
87 |
-
def int2float(sound):
|
88 |
-
"""
|
89 |
-
Taken from https://github.com/snakers4/silero-vad
|
90 |
-
"""
|
91 |
-
|
92 |
-
abs_max = np.abs(sound).max()
|
93 |
-
sound = sound.astype("float32")
|
94 |
-
if abs_max > 0:
|
95 |
-
sound *= 1 / 32768
|
96 |
-
sound = sound.squeeze() # depends on the use case
|
97 |
-
return sound
|
98 |
-
|
99 |
text_str=""
|
100 |
asr_output_str=""
|
101 |
vad_output=None
|
102 |
audio_output = None
|
103 |
audio_output1 = None
|
|
|
|
|
104 |
min_speech_ms=500
|
105 |
max_speech_ms=float("inf")
|
106 |
# ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
|
@@ -239,8 +245,7 @@ def relevant_vote4_last_response(
|
|
239 |
import json
|
240 |
import time
|
241 |
|
242 |
-
|
243 |
-
def transcribe(stream, new_chunk, option):
|
244 |
sr, y = new_chunk
|
245 |
global text_str
|
246 |
global chat
|
@@ -252,6 +257,11 @@ def transcribe(stream, new_chunk, option):
|
|
252 |
global start_record_time
|
253 |
global sids
|
254 |
global spembs
|
|
|
|
|
|
|
|
|
|
|
255 |
if stream is None:
|
256 |
stream=y
|
257 |
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
|
@@ -295,11 +305,14 @@ def transcribe(stream, new_chunk, option):
|
|
295 |
array = torch.cat(vad_output).cpu().numpy()
|
296 |
duration_ms = len(array) / sr * 1000
|
297 |
if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
|
298 |
-
print(len(array))
|
299 |
-
array = librosa.util.fix_length(array, size=(16000 * 30))
|
300 |
-
print(len(array))
|
301 |
start_time = time.time()
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
vad_output = None
|
304 |
if len(prompt.strip().split())<2:
|
305 |
text_str1=text_str
|
@@ -310,8 +323,11 @@ def transcribe(stream, new_chunk, option):
|
|
310 |
print(len(prompt.strip().split()))
|
311 |
print(prompt)
|
312 |
asr_output_str=prompt
|
|
|
313 |
# yield (stream,asr_output_str,text_str, audio_output)
|
314 |
-
|
|
|
|
|
315 |
# prompt=ASR_model.transcribe(array)["text"].strip()
|
316 |
chat.append({"role": user_role, "content": prompt})
|
317 |
chat_messages = chat.to_list()
|
@@ -322,14 +338,16 @@ def transcribe(stream, new_chunk, option):
|
|
322 |
temperature=0.0,
|
323 |
do_sample=False,
|
324 |
)
|
325 |
-
|
326 |
-
|
327 |
generated_text = output[0]['generated_text'][-1]["content"]
|
328 |
|
329 |
# torch.mps.empty_cache()
|
330 |
|
331 |
chat.append({"role": "assistant", "content": generated_text})
|
332 |
text_str=generated_text
|
|
|
|
|
333 |
# import pdb;pdb.set_trace()
|
334 |
with torch.no_grad():
|
335 |
if option=="ChatTTS":
|
@@ -347,7 +365,7 @@ def transcribe(stream, new_chunk, option):
|
|
347 |
audio_output=(text2speech.fs, audio_chunk)
|
348 |
audio_output1=(orig_sr,stream)
|
349 |
stream=y
|
350 |
-
|
351 |
# else:
|
352 |
# audio_output=None
|
353 |
text_str1=text_str
|
@@ -362,20 +380,22 @@ def transcribe(stream, new_chunk, option):
|
|
362 |
if current_record_time-start_record_time>300:
|
363 |
gr.Info("Conversations are limited to 5 minutes. The session will restart in approximately 60 seconds. Please wait for the demo to reset. Close this message once you have read it.", duration=None)
|
364 |
yield stream,gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Audio(visible=False)
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
chat.buffer=[{"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."}]
|
374 |
text_str=""
|
375 |
audio_output = None
|
376 |
audio_output1 = None
|
377 |
asr_output_str = ""
|
378 |
start_record_time = None
|
|
|
|
|
379 |
shutil.rmtree('flagged_data_points')
|
380 |
os.mkdir("flagged_data_points")
|
381 |
yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
|
@@ -486,6 +506,8 @@ def handle_LLM_selection(option):
|
|
486 |
|
487 |
def handle_ASR_selection(option):
|
488 |
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
|
|
|
|
489 |
global s2t
|
490 |
if option=="espnet/owsm_v3.1_ebf":
|
491 |
s2t = Speech2Text.from_pretrained(
|
@@ -495,6 +517,14 @@ def handle_ASR_selection(option):
|
|
495 |
beam_size=1,
|
496 |
predict_time=False,
|
497 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
else:
|
499 |
s2t = Speech2TextGreedySearch.from_pretrained(
|
500 |
option,
|
@@ -508,17 +538,36 @@ def handle_ASR_selection(option):
|
|
508 |
end_event = torch.cuda.Event(enable_timing=True)
|
509 |
torch.cuda.synchronize()
|
510 |
start_event.record()
|
511 |
-
|
512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
end_event.record()
|
514 |
torch.cuda.synchronize()
|
515 |
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
with gr.Blocks(
|
523 |
title="E2E Spoken Dialog System",
|
524 |
) as demo:
|
@@ -527,7 +576,7 @@ with gr.Blocks(
|
|
527 |
user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
|
528 |
with gr.Row():
|
529 |
ASR_radio = gr.Radio(
|
530 |
-
choices=["pyf98/owsm_ctc_v3.1_1B", "espnet/owsm_ctc_v3.2_ft_1B", "espnet/owsm_v3.1_ebf"],
|
531 |
label="Choose ASR:",
|
532 |
value="pyf98/owsm_ctc_v3.1_1B",
|
533 |
)
|
@@ -574,6 +623,15 @@ with gr.Blocks(
|
|
574 |
output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
|
575 |
output_asr_text = gr.Textbox(label="ASR output")
|
576 |
output_text = gr.Textbox(label="LLM output")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
state = gr.State()
|
578 |
with gr.Row():
|
579 |
privacy_text = gr.Textbox(label="Privacy Notice",interactive=False, value="By using this demo, you acknowledge that interactions with this dialog system are collected for research and improvement purposes. The data will only be used to enhance the performance and understanding of the system. If you have any concerns about data collection, please discontinue use.")
|
@@ -604,10 +662,11 @@ with gr.Blocks(
|
|
604 |
diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
|
605 |
ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
|
606 |
callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,natural_response,diversity_response,ip_address],"flagged_data_points")
|
607 |
-
user_audio.stream(transcribe, inputs=[state, user_audio, radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
|
608 |
radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
|
609 |
LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
|
610 |
ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
|
|
|
611 |
output_audio.play(
|
612 |
flash_buttons, [], [natural_response,diversity_response]+btn_list
|
613 |
).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1], None,preprocess=False)
|
|
|
1 |
import os
|
2 |
import shutil
|
3 |
+
import soundfile
|
4 |
+
from eval.TTS_intelligibility import handle_espnet_TTS_intelligibility
|
5 |
+
from eval.ASR_WER import handle_espnet_ASR_WER
|
6 |
+
from eval.TTS_speech_quality import TTS_psuedomos
|
7 |
+
from eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
|
8 |
from huggingface_hub import HfApi
|
9 |
+
from utils import int2float
|
10 |
api = HfApi()
|
11 |
import nltk
|
12 |
nltk.download('averaged_perceptron_tagger_eng')
|
|
|
34 |
noise_scale=0.333,
|
35 |
noise_scale_dur=0.333,
|
36 |
)
|
37 |
+
try:
|
38 |
+
import versa
|
39 |
+
except ImportError:
|
40 |
+
from subprocess import call
|
41 |
+
with open('versa.sh', 'rb') as file:
|
42 |
+
script = file.read()
|
43 |
+
rc = call(script, shell=True)
|
44 |
import numpy as np
|
45 |
from VAD.vad_iterator import VADIterator
|
46 |
import torch
|
|
|
77 |
import kaldiio
|
78 |
from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch
|
79 |
from espnet2.bin.s2t_inference import Speech2Text
|
80 |
+
from espnet2.bin.asr_inference import Speech2Text as S2T_ASR
|
81 |
+
import whisper
|
82 |
|
83 |
s2t = Speech2TextGreedySearch.from_pretrained(
|
84 |
"pyf98/owsm_ctc_v3.1_1B",
|
|
|
87 |
lang_sym='<eng>',
|
88 |
task_sym='<asr>',
|
89 |
)
|
90 |
+
latency_ASR=0.0
|
91 |
+
latency_LM=0.0
|
92 |
+
latency_TTS=0.0
|
93 |
|
94 |
start_event = torch.cuda.Event(enable_timing=True)
|
95 |
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
100 |
end_event.record()
|
101 |
torch.cuda.synchronize()
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
text_str=""
|
104 |
asr_output_str=""
|
105 |
vad_output=None
|
106 |
audio_output = None
|
107 |
audio_output1 = None
|
108 |
+
LLM_response_arr=[]
|
109 |
+
total_response_arr=[]
|
110 |
min_speech_ms=500
|
111 |
max_speech_ms=float("inf")
|
112 |
# ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
|
|
|
245 |
import json
|
246 |
import time
|
247 |
|
248 |
+
def transcribe(stream, new_chunk, option, asr_option):
|
|
|
249 |
sr, y = new_chunk
|
250 |
global text_str
|
251 |
global chat
|
|
|
257 |
global start_record_time
|
258 |
global sids
|
259 |
global spembs
|
260 |
+
global latency_ASR
|
261 |
+
global latency_LM
|
262 |
+
global latency_TTS
|
263 |
+
global LLM_response_arr
|
264 |
+
global total_response_arr
|
265 |
if stream is None:
|
266 |
stream=y
|
267 |
chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
|
|
|
305 |
array = torch.cat(vad_output).cpu().numpy()
|
306 |
duration_ms = len(array) / sr * 1000
|
307 |
if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
|
|
|
|
|
|
|
308 |
start_time = time.time()
|
309 |
+
if asr_option=="whisper":
|
310 |
+
prompt=s2t.transcribe(torch.tensor(array).float(), beam_size=1)["text"]
|
311 |
+
elif asr_option=="librispeech_asr":
|
312 |
+
prompt=s2t(array)[0][0]
|
313 |
+
else:
|
314 |
+
array = librosa.util.fix_length(array, size=(16000 * 30))
|
315 |
+
prompt=" ".join(s2t(array)[0][0].split()[1:])
|
316 |
vad_output = None
|
317 |
if len(prompt.strip().split())<2:
|
318 |
text_str1=text_str
|
|
|
323 |
print(len(prompt.strip().split()))
|
324 |
print(prompt)
|
325 |
asr_output_str=prompt
|
326 |
+
total_response_arr.append(prompt.replace("\n"," "))
|
327 |
# yield (stream,asr_output_str,text_str, audio_output)
|
328 |
+
start_LM_time=time.time()
|
329 |
+
latency_ASR=(start_LM_time - start_time)
|
330 |
+
# print("--- %s seconds ---" % (time.time() - start_time))
|
331 |
# prompt=ASR_model.transcribe(array)["text"].strip()
|
332 |
chat.append({"role": user_role, "content": prompt})
|
333 |
chat_messages = chat.to_list()
|
|
|
338 |
temperature=0.0,
|
339 |
do_sample=False,
|
340 |
)
|
341 |
+
start_TTS_time=time.time()
|
342 |
+
latency_LM=(start_TTS_time - start_LM_time)
|
343 |
generated_text = output[0]['generated_text'][-1]["content"]
|
344 |
|
345 |
# torch.mps.empty_cache()
|
346 |
|
347 |
chat.append({"role": "assistant", "content": generated_text})
|
348 |
text_str=generated_text
|
349 |
+
LLM_response_arr.append(text_str.replace("\n"," "))
|
350 |
+
total_response_arr.append(text_str.replace("\n"," "))
|
351 |
# import pdb;pdb.set_trace()
|
352 |
with torch.no_grad():
|
353 |
if option=="ChatTTS":
|
|
|
365 |
audio_output=(text2speech.fs, audio_chunk)
|
366 |
audio_output1=(orig_sr,stream)
|
367 |
stream=y
|
368 |
+
latency_TTS=(time.time() - start_TTS_time)
|
369 |
# else:
|
370 |
# audio_output=None
|
371 |
text_str1=text_str
|
|
|
380 |
if current_record_time-start_record_time>300:
|
381 |
gr.Info("Conversations are limited to 5 minutes. The session will restart in approximately 60 seconds. Please wait for the demo to reset. Close this message once you have read it.", duration=None)
|
382 |
yield stream,gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Audio(visible=False)
|
383 |
+
api.upload_folder(
|
384 |
+
folder_path="flagged_data_points",
|
385 |
+
path_in_repo="checkpoint_"+str(start_record_time),
|
386 |
+
repo_id="Siddhant/Cascaded_demo_data",
|
387 |
+
repo_type="dataset",
|
388 |
+
token=access_token,
|
389 |
+
# ignore_patterns="**/logs/*.txt", # Ignore all text logs
|
390 |
+
)
|
391 |
chat.buffer=[{"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."}]
|
392 |
text_str=""
|
393 |
audio_output = None
|
394 |
audio_output1 = None
|
395 |
asr_output_str = ""
|
396 |
start_record_time = None
|
397 |
+
LLM_response_arr=[]
|
398 |
+
total_response_arr=[]
|
399 |
shutil.rmtree('flagged_data_points')
|
400 |
os.mkdir("flagged_data_points")
|
401 |
yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
|
|
|
506 |
|
507 |
def handle_ASR_selection(option):
|
508 |
yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
|
509 |
+
if option=="librispeech_asr":
|
510 |
+
option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
|
511 |
global s2t
|
512 |
if option=="espnet/owsm_v3.1_ebf":
|
513 |
s2t = Speech2Text.from_pretrained(
|
|
|
517 |
beam_size=1,
|
518 |
predict_time=False,
|
519 |
)
|
520 |
+
elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
|
521 |
+
s2t = S2T_ASR.from_pretrained(
|
522 |
+
model_tag=option,
|
523 |
+
device="cuda",
|
524 |
+
beam_size=1,
|
525 |
+
)
|
526 |
+
elif option=="whisper":
|
527 |
+
s2t = whisper.load_model("large", device="cuda")
|
528 |
else:
|
529 |
s2t = Speech2TextGreedySearch.from_pretrained(
|
530 |
option,
|
|
|
538 |
end_event = torch.cuda.Event(enable_timing=True)
|
539 |
torch.cuda.synchronize()
|
540 |
start_event.record()
|
541 |
+
|
542 |
+
if option=="whisper":
|
543 |
+
audio, rate = soundfile.read("tts_samples/sample1.wav")
|
544 |
+
array=librosa.resample(audio, orig_sr=rate, target_sr=16000)
|
545 |
+
res=s2t.transcribe(torch.tensor(array).float(), beam_size=1)["text"]
|
546 |
+
elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
|
547 |
+
res = s2t(dummy_input)[0][0]
|
548 |
+
else:
|
549 |
+
speech = librosa.util.fix_length(dummy_input, size=(16000 * 30))
|
550 |
+
res = s2t(speech)
|
551 |
end_event.record()
|
552 |
torch.cuda.synchronize()
|
553 |
yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
|
554 |
+
|
555 |
+
def handle_eval_selection(option, TTS_audio_output, LLM_Output, ASR_audio_output, ASR_transcript):
|
556 |
+
global LLM_response_arr
|
557 |
+
global total_response_arr
|
558 |
+
yield (option,gr.Textbox(visible=True))
|
559 |
+
if option=="Latency":
|
560 |
+
text=f"ASR Latency: {latency_ASR:.2f}\nLLM Latency: {latency_LM:.2f}\nTTS Latency: {latency_TTS:.2f}"
|
561 |
+
yield (None,text)
|
562 |
+
elif option=="TTS Intelligibility":
|
563 |
+
yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
|
564 |
+
elif option=="TTS Speech Quality":
|
565 |
+
yield (None,TTS_psuedomos(TTS_audio_output))
|
566 |
+
elif option=="ASR WER":
|
567 |
+
yield (None,handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
|
568 |
+
elif option=="Text Dialog Metrics":
|
569 |
+
yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr)+bert_score(total_response_arr)+DialoGPT_perplexity(ASR_transcript.replace("\n"," "),LLM_Output.replace("\n"," ")))
|
570 |
+
|
571 |
with gr.Blocks(
|
572 |
title="E2E Spoken Dialog System",
|
573 |
) as demo:
|
|
|
576 |
user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
|
577 |
with gr.Row():
|
578 |
ASR_radio = gr.Radio(
|
579 |
+
choices=["pyf98/owsm_ctc_v3.1_1B", "espnet/owsm_ctc_v3.2_ft_1B", "espnet/owsm_v3.1_ebf", "librispeech_asr", "whisper"],
|
580 |
label="Choose ASR:",
|
581 |
value="pyf98/owsm_ctc_v3.1_1B",
|
582 |
)
|
|
|
623 |
output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
|
624 |
output_asr_text = gr.Textbox(label="ASR output")
|
625 |
output_text = gr.Textbox(label="LLM output")
|
626 |
+
eval_radio = gr.Radio(
|
627 |
+
choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
|
628 |
+
label="Choose Evaluation metrics:",
|
629 |
+
)
|
630 |
+
# TTS Intelligibility_radio = gr.Radio(
|
631 |
+
# choices=["ESPnet", "TTS Intelligibility", "TTS Speech Quality"],
|
632 |
+
# label="Choose ASR model:",
|
633 |
+
# )
|
634 |
+
output_eval_text = gr.Textbox(label="Evaluation Results")
|
635 |
state = gr.State()
|
636 |
with gr.Row():
|
637 |
privacy_text = gr.Textbox(label="Privacy Notice",interactive=False, value="By using this demo, you acknowledge that interactions with this dialog system are collected for research and improvement purposes. The data will only be used to enhance the performance and understanding of the system. If you have any concerns about data collection, please discontinue use.")
|
|
|
662 |
diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
|
663 |
ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
|
664 |
callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,natural_response,diversity_response,ip_address],"flagged_data_points")
|
665 |
+
user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
|
666 |
radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
|
667 |
LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
|
668 |
ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
|
669 |
+
eval_radio.change(fn=handle_eval_selection, inputs=[eval_radio,output_audio,output_text,output_audio1,output_asr_text], outputs=[eval_radio,output_eval_text])
|
670 |
output_audio.play(
|
671 |
flash_buttons, [], [natural_response,diversity_response]+btn_list
|
672 |
).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1], None,preprocess=False)
|
eval/ASR_WER.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import int2float
|
2 |
+
def handle_espnet_ASR_WER(ASR_audio_output,ASR_transcript):
|
3 |
+
from versa import espnet_levenshtein_metric, espnet_wer_setup, owsm_levenshtein_metric, owsm_wer_setup, whisper_levenshtein_metric, whisper_wer_setup
|
4 |
+
score_modules_espnet = {
|
5 |
+
"module": espnet_levenshtein_metric,
|
6 |
+
"args": espnet_wer_setup(
|
7 |
+
model_tag="default",
|
8 |
+
beam_size=1,
|
9 |
+
text_cleaner="whisper_en",
|
10 |
+
use_gpu=True,
|
11 |
+
),
|
12 |
+
}
|
13 |
+
dict1=score_modules_espnet["module"](
|
14 |
+
score_modules_espnet["args"],
|
15 |
+
int2float(ASR_audio_output[1]),
|
16 |
+
ASR_transcript,
|
17 |
+
ASR_audio_output[0],
|
18 |
+
)
|
19 |
+
espnet_wer=(dict1["espnet_wer_delete"]+dict1["espnet_wer_insert"]+dict1["espnet_wer_replace"])/(dict1["espnet_wer_insert"]+dict1["espnet_wer_replace"]+dict1["espnet_wer_equal"])
|
20 |
+
espnet_cer=(dict1["espnet_cer_delete"]+dict1["espnet_cer_insert"]+dict1["espnet_cer_replace"])/(dict1["espnet_cer_insert"]+dict1["espnet_cer_replace"]+dict1["espnet_cer_equal"])
|
21 |
+
score_modules_owsm = {
|
22 |
+
"module": owsm_levenshtein_metric,
|
23 |
+
"args": owsm_wer_setup(
|
24 |
+
model_tag="default",
|
25 |
+
beam_size=1,
|
26 |
+
text_cleaner="whisper_en",
|
27 |
+
use_gpu=True,
|
28 |
+
),
|
29 |
+
}
|
30 |
+
dict1=score_modules_owsm["module"](
|
31 |
+
score_modules_owsm["args"],
|
32 |
+
int2float(ASR_audio_output[1]),
|
33 |
+
ASR_transcript,
|
34 |
+
ASR_audio_output[0],
|
35 |
+
)
|
36 |
+
owsm_wer=(dict1["owsm_wer_delete"]+dict1["owsm_wer_insert"]+dict1["owsm_wer_replace"])/(dict1["owsm_wer_insert"]+dict1["owsm_wer_replace"]+dict1["owsm_wer_equal"])
|
37 |
+
owsm_cer=(dict1["owsm_cer_delete"]+dict1["owsm_cer_insert"]+dict1["owsm_cer_replace"])/(dict1["owsm_cer_insert"]+dict1["owsm_cer_replace"]+dict1["owsm_cer_equal"])
|
38 |
+
score_modules_whisper = {
|
39 |
+
"module": whisper_levenshtein_metric,
|
40 |
+
"args": whisper_wer_setup(
|
41 |
+
model_tag="default",
|
42 |
+
beam_size=1,
|
43 |
+
text_cleaner="whisper_en",
|
44 |
+
use_gpu=True,
|
45 |
+
),
|
46 |
+
}
|
47 |
+
dict1=score_modules_whisper["module"](
|
48 |
+
score_modules_whisper["args"],
|
49 |
+
int2float(ASR_audio_output[1]),
|
50 |
+
ASR_transcript,
|
51 |
+
ASR_audio_output[0],
|
52 |
+
)
|
53 |
+
whisper_wer=(dict1["whisper_wer_delete"]+dict1["whisper_wer_insert"]+dict1["whisper_wer_replace"])/(dict1["whisper_wer_insert"]+dict1["whisper_wer_replace"]+dict1["whisper_wer_equal"])
|
54 |
+
whisper_cer=(dict1["whisper_cer_delete"]+dict1["whisper_cer_insert"]+dict1["whisper_cer_replace"])/(dict1["whisper_cer_insert"]+dict1["whisper_cer_replace"]+dict1["whisper_cer_equal"])
|
55 |
+
return f"ESPnet WER: {espnet_wer*100:.2f}\nESPnet CER: {espnet_cer*100:.2f}\nOWSM WER: {owsm_wer*100:.2f}\nOWSM CER: {owsm_cer*100:.2f}\nWhisper WER: {whisper_wer*100:.2f}\nWhisper CER: {whisper_cer*100:.2f}"
|
eval/LLM_Metrics.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocessing import Pool
|
2 |
+
from eval.vert import get_self_bleu2_geometric, get_auto_bleu2_geometric, run_f
|
3 |
+
import numpy as np
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
5 |
+
import torch
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
from scipy.stats import gmean
|
8 |
+
def perplexity(LLM_Output):
|
9 |
+
import evaluate
|
10 |
+
# import pdb;pdb.set_trace()
|
11 |
+
perplexity = evaluate.load("perplexity", module_type="metric")
|
12 |
+
results = perplexity.compute(model_id='gpt2',predictions=[LLM_Output])
|
13 |
+
return f"Perplexity: {results['mean_perplexity']:.2f}\n"
|
14 |
+
|
15 |
+
def vert(LLM_response_arr):
|
16 |
+
# import pdb;pdb.set_trace()
|
17 |
+
terms = [x.strip().split() for x in LLM_response_arr]
|
18 |
+
|
19 |
+
|
20 |
+
tasks = [
|
21 |
+
('Self-BLEU2-geometric', get_self_bleu2_geometric),
|
22 |
+
('Auto-BLEU2-geometric', get_auto_bleu2_geometric),
|
23 |
+
]
|
24 |
+
n_processes = min(16, len(tasks))
|
25 |
+
with Pool(n_processes) as pool:
|
26 |
+
metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
|
27 |
+
metric_arr=[]
|
28 |
+
str1=""
|
29 |
+
for (metric_name, _), metric in zip(tasks, metrics):
|
30 |
+
metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
|
31 |
+
|
32 |
+
metric, sem = [
|
33 |
+
round(100 * x, 2) for x in [metric, sem]
|
34 |
+
]
|
35 |
+
metric_arr.append(metric)
|
36 |
+
|
37 |
+
str1+=(f'{metric_name}: {metric}\n')
|
38 |
+
str1+=(f'VERT: {round(100*gmean(metric), 2)}\n')
|
39 |
+
return str1
|
40 |
+
|
41 |
+
def bert_score(total_response_arr):
|
42 |
+
# import pdb;pdb.set_trace()
|
43 |
+
def cosine_similarity_context_response(context, response, model, tokenizer):
|
44 |
+
# Tokenize and encode both context and response
|
45 |
+
context_inputs = tokenizer(context, return_tensors="pt", truncation=True)
|
46 |
+
response_inputs = tokenizer(response, return_tensors="pt", truncation=True)
|
47 |
+
for k in context_inputs:
|
48 |
+
context_inputs[k]=context_inputs[k].cuda()
|
49 |
+
for k in response_inputs:
|
50 |
+
response_inputs[k]=response_inputs[k].cuda()
|
51 |
+
|
52 |
+
# Get embeddings from the model
|
53 |
+
with torch.no_grad():
|
54 |
+
context_embedding = model(**context_inputs).last_hidden_state.mean(dim=1)
|
55 |
+
response_embedding = model(**response_inputs).last_hidden_state.mean(dim=1)
|
56 |
+
|
57 |
+
# Compute cosine similarity
|
58 |
+
similarity = cosine_similarity(context_embedding.cpu().numpy(), response_embedding.cpu().numpy())
|
59 |
+
return similarity[0][0]
|
60 |
+
|
61 |
+
bert_model_name = "bert-base-uncased"
|
62 |
+
bert_model = AutoModel.from_pretrained(bert_model_name).cuda()
|
63 |
+
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
|
64 |
+
similarity = cosine_similarity_context_response(" ".join(total_response_arr[:-1]), total_response_arr[-1], bert_model, bert_tokenizer)
|
65 |
+
return (f"Cosine Similarity: {similarity*100:.2f}"+"\n")
|
66 |
+
|
67 |
+
def DialoGPT_perplexity(user_utterance, response):
|
68 |
+
# import pdb;pdb.set_trace()
|
69 |
+
def evaluate_response_with_dialoGPT(context, response, model, tokenizer):
|
70 |
+
"""
|
71 |
+
Evaluate the appropriateness of a response based on the given context using DialoGPT.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
context (str): The dialogue context (previous conversation).
|
75 |
+
response (str): The generated response to evaluate.
|
76 |
+
model: Pre-trained DialoGPT model.
|
77 |
+
tokenizer: Corresponding tokenizer for the DialoGPT model.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
float: Perplexity score of the response given the context.
|
81 |
+
"""
|
82 |
+
model.eval()
|
83 |
+
|
84 |
+
# Combine context and response as input
|
85 |
+
input_text = context + tokenizer.eos_token + response + tokenizer.eos_token
|
86 |
+
inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
|
87 |
+
inputs['input_ids']=inputs['input_ids'].cuda()
|
88 |
+
inputs['attention_mask']=inputs['attention_mask'].cuda()
|
89 |
+
# import pdb;pdb.set_trace()
|
90 |
+
|
91 |
+
# Compute model outputs and loss
|
92 |
+
with torch.no_grad():
|
93 |
+
outputs = model(**inputs, labels=inputs["input_ids"].cuda())
|
94 |
+
loss = outputs.loss
|
95 |
+
|
96 |
+
# Calculate perplexity
|
97 |
+
perplexity = torch.exp(loss)
|
98 |
+
return perplexity.cpu().item()
|
99 |
+
|
100 |
+
# Load DialoGPT model and tokenizer
|
101 |
+
model_name = "microsoft/DialoGPT-medium" # Choose small/medium/large based on your resources
|
102 |
+
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
|
103 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
104 |
+
perplexity = evaluate_response_with_dialoGPT(user_utterance, response, model, tokenizer)
|
105 |
+
return (f"DialoGPT Perplexity: {perplexity:.2f}"+"\n")
|
106 |
+
|
eval/TTS_intelligibility.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import int2float
|
2 |
+
def handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output):
|
3 |
+
from versa import espnet_levenshtein_metric, espnet_wer_setup, owsm_levenshtein_metric, owsm_wer_setup, whisper_levenshtein_metric, whisper_wer_setup
|
4 |
+
score_modules_espnet = {
|
5 |
+
"module": espnet_levenshtein_metric,
|
6 |
+
"args": espnet_wer_setup(
|
7 |
+
model_tag="default",
|
8 |
+
beam_size=1,
|
9 |
+
text_cleaner="whisper_en",
|
10 |
+
use_gpu=True,
|
11 |
+
),
|
12 |
+
}
|
13 |
+
# import pdb;pdb.set_trace()
|
14 |
+
dict1=score_modules_espnet["module"](
|
15 |
+
score_modules_espnet["args"],
|
16 |
+
int2float(TTS_audio_output[1]),
|
17 |
+
LLM_Output,
|
18 |
+
TTS_audio_output[0],
|
19 |
+
)
|
20 |
+
espnet_wer=(dict1["espnet_wer_delete"]+dict1["espnet_wer_insert"]+dict1["espnet_wer_replace"])/(dict1["espnet_wer_delete"]+dict1["espnet_wer_replace"]+dict1["espnet_wer_equal"])
|
21 |
+
espnet_cer=(dict1["espnet_cer_delete"]+dict1["espnet_cer_insert"]+dict1["espnet_cer_replace"])/(dict1["espnet_cer_delete"]+dict1["espnet_cer_replace"]+dict1["espnet_cer_equal"])
|
22 |
+
score_modules_owsm = {
|
23 |
+
"module": owsm_levenshtein_metric,
|
24 |
+
"args": owsm_wer_setup(
|
25 |
+
model_tag="default",
|
26 |
+
beam_size=1,
|
27 |
+
text_cleaner="whisper_en",
|
28 |
+
use_gpu=True,
|
29 |
+
),
|
30 |
+
}
|
31 |
+
# import pdb;pdb.set_trace()
|
32 |
+
dict1=score_modules_owsm["module"](
|
33 |
+
score_modules_owsm["args"],
|
34 |
+
int2float(TTS_audio_output[1]),
|
35 |
+
LLM_Output,
|
36 |
+
TTS_audio_output[0],
|
37 |
+
)
|
38 |
+
owsm_wer=(dict1["owsm_wer_delete"]+dict1["owsm_wer_insert"]+dict1["owsm_wer_replace"])/(dict1["owsm_wer_delete"]+dict1["owsm_wer_replace"]+dict1["owsm_wer_equal"])
|
39 |
+
owsm_cer=(dict1["owsm_cer_delete"]+dict1["owsm_cer_insert"]+dict1["owsm_cer_replace"])/(dict1["owsm_cer_delete"]+dict1["owsm_cer_replace"]+dict1["owsm_cer_equal"])
|
40 |
+
score_modules_whisper = {
|
41 |
+
"module": whisper_levenshtein_metric,
|
42 |
+
"args": whisper_wer_setup(
|
43 |
+
model_tag="default",
|
44 |
+
beam_size=1,
|
45 |
+
text_cleaner="whisper_en",
|
46 |
+
use_gpu=True,
|
47 |
+
),
|
48 |
+
}
|
49 |
+
# import pdb;pdb.set_trace()
|
50 |
+
dict1=score_modules_whisper["module"](
|
51 |
+
score_modules_whisper["args"],
|
52 |
+
int2float(TTS_audio_output[1]),
|
53 |
+
LLM_Output,
|
54 |
+
TTS_audio_output[0],
|
55 |
+
)
|
56 |
+
whisper_wer=(dict1["whisper_wer_delete"]+dict1["whisper_wer_insert"]+dict1["whisper_wer_replace"])/(dict1["whisper_wer_delete"]+dict1["whisper_wer_replace"]+dict1["whisper_wer_equal"])
|
57 |
+
whisper_cer=(dict1["whisper_cer_delete"]+dict1["whisper_cer_insert"]+dict1["whisper_cer_replace"])/(dict1["whisper_cer_delete"]+dict1["whisper_cer_replace"]+dict1["whisper_cer_equal"])
|
58 |
+
return f"ESPnet WER: {espnet_wer*100:.2f}\nESPnet CER: {espnet_cer*100:.2f}\nOWSM WER: {owsm_wer*100:.2f}\nOWSM CER: {owsm_cer*100:.2f}\nWhisper WER: {whisper_wer*100:.2f}\nWhisper CER: {whisper_cer*100:.2f}"
|
eval/TTS_speech_quality.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import int2float
|
2 |
+
def TTS_psuedomos(TTS_audio_output):
|
3 |
+
from versa import pseudo_mos_metric, pseudo_mos_setup, sheet_ssqa, sheet_ssqa_setup
|
4 |
+
|
5 |
+
predictor_dict, predictor_fs = pseudo_mos_setup(
|
6 |
+
use_gpu=True,
|
7 |
+
predictor_types=["utmos", "dnsmos", "plcmos"],
|
8 |
+
predictor_args={"utmos":{"fs": 16000},"dnsmos":{"fs": 16000},"plcmos":{"fs": 16000}},
|
9 |
+
)
|
10 |
+
score_modules = {
|
11 |
+
"module": pseudo_mos_metric,
|
12 |
+
"args": {
|
13 |
+
"predictor_dict": predictor_dict,
|
14 |
+
"predictor_fs": predictor_fs,
|
15 |
+
"use_gpu": True,
|
16 |
+
},
|
17 |
+
}
|
18 |
+
dict1=score_modules["module"](
|
19 |
+
int2float(TTS_audio_output[1]),
|
20 |
+
TTS_audio_output[0],
|
21 |
+
**score_modules["args"],
|
22 |
+
)
|
23 |
+
str1=""
|
24 |
+
for k in dict1:
|
25 |
+
str1=str1+f"{k}: {dict1[k]:.2f}\n"
|
26 |
+
sheet_model = sheet_ssqa_setup(
|
27 |
+
model_tag="default",
|
28 |
+
model_path=None,
|
29 |
+
model_config=None,
|
30 |
+
use_gpu=True,
|
31 |
+
)
|
32 |
+
score_modules = {
|
33 |
+
"module": sheet_ssqa,
|
34 |
+
"args": {"model": sheet_model, "use_gpu": True},
|
35 |
+
}
|
36 |
+
dict1 = score_modules["module"](
|
37 |
+
score_modules["args"]["model"], int2float(TTS_audio_output[1]), TTS_audio_output[0],
|
38 |
+
use_gpu=score_modules["args"]["use_gpu"]
|
39 |
+
)
|
40 |
+
for k in dict1:
|
41 |
+
str1=str1+f"{k}: {dict1[k]:.2f}\n"
|
42 |
+
return str1
|
eval/vert.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import nltk
|
8 |
+
import math
|
9 |
+
import sys
|
10 |
+
from fractions import Fraction
|
11 |
+
import warnings
|
12 |
+
from collections import Counter
|
13 |
+
from nltk.translate.bleu_score import modified_precision, closest_ref_length, brevity_penalty, SmoothingFunction
|
14 |
+
import warnings
|
15 |
+
def corpus_bleu(
|
16 |
+
list_of_references,
|
17 |
+
hypotheses,
|
18 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
19 |
+
smoothing_function=None,
|
20 |
+
auto_reweigh=False,
|
21 |
+
averaging_mode="geometric",
|
22 |
+
no_length_penalty=False
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
|
26 |
+
the hypotheses and their respective references.
|
27 |
+
|
28 |
+
Instead of averaging the sentence level BLEU scores (i.e. marco-average
|
29 |
+
precision), the original BLEU metric (Papineni et al. 2002) accounts for
|
30 |
+
the micro-average precision (i.e. summing the numerators and denominators
|
31 |
+
for each hypothesis-reference(s) pairs before the division).
|
32 |
+
|
33 |
+
>>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
|
34 |
+
... 'ensures', 'that', 'the', 'military', 'always',
|
35 |
+
... 'obeys', 'the', 'commands', 'of', 'the', 'party']
|
36 |
+
>>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
|
37 |
+
... 'ensures', 'that', 'the', 'military', 'will', 'forever',
|
38 |
+
... 'heed', 'Party', 'commands']
|
39 |
+
>>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
|
40 |
+
... 'guarantees', 'the', 'military', 'forces', 'always',
|
41 |
+
... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
|
42 |
+
>>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
|
43 |
+
... 'army', 'always', 'to', 'heed', 'the', 'directions',
|
44 |
+
... 'of', 'the', 'party']
|
45 |
+
|
46 |
+
>>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
|
47 |
+
... 'interested', 'in', 'world', 'history']
|
48 |
+
>>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
|
49 |
+
... 'because', 'he', 'read', 'the', 'book']
|
50 |
+
|
51 |
+
>>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
|
52 |
+
>>> hypotheses = [hyp1, hyp2]
|
53 |
+
>>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
|
54 |
+
0.5920...
|
55 |
+
|
56 |
+
The example below show that corpus_bleu() is different from averaging
|
57 |
+
sentence_bleu() for hypotheses
|
58 |
+
|
59 |
+
>>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
|
60 |
+
>>> score2 = sentence_bleu([ref2a], hyp2)
|
61 |
+
>>> (score1 + score2) / 2 # doctest: +ELLIPSIS
|
62 |
+
0.6223...
|
63 |
+
|
64 |
+
:param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses
|
65 |
+
:type list_of_references: list(list(list(str)))
|
66 |
+
:param hypotheses: a list of hypothesis sentences
|
67 |
+
:type hypotheses: list(list(str))
|
68 |
+
:param weights: weights for unigrams, bigrams, trigrams and so on
|
69 |
+
:type weights: list(float)
|
70 |
+
:param smoothing_function:
|
71 |
+
:type smoothing_function: SmoothingFunction
|
72 |
+
:param auto_reweigh: Option to re-normalize the weights uniformly.
|
73 |
+
:type auto_reweigh: bool
|
74 |
+
:return: The corpus-level BLEU score.
|
75 |
+
:rtype: float
|
76 |
+
"""
|
77 |
+
# Before proceeding to compute BLEU, perform sanity checks.
|
78 |
+
|
79 |
+
p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
|
80 |
+
p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
|
81 |
+
hyp_lengths, ref_lengths = 0, 0
|
82 |
+
|
83 |
+
assert len(list_of_references) == len(hypotheses), (
|
84 |
+
"The number of hypotheses and their reference(s) should be the " "same "
|
85 |
+
)
|
86 |
+
|
87 |
+
# Iterate through each hypothesis and their corresponding references.
|
88 |
+
for references, hypothesis in zip(list_of_references, hypotheses):
|
89 |
+
# For each order of ngram, calculate the numerator and
|
90 |
+
# denominator for the corpus-level modified precision.
|
91 |
+
for i, _ in enumerate(weights, start=1):
|
92 |
+
p_i = modified_precision(references, hypothesis, i)
|
93 |
+
p_numerators[i] += p_i.numerator
|
94 |
+
p_denominators[i] += p_i.denominator
|
95 |
+
|
96 |
+
# Calculate the hypothesis length and the closest reference length.
|
97 |
+
# Adds them to the corpus-level hypothesis and reference counts.
|
98 |
+
hyp_len = len(hypothesis)
|
99 |
+
hyp_lengths += hyp_len
|
100 |
+
ref_lengths += closest_ref_length(references, hyp_len)
|
101 |
+
|
102 |
+
# Calculate corpus-level brevity penalty.
|
103 |
+
if no_length_penalty and averaging_mode == 'geometric':
|
104 |
+
bp = 1.0
|
105 |
+
elif no_length_penalty and averaging_mode == 'arithmetic':
|
106 |
+
bp = 0.0
|
107 |
+
else:
|
108 |
+
assert not no_length_penalty
|
109 |
+
assert averaging_mode != 'arithmetic', 'Not sure how to apply length penalty when aurithmetic mode'
|
110 |
+
bp = brevity_penalty(ref_lengths, hyp_lengths)
|
111 |
+
|
112 |
+
# Uniformly re-weighting based on maximum hypothesis lengths if largest
|
113 |
+
# order of n-grams < 4 and weights is set at default.
|
114 |
+
if auto_reweigh:
|
115 |
+
if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
|
116 |
+
weights = (1 / hyp_lengths,) * hyp_lengths
|
117 |
+
|
118 |
+
# Collects the various precision values for the different ngram orders.
|
119 |
+
p_n = [
|
120 |
+
Fraction(p_numerators[i], p_denominators[i], _normalize=False)
|
121 |
+
for i, _ in enumerate(weights, start=1)
|
122 |
+
]
|
123 |
+
|
124 |
+
# Returns 0 if there's no matching n-grams
|
125 |
+
# We only need to check for p_numerators[1] == 0, since if there's
|
126 |
+
# no unigrams, there won't be any higher order ngrams.
|
127 |
+
if p_numerators[1] == 0:
|
128 |
+
return 0
|
129 |
+
|
130 |
+
# If there's no smoothing, set use method0 from SmoothinFunction class.
|
131 |
+
if not smoothing_function:
|
132 |
+
smoothing_function = SmoothingFunction().method0
|
133 |
+
# Smoothen the modified precision.
|
134 |
+
# Note: smoothing_function() may convert values into floats;
|
135 |
+
# it tries to retain the Fraction object as much as the
|
136 |
+
# smoothing method allows.
|
137 |
+
p_n = smoothing_function(
|
138 |
+
p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
|
139 |
+
)
|
140 |
+
|
141 |
+
if averaging_mode == "geometric":
|
142 |
+
s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
|
143 |
+
s = bp * math.exp(math.fsum(s))
|
144 |
+
elif averaging_mode == "arithmetic":
|
145 |
+
s = (w_i * p_i for w_i, p_i in zip(weights, p_n))
|
146 |
+
s = math.fsum(s)
|
147 |
+
|
148 |
+
return s
|
149 |
+
|
150 |
+
|
151 |
+
def sentence_bleu(
|
152 |
+
references,
|
153 |
+
hypothesis,
|
154 |
+
weights=(0.25, 0.25, 0.25, 0.25),
|
155 |
+
smoothing_function=None,
|
156 |
+
auto_reweigh=False,
|
157 |
+
averaging_mode="geometric",
|
158 |
+
no_length_penalty=False
|
159 |
+
):
|
160 |
+
return corpus_bleu(
|
161 |
+
[references], [hypothesis], weights, smoothing_function, auto_reweigh, averaging_mode, no_length_penalty
|
162 |
+
)
|
163 |
+
|
164 |
+
def get_target_sequences(manifest, ground_truth, to_take=1000):
|
165 |
+
import json
|
166 |
+
import pathlib
|
167 |
+
|
168 |
+
with open(ground_truth, 'r') as fin:
|
169 |
+
original_continuations = json.loads(fin.read())
|
170 |
+
|
171 |
+
sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
|
172 |
+
assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
|
173 |
+
|
174 |
+
sequence2length.sort(key=lambda x: x[1])
|
175 |
+
to_take_sequences = set(v[0] for v in sequence2length[:to_take])
|
176 |
+
to_take_ids = []
|
177 |
+
|
178 |
+
with open(manifest, 'r') as f:
|
179 |
+
f.readline()
|
180 |
+
|
181 |
+
for i, line in enumerate(f.readlines()):
|
182 |
+
seq_id = line.split()[0]
|
183 |
+
seq_id = pathlib.Path(seq_id).name.split('__')[0]
|
184 |
+
|
185 |
+
if seq_id in to_take_sequences:
|
186 |
+
to_take_ids.append(i)
|
187 |
+
|
188 |
+
print(f'Took {len(to_take_ids)} ids')
|
189 |
+
return set(to_take_ids)
|
190 |
+
|
191 |
+
def get_self_bleu(utterances, averaging_mode, weights):
|
192 |
+
self_bleu = []
|
193 |
+
|
194 |
+
for i in range(len(utterances)):
|
195 |
+
hypo = utterances[i]
|
196 |
+
rest = utterances[:i] + utterances[i+1:]
|
197 |
+
|
198 |
+
self_bleu.append(sentence_bleu(rest, hypo, weights,
|
199 |
+
no_length_penalty=True, averaging_mode=averaging_mode))
|
200 |
+
|
201 |
+
return self_bleu
|
202 |
+
|
203 |
+
|
204 |
+
def get_self_bleu2_arithmetic(utterances):
|
205 |
+
weights = (0.5, 0.5) # equal weight for unigrams and bigrams
|
206 |
+
return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
|
207 |
+
|
208 |
+
|
209 |
+
def get_self_bleu2_geometric(utterances):
|
210 |
+
weights = (0.5, 0.5)
|
211 |
+
return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
|
212 |
+
|
213 |
+
|
214 |
+
def get_auto_bleu2_arithmetic(utterances):
|
215 |
+
weights = (0.5, 0.5)
|
216 |
+
return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
|
217 |
+
|
218 |
+
|
219 |
+
def get_auto_bleu2_geometric(utterances):
|
220 |
+
weights = (0.5, 0.5)
|
221 |
+
return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
|
222 |
+
|
223 |
+
|
224 |
+
def get_auto_bleu3_geometric(utterances):
|
225 |
+
weights = (1./3, 1./3, 1./3)
|
226 |
+
return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
|
227 |
+
|
228 |
+
|
229 |
+
def get_auto_bleu3_arithmetic(utterances):
|
230 |
+
weights = (1./3, 1./3, 1./3)
|
231 |
+
return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
|
232 |
+
|
233 |
+
|
234 |
+
def get_self_bleu3_arithmetic(utterances):
|
235 |
+
weights = (1./3, 1./3, 1./3)
|
236 |
+
return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
|
237 |
+
|
238 |
+
|
239 |
+
def get_self_bleu3_geometric(utterances):
|
240 |
+
weights = (1./3, 1./3, 1./3)
|
241 |
+
return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
|
242 |
+
|
243 |
+
|
244 |
+
def auto_bleu(sentence, weights, mean_mode='arithmetic'):
|
245 |
+
if len(sentence) <= 1:
|
246 |
+
return 0
|
247 |
+
|
248 |
+
N = len(weights)
|
249 |
+
|
250 |
+
bleu_n = np.zeros([N])
|
251 |
+
for n in range(N):
|
252 |
+
targ_ngrams = list(nltk.ngrams(sentence, n+1))
|
253 |
+
for p in range(len(targ_ngrams)):
|
254 |
+
left = sentence[:p]
|
255 |
+
right = sentence[(p+n+1):]
|
256 |
+
rest_ngrams = list(nltk.ngrams(left, n+1)) + \
|
257 |
+
list(nltk.ngrams(right, n+1))
|
258 |
+
# compute the nb of matching ngrams
|
259 |
+
bleu_n[n] += targ_ngrams[p] in rest_ngrams
|
260 |
+
bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
|
261 |
+
|
262 |
+
weights = np.array(weights)
|
263 |
+
if mean_mode == 'arithmetic':
|
264 |
+
return (bleu_n * weights).sum()
|
265 |
+
elif mean_mode == 'geometric':
|
266 |
+
return (bleu_n ** weights).prod()
|
267 |
+
else:
|
268 |
+
raise ValueError(f'Unknown agggregation mode {mean_mode}')
|
269 |
+
|
270 |
+
def run_f(task_params):
|
271 |
+
f, terms = task_params
|
272 |
+
return f(terms)
|
requirements.txt
CHANGED
@@ -12,4 +12,5 @@ sounddevice==0.5.0
|
|
12 |
webrtcvad-wheels
|
13 |
webrtcvad==2.0.10
|
14 |
gradio==4.43.0
|
15 |
-
ChatTTS
|
|
|
|
12 |
webrtcvad-wheels
|
13 |
webrtcvad==2.0.10
|
14 |
gradio==4.43.0
|
15 |
+
ChatTTS
|
16 |
+
evaluate
|
tts_samples/sample1.wav
ADDED
Binary file (414 kB). View file
|
|
utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
def int2float(sound):
|
3 |
+
"""
|
4 |
+
Taken from https://github.com/snakers4/silero-vad
|
5 |
+
"""
|
6 |
+
|
7 |
+
abs_max = np.abs(sound).max()
|
8 |
+
sound = sound.astype("float32")
|
9 |
+
if abs_max > 0:
|
10 |
+
sound *= 1 / 32768
|
11 |
+
sound = sound.squeeze() # depends on the use case
|
12 |
+
return sound
|
versa.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git clone https://github.com/shinjiwlab/versa.git
|
2 |
+
cd versa
|
3 |
+
pip install .
|
4 |
+
cd ..
|