Will Held commited on
Commit
8f64bcf
1 Parent(s): 33e9a7f

Replace App

Browse files
Files changed (2) hide show
  1. app.py +394 -50
  2. demo.py +0 -407
app.py CHANGED
@@ -1,63 +1,407 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
8
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ),
58
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
+ import copy
2
+ import os
3
+ import random
4
+ import sys
5
+
6
+
7
+ import spaces
8
  import gradio as gr
9
+ import librosa
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from accelerate import infer_auto_device_map
15
+ from datasets import Audio
16
+ from models.salmonn import SALMONN
17
+ from safetensors.torch import load, load_model
18
+ from tinydb import TinyDB
19
+ from torch import nn
20
+ from transformers import (
21
+ AutoModelForCausalLM,
22
+ AutoProcessor,
23
+ AutoModel,
24
+ AutoTokenizer,
25
+ LlamaForCausalLM,
26
+ TextIteratorStreamer,
27
+ WhisperForConditionalGeneration,
28
+ )
29
+ from transformers.generation import GenerationConfig
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
32
+ prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to("cuda")
33
+ pre_user_suffix = torch.tensor([271]).to("cuda")
34
+ final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to("cuda")
35
+ cache = None
36
+ anonymous = False
37
+
38
+ resampler = Audio(sampling_rate=16_000)
39
+
40
+
41
+ qwen_tokenizer = AutoTokenizer.from_pretrained(
42
+ "Qwen/Qwen-Audio-Chat", trust_remote_code=True
43
+ )
44
+ qwen_model = AutoModelForCausalLM.from_pretrained(
45
+ "Qwen/Qwen-Audio-Chat",
46
+ device_map="auto",
47
+ trust_remote_code=True,
48
+ torch_dtype=torch.float16,
49
+ ).eval()
50
+
51
+ qwen_model.generation_config = GenerationConfig.from_pretrained(
52
+ "Qwen/Qwen-Audio-Chat",
53
+ trust_remote_code=True,
54
+ do_sample=False,
55
+ top_k=50,
56
+ top_p=1.0,
57
+ )
58
+
59
 
60
+ # salmonn_model = SALMONN(
61
+ # ckpt="./SALMONN_PATHS/salmonn_v1.pth",
62
+ # whisper_path="./SALMONN_PATHS/whisper-large-v2",
63
+ # beats_path="./SALMONN_PATHS/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt",
64
+ # vicuna_path="./SALMONN_PATHS/vicuna-13b-v1.1",
65
+ # low_resource=False,
66
+ # device="cuda:0",
67
+ # )
68
+ # salmonn_tokenizer = salmonn_model.llama_tokenizer
69
 
70
 
