Voila-demo / app.py
guangyil's picture
Update app.py
37c3308 verified
import spaces
import subprocess
import os
subprocess.run(
'pip install flash-attn --no-build-isolation',
shell=True,
env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}
)
# subprocess.run('FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation', shell=True)
# subprocess.run('pip install flash-attn==2.2.0 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# subprocess.run('git clone https://github.com/Dao-AILab/flash-attention.git && cd flash-attention && pip install flash-attn . --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "False"}, shell=True)
import tempfile
import random
import shutil
import pickle
import gradio as gr
import soundfile as sf
from pathlib import Path
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from infer import load_model, eval_model
from spkr import SpeakerEmbedding
@spaces.GPU
def spkr_model_init():
spkr_model = SpeakerEmbedding(device="cpu")
return spkr_model
spkr_model = spkr_model_init()
spkr_model.model.to("cuda")
spkr_model.device = "cuda"
model, tokenizer, tokenizer_voila, model_type = load_model("maitrix-org/Voila-chat", "maitrix-org/Voila-Tokenizer")
model = model.to("cuda")
tokenizer_voila.to("cuda")
default_ref_file = "examples/character_ref_emb_demo.pkl"
default_ref_name = "Homer Simpson"
million_voice_ref_file = hf_hub_download(repo_id="maitrix-org/Voila-million-voice", filename="character_ref_emb_chunk0.pkl", repo_type="dataset")
instruction = "You are a smart AI agent created by Maitrix.org."
save_path = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
intro = """**Voila**
For more demos, please goto [https://voila.maitrix.org](https://voila.maitrix.org)."""
default_ref_emb_mask_list = pickle.load(open(default_ref_file, "rb"))
million_voice_ref_emb_mask_list = pickle.load(open(million_voice_ref_file, "rb"))
@spaces.GPU
def get_ref_embs(ref_audio):
wav, sr = torchaudio.load(ref_audio)
ref_embs = spkr_model(wav, sr).cpu()
return ref_embs
def delete_directory(request: gr.Request):
if not request.session_hash:
return
user_dir = Path(f"{save_path}/{str(request.session_hash)}")
if user_dir.exists():
shutil.rmtree(str(user_dir))
def add_message(history, message):
history.append({"role": "user", "content": {"path": message}})
return history, gr.Audio(value=None), gr.Button(interactive=False)
@spaces.GPU
def call_bot(history, ref_embs, request: gr.Request):
formated_history = {
"instruction": instruction,
"conversations": [{'from': item["role"], 'audio': {"file": item["content"][0]}} for item in history],
}
formated_history["conversations"].append({"from": "assistant"})
print(formated_history)
ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cpu")
ref_embs_mask = torch.tensor([1], device="cpu")
ref_embs.to("cuda")
ref_embs_mask.to("cuda")
out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_aiao", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
if 'audio' in out:
wav, sr = out['audio']
user_dir = Path(f"{save_path}/{str(request.session_hash)}")
user_dir.mkdir(exist_ok=True)
save_name = f"{user_dir}/{len(history)}.wav"
sf.write(save_name, wav, sr)
history.append({"role": "assistant", "content": {"path": save_name}})
else:
history.append({"role": "assistant", "content": {"text": out['text']}})
return history
@spaces.GPU
def run_tts(text, ref_embs):
formated_history = {
"instruction": "",
"conversations": [{'from': "user", 'text': text}],
}
formated_history["conversations"].append({"from": "assistant"})
ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cpu")
ref_embs_mask = torch.tensor([1], device="cpu")
ref_embs.to("cuda")
ref_embs_mask.to("cuda")
out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_tts", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
if 'audio' in out:
wav, sr = out['audio']
return (sr, wav)
else:
raise Exception("No audio output")
@spaces.GPU
def run_asr(audio):
formated_history = {
"instruction": "",
"conversations": [{'from': "user", 'audio': {"file": audio}}],
}
formated_history["conversations"].append({"from": "assistant"})
out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_asr", formated_history, None, None, max_new_tokens=512)
if 'text' in out:
return out['text']
else:
raise Exception("No text output")
def markdown_ref_name(ref_name):
return f"### Current voice id: {ref_name}"
def random_million_voice():
voice_id = random.choice(list(million_voice_ref_emb_mask_list.keys()))
return markdown_ref_name(voice_id), million_voice_ref_emb_mask_list[voice_id]
def get_ref_modules(cur_ref_embs):
with gr.Row() as ref_row:
with gr.Row():
current_ref_name = gr.Markdown(markdown_ref_name(default_ref_name))
with gr.Row() as ref_name_row:
with gr.Column(scale=2, min_width=160):
ref_name_dropdown = gr.Dropdown(
choices=list(default_ref_emb_mask_list.keys()),
value=default_ref_name,
label="Reference voice",
min_width=160,
)
with gr.Column(scale=1, min_width=80):
random_ref_button = gr.Button(
"Random from Million Voice", size="md",
)
with gr.Row(visible=False) as ref_audio_row:
with gr.Column(scale=2, min_width=80):
ref_audio = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
show_label=False,
min_width=80,
)
with gr.Column(scale=1, min_width=80):
change_ref_button = gr.Button(
"Change voice",
interactive=False,
min_width=80,
)
ref_name_dropdown.change(
lambda x: (markdown_ref_name(x), default_ref_emb_mask_list[x]),
ref_name_dropdown,
[current_ref_name, cur_ref_embs]
)
random_ref_button.click(
random_million_voice,
None,
[current_ref_name, cur_ref_embs],
)
ref_audio.input(lambda: gr.Button(interactive=True), None, change_ref_button)
# If custom ref voice checkbox is checked, show the Audio component to record or upload a reference voice
custom_ref_voice = gr.Checkbox(label="Use custom voice", value=False)
# Checked: enable audio and button
# Unchecked: disable audio and button
def custom_ref_voice_change(x, cur_ref_embs, cur_ref_embs_mask):
if not x:
cur_ref_embs = default_ref_emb_mask_list[default_ref_name]
return [gr.Row(visible=not x), gr.Audio(value=None), gr.Row(visible=x), markdown_ref_name("Custom voice"), cur_ref_embs]
custom_ref_voice.change(
custom_ref_voice_change,
[custom_ref_voice, cur_ref_embs],
[ref_name_row, ref_audio, ref_audio_row, current_ref_name, cur_ref_embs]
)
# When change ref button is clicked, get the reference voice and update the reference voice state
change_ref_button.click(
lambda: gr.Button(interactive=False), None, [change_ref_button]
).then(
get_ref_embs, ref_audio, cur_ref_embs
)
return ref_row
def get_chat_tab():
cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
with gr.Row() as chat_tab:
with gr.Column(scale=1):
ref_row = get_ref_modules(cur_ref_embs)
# Voice chat input
chat_input = gr.Audio(
sources=["microphone", "upload"],
type="filepath",
show_label=False,
)
submit = gr.Button("Submit", interactive=False)
gr.Markdown(intro)
with gr.Column(scale=9):
chatbot = gr.Chatbot(
elem_id="chatbot",
type="messages",
bubble_full_width=False,
scale=1,
show_copy_button=False,
avatar_images=(
None, # os.path.join("files", "avatar.png"),
None, # os.path.join("files", "avatar.png"),
),
)
chat_input.input(lambda: gr.Button(interactive=True), None, submit)
submit.click(
add_message, [chatbot, chat_input], [chatbot, chat_input, submit]
).then(
call_bot, [chatbot, cur_ref_embs], chatbot, api_name="bot_response"
)
return chat_tab
def get_tts_tab():
cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
with gr.Row() as tts_tab:
with gr.Column(scale=1):
ref_row = get_ref_modules(cur_ref_embs)
gr.Markdown(intro)
with gr.Column(scale=9):
tts_output = gr.Audio(label="TTS output", interactive=False)
with gr.Row():
text_input = gr.Textbox(label="Text", placeholder="Text to TTS")
submit = gr.Button("Submit")
submit.click(
run_tts, [text_input, cur_ref_embs], tts_output
)
return tts_tab
def get_asr_tab():
with gr.Row() as asr_tab:
with gr.Column():
asr_input = gr.Audio(
label="ASR input",
sources=["microphone", "upload"],
type="filepath",
)
submit = gr.Button("Submit")
gr.Markdown(intro)
with gr.Column():
asr_output = gr.Textbox(label="ASR output", interactive=False)
submit.click(
run_asr, [asr_input], asr_output
)
return asr_tab
with gr.Blocks(fill_height=True) as demo:
with gr.Tab("Chat"):
chat_tab = get_chat_tab()
with gr.Tab("TTS"):
tts_tab = get_tts_tab()
with gr.Tab("ASR"):
asr_tab = get_asr_tab()
demo.unload(delete_directory)
if __name__ == "__main__":
demo.launch()