txya900619's picture
Update app.py
971c779 verified
import os
import gradio as gr
import librosa
import spaces
import torch
from e2_tts_pytorch import DurationPredictor
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from tokenizers import Tokenizer
from transformers import PreTrainedTokenizerFast
from ipa.ipa import get_ipa, parse_ipa
from patch.e2_tts_pytorch import E2TTSPatched as E2TTS
def load_model(model_id):
model_dir = snapshot_download(model_id)
e2tts_ckpt_path = os.path.join(model_dir, "e2tts.pt")
duration_predictor_ckpt_path = os.path.join(model_dir, "duration_predictor.pt")
tokenizer_file_path = os.path.join(model_dir, "tokenizer.json")
duration_predictor_ckpt = torch.load(duration_predictor_ckpt_path)
e2tts_ckpt = torch.load(e2tts_ckpt_path)
tokenizer_object = Tokenizer.from_file(tokenizer_file_path)
fast_tokenizer_object = PreTrainedTokenizerFast(tokenizer_object=tokenizer_object)
def tokenizer(text):
ids = fast_tokenizer_object(text, return_tensors="pt", padding=True).input_ids
ids[ids == 0] = -1
return ids
duration_predictor = DurationPredictor(
transformer=dict(
dim=384,
depth=8,
heads=6,
attn_kwargs=dict(
gate_value_heads=True,
flash = True,
),
),
text_num_embeds=fast_tokenizer_object.vocab_size,
tokenizer=tokenizer,
)
duration_predictor.load_state_dict(duration_predictor_ckpt["model_state_dict"])
e2tts = E2TTS(
cond_drop_prob=0.2,
transformer=dict(
dim=512,
depth=12,
heads=6,
attn_kwargs=dict(
gate_value_heads=True,
flash = True,
),
),
text_num_embeds=fast_tokenizer_object.vocab_size,
tokenizer=tokenizer,
)
e2tts.load_state_dict(e2tts_ckpt["model_state_dict"])
duration_predictor.eval()
e2tts.eval()
e2tts.duration_predictor = duration_predictor
return e2tts
OmegaConf.register_new_resolver("load_model", load_model)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
@spaces.GPU
def _do_tts(model_id, ipa, ref_wav, ref_transcript, speed):
with torch.inference_mode():
model = models_config[model_id]["model"].cuda()
ref_wav = librosa.load(ref_wav, sr=model.sampling_rate)[0]
ipa = ipa + " <sil>"
print(ref_transcript + ipa)
text = model.tokenizer([ref_transcript + ipa]).to(model.device)
generated = model.sample(
cond=torch.from_numpy(ref_wav).float().unsqueeze(0).cuda(),
text=text,
steps=32,
cfg_strength=1.0,
speed=speed,
)[0]
return generated.cpu().numpy()
def text_to_speech(
model_id: str,
use_default_or_custom: str,
speaker_name: str,
dialect: str,
speed: float,
text: str,
ref_wav: str,
ref_transcript: str,
):
if len(text) == 0:
raise gr.Error("請勿輸入空字串。")
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect)
if len(missing_words) > 0:
raise gr.Error(
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。"
)
parsed_ipa = parse_ipa(ipa)
if dialect == "nansixian":
dialect = "sixian"
wav = _do_tts(
model_id,
parsed_ipa,
ref_wav,
ref_transcript,
speed,
)
return (
words,
pinyin,
(
models_config[model_id]["model"].sampling_rate,
wav,
),
)
def when_model_selected(model_id):
model = models_config[model_id]
return (
gr.update(
choices=[speaker_name for speaker_name in model["speaker_mapping"].keys()],
value=list(model["speaker_mapping"].keys())[0],
),
gr.update(
choices=[(k, v) for k, v in model["dialect_mapping"].items()],
value=list(model["dialect_mapping"].values())[0],
),
gr.update(
value="預設語者",
),
)
def when_default_speaker_selected(model_id, speaker_name):
speaker_mapping = models_config[model_id]["speaker_mapping"]
ref_wav_path = speaker_mapping[speaker_name]["ref_wav"]
ref_transcript = speaker_mapping[speaker_name]["ref_transcript"]
return gr.update(
value=ref_wav_path,
), gr.update(
value=ref_transcript,
)
def use_default_or_custom_radio_input(use_default_or_custom):
if use_default_or_custom == "客製化語者":
return gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=True)
demo = gr.Blocks(
title="臺灣客語語音生成系統",
css="@import url(https://tauhu.tw/tauhu-oo.css);",
theme=gr.themes.Default(
font=(
"tauhu-oo",
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
)
),
)
with demo:
default_model_id = list(models_config.keys())[0]
model_drop_down = gr.Dropdown(
models_config.keys(),
value=default_model_id,
label="模型",
)
use_default_or_custom_radio = gr.Radio(
label="語者類型",
choices=["預設語者", "客製化語者"],
value="預設語者",
visible=True,
show_label=False,
interactive=False, # TODO
)
ref_wav = gr.Audio(
visible=False,
type="filepath",
value=list(models_config[default_model_id]["speaker_mapping"].values())[0][
"ref_wav"
],
waveform_options=gr.WaveformOptions(
show_controls=False,
sample_rate=24000,
),
)
ref_transcript = gr.Textbox(
value=list(models_config[default_model_id]["speaker_mapping"].values())[0][
"ref_transcript"
],
visible=False,
)
speaker_wav = gr.Audio(
label="客製化語音",
visible=False,
editable=False,
type="filepath",
waveform_options=gr.WaveformOptions(
show_controls=False,
sample_rate=24000,
),
)
speaker_drop_down = gr.Dropdown(
choices=[
speaker_name
for speaker_name in models_config[default_model_id][
"speaker_mapping"
].keys()
],
value=list(models_config[default_model_id]["speaker_mapping"].keys())[0],
label="語者",
interactive=True,
visible=True,
)
speaker_drop_down.change(
when_default_speaker_selected,
inputs=[model_drop_down, speaker_drop_down],
outputs=[ref_wav, ref_transcript],
)
use_default_or_custom_radio.change(
use_default_or_custom_radio_input,
inputs=[use_default_or_custom_radio],
outputs=[speaker_wav, speaker_drop_down],
)
dialect_radio = gr.Radio(
choices=[
(k, v)
for k, v in models_config[default_model_id]["dialect_mapping"].items()
],
value=list(models_config[default_model_id]["dialect_mapping"].values())[0],
label="腔調",
interactive=len(models_config[default_model_id]["dialect_mapping"]) > 1,
)
model_drop_down.input(
when_model_selected,
inputs=[model_drop_down],
outputs=[speaker_drop_down, dialect_radio, use_default_or_custom_radio],
)
input_text = gr.Textbox(
label="輸入文字",
value="",
)
speed = gr.Slider(maximum=1.5, minimum=0.5, value=1, label="語速(越大越慢)")
gr.Markdown(
"""
# 臺灣客語語音合成系統
### Taiwanese Hakka Text-to-Speech System
### 研發團隊
- **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))**
- **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))**
### 合作單位
- **[國立聯合大學智慧客家實驗室](https://www.gohakka.org)**
"""
)
gr.Interface(
text_to_speech,
inputs=[
model_drop_down,
use_default_or_custom_radio,
speaker_drop_down,
dialect_radio,
speed,
input_text,
ref_wav,
ref_transcript,
],
outputs=[
gr.Textbox(interactive=False, label="斷詞"),
gr.Textbox(interactive=False, label="客語拼音"),
gr.Audio(interactive=False, label="合成語音", show_download_button=True),
],
allow_flagging="auto",
)
gr.Examples(
[
[
"預設語者",
"XF",
"sixian",
"食飯愛正經食,正毋會食到半出半入",
],
[
"預設語者",
"XF",
"sixian",
"歸條路吊等長長个花燈,祈求風調雨順,歸屋下人个心願,親像花燈下燒暖个光華",
],
# [
# "預設語者",
# "戴君儒",
# "hailu",
# "男女平等个時代,平平做得受教育",
# ],
# [
# "預設語者",
# "宋涵葳",
# "dapu",
# "客家山城乜跈緊鬧熱䟘來咧",
# ],
# [
# "預設語者",
# "江芮敏",
# "raoping",
# "頭擺匱人,戴个毋係菅草屋,个創商品哦",
# ],
# [
# "預設語者",
# "洪藝晅",
# "zhaoan",
# "歇熱个時務,阿松歸屋下轉去在客莊个老屋",
# ],
# [
# "預設語者",
# "江芮敏",
# "nansixian",
# "在𠊎讀小學一年生个時節,阿爸輒常用自轉車載𠊎去學校讀書",
# ],
],
label="範例",
inputs=[
use_default_or_custom_radio,
speaker_drop_down,
dialect_radio,
input_text,
],
)
demo.launch()