File size: 10,287 Bytes
c0a944c
 
 
08d394e
68fe31b
 
 
 
 
 
2e0e7ea
ec95142
c0a944c
2e6c9fd
c0a944c
 
 
 
 
 
 
 
 
 
 
 
 
 
3043708
 
 
 
 
abdc143
3043708
3d60426
 
5cf1b23
bc6a876
 
5cf1b23
 
 
c0a944c
 
 
 
 
 
 
 
5cf1b23
 
 
37c3308
c0a944c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af4bca2
 
c0a944c
 
 
 
 
 
 
3043708
 
 
 
c0a944c
 
 
 
 
 
 
 
 
 
 
 
 
 
af4bca2
 
c0a944c
 
 
 
 
 
3043708
 
 
 
c0a944c
 
 
 
 
 
fdb9835
 
c0a944c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cf1b23
 
 
 
 
 
 
 
c0a944c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
import spaces

import subprocess
import os
subprocess.run(
    'pip install flash-attn --no-build-isolation',
    shell=True,
    env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}
)
# subprocess.run('FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE pip install flash-attn --no-build-isolation', shell=True)
# subprocess.run('pip install flash-attn==2.2.0 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# subprocess.run('git clone https://github.com/Dao-AILab/flash-attention.git && cd flash-attention && pip install flash-attn . --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "False"}, shell=True)

import tempfile
import random
import shutil
import pickle
import gradio as gr
import soundfile as sf
from pathlib import Path

import torch
import torchaudio

from huggingface_hub import hf_hub_download

from infer import load_model, eval_model
from spkr import SpeakerEmbedding
    
@spaces.GPU
def spkr_model_init():
    spkr_model = SpeakerEmbedding(device="cpu")
    return spkr_model

spkr_model = spkr_model_init()
spkr_model.model.to("cuda")
spkr_model.device = "cuda"
model, tokenizer, tokenizer_voila, model_type = load_model("maitrix-org/Voila-chat", "maitrix-org/Voila-Tokenizer")
model = model.to("cuda")
tokenizer_voila.to("cuda")
default_ref_file = "examples/character_ref_emb_demo.pkl"
default_ref_name = "Homer Simpson"
million_voice_ref_file = hf_hub_download(repo_id="maitrix-org/Voila-million-voice", filename="character_ref_emb_chunk0.pkl", repo_type="dataset")

instruction = "You are a smart AI agent created by Maitrix.org."
save_path = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())

intro = """**Voila**

For more demos, please goto [https://voila.maitrix.org](https://voila.maitrix.org)."""

default_ref_emb_mask_list = pickle.load(open(default_ref_file, "rb"))
million_voice_ref_emb_mask_list = pickle.load(open(million_voice_ref_file, "rb"))

@spaces.GPU
def get_ref_embs(ref_audio):
    wav, sr = torchaudio.load(ref_audio)
    ref_embs = spkr_model(wav, sr).cpu()
    return ref_embs

def delete_directory(request: gr.Request):
    if not request.session_hash:
        return
    user_dir = Path(f"{save_path}/{str(request.session_hash)}")
    if user_dir.exists():
        shutil.rmtree(str(user_dir))

def add_message(history, message):
    history.append({"role": "user", "content": {"path": message}})
    return history, gr.Audio(value=None), gr.Button(interactive=False)
    
@spaces.GPU
def call_bot(history, ref_embs, request: gr.Request):
    formated_history = {
        "instruction": instruction,
        "conversations": [{'from': item["role"], 'audio': {"file": item["content"][0]}} for item in history],
    }
    formated_history["conversations"].append({"from": "assistant"})
    print(formated_history)
    ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cpu")
    ref_embs_mask = torch.tensor([1], device="cpu")
    ref_embs.to("cuda")
    ref_embs_mask.to("cuda")
    out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_aiao", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
    if 'audio' in out:
        wav, sr = out['audio']

        user_dir = Path(f"{save_path}/{str(request.session_hash)}")
        user_dir.mkdir(exist_ok=True)
        save_name = f"{user_dir}/{len(history)}.wav"
        sf.write(save_name, wav, sr)

        history.append({"role": "assistant", "content": {"path": save_name}})
    else:
        history.append({"role": "assistant", "content": {"text": out['text']}})

    return history
    
@spaces.GPU
def run_tts(text, ref_embs):
    formated_history = {
        "instruction": "",
        "conversations": [{'from': "user", 'text': text}],
    }
    formated_history["conversations"].append({"from": "assistant"})
    ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cpu")
    ref_embs_mask = torch.tensor([1], device="cpu")
    ref_embs.to("cuda")
    ref_embs_mask.to("cuda")
    out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_tts", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
    if 'audio' in out:
        wav, sr = out['audio']
        return (sr, wav)
    else:
        raise Exception("No audio output")
        
@spaces.GPU
def run_asr(audio):
    formated_history = {
        "instruction": "",
        "conversations": [{'from': "user", 'audio': {"file": audio}}],
    }
    formated_history["conversations"].append({"from": "assistant"})
    out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_asr", formated_history, None, None, max_new_tokens=512)
    if 'text' in out:
        return out['text']
    else:
        raise Exception("No text output")


def markdown_ref_name(ref_name):
    return f"### Current voice id: {ref_name}"

