RUSH-miaomi commited on
Commit
3530d5c
1 Parent(s): d7af1a0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -0
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ if sys.platform == "darwin":
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ import logging
7
+
8
+ logging.getLogger("numba").setLevel(logging.WARNING)
9
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
10
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
11
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
12
+
13
+ logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ import torch
18
+ import argparse
19
+ import commons
20
+ import utils
21
+ from models import SynthesizerTrn
22
+ from text.symbols import symbols
23
+ from text import cleaned_text_to_sequence, get_bert
24
+ from text.cleaner import clean_text
25
+ import gradio as gr
26
+ import webbrowser
27
+
28
+
29
+ net_g = None
30
+
31
+
32
+ def get_text(text, language_str, hps):
33
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
34
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
35
+
36
+ if hps.data.add_blank:
37
+ phone = commons.intersperse(phone, 0)
38
+ tone = commons.intersperse(tone, 0)
39
+ language = commons.intersperse(language, 0)
40
+ for i in range(len(word2ph)):
41
+ word2ph[i] = word2ph[i] * 2
42
+ word2ph[0] += 1
43
+ bert = get_bert(norm_text, word2ph, language_str)
44
+ del word2ph
45
+
46
+ assert bert.shape[-1] == len(phone)
47
+
48
+ phone = torch.LongTensor(phone)
49
+ tone = torch.LongTensor(tone)
50
+ language = torch.LongTensor(language)
51
+
52
+ return bert, phone, tone, language
53
+
54
+ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
55
+ global net_g
56
+ bert, phones, tones, lang_ids = get_text(text, "ZH", hps)
57
+ with torch.no_grad():
58
+ x_tst=phones.to(device).unsqueeze(0)
59
+ tones=tones.to(device).unsqueeze(0)
60
+ lang_ids=lang_ids.to(device).unsqueeze(0)
61
+ bert = bert.to(device).unsqueeze(0)
62
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
63
+ del phones
64
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
65
+ audio = net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, sdp_ratio=sdp_ratio
66
+ , noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
67
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
68
+ return audio
69
+
70
+ def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
71
+ with torch.no_grad():
72
+ audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker)
73
+ return "Success", (hps.data.sampling_rate, audio)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ parser = argparse.ArgumentParser()
78
+ parser.add_argument("--model_dir", default="./logs/maolei/G_4800.pth", help="path of your model")
79
+ parser.add_argument("--config_dir", default="./configs/config.json", help="path of your config file")
80
+ parser.add_argument("--share", default=False, help="make link public")
81
+ parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
82
+
83
+ args = parser.parse_args()
84
+ if args.debug:
85
+ logger.info("Enable DEBUG-LEVEL log")
86
+ logging.basicConfig(level=logging.DEBUG)
87
+ hps = utils.get_hparams_from_file(args.config_dir)
88
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
89
+ '''
90
+ device = (
91
+ "cuda:0"
92
+ if torch.cuda.is_available()
93
+ else (
94
+ "mps"
95
+ if sys.platform == "darwin" and torch.backends.mps.is_available()
96
+ else "cpu"
97
+ )
98
+ )
99
+ '''
100
+ net_g = SynthesizerTrn(
101
+ len(symbols),
102
+ hps.data.filter_length // 2 + 1,
103
+ hps.train.segment_size // hps.data.hop_length,
104
+ n_speakers=hps.data.n_speakers,
105
+ **hps.model).to(device)
106
+ _ = net_g.eval()
107
+
108
+ _ = utils.load_checkpoint(args.model_dir, net_g, None, skip_optimizer=True)
109
+
110
+ speaker_ids = hps.data.spk2id
111
+ speakers = list(speaker_ids.keys())
112
+ with gr.Blocks() as app:
113
+ with gr.Row():
114
+ with gr.Column():
115
+
116
+ text = gr.TextArea(label="Text", placeholder="Input Text Here",
117
+ value="猫雷最强!")
118
+ speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='Speaker')
119
+ sdp_ratio = gr.Slider(minimum=0.1, maximum=1, value=0.2, step=0.01, label='SDP/DP混合比')
120
+ noise_scale = gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.01, label='感情调节')
121
+ noise_scale_w = gr.Slider(minimum=0.1, maximum=1, value=0.9, step=0.01, label='音素长度')
122
+ length_scale = gr.Slider(minimum=0.1, maximum=2, value=1, step=0.01, label='生成长度')
123
+ language = gr.Dropdown(choices=languages, value=languages[0], label="选择语言(该模型mix有问题先别选)" )
124
+ btn = gr.Button("点击生成", variant="primary")
125
+ with gr.Column():
126
+ text_output = gr.Textbox(label="Message")
127
+ audio_output = gr.Audio(label="Output Audio")
128
+
129
+ btn.click(
130
+ tts_fn,
131
+ inputs=[
132
+ text,
133
+ speaker,
134
+ sdp_ratio,
135
+ noise_scale,
136
+ noise_scale_w,
137
+ length_scale,
138
+ language,
139
+ ],
140
+ outputs=[text_output, audio_output],
141
+ )
142
+
143
+ # webbrowser.open("http://127.0.0.1:6006")
144
+ # app.launch(server_port=6006, show_error=True)
145
+
146
+ app.launch(show_error=True)