XzJosh commited on
Commit
e33d8a8
1 Parent(s): c65ba09

Delete webui.py

Browse files
Files changed (1) hide show
  1. webui.py +0 -130
webui.py DELETED
@@ -1,130 +0,0 @@
1
- import sys, os
2
-
3
- if sys.platform == "darwin":
4
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
-
6
- import logging
7
-
8
- logging.getLogger("numba").setLevel(logging.WARNING)
9
- logging.getLogger("markdown_it").setLevel(logging.WARNING)
10
- logging.getLogger("urllib3").setLevel(logging.WARNING)
11
- logging.getLogger("matplotlib").setLevel(logging.WARNING)
12
-
13
- logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
- import torch
18
- import argparse
19
- import commons
20
- import utils
21
- from models import SynthesizerTrn
22
- from text.symbols import symbols
23
- from text import cleaned_text_to_sequence, get_bert
24
- from text.cleaner import clean_text
25
- import gradio as gr
26
- import webbrowser
27
-
28
-
29
- net_g = None
30
-
31
-
32
- def get_text(text, language_str, hps):
33
- norm_text, phone, tone, word2ph = clean_text(text, language_str)
34
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
35
-
36
- if hps.data.add_blank:
37
- phone = commons.intersperse(phone, 0)
38
- tone = commons.intersperse(tone, 0)
39
- language = commons.intersperse(language, 0)
40
- for i in range(len(word2ph)):
41
- word2ph[i] = word2ph[i] * 2
42
- word2ph[0] += 1
43
- bert = get_bert(norm_text, word2ph, language_str)
44
- del word2ph
45
-
46
- assert bert.shape[-1] == len(phone)
47
-
48
- phone = torch.LongTensor(phone)
49
- tone = torch.LongTensor(tone)
50
- language = torch.LongTensor(language)
51
-
52
- return bert, phone, tone, language
53
-
54
- def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
55
- global net_g
56
- bert, phones, tones, lang_ids = get_text(text, "ZH", hps)
57
- with torch.no_grad():
58
- x_tst=phones.to(device).unsqueeze(0)
59
- tones=tones.to(device).unsqueeze(0)
60
- lang_ids=lang_ids.to(device).unsqueeze(0)
61
- bert = bert.to(device).unsqueeze(0)
62
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
63
- del phones
64
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
65
- audio = net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, sdp_ratio=sdp_ratio
66
- , noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
67
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
68
- return audio
69
-
70
- def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
71
- with torch.no_grad():
72
- audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker)
73
- return "Success", (hps.data.sampling_rate, audio)
74
-
75
-
76
- if __name__ == "__main__":
77
- parser = argparse.ArgumentParser()
78
- parser.add_argument("-m", "--model", default="./logs/as/G_8000.pth", help="path of your model")
79
- parser.add_argument("-c", "--config", default="./configs/config.json", help="path of your config file")
80
- parser.add_argument("--share", default=False, help="make link public")
81
- parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
82
-
83
- args = parser.parse_args()
84
- if args.debug:
85
- logger.info("Enable DEBUG-LEVEL log")
86
- logging.basicConfig(level=logging.DEBUG)
87
- hps = utils.get_hparams_from_file(args.config)
88
-
89
- device = (
90
- "cuda:0"
91
- if torch.cuda.is_available()
92
- else (
93
- "mps"
94
- if sys.platform == "darwin" and torch.backends.mps.is_available()
95
- else "cpu"
96
- )
97
- )
98
- net_g = SynthesizerTrn(
99
- len(symbols),
100
- hps.data.filter_length // 2 + 1,
101
- hps.train.segment_size // hps.data.hop_length,
102
- n_speakers=hps.data.n_speakers,
103
- **hps.model).to(device)
104
- _ = net_g.eval()
105
-
106
- _ = utils.load_checkpoint(args.model, net_g, None, skip_optimizer=True)
107
-
108
- speaker_ids = hps.data.spk2id
109
- speakers = list(speaker_ids.keys())
110
- with gr.Blocks() as app:
111
- with gr.Row():
112
- with gr.Column():
113
- text = gr.TextArea(label="Text", placeholder="Input Text Here",
114
- value="吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。")
115
- speaker = gr.Dropdown(choices=speakers, value=speakers[0], label='Speaker')
116
- sdp_ratio = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label='SDP Ratio')
117
- noise_scale = gr.Slider(minimum=0.1, maximum=2, value=0.6, step=0.1, label='Noise Scale')
118
- noise_scale_w = gr.Slider(minimum=0.1, maximum=2, value=0.8, step=0.1, label='Noise Scale W')
119
- length_scale = gr.Slider(minimum=0.1, maximum=2, value=1, step=0.1, label='Length Scale')
120
- btn = gr.Button("Generate!", variant="primary")
121
- with gr.Column():
122
- text_output = gr.Textbox(label="Message")
123
- audio_output = gr.Audio(label="Output Audio")
124
-
125
- btn.click(tts_fn,
126
- inputs=[text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale],
127
- outputs=[text_output, audio_output])
128
-
129
- webbrowser.open("http://127.0.0.1:7860")
130
- app.launch(share=args.share)