Siddhant commited on
Commit
f78ed8b
1 Parent(s): 3168a2f

Add eval metrics

Browse files
app.py CHANGED
@@ -1,7 +1,12 @@
1
  import os
2
  import shutil
3
- import os
 
 
 
 
4
  from huggingface_hub import HfApi
 
5
  api = HfApi()
6
  import nltk
7
  nltk.download('averaged_perceptron_tagger_eng')
@@ -29,7 +34,13 @@ text2speech = Text2Speech.from_pretrained(
29
  noise_scale=0.333,
30
  noise_scale_dur=0.333,
31
  )
32
-
 
 
 
 
 
 
33
  import numpy as np
34
  from VAD.vad_iterator import VADIterator
35
  import torch
@@ -66,6 +77,8 @@ import soundfile as sf
66
  import kaldiio
67
  from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch
68
  from espnet2.bin.s2t_inference import Speech2Text
 
 
69
 
70
  s2t = Speech2TextGreedySearch.from_pretrained(
71
  "pyf98/owsm_ctc_v3.1_1B",
@@ -74,6 +87,9 @@ s2t = Speech2TextGreedySearch.from_pretrained(
74
  lang_sym='<eng>',
75
  task_sym='<asr>',
76
  )
 
 
 
77
 
78
  start_event = torch.cuda.Event(enable_timing=True)
79
  end_event = torch.cuda.Event(enable_timing=True)
@@ -84,23 +100,13 @@ res = s2t(speech)
84
  end_event.record()
85
  torch.cuda.synchronize()
86
 
87
- def int2float(sound):
88
- """
89
- Taken from https://github.com/snakers4/silero-vad
90
- """
91
-
92
- abs_max = np.abs(sound).max()
93
- sound = sound.astype("float32")
94
- if abs_max > 0:
95
- sound *= 1 / 32768
96
- sound = sound.squeeze() # depends on the use case
97
- return sound
98
-
99
  text_str=""
100
  asr_output_str=""
101
  vad_output=None
102
  audio_output = None
103
  audio_output1 = None
 
 
104
  min_speech_ms=500
105
  max_speech_ms=float("inf")
106
  # ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
@@ -239,8 +245,7 @@ def relevant_vote4_last_response(
239
  import json
240
  import time
241
 
242
-
243
- def transcribe(stream, new_chunk, option):
244
  sr, y = new_chunk
245
  global text_str
246
  global chat
@@ -252,6 +257,11 @@ def transcribe(stream, new_chunk, option):
252
  global start_record_time
253
  global sids
254
  global spembs
 
 
 
 
 
255
  if stream is None:
256
  stream=y
257
  chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
@@ -295,11 +305,14 @@ def transcribe(stream, new_chunk, option):
295
  array = torch.cat(vad_output).cpu().numpy()
296
  duration_ms = len(array) / sr * 1000
297
  if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
298
- print(len(array))
299
- array = librosa.util.fix_length(array, size=(16000 * 30))
300
- print(len(array))
301
  start_time = time.time()
302
- prompt=" ".join(s2t(array)[0][0].split()[1:])
 
 
 
 
 
 
303
  vad_output = None
304
  if len(prompt.strip().split())<2:
305
  text_str1=text_str
@@ -310,8 +323,11 @@ def transcribe(stream, new_chunk, option):
310
  print(len(prompt.strip().split()))
311
  print(prompt)
312
  asr_output_str=prompt
 
313
  # yield (stream,asr_output_str,text_str, audio_output)
314
- print("--- %s seconds ---" % (time.time() - start_time))
 
 
315
  # prompt=ASR_model.transcribe(array)["text"].strip()
316
  chat.append({"role": user_role, "content": prompt})
317
  chat_messages = chat.to_list()
@@ -322,14 +338,16 @@ def transcribe(stream, new_chunk, option):
322
  temperature=0.0,
323
  do_sample=False,
324
  )
325
-
326
- print("--- %s seconds ---" % (time.time() - start_time))
327
  generated_text = output[0]['generated_text'][-1]["content"]
328
 
329
  # torch.mps.empty_cache()
330
 
331
  chat.append({"role": "assistant", "content": generated_text})
332
  text_str=generated_text
 
 
333
  # import pdb;pdb.set_trace()
334
  with torch.no_grad():
335
  if option=="ChatTTS":
@@ -347,7 +365,7 @@ def transcribe(stream, new_chunk, option):
347
  audio_output=(text2speech.fs, audio_chunk)
348
  audio_output1=(orig_sr,stream)
349
  stream=y
350
- print("--- %s seconds ---" % (time.time() - start_time))
351
  # else:
352
  # audio_output=None
353
  text_str1=text_str
@@ -362,20 +380,22 @@ def transcribe(stream, new_chunk, option):
362
  if current_record_time-start_record_time>300:
363
  gr.Info("Conversations are limited to 5 minutes. The session will restart in approximately 60 seconds. Please wait for the demo to reset. Close this message once you have read it.", duration=None)
364
  yield stream,gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Audio(visible=False)
365
- # api.upload_folder(
366
- # folder_path="flagged_data_points",
367
- # path_in_repo="checkpoint_"+str(start_record_time),
368
- # repo_id="Siddhant/Cascaded_demo_data",
369
- # repo_type="dataset",
370
- # token=access_token,
371
- # # ignore_patterns="**/logs/*.txt", # Ignore all text logs
372
- # )
373
  chat.buffer=[{"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."}]
374
  text_str=""
375
  audio_output = None
376
  audio_output1 = None
377
  asr_output_str = ""
378
  start_record_time = None
 
 
379
  shutil.rmtree('flagged_data_points')
380
  os.mkdir("flagged_data_points")
381
  yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
@@ -486,6 +506,8 @@ def handle_LLM_selection(option):
486
 
487
  def handle_ASR_selection(option):
488
  yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
 
 
489
  global s2t
490
  if option=="espnet/owsm_v3.1_ebf":
491
  s2t = Speech2Text.from_pretrained(
@@ -495,6 +517,14 @@ def handle_ASR_selection(option):
495
  beam_size=1,
496
  predict_time=False,
497
  )
 
 
 
 
 
 
 
 
498
  else:
499
  s2t = Speech2TextGreedySearch.from_pretrained(
500
  option,
@@ -508,17 +538,36 @@ def handle_ASR_selection(option):
508
  end_event = torch.cuda.Event(enable_timing=True)
509
  torch.cuda.synchronize()
510
  start_event.record()
511
- speech = librosa.util.fix_length(dummy_input, size=(16000 * 30))
512
- res = s2t(speech)
 
 
 
 
 
 
 
 
513
  end_event.record()
514
  torch.cuda.synchronize()
515
  yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
516
- # demo = gr.Interface(
517
- # transcribe,
518
- # ["state", gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))],
519
- # ["state", gr.Textbox(label="ASR output"),gr.Textbox(label="LLM output"), gr.Audio(label="Output", autoplay=True)],
520
- # live=True,
521
- # )
 
 
 
 
 
 
 
 
 
 
 
522
  with gr.Blocks(
523
  title="E2E Spoken Dialog System",
524
  ) as demo:
@@ -527,7 +576,7 @@ with gr.Blocks(
527
  user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
528
  with gr.Row():
529
  ASR_radio = gr.Radio(
530
- choices=["pyf98/owsm_ctc_v3.1_1B", "espnet/owsm_ctc_v3.2_ft_1B", "espnet/owsm_v3.1_ebf"],
531
  label="Choose ASR:",
532
  value="pyf98/owsm_ctc_v3.1_1B",
533
  )
@@ -574,6 +623,15 @@ with gr.Blocks(
574
  output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
575
  output_asr_text = gr.Textbox(label="ASR output")
576
  output_text = gr.Textbox(label="LLM output")
 
 
 
 
 
 
 
 
 
577
  state = gr.State()
578
  with gr.Row():
579
  privacy_text = gr.Textbox(label="Privacy Notice",interactive=False, value="By using this demo, you acknowledge that interactions with this dialog system are collected for research and improvement purposes. The data will only be used to enhance the performance and understanding of the system. If you have any concerns about data collection, please discontinue use.")
@@ -604,10 +662,11 @@ with gr.Blocks(
604
  diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
605
  ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
606
  callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,natural_response,diversity_response,ip_address],"flagged_data_points")
607
- user_audio.stream(transcribe, inputs=[state, user_audio, radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
608
  radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
609
  LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
610
  ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
 
611
  output_audio.play(
612
  flash_buttons, [], [natural_response,diversity_response]+btn_list
613
  ).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1], None,preprocess=False)
 
1
  import os
2
  import shutil
3
+ import soundfile
4
+ from eval.TTS_intelligibility import handle_espnet_TTS_intelligibility
5
+ from eval.ASR_WER import handle_espnet_ASR_WER
6
+ from eval.TTS_speech_quality import TTS_psuedomos
7
+ from eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
8
  from huggingface_hub import HfApi
9
+ from utils import int2float
10
  api = HfApi()
11
  import nltk
12
  nltk.download('averaged_perceptron_tagger_eng')
 
34
  noise_scale=0.333,
35
  noise_scale_dur=0.333,
36
  )
37
+ try:
38
+ import versa
39
+ except ImportError:
40
+ from subprocess import call
41
+ with open('versa.sh', 'rb') as file:
42
+ script = file.read()
43
+ rc = call(script, shell=True)
44
  import numpy as np
45
  from VAD.vad_iterator import VADIterator
46
  import torch
 
77
  import kaldiio
78
  from espnet2.bin.s2t_inference_ctc import Speech2TextGreedySearch
79
  from espnet2.bin.s2t_inference import Speech2Text
80
+ from espnet2.bin.asr_inference import Speech2Text as S2T_ASR
81
+ import whisper
82
 
83
  s2t = Speech2TextGreedySearch.from_pretrained(
84
  "pyf98/owsm_ctc_v3.1_1B",
 
87
  lang_sym='<eng>',
88
  task_sym='<asr>',
89
  )
90
+ latency_ASR=0.0
91
+ latency_LM=0.0
92
+ latency_TTS=0.0
93
 
94
  start_event = torch.cuda.Event(enable_timing=True)
95
  end_event = torch.cuda.Event(enable_timing=True)
 
100
  end_event.record()
101
  torch.cuda.synchronize()
102
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  text_str=""
104
  asr_output_str=""
105
  vad_output=None
106
  audio_output = None
107
  audio_output1 = None
108
+ LLM_response_arr=[]
109
+ total_response_arr=[]
110
  min_speech_ms=500
111
  max_speech_ms=float("inf")
112
  # ASR_model = LightningWhisperMLX(model="distil-large-v3", batch_size=6, quant=None)
 
245
  import json
246
  import time
247
 
248
+ def transcribe(stream, new_chunk, option, asr_option):
 
249
  sr, y = new_chunk
250
  global text_str
251
  global chat
 
257
  global start_record_time
258
  global sids
259
  global spembs
260
+ global latency_ASR
261
+ global latency_LM
262
+ global latency_TTS
263
+ global LLM_response_arr
264
+ global total_response_arr
265
  if stream is None:
266
  stream=y
267
  chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
 
305
  array = torch.cat(vad_output).cpu().numpy()
306
  duration_ms = len(array) / sr * 1000
307
  if (not(duration_ms < min_speech_ms or duration_ms > max_speech_ms)):
 
 
 
308
  start_time = time.time()
309
+ if asr_option=="whisper":
310
+ prompt=s2t.transcribe(torch.tensor(array).float(), beam_size=1)["text"]
311
+ elif asr_option=="librispeech_asr":
312
+ prompt=s2t(array)[0][0]
313
+ else:
314
+ array = librosa.util.fix_length(array, size=(16000 * 30))
315
+ prompt=" ".join(s2t(array)[0][0].split()[1:])
316
  vad_output = None
317
  if len(prompt.strip().split())<2:
318
  text_str1=text_str
 
323
  print(len(prompt.strip().split()))
324
  print(prompt)
325
  asr_output_str=prompt
326
+ total_response_arr.append(prompt.replace("\n"," "))
327
  # yield (stream,asr_output_str,text_str, audio_output)
328
+ start_LM_time=time.time()
329
+ latency_ASR=(start_LM_time - start_time)
330
+ # print("--- %s seconds ---" % (time.time() - start_time))
331
  # prompt=ASR_model.transcribe(array)["text"].strip()
332
  chat.append({"role": user_role, "content": prompt})
333
  chat_messages = chat.to_list()
 
338
  temperature=0.0,
339
  do_sample=False,
340
  )
341
+ start_TTS_time=time.time()
342
+ latency_LM=(start_TTS_time - start_LM_time)
343
  generated_text = output[0]['generated_text'][-1]["content"]
344
 
345
  # torch.mps.empty_cache()
346
 
347
  chat.append({"role": "assistant", "content": generated_text})
348
  text_str=generated_text
349
+ LLM_response_arr.append(text_str.replace("\n"," "))
350
+ total_response_arr.append(text_str.replace("\n"," "))
351
  # import pdb;pdb.set_trace()
352
  with torch.no_grad():
353
  if option=="ChatTTS":
 
365
  audio_output=(text2speech.fs, audio_chunk)
366
  audio_output1=(orig_sr,stream)
367
  stream=y
368
+ latency_TTS=(time.time() - start_TTS_time)
369
  # else:
370
  # audio_output=None
371
  text_str1=text_str
 
380
  if current_record_time-start_record_time>300:
381
  gr.Info("Conversations are limited to 5 minutes. The session will restart in approximately 60 seconds. Please wait for the demo to reset. Close this message once you have read it.", duration=None)
382
  yield stream,gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Audio(visible=False)
383
+ api.upload_folder(
384
+ folder_path="flagged_data_points",
385
+ path_in_repo="checkpoint_"+str(start_record_time),
386
+ repo_id="Siddhant/Cascaded_demo_data",
387
+ repo_type="dataset",
388
+ token=access_token,
389
+ # ignore_patterns="**/logs/*.txt", # Ignore all text logs
390
+ )
391
  chat.buffer=[{"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."}]
392
  text_str=""
393
  audio_output = None
394
  audio_output1 = None
395
  asr_output_str = ""
396
  start_record_time = None
397
+ LLM_response_arr=[]
398
+ total_response_arr=[]
399
  shutil.rmtree('flagged_data_points')
400
  os.mkdir("flagged_data_points")
401
  yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
 
506
 
507
  def handle_ASR_selection(option):
508
  yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
509
+ if option=="librispeech_asr":
510
+ option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
511
  global s2t
512
  if option=="espnet/owsm_v3.1_ebf":
513
  s2t = Speech2Text.from_pretrained(
 
517
  beam_size=1,
518
  predict_time=False,
519
  )
520
+ elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
521
+ s2t = S2T_ASR.from_pretrained(
522
+ model_tag=option,
523
+ device="cuda",
524
+ beam_size=1,
525
+ )
526
+ elif option=="whisper":
527
+ s2t = whisper.load_model("large", device="cuda")
528
  else:
529
  s2t = Speech2TextGreedySearch.from_pretrained(
530
  option,
 
538
  end_event = torch.cuda.Event(enable_timing=True)
539
  torch.cuda.synchronize()
540
  start_event.record()
541
+
542
+ if option=="whisper":
543
+ audio, rate = soundfile.read("tts_samples/sample1.wav")
544
+ array=librosa.resample(audio, orig_sr=rate, target_sr=16000)
545
+ res=s2t.transcribe(torch.tensor(array).float(), beam_size=1)["text"]
546
+ elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
547
+ res = s2t(dummy_input)[0][0]
548
+ else:
549
+ speech = librosa.util.fix_length(dummy_input, size=(16000 * 30))
550
+ res = s2t(speech)
551
  end_event.record()
552
  torch.cuda.synchronize()
553
  yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
554
+
555
+ def handle_eval_selection(option, TTS_audio_output, LLM_Output, ASR_audio_output, ASR_transcript):
556
+ global LLM_response_arr
557
+ global total_response_arr
558
+ yield (option,gr.Textbox(visible=True))
559
+ if option=="Latency":
560
+ text=f"ASR Latency: {latency_ASR:.2f}\nLLM Latency: {latency_LM:.2f}\nTTS Latency: {latency_TTS:.2f}"
561
+ yield (None,text)
562
+ elif option=="TTS Intelligibility":
563
+ yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
564
+ elif option=="TTS Speech Quality":
565
+ yield (None,TTS_psuedomos(TTS_audio_output))
566
+ elif option=="ASR WER":
567
+ yield (None,handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
568
+ elif option=="Text Dialog Metrics":
569
+ yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr)+bert_score(total_response_arr)+DialoGPT_perplexity(ASR_transcript.replace("\n"," "),LLM_Output.replace("\n"," ")))
570
+
571
  with gr.Blocks(
572
  title="E2E Spoken Dialog System",
573
  ) as demo:
 
576
  user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
577
  with gr.Row():
578
  ASR_radio = gr.Radio(
579
+ choices=["pyf98/owsm_ctc_v3.1_1B", "espnet/owsm_ctc_v3.2_ft_1B", "espnet/owsm_v3.1_ebf", "librispeech_asr", "whisper"],
580
  label="Choose ASR:",
581
  value="pyf98/owsm_ctc_v3.1_1B",
582
  )
 
623
  output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
624
  output_asr_text = gr.Textbox(label="ASR output")
625
  output_text = gr.Textbox(label="LLM output")
626
+ eval_radio = gr.Radio(
627
+ choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
628
+ label="Choose Evaluation metrics:",
629
+ )
630
+ # TTS Intelligibility_radio = gr.Radio(
631
+ # choices=["ESPnet", "TTS Intelligibility", "TTS Speech Quality"],
632
+ # label="Choose ASR model:",
633
+ # )
634
+ output_eval_text = gr.Textbox(label="Evaluation Results")
635
  state = gr.State()
636
  with gr.Row():
637
  privacy_text = gr.Textbox(label="Privacy Notice",interactive=False, value="By using this demo, you acknowledge that interactions with this dialog system are collected for research and improvement purposes. The data will only be used to enhance the performance and understanding of the system. If you have any concerns about data collection, please discontinue use.")
 
662
  diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
663
  ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
664
  callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,natural_response,diversity_response,ip_address],"flagged_data_points")
