SayaSS commited on
Commit
cf0491a
1 Parent(s): 51d1e40
.gitignore CHANGED
@@ -165,4 +165,6 @@ cython_debug/
165
  filelists/*
166
  !/filelists/esd.list
167
  data/*
168
- /infer_save
 
 
 
165
  filelists/*
166
  !/filelists/esd.list
167
  data/*
168
+ /infer_save
169
+
170
+ .idea
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Umamusume Bert Vits2
3
  emoji: 📊
4
  colorFrom: red
5
  colorTo: green
 
1
  ---
2
+ title: Bert Vits2
3
  emoji: 📊
4
  colorFrom: red
5
  colorTo: green
app.py CHANGED
@@ -1,12 +1,17 @@
1
- # flake8: noqa: E402
2
-
3
- import sys, os
4
  import logging
5
  import os
6
- import time
7
- import numpy as np # 假设你使用NumPy来处理音频数据
8
- import shutil # 用于删除文件夹和文件
9
- from scipy.io import wavfile
 
 
 
 
 
 
 
10
 
11
  logging.getLogger("numba").setLevel(logging.WARNING)
12
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
@@ -18,29 +23,11 @@ logging.basicConfig(
18
  )
19
 
20
  logger = logging.getLogger(__name__)
 
21
 
22
- import torch
23
- import argparse
24
- import commons
25
- import utils
26
- from models import SynthesizerTrn
27
- from text.symbols import symbols
28
- from text import cleaned_text_to_sequence, get_bert
29
- from text.cleaner import clean_text
30
- import gradio as gr
31
- import webbrowser
32
- import numpy as np
33
-
34
- net_g = None
35
-
36
- if sys.platform == "darwin" and torch.backends.mps.is_available():
37
- device = "mps"
38
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
39
- else:
40
- device = "cuda"
41
 
42
-
43
- def get_text(text, language_str, hps):
44
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
45
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
46
 
@@ -55,15 +42,8 @@ def get_text(text, language_str, hps):
55
  del word2ph
56
  assert bert.shape[-1] == len(phone), phone
57
 
58
- if language_str == "ZH":
59
- bert = bert
60
- ja_bert = torch.zeros(768, len(phone))
61
- elif language_str == "JP":
62
- ja_bert = bert
63
- bert = torch.zeros(1024, len(phone))
64
- else:
65
- bert = torch.zeros(1024, len(phone))
66
- ja_bert = torch.zeros(768, len(phone))
67
 
68
  assert bert.shape[-1] == len(
69
  phone
@@ -75,9 +55,8 @@ def get_text(text, language_str, hps):
75
  return bert, ja_bert, phone, tone, language
76
 
77
 
78
- def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language):
79
- global net_g
80
- bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps)
81
  with torch.no_grad():
82
  x_tst = phones.to(device).unsqueeze(0)
83
  tones = tones.to(device).unsqueeze(0)
@@ -85,14 +64,13 @@ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, langua
85
  bert = bert.to(device).unsqueeze(0)
86
  ja_bert = ja_bert.to(device).unsqueeze(0)
87
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
88
- #print(x_tst.type(), tones.type(), lang_ids.type(), bert.type(), ja_bert.type(), x_tst_lengths.type())
89
  del phones
90
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
91
  audio = (
92
- net_g.infer(
93
  x_tst,
94
  x_tst_lengths,
95
- speakers,
96
  tones,
97
  lang_ids,
98
  bert,
@@ -106,108 +84,25 @@ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, langua
106
  .float()
107
  .numpy()
108
  )
109
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
110
- torch.cuda.empty_cache()
111
- return audio
112
-
113
- def infer_2(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language):
114
- global net_g_2
115
- bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps)
116
- with torch.no_grad():
117
- x_tst = phones.to(device).unsqueeze(0)
118
- tones = tones.to(device).unsqueeze(0)
119
- lang_ids = lang_ids.to(device).unsqueeze(0)
120
- bert = bert.to(device).unsqueeze(0)
121
- ja_bert = ja_bert.to(device).unsqueeze(0)
122
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
123
- #print(x_tst.type(), tones.type(), lang_ids.type(), bert.type(), ja_bert.type(), x_tst_lengths.type())
124
- del phones
125
- speakers = torch.LongTensor([hps_2.data.spk2id[sid]]).to(device)
126
- audio = (
127
- net_g_2.infer(
128
- x_tst,
129
- x_tst_lengths,
130
- speakers,
131
- tones,
132
- lang_ids,
133
- bert,
134
- ja_bert,
135
- sdp_ratio=sdp_ratio,
136
- noise_scale=noise_scale,
137
- noise_scale_w=noise_scale_w,
138
- length_scale=length_scale,
139
- )[0][0, 0]
140
- .data.cpu()
141
- .float()
142
- .numpy()
143
- )
144
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
145
  torch.cuda.empty_cache()
146
  return audio
147
 
148
- __LOG__ = "./generation_logs.txt"
149
- def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,from_model=0):
150
- # 清空 ./infer_save 文件夹
151
- if os.path.exists('./infer_save'):
152
- shutil.rmtree('./infer_save')
153
- os.makedirs('./infer_save')
154
-
155
- slices = text.split("\n")
156
- slices = [slice for slice in slices if slice.strip() != ""]
157
- audio_list = []
158
- with torch.no_grad():
159
- with open(__LOG__,"a",encoding="UTF-8") as f:
160
- for slice in slices:
161
- assert len(slice) < 150 # 限制输入的文本长度
162
- if from_model == 0:
163
- audio = infer(slice, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker, language=language)
164
- else:
165
- audio = infer_2(slice, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, sid=speaker, language=language)
166
- audio_list.append(audio)
167
-
168
- # 创建唯一的文件名
169
- timestamp = str(int(time.time() * 1000))
170
- audio_file_path = f'./infer_save/audio_{timestamp}.wav'
171
-
172
- # 保存音频数据到.wav文件
173
- wavfile.write(audio_file_path, hps.data.sampling_rate, audio)
174
-
175
- silence = np.zeros(hps.data.sampling_rate, dtype=np.int16) # 生成1秒的静音
176
- audio_list.append(silence) # 将静音添加到列表中
177
-
178
- f.write(f"{slice} | {speaker}\n")
179
- print(f"{slice} | {speaker}")
180
-
181
- audio_concat = np.concatenate(audio_list)
182
- return "Success", (hps.data.sampling_rate, audio_concat)
183
- def tts_fn_2(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,from_model=1):
184
- return tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale, language,from_model)
185
 
186
  if __name__ == "__main__":
187
- parser = argparse.ArgumentParser()
188
- parser.add_argument(
189
- "-m", "--model", default="./logs/natuki/G_72000.pth", help="path of your model"
190
- )
191
- parser.add_argument(
192
- "-c",
193
- "--config",
194
- default="./configs/config.json",
195
- help="path of your config file",
196
- )
197
- parser.add_argument(
198
- "--share", default=False, help="make link public", action="store_true"
199
- )
200
- parser.add_argument(
201
- "-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log"
202
- )
203
-
204
- args = parser.parse_args()
205
- if args.debug:
206
- logger.info("Enable DEBUG-LEVEL log")
207
- logging.basicConfig(level=logging.DEBUG)
208
- hps = utils.get_hparams_from_file("./logs/umamusume/config.json")
209
- hps_2 = utils.get_hparams_from_file("./logs/natuki/config.json")
210
-
211
  device = (
212
  "cuda:0"
213
  if torch.cuda.is_available()
@@ -217,128 +112,57 @@ if __name__ == "__main__":
217
  else "cpu"
218
  )
219
  )
220
- net_g = SynthesizerTrn(
221
- len(symbols),
222
- hps.data.filter_length // 2 + 1,
223
- hps.train.segment_size // hps.data.hop_length,
224
- n_speakers=hps.data.n_speakers,
225
- **hps.model,
226
- ).to(device)
227
- _ = net_g.eval()
228
-
229
- net_g_2 = SynthesizerTrn(
230
- len(symbols),
231
- hps.data.filter_length // 2 + 1,
232
- hps.train.segment_size // hps.data.hop_length,
233
- n_speakers=hps.data.n_speakers,
234
- **hps.model,
235
- ).to(device)
236
 
237
- _ = utils.load_checkpoint("./logs/clara/G_4400.pth", net_g, None, skip_optimizer=True)
238
- _ = utils.load_checkpoint("./logs/kafka/G_4000.pth", net_g_2, None, skip_optimizer=True)
239
-
240
- speaker_ids = hps.data.spk2id
241
- speakers = list(speaker_ids.keys())
242
- speaker_ids_2 = hps_2.data.spk2id
243
- speakers_2 = list(speaker_ids_2.keys())
244
-
245
-
246
- languages = ["ZH", "JP"]
247
- with gr.Blocks() as app:
248
- with gr.Tab(label="umamusume"):
249
- with gr.Row():
250
- with gr.Column():
251
- text = gr.TextArea(
252
- label="Text",
253
- placeholder="Input Text Here",
254
- value="はりきっていこう!",
255
- )
256
- speaker = gr.Dropdown(
257
- choices=speakers, value=speakers[0], label="Speaker"
258
- )
259
- sdp_ratio = gr.Slider(
260
- minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
261
- )
262
- noise_scale = gr.Slider(
263
- minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise Scale"
264
- )
265
- noise_scale_w = gr.Slider(
266
- minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise Scale W"
267
- )
268
- length_scale = gr.Slider(
269
- minimum=0.1, maximum=2, value=1, step=0.1, label="Length Scale"
270
- )
271
- language = gr.Dropdown(
272
- choices=languages, value=languages[1], label="Language"
273
- )
274
- btn = gr.Button("Generate!", variant="primary")
275
- with gr.Column():
276
- text_output = gr.Textbox(label="Message")
277
- audio_output = gr.Audio(label="Output Audio")
278
- gr.Markdown("# 赛马娘 Bert-VITS2 语音合成\n"
279
- "Project page:[GitHub](https://github.com/fishaudio/Bert-VITS2)\n"
280
- "- 本项目在日语方面有所欠缺,特别是音调的设计上,需要帮助。\n"
281
- "- このプロジェクトは、日本語の方面で不足しています。特に、音調の設計に関して助けが欲しいです。")
282
-
283
- btn.click(
284
- tts_fn,
285
- inputs=[
286
- text,
287
- speaker,
288
- sdp_ratio,
289
- noise_scale,
290
- noise_scale_w,
291
- length_scale,
292
- language,
293
- ],
294
- outputs=[text_output, audio_output],
295
- )
296
- with gr.Tab(label="natuki"):
297
- with gr.Row():
298
- with gr.Column():
299
- text2 = gr.TextArea(
300
- label="Text",
301
- placeholder="Input Text Here",
302
- value="はりきっていこう!",
303
- )
304
- speaker2 = gr.Dropdown(
305
- choices=speakers_2, value=speakers_2[0], label="Speaker"
306
- )
307
- sdp_ratio2 = gr.Slider(
308
- minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
309
- )
310
- noise_scale2 = gr.Slider(
311
- minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise Scale"
312
- )
313
- noise_scale_w2 = gr.Slider(
314
- minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise Scale W"
315
- )
316
- length_scale2 = gr.Slider(
317
- minimum=0.1, maximum=2, value=1, step=0.1, label="Length Scale"
318
- )
319
- language2 = gr.Dropdown(
320
- choices=languages, value=languages[1], label="Language"
321
- )
322
- btn2 = gr.Button("Generate!", variant="primary")
323
- with gr.Column():
324
- text_output2 = gr.Textbox(label="Message")
325
- audio_output2 = gr.Audio(label="Output Audio")
326
- gr.Markdown("# 赛马娘 Bert-VITS2 语音合成\n"
327
- "Project page:[GitHub](https://github.com/fishaudio/Bert-VITS2)\n"
328
- "- 本项目在日语方面有所欠缺,特别是音调的设计上,需要帮助。\n"
329
- "- このプロジェクトは、日本語の方面で不足しています。特に、音調の設計に関して助けが欲しいです。")
330
 
331
- btn2.click(
332
- tts_fn_2,
333
- inputs=[
334
- text2,
335
- speaker2,
336
- sdp_ratio2,
337
- noise_scale2,
338
- noise_scale_w2,
339
- length_scale2,
340
- language2,
341
- ],
342
- outputs=[text_output2, audio_output2],
343
- )
344
- app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
 
 
2
  import logging
3
  import os
4
+ import json
5
+ import torch
6
+ import argparse
7
+ import commons
8
+ import utils
9
+ import gradio as gr
10
+
11
+ from models import SynthesizerTrn
12
+ from text.symbols import symbols
13
+ from text import cleaned_text_to_sequence, get_bert
14
+ from text.cleaner import clean_text
15
 
16
  logging.getLogger("numba").setLevel(logging.WARNING)
17
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
 
23
  )
24
 
25
  logger = logging.getLogger(__name__)
26
+ limitation = os.getenv("SYSTEM") == "spaces" # limit text and audio length in huggingface spaces
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def get_text(text, hps):
30
+ language_str = "JP"
31
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
32
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
33
 
 
42
  del word2ph
43
  assert bert.shape[-1] == len(phone), phone
44
 
45
+ ja_bert = bert
46
+ bert = torch.zeros(1024, len(phone))
 
 
 
 
 
 
 
47
 
48
  assert bert.shape[-1] == len(
49
  phone
 
55
  return bert, ja_bert, phone, tone, language
56
 
57
 
58
+ def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, net_g_ms, hps):
59
+ bert, ja_bert, phones, tones, lang_ids = get_text(text, hps)
 
60
  with torch.no_grad():
61
  x_tst = phones.to(device).unsqueeze(0)
62
  tones = tones.to(device).unsqueeze(0)
 
64
  bert = bert.to(device).unsqueeze(0)
65
  ja_bert = ja_bert.to(device).unsqueeze(0)
66
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
 
67
  del phones
68
+ sid = torch.LongTensor([sid]).to(device)
69
  audio = (
70
+ net_g_ms.infer(
71
  x_tst,
72
  x_tst_lengths,
73
+ sid,
74
  tones,
75
  lang_ids,
76
  bert,
 
84
  .float()
85
  .numpy()
86
  )
87
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, sid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  torch.cuda.empty_cache()
89
  return audio
90
 
91
+ def create_tts_fn(net_g_ms, hps):
92
+ def tts_fn(text, speaker, sdp_ratio, noise_scale, noise_scale_w, length_scale):
93
+ print(f"{text} | {speaker}")
94
+ sid = hps.data.spk2id[speaker]
95
+ text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
96
+ if limitation:
97
+ max_len = 100
98
+ if len(text) > max_len:
99
+ return "Error: Text is too long", None
100
+ audio = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
101
+ length_scale=length_scale, sid=sid, net_g_ms=net_g_ms, hps=hps)
102
+ return "Success", (hps.data.sampling_rate, audio)
103
+ return tts_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  device = (
107
  "cuda:0"
108
  if torch.cuda.is_available()
 
112
  else "cpu"
113
  )
114
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ parser = argparse.ArgumentParser()
117
+ parser.add_argument("--share", default=False, help="make link public", action="store_true")
118
+ parser.add_argument("-d", "--debug", action="store_true", help="enable DEBUG-LEVEL log")
119
+ args = parser.parse_args()
120
+ if args.debug:
121
+ logger.info("Enable DEBUG-LEVEL log")
122
+ logging.basicConfig(level=logging.DEBUG)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ models = []
125
+ with open("pretrained_models/info.json", "r", encoding="utf-8") as f:
126
+ models_info = json.load(f)
127
+ for i, info in models_info.items():
128
+ if not info['enable']:
129
+ continue
130
+ name = info['name']
131
+ title = info['title']
132
+ example = info['example']
133
+ hps = utils.get_hparams_from_file(f"./pretrained_models/{name}/config.json")
134
+ net_g_ms = SynthesizerTrn(
135
+ len(symbols),
136
+ hps.data.filter_length // 2 + 1,
137
+ hps.train.segment_size // hps.data.hop_length,
138
+ n_speakers=hps.data.n_speakers,
139
+ **hps.model)
140
+ utils.load_checkpoint(f'pretrained_models/{i}/{i}.pth', net_g_ms, None, skip_optimizer=True)
141
+ _ = net_g_ms.eval().to(device)
142
+ models.append((name, title, example, list(hps.data.spk2id.keys()), net_g_ms, create_tts_fn(net_g_ms, hps)))
143
+ with gr.Blocks(theme='NoCrypt/miku') as app:
144
+ with gr.Tabs():
145
+ for (name, title, example, speakers, net_g_ms, tts_fn) in models:
146
+ with gr.TabItem(name):
147
+ with gr.Row():
148
+ gr.Markdown(
149
+ '<div align="center">'
150
+ f'<a><strong>{title}</strong></a>'
151
+ f'</div>'
152
+ )
153
+ with gr.Row():
154
+ with gr.Column():
155
+ input_text = gr.Textbox(label="Text (100 words limitation)" if limitation else "Text", lines=5, value=example)
156
+ btn = gr.Button(value="Generate", variant="primary")
157
+ with gr.Row():
158
+ sp = gr.Dropdown(choices=speakers, value=speakers[0], label="Speaker")
159
+ with gr.Row():
160
+ sdpr = gr.Slider(label="SDP Ratio", minimum=0, maximum=1, step=0.1, value=0.2)
161
+ ns = gr.Slider(label="noise_scale", minimum=0.1, maximum=1.0, step=0.1, value=0.6)
162
+ nsw = gr.Slider(label="noise_scale_w", minimum=0.1, maximum=1.0, step=0.1, value=0.8)
163
+ ls = gr.Slider(label="length_scale", minimum=0.1, maximum=2.0, step=0.1, value=1)
164
+ with gr.Column():
165
+ o1 = gr.Textbox(label="Output Message")
166
+ o2 = gr.Audio(label="Output Audio")
167
+ btn.click(tts_fn, inputs=[input_text, sp, sdpr, ns, nsw, ls], outputs=[o1, o2])
168
+ app.queue(concurrency_count=1).launch(share=args.share)
logs/clara/G_4400.pth → pretrained_models/clara/clara.pth RENAMED
File without changes
{logs → pretrained_models}/clara/config.json RENAMED
File without changes
pretrained_models/info.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "kafka": {
3
+ "enable": true,
4
+ "name": "kafka",
5
+ "title": "Honkai: Star Rail-カフカ",
6
+ "example": "嗅ぎます?この子は、特に香りもいいんです。艶があるっていうのかなぁ。とにかく、絶対に嗅いだ方がいい。ほら、どうです?"
7
+ },
8
+ "clara": {
9
+ "enable": true,
10
+ "name": "clara",
11
+ "title": "Honkai: Star Rail-クラーラ",
12
+ "example": "ーーーチャンスって何の?誰?どこから話してる?"
13
+ }
14
+ }
{logs → pretrained_models}/kafka/config.json RENAMED
File without changes
logs/kafka/G_4000.pth → pretrained_models/kafka/kafka.pth RENAMED
File without changes
server.py DELETED
@@ -1,170 +0,0 @@
1
- from flask import Flask, request, Response
2
- from io import BytesIO
3
- import torch
4
- from av import open as avopen
5
-
6
- import commons
7
- import utils
8
- from models import SynthesizerTrn
9
- from text.symbols import symbols
10
- from text import cleaned_text_to_sequence, get_bert
11
- from text.cleaner import clean_text
12
- from scipy.io import wavfile
13
-
14
- # Flask Init
15
- app = Flask(__name__)
16
- app.config["JSON_AS_ASCII"] = False
17
-
18
-
19
- def get_text(text, language_str, hps):
20
- norm_text, phone, tone, word2ph = clean_text(text, language_str)
21
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
22
-
23
- if hps.data.add_blank:
24
- phone = commons.intersperse(phone, 0)
25
- tone = commons.intersperse(tone, 0)
26
- language = commons.intersperse(language, 0)
27
- for i in range(len(word2ph)):
28
- word2ph[i] = word2ph[i] * 2
29
- word2ph[0] += 1
30
- bert = get_bert(norm_text, word2ph, language_str)
31
- del word2ph
32
- assert bert.shape[-1] == len(phone), phone
33
-
34
- if language_str == "ZH":
35
- bert = bert
36
- ja_bert = torch.zeros(768, len(phone))
37
- elif language_str == "JA":
38
- ja_bert = bert
39
- bert = torch.zeros(1024, len(phone))
40
- else:
41
- bert = torch.zeros(1024, len(phone))
42
- ja_bert = torch.zeros(768, len(phone))
43
- assert bert.shape[-1] == len(
44
- phone
45
- ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
46
- phone = torch.LongTensor(phone)
47
- tone = torch.LongTensor(tone)
48
- language = torch.LongTensor(language)
49
- return bert, ja_bert, phone, tone, language
50
-
51
-
52
- def infer(text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid, language):
53
- bert, ja_bert, phones, tones, lang_ids = get_text(text, language, hps)
54
- with torch.no_grad():
55
- x_tst = phones.to(dev).unsqueeze(0)
56
- tones = tones.to(dev).unsqueeze(0)
57
- lang_ids = lang_ids.to(dev).unsqueeze(0)
58
- bert = bert.to(dev).unsqueeze(0)
59
- ja_bert = ja_bert.to(device).unsqueeze(0)
60
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(dev)
61
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(dev)
62
- audio = (
63
- net_g.infer(
64
- x_tst,
65
- x_tst_lengths,
66
- speakers,
67
- tones,
68
- lang_ids,
69
- bert,
70
- ja_bert,
71
- sdp_ratio=sdp_ratio,
72
- noise_scale=noise_scale,
73
- noise_scale_w=noise_scale_w,
74
- length_scale=length_scale,
75
- )[0][0, 0]
76
- .data.cpu()
77
- .float()
78
- .numpy()
79
- )
80
- return audio
81
-
82
-
83
- def replace_punctuation(text, i=2):
84
- punctuation = ",。?!"
85
- for char in punctuation:
86
- text = text.replace(char, char * i)
87
- return text
88
-
89
-
90
- def wav2(i, o, format):
91
- inp = avopen(i, "rb")
92
- out = avopen(o, "wb", format=format)
93
- if format == "ogg":
94
- format = "libvorbis"
95
-
96
- ostream = out.add_stream(format)
97
-
98
- for frame in inp.decode(audio=0):
99
- for p in ostream.encode(frame):
100
- out.mux(p)
101
-
102
- for p in ostream.encode(None):
103
- out.mux(p)
104
-
105
- out.close()
106
- inp.close()
107
-
108
-
109
- # Load Generator
110
- hps = utils.get_hparams_from_file("./configs/config.json")
111
-
112
- dev = "cuda"
113
- net_g = SynthesizerTrn(
114
- len(symbols),
115
- hps.data.filter_length // 2 + 1,
116
- hps.train.segment_size // hps.data.hop_length,
117
- n_speakers=hps.data.n_speakers,
118
- **hps.model,
119
- ).to(dev)
120
- _ = net_g.eval()
121
-
122
- _ = utils.load_checkpoint("logs/G_649000.pth", net_g, None, skip_optimizer=True)
123
-
124
-
125
- @app.route("/")
126
- def main():
127
- try:
128
- speaker = request.args.get("speaker")
129
- text = request.args.get("text").replace("/n", "")
130
- sdp_ratio = float(request.args.get("sdp_ratio", 0.2))
131
- noise = float(request.args.get("noise", 0.5))
132
- noisew = float(request.args.get("noisew", 0.6))
133
- length = float(request.args.get("length", 1.2))
134
- language = request.args.get("language")
135
- if length >= 2:
136
- return "Too big length"
137
- if len(text) >= 250:
138
- return "Too long text"
139
- fmt = request.args.get("format", "wav")
140
- if None in (speaker, text):
141
- return "Missing Parameter"
142
- if fmt not in ("mp3", "wav", "ogg"):
143
- return "Invalid Format"
144
- if language not in ("JA", "ZH"):
145
- return "Invalid language"
146
- except:
147
- return "Invalid Parameter"
148
-
149
- with torch.no_grad():
150
- audio = infer(
151
- text,
152
- sdp_ratio=sdp_ratio,
153
- noise_scale=noise,
154
- noise_scale_w=noisew,
155
- length_scale=length,
156
- sid=speaker,
157
- language=language,
158
- )
159
-
160
- with BytesIO() as wav:
161
- wavfile.write(wav, hps.data.sampling_rate, audio)
162
- torch.cuda.empty_cache()
163
- if fmt == "wav":
164
- return Response(wav.getvalue(), mimetype="audio/wav")
165
- wav.seek(0, 0)
166
- with BytesIO() as ofp:
167
- wav2(wav, ofp, fmt)
168
- return Response(
169
- ofp.getvalue(), mimetype="audio/mpeg" if fmt == "mp3" else "audio/ogg"
170
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text/__init__.py CHANGED
@@ -19,10 +19,8 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
19
 
20
 
21
  def get_bert(norm_text, word2ph, language, device):
22
- from .chinese_bert import get_bert_feature as zh_bert
23
- from .english_bert_mock import get_bert_feature as en_bert
24
  from .japanese_bert import get_bert_feature as jp_bert
25
 
26
- lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
27
  bert = lang_bert_func_map[language](norm_text, word2ph, device)
28
  return bert
 
19
 
20
 
21
  def get_bert(norm_text, word2ph, language, device):
 
 
22
  from .japanese_bert import get_bert_feature as jp_bert
23
 
24
+ lang_bert_func_map = {"JP": jp_bert}
25
  bert = lang_bert_func_map[language](norm_text, word2ph, device)
26
  return bert