qxdn commited on
Commit
3787550
1 Parent(s): 2493faf

init space

Browse files
.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ .DS_Store
Data/configs/haruka.json ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 42,
6
+ "epochs": 1000,
7
+ "learning_rate": 0.0002,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 4,
14
+ "fp16_run": false,
15
+ "lr_decay": 0.99995,
16
+ "segment_size": 16384,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "skip_optimizer": true
22
+ },
23
+ "data": {
24
+ "training_files": "Data/filelists/train.list",
25
+ "validation_files": "Data/filelists/val.list",
26
+ "max_wav_value": 32768.0,
27
+ "sampling_rate": 44100,
28
+ "filter_length": 2048,
29
+ "hop_length": 512,
30
+ "win_length": 2048,
31
+ "n_mel_channels": 128,
32
+ "mel_fmin": 0.0,
33
+ "mel_fmax": null,
34
+ "add_blank": true,
35
+ "n_speakers": 4,
36
+ "cleaned_text": true,
37
+ "spk2id": {
38
+ "希亜": 0,
39
+ "天": 1,
40
+ "春風": 2,
41
+ "都": 3
42
+ }
43
+ },
44
+ "model": {
45
+ "use_spk_conditioned_encoder": true,
46
+ "use_noise_scaled_mas": true,
47
+ "use_mel_posterior_encoder": false,
48
+ "use_duration_discriminator": true,
49
+ "inter_channels": 192,
50
+ "hidden_channels": 192,
51
+ "filter_channels": 768,
52
+ "n_heads": 2,
53
+ "n_layers": 6,
54
+ "kernel_size": 3,
55
+ "p_dropout": 0.1,
56
+ "resblock": "1",
57
+ "resblock_kernel_sizes": [
58
+ 3,
59
+ 7,
60
+ 11
61
+ ],
62
+ "resblock_dilation_sizes": [
63
+ [
64
+ 1,
65
+ 3,
66
+ 5
67
+ ],
68
+ [
69
+ 1,
70
+ 3,
71
+ 5
72
+ ],
73
+ [
74
+ 1,
75
+ 3,
76
+ 5
77
+ ]
78
+ ],
79
+ "upsample_rates": [
80
+ 8,
81
+ 8,
82
+ 2,
83
+ 2,
84
+ 2
85
+ ],
86
+ "upsample_initial_channel": 512,
87
+ "upsample_kernel_sizes": [
88
+ 16,
89
+ 16,
90
+ 8,
91
+ 2,
92
+ 2
93
+ ],
94
+ "n_layers_q": 3,
95
+ "use_spectral_norm": false,
96
+ "gin_channels": 256
97
+ },
98
+ "version": "2.1"
99
+ }
Data/models/G_82000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9cff7a84b5bc7fbb44df429e1737132eebe0a01588acb0e7c2a30cb638ac08f5
3
+ size 718770997
app.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
2
+ import os
3
+ import logging
4
+ import re_matching
5
+ from tools.sentence import split_by_language
6
+
7
+ logging.getLogger("numba").setLevel(logging.WARNING)
8
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
9
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
10
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
11
+
12
+ logging.basicConfig(
13
+ level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ import torch
19
+ import utils
20
+ from infer import infer, latest_version, get_net_g, infer_multilang
21
+ import gradio as gr
22
+ import webbrowser
23
+ import numpy as np
24
+ from config import config
25
+ from tools.translate import translate
26
+ import librosa
27
+
28
+ net_g = None
29
+
30
+ device = config.webui_config.device
31
+ if device == "mps":
32
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
33
+
34
+
35
+ def generate_audio(
36
+ slices,
37
+ sdp_ratio,
38
+ noise_scale,
39
+ noise_scale_w,
40
+ length_scale,
41
+ speaker,
42
+ language,
43
+ reference_audio,
44
+ emotion,
45
+ skip_start=False,
46
+ skip_end=False,
47
+ ):
48
+ audio_list = []
49
+ # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
50
+ with torch.no_grad():
51
+ for idx, piece in enumerate(slices):
52
+ skip_start = (idx != 0) and skip_start
53
+ skip_end = (idx != len(slices) - 1) and skip_end
54
+ audio = infer(
55
+ piece,
56
+ reference_audio=reference_audio,
57
+ emotion=emotion,
58
+ sdp_ratio=sdp_ratio,
59
+ noise_scale=noise_scale,
60
+ noise_scale_w=noise_scale_w,
61
+ length_scale=length_scale,
62
+ sid=speaker,
63
+ language=language,
64
+ hps=hps,
65
+ net_g=net_g,
66
+ device=device,
67
+ skip_start=skip_start,
68
+ skip_end=skip_end,
69
+ )
70
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
71
+ audio_list.append(audio16bit)
72
+ # audio_list.append(silence) # 将静音添加到列表中
73
+ return audio_list
74
+
75
+
76
+ def generate_audio_multilang(
77
+ slices,
78
+ sdp_ratio,
79
+ noise_scale,
80
+ noise_scale_w,
81
+ length_scale,
82
+ speaker,
83
+ language,
84
+ reference_audio,
85
+ emotion,
86
+ skip_start=False,
87
+ skip_end=False,
88
+ ):
89
+ audio_list = []
90
+ # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
91
+ with torch.no_grad():
92
+ for idx, piece in enumerate(slices):
93
+ skip_start = (idx != 0) and skip_start
94
+ skip_end = (idx != len(slices) - 1) and skip_end
95
+ audio = infer_multilang(
96
+ piece,
97
+ reference_audio=reference_audio,
98
+ emotion=emotion,
99
+ sdp_ratio=sdp_ratio,
100
+ noise_scale=noise_scale,
101
+ noise_scale_w=noise_scale_w,
102
+ length_scale=length_scale,
103
+ sid=speaker,
104
+ language=language[idx],
105
+ hps=hps,
106
+ net_g=net_g,
107
+ device=device,
108
+ skip_start=skip_start,
109
+ skip_end=skip_end,
110
+ )
111
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
112
+ audio_list.append(audio16bit)
113
+ # audio_list.append(silence) # 将静音添加到列表中
114
+ return audio_list
115
+
116
+
117
+ def tts_split(
118
+ text: str,
119
+ speaker,
120
+ sdp_ratio,
121
+ noise_scale,
122
+ noise_scale_w,
123
+ length_scale,
124
+ language,
125
+ cut_by_sent,
126
+ interval_between_para,
127
+ interval_between_sent,
128
+ reference_audio,
129
+ emotion,
130
+ ):
131
+ if language == "mix":
132
+ return ("invalid", None)
133
+ while text.find("\n\n") != -1:
134
+ text = text.replace("\n\n", "\n")
135
+ para_list = re_matching.cut_para(text)
136
+ audio_list = []
137
+ if not cut_by_sent:
138
+ for idx, p in enumerate(para_list):
139
+ skip_start = idx != 0
140
+ skip_end = idx != len(para_list) - 1
141
+ audio = infer(
142
+ p,
143
+ reference_audio=reference_audio,
144
+ emotion=emotion,
145
+ sdp_ratio=sdp_ratio,
146
+ noise_scale=noise_scale,
147
+ noise_scale_w=noise_scale_w,
148
+ length_scale=length_scale,
149
+ sid=speaker,
150
+ language=language,
151
+ hps=hps,
152
+ net_g=net_g,
153
+ device=device,
154
+ skip_start=skip_start,
155
+ skip_end=skip_end,
156
+ )
157
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
158
+ audio_list.append(audio16bit)
159
+ silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
160
+ audio_list.append(silence)
161
+ else:
162
+ for idx, p in enumerate(para_list):
163
+ skip_start = idx != 0
164
+ skip_end = idx != len(para_list) - 1
165
+ audio_list_sent = []
166
+ sent_list = re_matching.cut_sent(p)
167
+ for idx, s in enumerate(sent_list):
168
+ skip_start = (idx != 0) and skip_start
169
+ skip_end = (idx != len(sent_list) - 1) and skip_end
170
+ audio = infer(
171
+ s,
172
+ reference_audio=reference_audio,
173
+ emotion=emotion,
174
+ sdp_ratio=sdp_ratio,
175
+ noise_scale=noise_scale,
176
+ noise_scale_w=noise_scale_w,
177
+ length_scale=length_scale,
178
+ sid=speaker,
179
+ language=language,
180
+ hps=hps,
181
+ net_g=net_g,
182
+ device=device,
183
+ skip_start=skip_start,
184
+ skip_end=skip_end,
185
+ )
186
+ audio_list_sent.append(audio)
187
+ silence = np.zeros((int)(44100 * interval_between_sent))
188
+ audio_list_sent.append(silence)
189
+ if (interval_between_para - interval_between_sent) > 0:
190
+ silence = np.zeros(
191
+ (int)(44100 * (interval_between_para - interval_between_sent))
192
+ )
193
+ audio_list_sent.append(silence)
194
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(
195
+ np.concatenate(audio_list_sent)
196
+ ) # 对完整句子做音量归一
197
+ audio_list.append(audio16bit)
198
+ audio_concat = np.concatenate(audio_list)
199
+ return ("Success", (44100, audio_concat))
200
+
201
+
202
+ def tts_fn(
203
+ text: str,
204
+ speaker,
205
+ sdp_ratio,
206
+ noise_scale,
207
+ noise_scale_w,
208
+ length_scale,
209
+ language,
210
+ reference_audio,
211
+ emotion,
212
+ ):
213
+ audio_list = []
214
+ if language == "mix":
215
+ bool_valid, str_valid = re_matching.validate_text(text)
216
+ if not bool_valid:
217
+ return str_valid, (
218
+ hps.data.sampling_rate,
219
+ np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
220
+ )
221
+ result = []
222
+ for slice in re_matching.text_matching(text):
223
+ _speaker = slice.pop()
224
+ temp_contant = []
225
+ temp_lang = []
226
+ for lang, content in slice:
227
+ if "|" in content:
228
+ temp = []
229
+ temp_ = []
230
+ for i in content.split("|"):
231
+ if i != "":
232
+ temp.append([i])
233
+ temp_.append([lang])
234
+ else:
235
+ temp.append([])
236
+ temp_.append([])
237
+ temp_contant += temp
238
+ temp_lang += temp_
239
+ else:
240
+ if len(temp_contant) == 0:
241
+ temp_contant.append([])
242
+ temp_lang.append([])
243
+ temp_contant[-1].append(content)
244
+ temp_lang[-1].append(lang)
245
+ for i, j in zip(temp_lang, temp_contant):
246
+ result.append([*zip(i, j), _speaker])
247
+ for i, one in enumerate(result):
248
+ skip_start = i != 0
249
+ skip_end = i != len(result) - 1
250
+ _speaker = one.pop()
251
+ idx = 0
252
+ while idx < len(one):
253
+ text_to_generate = []
254
+ lang_to_generate = []
255
+ while True:
256
+ lang, content = one[idx]
257
+ temp_text = [content]
258
+ if len(text_to_generate) > 0:
259
+ text_to_generate[-1] += [temp_text.pop(0)]
260
+ lang_to_generate[-1] += [lang]
261
+ if len(temp_text) > 0:
262
+ text_to_generate += [[i] for i in temp_text]
263
+ lang_to_generate += [[lang]] * len(temp_text)
264
+ if idx + 1 < len(one):
265
+ idx += 1
266
+ else:
267
+ break
268
+ skip_start = (idx != 0) and skip_start
269
+ skip_end = (idx != len(one) - 1) and skip_end
270
+ print(text_to_generate, lang_to_generate)
271
+ audio_list.extend(
272
+ generate_audio_multilang(
273
+ text_to_generate,
274
+ sdp_ratio,
275
+ noise_scale,
276
+ noise_scale_w,
277
+ length_scale,
278
+ speaker,
279
+ lang_to_generate,
280
+ reference_audio,
281
+ emotion,
282
+ skip_start,
283
+ skip_end,
284
+ )
285
+ )
286
+ idx += 1
287
+ elif language.lower() == "auto":
288
+ for idx, slice in enumerate(text.split("|")):
289
+ if slice == "":
290
+ continue
291
+ skip_start = idx != 0
292
+ skip_end = idx != len(text.split("|")) - 1
293
+ sentences_list = split_by_language(
294
+ slice, target_languages=["zh", "ja", "en"]
295
+ )
296
+ idx = 0
297
+ while idx < len(sentences_list):
298
+ text_to_generate = []
299
+ lang_to_generate = []
300
+ while True:
301
+ content, lang = sentences_list[idx]
302
+ temp_text = [content]
303
+ lang = lang.upper()
304
+ if lang == "JA":
305
+ lang = "JP"
306
+ if len(text_to_generate) > 0:
307
+ text_to_generate[-1] += [temp_text.pop(0)]
308
+ lang_to_generate[-1] += [lang]
309
+ if len(temp_text) > 0:
310
+ text_to_generate += [[i] for i in temp_text]
311
+ lang_to_generate += [[lang]] * len(temp_text)
312
+ if idx + 1 < len(sentences_list):
313
+ idx += 1
314
+ else:
315
+ break
316
+ skip_start = (idx != 0) and skip_start
317
+ skip_end = (idx != len(sentences_list) - 1) and skip_end
318
+ print(text_to_generate, lang_to_generate)
319
+ audio_list.extend(
320
+ generate_audio_multilang(
321
+ text_to_generate,
322
+ sdp_ratio,
323
+ noise_scale,
324
+ noise_scale_w,
325
+ length_scale,
326
+ speaker,
327
+ lang_to_generate,
328
+ reference_audio,
329
+ emotion,
330
+ skip_start,
331
+ skip_end,
332
+ )
333
+ )
334
+ idx += 1
335
+ else:
336
+ audio_list.extend(
337
+ generate_audio(
338
+ text.split("|"),
339
+ sdp_ratio,
340
+ noise_scale,
341
+ noise_scale_w,
342
+ length_scale,
343
+ speaker,
344
+ language,
345
+ reference_audio,
346
+ emotion,
347
+ )
348
+ )
349
+
350
+ audio_concat = np.concatenate(audio_list)
351
+ return "Success", (hps.data.sampling_rate, audio_concat)
352
+
353
+
354
+ if __name__ == "__main__":
355
+ if config.webui_config.debug:
356
+ logger.info("Enable DEBUG-LEVEL log")
357
+ logging.basicConfig(level=logging.DEBUG)
358
+ hps = utils.get_hparams_from_file(config.webui_config.config_path)
359
+ # 若config.json中未指定版本则默认为最新版本
360
+ version = hps.version if hasattr(hps, "version") else latest_version
361
+ net_g = get_net_g(
362
+ model_path=config.webui_config.model, version=version, device=device, hps=hps
363
+ )
364
+ speaker_ids = hps.data.spk2id
365
+ speakers = list(speaker_ids.keys())
366
+ languages = ["ZH", "JP", "EN", "mix", "auto"]
367
+ with gr.Blocks() as app:
368
+ with gr.Row():
369
+ with gr.Column():
370
+ text = gr.TextArea(
371
+ label="输入文本内容",
372
+ placeholder="""
373
+ 如果你选择语言为\'mix\',必须按照格式输入,否则报错:
374
+ 格式举例(zh是中文,jp是日语,不区分大小写;说话人举例:gongzi):
375
+ [说话人1]<zh>你好,こんにちは! <jp>こんにちは,世界。
376
+ [说话人2]<zh>你好吗?<jp>元気ですか?
377
+ [说话人3]<zh>谢谢。<jp>どういたしまして。
378
+ ...
379
+ 另外,所有的语言选项都可以用'|'分割长段实现分句生成。
380
+ """,
381
+ )
382
+ trans = gr.Button("中翻日", variant="primary")
383
+ slicer = gr.Button("快速切分", variant="primary")
384
+ speaker = gr.Dropdown(
385
+ choices=speakers, value=speakers[0], label="Speaker"
386
+ )
387
+ emotion = gr.Slider(
388
+ minimum=0, maximum=9, value=0, step=1, label="Emotion"
389
+ )
390
+ sdp_ratio = gr.Slider(
391
+ minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
392
+ )
393
+ noise_scale = gr.Slider(
394
+ minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
395
+ )
396
+ noise_scale_w = gr.Slider(
397
+ minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise_W"
398
+ )
399
+ length_scale = gr.Slider(
400
+ minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
401
+ )
402
+ language = gr.Dropdown(
403
+ choices=languages, value=languages[0], label="Language"
404
+ )
405
+ btn = gr.Button("生成音频!", variant="primary")
406
+ with gr.Column():
407
+ with gr.Row():
408
+ with gr.Column():
409
+ interval_between_sent = gr.Slider(
410
+ minimum=0,
411
+ maximum=5,
412
+ value=0.2,
413
+ step=0.1,
414
+ label="句间停顿(秒),勾选按句切分才生效",
415
+ )
416
+ interval_between_para = gr.Slider(
417
+ minimum=0,
418
+ maximum=10,
419
+ value=1,
420
+ step=0.1,
421
+ label="段间停顿(秒),需要大于句间停顿才有效",
422
+ )
423
+ opt_cut_by_sent = gr.Checkbox(
424
+ label="按句切分 在按段落切分的基础上再按句子切分文本"
425
+ )
426
+ slicer = gr.Button("切分生成", variant="primary")
427
+ text_output = gr.Textbox(label="状态信息")
428
+ audio_output = gr.Audio(label="输出音频")
429
+ # explain_image = gr.Image(
430
+ # label="参数解释信息",
431
+ # show_label=True,
432
+ # show_share_button=False,
433
+ # show_download_button=False,
434
+ # value=os.path.abspath("./img/参数说明.png"),
435
+ # )
436
+ reference_text = gr.Markdown(value="## 情感参考音频(WAV 格式):用于生成语音的情感参考。")
437
+ reference_audio = gr.Audio(label="情感参考音频(WAV 格式)", type="filepath")
438
+ btn.click(
439
+ tts_fn,
440
+ inputs=[
441
+ text,
442
+ speaker,
443
+ sdp_ratio,
444
+ noise_scale,
445
+ noise_scale_w,
446
+ length_scale,
447
+ language,
448
+ reference_audio,
449
+ emotion,
450
+ ],
451
+ outputs=[text_output, audio_output],
452
+ )
453
+
454
+ trans.click(
455
+ translate,
456
+ inputs=[text],
457
+ outputs=[text],
458
+ )
459
+ slicer.click(
460
+ tts_split,
461
+ inputs=[
462
+ text,
463
+ speaker,
464
+ sdp_ratio,
465
+ noise_scale,
466
+ noise_scale_w,
467
+ length_scale,
468
+ language,
469
+ opt_cut_by_sent,
470
+ interval_between_para,
471
+ interval_between_sent,
472
+ reference_audio,
473
+ emotion,
474
+ ],
475
+ outputs=[text_output, audio_output],
476
+ )
477
+
478
+ reference_audio.upload(
479
+ lambda x: librosa.load(x, 16000)[::-1],
480
+ inputs=[reference_audio],
481
+ outputs=[reference_audio],
482
+ )
483
+ print("推理页面已开启!")
484
+ webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
485
+ app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
attentions.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+ # if isflow:
59
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
62
+ # self.gin_channels = 256
63
+ self.cond_layer_idx = self.n_layers
64
+ if "gin_channels" in kwargs:
65
+ self.gin_channels = kwargs["gin_channels"]
66
+ if self.gin_channels != 0:
67
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
+ # vits2 says 3rd block, so idx is 2 by default
69
+ self.cond_layer_idx = (
70
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
+ )
72
+ logging.debug(self.gin_channels, self.cond_layer_idx)
73
+ assert (
74
+ self.cond_layer_idx < self.n_layers
75
+ ), "cond_layer_idx should be less than n_layers"
76
+ self.drop = nn.Dropout(p_dropout)
77
+ self.attn_layers = nn.ModuleList()
78
+ self.norm_layers_1 = nn.ModuleList()
79
+ self.ffn_layers = nn.ModuleList()
80
+ self.norm_layers_2 = nn.ModuleList()
81
+ for i in range(self.n_layers):
82
+ self.attn_layers.append(
83
+ MultiHeadAttention(
84
+ hidden_channels,
85
+ hidden_channels,
86
+ n_heads,
87
+ p_dropout=p_dropout,
88
+ window_size=window_size,
89
+ )
90
+ )
91
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
92
+ self.ffn_layers.append(
93
+ FFN(
94
+ hidden_channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ kernel_size,
98
+ p_dropout=p_dropout,
99
+ )
100
+ )
101
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
102
+
103
+ def forward(self, x, x_mask, g=None):
104
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
+ x = x * x_mask
106
+ for i in range(self.n_layers):
107
+ if i == self.cond_layer_idx and g is not None:
108
+ g = self.spk_emb_linear(g.transpose(1, 2))
109
+ g = g.transpose(1, 2)
110
+ x = x + g
111
+ x = x * x_mask
112
+ y = self.attn_layers[i](x, x, attn_mask)
113
+ y = self.drop(y)
114
+ x = self.norm_layers_1[i](x + y)
115
+
116
+ y = self.ffn_layers[i](x, x_mask)
117
+ y = self.drop(y)
118
+ x = self.norm_layers_2[i](x + y)
119
+ x = x * x_mask
120
+ return x
121
+
122
+
123
+ class Decoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ hidden_channels,
127
+ filter_channels,
128
+ n_heads,
129
+ n_layers,
130
+ kernel_size=1,
131
+ p_dropout=0.0,
132
+ proximal_bias=False,
133
+ proximal_init=True,
134
+ **kwargs
135
+ ):
136
+ super().__init__()
137
+ self.hidden_channels = hidden_channels
138
+ self.filter_channels = filter_channels
139
+ self.n_heads = n_heads
140
+ self.n_layers = n_layers
141
+ self.kernel_size = kernel_size
142
+ self.p_dropout = p_dropout
143
+ self.proximal_bias = proximal_bias
144
+ self.proximal_init = proximal_init
145
+
146
+ self.drop = nn.Dropout(p_dropout)
147
+ self.self_attn_layers = nn.ModuleList()
148
+ self.norm_layers_0 = nn.ModuleList()
149
+ self.encdec_attn_layers = nn.ModuleList()
150
+ self.norm_layers_1 = nn.ModuleList()
151
+ self.ffn_layers = nn.ModuleList()
152
+ self.norm_layers_2 = nn.ModuleList()
153
+ for i in range(self.n_layers):
154
+ self.self_attn_layers.append(
155
+ MultiHeadAttention(
156
+ hidden_channels,
157
+ hidden_channels,
158
+ n_heads,
159
+ p_dropout=p_dropout,
160
+ proximal_bias=proximal_bias,
161
+ proximal_init=proximal_init,
162
+ )
163
+ )
164
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
165
+ self.encdec_attn_layers.append(
166
+ MultiHeadAttention(
167
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
168
+ )
169
+ )
170
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
171
+ self.ffn_layers.append(
172
+ FFN(
173
+ hidden_channels,
174
+ hidden_channels,
175
+ filter_channels,
176
+ kernel_size,
177
+ p_dropout=p_dropout,
178
+ causal=True,
179
+ )
180
+ )
181
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
182
+
183
+ def forward(self, x, x_mask, h, h_mask):
184
+ """
185
+ x: decoder input
186
+ h: encoder output
187
+ """
188
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
189
+ device=x.device, dtype=x.dtype
190
+ )
191
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
192
+ x = x * x_mask
193
+ for i in range(self.n_layers):
194
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
195
+ y = self.drop(y)
196
+ x = self.norm_layers_0[i](x + y)
197
+
198
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
199
+ y = self.drop(y)
200
+ x = self.norm_layers_1[i](x + y)
201
+
202
+ y = self.ffn_layers[i](x, x_mask)
203
+ y = self.drop(y)
204
+ x = self.norm_layers_2[i](x + y)
205
+ x = x * x_mask
206
+ return x
207
+
208
+
209
+ class MultiHeadAttention(nn.Module):
210
+ def __init__(
211
+ self,
212
+ channels,
213
+ out_channels,
214
+ n_heads,
215
+ p_dropout=0.0,
216
+ window_size=None,
217
+ heads_share=True,
218
+ block_length=None,
219
+ proximal_bias=False,
220
+ proximal_init=False,
221
+ ):
222
+ super().__init__()
223
+ assert channels % n_heads == 0
224
+
225
+ self.channels = channels
226
+ self.out_channels = out_channels
227
+ self.n_heads = n_heads
228
+ self.p_dropout = p_dropout
229
+ self.window_size = window_size
230
+ self.heads_share = heads_share
231
+ self.block_length = block_length
232
+ self.proximal_bias = proximal_bias
233
+ self.proximal_init = proximal_init
234
+ self.attn = None
235
+
236
+ self.k_channels = channels // n_heads
237
+ self.conv_q = nn.Conv1d(channels, channels, 1)
238
+ self.conv_k = nn.Conv1d(channels, channels, 1)
239
+ self.conv_v = nn.Conv1d(channels, channels, 1)
240
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
241
+ self.drop = nn.Dropout(p_dropout)
242
+
243
+ if window_size is not None:
244
+ n_heads_rel = 1 if heads_share else n_heads
245
+ rel_stddev = self.k_channels**-0.5
246
+ self.emb_rel_k = nn.Parameter(
247
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
248
+ * rel_stddev
249
+ )
250
+ self.emb_rel_v = nn.Parameter(
251
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
252
+ * rel_stddev
253
+ )
254
+
255
+ nn.init.xavier_uniform_(self.conv_q.weight)
256
+ nn.init.xavier_uniform_(self.conv_k.weight)
257
+ nn.init.xavier_uniform_(self.conv_v.weight)
258
+ if proximal_init:
259
+ with torch.no_grad():
260
+ self.conv_k.weight.copy_(self.conv_q.weight)
261
+ self.conv_k.bias.copy_(self.conv_q.bias)
262
+
263
+ def forward(self, x, c, attn_mask=None):
264
+ q = self.conv_q(x)
265
+ k = self.conv_k(c)
266
+ v = self.conv_v(c)
267
+
268
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
269
+
270
+ x = self.conv_o(x)
271
+ return x
272
+
273
+ def attention(self, query, key, value, mask=None):
274
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
275
+ b, d, t_s, t_t = (*key.size(), query.size(2))
276
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
277
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
278
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
279
+
280
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
281
+ if self.window_size is not None:
282
+ assert (
283
+ t_s == t_t
284
+ ), "Relative attention is only available for self-attention."
285
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
286
+ rel_logits = self._matmul_with_relative_keys(
287
+ query / math.sqrt(self.k_channels), key_relative_embeddings
288
+ )
289
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
290
+ scores = scores + scores_local
291
+ if self.proximal_bias:
292
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
293
+ scores = scores + self._attention_bias_proximal(t_s).to(
294
+ device=scores.device, dtype=scores.dtype
295
+ )
296
+ if mask is not None:
297
+ scores = scores.masked_fill(mask == 0, -1e4)
298
+ if self.block_length is not None:
299
+ assert (
300
+ t_s == t_t
301
+ ), "Local attention is only available for self-attention."
302
+ block_mask = (
303
+ torch.ones_like(scores)
304
+ .triu(-self.block_length)
305
+ .tril(self.block_length)
306
+ )
307
+ scores = scores.masked_fill(block_mask == 0, -1e4)
308
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
309
+ p_attn = self.drop(p_attn)
310
+ output = torch.matmul(p_attn, value)
311
+ if self.window_size is not None:
312
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
313
+ value_relative_embeddings = self._get_relative_embeddings(
314
+ self.emb_rel_v, t_s
315
+ )
316
+ output = output + self._matmul_with_relative_values(
317
+ relative_weights, value_relative_embeddings
318
+ )
319
+ output = (
320
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
321
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
322
+ return output, p_attn
323
+
324
+ def _matmul_with_relative_values(self, x, y):
325
+ """
326
+ x: [b, h, l, m]
327
+ y: [h or 1, m, d]
328
+ ret: [b, h, l, d]
329
+ """
330
+ ret = torch.matmul(x, y.unsqueeze(0))
331
+ return ret
332
+
333
+ def _matmul_with_relative_keys(self, x, y):
334
+ """
335
+ x: [b, h, l, d]
336
+ y: [h or 1, m, d]
337
+ ret: [b, h, l, m]
338
+ """
339
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
340
+ return ret
341
+
342
+ def _get_relative_embeddings(self, relative_embeddings, length):
343
+ 2 * self.window_size + 1
344
+ # Pad first before slice to avoid using cond ops.
345
+ pad_length = max(length - (self.window_size + 1), 0)
346
+ slice_start_position = max((self.window_size + 1) - length, 0)
347
+ slice_end_position = slice_start_position + 2 * length - 1
348
+ if pad_length > 0:
349
+ padded_relative_embeddings = F.pad(
350
+ relative_embeddings,
351
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
352
+ )
353
+ else:
354
+ padded_relative_embeddings = relative_embeddings
355
+ used_relative_embeddings = padded_relative_embeddings[
356
+ :, slice_start_position:slice_end_position
357
+ ]
358
+ return used_relative_embeddings
359
+
360
+ def _relative_position_to_absolute_position(self, x):
361
+ """
362
+ x: [b, h, l, 2*l-1]
363
+ ret: [b, h, l, l]
364
+ """
365
+ batch, heads, length, _ = x.size()
366
+ # Concat columns of pad to shift from relative to absolute indexing.
367
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
368
+
369
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
370
+ x_flat = x.view([batch, heads, length * 2 * length])
371
+ x_flat = F.pad(
372
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
373
+ )
374
+
375
+ # Reshape and slice out the padded elements.
376
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
377
+ :, :, :length, length - 1 :
378
+ ]
379
+ return x_final
380
+
381
+ def _absolute_position_to_relative_position(self, x):
382
+ """
383
+ x: [b, h, l, l]
384
+ ret: [b, h, l, 2*l-1]
385
+ """
386
+ batch, heads, length, _ = x.size()
387
+ # pad along column
388
+ x = F.pad(
389
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
390
+ )
391
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
392
+ # add 0's in the beginning that will skew the elements after reshape
393
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
394
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
395
+ return x_final
396
+
397
+ def _attention_bias_proximal(self, length):
398
+ """Bias for self-attention to encourage attention to close positions.
399
+ Args:
400
+ length: an integer scalar.
401
+ Returns:
402
+ a Tensor with shape [1, 1, length, length]
403
+ """
404
+ r = torch.arange(length, dtype=torch.float32)
405
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
406
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
407
+
408
+
409
+ class FFN(nn.Module):
410
+ def __init__(
411
+ self,
412
+ in_channels,
413
+ out_channels,
414
+ filter_channels,
415
+ kernel_size,
416
+ p_dropout=0.0,
417
+ activation=None,
418
+ causal=False,
419
+ ):
420
+ super().__init__()
421
+ self.in_channels = in_channels
422
+ self.out_channels = out_channels
423
+ self.filter_channels = filter_channels
424
+ self.kernel_size = kernel_size
425
+ self.p_dropout = p_dropout
426
+ self.activation = activation
427
+ self.causal = causal
428
+
429
+ if causal:
430
+ self.padding = self._causal_padding
431
+ else:
432
+ self.padding = self._same_padding
433
+
434
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
435
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
436
+ self.drop = nn.Dropout(p_dropout)
437
+
438
+ def forward(self, x, x_mask):
439
+ x = self.conv_1(self.padding(x * x_mask))
440
+ if self.activation == "gelu":
441
+ x = x * torch.sigmoid(1.702 * x)
442
+ else:
443
+ x = torch.relu(x)
444
+ x = self.drop(x)
445
+ x = self.conv_2(self.padding(x * x_mask))
446
+ return x * x_mask
447
+
448
+ def _causal_padding(self, x):
449
+ if self.kernel_size == 1:
450
+ return x
451
+ pad_l = self.kernel_size - 1
452
+ pad_r = 0
453
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
454
+ x = F.pad(x, commons.convert_pad_shape(padding))
455
+ return x
456
+
457
+ def _same_padding(self, x):
458
+ if self.kernel_size == 1:
459
+ return x
460
+ pad_l = (self.kernel_size - 1) // 2
461
+ pad_r = self.kernel_size // 2
462
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
463
+ x = F.pad(x, commons.convert_pad_shape(padding))
464
+ return x
bert/bert_models.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "deberta-v2-large-japanese-char-wwm": {
3
+ "repo_id": "ku-nlp/deberta-v2-large-japanese-char-wwm",
4
+ "files": ["pytorch_model.bin"]
5
+ },
6
+ "chinese-roberta-wwm-ext-large": {
7
+ "repo_id": "hfl/chinese-roberta-wwm-ext-large",
8
+ "files": ["pytorch_model.bin"]
9
+ },
10
+ "deberta-v3-large": {
11
+ "repo_id": "microsoft/deberta-v3-large",
12
+ "files": ["spm.model", "pytorch_model.bin"]
13
+ }
14
+ }
commons.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ if idx_str < 0:
54
+ i1 = x.size(2) + idx_str
55
+ r1 = x[i, :, i1:]
56
+ r2 = x[i, :, :idx_end]
57
+ ret[i] = torch.cat([r1, r2], dim=1)
58
+ else:
59
+ ret[i] = x[i, :, idx_str:idx_end]
60
+ return ret
61
+
62
+
63
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
64
+ b, d, t = x.size()
65
+ if x_lengths is None:
66
+ x_lengths = t
67
+ ids_str_max = x_lengths - segment_size + 1
68
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
69
+ ret = slice_segments(x, ids_str, segment_size)
70
+ return ret, ids_str
71
+
72
+
73
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
74
+ position = torch.arange(length, dtype=torch.float)
75
+ num_timescales = channels // 2
76
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
77
+ num_timescales - 1
78
+ )
79
+ inv_timescales = min_timescale * torch.exp(
80
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
81
+ )
82
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
83
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
84
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
85
+ signal = signal.view(1, channels, length)
86
+ return signal
87
+
88
+
89
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return x + signal.to(dtype=x.dtype, device=x.device)
93
+
94
+
95
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
96
+ b, channels, length = x.size()
97
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
98
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
99
+
100
+
101
+ def subsequent_mask(length):
102
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
103
+ return mask
104
+
105
+
106
+ @torch.jit.script
107
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
108
+ n_channels_int = n_channels[0]
109
+ in_act = input_a + input_b
110
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
111
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
112
+ acts = t_act * s_act
113
+ return acts
114
+
115
+
116
+ def convert_pad_shape(pad_shape):
117
+ layer = pad_shape[::-1]
118
+ pad_shape = [item for sublist in layer for item in sublist]
119
+ return pad_shape
120
+
121
+
122
+ def shift_1d(x):
123
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
124
+ return x
125
+
126
+
127
+ def sequence_mask(length, max_length=None):
128
+ if max_length is None:
129
+ max_length = length.max()
130
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
131
+ return x.unsqueeze(0) < length.unsqueeze(1)
132
+
133
+
134
+ def generate_path(duration, mask):
135
+ """
136
+ duration: [b, 1, t_x]
137
+ mask: [b, 1, t_y, t_x]
138
+ """
139
+
140
+ b, _, t_y, t_x = mask.shape
141
+ cum_duration = torch.cumsum(duration, -1)
142
+
143
+ cum_duration_flat = cum_duration.view(b * t_x)
144
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
145
+ path = path.view(b, t_x, t_y)
146
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
147
+ path = path.unsqueeze(1).transpose(2, 3) * mask
148
+ return path
149
+
150
+
151
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
152
+ if isinstance(parameters, torch.Tensor):
153
+ parameters = [parameters]
154
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
155
+ norm_type = float(norm_type)
156
+ if clip_value is not None:
157
+ clip_value = float(clip_value)
158
+
159
+ total_norm = 0
160
+ for p in parameters:
161
+ param_norm = p.grad.data.norm(norm_type)
162
+ total_norm += param_norm.item() ** norm_type
163
+ if clip_value is not None:
164
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
165
+ total_norm = total_norm ** (1.0 / norm_type)
166
+ return total_norm
config.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Desc: 全局配置文件读取
3
+ """
4
+ import argparse
5
+ import yaml
6
+ from typing import Dict, List
7
+ import os
8
+ import shutil
9
+ import sys
10
+
11
+
12
+ class Resample_config:
13
+ """重采样配置"""
14
+
15
+ def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
16
+ self.sampling_rate: int = sampling_rate # 目标采样率
17
+ self.in_dir: str = in_dir # 待处理音频目录路径
18
+ self.out_dir: str = out_dir # 重采样输出路径
19
+
20
+ @classmethod
21
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
22
+ """从字典中生成实例"""
23
+
24
+ # 不检查路径是否有效,此逻辑在resample.py中处理
25
+ data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
26
+ data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
27
+
28
+ return cls(**data)
29
+
30
+
31
+ class Preprocess_text_config:
32
+ """数据预处理配置"""
33
+
34
+ def __init__(
35
+ self,
36
+ transcription_path: str,
37
+ cleaned_path: str,
38
+ train_path: str,
39
+ val_path: str,
40
+ config_path: str,
41
+ val_per_spk: int = 5,
42
+ max_val_total: int = 10000,
43
+ clean: bool = True,
44
+ ):
45
+ self.transcription_path: str = transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
46
+ self.cleaned_path: str = cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
47
+ self.train_path: str = train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
48
+ self.val_path: str = val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
49
+ self.config_path: str = config_path # 配置文件路径
50
+ self.val_per_spk: int = val_per_spk # 每个speaker的验证集条数
51
+ self.max_val_total: int = max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
52
+ self.clean: bool = clean # 是否进行数据清洗
53
+
54
+ @classmethod
55
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
56
+ """从字典中生成实例"""
57
+
58
+ data["transcription_path"] = os.path.join(
59
+ dataset_path, data["transcription_path"]
60
+ )
61
+ if data["cleaned_path"] == "" or data["cleaned_path"] is None:
62
+ data["cleaned_path"] = None
63
+ else:
64
+ data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
65
+ data["train_path"] = os.path.join(dataset_path, data["train_path"])
66
+ data["val_path"] = os.path.join(dataset_path, data["val_path"])
67
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
68
+
69
+ return cls(**data)
70
+
71
+
72
+ class Bert_gen_config:
73
+ """bert_gen 配置"""
74
+
75
+ def __init__(
76
+ self,
77
+ config_path: str,
78
+ num_processes: int = 2,
79
+ device: str = "cuda",
80
+ use_multi_device: bool = False,
81
+ ):
82
+ self.config_path = config_path
83
+ self.num_processes = num_processes
84
+ self.device = device
85
+ self.use_multi_device = use_multi_device
86
+
87
+ @classmethod
88
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
89
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
90
+
91
+ return cls(**data)
92
+
93
+
94
+ class Emo_gen_config:
95
+ """emo_gen 配置"""
96
+
97
+ def __init__(
98
+ self,
99
+ config_path: str,
100
+ num_processes: int = 2,
101
+ device: str = "cuda",
102
+ ):
103
+ self.config_path = config_path
104
+ self.num_processes = num_processes
105
+ self.device = device
106
+
107
+ @classmethod
108
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
109
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
110
+
111
+ return cls(**data)
112
+
113
+
114
+ class Train_ms_config:
115
+ """训练配置"""
116
+
117
+ def __init__(
118
+ self,
119
+ config_path: str,
120
+ env: Dict[str, any],
121
+ base: Dict[str, any],
122
+ model: str,
123
+ num_workers: int,
124
+ spec_cache: bool,
125
+ keep_ckpts: int,
126
+ ):
127
+ self.env = env # 需要加载的环境变量
128
+ self.base = base # 底模配置
129
+ self.model = model # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
130
+ self.config_path = config_path # 配置文件路径
131
+ self.num_workers = num_workers # worker数量
132
+ self.spec_cache = spec_cache # 是否启用spec缓存
133
+ self.keep_ckpts = keep_ckpts # ckpt数量
134
+
135
+ @classmethod
136
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
137
+ # data["model"] = os.path.join(dataset_path, data["model"])
138
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
139
+
140
+ return cls(**data)
141
+
142
+
143
+ class Webui_config:
144
+ """webui 配置"""
145
+
146
+ def __init__(
147
+ self,
148
+ device: str,
149
+ model: str,
150
+ config_path: str,
151
+ language_identification_library: str,
152
+ port: int = 7860,
153
+ share: bool = False,
154
+ debug: bool = False,
155
+ ):
156
+ self.device: str = device
157
+ self.model: str = model # 端口号
158
+ self.config_path: str = config_path # 是否公开部署,对外网开放
159
+ self.port: int = port # 是否开启debug模式
160
+ self.share: bool = share # 模型路径
161
+ self.debug: bool = debug # 配置文件路径
162
+ self.language_identification_library: str = (
163
+ language_identification_library # 语种识别库
164
+ )
165
+
166
+ @classmethod
167
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
168
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
169
+ data["model"] = os.path.join(dataset_path, data["model"])
170
+ return cls(**data)
171
+
172
+
173
+ class Server_config:
174
+ def __init__(
175
+ self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda"
176
+ ):
177
+ self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置
178
+ self.port: int = port # 端口号
179
+ self.device: str = device # 模型默认使用设备
180
+
181
+ @classmethod
182
+ def from_dict(cls, data: Dict[str, any]):
183
+ return cls(**data)
184
+
185
+
186
+ class Translate_config:
187
+ """翻译api配置"""
188
+
189
+ def __init__(self, app_key: str, secret_key: str):
190
+ self.app_key = app_key
191
+ self.secret_key = secret_key
192
+
193
+ @classmethod
194
+ def from_dict(cls, data: Dict[str, any]):
195
+ return cls(**data)
196
+
197
+
198
+ class Config:
199
+ def __init__(self, config_path: str):
200
+ if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
201
+ shutil.copy(src="default_config.yml", dst=config_path)
202
+ print(
203
+ f"已根据默认配置文件default_config.yml生成配置文件{config_path}。请按该配置文件的说明进行配置后重新运行。"
204
+ )
205
+ print("如无特殊需求,请勿修改default_config.yml或备份该文件。")
206
+ sys.exit(0)
207
+ with open(file=config_path, mode="r", encoding="utf-8") as file:
208
+ yaml_config: Dict[str, any] = yaml.safe_load(file.read())
209
+ dataset_path: str = yaml_config["dataset_path"]
210
+ openi_token: str = yaml_config["openi_token"]
211
+ self.dataset_path: str = dataset_path
212
+ self.mirror: str = yaml_config["mirror"]
213
+ self.openi_token: str = openi_token
214
+ self.resample_config: Resample_config = Resample_config.from_dict(
215
+ dataset_path, yaml_config["resample"]
216
+ )
217
+ self.preprocess_text_config: Preprocess_text_config = (
218
+ Preprocess_text_config.from_dict(
219
+ dataset_path, yaml_config["preprocess_text"]
220
+ )
221
+ )
222
+ self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
223
+ dataset_path, yaml_config["bert_gen"]
224
+ )
225
+ self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
226
+ dataset_path, yaml_config["train_ms"]
227
+ )
228
+ self.webui_config: Webui_config = Webui_config.from_dict(
229
+ dataset_path, yaml_config["webui"]
230
+ )
231
+ self.server_config: Server_config = Server_config.from_dict(
232
+ yaml_config["server"]
233
+ )
234
+ self.translate_config: Translate_config = Translate_config.from_dict(
235
+ yaml_config["translate"]
236
+ )
237
+
238
+
239
+ parser = argparse.ArgumentParser()
240
+ # 为避免与以前的config.json起冲突,将其更名如下
241
+ parser.add_argument("-y", "--yml_config", type=str, default="config.yml")
242
+ args, _ = parser.parse_known_args()
243
+ config = Config(args.yml_config)
config.yml ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 全局配置
2
+ # 对于希望在同一时间使用多个配置文件的情况,例如两个GPU同时跑两个训练集:通过环境变量指定配置文件,不指定则默认为./config.yml
3
+
4
+ # 拟提供通用路径配置,统一存放数据,避免数据放得很乱
5
+ # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
6
+ # 不填或者填空则路径为相对于项目根目录的路径
7
+ dataset_path: "Data/"
8
+
9
+ # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
10
+ mirror: ""
11
+ openi_token: "" # openi token
12
+
13
+ # resample 音频重采样配置
14
+ # 注意, “:” 后需要加空格
15
+ resample:
16
+ # 目标重采样率
17
+ sampling_rate: 44100
18
+ # 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
19
+ # 请填入相对于datasetPath的相对路径
20
+ in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir
21
+ # 音频文件重采样后输出路径
22
+ out_dir: "audios/wavs"
23
+
24
+
25
+ # preprocess_text 数据集预处理相关配置
26
+ # 注意, “:” 后需要加空格
27
+ preprocess_text:
28
+ # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
29
+ transcription_path: "filelists/haruka.list"
30
+ # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
31
+ cleaned_path: ""
32
+ # 训练集路径
33
+ train_path: "filelists/train.list"
34
+ # 验证集路径
35
+ val_path: "filelists/val.list"
36
+ # 配置文件路径
37
+ config_path: "haruka.json"
38
+ # 每个speaker的验证集条数
39
+ val_per_spk: 40
40
+ # 验证集最大条数,多于的会被截断并放到训练集中
41
+ max_val_total: 200
42
+ # 是否进行数据清洗
43
+ clean: true
44
+
45
+
46
+ # bert_gen 相关配置
47
+ # 注意, “:” 后需要加空格
48
+ bert_gen:
49
+ # 训练数据集配置文件路径
50
+ config_path: "haruka.json"
51
+ # 并行数
52
+ num_processes: 2
53
+ # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
54
+ # 该选项同时决定了get_bert_feature的默认设备
55
+ device: "cuda"
56
+ # 使用多卡推理
57
+ use_multi_device: false
58
+
59
+ # emo_gen 相关配置
60
+ # 注意, “:” 后需要加空格
61
+ emo_gen:
62
+ # 训练数据集配置文件路径
63
+ config_path: "haruka.json"
64
+ # 并行数
65
+ num_processes: 2
66
+ # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
67
+ device: "cuda"
68
+
69
+ # train 训练配置
70
+ # 注意, “:” 后需要加空格
71
+ train_ms:
72
+ env:
73
+ MASTER_ADDR: "localhost"
74
+ MASTER_PORT: 10086
75
+ WORLD_SIZE: 1
76
+ LOCAL_RANK: 0
77
+ RANK: 0
78
+ # 可以填写任意名的环境变量
79
+ # THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
80
+ # 底模设置
81
+ base:
82
+ use_base_model: false
83
+ repo_id: "Stardust_minus/Bert-VITS2"
84
+ model_image: "Bert-VITS2_2.1-Emo底模" # openi网页的模型名
85
+ # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
86
+ model: "models"
87
+ # 配置文件路径
88
+ config_path: "haruka.json"
89
+ # 训练使用的worker,不建议超过CPU核心数
90
+ num_workers: 16
91
+ # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
92
+ spec_cache: True
93
+ # 保存的检查点数量,多于此数目的权重会被删除来节省空间。
94
+ keep_ckpts: 8
95
+
96
+
97
+ # webui webui配置
98
+ # 注意, “:” 后需要加空格
99
+ webui:
100
+ # 推理设备
101
+ device: "cuda"
102
+ # 模型路径
103
+ model: "models/G_82000.pth"
104
+ # 配置文件路径
105
+ config_path: "configs/haruka.json"
106
+ # 端口号
107
+ port: 7860
108
+ # 是否公开部署,对外网开放
109
+ share: false
110
+ # 是否开启debug模式
111
+ debug: false
112
+ # 语种识别库,可选langid, fastlid
113
+ language_identification_library: "langid"
114
+
115
+
116
+ # server api配置
117
+ # 注意, “:” 后需要加空格
118
+ # 注意,本配置下的所有配置均为相对于根目录的路径
119
+ server:
120
+ # 端口号
121
+ port: 5000
122
+ # 模型默认使用设备:但是当前并没有实现这个配置。
123
+ device: "cuda"
124
+ # 需要加载的所有模型的配置
125
+ # 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
126
+ models:
127
+ - # 模型的路径
128
+ model: ""
129
+ # 模型config.json的路径
130
+ config: ""
131
+ # 模型使用设备,若填写则会覆盖默认配置
132
+ device: "cuda"
133
+ # 模型默认使用的语言
134
+ language: "ZH"
135
+ # 模型人物默认参数
136
+ # 不必填写所有人物,不填的使用默认值
137
+ # 暂时不用填写,当前尚未实现按人区分配置
138
+ speakers:
139
+ - speaker: "科比"
140
+ sdp_ratio: 0.2
141
+ noise_scale: 0.6
142
+ noise_scale_w: 0.8
143
+ length_scale: 1
144
+ - speaker: "五条悟"
145
+ sdp_ratio: 0.3
146
+ noise_scale: 0.7
147
+ noise_scale_w: 0.8
148
+ length_scale: 0.5
149
+ - speaker: "安倍晋三"
150
+ sdp_ratio: 0.2
151
+ noise_scale: 0.6
152
+ noise_scale_w: 0.8
153
+ length_scale: 1.2
154
+ - # 模型的路径
155
+ model: ""
156
+ # 模型config.json的路径
157
+ config: ""
158
+ # 模型使用设备,若填写则会覆盖默认配置
159
+ device: "cpu"
160
+ # 模型默认使用的语言
161
+ language: "JP"
162
+ # 模型人物默认参数
163
+ # 不必填写所有人物,不填的使用默认值
164
+ speakers: [ ] # 也可以不填
165
+
166
+
167
+ # 百度翻译开放平台 api配置
168
+ # api接入文档 https://api.fanyi.baidu.com/doc/21
169
+ # 请不要在github等网站公开分享你的app id 与 key
170
+ translate:
171
+ # 你的APPID
172
+ "app_key": ""
173
+ # 你的密钥
174
+ "secret_key": ""
emo_gen.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from tqdm import tqdm
11
+ from transformers import Wav2Vec2Processor
12
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
13
+ Wav2Vec2Model,
14
+ Wav2Vec2PreTrainedModel,
15
+ )
16
+
17
+ import utils
18
+ from config import config
19
+
20
+
21
+ class RegressionHead(nn.Module):
22
+ r"""Classification head."""
23
+
24
+ def __init__(self, config):
25
+ super().__init__()
26
+
27
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
28
+ self.dropout = nn.Dropout(config.final_dropout)
29
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
30
+
31
+ def forward(self, features, **kwargs):
32
+ x = features
33
+ x = self.dropout(x)
34
+ x = self.dense(x)
35
+ x = torch.tanh(x)
36
+ x = self.dropout(x)
37
+ x = self.out_proj(x)
38
+
39
+ return x
40
+
41
+
42
+ class EmotionModel(Wav2Vec2PreTrainedModel):
43
+ r"""Speech emotion classifier."""
44
+
45
+ def __init__(self, config):
46
+ super().__init__(config)
47
+
48
+ self.config = config
49
+ self.wav2vec2 = Wav2Vec2Model(config)
50
+ self.classifier = RegressionHead(config)
51
+ self.init_weights()
52
+
53
+ def forward(
54
+ self,
55
+ input_values,
56
+ ):
57
+ outputs = self.wav2vec2(input_values)
58
+ hidden_states = outputs[0]
59
+ hidden_states = torch.mean(hidden_states, dim=1)
60
+ logits = self.classifier(hidden_states)
61
+
62
+ return hidden_states, logits
63
+
64
+
65
+ class AudioDataset(Dataset):
66
+ def __init__(self, list_of_wav_files, sr, processor):
67
+ self.list_of_wav_files = list_of_wav_files
68
+ self.processor = processor
69
+ self.sr = sr
70
+
71
+ def __len__(self):
72
+ return len(self.list_of_wav_files)
73
+
74
+ def __getitem__(self, idx):
75
+ wav_file = self.list_of_wav_files[idx]
76
+ audio_data, _ = librosa.load(wav_file, sr=self.sr)
77
+ processed_data = self.processor(audio_data, sampling_rate=self.sr)[
78
+ "input_values"
79
+ ][0]
80
+ return torch.from_numpy(processed_data)
81
+
82
+
83
+ def process_func(
84
+ x: np.ndarray,
85
+ sampling_rate: int,
86
+ model: EmotionModel,
87
+ processor: Wav2Vec2Processor,
88
+ device: str,
89
+ embeddings: bool = False,
90
+ ) -> np.ndarray:
91
+ r"""Predict emotions or extract embeddings from raw audio signal."""
92
+ model = model.to(device)
93
+ y = processor(x, sampling_rate=sampling_rate)
94
+ y = y["input_values"][0]
95
+ y = torch.from_numpy(y).unsqueeze(0).to(device)
96
+
97
+ # run through model
98
+ with torch.no_grad():
99
+ y = model(y)[0 if embeddings else 1]
100
+
101
+ # convert to numpy
102
+ y = y.detach().cpu().numpy()
103
+
104
+ return y
105
+
106
+
107
+ def get_emo(path):
108
+ wav, sr = librosa.load(path, 16000)
109
+ device = config.bert_gen_config.device
110
+ return process_func(
111
+ np.expand_dims(wav, 0).astype(np.float64),
112
+ sr,
113
+ model,
114
+ processor,
115
+ device,
116
+ embeddings=True,
117
+ ).squeeze(0)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument(
123
+ "-c", "--config", type=str, default=config.bert_gen_config.config_path
124
+ )
125
+ parser.add_argument(
126
+ "--num_processes", type=int, default=config.bert_gen_config.num_processes
127
+ )
128
+ args, _ = parser.parse_known_args()
129
+ config_path = args.config
130
+ hps = utils.get_hparams_from_file(config_path)
131
+
132
+ device = config.bert_gen_config.device
133
+
134
+ model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
135
+ REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
136
+ if not Path(model_name).joinpath("pytorch_model.bin").exists():
137
+ utils.download_emo_models(config.mirror, model_name, REPO_ID)
138
+
139
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
140
+ model = EmotionModel.from_pretrained(model_name).to(device)
141
+
142
+ lines = []
143
+ with open(hps.data.training_files, encoding="utf-8") as f:
144
+ lines.extend(f.readlines())
145
+
146
+ with open(hps.data.validation_files, encoding="utf-8") as f:
147
+ lines.extend(f.readlines())
148
+
149
+ wavnames = [line.split("|")[0] for line in lines]
150
+ dataset = AudioDataset(wavnames, 16000, processor)
151
+ data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=16)
152
+
153
+ with torch.no_grad():
154
+ for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
155
+ wavname = wavnames[i]
156
+ emo_path = wavname.replace(".wav", ".emo.npy")
157
+ if os.path.exists(emo_path):
158
+ continue
159
+ emb = model(data.to(device))[0].detach().cpu().numpy()
160
+ np.save(emo_path, emb)
161
+
162
+ print("Emo vec 生成完毕!")
infer.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 版本管理、兼容推理及模型加载实现。
3
+ 版本说明:
4
+ 1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
5
+ 2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
+ 特殊版本说明:
7
+ 1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
+ 1.1.1-dev: dev开发
9
+ 2.1:当前版本
10
+ """
11
+ import torch
12
+ import commons
13
+ from text import cleaned_text_to_sequence, get_bert
14
+ from emo_gen import get_emo
15
+ from text.cleaner import clean_text
16
+ import utils
17
+
18
+ from models import SynthesizerTrn
19
+ from text.symbols import symbols
20
+
21
+ # 当前版本信息
22
+ latest_version = "2.1"
23
+
24
+
25
+
26
+ def get_net_g(model_path: str, version: str, device: str, hps):
27
+ if version != latest_version:
28
+ pass
29
+ else:
30
+ # 当前版本模型 net_g
31
+ net_g = SynthesizerTrn(
32
+ len(symbols),
33
+ hps.data.filter_length // 2 + 1,
34
+ hps.train.segment_size // hps.data.hop_length,
35
+ n_speakers=hps.data.n_speakers,
36
+ **hps.model,
37
+ ).to(device)
38
+ _ = net_g.eval()
39
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
40
+ return net_g
41
+
42
+
43
+ def get_text(text, reference_audio, emotion, language_str, hps, device):
44
+ # 在此处实现当前版本的get_text
45
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
46
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
47
+
48
+ if hps.data.add_blank:
49
+ phone = commons.intersperse(phone, 0)
50
+ tone = commons.intersperse(tone, 0)
51
+ language = commons.intersperse(language, 0)
52
+ for i in range(len(word2ph)):
53
+ word2ph[i] = word2ph[i] * 2
54
+ word2ph[0] += 1
55
+ bert_ori = get_bert(norm_text, word2ph, language_str, device)
56
+ del word2ph
57
+ assert bert_ori.shape[-1] == len(phone), phone
58
+
59
+ if language_str == "ZH":
60
+ bert = bert_ori
61
+ ja_bert = torch.zeros(1024, len(phone))
62
+ en_bert = torch.zeros(1024, len(phone))
63
+ elif language_str == "JP":
64
+ bert = torch.zeros(1024, len(phone))
65
+ ja_bert = bert_ori
66
+ en_bert = torch.zeros(1024, len(phone))
67
+ elif language_str == "EN":
68
+ bert = torch.zeros(1024, len(phone))
69
+ ja_bert = torch.zeros(1024, len(phone))
70
+ en_bert = bert_ori
71
+ else:
72
+ raise ValueError("language_str should be ZH, JP or EN")
73
+
74
+ emo = (
75
+ torch.from_numpy(get_emo(reference_audio))
76
+ if reference_audio
77
+ else torch.Tensor([emotion])
78
+ )
79
+
80
+ assert bert.shape[-1] == len(
81
+ phone
82
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
83
+
84
+ phone = torch.LongTensor(phone)
85
+ tone = torch.LongTensor(tone)
86
+ language = torch.LongTensor(language)
87
+ return bert, ja_bert, en_bert, emo, phone, tone, language
88
+
89
+
90
+ def infer(
91
+ text,
92
+ sdp_ratio,
93
+ noise_scale,
94
+ noise_scale_w,
95
+ length_scale,
96
+ sid,
97
+ language,
98
+ hps,
99
+ net_g,
100
+ device,
101
+ reference_audio=None,
102
+ emotion=None,
103
+ skip_start=False,
104
+ skip_end=False,
105
+ ):
106
+ version = hps.version if hasattr(hps, "version") else latest_version
107
+ # 非当前版本,根据版本号选择合适的infer
108
+ if version != latest_version:
109
+ pass
110
+ # 在此处实现当前版本的推理
111
+ bert, ja_bert, en_bert, emo, phones, tones, lang_ids = get_text(
112
+ text, reference_audio, emotion, language, hps, device
113
+ )
114
+ if skip_start:
115
+ phones = phones[1:]
116
+ tones = tones[1:]
117
+ lang_ids = lang_ids[1:]
118
+ bert = bert[:, 1:]
119
+ ja_bert = ja_bert[:, 1:]
120
+ en_bert = en_bert[:, 1:]
121
+ if skip_end:
122
+ phones = phones[:-1]
123
+ tones = tones[:-1]
124
+ lang_ids = lang_ids[:-1]
125
+ bert = bert[:, :-1]
126
+ ja_bert = ja_bert[:, :-1]
127
+ en_bert = en_bert[:, :-1]
128
+ with torch.no_grad():
129
+ x_tst = phones.to(device).unsqueeze(0)
130
+ tones = tones.to(device).unsqueeze(0)
131
+ lang_ids = lang_ids.to(device).unsqueeze(0)
132
+ bert = bert.to(device).unsqueeze(0)
133
+ ja_bert = ja_bert.to(device).unsqueeze(0)
134
+ en_bert = en_bert.to(device).unsqueeze(0)
135
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
136
+ emo = emo.to(device).unsqueeze(0)
137
+ del phones
138
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
139
+ audio = (
140
+ net_g.infer(
141
+ x_tst,
142
+ x_tst_lengths,
143
+ speakers,
144
+ tones,
145
+ lang_ids,
146
+ bert,
147
+ ja_bert,
148
+ en_bert,
149
+ emo,
150
+ sdp_ratio=sdp_ratio,
151
+ noise_scale=noise_scale,
152
+ noise_scale_w=noise_scale_w,
153
+ length_scale=length_scale,
154
+ )[0][0, 0]
155
+ .data.cpu()
156
+ .float()
157
+ .numpy()
158
+ )
159
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
160
+ if torch.cuda.is_available():
161
+ torch.cuda.empty_cache()
162
+ return audio
163
+
164
+
165
+ def infer_multilang(
166
+ text,
167
+ sdp_ratio,
168
+ noise_scale,
169
+ noise_scale_w,
170
+ length_scale,
171
+ sid,
172
+ language,
173
+ hps,
174
+ net_g,
175
+ device,
176
+ reference_audio=None,
177
+ emotion=None,
178
+ skip_start=False,
179
+ skip_end=False,
180
+ ):
181
+ bert, ja_bert, en_bert, emo, phones, tones, lang_ids = [], [], [], [], [], [], []
182
+ # bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
183
+ # text, language, hps, device
184
+ # )
185
+ for idx, (txt, lang) in enumerate(zip(text, language)):
186
+ skip_start = (idx != 0) or (skip_start and idx == 0)
187
+ skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
188
+ (
189
+ temp_bert,
190
+ temp_ja_bert,
191
+ temp_en_bert,
192
+ temp_emo,
193
+ temp_phones,
194
+ temp_tones,
195
+ temp_lang_ids,
196
+ ) = get_text(txt, reference_audio, emotion, language, hps, device)
197
+ if skip_start:
198
+ temp_bert = temp_bert[:, 1:]
199
+ temp_ja_bert = temp_ja_bert[:, 1:]
200
+ temp_en_bert = temp_en_bert[:, 1:]
201
+ temp_emo = temp_emo[:, 1:]
202
+ temp_phones = temp_phones[1:]
203
+ temp_tones = temp_tones[1:]
204
+ temp_lang_ids = temp_lang_ids[1:]
205
+ if skip_end:
206
+ temp_bert = temp_bert[:, :-1]
207
+ temp_ja_bert = temp_ja_bert[:, :-1]
208
+ temp_en_bert = temp_en_bert[:, :-1]
209
+ temp_emo = temp_emo[:, :-1]
210
+ temp_phones = temp_phones[:-1]
211
+ temp_tones = temp_tones[:-1]
212
+ temp_lang_ids = temp_lang_ids[:-1]
213
+ bert.append(temp_bert)
214
+ ja_bert.append(temp_ja_bert)
215
+ en_bert.append(temp_en_bert)
216
+ emo.append(temp_emo)
217
+ phones.append(temp_phones)
218
+ tones.append(temp_tones)
219
+ lang_ids.append(temp_lang_ids)
220
+ bert = torch.concatenate(bert, dim=1)
221
+ ja_bert = torch.concatenate(ja_bert, dim=1)
222
+ en_bert = torch.concatenate(en_bert, dim=1)
223
+ emo = torch.concatenate(emo, dim=1)
224
+ phones = torch.concatenate(phones, dim=0)
225
+ tones = torch.concatenate(tones, dim=0)
226
+ lang_ids = torch.concatenate(lang_ids, dim=0)
227
+ with torch.no_grad():
228
+ x_tst = phones.to(device).unsqueeze(0)
229
+ tones = tones.to(device).unsqueeze(0)
230
+ lang_ids = lang_ids.to(device).unsqueeze(0)
231
+ bert = bert.to(device).unsqueeze(0)
232
+ ja_bert = ja_bert.to(device).unsqueeze(0)
233
+ en_bert = en_bert.to(device).unsqueeze(0)
234
+ emo = emo.to(device).unsqueeze(0)
235
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
236
+ del phones
237
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
238
+ audio = (
239
+ net_g.infer(
240
+ x_tst,
241
+ x_tst_lengths,
242
+ speakers,
243
+ tones,
244
+ lang_ids,
245
+ bert,
246
+ ja_bert,
247
+ en_bert,
248
+ emo,
249
+ sdp_ratio=sdp_ratio,
250
+ noise_scale=noise_scale,
251
+ noise_scale_w=noise_scale_w,
252
+ length_scale=length_scale,
253
+ )[0][0, 0]
254
+ .data.cpu()
255
+ .float()
256
+ .numpy()
257
+ )
258
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
259
+ if torch.cuda.is_available():
260
+ torch.cuda.empty_cache()
261
+ return audio
models.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from vector_quantize_pytorch import VectorQuantize
14
+
15
+ from commons import init_weights, get_padding
16
+ from text import symbols, num_tones, num_languages
17
+
18
+
19
+ class DurationDiscriminator(nn.Module): # vits2
20
+ def __init__(
21
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
22
+ ):
23
+ super().__init__()
24
+
25
+ self.in_channels = in_channels
26
+ self.filter_channels = filter_channels
27
+ self.kernel_size = kernel_size
28
+ self.p_dropout = p_dropout
29
+ self.gin_channels = gin_channels
30
+
31
+ self.drop = nn.Dropout(p_dropout)
32
+ self.conv_1 = nn.Conv1d(
33
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
+ )
35
+ self.norm_1 = modules.LayerNorm(filter_channels)
36
+ self.conv_2 = nn.Conv1d(
37
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
38
+ )
39
+ self.norm_2 = modules.LayerNorm(filter_channels)
40
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
41
+
42
+ self.pre_out_conv_1 = nn.Conv1d(
43
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
44
+ )
45
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
46
+ self.pre_out_conv_2 = nn.Conv1d(
47
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
48
+ )
49
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
50
+
51
+ if gin_channels != 0:
52
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
53
+
54
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
55
+
56
+ def forward_probability(self, x, x_mask, dur, g=None):
57
+ dur = self.dur_proj(dur)
58
+ x = torch.cat([x, dur], dim=1)
59
+ x = self.pre_out_conv_1(x * x_mask)
60
+ x = torch.relu(x)
61
+ x = self.pre_out_norm_1(x)
62
+ x = self.drop(x)
63
+ x = self.pre_out_conv_2(x * x_mask)
64
+ x = torch.relu(x)
65
+ x = self.pre_out_norm_2(x)
66
+ x = self.drop(x)
67
+ x = x * x_mask
68
+ x = x.transpose(1, 2)
69
+ output_prob = self.output_layer(x)
70
+ return output_prob
71
+
72
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
73
+ x = torch.detach(x)
74
+ if g is not None:
75
+ g = torch.detach(g)
76
+ x = x + self.cond(g)
77
+ x = self.conv_1(x * x_mask)
78
+ x = torch.relu(x)
79
+ x = self.norm_1(x)
80
+ x = self.drop(x)
81
+ x = self.conv_2(x * x_mask)
82
+ x = torch.relu(x)
83
+ x = self.norm_2(x)
84
+ x = self.drop(x)
85
+
86
+ output_probs = []
87
+ for dur in [dur_r, dur_hat]:
88
+ output_prob = self.forward_probability(x, x_mask, dur, g)
89
+ output_probs.append(output_prob)
90
+
91
+ return output_probs
92
+
93
+
94
+ class TransformerCouplingBlock(nn.Module):
95
+ def __init__(
96
+ self,
97
+ channels,
98
+ hidden_channels,
99
+ filter_channels,
100
+ n_heads,
101
+ n_layers,
102
+ kernel_size,
103
+ p_dropout,
104
+ n_flows=4,
105
+ gin_channels=0,
106
+ share_parameter=False,
107
+ ):
108
+ super().__init__()
109
+ self.channels = channels
110
+ self.hidden_channels = hidden_channels
111
+ self.kernel_size = kernel_size
112
+ self.n_layers = n_layers
113
+ self.n_flows = n_flows
114
+ self.gin_channels = gin_channels
115
+
116
+ self.flows = nn.ModuleList()
117
+
118
+ self.wn = (
119
+ attentions.FFT(
120
+ hidden_channels,
121
+ filter_channels,
122
+ n_heads,
123
+ n_layers,
124
+ kernel_size,
125
+ p_dropout,
126
+ isflow=True,
127
+ gin_channels=self.gin_channels,
128
+ )
129
+ if share_parameter
130
+ else None
131
+ )
132
+
133
+ for i in range(n_flows):
134
+ self.flows.append(
135
+ modules.TransformerCouplingLayer(
136
+ channels,
137
+ hidden_channels,
138
+ kernel_size,
139
+ n_layers,
140
+ n_heads,
141
+ p_dropout,
142
+ filter_channels,
143
+ mean_only=True,
144
+ wn_sharing_parameter=self.wn,
145
+ gin_channels=self.gin_channels,
146
+ )
147
+ )
148
+ self.flows.append(modules.Flip())
149
+
150
+ def forward(self, x, x_mask, g=None, reverse=False):
151
+ if not reverse:
152
+ for flow in self.flows:
153
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
154
+ else:
155
+ for flow in reversed(self.flows):
156
+ x = flow(x, x_mask, g=g, reverse=reverse)
157
+ return x
158
+
159
+
160
+ class StochasticDurationPredictor(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ filter_channels,
165
+ kernel_size,
166
+ p_dropout,
167
+ n_flows=4,
168
+ gin_channels=0,
169
+ ):
170
+ super().__init__()
171
+ filter_channels = in_channels # it needs to be removed from future version.
172
+ self.in_channels = in_channels
173
+ self.filter_channels = filter_channels
174
+ self.kernel_size = kernel_size
175
+ self.p_dropout = p_dropout
176
+ self.n_flows = n_flows
177
+ self.gin_channels = gin_channels
178
+
179
+ self.log_flow = modules.Log()
180
+ self.flows = nn.ModuleList()
181
+ self.flows.append(modules.ElementwiseAffine(2))
182
+ for i in range(n_flows):
183
+ self.flows.append(
184
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
185
+ )
186
+ self.flows.append(modules.Flip())
187
+
188
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
189
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
190
+ self.post_convs = modules.DDSConv(
191
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
192
+ )
193
+ self.post_flows = nn.ModuleList()
194
+ self.post_flows.append(modules.ElementwiseAffine(2))
195
+ for i in range(4):
196
+ self.post_flows.append(
197
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
198
+ )
199
+ self.post_flows.append(modules.Flip())
200
+
201
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
202
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
203
+ self.convs = modules.DDSConv(
204
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
205
+ )
206
+ if gin_channels != 0:
207
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
208
+
209
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
210
+ x = torch.detach(x)
211
+ x = self.pre(x)
212
+ if g is not None:
213
+ g = torch.detach(g)
214
+ x = x + self.cond(g)
215
+ x = self.convs(x, x_mask)
216
+ x = self.proj(x) * x_mask
217
+
218
+ if not reverse:
219
+ flows = self.flows
220
+ assert w is not None
221
+
222
+ logdet_tot_q = 0
223
+ h_w = self.post_pre(w)
224
+ h_w = self.post_convs(h_w, x_mask)
225
+ h_w = self.post_proj(h_w) * x_mask
226
+ e_q = (
227
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
228
+ * x_mask
229
+ )
230
+ z_q = e_q
231
+ for flow in self.post_flows:
232
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
233
+ logdet_tot_q += logdet_q
234
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
235
+ u = torch.sigmoid(z_u) * x_mask
236
+ z0 = (w - u) * x_mask
237
+ logdet_tot_q += torch.sum(
238
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
239
+ )
240
+ logq = (
241
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
242
+ - logdet_tot_q
243
+ )
244
+
245
+ logdet_tot = 0
246
+ z0, logdet = self.log_flow(z0, x_mask)
247
+ logdet_tot += logdet
248
+ z = torch.cat([z0, z1], 1)
249
+ for flow in flows:
250
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
251
+ logdet_tot = logdet_tot + logdet
252
+ nll = (
253
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
254
+ - logdet_tot
255
+ )
256
+ return nll + logq # [b]
257
+ else:
258
+ flows = list(reversed(self.flows))
259
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
260
+ z = (
261
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
262
+ * noise_scale
263
+ )
264
+ for flow in flows:
265
+ z = flow(z, x_mask, g=x, reverse=reverse)
266
+ z0, z1 = torch.split(z, [1, 1], 1)
267
+ logw = z0
268
+ return logw
269
+
270
+
271
+ class DurationPredictor(nn.Module):
272
+ def __init__(
273
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
274
+ ):
275
+ super().__init__()
276
+
277
+ self.in_channels = in_channels
278
+ self.filter_channels = filter_channels
279
+ self.kernel_size = kernel_size
280
+ self.p_dropout = p_dropout
281
+ self.gin_channels = gin_channels
282
+
283
+ self.drop = nn.Dropout(p_dropout)
284
+ self.conv_1 = nn.Conv1d(
285
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
286
+ )
287
+ self.norm_1 = modules.LayerNorm(filter_channels)
288
+ self.conv_2 = nn.Conv1d(
289
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
290
+ )
291
+ self.norm_2 = modules.LayerNorm(filter_channels)
292
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
293
+
294
+ if gin_channels != 0:
295
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
296
+
297
+ def forward(self, x, x_mask, g=None):
298
+ x = torch.detach(x)
299
+ if g is not None:
300
+ g = torch.detach(g)
301
+ x = x + self.cond(g)
302
+ x = self.conv_1(x * x_mask)
303
+ x = torch.relu(x)
304
+ x = self.norm_1(x)
305
+ x = self.drop(x)
306
+ x = self.conv_2(x * x_mask)
307
+ x = torch.relu(x)
308
+ x = self.norm_2(x)
309
+ x = self.drop(x)
310
+ x = self.proj(x * x_mask)
311
+ return x * x_mask
312
+
313
+
314
+ class TextEncoder(nn.Module):
315
+ def __init__(
316
+ self,
317
+ n_vocab,
318
+ out_channels,
319
+ hidden_channels,
320
+ filter_channels,
321
+ n_heads,
322
+ n_layers,
323
+ kernel_size,
324
+ p_dropout,
325
+ n_speakers,
326
+ gin_channels=0,
327
+ ):
328
+ super().__init__()
329
+ self.n_vocab = n_vocab
330
+ self.out_channels = out_channels
331
+ self.hidden_channels = hidden_channels
332
+ self.filter_channels = filter_channels
333
+ self.n_heads = n_heads
334
+ self.n_layers = n_layers
335
+ self.kernel_size = kernel_size
336
+ self.p_dropout = p_dropout
337
+ self.gin_channels = gin_channels
338
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
339
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
340
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
341
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
342
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
343
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
344
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
345
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
346
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
347
+ self.emo_proj = nn.Linear(1024, 1024)
348
+ self.emo_quantizer = [
349
+ VectorQuantize(
350
+ dim=1024,
351
+ codebook_size=10,
352
+ decay=0.8,
353
+ commitment_weight=1.0,
354
+ learnable_codebook=True,
355
+ ema_update=False,
356
+ )
357
+ ] * n_speakers
358
+ self.emo_q_proj = nn.Linear(1024, hidden_channels)
359
+
360
+ self.encoder = attentions.Encoder(
361
+ hidden_channels,
362
+ filter_channels,
363
+ n_heads,
364
+ n_layers,
365
+ kernel_size,
366
+ p_dropout,
367
+ gin_channels=self.gin_channels,
368
+ )
369
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
370
+
371
+ def forward(
372
+ self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
373
+ ):
374
+ sid = sid.cpu()
375
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
376
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
377
+ en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
378
+ if emo.size(-1) == 1024:
379
+ emo_emb = self.emo_proj(emo.unsqueeze(1))
380
+ emo_commit_loss = torch.zeros(1)
381
+ emo_emb_ = []
382
+ for i in range(emo_emb.size(0)):
383
+ temp_emo_emb, _, temp_emo_commit_loss = self.emo_quantizer[sid[i]](
384
+ emo_emb[i].unsqueeze(0).cpu()
385
+ )
386
+ emo_commit_loss += temp_emo_commit_loss
387
+ emo_emb_.append(temp_emo_emb)
388
+ emo_emb = torch.cat(emo_emb_, dim=0).to(emo_emb.device)
389
+ emo_commit_loss = emo_commit_loss.to(emo_emb.device)
390
+ else:
391
+ emo_emb = (
392
+ self.emo_quantizer[sid[0]]
393
+ .get_output_from_indices(emo.to(torch.int).cpu())
394
+ .unsqueeze(0)
395
+ .to(emo.device)
396
+ )
397
+ emo_commit_loss = torch.zeros(1)
398
+ x = (
399
+ self.emb(x)
400
+ + self.tone_emb(tone)
401
+ + self.language_emb(language)
402
+ + bert_emb
403
+ + ja_bert_emb
404
+ + en_bert_emb
405
+ + self.emo_q_proj(emo_emb)
406
+ ) * math.sqrt(
407
+ self.hidden_channels
408
+ ) # [b, t, h]
409
+ x = torch.transpose(x, 1, -1) # [b, h, t]
410
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
411
+ x.dtype
412
+ )
413
+
414
+ x = self.encoder(x * x_mask, x_mask, g=g)
415
+ stats = self.proj(x) * x_mask
416
+
417
+ m, logs = torch.split(stats, self.out_channels, dim=1)
418
+ return x, m, logs, x_mask, emo_commit_loss
419
+
420
+
421
+ class ResidualCouplingBlock(nn.Module):
422
+ def __init__(
423
+ self,
424
+ channels,
425
+ hidden_channels,
426
+ kernel_size,
427
+ dilation_rate,
428
+ n_layers,
429
+ n_flows=4,
430
+ gin_channels=0,
431
+ ):
432
+ super().__init__()
433
+ self.channels = channels
434
+ self.hidden_channels = hidden_channels
435
+ self.kernel_size = kernel_size
436
+ self.dilation_rate = dilation_rate
437
+ self.n_layers = n_layers
438
+ self.n_flows = n_flows
439
+ self.gin_channels = gin_channels
440
+
441
+ self.flows = nn.ModuleList()
442
+ for i in range(n_flows):
443
+ self.flows.append(
444
+ modules.ResidualCouplingLayer(
445
+ channels,
446
+ hidden_channels,
447
+ kernel_size,
448
+ dilation_rate,
449
+ n_layers,
450
+ gin_channels=gin_channels,
451
+ mean_only=True,
452
+ )
453
+ )
454
+ self.flows.append(modules.Flip())
455
+
456
+ def forward(self, x, x_mask, g=None, reverse=False):
457
+ if not reverse:
458
+ for flow in self.flows:
459
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
460
+ else:
461
+ for flow in reversed(self.flows):
462
+ x = flow(x, x_mask, g=g, reverse=reverse)
463
+ return x
464
+
465
+
466
+ class PosteriorEncoder(nn.Module):
467
+ def __init__(
468
+ self,
469
+ in_channels,
470
+ out_channels,
471
+ hidden_channels,
472
+ kernel_size,
473
+ dilation_rate,
474
+ n_layers,
475
+ gin_channels=0,
476
+ ):
477
+ super().__init__()
478
+ self.in_channels = in_channels
479
+ self.out_channels = out_channels
480
+ self.hidden_channels = hidden_channels
481
+ self.kernel_size = kernel_size
482
+ self.dilation_rate = dilation_rate
483
+ self.n_layers = n_layers
484
+ self.gin_channels = gin_channels
485
+
486
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
487
+ self.enc = modules.WN(
488
+ hidden_channels,
489
+ kernel_size,
490
+ dilation_rate,
491
+ n_layers,
492
+ gin_channels=gin_channels,
493
+ )
494
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
495
+
496
+ def forward(self, x, x_lengths, g=None):
497
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
498
+ x.dtype
499
+ )
500
+ x = self.pre(x) * x_mask
501
+ x = self.enc(x, x_mask, g=g)
502
+ stats = self.proj(x) * x_mask
503
+ m, logs = torch.split(stats, self.out_channels, dim=1)
504
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
505
+ return z, m, logs, x_mask
506
+
507
+
508
+ class Generator(torch.nn.Module):
509
+ def __init__(
510
+ self,
511
+ initial_channel,
512
+ resblock,
513
+ resblock_kernel_sizes,
514
+ resblock_dilation_sizes,
515
+ upsample_rates,
516
+ upsample_initial_channel,
517
+ upsample_kernel_sizes,
518
+ gin_channels=0,
519
+ ):
520
+ super(Generator, self).__init__()
521
+ self.num_kernels = len(resblock_kernel_sizes)
522
+ self.num_upsamples = len(upsample_rates)
523
+ self.conv_pre = Conv1d(
524
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
525
+ )
526
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
527
+
528
+ self.ups = nn.ModuleList()
529
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
530
+ self.ups.append(
531
+ weight_norm(
532
+ ConvTranspose1d(
533
+ upsample_initial_channel // (2**i),
534
+ upsample_initial_channel // (2 ** (i + 1)),
535
+ k,
536
+ u,
537
+ padding=(k - u) // 2,
538
+ )
539
+ )
540
+ )
541
+
542
+ self.resblocks = nn.ModuleList()
543
+ for i in range(len(self.ups)):
544
+ ch = upsample_initial_channel // (2 ** (i + 1))
545
+ for j, (k, d) in enumerate(
546
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
547
+ ):
548
+ self.resblocks.append(resblock(ch, k, d))
549
+
550
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
551
+ self.ups.apply(init_weights)
552
+
553
+ if gin_channels != 0:
554
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
555
+
556
+ def forward(self, x, g=None):
557
+ x = self.conv_pre(x)
558
+ if g is not None:
559
+ x = x + self.cond(g)
560
+
561
+ for i in range(self.num_upsamples):
562
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
563
+ x = self.ups[i](x)
564
+ xs = None
565
+ for j in range(self.num_kernels):
566
+ if xs is None:
567
+ xs = self.resblocks[i * self.num_kernels + j](x)
568
+ else:
569
+ xs += self.resblocks[i * self.num_kernels + j](x)
570
+ x = xs / self.num_kernels
571
+ x = F.leaky_relu(x)
572
+ x = self.conv_post(x)
573
+ x = torch.tanh(x)
574
+
575
+ return x
576
+
577
+ def remove_weight_norm(self):
578
+ print("Removing weight norm...")
579
+ for layer in self.ups:
580
+ remove_weight_norm(layer)
581
+ for layer in self.resblocks:
582
+ layer.remove_weight_norm()
583
+
584
+
585
+ class DiscriminatorP(torch.nn.Module):
586
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
587
+ super(DiscriminatorP, self).__init__()
588
+ self.period = period
589
+ self.use_spectral_norm = use_spectral_norm
590
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
591
+ self.convs = nn.ModuleList(
592
+ [
593
+ norm_f(
594
+ Conv2d(
595
+ 1,
596
+ 32,
597
+ (kernel_size, 1),
598
+ (stride, 1),
599
+ padding=(get_padding(kernel_size, 1), 0),
600
+ )
601
+ ),
602
+ norm_f(
603
+ Conv2d(
604
+ 32,
605
+ 128,
606
+ (kernel_size, 1),
607
+ (stride, 1),
608
+ padding=(get_padding(kernel_size, 1), 0),
609
+ )
610
+ ),
611
+ norm_f(
612
+ Conv2d(
613
+ 128,
614
+ 512,
615
+ (kernel_size, 1),
616
+ (stride, 1),
617
+ padding=(get_padding(kernel_size, 1), 0),
618
+ )
619
+ ),
620
+ norm_f(
621
+ Conv2d(
622
+ 512,
623
+ 1024,
624
+ (kernel_size, 1),
625
+ (stride, 1),
626
+ padding=(get_padding(kernel_size, 1), 0),
627
+ )
628
+ ),
629
+ norm_f(
630
+ Conv2d(
631
+ 1024,
632
+ 1024,
633
+ (kernel_size, 1),
634
+ 1,
635
+ padding=(get_padding(kernel_size, 1), 0),
636
+ )
637
+ ),
638
+ ]
639
+ )
640
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
641
+
642
+ def forward(self, x):
643
+ fmap = []
644
+
645
+ # 1d to 2d
646
+ b, c, t = x.shape
647
+ if t % self.period != 0: # pad first
648
+ n_pad = self.period - (t % self.period)
649
+ x = F.pad(x, (0, n_pad), "reflect")
650
+ t = t + n_pad
651
+ x = x.view(b, c, t // self.period, self.period)
652
+
653
+ for layer in self.convs:
654
+ x = layer(x)
655
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
656
+ fmap.append(x)
657
+ x = self.conv_post(x)
658
+ fmap.append(x)
659
+ x = torch.flatten(x, 1, -1)
660
+
661
+ return x, fmap
662
+
663
+
664
+ class DiscriminatorS(torch.nn.Module):
665
+ def __init__(self, use_spectral_norm=False):
666
+ super(DiscriminatorS, self).__init__()
667
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
668
+ self.convs = nn.ModuleList(
669
+ [
670
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
671
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
672
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
673
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
674
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
675
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
676
+ ]
677
+ )
678
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
679
+
680
+ def forward(self, x):
681
+ fmap = []
682
+
683
+ for layer in self.convs:
684
+ x = layer(x)
685
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
686
+ fmap.append(x)
687
+ x = self.conv_post(x)
688
+ fmap.append(x)
689
+ x = torch.flatten(x, 1, -1)
690
+
691
+ return x, fmap
692
+
693
+
694
+ class MultiPeriodDiscriminator(torch.nn.Module):
695
+ def __init__(self, use_spectral_norm=False):
696
+ super(MultiPeriodDiscriminator, self).__init__()
697
+ periods = [2, 3, 5, 7, 11]
698
+
699
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
700
+ discs = discs + [
701
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
702
+ ]
703
+ self.discriminators = nn.ModuleList(discs)
704
+
705
+ def forward(self, y, y_hat):
706
+ y_d_rs = []
707
+ y_d_gs = []
708
+ fmap_rs = []
709
+ fmap_gs = []
710
+ for i, d in enumerate(self.discriminators):
711
+ y_d_r, fmap_r = d(y)
712
+ y_d_g, fmap_g = d(y_hat)
713
+ y_d_rs.append(y_d_r)
714
+ y_d_gs.append(y_d_g)
715
+ fmap_rs.append(fmap_r)
716
+ fmap_gs.append(fmap_g)
717
+
718
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
719
+
720
+
721
+ class ReferenceEncoder(nn.Module):
722
+ """
723
+ inputs --- [N, Ty/r, n_mels*r] mels
724
+ outputs --- [N, ref_enc_gru_size]
725
+ """
726
+
727
+ def __init__(self, spec_channels, gin_channels=0):
728
+ super().__init__()
729
+ self.spec_channels = spec_channels
730
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
731
+ K = len(ref_enc_filters)
732
+ filters = [1] + ref_enc_filters
733
+ convs = [
734
+ weight_norm(
735
+ nn.Conv2d(
736
+ in_channels=filters[i],
737
+ out_channels=filters[i + 1],
738
+ kernel_size=(3, 3),
739
+ stride=(2, 2),
740
+ padding=(1, 1),
741
+ )
742
+ )
743
+ for i in range(K)
744
+ ]
745
+ self.convs = nn.ModuleList(convs)
746
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
747
+
748
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
749
+ self.gru = nn.GRU(
750
+ input_size=ref_enc_filters[-1] * out_channels,
751
+ hidden_size=256 // 2,
752
+ batch_first=True,
753
+ )
754
+ self.proj = nn.Linear(128, gin_channels)
755
+
756
+ def forward(self, inputs, mask=None):
757
+ N = inputs.size(0)
758
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
759
+ for conv in self.convs:
760
+ out = conv(out)
761
+ # out = wn(out)
762
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
763
+
764
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
765
+ T = out.size(1)
766
+ N = out.size(0)
767
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
768
+
769
+ self.gru.flatten_parameters()
770
+ memory, out = self.gru(out) # out --- [1, N, 128]
771
+
772
+ return self.proj(out.squeeze(0))
773
+
774
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
775
+ for i in range(n_convs):
776
+ L = (L - kernel_size + 2 * pad) // stride + 1
777
+ return L
778
+
779
+
780
+ class SynthesizerTrn(nn.Module):
781
+ """
782
+ Synthesizer for Training
783
+ """
784
+
785
+ def __init__(
786
+ self,
787
+ n_vocab,
788
+ spec_channels,
789
+ segment_size,
790
+ inter_channels,
791
+ hidden_channels,
792
+ filter_channels,
793
+ n_heads,
794
+ n_layers,
795
+ kernel_size,
796
+ p_dropout,
797
+ resblock,
798
+ resblock_kernel_sizes,
799
+ resblock_dilation_sizes,
800
+ upsample_rates,
801
+ upsample_initial_channel,
802
+ upsample_kernel_sizes,
803
+ n_speakers=256,
804
+ gin_channels=256,
805
+ use_sdp=True,
806
+ n_flow_layer=4,
807
+ n_layers_trans_flow=4,
808
+ flow_share_parameter=False,
809
+ use_transformer_flow=True,
810
+ **kwargs
811
+ ):
812
+ super().__init__()
813
+ self.n_vocab = n_vocab
814
+ self.spec_channels = spec_channels
815
+ self.inter_channels = inter_channels
816
+ self.hidden_channels = hidden_channels
817
+ self.filter_channels = filter_channels
818
+ self.n_heads = n_heads
819
+ self.n_layers = n_layers
820
+ self.kernel_size = kernel_size
821
+ self.p_dropout = p_dropout
822
+ self.resblock = resblock
823
+ self.resblock_kernel_sizes = resblock_kernel_sizes
824
+ self.resblock_dilation_sizes = resblock_dilation_sizes
825
+ self.upsample_rates = upsample_rates
826
+ self.upsample_initial_channel = upsample_initial_channel
827
+ self.upsample_kernel_sizes = upsample_kernel_sizes
828
+ self.segment_size = segment_size
829
+ self.n_speakers = n_speakers
830
+ self.gin_channels = gin_channels
831
+ self.n_layers_trans_flow = n_layers_trans_flow
832
+ self.use_spk_conditioned_encoder = kwargs.get(
833
+ "use_spk_conditioned_encoder", True
834
+ )
835
+ self.use_sdp = use_sdp
836
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
837
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
838
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
839
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
840
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
841
+ self.enc_gin_channels = gin_channels
842
+ self.enc_p = TextEncoder(
843
+ n_vocab,
844
+ inter_channels,
845
+ hidden_channels,
846
+ filter_channels,
847
+ n_heads,
848
+ n_layers,
849
+ kernel_size,
850
+ p_dropout,
851
+ self.n_speakers,
852
+ gin_channels=self.enc_gin_channels,
853
+ )
854
+ self.dec = Generator(
855
+ inter_channels,
856
+ resblock,
857
+ resblock_kernel_sizes,
858
+ resblock_dilation_sizes,
859
+ upsample_rates,
860
+ upsample_initial_channel,
861
+ upsample_kernel_sizes,
862
+ gin_channels=gin_channels,
863
+ )
864
+ self.enc_q = PosteriorEncoder(
865
+ spec_channels,
866
+ inter_channels,
867
+ hidden_channels,
868
+ 5,
869
+ 1,
870
+ 16,
871
+ gin_channels=gin_channels,
872
+ )
873
+ if use_transformer_flow:
874
+ self.flow = TransformerCouplingBlock(
875
+ inter_channels,
876
+ hidden_channels,
877
+ filter_channels,
878
+ n_heads,
879
+ n_layers_trans_flow,
880
+ 5,
881
+ p_dropout,
882
+ n_flow_layer,
883
+ gin_channels=gin_channels,
884
+ share_parameter=flow_share_parameter,
885
+ )
886
+ else:
887
+ self.flow = ResidualCouplingBlock(
888
+ inter_channels,
889
+ hidden_channels,
890
+ 5,
891
+ 1,
892
+ n_flow_layer,
893
+ gin_channels=gin_channels,
894
+ )
895
+ self.sdp = StochasticDurationPredictor(
896
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
897
+ )
898
+ self.dp = DurationPredictor(
899
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
900
+ )
901
+
902
+ if n_speakers >= 1:
903
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
904
+ else:
905
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
906
+
907
+ def forward(
908
+ self,
909
+ x,
910
+ x_lengths,
911
+ y,
912
+ y_lengths,
913
+ sid,
914
+ tone,
915
+ language,
916
+ bert,
917
+ ja_bert,
918
+ en_bert,
919
+ emo=None,
920
+ ):
921
+ if self.n_speakers > 0:
922
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
923
+ else:
924
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
925
+ x, m_p, logs_p, x_mask, loss_commit = self.enc_p(
926
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
927
+ )
928
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
929
+ z_p = self.flow(z, y_mask, g=g)
930
+
931
+ with torch.no_grad():
932
+ # negative cross-entropy
933
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
934
+ neg_cent1 = torch.sum(
935
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
936
+ ) # [b, 1, t_s]
937
+ neg_cent2 = torch.matmul(
938
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
939
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
940
+ neg_cent3 = torch.matmul(
941
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
942
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
943
+ neg_cent4 = torch.sum(
944
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
945
+ ) # [b, 1, t_s]
946
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
947
+ if self.use_noise_scaled_mas:
948
+ epsilon = (
949
+ torch.std(neg_cent)
950
+ * torch.randn_like(neg_cent)
951
+ * self.current_mas_noise_scale
952
+ )
953
+ neg_cent = neg_cent + epsilon
954
+
955
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
956
+ attn = (
957
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
958
+ .unsqueeze(1)
959
+ .detach()
960
+ )
961
+
962
+ w = attn.sum(2)
963
+
964
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
965
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
966
+
967
+ logw_ = torch.log(w + 1e-6) * x_mask
968
+ logw = self.dp(x, x_mask, g=g)
969
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
970
+ x_mask
971
+ ) # for averaging
972
+
973
+ l_length = l_length_dp + l_length_sdp
974
+
975
+ # expand prior
976
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
977
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
978
+
979
+ z_slice, ids_slice = commons.rand_slice_segments(
980
+ z, y_lengths, self.segment_size
981
+ )
982
+ o = self.dec(z_slice, g=g)
983
+ return (
984
+ o,
985
+ l_length,
986
+ attn,
987
+ ids_slice,
988
+ x_mask,
989
+ y_mask,
990
+ (z, z_p, m_p, logs_p, m_q, logs_q),
991
+ (x, logw, logw_),
992
+ loss_commit,
993
+ )
994
+
995
+ def infer(
996
+ self,
997
+ x,
998
+ x_lengths,
999
+ sid,
1000
+ tone,
1001
+ language,
1002
+ bert,
1003
+ ja_bert,
1004
+ en_bert,
1005
+ emo=None,
1006
+ noise_scale=0.667,
1007
+ length_scale=1,
1008
+ noise_scale_w=0.8,
1009
+ max_len=None,
1010
+ sdp_ratio=0,
1011
+ y=None,
1012
+ ):
1013
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1014
+ # g = self.gst(y)
1015
+ if self.n_speakers > 0:
1016
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1017
+ else:
1018
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1019
+ x, m_p, logs_p, x_mask, _ = self.enc_p(
1020
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
1021
+ )
1022
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1023
+ sdp_ratio
1024
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1025
+ w = torch.exp(logw) * x_mask * length_scale
1026
+ w_ceil = torch.ceil(w)
1027
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1028
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1029
+ x_mask.dtype
1030
+ )
1031
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1032
+ attn = commons.generate_path(w_ceil, attn_mask)
1033
+
1034
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1035
+ 1, 2
1036
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1037
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1038
+ 1, 2
1039
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1040
+
1041
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1042
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1043
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1044
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
modules.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ import commons
10
+ from commons import init_weights, get_padding
11
+ from transforms import piecewise_rational_quadratic_transform
12
+ from attentions import Encoder
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(LayerNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DDSConv(nn.Module):
85
+ """
86
+ Dialted and Depth-Separable Convolution
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
+ super().__init__()
91
+ self.channels = channels
92
+ self.kernel_size = kernel_size
93
+ self.n_layers = n_layers
94
+ self.p_dropout = p_dropout
95
+
96
+ self.drop = nn.Dropout(p_dropout)
97
+ self.convs_sep = nn.ModuleList()
98
+ self.convs_1x1 = nn.ModuleList()
99
+ self.norms_1 = nn.ModuleList()
100
+ self.norms_2 = nn.ModuleList()
101
+ for i in range(n_layers):
102
+ dilation = kernel_size**i
103
+ padding = (kernel_size * dilation - dilation) // 2
104
+ self.convs_sep.append(
105
+ nn.Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ groups=channels,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ )
114
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
+ self.norms_1.append(LayerNorm(channels))
116
+ self.norms_2.append(LayerNorm(channels))
117
+
118
+ def forward(self, x, x_mask, g=None):
119
+ if g is not None:
120
+ x = x + g
121
+ for i in range(self.n_layers):
122
+ y = self.convs_sep[i](x * x_mask)
123
+ y = self.norms_1[i](y)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+ in_layer = torch.nn.Conv1d(
166
+ hidden_channels,
167
+ 2 * hidden_channels,
168
+ kernel_size,
169
+ dilation=dilation,
170
+ padding=padding,
171
+ )
172
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
+ self.in_layers.append(in_layer)
174
+
175
+ # last one is not necessary
176
+ if i < n_layers - 1:
177
+ res_skip_channels = 2 * hidden_channels
178
+ else:
179
+ res_skip_channels = hidden_channels
180
+
181
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
+ self.res_skip_layers.append(res_skip_layer)
184
+
185
+ def forward(self, x, x_mask, g=None, **kwargs):
186
+ output = torch.zeros_like(x)
187
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
+
189
+ if g is not None:
190
+ g = self.cond_layer(g)
191
+
192
+ for i in range(self.n_layers):
193
+ x_in = self.in_layers[i](x)
194
+ if g is not None:
195
+ cond_offset = i * 2 * self.hidden_channels
196
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
+ else:
198
+ g_l = torch.zeros_like(x_in)
199
+
200
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
+ acts = self.drop(acts)
202
+
203
+ res_skip_acts = self.res_skip_layers[i](acts)
204
+ if i < self.n_layers - 1:
205
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
+ x = (x + res_acts) * x_mask
207
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
208
+ else:
209
+ output = output + res_skip_acts
210
+ return output * x_mask
211
+
212
+ def remove_weight_norm(self):
213
+ if self.gin_channels != 0:
214
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
215
+ for l in self.in_layers:
216
+ torch.nn.utils.remove_weight_norm(l)
217
+ for l in self.res_skip_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+
220
+
221
+ class ResBlock1(torch.nn.Module):
222
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
+ super(ResBlock1, self).__init__()
224
+ self.convs1 = nn.ModuleList(
225
+ [
226
+ weight_norm(
227
+ Conv1d(
228
+ channels,
229
+ channels,
230
+ kernel_size,
231
+ 1,
232
+ dilation=dilation[0],
233
+ padding=get_padding(kernel_size, dilation[0]),
234
+ )
235
+ ),
236
+ weight_norm(
237
+ Conv1d(
238
+ channels,
239
+ channels,
240
+ kernel_size,
241
+ 1,
242
+ dilation=dilation[1],
243
+ padding=get_padding(kernel_size, dilation[1]),
244
+ )
245
+ ),
246
+ weight_norm(
247
+ Conv1d(
248
+ channels,
249
+ channels,
250
+ kernel_size,
251
+ 1,
252
+ dilation=dilation[2],
253
+ padding=get_padding(kernel_size, dilation[2]),
254
+ )
255
+ ),
256
+ ]
257
+ )
258
+ self.convs1.apply(init_weights)
259
+
260
+ self.convs2 = nn.ModuleList(
261
+ [
262
+ weight_norm(
263
+ Conv1d(
264
+ channels,
265
+ channels,
266
+ kernel_size,
267
+ 1,
268
+ dilation=1,
269
+ padding=get_padding(kernel_size, 1),
270
+ )
271
+ ),
272
+ weight_norm(
273
+ Conv1d(
274
+ channels,
275
+ channels,
276
+ kernel_size,
277
+ 1,
278
+ dilation=1,
279
+ padding=get_padding(kernel_size, 1),
280
+ )
281
+ ),
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=1,
289
+ padding=get_padding(kernel_size, 1),
290
+ )
291
+ ),
292
+ ]
293
+ )
294
+ self.convs2.apply(init_weights)
295
+
296
+ def forward(self, x, x_mask=None):
297
+ for c1, c2 in zip(self.convs1, self.convs2):
298
+ xt = F.leaky_relu(x, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c1(xt)
302
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
303
+ if x_mask is not None:
304
+ xt = xt * x_mask
305
+ xt = c2(xt)
306
+ x = xt + x
307
+ if x_mask is not None:
308
+ x = x * x_mask
309
+ return x
310
+
311
+ def remove_weight_norm(self):
312
+ for l in self.convs1:
313
+ remove_weight_norm(l)
314
+ for l in self.convs2:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class ResBlock2(torch.nn.Module):
319
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
+ super(ResBlock2, self).__init__()
321
+ self.convs = nn.ModuleList(
322
+ [
323
+ weight_norm(
324
+ Conv1d(
325
+ channels,
326
+ channels,
327
+ kernel_size,
328
+ 1,
329
+ dilation=dilation[0],
330
+ padding=get_padding(kernel_size, dilation[0]),
331
+ )
332
+ ),
333
+ weight_norm(
334
+ Conv1d(
335
+ channels,
336
+ channels,
337
+ kernel_size,
338
+ 1,
339
+ dilation=dilation[1],
340
+ padding=get_padding(kernel_size, dilation[1]),
341
+ )
342
+ ),
343
+ ]
344
+ )
345
+ self.convs.apply(init_weights)
346
+
347
+ def forward(self, x, x_mask=None):
348
+ for c in self.convs:
349
+ xt = F.leaky_relu(x, LRELU_SLOPE)
350
+ if x_mask is not None:
351
+ xt = xt * x_mask
352
+ xt = c(xt)
353
+ x = xt + x
354
+ if x_mask is not None:
355
+ x = x * x_mask
356
+ return x
357
+
358
+ def remove_weight_norm(self):
359
+ for l in self.convs:
360
+ remove_weight_norm(l)
361
+
362
+
363
+ class Log(nn.Module):
364
+ def forward(self, x, x_mask, reverse=False, **kwargs):
365
+ if not reverse:
366
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
+ logdet = torch.sum(-y, [1, 2])
368
+ return y, logdet
369
+ else:
370
+ x = torch.exp(x) * x_mask
371
+ return x
372
+
373
+
374
+ class Flip(nn.Module):
375
+ def forward(self, x, *args, reverse=False, **kwargs):
376
+ x = torch.flip(x, [1])
377
+ if not reverse:
378
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
+ return x, logdet
380
+ else:
381
+ return x
382
+
383
+
384
+ class ElementwiseAffine(nn.Module):
385
+ def __init__(self, channels):
386
+ super().__init__()
387
+ self.channels = channels
388
+ self.m = nn.Parameter(torch.zeros(channels, 1))
389
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
390
+
391
+ def forward(self, x, x_mask, reverse=False, **kwargs):
392
+ if not reverse:
393
+ y = self.m + torch.exp(self.logs) * x
394
+ y = y * x_mask
395
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
396
+ return y, logdet
397
+ else:
398
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
+ return x
400
+
401
+
402
+ class ResidualCouplingLayer(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ p_dropout=0,
411
+ gin_channels=0,
412
+ mean_only=False,
413
+ ):
414
+ assert channels % 2 == 0, "channels should be divisible by 2"
415
+ super().__init__()
416
+ self.channels = channels
417
+ self.hidden_channels = hidden_channels
418
+ self.kernel_size = kernel_size
419
+ self.dilation_rate = dilation_rate
420
+ self.n_layers = n_layers
421
+ self.half_channels = channels // 2
422
+ self.mean_only = mean_only
423
+
424
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
+ self.enc = WN(
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ p_dropout=p_dropout,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
+ self.post.weight.data.zero_()
435
+ self.post.bias.data.zero_()
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
+ h = self.pre(x0) * x_mask
440
+ h = self.enc(h, x_mask, g=g)
441
+ stats = self.post(h) * x_mask
442
+ if not self.mean_only:
443
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
+ else:
445
+ m = stats
446
+ logs = torch.zeros_like(m)
447
+
448
+ if not reverse:
449
+ x1 = m + x1 * torch.exp(logs) * x_mask
450
+ x = torch.cat([x0, x1], 1)
451
+ logdet = torch.sum(logs, [1, 2])
452
+ return x, logdet
453
+ else:
454
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
+ x = torch.cat([x0, x1], 1)
456
+ return x
457
+
458
+
459
+ class ConvFlow(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels,
463
+ filter_channels,
464
+ kernel_size,
465
+ n_layers,
466
+ num_bins=10,
467
+ tail_bound=5.0,
468
+ ):
469
+ super().__init__()
470
+ self.in_channels = in_channels
471
+ self.filter_channels = filter_channels
472
+ self.kernel_size = kernel_size
473
+ self.n_layers = n_layers
474
+ self.num_bins = num_bins
475
+ self.tail_bound = tail_bound
476
+ self.half_channels = in_channels // 2
477
+
478
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
+ self.proj = nn.Conv1d(
481
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
+ )
483
+ self.proj.weight.data.zero_()
484
+ self.proj.bias.data.zero_()
485
+
486
+ def forward(self, x, x_mask, g=None, reverse=False):
487
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
+ h = self.pre(x0)
489
+ h = self.convs(h, x_mask, g=g)
490
+ h = self.proj(h) * x_mask
491
+
492
+ b, c, t = x0.shape
493
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
+
495
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
+ self.filter_channels
498
+ )
499
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
+
501
+ x1, logabsdet = piecewise_rational_quadratic_transform(
502
+ x1,
503
+ unnormalized_widths,
504
+ unnormalized_heights,
505
+ unnormalized_derivatives,
506
+ inverse=reverse,
507
+ tails="linear",
508
+ tail_bound=self.tail_bound,
509
+ )
510
+
511
+ x = torch.cat([x0, x1], 1) * x_mask
512
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
+ if not reverse:
514
+ return x, logdet
515
+ else:
516
+ return x
517
+
518
+
519
+ class TransformerCouplingLayer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ channels,
523
+ hidden_channels,
524
+ kernel_size,
525
+ n_layers,
526
+ n_heads,
527
+ p_dropout=0,
528
+ filter_channels=0,
529
+ mean_only=False,
530
+ wn_sharing_parameter=None,
531
+ gin_channels=0,
532
+ ):
533
+ assert channels % 2 == 0, "channels should be divisible by 2"
534
+ super().__init__()
535
+ self.channels = channels
536
+ self.hidden_channels = hidden_channels
537
+ self.kernel_size = kernel_size
538
+ self.n_layers = n_layers
539
+ self.half_channels = channels // 2
540
+ self.mean_only = mean_only
541
+
542
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
543
+ self.enc = (
544
+ Encoder(
545
+ hidden_channels,
546
+ filter_channels,
547
+ n_heads,
548
+ n_layers,
549
+ kernel_size,
550
+ p_dropout,
551
+ isflow=True,
552
+ gin_channels=gin_channels,
553
+ )
554
+ if wn_sharing_parameter is None
555
+ else wn_sharing_parameter
556
+ )
557
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
558
+ self.post.weight.data.zero_()
559
+ self.post.bias.data.zero_()
560
+
561
+ def forward(self, x, x_mask, g=None, reverse=False):
562
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
563
+ h = self.pre(x0) * x_mask
564
+ h = self.enc(h, x_mask, g=g)
565
+ stats = self.post(h) * x_mask
566
+ if not self.mean_only:
567
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
568
+ else:
569
+ m = stats
570
+ logs = torch.zeros_like(m)
571
+
572
+ if not reverse:
573
+ x1 = m + x1 * torch.exp(logs) * x_mask
574
+ x = torch.cat([x0, x1], 1)
575
+ logdet = torch.sum(logs, [1, 2])
576
+ return x, logdet
577
+ else:
578
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
579
+ x = torch.cat([x0, x1], 1)
580
+ return x
581
+
582
+ x1, logabsdet = piecewise_rational_quadratic_transform(
583
+ x1,
584
+ unnormalized_widths,
585
+ unnormalized_heights,
586
+ unnormalized_derivatives,
587
+ inverse=reverse,
588
+ tails="linear",
589
+ tail_bound=self.tail_bound,
590
+ )
591
+
592
+ x = torch.cat([x0, x1], 1) * x_mask
593
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
594
+ if not reverse:
595
+ return x, logdet
596
+ else:
597
+ return x
monotonic_align/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
re_matching.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def extract_language_and_text_updated(speaker, dialogue):
5
+ # 使用正则表达式匹配<语言>标签和其后的文本
6
+ pattern_language_text = r"<(\S+?)>([^<]+)"
7
+ matches = re.findall(pattern_language_text, dialogue, re.DOTALL)
8
+ speaker = speaker[1:-1]
9
+ # 清理文本:去除两边的空白字符
10
+ matches_cleaned = [(lang.upper(), text.strip()) for lang, text in matches]
11
+ matches_cleaned.append(speaker)
12
+ return matches_cleaned
13
+
14
+
15
+ def validate_text(input_text):
16
+ # 验证说话人的正则表达式
17
+ pattern_speaker = r"(\[\S+?\])((?:\s*<\S+?>[^<\[\]]+?)+)"
18
+
19
+ # 使用re.DOTALL标志使.匹配包括换行符在内的所有字符
20
+ matches = re.findall(pattern_speaker, input_text, re.DOTALL)
21
+
22
+ # 对每个匹配到的说话人内容进行进一步验证
23
+ for _, dialogue in matches:
24
+ language_text_matches = extract_language_and_text_updated(_, dialogue)
25
+ if not language_text_matches:
26
+ return (
27
+ False,
28
+ "Error: Invalid format detected in dialogue content. Please check your input.",
29
+ )
30
+
31
+ # 如果输入的文本中没有找到任何匹配项
32
+ if not matches:
33
+ return (
34
+ False,
35
+ "Error: No valid speaker format detected. Please check your input.",
36
+ )
37
+
38
+ return True, "Input is valid."
39
+
40
+
41
+ def text_matching(text: str) -> list:
42
+ speaker_pattern = r"(\[\S+?\])(.+?)(?=\[\S+?\]|$)"
43
+ matches = re.findall(speaker_pattern, text, re.DOTALL)
44
+ result = []
45
+ for speaker, dialogue in matches:
46
+ result.append(extract_language_and_text_updated(speaker, dialogue))
47
+ print(result)
48
+ return result
49
+
50
+
51
+ def cut_para(text):
52
+ splitted_para = re.split("[\n]", text) # 按段分
53
+ splitted_para = [
54
+ sentence.strip() for sentence in splitted_para if sentence.strip()
55
+ ] # 删除空字符串
56
+ return splitted_para
57
+
58
+
59
+ def cut_sent(para):
60
+ para = re.sub("([。!;?\?])([^”’])", r"\1\n\2", para) # 单字符断句符
61
+ para = re.sub("(\.{6})([^”’])", r"\1\n\2", para) # 英文省略号
62
+ para = re.sub("(\…{2})([^”’])", r"\1\n\2", para) # 中文省略号
63
+ para = re.sub("([。!?\?][”’])([^,。!?\?])", r"\1\n\2", para)
64
+ para = para.rstrip() # 段尾如果有多余的\n就去掉它
65
+ return para.split("\n")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ text = """
70
+ [说话人1]
71
+ [说话人2]<zh>你好吗?<jp>元気ですか?<jp>こんにちは,世界。<zh>你好吗?
72
+ [说话人3]<zh>谢谢。<jp>どういたしまして。
73
+ """
74
+ text_matching(text)
75
+ # 测试函数
76
+ test_text = """
77
+ [说话人1]<zh>你好,こんにちは!<jp>こんにちは,世界。
78
+ [说话人2]<zh>你好吗?
79
+ """
80
+ text_matching(test_text)
81
+ res = validate_text(test_text)
82
+ print(res)
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librosa==0.9.1
2
+ matplotlib
3
+ numpy
4
+ numba
5
+ phonemizer
6
+ scipy
7
+ tensorboard
8
+ Unidecode
9
+ amfm_decompy
10
+ jieba
11
+ transformers
12
+ pypinyin
13
+ cn2an
14
+ gradio
15
+ av
16
+ mecab-python3
17
+ loguru
18
+ unidic-lite
19
+ cmudict
20
+ fugashi
21
+ num2words
22
+ PyYAML
23
+ requests
24
+ pyopenjtalk; sys_platform == 'linux'
25
+ openjtalk; sys_platform != 'linux'
26
+ jaconv
27
+ psutil
28
+ GPUtil
29
+ vector_quantize_pytorch
30
+ g2p_en
31
+ sentencepiece
32
+ pykakasi
33
+ langid
text/__init__.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text.symbols import *
2
+
3
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
4
+
5
+
6
+ def cleaned_text_to_sequence(cleaned_text, tones, language):
7
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
8
+ Args:
9
+ text: string to convert to a sequence
10
+ Returns:
11
+ List of integers corresponding to the symbols in the text
12
+ """
13
+ phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
14
+ tone_start = language_tone_start_map[language]
15
+ tones = [i + tone_start for i in tones]
16
+ lang_id = language_id_map[language]
17
+ lang_ids = [lang_id for i in phones]
18
+ return phones, tones, lang_ids
19
+
20
+
21
+ def get_bert(norm_text, word2ph, language, device):
22
+ from .chinese_bert import get_bert_feature as zh_bert
23
+ from .english_bert_mock import get_bert_feature as en_bert
24
+ from .japanese_bert import get_bert_feature as jp_bert
25
+
26
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
27
+ bert = lang_bert_func_map[language](norm_text, word2ph, device)
28
+ return bert
29
+
30
+
31
+ def check_bert_models():
32
+ import json
33
+ from pathlib import Path
34
+
35
+ from config import config
36
+ from .bert_utils import _check_bert
37
+
38
+ if config.mirror.lower() == "openi":
39
+ import openi
40
+
41
+ kwargs = {"token": config.openi_token} if config.openi_token else {}
42
+ openi.login(**kwargs)
43
+
44
+ with open("./bert/bert_models.json", "r") as fp:
45
+ models = json.load(fp)
46
+ for k, v in models.items():
47
+ local_path = Path("./bert").joinpath(k)
48
+ _check_bert(v["repo_id"], v["files"], local_path)
49
+
50
+
51
+ #check_bert_models()
text/bert_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ from config import config
6
+
7
+
8
+ MIRROR: str = config.mirror
9
+
10
+
11
+ def _check_bert(repo_id, files, local_path):
12
+ for file in files:
13
+ if not Path(local_path).joinpath(file).exists():
14
+ if MIRROR.lower() == "openi":
15
+ import openi
16
+
17
+ openi.model.download_model(
18
+ "Stardust_minus/Bert-VITS2", repo_id.split("/")[-1], "./bert"
19
+ )
20
+ else:
21
+ hf_hub_download(
22
+ repo_id, file, local_dir=local_path, local_dir_use_symlinks=False
23
+ )
text/chinese.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ from text.symbols import punctuation
8
+ from text.tone_sandhi import ToneSandhi
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ pinyin_to_symbol_map = {
12
+ line.split("\t")[0]: line.strip().split("\t")[1]
13
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
+ }
15
+
16
+ import jieba.posseg as psg
17
+
18
+
19
+ rep_map = {
20
+ ":": ",",
21
+ ";": ",",
22
+ ",": ",",
23
+ "。": ".",
24
+ "!": "!",
25
+ "?": "?",
26
+ "\n": ".",
27
+ "·": ",",
28
+ "、": ",",
29
+ "...": "…",
30
+ "$": ".",
31
+ "“": "'",
32
+ "”": "'",
33
+ '"': "'",
34
+ "‘": "'",
35
+ "’": "'",
36
+ "(": "'",
37
+ ")": "'",
38
+ "(": "'",
39
+ ")": "'",
40
+ "《": "'",
41
+ "》": "'",
42
+ "【": "'",
43
+ "】": "'",
44
+ "[": "'",
45
+ "]": "'",
46
+ "—": "-",
47
+ "~": "-",
48
+ "~": "-",
49
+ "「": "'",
50
+ "」": "'",
51
+ }
52
+
53
+ tone_modifier = ToneSandhi()
54
+
55
+
56
+ def replace_punctuation(text):
57
+ text = text.replace("嗯", "恩").replace("呣", "母")
58
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
59
+
60
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
61
+
62
+ replaced_text = re.sub(
63
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
64
+ )
65
+
66
+ return replaced_text
67
+
68
+
69
+ def g2p(text):
70
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
71
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
72
+ phones, tones, word2ph = _g2p(sentences)
73
+ assert sum(word2ph) == len(phones)
74
+ assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
75
+ phones = ["_"] + phones + ["_"]
76
+ tones = [0] + tones + [0]
77
+ word2ph = [1] + word2ph + [1]
78
+ return phones, tones, word2ph
79
+
80
+
81
+ def _get_initials_finals(word):
82
+ initials = []
83
+ finals = []
84
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
85
+ orig_finals = lazy_pinyin(
86
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
87
+ )
88
+ for c, v in zip(orig_initials, orig_finals):
89
+ initials.append(c)
90
+ finals.append(v)
91
+ return initials, finals
92
+
93
+
94
+ def _g2p(segments):
95
+ phones_list = []
96
+ tones_list = []
97
+ word2ph = []
98
+ for seg in segments:
99
+ # Replace all English words in the sentence
100
+ seg = re.sub("[a-zA-Z]+", "", seg)
101
+ seg_cut = psg.lcut(seg)
102
+ initials = []
103
+ finals = []
104
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
105
+ for word, pos in seg_cut:
106
+ if pos == "eng":
107
+ continue
108
+ sub_initials, sub_finals = _get_initials_finals(word)
109
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
110
+ initials.append(sub_initials)
111
+ finals.append(sub_finals)
112
+
113
+ # assert len(sub_initials) == len(sub_finals) == len(word)
114
+ initials = sum(initials, [])
115
+ finals = sum(finals, [])
116
+ #
117
+ for c, v in zip(initials, finals):
118
+ raw_pinyin = c + v
119
+ # NOTE: post process for pypinyin outputs
120
+ # we discriminate i, ii and iii
121
+ if c == v:
122
+ assert c in punctuation
123
+ phone = [c]
124
+ tone = "0"
125
+ word2ph.append(1)
126
+ else:
127
+ v_without_tone = v[:-1]
128
+ tone = v[-1]
129
+
130
+ pinyin = c + v_without_tone
131
+ assert tone in "12345"
132
+
133
+ if c:
134
+ # 多音节
135
+ v_rep_map = {
136
+ "uei": "ui",
137
+ "iou": "iu",
138
+ "uen": "un",
139
+ }
140
+ if v_without_tone in v_rep_map.keys():
141
+ pinyin = c + v_rep_map[v_without_tone]
142
+ else:
143
+ # 单音节
144
+ pinyin_rep_map = {
145
+ "ing": "ying",
146
+ "i": "yi",
147
+ "in": "yin",
148
+ "u": "wu",
149
+ }
150
+ if pinyin in pinyin_rep_map.keys():
151
+ pinyin = pinyin_rep_map[pinyin]
152
+ else:
153
+ single_rep_map = {
154
+ "v": "yu",
155
+ "e": "e",
156
+ "i": "y",
157
+ "u": "w",
158
+ }
159
+ if pinyin[0] in single_rep_map.keys():
160
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
161
+
162
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
163
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
164
+ word2ph.append(len(phone))
165
+
166
+ phones_list += phone
167
+ tones_list += [int(tone)] * len(phone)
168
+ return phones_list, tones_list, word2ph
169
+
170
+
171
+ def text_normalize(text):
172
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
173
+ for number in numbers:
174
+ text = text.replace(number, cn2an.an2cn(number), 1)
175
+ text = replace_punctuation(text)
176
+ return text
177
+
178
+
179
+ def get_bert_feature(text, word2ph):
180
+ from text import chinese_bert
181
+
182
+ return chinese_bert.get_bert_feature(text, word2ph)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from text.chinese_bert import get_bert_feature
187
+
188
+ text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
189
+ text = text_normalize(text)
190
+ print(text)
191
+ phones, tones, word2ph = g2p(text)
192
+ bert = get_bert_feature(text, word2ph)
193
+
194
+ print(phones, tones, word2ph, bert.shape)
195
+
196
+
197
+ # # 示例用法
198
+ # text = "这是一个示例文本:,你好!这是一个测试...."
199
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
text/chinese_bert.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
5
+
6
+ from config import config
7
+
8
+ #LOCAL_PATH = "./bert/chinese-roberta-wwm-ext-large"
9
+ LOCAL_PATH = 'hfl/chinese-roberta-wwm-ext-large'
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
12
+
13
+ models = dict()
14
+
15
+
16
+ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
17
+ if (
18
+ sys.platform == "darwin"
19
+ and torch.backends.mps.is_available()
20
+ and device == "cpu"
21
+ ):
22
+ device = "mps"
23
+ if not device:
24
+ device = "cuda"
25
+ if device not in models.keys():
26
+ models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
27
+ with torch.no_grad():
28
+ inputs = tokenizer(text, return_tensors="pt")
29
+ for i in inputs:
30
+ inputs[i] = inputs[i].to(device)
31
+ res = models[device](**inputs, output_hidden_states=True)
32
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
33
+
34
+ assert len(word2ph) == len(text) + 2
35
+ word2phone = word2ph
36
+ phone_level_feature = []
37
+ for i in range(len(word2phone)):
38
+ repeat_feature = res[i].repeat(word2phone[i], 1)
39
+ phone_level_feature.append(repeat_feature)
40
+
41
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
42
+
43
+ return phone_level_feature.T
44
+
45
+
46
+ if __name__ == "__main__":
47
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
48
+ word2phone = [
49
+ 1,
50
+ 2,
51
+ 1,
52
+ 2,
53
+ 2,
54
+ 1,
55
+ 2,
56
+ 2,
57
+ 1,
58
+ 2,
59
+ 2,
60
+ 1,
61
+ 2,
62
+ 2,
63
+ 2,
64
+ 2,
65
+ 2,
66
+ 1,
67
+ 1,
68
+ 2,
69
+ 2,
70
+ 1,
71
+ 2,
72
+ 2,
73
+ 2,
74
+ 2,
75
+ 1,
76
+ 2,
77
+ 2,
78
+ 2,
79
+ 2,
80
+ 2,
81
+ 1,
82
+ 2,
83
+ 2,
84
+ 2,
85
+ 2,
86
+ 1,
87
+ ]
88
+
89
+ # 计算总帧数
90
+ total_frames = sum(word2phone)
91
+ print(word_level_feature.shape)
92
+ print(word2phone)
93
+ phone_level_feature = []
94
+ for i in range(len(word2phone)):
95
+ print(word_level_feature[i].shape)
96
+
97
+ # 对每个词重复word2phone[i]次
98
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
99
+ phone_level_feature.append(repeat_feature)
100
+
101
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
102
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
text/cleaner.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text import chinese, japanese, english, cleaned_text_to_sequence
2
+
3
+
4
+ language_module_map = {"ZH": chinese, "JP": japanese, "EN": english}
5
+
6
+
7
+ def clean_text(text, language):
8
+ language_module = language_module_map[language]
9
+ norm_text = language_module.text_normalize(text)
10
+ phones, tones, word2ph = language_module.g2p(norm_text)
11
+ return norm_text, phones, tones, word2ph
12
+
13
+
14
+ def clean_text_bert(text, language):
15
+ language_module = language_module_map[language]
16
+ norm_text = language_module.text_normalize(text)
17
+ phones, tones, word2ph = language_module.g2p(norm_text)
18
+ bert = language_module.get_bert_feature(norm_text, word2ph)
19
+ return phones, tones, bert
20
+
21
+
22
+ def text_to_sequence(text, language):
23
+ norm_text, phones, tones, word2ph = clean_text(text, language)
24
+ return cleaned_text_to_sequence(phones, tones, language)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ pass
text/cmudict.rep ADDED
The diff for this file is too large to render. See raw diff
 
text/cmudict_cache.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
3
+ size 6212655
text/english.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+ from transformers import DebertaV2Tokenizer
6
+
7
+ from text import symbols
8
+
9
+ current_file_path = os.path.dirname(__file__)
10
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
11
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
12
+ _g2p = G2p()
13
+ #LOCAL_PATH = "./bert/deberta-v3-large"
14
+ LOCAL_PATH = 'microsoft/deberta-v3-large'
15
+ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
16
+
17
+ arpa = {
18
+ "AH0",
19
+ "S",
20
+ "AH1",
21
+ "EY2",
22
+ "AE2",
23
+ "EH0",
24
+ "OW2",
25
+ "UH0",
26
+ "NG",
27
+ "B",
28
+ "G",
29
+ "AY0",
30
+ "M",
31
+ "AA0",
32
+ "F",
33
+ "AO0",
34
+ "ER2",
35
+ "UH1",
36
+ "IY1",
37
+ "AH2",
38
+ "DH",
39
+ "IY0",
40
+ "EY1",
41
+ "IH0",
42
+ "K",
43
+ "N",
44
+ "W",
45
+ "IY2",
46
+ "T",
47
+ "AA1",
48
+ "ER1",
49
+ "EH2",
50
+ "OY0",
51
+ "UH2",
52
+ "UW1",
53
+ "Z",
54
+ "AW2",
55
+ "AW1",
56
+ "V",
57
+ "UW2",
58
+ "AA2",
59
+ "ER",
60
+ "AW0",
61
+ "UW0",
62
+ "R",
63
+ "OW1",
64
+ "EH1",
65
+ "ZH",
66
+ "AE0",
67
+ "IH2",
68
+ "IH",
69
+ "Y",
70
+ "JH",
71
+ "P",
72
+ "AY1",
73
+ "EY0",
74
+ "OY2",
75
+ "TH",
76
+ "HH",
77
+ "D",
78
+ "ER0",
79
+ "CH",
80
+ "AO1",
81
+ "AE1",
82
+ "AO2",
83
+ "OY1",
84
+ "AY2",
85
+ "IH1",
86
+ "OW0",
87
+ "L",
88
+ "SH",
89
+ }
90
+
91
+
92
+ def post_replace_ph(ph):
93
+ rep_map = {
94
+ ":": ",",
95
+ ";": ",",
96
+ ",": ",",
97
+ "。": ".",
98
+ "!": "!",
99
+ "?": "?",
100
+ "\n": ".",
101
+ "·": ",",
102
+ "、": ",",
103
+ "…": "...",
104
+ "···": "...",
105
+ "・・・": "...",
106
+ "v": "V",
107
+ }
108
+ if ph in rep_map.keys():
109
+ ph = rep_map[ph]
110
+ if ph in symbols:
111
+ return ph
112
+ if ph not in symbols:
113
+ ph = "UNK"
114
+ return ph
115
+
116
+
117
+ rep_map = {
118
+ ":": ",",
119
+ ";": ",",
120
+ ",": ",",
121
+ "。": ".",
122
+ "!": "!",
123
+ "?": "?",
124
+ "\n": ".",
125
+ ".": ".",
126
+ "…": "...",
127
+ "···": "...",
128
+ "・・・": "...",
129
+ "·": ",",
130
+ "・": ",",
131
+ "、": ",",
132
+ "$": ".",
133
+ "“": "'",
134
+ "”": "'",
135
+ '"': "'",
136
+ "‘": "'",
137
+ "’": "'",
138
+ "(": "'",
139
+ ")": "'",
140
+ "(": "'",
141
+ ")": "'",
142
+ "《": "'",
143
+ "》": "'",
144
+ "【": "'",
145
+ "】": "'",
146
+ "[": "'",
147
+ "]": "'",
148
+ "—": "-",
149
+ "−": "-",
150
+ "~": "-",
151
+ "~": "-",
152
+ "「": "'",
153
+ "」": "'",
154
+ }
155
+
156
+
157
+ def replace_punctuation(text):
158
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
159
+
160
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
161
+
162
+ # replaced_text = re.sub(
163
+ # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
164
+ # + "".join(punctuation)
165
+ # + r"]+",
166
+ # "",
167
+ # replaced_text,
168
+ # )
169
+
170
+ return replaced_text
171
+
172
+
173
+ def read_dict():
174
+ g2p_dict = {}
175
+ start_line = 49
176
+ with open(CMU_DICT_PATH) as f:
177
+ line = f.readline()
178
+ line_index = 1
179
+ while line:
180
+ if line_index >= start_line:
181
+ line = line.strip()
182
+ word_split = line.split(" ")
183
+ word = word_split[0]
184
+
185
+ syllable_split = word_split[1].split(" - ")
186
+ g2p_dict[word] = []
187
+ for syllable in syllable_split:
188
+ phone_split = syllable.split(" ")
189
+ g2p_dict[word].append(phone_split)
190
+
191
+ line_index = line_index + 1
192
+ line = f.readline()
193
+
194
+ return g2p_dict
195
+
196
+
197
+ def cache_dict(g2p_dict, file_path):
198
+ with open(file_path, "wb") as pickle_file:
199
+ pickle.dump(g2p_dict, pickle_file)
200
+
201
+
202
+ def get_dict():
203
+ if os.path.exists(CACHE_PATH):
204
+ with open(CACHE_PATH, "rb") as pickle_file:
205
+ g2p_dict = pickle.load(pickle_file)
206
+ else:
207
+ g2p_dict = read_dict()
208
+ cache_dict(g2p_dict, CACHE_PATH)
209
+
210
+ return g2p_dict
211
+
212
+
213
+ eng_dict = get_dict()
214
+
215
+
216
+ def refine_ph(phn):
217
+ tone = 0
218
+ if re.search(r"\d$", phn):
219
+ tone = int(phn[-1]) + 1
220
+ phn = phn[:-1]
221
+ return phn.lower(), tone
222
+
223
+
224
+ def refine_syllables(syllables):
225
+ tones = []
226
+ phonemes = []
227
+ for phn_list in syllables:
228
+ for i in range(len(phn_list)):
229
+ phn = phn_list[i]
230
+ phn, tone = refine_ph(phn)
231
+ phonemes.append(phn)
232
+ tones.append(tone)
233
+ return phonemes, tones
234
+
235
+
236
+ import re
237
+ import inflect
238
+
239
+ _inflect = inflect.engine()
240
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
241
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
242
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
243
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
244
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
245
+ _number_re = re.compile(r"[0-9]+")
246
+
247
+ # List of (regular expression, replacement) pairs for abbreviations:
248
+ _abbreviations = [
249
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
250
+ for x in [
251
+ ("mrs", "misess"),
252
+ ("mr", "mister"),
253
+ ("dr", "doctor"),
254
+ ("st", "saint"),
255
+ ("co", "company"),
256
+ ("jr", "junior"),
257
+ ("maj", "major"),
258
+ ("gen", "general"),
259
+ ("drs", "doctors"),
260
+ ("rev", "reverend"),
261
+ ("lt", "lieutenant"),
262
+ ("hon", "honorable"),
263
+ ("sgt", "sergeant"),
264
+ ("capt", "captain"),
265
+ ("esq", "esquire"),
266
+ ("ltd", "limited"),
267
+ ("col", "colonel"),
268
+ ("ft", "fort"),
269
+ ]
270
+ ]
271
+
272
+
273
+ # List of (ipa, lazy ipa) pairs:
274
+ _lazy_ipa = [
275
+ (re.compile("%s" % x[0]), x[1])
276
+ for x in [
277
+ ("r", "ɹ"),
278
+ ("æ", "e"),
279
+ ("ɑ", "a"),
280
+ ("ɔ", "o"),
281
+ ("ð", "z"),
282
+ ("θ", "s"),
283
+ ("ɛ", "e"),
284
+ ("ɪ", "i"),
285
+ ("ʊ", "u"),
286
+ ("ʒ", "ʥ"),
287
+ ("ʤ", "ʥ"),
288
+ ("ˈ", "↓"),
289
+ ]
290
+ ]
291
+
292
+ # List of (ipa, lazy ipa2) pairs:
293
+ _lazy_ipa2 = [
294
+ (re.compile("%s" % x[0]), x[1])
295
+ for x in [
296
+ ("r", "ɹ"),
297
+ ("ð", "z"),
298
+ ("θ", "s"),
299
+ ("ʒ", "ʑ"),
300
+ ("ʤ", "dʑ"),
301
+ ("ˈ", "↓"),
302
+ ]
303
+ ]
304
+
305
+ # List of (ipa, ipa2) pairs
306
+ _ipa_to_ipa2 = [
307
+ (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")]
308
+ ]
309
+
310
+
311
+ def _expand_dollars(m):
312
+ match = m.group(1)
313
+ parts = match.split(".")
314
+ if len(parts) > 2:
315
+ return match + " dollars" # Unexpected format
316
+ dollars = int(parts[0]) if parts[0] else 0
317
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
318
+ if dollars and cents:
319
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
320
+ cent_unit = "cent" if cents == 1 else "cents"
321
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
322
+ elif dollars:
323
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
324
+ return "%s %s" % (dollars, dollar_unit)
325
+ elif cents:
326
+ cent_unit = "cent" if cents == 1 else "cents"
327
+ return "%s %s" % (cents, cent_unit)
328
+ else:
329
+ return "zero dollars"
330
+
331
+
332
+ def _remove_commas(m):
333
+ return m.group(1).replace(",", "")
334
+
335
+
336
+ def _expand_ordinal(m):
337
+ return _inflect.number_to_words(m.group(0))
338
+
339
+
340
+ def _expand_number(m):
341
+ num = int(m.group(0))
342
+ if num > 1000 and num < 3000:
343
+ if num == 2000:
344
+ return "two thousand"
345
+ elif num > 2000 and num < 2010:
346
+ return "two thousand " + _inflect.number_to_words(num % 100)
347
+ elif num % 100 == 0:
348
+ return _inflect.number_to_words(num // 100) + " hundred"
349
+ else:
350
+ return _inflect.number_to_words(
351
+ num, andword="", zero="oh", group=2
352
+ ).replace(", ", " ")
353
+ else:
354
+ return _inflect.number_to_words(num, andword="")
355
+
356
+
357
+ def _expand_decimal_point(m):
358
+ return m.group(1).replace(".", " point ")
359
+
360
+
361
+ def normalize_numbers(text):
362
+ text = re.sub(_comma_number_re, _remove_commas, text)
363
+ text = re.sub(_pounds_re, r"\1 pounds", text)
364
+ text = re.sub(_dollars_re, _expand_dollars, text)
365
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
366
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
367
+ text = re.sub(_number_re, _expand_number, text)
368
+ return text
369
+
370
+
371
+ def text_normalize(text):
372
+ text = normalize_numbers(text)
373
+ text = replace_punctuation(text)
374
+ text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text)
375
+ return text
376
+
377
+
378
+ def distribute_phone(n_phone, n_word):
379
+ phones_per_word = [0] * n_word
380
+ for task in range(n_phone):
381
+ min_tasks = min(phones_per_word)
382
+ min_index = phones_per_word.index(min_tasks)
383
+ phones_per_word[min_index] += 1
384
+ return phones_per_word
385
+
386
+
387
+ def sep_text(text):
388
+ words = re.split(r"([,;.\?\!\s+])", text)
389
+ words = [word for word in words if word.strip() != ""]
390
+ return words
391
+
392
+
393
+ def g2p(text):
394
+ phones = []
395
+ tones = []
396
+ # word2ph = []
397
+ words = sep_text(text)
398
+ tokens = [tokenizer.tokenize(i) for i in words]
399
+ for word in words:
400
+ if word.upper() in eng_dict:
401
+ phns, tns = refine_syllables(eng_dict[word.upper()])
402
+ phones.append([post_replace_ph(i) for i in phns])
403
+ tones.append(tns)
404
+ # word2ph.append(len(phns))
405
+ else:
406
+ phone_list = list(filter(lambda p: p != " ", _g2p(word)))
407
+ phns = []
408
+ tns = []
409
+ for ph in phone_list:
410
+ if ph in arpa:
411
+ ph, tn = refine_ph(ph)
412
+ phns.append(ph)
413
+ tns.append(tn)
414
+ else:
415
+ phns.append(ph)
416
+ tns.append(0)
417
+ phones.append([post_replace_ph(i) for i in phns])
418
+ tones.append(tns)
419
+ # word2ph.append(len(phns))
420
+ # phones = [post_replace_ph(i) for i in phones]
421
+
422
+ word2ph = []
423
+ for token, phoneme in zip(tokens, phones):
424
+ phone_len = len(phoneme)
425
+ word_len = len(token)
426
+
427
+ aaa = distribute_phone(phone_len, word_len)
428
+ word2ph += aaa
429
+
430
+ phones = ["_"] + [j for i in phones for j in i] + ["_"]
431
+ tones = [0] + [j for i in tones for j in i] + [0]
432
+ word2ph = [1] + word2ph + [1]
433
+ assert len(phones) == len(tones), text
434
+ assert len(phones) == sum(word2ph), text
435
+
436
+ return phones, tones, word2ph
437
+
438
+
439
+ def get_bert_feature(text, word2ph):
440
+ from text import english_bert_mock
441
+
442
+ return english_bert_mock.get_bert_feature(text, word2ph)
443
+
444
+
445
+ if __name__ == "__main__":
446
+ # print(get_dict())
447
+ # print(eng_word_to_phoneme("hello"))
448
+ print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))
449
+ # all_phones = set()
450
+ # for k, syllables in eng_dict.items():
451
+ # for group in syllables:
452
+ # for ph in group:
453
+ # all_phones.add(ph)
454
+ # print(all_phones)
text/english_bert_mock.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import DebertaV2Model, DebertaV2Tokenizer
5
+
6
+ from config import config
7
+
8
+
9
+ #LOCAL_PATH = "./bert/deberta-v3-large"
10
+ LOCAL_PATH = 'microsoft/deberta-v2-xlarge'
11
+
12
+ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
13
+
14
+ models = dict()
15
+
16
+
17
+ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
18
+ if (
19
+ sys.platform == "darwin"
20
+ and torch.backends.mps.is_available()
21
+ and device == "cpu"
22
+ ):
23
+ device = "mps"
24
+ if not device:
25
+ device = "cuda"
26
+ if device not in models.keys():
27
+ models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device)
28
+ with torch.no_grad():
29
+ inputs = tokenizer(text, return_tensors="pt")
30
+ for i in inputs:
31
+ inputs[i] = inputs[i].to(device)
32
+ res = models[device](**inputs, output_hidden_states=True)
33
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
34
+ assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
35
+ word2phone = word2ph
36
+ phone_level_feature = []
37
+ for i in range(len(word2phone)):
38
+ repeat_feature = res[i].repeat(word2phone[i], 1)
39
+ phone_level_feature.append(repeat_feature)
40
+
41
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
42
+
43
+ return phone_level_feature.T
text/japanese.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert Japanese text to phonemes which is
2
+ # compatible with Julius https://github.com/julius-speech/segmentation-kit
3
+ import re
4
+ import unicodedata
5
+
6
+ from transformers import AutoTokenizer
7
+
8
+ from text import punctuation, symbols
9
+
10
+ from num2words import num2words
11
+
12
+ import pyopenjtalk
13
+ import jaconv
14
+
15
+
16
+ def kata2phoneme(text: str) -> str:
17
+ """Convert katakana text to phonemes."""
18
+ text = text.strip()
19
+ if text == "ー":
20
+ return ["ー"]
21
+ elif text.startswith("ー"):
22
+ return ["ー"] + kata2phoneme(text[1:])
23
+ res = []
24
+ prev = None
25
+ while text:
26
+ if re.match(_MARKS, text):
27
+ res.append(text)
28
+ text = text[1:]
29
+ continue
30
+ if text.startswith("ー"):
31
+ if prev:
32
+ res.append(prev[-1])
33
+ text = text[1:]
34
+ continue
35
+ res += pyopenjtalk.g2p(text).lower().replace("cl", "q").split(" ")
36
+ break
37
+ # res = _COLON_RX.sub(":", res)
38
+ return res
39
+
40
+
41
+ def hira2kata(text: str) -> str:
42
+ return jaconv.hira2kata(text)
43
+
44
+
45
+ _SYMBOL_TOKENS = set(list("・、。?!"))
46
+ _NO_YOMI_TOKENS = set(list("「」『』―()[][]"))
47
+ _MARKS = re.compile(
48
+ r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
49
+ )
50
+
51
+
52
+ def text2kata(text: str) -> str:
53
+ parsed = pyopenjtalk.run_frontend(text)
54
+
55
+ res = []
56
+ for parts in parsed:
57
+ word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace(
58
+ "’", ""
59
+ )
60
+ if yomi:
61
+ if re.match(_MARKS, yomi):
62
+ if len(word) > 1:
63
+ word = [replace_punctuation(i) for i in list(word)]
64
+ yomi = word
65
+ res += yomi
66
+ sep += word
67
+ continue
68
+ elif word not in rep_map.keys() and word not in rep_map.values():
69
+ word = ","
70
+ yomi = word
71
+ res.append(yomi)
72
+ else:
73
+ if word in _SYMBOL_TOKENS:
74
+ res.append(word)
75
+ elif word in ("っ", "ッ"):
76
+ res.append("ッ")
77
+ elif word in _NO_YOMI_TOKENS:
78
+ pass
79
+ else:
80
+ res.append(word)
81
+ return hira2kata("".join(res))
82
+
83
+
84
+ def text2sep_kata(text: str) -> (list, list):
85
+ parsed = pyopenjtalk.run_frontend(text)
86
+
87
+ res = []
88
+ sep = []
89
+ for parts in parsed:
90
+ word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace(
91
+ "’", ""
92
+ )
93
+ if yomi:
94
+ if re.match(_MARKS, yomi):
95
+ if len(word) > 1:
96
+ word = [replace_punctuation(i) for i in list(word)]
97
+ yomi = word
98
+ res += yomi
99
+ sep += word
100
+ continue
101
+ elif word not in rep_map.keys() and word not in rep_map.values():
102
+ word = ","
103
+ yomi = word
104
+ res.append(yomi)
105
+ else:
106
+ if word in _SYMBOL_TOKENS:
107
+ res.append(word)
108
+ elif word in ("っ", "ッ"):
109
+ res.append("ッ")
110
+ elif word in _NO_YOMI_TOKENS:
111
+ pass
112
+ else:
113
+ res.append(word)
114
+ sep.append(word)
115
+ return sep, [hira2kata(i) for i in res], get_accent(parsed)
116
+
117
+
118
+ def get_accent(parsed):
119
+ labels = pyopenjtalk.make_label(parsed)
120
+
121
+ phonemes = []
122
+ accents = []
123
+ for n, label in enumerate(labels):
124
+ phoneme = re.search(r"\-([^\+]*)\+", label).group(1)
125
+ if phoneme not in ["sil", "pau"]:
126
+ phonemes.append(phoneme.replace("cl", "q").lower())
127
+ else:
128
+ continue
129
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
130
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
131
+ if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]:
132
+ a2_next = -1
133
+ else:
134
+ a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
135
+ # Falling
136
+ if a1 == 0 and a2_next == a2 + 1:
137
+ accents.append(-1)
138
+ # Rising
139
+ elif a2 == 1 and a2_next == 2:
140
+ accents.append(1)
141
+ else:
142
+ accents.append(0)
143
+ return list(zip(phonemes, accents))
144
+
145
+
146
+ _ALPHASYMBOL_YOMI = {
147
+ "#": "シャープ",
148
+ "%": "パーセント",
149
+ "&": "アンド",
150
+ "+": "プラス",
151
+ "-": "マイナス",
152
+ ":": "コロン",
153
+ ";": "セミコロン",
154
+ "<": "小なり",
155
+ "=": "イコール",
156
+ ">": "大なり",
157
+ "@": "アット",
158
+ "a": "エー",
159
+ "b": "ビー",
160
+ "c": "シー",
161
+ "d": "ディー",
162
+ "e": "イー",
163
+ "f": "エフ",
164
+ "g": "ジー",
165
+ "h": "エイチ",
166
+ "i": "アイ",
167
+ "j": "ジェー",
168
+ "k": "ケー",
169
+ "l": "エル",
170
+ "m": "エム",
171
+ "n": "エヌ",
172
+ "o": "オー",
173
+ "p": "ピー",
174
+ "q": "キュー",
175
+ "r": "アール",
176
+ "s": "エス",
177
+ "t": "ティー",
178
+ "u": "ユー",
179
+ "v": "ブイ",
180
+ "w": "ダブリュー",
181
+ "x": "エックス",
182
+ "y": "ワイ",
183
+ "z": "ゼット",
184
+ "α": "アルファ",
185
+ "β": "ベータ",
186
+ "γ": "ガンマ",
187
+ "δ": "デルタ",
188
+ "ε": "イプシロン",
189
+ "ζ": "ゼータ",
190
+ "η": "イータ",
191
+ "θ": "シータ",
192
+ "ι": "イオタ",
193
+ "κ": "カッパ",
194
+ "λ": "ラムダ",
195
+ "μ": "ミュー",
196
+ "ν": "ニュー",
197
+ "ξ": "クサイ",
198
+ "ο": "オミクロン",
199
+ "π": "パイ",
200
+ "ρ": "ロー",
201
+ "σ": "シグマ",
202
+ "τ": "タウ",
203
+ "υ": "ウプシロン",
204
+ "φ": "ファイ",
205
+ "χ": "カイ",
206
+ "ψ": "プサイ",
207
+ "ω": "オメガ",
208
+ }
209
+
210
+
211
+ _NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
212
+ _CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"}
213
+ _CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
214
+ _NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
215
+
216
+
217
+ def japanese_convert_numbers_to_words(text: str) -> str:
218
+ res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
219
+ res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
220
+ res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
221
+ return res
222
+
223
+
224
+ def japanese_convert_alpha_symbols_to_words(text: str) -> str:
225
+ return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()])
226
+
227
+
228
+ def japanese_text_to_phonemes(text: str) -> str:
229
+ """Convert Japanese text to phonemes."""
230
+ res = unicodedata.normalize("NFKC", text)
231
+ res = japanese_convert_numbers_to_words(res)
232
+ # res = japanese_convert_alpha_symbols_to_words(res)
233
+ res = text2kata(res)
234
+ res = kata2phoneme(res)
235
+ return res
236
+
237
+
238
+ def is_japanese_character(char):
239
+ # 定义日语文字系统的 Unicode 范围
240
+ japanese_ranges = [
241
+ (0x3040, 0x309F), # 平假名
242
+ (0x30A0, 0x30FF), # 片假名
243
+ (0x4E00, 0x9FFF), # 汉字 (CJK Unified Ideographs)
244
+ (0x3400, 0x4DBF), # 汉字扩展 A
245
+ (0x20000, 0x2A6DF), # 汉字扩展 B
246
+ # 可以根据需要添加其他汉字扩展范围
247
+ ]
248
+
249
+ # 将字符的 Unicode 编码转换为整数
250
+ char_code = ord(char)
251
+
252
+ # 检查字符是否在任何一个日语范围内
253
+ for start, end in japanese_ranges:
254
+ if start <= char_code <= end:
255
+ return True
256
+
257
+ return False
258
+
259
+
260
+ rep_map = {
261
+ ":": ",",
262
+ ";": ",",
263
+ ",": ",",
264
+ "。": ".",
265
+ "!": "!",
266
+ "?": "?",
267
+ "\n": ".",
268
+ ".": ".",
269
+ "…": "...",
270
+ "···": "...",
271
+ "・・・": "...",
272
+ "·": ",",
273
+ "・": ",",
274
+ "、": ",",
275
+ "$": ".",
276
+ "“": "'",
277
+ "”": "'",
278
+ '"': "'",
279
+ "‘": "'",
280
+ "’": "'",
281
+ "(": "'",
282
+ ")": "'",
283
+ "(": "'",
284
+ ")": "'",
285
+ "《": "'",
286
+ "》": "'",
287
+ "【": "'",
288
+ "】": "'",
289
+ "[": "'",
290
+ "]": "'",
291
+ "—": "-",
292
+ "−": "-",
293
+ "~": "-",
294
+ "~": "-",
295
+ "「": "'",
296
+ "」": "'",
297
+ }
298
+
299
+
300
+ def replace_punctuation(text):
301
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
302
+
303
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
304
+
305
+ replaced_text = re.sub(
306
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
307
+ + "".join(punctuation)
308
+ + r"]+",
309
+ "",
310
+ replaced_text,
311
+ )
312
+
313
+ return replaced_text
314
+
315
+
316
+ def text_normalize(text):
317
+ res = unicodedata.normalize("NFKC", text)
318
+ res = japanese_convert_numbers_to_words(res)
319
+ # res = "".join([i for i in res if is_japanese_character(i)])
320
+ res = replace_punctuation(res)
321
+ res = res.replace("゙", "")
322
+ return res
323
+
324
+
325
+ def distribute_phone(n_phone, n_word):
326
+ phones_per_word = [0] * n_word
327
+ for task in range(n_phone):
328
+ min_tasks = min(phones_per_word)
329
+ min_index = phones_per_word.index(min_tasks)
330
+ phones_per_word[min_index] += 1
331
+ return phones_per_word
332
+
333
+
334
+ def handle_long(sep_phonemes):
335
+ for i in range(len(sep_phonemes)):
336
+ if sep_phonemes[i][0] == "ー":
337
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
338
+ if "ー" in sep_phonemes[i]:
339
+ for j in range(len(sep_phonemes[i])):
340
+ if sep_phonemes[i][j] == "ー":
341
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
342
+ return sep_phonemes
343
+
344
+
345
+ #tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm")
346
+ tokenizer = AutoTokenizer.from_pretrained("ku-nlp/deberta-v2-large-japanese-char-wwm")
347
+
348
+ def align_tones(phones, tones):
349
+ res = []
350
+ for pho in phones:
351
+ temp = [0] * len(pho)
352
+ for idx, p in enumerate(pho):
353
+ if len(tones) == 0:
354
+ break
355
+ if p == tones[0][0]:
356
+ temp[idx] = tones[0][1]
357
+ if idx > 0:
358
+ temp[idx] += temp[idx - 1]
359
+ tones.pop(0)
360
+ temp = [0] + temp
361
+ temp = temp[:-1]
362
+ if -1 in temp:
363
+ temp = [i + 1 for i in temp]
364
+ res.append(temp)
365
+ res = [i for j in res for i in j]
366
+ assert not any([i < 0 for i in res]) and not any([i > 1 for i in res])
367
+ return res
368
+
369
+
370
+ def rearrange_tones(tones, phones):
371
+ res = [0] * len(tones)
372
+ for i in range(len(tones)):
373
+ if i == 0:
374
+ if tones[i] not in punctuation:
375
+ res[i] = 1
376
+ elif tones[i] == prev:
377
+ if phones[i] in punctuation:
378
+ res[i] = 0
379
+ else:
380
+ res[i] = 1
381
+ elif tones[i] > prev:
382
+ res[i] = 2
383
+ elif tones[i] < prev:
384
+ res[i - 1] = 3
385
+ res[i] = 1
386
+ prev = tones[i]
387
+ return res
388
+
389
+
390
+ def g2p(norm_text):
391
+ sep_text, sep_kata, acc = text2sep_kata(norm_text)
392
+ sep_tokenized = []
393
+ for i in sep_text:
394
+ if i not in punctuation:
395
+ sep_tokenized.append(tokenizer.tokenize(i))
396
+ else:
397
+ sep_tokenized.append([i])
398
+
399
+ sep_phonemes = handle_long([kata2phoneme(i) for i in sep_kata])
400
+ # 异常处理,MeCab不认识的词的话会一路传到这里来,然后炸掉。目前来看只有那些超级稀有的生僻词会出现这种情况
401
+ for i in sep_phonemes:
402
+ for j in i:
403
+ assert j in symbols, (sep_text, sep_kata, sep_phonemes)
404
+ tones = align_tones(sep_phonemes, acc)
405
+
406
+ word2ph = []
407
+ for token, phoneme in zip(sep_tokenized, sep_phonemes):
408
+ phone_len = len(phoneme)
409
+ word_len = len(token)
410
+
411
+ aaa = distribute_phone(phone_len, word_len)
412
+ word2ph += aaa
413
+ phones = ["_"] + [j for i in sep_phonemes for j in i] + ["_"]
414
+ # tones = [0] + rearrange_tones(tones, phones[1:-1]) + [0]
415
+ tones = [0] + tones + [0]
416
+ word2ph = [1] + word2ph + [1]
417
+ assert len(phones) == len(tones)
418
+ return phones, tones, word2ph
419
+
420
+
421
+ if __name__ == "__main__":
422
+ tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese")
423
+ text = "hello,こんにちは、世界ー!……"
424
+ from text.japanese_bert import get_bert_feature
425
+
426
+ text = text_normalize(text)
427
+ print(text)
428
+
429
+ phones, tones, word2ph = g2p(text)
430
+ bert = get_bert_feature(text, word2ph)
431
+
432
+ print(phones, tones, word2ph, bert.shape)
text/japanese_bert.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
5
+
6
+ from config import config
7
+ from text.japanese import text2sep_kata
8
+
9
+ #LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm"
10
+ LOCAL_PATH = 'ku-nlp/deberta-v2-large-japanese-char-wwm'
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
13
+
14
+ models = dict()
15
+
16
+
17
+ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
18
+ text = "".join(text2sep_kata(text)[0])
19
+ if (
20
+ sys.platform == "darwin"
21
+ and torch.backends.mps.is_available()
22
+ and device == "cpu"
23
+ ):
24
+ device = "mps"
25
+ if not device:
26
+ device = "cuda"
27
+ if device not in models.keys():
28
+ models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
29
+ with torch.no_grad():
30
+ inputs = tokenizer(text, return_tensors="pt")
31
+ for i in inputs:
32
+ inputs[i] = inputs[i].to(device)
33
+ res = models[device](**inputs, output_hidden_states=True)
34
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
35
+
36
+ assert len(word2ph) == len(text) + 2
37
+ word2phone = word2ph
38
+ phone_level_feature = []
39
+ for i in range(len(word2phone)):
40
+ repeat_feature = res[i].repeat(word2phone[i], 1)
41
+ phone_level_feature.append(repeat_feature)
42
+
43
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
44
+
45
+ return phone_level_feature.T
text/opencpop-strict.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a AA a
2
+ ai AA ai
3
+ an AA an
4
+ ang AA ang
5
+ ao AA ao
6
+ ba b a
7
+ bai b ai
8
+ ban b an
9
+ bang b ang
10
+ bao b ao
11
+ bei b ei
12
+ ben b en
13
+ beng b eng
14
+ bi b i
15
+ bian b ian
16
+ biao b iao
17
+ bie b ie
18
+ bin b in
19
+ bing b ing
20
+ bo b o
21
+ bu b u
22
+ ca c a
23
+ cai c ai
24
+ can c an
25
+ cang c ang
26
+ cao c ao
27
+ ce c e
28
+ cei c ei
29
+ cen c en
30
+ ceng c eng
31
+ cha ch a
32
+ chai ch ai
33
+ chan ch an
34
+ chang ch ang
35
+ chao ch ao
36
+ che ch e
37
+ chen ch en
38
+ cheng ch eng
39
+ chi ch ir
40
+ chong ch ong
41
+ chou ch ou
42
+ chu ch u
43
+ chua ch ua
44
+ chuai ch uai
45
+ chuan ch uan
46
+ chuang ch uang
47
+ chui ch ui
48
+ chun ch un
49
+ chuo ch uo
50
+ ci c i0
51
+ cong c ong
52
+ cou c ou
53
+ cu c u
54
+ cuan c uan
55
+ cui c ui
56
+ cun c un
57
+ cuo c uo
58
+ da d a
59
+ dai d ai
60
+ dan d an
61
+ dang d ang
62
+ dao d ao
63
+ de d e
64
+ dei d ei
65
+ den d en
66
+ deng d eng
67
+ di d i
68
+ dia d ia
69
+ dian d ian
70
+ diao d iao
71
+ die d ie
72
+ ding d ing
73
+ diu d iu
74
+ dong d ong
75
+ dou d ou
76
+ du d u
77
+ duan d uan
78
+ dui d ui
79
+ dun d un
80
+ duo d uo
81
+ e EE e
82
+ ei EE ei
83
+ en EE en
84
+ eng EE eng
85
+ er EE er
86
+ fa f a
87
+ fan f an
88
+ fang f ang
89
+ fei f ei
90
+ fen f en
91
+ feng f eng
92
+ fo f o
93
+ fou f ou
94
+ fu f u
95
+ ga g a
96
+ gai g ai
97
+ gan g an
98
+ gang g ang
99
+ gao g ao
100
+ ge g e
101
+ gei g ei
102
+ gen g en
103
+ geng g eng
104
+ gong g ong
105
+ gou g ou
106
+ gu g u
107
+ gua g ua
108
+ guai g uai
109
+ guan g uan
110
+ guang g uang
111
+ gui g ui
112
+ gun g un
113
+ guo g uo
114
+ ha h a
115
+ hai h ai
116
+ han h an
117
+ hang h ang
118
+ hao h ao
119
+ he h e
120
+ hei h ei
121
+ hen h en
122
+ heng h eng
123
+ hong h ong
124
+ hou h ou
125
+ hu h u
126
+ hua h ua
127
+ huai h uai
128
+ huan h uan
129
+ huang h uang
130
+ hui h ui
131
+ hun h un
132
+ huo h uo
133
+ ji j i
134
+ jia j ia
135
+ jian j ian
136
+ jiang j iang
137
+ jiao j iao
138
+ jie j ie
139
+ jin j in
140
+ jing j ing
141
+ jiong j iong
142
+ jiu j iu
143
+ ju j v
144
+ jv j v
145
+ juan j van
146
+ jvan j van
147
+ jue j ve
148
+ jve j ve
149
+ jun j vn
150
+ jvn j vn
151
+ ka k a
152
+ kai k ai
153
+ kan k an
154
+ kang k ang
155
+ kao k ao
156
+ ke k e
157
+ kei k ei
158
+ ken k en
159
+ keng k eng
160
+ kong k ong
161
+ kou k ou
162
+ ku k u
163
+ kua k ua
164
+ kuai k uai
165
+ kuan k uan
166
+ kuang k uang
167
+ kui k ui
168
+ kun k un
169
+ kuo k uo
170
+ la l a
171
+ lai l ai
172
+ lan l an
173
+ lang l ang
174
+ lao l ao
175
+ le l e
176
+ lei l ei
177
+ leng l eng
178
+ li l i
179
+ lia l ia
180
+ lian l ian
181
+ liang l iang
182
+ liao l iao
183
+ lie l ie
184
+ lin l in
185
+ ling l ing
186
+ liu l iu
187
+ lo l o
188
+ long l ong
189
+ lou l ou
190
+ lu l u
191
+ luan l uan
192
+ lun l un
193
+ luo l uo
194
+ lv l v
195
+ lve l ve
196
+ ma m a
197
+ mai m ai
198
+ man m an
199
+ mang m ang
200
+ mao m ao
201
+ me m e
202
+ mei m ei
203
+ men m en
204
+ meng m eng
205
+ mi m i
206
+ mian m ian
207
+ miao m iao
208
+ mie m ie
209
+ min m in
210
+ ming m ing
211
+ miu m iu
212
+ mo m o
213
+ mou m ou
214
+ mu m u
215
+ na n a
216
+ nai n ai
217
+ nan n an
218
+ nang n ang
219
+ nao n ao
220
+ ne n e
221
+ nei n ei
222
+ nen n en
223
+ neng n eng
224
+ ni n i
225
+ nian n ian
226
+ niang n iang
227
+ niao n iao
228
+ nie n ie
229
+ nin n in
230
+ ning n ing
231
+ niu n iu
232
+ nong n ong
233
+ nou n ou
234
+ nu n u
235
+ nuan n uan
236
+ nun n un
237
+ nuo n uo
238
+ nv n v
239
+ nve n ve
240
+ o OO o
241
+ ou OO ou
242
+ pa p a
243
+ pai p ai
244
+ pan p an
245
+ pang p ang
246
+ pao p ao
247
+ pei p ei
248
+ pen p en
249
+ peng p eng
250
+ pi p i
251
+ pian p ian
252
+ piao p iao
253
+ pie p ie
254
+ pin p in
255
+ ping p ing
256
+ po p o
257
+ pou p ou
258
+ pu p u
259
+ qi q i
260
+ qia q ia
261
+ qian q ian
262
+ qiang q iang
263
+ qiao q iao
264
+ qie q ie
265
+ qin q in
266
+ qing q ing
267
+ qiong q iong
268
+ qiu q iu
269
+ qu q v
270
+ qv q v
271
+ quan q van
272
+ qvan q van
273
+ que q ve
274
+ qve q ve
275
+ qun q vn
276
+ qvn q vn
277
+ ran r an
278
+ rang r ang
279
+ rao r ao
280
+ re r e
281
+ ren r en
282
+ reng r eng
283
+ ri r ir
284
+ rong r ong
285
+ rou r ou
286
+ ru r u
287
+ rua r ua
288
+ ruan r uan
289
+ rui r ui
290
+ run r un
291
+ ruo r uo
292
+ sa s a
293
+ sai s ai
294
+ san s an
295
+ sang s ang
296
+ sao s ao
297
+ se s e
298
+ sen s en
299
+ seng s eng
300
+ sha sh a
301
+ shai sh ai
302
+ shan sh an
303
+ shang sh ang
304
+ shao sh ao
305
+ she sh e
306
+ shei sh ei
307
+ shen sh en
308
+ sheng sh eng
309
+ shi sh ir
310
+ shou sh ou
311
+ shu sh u
312
+ shua sh ua
313
+ shuai sh uai
314
+ shuan sh uan
315
+ shuang sh uang
316
+ shui sh ui
317
+ shun sh un
318
+ shuo sh uo
319
+ si s i0
320
+ song s ong
321
+ sou s ou
322
+ su s u
323
+ suan s uan
324
+ sui s ui
325
+ sun s un
326
+ suo s uo
327
+ ta t a
328
+ tai t ai
329
+ tan t an
330
+ tang t ang
331
+ tao t ao
332
+ te t e
333
+ tei t ei
334
+ teng t eng
335
+ ti t i
336
+ tian t ian
337
+ tiao t iao
338
+ tie t ie
339
+ ting t ing
340
+ tong t ong
341
+ tou t ou
342
+ tu t u
343
+ tuan t uan
344
+ tui t ui
345
+ tun t un
346
+ tuo t uo
347
+ wa w a
348
+ wai w ai
349
+ wan w an
350
+ wang w ang
351
+ wei w ei
352
+ wen w en
353
+ weng w eng
354
+ wo w o
355
+ wu w u
356
+ xi x i
357
+ xia x ia
358
+ xian x ian
359
+ xiang x iang
360
+ xiao x iao
361
+ xie x ie
362
+ xin x in
363
+ xing x ing
364
+ xiong x iong
365
+ xiu x iu
366
+ xu x v
367
+ xv x v
368
+ xuan x van
369
+ xvan x van
370
+ xue x ve
371
+ xve x ve
372
+ xun x vn
373
+ xvn x vn
374
+ ya y a
375
+ yan y En
376
+ yang y ang
377
+ yao y ao
378
+ ye y E
379
+ yi y i
380
+ yin y in
381
+ ying y ing
382
+ yo y o
383
+ yong y ong
384
+ you y ou
385
+ yu y v
386
+ yv y v
387
+ yuan y van
388
+ yvan y van
389
+ yue y ve
390
+ yve y ve
391
+ yun y vn
392
+ yvn y vn
393
+ za z a
394
+ zai z ai
395
+ zan z an
396
+ zang z ang
397
+ zao z ao
398
+ ze z e
399
+ zei z ei
400
+ zen z en
401
+ zeng z eng
402
+ zha zh a
403
+ zhai zh ai
404
+ zhan zh an
405
+ zhang zh ang
406
+ zhao zh ao
407
+ zhe zh e
408
+ zhei zh ei
409
+ zhen zh en
410
+ zheng zh eng
411
+ zhi zh ir
412
+ zhong zh ong
413
+ zhou zh ou
414
+ zhu zh u
415
+ zhua zh ua
416
+ zhuai zh uai
417
+ zhuan zh uan
418
+ zhuang zh uang
419
+ zhui zh ui
420
+ zhun zh un
421
+ zhuo zh uo
422
+ zi z i0
423
+ zong z ong
424
+ zou z ou
425
+ zu z u
426
+ zuan z uan
427
+ zui z ui
428
+ zun z un
429
+ zuo z uo
text/symbols.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
+
5
+ # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
+ num_zh_tones = 6
74
+
75
+ # japanese
76
+ ja_symbols = [
77
+ "N",
78
+ "a",
79
+ "a:",
80
+ "b",
81
+ "by",
82
+ "ch",
83
+ "d",
84
+ "dy",
85
+ "e",
86
+ "e:",
87
+ "f",
88
+ "g",
89
+ "gy",
90
+ "h",
91
+ "hy",
92
+ "i",
93
+ "i:",
94
+ "j",
95
+ "k",
96
+ "ky",
97
+ "m",
98
+ "my",
99
+ "n",
100
+ "ny",
101
+ "o",
102
+ "o:",
103
+ "p",
104
+ "py",
105
+ "q",
106
+ "r",
107
+ "ry",
108
+ "s",
109
+ "sh",
110
+ "t",
111
+ "ts",
112
+ "ty",
113
+ "u",
114
+ "u:",
115
+ "w",
116
+ "y",
117
+ "z",
118
+ "zy",
119
+ ]
120
+ num_ja_tones = 2
121
+
122
+ # English
123
+ en_symbols = [
124
+ "aa",
125
+ "ae",
126
+ "ah",
127
+ "ao",
128
+ "aw",
129
+ "ay",
130
+ "b",
131
+ "ch",
132
+ "d",
133
+ "dh",
134
+ "eh",
135
+ "er",
136
+ "ey",
137
+ "f",
138
+ "g",
139
+ "hh",
140
+ "ih",
141
+ "iy",
142
+ "jh",
143
+ "k",
144
+ "l",
145
+ "m",
146
+ "n",
147
+ "ng",
148
+ "ow",
149
+ "oy",
150
+ "p",
151
+ "r",
152
+ "s",
153
+ "sh",
154
+ "t",
155
+ "th",
156
+ "uh",
157
+ "uw",
158
+ "V",
159
+ "w",
160
+ "y",
161
+ "z",
162
+ "zh",
163
+ ]
164
+ num_en_tones = 4
165
+
166
+ # combine all symbols
167
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
+ symbols = [pad] + normal_symbols + pu_symbols
169
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
+
171
+ # combine all tones
172
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
+
174
+ # language maps
175
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
+ num_languages = len(language_id_map.keys())
177
+
178
+ language_tone_start_map = {
179
+ "ZH": 0,
180
+ "JP": num_zh_tones,
181
+ "EN": num_zh_tones + num_ja_tones,
182
+ }
183
+
184
+ if __name__ == "__main__":
185
+ a = set(zh_symbols)
186
+ b = set(en_symbols)
187
+ print(sorted(a & b))
text/tone_sandhi.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ from typing import Tuple
16
+
17
+ import jieba
18
+ from pypinyin import lazy_pinyin
19
+ from pypinyin import Style
20
+
21
+
22
+ class ToneSandhi:
23
+ def __init__(self):
24
+ self.must_neural_tone_words = {
25
+ "麻烦",
26
+ "麻利",
27
+ "鸳鸯",
28
+ "高粱",
29
+ "骨头",
30
+ "骆驼",
31
+ "马虎",
32
+ "首饰",
33
+ "馒头",
34
+ "馄饨",
35
+ "风筝",
36
+ "难为",
37
+ "队伍",
38
+ "阔气",
39
+ "闺女",
40
+ "门道",
41
+ "锄头",
42
+ "铺盖",
43
+ "铃铛",
44
+ "铁匠",
45
+ "钥匙",
46
+ "里脊",
47
+ "里头",
48
+ "部分",
49
+ "那么",
50
+ "道士",
51
+ "造化",
52
+ "迷糊",
53
+ "连累",
54
+ "这么",
55
+ "这个",
56
+ "运气",
57
+ "过去",
58
+ "软和",
59
+ "转悠",
60
+ "踏实",
61
+ "跳蚤",
62
+ "跟头",
63
+ "趔趄",
64
+ "财主",
65
+ "豆腐",
66
+ "讲究",
67
+ "记性",
68
+ "记号",
69
+ "认识",
70
+ "规矩",
71
+ "见识",
72
+ "裁缝",
73
+ "补丁",
74
+ "衣裳",
75
+ "衣服",
76
+ "衙门",
77
+ "街坊",
78
+ "行李",
79
+ "行当",
80
+ "蛤蟆",
81
+ "蘑菇",
82
+ "薄荷",
83
+ "葫芦",
84
+ "葡萄",
85
+ "萝卜",
86
+ "荸荠",
87
+ "苗条",
88
+ "苗头",
89
+ "苍蝇",
90
+ "芝麻",
91
+ "舒服",
92
+ "舒坦",
93
+ "舌头",
94
+ "自在",
95
+ "膏药",
96
+ "脾气",
97
+ "脑袋",
98
+ "脊梁",
99
+ "能耐",
100
+ "胳膊",
101
+ "胭脂",
102
+ "胡萝",
103
+ "胡琴",
104
+ "胡同",
105
+ "聪明",
106
+ "耽误",
107
+ "耽搁",
108
+ "耷拉",
109
+ "耳朵",
110
+ "老爷",
111
+ "老实",
112
+ "老婆",
113
+ "老头",
114
+ "老太",
115
+ "翻腾",
116
+ "罗嗦",
117
+ "罐头",
118
+ "编辑",
119
+ "结实",
120
+ "红火",
121
+ "累赘",
122
+ "糨糊",
123
+ "糊涂",
124
+ "精神",
125
+ "粮食",
126
+ "簸箕",
127
+ "篱笆",
128
+ "算计",
129
+ "算盘",
130
+ "答应",
131
+ "笤帚",
132
+ "笑语",
133
+ "笑话",
134
+ "窟窿",
135
+ "窝囊",
136
+ "窗户",
137
+ "稳当",
138
+ "稀罕",
139
+ "称呼",
140
+ "秧歌",
141
+ "秀气",
142
+ "秀才",
143
+ "福气",
144
+ "祖宗",
145
+ "砚台",
146
+ "码头",
147
+ "石榴",
148
+ "石头",
149
+ "石匠",
150
+ "知识",
151
+ "眼睛",
152
+ "眯缝",
153
+ "眨巴",
154
+ "眉毛",
155
+ "相声",
156
+ "盘算",
157
+ "白净",
158
+ "痢疾",
159
+ "痛快",
160
+ "疟疾",
161
+ "疙瘩",
162
+ "疏忽",
163
+ "畜生",
164
+ "生意",
165
+ "甘蔗",
166
+ "琵琶",
167
+ "琢磨",
168
+ "琉璃",
169
+ "玻璃",
170
+ "玫瑰",
171
+ "玄乎",
172
+ "狐狸",
173
+ "状元",
174
+ "特务",
175
+ "牲口",
176
+ "牙碜",
177
+ "牌楼",
178
+ "爽快",
179
+ "爱人",
180
+ "热闹",
181
+ "烧饼",
182
+ "烟筒",
183
+ "烂糊",
184
+ "点心",
185
+ "炊帚",
186
+ "灯笼",
187
+ "火候",
188
+ "漂亮",
189
+ "滑溜",
190
+ "溜达",
191
+ "温和",
192
+ "清楚",
193
+ "消息",
194
+ "浪头",
195
+ "活泼",
196
+ "比方",
197
+ "正经",
198
+ "欺负",
199
+ "模糊",
200
+ "槟榔",
201
+ "棺材",
202
+ "棒槌",
203
+ "棉花",
204
+ "核桃",
205
+ "栅栏",
206
+ "柴火",
207
+ "架势",
208
+ "枕头",
209
+ "枇杷",
210
+ "机灵",
211
+ "本事",
212
+ "木头",
213
+ "木匠",
214
+ "朋友",
215
+ "月饼",
216
+ "月亮",
217
+ "暖和",
218
+ "明白",
219
+ "时候",
220
+ "新鲜",
221
+ "故事",
222
+ "收拾",
223
+ "收成",
224
+ "提防",
225
+ "挖苦",
226
+ "挑剔",
227
+ "指甲",
228
+ "指头",
229
+ "拾掇",
230
+ "拳头",
231
+ "拨弄",
232
+ "招牌",
233
+ "招呼",
234
+ "抬举",
235
+ "护士",
236
+ "折腾",
237
+ "扫帚",
238
+ "打量",
239
+ "打算",
240
+ "打点",
241
+ "打扮",
242
+ "打听",
243
+ "打发",
244
+ "扎实",
245
+ "扁担",
246
+ "戒指",
247
+ "懒得",
248
+ "意识",
249
+ "意思",
250
+ "情形",
251
+ "悟性",
252
+ "怪物",
253
+ "思量",
254
+ "怎么",
255
+ "念头",
256
+ "念叨",
257
+ "快活",
258
+ "忙活",
259
+ "志气",
260
+ "心思",
261
+ "得罪",
262
+ "张罗",
263
+ "弟兄",
264
+ "开通",
265
+ "应酬",
266
+ "庄稼",
267
+ "干事",
268
+ "帮手",
269
+ "帐篷",
270
+ "希罕",
271
+ "师父",
272
+ "师傅",
273
+ "巴结",
274
+ "巴掌",
275
+ "差事",
276
+ "工夫",
277
+ "岁数",
278
+ "屁股",
279
+ "尾巴",
280
+ "少爷",
281
+ "小气",
282
+ "小伙",
283
+ "将就",
284
+ "对头",
285
+ "对付",
286
+ "寡妇",
287
+ "家伙",
288
+ "客气",
289
+ "实在",
290
+ "官司",
291
+ "学问",
292
+ "学生",
293
+ "字号",
294
+ "嫁妆",
295
+ "媳妇",
296
+ "媒人",
297
+ "婆家",
298
+ "娘家",
299
+ "委屈",
300
+ "姑娘",
301
+ "姐夫",
302
+ "妯娌",
303
+ "妥当",
304
+ "妖精",
305
+ "奴才",
306
+ "女婿",
307
+ "头发",
308
+ "太阳",
309
+ "大爷",
310
+ "大方",
311
+ "大意",
312
+ "大夫",
313
+ "多少",
314
+ "多么",
315
+ "外甥",
316
+ "壮实",
317
+ "地道",
318
+ "地方",
319
+ "在乎",
320
+ "困难",
321
+ "嘴巴",
322
+ "嘱咐",
323
+ "嘟囔",
324
+ "嘀咕",
325
+ "喜欢",
326
+ "喇嘛",
327
+ "喇叭",
328
+ "商量",
329
+ "唾沫",
330
+ "哑巴",
331
+ "哈欠",
332
+ "哆嗦",
333
+ "咳嗽",
334
+ "和尚",
335
+ "告诉",
336
+ "告示",
337
+ "含糊",
338
+ "吓唬",
339
+ "后头",
340
+ "名字",
341
+ "名堂",
342
+ "合同",
343
+ "吆喝",
344
+ "叫唤",
345
+ "口袋",
346
+ "厚道",
347
+ "厉害",
348
+ "千斤",
349
+ "包袱",
350
+ "包涵",
351
+ "匀称",
352
+ "勤快",
353
+ "动静",
354
+ "动弹",
355
+ "功夫",
356
+ "力气",
357
+ "前头",
358
+ "刺猬",
359
+ "刺激",
360
+ "别扭",
361
+ "利落",
362
+ "利索",
363
+ "利害",
364
+ "分析",
365
+ "出息",
366
+ "凑合",
367
+ "凉快",
368
+ "冷战",
369
+ "冤枉",
370
+ "冒失",
371
+ "养活",
372
+ "关系",
373
+ "先生",
374
+ "兄弟",
375
+ "便宜",
376
+ "使唤",
377
+ "佩服",
378
+ "作坊",
379
+ "体面",
380
+ "位置",
381
+ "似的",
382
+ "伙计",
383
+ "休息",
384
+ "什么",
385
+ "人家",
386
+ "亲戚",
387
+ "亲家",
388
+ "交情",
389
+ "云彩",
390
+ "事情",
391
+ "买卖",
392
+ "主意",
393
+ "丫头",
394
+ "丧气",
395
+ "两口",
396
+ "东西",
397
+ "东家",
398
+ "世故",
399
+ "不由",
400
+ "不在",
401
+ "下水",
402
+ "下巴",
403
+ "上头",
404
+ "上司",
405
+ "丈夫",
406
+ "丈人",
407
+ "一辈",
408
+ "那个",
409
+ "菩萨",
410
+ "父亲",
411
+ "母亲",
412
+ "咕噜",
413
+ "邋遢",
414
+ "费用",
415
+ "冤家",
416
+ "甜头",
417
+ "介绍",
418
+ "荒唐",
419
+ "大人",
420
+ "泥鳅",
421
+ "幸福",
422
+ "熟悉",
423
+ "计划",
424
+ "扑腾",
425
+ "蜡烛",
426
+ "姥爷",
427
+ "照顾",
428
+ "喉咙",
429
+ "吉他",
430
+ "弄堂",
431
+ "蚂蚱",
432
+ "凤凰",
433
+ "拖沓",
434
+ "寒碜",
435
+ "糟蹋",
436
+ "倒腾",
437
+ "报复",
438
+ "逻辑",
439
+ "盘缠",
440
+ "喽啰",
441
+ "牢骚",
442
+ "咖喱",
443
+ "扫把",
444
+ "惦记",
445
+ }
446
+ self.must_not_neural_tone_words = {
447
+ "男子",
448
+ "女子",
449
+ "分子",
450
+ "原子",
451
+ "量子",
452
+ "莲子",
453
+ "石子",
454
+ "瓜子",
455
+ "电子",
456
+ "人人",
457
+ "虎虎",
458
+ }
459
+ self.punc = ":,;。?!“”‘’':,;.?!"
460
+
461
+ # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
462
+ # e.g.
463
+ # word: "家里"
464
+ # pos: "s"
465
+ # finals: ['ia1', 'i3']
466
+ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
467
+ # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
468
+ for j, item in enumerate(word):
469
+ if (
470
+ j - 1 >= 0
471
+ and item == word[j - 1]
472
+ and pos[0] in {"n", "v", "a"}
473
+ and word not in self.must_not_neural_tone_words
474
+ ):
475
+ finals[j] = finals[j][:-1] + "5"
476
+ ge_idx = word.find("个")
477
+ if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
478
+ finals[-1] = finals[-1][:-1] + "5"
479
+ elif len(word) >= 1 and word[-1] in "的地得":
480
+ finals[-1] = finals[-1][:-1] + "5"
481
+ # e.g. 走了, 看着, 去过
482
+ # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
483
+ # finals[-1] = finals[-1][:-1] + "5"
484
+ elif (
485
+ len(word) > 1
486
+ and word[-1] in "们子"
487
+ and pos in {"r", "n"}
488
+ and word not in self.must_not_neural_tone_words
489
+ ):
490
+ finals[-1] = finals[-1][:-1] + "5"
491
+ # e.g. 桌上, 地下, 家里
492
+ elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
493
+ finals[-1] = finals[-1][:-1] + "5"
494
+ # e.g. 上来, 下去
495
+ elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
496
+ finals[-1] = finals[-1][:-1] + "5"
497
+ # 个做量词
498
+ elif (
499
+ ge_idx >= 1
500
+ and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
501
+ ) or word == "个":
502
+ finals[ge_idx] = finals[ge_idx][:-1] + "5"
503
+ else:
504
+ if (
505
+ word in self.must_neural_tone_words
506
+ or word[-2:] in self.must_neural_tone_words
507
+ ):
508
+ finals[-1] = finals[-1][:-1] + "5"
509
+
510
+ word_list = self._split_word(word)
511
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
512
+ for i, word in enumerate(word_list):
513
+ # conventional neural in Chinese
514
+ if (
515
+ word in self.must_neural_tone_words
516
+ or word[-2:] in self.must_neural_tone_words
517
+ ):
518
+ finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
519
+ finals = sum(finals_list, [])
520
+ return finals
521
+
522
+ def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
523
+ # e.g. 看不懂
524
+ if len(word) == 3 and word[1] == "不":
525
+ finals[1] = finals[1][:-1] + "5"
526
+ else:
527
+ for i, char in enumerate(word):
528
+ # "不" before tone4 should be bu2, e.g. 不怕
529
+ if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
530
+ finals[i] = finals[i][:-1] + "2"
531
+ return finals
532
+
533
+ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
534
+ # "一" in number sequences, e.g. 一零零, 二一零
535
+ if word.find("一") != -1 and all(
536
+ [item.isnumeric() for item in word if item != "一"]
537
+ ):
538
+ return finals
539
+ # "一" between reduplication words should be yi5, e.g. 看一看
540
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
541
+ finals[1] = finals[1][:-1] + "5"
542
+ # when "一" is ordinal word, it should be yi1
543
+ elif word.startswith("第一"):
544
+ finals[1] = finals[1][:-1] + "1"
545
+ else:
546
+ for i, char in enumerate(word):
547
+ if char == "一" and i + 1 < len(word):
548
+ # "一" before tone4 should be yi2, e.g. 一段
549
+ if finals[i + 1][-1] == "4":
550
+ finals[i] = finals[i][:-1] + "2"
551
+ # "一" before non-tone4 should be yi4, e.g. 一天
552
+ else:
553
+ # "一" 后面如果是标点,还读一声
554
+ if word[i + 1] not in self.punc:
555
+ finals[i] = finals[i][:-1] + "4"
556
+ return finals
557
+
558
+ def _split_word(self, word: str) -> List[str]:
559
+ word_list = jieba.cut_for_search(word)
560
+ word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
561
+ first_subword = word_list[0]
562
+ first_begin_idx = word.find(first_subword)
563
+ if first_begin_idx == 0:
564
+ second_subword = word[len(first_subword) :]
565
+ new_word_list = [first_subword, second_subword]
566
+ else:
567
+ second_subword = word[: -len(first_subword)]
568
+ new_word_list = [second_subword, first_subword]
569
+ return new_word_list
570
+
571
+ def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
572
+ if len(word) == 2 and self._all_tone_three(finals):
573
+ finals[0] = finals[0][:-1] + "2"
574
+ elif len(word) == 3:
575
+ word_list = self._split_word(word)
576
+ if self._all_tone_three(finals):
577
+ # disyllabic + monosyllabic, e.g. 蒙古/包
578
+ if len(word_list[0]) == 2:
579
+ finals[0] = finals[0][:-1] + "2"
580
+ finals[1] = finals[1][:-1] + "2"
581
+ # monosyllabic + disyllabic, e.g. 纸/老虎
582
+ elif len(word_list[0]) == 1:
583
+ finals[1] = finals[1][:-1] + "2"
584
+ else:
585
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
586
+ if len(finals_list) == 2:
587
+ for i, sub in enumerate(finals_list):
588
+ # e.g. 所有/人
589
+ if self._all_tone_three(sub) and len(sub) == 2:
590
+ finals_list[i][0] = finals_list[i][0][:-1] + "2"
591
+ # e.g. 好/喜欢
592
+ elif (
593
+ i == 1
594
+ and not self._all_tone_three(sub)
595
+ and finals_list[i][0][-1] == "3"
596
+ and finals_list[0][-1][-1] == "3"
597
+ ):
598
+ finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
599
+ finals = sum(finals_list, [])
600
+ # split idiom into two words who's length is 2
601
+ elif len(word) == 4:
602
+ finals_list = [finals[:2], finals[2:]]
603
+ finals = []
604
+ for sub in finals_list:
605
+ if self._all_tone_three(sub):
606
+ sub[0] = sub[0][:-1] + "2"
607
+ finals += sub
608
+
609
+ return finals
610
+
611
+ def _all_tone_three(self, finals: List[str]) -> bool:
612
+ return all(x[-1] == "3" for x in finals)
613
+
614
+ # merge "不" and the word behind it
615
+ # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
616
+ def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
617
+ new_seg = []
618
+ last_word = ""
619
+ for word, pos in seg:
620
+ if last_word == "不":
621
+ word = last_word + word
622
+ if word != "不":
623
+ new_seg.append((word, pos))
624
+ last_word = word[:]
625
+ if last_word == "不":
626
+ new_seg.append((last_word, "d"))
627
+ last_word = ""
628
+ return new_seg
629
+
630
+ # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
631
+ # function 2: merge single "一" and the word behind it
632
+ # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
633
+ # e.g.
634
+ # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
635
+ # output seg: [['听一听', 'v']]
636
+ def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
637
+ new_seg = []
638
+ # function 1
639
+ for i, (word, pos) in enumerate(seg):
640
+ if (
641
+ i - 1 >= 0
642
+ and word == "一"
643
+ and i + 1 < len(seg)
644
+ and seg[i - 1][0] == seg[i + 1][0]
645
+ and seg[i - 1][1] == "v"
646
+ ):
647
+ new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
648
+ else:
649
+ if (
650
+ i - 2 >= 0
651
+ and seg[i - 1][0] == "一"
652
+ and seg[i - 2][0] == word
653
+ and pos == "v"
654
+ ):
655
+ continue
656
+ else:
657
+ new_seg.append([word, pos])
658
+ seg = new_seg
659
+ new_seg = []
660
+ # function 2
661
+ for i, (word, pos) in enumerate(seg):
662
+ if new_seg and new_seg[-1][0] == "一":
663
+ new_seg[-1][0] = new_seg[-1][0] + word
664
+ else:
665
+ new_seg.append([word, pos])
666
+ return new_seg
667
+
668
+ # the first and the second words are all_tone_three
669
+ def _merge_continuous_three_tones(
670
+ self, seg: List[Tuple[str, str]]
671
+ ) -> List[Tuple[str, str]]:
672
+ new_seg = []
673
+ sub_finals_list = [
674
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
675
+ for (word, pos) in seg
676
+ ]
677
+ assert len(sub_finals_list) == len(seg)
678
+ merge_last = [False] * len(seg)
679
+ for i, (word, pos) in enumerate(seg):
680
+ if (
681
+ i - 1 >= 0
682
+ and self._all_tone_three(sub_finals_list[i - 1])
683
+ and self._all_tone_three(sub_finals_list[i])
684
+ and not merge_last[i - 1]
685
+ ):
686
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
687
+ if (
688
+ not self._is_reduplication(seg[i - 1][0])
689
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
690
+ ):
691
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
692
+ merge_last[i] = True
693
+ else:
694
+ new_seg.append([word, pos])
695
+ else:
696
+ new_seg.append([word, pos])
697
+
698
+ return new_seg
699
+
700
+ def _is_reduplication(self, word: str) -> bool:
701
+ return len(word) == 2 and word[0] == word[1]
702
+
703
+ # the last char of first word and the first char of second word is tone_three
704
+ def _merge_continuous_three_tones_2(
705
+ self, seg: List[Tuple[str, str]]
706
+ ) -> List[Tuple[str, str]]:
707
+ new_seg = []
708
+ sub_finals_list = [
709
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
710
+ for (word, pos) in seg
711
+ ]
712
+ assert len(sub_finals_list) == len(seg)
713
+ merge_last = [False] * len(seg)
714
+ for i, (word, pos) in enumerate(seg):
715
+ if (
716
+ i - 1 >= 0
717
+ and sub_finals_list[i - 1][-1][-1] == "3"
718
+ and sub_finals_list[i][0][-1] == "3"
719
+ and not merge_last[i - 1]
720
+ ):
721
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
722
+ if (
723
+ not self._is_reduplication(seg[i - 1][0])
724
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
725
+ ):
726
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
727
+ merge_last[i] = True
728
+ else:
729
+ new_seg.append([word, pos])
730
+ else:
731
+ new_seg.append([word, pos])
732
+ return new_seg
733
+
734
+ def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
735
+ new_seg = []
736
+ for i, (word, pos) in enumerate(seg):
737
+ if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
738
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
739
+ else:
740
+ new_seg.append([word, pos])
741
+ return new_seg
742
+
743
+ def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
744
+ new_seg = []
745
+ for i, (word, pos) in enumerate(seg):
746
+ if new_seg and word == new_seg[-1][0]:
747
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
748
+ else:
749
+ new_seg.append([word, pos])
750
+ return new_seg
751
+
752
+ def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
753
+ seg = self._merge_bu(seg)
754
+ try:
755
+ seg = self._merge_yi(seg)
756
+ except:
757
+ print("_merge_yi failed")
758
+ seg = self._merge_reduplication(seg)
759
+ seg = self._merge_continuous_three_tones(seg)
760
+ seg = self._merge_continuous_three_tones_2(seg)
761
+ seg = self._merge_er(seg)
762
+ return seg
763
+
764
+ def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
765
+ finals = self._bu_sandhi(word, finals)
766
+ finals = self._yi_sandhi(word, finals)
767
+ finals = self._neural_sandhi(word, pos, finals)
768
+ finals = self._three_sandhi(word, finals)
769
+ return finals
tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ 工具包
3
+ """
tools/classify_language.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from config import config
4
+
5
+ LANGUAGE_IDENTIFICATION_LIBRARY = config.webui_config.language_identification_library
6
+
7
+ module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
8
+
9
+ langid_languages = [
10
+ "af",
11
+ "am",
12
+ "an",
13
+ "ar",
14
+ "as",
15
+ "az",
16
+ "be",
17
+ "bg",
18
+ "bn",
19
+ "br",
20
+ "bs",
21
+ "ca",
22
+ "cs",
23
+ "cy",
24
+ "da",
25
+ "de",
26
+ "dz",
27
+ "el",
28
+ "en",
29
+ "eo",
30
+ "es",
31
+ "et",
32
+ "eu",
33
+ "fa",
34
+ "fi",
35
+ "fo",
36
+ "fr",
37
+ "ga",
38
+ "gl",
39
+ "gu",
40
+ "he",
41
+ "hi",
42
+ "hr",
43
+ "ht",
44
+ "hu",
45
+ "hy",
46
+ "id",
47
+ "is",
48
+ "it",
49
+ "ja",
50
+ "jv",
51
+ "ka",
52
+ "kk",
53
+ "km",
54
+ "kn",
55
+ "ko",
56
+ "ku",
57
+ "ky",
58
+ "la",
59
+ "lb",
60
+ "lo",
61
+ "lt",
62
+ "lv",
63
+ "mg",
64
+ "mk",
65
+ "ml",
66
+ "mn",
67
+ "mr",
68
+ "ms",
69
+ "mt",
70
+ "nb",
71
+ "ne",
72
+ "nl",
73
+ "nn",
74
+ "no",
75
+ "oc",
76
+ "or",
77
+ "pa",
78
+ "pl",
79
+ "ps",
80
+ "pt",
81
+ "qu",
82
+ "ro",
83
+ "ru",
84
+ "rw",
85
+ "se",
86
+ "si",
87
+ "sk",
88
+ "sl",
89
+ "sq",
90
+ "sr",
91
+ "sv",
92
+ "sw",
93
+ "ta",
94
+ "te",
95
+ "th",
96
+ "tl",
97
+ "tr",
98
+ "ug",
99
+ "uk",
100
+ "ur",
101
+ "vi",
102
+ "vo",
103
+ "wa",
104
+ "xh",
105
+ "zh",
106
+ "zu",
107
+ ]
108
+
109
+
110
+ def classify_language(text: str, target_languages: list = None) -> str:
111
+ if module == "fastlid" or module == "fasttext":
112
+ from fastlid import fastlid, supported_langs
113
+
114
+ classifier = fastlid
115
+ if target_languages != None:
116
+ target_languages = [
117
+ lang for lang in target_languages if lang in supported_langs
118
+ ]
119
+ fastlid.set_languages = target_languages
120
+ elif module == "langid":
121
+ import langid
122
+
123
+ classifier = langid.classify
124
+ if target_languages != None:
125
+ target_languages = [
126
+ lang for lang in target_languages if lang in langid_languages
127
+ ]
128
+ langid.set_languages(target_languages)
129
+ else:
130
+ raise ValueError(f"Wrong module {module}")
131
+
132
+ lang = classifier(text)[0]
133
+
134
+ return lang
135
+
136
+
137
+ def classify_zh_ja(text: str) -> str:
138
+ for idx, char in enumerate(text):
139
+ unicode_val = ord(char)
140
+
141
+ # 检测日语字符
142
+ if 0x3040 <= unicode_val <= 0x309F or 0x30A0 <= unicode_val <= 0x30FF:
143
+ return "ja"
144
+
145
+ # 检测汉字字符
146
+ if 0x4E00 <= unicode_val <= 0x9FFF:
147
+ # 检查周围的字符
148
+ next_char = text[idx + 1] if idx + 1 < len(text) else None
149
+
150
+ if next_char and (
151
+ 0x3040 <= ord(next_char) <= 0x309F or 0x30A0 <= ord(next_char) <= 0x30FF
152
+ ):
153
+ return "ja"
154
+
155
+ return "zh"
156
+
157
+
158
+ def split_alpha_nonalpha(text):
159
+ return re.split(
160
+ r"(?:(?<=[\u4e00-\u9fff])|(?<=[\u3040-\u30FF]))(?=[a-zA-Z])|(?<=[a-zA-Z])(?:(?=[\u4e00-\u9fff])|(?=[\u3040-\u30FF]))",
161
+ text,
162
+ )
163
+
164
+
165
+ if __name__ == "__main__":
166
+ text = "这是一个测试文本"
167
+ print(classify_language(text))
168
+ print(classify_zh_ja(text)) # "zh"
169
+
170
+ text = "これはテストテキストです"
171
+ print(classify_language(text))
172
+ print(classify_zh_ja(text)) # "ja"
tools/log.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ logger封装
3
+ """
4
+ from loguru import logger
5
+ import sys
6
+
7
+
8
+ # 移除所有默认的处理器
9
+ logger.remove()
10
+
11
+ # 自定义格式并添加到标准输出
12
+ log_format = (
13
+ "<g>{time:MM-DD HH:mm:ss}</g> <lvl>{level:<9}</lvl>| {file}:{line} | {message}"
14
+ )
15
+
16
+ logger.add(sys.stdout, format=log_format, backtrace=True, diagnose=True)
tools/sentence.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import regex as re
4
+
5
+ from tools.classify_language import classify_language, split_alpha_nonalpha
6
+
7
+
8
+ def check_is_none(item) -> bool:
9
+ """none -> True, not none -> False"""
10
+ return (
11
+ item is None
12
+ or (isinstance(item, str) and str(item).isspace())
13
+ or str(item) == ""
14
+ )
15
+
16
+
17
+ def markup_language(text: str, target_languages: list = None) -> str:
18
+ pattern = (
19
+ r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
20
+ r"\!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
21
+ r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+"
22
+ )
23
+ sentences = re.split(pattern, text)
24
+
25
+ pre_lang = ""
26
+ p = 0
27
+
28
+ sorted_target_languages = sorted(target_languages)
29
+ if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
30
+ new_sentences = []
31
+ for sentence in sentences:
32
+ new_sentences.extend(split_alpha_nonalpha(sentence))
33
+ sentences = new_sentences
34
+
35
+ for sentence in sentences:
36
+ if check_is_none(sentence):
37
+ continue
38
+
39
+ lang = classify_language(sentence, target_languages)
40
+
41
+ if pre_lang == "":
42
+ text = text[:p] + text[p:].replace(
43
+ sentence, f"[{lang.upper()}]{sentence}", 1
44
+ )
45
+ p += len(f"[{lang.upper()}]")
46
+ elif pre_lang != lang:
47
+ text = text[:p] + text[p:].replace(
48
+ sentence, f"[{pre_lang.upper()}][{lang.upper()}]{sentence}", 1
49
+ )
50
+ p += len(f"[{pre_lang.upper()}][{lang.upper()}]")
51
+ pre_lang = lang
52
+ p += text[p:].index(sentence) + len(sentence)
53
+ text += f"[{pre_lang.upper()}]"
54
+
55
+ return text
56
+
57
+
58
+ def split_by_language(text: str, target_languages: list = None) -> list:
59
+ pattern = (
60
+ r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
61
+ r"\!?\。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
62
+ r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+"
63
+ )
64
+ sentences = re.split(pattern, text)
65
+
66
+ pre_lang = ""
67
+ start = 0
68
+ end = 0
69
+ sentences_list = []
70
+
71
+ sorted_target_languages = sorted(target_languages)
72
+ if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
73
+ new_sentences = []
74
+ for sentence in sentences:
75
+ new_sentences.extend(split_alpha_nonalpha(sentence))
76
+ sentences = new_sentences
77
+
78
+ for sentence in sentences:
79
+ if check_is_none(sentence):
80
+ continue
81
+
82
+ lang = classify_language(sentence, target_languages)
83
+
84
+ end += text[end:].index(sentence)
85
+ if pre_lang != "" and pre_lang != lang:
86
+ sentences_list.append((text[start:end], pre_lang))
87
+ start = end
88
+ end += len(sentence)
89
+ pre_lang = lang
90
+ sentences_list.append((text[start:], pre_lang))
91
+
92
+ return sentences_list
93
+
94
+
95
+ def sentence_split(text: str, max: int) -> list:
96
+ pattern = r"[!(),—+\-.:;??。,、;:]+"
97
+ sentences = re.split(pattern, text)
98
+ discarded_chars = re.findall(pattern, text)
99
+
100
+ sentences_list, count, p = [], 0, 0
101
+
102
+ # 按被分割的符号遍历
103
+ for i, discarded_chars in enumerate(discarded_chars):
104
+ count += len(sentences[i]) + len(discarded_chars)
105
+ if count >= max:
106
+ sentences_list.append(text[p : p + count].strip())
107
+ p += count
108
+ count = 0
109
+
110
+ # 加入最后剩余的文本
111
+ if p < len(text):
112
+ sentences_list.append(text[p:])
113
+
114
+ return sentences_list
115
+
116
+
117
+ def sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None):
118
+ # 如果该speaker只支持一种语言
119
+ if speaker_lang is not None and len(speaker_lang) == 1:
120
+ if lang.upper() not in ["AUTO", "MIX"] and lang.lower() != speaker_lang[0]:
121
+ logging.debug(
122
+ f'lang "{lang}" is not in speaker_lang {speaker_lang},automatically set lang={speaker_lang[0]}'
123
+ )
124
+ lang = speaker_lang[0]
125
+
126
+ sentences_list = []
127
+ if lang.upper() != "MIX":
128
+ if max <= 0:
129
+ sentences_list.append(
130
+ markup_language(text, speaker_lang)
131
+ if lang.upper() == "AUTO"
132
+ else f"[{lang.upper()}]{text}[{lang.upper()}]"
133
+ )
134
+ else:
135
+ for i in sentence_split(text, max):
136
+ if check_is_none(i):
137
+ continue
138
+ sentences_list.append(
139
+ markup_language(i, speaker_lang)
140
+ if lang.upper() == "AUTO"
141
+ else f"[{lang.upper()}]{i}[{lang.upper()}]"
142
+ )
143
+ else:
144
+ sentences_list.append(text)
145
+
146
+ for i in sentences_list:
147
+ logging.debug(i)
148
+
149
+ return sentences_list
150
+
151
+
152
+ if __name__ == "__main__":
153
+ text = "这几天心里颇不宁静。今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。我悄悄地披了大衫,带上门出去。"
154
+ print(markup_language(text, target_languages=None))
155
+ print(sentence_split(text, max=50))
156
+ print(sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None))
157
+ text = "你好,这是一段用来测试自动标注的文本。こんにちは,これは自動ラベリングのテスト用テキストです.Hello, this is a piece of text to test autotagging.你好!今天我们要介绍VITS项目,其重点是使用了GAN Duration predictor和transformer flow,并且接入了Bert模型来提升韵律。Bert embedding会在稍后介绍。"
158
+ print(split_by_language(text, ["zh", "ja", "en"]))
tools/translate.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 翻译api
3
+ """
4
+ from config import config
5
+
6
+ import random
7
+ import hashlib
8
+ import requests
9
+
10
+
11
+ def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
12
+ """
13
+ :param Sentence: 待翻译语句
14
+ :param from_Language: 待翻译语句语言
15
+ :param to_Language: 目标语言
16
+ :return: 翻译后语句 出错时返回None
17
+
18
+ 常见语言代码:中文 zh 英语 en 日语 jp
19
+ """
20
+ appid = config.translate_config.app_key
21
+ key = config.translate_config.secret_key
22
+ if appid == "" or key == "":
23
+ return "请开发者在config.yml中配置app_key与secret_key"
24
+ url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
25
+ texts = Sentence.splitlines()
26
+ outTexts = []
27
+ for t in texts:
28
+ if t != "":
29
+ # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
30
+ salt = str(random.randint(1, 100000))
31
+ signString = appid + t + salt + key
32
+ hs = hashlib.md5()
33
+ hs.update(signString.encode("utf-8"))
34
+ signString = hs.hexdigest()
35
+ if from_Language == "":
36
+ from_Language = "auto"
37
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
38
+ payload = {
39
+ "q": t,
40
+ "from": from_Language,
41
+ "to": to_Language,
42
+ "appid": appid,
43
+ "salt": salt,
44
+ "sign": signString,
45
+ }
46
+ # 发送请求
47
+ try:
48
+ response = requests.post(
49
+ url=url, data=payload, headers=headers, timeout=3
50
+ )
51
+ response = response.json()
52
+ if "trans_result" in response.keys():
53
+ result = response["trans_result"][0]
54
+ if "dst" in result.keys():
55
+ dst = result["dst"]
56
+ outTexts.append(dst)
57
+ except Exception:
58
+ return Sentence
59
+ else:
60
+ outTexts.append(t)
61
+ return "\n".join(outTexts)
transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import argparse
4
+ import logging
5
+ import json
6
+ import shutil
7
+ import subprocess
8
+ import numpy as np
9
+ from huggingface_hub import hf_hub_download
10
+ from scipy.io.wavfile import read
11
+ import torch
12
+ import re
13
+
14
+ MATPLOTLIB_FLAG = False
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def download_emo_models(mirror, repo_id, model_name):
20
+ if mirror == "openi":
21
+ import openi
22
+
23
+ openi.model.download_model(
24
+ "Stardust_minus/Bert-VITS2",
25
+ repo_id.split("/")[-1],
26
+ "./emotional",
27
+ )
28
+ else:
29
+ hf_hub_download(
30
+ repo_id,
31
+ "pytorch_model.bin",
32
+ local_dir=model_name,
33
+ local_dir_use_symlinks=False,
34
+ )
35
+
36
+
37
+ def download_checkpoint(
38
+ dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
39
+ ):
40
+ repo_id = repo_config["repo_id"]
41
+ f_list = glob.glob(os.path.join(dir_path, regex))
42
+ if f_list:
43
+ print("Use existed model, skip downloading.")
44
+ return
45
+ if mirror.lower() == "openi":
46
+ import openi
47
+
48
+ kwargs = {"token": token} if token else {}
49
+ openi.login(**kwargs)
50
+
51
+ model_image = repo_config["model_image"]
52
+ openi.model.download_model(repo_id, model_image, dir_path)
53
+
54
+ fs = glob.glob(os.path.join(dir_path, model_image, "*.pth"))
55
+ for file in fs:
56
+ shutil.move(file, dir_path)
57
+ shutil.rmtree(os.path.join(dir_path, model_image))
58
+ else:
59
+ for file in ["DUR_0.pth", "D_0.pth", "G_0.pth"]:
60
+ hf_hub_download(
61
+ repo_id, file, local_dir=dir_path, local_dir_use_symlinks=False
62
+ )
63
+
64
+
65
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
66
+ assert os.path.isfile(checkpoint_path)
67
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
68
+ iteration = checkpoint_dict["iteration"]
69
+ learning_rate = checkpoint_dict["learning_rate"]
70
+ if (
71
+ optimizer is not None
72
+ and not skip_optimizer
73
+ and checkpoint_dict["optimizer"] is not None
74
+ ):
75
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
76
+ elif optimizer is None and not skip_optimizer:
77
+ # else: Disable this line if Infer and resume checkpoint,then enable the line upper
78
+ new_opt_dict = optimizer.state_dict()
79
+ new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
80
+ new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
81
+ new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
82
+ optimizer.load_state_dict(new_opt_dict)
83
+
84
+ saved_state_dict = checkpoint_dict["model"]
85
+ if hasattr(model, "module"):
86
+ state_dict = model.module.state_dict()
87
+ else:
88
+ state_dict = model.state_dict()
89
+
90
+ new_state_dict = {}
91
+ for k, v in state_dict.items():
92
+ try:
93
+ # assert "emb_g" not in k
94
+ new_state_dict[k] = saved_state_dict[k]
95
+ assert saved_state_dict[k].shape == v.shape, (
96
+ saved_state_dict[k].shape,
97
+ v.shape,
98
+ )
99
+ except:
100
+ # For upgrading from the old version
101
+ if "ja_bert_proj" in k:
102
+ v = torch.zeros_like(v)
103
+ logger.warn(
104
+ f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
105
+ )
106
+ else:
107
+ logger.error(f"{k} is not in the checkpoint")
108
+
109
+ new_state_dict[k] = v
110
+
111
+ if hasattr(model, "module"):
112
+ model.module.load_state_dict(new_state_dict, strict=False)
113
+ else:
114
+ model.load_state_dict(new_state_dict, strict=False)
115
+
116
+ logger.info(
117
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
118
+ )
119
+
120
+ return model, optimizer, learning_rate, iteration
121
+
122
+
123
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
124
+ logger.info(
125
+ "Saving model and optimizer state at iteration {} to {}".format(
126
+ iteration, checkpoint_path
127
+ )
128
+ )
129
+ if hasattr(model, "module"):
130
+ state_dict = model.module.state_dict()
131
+ else:
132
+ state_dict = model.state_dict()
133
+ torch.save(
134
+ {
135
+ "model": state_dict,
136
+ "iteration": iteration,
137
+ "optimizer": optimizer.state_dict(),
138
+ "learning_rate": learning_rate,
139
+ },
140
+ checkpoint_path,
141
+ )
142
+
143
+
144
+ def summarize(
145
+ writer,
146
+ global_step,
147
+ scalars={},
148
+ histograms={},
149
+ images={},
150
+ audios={},
151
+ audio_sampling_rate=22050,
152
+ ):
153
+ for k, v in scalars.items():
154
+ writer.add_scalar(k, v, global_step)
155
+ for k, v in histograms.items():
156
+ writer.add_histogram(k, v, global_step)
157
+ for k, v in images.items():
158
+ writer.add_image(k, v, global_step, dataformats="HWC")
159
+ for k, v in audios.items():
160
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
161
+
162
+
163
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
164
+ f_list = glob.glob(os.path.join(dir_path, regex))
165
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
166
+ x = f_list[-1]
167
+ return x
168
+
169
+
170
+ def plot_spectrogram_to_numpy(spectrogram):
171
+ global MATPLOTLIB_FLAG
172
+ if not MATPLOTLIB_FLAG:
173
+ import matplotlib
174
+
175
+ matplotlib.use("Agg")
176
+ MATPLOTLIB_FLAG = True
177
+ mpl_logger = logging.getLogger("matplotlib")
178
+ mpl_logger.setLevel(logging.WARNING)
179
+ import matplotlib.pylab as plt
180
+ import numpy as np
181
+
182
+ fig, ax = plt.subplots(figsize=(10, 2))
183
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
184
+ plt.colorbar(im, ax=ax)
185
+ plt.xlabel("Frames")
186
+ plt.ylabel("Channels")
187
+ plt.tight_layout()
188
+
189
+ fig.canvas.draw()
190
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
191
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
192
+ plt.close()
193
+ return data
194
+
195
+
196
+ def plot_alignment_to_numpy(alignment, info=None):
197
+ global MATPLOTLIB_FLAG
198
+ if not MATPLOTLIB_FLAG:
199
+ import matplotlib
200
+
201
+ matplotlib.use("Agg")
202
+ MATPLOTLIB_FLAG = True
203
+ mpl_logger = logging.getLogger("matplotlib")
204
+ mpl_logger.setLevel(logging.WARNING)
205
+ import matplotlib.pylab as plt
206
+ import numpy as np
207
+
208
+ fig, ax = plt.subplots(figsize=(6, 4))
209
+ im = ax.imshow(
210
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
211
+ )
212
+ fig.colorbar(im, ax=ax)
213
+ xlabel = "Decoder timestep"
214
+ if info is not None:
215
+ xlabel += "\n\n" + info
216
+ plt.xlabel(xlabel)
217
+ plt.ylabel("Encoder timestep")
218
+ plt.tight_layout()
219
+
220
+ fig.canvas.draw()
221
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
222
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
223
+ plt.close()
224
+ return data
225
+
226
+
227
+ def load_wav_to_torch(full_path):
228
+ sampling_rate, data = read(full_path)
229
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
230
+
231
+
232
+ def load_filepaths_and_text(filename, split="|"):
233
+ with open(filename, encoding="utf-8") as f:
234
+ filepaths_and_text = [line.strip().split(split) for line in f]
235
+ return filepaths_and_text
236
+
237
+
238
+ def get_hparams(init=True):
239
+ parser = argparse.ArgumentParser()
240
+ parser.add_argument(
241
+ "-c",
242
+ "--config",
243
+ type=str,
244
+ default="./configs/base.json",
245
+ help="JSON file for configuration",
246
+ )
247
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
248
+
249
+ args = parser.parse_args()
250
+ model_dir = os.path.join("./logs", args.model)
251
+
252
+ if not os.path.exists(model_dir):
253
+ os.makedirs(model_dir)
254
+
255
+ config_path = args.config
256
+ config_save_path = os.path.join(model_dir, "config.json")
257
+ if init:
258
+ with open(config_path, "r", encoding="utf-8") as f:
259
+ data = f.read()
260
+ with open(config_save_path, "w", encoding="utf-8") as f:
261
+ f.write(data)
262
+ else:
263
+ with open(config_save_path, "r", vencoding="utf-8") as f:
264
+ data = f.read()
265
+ config = json.loads(data)
266
+ hparams = HParams(**config)
267
+ hparams.model_dir = model_dir
268
+ return hparams
269
+
270
+
271
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
272
+ """Freeing up space by deleting saved ckpts
273
+
274
+ Arguments:
275
+ path_to_models -- Path to the model directory
276
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
277
+ sort_by_time -- True -> chronologically delete ckpts
278
+ False -> lexicographically delete ckpts
279
+ """
280
+ import re
281
+
282
+ ckpts_files = [
283
+ f
284
+ for f in os.listdir(path_to_models)
285
+ if os.path.isfile(os.path.join(path_to_models, f))
286
+ ]
287
+
288
+ def name_key(_f):
289
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
290
+
291
+ def time_key(_f):
292
+ return os.path.getmtime(os.path.join(path_to_models, _f))
293
+
294
+ sort_key = time_key if sort_by_time else name_key
295
+
296
+ def x_sorted(_x):
297
+ return sorted(
298
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
299
+ key=sort_key,
300
+ )
301
+
302
+ to_del = [
303
+ os.path.join(path_to_models, fn)
304
+ for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
305
+ ]
306
+
307
+ def del_info(fn):
308
+ return logger.info(f".. Free up space by deleting ckpt {fn}")
309
+
310
+ def del_routine(x):
311
+ return [os.remove(x), del_info(x)]
312
+
313
+ [del_routine(fn) for fn in to_del]
314
+
315
+
316
+ def get_hparams_from_dir(model_dir):
317
+ config_save_path = os.path.join(model_dir, "config.json")
318
+ with open(config_save_path, "r", encoding="utf-8") as f:
319
+ data = f.read()
320
+ config = json.loads(data)
321
+
322
+ hparams = HParams(**config)
323
+ hparams.model_dir = model_dir
324
+ return hparams
325
+
326
+
327
+ def get_hparams_from_file(config_path):
328
+ # print("config_path: ", config_path)
329
+ with open(config_path, "r", encoding="utf-8") as f:
330
+ data = f.read()
331
+ config = json.loads(data)
332
+
333
+ hparams = HParams(**config)
334
+ return hparams
335
+
336
+
337
+ def check_git_hash(model_dir):
338
+ source_dir = os.path.dirname(os.path.realpath(__file__))
339
+ if not os.path.exists(os.path.join(source_dir, ".git")):
340
+ logger.warn(
341
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
342
+ source_dir
343
+ )
344
+ )
345
+ return
346
+
347
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
348
+
349
+ path = os.path.join(model_dir, "githash")
350
+ if os.path.exists(path):
351
+ saved_hash = open(path).read()
352
+ if saved_hash != cur_hash:
353
+ logger.warn(
354
+ "git hash values are different. {}(saved) != {}(current)".format(
355
+ saved_hash[:8], cur_hash[:8]
356
+ )
357
+ )
358
+ else:
359
+ open(path, "w").write(cur_hash)
360
+
361
+
362
+ def get_logger(model_dir, filename="train.log"):
363
+ global logger
364
+ logger = logging.getLogger(os.path.basename(model_dir))
365
+ logger.setLevel(logging.DEBUG)
366
+
367
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
368
+ if not os.path.exists(model_dir):
369
+ os.makedirs(model_dir)
370
+ h = logging.FileHandler(os.path.join(model_dir, filename))
371
+ h.setLevel(logging.DEBUG)
372
+ h.setFormatter(formatter)
373
+ logger.addHandler(h)
374
+ return logger
375
+
376
+
377
+ class HParams:
378
+ def __init__(self, **kwargs):
379
+ for k, v in kwargs.items():
380
+ if type(v) == dict:
381
+ v = HParams(**v)
382
+ self[k] = v
383
+
384
+ def keys(self):
385
+ return self.__dict__.keys()
386
+
387
+ def items(self):
388
+ return self.__dict__.items()
389
+
390
+ def values(self):
391
+ return self.__dict__.values()
392
+
393
+ def __len__(self):
394
+ return len(self.__dict__)
395
+
396
+ def __getitem__(self, key):
397
+ return getattr(self, key)
398
+
399
+ def __setitem__(self, key, value):
400
+ return setattr(self, key, value)
401
+
402
+ def __contains__(self, key):
403
+ return key in self.__dict__
404
+
405
+ def __repr__(self):
406
+ return self.__dict__.__repr__()
407
+
408
+
409
+ def load_model(model_path, config_path):
410
+ hps = get_hparams_from_file(config_path)
411
+ net = SynthesizerTrn(
412
+ # len(symbols),
413
+ 108,
414
+ hps.data.filter_length // 2 + 1,
415
+ hps.train.segment_size // hps.data.hop_length,
416
+ n_speakers=hps.data.n_speakers,
417
+ **hps.model,
418
+ ).to("cpu")
419
+ _ = net.eval()
420
+ _ = load_checkpoint(model_path, net, None, skip_optimizer=True)
421
+ return net
422
+
423
+
424
+ def mix_model(
425
+ network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5)
426
+ ):
427
+ if hasattr(network1, "module"):
428
+ state_dict1 = network1.module.state_dict()
429
+ state_dict2 = network2.module.state_dict()
430
+ else:
431
+ state_dict1 = network1.state_dict()
432
+ state_dict2 = network2.state_dict()
433
+ for k in state_dict1.keys():
434
+ if k not in state_dict2.keys():
435
+ continue
436
+ if "enc_p" in k:
437
+ state_dict1[k] = (
438
+ state_dict1[k].clone() * tone_ratio[0]
439
+ + state_dict2[k].clone() * tone_ratio[1]
440
+ )
441
+ else:
442
+ state_dict1[k] = (
443
+ state_dict1[k].clone() * voice_ratio[0]
444
+ + state_dict2[k].clone() * voice_ratio[1]
445
+ )
446
+ for k in state_dict2.keys():
447
+ if k not in state_dict1.keys():
448
+ state_dict1[k] = state_dict2[k].clone()
449
+ torch.save(
450
+ {"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0},
451
+ output_path,
452
+ )
453
+
454
+
455
+ def get_steps(model_path):
456
+ matches = re.findall(r"\d+", model_path)
457
+ return matches[-1] if matches else None