665
+ user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
666
  radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
667
  LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
668
  ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
669
+ eval_radio.change(fn=handle_eval_selection, inputs=[eval_radio,output_audio,output_text,output_audio1,output_asr_text], outputs=[eval_radio,output_eval_text])
670
  output_audio.play(
671
  flash_buttons, [], [natural_response,diversity_response]+btn_list
672
  ).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1], None,preprocess=False)
eval/ASR_WER.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import int2float
2
+ def handle_espnet_ASR_WER(ASR_audio_output,ASR_transcript):
3
+ from versa import espnet_levenshtein_metric, espnet_wer_setup, owsm_levenshtein_metric, owsm_wer_setup, whisper_levenshtein_metric, whisper_wer_setup
4
+ score_modules_espnet = {
5
+ "module": espnet_levenshtein_metric,
6
+ "args": espnet_wer_setup(
7
+ model_tag="default",
8
+ beam_size=1,
9
+ text_cleaner="whisper_en",
10
+ use_gpu=True,
11
+ ),
12
+ }
13
+ dict1=score_modules_espnet["module"](
14
+ score_modules_espnet["args"],
15
+ int2float(ASR_audio_output[1]),
16
+ ASR_transcript,
17
+ ASR_audio_output[0],
18
+ )
19
+ espnet_wer=(dict1["espnet_wer_delete"]+dict1["espnet_wer_insert"]+dict1["espnet_wer_replace"])/(dict1["espnet_wer_insert"]+dict1["espnet_wer_replace"]+dict1["espnet_wer_equal"])
20
+ espnet_cer=(dict1["espnet_cer_delete"]+dict1["espnet_cer_insert"]+dict1["espnet_cer_replace"])/(dict1["espnet_cer_insert"]+dict1["espnet_cer_replace"]+dict1["espnet_cer_equal"])
21
+ score_modules_owsm = {
22
+ "module": owsm_levenshtein_metric,
23
+ "args": owsm_wer_setup(
24
+ model_tag="default",
25
+ beam_size=1,
26
+ text_cleaner="whisper_en",
27
+ use_gpu=True,
28
+ ),
29
+ }
30
+ dict1=score_modules_owsm["module"](
31
+ score_modules_owsm["args"],
32
+ int2float(ASR_audio_output[1]),
33
+ ASR_transcript,
34
+ ASR_audio_output[0],
35
+ )
36
+ owsm_wer=(dict1["owsm_wer_delete"]+dict1["owsm_wer_insert"]+dict1["owsm_wer_replace"])/(dict1["owsm_wer_insert"]+dict1["owsm_wer_replace"]+dict1["owsm_wer_equal"])
37
+ owsm_cer=(dict1["owsm_cer_delete"]+dict1["owsm_cer_insert"]+dict1["owsm_cer_replace"])/(dict1["owsm_cer_insert"]+dict1["owsm_cer_replace"]+dict1["owsm_cer_equal"])
38
+ score_modules_whisper = {
39
+ "module": whisper_levenshtein_metric,
40
+ "args": whisper_wer_setup(
41
+ model_tag="default",
42
+ beam_size=1,
43
+ text_cleaner="whisper_en",
44
+ use_gpu=True,
45
+ ),
46
+ }
47
+ dict1=score_modules_whisper["module"](
48
+ score_modules_whisper["args"],
49
+ int2float(ASR_audio_output[1]),
50
+ ASR_transcript,
51
+ ASR_audio_output[0],
52
+ )
53
+ whisper_wer=(dict1["whisper_wer_delete"]+dict1["whisper_wer_insert"]+dict1["whisper_wer_replace"])/(dict1["whisper_wer_insert"]+dict1["whisper_wer_replace"]+dict1["whisper_wer_equal"])
54
+ whisper_cer=(dict1["whisper_cer_delete"]+dict1["whisper_cer_insert"]+dict1["whisper_cer_replace"])/(dict1["whisper_cer_insert"]+dict1["whisper_cer_replace"]+dict1["whisper_cer_equal"])
55
+ return f"ESPnet WER: {espnet_wer*100:.2f}\nESPnet CER: {espnet_cer*100:.2f}\nOWSM WER: {owsm_wer*100:.2f}\nOWSM CER: {owsm_cer*100:.2f}\nWhisper WER: {whisper_wer*100:.2f}\nWhisper CER: {whisper_cer*100:.2f}"
eval/LLM_Metrics.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing import Pool
2
+ from eval.vert import get_self_bleu2_geometric, get_auto_bleu2_geometric, run_f
3
+ import numpy as np
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
5
+ import torch
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from scipy.stats import gmean
8
+ def perplexity(LLM_Output):
9
+ import evaluate
10
+ # import pdb;pdb.set_trace()
11
+ perplexity = evaluate.load("perplexity", module_type="metric")
12
+ results = perplexity.compute(model_id='gpt2',predictions=[LLM_Output])
13
+ return f"Perplexity: {results['mean_perplexity']:.2f}\n"
14
+
15
+ def vert(LLM_response_arr):
16
+ # import pdb;pdb.set_trace()
17
+ terms = [x.strip().split() for x in LLM_response_arr]
18
+
19
+
20
+ tasks = [
21
+ ('Self-BLEU2-geometric', get_self_bleu2_geometric),
22
+ ('Auto-BLEU2-geometric', get_auto_bleu2_geometric),
23
+ ]
24
+ n_processes = min(16, len(tasks))
25
+ with Pool(n_processes) as pool:
26
+ metrics = pool.map(run_f, [(t[1], terms) for t in tasks])
27
+ metric_arr=[]
28
+ str1=""
29
+ for (metric_name, _), metric in zip(tasks, metrics):
30
+ metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric))
31
+
32
+ metric, sem = [
33
+ round(100 * x, 2) for x in [metric, sem]
34
+ ]
35
+ metric_arr.append(metric)
36
+
37
+ str1+=(f'{metric_name}: {metric}\n')
38
+ str1+=(f'VERT: {round(100*gmean(metric), 2)}\n')
39
+ return str1
40
+
41
+ def bert_score(total_response_arr):
42
+ # import pdb;pdb.set_trace()
43
+ def cosine_similarity_context_response(context, response, model, tokenizer):
44
+ # Tokenize and encode both context and response
45
+ context_inputs = tokenizer(context, return_tensors="pt", truncation=True)
46
+ response_inputs = tokenizer(response, return_tensors="pt", truncation=True)
47
+ for k in context_inputs:
48
+ context_inputs[k]=context_inputs[k].cuda()
49
+ for k in response_inputs:
50
+ response_inputs[k]=response_inputs[k].cuda()
51
+
52
+ # Get embeddings from the model
53
+ with torch.no_grad():
54
+ context_embedding = model(**context_inputs).last_hidden_state.mean(dim=1)
55
+ response_embedding = model(**response_inputs).last_hidden_state.mean(dim=1)
56
+
57
+ # Compute cosine similarity
58
+ similarity = cosine_similarity(context_embedding.cpu().numpy(), response_embedding.cpu().numpy())
59
+ return similarity[0][0]
60
+
61
+ bert_model_name = "bert-base-uncased"
62
+ bert_model = AutoModel.from_pretrained(bert_model_name).cuda()
63
+ bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
64
+ similarity = cosine_similarity_context_response(" ".join(total_response_arr[:-1]), total_response_arr[-1], bert_model, bert_tokenizer)
65
+ return (f"Cosine Similarity: {similarity*100:.2f}"+"\n")
66
+
67
+ def DialoGPT_perplexity(user_utterance, response):
68
+ # import pdb;pdb.set_trace()
69
+ def evaluate_response_with_dialoGPT(context, response, model, tokenizer):
70
+ """
71
+ Evaluate the appropriateness of a response based on the given context using DialoGPT.
72
+
73
+ Args:
74
+ context (str): The dialogue context (previous conversation).
75
+ response (str): The generated response to evaluate.
76
+ model: Pre-trained DialoGPT model.
77
+ tokenizer: Corresponding tokenizer for the DialoGPT model.
78
+
79
+ Returns:
80
+ float: Perplexity score of the response given the context.
81
+ """
82
+ model.eval()
83
+
84
+ # Combine context and response as input
85
+ input_text = context + tokenizer.eos_token + response + tokenizer.eos_token
86
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True)
87
+ inputs['input_ids']=inputs['input_ids'].cuda()
88
+ inputs['attention_mask']=inputs['attention_mask'].cuda()
89
+ # import pdb;pdb.set_trace()
90
+
91
+ # Compute model outputs and loss
92
+ with torch.no_grad():
93
+ outputs = model(**inputs, labels=inputs["input_ids"].cuda())
94
+ loss = outputs.loss
95
+
96
+ # Calculate perplexity
97
+ perplexity = torch.exp(loss)
98
+ return perplexity.cpu().item()
99
+
100
+ # Load DialoGPT model and tokenizer
101
+ model_name = "microsoft/DialoGPT-medium" # Choose small/medium/large based on your resources
102
+ model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
103
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
104
+ perplexity = evaluate_response_with_dialoGPT(user_utterance, response, model, tokenizer)
105
+ return (f"DialoGPT Perplexity: {perplexity:.2f}"+"\n")
106
+
eval/TTS_intelligibility.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import int2float
2
+ def handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output):
3
+ from versa import espnet_levenshtein_metric, espnet_wer_setup, owsm_levenshtein_metric, owsm_wer_setup, whisper_levenshtein_metric, whisper_wer_setup
4
+ score_modules_espnet = {
5
+ "module": espnet_levenshtein_metric,
6
+ "args": espnet_wer_setup(
7
+ model_tag="default",
8
+ beam_size=1,
9
+ text_cleaner="whisper_en",
10
+ use_gpu=True,
11
+ ),
12
+ }
13
+ # import pdb;pdb.set_trace()
14
+ dict1=score_modules_espnet["module"](
15
+ score_modules_espnet["args"],
16
+ int2float(TTS_audio_output[1]),
17
+ LLM_Output,
18
+ TTS_audio_output[0],
19
+ )
20
+ espnet_wer=(dict1["espnet_wer_delete"]+dict1["espnet_wer_insert"]+dict1["espnet_wer_replace"])/(dict1["espnet_wer_delete"]+dict1["espnet_wer_replace"]+dict1["espnet_wer_equal"])
21
+ espnet_cer=(dict1["espnet_cer_delete"]+dict1["espnet_cer_insert"]+dict1["espnet_cer_replace"])/(dict1["espnet_cer_delete"]+dict1["espnet_cer_replace"]+dict1["espnet_cer_equal"])
22
+ score_modules_owsm = {
23
+ "module": owsm_levenshtein_metric,
24
+ "args": owsm_wer_setup(
25
+ model_tag="default",
26
+ beam_size=1,
27
+ text_cleaner="whisper_en",
28
+ use_gpu=True,
29
+ ),
30
+ }
31
+ # import pdb;pdb.set_trace()
32
+ dict1=score_modules_owsm["module"](
33
+ score_modules_owsm["args"],
34
+ int2float(TTS_audio_output[1]),
35
+ LLM_Output,
36
+ TTS_audio_output[0],
37
+ )
38
+ owsm_wer=(dict1["owsm_wer_delete"]+dict1["owsm_wer_insert"]+dict1["owsm_wer_replace"])/(dict1["owsm_wer_delete"]+dict1["owsm_wer_replace"]+dict1["owsm_wer_equal"])
39
+ owsm_cer=(dict1["owsm_cer_delete"]+dict1["owsm_cer_insert"]+dict1["owsm_cer_replace"])/(dict1["owsm_cer_delete"]+dict1["owsm_cer_replace"]+dict1["owsm_cer_equal"])
40
+ score_modules_whisper = {
41
+ "module": whisper_levenshtein_metric,
42
+ "args": whisper_wer_setup(
43
+ model_tag="default",
44
+ beam_size=1,
45
+ text_cleaner="whisper_en",
46
+ use_gpu=True,
47
+ ),
48
+ }
49
+ # import pdb;pdb.set_trace()
50
+ dict1=score_modules_whisper["module"](
51
+ score_modules_whisper["args"],
52
+ int2float(TTS_audio_output[1]),
53
+ LLM_Output,
54
+ TTS_audio_output[0],
55
+ )
56
+ whisper_wer=(dict1["whisper_wer_delete"]+dict1["whisper_wer_insert"]+dict1["whisper_wer_replace"])/(dict1["whisper_wer_delete"]+dict1["whisper_wer_replace"]+dict1["whisper_wer_equal"])
57
+ whisper_cer=(dict1["whisper_cer_delete"]+dict1["whisper_cer_insert"]+dict1["whisper_cer_replace"])/(dict1["whisper_cer_delete"]+dict1["whisper_cer_replace"]+dict1["whisper_cer_equal"])
58
+ return f"ESPnet WER: {espnet_wer*100:.2f}\nESPnet CER: {espnet_cer*100:.2f}\nOWSM WER: {owsm_wer*100:.2f}\nOWSM CER: {owsm_cer*100:.2f}\nWhisper WER: {whisper_wer*100:.2f}\nWhisper CER: {whisper_cer*100:.2f}"
eval/TTS_speech_quality.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import int2float
2
+ def TTS_psuedomos(TTS_audio_output):
3
+ from versa import pseudo_mos_metric, pseudo_mos_setup, sheet_ssqa, sheet_ssqa_setup
4
+
5
+ predictor_dict, predictor_fs = pseudo_mos_setup(
6
+ use_gpu=True,
7
+ predictor_types=["utmos", "dnsmos", "plcmos"],
8
+ predictor_args={"utmos":{"fs": 16000},"dnsmos":{"fs": 16000},"plcmos":{"fs": 16000}},
9
+ )
10
+ score_modules = {
11
+ "module": pseudo_mos_metric,
12
+ "args": {
13
+ "predictor_dict": predictor_dict,
14
+ "predictor_fs": predictor_fs,
15
+ "use_gpu": True,
16
+ },
17
+ }
18
+ dict1=score_modules["module"](
19
+ int2float(TTS_audio_output[1]),
20
+ TTS_audio_output[0],
21
+ **score_modules["args"],
22
+ )
23
+ str1=""
24
+ for k in dict1:
25
+ str1=str1+f"{k}: {dict1[k]:.2f}\n"
26
+ sheet_model = sheet_ssqa_setup(
27
+ model_tag="default",
28
+ model_path=None,
29
+ model_config=None,
30
+ use_gpu=True,
31
+ )
32
+ score_modules = {
33
+ "module": sheet_ssqa,
34
+ "args": {"model": sheet_model, "use_gpu": True},
35
+ }
36
+ dict1 = score_modules["module"](
37
+ score_modules["args"]["model"], int2float(TTS_audio_output[1]), TTS_audio_output[0],
38
+ use_gpu=score_modules["args"]["use_gpu"]
39
+ )
40
+ for k in dict1:
41
+ str1=str1+f"{k}: {dict1[k]:.2f}\n"
42
+ return str1
eval/vert.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import nltk
8
+ import math
9
+ import sys
10
+ from fractions import Fraction
11
+ import warnings
12
+ from collections import Counter
13
+ from nltk.translate.bleu_score import modified_precision, closest_ref_length, brevity_penalty, SmoothingFunction
14
+ import warnings
15
+ def corpus_bleu(
16
+ list_of_references,
17
+ hypotheses,
18
+ weights=(0.25, 0.25, 0.25, 0.25),
19
+ smoothing_function=None,
20
+ auto_reweigh=False,
21
+ averaging_mode="geometric",
22
+ no_length_penalty=False
23
+ ):
24
+ """
25
+ Calculate a single corpus-level BLEU score (aka. system-level BLEU) for all
26
+ the hypotheses and their respective references.
27
+
28
+ Instead of averaging the sentence level BLEU scores (i.e. marco-average
29
+ precision), the original BLEU metric (Papineni et al. 2002) accounts for
30
+ the micro-average precision (i.e. summing the numerators and denominators
31
+ for each hypothesis-reference(s) pairs before the division).
32
+
33
+ >>> hyp1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which',
34
+ ... 'ensures', 'that', 'the', 'military', 'always',
35
+ ... 'obeys', 'the', 'commands', 'of', 'the', 'party']
36
+ >>> ref1a = ['It', 'is', 'a', 'guide', 'to', 'action', 'that',
37
+ ... 'ensures', 'that', 'the', 'military', 'will', 'forever',
38
+ ... 'heed', 'Party', 'commands']
39
+ >>> ref1b = ['It', 'is', 'the', 'guiding', 'principle', 'which',
40
+ ... 'guarantees', 'the', 'military', 'forces', 'always',
41
+ ... 'being', 'under', 'the', 'command', 'of', 'the', 'Party']
42
+ >>> ref1c = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the',
43
+ ... 'army', 'always', 'to', 'heed', 'the', 'directions',
44
+ ... 'of', 'the', 'party']
45
+
46
+ >>> hyp2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was',
47
+ ... 'interested', 'in', 'world', 'history']
48
+ >>> ref2a = ['he', 'was', 'interested', 'in', 'world', 'history',
49
+ ... 'because', 'he', 'read', 'the', 'book']
50
+
51
+ >>> list_of_references = [[ref1a, ref1b, ref1c], [ref2a]]
52
+ >>> hypotheses = [hyp1, hyp2]
53
+ >>> corpus_bleu(list_of_references, hypotheses) # doctest: +ELLIPSIS
54
+ 0.5920...
55
+
56
+ The example below show that corpus_bleu() is different from averaging
57
+ sentence_bleu() for hypotheses
58
+
59
+ >>> score1 = sentence_bleu([ref1a, ref1b, ref1c], hyp1)
60
+ >>> score2 = sentence_bleu([ref2a], hyp2)
61
+ >>> (score1 + score2) / 2 # doctest: +ELLIPSIS
62
+ 0.6223...
63
+
64
+ :param list_of_references: a corpus of lists of reference sentences, w.r.t. hypotheses
65
+ :type list_of_references: list(list(list(str)))
66
+ :param hypotheses: a list of hypothesis sentences
67
+ :type hypotheses: list(list(str))
68
+ :param weights: weights for unigrams, bigrams, trigrams and so on
69
+ :type weights: list(float)
70
+ :param smoothing_function:
71
+ :type smoothing_function: SmoothingFunction
72
+ :param auto_reweigh: Option to re-normalize the weights uniformly.
73
+ :type auto_reweigh: bool
74
+ :return: The corpus-level BLEU score.
75
+ :rtype: float
76
+ """
77
+ # Before proceeding to compute BLEU, perform sanity checks.
78
+
79
+ p_numerators = Counter() # Key = ngram order, and value = no. of ngram matches.
80
+ p_denominators = Counter() # Key = ngram order, and value = no. of ngram in ref.
81
+ hyp_lengths, ref_lengths = 0, 0
82
+
83
+ assert len(list_of_references) == len(hypotheses), (
84
+ "The number of hypotheses and their reference(s) should be the " "same "
85
+ )
86
+
87
+ # Iterate through each hypothesis and their corresponding references.
88
+ for references, hypothesis in zip(list_of_references, hypotheses):
89
+ # For each order of ngram, calculate the numerator and
90
+ # denominator for the corpus-level modified precision.
91
+ for i, _ in enumerate(weights, start=1):
92
+ p_i = modified_precision(references, hypothesis, i)
93
+ p_numerators[i] += p_i.numerator
94
+ p_denominators[i] += p_i.denominator
95
+
96
+ # Calculate the hypothesis length and the closest reference length.
97
+ # Adds them to the corpus-level hypothesis and reference counts.
98
+ hyp_len = len(hypothesis)
99
+ hyp_lengths += hyp_len
100
+ ref_lengths += closest_ref_length(references, hyp_len)
101
+
102
+ # Calculate corpus-level brevity penalty.
103
+ if no_length_penalty and averaging_mode == 'geometric':
104
+ bp = 1.0
105
+ elif no_length_penalty and averaging_mode == 'arithmetic':
106
+ bp = 0.0
107
+ else:
108
+ assert not no_length_penalty
109
+ assert averaging_mode != 'arithmetic', 'Not sure how to apply length penalty when aurithmetic mode'
110
+ bp = brevity_penalty(ref_lengths, hyp_lengths)
111
+
112
+ # Uniformly re-weighting based on maximum hypothesis lengths if largest
113
+ # order of n-grams < 4 and weights is set at default.
114
+ if auto_reweigh:
115
+ if hyp_lengths < 4 and weights == (0.25, 0.25, 0.25, 0.25):
116
+ weights = (1 / hyp_lengths,) * hyp_lengths
117
+
118
+ # Collects the various precision values for the different ngram orders.
119
+ p_n = [
120
+ Fraction(p_numerators[i], p_denominators[i], _normalize=False)
121
+ for i, _ in enumerate(weights, start=1)
122
+ ]
123
+
124
+ # Returns 0 if there's no matching n-grams
125
+ # We only need to check for p_numerators[1] == 0, since if there's
126
+ # no unigrams, there won't be any higher order ngrams.
127
+ if p_numerators[1] == 0:
128
+ return 0
129
+
130
+ # If there's no smoothing, set use method0 from SmoothinFunction class.
131
+ if not smoothing_function:
132
+ smoothing_function = SmoothingFunction().method0
133
+ # Smoothen the modified precision.
134
+ # Note: smoothing_function() may convert values into floats;
135
+ # it tries to retain the Fraction object as much as the
136
+ # smoothing method allows.
137
+ p_n = smoothing_function(
138
+ p_n, references=references, hypothesis=hypothesis, hyp_len=hyp_lengths
139
+ )
140
+
141
+ if averaging_mode == "geometric":
142
+ s = (w_i * math.log(p_i) for w_i, p_i in zip(weights, p_n))
143
+ s = bp * math.exp(math.fsum(s))
144
+ elif averaging_mode == "arithmetic":
145
+ s = (w_i * p_i for w_i, p_i in zip(weights, p_n))
146
+ s = math.fsum(s)
147
+
148
+ return s
149
+
150
+
151
+ def sentence_bleu(
152
+ references,
153
+ hypothesis,
154
+ weights=(0.25, 0.25, 0.25, 0.25),
155
+ smoothing_function=None,
156
+ auto_reweigh=False,
157
+ averaging_mode="geometric",
158
+ no_length_penalty=False
159
+ ):
160
+ return corpus_bleu(
161
+ [references], [hypothesis], weights, smoothing_function, auto_reweigh, averaging_mode, no_length_penalty
162
+ )
163
+
164
+ def get_target_sequences(manifest, ground_truth, to_take=1000):
165
+ import json
166
+ import pathlib
167
+
168
+ with open(ground_truth, 'r') as fin:
169
+ original_continuations = json.loads(fin.read())
170
+
171
+ sequence2length = [(k, v[0]) for k, v in original_continuations.items()]
172
+ assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds
173
+
174
+ sequence2length.sort(key=lambda x: x[1])
175
+ to_take_sequences = set(v[0] for v in sequence2length[:to_take])
176
+ to_take_ids = []
177
+
178
+ with open(manifest, 'r') as f:
179
+ f.readline()
180
+
181
+ for i, line in enumerate(f.readlines()):
182
+ seq_id = line.split()[0]
183
+ seq_id = pathlib.Path(seq_id).name.split('__')[0]
184
+
185
+ if seq_id in to_take_sequences:
186
+ to_take_ids.append(i)
187
+
188
+ print(f'Took {len(to_take_ids)} ids')
189
+ return set(to_take_ids)
190
+
191
+ def get_self_bleu(utterances, averaging_mode, weights):
192
+ self_bleu = []
193
+
194
+ for i in range(len(utterances)):
195
+ hypo = utterances[i]
196
+ rest = utterances[:i] + utterances[i+1:]
197
+
198
+ self_bleu.append(sentence_bleu(rest, hypo, weights,
199
+ no_length_penalty=True, averaging_mode=averaging_mode))
200
+
201
+ return self_bleu
202
+
203
+
204
+ def get_self_bleu2_arithmetic(utterances):
205
+ weights = (0.5, 0.5) # equal weight for unigrams and bigrams
206
+ return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
207
+
208
+
209
+ def get_self_bleu2_geometric(utterances):
210
+ weights = (0.5, 0.5)
211
+ return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
212
+
213
+
214
+ def get_auto_bleu2_arithmetic(utterances):
215
+ weights = (0.5, 0.5)
216
+ return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
217
+
218
+
219
+ def get_auto_bleu2_geometric(utterances):
220
+ weights = (0.5, 0.5)
221
+ return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
222
+
223
+
224
+ def get_auto_bleu3_geometric(utterances):
225
+ weights = (1./3, 1./3, 1./3)
226
+ return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances]
227
+
228
+
229
+ def get_auto_bleu3_arithmetic(utterances):
230
+ weights = (1./3, 1./3, 1./3)
231
+ return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances]
232
+
233
+
234
+ def get_self_bleu3_arithmetic(utterances):
235
+ weights = (1./3, 1./3, 1./3)
236
+ return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights)
237
+
238
+
239
+ def get_self_bleu3_geometric(utterances):
240
+ weights = (1./3, 1./3, 1./3)
241
+ return get_self_bleu(utterances, averaging_mode='geometric', weights=weights)
242
+
243
+
244
+ def auto_bleu(sentence, weights, mean_mode='arithmetic'):
245
+ if len(sentence) <= 1:
246
+ return 0
247
+
248
+ N = len(weights)
249
+
250
+ bleu_n = np.zeros([N])
251
+ for n in range(N):
252
+ targ_ngrams = list(nltk.ngrams(sentence, n+1))
253
+ for p in range(len(targ_ngrams)):
254
+ left = sentence[:p]
255
+ right = sentence[(p+n+1):]
256
+ rest_ngrams = list(nltk.ngrams(left, n+1)) + \
257
+ list(nltk.ngrams(right, n+1))
258
+ # compute the nb of matching ngrams
259
+ bleu_n[n] += targ_ngrams[p] in rest_ngrams
260
+ bleu_n[n] /= len(targ_ngrams) # average them to get a proportion
261
+
262
+ weights = np.array(weights)
263
+ if mean_mode == 'arithmetic':
264
+ return (bleu_n * weights).sum()
265
+ elif mean_mode == 'geometric':
266
+ return (bleu_n ** weights).prod()
267
+ else:
268
+ raise ValueError(f'Unknown agggregation mode {mean_mode}')
269
+
270
+ def run_f(task_params):
271
+ f, terms = task_params
272
+ return f(terms)
requirements.txt CHANGED
@@ -12,4 +12,5 @@ sounddevice==0.5.0
12
  webrtcvad-wheels
13
  webrtcvad==2.0.10
14
  gradio==4.43.0
15
- ChatTTS
 
 
12
  webrtcvad-wheels
13
  webrtcvad==2.0.10
14
  gradio==4.43.0
15
+ ChatTTS
16
+ evaluate
tts_samples/sample1.wav ADDED
Binary file (414 kB). View file
 
utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ def int2float(sound):
3
+ """
4
+ Taken from https://github.com/snakers4/silero-vad
5
+ """
6
+
7
+ abs_max = np.abs(sound).max()
8
+ sound = sound.astype("float32")
9
+ if abs_max > 0:
10
+ sound *= 1 / 32768
11
+ sound = sound.squeeze() # depends on the use case
12
+ return sound
versa.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git clone https://github.com/shinjiwlab/versa.git
2
+ cd versa
3
+ pip install .
4
+ cd ..