File size: 5,509 Bytes
b2458f3 f584511 b2458f3 f584511 b2458f3 f584511 b2458f3 |
|
import logging
import re
import gradio as gr
import numpy
import torch
import utils
from infer import infer, get_net_g
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)
net_g = None
hps = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "models/G_1000.pth"
sampling_rate = 22050
def split_sentence(sentence: str):
if len(sentence) == 0:
return []
result = []
is_english = [i.isascii() for i in sentence]
is_chinese = [not re.match(r"[a-zA-Z]", i) for i in sentence]
assert len(is_english) == len(is_chinese) == len(sentence), "bad length"
assert is_english[0] or is_chinese[0], "bad first char: " + sentence[0]
current_language = ''
current_chain = []
for idx in range(len(sentence)):
if not is_english[idx]:
current_language = 'ZH'
current_chain = is_chinese
break
if not is_chinese[idx]:
current_language = 'EN'
current_chain = is_english
break
pass
step = 0
while step < len(sentence):
try:
next_step = current_chain.index(False, step)
except ValueError:
next_step = len(sentence)
result.append((sentence[step:next_step], current_language))
step = next_step
current_language = 'ZH' if current_language == 'EN' else 'EN'
current_chain = is_chinese if current_language == 'ZH' else is_english
pass
return result
def tts_fn(
text: str,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
):
language = 'ZH' if language == '普通话' else 'SH'
sentences = split_sentence(text)
silence = numpy.zeros(sampling_rate // 2, dtype=numpy.int16)
audio_data = numpy.array([], dtype=numpy.float32)
for (sentence, sentence_language) in sentences:
sub_audio_data = infer(
sentence,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid=speaker,
language=language if sentence_language == "ZH" else sentence_language,
hps=hps,
net_g=net_g,
device=device)
audio_data = numpy.concatenate((audio_data, sub_audio_data, silence))
audio_data = audio_data / numpy.abs(audio_data).max()
audio_data = audio_data * 32767
audio_data = audio_data.astype(numpy.int16)
return "Success", (sampling_rate, audio_data)
def main():
logging.basicConfig(level=logging.DEBUG)
global hps
hps = utils.get_hparams_from_file("configs/config.json")
global net_g
net_g = get_net_g(model_path=model_path, device=device, hps=hps)
speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
languages = ["上海话", "普通话"]
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
gr.Markdown('# Bert-VITS2-Shanghainese')
with gr.Row():
with gr.Column():
text = gr.TextArea(
label="输入文本内容",
value="\n".join([
"站一个制高点看上海,",
"Looking at Shanghai from a commanding height,",
"上海的弄堂是壮观的景象。",
"The alleys in Shanghai are a great sight.",
"它是这城市背景一样的东西。",
"It is something with the same background as this city."
]),
)
sdp_ratio = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label="SDP/DP混合比")
noise_scale = gr.Slider(minimum=0.1, maximum=2, value=0.6, step=0.1, label="感情")
noise_scale_w = gr.Slider(minimum=0.1, maximum=2, value=0.8, step=0.1, label="音素长度")
length_scale = gr.Slider(minimum=0.1, maximum=2, value=1.0, step=0.1, label="语速")
with gr.Column():
with gr.Row():
with gr.Column():
speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="选择说话人")
with gr.Column():
language = gr.Dropdown(choices=languages, value=languages[0], label="选择语言")
submit_btn = gr.Button("生成音频", variant="primary")
text_output = gr.Textbox(label="状态")
audio_output = gr.Audio(label="音频")
with gr.Row():
with gr.Column():
gr.Markdown('### Thanks <a href="https://github.com/fishaudio/Bert-VITS2" target="_blank">github@Bert-VITS2</a>')
submit_btn.click(
tts_fn,
inputs=[
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
language,
],
outputs=[text_output, audio_output],
)
app.launch(share=False, server_name="0.0.0.0", server_port=7860)
if __name__ == "__main__":
main()
|