Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import librosa | |
import spaces | |
import torch | |
from e2_tts_pytorch import DurationPredictor | |
from huggingface_hub import snapshot_download | |
from omegaconf import OmegaConf | |
from tokenizers import Tokenizer | |
from transformers import PreTrainedTokenizerFast | |
from ipa.ipa import get_ipa, parse_ipa | |
from patch.e2_tts_pytorch import E2TTSPatched as E2TTS | |
def load_model(model_id): | |
model_dir = snapshot_download(model_id) | |
e2tts_ckpt_path = os.path.join(model_dir, "e2tts.pt") | |
duration_predictor_ckpt_path = os.path.join(model_dir, "duration_predictor.pt") | |
tokenizer_file_path = os.path.join(model_dir, "tokenizer.json") | |
duration_predictor_ckpt = torch.load(duration_predictor_ckpt_path) | |
e2tts_ckpt = torch.load(e2tts_ckpt_path) | |
tokenizer_object = Tokenizer.from_file(tokenizer_file_path) | |
fast_tokenizer_object = PreTrainedTokenizerFast(tokenizer_object=tokenizer_object) | |
def tokenizer(text): | |
ids = fast_tokenizer_object(text, return_tensors="pt", padding=True).input_ids | |
ids[ids == 0] = -1 | |
return ids | |
duration_predictor = DurationPredictor( | |
transformer=dict( | |
dim=384, | |
depth=8, | |
heads=6, | |
attn_kwargs=dict( | |
gate_value_heads=True, | |
flash = True, | |
), | |
), | |
text_num_embeds=fast_tokenizer_object.vocab_size, | |
tokenizer=tokenizer, | |
) | |
duration_predictor.load_state_dict(duration_predictor_ckpt["model_state_dict"]) | |
e2tts = E2TTS( | |
cond_drop_prob=0.2, | |
transformer=dict( | |
dim=512, | |
depth=12, | |
heads=6, | |
attn_kwargs=dict( | |
gate_value_heads=True, | |
flash = True, | |
), | |
), | |
text_num_embeds=fast_tokenizer_object.vocab_size, | |
tokenizer=tokenizer, | |
) | |
e2tts.load_state_dict(e2tts_ckpt["model_state_dict"]) | |
duration_predictor.eval() | |
e2tts.eval() | |
e2tts.duration_predictor = duration_predictor | |
return e2tts | |
OmegaConf.register_new_resolver("load_model", load_model) | |
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml")) | |
def _do_tts(model_id, ipa, ref_wav, ref_transcript, speed): | |
with torch.inference_mode(): | |
model = models_config[model_id]["model"].cuda() | |
ref_wav = librosa.load(ref_wav, sr=model.sampling_rate)[0] | |
ipa = ipa + " <sil>" | |
print(ref_transcript + ipa) | |
text = model.tokenizer([ref_transcript + ipa]).to(model.device) | |
generated = model.sample( | |
cond=torch.from_numpy(ref_wav).float().unsqueeze(0).cuda(), | |
text=text, | |
steps=32, | |
cfg_strength=1.0, | |
speed=speed, | |
)[0] | |
return generated.cpu().numpy() | |
def text_to_speech( | |
model_id: str, | |
use_default_or_custom: str, | |
speaker_name: str, | |
dialect: str, | |
speed: float, | |
text: str, | |
ref_wav: str, | |
ref_transcript: str, | |
): | |
if len(text) == 0: | |
raise gr.Error("請勿輸入空字串。") | |
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect) | |
if len(missing_words) > 0: | |
raise gr.Error( | |
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。" | |
) | |
parsed_ipa = parse_ipa(ipa) | |
if dialect == "nansixian": | |
dialect = "sixian" | |
wav = _do_tts( | |
model_id, | |
parsed_ipa, | |
ref_wav, | |
ref_transcript, | |
speed, | |
) | |
return ( | |
words, | |
pinyin, | |
( | |
models_config[model_id]["model"].sampling_rate, | |
wav, | |
), | |
) | |
def when_model_selected(model_id): | |
model = models_config[model_id] | |
return ( | |
gr.update( | |
choices=[speaker_name for speaker_name in model["speaker_mapping"].keys()], | |
value=list(model["speaker_mapping"].keys())[0], | |
), | |
gr.update( | |
choices=[(k, v) for k, v in model["dialect_mapping"].items()], | |
value=list(model["dialect_mapping"].values())[0], | |
), | |
gr.update( | |
value="預設語者", | |
), | |
) | |
def when_default_speaker_selected(model_id, speaker_name): | |
speaker_mapping = models_config[model_id]["speaker_mapping"] | |
ref_wav_path = speaker_mapping[speaker_name]["ref_wav"] | |
ref_transcript = speaker_mapping[speaker_name]["ref_transcript"] | |
return gr.update( | |
value=ref_wav_path, | |
), gr.update( | |
value=ref_transcript, | |
) | |
def use_default_or_custom_radio_input(use_default_or_custom): | |
if use_default_or_custom == "客製化語者": | |
return gr.update(visible=True), gr.update(visible=False) | |
return gr.update(visible=False), gr.update(visible=True) | |
demo = gr.Blocks( | |
title="臺灣客語語音生成系統", | |
css="@import url(https://tauhu.tw/tauhu-oo.css);", | |
theme=gr.themes.Default( | |
font=( | |
"tauhu-oo", | |
gr.themes.GoogleFont("Source Sans Pro"), | |
"ui-sans-serif", | |
"system-ui", | |
"sans-serif", | |
) | |
), | |
) | |
with demo: | |
default_model_id = list(models_config.keys())[0] | |
model_drop_down = gr.Dropdown( | |
models_config.keys(), | |
value=default_model_id, | |
label="模型", | |
) | |
use_default_or_custom_radio = gr.Radio( | |
label="語者類型", | |
choices=["預設語者", "客製化語者"], | |
value="預設語者", | |
visible=True, | |
show_label=False, | |
interactive=False, # TODO | |
) | |
ref_wav = gr.Audio( | |
visible=False, | |
type="filepath", | |
value=list(models_config[default_model_id]["speaker_mapping"].values())[0][ | |
"ref_wav" | |
], | |
waveform_options=gr.WaveformOptions( | |
show_controls=False, | |
sample_rate=24000, | |
), | |
) | |
ref_transcript = gr.Textbox( | |
value=list(models_config[default_model_id]["speaker_mapping"].values())[0][ | |
"ref_transcript" | |
], | |
visible=False, | |
) | |
speaker_wav = gr.Audio( | |
label="客製化語音", | |
visible=False, | |
editable=False, | |
type="filepath", | |
waveform_options=gr.WaveformOptions( | |
show_controls=False, | |
sample_rate=24000, | |
), | |
) | |
speaker_drop_down = gr.Dropdown( | |
choices=[ | |
speaker_name | |
for speaker_name in models_config[default_model_id][ | |
"speaker_mapping" | |
].keys() | |
], | |
value=list(models_config[default_model_id]["speaker_mapping"].keys())[0], | |
label="語者", | |
interactive=True, | |
visible=True, | |
) | |
speaker_drop_down.change( | |
when_default_speaker_selected, | |
inputs=[model_drop_down, speaker_drop_down], | |
outputs=[ref_wav, ref_transcript], | |
) | |
use_default_or_custom_radio.change( | |
use_default_or_custom_radio_input, | |
inputs=[use_default_or_custom_radio], | |
outputs=[speaker_wav, speaker_drop_down], | |
) | |
dialect_radio = gr.Radio( | |
choices=[ | |
(k, v) | |
for k, v in models_config[default_model_id]["dialect_mapping"].items() | |
], | |
value=list(models_config[default_model_id]["dialect_mapping"].values())[0], | |
label="腔調", | |
interactive=len(models_config[default_model_id]["dialect_mapping"]) > 1, | |
) | |
model_drop_down.input( | |
when_model_selected, | |
inputs=[model_drop_down], | |
outputs=[speaker_drop_down, dialect_radio, use_default_or_custom_radio], | |
) | |
input_text = gr.Textbox( | |
label="輸入文字", | |
value="", | |
) | |
speed = gr.Slider(maximum=1.5, minimum=0.5, value=1, label="語速(越大越慢)") | |
gr.Markdown( | |
""" | |
# 臺灣客語語音合成系統 | |
### Taiwanese Hakka Text-to-Speech System | |
### 研發團隊 | |
- **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))** | |
- **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))** | |
### 合作單位 | |
- **[國立聯合大學智慧客家實驗室](https://www.gohakka.org)** | |
""" | |
) | |
gr.Interface( | |
text_to_speech, | |
inputs=[ | |
model_drop_down, | |
use_default_or_custom_radio, | |
speaker_drop_down, | |
dialect_radio, | |
speed, | |
input_text, | |
ref_wav, | |
ref_transcript, | |
], | |
outputs=[ | |
gr.Textbox(interactive=False, label="斷詞"), | |
gr.Textbox(interactive=False, label="客語拼音"), | |
gr.Audio(interactive=False, label="合成語音", show_download_button=True), | |
], | |
allow_flagging="auto", | |
) | |
gr.Examples( | |
[ | |
[ | |
"預設語者", | |
"XF", | |
"sixian", | |
"食飯愛正經食,正毋會食到半出半入", | |
], | |
[ | |
"預設語者", | |
"XF", | |
"sixian", | |
"歸條路吊等長長个花燈,祈求風調雨順,歸屋下人个心願,親像花燈下燒暖个光華", | |
], | |
# [ | |
# "預設語者", | |
# "戴君儒", | |
# "hailu", | |
# "男女平等个時代,平平做得受教育", | |
# ], | |
# [ | |
# "預設語者", | |
# "宋涵葳", | |
# "dapu", | |
# "客家山城乜跈緊鬧熱䟘來咧", | |
# ], | |
# [ | |
# "預設語者", | |
# "江芮敏", | |
# "raoping", | |
# "頭擺匱人,戴个毋係菅草屋,个創商品哦", | |
# ], | |
# [ | |
# "預設語者", | |
# "洪藝晅", | |
# "zhaoan", | |
# "歇熱个時務,阿松歸屋下轉去在客莊个老屋", | |
# ], | |
# [ | |
# "預設語者", | |
# "江芮敏", | |
# "nansixian", | |
# "在𠊎讀小學一年生个時節,阿爸輒常用自轉車載𠊎去學校讀書", | |
# ], | |
], | |
label="範例", | |
inputs=[ | |
use_default_or_custom_radio, | |
speaker_drop_down, | |
dialect_radio, | |
input_text, | |
], | |
) | |
demo.launch() | |