sayashi commited on
Commit
b875bd7
1 Parent(s): fb078b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -156
app.py CHANGED
@@ -1,156 +1,154 @@
1
- # coding=utf-8
2
- import os
3
- import re
4
- import utils
5
- import commons
6
- import json
7
- import gradio as gr
8
- from models import SynthesizerTrn
9
- from text import text_to_sequence
10
- from torch import no_grad, LongTensor
11
- import logging
12
- logging.getLogger('numba').setLevel(logging.WARNING)
13
- hps_ms = utils.get_hparams_from_file(r'config/config.json')
14
-
15
- def get_text(text, hps):
16
- text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
17
- if hps.data.add_blank:
18
- text_norm = commons.intersperse(text_norm, 0)
19
- text_norm = LongTensor(text_norm)
20
- return text_norm, clean_text
21
-
22
- def create_tts_fn(net_g_ms, speaker_id):
23
- def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
24
- text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
25
- text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
26
- max_len = 150
27
- if text_len > max_len:
28
- return "Error: Text is too long", None
29
- if language == 0:
30
- text = f"[ZH]{text}[ZH]"
31
- elif language == 1:
32
- text = f"[JA]{text}[JA]"
33
- else:
34
- text = f"{text}"
35
- stn_tst, clean_text = get_text(text, hps_ms)
36
- with no_grad():
37
- x_tst = stn_tst.unsqueeze(0)
38
- x_tst_lengths = LongTensor([stn_tst.size(0)])
39
- sid = LongTensor([speaker_id])
40
- audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
41
- length_scale=length_scale)[0][0, 0].data.float().numpy()
42
-
43
- return "Success", (22050, audio)
44
- return tts_fn
45
-
46
- def change_lang(language):
47
- if language == 0:
48
- return 0.6, 0.668, 1.2
49
- else:
50
- return 0.6, 0.668, 1
51
-
52
- download_audio_js = """
53
- () =>{{
54
- let root = document.querySelector("body > gradio-app");
55
- if (root.shadowRoot != null)
56
- root = root.shadowRoot;
57
- let audio = root.querySelector("#tts-audio").querySelector("audio");
58
- let text = root.querySelector("#input-text").querySelector("textarea");
59
- if (audio == undefined)
60
- return;
61
- text = text.value;
62
- if (text == undefined)
63
- text = Math.floor(Math.random()*100000000);
64
- audio = audio.src;
65
- let oA = document.createElement("a");
66
- oA.download = text.substr(0, 20)+'.wav';
67
- oA.href = audio;
68
- document.body.appendChild(oA);
69
- oA.click();
70
- oA.remove();
71
- }}
72
- """
73
-
74
- if __name__ == '__main__':
75
- models = []
76
- with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
77
- models_info = json.load(f)
78
- for i, info in models_info.items():
79
- net_g_ms = SynthesizerTrn(
80
- len(hps_ms.symbols),
81
- hps_ms.data.filter_length // 2 + 1,
82
- hps_ms.train.segment_size // hps_ms.data.hop_length,
83
- n_speakers=hps_ms.data.n_speakers,
84
- **hps_ms.model)
85
- _ = net_g_ms.eval()
86
- sid = info['sid']
87
- name_en = info['name_en']
88
- name_zh = info['name_zh']
89
- title = info['title']
90
- cover = f"pretrained_models/{i}/{info['cover']}"
91
- utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
92
- models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
93
- with gr.Blocks() as app:
94
- gr.Markdown(
95
- "# <center> vits-models\n"
96
- "<div align='center'>主要有赛马娘,原神中文,原神日语,崩坏3的音色</div>"
97
- '<div align="center"><a><font color="#dd0000">结果有随机性,语调可能很奇怪,可多次生成取最佳效果</font></a></div>'
98
- '<div align="center"><a><font color="#dd0000">标点符号会影响生成的结果</font></a></div>'
99
- )
100
-
101
- with gr.Tabs():
102
- with gr.TabItem("EN"):
103
- for (sid, name_en, name_zh, title, cover, net_g_ms, tts_fn) in models:
104
- with gr.TabItem(name_en):
105
- with gr.Row():
106
- gr.Markdown(
107
- '<div align="center">'
108
- f'<a><strong>{title}</strong></a>'
109
- f'<img src="file/{cover}">' if cover else ""
110
- '</div>'
111
- )
112
- with gr.Row():
113
- with gr.Column():
114
- input_text = gr.Textbox(label="Text (100 words limitation)", lines=5, value="先生。今日も全力であなたをアシストしますね。", elem_id=f"input-text")
115
- lang = gr.Dropdown(label="Language", choices=["Chinese", "Japanese", "Mix(wrap the Chinese text with [ZH][ZH], wrap the Japanese text with [JA][JA])"],
116
- type="index", value="Japanese")
117
- btn = gr.Button(value="Generate")
118
- with gr.Row():
119
- ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
120
- nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
121
- ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True)
122
- with gr.Column():
123
- o1 = gr.Textbox(label="Output Message")
124
- o2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio")
125
- download = gr.Button("Download Audio")
126
- btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls], outputs=[o1, o2])
127
- download.click(None, [], [], _js=download_audio_js.format())
128
- lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
129
- with gr.TabItem("中文"):
130
- for (sid, name_en, name_zh, title, cover, net_g_ms, tts_fn) in models:
131
- with gr.TabItem(name_zh):
132
- with gr.Row():
133
- gr.Markdown(
134
- '<div align="center">'
135
- f'<a><strong>{title}</strong></a>'
136
- f'<img src="file/{cover}">' if cover else ""
137
- '</div>'
138
- )
139
- with gr.Row():
140
- with gr.Column():
141
- input_text = gr.Textbox(label="文本 (100字上限)", lines=5, value="先生。今日も全力であなたをアシストしますね。", elem_id=f"input-text")
142
- lang = gr.Dropdown(label="语言", choices=["中文", "日语", "中日混合(中文用[ZH][ZH]包裹起来,日文用[JA][JA]包裹起来)"],
143
- type="index", value="日语")
144
- btn = gr.Button(value="生成")
145
- with gr.Row():
146
- ns = gr.Slider(label="控制感情变化程度", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
147
- nsw = gr.Slider(label="控制音素发音长度", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
148
- ls = gr.Slider(label="控制整体语速", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True)
149
- with gr.Column():
150
- o1 = gr.Textbox(label="输出信息")
151
- o2 = gr.Audio(label="输出音频", elem_id=f"tts-audio")
152
- download = gr.Button("下载音频")
153
- btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls], outputs=[o1, o2])
154
- download.click(None, [], [], _js=download_audio_js.format())
155
- lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
156
- app.queue(concurrency_count=1).launch()
 
