ZenXir kevinwang676 commited on
Commit
8ed8dd2
0 Parent(s):

Duplicate from kevinwang676/FreeVC

Browse files

Co-authored-by: Kevin Wang <kevinwang676@users.noreply.huggingface.co>

.flake8 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore =
3
+ # E203 whitespace before ':'
4
+ E203
5
+ D203,
6
+ # line too long
7
+ E501
8
+ per-file-ignores =
9
+ # imported but unused
10
+ # __init__.py: F401
11
+ test_*.py: F401
12
+ exclude =
13
+ .git,
14
+ __pycache__,
15
+ docs/source/conf.py,
16
+ old,
17
+ build,
18
+ dist,
19
+ .venv
20
+ pad*.py
21
+ max-complexity = 25
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ flagged
3
+ call-activate.bat
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FreeVC
3
+ emoji: 🚀
4
+ colorFrom: gray
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.36.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: kevinwang676/FreeVC
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import librosa
4
+ import gradio as gr
5
+ from scipy.io.wavfile import write
6
+ from transformers import WavLMModel
7
+
8
+ import utils
9
+ from models import SynthesizerTrn
10
+ from mel_processing import mel_spectrogram_torch
11
+ from speaker_encoder.voice_encoder import SpeakerEncoder
12
+
13
+ import time
14
+ from textwrap import dedent
15
+
16
+ import mdtex2html
17
+ from loguru import logger
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ from tts_voice import tts_order_voice
21
+ import edge_tts
22
+ import tempfile
23
+ import anyio
24
+
25
+ '''
26
+ def get_wavlm():
27
+ os.system('gdown https://drive.google.com/uc?id=12-cB34qCTvByWT-QtOcZaqwwO21FLSqU')
28
+ shutil.move('WavLM-Large.pt', 'wavlm')
29
+ '''
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
34
+
35
+ print("Loading FreeVC(24k)...")
36
+ hps = utils.get_hparams_from_file("configs/freevc-24.json")
37
+ freevc_24 = SynthesizerTrn(
38
+ hps.data.filter_length // 2 + 1,
39
+ hps.train.segment_size // hps.data.hop_length,
40
+ **hps.model).to(device)
41
+ _ = freevc_24.eval()
42
+ _ = utils.load_checkpoint("checkpoints/freevc-24.pth", freevc_24, None)
43
+
44
+ print("Loading WavLM for content...")
45
+ cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
46
+
47
+ def convert(model, src, tgt):
48
+ with torch.no_grad():
49
+ # tgt
50
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
51
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
52
+ if model == "FreeVC" or model == "FreeVC (24kHz)":
53
+ g_tgt = smodel.embed_utterance(wav_tgt)
54
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
55
+ else:
56
+ wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(device)
57
+ mel_tgt = mel_spectrogram_torch(
58
+ wav_tgt,
59
+ hps.data.filter_length,
60
+ hps.data.n_mel_channels,
61
+ hps.data.sampling_rate,
62
+ hps.data.hop_length,
63
+ hps.data.win_length,
64
+ hps.data.mel_fmin,
65
+ hps.data.mel_fmax
66
+ )
67
+ # src
68
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
69
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
70
+ c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
71
+ # infer
72
+ if model == "FreeVC":
73
+ audio = freevc.infer(c, g=g_tgt)
74
+ elif model == "FreeVC-s":
75
+ audio = freevc_s.infer(c, mel=mel_tgt)
76
+ else:
77
+ audio = freevc_24.infer(c, g=g_tgt)
78
+ audio = audio[0][0].data.cpu().float().numpy()
79
+ if model == "FreeVC" or model == "FreeVC-s":
80
+ write("out.wav", hps.data.sampling_rate, audio)
81
+ else:
82
+ write("out.wav", 24000, audio)
83
+ out = "out.wav"
84
+ return out
85
+
86
+ # GLM2
87
+
88
+ language_dict = tts_order_voice
89
+
90
+ # fix timezone in Linux
91
+ os.environ["TZ"] = "Asia/Shanghai"
92
+ try:
93
+ time.tzset() # type: ignore # pylint: disable=no-member
94
+ except Exception:
95
+ # Windows
96
+ logger.warning("Windows, cant run time.tzset()")
97
+
98
+ # model_name = "THUDM/chatglm2-6b"
99
+ model_name = "THUDM/chatglm2-6b-int4"
100
+
101
+ RETRY_FLAG = False
102
+
103
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
104
+
105
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
106
+
107
+ # 4/8 bit
108
+ # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
109
+
110
+ has_cuda = torch.cuda.is_available()
111
+
112
+ # has_cuda = False # force cpu
113
+
114
+ if has_cuda:
115
+ model_glm = (
116
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
117
+ ) # 3.92G
118
+ else:
119
+ model_glm = AutoModel.from_pretrained(
120
+ model_name, trust_remote_code=True
121
+ ).float() # .float() .half().float()
122
+
123
+ model_glm = model_glm.eval()
124
+
125
+ _ = """Override Chatbot.postprocess"""
126
+
127
+
128
+ def postprocess(self, y):
129
+ if y is None:
130
+ return []
131
+ for i, (message, response) in enumerate(y):
132
+ y[i] = (
133
+ None if message is None else mdtex2html.convert((message)),
134
+ None if response is None else mdtex2html.convert(response),
135
+ )
136
+ return y
137
+
138
+
139
+ gr.Chatbot.postprocess = postprocess
140
+
141
+
142
+ def parse_text(text):
143
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
144
+ lines = text.split("\n")
145
+ lines = [line for line in lines if line != ""]
146
+ count = 0
147
+ for i, line in enumerate(lines):
148
+ if "```" in line:
149
+ count += 1
150
+ items = line.split("`")
151
+ if count % 2 == 1:
152
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
153
+ else:
154
+ lines[i] = "<br></code></pre>"
155
+ else:
156
+ if i > 0:
157
+ if count % 2 == 1:
158
+ line = line.replace("`", r"\`")
159
+ line = line.replace("<", "&lt;")
160
+ line = line.replace(">", "&gt;")
161
+ line = line.replace(" ", "&nbsp;")
162
+ line = line.replace("*", "&ast;")
163
+ line = line.replace("_", "&lowbar;")
164
+ line = line.replace("-", "&#45;")
165
+ line = line.replace(".", "&#46;")
166
+ line = line.replace("!", "&#33;")
167
+ line = line.replace("(", "&#40;")
168
+ line = line.replace(")", "&#41;")
169
+ line = line.replace("$", "&#36;")
170
+ lines[i] = "<br>" + line
171
+ text = "".join(lines)
172
+ return text
173
+
174
+
175
+ def predict(
176
+ RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
177
+ ):
178
+ try:
179
+ chatbot.append((parse_text(input), ""))
180
+ except Exception as exc:
181
+ logger.error(exc)
182
+ logger.debug(f"{chatbot=}")
183
+ _ = """
184
+ if chatbot:
185
+ chatbot[-1] = (parse_text(input), str(exc))
186
+ yield chatbot, history, past_key_values
187
+ # """
188
+ yield chatbot, history, past_key_values
189
+
190
+ for response, history, past_key_values in model_glm.stream_chat(
191
+ tokenizer,
192
+ input,
193
+ history,
194
+ past_key_values=past_key_values,
195
+ return_past_key_values=True,
196
+ max_length=max_length,
197
+ top_p=top_p,
198
+ temperature=temperature,
199
+ ):
200
+ chatbot[-1] = (parse_text(input), parse_text(response))
201
+ # chatbot[-1][-1] = parse_text(response)
202
+
203
+ yield chatbot, history, past_key_values, parse_text(response)
204
+
205
+
206
+ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
207
+ if max_length < 10:
208
+ max_length = 4096
209
+ if top_p < 0.1 or top_p > 1:
210
+ top_p = 0.85
211
+ if temperature <= 0 or temperature > 1:
212
+ temperature = 0.01
213
+ try:
214
+ res, _ = model_glm.chat(
215
+ tokenizer,
216
+ input,
217
+ history=[],
218
+ past_key_values=None,
219
+ max_length=max_length,
220
+ top_p=top_p,
221
+ temperature=temperature,
222
+ )
223
+ # logger.debug(f"{res=} \n{_=}")
224
+ except Exception as exc:
225
+ logger.error(f"{exc=}")
226
+ res = str(exc)
227
+
228
+ return res
229
+
230
+
231
+ def reset_user_input():
232
+ return gr.update(value="")
233
+
234
+
235
+ def reset_state():
236
+ return [], [], None, ""
237
+
238
+
239
+ # Delete last turn
240
+ def delete_last_turn(chat, history):
241
+ if chat and history:
242
+ chat.pop(-1)
243
+ history.pop(-1)
244
+ return chat, history
245
+
246
+
247
+ # Regenerate response
248
+ def retry_last_answer(
249
+ user_input, chatbot, max_length, top_p, temperature, history, past_key_values
250
+ ):
251
+ if chatbot and history:
252
+ # Removing the previous conversation from chat
253
+ chatbot.pop(-1)
254
+ # Setting up a flag to capture a retry
255
+ RETRY_FLAG = True
256
+ # Getting last message from user
257
+ user_input = history[-1][0]
258
+ # Removing bot response from the history
259
+ history.pop(-1)
260
+
261
+ yield from predict(
262
+ RETRY_FLAG, # type: ignore
263
+ user_input,
264
+ chatbot,
265
+ max_length,
266
+ top_p,
267
+ temperature,
268
+ history,
269
+ past_key_values,
270
+ )
271
+
272
+ # print
273
+
274
+ def print(text):
275
+ return text
276
+
277
+ # TTS
278
+
279
+ async def text_to_speech_edge(text, language_code):
280
+ voice = language_dict[language_code]
281
+ communicate = edge_tts.Communicate(text, voice)
282
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
283
+ tmp_path = tmp_file.name
284
+
285
+ await communicate.save(tmp_path)
286
+
287
+ return tmp_path
288
+
289
+
290
+ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) as demo:
291
+ gr.HTML("<center>"
292
+ "<h1>🥳💕🎶 - ChatGLM2 + 声音克隆:和你喜欢的角色畅所欲言吧!</h1>"
293
+ "</center>")
294
+ gr.Markdown("## <center>💡 - 第二代ChatGLM大语言模型 + FreeVC变声,为您打造独一无二的沉浸式对话体验,支持中英双语</center>")
295
+ gr.Markdown("## <center>🌊 - 更多精彩应用,尽在[滔滔AI](http://www.talktalkai.com);滔滔AI,为爱滔滔!💕</center>")
296
+ gr.Markdown("### <center>⭐ - 如果您喜欢这个程序,欢迎给我的[Github项目](https://github.com/KevinWang676/ChatGLM2-Voice-Cloning)点赞支持!</center>")
297
+
298
+ with gr.Accordion("📒 相关信息", open=False):
299
+ _ = f""" ChatGLM2的可选参数信息:
300
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
301
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
302
+ * Top P controls dynamic vocabulary selection based on context.\n
303
+ 如果您想让ChatGLM2进行角色扮演并与之对话,请先输入恰当的提示词,如“请你扮演成动漫角色蜡笔小新并和我进行对话”;您也可以为ChatGLM2提供自定义的角色设定\n
304
+ 当您使用声音克隆功能时,请先在此程序的对应位置上传一段您喜欢的音频
305
+ """
306
+ gr.Markdown(dedent(_))
307
+ chatbot = gr.Chatbot(height=300)
308
+ with gr.Row():
309
+ with gr.Column(scale=4):
310
+ with gr.Column(scale=12):
311
+ user_input = gr.Textbox(
312
+ label="请在此处和GLM2聊天 (按回车键即可发送)",
313
+ placeholder="聊点什么吧",
314
+ )
315
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
316
+ with gr.Column(min_width=32, scale=1):
317
+ with gr.Row():
318
+ submitBtn = gr.Button("开始和GLM2交流吧", variant="primary")
319
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
320
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
321
+
322
+ with gr.Accordion("🔧 更多设置", open=False):
323
+ with gr.Row():
324
+ emptyBtn = gr.Button("清空所有聊天记录")
325
+ max_length = gr.Slider(
326
+ 0,
327
+ 32768,
328
+ value=8192,
329
+ step=1.0,
330
+ label="Maximum length",
331
+ interactive=True,
332
+ )
333
+ top_p = gr.Slider(
334
+ 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
335
+ )
336
+ temperature = gr.Slider(
337
+ 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
338
+ )
339
+
340
+
341
+ with gr.Row():
342
+ test1 = gr.Textbox(label="GLM2的最新回答 (可编辑)", lines = 3)
343
+ with gr.Column():
344
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
345
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
346
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频", interactive=False)
347
+
348
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
349
+
350
+ with gr.Row():
351
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
352
+ audio1 = output_audio
353
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
354
+ clone_btn = gr.Button("开始AI声音克隆吧", variant="primary")
355
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
356
+
357
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
358
+
359
+ history = gr.State([])
360
+ past_key_values = gr.State(None)
361
+
362
+ user_input.submit(
363
+ predict,
364
+ [
365
+ RETRY_FLAG,
366
+ user_input,
367
+ chatbot,
368
+ max_length,
369
+ top_p,
370
+ temperature,
371
+ history,
372
+ past_key_values,
373
+ ],
374
+ [chatbot, history, past_key_values, test1],
375
+ show_progress="full",
376
+ )
377
+ submitBtn.click(
378
+ predict,
379
+ [
380
+ RETRY_FLAG,
381
+ user_input,
382
+ chatbot,
383
+ max_length,
384
+ top_p,
385
+ temperature,
386
+ history,
387
+ past_key_values,
388
+ ],
389
+ [chatbot, history, past_key_values, test1],
390
+ show_progress="full",
391
+ api_name="predict",
392
+ )
393
+ submitBtn.click(reset_user_input, [], [user_input])
394
+
395
+ emptyBtn.click(
396
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
397
+ )
398
+
399
+ retryBtn.click(
400
+ retry_last_answer,
401
+ inputs=[
402
+ user_input,
403
+ chatbot,
404
+ max_length,
405
+ top_p,
406
+ temperature,
407
+ history,
408
+ past_key_values,
409
+ ],
410
+ # outputs = [chatbot, history, last_user_message, user_message]
411
+ outputs=[chatbot, history, past_key_values, test1],
412
+ )
413
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
414
+
415
+ with gr.Accordion("📔 提示词示例", open=False):
416
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
417
+ examples = gr.Examples(
418
+ examples=[
419
+ ["Explain the plot of Cinderella in a sentence."],
420
+ [
421
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
422
+ ],
423
+ ["What are some common mistakes to avoid when writing code?"],
424
+ ["Build a prompt to generate a beautiful portrait of a horse"],
425
+ ["Suggest four metaphors to describe the benefits of AI"],
426
+ ["Write a pop song about leaving home for the sandy beaches."],
427
+ ["Write a summary demonstrating my ability to tame lions"],
428
+ ["鲁迅和周树人什么关系"],
429
+ ["从前有一头牛,这头牛后面有什么?"],
430
+ ["正无穷大加一大于正无穷大吗?"],
431
+ ["正无穷大加正无穷大大于正无穷大吗?"],
432
+ ["-2的平方根等于什么"],
433
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
434
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
435
+ ["鲁迅和周树人什么关系 用英文回答"],
436
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
437
+ [f"{etext} 翻成中文,列出3个版本"],
438
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
439
+ ["js 判断一个数是不是质数"],
440
+ ["js 实现python 的 range(10)"],
441
+ ["js 实现python 的 [*(range(10)]"],
442
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
443
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
444
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
445
+ ],
446
+ inputs=[user_input],
447
+ examples_per_page=30,
448
+ )
449
+
450
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
451
+ input_text = gr.Text()
452
+ tr_btn = gr.Button("Go", variant="primary")
453
+ out_text = gr.Text()
454
+ tr_btn.click(
455
+ trans_api,
456
+ [input_text, max_length, top_p, temperature],
457
+ out_text,
458
+ # show_progress="full",
459
+ api_name="tr",
460
+ )
461
+ _ = """
462
+ input_text.submit(
463
+ trans_api,
464
+ [input_text, max_length, top_p, temperature],
465
+ out_text,
466
+ show_progress="full",
467
+ api_name="tr1",
468
+ )
469
+ # """
470
+
471
+ gr.Markdown("### <center>注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。</center>")
472
+ gr.Markdown("<center>💡 - 如何使用此程序:输入您对ChatGLM的提问后,依次点击“开始和GLM2交流吧”、“生成对应的音频吧”、“开始AI声音克隆吧”三个按键即可;使用声音克隆功能时,请先上传一段您喜欢的音频</center>")
473
+ gr.HTML('''
474
+ <div class="footer">
475
+ <p>🌊🏞️🎶 - 江水东流急,滔滔无尽声。 明·顾璘
476
+ </p>
477
+ </div>
478
+ ''')
479
+
480
+
481
+ demo.queue().launch(show_error=True, debug=True)
checkpoints/freevc-24.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b39a86fefbc9ec6e30be8d26ee2a6aa5ffe6d235f6ab15773d01cdf348e5b20
3
+ size 472644351
commons.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size*dilation - dilation)/2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def rand_spec_segments(x, x_lengths=None, segment_size=4):
68
+ b, d, t = x.size()
69
+ if x_lengths is None:
70
+ x_lengths = t
71
+ ids_str_max = x_lengths - segment_size
72
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
73
+ ret = slice_segments(x, ids_str, segment_size)
74
+ return ret, ids_str
75
+
76
+
77
+ def get_timing_signal_1d(
78
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
79
+ position = torch.arange(length, dtype=torch.float)
80
+ num_timescales = channels // 2
81
+ log_timescale_increment = (
82
+ math.log(float(max_timescale) / float(min_timescale)) /
83
+ (num_timescales - 1))
84
+ inv_timescales = min_timescale * torch.exp(
85
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
86
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
87
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
88
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
89
+ signal = signal.view(1, channels, length)
90
+ return signal
91
+
92
+
93
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
94
+ b, channels, length = x.size()
95
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
96
+ return x + signal.to(dtype=x.dtype, device=x.device)
97
+
98
+
99
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
100
+ b, channels, length = x.size()
101
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
102
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
103
+
104
+
105
+ def subsequent_mask(length):
106
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
107
+ return mask
108
+
109
+
110
+ @torch.jit.script
111
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
112
+ n_channels_int = n_channels[0]
113
+ in_act = input_a + input_b
114
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
115
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
116
+ acts = t_act * s_act
117
+ return acts
118
+
119
+
120
+ def convert_pad_shape(pad_shape):
121
+ l = pad_shape[::-1]
122
+ pad_shape = [item for sublist in l for item in sublist]
123
+ return pad_shape
124
+
125
+
126
+ def shift_1d(x):
127
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
128
+ return x
129
+
130
+
131
+ def sequence_mask(length, max_length=None):
132
+ if max_length is None:
133
+ max_length = length.max()
134
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
135
+ return x.unsqueeze(0) < length.unsqueeze(1)
136
+
137
+
138
+ def generate_path(duration, mask):
139
+ """
140
+ duration: [b, 1, t_x]
141
+ mask: [b, 1, t_y, t_x]
142
+ """
143
+ device = duration.device
144
+
145
+ b, _, t_y, t_x = mask.shape
146
+ cum_duration = torch.cumsum(duration, -1)
147
+
148
+ cum_duration_flat = cum_duration.view(b * t_x)
149
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
150
+ path = path.view(b, t_x, t_y)
151
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
152
+ path = path.unsqueeze(1).transpose(2,3) * mask
153
+ return path
154
+
155
+
156
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
157
+ if isinstance(parameters, torch.Tensor):
158
+ parameters = [parameters]
159
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
160
+ norm_type = float(norm_type)
161
+ if clip_value is not None:
162
+ clip_value = float(clip_value)
163
+
164
+ total_norm = 0
165
+ for p in parameters:
166
+ param_norm = p.grad.data.norm(norm_type)
167
+ total_norm += param_norm.item() ** norm_type
168
+ if clip_value is not None:
169
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
170
+ total_norm = total_norm ** (1. / norm_type)
171
+ return total_norm
configs/freevc-24.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 10000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 64,
11
+ "fp16_run": false,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8640,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0,
18
+ "use_sr": true,
19
+ "max_speclen": 128,
20
+ "port": "8008"
21
+ },
22
+ "data": {
23
+ "training_files":"filelists/train.txt",
24
+ "validation_files":"filelists/val.txt",
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 16000,
27
+ "filter_length": 1280,
28
+ "hop_length": 320,
29
+ "win_length": 1280,
30
+ "n_mel_channels": 80,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null
33
+ },
34
+ "model": {
35
+ "inter_channels": 192,
36
+ "hidden_channels": 192,
37
+ "filter_channels": 768,
38
+ "n_heads": 2,
39
+ "n_layers": 6,
40
+ "kernel_size": 3,
41
+ "p_dropout": 0.1,
42
+ "resblock": "1",
43
+ "resblock_kernel_sizes": [3,7,11],
44
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
45
+ "upsample_rates": [10,6,4,2],
46
+ "upsample_initial_channel": 512,
47
+ "upsample_kernel_sizes": [16,16,4,4],
48
+ "n_layers_q": 3,
49
+ "use_spectral_norm": false,
50
+ "gin_channels": 256,
51
+ "ssl_dim": 1024,
52
+ "use_spk": true
53
+ }
54
+ }
mel_processing.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ import numpy as np
9
+ import librosa
10
+ import librosa.util as librosa_util
11
+ from librosa.util import normalize, pad_center, tiny
12
+ from scipy.signal import get_window
13
+ from scipy.io.wavfile import read
14
+ from librosa.filters import mel as librosa_mel_fn
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ """
21
+ PARAMS
22
+ ------
23
+ C: compression factor
24
+ """
25
+ return torch.log(torch.clamp(x, min=clip_val) * C)
26
+
27
+
28
+ def dynamic_range_decompression_torch(x, C=1):
29
+ """
30
+ PARAMS
31
+ ------
32
+ C: compression factor used to compress
33
+ """
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
+ if torch.min(y) < -1.:
53
+ print('min value is ', torch.min(y))
54
+ if torch.max(y) > 1.:
55
+ print('max value is ', torch.max(y))
56
+
57
+ global hann_window
58
+ dtype_device = str(y.dtype) + '_' + str(y.device)
59
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
60
+ if wnsize_dtype_device not in hann_window:
61
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62
+
63
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64
+ y = y.squeeze(1)
65
+
66
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
68
+
69
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70
+ return spec
71
+
72
+
73
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74
+ global mel_basis
75
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
76
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
77
+ if fmax_dtype_device not in mel_basis:
78
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81
+ spec = spectral_normalize_torch(spec)
82
+ return spec
83
+
84
+
85
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86
+ if torch.min(y) < -1.:
87
+ print('min value is ', torch.min(y))
88
+ if torch.max(y) > 1.:
89
+ print('max value is ', torch.max(y))
90
+
91
+ global mel_basis, hann_window
92
+ dtype_device = str(y.dtype) + '_' + str(y.device)
93
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
94
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
95
+ if fmax_dtype_device not in mel_basis:
96
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
97
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98
+ if wnsize_dtype_device not in hann_window:
99
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100
+
101
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102
+ y = y.squeeze(1)
103
+
104
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
106
+
107
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108
+
109
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110
+ spec = spectral_normalize_torch(spec)
111
+
112
+ return spec
models.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ import modules
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from commons import init_weights, get_padding
13
+
14
+
15
+ class ResidualCouplingBlock(nn.Module):
16
+ def __init__(self,
17
+ channels,
18
+ hidden_channels,
19
+ kernel_size,
20
+ dilation_rate,
21
+ n_layers,
22
+ n_flows=4,
23
+ gin_channels=0):
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.hidden_channels = hidden_channels
27
+ self.kernel_size = kernel_size
28
+ self.dilation_rate = dilation_rate
29
+ self.n_layers = n_layers
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.flows = nn.ModuleList()
34
+ for i in range(n_flows):
35
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
36
+ self.flows.append(modules.Flip())
37
+
38
+ def forward(self, x, x_mask, g=None, reverse=False):
39
+ if not reverse:
40
+ for flow in self.flows:
41
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
42
+ else:
43
+ for flow in reversed(self.flows):
44
+ x = flow(x, x_mask, g=g, reverse=reverse)
45
+ return x
46
+
47
+
48
+ class Encoder(nn.Module):
49
+ def __init__(self,
50
+ in_channels,
51
+ out_channels,
52
+ hidden_channels,
53
+ kernel_size,
54
+ dilation_rate,
55
+ n_layers,
56
+ gin_channels=0):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ self.out_channels = out_channels
60
+ self.hidden_channels = hidden_channels
61
+ self.kernel_size = kernel_size
62
+ self.dilation_rate = dilation_rate
63
+ self.n_layers = n_layers
64
+ self.gin_channels = gin_channels
65
+
66
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
67
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
68
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
69
+
70
+ def forward(self, x, x_lengths, g=None):
71
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
72
+ x = self.pre(x) * x_mask
73
+ x = self.enc(x, x_mask, g=g)
74
+ stats = self.proj(x) * x_mask
75
+ m, logs = torch.split(stats, self.out_channels, dim=1)
76
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
77
+ return z, m, logs, x_mask
78
+
79
+
80
+ class Generator(torch.nn.Module):
81
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
82
+ super(Generator, self).__init__()
83
+ self.num_kernels = len(resblock_kernel_sizes)
84
+ self.num_upsamples = len(upsample_rates)
85
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
86
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
87
+
88
+ self.ups = nn.ModuleList()
89
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
90
+ self.ups.append(weight_norm(
91
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
92
+ k, u, padding=(k-u)//2)))
93
+
94
+ self.resblocks = nn.ModuleList()
95
+ for i in range(len(self.ups)):
96
+ ch = upsample_initial_channel//(2**(i+1))
97
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
98
+ self.resblocks.append(resblock(ch, k, d))
99
+
100
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
101
+ self.ups.apply(init_weights)
102
+
103
+ if gin_channels != 0:
104
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
105
+
106
+ def forward(self, x, g=None):
107
+ x = self.conv_pre(x)
108
+ if g is not None:
109
+ x = x + self.cond(g)
110
+
111
+ for i in range(self.num_upsamples):
112
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
113
+ x = self.ups[i](x)
114
+ xs = None
115
+ for j in range(self.num_kernels):
116
+ if xs is None:
117
+ xs = self.resblocks[i*self.num_kernels+j](x)
118
+ else:
119
+ xs += self.resblocks[i*self.num_kernels+j](x)
120
+ x = xs / self.num_kernels
121
+ x = F.leaky_relu(x)
122
+ x = self.conv_post(x)
123
+ x = torch.tanh(x)
124
+
125
+ return x
126
+
127
+ def remove_weight_norm(self):
128
+ print('Removing weight norm...')
129
+ for l in self.ups:
130
+ remove_weight_norm(l)
131
+ for l in self.resblocks:
132
+ l.remove_weight_norm()
133
+
134
+
135
+ class DiscriminatorP(torch.nn.Module):
136
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
137
+ super(DiscriminatorP, self).__init__()
138
+ self.period = period
139
+ self.use_spectral_norm = use_spectral_norm
140
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
141
+ self.convs = nn.ModuleList([
142
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
143
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
144
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
145
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
146
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
147
+ ])
148
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
149
+
150
+ def forward(self, x):
151
+ fmap = []
152
+
153
+ # 1d to 2d
154
+ b, c, t = x.shape
155
+ if t % self.period != 0: # pad first
156
+ n_pad = self.period - (t % self.period)
157
+ x = F.pad(x, (0, n_pad), "reflect")
158
+ t = t + n_pad
159
+ x = x.view(b, c, t // self.period, self.period)
160
+
161
+ for l in self.convs:
162
+ x = l(x)
163
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
164
+ fmap.append(x)
165
+ x = self.conv_post(x)
166
+ fmap.append(x)
167
+ x = torch.flatten(x, 1, -1)
168
+
169
+ return x, fmap
170
+
171
+
172
+ class DiscriminatorS(torch.nn.Module):
173
+ def __init__(self, use_spectral_norm=False):
174
+ super(DiscriminatorS, self).__init__()
175
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
176
+ self.convs = nn.ModuleList([
177
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
178
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
179
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
180
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
181
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
182
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
183
+ ])
184
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
185
+
186
+ def forward(self, x):
187
+ fmap = []
188
+
189
+ for l in self.convs:
190
+ x = l(x)
191
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
192
+ fmap.append(x)
193
+ x = self.conv_post(x)
194
+ fmap.append(x)
195
+ x = torch.flatten(x, 1, -1)
196
+
197
+ return x, fmap
198
+
199
+
200
+ class MultiPeriodDiscriminator(torch.nn.Module):
201
+ def __init__(self, use_spectral_norm=False):
202
+ super(MultiPeriodDiscriminator, self).__init__()
203
+ periods = [2,3,5,7,11]
204
+
205
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
206
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
207
+ self.discriminators = nn.ModuleList(discs)
208
+
209
+ def forward(self, y, y_hat):
210
+ y_d_rs = []
211
+ y_d_gs = []
212
+ fmap_rs = []
213
+ fmap_gs = []
214
+ for i, d in enumerate(self.discriminators):
215
+ y_d_r, fmap_r = d(y)
216
+ y_d_g, fmap_g = d(y_hat)
217
+ y_d_rs.append(y_d_r)
218
+ y_d_gs.append(y_d_g)
219
+ fmap_rs.append(fmap_r)
220
+ fmap_gs.append(fmap_g)
221
+
222
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
223
+
224
+
225
+ class SpeakerEncoder(torch.nn.Module):
226
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
227
+ super(SpeakerEncoder, self).__init__()
228
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
229
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
230
+ self.relu = nn.ReLU()
231
+
232
+ def forward(self, mels):
233
+ self.lstm.flatten_parameters()
234
+ _, (hidden, _) = self.lstm(mels)
235
+ embeds_raw = self.relu(self.linear(hidden[-1]))
236
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
237
+
238
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
239
+ mel_slices = []
240
+ for i in range(0, total_frames-partial_frames, partial_hop):
241
+ mel_range = torch.arange(i, i+partial_frames)
242
+ mel_slices.append(mel_range)
243
+
244
+ return mel_slices
245
+
246
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
247
+ mel_len = mel.size(1)
248
+ last_mel = mel[:,-partial_frames:]
249
+
250
+ if mel_len > partial_frames:
251
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
252
+ mels = list(mel[:,s] for s in mel_slices)
253
+ mels.append(last_mel)
254
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
255
+
256
+ with torch.no_grad():
257
+ partial_embeds = self(mels)
258
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
259
+ #embed = embed / torch.linalg.norm(embed, 2)
260
+ else:
261
+ with torch.no_grad():
262
+ embed = self(last_mel)
263
+
264
+ return embed
265
+
266
+
267
+ class SynthesizerTrn(nn.Module):
268
+ """
269
+ Synthesizer for Training
270
+ """
271
+
272
+ def __init__(self,
273
+ spec_channels,
274
+ segment_size,
275
+ inter_channels,
276
+ hidden_channels,
277
+ filter_channels,
278
+ n_heads,
279
+ n_layers,
280
+ kernel_size,
281
+ p_dropout,
282
+ resblock,
283
+ resblock_kernel_sizes,
284
+ resblock_dilation_sizes,
285
+ upsample_rates,
286
+ upsample_initial_channel,
287
+ upsample_kernel_sizes,
288
+ gin_channels,
289
+ ssl_dim,
290
+ use_spk,
291
+ **kwargs):
292
+
293
+ super().__init__()
294
+ self.spec_channels = spec_channels
295
+ self.inter_channels = inter_channels
296
+ self.hidden_channels = hidden_channels
297
+ self.filter_channels = filter_channels
298
+ self.n_heads = n_heads
299
+ self.n_layers = n_layers
300
+ self.kernel_size = kernel_size
301
+ self.p_dropout = p_dropout
302
+ self.resblock = resblock
303
+ self.resblock_kernel_sizes = resblock_kernel_sizes
304
+ self.resblock_dilation_sizes = resblock_dilation_sizes
305
+ self.upsample_rates = upsample_rates
306
+ self.upsample_initial_channel = upsample_initial_channel
307
+ self.upsample_kernel_sizes = upsample_kernel_sizes
308
+ self.segment_size = segment_size
309
+ self.gin_channels = gin_channels
310
+ self.ssl_dim = ssl_dim
311
+ self.use_spk = use_spk
312
+
313
+ self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16)
314
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
315
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
316
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
317
+
318
+ if not self.use_spk:
319
+ self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels)
320
+
321
+ def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
322
+ if c_lengths == None:
323
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
324
+ if spec_lengths == None:
325
+ spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
326
+
327
+ if not self.use_spk:
328
+ g = self.enc_spk(mel.transpose(1,2))
329
+ g = g.unsqueeze(-1)
330
+
331
+ _, m_p, logs_p, _ = self.enc_p(c, c_lengths)
332
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
333
+ z_p = self.flow(z, spec_mask, g=g)
334
+
335
+ z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
336
+ o = self.dec(z_slice, g=g)
337
+
338
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
339
+
340
+ def infer(self, c, g=None, mel=None, c_lengths=None):
341
+ if c_lengths == None:
342
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
343
+ if not self.use_spk:
344
+ g = self.enc_spk.embed_utterance(mel.transpose(1,2))
345
+ g = g.unsqueeze(-1)
346
+
347
+ z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
348
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
349
+ o = self.dec(z * c_mask, g=g)
350
+
351
+ return o
modules.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48
+ self.norm_layers.append(LayerNorm(hidden_channels))
49
+ self.relu_drop = nn.Sequential(
50
+ nn.ReLU(),
51
+ nn.Dropout(p_dropout))
52
+ for _ in range(n_layers-1):
53
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56
+ self.proj.weight.data.zero_()
57
+ self.proj.bias.data.zero_()
58
+
59
+ def forward(self, x, x_mask):
60
+ x_org = x
61
+ for i in range(self.n_layers):
62
+ x = self.conv_layers[i](x * x_mask)
63
+ x = self.norm_layers[i](x)
64
+ x = self.relu_drop(x)
65
+ x = x_org + self.proj(x)
66
+ return x * x_mask
67
+
68
+
69
+ class DDSConv(nn.Module):
70
+ """
71
+ Dialted and Depth-Separable Convolution
72
+ """
73
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.kernel_size = kernel_size
77
+ self.n_layers = n_layers
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = nn.Dropout(p_dropout)
81
+ self.convs_sep = nn.ModuleList()
82
+ self.convs_1x1 = nn.ModuleList()
83
+ self.norms_1 = nn.ModuleList()
84
+ self.norms_2 = nn.ModuleList()
85
+ for i in range(n_layers):
86
+ dilation = kernel_size ** i
87
+ padding = (kernel_size * dilation - dilation) // 2
88
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89
+ groups=channels, dilation=dilation, padding=padding
90
+ ))
91
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92
+ self.norms_1.append(LayerNorm(channels))
93
+ self.norms_2.append(LayerNorm(channels))
94
+
95
+ def forward(self, x, x_mask, g=None):
96
+ if g is not None:
97
+ x = x + g
98
+ for i in range(self.n_layers):
99
+ y = self.convs_sep[i](x * x_mask)
100
+ y = self.norms_1[i](y)
101
+ y = F.gelu(y)
102
+ y = self.convs_1x1[i](y)
103
+ y = self.norms_2[i](y)
104
+ y = F.gelu(y)
105
+ y = self.drop(y)
106
+ x = x + y
107
+ return x * x_mask
108
+
109
+
110
+ class WN(torch.nn.Module):
111
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112
+ super(WN, self).__init__()
113
+ assert(kernel_size % 2 == 1)
114
+ self.hidden_channels =hidden_channels
115
+ self.kernel_size = kernel_size,
116
+ self.dilation_rate = dilation_rate
117
+ self.n_layers = n_layers
118
+ self.gin_channels = gin_channels
119
+ self.p_dropout = p_dropout
120
+
121
+ self.in_layers = torch.nn.ModuleList()
122
+ self.res_skip_layers = torch.nn.ModuleList()
123
+ self.drop = nn.Dropout(p_dropout)
124
+
125
+ if gin_channels != 0:
126
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128
+
129
+ for i in range(n_layers):
130
+ dilation = dilation_rate ** i
131
+ padding = int((kernel_size * dilation - dilation) / 2)
132
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133
+ dilation=dilation, padding=padding)
134
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135
+ self.in_layers.append(in_layer)
136
+
137
+ # last one is not necessary
138
+ if i < n_layers - 1:
139
+ res_skip_channels = 2 * hidden_channels
140
+ else:
141
+ res_skip_channels = hidden_channels
142
+
143
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145
+ self.res_skip_layers.append(res_skip_layer)
146
+
147
+ def forward(self, x, x_mask, g=None, **kwargs):
148
+ output = torch.zeros_like(x)
149
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
150
+
151
+ if g is not None:
152
+ g = self.cond_layer(g)
153
+
154
+ for i in range(self.n_layers):
155
+ x_in = self.in_layers[i](x)
156
+ if g is not None:
157
+ cond_offset = i * 2 * self.hidden_channels
158
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159
+ else:
160
+ g_l = torch.zeros_like(x_in)
161
+
162
+ acts = commons.fused_add_tanh_sigmoid_multiply(
163
+ x_in,
164
+ g_l,
165
+ n_channels_tensor)
166
+ acts = self.drop(acts)
167
+
168
+ res_skip_acts = self.res_skip_layers[i](acts)
169
+ if i < self.n_layers - 1:
170
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
171
+ x = (x + res_acts) * x_mask
172
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
173
+ else:
174
+ output = output + res_skip_acts
175
+ return output * x_mask
176
+
177
+ def remove_weight_norm(self):
178
+ if self.gin_channels != 0:
179
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
180
+ for l in self.in_layers:
181
+ torch.nn.utils.remove_weight_norm(l)
182
+ for l in self.res_skip_layers:
183
+ torch.nn.utils.remove_weight_norm(l)
184
+
185
+
186
+ class ResBlock1(torch.nn.Module):
187
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188
+ super(ResBlock1, self).__init__()
189
+ self.convs1 = nn.ModuleList([
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191
+ padding=get_padding(kernel_size, dilation[0]))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193
+ padding=get_padding(kernel_size, dilation[1]))),
194
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195
+ padding=get_padding(kernel_size, dilation[2])))
196
+ ])
197
+ self.convs1.apply(init_weights)
198
+
199
+ self.convs2 = nn.ModuleList([
200
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
+ padding=get_padding(kernel_size, 1))),
202
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
+ padding=get_padding(kernel_size, 1))),
204
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205
+ padding=get_padding(kernel_size, 1)))
206
+ ])
207
+ self.convs2.apply(init_weights)
208
+
209
+ def forward(self, x, x_mask=None):
210
+ for c1, c2 in zip(self.convs1, self.convs2):
211
+ xt = F.leaky_relu(x, LRELU_SLOPE)
212
+ if x_mask is not None:
213
+ xt = xt * x_mask
214
+ xt = c1(xt)
215
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
216
+ if x_mask is not None:
217
+ xt = xt * x_mask
218
+ xt = c2(xt)
219
+ x = xt + x
220
+ if x_mask is not None:
221
+ x = x * x_mask
222
+ return x
223
+
224
+ def remove_weight_norm(self):
225
+ for l in self.convs1:
226
+ remove_weight_norm(l)
227
+ for l in self.convs2:
228
+ remove_weight_norm(l)
229
+
230
+
231
+ class ResBlock2(torch.nn.Module):
232
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233
+ super(ResBlock2, self).__init__()
234
+ self.convs = nn.ModuleList([
235
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]))),
237
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238
+ padding=get_padding(kernel_size, dilation[1])))
239
+ ])
240
+ self.convs.apply(init_weights)
241
+
242
+ def forward(self, x, x_mask=None):
243
+ for c in self.convs:
244
+ xt = F.leaky_relu(x, LRELU_SLOPE)
245
+ if x_mask is not None:
246
+ xt = xt * x_mask
247
+ xt = c(xt)
248
+ x = xt + x
249
+ if x_mask is not None:
250
+ x = x * x_mask
251
+ return x
252
+
253
+ def remove_weight_norm(self):
254
+ for l in self.convs:
255
+ remove_weight_norm(l)
256
+
257
+
258
+ class Log(nn.Module):
259
+ def forward(self, x, x_mask, reverse=False, **kwargs):
260
+ if not reverse:
261
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262
+ logdet = torch.sum(-y, [1, 2])
263
+ return y, logdet
264
+ else:
265
+ x = torch.exp(x) * x_mask
266
+ return x
267
+
268
+
269
+ class Flip(nn.Module):
270
+ def forward(self, x, *args, reverse=False, **kwargs):
271
+ x = torch.flip(x, [1])
272
+ if not reverse:
273
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274
+ return x, logdet
275
+ else:
276
+ return x
277
+
278
+
279
+ class ElementwiseAffine(nn.Module):
280
+ def __init__(self, channels):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.m = nn.Parameter(torch.zeros(channels,1))
284
+ self.logs = nn.Parameter(torch.zeros(channels,1))
285
+
286
+ def forward(self, x, x_mask, reverse=False, **kwargs):
287
+ if not reverse:
288
+ y = self.m + torch.exp(self.logs) * x
289
+ y = y * x_mask
290
+ logdet = torch.sum(self.logs * x_mask, [1,2])
291
+ return y, logdet
292
+ else:
293
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
294
+ return x
295
+
296
+
297
+ class ResidualCouplingLayer(nn.Module):
298
+ def __init__(self,
299
+ channels,
300
+ hidden_channels,
301
+ kernel_size,
302
+ dilation_rate,
303
+ n_layers,
304
+ p_dropout=0,
305
+ gin_channels=0,
306
+ mean_only=False):
307
+ assert channels % 2 == 0, "channels should be divisible by 2"
308
+ super().__init__()
309
+ self.channels = channels
310
+ self.hidden_channels = hidden_channels
311
+ self.kernel_size = kernel_size
312
+ self.dilation_rate = dilation_rate
313
+ self.n_layers = n_layers
314
+ self.half_channels = channels // 2
315
+ self.mean_only = mean_only
316
+
317
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320
+ self.post.weight.data.zero_()
321
+ self.post.bias.data.zero_()
322
+
323
+ def forward(self, x, x_mask, g=None, reverse=False):
324
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325
+ h = self.pre(x0) * x_mask
326
+ h = self.enc(h, x_mask, g=g)
327
+ stats = self.post(h) * x_mask
328
+ if not self.mean_only:
329
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
330
+ else:
331
+ m = stats
332
+ logs = torch.zeros_like(m)
333
+
334
+ if not reverse:
335
+ x1 = m + x1 * torch.exp(logs) * x_mask
336
+ x = torch.cat([x0, x1], 1)
337
+ logdet = torch.sum(logs, [1,2])
338
+ return x, logdet
339
+ else:
340
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
341
+ x = torch.cat([x0, x1], 1)
342
+ return x
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.22.0
2
+ scipy
3
+ torch
4
+ transformers
5
+ librosa==0.8.1
6
+ webrtcvad==2.0.10
7
+ protobuf
8
+ cpm_kernels
9
+ mdtex2html
10
+ sentencepiece
11
+ accelerate
12
+ loguru
13
+ edge_tts
14
+ altair
15
+ gradio==3.36.1
speaker_encoder/__init__.py ADDED
File without changes
speaker_encoder/audio.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from speaker_encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ import numpy as np
6
+ import webrtcvad
7
+ import librosa
8
+ import struct
9
+
10
+ int16_max = (2 ** 15) - 1
11
+
12
+
13
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14
+ source_sr: Optional[int] = None):
15
+ """
16
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18
+
19
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20
+ just .wav), either the waveform as a numpy array of floats.
21
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
23
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24
+ this argument will be ignored.
25
+ """
26
+ # Load the wav from disk if needed
27
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28
+ wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29
+ else:
30
+ wav = fpath_or_wav
31
+
32
+ # Resample the wav if needed
33
+ if source_sr is not None and source_sr != sampling_rate:
34
+ wav = librosa.resample(wav, source_sr, sampling_rate)
35
+
36
+ # Apply the preprocessing: normalize volume and shorten long silences
37
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38
+ wav = trim_long_silences(wav)
39
+
40
+ return wav
41
+
42
+
43
+ def wav_to_mel_spectrogram(wav):
44
+ """
45
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46
+ Note: this not a log-mel spectrogram.
47
+ """
48
+ frames = librosa.feature.melspectrogram(
49
+ y=wav,
50
+ sr=sampling_rate,
51
+ n_fft=int(sampling_rate * mel_window_length / 1000),
52
+ hop_length=int(sampling_rate * mel_window_step / 1000),
53
+ n_mels=mel_n_channels
54
+ )
55
+ return frames.astype(np.float32).T
56
+
57
+
58
+ def trim_long_silences(wav):
59
+ """
60
+ Ensures that segments without voice in the waveform remain no longer than a
61
+ threshold determined by the VAD parameters in params.py.
62
+
63
+ :param wav: the raw waveform as a numpy array of floats
64
+ :return: the same waveform with silences trimmed away (length <= original wav length)
65
+ """
66
+ # Compute the voice detection window size
67
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
68
+
69
+ # Trim the end of the audio to have a multiple of the window size
70
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71
+
72
+ # Convert the float waveform to 16-bit mono PCM
73
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74
+
75
+ # Perform voice activation detection
76
+ voice_flags = []
77
+ vad = webrtcvad.Vad(mode=3)
78
+ for window_start in range(0, len(wav), samples_per_window):
79
+ window_end = window_start + samples_per_window
80
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81
+ sample_rate=sampling_rate))
82
+ voice_flags = np.array(voice_flags)
83
+
84
+ # Smooth the voice detection with a moving average
85
+ def moving_average(array, width):
86
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87
+ ret = np.cumsum(array_padded, dtype=float)
88
+ ret[width:] = ret[width:] - ret[:-width]
89
+ return ret[width - 1:] / width
90
+
91
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
92
+ audio_mask = np.round(audio_mask).astype(np.bool)
93
+
94
+ # Dilate the voiced regions
95
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96
+ audio_mask = np.repeat(audio_mask, samples_per_window)
97
+
98
+ return wav[audio_mask == True]
99
+
100
+
101
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102
+ if increase_only and decrease_only:
103
+ raise ValueError("Both increase only and decrease only are set")
104
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106
+ return wav
107
+ return wav * (10 ** (dBFS_change / 20))
speaker_encoder/ckpt/pretrained_bak_5805000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
3
+ size 17090379
speaker_encoder/compute_embed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder import inference as encoder
2
+ from multiprocessing.pool import Pool
3
+ from functools import partial
4
+ from pathlib import Path
5
+ # from utils import logmmse
6
+ # from tqdm import tqdm
7
+ # import numpy as np
8
+ # import librosa
9
+
10
+
11
+ def embed_utterance(fpaths, encoder_model_fpath):
12
+ if not encoder.is_loaded():
13
+ encoder.load_model(encoder_model_fpath)
14
+
15
+ # Compute the speaker embedding of the utterance
16
+ wav_fpath, embed_fpath = fpaths
17
+ wav = np.load(wav_fpath)
18
+ wav = encoder.preprocess_wav(wav)
19
+ embed = encoder.embed_utterance(wav)
20
+ np.save(embed_fpath, embed, allow_pickle=False)
21
+
22
+
23
+ def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24
+
25
+ wav_dir = outdir_root.joinpath("audio")
26
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
27
+ assert wav_dir.exists() and metadata_fpath.exists()
28
+ embed_dir = synthesizer_root.joinpath("embeds")
29
+ embed_dir.mkdir(exist_ok=True)
30
+
31
+ # Gather the input wave filepath and the target output embed filepath
32
+ with metadata_fpath.open("r") as metadata_file:
33
+ metadata = [line.split("|") for line in metadata_file]
34
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35
+
36
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37
+ # Embed the utterances in separate threads
38
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39
+ job = Pool(n_processes).imap(func, fpaths)
40
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
speaker_encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
speaker_encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
speaker_encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
speaker_encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
speaker_encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
speaker_encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+ from speaker_encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
speaker_encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
speaker_encoder/hparams.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Mel-filterbank
2
+ mel_window_length = 25 # In milliseconds
3
+ mel_window_step = 10 # In milliseconds
4
+ mel_n_channels = 40
5
+
6
+
7
+ ## Audio
8
+ sampling_rate = 16000
9
+ # Number of spectrogram frames in a partial utterance
10
+ partials_n_frames = 160 # 1600 ms
11
+
12
+
13
+ ## Voice Activation Detection
14
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15
+ # This sets the granularity of the VAD. Should not need to be changed.
16
+ vad_window_length = 30 # In milliseconds
17
+ # Number of frames to average together when performing the moving average smoothing.
18
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
19
+ vad_moving_average_width = 8
20
+ # Maximum number of consecutive silent frames a segment can have.
21
+ vad_max_silence_length = 6
22
+
23
+
24
+ ## Audio volume normalization
25
+ audio_norm_target_dBFS = -30
26
+
27
+
28
+ ## Model parameters
29
+ model_hidden_size = 256
30
+ model_embedding_size = 256
31
+ model_num_layers = 3
speaker_encoder/inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_data import *
2
+ from speaker_encoder.model import SpeakerEncoder
3
+ from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ cbar.set_clim(*color_range)
175
+
176
+ ax.set_xticks([]), ax.set_yticks([])
177
+ ax.set_title(title)
speaker_encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_model import *
2
+ from speaker_encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19
+ hidden_size=model_hidden_size, # 256
20
+ num_layers=model_num_layers, # 3
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
speaker_encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
speaker_encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
speaker_encoder/preprocess.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from speaker_encoder.params_data import *
3
+ from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from speaker_encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+
121
+ # Function to preprocess utterances for one speaker
122
+ def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
123
+ # Give a name to the speaker that includes its dataset
124
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
125
+
126
+ # Create an output directory with that name, as well as a txt file containing a
127
+ # reference to each source file.
128
+ speaker_out_dir = out_dir.joinpath(speaker_name)
129
+ speaker_out_dir.mkdir(exist_ok=True)
130
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
131
+
132
+ # There's a possibility that the preprocessing was interrupted earlier, check if
133
+ # there already is a sources file.
134
+ # if sources_fpath.exists():
135
+ # try:
136
+ # with sources_fpath.open("r") as sources_file:
137
+ # existing_fnames = {line.split(",")[0] for line in sources_file}
138
+ # except:
139
+ # existing_fnames = {}
140
+ # else:
141
+ # existing_fnames = {}
142
+ existing_fnames = {}
143
+ # Gather all audio files for that speaker recursively
144
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
145
+
146
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
147
+ # Check if the target output file already exists
148
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
149
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
150
+ if skip_existing and out_fname in existing_fnames:
151
+ continue
152
+
153
+ # Load and preprocess the waveform
154
+ wav = audio.preprocess_wav(in_fpath)
155
+ if len(wav) == 0:
156
+ continue
157
+
158
+ # Create the mel spectrogram, discard those that are too short
159
+ frames = audio.wav_to_mel_spectrogram(wav)
160
+ if len(frames) < partials_n_frames:
161
+ continue
162
+
163
+ out_fpath = speaker_out_dir.joinpath(out_fname)
164
+ np.save(out_fpath, frames)
165
+ # logger.add_sample(duration=len(wav) / sampling_rate)
166
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
167
+
168
+ sources_file.close()
169
+ return len(wav)
170
+
171
+ def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
172
+ skip_existing, logger):
173
+ # from multiprocessing import Pool, cpu_count
174
+ from pathos.multiprocessing import ProcessingPool as Pool
175
+ # Function to preprocess utterances for one speaker
176
+ def __preprocess_speaker(speaker_dir: Path):
177
+ # Give a name to the speaker that includes its dataset
178
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
179
+
180
+ # Create an output directory with that name, as well as a txt file containing a
181
+ # reference to each source file.
182
+ speaker_out_dir = out_dir.joinpath(speaker_name)
183
+ speaker_out_dir.mkdir(exist_ok=True)
184
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
185
+
186
+ existing_fnames = {}
187
+ # Gather all audio files for that speaker recursively
188
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
189
+ wav_lens = []
190
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
191
+ # Check if the target output file already exists
192
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
193
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
194
+ if skip_existing and out_fname in existing_fnames:
195
+ continue
196
+
197
+ # Load and preprocess the waveform
198
+ wav = audio.preprocess_wav(in_fpath)
199
+ if len(wav) == 0:
200
+ continue
201
+
202
+ # Create the mel spectrogram, discard those that are too short
203
+ frames = audio.wav_to_mel_spectrogram(wav)
204
+ if len(frames) < partials_n_frames:
205
+ continue
206
+
207
+ out_fpath = speaker_out_dir.joinpath(out_fname)
208
+ np.save(out_fpath, frames)
209
+ # logger.add_sample(duration=len(wav) / sampling_rate)
210
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
211
+ wav_lens.append(len(wav))
212
+ sources_file.close()
213
+ return wav_lens
214
+
215
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
216
+ # Process the utterances for each speaker
217
+ # with ThreadPool(8) as pool:
218
+ # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
219
+ # unit="speakers"))
220
+ pool = Pool(processes=20)
221
+ for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
222
+ for wav_len in wav_lens:
223
+ logger.add_sample(duration=wav_len / sampling_rate)
224
+ print(f'{i}/{len(speaker_dirs)} \r')
225
+
226
+ logger.finalize()
227
+ print("Done preprocessing %s.\n" % dataset_name)
228
+
229
+
230
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
231
+ for dataset_name in librispeech_datasets["train"]["other"]:
232
+ # Initialize the preprocessing
233
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
234
+ if not dataset_root:
235
+ return
236
+
237
+ # Preprocess all speakers
238
+ speaker_dirs = list(dataset_root.glob("*"))
239
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
240
+ skip_existing, logger)
241
+
242
+
243
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
244
+ # Initialize the preprocessing
245
+ dataset_name = "VoxCeleb1"
246
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
247
+ if not dataset_root:
248
+ return
249
+
250
+ # Get the contents of the meta file
251
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
252
+ metadata = [line.split("\t") for line in metafile][1:]
253
+
254
+ # Select the ID and the nationality, filter out non-anglophone speakers
255
+ nationalities = {line[0]: line[3] for line in metadata}
256
+ # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
257
+ # nationality.lower() in anglophone_nationalites]
258
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
259
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
260
+ (len(keep_speaker_ids), len(nationalities)))
261
+
262
+ # Get the speaker directories for anglophone speakers only
263
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
264
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
265
+ speaker_dir.name in keep_speaker_ids]
266
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
267
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
268
+
269
+ # Preprocess all speakers
270
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
271
+ skip_existing, logger)
272
+
273
+
274
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
275
+ # Initialize the preprocessing
276
+ dataset_name = "VoxCeleb2"
277
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
278
+ if not dataset_root:
279
+ return
280
+
281
+ # Get the speaker directories
282
+ # Preprocess all speakers
283
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
284
+ _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
285
+ skip_existing, logger)
speaker_encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.visualizations import Visualizations
2
+ from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from speaker_encoder.params_model import *
4
+ from speaker_encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # FIXME
11
+ return
12
+ # For correct profiling (cuda operations are async)
13
+ if device.type == "cuda":
14
+ torch.cuda.synchronize(device)
15
+
16
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
17
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
18
+ no_visdom: bool):
19
+ # Create a dataset and a dataloader
20
+ dataset = SpeakerVerificationDataset(clean_data_root)
21
+ loader = SpeakerVerificationDataLoader(
22
+ dataset,
23
+ speakers_per_batch, # 64
24
+ utterances_per_speaker, # 10
25
+ num_workers=8,
26
+ )
27
+
28
+ # Setup the device on which to run the forward pass and the loss. These can be different,
29
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
30
+ # hyperparameters) faster on the CPU.
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ # FIXME: currently, the gradient is None if loss_device is cuda
33
+ loss_device = torch.device("cpu")
34
+
35
+ # Create the model and the optimizer
36
+ model = SpeakerEncoder(device, loss_device)
37
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
38
+ init_step = 1
39
+
40
+ # Configure file path for the model
41
+ state_fpath = models_dir.joinpath(run_id + ".pt")
42
+ backup_dir = models_dir.joinpath(run_id + "_backups")
43
+
44
+ # Load any existing model
45
+ if not force_restart:
46
+ if state_fpath.exists():
47
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
48
+ checkpoint = torch.load(state_fpath)
49
+ init_step = checkpoint["step"]
50
+ model.load_state_dict(checkpoint["model_state"])
51
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
52
+ optimizer.param_groups[0]["lr"] = learning_rate_init
53
+ else:
54
+ print("No model \"%s\" found, starting training from scratch." % run_id)
55
+ else:
56
+ print("Starting the training from scratch.")
57
+ model.train()
58
+
59
+ # Initialize the visualization environment
60
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
61
+ vis.log_dataset(dataset)
62
+ vis.log_params()
63
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
64
+ vis.log_implementation({"Device": device_name})
65
+
66
+ # Training loop
67
+ profiler = Profiler(summarize_every=10, disabled=False)
68
+ for step, speaker_batch in enumerate(loader, init_step):
69
+ profiler.tick("Blocking, waiting for batch (threaded)")
70
+
71
+ # Forward pass
72
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
73
+ sync(device)
74
+ profiler.tick("Data to %s" % device)
75
+ embeds = model(inputs)
76
+ sync(device)
77
+ profiler.tick("Forward pass")
78
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
79
+ loss, eer = model.loss(embeds_loss)
80
+ sync(loss_device)
81
+ profiler.tick("Loss")
82
+
83
+ # Backward pass
84
+ model.zero_grad()
85
+ loss.backward()
86
+ profiler.tick("Backward pass")
87
+ model.do_gradient_ops()
88
+ optimizer.step()
89
+ profiler.tick("Parameter update")
90
+
91
+ # Update visualizations
92
+ # learning_rate = optimizer.param_groups[0]["lr"]
93
+ vis.update(loss.item(), eer, step)
94
+
95
+ # Draw projections and save them to the backup folder
96
+ if umap_every != 0 and step % umap_every == 0:
97
+ print("Drawing and saving projections (step %d)" % step)
98
+ backup_dir.mkdir(exist_ok=True)
99
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
100
+ embeds = embeds.detach().cpu().numpy()
101
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
102
+ vis.save()
103
+
104
+ # Overwrite the latest version of the model
105
+ if save_every != 0 and step % save_every == 0:
106
+ print("Saving the model (step %d)" % step)
107
+ torch.save({
108
+ "step": step + 1,
109
+ "model_state": model.state_dict(),
110
+ "optimizer_state": optimizer.state_dict(),
111
+ }, state_fpath)
112
+
113
+ # Make a backup
114
+ if backup_every != 0 and step % backup_every == 0:
115
+ print("Making a backup (step %d)" % step)
116
+ backup_dir.mkdir(exist_ok=True)
117
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
118
+ torch.save({
119
+ "step": step + 1,
120
+ "model_state": model.state_dict(),
121
+ "optimizer_state": optimizer.state_dict(),
122
+ }, backup_fpath)
123
+
124
+ profiler.tick("Extras (visualizations, saving)")
125
+
speaker_encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from speaker_encoder import params_data
69
+ from speaker_encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
speaker_encoder/voice_encoder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.hparams import *
2
+ from speaker_encoder import audio
3
+ from pathlib import Path
4
+ from typing import Union, List
5
+ from torch import nn
6
+ from time import perf_counter as timer
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class SpeakerEncoder(nn.Module):
12
+ def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
13
+ """
14
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
15
+ If None, defaults to cuda if it is available on your machine, otherwise the model will
16
+ run on cpu. Outputs are always returned on the cpu, as numpy arrays.
17
+ """
18
+ super().__init__()
19
+
20
+ # Define the network
21
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
22
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
23
+ self.relu = nn.ReLU()
24
+
25
+ # Get the target device
26
+ if device is None:
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ elif isinstance(device, str):
29
+ device = torch.device(device)
30
+ self.device = device
31
+
32
+ # Load the pretrained model'speaker weights
33
+ # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
34
+ # if not weights_fpath.exists():
35
+ # raise Exception("Couldn't find the voice encoder pretrained model at %s." %
36
+ # weights_fpath)
37
+
38
+ start = timer()
39
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
40
+
41
+ self.load_state_dict(checkpoint["model_state"], strict=False)
42
+ self.to(device)
43
+
44
+ if verbose:
45
+ print("Loaded the voice encoder model on %s in %.2f seconds." %
46
+ (device.type, timer() - start))
47
+
48
+ def forward(self, mels: torch.FloatTensor):
49
+ """
50
+ Computes the embeddings of a batch of utterance spectrograms.
51
+ :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
52
+ (batch_size, n_frames, n_channels)
53
+ :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
54
+ Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
55
+ """
56
+ # Pass the input through the LSTM layers and retrieve the final hidden state of the last
57
+ # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
58
+ _, (hidden, _) = self.lstm(mels)
59
+ embeds_raw = self.relu(self.linear(hidden[-1]))
60
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
61
+
62
+ @staticmethod
63
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
64
+ """
65
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to
66
+ obtain partial utterances of <partials_n_frames> each. Both the waveform and the
67
+ mel spectrogram slices are returned, so as to make each partial utterance waveform
68
+ correspond to its spectrogram.
69
+
70
+ The returned ranges may be indexing further than the length of the waveform. It is
71
+ recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
72
+
73
+ :param n_samples: the number of samples in the waveform
74
+ :param rate: how many partial utterances should occur per second. Partial utterances must
75
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
76
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
77
+ the minimum rate is thus 0.625.
78
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
79
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
80
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
81
+ it will be discarded. If there aren't enough frames for one partial utterance,
82
+ this parameter is ignored so that the function always returns at least one slice.
83
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
84
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
85
+ utterances.
86
+ """
87
+ assert 0 < min_coverage <= 1
88
+
89
+ # Compute how many frames separate two partial utterances
90
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
91
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
92
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
93
+ assert 0 < frame_step, "The rate is too high"
94
+ assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
95
+ (sampling_rate / (samples_per_frame * partials_n_frames))
96
+
97
+ # Compute the slices
98
+ wav_slices, mel_slices = [], []
99
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
100
+ for i in range(0, steps, frame_step):
101
+ mel_range = np.array([i, i + partials_n_frames])
102
+ wav_range = mel_range * samples_per_frame
103
+ mel_slices.append(slice(*mel_range))
104
+ wav_slices.append(slice(*wav_range))
105
+
106
+ # Evaluate whether extra padding is warranted or not
107
+ last_wav_range = wav_slices[-1]
108
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
109
+ if coverage < min_coverage and len(mel_slices) > 1:
110
+ mel_slices = mel_slices[:-1]
111
+ wav_slices = wav_slices[:-1]
112
+
113
+ return wav_slices, mel_slices
114
+
115
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
116
+ """
117
+ Computes an embedding for a single utterance. The utterance is divided in partial
118
+ utterances and an embedding is computed for each. The complete utterance embedding is the
119
+ L2-normed average embedding of the partial utterances.
120
+
121
+ TODO: independent batched version of this function
122
+
123
+ :param wav: a preprocessed utterance waveform as a numpy array of float32
124
+ :param return_partials: if True, the partial embeddings will also be returned along with
125
+ the wav slices corresponding to each partial utterance.
126
+ :param rate: how many partial utterances should occur per second. Partial utterances must
127
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
128
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
129
+ the minimum rate is thus 0.625.
130
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
131
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
132
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
133
+ it will be discarded. If there aren't enough frames for one partial utterance,
134
+ this parameter is ignored so that the function always returns at least one slice.
135
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
136
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
137
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
138
+ returned.
139
+ """
140
+ # Compute where to split the utterance into partials and pad the waveform with zeros if
141
+ # the partial utterances cover a larger range.
142
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
143
+ max_wave_length = wav_slices[-1].stop
144
+ if max_wave_length >= len(wav):
145
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
146
+
147
+ # Split the utterance into partials and forward them through the model
148
+ mel = audio.wav_to_mel_spectrogram(wav)
149
+ mels = np.array([mel[s] for s in mel_slices])
150
+ with torch.no_grad():
151
+ mels = torch.from_numpy(mels).to(self.device)
152
+ partial_embeds = self(mels).cpu().numpy()
153
+
154
+ # Compute the utterance embedding from the partial embeddings
155
+ raw_embed = np.mean(partial_embeds, axis=0)
156
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
157
+
158
+ if return_partials:
159
+ return embed, partial_embeds, wav_slices
160
+ return embed
161
+
162
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
163
+ """
164
+ Compute the embedding of a collection of wavs (presumably from the same speaker) by
165
+ averaging their embedding and L2-normalizing it.
166
+
167
+ :param wavs: list of wavs a numpy arrays of float32.
168
+ :param kwargs: extra arguments to embed_utterance()
169
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
170
+ """
171
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
172
+ for wav in wavs], axis=0)
173
+ return raw_embed / np.linalg.norm(raw_embed, 2)
tts_voice.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tts_order_voice = {'英语 (美国)-Jenny-女': 'en-US-JennyNeural',
2
+ '英语 (美国)-Guy-男': 'en-US-GuyNeural',
3
+ '英语 (美国)-Ana-女': 'en-US-AnaNeural',
4
+ '英语 (美国)-Aria-女': 'en-US-AriaNeural',
5
+ '英语 (美国)-Christopher-男': 'en-US-ChristopherNeural',
6
+ '英语 (美国)-Eric-男': 'en-US-EricNeural',
7
+ '英语 (美国)-Michelle-女': 'en-US-MichelleNeural',
8
+ '英语 (美国)-Roger-男': 'en-US-RogerNeural',
9
+ '西班牙语 (墨西哥)-Dalia-女': 'es-MX-DaliaNeural',
10
+ '西班牙语 (墨西哥)-Jorge-男': 'es-MX-JorgeNeural',
11
+ '韩语 (韩国)-Sun-Hi-女': 'ko-KR-SunHiNeural',
12
+ '韩语 (韩国)-InJoon-男': 'ko-KR-InJoonNeural',
13
+ '泰语 (泰国)-Premwadee-女': 'th-TH-PremwadeeNeural',
14
+ '泰语 (泰国)-Niwat-男': 'th-TH-NiwatNeural',
15
+ '越南语 (越南)-HoaiMy-女': 'vi-VN-HoaiMyNeural',
16
+ '越南语 (越南)-NamMinh-男': 'vi-VN-NamMinhNeural',
17
+ '日语 (日本)-Nanami-女': 'ja-JP-NanamiNeural',
18
+ '日语 (日本)-Keita-男': 'ja-JP-KeitaNeural',
19
+ '法语 (法国)-Denise-女': 'fr-FR-DeniseNeural',
20
+ '法语 (法国)-Eloise-女': 'fr-FR-EloiseNeural',
21
+ '法语 (法国)-Henri-男': 'fr-FR-HenriNeural',
22
+ '葡萄牙语 (巴西)-Francisca-女': 'pt-BR-FranciscaNeural',
23
+ '葡萄牙语 (巴西)-Antonio-男': 'pt-BR-AntonioNeural',
24
+ '印度尼西亚语 (印度尼西亚)-Ardi-男': 'id-ID-ArdiNeural',
25
+ '印度尼西亚语 (印度尼西亚)-Gadis-女': 'id-ID-GadisNeural',
26
+ '希伯来语 (以色列)-Avri-男': 'he-IL-AvriNeural',
27
+ '希伯来语 (以色列)-Hila-女': 'he-IL-HilaNeural',
28
+ '意大利语 (意大利)-Isabella-女': 'it-IT-IsabellaNeural',
29
+ '意大利语 (意大利)-Diego-男': 'it-IT-DiegoNeural',
30
+ '意大利语 (意大利)-Elsa-女': 'it-IT-ElsaNeural',
31
+ '荷兰语 (荷兰)-Colette-女': 'nl-NL-ColetteNeural',
32
+ '荷兰语 (荷兰)-Fenna-女': 'nl-NL-FennaNeural',
33
+ '荷兰语 (荷兰)-Maarten-男': 'nl-NL-MaartenNeural',
34
+ '马来语 (马来西亚)-Osman-男': 'ms-MY-OsmanNeural',
35
+ '马来语 (马来西亚)-Yasmin-女': 'ms-MY-YasminNeural',
36
+ '挪威语 (挪威)-Pernille-女': 'nb-NO-PernilleNeural',
37
+ '挪威语 (挪威)-Finn-男': 'nb-NO-FinnNeural',
38
+ '瑞典语 (瑞典)-Sofie-女': 'sv-SE-SofieNeural',
39
+ '瑞典语 (瑞典)-Mattias-男': 'sv-SE-MattiasNeural',
40
+ '阿拉伯语 (沙特阿拉伯)-Hamed-男': 'ar-SA-HamedNeural',
41
+ '阿拉伯语 (沙特阿拉伯)-Zariyah-女': 'ar-SA-ZariyahNeural',
42
+ '希腊语 (希腊)-Athina-女': 'el-GR-AthinaNeural',
43
+ '希腊语 (希腊)-Nestoras-男': 'el-GR-NestorasNeural',
44
+ '德语 (德国)-Katja-女': 'de-DE-KatjaNeural',
45
+ '德语 (德国)-Amala-女': 'de-DE-AmalaNeural',
46
+ '德语 (德国)-Conrad-男': 'de-DE-ConradNeural',
47
+ '德语 (德国)-Killian-男': 'de-DE-KillianNeural',
48
+ '阿拉伯语 (南非)-Adri-女': 'af-ZA-AdriNeural',
49
+ '阿拉伯语 (南非)-Willem-男': 'af-ZA-WillemNeural',
50
+ '阿姆哈拉语 (埃塞俄比亚)-Ameha-男': 'am-ET-AmehaNeural',
51
+ '阿姆哈拉语 (埃塞俄比亚)-Mekdes-女': 'am-ET-MekdesNeural',
52
+ '阿拉伯语 (阿拉伯联合酋长国)-Fatima-女': 'ar-AE-FatimaNeural',
53
+ '阿拉伯语 (阿拉伯联合酋长国)-Hamdan-男': 'ar-AE-HamdanNeural',
54
+ '阿拉伯语 (巴林)-Ali-男': 'ar-BH-AliNeural',
55
+ '阿拉伯语 (巴林)-Laila-女': 'ar-BH-LailaNeural',
56
+ '阿拉伯语 (阿尔及利亚)-Ismael-男': 'ar-DZ-IsmaelNeural',
57
+ '阿拉伯语 (埃及)-Salma-女': 'ar-EG-SalmaNeural',
58
+ '阿拉伯语 (埃及)-Shakir-男': 'ar-EG-ShakirNeural',
59
+ '阿拉伯语 (伊拉克)-Bassel-男': 'ar-IQ-BasselNeural',
60
+ '阿拉伯语 (伊拉克)-Rana-女': 'ar-IQ-RanaNeural',
61
+ '阿拉伯语 (约旦)-Sana-女': 'ar-JO-SanaNeural',
62
+ '阿拉伯语 (约旦)-Taim-男': 'ar-JO-TaimNeural',
63
+ '阿拉伯语 (科威特)-Fahed-男': 'ar-KW-FahedNeural',
64
+ '阿拉伯语 (科威特)-Noura-女': 'ar-KW-NouraNeural',
65
+ '阿拉伯语 (黎巴嫩)-Layla-女': 'ar-LB-LaylaNeural',
66
+ '阿拉伯语 (黎巴嫩)-Rami-男': 'ar-LB-RamiNeural',
67
+ '阿拉伯语 (利比亚)-Iman-女': 'ar-LY-ImanNeural',
68
+ '阿拉伯语 (利比亚)-Omar-男': 'ar-LY-OmarNeural',
69
+ '阿拉伯语 (摩洛哥)-Jamal-男': 'ar-MA-JamalNeural',
70
+ '阿拉伯语 (摩洛哥)-Mouna-女': 'ar-MA-MounaNeural',
71
+ '阿拉伯语 (阿曼)-Abdullah-男': 'ar-OM-AbdullahNeural',
72
+ '阿拉伯语 (阿曼)-Aysha-女': 'ar-OM-AyshaNeural',
73
+ '阿拉伯语 (卡塔尔)-Amal-女': 'ar-QA-AmalNeural',
74
+ '阿拉伯语 (卡塔尔)-Moaz-男': 'ar-QA-MoazNeural',
75
+ '阿拉伯语 (叙利亚)-Amany-女': 'ar-SY-AmanyNeural',
76
+ '阿拉伯语 (叙利亚)-Laith-男': 'ar-SY-LaithNeural',
77
+ '阿拉伯语 (突尼斯)-Hedi-男': 'ar-TN-HediNeural',
78
+ '阿拉伯语 (突尼斯)-Reem-女': 'ar-TN-ReemNeural',
79
+ '阿拉伯语 (也门)-Maryam-女': 'ar-YE-MaryamNeural',
80
+ '阿拉伯语 (也门)-Saleh-男': 'ar-YE-SalehNeural',
81
+ '阿塞拜疆语 (阿塞拜疆)-Babek-男': 'az-AZ-BabekNeural',
82
+ '阿塞拜疆语 (阿塞拜疆)-Banu-女': 'az-AZ-BanuNeural',
83
+ '保加利亚语 (保加利亚)-Borislav-男': 'bg-BG-BorislavNeural',
84
+ '保加利亚语 (保加利亚)-Kalina-女': 'bg-BG-KalinaNeural',
85
+ '孟加拉语 (孟加拉国)-Nabanita-女': 'bn-BD-NabanitaNeural',
86
+ '孟加拉语 (孟加拉国)-Pradeep-男': 'bn-BD-PradeepNeural',
87
+ '孟加拉语 (印度)-Bashkar-男': 'bn-IN-BashkarNeural',
88
+ '孟加拉语 (印度)-Tanishaa-女': 'bn-IN-TanishaaNeural',
89
+ '波斯尼亚语 (波斯尼亚和黑塞哥维那)-Goran-男': 'bs-BA-GoranNeural',
90
+ '波斯尼亚语 (波斯尼亚和黑塞哥维那)-Vesna-女': 'bs-BA-VesnaNeural',
91
+ '加泰罗尼亚语 (西班牙)-Joana-女': 'ca-ES-JoanaNeural',
92
+ '加泰罗尼亚语 (西班牙)-Enric-男': 'ca-ES-EnricNeural',
93
+ '捷克语 (捷克共和国)-Antonin-男': 'cs-CZ-AntoninNeural',
94
+ '捷克语 (捷克共和国)-Vlasta-女': 'cs-CZ-VlastaNeural',
95
+ '威尔士语 (英国)-Aled-男': 'cy-GB-AledNeural',
96
+ '威尔士语 (英国)-Nia-女': 'cy-GB-NiaNeural',
97
+ '丹麦语 (丹麦)-Christel-女': 'da-DK-ChristelNeural',
98
+ '丹麦语 (丹麦)-Jeppe-男': 'da-DK-JeppeNeural',
99
+ '德语 (奥地利)-Ingrid-女': 'de-AT-IngridNeural',
100
+ '德语 (奥地利)-Jonas-男': 'de-AT-JonasNeural',
101
+ '德语 (瑞士)-Jan-男': 'de-CH-JanNeural',
102
+ '德语 (瑞士)-Leni-女': 'de-CH-LeniNeural',
103
+ '英语 (澳大利亚)-Natasha-女': 'en-AU-NatashaNeural',
104
+ '英语 (澳大利亚)-William-男': 'en-AU-WilliamNeural',
105
+ '英语 (加拿大)-Clara-女': 'en-CA-ClaraNeural',
106
+ '英语 (加拿大)-Liam-男': 'en-CA-LiamNeural',
107
+ '英语 (英国)-Libby-女': 'en-GB-LibbyNeural',
108
+ '英语 (英国)-Maisie-女': 'en-GB-MaisieNeural',
109
+ '英语 (英国)-Ryan-男': 'en-GB-RyanNeural',
110
+ '英语 (英国)-Sonia-女': 'en-GB-SoniaNeural',
111
+ '英语 (英国)-Thomas-男': 'en-GB-ThomasNeural',
112
+ '英语 (香港)-Sam-男': 'en-HK-SamNeural',
113
+ '英语 (香港)-Yan-女': 'en-HK-YanNeural',
114
+ '英语 (爱尔兰)-Connor-男': 'en-IE-ConnorNeural',
115
+ '英语 (爱尔兰)-Emily-女': 'en-IE-EmilyNeural',
116
+ '英语 (印度)-Neerja-女': 'en-IN-NeerjaNeural',
117
+ '英语 (印度)-Prabhat-男': 'en-IN-PrabhatNeural',
118
+ '英语 (肯尼亚)-Asilia-女': 'en-KE-AsiliaNeural',
119
+ '英语 (肯尼亚)-Chilemba-男': 'en-KE-ChilembaNeural',
120
+ '英语 (尼日利亚)-Abeo-男': 'en-NG-AbeoNeural',
121
+ '英语 (尼日利亚)-Ezinne-女': 'en-NG-EzinneNeural',
122
+ '英语 (新西兰)-Mitchell-男': 'en-NZ-MitchellNeural',
123
+ '英语 (菲律宾)-James-男': 'en-PH-JamesNeural',
124
+ '英语 (菲律宾)-Rosa-女': 'en-PH-RosaNeural',
125
+ '英语 (新加坡)-Luna-女': 'en-SG-LunaNeural',
126
+ '英语 (新加坡)-Wayne-男': 'en-SG-WayneNeural',
127
+ '英语 (坦桑尼亚)-Elimu-男': 'en-TZ-ElimuNeural',
128
+ '英语 (坦桑尼亚)-Imani-女': 'en-TZ-ImaniNeural',
129
+ '英语 (南非)-Leah-女': 'en-ZA-LeahNeural',
130
+ '英语 (南非)-Luke-男': 'en-ZA-LukeNeural',
131
+ '西班牙语 (阿根廷)-Elena-女': 'es-AR-ElenaNeural',
132
+ '西班牙语 (阿根廷)-Tomas-男': 'es-AR-TomasNeural',
133
+ '西班牙语 (玻利维亚)-Marcelo-男': 'es-BO-MarceloNeural',
134
+ '西班牙语 (玻利维亚)-Sofia-女': 'es-BO-SofiaNeural',
135
+ '西班牙语 (哥伦比亚)-Gonzalo-男': 'es-CO-GonzaloNeural',
136
+ '西班牙语 (哥伦比亚)-Salome-女': 'es-CO-SalomeNeural',
137
+ '西班牙语 (哥斯达黎加)-Juan-男': 'es-CR-JuanNeural',
138
+ '西班牙语 (哥斯达黎加)-Maria-女': 'es-CR-MariaNeural',
139
+ '西班牙语 (古巴)-Belkys-女': 'es-CU-BelkysNeural',
140
+ '西班牙语 (多米尼加共和国)-Emilio-男': 'es-DO-EmilioNeural',
141
+ '西班牙语 (多米尼加共和国)-Ramona-女': 'es-DO-RamonaNeural',
142
+ '西班牙语 (厄瓜多尔)-Andrea-女': 'es-EC-AndreaNeural',
143
+ '西班牙语 (厄瓜多尔)-Luis-男': 'es-EC-LuisNeural',
144
+ '西班牙语 (西班牙)-Alvaro-男': 'es-ES-AlvaroNeural',
145
+ '西班牙语 (西班牙)-Elvira-女': 'es-ES-ElviraNeural',
146
+ '西班牙语 (赤道几内亚)-Teresa-女': 'es-GQ-TeresaNeural',
147
+ '西班牙语 (危地马拉)-Andres-男': 'es-GT-AndresNeural',
148
+ '西班牙语 (危地马拉)-Marta-女': 'es-GT-MartaNeural',
149
+ '西班牙语 (洪都拉斯)-Carlos-男': 'es-HN-CarlosNeural',
150
+ '西班牙语 (洪都拉斯)-Karla-女': 'es-HN-KarlaNeural',
151
+ '西班牙语 (尼加拉瓜)-Federico-男': 'es-NI-FedericoNeural',
152
+ '西班牙语 (尼加拉瓜)-Yolanda-女': 'es-NI-YolandaNeural',
153
+ '西班牙语 (巴拿马)-Margarita-女': 'es-PA-MargaritaNeural',
154
+ '西班牙语 (巴拿马)-Roberto-男': 'es-PA-RobertoNeural',
155
+ '西班牙语 (秘鲁)-Alex-男': 'es-PE-AlexNeural',
156
+ '西班牙语 (秘鲁)-Camila-女': 'es-PE-CamilaNeural',
157
+ '西班牙语 (波多黎各)-Karina-女': 'es-PR-KarinaNeural',
158
+ '西班牙语 (波多黎各)-Victor-男': 'es-PR-VictorNeural',
159
+ '西班牙语 (巴拉圭)-Mario-男': 'es-PY-MarioNeural',
160
+ '西班牙语 (巴拉圭)-Tania-女': 'es-PY-TaniaNeural',
161
+ '西班牙语 (萨尔瓦多)-Lorena-女': 'es-SV-LorenaNeural',
162
+ '西班牙语 (萨尔瓦多)-Rodrigo-男': 'es-SV-RodrigoNeural',
163
+ '西班牙语 (美国)-Alonso-男': 'es-US-AlonsoNeural',
164
+ '西班牙语 (美国)-Paloma-女': 'es-US-PalomaNeural',
165
+ '西班牙语 (乌拉圭)-Mateo-男': 'es-UY-MateoNeural',
166
+ '西班牙语 (乌拉圭)-Valentina-女': 'es-UY-ValentinaNeural',
167
+ '西班牙语 (委内瑞拉)-Paola-女': 'es-VE-PaolaNeural',
168
+ '西班牙语 (委内瑞拉)-Sebastian-男': 'es-VE-SebastianNeural',
169
+ '爱沙尼亚语 (爱沙尼亚)-Anu-���': 'et-EE-AnuNeural',
170
+ '爱沙尼亚语 (爱沙尼亚)-Kert-男': 'et-EE-KertNeural',
171
+ '波斯语 (伊朗)-Dilara-女': 'fa-IR-DilaraNeural',
172
+ '波斯语 (伊朗)-Farid-男': 'fa-IR-FaridNeural',
173
+ '芬兰语 (芬兰)-Harri-男': 'fi-FI-HarriNeural',
174
+ '芬兰语 (芬兰)-Noora-女': 'fi-FI-NooraNeural',
175
+ '法语 (比利时)-Charline-女': 'fr-BE-CharlineNeural',
176
+ '法语 (比利时)-Gerard-男': 'fr-BE-GerardNeural',
177
+ '法语 (加拿大)-Sylvie-女': 'fr-CA-SylvieNeural',
178
+ '法语 (加拿大)-Antoine-男': 'fr-CA-AntoineNeural',
179
+ '法语 (加拿大)-Jean-男': 'fr-CA-JeanNeural',
180
+ '法语 (瑞士)-Ariane-女': 'fr-CH-ArianeNeural',
181
+ '法语 (瑞士)-Fabrice-男': 'fr-CH-FabriceNeural',
182
+ '爱尔兰语 (爱尔兰)-Colm-男': 'ga-IE-ColmNeural',
183
+ '爱尔兰语 (爱尔兰)-Orla-女': 'ga-IE-OrlaNeural',
184
+ '加利西亚语 (西班牙)-Roi-男': 'gl-ES-RoiNeural',
185
+ '加利西亚语 (西班牙)-Sabela-女': 'gl-ES-SabelaNeural',
186
+ '古吉拉特语 (印度)-Dhwani-女': 'gu-IN-DhwaniNeural',
187
+ '古吉拉特语 (印度)-Niranjan-男': 'gu-IN-NiranjanNeural',
188
+ '印地语 (印度)-Madhur-男': 'hi-IN-MadhurNeural',
189
+ '印地语 (印度)-Swara-女': 'hi-IN-SwaraNeural',
190
+ '克罗地亚语 (克罗地亚)-Gabrijela-女': 'hr-HR-GabrijelaNeural',
191
+ '克罗地亚语 (克罗地亚)-Srecko-男': 'hr-HR-SreckoNeural',
192
+ '匈牙利语 (匈牙利)-Noemi-女': 'hu-HU-NoemiNeural',
193
+ '匈牙利语 (匈牙利)-Tamas-男': 'hu-HU-TamasNeural',
194
+ '冰岛语 (冰岛)-Gudrun-女': 'is-IS-GudrunNeural',
195
+ '冰岛语 (冰岛)-Gunnar-男': 'is-IS-GunnarNeural',
196
+ '爪哇语 (印度尼西亚)-Dimas-男': 'jv-ID-DimasNeural',
197
+ '爪哇语 (印度尼西亚)-Siti-女': 'jv-ID-SitiNeural',
198
+ '格鲁吉亚语 (格鲁吉亚)-Eka-女': 'ka-GE-EkaNeural',
199
+ '格鲁吉亚语 (格鲁吉亚)-Giorgi-男': 'ka-GE-GiorgiNeural',
200
+ '哈萨克语 (哈萨克斯坦)-Aigul-女': 'kk-KZ-AigulNeural',
201
+ '哈萨克语 (哈萨克斯坦)-Daulet-男': 'kk-KZ-DauletNeural',
202
+ '高棉语 (柬埔寨)-Piseth-男': 'km-KH-PisethNeural',
203
+ '高棉语 (柬埔寨)-Sreymom-女': 'km-KH-SreymomNeural',
204
+ '卡纳达语 (印度)-Gagan-男': 'kn-IN-GaganNeural',
205
+ '卡纳达语 (印度)-Sapna-女': 'kn-IN-SapnaNeural',
206
+ '老挝语 (老挝)-Chanthavong-男': 'lo-LA-ChanthavongNeural',
207
+ '老挝语 (老挝)-Keomany-女': 'lo-LA-KeomanyNeural',
208
+ '立陶宛语 (立陶宛)-Leonas-男': 'lt-LT-LeonasNeural',
209
+ '立陶宛语 (立陶宛)-Ona-女': 'lt-LT-OnaNeural',
210
+ '拉脱维亚语 (拉脱维亚)-Everita-女': 'lv-LV-EveritaNeural',
211
+ '拉脱维亚语 (拉脱维亚)-Nils-男': 'lv-LV-NilsNeural',
212
+ '马其顿语 (北马其顿共和国)-Aleksandar-男': 'mk-MK-AleksandarNeural',
213
+ '马其顿语 (北马其顿共和国)-Marija-女': 'mk-MK-MarijaNeural',
214
+ '马拉雅拉姆语 (印度)-Midhun-男': 'ml-IN-MidhunNeural',
215
+ '马拉雅拉姆语 (印度)-Sobhana-女': 'ml-IN-SobhanaNeural',
216
+ '蒙古语 (蒙古)-Bataa-男': 'mn-MN-BataaNeural',
217
+ '蒙古语 (蒙古)-Yesui-女': 'mn-MN-YesuiNeural',
218
+ '马拉地语 (印度)-Aarohi-女': 'mr-IN-AarohiNeural',
219
+ '马拉地语 (印度)-Manohar-男': 'mr-IN-ManoharNeural',
220
+ '马耳他语 (马耳他)-Grace-女': 'mt-MT-GraceNeural',
221
+ '马耳他语 (马耳他)-Joseph-男': 'mt-MT-JosephNeural',
222
+ '缅甸语 (缅甸)-Nilar-女': 'my-MM-NilarNeural',
223
+ '缅甸语 (缅甸)-Thiha-男': 'my-MM-ThihaNeural',
224
+ '尼泊尔语 (尼泊尔)-Hemkala-女': 'ne-NP-HemkalaNeural',
225
+ '尼泊尔语 (尼泊尔)-Sagar-男': 'ne-NP-SagarNeural',
226
+ '荷兰语 (比利时)-Arnaud-男': 'nl-BE-ArnaudNeural',
227
+ '荷兰语 (比利时)-Dena-女': 'nl-BE-DenaNeural',
228
+ '波兰语 (波兰)-Marek-男': 'pl-PL-MarekNeural',
229
+ '波兰语 (波兰)-Zofia-女': 'pl-PL-ZofiaNeural',
230
+ '普什图语 (阿富汗)-Gul Nawaz-男': 'ps-AF-GulNawazNeural',
231
+ '普什图语 (阿富汗)-Latifa-女': 'ps-AF-LatifaNeural',
232
+ '葡萄牙语 (葡萄牙)-Duarte-男': 'pt-PT-DuarteNeural',
233
+ '葡萄牙语 (葡萄牙)-Raquel-女': 'pt-PT-RaquelNeural',
234
+ '罗马尼亚语 (罗马尼亚)-Alina-女': 'ro-RO-AlinaNeural',
235
+ '罗马尼亚语 (罗马尼亚)-Emil-男': 'ro-RO-EmilNeural',
236
+ '俄语 (俄罗斯)-Svetlana-女': 'ru-RU-SvetlanaNeural',
237
+ '俄语 (俄罗斯)-Dmitry-男': 'ru-RU-DmitryNeural',
238
+ '僧伽罗语 (斯里兰卡)-Sameera-男': 'si-LK-SameeraNeural',
239
+ '僧伽罗语 (斯里兰卡)-Thilini-女': 'si-LK-ThiliniNeural',
240
+ '斯洛伐克语 (斯洛伐克)-Lukas-男': 'sk-SK-LukasNeural',
241
+ '斯洛伐克语 (斯洛伐克)-Viktoria-女': 'sk-SK-ViktoriaNeural',
242
+ '斯洛文尼亚语 (斯洛文尼亚)-Petra-女': 'sl-SI-PetraNeural',
243
+ '斯洛文尼亚语 (斯洛文尼亚)-Rok-男': 'sl-SI-RokNeural',
244
+ '索马里语 (索马里)-Muuse-男': 'so-SO-MuuseNeural',
245
+ '索马里语 (索马里)-Ubax-女': 'so-SO-UbaxNeural',
246
+ '阿尔巴尼亚语 (阿尔巴尼亚)-Anila-女': 'sq-AL-AnilaNeural',
247
+ '阿尔巴尼亚语 (阿尔巴尼亚)-Ilir-男': 'sq-AL-IlirNeural',
248
+ '塞尔维亚语 (塞尔维亚)-Nicholas-男': 'sr-RS-NicholasNeural',
249
+ '塞尔维亚语 (塞尔维亚)-Sophie-女': 'sr-RS-SophieNeural',
250
+ '巽他语 (印度尼西亚)-Jajang-男': 'su-ID-JajangNeural',
251
+ '巽他语 (印度尼��亚)-Tuti-女': 'su-ID-TutiNeural',
252
+ '斯瓦希里语 (肯尼亚)-Rafiki-男': 'sw-KE-RafikiNeural',
253
+ '斯瓦希里语 (肯尼亚)-Zuri-女': 'sw-KE-ZuriNeural',
254
+ '斯瓦希里语 (坦桑尼亚)-Daudi-男': 'sw-TZ-DaudiNeural',
255
+ '斯瓦希里语 (坦桑尼亚)-Rehema-女': 'sw-TZ-RehemaNeural',
256
+ '泰米尔语 (印度)-Pallavi-女': 'ta-IN-PallaviNeural',
257
+ '泰米尔语 (印度)-Valluvar-男': 'ta-IN-ValluvarNeural',
258
+ '泰米尔语 (斯里兰卡)-Kumar-男': 'ta-LK-KumarNeural',
259
+ '泰米尔语 (斯里兰卡)-Saranya-女': 'ta-LK-SaranyaNeural',
260
+ '泰米尔语 (马来西亚)-Kani-女': 'ta-MY-KaniNeural',
261
+ '泰米尔语 (马来西亚)-Surya-男': 'ta-MY-SuryaNeural',
262
+ '泰米尔语 (新加坡)-Anbu-男': 'ta-SG-AnbuNeural',
263
+ '泰卢固语 (印度)-Mohan-男': 'te-IN-MohanNeural',
264
+ '泰卢固语 (印度)-Shruti-女': 'te-IN-ShrutiNeural',
265
+ '土耳其语 (土耳其)-Ahmet-男': 'tr-TR-AhmetNeural',
266
+ '土耳其语 (土耳其)-Emel-女': 'tr-TR-EmelNeural',
267
+ '乌克兰语 (乌克兰)-Ostap-男': 'uk-UA-OstapNeural',
268
+ '乌克兰语 (乌克兰)-Polina-女': 'uk-UA-PolinaNeural',
269
+ '乌尔都语 (印度)-Gul-女': 'ur-IN-GulNeural',
270
+ '乌尔都语 (印度)-Salman-男': 'ur-IN-SalmanNeural',
271
+ '乌尔都语 (巴基斯坦)-Asad-男': 'ur-PK-AsadNeural',
272
+ '乌尔都语 (巴基斯坦)-Uzma-女': 'ur-PK-UzmaNeural',
273
+ '乌兹别克语 (乌兹别克斯坦)-Madina-女': 'uz-UZ-MadinaNeural',
274
+ '乌兹别克语 (乌兹别克斯坦)-Sardor-男': 'uz-UZ-SardorNeural',
275
+ '普通话 (中国大陆)-Xiaoxiao-女': 'zh-CN-XiaoxiaoNeural',
276
+ '普通话 (中国大陆)-Yunyang-男': 'zh-CN-YunyangNeural',
277
+ '普通话 (中国大陆)-Yunxi-男': 'zh-CN-YunxiNeural',
278
+ '普通话 (中国大陆)-Xiaoyi-女': 'zh-CN-XiaoyiNeural',
279
+ '普通话 (中国大陆)-Yunjian-男': 'zh-CN-YunjianNeural',
280
+ '普通话 (中国大陆)-Yunxia-男': 'zh-CN-YunxiaNeural',
281
+ '东北话 (中国大陆)-Xiaobei-女': 'zh-CN-liaoning-XiaobeiNeural',
282
+ '中原官话 (中国陕西)-Xiaoni-女': 'zh-CN-shaanxi-XiaoniNeural',
283
+ '粤语 (中国香港)-HiuMaan-女': 'zh-HK-HiuMaanNeural',
284
+ '粤语 (中国香港)-HiuGaai-女': 'zh-HK-HiuGaaiNeural',
285
+ '粤语 (中国香港)-WanLung-男': 'zh-HK-WanLungNeural',
286
+ '台湾普通话-HsiaoChen-女': 'zh-TW-HsiaoChenNeural',
287
+ '台湾普通话-HsiaoYu-女': 'zh-TW-HsiaoYuNeural',
288
+ '台湾普通话-YunJhe-男': 'zh-TW-YunJheNeural',
289
+ '祖鲁语 (南非)-Thando-女': 'zu-ZA-ThandoNeural',
290
+ '祖鲁语 (南非)-Themba-男': 'zu-ZA-ThembaNeural'}
utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import logging
5
+ import json
6
+ import subprocess
7
+ import numpy as np
8
+ from scipy.io.wavfile import read
9
+ import torch
10
+ from torch.nn import functional as F
11
+ from commons import sequence_mask
12
+
13
+ MATPLOTLIB_FLAG = False
14
+
15
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
16
+ logger = logging
17
+
18
+
19
+ def get_cmodel(rank):
20
+ checkpoint = torch.load('wavlm/WavLM-Large.pt')
21
+ cfg = WavLMConfig(checkpoint['cfg'])
22
+ cmodel = WavLM(cfg).cuda(rank)
23
+ cmodel.load_state_dict(checkpoint['model'])
24
+ cmodel.eval()
25
+ return cmodel
26
+
27
+
28
+ def get_content(cmodel, y):
29
+ with torch.no_grad():
30
+ c = cmodel.extract_features(y.squeeze(1))[0]
31
+ c = c.transpose(1, 2)
32
+ return c
33
+
34
+
35
+ def get_vocoder(rank):
36
+ with open("hifigan/config.json", "r") as f:
37
+ config = json.load(f)
38
+ config = hifigan.AttrDict(config)
39
+ vocoder = hifigan.Generator(config)
40
+ ckpt = torch.load("hifigan/generator_v1")
41
+ vocoder.load_state_dict(ckpt["generator"])
42
+ vocoder.eval()
43
+ vocoder.remove_weight_norm()
44
+ vocoder.cuda(rank)
45
+ return vocoder
46
+
47
+
48
+ def transform(mel, height): # 68-92
49
+ #r = np.random.random()
50
+ #rate = r * 0.3 + 0.85 # 0.85-1.15
51
+ #height = int(mel.size(-2) * rate)
52
+ tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
53
+ if height >= mel.size(-2):
54
+ return tgt[:, :mel.size(-2), :]
55
+ else:
56
+ silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1)
57
+ silence += torch.randn_like(silence) / 10
58
+ return torch.cat((tgt, silence), 1)
59
+
60
+
61
+ def stretch(mel, width): # 0.5-2
62
+ return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
63
+
64
+
65
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
66
+ assert os.path.isfile(checkpoint_path)
67
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
68
+ iteration = checkpoint_dict['iteration']
69
+ learning_rate = checkpoint_dict['learning_rate']
70
+ if optimizer is not None:
71
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
72
+ saved_state_dict = checkpoint_dict['model']
73
+ if hasattr(model, 'module'):
74
+ state_dict = model.module.state_dict()
75
+ else:
76
+ state_dict = model.state_dict()
77
+ new_state_dict= {}
78
+ for k, v in state_dict.items():
79
+ try:
80
+ new_state_dict[k] = saved_state_dict[k]
81
+ except:
82
+ logger.info("%s is not in the checkpoint" % k)
83
+ new_state_dict[k] = v
84
+ if hasattr(model, 'module'):
85
+ model.module.load_state_dict(new_state_dict)
86
+ else:
87
+ model.load_state_dict(new_state_dict)
88
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
89
+ checkpoint_path, iteration))
90
+ return model, optimizer, learning_rate, iteration
91
+
92
+
93
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
94
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
95
+ iteration, checkpoint_path))
96
+ if hasattr(model, 'module'):
97
+ state_dict = model.module.state_dict()
98
+ else:
99
+ state_dict = model.state_dict()
100
+ torch.save({'model': state_dict,
101
+ 'iteration': iteration,
102
+ 'optimizer': optimizer.state_dict(),
103
+ 'learning_rate': learning_rate}, checkpoint_path)
104
+
105
+
106
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
107
+ for k, v in scalars.items():
108
+ writer.add_scalar(k, v, global_step)
109
+ for k, v in histograms.items():
110
+ writer.add_histogram(k, v, global_step)
111
+ for k, v in images.items():
112
+ writer.add_image(k, v, global_step, dataformats='HWC')
113
+ for k, v in audios.items():
114
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
115
+
116
+
117
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
118
+ f_list = glob.glob(os.path.join(dir_path, regex))
119
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
120
+ x = f_list[-1]
121
+ print(x)
122
+ return x
123
+
124
+
125
+ def plot_spectrogram_to_numpy(spectrogram):
126
+ global MATPLOTLIB_FLAG
127
+ if not MATPLOTLIB_FLAG:
128
+ import matplotlib
129
+ matplotlib.use("Agg")
130
+ MATPLOTLIB_FLAG = True
131
+ mpl_logger = logging.getLogger('matplotlib')
132
+ mpl_logger.setLevel(logging.WARNING)
133
+ import matplotlib.pylab as plt
134
+ import numpy as np
135
+
136
+ fig, ax = plt.subplots(figsize=(10,2))
137
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
138
+ interpolation='none')
139
+ plt.colorbar(im, ax=ax)
140
+ plt.xlabel("Frames")
141
+ plt.ylabel("Channels")
142
+ plt.tight_layout()
143
+
144
+ fig.canvas.draw()
145
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
146
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
147
+ plt.close()
148
+ return data
149
+
150
+
151
+ def plot_alignment_to_numpy(alignment, info=None):
152
+ global MATPLOTLIB_FLAG
153
+ if not MATPLOTLIB_FLAG:
154
+ import matplotlib
155
+ matplotlib.use("Agg")
156
+ MATPLOTLIB_FLAG = True
157
+ mpl_logger = logging.getLogger('matplotlib')
158
+ mpl_logger.setLevel(logging.WARNING)
159
+ import matplotlib.pylab as plt
160
+ import numpy as np
161
+
162
+ fig, ax = plt.subplots(figsize=(6, 4))
163
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
164
+ interpolation='none')
165
+ fig.colorbar(im, ax=ax)
166
+ xlabel = 'Decoder timestep'
167
+ if info is not None:
168
+ xlabel += '\n\n' + info
169
+ plt.xlabel(xlabel)
170
+ plt.ylabel('Encoder timestep')
171
+ plt.tight_layout()
172
+
173
+ fig.canvas.draw()
174
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
175
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
176
+ plt.close()
177
+ return data
178
+
179
+
180
+ def load_wav_to_torch(full_path):
181
+ sampling_rate, data = read(full_path)
182
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
183
+
184
+
185
+ def load_filepaths_and_text(filename, split="|"):
186
+ with open(filename, encoding='utf-8') as f:
187
+ filepaths_and_text = [line.strip().split(split) for line in f]
188
+ return filepaths_and_text
189
+
190
+
191
+ def get_hparams(init=True):
192
+ parser = argparse.ArgumentParser()
193
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
194
+ help='JSON file for configuration')
195
+ parser.add_argument('-m', '--model', type=str, required=True,
196
+ help='Model name')
197
+
198
+ args = parser.parse_args()
199
+ model_dir = os.path.join("./logs", args.model)
200
+
201
+ if not os.path.exists(model_dir):
202
+ os.makedirs(model_dir)
203
+
204
+ config_path = args.config
205
+ config_save_path = os.path.join(model_dir, "config.json")
206
+ if init:
207
+ with open(config_path, "r") as f:
208
+ data = f.read()
209
+ with open(config_save_path, "w") as f:
210
+ f.write(data)
211
+ else:
212
+ with open(config_save_path, "r") as f:
213
+ data = f.read()
214
+ config = json.loads(data)
215
+
216
+ hparams = HParams(**config)
217
+ hparams.model_dir = model_dir
218
+ return hparams
219
+
220
+
221
+ def get_hparams_from_dir(model_dir):
222
+ config_save_path = os.path.join(model_dir, "config.json")
223
+ with open(config_save_path, "r") as f:
224
+ data = f.read()
225
+ config = json.loads(data)
226
+
227
+ hparams =HParams(**config)
228
+ hparams.model_dir = model_dir
229
+ return hparams
230
+
231
+
232
+ def get_hparams_from_file(config_path):
233
+ with open(config_path, "r") as f:
234
+ data = f.read()
235
+ config = json.loads(data)
236
+
237
+ hparams =HParams(**config)
238
+ return hparams
239
+
240
+
241
+ def check_git_hash(model_dir):
242
+ source_dir = os.path.dirname(os.path.realpath(__file__))
243
+ if not os.path.exists(os.path.join(source_dir, ".git")):
244
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
245
+ source_dir
246
+ ))
247
+ return
248
+
249
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
250
+
251
+ path = os.path.join(model_dir, "githash")
252
+ if os.path.exists(path):
253
+ saved_hash = open(path).read()
254
+ if saved_hash != cur_hash:
255
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
256
+ saved_hash[:8], cur_hash[:8]))
257
+ else:
258
+ open(path, "w").write(cur_hash)
259
+
260
+
261
+ def get_logger(model_dir, filename="train.log"):
262
+ global logger
263
+ logger = logging.getLogger(os.path.basename(model_dir))
264
+ logger.setLevel(logging.DEBUG)
265
+
266
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
267
+ if not os.path.exists(model_dir):
268
+ os.makedirs(model_dir)
269
+ h = logging.FileHandler(os.path.join(model_dir, filename))
270
+ h.setLevel(logging.DEBUG)
271
+ h.setFormatter(formatter)
272
+ logger.addHandler(h)
273
+ return logger
274
+
275
+
276
+ class HParams():
277
+ def __init__(self, **kwargs):
278
+ for k, v in kwargs.items():
279
+ if type(v) == dict:
280
+ v = HParams(**v)
281
+ self[k] = v
282
+
283
+ def keys(self):
284
+ return self.__dict__.keys()
285
+
286
+ def items(self):
287
+ return self.__dict__.items()
288
+
289
+ def values(self):
290
+ return self.__dict__.values()
291
+
292
+ def __len__(self):
293
+ return len(self.__dict__)
294
+
295
+ def __getitem__(self, key):
296
+ return getattr(self, key)
297
+
298
+ def __setitem__(self, key, value):
299
+ return setattr(self, key, value)
300
+
301
+ def __contains__(self, key):
302
+ return key in self.__dict__
303
+
304
+ def __repr__(self):
305
+ return self.__dict__.__repr__()