Spaces:
Running
on
Zero
Running
on
Zero
import copy | |
import os | |
import random | |
import shutil | |
import sys | |
from pathlib import Path | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import spaces | |
import torch | |
import torch.nn.functional as F | |
from accelerate import infer_auto_device_map | |
from datasets import Audio | |
from huggingface_hub import CommitScheduler, delete_file, hf_hub_download | |
from safetensors.torch import load, load_model | |
from tinydb import TinyDB | |
from torch import nn | |
from transformers import ( | |
AutoModel, | |
AutoModelForCausalLM, | |
AutoProcessor, | |
AutoTokenizer, | |
LlamaForCausalLM, | |
TextIteratorStreamer, | |
WhisperForConditionalGeneration, | |
) | |
from transformers.generation import GenerationConfig | |
from models.salmonn import SALMONN | |
DB_PATH = "user_study.json" | |
DB_DATASET_ID = "WillHeld/DiVAVotes" | |
# Download existing DB | |
if not os.path.isfile(DB_PATH): | |
print("Downloading DB...") | |
try: | |
cache_path = hf_hub_download( | |
repo_id=DB_DATASET_ID, repo_type="dataset", filename=DB_NAME | |
) | |
shutil.copyfile(cache_path, DB_PATH) | |
print("Downloaded DB") | |
except Exception as e: | |
print("Error while downloading DB:", e) | |
db = TinyDB(DB_PATH) | |
# Sync local DB with remote repo every 5 minute (only if a change is detected) | |
scheduler = CommitScheduler( | |
repo_id=DB_DATASET_ID, | |
repo_type="dataset", | |
folder_path=Path(DB_PATH).parent, | |
every=5, | |
allow_patterns=DB_NAME, | |
) | |
tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama") | |
prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to("cuda") | |
pre_user_suffix = torch.tensor([271]).to("cuda") | |
final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to("cuda") | |
cache = None | |
anonymous = False | |
resampler = Audio(sampling_rate=16_000) | |
qwen_tokenizer = AutoTokenizer.from_pretrained( | |
"Qwen/Qwen-Audio-Chat", trust_remote_code=True | |
) | |
qwen_model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen-Audio-Chat", | |
device_map="auto", | |
trust_remote_code=True, | |
torch_dtype=torch.float16, | |
).eval() | |
qwen_model.generation_config = GenerationConfig.from_pretrained( | |
"Qwen/Qwen-Audio-Chat", | |
trust_remote_code=True, | |
do_sample=False, | |
top_k=50, | |
top_p=1.0, | |
) | |
# salmonn_model = SALMONN( | |
# ckpt="./SALMONN_PATHS/salmonn_v1.pth", | |
# whisper_path="./SALMONN_PATHS/whisper-large-v2", | |
# beats_path="./SALMONN_PATHS/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt", | |
# vicuna_path="./SALMONN_PATHS/vicuna-13b-v1.1", | |
# low_resource=False, | |
# device="cuda:0", | |
# ) | |
# salmonn_tokenizer = salmonn_model.llama_tokenizer | |
diva = AutoModel.from_pretrained("WillHeld/DiVA-llama-3-v0-8b", trust_remote_code=True) | |
def salmonn_fwd(audio_input, prompt, do_sample=False, temperature=0.001): | |
if audio_input == None: | |
return "" | |
sr, y = audio_input | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
a = resampler.decode_example( | |
resampler.encode_example({"array": y, "sampling_rate": sr}) | |
) | |
sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav") | |
streamer = TextIteratorStreamer(salmonn_tokenizer) | |
with torch.cuda.amp.autocast(dtype=torch.float16): | |
llm_message = salmonn_model.generate( | |
wav_path="tmp.wav", | |
prompt=prompt, | |
do_sample=False, | |
top_p=1.0, | |
temperature=0.0, | |
device="cuda:0", | |
streamer=streamer, | |
) | |
response = "" | |
for new_tokens in streamer: | |
response += new_tokens | |
yield response.replace("</s>", "") | |
def qwen_audio(audio_input, prompt, do_sample=False, temperature=0.001): | |
if audio_input == None: | |
return "" | |
sr, y = audio_input | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
a = resampler.decode_example( | |
resampler.encode_example({"array": y, "sampling_rate": sr}) | |
) | |
sf.write("tmp.wav", a["array"], a["sampling_rate"], format="wav") | |
query = qwen_tokenizer.from_list_format([{"audio": "tmp.wav"}, {"text": prompt}]) | |
response, history = qwen_model.chat( | |
qwen_tokenizer, | |
query=query, | |
system="You are a helpful assistant.", | |
history=None, | |
) | |
return response | |
def via(audio_input, prompt, do_sample=False, temperature=0.001): | |
if audio_input == None: | |
return "" | |
sr, y = audio_input | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
a = resampler.decode_example( | |
resampler.encode_example({"array": y, "sampling_rate": sr}) | |
) | |
audio = a["array"] | |
yield from diva.generate_stream(audio, prompt) | |
def transcribe(audio_input, text_prompt, state, model_order): | |
yield ( | |
gr.Button( | |
value="Waiting in queue for GPU time...", | |
interactive=False, | |
variant="primary", | |
), | |
"", | |
"", | |
"", | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
state, | |
) | |
if audio_input == None: | |
return ( | |
"", | |
"", | |
"", | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
state, | |
) | |
def gen_from_via(): | |
via_resp = via(audio_input, text_prompt) | |
for resp in via_resp: | |
v_resp = gr.Textbox( | |
value=resp, | |
visible=True, | |
label=model_names[0] if not anonymous else f"Model {order}", | |
) | |
yield (v_resp, s_resp, q_resp) | |
def gen_from_salmonn(): | |
salmonn_resp = salmonn_fwd(audio_input, text_prompt) | |
for resp in salmonn_resp: | |
s_resp = gr.Textbox( | |
value=resp, | |
visible=True, | |
label=model_names[1] if not anonymous else f"Model {order}", | |
) | |
yield (v_resp, s_resp, q_resp) | |
def gen_from_qwen(): | |
qwen_resp = qwen_audio(audio_input, text_prompt) | |
q_resp = gr.Textbox( | |
value=qwen_resp, | |
visible=True, | |
label=model_names[2] if not anonymous else f"Model {order}", | |
) | |
yield (v_resp, s_resp, q_resp) | |
spinner_id = 0 | |
spinners = ["β ", "β ", "β", "β"] | |
initial_responses = [("", "", "")] | |
resp_generators = [ | |
gen_from_via(), | |
# gen_from_salmonn(), | |
gen_from_qwen(), | |
] | |
order = -1 | |
resp_generators = [ | |
resp_generators[model_order[0]], | |
resp_generators[model_order[1]], | |
resp_generators[model_order[2]], | |
] | |
for generator in [initial_responses, *resp_generators]: | |
order += 1 | |
for resps in generator: | |
v_resp, s_resp, q_resp = resps | |
resp_1 = resps[model_order[0]] | |
resp_2 = resps[model_order[1]] | |
resp_3 = resps[model_order[2]] | |
spinner = spinners[spinner_id] | |
spinner_id = (spinner_id + 1) % 4 | |
yield ( | |
gr.Button( | |
value=spinner + " Generating Responses " + spinner, | |
interactive=False, | |
variant="primary", | |
), | |
resp_1, | |
resp_2, | |
resp_3, | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
state, | |
) | |
yield ( | |
gr.Button( | |
value="Click to compare models!", interactive=True, variant="primary" | |
), | |
resp_1, | |
resp_2, | |
resp_3, | |
gr.Button(visible=True), | |
gr.Button(visible=False), | |
gr.Button(visible=True), | |
responses_complete(state), | |
) | |
def on_page_load(state, model_order): | |
if state == 0: | |
gr.Info( | |
"Record what you want to say to your AI Assistant! All Audio recordings are stored only temporarily and will be erased as soon as you exit this page." | |
) | |
state = 1 | |
if anonymous: | |
random.shuffle(model_order) | |
return state, model_order | |
def recording_complete(state): | |
if state == 1: | |
gr.Info( | |
"Submit your recording to get responses from all three models! You can also influence the model responses with an optional prompt." | |
) | |
state = 2 | |
return ( | |
gr.Button( | |
value="Click to compare models!", interactive=True, variant="primary" | |
), | |
state, | |
) | |
def responses_complete(state): | |
if state == 2: | |
gr.Info( | |
"Give us your feedback! Mark which model gave you the best response so we can understand the quality of these different voice assistant models. NOTE: This will save an (irreversible) hash of your inputs to deduplicate any repeated votes." | |
) | |
state = 3 | |
return state | |
def clear_factory(button_id): | |
def clear(audio_input, text_prompt, model_order): | |
if button_id != None: | |
sr, y = audio_input | |
with scheduler.lock: | |
db.insert( | |
{ | |
"audio_hash": hash(str(y)), | |
"text_prompt": hash(text_prompt), | |
"best": model_shorthand[model_order[button_id]], | |
} | |
) | |
if anonymous: | |
random.shuffle(model_order) | |
return ( | |
model_order, | |
gr.Button( | |
value="Record Audio to Submit!", | |
interactive=False, | |
), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
gr.Button(visible=False), | |
None, | |
gr.Textbox(visible=False), | |
gr.Textbox(visible=False), | |
gr.Textbox(visible=False), | |
) | |
return clear | |
theme = gr.themes.Soft( | |
primary_hue=gr.themes.Color( | |
c100="#82000019", | |
c200="#82000033", | |
c300="#8200004c", | |
c400="#82000066", | |
c50="#8200007f", | |
c500="#8200007f", | |
c600="#82000099", | |
c700="#820000b2", | |
c800="#820000cc", | |
c900="#820000e5", | |
c950="#820000f2", | |
), | |
secondary_hue="rose", | |
neutral_hue="stone", | |
) | |
model_names = ["Llama 3 DiVA", "SALMONN", "Qwen Audio"] | |
model_shorthand = ["via", "salmonn", "qwen"] | |
with gr.Blocks(theme=theme) as demo: | |
state = gr.State(0) | |
model_order = gr.State([0, 1, 2]) | |
with gr.Row(): | |
audio_input = gr.Audio( | |
sources=["microphone"], streaming=False, label="Audio Input" | |
) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
value="", | |
label="Text Prompt", | |
placeholder="Optional: Additional text prompt to influence how the model responds to your speech. e.g. 'Respond in a Haiku style.' or 'Translate the input to Arabic'", | |
) | |
with gr.Row(): | |
btn = gr.Button(value="Record Audio to Submit!", interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
out1 = gr.Textbox(visible=False) | |
best1 = gr.Button(value="This response is best", visible=False) | |
with gr.Column(scale=1): | |
out2 = gr.Textbox(visible=False) | |
best2 = gr.Button(value="This response is best", visible=False) | |
with gr.Column(scale=1): | |
out3 = gr.Textbox(visible=False) | |
best3 = gr.Button(value="This response is best", visible=False) | |
audio_input.stop_recording( | |
recording_complete, | |
[state], | |
[btn, state], | |
) | |
audio_input.start_recording( | |
lambda: gr.Button( | |
value="Uploading Audio to Cloud", interactive=False, variant="primary" | |
), | |
None, | |
btn, | |
) | |
btn.click( | |
fn=transcribe, | |
inputs=[audio_input, prompt, state, model_order], | |
outputs=[btn, out1, out2, out3, best1, best2, best3, state], | |
) | |
best1.click( | |
fn=clear_factory(0), | |
inputs=[audio_input, prompt, model_order], | |
outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3], | |
) | |
best2.click( | |
fn=clear_factory(1), | |
inputs=[audio_input, prompt, model_order], | |
outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3], | |
) | |
best3.click( | |
fn=clear_factory(2), | |
inputs=[audio_input, prompt, model_order], | |
outputs=[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3], | |
) | |
audio_input.clear( | |
clear_factory(None), | |
[audio_input, prompt, model_order], | |
[model_order, btn, best1, best2, best3, audio_input, out1, out2, out3], | |
) | |
demo.load( | |
fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order] | |
) | |
demo.launch(share=True) | |