|
import logging |
|
import re |
|
|
|
import gradio as gr |
|
import numpy |
|
import torch |
|
|
|
import utils |
|
from infer import infer, get_net_g |
|
|
|
logging.getLogger("numba").setLevel(logging.WARNING) |
|
logging.getLogger("markdown_it").setLevel(logging.WARNING) |
|
logging.getLogger("urllib3").setLevel(logging.WARNING) |
|
logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
|
|
logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s") |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
net_g = None |
|
hps = None |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_path = "models/G_1000.pth" |
|
sampling_rate = 22050 |
|
|
|
|
|
def split_sentence(sentence: str): |
|
if len(sentence) == 0: |
|
return [] |
|
|
|
result = [] |
|
|
|
is_english = [i.isascii() for i in sentence] |
|
is_chinese = [not re.match(r"[a-zA-Z]", i) for i in sentence] |
|
|
|
assert len(is_english) == len(is_chinese) == len(sentence), "bad length" |
|
assert is_english[0] or is_chinese[0], "bad first char: " + sentence[0] |
|
|
|
current_language = '' |
|
current_chain = [] |
|
for idx in range(len(sentence)): |
|
if not is_english[idx]: |
|
current_language = 'ZH' |
|
current_chain = is_chinese |
|
break |
|
if not is_chinese[idx]: |
|
current_language = 'EN' |
|
current_chain = is_english |
|
break |
|
pass |
|
|
|
step = 0 |
|
while step < len(sentence): |
|
try: |
|
next_step = current_chain.index(False, step) |
|
except ValueError: |
|
next_step = len(sentence) |
|
result.append((sentence[step:next_step], current_language)) |
|
step = next_step |
|
current_language = 'ZH' if current_language == 'EN' else 'EN' |
|
current_chain = is_chinese if current_language == 'ZH' else is_english |
|
pass |
|
|
|
return result |
|
|
|
|
|
def tts_fn( |
|
text: str, |
|
speaker, |
|
sdp_ratio, |
|
noise_scale, |
|
noise_scale_w, |
|
length_scale, |
|
language, |
|
): |
|
language = 'ZH' if language == '普通话' else 'SH' |
|
sentences = split_sentence(text) |
|
|
|
silence = numpy.zeros(sampling_rate // 2, dtype=numpy.int16) |
|
audio_data = numpy.array([], dtype=numpy.float32) |
|
for (sentence, sentence_language) in sentences: |
|
sub_audio_data = infer( |
|
sentence, |
|
sdp_ratio, |
|
noise_scale, |
|
noise_scale_w, |
|
length_scale, |
|
sid=speaker, |
|
language=language if sentence_language == "ZH" else sentence_language, |
|
hps=hps, |
|
net_g=net_g, |
|
device=device) |
|
audio_data = numpy.concatenate((audio_data, sub_audio_data, silence)) |
|
|
|
audio_data = audio_data / numpy.abs(audio_data).max() |
|
audio_data = audio_data * 32767 |
|
audio_data = audio_data.astype(numpy.int16) |
|
|
|
return "Success", (sampling_rate, audio_data) |
|
|
|
|
|
def main(): |
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
global hps |
|
hps = utils.get_hparams_from_file("configs/config.json") |
|
|
|
global net_g |
|
net_g = get_net_g(model_path=model_path, device=device, hps=hps) |
|
|
|
speaker_ids = hps.data.spk2id |
|
speakers = list(speaker_ids.keys()) |
|
languages = ["上海话", "普通话"] |
|
with gr.Blocks() as app: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown('# Bert-VITS2-Shanghainese') |
|
with gr.Row(): |
|
with gr.Column(): |
|
text = gr.TextArea( |
|
label="输入文本内容", |
|
value="\n".join([ |
|
"站一个制高点看上海,", |
|
"Looking at Shanghai from a commanding height,", |
|
"上海的弄堂是壮观的景象。", |
|
"The alleys in Shanghai are a great sight.", |
|
"它是这城市背景一样的东西。", |
|
"It is something with the same background as this city." |
|
]), |
|
) |
|
sdp_ratio = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label="SDP/DP混合比") |
|
noise_scale = gr.Slider(minimum=0.1, maximum=2, value=0.6, step=0.1, label="感情") |
|
noise_scale_w = gr.Slider(minimum=0.1, maximum=2, value=0.8, step=0.1, label="音素长度") |
|
length_scale = gr.Slider(minimum=0.1, maximum=2, value=1.0, step=0.1, label="语速") |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="选择说话人") |
|
with gr.Column(): |
|
language = gr.Dropdown(choices=languages, value=languages[0], label="选择语言") |
|
submit_btn = gr.Button("生成音频", variant="primary") |
|
text_output = gr.Textbox(label="状态") |
|
audio_output = gr.Audio(label="音频") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown('### Thanks <a href="https://github.com/fishaudio/Bert-VITS2" target="_blank">github@Bert-VITS2</a>') |
|
submit_btn.click( |
|
tts_fn, |
|
inputs=[ |
|
text, |
|
speaker, |
|
sdp_ratio, |
|
noise_scale, |
|
noise_scale_w, |
|
length_scale, |
|
language, |
|
], |
|
outputs=[text_output, audio_output], |
|
) |
|
|
|
app.launch(share=False, server_name="0.0.0.0", server_port=7860) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|