File size: 4,869 Bytes
af3d42a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from infer import OnnxInferenceSession
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text
import numpy as np
from huggingface_hub import hf_hub_download
import asyncio
from pathlib import Path

OnnxSession = None

models = [
    {
        "local_path": "./bert/bert-large-cantonese",
        "repo_id": "hon9kon9ize/bert-large-cantonese",
        "files": [
            "pytorch_model.bin"
        ]
    },
    {
        "local_path": "./bert/deberta-v3-large",
        "repo_id": "microsoft/deberta-v3-large",
        "files": [
            "spm.model",
            "pytorch_model.bin"
        ]
    },
    {
        "local_path": "./onnx",
        "repo_id": "hon9kon9ize/bert-vits-zoengjyutgaai-onnx",
        "files": [
            "BertVits2.2PT.json",
            "BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
            "BertVits2.2PT/BertVits2.2PT_emb.onnx",
            "BertVits2.2PT/BertVits2.2PT_dp.onnx",
            "BertVits2.2PT/BertVits2.2PT_sdp.onnx",
            "BertVits2.2PT/BertVits2.2PT_flow.onnx",
            "BertVits2.2PT/BertVits2.2PT_dec.onnx"
        ]
    }
]

def get_onnx_session():
    global OnnxSession

    if OnnxSession is not None:
        return OnnxSession

    OnnxSession = OnnxInferenceSession(
        {
            "enc": "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
            "emb_g": "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
            "dp": "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
            "sdp": "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
            "flow": "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
            "dec": "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx",
        },
        Providers=["CPUExecutionProvider"],
    )
    return OnnxSession

def download_model_files(repo_id, files, local_path):
    for file in files:
        if not Path(local_path).joinpath(file).exists():
            hf_hub_download(
                repo_id, file, local_dir=local_path, local_dir_use_symlinks=False
            )

def download_models():
    for data in models:
        download_model_files(data["repo_id"], data["files"], data["local_path"])

def intersperse(lst, item):
    result = [item] * (len(lst) * 2 + 1)
    result[1::2] = lst
    return result

def get_text(text, language_str, style_text=None, style_weight=0.7):
    style_text = None if style_text == "" else style_text
    # 在此处实现当前版本的get_text
    norm_text, phone, tone, word2ph = clean_text(text, language_str)
    phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)

    # add blank
    phone = intersperse(phone, 0)
    tone = intersperse(tone, 0)
    language = intersperse(language, 0)
    for i in range(len(word2ph)):
        word2ph[i] = word2ph[i] * 2
    word2ph[0] += 1

    bert_ori = get_bert(
        norm_text, word2ph, language_str, "cpu", style_text, style_weight
    )
    del word2ph
    assert bert_ori.shape[-1] == len(phone), phone

    if language_str == "EN":
        en_bert = bert_ori
        yue_bert = np.random.randn(1024, len(phone))
    elif language_str == "YUE":
        en_bert = np.random.randn(1024, len(phone))
        yue_bert = bert_ori
    else:
        raise ValueError("language_str should be EN or YUE")

    assert yue_bert.shape[-1] == len(
        phone
    ), f"Bert seq len {yue_bert.shape[-1]} != {len(phone)}"

    phone = np.asarray(phone)
    tone = np.asarray(tone)
    language = np.asarray(language)
    en_bert = np.asarray(en_bert.T)
    yue_bert = np.asarray(yue_bert.T)

    return en_bert, yue_bert, phone, tone, language

# Text-to-speech function
async def text_to_speech(text, sid=0, language="YUE"):
    Session = get_onnx_session()
    if not text.strip():
        return None, gr.Warning("Please enter text to convert.")
    en_bert, yue_bert, x, tone, language = get_text(text, language)
    sid = np.array([sid])
    audio = Session(x, tone, language, en_bert, yue_bert, sid)

    return audio[0][0]


# Create Gradio application
import gradio as gr

# Gradio interface function
def tts_interface(text):
    audio = asyncio.run(text_to_speech(text, 0, "YUE"))
    return 44100, audio

async def create_demo():    
    description = """廣東話語音生成器,基於Bert-VITS2模型

注意:model 本身支持廣東話同英文,但呢個 space 未實現中英夾雜生成。
"""
    
    demo = gr.Interface(
        fn=tts_interface,
        inputs=[
            gr.Textbox(label="Input Text", lines=5),
        ],
        outputs=[
            gr.Audio(label="Generated Audio"),
        ],
        title="Cantonese TTS Text-to-Speech",
        description=description,
        analytics_enabled=False,
        allow_flagging=False
    )
    return demo


# Run the application
if __name__ == "__main__":
    download_models()

    demo = asyncio.run(create_demo())
    demo.launch()