File size: 10,684 Bytes
8a1292d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2355aad
 
 
 
 
 
 
 
 
 
 
 
031b408
2355aad
 
 
 
 
 
83fc5d0
 
bd158ce
83fc5d0
2355aad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75b94c9
2355aad
 
 
 
 
 
 
 
 
 
8a1292d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb57f70
97813ba
bb57f70
 
8a1292d
 
 
80e17e6
8a1292d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e5e066
3fc5ebe
5a2b8b8
58a7186
bb57f70
5a2b8b8
2355aad
 
 
 
2e5e066
2355aad
 
2e5e066
2355aad
2e5e066
2355aad
 
 
75b94c9
2355aad
 
 
 
 
75b94c9
2355aad
 
bc670cc
2355aad
 
8a1292d
 
75b94c9
 
5a2b8b8
 
 
d546fbc
5a2b8b8
 
8a1292d
 
2e5e066
 
8a1292d
 
 
 
5a2b8b8
d2565d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import sys, os

if sys.platform == "darwin":
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import logging

logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")

logger = logging.getLogger(__name__)

import torch
import argparse
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text
import gradio as gr
import webbrowser

# ChatGLM2

from transformers import AutoModel, AutoTokenizer, AutoConfig
import gradio as gr
import mdtex2html
import torch
import os

CHECKPOINT_PATH=f'checkpoint-600'
tokenizer = AutoTokenizer.from_pretrained("chatglm2-6b", trust_remote_code=True)
config = AutoConfig.from_pretrained("chatglm2-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("chatglm2-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"), map_location=torch.device('cpu'))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

#model = model.half().cuda()

model = model.half().float()

model.transformer.prefix_encoder.float()
model = model.eval()


"""Override Chatbot.postprocess"""


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def parse_text(text):
    """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split('`')
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f'<br></code></pre>'
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", "\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>"+line
    text = "".join(lines)
    return text


def predict(input, chatbot, max_length, top_p, temperature, history, past_key_values):
    chatbot.append((parse_text(input), ""))
    for response, history, past_key_values in model.stream_chat(tokenizer, input, history, past_key_values=past_key_values,
                                                                return_past_key_values=True,
                                                                max_length=max_length, top_p=top_p,
                                                                temperature=temperature):
        chatbot[-1] = (parse_text(input), parse_text(response))

        yield chatbot, history, past_key_values, response


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return [], [], None

# Bert-VITS2

net_g = None


def get_text(text, language_str, hps):
    norm_text, phone, tone, word2ph = clean_text(text, language_str)
    phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)

    if hps.data.add_blank:
        phone = commons.intersperse(phone, 0)
        tone = commons.intersperse(tone, 0)
        language = commons.intersperse(language, 0)
        for i in range(len(word2ph)):
            word2ph[i] = word2ph[i] * 2
        word2ph[0] += 1
    bert = get_bert(norm_text, word2ph, language_str)
    del word2ph

    assert bert.shape[-1] == len(phone)

    phone = torch.LongTensor(phone)
    tone = torch.LongTensor(tone)
    language = torch.LongTensor(language)

    return bert, phone, tone, language

def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
    global net_g
    bert, phones, tones, lang_ids = get_text(text, "ZH", hps)
    with torch.no_grad():
        x_tst=phones.to(device).unsqueeze(0)
        tones=tones.to(device).unsqueeze(0)
        lang_ids=lang_ids.to(device).unsqueeze(0)
        bert = bert.to(device).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
        del phones
        speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
        audio = net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, sdp_ratio=sdp_ratio
                           , noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
        del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
        return audio

def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
    with torch.no_grad():
        audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker)
    return "Success", (hps.data.sampling_rate, audio)

image_markdown = ("""
<h1 align="center"><a href="http://www.talktalkai.com"><img src="https://media.9game.cn/gamebase/2021/7/23/227829877.jpg", alt="talktalkai" border="0" style="margin: 0 auto; height: 200px;" /></a> </h1>
""")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", default="./logs/OUTPUT_MODEL/G_13900.pth", help="path of your model")
    parser.add_argument("--config_dir", default="./configs/config.json", help="path of your config file")
    parser.add_argument("--share", default=False, help="make link public")
    parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")

    args = parser.parse_args()
    if args.debug:
        logger.info("Enable DEBUG-LEVEL log")
        logging.basicConfig(level=logging.DEBUG)
    hps = utils.get_hparams_from_file(args.config_dir)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    '''
    device = (
        "cuda:0"
        if torch.cuda.is_available()
        else (
            "mps"
            if sys.platform == "darwin" and torch.backends.mps.is_available()
            else "cpu"
        )
    )
    '''
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model).to(device)
    _ = net_g.eval()

    _ = utils.load_checkpoint(args.model_dir, net_g, None, skip_optimizer=True)

    speaker_ids = hps.data.spk2id
    speakers = list(speaker_ids.keys())
    with gr.Blocks() as app:
        gr.Markdown("# <center>🌊💕🎶 ChatGLM2 神里绫华 + Bert-VITS2</center>")
        gr.Markdown("## <center>🌟 - 和绫华 畅所欲言吧:稻妻神里流太刀术皆传,神里绫华,参上! </center>")      
        gr.Markdown("### <center>🍻 - 更多精彩应用,尽在[滔滔AI](http://www.talktalkai.com);滔滔AI,为爱滔滔!💕</center>")
        with gr.Accordion("绫华", open=True):
            gr.Markdown(image_markdown)

        chatbot = gr.Chatbot()
        with gr.Row():
            with gr.Column(scale=4):
                with gr.Column(scale=12):
                    user_input = gr.Textbox(show_label=False, placeholder="和绫华一起叙叙旧吧...", lines=8).style(
                        container=False)
                with gr.Column(min_width=32, scale=1):
                    submitBtn = gr.Button("开始对话吧!", variant="primary")
            with gr.Column(scale=1):
                emptyBtn = gr.Button("清空所有聊天记录")
                max_length = gr.Slider(0, 32768, value=8192, step=1.0, label="Maximum length", interactive=True)
                top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
                temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
                response_lh = gr.Textbox(label="神里绫华的回答", visible=False)
    
        history = gr.State([])
        past_key_values = gr.State(None)
    
        submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history, past_key_values],
                        [chatbot, history, past_key_values, response_lh], show_progress=True)
        submitBtn.click(reset_user_input, [], [user_input])
    
        emptyBtn.click(reset_state, outputs=[chatbot, history, past_key_values], show_progress=True)

        
        with gr.Row():
            with gr.Column():
                text = response_lh
                speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='Speaker', visible=False)
                with gr.Row():
                    sdp_ratio = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label='语调变化')
                    noise_scale = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.1, label='感情变化')
                with gr.Row():
                    noise_scale_w = gr.Slider(minimum=0.1, maximum=1.4, value=0.8, step=0.1, label='音节发音长度变化')
                    length_scale = gr.Slider(minimum=0.1, maximum=2, value=1, step=0.1, label='语速 (数值越小,语速越快)')
                btn = gr.Button("开启AI语音之旅吧!", variant="primary")
            with gr.Column():
                text_output = gr.Textbox(label="Message", visible=False)
                audio_output = gr.Audio(label="神里绫华发来的语音", autoplay=True)

        btn.click(tts_fn,
                inputs=[text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale],
                outputs=[text_output, audio_output])

    app.launch(show_error=True)