2568-Bert-VITS2 / app.py
CrawfordZhou's picture
Update app.py
f925290 verified
# flake8: noqa: E402
import os
import logging
import re_matching
from tools.sentence import split_by_language, sentence_split
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(
level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
import torch
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
import nltk
nltk.download('cmudict')
import utils
from infer import infer, latest_version, get_net_g
import gradio as gr
import webbrowser
import numpy as np
from config import config
net_g = None
import openai
# openai.log = "debug"
openai.api_base = "https://api.chatanywhere.com.cn/v1"
# 非流式响应
def gpt_35_api(gptkey, message):
openai.api_key = "sk-" + gptkey
completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": message}])
return completion.choices[0].message.content
device = config.webui_config.device
if device == "mps":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
def generate_audio(
slices,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
):
audio_list = []
silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
with torch.no_grad():
for piece in slices:
audio = infer(
piece,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
audio_list.append(silence) # 将静音添加到列表中
return audio_list
def tts_split(
text: str,
key,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
cut_by_sent,
interval_between_para,
interval_between_sent,
):
if language == "mix":
return ("invalid", None)
while text.find("\n\n") != -1:
text = text.replace("\n\n", "\n")
transfer_text = gpt_35_api(key, text)
para_list = re_matching.cut_para(transfer_text)
audio_list = []
if not cut_by_sent:
for p in para_list:
audio = infer(
p,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
audio_list.append(audio16bit)
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
audio_list.append(silence)
else:
for p in para_list:
audio_list_sent = []
sent_list = re_matching.cut_sent(p)
for s in sent_list:
audio = infer(
s,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
language=language,
hps=hps,
net_g=net_g,
device=device,
)
audio_list_sent.append(audio)
silence = np.zeros((int)(44100 * interval_between_sent))
audio_list_sent.append(silence)
if (interval_between_para - interval_between_sent) > 0:
silence = np.zeros(
(int)(44100 * (interval_between_para - interval_between_sent))
)
audio_list_sent.append(silence)
audio16bit = gr.processing_utils.convert_to_16_bit_wav(
np.concatenate(audio_list_sent)
) # 对完整句子做音量归一
audio_list.append(audio16bit)
audio_concat = np.concatenate(audio_list)
return ("Success", transfer_text, (44100, audio_concat))
def tts_fn(
text: str,
key,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
):
audio_list = []
transfer_text = gpt_35_api(key, text)
if language == "mix":
bool_valid, str_valid = re_matching.validate_text(transfer_text)
if not bool_valid:
return str_valid, (
hps.data.sampling_rate,
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
)
result = re_matching.text_matching(transfer_text)
for one in result:
_speaker = one.pop()
for lang, content in one:
audio_list.extend(
generate_audio(
content.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
_speaker,
lang,
)
)
elif language.lower() == "auto":
sentences_list = split_by_language(transfer_text, target_languages=["zh", "ja", "en"])
for sentences, lang in sentences_list:
lang = lang.upper()
if lang == "JA":
lang = "JP"
sentences = sentence_split(sentences, max=250)
for content in sentences:
audio_list.extend(
generate_audio(
content.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
lang,
)
)
else:
audio_list.extend(
generate_audio(
transfer_text.split("|"),
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
language,
)
)
audio_concat = np.concatenate(audio_list)
return "Success", transfer_text, (hps.data.sampling_rate, audio_concat)
if __name__ == "__main__":
if config.webui_config.debug:
logger.info("Enable DEBUG-LEVEL log")
logging.basicConfig(level=logging.DEBUG)
hps = utils.get_hparams_from_file(config.webui_config.config_path)
# 若config.json中未指定版本则默认为最新版本
version = hps.version if hasattr(hps, "version") else latest_version
net_g = get_net_g(
model_path=config.webui_config.model, version=version, device=device, hps=hps
)
speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
languages = ["ZH", "JP", "EN", "auto", "mix"]
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
gr.Markdown(value="""
# AI星瞳——gpt对话版(Bert-Vits2 2.0中日英)\n
作者:[Xz乔希](https://space.bilibili.com/5859321) & [碎语碎念](https://space.bilibili.com/4269384) 声音归属:[星瞳_Official](https://space.bilibili.com/401315430) \n
Bert-VITS2项目:https://github.com/Stardust-minus/Bert-VITS2\n
GPT_API_free项目:https://github.com/chatanywhere/GPT_API_free\n
本项目中的apiKey可以从https://github.com/chatanywhere/GPT_API_free\n
免费获取(本项目默认提供了一个,如果没法用了去仓库申请替换就好啦)!\n
使用本模型请严格遵守法律法规!\n
发布二创作品请标注本项目作者及链接、作品使用Bert-VITS2 AI生成!\n
""")
with gr.Column():
gr.Markdown(value="""
![avatar](https://img1.baidu.com/it/u=381691319,2894195285&fm=253&fmt=auto&app=138&f=JPEG?w=400&h=300)\n
""")
with gr.Row():
with gr.Column():
text = gr.TextArea(
label="请输入要向星瞳老师提问的问题",
placeholder="""
虚拟主播是什么?
""",
value="虚拟主播是什么?"
)
key = gr.Text(label="GPT Key", placeholder="请输入上面提示中获取的gpt key",
value="izlrijShDu7tp2rIgvYfibcC2J0Eh3uWfdm9ndrxN5nWrL96")
speaker = gr.Dropdown(
choices=speakers, value=speakers[0], label="选择说话人"
)
sdp_ratio = gr.Slider(
minimum=0, maximum=1, value=0.2, step=0.01, label="SDP/DP混合比"
)
noise_scale = gr.Slider(
minimum=0.1, maximum=2, value=0.5, step=0.01, label="感情"
)
noise_scale_w = gr.Slider(
minimum=0.1, maximum=2, value=0.9, step=0.01, label="音素长度"
)
length_scale = gr.Slider(
minimum=0.1, maximum=2, value=1.0, step=0.01, label="语速"
)
language = gr.Dropdown(
choices=languages, value=languages[3], label="选择语言"
)
with gr.Column():
with gr.Row():
with gr.Column():
btn = gr.Button("点击生成", variant="primary")
interval_between_sent = gr.Slider(
minimum=0,
maximum=5,
value=0.2,
step=0.1,
label="句间停顿(秒),勾选按句切分才生效",
)
interval_between_para = gr.Slider(
minimum=0,
maximum=10,
value=1,
step=0.1,
label="段间停顿(秒),需要大于句间停顿才有效",
)
opt_cut_by_sent = gr.Checkbox(
label="按句切分 在按段落切分的基础上再按句子切分文本"
)
slicer = gr.Button("切分生成", variant="primary")
text_output = gr.Textbox(label="状态信息")
gpt_output = gr.TextArea(label="星瞳老师的文字答案")
audio_output = gr.Audio(label="星瞳老师的语音答案")
# explain_image = gr.Image(
# label="参数解释信息",
# show_label=True,
# show_share_button=False,
# show_download_button=False,
# value=os.path.abspath("./img/参数说明.png"),
# )
btn.click(
tts_fn,
inputs=[
text,
key,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
],
outputs=[text_output, gpt_output, audio_output],
)
slicer.click(
tts_split,
inputs=[
text,
key,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
opt_cut_by_sent,
interval_between_para,
interval_between_sent,
],
outputs=[text_output, gpt_output, audio_output],
)
print("推理页面已开启!")
webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
app.launch(share=config.webui_config.share, server_port=config.webui_config.port)