import os import gradio as gr import librosa import spaces import torch from e2_tts_pytorch import DurationPredictor from huggingface_hub import snapshot_download from omegaconf import OmegaConf from tokenizers import Tokenizer from transformers import PreTrainedTokenizerFast from ipa.ipa import get_ipa, parse_ipa from patch.e2_tts_pytorch import E2TTSPatched as E2TTS def load_model(model_id): model_dir = snapshot_download(model_id) e2tts_ckpt_path = os.path.join(model_dir, "e2tts.pt") duration_predictor_ckpt_path = os.path.join(model_dir, "duration_predictor.pt") tokenizer_file_path = os.path.join(model_dir, "tokenizer.json") duration_predictor_ckpt = torch.load(duration_predictor_ckpt_path) e2tts_ckpt = torch.load(e2tts_ckpt_path) tokenizer_object = Tokenizer.from_file(tokenizer_file_path) fast_tokenizer_object = PreTrainedTokenizerFast(tokenizer_object=tokenizer_object) def tokenizer(text): ids = fast_tokenizer_object(text, return_tensors="pt", padding=True).input_ids ids[ids == 0] = -1 return ids duration_predictor = DurationPredictor( transformer=dict( dim=384, depth=8, heads=6, attn_kwargs=dict( gate_value_heads=True, flash = True, ), ), text_num_embeds=fast_tokenizer_object.vocab_size, tokenizer=tokenizer, ) duration_predictor.load_state_dict(duration_predictor_ckpt["model_state_dict"]) e2tts = E2TTS( cond_drop_prob=0.2, transformer=dict( dim=512, depth=12, heads=6, attn_kwargs=dict( gate_value_heads=True, flash = True, ), ), text_num_embeds=fast_tokenizer_object.vocab_size, tokenizer=tokenizer, ) e2tts.load_state_dict(e2tts_ckpt["model_state_dict"]) duration_predictor.eval() e2tts.eval() e2tts.duration_predictor = duration_predictor return e2tts OmegaConf.register_new_resolver("load_model", load_model) models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml")) @spaces.GPU def _do_tts(model_id, ipa, ref_wav, ref_transcript, speed): with torch.inference_mode(): model = models_config[model_id]["model"].cuda() ref_wav = librosa.load(ref_wav, sr=model.sampling_rate)[0] ipa = ipa + " " print(ref_transcript + ipa) text = model.tokenizer([ref_transcript + ipa]).to(model.device) generated = model.sample( cond=torch.from_numpy(ref_wav).float().unsqueeze(0).cuda(), text=text, steps=32, cfg_strength=1.0, speed=speed, )[0] return generated.cpu().numpy() def text_to_speech( model_id: str, use_default_or_custom: str, speaker_name: str, dialect: str, speed: float, text: str, ref_wav: str, ref_transcript: str, ): if len(text) == 0: raise gr.Error("請勿輸入空字串。") words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect) if len(missing_words) > 0: raise gr.Error( f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。" ) parsed_ipa = parse_ipa(ipa) if dialect == "nansixian": dialect = "sixian" wav = _do_tts( model_id, parsed_ipa, ref_wav, ref_transcript, speed, ) return ( words, pinyin, ( models_config[model_id]["model"].sampling_rate, wav, ), ) def when_model_selected(model_id): model = models_config[model_id] return ( gr.update( choices=[speaker_name for speaker_name in model["speaker_mapping"].keys()], value=list(model["speaker_mapping"].keys())[0], ), gr.update( choices=[(k, v) for k, v in model["dialect_mapping"].items()], value=list(model["dialect_mapping"].values())[0], ), gr.update( value="預設語者", ), ) def when_default_speaker_selected(model_id, speaker_name): speaker_mapping = models_config[model_id]["speaker_mapping"] ref_wav_path = speaker_mapping[speaker_name]["ref_wav"] ref_transcript = speaker_mapping[speaker_name]["ref_transcript"] return gr.update( value=ref_wav_path, ), gr.update( value=ref_transcript, ) def use_default_or_custom_radio_input(use_default_or_custom): if use_default_or_custom == "客製化語者": return gr.update(visible=True), gr.update(visible=False) return gr.update(visible=False), gr.update(visible=True) demo = gr.Blocks( title="臺灣客語語音生成系統", css="@import url(https://tauhu.tw/tauhu-oo.css);", theme=gr.themes.Default( font=( "tauhu-oo", gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui", "sans-serif", ) ), ) with demo: default_model_id = list(models_config.keys())[0] model_drop_down = gr.Dropdown( models_config.keys(), value=default_model_id, label="模型", ) use_default_or_custom_radio = gr.Radio( label="語者類型", choices=["預設語者", "客製化語者"], value="預設語者", visible=True, show_label=False, interactive=False, # TODO ) ref_wav = gr.Audio( visible=False, type="filepath", value=list(models_config[default_model_id]["speaker_mapping"].values())[0][ "ref_wav" ], waveform_options=gr.WaveformOptions( show_controls=False, sample_rate=24000, ), ) ref_transcript = gr.Textbox( value=list(models_config[default_model_id]["speaker_mapping"].values())[0][ "ref_transcript" ], visible=False, ) speaker_wav = gr.Audio( label="客製化語音", visible=False, editable=False, type="filepath", waveform_options=gr.WaveformOptions( show_controls=False, sample_rate=24000, ), ) speaker_drop_down = gr.Dropdown( choices=[ speaker_name for speaker_name in models_config[default_model_id][ "speaker_mapping" ].keys() ], value=list(models_config[default_model_id]["speaker_mapping"].keys())[0], label="語者", interactive=True, visible=True, ) speaker_drop_down.change( when_default_speaker_selected, inputs=[model_drop_down, speaker_drop_down], outputs=[ref_wav, ref_transcript], ) use_default_or_custom_radio.change( use_default_or_custom_radio_input, inputs=[use_default_or_custom_radio], outputs=[speaker_wav, speaker_drop_down], ) dialect_radio = gr.Radio( choices=[ (k, v) for k, v in models_config[default_model_id]["dialect_mapping"].items() ], value=list(models_config[default_model_id]["dialect_mapping"].values())[0], label="腔調", interactive=len(models_config[default_model_id]["dialect_mapping"]) > 1, ) model_drop_down.input( when_model_selected, inputs=[model_drop_down], outputs=[speaker_drop_down, dialect_radio, use_default_or_custom_radio], ) input_text = gr.Textbox( label="輸入文字", value="", ) speed = gr.Slider(maximum=1.5, minimum=0.5, value=1, label="語速(越大越慢)") gr.Markdown( """ # 臺灣客語語音合成系統 ### Taiwanese Hakka Text-to-Speech System ### 研發團隊 - **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))** - **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))** ### 合作單位 - **[國立聯合大學智慧客家實驗室](https://www.gohakka.org)** """ ) gr.Interface( text_to_speech, inputs=[ model_drop_down, use_default_or_custom_radio, speaker_drop_down, dialect_radio, speed, input_text, ref_wav, ref_transcript, ], outputs=[ gr.Textbox(interactive=False, label="斷詞"), gr.Textbox(interactive=False, label="客語拼音"), gr.Audio(interactive=False, label="合成語音", show_download_button=True), ], allow_flagging="auto", ) gr.Examples( [ [ "預設語者", "XF", "sixian", "食飯愛正經食,正毋會食到半出半入", ], [ "預設語者", "XF", "sixian", "歸條路吊等長長个花燈,祈求風調雨順,歸屋下人个心願,親像花燈下燒暖个光華", ], # [ # "預設語者", # "戴君儒", # "hailu", # "男女平等个時代,平平做得受教育", # ], # [ # "預設語者", # "宋涵葳", # "dapu", # "客家山城乜跈緊鬧熱䟘來咧", # ], # [ # "預設語者", # "江芮敏", # "raoping", # "頭擺匱人,戴个毋係菅草屋,个創商品哦", # ], # [ # "預設語者", # "洪藝晅", # "zhaoan", # "歇熱个時務,阿松歸屋下轉去在客莊个老屋", # ], # [ # "預設語者", # "江芮敏", # "nansixian", # "在𠊎讀小學一年生个時節,阿爸輒常用自轉車載𠊎去學校讀書", # ], ], label="範例", inputs=[ use_default_or_custom_radio, speaker_drop_down, dialect_radio, input_text, ], ) demo.launch()