94insane's picture
Update app.py
8c51768
raw
history blame
No virus
6.12 kB
# @ 2023.10.23
# @ Elena
import sys, os
os.system('apt-get update')
os.system('apt-get install espeak')
import logging
import re
from scipy.io.wavfile import write
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(
level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
import torch
import argparse
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
import gradio as gr
import webbrowser
import numpy as np
'''# - paths
path_to_config = "config.json" # path to .json
path_to_model = "best.pth" # path to G_xxxx.pth'''
net_g = None
if sys.platform == "darwin" and torch.backends.mps.is_available():
device = "mps"
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
else:
device = "cuda"
def get_text(text, hps):
text_norm = text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
global net_g
fltstr = re.sub(r"[\[\]\(\)\{\}]", "", text)
stn_tst = get_text(fltstr, hps)
speed = 1
output_dir = 'output'
sid = 0
with torch.no_grad():
x_tst = stn_tst.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1 / speed)[0][
0, 0].data.cpu().float().numpy()
return audio
def tts_fn(
text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale
):
slices = text.split("|")
audio_list = []
with torch.no_grad():
for slice in slices:
audio = infer(
slice,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
sid=speaker,
)
audio_list.append(audio)
silence = np.zeros(hps.data.sampling_rate)
audio_list.append(silence)
audio_concat = np.concatenate(audio_list)
return "Success", (hps.data.sampling_rate, audio_concat)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m", "--model", default="./best.pth", help="path of your model"
)
parser.add_argument(
"-c",
"--config",
default="./config.json",
help="path of your config file",
)
parser.add_argument(
"--share", default=False, help="make link public", action="store_true"
)
parser.add_argument(
"-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log"
)
args = parser.parse_args()
if args.debug:
logger.info("Enable DEBUG-LEVEL log")
logging.basicConfig(level=logging.DEBUG)
hps = utils.get_hparams_from_file(args.config)
if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True:
print("Using mel posterior encoder for VITS2")
posterior_channels = 80 # vits2
hps.data.use_mel_posterior_encoder = True
else:
print("Using lin posterior encoder for VITS1")
posterior_channels = hps.data.filter_length // 2 + 1
hps.data.use_mel_posterior_encoder = False
device = (
"cuda:1"
if torch.cuda.is_available()
else (
"mps"
if sys.platform == "darwin" and torch.backends.mps.is_available()
else "cpu"
)
)
net_g = SynthesizerTrn(
len(symbols),
posterior_channels,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers, #- >0 for multi speaker
**hps.model
).to(device)
_ = net_g.eval()
# Load model
_ = utils.load_checkpoint(args.model, net_g, None)
speakers = hps.data.n_speakers
languages = ["KO"]
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
text = gr.TextArea(
label="Text",
placeholder="Input Text Here",
value="TTS는 텍스트 문서를 음성으로 출력시켜 주는 기술이며 텍스트 문서를 입력하면 음성으로 읽어주는 기술이다.",
)
speaker = gr.Slider(
minimum=0, maximum=speakers-1, value=0, step=1, label="성우"
)
sdp_ratio = gr.Slider(
minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
)
noise_scale = gr.Slider(
minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise Scale"
)
noise_scale_w = gr.Slider(
minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise Scale W"
)
length_scale = gr.Slider(
minimum=0.1, maximum=2, value=1, step=0.1, label="Length Scale"
)
language = gr.Dropdown(
choices=languages, value=languages[0], label="Language"
)
btn = gr.Button("Generate!", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Message")
audio_output = gr.Audio(label="Output Audio")
btn.click(
tts_fn,
inputs=[
text,
speaker,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
],
outputs=[text_output, audio_output],
)
webbrowser.open("http://127.0.0.1:7860")
app.launch(share=True)