def random_million_voice():
    voice_id = random.choice(list(million_voice_ref_emb_mask_list.keys()))
    return markdown_ref_name(voice_id), million_voice_ref_emb_mask_list[voice_id]

def get_ref_modules(cur_ref_embs):
    with gr.Row() as ref_row:
        with gr.Row():
            current_ref_name = gr.Markdown(markdown_ref_name(default_ref_name))
        with gr.Row() as ref_name_row:
            with gr.Column(scale=2, min_width=160):
                ref_name_dropdown = gr.Dropdown(
                    choices=list(default_ref_emb_mask_list.keys()),
                    value=default_ref_name,
                    label="Reference voice",
                    min_width=160,
                )
            with gr.Column(scale=1, min_width=80):
                random_ref_button = gr.Button(
                    "Random from Million Voice", size="md",
                )
        with gr.Row(visible=False) as ref_audio_row:
            with gr.Column(scale=2, min_width=80):
                ref_audio = gr.Audio(
                    sources=["microphone", "upload"],
                    type="filepath",
                    show_label=False,
                    min_width=80,
                )
            with gr.Column(scale=1, min_width=80):
                change_ref_button = gr.Button(
                    "Change voice",
                    interactive=False,
                    min_width=80,
                )
    ref_name_dropdown.change(
        lambda x: (markdown_ref_name(x), default_ref_emb_mask_list[x]),
        ref_name_dropdown,
        [current_ref_name, cur_ref_embs]
    )
    random_ref_button.click(
        random_million_voice,
        None,
        [current_ref_name, cur_ref_embs],
    )
    ref_audio.input(lambda: gr.Button(interactive=True), None, change_ref_button)
    # If custom ref voice checkbox is checked, show the Audio component to record or upload a reference voice
    custom_ref_voice = gr.Checkbox(label="Use custom voice", value=False)
    # Checked: enable audio and button
    # Unchecked: disable audio and button
    def custom_ref_voice_change(x, cur_ref_embs, cur_ref_embs_mask):
        if not x:
            cur_ref_embs = default_ref_emb_mask_list[default_ref_name]
        return [gr.Row(visible=not x), gr.Audio(value=None), gr.Row(visible=x), markdown_ref_name("Custom voice"), cur_ref_embs]
    custom_ref_voice.change(
        custom_ref_voice_change,
        [custom_ref_voice, cur_ref_embs],
        [ref_name_row, ref_audio, ref_audio_row, current_ref_name, cur_ref_embs]
    )
    # When change ref button is clicked, get the reference voice and update the reference voice state
    change_ref_button.click(
        lambda: gr.Button(interactive=False), None, [change_ref_button]
    ).then(
        get_ref_embs, ref_audio, cur_ref_embs
    )
    return ref_row

def get_chat_tab():
    cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
    with gr.Row() as chat_tab:
        with gr.Column(scale=1):
            ref_row = get_ref_modules(cur_ref_embs)
            # Voice chat input
            chat_input = gr.Audio(
                sources=["microphone", "upload"],
                type="filepath",
                show_label=False,
            )
            submit = gr.Button("Submit", interactive=False)
            gr.Markdown(intro)
        with gr.Column(scale=9):
            chatbot = gr.Chatbot(
                elem_id="chatbot",
                type="messages",
                bubble_full_width=False,
                scale=1,
                show_copy_button=False,
                avatar_images=(
                    None,  # os.path.join("files", "avatar.png"),
                    None, # os.path.join("files", "avatar.png"),
                ),
            )

        chat_input.input(lambda: gr.Button(interactive=True), None, submit)
        submit.click(
            add_message, [chatbot, chat_input], [chatbot, chat_input, submit]
        ).then(
            call_bot, [chatbot, cur_ref_embs], chatbot, api_name="bot_response"
        )
    return chat_tab

def get_tts_tab():
    cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
    with gr.Row() as tts_tab:
        with gr.Column(scale=1):
            ref_row = get_ref_modules(cur_ref_embs)
            gr.Markdown(intro)
        with gr.Column(scale=9):
            tts_output = gr.Audio(label="TTS output", interactive=False)
            with gr.Row():
                text_input = gr.Textbox(label="Text", placeholder="Text to TTS")
                submit = gr.Button("Submit")
        submit.click(
            run_tts, [text_input, cur_ref_embs], tts_output
        )
    return tts_tab

def get_asr_tab():
    with gr.Row() as asr_tab:
        with gr.Column():
            asr_input = gr.Audio(
                label="ASR input",
                sources=["microphone", "upload"],
                type="filepath",
            )
            submit = gr.Button("Submit")
            gr.Markdown(intro)
        with gr.Column():
            asr_output = gr.Textbox(label="ASR output", interactive=False)
    submit.click(
        run_asr, [asr_input], asr_output
    )
    return asr_tab

with gr.Blocks(fill_height=True) as demo:
    with gr.Tab("Chat"):
        chat_tab = get_chat_tab()
    with gr.Tab("TTS"):
        tts_tab = get_tts_tab()
    with gr.Tab("ASR"):
        asr_tab = get_asr_tab()
    demo.unload(delete_directory)

if __name__ == "__main__":
    demo.launch()