Spaces:
Running
Running
# flake8: noqa: E402 | |
import os | |
import logging | |
import re_matching | |
from tools.sentence import split_by_language | |
logging.getLogger("numba").setLevel(logging.WARNING) | |
logging.getLogger("markdown_it").setLevel(logging.WARNING) | |
logging.getLogger("urllib3").setLevel(logging.WARNING) | |
logging.getLogger("matplotlib").setLevel(logging.WARNING) | |
logging.basicConfig( | |
level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
import torch | |
import utils | |
from infer import infer, latest_version, get_net_g, infer_multilang | |
import gradio as gr | |
import webbrowser | |
import numpy as np | |
from config import config | |
from tools.translate import translate | |
net_g = None | |
device = config.webui_config.device | |
if device == "mps": | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
def generate_audio( | |
slices, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
speaker, | |
language, | |
skip_start=False, | |
skip_end=False, | |
): | |
audio_list = [] | |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16) | |
with torch.no_grad(): | |
for idx, piece in enumerate(slices): | |
skip_start = (idx != 0) and skip_start | |
skip_end = (idx != len(slices) - 1) and skip_end | |
audio = infer( | |
piece, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
sid=speaker, | |
language=language, | |
hps=hps, | |
net_g=net_g, | |
device=device, | |
skip_start=skip_start, | |
skip_end=skip_end, | |
) | |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio) | |
audio_list.append(audio16bit) | |
# audio_list.append(silence) # 将静音添加到列表中 | |
return audio_list | |
def generate_audio_multilang( | |
slices, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
speaker, | |
language, | |
skip_start=False, | |
skip_end=False, | |
): | |
audio_list = [] | |
# silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16) | |
with torch.no_grad(): | |
for idx, piece in enumerate(slices): | |
skip_start = (idx != 0) and skip_start | |
skip_end = (idx != len(slices) - 1) and skip_end | |
audio = infer_multilang( | |
piece, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
sid=speaker, | |
language=language[idx], | |
hps=hps, | |
net_g=net_g, | |
device=device, | |
skip_start=skip_start, | |
skip_end=skip_end, | |
) | |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio) | |
audio_list.append(audio16bit) | |
# audio_list.append(silence) # 将静音添加到列表中 | |
return audio_list | |
def tts_split( | |
text: str, | |
speaker, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
language, | |
cut_by_sent, | |
interval_between_para, | |
interval_between_sent, | |
): | |
if language == "mix": | |
return ("invalid", None) | |
while text.find("\n\n") != -1: | |
text = text.replace("\n\n", "\n") | |
para_list = re_matching.cut_para(text) | |
audio_list = [] | |
if not cut_by_sent: | |
for idx, p in enumerate(para_list): | |
skip_start = idx != 0 | |
skip_end = idx != len(para_list) - 1 | |
audio = infer( | |
p, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
sid=speaker, | |
language=language, | |
hps=hps, | |
net_g=net_g, | |
device=device, | |
skip_start=skip_start, | |
skip_end=skip_end, | |
) | |
audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio) | |
audio_list.append(audio16bit) | |
silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16) | |
audio_list.append(silence) | |
else: | |
for idx, p in enumerate(para_list): | |
skip_start = idx != 0 | |
skip_end = idx != len(para_list) - 1 | |
audio_list_sent = [] | |
sent_list = re_matching.cut_sent(p) | |
for idx, s in enumerate(sent_list): | |
skip_start = (idx != 0) and skip_start | |
skip_end = (idx != len(sent_list) - 1) and skip_end | |
audio = infer( | |
s, | |
sdp_ratio=sdp_ratio, | |
noise_scale=noise_scale, | |
noise_scale_w=noise_scale_w, | |
length_scale=length_scale, | |
sid=speaker, | |
language=language, | |
hps=hps, | |
net_g=net_g, | |
device=device, | |
skip_start=skip_start, | |
skip_end=skip_end, | |
) | |
audio_list_sent.append(audio) | |
silence = np.zeros((int)(44100 * interval_between_sent)) | |
audio_list_sent.append(silence) | |
if (interval_between_para - interval_between_sent) > 0: | |
silence = np.zeros( | |
(int)(44100 * (interval_between_para - interval_between_sent)) | |
) | |
audio_list_sent.append(silence) | |
audio16bit = gr.processing_utils.convert_to_16_bit_wav( | |
np.concatenate(audio_list_sent) | |
) # 对完整句子做音量归一 | |
audio_list.append(audio16bit) | |
audio_concat = np.concatenate(audio_list) | |
return ("Success", (44100, audio_concat)) | |
def tts_fn( | |
text: str, | |
speaker, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
language, | |
): | |
audio_list = [] | |
if language == "mix": | |
bool_valid, str_valid = re_matching.validate_text(text) | |
if not bool_valid: | |
return str_valid, ( | |
hps.data.sampling_rate, | |
np.concatenate([np.zeros(hps.data.sampling_rate // 2)]), | |
) | |
result = [] | |
for slice in re_matching.text_matching(text): | |
_speaker = slice.pop() | |
temp_contant = [] | |
temp_lang = [] | |
for lang, content in slice: | |
if "|" in content: | |
temp = [] | |
temp_ = [] | |
for i in content.split("|"): | |
if i != "": | |
temp.append([i]) | |
temp_.append([lang]) | |
else: | |
temp.append([]) | |
temp_.append([]) | |
temp_contant += temp | |
temp_lang += temp_ | |
else: | |
if len(temp_contant) == 0: | |
temp_contant.append([]) | |
temp_lang.append([]) | |
temp_contant[-1].append(content) | |
temp_lang[-1].append(lang) | |
for i, j in zip(temp_lang, temp_contant): | |
result.append([*zip(i, j), _speaker]) | |
for i, one in enumerate(result): | |
skip_start = i != 0 | |
skip_end = i != len(result) - 1 | |
_speaker = one.pop() | |
idx = 0 | |
while idx < len(one): | |
text_to_generate = [] | |
lang_to_generate = [] | |
while True: | |
lang, content = one[idx] | |
temp_text = [content] | |
if len(text_to_generate) > 0: | |
text_to_generate[-1] += [temp_text.pop(0)] | |
lang_to_generate[-1] += [lang] | |
if len(temp_text) > 0: | |
text_to_generate += [[i] for i in temp_text] | |
lang_to_generate += [[lang]] * len(temp_text) | |
if idx + 1 < len(one): | |
idx += 1 | |
else: | |
break | |
skip_start = (idx != 0) and skip_start | |
skip_end = (idx != len(one) - 1) and skip_end | |
print(text_to_generate, lang_to_generate) | |
audio_list.extend( | |
generate_audio_multilang( | |
text_to_generate, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
_speaker, | |
lang_to_generate, | |
skip_start, | |
skip_end, | |
) | |
) | |
idx += 1 | |
elif language.lower() == "auto": | |
for idx, slice in enumerate(text.split("|")): | |
if slice == "": | |
continue | |
skip_start = idx != 0 | |
skip_end = idx != len(text.split("|")) - 1 | |
sentences_list = split_by_language( | |
slice, target_languages=["zh", "ja", "en"] | |
) | |
idx = 0 | |
while idx < len(sentences_list): | |
text_to_generate = [] | |
lang_to_generate = [] | |
while True: | |
content, lang = sentences_list[idx] | |
temp_text = [content] | |
lang = lang.upper() | |
if lang == "JA": | |
lang = "JP" | |
if len(text_to_generate) > 0: | |
text_to_generate[-1] += [temp_text.pop(0)] | |
lang_to_generate[-1] += [lang] | |
if len(temp_text) > 0: | |
text_to_generate += [[i] for i in temp_text] | |
lang_to_generate += [[lang]] * len(temp_text) | |
if idx + 1 < len(sentences_list): | |
idx += 1 | |
else: | |
break | |
skip_start = (idx != 0) and skip_start | |
skip_end = (idx != len(sentences_list) - 1) and skip_end | |
print(text_to_generate, lang_to_generate) | |
audio_list.extend( | |
generate_audio_multilang( | |
text_to_generate, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
speaker, | |
lang_to_generate, | |
skip_start, | |
skip_end, | |
) | |
) | |
idx += 1 | |
else: | |
audio_list.extend( | |
generate_audio( | |
text.split("|"), | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
speaker, | |
language, | |
) | |
) | |
audio_concat = np.concatenate(audio_list) | |
return "Success", (hps.data.sampling_rate, audio_concat) | |
if __name__ == "__main__": | |
if config.webui_config.debug: | |
logger.info("Enable DEBUG-LEVEL log") | |
logging.basicConfig(level=logging.DEBUG) | |
hps = utils.get_hparams_from_file(config.webui_config.config_path) | |
# 若config.json中未指定版本则默认为最新版本 | |
version = hps.version if hasattr(hps, "version") else latest_version | |
net_g = get_net_g( | |
model_path=config.webui_config.model, version=version, device=device, hps=hps | |
) | |
speaker_ids = hps.data.spk2id | |
speakers = list(speaker_ids.keys()) | |
languages = ["ZH", "JP", "EN", "auto", "mix"] | |
with gr.Blocks() as app: | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(value=""" | |
【AI塔菲】在线语音合成(Bert-Vits2 2.0中日英)\n | |
作者:Xz乔希 https://space.bilibili.com/5859321\n | |
声音归属:永雏塔菲 https://space.bilibili.com/1265680561\n | |
【AI合集】https://www.modelscope.cn/studios/xzjosh/Bert-VITS2\n | |
Bert-VITS2项目:https://github.com/Stardust-minus/Bert-VITS2\n | |
使用本模型请严格遵守法律法规!\n | |
发布二创作品请标注本项目作者及链接、作品使用Bert-VITS2 AI生成!\n | |
【提示】手机端容易误触调节,请刷新恢复默认!每次生成的结果都不一样,效果不好请尝试多次生成与调节,选择最佳结果!\n | |
""") | |
text = gr.TextArea( | |
label="输入文本内容", | |
placeholder=""" | |
推荐不同语言分开推理,因为无法连贯且可能影响最终效果! | |
如果选择语言为\'auto\',有概率无法识别。 | |
如果选择语言为\'mix\',必须按照格式输入,否则报错: | |
格式举例(zh是中文,jp是日语,en是英语;不区分大小写): | |
[说话人]<zh>你好 <jp>こんにちは <en>Hello | |
另外,所有的语言选项都可以用'|'分割长段实现分句生成。 | |
""", | |
) | |
speaker = gr.Dropdown( | |
choices=speakers, value=speakers[0], label="选择说话人" | |
) | |
sdp_ratio = gr.Slider( | |
minimum=0, maximum=1, value=0.2, step=0.01, label="SDP/DP混合比" | |
) | |
noise_scale = gr.Slider( | |
minimum=0.1, maximum=2, value=0.5, step=0.01, label="感情" | |
) | |
noise_scale_w = gr.Slider( | |
minimum=0.1, maximum=2, value=0.9, step=0.01, label="音素长度" | |
) | |
length_scale = gr.Slider( | |
minimum=0.1, maximum=2, value=1.0, step=0.01, label="语速" | |
) | |
language = gr.Dropdown( | |
choices=languages, value=languages[0], label="选择语言" | |
) | |
btn = gr.Button("点击生成", variant="primary") | |
with gr.Column(): | |
with gr.Row(): | |
with gr.Column(): | |
interval_between_sent = gr.Slider( | |
minimum=0, | |
maximum=5, | |
value=0.2, | |
step=0.1, | |
label="句间停顿(秒),勾选按句切分才生效", | |
) | |
interval_between_para = gr.Slider( | |
minimum=0, | |
maximum=10, | |
value=1, | |
step=0.1, | |
label="段间停顿(秒),需要大于句间停顿才有效", | |
) | |
opt_cut_by_sent = gr.Checkbox( | |
label="按句切分 在按段落切分的基础上再按句子切分文本" | |
) | |
slicer = gr.Button("切分生成", variant="primary") | |
text_output = gr.Textbox(label="状态信息") | |
audio_output = gr.Audio(label="输出音频") | |
# explain_image = gr.Image( | |
# label="参数解释信息", | |
# show_label=True, | |
# show_share_button=False, | |
# show_download_button=False, | |
# value=os.path.abspath("./img/参数说明.png"), | |
# ) | |
btn.click( | |
tts_fn, | |
inputs=[ | |
text, | |
speaker, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
language, | |
], | |
outputs=[text_output, audio_output], | |
) | |
slicer.click( | |
tts_split, | |
inputs=[ | |
text, | |
speaker, | |
sdp_ratio, | |
noise_scale, | |
noise_scale_w, | |
length_scale, | |
language, | |
opt_cut_by_sent, | |
interval_between_para, | |
interval_between_sent, | |
], | |
outputs=[text_output, audio_output], | |
) | |
print("推理页面已开启!") | |
webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}") | |
app.launch(share=config.webui_config.share, server_port=config.webui_config.port) | |