1
+ # coding=utf-8
2
+ import os
3
+ import re
4
+ import utils
5
+ import commons
6
+ import json
7
+ import gradio as gr
8
+ from models import SynthesizerTrn
9
+ from text import text_to_sequence
10
+ from torch import no_grad, LongTensor
11
+ import logging
12
+ logging.getLogger('numba').setLevel(logging.WARNING)
13
+ hps_ms = utils.get_hparams_from_file(r'config/config.json')
14
+
15
+ def get_text(text, hps):
16
+ text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
17
+ if hps.data.add_blank:
18
+ text_norm = commons.intersperse(text_norm, 0)
19
+ text_norm = LongTensor(text_norm)
20
+ return text_norm, clean_text
21
+
22
+ def create_tts_fn(net_g_ms, speaker_id):
23
+ def tts_fn(text, language, noise_scale, noise_scale_w, length_scale):
24
+ text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
25
+ text_len = len(re.sub("\[([A-Z]{2})\]", "", text))
26
+ max_len = 150
27
+ if text_len > max_len:
28
+ return "Error: Text is too long", None
29
+ if language == 0:
30
+ text = f"[ZH]{text}[ZH]"
31
+ elif language == 1:
32
+ text = f"[JA]{text}[JA]"
33
+ else:
34
+ text = f"{text}"
35
+ stn_tst, clean_text = get_text(text, hps_ms)
36
+ with no_grad():
37
+ x_tst = stn_tst.unsqueeze(0)
38
+ x_tst_lengths = LongTensor([stn_tst.size(0)])
39
+ sid = LongTensor([speaker_id])
40
+ audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
41
+ length_scale=length_scale)[0][0, 0].data.float().numpy()
42
+
43
+ return "Success", (22050, audio)
44
+ return tts_fn
45
+
46
+ def change_lang(language):
47
+ if language == 0:
48
+ return 0.6, 0.668, 1.2
49
+ else:
50
+ return 0.6, 0.668, 1
51
+
52
+ download_audio_js = """
53
+ () =>{{
54
+ let root = document.querySelector("body > gradio-app");
55
+ if (root.shadowRoot != null)
56
+ root = root.shadowRoot;
57
+ let audio = root.querySelector("#tts-audio").querySelector("audio");
58
+ let text = root.querySelector("#input-text").querySelector("textarea");
59
+ if (audio == undefined)
60
+ return;
61
+ text = text.value;
62
+ if (text == undefined)
63
+ text = Math.floor(Math.random()*100000000);
64
+ audio = audio.src;
65
+ let oA = document.createElement("a");
66
+ oA.download = text.substr(0, 20)+'.wav';
67
+ oA.href = audio;
68
+ document.body.appendChild(oA);
69
+ oA.click();
70
+ oA.remove();
71
+ }}
72
+ """
73
+
74
+ if __name__ == '__main__':
75
+ models = []
76
+ with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
77
+ models_info = json.load(f)
78
+ for i, info in models_info.items():
79
+ net_g_ms = SynthesizerTrn(
80
+ len(hps_ms.symbols),
81
+ hps_ms.data.filter_length // 2 + 1,
82
+ hps_ms.train.segment_size // hps_ms.data.hop_length,
83
+ n_speakers=hps_ms.data.n_speakers,
84
+ **hps_ms.model)
85
+ _ = net_g_ms.eval()
86
+ sid = info['sid']
87
+ name_en = info['name_en']
88
+ name_zh = info['name_zh']
89
+ title = info['title']
90
+ cover = f"pretrained_models/{i}/{info['cover']}"
91
+ utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None)
92
+ models.append((sid, name_en, name_zh, title, cover, net_g_ms, create_tts_fn(net_g_ms, sid)))
93
+ with gr.Blocks() as app:
94
+ gr.Markdown(
95
+ "# <center> vits-models\n"
96
+ "![visitor badge](https://visitor-badge.glitch.me/badge?page_id=sayashi.vits-models)\n\n"
97
+ )
98
+
99
+ with gr.Tabs():
100
+ with gr.TabItem("EN"):
101
+ for (sid, name_en, name_zh, title, cover, net_g_ms, tts_fn) in models:
102
+ with gr.TabItem(name_en):
103
+ with gr.Row():
104
+ gr.Markdown(
105
+ '<div align="center">'
106
+ f'<a><strong>{title}</strong></a>'
107
+ f'<img width="300px" src="file/{cover}">' if cover else ""
108
+ '</div>'
109
+ )
110
+ with gr.Row():
111
+ with gr.Column():
112
+ input_text = gr.Textbox(label="Text (100 words limitation)", lines=5, value="先生。今日も全力であなたをアシストしますね。", elem_id=f"input-text")
113
+ lang = gr.Dropdown(label="Language", choices=["Chinese", "Japanese", "Mix(wrap the Chinese text with [ZH][ZH], wrap the Japanese text with [JA][JA])"],
114
+ type="index", value="Japanese")
115
+ btn = gr.Button(value="Generate")
116
+ with gr.Row():
117
+ ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
118
+ nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
119
+ ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True)
120
+ with gr.Column():
121
+ o1 = gr.Textbox(label="Output Message")
122
+ o2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio")
123
+ download = gr.Button("Download Audio")
124
+ btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls], outputs=[o1, o2])
125
+ download.click(None, [], [], _js=download_audio_js.format())
126
+ lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
127
+ with gr.TabItem("中文"):
128
+ for (sid, name_en, name_zh, title, cover, net_g_ms, tts_fn) in models:
129
+ with gr.TabItem(name_zh):
130
+ with gr.Row():
131
+ gr.Markdown(
132
+ '<div align="center">'
133
+ f'<a><strong>{title}</strong></a>'
134
+ f'<img src="file/{cover}">' if cover else ""
135
+ '</div>'
136
+ )
137
+ with gr.Row():
138
+ with gr.Column():
139
+ input_text = gr.Textbox(label="文本 (100字上限)", lines=5, value="先生。今日も全力であなたをアシストしますね。", elem_id=f"input-text")
140
+ lang = gr.Dropdown(label="语言", choices=["中文", "日语", "中日混合(中文用[ZH][ZH]包裹起来,日文用[JA][JA]包裹起来)"],
141
+ type="index", value="日语")
142
+ btn = gr.Button(value="生成")
143
+ with gr.Row():
144
+ ns = gr.Slider(label="控制感情变化程度", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True)
145
+ nsw = gr.Slider(label="控制音素发音长度", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True)
146
+ ls = gr.Slider(label="控制整体语速", minimum=0.1, maximum=2.0, step=0.1, value=1, interactive=True)
147
+ with gr.Column():
148
+ o1 = gr.Textbox(label="输出信息")
149
+ o2 = gr.Audio(label="输出音频", elem_id=f"tts-audio")
150
+ download = gr.Button("下载音频")
151
+ btn.click(tts_fn, inputs=[input_text, lang, ns, nsw, ls], outputs=[o1, o2])
152
+ download.click(None, [], [], _js=download_audio_js.format())
153
+ lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
154
+ app.queue(concurrency_count=1).launch()