Siddhant commited on
Commit
0c97eed
·
1 Parent(s): 73e3428

Add SDS demo code

Browse files
Files changed (3) hide show
  1. app.py +442 -0
  2. requirements.txt +18 -0
  3. versa.sh +4 -0
app.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import versa
3
+ except ImportError:
4
+ from subprocess import call
5
+ with open('versa.sh', 'rb') as file:
6
+ script = file.read()
7
+ rc = call(script, shell=True)
8
+ import os
9
+ import shutil
10
+ from espnet2.sds.asr.espnet_asr import ESPnetASRModel
11
+ from espnet2.sds.asr.owsm_asr import OWSMModel
12
+ from espnet2.sds.asr.owsm_ctc_asr import OWSMCTCModel
13
+ from espnet2.sds.asr.whisper_asr import WhisperASRModel
14
+ from espnet2.sds.tts.espnet_tts import ESPnetTTSModel
15
+ from espnet2.sds.tts.chat_tts import ChatTTSModel
16
+ from espnet2.sds.llm.hugging_face_llm import HuggingFaceLLM
17
+ from espnet2.sds.vad.webrtc_vad import WebrtcVADModel
18
+ from espnet2.sds.eval.TTS_intelligibility import handle_espnet_TTS_intelligibility
19
+ from espnet2.sds.eval.ASR_WER import handle_espnet_ASR_WER
20
+ from espnet2.sds.eval.TTS_speech_quality import TTS_psuedomos
21
+ from espnet2.sds.eval.LLM_Metrics import perplexity, vert, bert_score, DialoGPT_perplexity
22
+ from espnet2.sds.utils.chat import Chat
23
+ import argparse
24
+
25
+ access_token = os.environ.get("HF_TOKEN")
26
+ ASR_name="pyf98/owsm_ctc_v3.1_1B"
27
+ LLM_name="meta-llama/Llama-3.2-1B-Instruct"
28
+ TTS_name="kan-bayashi/ljspeech_vits"
29
+ ASR_options="pyf98/owsm_ctc_v3.1_1B,espnet/owsm_ctc_v3.2_ft_1B,espnet/owsm_v3.1_ebf,librispeech_asr,whisper".split(",")
30
+ LLM_options="meta-llama/Llama-3.2-1B-Instruct,HuggingFaceTB/SmolLM2-1.7B-Instruct".split(",")
31
+ TTS_options="kan-bayashi/ljspeech_vits,kan-bayashi/libritts_xvector_vits,kan-bayashi/vctk_multi_spk_vits,ChatTTS".split(",")
32
+ Eval_options="Latency,TTS Intelligibility,TTS Speech Quality,ASR WER,Text Dialog Metrics"
33
+ upload_to_hub="Siddhant/Cascaded_demo_data"
34
+ # def read_args():
35
+ # global access_token
36
+ # global ASR_name
37
+ # global LLM_name
38
+ # global TTS_name
39
+ # global ASR_options
40
+ # global LLM_options
41
+ # global TTS_options
42
+ # global Eval_options
43
+ # global upload_to_hub
44
+ # parser = argparse.ArgumentParser(description="Run the app with HF_TOKEN as a command-line argument.")
45
+ # parser.add_argument("--HF_TOKEN", required=True, help="Provide the Hugging Face token.")
46
+ # parser.add_argument("--asr_options", required=True, help="Provide the possible ASR options available to user.")
47
+ # parser.add_argument("--llm_options", required=True, help="Provide the possible LLM options available to user.")
48
+ # parser.add_argument("--tts_options", required=True, help="Provide the possible TTS options available to user.")
49
+ # parser.add_argument("--eval_options", required=True, help="Provide the possible automatic evaluation metrics available to user.")
50
+ # parser.add_argument("--default_asr_model", required=False, default="pyf98/owsm_ctc_v3.1_1B", help="Provide the default ASR model.")
51
+ # parser.add_argument("--default_llm_model", required=False, default="meta-llama/Llama-3.2-1B-Instruct", help="Provide the default ASR model.")
52
+ # parser.add_argument("--default_tts_model", required=False, default="kan-bayashi/ljspeech_vits", help="Provide the default ASR model.")
53
+ # parser.add_argument("--upload_to_hub", required=False, default=None, help="Hugging Face dataset to upload user data")
54
+ # args = parser.parse_args()
55
+ # access_token=args.HF_TOKEN
56
+ # ASR_name=args.default_asr_model
57
+ # LLM_name=args.default_llm_model
58
+ # TTS_name=args.default_tts_model
59
+ # ASR_options=args.asr_options.split(",")
60
+ # LLM_options=args.llm_options.split(",")
61
+ # TTS_options=args.tts_options.split(",")
62
+ # Eval_options=args.eval_options.split(",")
63
+ # upload_to_hub=args.upload_to_hub
64
+
65
+ # read_args()
66
+ from huggingface_hub import HfApi
67
+
68
+ api = HfApi()
69
+ import nltk
70
+ nltk.download('averaged_perceptron_tagger_eng')
71
+ import gradio as gr
72
+
73
+
74
+ import numpy as np
75
+
76
+ chat = Chat(2)
77
+ chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. The user is talking to you with their voice and you should respond in a conversational style. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
78
+ user_role = "user"
79
+
80
+ text2speech=None
81
+ s2t=None
82
+ LM_pipe=None
83
+
84
+ latency_ASR=0.0
85
+ latency_LM=0.0
86
+ latency_TTS=0.0
87
+
88
+ text_str=""
89
+ asr_output_str=""
90
+ vad_output=None
91
+ audio_output = None
92
+ audio_output1 = None
93
+ LLM_response_arr=[]
94
+ total_response_arr=[]
95
+
96
+ def handle_selection(option):
97
+ yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
98
+ global text2speech
99
+ tag = option
100
+ if tag=="ChatTTS":
101
+ text2speech = ChatTTSModel()
102
+ else:
103
+ text2speech = ESPnetTTSModel(tag)
104
+ text2speech.warmup()
105
+ yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
106
+
107
+ def handle_LLM_selection(option):
108
+ yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
109
+ global LM_pipe
110
+ LM_pipe = HuggingFaceLLM(access_token=access_token,tag = option)
111
+ LM_pipe.warmup()
112
+ yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
113
+
114
+ def handle_ASR_selection(option):
115
+ yield gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False)
116
+ if option=="librispeech_asr":
117
+ option="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp"
118
+ global s2t
119
+ if option=="espnet/owsm_v3.1_ebf":
120
+ s2t = OWSMModel()
121
+ elif option=="espnet/simpleoier_librispeech_asr_train_asr_conformer7_wavlm_large_raw_en_bpe5000_sp":
122
+ s2t = ESPnetASRModel(tag=option)
123
+ elif option=="whisper":
124
+ s2t = WhisperASRModel()
125
+ else:
126
+ s2t = OWSMCTCModel(tag=option)
127
+
128
+ s2t.warmup()
129
+ yield gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True)
130
+
131
+ def handle_eval_selection(option, TTS_audio_output, LLM_Output, ASR_audio_output, ASR_transcript):
132
+ global LLM_response_arr
133
+ global total_response_arr
134
+ yield (option,gr.Textbox(visible=True))
135
+ if option=="Latency":
136
+ text=f"ASR Latency: {latency_ASR:.2f}\nLLM Latency: {latency_LM:.2f}\nTTS Latency: {latency_TTS:.2f}"
137
+ yield (None,text)
138
+ elif option=="TTS Intelligibility":
139
+ yield (None,handle_espnet_TTS_intelligibility(TTS_audio_output,LLM_Output))
140
+ elif option=="TTS Speech Quality":
141
+ yield (None,TTS_psuedomos(TTS_audio_output))
142
+ elif option=="ASR WER":
143
+ yield (None,handle_espnet_ASR_WER(ASR_audio_output, ASR_transcript))
144
+ elif option=="Text Dialog Metrics":
145
+ yield (None,perplexity(LLM_Output.replace("\n"," "))+vert(LLM_response_arr)+bert_score(total_response_arr)+DialoGPT_perplexity(ASR_transcript.replace("\n"," "),LLM_Output.replace("\n"," ")))
146
+
147
+ for _ in handle_selection(TTS_name):
148
+ continue
149
+ for _ in handle_ASR_selection(ASR_name):
150
+ continue
151
+ for _ in handle_LLM_selection(LLM_name):
152
+ continue
153
+ vad_model=WebrtcVADModel()
154
+
155
+ callback = gr.CSVLogger()
156
+ start_record_time=None
157
+ enable_btn = gr.Button(interactive=True, visible=True)
158
+ disable_btn = gr.Button(interactive=False, visible=False)
159
+ def flash_buttons():
160
+ btn_updates = (enable_btn,) * 8
161
+ print(enable_btn)
162
+ yield ("","",)+btn_updates
163
+
164
+
165
+ def get_ip(request: gr.Request):
166
+ if "cf-connecting-ip" in request.headers:
167
+ ip = request.headers["cf-connecting-ip"]
168
+ elif "x-forwarded-for" in request.headers:
169
+ ip = request.headers["x-forwarded-for"]
170
+ if "," in ip:
171
+ ip = ip.split(",")[0]
172
+ else:
173
+ ip = request.client.host
174
+ return ip
175
+
176
+
177
+ def vote_last_response(vote_type, request: gr.Request):
178
+ with open("save_dict.json", "a") as fout:
179
+ data = {
180
+ "tstamp": round(time.time(), 4),
181
+ "type": vote_type,
182
+ "ip": get_ip(request),
183
+ }
184
+ fout.write(json.dumps(data) + "\n")
185
+
186
+
187
+ def natural_vote1_last_response(
188
+ request: gr.Request
189
+ ):
190
+ ip_address1=get_ip(request)
191
+ print(f"Very Natural (voted). ip: {ip_address1}")
192
+ return ("Very Natural",ip_address1,)+(disable_btn,) * 4
193
+
194
+ def natural_vote2_last_response(
195
+ request: gr.Request
196
+ ):
197
+ ip_address1=get_ip(request)
198
+ print(f"Somewhat Awkward (voted). ip: {ip_address1}")
199
+ return ("Somewhat Awkward",ip_address1,)+(disable_btn,) * 4
200
+
201
+ def natural_vote3_last_response(
202
+ request: gr.Request
203
+ ):
204
+ ip_address1=get_ip(request)
205
+ print(f"Very Awkward (voted). ip: {ip_address1}")
206
+ return ("Very Awkward",ip_address1,)+(disable_btn,) * 4
207
+
208
+ def natural_vote4_last_response(
209
+ request: gr.Request
210
+ ):
211
+ ip_address1=get_ip(request)
212
+ print(f"Unnatural (voted). ip: {ip_address1}")
213
+ return ("Unnatural",ip_address1,)+(disable_btn,) * 4
214
+
215
+ def relevant_vote1_last_response(
216
+ request: gr.Request
217
+ ):
218
+ ip_address1=get_ip(request)
219
+ print(f"Highly Relevant (voted). ip: {ip_address1}")
220
+ return ("Highly Relevant",ip_address1,)+(disable_btn,) * 4
221
+
222
+ def relevant_vote2_last_response(
223
+ request: gr.Request
224
+ ):
225
+ ip_address1=get_ip(request)
226
+ print(f"Partially Relevant (voted). ip: {ip_address1}")
227
+ return ("Partially Relevant",ip_address1,)+(disable_btn,) * 4
228
+
229
+ def relevant_vote3_last_response(
230
+ request: gr.Request
231
+ ):
232
+ ip_address1=get_ip(request)
233
+ print(f"Slightly Irrelevant (voted). ip: {ip_address1}")
234
+ return ("Slightly Irrelevant",ip_address1,)+(disable_btn,) * 4
235
+
236
+ def relevant_vote4_last_response(
237
+ request: gr.Request
238
+ ):
239
+ ip_address1=get_ip(request)
240
+ print(f"Completely Irrelevant (voted). ip: {ip_address1}")
241
+ return ("Completely Irrelevant",ip_address1,)+(disable_btn,) * 4
242
+
243
+ import json
244
+ import time
245
+
246
+ def transcribe(stream, new_chunk, option, asr_option):
247
+ sr, y = new_chunk
248
+ global text_str
249
+ global chat
250
+ global user_role
251
+ global audio_output
252
+ global audio_output1
253
+ global vad_output
254
+ global asr_output_str
255
+ global start_record_time
256
+ global sids
257
+ global spembs
258
+ global latency_ASR
259
+ global latency_LM
260
+ global latency_TTS
261
+ global LLM_response_arr
262
+ global total_response_arr
263
+ if stream is None:
264
+ stream=y
265
+ chat.init_chat({"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."})
266
+ text_str=""
267
+ audio_output = None
268
+ audio_output1 = None
269
+ else:
270
+ stream=np.concatenate((stream,y))
271
+ orig_sr=sr
272
+ sr=16000
273
+ array=vad_model(y,orig_sr)
274
+
275
+ if array is not None:
276
+ print("VAD: end of speech detected")
277
+ start_time = time.time()
278
+ prompt=s2t(array)
279
+ if len(prompt.strip().split())<2:
280
+ text_str1=text_str
281
+ yield (stream, asr_output_str, text_str1, audio_output, audio_output1)
282
+ return
283
+
284
+
285
+ asr_output_str=prompt
286
+ total_response_arr.append(prompt.replace("\n"," "))
287
+ start_LM_time=time.time()
288
+ latency_ASR=(start_LM_time - start_time)
289
+ chat.append({"role": user_role, "content": prompt})
290
+ chat_messages = chat.to_list()
291
+ generated_text = LM_pipe(chat_messages)
292
+ start_TTS_time=time.time()
293
+ latency_LM=(start_TTS_time - start_LM_time)
294
+
295
+ chat.append({"role": "assistant", "content": generated_text})
296
+ text_str=generated_text
297
+ LLM_response_arr.append(text_str.replace("\n"," "))
298
+ total_response_arr.append(text_str.replace("\n"," "))
299
+ audio_output=text2speech(text_str)
300
+ audio_output1=(orig_sr,stream)
301
+ stream=y
302
+ latency_TTS=(time.time() - start_TTS_time)
303
+ text_str1=text_str
304
+ if ((text_str!="") and (start_record_time is None)):
305
+ start_record_time=time.time()
306
+ elif start_record_time is not None:
307
+ current_record_time=time.time()
308
+ if current_record_time-start_record_time>300:
309
+ gr.Info("Conversations are limited to 5 minutes. The session will restart in approximately 60 seconds. Please wait for the demo to reset. Close this message once you have read it.", duration=None)
310
+ yield stream,gr.Textbox(visible=False),gr.Textbox(visible=False),gr.Audio(visible=False),gr.Audio(visible=False)
311
+ if upload_to_hub is not None:
312
+ api.upload_folder(
313
+ folder_path="flagged_data_points",
314
+ path_in_repo="checkpoint_"+str(start_record_time),
315
+ repo_id=upload_to_hub,
316
+ repo_type="dataset",
317
+ token=access_token,
318
+ )
319
+ chat.buffer=[{"role": "system", "content": "You are a helpful and friendly AI assistant. You are polite, respectful, and aim to provide concise and complete responses of less than 15 words."}]
320
+ text_str=""
321
+ audio_output = None
322
+ audio_output1 = None
323
+ asr_output_str = ""
324
+ start_record_time = None
325
+ LLM_response_arr=[]
326
+ total_response_arr=[]
327
+ shutil.rmtree('flagged_data_points')
328
+ os.mkdir("flagged_data_points")
329
+ yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
330
+ yield stream,gr.Textbox(visible=True),gr.Textbox(visible=True),gr.Audio(visible=True),gr.Audio(visible=False)
331
+
332
+ yield (stream,asr_output_str,text_str1, audio_output, audio_output1)
333
+
334
+
335
+ with gr.Blocks(
336
+ title="E2E Spoken Dialog System",
337
+ ) as demo:
338
+ with gr.Row():
339
+ with gr.Column(scale=1):
340
+ user_audio = gr.Audio(sources=["microphone"], streaming=True, waveform_options=gr.WaveformOptions(sample_rate=16000))
341
+ with gr.Row():
342
+ ASR_radio = gr.Radio(
343
+ choices=ASR_options,
344
+ label="Choose ASR:",
345
+ value=ASR_name,
346
+ )
347
+ with gr.Row():
348
+ LLM_radio = gr.Radio(
349
+ choices=LLM_options,
350
+ label="Choose LLM:",
351
+ value=LLM_name,
352
+ )
353
+ with gr.Row():
354
+ radio = gr.Radio(
355
+ choices=TTS_options,
356
+ label="Choose TTS:",
357
+ value=TTS_name,
358
+ )
359
+ with gr.Row():
360
+ feedback_btn = gr.Button(
361
+ value="Please provide your feedback after each system response below.", visible=True, interactive=False, elem_id="button"
362
+ )
363
+ with gr.Row():
364
+ natural_btn1 = gr.Button(
365
+ value="Very Natural", visible=False, interactive=False, scale=1
366
+ )
367
+ natural_btn2 = gr.Button(
368
+ value="Somewhat Awkward", visible=False, interactive=False, scale=1
369
+ )
370
+ natural_btn3 = gr.Button(value="Very Awkward", visible=False, interactive=False, scale=1)
371
+ natural_btn4 = gr.Button(
372
+ value="Unnatural", visible=False, interactive=False, scale=1
373
+ )
374
+ with gr.Row():
375
+ relevant_btn1 = gr.Button(
376
+ value="Highly Relevant", visible=False, interactive=False, scale=1
377
+ )
378
+ relevant_btn2 = gr.Button(
379
+ value="Partially Relevant", visible=False, interactive=False, scale=1
380
+ )
381
+ relevant_btn3 = gr.Button(value="Slightly Irrelevant", visible=False, interactive=False, scale=1)
382
+ relevant_btn4 = gr.Button(
383
+ value= "Completely Irrelevant", visible=False, interactive=False, scale=1
384
+ )
385
+ with gr.Column(scale=1):
386
+ output_audio = gr.Audio(label="Output", autoplay=True, visible=True)
387
+ output_audio1 = gr.Audio(label="Output1", autoplay=False, visible=False)
388
+ output_asr_text = gr.Textbox(label="ASR output")
389
+ output_text = gr.Textbox(label="LLM output")
390
+ eval_radio = gr.Radio(
391
+ choices=["Latency", "TTS Intelligibility", "TTS Speech Quality", "ASR WER","Text Dialog Metrics"],
392
+ label="Choose Evaluation metrics:",
393
+ )
394
+ output_eval_text = gr.Textbox(label="Evaluation Results")
395
+ state = gr.State()
396
+ with gr.Row():
397
+ privacy_text = gr.Textbox(label="Privacy Notice",interactive=False, value="By using this demo, you acknowledge that interactions with this dialog system are collected for research and improvement purposes. The data will only be used to enhance the performance and understanding of the system. If you have any concerns about data collection, please discontinue use.")
398
+
399
+ btn_list=[
400
+ natural_btn1,
401
+ natural_btn2,
402
+ natural_btn3,
403
+ natural_btn4,
404
+ relevant_btn1,
405
+ relevant_btn2,
406
+ relevant_btn3,
407
+ relevant_btn4,
408
+ ]
409
+ natural_btn_list=[
410
+ natural_btn1,
411
+ natural_btn2,
412
+ natural_btn3,
413
+ natural_btn4,
414
+ ]
415
+ relevant_btn_list=[
416
+ relevant_btn1,
417
+ relevant_btn2,
418
+ relevant_btn3,
419
+ relevant_btn4,
420
+ ]
421
+ natural_response = gr.Textbox(label="natural_response",visible=False,interactive=False)
422
+ diversity_response = gr.Textbox(label="diversity_response",visible=False,interactive=False)
423
+ ip_address = gr.Textbox(label="ip_address",visible=False,interactive=False)
424
+ callback.setup([user_audio, output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address],"flagged_data_points")
425
+ user_audio.stream(transcribe, inputs=[state, user_audio, radio, ASR_radio], outputs=[state, output_asr_text, output_text, output_audio, output_audio1]).then(lambda *args: callback.flag(list(args)),[user_audio], None,preprocess=False)
426
+ radio.change(fn=handle_selection, inputs=[radio], outputs=[output_asr_text, output_text, output_audio])
427
+ LLM_radio.change(fn=handle_LLM_selection, inputs=[LLM_radio], outputs=[output_asr_text, output_text, output_audio])
428
+ ASR_radio.change(fn=handle_ASR_selection, inputs=[ASR_radio], outputs=[output_asr_text, output_text, output_audio])
429
+ eval_radio.change(fn=handle_eval_selection, inputs=[eval_radio,output_audio,output_text,output_audio1,output_asr_text], outputs=[eval_radio,output_eval_text])
430
+ output_audio.play(
431
+ flash_buttons, [], [natural_response,diversity_response]+btn_list
432
+ ).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio], None,preprocess=False)
433
+ natural_btn1.click(natural_vote1_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
434
+ natural_btn2.click(natural_vote2_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
435
+ natural_btn3.click(natural_vote3_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
436
+ natural_btn4.click(natural_vote4_last_response,[],[natural_response,ip_address]+natural_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
437
+ relevant_btn1.click(relevant_vote1_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
438
+ relevant_btn2.click(relevant_vote2_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
439
+ relevant_btn3.click(relevant_vote3_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
440
+ relevant_btn4.click(relevant_vote4_last_response,[],[diversity_response,ip_address]+relevant_btn_list).then(lambda *args: callback.flag(list(args)),[user_audio,output_asr_text, output_text, output_audio,output_audio1,ASR_radio,LLM_radio,radio,natural_response,diversity_response,ip_address], None,preprocess=False)
441
+
442
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ typeguard==2.13.3
2
+ espnet @ git+https://github.com/siddhu001/espnet@sds_demo_recipe
3
+ espnet_model_zoo
4
+ huggingface_hub==0.23.2
5
+ transformers[sentencepiece]
6
+ sentencepiece
7
+ datasets
8
+ torch==2.5.1
9
+ torchaudio==2.5.1
10
+ librosa
11
+ sounddevice==0.5.0
12
+ webrtcvad-wheels
13
+ webrtcvad==2.0.10
14
+ gradio==4.43.0
15
+ ChatTTS
16
+ evaluate
17
+ snac==1.2.0
18
+ litgpt==0.4.3
versa.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ git clone https://github.com/shinjiwlab/versa.git
2
+ cd versa
3
+ pip install .
4
+ cd ..