Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import os | |
import numpy as np | |
import copy | |
import gradio as gr | |
import sys | |
import spaces | |
from vita_audio.tokenizer import get_audio_tokenizer | |
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, AutoConfig, GenerationConfig | |
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." | |
import math | |
from numba import jit | |
def float_to_int16(audio: np.ndarray) -> np.ndarray: | |
am = int(math.ceil(float(np.abs(audio).max())) * 32768) | |
am = 32767 * 32768 // am | |
return np.multiply(audio, am).astype(np.int16) | |
def is_wav(file_path): | |
wav_extensions = {'.wav'} | |
_, ext = os.path.splitext(file_path) | |
return ext.lower() in wav_extensions | |
def _parse_text(text): | |
lines = text.split("\n") | |
lines = [line for line in lines if line != ""] | |
count = 0 | |
for i, line in enumerate(lines): | |
if "```" in line: | |
count += 1 | |
items = line.split("`") | |
if count % 2 == 1: | |
lines[i] = f'<pre><code class="language-{items[-1]}">' | |
else: | |
lines[i] = "<br></code></pre>" | |
else: | |
if i > 0 and count % 2 == 1: | |
line = line.replace("`", r"\`") | |
line = line.replace("<", "<") | |
line = line.replace(">", ">") | |
line = line.replace(" ", " ") | |
line = line.replace("*", "*") | |
line = line.replace("_", "_") | |
line = line.replace("-", "-") | |
line = line.replace(".", ".") | |
line = line.replace("!", "!") | |
line = line.replace("(", "(") | |
line = line.replace(")", ")") | |
line = line.replace("$", "$") | |
lines[i] = "<br>" + line | |
return "".join(lines) | |
def _launch_demo(model, tokenizer, audio_tokenizer): | |
def predict(_chatbot, task_history,task): | |
chat_query = task_history[-1][0] | |
print(task_history) | |
messages = [] | |
audio_path_list =[] | |
if task == 'Spoken QA': | |
messages = [ | |
{ | |
"role": "system", | |
#"content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.", | |
# "content": "Your Name: Lucy\nYour Gender: female\nRespond in a text-audio interleaved manner.", | |
"content": "Your Name: Omni\nYour Gender: female\nRespond in a text-audio interleaved manner.", | |
}, | |
] | |
for i, (q, a) in enumerate(task_history): | |
if isinstance(q, (tuple, list)) and is_wav(q[0]): | |
audio_path_list.append(q[0]) | |
messages = messages + [ | |
{ | |
"role": "user", | |
"content": f"\n<|audio|>", | |
}, | |
] | |
else: | |
messages = messages + [ | |
{ | |
"role": "user", | |
"content": q , | |
}, | |
] | |
if a != None: | |
messages = messages + [ | |
{ | |
"role": "assistant", | |
"content": a , | |
}, | |
] | |
model.generation_config.do_sample = False | |
elif task == 'TTS': | |
for i, (q, a) in enumerate(task_history): | |
if isinstance(q, (tuple, list)) and is_wav(q[0]): | |
audio_path_list.append(q[0]) | |
messages = messages + [ | |
{ | |
"role": "user", | |
"content": f"\n<|audio|>", | |
}, | |
] | |
else: | |
messages = messages + [ | |
{ | |
"role": "user", | |
"content": f'Convert the text to speech.\n{q}' , | |
}, | |
] | |
if a != None: | |
messages = messages + [ | |
{ | |
"role": "assistant", | |
"content": a , | |
}, | |
] | |
model.generation_config.do_sample = True | |
elif task == 'ASR': | |
for i, (q, a) in enumerate(task_history): | |
if isinstance(q, (tuple, list)) and is_wav(q[0]): | |
audio_path_list.append(q[0]) | |
messages = messages + [ | |
{ | |
"role": "user", | |
"content": f"Convert the speech to text.\n<|audio|>", | |
}, | |
] | |
else: | |
messages = messages + [ | |
{ | |
"role": "user", | |
"content": f"{q}" , | |
}, | |
] | |
if a != None: | |
messages = messages + [ | |
{ | |
"role": "assistant", | |
"content": a , | |
}, | |
] | |
model.generation_config.do_sample = False | |
add_generation_prompt =True | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=add_generation_prompt, | |
# return_tensors="pt", | |
) | |
input_ids, audios, audio_indices = add_audio_input_contiguous( | |
input_ids, audio_path_list, tokenizer, audio_tokenizer | |
) | |
input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda") | |
print("input", tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True) | |
if audio_path_list == []: | |
audios = None | |
audio_indices = None | |
outputs = model.generate( | |
input_ids, | |
audios=audios, | |
audio_indices=audio_indices, | |
) | |
output = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
# print(f"{output=}", flush=True) | |
audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
begin_of_audio = tokenizer.convert_tokens_to_ids("<|begin_of_audio|>") | |
end_of_audio = tokenizer.convert_tokens_to_ids("<|end_of_audio|>") | |
im_end = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
response = outputs[0][len(input_ids[0]):] | |
audio_tokens = [] | |
text_tokens = [] | |
for token_id in response: | |
if token_id >= audio_offset: | |
audio_tokens.append(token_id - audio_offset) | |
elif (token_id.item() != begin_of_audio) and (token_id.item() != end_of_audio) and (token_id.item() != im_end): | |
text_tokens.append(token_id) | |
if len(audio_tokens) > 0: | |
tts_speech = audio_tokenizer.decode(audio_tokens) | |
audio_np = float_to_int16(tts_speech.cpu().numpy()) | |
tts_speech = (22050,audio_np) | |
else: | |
tts_speech = None | |
# import pdb;pdb.set_trace() | |
history_response = tokenizer.decode(text_tokens) | |
task_history[-1] = (chat_query, history_response) | |
_chatbot[-1] = (chat_query, history_response) | |
# print("query",chat_query) | |
# print("task_history",task_history) | |
# print(_chatbot) | |
# print("answer: ",outputs) | |
return _chatbot, tts_speech | |
def add_text(history, task_history, text): | |
task_text = text | |
# import pdb;pdb.set_trace() | |
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: | |
task_text = text[:-1] | |
history = history + [(_parse_text(text), None)] | |
task_history = task_history + [(task_text, None)] | |
return history, task_history, "" | |
def add_audio(history, task_history, file): | |
print(file) | |
if file is None: | |
return history, task_history | |
history = history + [((file,), None)] | |
task_history = task_history + [((file,), None)] | |
return history, task_history | |
def reset_user_input(): | |
# import pdb;pdb.set_trace() | |
return gr.update(value="") | |
def reset_state(task_history): | |
task_history.clear() | |
return [] | |
font_size = "2.5em" | |
html = f""" | |
<p align="center" style="font-size: {font_size}; line-height: 1;"> | |
<span style="display: inline-block; vertical-align: middle;">VITA-Audio-Plus-Vanilla</span> | |
</p> | |
<center> | |
<font size=3> | |
<p> | |
<b>VITA-Audio</b> has been fully open-sourced on <a href='https://huggingface.co/VITA-MLLM'>😊 Huggingface</a> and <a href='https://github.com/VITA-MLLM/VITA-Audio'>🌟 GitHub</a>. If you find VITA-Audio useful, a like❤️ or a star🌟 would be appreciated. | |
</p> | |
</font> | |
<font size=3> | |
<p> | |
The deployment of the VITA-Audio-Plus-Vanilla model employs a non-streaming deployment approach. | |
For the ASR and TTS tasks, only single-turn dialogues are supported. In the Spoken QA task, generated text is used as dialogue history to reduce the context length. | |
</p> | |
</font> | |
</center> | |
""" | |
with gr.Blocks(title="VITA-Audio-Plus-Vanilla") as demo: | |
gr.HTML(html) | |
chatbot = gr.Chatbot(label='VITA-Audio-Plus-Vanilla', elem_classes="control-height", height=500) | |
query = gr.Textbox(lines=2, label='Text Input') | |
task_history = gr.State([]) | |
with gr.Row(): | |
add_text_button = gr.Button("Submit Text (提交文本)") | |
add_audio_button = gr.Button("Submit Audio (提交音频)") | |
empty_bin = gr.Button("🧹 Clear History (清除历史)") | |
task = gr.Radio( | |
choices = ["ASR", "TTS", "Spoken QA"], label="TASK", value = 'Spoken QA' | |
) | |
with gr.Row(scale=1): | |
record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000)) | |
audio_output = gr.Audio(label="Play", streaming=True, | |
autoplay=True, show_download_button=True) | |
add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then( | |
reset_user_input, [], [query] | |
).then( | |
predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True | |
) | |
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) | |
add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then( | |
predict, [chatbot, task_history,task], [chatbot,audio_output], show_progress=True | |
) | |
demo.launch( | |
show_error=True, | |
) | |
def main(): | |
model_name_or_path = "VITA-MLLM/VITA-Audio-Plus-Vanilla" | |
device_map = "cuda:0" | |
sys.path.append("third_party/GLM-4-Voice/") | |
sys.path.append("third_party/GLM-4-Voice/cosyvoice/") | |
sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/") | |
from huggingface_hub import snapshot_download | |
audio_tokenizer_path = snapshot_download(repo_id="THUDM/glm-4-voice-tokenizer") | |
flow_path = snapshot_download(repo_id="THUDM/glm-4-voice-decoder") | |
audio_tokenizer_rank = 0 | |
audio_tokenizer_type = "sensevoice_glm4voice" | |
torch_dtype = torch.bfloat16 | |
audio_tokenizer = get_audio_tokenizer( | |
audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path, rank=audio_tokenizer_rank | |
) | |
audio_tokenizer.load_model() | |
from evaluation.get_chat_template import qwen2_chat_template as chat_template | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True, | |
chat_template=chat_template, | |
) | |
# print(f"{tokenizer=}") | |
# print(f"{tokenizer.get_chat_template()=}") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True, | |
device_map=device_map, | |
torch_dtype=torch_dtype, | |
attn_implementation="flash_attention_2", | |
).eval() | |
# print(f"{model.config.model_type=}") | |
model.generation_config = GenerationConfig.from_pretrained( | |
model_name_or_path, trust_remote_code=True | |
) | |
model.generation_config.max_new_tokens = 4096 | |
model.generation_config.chat_format = "chatml" | |
model.generation_config.max_window_size = 8192 | |
model.generation_config.use_cache = True | |
model.generation_config.do_sample = True | |
model.generation_config.temperature = 1.0 | |
model.generation_config.top_k = 50 | |
model.generation_config.top_p = 1.0 | |
model.generation_config.num_beams = 1 | |
model.generation_config.pad_token_id = tokenizer.pad_token_id | |
model.generation_config.mtp_inference_mode = [8192,10] | |
_launch_demo(model, tokenizer, audio_tokenizer) | |
if __name__ == '__main__': | |
main() | |