71
+ diva = AutoModel.from_pretrained("WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True)
 
 
 
 
 
 
 
 
72
 
 
 
 
 
 
73
 
74
+
75
+ @spaces.GPU
76
+ @torch.no_grad
77
+ def salmonn_fwd(audio_input, prompt, do_sample=False, temperature=0.001):
78
+ if audio_input == None:
79
+ return ""
80
+ sr, y = audio_input
81
+ y = y.astype(np.float32)
82
+ y /= np.max(np.abs(y))
83
+ a = resampler.decode_example(
84
+ resampler.encode_example({"array": y, "sampling_rate": sr})
85
+ )
86
+ sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
87
+ streamer = TextIteratorStreamer(salmonn_tokenizer)
88
+ with torch.cuda.amp.autocast(dtype=torch.float16):
89
+ llm_message = salmonn_model.generate(
90
+ wav_path="tmp.wav",
91
+ prompt=prompt,
92
+ do_sample=False,
93
+ top_p=1.0,
94
+ temperature=0.0,
95
+ device="cuda:0",
96
+ streamer=streamer,
97
+ )
98
 
99
  response = ""
100
+ for new_tokens in streamer:
101
+ response += new_tokens
102
+ yield response.replace("</s>", "")
103
+
104
+
105
+ @spaces.GPU
106
+ @torch.no_grad
107
+ def qwen_audio(audio_input, prompt, do_sample=False, temperature=0.001):
108
+ if audio_input == None:
109
+ return ""
110
+ sr, y = audio_input
111
+ y = y.astype(np.float32)
112
+ y /= np.max(np.abs(y))
113
+ a = resampler.decode_example(
114
+ resampler.encode_example({"array": y, "sampling_rate": sr})
115
+ )
116
+ sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
117
+ query = qwen_tokenizer.from_list_format([{"audio": "tmp.wav"}, {"text": prompt}])
118
+
119
+ response, history = qwen_model.chat(
120
+ qwen_tokenizer,
121
+ query=query,
122
+ system="You are a helpful assistant.",
123
+ history=None,
124
+ )
125
+ return response
126
+
127
+
128
+ @spaces.GPU
129
+ @torch.no_grad
130
+ def via(audio_input, prompt, do_sample=False, temperature=0.001):
131
+ if audio_input == None:
132
+ return ""
133
+ sr, y = audio_input
134
+ y = y.astype(np.float32)
135
+ y /= np.max(np.abs(y))
136
+ a = resampler.decode_example(
137
+ resampler.encode_example({"array": y, "sampling_rate": sr})
138
+ )
139
+
140
+ audio = a["array"]
141
+
142
+ yield from diva.generate_stream(audio, prompt)
143
+
144
+
145
+ def transcribe(audio_input, text_prompt, state, model_order):
146
+ yield (
147
+ gr.Button(
148
+ value="Waiting in queue for GPU time...",
149
+ interactive=False,
150
+ variant="primary",
151
+ ),
152
+ "",
153
+ "",
154
+ "",
155
+ gr.Button(visible=False),
156
+ gr.Button(visible=False),
157
+ gr.Button(visible=False),
158
+ state,
159
+ )
160
+ if audio_input == None:
161
+ return (
162
+ "",
163
+ "",
164
+ "",
165
+ gr.Button(visible=False),
166
+ gr.Button(visible=False),
167
+ gr.Button(visible=False),
168
+ state,
169
+ )
170
+
171
+ def gen_from_via():
172
+ via_resp = via(audio_input, text_prompt)
173
+ for resp in via_resp:
174
+ v_resp = gr.Textbox(
175
+ value=resp,
176
+ visible=True,
177
+ label=model_names[0] if not anonymous else f"Model {order}",
178
+ )
179
+ yield (v_resp, s_resp, q_resp)
180
+
181
+ def gen_from_salmonn():
182
+ salmonn_resp = salmonn_fwd(audio_input, text_prompt)
183
+ for resp in salmonn_resp:
184
+ s_resp = gr.Textbox(
185
+ value=resp,
186
+ visible=True,
187
+ label=model_names[1] if not anonymous else f"Model {order}",
188
+ )
189
+ yield (v_resp, s_resp, q_resp)
190
 
191
+ def gen_from_qwen():
192
+ qwen_resp = qwen_audio(audio_input, text_prompt)
193
+ q_resp = gr.Textbox(
194
+ value=qwen_resp,
195
+ visible=True,
196
+ label=model_names[2] if not anonymous else f"Model {order}",
197
+ )
198
+ yield (v_resp, s_resp, q_resp)
199
+
200
+ spinner_id = 0
201
+ spinners = ["◐ ", "◓ ", "◑", "◒"]
202
+ initial_responses = [("", "", "")]
203
+ resp_generators = [
204
+ gen_from_via(),
205
+ #gen_from_salmonn(),
206
+ gen_from_qwen(),
207
+ ]
208
+ order = -1
209
+ resp_generators = [
210
+ resp_generators[model_order[0]],
211
+ resp_generators[model_order[1]],
212
+ resp_generators[model_order[2]],
213
+ ]
214
+ for generator in [initial_responses, *resp_generators]:
215
+ order += 1
216
+ for resps in generator:
217
+ v_resp, s_resp, q_resp = resps
218
+ resp_1 = resps[model_order[0]]
219
+ resp_2 = resps[model_order[1]]
220
+ resp_3 = resps[model_order[2]]
221
+ spinner = spinners[spinner_id]
222
+ spinner_id = (spinner_id + 1) % 4
223
+ yield (
224
+ gr.Button(
225
+ value=spinner + " Generating Responses " + spinner,
226
+ interactive=False,
227
+ variant="primary",
228
+ ),
229
+ resp_1,
230
+ resp_2,
231
+ resp_3,
232
+ gr.Button(visible=False),
233
+ gr.Button(visible=False),
234
+ gr.Button(visible=False),
235
+ state,
236
+ )
237
+ yield (
238
+ gr.Button(
239
+ value="Click to compare models!", interactive=True, variant="primary"
240
  ),
241
+ resp_1,
242
+ resp_2,
243
+ resp_3,
244
+ gr.Button(visible=True),
245
+ gr.Button(visible=False),
246
+ gr.Button(visible=True),
247
+ responses_complete(state),
248
+ )
249
+
250
+
251
+ def on_page_load(state, model_order):
252
+ if state == 0:
253
+ gr.Info(
254
+ "Record what you want to say to your AI Assistant! All Audio recordings are stored only temporarily and will be erased as soon as you exit this page."
255
+ )
256
+ state = 1
257
+ if anonymous:
258
+ random.shuffle(model_order)
259
+ return state, model_order
260
+
261
+
262
+ def recording_complete(state):
263
+ if state == 1:
264
+ gr.Info(
265
+ "Submit your recording to get responses from all three models! You can also influence the model responses with an optional prompt."
266
+ )
267
+ state = 2
268
+ return (
269
+ gr.Button(
270
+ value="Click to compare models!", interactive=True, variant="primary"
271
+ ),
272
+ state,
273
+ )
274
+
275
+
276
+ def responses_complete(state):
277
+ if state == 2:
278
+ gr.Info(
279
+ "Give us your feedback! Mark which model gave you the best response so we can understand the quality of these different voice assistant models."
280
+ )
281
+ state = 3
282
+ return state
283
+
284
+
285
+ def clear_factory(button_id):
286
+ def clear(audio_input, text_prompt, model_order):
287
+ if button_id != None:
288
+ sr, y = audio_input
289
+ db.insert(
290
+ {
291
+ "audio_hash": hash(str(y)),
292
+ "text_prompt": text_prompt,
293
+ "best": model_shorthand[model_order[button_id]],
294
+ }
295
+ )
296
+ if anonymous:
297
+ random.shuffle(model_order)
298
+ return (
299
+ model_order,
300
+ gr.Button(
301
+ value="Record Audio to Submit!",
302
+ interactive=False,
303
+ ),
304
+ gr.Button(visible=False),
305
+ gr.Button(visible=False),
306
+ gr.Button(visible=False),
307
+ None,
308
+ gr.Textbox(visible=False),
309
+ gr.Textbox(visible=False),
310
+ gr.Textbox(visible=False),
311
+ )
312
+
313
+ return clear
314
+
315
+
316
+ theme = gr.themes.Soft(
317
+ primary_hue=gr.themes.Color(
318
+ c100="#82000019",
319
+ c200="#82000033",
320
+ c300="#8200004c",
321
+ c400="#82000066",
322
+ c50="#8200007f",
323
+ c500="#8200007f",
324
+ c600="#82000099",
325
+ c700="#820000b2",
326
+ c800="#820000cc",
327
+ c900="#820000e5",
328
+ c950="#820000f2",
329
+ ),
330
+ secondary_hue="rose",
331
+ neutral_hue="stone",
332
  )
333
 
334
+ db = TinyDB("user_study.json")
335
+
336
+ model_names = ["Llama 3 DiVA", "SALMONN", "Qwen Audio"]
337
+ model_shorthand = ["via", "salmonn", "qwen"]
338
+ with gr.Blocks(theme=theme) as demo:
339
+ state = gr.State(0)
340
+ model_order = gr.State([0, 1, 2])
341
+ with gr.Row():
342
+ audio_input = gr.Audio(
343
+ sources=["microphone"], streaming=False, label="Audio Input"
344
+ )
345
+ with gr.Row():
346
+ prompt = gr.Textbox(
347
+ value="",
348
+ label="Text Prompt",
349
+ placeholder="Optional: Additional text prompt to influence how the model responds to your speech. e.g. 'Respond in a Haiku style.'",
350
+ )
351
+
352
+ with gr.Row():
353
+ btn = gr.Button(value="Record Audio to Submit!", interactive=False)
354
+
355
+ with gr.Row():
356
+ with gr.Column(scale=1):
357
+ out1 = gr.Textbox(visible=False)
358
+ best1 = gr.Button(value="This response is best", visible=False)
359
+ with gr.Column(scale=1):
360
+ out2 = gr.Textbox(visible=False)
361
+ best2 = gr.Button(value="This response is best", visible=False)
362
+ with gr.Column(scale=1):
363
+ out3 = gr.Textbox(visible=False)
364
+ best3 = gr.Button(value="This response is best", visible=False)
365
+
366
+ audio_input.stop_recording(
367
+ recording_complete,
368
+ [state],
369
+ [btn, state],
370
+ )
371
+ audio_input.start_recording(
372
+ lambda: gr.Button(
373
+ value="Uploading Audio to Cloud", interactive=False, variant="primary"
374
+ ),
375
+ None,
376
+ btn,
377
+ )
378
+ btn.click(
379
+ fn=transcribe,
380
+ inputs=[audio_input, prompt, state, model_order],
381
+ outputs=[btn, out1, out2, out3, best1, best2, best3, state],
382
+ )
383
+ best1.click(
384
+ fn=clear_factory(0),
385
+ inputs=[audio_input, prompt, model_order],
386
+ outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
387
+ )
388
+ best2.click(
389
+ fn=clear_factory(1),
390
+ inputs=[audio_input, prompt, model_order],
391
+ outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
392
+ )
393
+ best3.click(
394
+ fn=clear_factory(2),
395
+ inputs=[audio_input, prompt, model_order],
396
+ outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
397
+ )
398
+ audio_input.clear(
399
+ clear_factory(None),
400
+ [audio_input, prompt, model_order],
401
+ [model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
402
+ )
403
+ demo.load(
404
+ fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order]
405
+ )
406
 
407
+ demo.launch(share=True)
 
demo.py DELETED
@@ -1,407 +0,0 @@
1
- import copy
2
- import os
3
- import random
4
- import sys
5
-
6
-
7
- import spaces
8
- import gradio as gr
9
- import librosa
10
- import numpy as np
11
- import soundfile as sf
12
- import torch
13
- import torch.nn.functional as F
14
- from accelerate import infer_auto_device_map
15
- from datasets import Audio
16
- from models.salmonn import SALMONN
17
- from safetensors.torch import load, load_model
18
- from tinydb import TinyDB
19
- from torch import nn
20
- from transformers import (
21
- AutoModelForCausalLM,
22
- AutoProcessor,
23
- AutoModel,
24
- AutoTokenizer,
25
- LlamaForCausalLM,
26
- TextIteratorStreamer,
27
- WhisperForConditionalGeneration,
28
- )
29
- from transformers.generation import GenerationConfig
30
-
31
- tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
32
- prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to("cuda")
33
- pre_user_suffix = torch.tensor([271]).to("cuda")
34
- final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to("cuda")
35
- cache = None
36
- anonymous = False
37
-
38
- resampler = Audio(sampling_rate=16_000)
39
-
40
-
41
- qwen_tokenizer = AutoTokenizer.from_pretrained(
42
- "Qwen/Qwen-Audio-Chat", trust_remote_code=True
43
- )
44
- qwen_model = AutoModelForCausalLM.from_pretrained(
45
- "Qwen/Qwen-Audio-Chat",
46
- device_map="auto",
47
- trust_remote_code=True,
48
- torch_dtype=torch.float16,
49
- ).eval()
50
-
51
- qwen_model.generation_config = GenerationConfig.from_pretrained(
52
- "Qwen/Qwen-Audio-Chat",
53
- trust_remote_code=True,
54
- do_sample=False,
55
- top_k=50,
56
- top_p=1.0,
57
- )
58
-
59
-
60
- # salmonn_model = SALMONN(
61
- # ckpt="./SALMONN_PATHS/salmonn_v1.pth",
62
- # whisper_path="./SALMONN_PATHS/whisper-large-v2",
63
- # beats_path="./SALMONN_PATHS/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt",
64
- # vicuna_path="./SALMONN_PATHS/vicuna-13b-v1.1",
65
- # low_resource=False,
66
- # device="cuda:0",
67
- # )
68
- # salmonn_tokenizer = salmonn_model.llama_tokenizer
69
-
70
-
71
- diva = AutoModel.from_pretrained("WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True)
72
-
73
-
74
-
75
- @spaces.GPU
76
- @torch.no_grad
77
- def salmonn_fwd(audio_input, prompt, do_sample=False, temperature=0.001):
78
- if audio_input == None:
79
- return ""
80
- sr, y = audio_input
81
- y = y.astype(np.float32)
82
- y /= np.max(np.abs(y))
83
- a = resampler.decode_example(
84
- resampler.encode_example({"array": y, "sampling_rate": sr})
85
- )
86
- sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
87
- streamer = TextIteratorStreamer(salmonn_tokenizer)
88
- with torch.cuda.amp.autocast(dtype=torch.float16):
89
- llm_message = salmonn_model.generate(
90
- wav_path="tmp.wav",
91
- prompt=prompt,
92
- do_sample=False,
93
- top_p=1.0,
94
- temperature=0.0,
95
- device="cuda:0",
96
- streamer=streamer,
97
- )
98
-
99
- response = ""
100
- for new_tokens in streamer:
101
- response += new_tokens
102
- yield response.replace("</s>", "")
103
-
104
-
105
- @spaces.GPU
106
- @torch.no_grad
107
- def qwen_audio(audio_input, prompt, do_sample=False, temperature=0.001):
108
- if audio_input == None:
109
- return ""
110
- sr, y = audio_input
111
- y = y.astype(np.float32)
112
- y /= np.max(np.abs(y))
113
- a = resampler.decode_example(
114
- resampler.encode_example({"array": y, "sampling_rate": sr})
115
- )
116
- sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav")
117
- query = qwen_tokenizer.from_list_format([{"audio": "tmp.wav"}, {"text": prompt}])
118
-
119
- response, history = qwen_model.chat(
120
- qwen_tokenizer,
121
- query=query,
122
- system="You are a helpful assistant.",
123
- history=None,
124
- )
125
- return response
126
-
127
-
128
- @spaces.GPU
129
- @torch.no_grad
130
- def via(audio_input, prompt, do_sample=False, temperature=0.001):
131
- if audio_input == None:
132
- return ""
133
- sr, y = audio_input
134
- y = y.astype(np.float32)
135
- y /= np.max(np.abs(y))
136
- a = resampler.decode_example(
137
- resampler.encode_example({"array": y, "sampling_rate": sr})
138
- )
139
-
140
- audio = a["array"]
141
-
142
- yield from diva.generate_stream(audio, prompt)
143
-
144
-
145
- def transcribe(audio_input, text_prompt, state, model_order):
146
- yield (
147
- gr.Button(
148
- value="Waiting in queue for GPU time...",
149
- interactive=False,
150
- variant="primary",
151
- ),
152
- "",
153
- "",
154
- "",
155
- gr.Button(visible=False),
156
- gr.Button(visible=False),
157
- gr.Button(visible=False),
158
- state,
159
- )
160
- if audio_input == None:
161
- return (
162
- "",
163
- "",
164
- "",
165
- gr.Button(visible=False),
166
- gr.Button(visible=False),
167
- gr.Button(visible=False),
168
- state,
169
- )
170
-
171
- def gen_from_via():
172
- via_resp = via(audio_input, text_prompt)
173
- for resp in via_resp:
174
- v_resp = gr.Textbox(
175
- value=resp,
176
- visible=True,
177
- label=model_names[0] if not anonymous else f"Model {order}",
178
- )
179
- yield (v_resp, s_resp, q_resp)
180
-
181
- def gen_from_salmonn():
182
- salmonn_resp = salmonn_fwd(audio_input, text_prompt)
183
- for resp in salmonn_resp:
184
- s_resp = gr.Textbox(
185
- value=resp,
186
- visible=True,
187
- label=model_names[1] if not anonymous else f"Model {order}",
188
- )
189
- yield (v_resp, s_resp, q_resp)
190
-
191
- def gen_from_qwen():
192
- qwen_resp = qwen_audio(audio_input, text_prompt)
193
- q_resp = gr.Textbox(
194
- value=qwen_resp,
195
- visible=True,
196
- label=model_names[2] if not anonymous else f"Model {order}",
197
- )
198
- yield (v_resp, s_resp, q_resp)
199
-
200
- spinner_id = 0
201
- spinners = ["◐ ", "◓ ", "◑", "◒"]
202
- initial_responses = [("", "", "")]
203
- resp_generators = [
204
- gen_from_via(),
205
- #gen_from_salmonn(),
206
- gen_from_qwen(),
207
- ]
208
- order = -1
209
- resp_generators = [
210
- resp_generators[model_order[0]],
211
- resp_generators[model_order[1]],
212
- resp_generators[model_order[2]],
213
- ]
214
- for generator in [initial_responses, *resp_generators]:
215
- order += 1
216
- for resps in generator:
217
- v_resp, s_resp, q_resp = resps
218
- resp_1 = resps[model_order[0]]
219
- resp_2 = resps[model_order[1]]
220
- resp_3 = resps[model_order[2]]
221
- spinner = spinners[spinner_id]
222
- spinner_id = (spinner_id + 1) % 4
223
- yield (
224
- gr.Button(
225
- value=spinner + " Generating Responses " + spinner,
226
- interactive=False,
227
- variant="primary",
228
- ),
229
- resp_1,
230
- resp_2,
231
- resp_3,
232
- gr.Button(visible=False),
233
- gr.Button(visible=False),
234
- gr.Button(visible=False),
235
- state,
236
- )
237
- yield (
238
- gr.Button(
239
- value="Click to compare models!", interactive=True, variant="primary"
240
- ),
241
- resp_1,
242
- resp_2,
243
- resp_3,
244
- gr.Button(visible=True),
245
- gr.Button(visible=False),
246
- gr.Button(visible=True),
247
- responses_complete(state),
248
- )
249
-
250
-
251
- def on_page_load(state, model_order):
252
- if state == 0:
253
- gr.Info(
254
- "Record what you want to say to your AI Assistant! All Audio recordings are stored only temporarily and will be erased as soon as you exit this page."
255
- )
256
- state = 1
257
- if anonymous:
258
- random.shuffle(model_order)
259
- return state, model_order
260
-
261
-
262
- def recording_complete(state):
263
- if state == 1:
264
- gr.Info(
265
- "Submit your recording to get responses from all three models! You can also influence the model responses with an optional prompt."
266
- )
267
- state = 2
268
- return (
269
- gr.Button(
270
- value="Click to compare models!", interactive=True, variant="primary"
271
- ),
272
- state,
273
- )
274
-
275
-
276
- def responses_complete(state):
277
- if state == 2:
278
- gr.Info(
279
- "Give us your feedback! Mark which model gave you the best response so we can understand the quality of these different voice assistant models."
280
- )
281
- state = 3
282
- return state
283
-
284
-
285
- def clear_factory(button_id):
286
- def clear(audio_input, text_prompt, model_order):
287
- if button_id != None:
288
- sr, y = audio_input
289
- db.insert(
290
- {
291
- "audio_hash": hash(str(y)),
292
- "text_prompt": text_prompt,
293
- "best": model_shorthand[model_order[button_id]],
294
- }
295
- )
296
- if anonymous:
297
- random.shuffle(model_order)
298
- return (
299
- model_order,
300
- gr.Button(
301
- value="Record Audio to Submit!",
302
- interactive=False,
303
- ),
304
- gr.Button(visible=False),
305
- gr.Button(visible=False),
306
- gr.Button(visible=False),
307
- None,
308
- gr.Textbox(visible=False),
309
- gr.Textbox(visible=False),
310
- gr.Textbox(visible=False),
311
- )
312
-
313
- return clear
314
-
315
-
316
- theme = gr.themes.Soft(
317
- primary_hue=gr.themes.Color(
318
- c100="#82000019",
319
- c200="#82000033",
320
- c300="#8200004c",
321
- c400="#82000066",
322
- c50="#8200007f",
323
- c500="#8200007f",
324
- c600="#82000099",
325
- c700="#820000b2",
326
- c800="#820000cc",
327
- c900="#820000e5",
328
- c950="#820000f2",
329
- ),
330
- secondary_hue="rose",
331
- neutral_hue="stone",
332
- )
333
-
334
- db = TinyDB("user_study.json")
335
-
336
- model_names = ["Llama 3 DiVA", "SALMONN", "Qwen Audio"]
337
- model_shorthand = ["via", "salmonn", "qwen"]
338
- with gr.Blocks(theme=theme) as demo:
339
- state = gr.State(0)
340
- model_order = gr.State([0, 1, 2])
341
- with gr.Row():
342
- audio_input = gr.Audio(
343
- sources=["microphone"], streaming=False, label="Audio Input"
344
- )
345
- with gr.Row():
346
- prompt = gr.Textbox(
347
- value="",
348
- label="Text Prompt",
349
- placeholder="Optional: Additional text prompt to influence how the model responds to your speech. e.g. 'Respond in a Haiku style.'",
350
- )
351
-
352
- with gr.Row():
353
- btn = gr.Button(value="Record Audio to Submit!", interactive=False)
354
-
355
- with gr.Row():
356
- with gr.Column(scale=1):
357
- out1 = gr.Textbox(visible=False)
358
- best1 = gr.Button(value="This response is best", visible=False)
359
- with gr.Column(scale=1):
360
- out2 = gr.Textbox(visible=False)
361
- best2 = gr.Button(value="This response is best", visible=False)
362
- with gr.Column(scale=1):
363
- out3 = gr.Textbox(visible=False)
364
- best3 = gr.Button(value="This response is best", visible=False)
365
-
366
- audio_input.stop_recording(
367
- recording_complete,
368
- [state],
369
- [btn, state],
370
- )
371
- audio_input.start_recording(
372
- lambda: gr.Button(
373
- value="Uploading Audio to Cloud", interactive=False, variant="primary"
374
- ),
375
- None,
376
- btn,
377
- )
378
- btn.click(
379
- fn=transcribe,
380
- inputs=[audio_input, prompt, state, model_order],
381
- outputs=[btn, out1, out2, out3, best1, best2, best3, state],
382
- )
383
- best1.click(
384
- fn=clear_factory(0),
385
- inputs=[audio_input, prompt, model_order],
386
- outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
387
- )
388
- best2.click(
389
- fn=clear_factory(1),
390
- inputs=[audio_input, prompt, model_order],
391
- outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
392
- )
393
- best3.click(
394
- fn=clear_factory(2),
395
- inputs=[audio_input, prompt, model_order],
396
- outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
397
- )
398
- audio_input.clear(
399
- clear_factory(None),
400
- [audio_input, prompt, model_order],
401
- [model_order, btn, best1, best2, best3, audio_input, out1, out2, out3],
402
- )
403
- demo.load(
404
- fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order]
405
- )
406
-
407
- demo.launch(share=True)