File size: 7,415 Bytes
bf0a127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d5e4b2
 
 
 
 
 
 
 
 
 
 
 
 
bf0a127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b7b812
bf0a127
2b7b812
01381fa
bf0a127
 
 
 
 
 
 
 
 
 
 
 
 
55234f5
 
bf0a127
2b7b812
8052595
bf0a127
 
 
 
b6d7961
bf0a127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e659cb
 
 
 
 
 
 
 
 
 
 
 
33c803b
 
 
 
8052595
 
3dce183
f54871c
2b7b812
67c0171
3dce183
ed2ee73
 
 
f2850f7
7be2e7a
8052595
3dce183
3e659cb
 
bf0a127
55234f5
2b7b812
4b5c1de
bf0a127
 
4b5c1de
bf0a127
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import sys, os
if sys.platform == "darwin":
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import logging

logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")

logger = logging.getLogger(__name__)

import torch
import argparse
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import cleaned_text_to_sequence, get_bert
from text.cleaner import clean_text
import gradio as gr
import webbrowser

import openai

# openai.log = "debug"
openai.api_base = "https://api.chatanywhere.com.cn/v1"


# 非流式响应

def gpt_35_api(gptkey, message):
    openai.api_key = "sk-" + gptkey
    completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": message}])
    return completion.choices[0].message.content


net_g = None


def get_text(text, language_str, hps):
    norm_text, phone, tone, word2ph = clean_text(text, language_str)
    phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)

    if hps.data.add_blank:
        phone = commons.intersperse(phone, 0)
        tone = commons.intersperse(tone, 0)
        language = commons.intersperse(language, 0)
        for i in range(len(word2ph)):
            word2ph[i] = word2ph[i] * 2
        word2ph[0] += 1
    bert = get_bert(norm_text, word2ph, language_str)
    del word2ph

    assert bert.shape[-1] == len(phone)

    phone = torch.LongTensor(phone)
    tone = torch.LongTensor(tone)
    language = torch.LongTensor(language)

    return bert, phone, tone, language

def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
    global net_g
    print(text)
    bert, phones, tones, lang_ids = get_text(text, "ZH", hps)
    with torch.no_grad():
        x_tst=phones.to(device).unsqueeze(0)
        tones=tones.to(device).unsqueeze(0)
        lang_ids=lang_ids.to(device).unsqueeze(0)
        bert = bert.to(device).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
        del phones
        speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
        audio = net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, sdp_ratio=sdp_ratio
                           , noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
        del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
        return audio

def tts_fn(text, key, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
    message = gpt_35_api(key, text)
    with torch.no_grad():
        audio = infer(message, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker)
    return "Success", message, (hps.data.sampling_rate, audio)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", default="./logs/Diana/G_4800.pth", help="path of your model")
    parser.add_argument("--config_dir", default="./configs/config.json", help="path of your config file")
    parser.add_argument("--share", default=False, help="make link public")
    parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")

    args = parser.parse_args()
    if args.debug:
        logger.info("Enable DEBUG-LEVEL log")
        logging.basicConfig(level=logging.DEBUG)
    hps = utils.get_hparams_from_file(args.config_dir)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    '''
    device = (
        "cuda:0"
        if torch.cuda.is_available()
        else (
            "mps"
            if sys.platform == "darwin" and torch.backends.mps.is_available()
            else "cpu"
        )
    )
    '''
    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model).to(device)
    _ = net_g.eval()

    _ = utils.load_checkpoint(args.model_dir, net_g, None, skip_optimizer=True)

    speaker_ids = hps.data.spk2id
    speakers = list(speaker_ids.keys())
    with gr.Blocks() as app:
        with gr.Row():
            with gr.Column():
                gr.Markdown(value="""
                    # AI嘉然——在线gpt对话版(Bert-Vits2 + gpt)\n
                    (注:转发生成可能较慢,请等待大约2分钟哦,目前只支持中文语言的答案,其他语言的支持在开发中捏!)\n
                    作者:[Xz乔希](https://space.bilibili.com/5859321) & [碎语碎念](https://space.bilibili.com/4269384) 声音归属:[嘉然今天吃什么](https://space.bilibili.com/672328094) \n
                    Bert-VITS2项目:https://github.com/Stardust-minus/Bert-VITS2\n
                    GPT_API_free项目:https://github.com/chatanywhere/GPT_API_free\n
                    本项目中的apiKey可以从https://github.com/chatanywhere/GPT_API_free\n
                    免费获取(本项目默认提供了一个,如果没法用了去仓库申请替换就好啦)!\n
                    使用本模型请严格遵守法律法规!\n
                    发布二创作品请标注本项目作者及链接、作品使用Bert-VITS2 AI生成!\n                
                    """)
            with gr.Column():
                gr.Markdown(value="""
                    ![avatar](https://img0.baidu.com/it/u=3803234914,1472318531&fm=253&fmt=auto&app=120&f=PNG?w=500&h=500) \n               
                    """)
        with gr.Row():
            with gr.Column():
                text = gr.TextArea(label="请输入要提问然然的问题", placeholder="Input Text Here",
                                      value="虚拟主播是什么?")
                key = gr.Text(label="GPT Key(过期请更换)", placeholder="请输入上面提示中获取的gpt key",
                                   value="izlrijShDu7tp2rIgvYfibcC2J0Eh3uWfdm9ndrxN5nWrL96")
                speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='说话人')
                sdp_ratio = gr.Slider(minimum=0.1, maximum=1, value=0.2, step=0.01, label='SDP/DP混合比')
                noise_scale = gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.01, label='感情调节')
                noise_scale_w = gr.Slider(minimum=0.1, maximum=1, value=0.9, step=0.01, label='音素长度')
                length_scale = gr.Slider(minimum=0.1, maximum=2, value=1, step=0.01, label='生成长度')
            with gr.Column():
                btn = gr.Button("点击生成", variant="primary")
                text_output = gr.Textbox(label="获取状态")
                gpt_output = gr.TextArea(label="然然老师的文字回答")
                audio_output = gr.Audio(label="然然老师的语音回答")
        btn.click(tts_fn,
                inputs=[text, key, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale],
                outputs=[text_output, gpt_output, audio_output])

#    webbrowser.open("http://127.0.0.1:6006")
#    app.launch(server_port=6006, show_error=True)

    app.launch(show_error=True)