XzJosh commited on
Commit
ae80214
1 Parent(s): 4bf797e

Upload 201 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +12 -11
  2. .gitignore +1 -24
  3. .pre-commit-config.yaml +2 -2
  4. Data/Azusa/config.json +108 -0
  5. Data/Azusa/models/G_11300.pth +3 -0
  6. app.py +196 -177
  7. bert_gen.py +24 -17
  8. clap_gen.py +1 -1
  9. configs/config.json +771 -767
  10. data_utils.py +7 -23
  11. default_config.yml +3 -3
  12. export_onnx.py +6 -4
  13. for_deploy/infer.py +386 -0
  14. for_deploy/infer_utils.py +111 -0
  15. for_deploy/webui.py +556 -0
  16. infer.py +118 -88
  17. losses.py +95 -0
  18. models.py +66 -65
  19. oldVersion/V210/__init__.py +9 -4
  20. oldVersion/V210/models.py +1 -1
  21. oldVersion/V210/text/__init__.py +4 -2
  22. oldVersion/V210/text/chinese_bert.py +21 -2
  23. oldVersion/V210/text/english_bert_mock.py +21 -2
  24. oldVersion/V210/text/japanese_bert.py +23 -2
  25. onnx_infer.py +68 -0
  26. onnx_modules/V200/__init__.py +4 -0
  27. onnx_modules/V200_OnnxInference/__init__.py +126 -0
  28. onnx_modules/V210/__init__.py +4 -0
  29. onnx_modules/V210/models_onnx.py +1 -1
  30. onnx_modules/V210_OnnxInference/__init__.py +129 -0
  31. onnx_modules/V220/__init__.py +4 -0
  32. onnx_modules/V220/attentions_onnx.py +378 -0
  33. onnx_modules/V220/models_onnx.py +1076 -0
  34. onnx_modules/V220/text/__init__.py +1 -0
  35. onnx_modules/V220/text/symbols.py +187 -0
  36. onnx_modules/V220_OnnxInference/__init__.py +128 -0
  37. onnx_modules/V220_novq_dev/__init__.py +4 -0
  38. onnx_modules/V220_novq_dev/attentions_onnx.py +378 -0
  39. onnx_modules/V220_novq_dev/models_onnx.py +1048 -0
  40. onnx_modules/V220_novq_dev/text/__init__.py +1 -0
  41. onnx_modules/V220_novq_dev/text/symbols.py +187 -0
  42. onnx_modules/V230/__init__.py +4 -0
  43. onnx_modules/V230/attentions_onnx.py +378 -0
  44. onnx_modules/V230/models_onnx.py +1061 -0
  45. onnx_modules/V230/text/__init__.py +1 -0
  46. onnx_modules/V230/text/symbols.py +187 -0
  47. onnx_modules/V230_OnnxInference/__init__.py +126 -0
  48. onnx_modules/__init__.py +12 -4
  49. re_matching.py +0 -1
  50. requirements.txt +2 -3
.gitattributes CHANGED
@@ -1,35 +1,36 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -textoldVersion/V200/text/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
35
+ oldVersion/V210/text/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
36
+ text/cmudict_cache.pickle filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -159,27 +159,4 @@ cython_debug/
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
 
162
- .DS_Store
163
- /models
164
- /logs
165
-
166
- filelists/*
167
- !/filelists/esd.list
168
- data/*
169
- /*.yml
170
- !/default_config.yml
171
- /Web/
172
- /emotional/*/*.bin
173
- /bert/*/*.bin
174
- /bert/*/*.h5
175
- /bert/*/*.model
176
- /bert/*/*.safetensors
177
- /bert/*/*.msgpack
178
- asr_transcript.py
179
- extract_list.py
180
- dataset
181
- /Data
182
- Model
183
- raw/
184
- logs/
185
- Data/*
 
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
 
162
+ .DS_Store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.pre-commit-config.yaml CHANGED
@@ -7,13 +7,13 @@ repos:
7
  - id: trailing-whitespace
8
 
9
  - repo: https://github.com/astral-sh/ruff-pre-commit
10
- rev: v0.1.7
11
  hooks:
12
  - id: ruff
13
  args: [ --fix ]
14
 
15
  - repo: https://github.com/psf/black
16
- rev: 23.11.0
17
  hooks:
18
  - id: black
19
 
 
7
  - id: trailing-whitespace
8
 
9
  - repo: https://github.com/astral-sh/ruff-pre-commit
10
+ rev: v0.1.8
11
  hooks:
12
  - id: ruff
13
  args: [ --fix ]
14
 
15
  - repo: https://github.com/psf/black
16
+ rev: 23.12.0
17
  hooks:
18
  - id: black
19
 
Data/Azusa/config.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 100,
4
+ "eval_interval": 100,
5
+ "seed": 42,
6
+ "epochs": 1000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [
9
+ 0.8,
10
+ 0.99
11
+ ],
12
+ "eps": 1e-09,
13
+ "batch_size": 12,
14
+ "bf16_run": false,
15
+ "lr_decay": 0.99995,
16
+ "segment_size": 16384,
17
+ "init_lr_ratio": 1,
18
+ "warmup_epochs": 0,
19
+ "c_mel": 45,
20
+ "c_kl": 1.0,
21
+ "c_commit": 100,
22
+ "skip_optimizer": true,
23
+ "freeze_ZH_bert": false,
24
+ "freeze_JP_bert": false,
25
+ "freeze_EN_bert": false,
26
+ "freeze_emo": false
27
+ },
28
+ "data": {
29
+ "training_files": "Data/Azusa/filelists/train.list",
30
+ "validation_files": "Data/Azusa/filelists/val.list",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 128,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": null,
39
+ "add_blank": true,
40
+ "n_speakers": 1,
41
+ "cleaned_text": true,
42
+ "spk2id": {
43
+ "Azusa": 0
44
+ }
45
+ },
46
+ "model": {
47
+ "use_spk_conditioned_encoder": true,
48
+ "use_noise_scaled_mas": true,
49
+ "use_mel_posterior_encoder": false,
50
+ "use_duration_discriminator": true,
51
+ "inter_channels": 192,
52
+ "hidden_channels": 192,
53
+ "filter_channels": 768,
54
+ "n_heads": 2,
55
+ "n_layers": 6,
56
+ "kernel_size": 3,
57
+ "p_dropout": 0.1,
58
+ "resblock": "1",
59
+ "resblock_kernel_sizes": [
60
+ 3,
61
+ 7,
62
+ 11
63
+ ],
64
+ "resblock_dilation_sizes": [
65
+ [
66
+ 1,
67
+ 3,
68
+ 5
69
+ ],
70
+ [
71
+ 1,
72
+ 3,
73
+ 5
74
+ ],
75
+ [
76
+ 1,
77
+ 3,
78
+ 5
79
+ ]
80
+ ],
81
+ "upsample_rates": [
82
+ 8,
83
+ 8,
84
+ 2,
85
+ 2,
86
+ 2
87
+ ],
88
+ "upsample_initial_channel": 512,
89
+ "upsample_kernel_sizes": [
90
+ 16,
91
+ 16,
92
+ 8,
93
+ 2,
94
+ 2
95
+ ],
96
+ "n_layers_q": 3,
97
+ "use_spectral_norm": false,
98
+ "gin_channels": 512,
99
+ "slm": {
100
+ "model": "./slm/wavlm-base-plus",
101
+ "sr": 16000,
102
+ "hidden": 768,
103
+ "nlayers": 13,
104
+ "initial_channel": 64
105
+ }
106
+ },
107
+ "version": "2.3"
108
+ }
Data/Azusa/models/G_11300.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0708043b54ab21eb8ec1b600982ea7b105bcded370a9207281e043c64e195dc3
3
+ size 728379830
app.py CHANGED
@@ -16,6 +16,10 @@ logging.basicConfig(
16
  logger = logging.getLogger(__name__)
17
 
18
  import torch
 
 
 
 
19
  import utils
20
  from infer import infer, latest_version, get_net_g, infer_multilang
21
  import gradio as gr
@@ -42,6 +46,8 @@ def generate_audio(
42
  language,
43
  reference_audio,
44
  emotion,
 
 
45
  skip_start=False,
46
  skip_end=False,
47
  ):
@@ -49,8 +55,8 @@ def generate_audio(
49
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
50
  with torch.no_grad():
51
  for idx, piece in enumerate(slices):
52
- skip_start = (idx != 0) and skip_start
53
- skip_end = (idx != len(slices) - 1) and skip_end
54
  audio = infer(
55
  piece,
56
  reference_audio=reference_audio,
@@ -66,10 +72,11 @@ def generate_audio(
66
  device=device,
67
  skip_start=skip_start,
68
  skip_end=skip_end,
 
 
69
  )
70
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
71
  audio_list.append(audio16bit)
72
- # audio_list.append(silence) # 将静音添加到列表中
73
  return audio_list
74
 
75
 
@@ -90,8 +97,8 @@ def generate_audio_multilang(
90
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
91
  with torch.no_grad():
92
  for idx, piece in enumerate(slices):
93
- skip_start = (idx != 0) and skip_start
94
- skip_end = (idx != len(slices) - 1) and skip_end
95
  audio = infer_multilang(
96
  piece,
97
  reference_audio=reference_audio,
@@ -110,7 +117,6 @@ def generate_audio_multilang(
110
  )
111
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
112
  audio_list.append(audio16bit)
113
- # audio_list.append(silence) # 将静音添加到列表中
114
  return audio_list
115
 
116
 
@@ -127,63 +133,50 @@ def tts_split(
127
  interval_between_sent,
128
  reference_audio,
129
  emotion,
 
 
130
  ):
131
- if language == "mix":
132
- return ("invalid", None)
133
  while text.find("\n\n") != -1:
134
  text = text.replace("\n\n", "\n")
 
135
  para_list = re_matching.cut_para(text)
 
136
  audio_list = []
137
- if not cut_by_sent:
138
- for idx, p in enumerate(para_list):
139
- skip_start = idx != 0
140
- skip_end = idx != len(para_list) - 1
141
- audio = infer(
142
  p,
143
- reference_audio=reference_audio,
144
- emotion=emotion,
145
- sdp_ratio=sdp_ratio,
146
- noise_scale=noise_scale,
147
- noise_scale_w=noise_scale_w,
148
- length_scale=length_scale,
149
- sid=speaker,
150
- language=language,
151
- hps=hps,
152
- net_g=net_g,
153
- device=device,
154
- skip_start=skip_start,
155
- skip_end=skip_end,
156
  )
157
- audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
158
- audio_list.append(audio16bit)
159
  silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
160
  audio_list.append(silence)
161
- else:
162
- for idx, p in enumerate(para_list):
163
- skip_start = idx != 0
164
- skip_end = idx != len(para_list) - 1
165
  audio_list_sent = []
166
  sent_list = re_matching.cut_sent(p)
167
- for idx, s in enumerate(sent_list):
168
- skip_start = (idx != 0) and skip_start
169
- skip_end = (idx != len(sent_list) - 1) and skip_end
170
- audio = infer(
171
  s,
172
- reference_audio=reference_audio,
173
- emotion=emotion,
174
- sdp_ratio=sdp_ratio,
175
- noise_scale=noise_scale,
176
- noise_scale_w=noise_scale_w,
177
- length_scale=length_scale,
178
- sid=speaker,
179
- language=language,
180
- hps=hps,
181
- net_g=net_g,
182
- device=device,
183
- skip_start=skip_start,
184
- skip_end=skip_end,
185
  )
186
- audio_list_sent.append(audio)
187
  silence = np.zeros((int)(44100 * interval_between_sent))
188
  audio_list_sent.append(silence)
189
  if (interval_between_para - interval_between_sent) > 0:
@@ -196,10 +189,47 @@ def tts_split(
196
  ) # 对完整句子做音量归一
197
  audio_list.append(audio16bit)
198
  audio_concat = np.concatenate(audio_list)
199
- return ("Success", (44100, audio_concat))
200
 
201
 
202
- def tts_fn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  text: str,
204
  speaker,
205
  sdp_ratio,
@@ -209,15 +239,9 @@ def tts_fn(
209
  language,
210
  reference_audio,
211
  emotion,
212
- prompt_mode,
 
213
  ):
214
- if prompt_mode == "Audio prompt":
215
- if reference_audio == None:
216
- return ("Invalid audio prompt", None)
217
- else:
218
- reference_audio = load_audio(reference_audio)[1]
219
- else:
220
- reference_audio = None
221
  audio_list = []
222
  if language == "mix":
223
  bool_valid, str_valid = re_matching.validate_text(text)
@@ -226,120 +250,40 @@ def tts_fn(
226
  hps.data.sampling_rate,
227
  np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
228
  )
229
- result = []
230
  for slice in re_matching.text_matching(text):
231
- _speaker = slice.pop()
232
- temp_contant = []
233
- temp_lang = []
234
- for lang, content in slice:
235
- if "|" in content:
236
- temp = []
237
- temp_ = []
238
- for i in content.split("|"):
239
- if i != "":
240
- temp.append([i])
241
- temp_.append([lang])
242
- else:
243
- temp.append([])
244
- temp_.append([])
245
- temp_contant += temp
246
- temp_lang += temp_
247
- else:
248
- if len(temp_contant) == 0:
249
- temp_contant.append([])
250
- temp_lang.append([])
251
- temp_contant[-1].append(content)
252
- temp_lang[-1].append(lang)
253
- for i, j in zip(temp_lang, temp_contant):
254
- result.append([*zip(i, j), _speaker])
255
- for i, one in enumerate(result):
256
- skip_start = i != 0
257
- skip_end = i != len(result) - 1
258
- _speaker = one.pop()
259
- idx = 0
260
- while idx < len(one):
261
- text_to_generate = []
262
- lang_to_generate = []
263
- while True:
264
- lang, content = one[idx]
265
- temp_text = [content]
266
- if len(text_to_generate) > 0:
267
- text_to_generate[-1] += [temp_text.pop(0)]
268
- lang_to_generate[-1] += [lang]
269
- if len(temp_text) > 0:
270
- text_to_generate += [[i] for i in temp_text]
271
- lang_to_generate += [[lang]] * len(temp_text)
272
- if idx + 1 < len(one):
273
- idx += 1
274
- else:
275
- break
276
- skip_start = (idx != 0) and skip_start
277
- skip_end = (idx != len(one) - 1) and skip_end
278
- print(text_to_generate, lang_to_generate)
279
- audio_list.extend(
280
- generate_audio_multilang(
281
- text_to_generate,
282
- sdp_ratio,
283
- noise_scale,
284
- noise_scale_w,
285
- length_scale,
286
- _speaker,
287
- lang_to_generate,
288
- reference_audio,
289
- emotion,
290
- skip_start,
291
- skip_end,
292
- )
293
  )
294
- idx += 1
295
  elif language.lower() == "auto":
296
- for idx, slice in enumerate(text.split("|")):
297
- if slice == "":
298
- continue
299
- skip_start = idx != 0
300
- skip_end = idx != len(text.split("|")) - 1
301
- sentences_list = split_by_language(
302
- slice, target_languages=["zh", "ja", "en"]
 
 
 
 
 
 
303
  )
304
- idx = 0
305
- while idx < len(sentences_list):
306
- text_to_generate = []
307
- lang_to_generate = []
308
- while True:
309
- content, lang = sentences_list[idx]
310
- temp_text = [content]
311
- lang = lang.upper()
312
- if lang == "JA":
313
- lang = "JP"
314
- if len(text_to_generate) > 0:
315
- text_to_generate[-1] += [temp_text.pop(0)]
316
- lang_to_generate[-1] += [lang]
317
- if len(temp_text) > 0:
318
- text_to_generate += [[i] for i in temp_text]
319
- lang_to_generate += [[lang]] * len(temp_text)
320
- if idx + 1 < len(sentences_list):
321
- idx += 1
322
- else:
323
- break
324
- skip_start = (idx != 0) and skip_start
325
- skip_end = (idx != len(sentences_list) - 1) and skip_end
326
- print(text_to_generate, lang_to_generate)
327
- audio_list.extend(
328
- generate_audio_multilang(
329
- text_to_generate,
330
- sdp_ratio,
331
- noise_scale,
332
- noise_scale_w,
333
- length_scale,
334
- speaker,
335
- lang_to_generate,
336
- reference_audio,
337
- emotion,
338
- skip_start,
339
- skip_end,
340
- )
341
- )
342
- idx += 1
343
  else:
344
  audio_list.extend(
345
  generate_audio(
@@ -352,13 +296,65 @@ def tts_fn(
352
  language,
353
  reference_audio,
354
  emotion,
 
 
355
  )
356
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  audio_concat = np.concatenate(audio_list)
359
  return "Success", (hps.data.sampling_rate, audio_concat)
360
 
361
 
 
 
 
 
 
 
 
 
 
 
362
  def load_audio(path):
363
  audio, sr = librosa.load(path, 48000)
364
  # audio = librosa.resample(audio, 44100, 48000)
@@ -394,10 +390,10 @@ if __name__ == "__main__":
394
  with gr.Blocks() as app:
395
  with gr.Row():
396
  with gr.Column():
397
- gr.Markdown(value="""
398
- 【AI星瞳2.2】在线语音合成(Bert-Vits2 2.2中日英)\n
399
  作者:Xz乔希 https://space.bilibili.com/5859321\n
400
- 声音归属:星瞳_Official https://space.bilibili.com/401315430\n
401
  【AI合集】https://www.modelscope.cn/studios/xzjosh/Bert-VITS2\n
402
  Bert-VITS2项目:https://github.com/Stardust-minus/Bert-VITS2\n
403
  使用本模型请严格遵守法律法规!\n
@@ -414,27 +410,31 @@ if __name__ == "__main__":
414
  另外,所有的语言选项都可以用'|'分割长段实现分句生成。
415
  """,
416
  )
 
417
  speaker = gr.Dropdown(
418
  choices=speakers, value=speakers[0], label="Speaker"
419
  )
420
  _ = gr.Markdown(
421
- value="提示模式(Prompt mode):使用首字母大写英文单词或上传音频提示,用于生成指定风格的声音。\n"
 
422
  )
423
  prompt_mode = gr.Radio(
424
  ["Text prompt", "Audio prompt"],
425
  label="Prompt Mode",
426
  value="Text prompt",
 
427
  )
428
  text_prompt = gr.Textbox(
429
  label="Text prompt",
430
- placeholder="单词描述生成风格。如:Happy",
431
- visible=True,
 
432
  )
433
  audio_prompt = gr.Audio(
434
  label="Audio prompt", type="filepath", visible=False
435
  )
436
  sdp_ratio = gr.Slider(
437
- minimum=0, maximum=1, value=0.2, step=0.01, label="SDP Ratio"
438
  )
439
  noise_scale = gr.Slider(
440
  minimum=0.1, maximum=2, value=0.5, step=0.01, label="Noise"
@@ -450,6 +450,21 @@ if __name__ == "__main__":
450
  )
451
  btn = gr.Button("点击生成", variant="primary")
452
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  with gr.Row():
454
  with gr.Column():
455
  interval_between_sent = gr.Slider(
@@ -492,6 +507,8 @@ if __name__ == "__main__":
492
  audio_prompt,
493
  text_prompt,
494
  prompt_mode,
 
 
495
  ],
496
  outputs=[text_output, audio_output],
497
  )
@@ -510,6 +527,8 @@ if __name__ == "__main__":
510
  interval_between_sent,
511
  audio_prompt,
512
  text_prompt,
 
 
513
  ],
514
  outputs=[text_output, audio_output],
515
  )
 
16
  logger = logging.getLogger(__name__)
17
 
18
  import torch
19
+ import ssl
20
+ ssl._create_default_https_context = ssl._create_unverified_context
21
+ import nltk
22
+ nltk.download('cmudict')
23
  import utils
24
  from infer import infer, latest_version, get_net_g, infer_multilang
25
  import gradio as gr
 
46
  language,
47
  reference_audio,
48
  emotion,
49
+ style_text,
50
+ style_weight,
51
  skip_start=False,
52
  skip_end=False,
53
  ):
 
55
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
56
  with torch.no_grad():
57
  for idx, piece in enumerate(slices):
58
+ skip_start = idx != 0
59
+ skip_end = idx != len(slices) - 1
60
  audio = infer(
61
  piece,
62
  reference_audio=reference_audio,
 
72
  device=device,
73
  skip_start=skip_start,
74
  skip_end=skip_end,
75
+ style_text=style_text,
76
+ style_weight=style_weight,
77
  )
78
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
79
  audio_list.append(audio16bit)
 
80
  return audio_list
81
 
82
 
 
97
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
98
  with torch.no_grad():
99
  for idx, piece in enumerate(slices):
100
+ skip_start = idx != 0
101
+ skip_end = idx != len(slices) - 1
102
  audio = infer_multilang(
103
  piece,
104
  reference_audio=reference_audio,
 
117
  )
118
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
119
  audio_list.append(audio16bit)
 
120
  return audio_list
121
 
122
 
 
133
  interval_between_sent,
134
  reference_audio,
135
  emotion,
136
+ style_text,
137
+ style_weight,
138
  ):
 
 
139
  while text.find("\n\n") != -1:
140
  text = text.replace("\n\n", "\n")
141
+ text = text.replace("|", "")
142
  para_list = re_matching.cut_para(text)
143
+ para_list = [p for p in para_list if p != ""]
144
  audio_list = []
145
+ for p in para_list:
146
+ if not cut_by_sent:
147
+ audio_list += process_text(
 
 
148
  p,
149
+ speaker,
150
+ sdp_ratio,
151
+ noise_scale,
152
+ noise_scale_w,
153
+ length_scale,
154
+ language,
155
+ reference_audio,
156
+ emotion,
157
+ style_text,
158
+ style_weight,
 
 
 
159
  )
 
 
160
  silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
161
  audio_list.append(silence)
162
+ else:
 
 
 
163
  audio_list_sent = []
164
  sent_list = re_matching.cut_sent(p)
165
+ sent_list = [s for s in sent_list if s != ""]
166
+ for s in sent_list:
167
+ audio_list_sent += process_text(
 
168
  s,
169
+ speaker,
170
+ sdp_ratio,
171
+ noise_scale,
172
+ noise_scale_w,
173
+ length_scale,
174
+ language,
175
+ reference_audio,
176
+ emotion,
177
+ style_text,
178
+ style_weight,
 
 
 
179
  )
 
180
  silence = np.zeros((int)(44100 * interval_between_sent))
181
  audio_list_sent.append(silence)
182
  if (interval_between_para - interval_between_sent) > 0:
 
189
  ) # 对完整句子做音量归一
190
  audio_list.append(audio16bit)
191
  audio_concat = np.concatenate(audio_list)
192
+ return ("Success", (hps.data.sampling_rate, audio_concat))
193
 
194
 
195
+ def process_mix(slice):
196
+ _speaker = slice.pop()
197
+ _text, _lang = [], []
198
+ for lang, content in slice:
199
+ content = content.split("|")
200
+ content = [part for part in content if part != ""]
201
+ if len(content) == 0:
202
+ continue
203
+ if len(_text) == 0:
204
+ _text = [[part] for part in content]
205
+ _lang = [[lang] for part in content]
206
+ else:
207
+ _text[-1].append(content[0])
208
+ _lang[-1].append(lang)
209
+ if len(content) > 1:
210
+ _text += [[part] for part in content[1:]]
211
+ _lang += [[lang] for part in content[1:]]
212
+ return _text, _lang, _speaker
213
+
214
+
215
+ def process_auto(text):
216
+ _text, _lang = [], []
217
+ for slice in text.split("|"):
218
+ if slice == "":
219
+ continue
220
+ temp_text, temp_lang = [], []
221
+ sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
222
+ for sentence, lang in sentences_list:
223
+ if sentence == "":
224
+ continue
225
+ temp_text.append(sentence)
226
+ temp_lang.append(lang.upper())
227
+ _text.append(temp_text)
228
+ _lang.append(temp_lang)
229
+ return _text, _lang
230
+
231
+
232
+ def process_text(
233
  text: str,
234
  speaker,
235
  sdp_ratio,
 
239
  language,
240
  reference_audio,
241
  emotion,
242
+ style_text=None,
243
+ style_weight=0,
244
  ):
 
 
 
 
 
 
 
245
  audio_list = []
246
  if language == "mix":
247
  bool_valid, str_valid = re_matching.validate_text(text)
 
250
  hps.data.sampling_rate,
251
  np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
252
  )
 
253
  for slice in re_matching.text_matching(text):
254
+ _text, _lang, _speaker = process_mix(slice)
255
+ if _speaker is None:
256
+ continue
257
+ print(f"Text: {_text}\nLang: {_lang}")
258
+ audio_list.extend(
259
+ generate_audio_multilang(
260
+ _text,
261
+ sdp_ratio,
262
+ noise_scale,
263
+ noise_scale_w,
264
+ length_scale,
265
+ _speaker,
266
+ _lang,
267
+ reference_audio,
268
+ emotion,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  )
270
+ )
271
  elif language.lower() == "auto":
272
+ _text, _lang = process_auto(text)
273
+ print(f"Text: {_text}\nLang: {_lang}")
274
+ audio_list.extend(
275
+ generate_audio_multilang(
276
+ _text,
277
+ sdp_ratio,
278
+ noise_scale,
279
+ noise_scale_w,
280
+ length_scale,
281
+ speaker,
282
+ _lang,
283
+ reference_audio,
284
+ emotion,
285
  )
286
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  else:
288
  audio_list.extend(
289
  generate_audio(
 
296
  language,
297
  reference_audio,
298
  emotion,
299
+ style_text,
300
+ style_weight,
301
  )
302
  )
303
+ return audio_list
304
+
305
+
306
+ def tts_fn(
307
+ text: str,
308
+ speaker,
309
+ sdp_ratio,
310
+ noise_scale,
311
+ noise_scale_w,
312
+ length_scale,
313
+ language,
314
+ reference_audio,
315
+ emotion,
316
+ prompt_mode,
317
+ style_text=None,
318
+ style_weight=0,
319
+ ):
320
+ if style_text == "":
321
+ style_text = None
322
+ if prompt_mode == "Audio prompt":
323
+ if reference_audio == None:
324
+ return ("Invalid audio prompt", None)
325
+ else:
326
+ reference_audio = load_audio(reference_audio)[1]
327
+ else:
328
+ reference_audio = None
329
+
330
+ audio_list = process_text(
331
+ text,
332
+ speaker,
333
+ sdp_ratio,
334
+ noise_scale,
335
+ noise_scale_w,
336
+ length_scale,
337
+ language,
338
+ reference_audio,
339
+ emotion,
340
+ style_text,
341
+ style_weight,
342
+ )
343
 
344
  audio_concat = np.concatenate(audio_list)
345
  return "Success", (hps.data.sampling_rate, audio_concat)
346
 
347
 
348
+ def format_utils(text, speaker):
349
+ _text, _lang = process_auto(text)
350
+ res = f"[{speaker}]"
351
+ for lang_s, content_s in zip(_lang, _text):
352
+ for lang, content in zip(lang_s, content_s):
353
+ res += f"<{lang.lower()}>{content}"
354
+ res += "|"
355
+ return "mix", res[:-1]
356
+
357
+
358
  def load_audio(path):
359
  audio, sr = librosa.load(path, 48000)
360
  # audio = librosa.resample(audio, 44100, 48000)
 
390
  with gr.Blocks() as app:
391
  with gr.Row():
392
  with gr.Column():
393
+ gr.Markdown(value="""
394
+ 【AI阿梓】在线语音合成(Bert-Vits2 2.3中日英)\n
395
  作者:Xz乔希 https://space.bilibili.com/5859321\n
396
+ 声音归属:阿梓从小就很可爱 https://space.bilibili.com/7706705\n
397
  【AI合集】https://www.modelscope.cn/studios/xzjosh/Bert-VITS2\n
398
  Bert-VITS2项目:https://github.com/Stardust-minus/Bert-VITS2\n
399
  使用本模型请严格遵守法律法规!\n
 
410
  另外,所有的语言选项都可以用'|'分割长段实现分句生成。
411
  """,
412
  )
413
+ formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
414
  speaker = gr.Dropdown(
415
  choices=speakers, value=speakers[0], label="Speaker"
416
  )
417
  _ = gr.Markdown(
418
+ value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
419
+ visible=False,
420
  )
421
  prompt_mode = gr.Radio(
422
  ["Text prompt", "Audio prompt"],
423
  label="Prompt Mode",
424
  value="Text prompt",
425
+ visible=False,
426
  )
427
  text_prompt = gr.Textbox(
428
  label="Text prompt",
429
+ placeholder="用文字描述生成风格。如:Happy",
430
+ value="Happy",
431
+ visible=False,
432
  )
433
  audio_prompt = gr.Audio(
434
  label="Audio prompt", type="filepath", visible=False
435
  )
436
  sdp_ratio = gr.Slider(
437
+ minimum=0, maximum=1, value=0.5, step=0.01, label="SDP Ratio"
438
  )
439
  noise_scale = gr.Slider(
440
  minimum=0.1, maximum=2, value=0.5, step=0.01, label="Noise"
 
450
  )
451
  btn = gr.Button("点击生成", variant="primary")
452
  with gr.Column():
453
+ with gr.Accordion("融合文本语义(实验功能)", open=False):
454
+ gr.Markdown(
455
+ value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
456
+ "**注意**:请使用**带有强烈情感的文本**(如:我好快乐!)\n\n"
457
+ "效果较不明确,留空即为不使用该功能"
458
+ )
459
+ style_text = gr.Textbox(label="辅助文本")
460
+ style_weight = gr.Slider(
461
+ minimum=0,
462
+ maximum=1,
463
+ value=0.7,
464
+ step=0.1,
465
+ label="Weight",
466
+ info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
467
+ )
468
  with gr.Row():
469
  with gr.Column():
470
  interval_between_sent = gr.Slider(
 
507
  audio_prompt,
508
  text_prompt,
509
  prompt_mode,
510
+ style_text,
511
+ style_weight,
512
  ],
513
  outputs=[text_output, audio_output],
514
  )
 
527
  interval_between_sent,
528
  audio_prompt,
529
  text_prompt,
530
+ style_text,
531
+ style_weight,
532
  ],
533
  outputs=[text_output, audio_output],
534
  )
bert_gen.py CHANGED
@@ -1,17 +1,16 @@
1
- import argparse
2
- from multiprocessing import Pool, cpu_count
3
-
4
  import torch
5
- import torch.multiprocessing as mp
6
- from tqdm import tqdm
7
-
8
  import commons
9
  import utils
 
 
 
 
10
  from config import config
11
- from text import cleaned_text_to_sequence, get_bert
12
 
13
 
14
- def process_line(line):
 
15
  device = config.bert_gen_config.device
16
  if config.bert_gen_config.use_multi_device:
17
  rank = mp.current_process()._identity
@@ -28,12 +27,13 @@ def process_line(line):
28
  word2ph = [i for i in word2ph]
29
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
30
 
31
- phone = commons.intersperse(phone, 0)
32
- tone = commons.intersperse(tone, 0)
33
- language = commons.intersperse(language, 0)
34
- for i in range(len(word2ph)):
35
- word2ph[i] = word2ph[i] * 2
36
- word2ph[0] += 1
 
37
 
38
  bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
39
 
@@ -59,16 +59,23 @@ if __name__ == "__main__":
59
  args, _ = parser.parse_known_args()
60
  config_path = args.config
61
  hps = utils.get_hparams_from_file(config_path)
 
62
  lines = []
63
  with open(hps.data.training_files, encoding="utf-8") as f:
64
  lines.extend(f.readlines())
65
 
66
  with open(hps.data.validation_files, encoding="utf-8") as f:
67
  lines.extend(f.readlines())
 
 
68
  if len(lines) != 0:
69
- num_processes = min(args.num_processes, cpu_count())
70
  with Pool(processes=num_processes) as pool:
71
- for _ in tqdm(pool.imap_unordered(process_line, lines), total=len(lines)):
72
- pass
 
 
 
 
73
 
74
  print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
 
 
 
 
1
  import torch
2
+ from multiprocessing import Pool
 
 
3
  import commons
4
  import utils
5
+ from tqdm import tqdm
6
+ from text import check_bert_models, cleaned_text_to_sequence, get_bert
7
+ import argparse
8
+ import torch.multiprocessing as mp
9
  from config import config
 
10
 
11
 
12
+ def process_line(x):
13
+ line, add_blank = x
14
  device = config.bert_gen_config.device
15
  if config.bert_gen_config.use_multi_device:
16
  rank = mp.current_process()._identity
 
27
  word2ph = [i for i in word2ph]
28
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
29
 
30
+ if add_blank:
31
+ phone = commons.intersperse(phone, 0)
32
+ tone = commons.intersperse(tone, 0)
33
+ language = commons.intersperse(language, 0)
34
+ for i in range(len(word2ph)):
35
+ word2ph[i] = word2ph[i] * 2
36
+ word2ph[0] += 1
37
 
38
  bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
39
 
 
59
  args, _ = parser.parse_known_args()
60
  config_path = args.config
61
  hps = utils.get_hparams_from_file(config_path)
62
+ check_bert_models()
63
  lines = []
64
  with open(hps.data.training_files, encoding="utf-8") as f:
65
  lines.extend(f.readlines())
66
 
67
  with open(hps.data.validation_files, encoding="utf-8") as f:
68
  lines.extend(f.readlines())
69
+ add_blank = [hps.data.add_blank] * len(lines)
70
+
71
  if len(lines) != 0:
72
+ num_processes = args.num_processes
73
  with Pool(processes=num_processes) as pool:
74
+ for _ in tqdm(
75
+ pool.imap_unordered(process_line, zip(lines, add_blank)),
76
+ total=len(lines),
77
+ ):
78
+ # 这里是缩进的代码块,表示循环体
79
+ pass # 使用pass语句作为占位符
80
 
81
  print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
clap_gen.py CHANGED
@@ -27,7 +27,7 @@ def process_line(line):
27
  device = torch.device("cpu")
28
  wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
29
 
30
- clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.npy")
31
  if os.path.isfile(clap_path):
32
  return
33
 
 
27
  device = torch.device("cpu")
28
  wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
29
 
30
+ clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.pt")
31
  if os.path.isfile(clap_path):
32
  return
33
 
configs/config.json CHANGED
@@ -10,18 +10,20 @@
10
  0.99
11
  ],
12
  "eps": 1e-09,
13
- "batch_size": 12,
14
- "fp16_run": false,
15
  "lr_decay": 0.99995,
16
  "segment_size": 16384,
17
  "init_lr_ratio": 1,
18
  "warmup_epochs": 0,
19
  "c_mel": 45,
20
  "c_kl": 1.0,
 
21
  "skip_optimizer": true,
22
  "freeze_ZH_bert": false,
23
  "freeze_JP_bert": false,
24
- "freeze_EN_bert": false
 
25
  },
26
  "data": {
27
  "training_files": "filelists/train.list",
@@ -35,7 +37,7 @@
35
  "mel_fmin": 0.0,
36
  "mel_fmax": null,
37
  "add_blank": true,
38
- "n_speakers": 896,
39
  "cleaned_text": true,
40
  "spk2id": {
41
  "派蒙_ZH": 0,
@@ -119,203 +121,203 @@
119
  "伊迪娅_ZH": 78,
120
  "留云借风真君_ZH": 79,
121
  "绮良良_ZH": 80,
122
- "七七_ZH": 81,
123
- "式大将_ZH": 82,
124
- "瑶瑶_ZH": 83,
125
- "奥兹_ZH": 84,
126
- "菲米尼_ZH": 85,
127
- "米卡_ZH": 86,
128
- "哲平_ZH": 87,
129
- "大肉丸_ZH": 88,
130
- "托克_ZH": 89,
131
- "蒂玛乌斯_ZH": 90,
132
- "昆钧_ZH": 91,
133
- "欧菲妮_ZH": 92,
134
- "塞琉斯_ZH": 93,
135
- "仆人_ZH": 94,
136
- "迈勒斯_ZH": 95,
137
- "希格雯_ZH": 96,
138
- "阿守_ZH": 97,
139
- "拉赫曼_ZH": 98,
140
- "杜拉夫_ZH": 99,
141
- "伊利亚斯_ZH": 100,
142
- "阿晃_ZH": 101,
143
- "旁白_ZH": 102,
144
- "爱德琳_ZH": 103,
145
- "埃洛伊_ZH": 104,
146
- "德沃沙克_ZH": 105,
147
- "玛乔丽_ZH": 106,
148
- "塞塔蕾_ZH": 107,
149
- "柊千里_ZH": 108,
150
- "海芭夏_ZH": 109,
151
- "九条镰治_ZH": 110,
152
- "阿娜耶_ZH": 111,
153
- "笼钓瓶一心_ZH": 112,
154
- "回声海螺_ZH": 113,
155
- "劳维克_ZH": 114,
156
- "元太_ZH": 115,
157
- "阿扎尔_ZH": 116,
158
- "查尔斯_ZH": 117,
159
- "阿洛瓦_ZH": 118,
160
- "埃勒曼_ZH": 119,
161
- "纳比尔_ZH": 120,
162
- "莎拉_ZH": 121,
163
- "康纳_ZH": 122,
164
- "博来_ZH": 123,
165
- "玛塞勒_ZH": 124,
166
- "阿祇_ZH": 125,
167
- "博士_ZH": 126,
168
- "玛格丽特_ZH": 127,
169
- "迪尔菲_ZH": 128,
170
- "宛烟_ZH": 129,
171
- "羽生田千鹤_ZH": 130,
172
- "海妮耶_ZH": 131,
173
- "旅行者_ZH": 132,
174
- "霍夫曼_ZH": 133,
175
- "佐西摩斯_ZH": 134,
176
- "鹿野奈奈_ZH": 135,
177
- "舒伯特_ZH": 136,
178
- "天叔_ZH": 137,
179
- "艾莉丝_ZH": 138,
180
- "龙二_ZH": 139,
181
- "莺儿_ZH": 140,
182
- "嘉良_ZH": 141,
183
- "一心传名刀_ZH": 142,
184
- "费迪南德_ZH": 143,
185
- "珊瑚_ZH": 144,
186
- "言笑_ZH": 145,
187
- "久利须_ZH": 146,
188
- "嘉玛_ZH": 147,
189
- "艾文_ZH": 148,
190
- "克洛琳德_ZH": 149,
191
- "丹吉尔_ZH": 150,
192
- "女士_ZH": 151,
193
- "白老先生_ZH": 152,
194
- "天目十五_ZH": 153,
195
- "老孟_ZH": 154,
196
- "巴达维_ZH": 155,
197
- "长生_ZH": 156,
198
- "吴船长_ZH": 157,
199
- "拉齐_ZH": 158,
200
- "艾伯特_ZH": 159,
201
- "松浦_ZH": 160,
202
- "埃泽_ZH": 161,
203
- "阿圆_ZH": 162,
204
- "莫塞伊思_ZH": 163,
205
- "阿拉夫_ZH": 164,
206
- "杜吉耶_ZH": 165,
207
- "石头_ZH": 166,
208
- "百闻_ZH": 167,
209
- "波洛_ZH": 168,
210
- "斯坦利_ZH": 169,
211
- "博易_ZH": 170,
212
- "迈蒙_ZH": 171,
213
- "掇星攫辰天君_ZH": 172,
214
- "毗伽尔_ZH": 173,
215
- "芙卡洛斯_ZH": 174,
216
- "恶龙_ZH": 175,
217
- "恕筠_ZH": 176,
218
- "知易_ZH": 177,
219
- "克列门特_ZH": 178,
220
- "大慈树王_ZH": 179,
221
- "西拉杰_ZH": 180,
222
- "上杉_ZH": 181,
223
- "阿尔卡米_ZH": 182,
224
- "纯水精灵_ZH": 183,
225
- "常九爷_ZH": 184,
226
- "沙扎曼_ZH": 185,
227
- "田铁嘴_ZH": 186,
228
- "克罗索_ZH": 187,
229
- "阿巴图伊_ZH": 188,
230
- "_ZH": 189,
231
  "阿佩普_ZH": 190,
232
  "埃尔欣根_ZH": 191,
233
  "萨赫哈蒂_ZH": 192,
234
  "塔杰·拉德卡尼_ZH": 193,
235
  "安西_ZH": 194,
236
- "埃舍尔_ZH": 195,
237
- "萨齐因_ZH": 196,
238
- "派蒙_JP": 197,
239
- "纳西妲_JP": 198,
240
- "凯亚_JP": 199,
241
- "阿贝多_JP": 200,
242
- "温迪_JP": 201,
243
- "枫原万叶_JP": 202,
244
- "钟离_JP": 203,
245
- "荒泷一斗_JP": 204,
246
- "八重神子_JP": 205,
247
- "艾尔海森_JP": 206,
248
- "提纳里_JP": 207,
249
- "迪希雅_JP": 208,
250
- "卡维_JP": 209,
251
- "宵宫_JP": 210,
252
- "那维莱特_JP": 211,
253
- "莱依拉_JP": 212,
254
- "赛诺_JP": 213,
255
- "莫娜_JP": 214,
256
- "诺艾尔_JP": 215,
257
- "托马_JP": 216,
258
- "凝光_JP": 217,
259
- "林尼_JP": 218,
260
- "北斗_JP": 219,
261
- "柯莱_JP": 220,
262
- "神里绫华_JP": 221,
263
- "可莉_JP": 222,
264
- "芭芭拉_JP": 223,
265
- "雷电将军_JP": 224,
266
- "娜维娅_JP": 225,
267
- "芙宁娜_JP": 226,
268
- "珊瑚宫心海_JP": 227,
269
- "鹿野院平藏_JP": 228,
270
- "迪奥娜_JP": 229,
271
- "_JP": 230,
272
- "五郎_JP": 231,
273
- "班尼特_JP": 232,
274
- "达达利亚_JP": 233,
275
- "安柏_JP": 234,
276
- "莱欧斯利_JP": 235,
277
- "夜兰_JP": 236,
278
- "妮露_JP": 237,
279
- "辛焱_JP": 238,
280
- "丽莎_JP": 239,
281
- "珐露珊_JP": 240,
282
- "_JP": 241,
283
- "香菱_JP": 242,
284
- "迪卢克_JP": 243,
285
- "砂糖_JP": 244,
286
- "烟绯_JP": 245,
287
- "早柚_JP": 246,
288
- "云堇_JP": 247,
289
- "刻晴_JP": 248,
290
- "重云_JP": 249,
291
- "优菈_JP": 250,
292
- "胡桃_JP": 251,
293
- "流浪者_JP": 252,
294
- "久岐忍_JP": 253,
295
- "神里绫人_JP": 254,
296
- "甘雨_JP": 255,
297
- "戴因斯雷布_JP": 256,
298
- "菲谢尔_JP": 257,
299
- "白术_JP": 258,
300
- "行秋_JP": 259,
301
- "九条裟罗_JP": 260,
302
- "夏洛蒂_JP": 261,
303
- "雷泽_JP": 262,
304
- "申鹤_JP": 263,
305
- "_JP": 264,
306
- "_JP": 265,
307
- "迪娜泽黛_JP": 266,
308
- "凯瑟琳_JP": 267,
309
- "多莉_JP": 268,
310
- "坎蒂丝_JP": 269,
311
- "琳妮特_JP": 270,
312
- "萍姥姥_JP": 271,
313
- "罗莎莉亚_JP": 272,
314
- "埃德_JP": 273,
315
- "爱贝尔_JP": 274,
316
- "伊迪娅_JP": 275,
317
- "留云借风真君_JP": 276,
318
- "绮良良_JP": 277,
319
  "七七_JP": 278,
320
  "式大将_JP": 279,
321
  "瑶瑶_JP": 280,
@@ -323,576 +325,571 @@
323
  "菲米尼_JP": 282,
324
  "米卡_JP": 283,
325
  "哲平_JP": 284,
326
- "大肉丸_JP": 285,
327
- "托克_JP": 286,
328
- "蒂玛乌斯_JP": 287,
329
- "昆钧_JP": 288,
330
- "欧菲妮_JP": 289,
331
- "塞琉斯_JP": 290,
332
- "仆人_JP": 291,
333
- "迈勒斯_JP": 292,
334
- "希格雯_JP": 293,
335
- "阿守_JP": 294,
336
- "拉赫曼_JP": 295,
337
- "杜拉夫_JP": 296,
338
- "伊利亚斯_JP": 297,
339
- "阿晃_JP": 298,
340
- "旁白_JP": 299,
341
- "爱德琳_JP": 300,
342
- "埃洛伊_JP": 301,
343
- "德沃沙克_JP": 302,
344
- "玛乔丽_JP": 303,
345
- "塞塔蕾_JP": 304,
346
- "柊千里_JP": 305,
347
- "海芭夏_JP": 306,
348
- "九条镰治_JP": 307,
349
- "阿娜耶_JP": 308,
350
- "笼钓瓶一心_JP": 309,
351
- "回声海螺_JP": 310,
352
- "劳维克_JP": 311,
353
- "元太_JP": 312,
354
- "阿扎尔_JP": 313,
355
- "查尔斯_JP": 314,
356
- "阿洛瓦_JP": 315,
357
- "埃勒曼_JP": 316,
358
- "纳比尔_JP": 317,
359
- "莎拉_JP": 318,
360
- "康纳_JP": 319,
361
- "博来_JP": 320,
362
- "玛塞勒_JP": 321,
363
- "阿祇_JP": 322,
364
- "博士_JP": 323,
365
- "迪尔菲_JP": 324,
366
- "玛格丽特_JP": 325,
367
- "宛烟_JP": 326,
368
- "羽生田千鹤_JP": 327,
369
- "海妮耶_JP": 328,
370
- "霍夫曼_JP": 329,
371
- "旅行者_JP": 330,
372
- "佐西摩斯_JP": 331,
373
- "舒伯特_JP": 332,
374
- "鹿野奈奈_JP": 333,
375
- "天叔_JP": 334,
376
- "龙二_JP": 335,
377
- "艾莉丝_JP": 336,
378
- "莺儿_JP": 337,
379
- "嘉良_JP": 338,
380
- "珊瑚_JP": 339,
381
- "言笑_JP": 340,
382
- "一心传名刀_JP": 341,
383
- "费迪南德_JP": 342,
384
- "久利须_JP": 343,
385
- "嘉玛_JP": 344,
386
- "艾文_JP": 345,
387
- "克洛琳德_JP": 346,
388
- "丹吉尔_JP": 347,
389
- "天目十五_JP": 348,
390
- "女士_JP": 349,
391
- "老孟_JP": 350,
392
- "白老先生_JP": 351,
393
- "舍利夫_JP": 352,
394
- "巴达维_JP": 353,
395
- "拉齐_JP": 354,
396
- "长生_JP": 355,
397
- "吴船长_JP": 356,
398
- "艾伯特_JP": 357,
399
- "松浦_JP": 358,
400
- "埃泽_JP": 359,
401
- "阿圆_JP": 360,
402
- "阿拉夫_JP": 361,
403
- "莫塞伊思_JP": 362,
404
- "石头_JP": 363,
405
- "百闻_JP": 364,
406
- "杜吉耶_JP": 365,
407
- "波洛_JP": 366,
408
- "掇星攫辰天君_JP": 367,
409
- "迈蒙_JP": 368,
410
- "博易_JP": 369,
411
- "诗筠_JP": 370,
412
- "斯坦利_JP": 371,
413
- "毗伽尔_JP": 372,
414
- "芙卡洛斯_JP": 373,
415
- "恶龙_JP": 374,
416
- "小仓澪_JP": 375,
417
- "恕筠_JP": 376,
418
- "知易_JP": 377,
419
- "克列门特_JP": 378,
420
- "大慈树王_JP": 379,
421
- "望雅_JP": 380,
422
- "黑田_JP": 381,
423
- "卡莉娜_JP": 382,
424
- "马姆杜_JP": 383,
425
- "科林斯_JP": 384,
426
- "上杉_JP": 385,
427
- "西拉杰_JP": 386,
428
- "菲尔戈黛特_JP": 387,
429
- "一平_JP": 388,
430
- "纯水精灵_JP": 389,
431
- "阿尔卡米_JP": 390,
432
- "老戴_JP": 391,
433
- "谢赫祖拜尔_JP": 392,
434
- "沙扎曼_JP": 393,
435
- "田铁嘴_JP": 394,
436
- "小野寺_JP": 395,
437
- "百识_JP": 396,
438
- "克罗索_JP": 397,
439
- "莱斯格_JP": 398,
440
- "芷巧_JP": 399,
441
- "加藤洋平_JP": 400,
442
- "阿巴图伊_JP": 401,
443
- "埃尔欣根_JP": 402,
444
- "斯嘉莉_JP": 403,
445
- "阿佩普_JP": 404,
446
- "巫女_JP": 405,
447
- "卡布斯_JP": 406,
448
- "洛伦佐_JP": 407,
449
- "萨赫哈蒂_JP": 408,
450
- "娜德瓦_JP": 409,
451
- "塞德娜_JP": 410,
452
- "塔杰·拉德卡尼_JP": 411,
453
- "绘星_JP": 412,
454
- "泽田_JP": 413,
455
- "安西_JP": 414,
456
- "拉���德_JP": 415,
457
- "亚卡巴_JP": 416,
458
- "有乐斋_JP": 417,
459
- "莱昂_JP": 418,
460
- "尤苏波夫_JP": 419,
461
- "夏妮_JP": 420,
462
- "埃舍尔_JP": 421,
463
- "萨齐因_JP": 422,
464
- "古山_JP": 423,
465
- "自称渊上之物_JP": 424,
466
- "丹羽_JP": 425,
467
- "塞萨尔的日记_JP": 426,
468
- "派蒙_EN": 427,
469
- "纳西妲_EN": 428,
470
- "凯亚_EN": 429,
471
- "阿贝多_EN": 430,
472
- "温迪_EN": 431,
473
- "枫原万叶_EN": 432,
474
- "钟离_EN": 433,
475
- "荒泷一斗_EN": 434,
476
- "八重神子_EN": 435,
477
- "艾尔海森_EN": 436,
478
- "提纳里_EN": 437,
479
- "迪希雅_EN": 438,
480
- "卡维_EN": 439,
481
- "宵宫_EN": 440,
482
- "莱依拉_EN": 441,
483
- "那维莱特_EN": 442,
484
- "赛诺_EN": 443,
485
- "莫娜_EN": 444,
486
- "诺艾尔_EN": 445,
487
- "托马_EN": 446,
488
- "凝光_EN": 447,
489
- "林尼_EN": 448,
490
- "北斗_EN": 449,
491
- "柯莱_EN": 450,
492
- "神里绫华_EN": 451,
493
- "可莉_EN": 452,
494
- "芭芭拉_EN": 453,
495
- "雷电将军_EN": 454,
496
- "娜维娅_EN": 455,
497
- "芙宁娜_EN": 456,
498
- "珊瑚宫心海_EN": 457,
499
- "鹿野院平藏_EN": 458,
500
- "迪奥娜_EN": 459,
501
- "五郎_EN": 460,
502
- "_EN": 461,
503
- "班尼特_EN": 462,
504
- "达达利亚_EN": 463,
505
- "安柏_EN": 464,
506
- "莱欧斯利_EN": 465,
507
- "夜兰_EN": 466,
508
- "妮露_EN": 467,
509
- "辛焱_EN": 468,
510
- "珐露珊_EN": 469,
511
- "丽莎_EN": 470,
512
- "_EN": 471,
513
- "香菱_EN": 472,
514
- "迪卢克_EN": 473,
515
- "砂糖_EN": 474,
516
- "烟绯_EN": 475,
517
- "早柚_EN": 476,
518
- "云堇_EN": 477,
519
- "刻晴_EN": 478,
520
- "重云_EN": 479,
521
- "优菈_EN": 480,
522
- "胡桃_EN": 481,
523
- "流浪者_EN": 482,
524
- "久岐忍_EN": 483,
525
- "神里绫人_EN": 484,
526
- "甘雨_EN": 485,
527
- "戴因斯雷布_EN": 486,
528
- "菲谢尔_EN": 487,
529
- "白术_EN": 488,
530
- "行秋_EN": 489,
531
- "九条裟罗_EN": 490,
532
- "夏洛蒂_EN": 491,
533
- "雷泽_EN": 492,
534
- "申鹤_EN": 493,
535
- "_EN": 494,
536
- "_EN": 495,
537
- "迪娜泽黛_EN": 496,
538
- "凯瑟琳_EN": 497,
539
- "多莉_EN": 498,
540
- "坎蒂丝_EN": 499,
541
- "琳妮特_EN": 500,
542
- "萍姥姥_EN": 501,
543
- "罗莎莉亚_EN": 502,
544
- "埃德_EN": 503,
545
- "爱贝尔_EN": 504,
546
- "伊迪娅_EN": 505,
547
- "留云借风真君_EN": 506,
548
- "绮良良_EN": 507,
549
- "七七_EN": 508,
550
- "式大将_EN": 509,
551
- "瑶瑶_EN": 510,
552
- "奥兹_EN": 511,
553
- "菲米尼_EN": 512,
554
- "米卡_EN": 513,
555
- "哲平_EN": 514,
556
- "大肉丸_EN": 515,
557
- "托克_EN": 516,
558
- "蒂玛乌斯_EN": 517,
559
- "昆钧_EN": 518,
560
- "欧菲妮_EN": 519,
561
- "塞琉斯_EN": 520,
562
- "仆人_EN": 521,
563
- "迈勒斯_EN": 522,
564
- "希格雯_EN": 523,
565
- "阿守_EN": 524,
566
- "拉赫曼_EN": 525,
567
- "杜拉夫_EN": 526,
568
- "伊利亚斯_EN": 527,
569
- "阿晃_EN": 528,
570
- "旁白_EN": 529,
571
- "爱德琳_EN": 530,
572
- "埃洛伊_EN": 531,
573
- "德沃沙克_EN": 532,
574
- "玛乔丽_EN": 533,
575
- "塞塔蕾_EN": 534,
576
- "柊千里_EN": 535,
577
- "海芭夏_EN": 536,
578
- "九条镰治_EN": 537,
579
- "阿娜耶_EN": 538,
580
- "笼钓瓶一心_EN": 539,
581
- "回声海螺_EN": 540,
582
- "劳维克_EN": 541,
583
- "元太_EN": 542,
584
- "阿扎尔_EN": 543,
585
- "查尔斯_EN": 544,
586
- "阿洛瓦_EN": 545,
587
- "埃勒曼_EN": 546,
588
- "纳比尔_EN": 547,
589
- "莎拉_EN": 548,
590
- "康纳_EN": 549,
591
- "博来_EN": 550,
592
- "玛塞勒_EN": 551,
593
- "阿祇_EN": 552,
594
- "博士_EN": 553,
595
- "迪尔菲_EN": 554,
596
- "宛烟_EN": 555,
597
- "玛格丽特_EN": 556,
598
- "羽生田千鹤_EN": 557,
599
- "海妮耶_EN": 558,
600
- "霍夫曼_EN": 559,
601
- "旅行者_EN": 560,
602
- "佐西摩斯_EN": 561,
603
- "鹿野奈奈_EN": 562,
604
- "舒伯特_EN": 563,
605
- "天叔_EN": 564,
606
- "艾莉丝_EN": 565,
607
- "龙二_EN": 566,
608
- "莺儿_EN": 567,
609
- "嘉良_EN": 568,
610
- "珊瑚_EN": 569,
611
- "费迪南德_EN": 570,
612
- "言笑_EN": 571,
613
- "一心传名刀_EN": 572,
614
- "久利须_EN": 573,
615
- "嘉玛_EN": 574,
616
- "艾文_EN": 575,
617
- "克洛琳德_EN": 576,
618
- "丹吉尔_EN": 577,
619
- "女士_EN": 578,
620
- "天目十五_EN": 579,
621
- "老孟_EN": 580,
622
- "白老先生_EN": 581,
623
- "舍利夫_EN": 582,
624
- "巴达维_EN": 583,
625
- "拉齐_EN": 584,
626
- "长生_EN": 585,
627
- "吴船长_EN": 586,
628
- "艾伯特_EN": 587,
629
- "松浦_EN": 588,
630
- "埃泽_EN": 589,
631
- "阿圆_EN": 590,
632
- "阿拉夫_EN": 591,
633
- "莫塞伊思_EN": 592,
634
- "石头_EN": 593,
635
- "百闻_EN": 594,
636
- "杜吉耶_EN": 595,
637
- "波洛_EN": 596,
638
- "斯坦利_EN": 597,
639
- "掇星攫辰天君_EN": 598,
640
- "迈蒙_EN": 599,
641
- "博易_EN": 600,
642
- "诗筠_EN": 601,
643
- "毗伽尔_EN": 602,
644
- "慧心_EN": 603,
645
- "芙卡洛斯_EN": 604,
646
- "恶龙_EN": 605,
647
- "小仓澪_EN": 606,
648
- "恕筠_EN": 607,
649
- "知易_EN": 608,
650
- "克列门特_EN": 609,
651
- "大慈树王_EN": 610,
652
- "维多利亚_EN": 611,
653
- "黑田_EN": 612,
654
- "马姆杜_EN": 613,
655
- "科林斯_EN": 614,
656
- "上杉_EN": 615,
657
- "西拉杰_EN": 616,
658
- "宁禄_EN": 617,
659
- "纯水精灵_EN": 618,
660
- "常九爷_EN": 619,
661
- "阿尔卡米_EN": 620,
662
- "沙扎曼_EN": 621,
663
- "田铁嘴_EN": 622,
664
- "加萨尼_EN": 623,
665
- "克罗索_EN": 624,
666
- "星稀_EN": 625,
667
- "莱斯格_EN": 626,
668
- "阿巴图伊_EN": 627,
669
- "_EN": 628,
670
- "德田_EN": 629,
671
- "埃尔欣根_EN": 630,
672
- "阿佩普_EN": 631,
673
- "萨赫哈蒂_EN": 632,
674
- "洛伦佐_EN": 633,
675
- "塔杰·拉德卡尼_EN": 634,
676
- "泽田_EN": 635,
677
- "安西_EN": 636,
678
- "理水叠山真君_EN": 637,
679
  "埃舍尔_EN": 638,
680
- "萨齐因_EN": 639,
681
- "古田_EN": 640,
682
- "三月七_ZH": 641,
683
- "丹恒_ZH": 642,
684
- "希儿_ZH": 643,
685
- "娜塔莎_ZH": 644,
686
- "希露瓦_ZH": 645,
687
- "瓦尔特_ZH": 646,
688
- "佩拉_ZH": 647,
689
- "布洛妮娅_ZH": 648,
690
- "虎克_ZH": 649,
691
- "素裳_ZH": 650,
692
- "克拉拉_ZH": 651,
693
- "符玄_ZH": 652,
694
- "白露_ZH": 653,
695
- "杰帕德_ZH": 654,
696
- "景元_ZH": 655,
697
- "藿藿_ZH": 656,
698
- "姬子_ZH": 657,
699
- "_ZH": 658,
700
- "_ZH": 659,
701
- "卡芙卡_ZH": 660,
702
- "桂乃芬_ZH": 661,
703
- "艾丝妲_ZH": 662,
704
- "玲可_ZH": 663,
705
- "彦卿_ZH": 664,
706
- "托帕_ZH": 665,
707
- "驭空_ZH": 666,
708
- "浮烟_ZH": 667,
709
- "停云_ZH": 668,
710
- "镜流_ZH": 669,
711
- "罗刹_ZH": 670,
712
- "卢卡_ZH": 671,
713
- "史瓦罗_ZH": 672,
714
- "黑塔_ZH": 673,
715
- "桑博_ZH": 674,
716
- "伦纳德_ZH": 675,
717
- "明曦_ZH": 676,
718
- "银狼_ZH": 677,
719
- "帕姆_ZH": 678,
720
- "青雀_ZH": 679,
721
- "乔瓦尼_ZH": 680,
722
- "公输师傅_ZH": 681,
723
- "晴霓_ZH": 682,
724
- "螺丝咕姆_ZH": 683,
725
- "阿兰_ZH": 684,
726
- "奥列格_ZH": 685,
727
- "丹枢_ZH": 686,
728
- "尾巴_ZH": 687,
729
- "寒鸦_ZH": 688,
730
- "雪衣_ZH": 689,
731
- "可可利亚_ZH": 690,
732
- "青镞_ZH": 691,
733
- "半夏_ZH": 692,
734
- "银枝_ZH": 693,
735
- "大毫_ZH": 694,
736
- "霄翰_ZH": 695,
737
- "信使_ZH": 696,
738
- "费斯曼_ZH": 697,
739
- "绿芙蓉_ZH": 698,
740
- "dev_成男_ZH": 699,
741
- "金人会长_ZH": 700,
742
- "维利特_ZH": 701,
743
- "维尔德_ZH": 702,
744
- "斯科特_ZH": 703,
745
- "卡波特_ZH": 704,
746
- "刃_ZH": 705,
747
- "岩明_ZH": 706,
748
- "浣溪_ZH": 707,
749
- "三月七_JP": 708,
750
- "丹恒_JP": 709,
751
- "希儿_JP": 710,
752
- "娜塔莎_JP": 711,
753
- "希露瓦_JP": 712,
754
- "瓦尔特_JP": 713,
755
- "佩拉_JP": 714,
756
- "布洛妮娅_JP": 715,
757
- "虎克_JP": 716,
758
- "素裳_JP": 717,
759
- "克拉拉_JP": 718,
760
- "符玄_JP": 719,
761
- "白露_JP": 720,
762
- "杰帕德_JP": 721,
763
- "景元_JP": 722,
764
- "藿藿_JP": 723,
765
- "姬子_JP": 724,
766
- "卡芙卡_JP": 725,
767
- "_JP": 726,
768
- "_JP": 727,
769
- "桂乃芬_JP": 728,
770
- "艾丝妲_JP": 729,
771
- "彦卿_JP": 730,
772
- "玲可_JP": 731,
773
- "托帕_JP": 732,
774
- "驭空_JP": 733,
775
- "浮烟_JP": 734,
776
- "停云_JP": 735,
777
- "镜流_JP": 736,
778
- "罗刹_JP": 737,
779
- "卢卡_JP": 738,
780
- "史瓦罗_JP": 739,
781
- "黑塔_JP": 740,
782
- "桑博_JP": 741,
783
- "伦纳德_JP": 742,
784
- "明曦_JP": 743,
785
- "银狼_JP": 744,
786
- "帕姆_JP": 745,
787
- "青雀_JP": 746,
788
- "乔瓦尼_JP": 747,
789
- "公输师傅_JP": 748,
790
- "晴霓_JP": 749,
791
- "螺丝咕姆_JP": 750,
792
- "阿兰_JP": 751,
793
- "奥列格_JP": 752,
794
- "丹枢_JP": 753,
795
- "尾巴_JP": 754,
796
- "寒鸦_JP": 755,
797
- "雪衣_JP": 756,
798
- "可可利亚_JP": 757,
799
- "青镞_JP": 758,
800
- "半夏_JP": 759,
801
- "银枝_JP": 760,
802
- "大毫_JP": 761,
803
- "霄翰_JP": 762,
804
- "信使_JP": 763,
805
- "费斯曼_JP": 764,
806
- "绿芙蓉_JP": 765,
807
- "dev_成男_JP": 766,
808
- "金人会长_JP": 767,
809
- "维利特_JP": 768,
810
- "维尔德_JP": 769,
811
- "斯科特_JP": 770,
812
- "_JP": 771,
813
- "卡波特_JP": 772,
814
- "岩明_JP": 773,
815
- "浣溪_JP": 774,
816
- "净砚_JP": 775,
817
- "紫月季_JP": 776,
818
- "歌蒂_JP": 777,
819
- "奇怪的云骑_JP": 778,
820
- "幻胧_JP": 779,
821
- "斯薇塔_JP": 780,
822
- "隐书_JP": 781,
823
- "三月七_EN": 782,
824
- "丹恒_EN": 783,
825
- "希儿_EN": 784,
826
- "娜塔莎_EN": 785,
827
- "希露瓦_EN": 786,
828
- "瓦尔特_EN": 787,
829
- "佩拉_EN": 788,
830
- "布洛妮娅_EN": 789,
831
- "虎克_EN": 790,
832
- "素裳_EN": 791,
833
- "克拉拉_EN": 792,
834
- "符玄_EN": 793,
835
- "白露_EN": 794,
836
- "杰帕德_EN": 795,
837
- "景元_EN": 796,
838
- "藿藿_EN": 797,
839
- "姬子_EN": 798,
840
- "卡芙卡_EN": 799,
841
- "_EN": 800,
842
- "_EN": 801,
843
- "桂乃芬_EN": 802,
844
- "艾丝妲_EN": 803,
845
- "彦卿_EN": 804,
846
- "玲可_EN": 805,
847
- "托帕_EN": 806,
848
- "驭空_EN": 807,
849
- "浮烟_EN": 808,
850
- "停云_EN": 809,
851
- "镜流_EN": 810,
852
- "罗刹_EN": 811,
853
- "卢卡_EN": 812,
854
- "史瓦罗_EN": 813,
855
- "黑塔_EN": 814,
856
- "桑博_EN": 815,
857
- "伦纳德_EN": 816,
858
- "明曦_EN": 817,
859
- "银狼_EN": 818,
860
- "帕姆_EN": 819,
861
- "青雀_EN": 820,
862
- "乔瓦尼_EN": 821,
863
- "公输师傅_EN": 822,
864
- "晴霓_EN": 823,
865
- "螺丝咕姆_EN": 824,
866
- "阿兰_EN": 825,
867
- "奥列格_EN": 826,
868
- "丹枢_EN": 827,
869
- "尾巴_EN": 828,
870
- "寒鸦_EN": 829,
871
- "雪衣_EN": 830,
872
- "可可利亚_EN": 831,
873
- "青镞_EN": 832,
874
- "半夏_EN": 833,
875
- "银枝_EN": 834,
876
- "大毫_EN": 835,
877
- "霄翰_EN": 836,
878
- "信使_EN": 837,
879
- "费斯曼_EN": 838,
880
- "绿芙蓉_EN": 839,
881
- "dev_成男_EN": 840,
882
- "金人会长_EN": 841,
883
- "维利特_EN": 842,
884
- "维尔德_EN": 843,
885
- "_EN": 844,
886
- "卡波特_EN": 845,
887
- "岩明_EN": 846,
888
- "浣溪_EN": 847,
889
- "紫月季_EN": 848,
890
- "幻胧_EN": 849,
891
- "女声_EN": 850,
892
- "陆景和": 851,
893
- "莫弈": 852,
894
- "左然": 853,
895
- "夏彦": 854
896
  }
897
  },
898
  "model": {
@@ -947,7 +944,14 @@
947
  ],
948
  "n_layers_q": 3,
949
  "use_spectral_norm": false,
950
- "gin_channels": 256
 
 
 
 
 
 
 
951
  },
952
- "version": "2.2"
953
  }
 
10
  0.99
11
  ],
12
  "eps": 1e-09,
13
+ "batch_size": 16,
14
+ "bf16_run": false,
15
  "lr_decay": 0.99995,
16
  "segment_size": 16384,
17
  "init_lr_ratio": 1,
18
  "warmup_epochs": 0,
19
  "c_mel": 45,
20
  "c_kl": 1.0,
21
+ "c_commit": 100,
22
  "skip_optimizer": true,
23
  "freeze_ZH_bert": false,
24
  "freeze_JP_bert": false,
25
+ "freeze_EN_bert": false,
26
+ "freeze_emo": false
27
  },
28
  "data": {
29
  "training_files": "filelists/train.list",
 
37
  "mel_fmin": 0.0,
38
  "mel_fmax": null,
39
  "add_blank": true,
40
+ "n_speakers": 850,
41
  "cleaned_text": true,
42
  "spk2id": {
43
  "派蒙_ZH": 0,
 
121
  "伊迪娅_ZH": 78,
122
  "留云借风真君_ZH": 79,
123
  "绮良良_ZH": 80,
124
+ "陌生人_ZH": 81,
125
+ "七七_ZH": 82,
126
+ "式大将_ZH": 83,
127
+ "瑶瑶_ZH": 84,
128
+ "奥兹_ZH": 85,
129
+ "菲米尼_ZH": 86,
130
+ "米卡_ZH": 87,
131
+ "哲平_ZH": 88,
132
+ "浮游水蕈兽·元素生命_ZH": 89,
133
+ "大肉丸_ZH": 90,
134
+ "托克_ZH": 91,
135
+ "蒂玛乌斯_ZH": 92,
136
+ "昆钧_ZH": 93,
137
+ "欧菲妮_ZH": 94,
138
+ "塞琉斯_ZH": 95,
139
+ "仆人_ZH": 96,
140
+ "迈勒斯_ZH": 97,
141
+ "希格雯_ZH": 98,
142
+ "阿守_ZH": 99,
143
+ "拉赫曼_ZH": 100,
144
+ "杜拉夫_ZH": 101,
145
+ "伊利亚斯_ZH": 102,
146
+ "阿晃_ZH": 103,
147
+ "旁白_ZH": 104,
148
+ "爱德琳_ZH": 105,
149
+ "埃洛伊_ZH": 106,
150
+ "德沃沙克_ZH": 107,
151
+ "玛乔丽_ZH": 108,
152
+ "塞塔蕾_ZH": 109,
153
+ "柊千里_ZH": 110,
154
+ "海芭夏_ZH": 111,
155
+ "九条镰治_ZH": 112,
156
+ "阿娜耶_ZH": 113,
157
+ "笼钓瓶一心_ZH": 114,
158
+ "回声海螺_ZH": 115,
159
+ "劳维克_ZH": 116,
160
+ "元太_ZH": 117,
161
+ "阿扎尔_ZH": 118,
162
+ "查尔斯_ZH": 119,
163
+ "阿洛瓦_ZH": 120,
164
+ "埃勒曼_ZH": 121,
165
+ "纳比尔_ZH": 122,
166
+ "莎拉_ZH": 123,
167
+ "康纳_ZH": 124,
168
+ "博来_ZH": 125,
169
+ "玛塞勒_ZH": 126,
170
+ "阿祇_ZH": 127,
171
+ "博士_ZH": 128,
172
+ "玛格丽特_ZH": 129,
173
+ "迪尔菲_ZH": 130,
174
+ "宛烟_ZH": 131,
175
+ "羽生田千鹤_ZH": 132,
176
+ "海妮耶_ZH": 133,
177
+ "旅行者_ZH": 134,
178
+ "霍夫曼_ZH": 135,
179
+ "佐西摩斯_ZH": 136,
180
+ "鹿野奈奈_ZH": 137,
181
+ "舒伯特_ZH": 138,
182
+ "天叔_ZH": 139,
183
+ "艾莉丝_ZH": 140,
184
+ "龙二_ZH": 141,
185
+ "莺儿_ZH": 142,
186
+ "嘉良_ZH": 143,
187
+ "一心传名刀_ZH": 144,
188
+ "珊瑚_ZH": 145,
189
+ "言笑_ZH": 146,
190
+ "久利须_ZH": 147,
191
+ "嘉玛_ZH": 148,
192
+ "艾文_ZH": 149,
193
+ "克洛琳德_ZH": 150,
194
+ "丹吉尔_ZH": 151,
195
+ "女士_ZH": 152,
196
+ "白老先生_ZH": 153,
197
+ "天目十五_ZH": 154,
198
+ "老孟_ZH": 155,
199
+ "巴达维_ZH": 156,
200
+ "长生_ZH": 157,
201
+ "吴船长_ZH": 158,
202
+ "拉齐_ZH": 159,
203
+ "艾伯特_ZH": 160,
204
+ "松浦_ZH": 161,
205
+ "埃泽_ZH": 162,
206
+ "阿圆_ZH": 163,
207
+ "莫塞伊思_ZH": 164,
208
+ "阿拉夫_ZH": 165,
209
+ "杜吉耶_ZH": 166,
210
+ "石头_ZH": 167,
211
+ "百闻_ZH": 168,
212
+ "波洛_ZH": 169,
213
+ "斯坦利_ZH": 170,
214
+ "博易_ZH": 171,
215
+ "迈蒙_ZH": 172,
216
+ "掇星攫辰天君_ZH": 173,
217
+ "毗伽尔_ZH": 174,
218
+ "芙卡洛斯_ZH": 175,
219
+ "恶龙_ZH": 176,
220
+ "恕筠_ZH": 177,
221
+ "知易_ZH": 178,
222
+ "克列门特_ZH": 179,
223
+ "大慈树王_ZH": 180,
224
+ "西拉杰_ZH": 181,
225
+ "上杉_ZH": 182,
226
+ "阿尔卡米_ZH": 183,
227
+ "纯水精灵_ZH": 184,
228
+ "常九爷_ZH": 185,
229
+ "沙扎曼_ZH": 186,
230
+ "田铁嘴_ZH": 187,
231
+ "克罗索_ZH": 188,
232
+ "阿巴图伊_ZH": 189,
233
  "阿佩普_ZH": 190,
234
  "埃尔欣根_ZH": 191,
235
  "萨赫哈蒂_ZH": 192,
236
  "塔杰·拉德卡尼_ZH": 193,
237
  "安西_ZH": 194,
238
+ "陆行岩本真蕈·元素生命_ZH": 195,
239
+ "派蒙_JP": 196,
240
+ "纳西妲_JP": 197,
241
+ "凯亚_JP": 198,
242
+ "阿贝多_JP": 199,
243
+ "温迪_JP": 200,
244
+ "枫原万叶_JP": 201,
245
+ "钟离_JP": 202,
246
+ "荒泷一斗_JP": 203,
247
+ "八重神子_JP": 204,
248
+ "艾尔海森_JP": 205,
249
+ "提纳里_JP": 206,
250
+ "迪希雅_JP": 207,
251
+ "卡维_JP": 208,
252
+ "宵宫_JP": 209,
253
+ "那维莱特_JP": 210,
254
+ "莱依拉_JP": 211,
255
+ "赛诺_JP": 212,
256
+ "莫娜_JP": 213,
257
+ "诺艾尔_JP": 214,
258
+ "托马_JP": 215,
259
+ "凝光_JP": 216,
260
+ "林尼_JP": 217,
261
+ "北斗_JP": 218,
262
+ "柯莱_JP": 219,
263
+ "神里绫华_JP": 220,
264
+ "可莉_JP": 221,
265
+ "芭芭拉_JP": 222,
266
+ "雷电将军_JP": 223,
267
+ "娜维娅_JP": 224,
268
+ "芙宁娜_JP": 225,
269
+ "珊瑚宫心海_JP": 226,
270
+ "鹿野院平藏_JP": 227,
271
+ "迪奥娜_JP": 228,
272
+ "_JP": 229,
273
+ "五郎_JP": 230,
274
+ "班尼特_JP": 231,
275
+ "达达利亚_JP": 232,
276
+ "安柏_JP": 233,
277
+ "莱欧斯利_JP": 234,
278
+ "夜兰_JP": 235,
279
+ "妮露_JP": 236,
280
+ "辛焱_JP": 237,
281
+ "丽莎_JP": 238,
282
+ "珐露珊_JP": 239,
283
+ "_JP": 240,
284
+ "香菱_JP": 241,
285
+ "迪卢克_JP": 242,
286
+ "砂糖_JP": 243,
287
+ "烟绯_JP": 244,
288
+ "早柚_JP": 245,
289
+ "云堇_JP": 246,
290
+ "刻晴_JP": 247,
291
+ "重云_JP": 248,
292
+ "优菈_JP": 249,
293
+ "胡桃_JP": 250,
294
+ "流浪者_JP": 251,
295
+ "久岐忍_JP": 252,
296
+ "神里绫人_JP": 253,
297
+ "甘雨_JP": 254,
298
+ "戴因斯雷布_JP": 255,
299
+ "菲谢尔_JP": 256,
300
+ "白术_JP": 257,
301
+ "行秋_JP": 258,
302
+ "九条裟罗_JP": 259,
303
+ "夏洛蒂_JP": 260,
304
+ "雷泽_JP": 261,
305
+ "申鹤_JP": 262,
306
+ "_JP": 263,
307
+ "_JP": 264,
308
+ "迪娜泽黛_JP": 265,
309
+ "凯瑟琳_JP": 266,
310
+ "多莉_JP": 267,
311
+ "坎蒂丝_JP": 268,
312
+ "琳妮特_JP": 269,
313
+ "萍姥姥_JP": 270,
314
+ "罗莎莉亚_JP": 271,
315
+ "埃德_JP": 272,
316
+ "爱贝尔_JP": 273,
317
+ "伊迪娅_JP": 274,
318
+ "留云借风真君_JP": 275,
319
+ "绮良良_JP": 276,
320
+ "陌生人_JP": 277,
321
  "七七_JP": 278,
322
  "式大将_JP": 279,
323
  "瑶瑶_JP": 280,
 
325
  "菲米尼_JP": 282,
326
  "米卡_JP": 283,
327
  "哲平_JP": 284,
328
+ "浮游水蕈兽·元素生命_JP": 285,
329
+ "大肉丸_JP": 286,
330
+ "托克_JP": 287,
331
+ "蒂玛乌斯_JP": 288,
332
+ "昆钧_JP": 289,
333
+ "欧菲妮_JP": 290,
334
+ "塞琉斯_JP": 291,
335
+ "仆人_JP": 292,
336
+ "迈勒斯_JP": 293,
337
+ "希格雯_JP": 294,
338
+ "阿守_JP": 295,
339
+ "拉赫曼_JP": 296,
340
+ "杜拉夫_JP": 297,
341
+ "伊利亚斯_JP": 298,
342
+ "阿晃_JP": 299,
343
+ "旁白_JP": 300,
344
+ "爱德琳_JP": 301,
345
+ "埃洛伊_JP": 302,
346
+ "德沃沙克_JP": 303,
347
+ "玛乔丽_JP": 304,
348
+ "塞塔蕾_JP": 305,
349
+ "柊千里_JP": 306,
350
+ "海芭夏_JP": 307,
351
+ "九条镰治_JP": 308,
352
+ "阿娜耶_JP": 309,
353
+ "笼钓瓶一心_JP": 310,
354
+ "回声海螺_JP": 311,
355
+ "劳维克_JP": 312,
356
+ "元太_JP": 313,
357
+ "阿扎尔_JP": 314,
358
+ "查尔斯_JP": 315,
359
+ "阿洛瓦_JP": 316,
360
+ "埃勒曼_JP": 317,
361
+ "纳比尔_JP": 318,
362
+ "莎拉_JP": 319,
363
+ "康纳_JP": 320,
364
+ "博来_JP": 321,
365
+ "玛塞勒_JP": 322,
366
+ "阿祇_JP": 323,
367
+ "博士_JP": 324,
368
+ "迪尔菲_JP": 325,
369
+ "玛格丽特_JP": 326,
370
+ "宛烟_JP": 327,
371
+ "羽生田千鹤_JP": 328,
372
+ "海妮耶_JP": 329,
373
+ "霍夫曼_JP": 330,
374
+ "旅行者_JP": 331,
375
+ "佐西摩斯_JP": 332,
376
+ "舒伯特_JP": 333,
377
+ "鹿野奈奈_JP": 334,
378
+ "天叔_JP": 335,
379
+ "龙二_JP": 336,
380
+ "艾莉丝_JP": 337,
381
+ "莺儿_JP": 338,
382
+ "嘉良_JP": 339,
383
+ "珊瑚_JP": 340,
384
+ "言笑_JP": 341,
385
+ "一心传名刀_JP": 342,
386
+ "费迪南德_JP": 343,
387
+ "久利须_JP": 344,
388
+ "嘉玛_JP": 345,
389
+ "艾文_JP": 346,
390
+ "克洛琳德_JP": 347,
391
+ "丹吉尔_JP": 348,
392
+ "天目十五_JP": 349,
393
+ "女士_JP": 350,
394
+ "老孟_JP": 351,
395
+ "白老先生_JP": 352,
396
+ "舍利夫_JP": 353,
397
+ "巴达维_JP": 354,
398
+ "拉齐_JP": 355,
399
+ "长生_JP": 356,
400
+ "吴船长_JP": 357,
401
+ "艾伯特_JP": 358,
402
+ "松浦_JP": 359,
403
+ "埃泽_JP": 360,
404
+ "阿圆_JP": 361,
405
+ "阿拉夫_JP": 362,
406
+ "莫塞伊思_JP": 363,
407
+ "石头_JP": 364,
408
+ "百闻_JP": 365,
409
+ "杜吉耶_JP": 366,
410
+ "波洛_JP": 367,
411
+ "掇星攫辰天君_JP": 368,
412
+ "迈蒙_JP": 369,
413
+ "博易_JP": 370,
414
+ "诗筠_JP": 371,
415
+ "斯坦利_JP": 372,
416
+ "毗伽尔_JP": 373,
417
+ "芙卡洛斯_JP": 374,
418
+ "恶龙_JP": 375,
419
+ "小仓澪_JP": 376,
420
+ "恕筠_JP": 377,
421
+ "知易_JP": 378,
422
+ "克列门特_JP": 379,
423
+ "大慈树王_JP": 380,
424
+ "望雅_JP": 381,
425
+ "黑田_JP": 382,
426
+ "卡莉娜_JP": 383,
427
+ "马姆杜_JP": 384,
428
+ "科林斯_JP": 385,
429
+ "上杉_JP": 386,
430
+ "西拉杰_JP": 387,
431
+ "菲尔戈黛特_JP": 388,
432
+ "一平_JP": 389,
433
+ "纯水精灵_JP": 390,
434
+ "阿尔卡米_JP": 391,
435
+ "老戴_JP": 392,
436
+ "谢赫祖拜尔_JP": 393,
437
+ "沙扎曼_JP": 394,
438
+ "田铁嘴_JP": 395,
439
+ "小野寺_JP": 396,
440
+ "百识_JP": 397,
441
+ "克罗索_JP": 398,
442
+ "莱斯格_JP": 399,
443
+ "芷巧_JP": 400,
444
+ "加藤洋平_JP": 401,
445
+ "阿巴图伊_JP": 402,
446
+ "埃尔欣根_JP": 403,
447
+ "斯嘉莉_JP": 404,
448
+ "阿佩普_JP": 405,
449
+ "巫女_JP": 406,
450
+ "卡布斯_JP": 407,
451
+ "洛伦佐_JP": 408,
452
+ "萨赫哈蒂_JP": 409,
453
+ "娜德瓦_JP": 410,
454
+ "塞德娜_JP": 411,
455
+ "塔杰·拉德卡尼_JP": 412,
456
+ "绘星_JP": 413,
457
+ "泽田_JP": 414,
458
+ "安西_JP": 415,
459
+ "拉伊德_JP": 416,
460
+ "亚卡巴_JP": 417,
461
+ "有乐斋_JP": 418,
462
+ "莱昂_JP": 419,
463
+ "尤苏波夫_JP": 420,
464
+ "夏妮_JP": 421,
465
+ "埃舍尔_JP": 422,
466
+ "萨齐因_JP": 423,
467
+ "古山_JP": 424,
468
+ "自称渊上之物_JP": 425,
469
+ "丹羽_JP": 426,
470
+ "塞萨尔的日记_JP": 427,
471
+ "派蒙_EN": 428,
472
+ "纳西妲_EN": 429,
473
+ "凯亚_EN": 430,
474
+ "阿贝多_EN": 431,
475
+ "温迪_EN": 432,
476
+ "枫原万叶_EN": 433,
477
+ "钟离_EN": 434,
478
+ "荒泷一斗_EN": 435,
479
+ "八重神子_EN": 436,
480
+ "艾尔海森_EN": 437,
481
+ "提纳里_EN": 438,
482
+ "迪希雅_EN": 439,
483
+ "卡维_EN": 440,
484
+ "宵宫_EN": 441,
485
+ "莱依拉_EN": 442,
486
+ "那维莱特_EN": 443,
487
+ "赛诺_EN": 444,
488
+ "莫娜_EN": 445,
489
+ "诺艾尔_EN": 446,
490
+ "托马_EN": 447,
491
+ "凝光_EN": 448,
492
+ "林尼_EN": 449,
493
+ "北斗_EN": 450,
494
+ "柯莱_EN": 451,
495
+ "神里绫华_EN": 452,
496
+ "可莉_EN": 453,
497
+ "芭芭拉_EN": 454,
498
+ "雷电将军_EN": 455,
499
+ "娜维娅_EN": 456,
500
+ "芙宁娜_EN": 457,
501
+ "珊瑚宫心海_EN": 458,
502
+ "鹿野院平藏_EN": 459,
503
+ "迪奥娜_EN": 460,
504
+ "五郎_EN": 461,
505
+ "_EN": 462,
506
+ "班尼特_EN": 463,
507
+ "达达利亚_EN": 464,
508
+ "安柏_EN": 465,
509
+ "莱欧斯利_EN": 466,
510
+ "夜兰_EN": 467,
511
+ "妮露_EN": 468,
512
+ "辛焱_EN": 469,
513
+ "珐露珊_EN": 470,
514
+ "丽莎_EN": 471,
515
+ "_EN": 472,
516
+ "香菱_EN": 473,
517
+ "迪卢克_EN": 474,
518
+ "砂糖_EN": 475,
519
+ "烟绯_EN": 476,
520
+ "早柚_EN": 477,
521
+ "云堇_EN": 478,
522
+ "刻晴_EN": 479,
523
+ "重云_EN": 480,
524
+ "优菈_EN": 481,
525
+ "胡桃_EN": 482,
526
+ "流浪者_EN": 483,
527
+ "久岐忍_EN": 484,
528
+ "神里绫人_EN": 485,
529
+ "甘雨_EN": 486,
530
+ "戴因斯雷布_EN": 487,
531
+ "菲谢尔_EN": 488,
532
+ "白术_EN": 489,
533
+ "行秋_EN": 490,
534
+ "九条裟罗_EN": 491,
535
+ "夏洛蒂_EN": 492,
536
+ "雷泽_EN": 493,
537
+ "申鹤_EN": 494,
538
+ "_EN": 495,
539
+ "_EN": 496,
540
+ "迪娜泽黛_EN": 497,
541
+ "凯瑟琳_EN": 498,
542
+ "多莉_EN": 499,
543
+ "坎蒂丝_EN": 500,
544
+ "琳妮特_EN": 501,
545
+ "萍姥姥_EN": 502,
546
+ "罗莎莉亚_EN": 503,
547
+ "埃德_EN": 504,
548
+ "爱贝尔_EN": 505,
549
+ "伊迪娅_EN": 506,
550
+ "留云借风真君_EN": 507,
551
+ "绮良良_EN": 508,
552
+ "陌生人_EN": 509,
553
+ "七七_EN": 510,
554
+ "式大将_EN": 511,
555
+ "瑶瑶_EN": 512,
556
+ "奥兹_EN": 513,
557
+ "菲米尼_EN": 514,
558
+ "米卡_EN": 515,
559
+ "哲平_EN": 516,
560
+ "浮游水蕈兽·元素生命_EN": 517,
561
+ "大肉丸_EN": 518,
562
+ "托克_EN": 519,
563
+ "蒂玛乌斯_EN": 520,
564
+ "昆钧_EN": 521,
565
+ "欧菲妮_EN": 522,
566
+ "塞琉斯_EN": 523,
567
+ "仆人_EN": 524,
568
+ "迈勒斯_EN": 525,
569
+ "希格雯_EN": 526,
570
+ "阿守_EN": 527,
571
+ "拉赫曼_EN": 528,
572
+ "杜拉夫_EN": 529,
573
+ "伊利亚斯_EN": 530,
574
+ "阿晃_EN": 531,
575
+ "旁白_EN": 532,
576
+ "爱德琳_EN": 533,
577
+ "埃洛伊_EN": 534,
578
+ "德沃沙克_EN": 535,
579
+ "玛乔丽_EN": 536,
580
+ "塞塔蕾_EN": 537,
581
+ "柊千里_EN": 538,
582
+ "海芭夏_EN": 539,
583
+ "九条镰治_EN": 540,
584
+ "阿娜耶_EN": 541,
585
+ "笼钓瓶一心_EN": 542,
586
+ "回声海螺_EN": 543,
587
+ "劳维克_EN": 544,
588
+ "元太_EN": 545,
589
+ "阿扎尔_EN": 546,
590
+ "查尔斯_EN": 547,
591
+ "阿洛瓦_EN": 548,
592
+ "埃勒曼_EN": 549,
593
+ "纳比尔_EN": 550,
594
+ "莎拉_EN": 551,
595
+ "康纳_EN": 552,
596
+ "博来_EN": 553,
597
+ "玛塞勒_EN": 554,
598
+ "阿祇_EN": 555,
599
+ "博士_EN": 556,
600
+ "迪尔菲_EN": 557,
601
+ "宛烟_EN": 558,
602
+ "玛格丽特_EN": 559,
603
+ "羽生田千鹤_EN": 560,
604
+ "海妮耶_EN": 561,
605
+ "霍夫曼_EN": 562,
606
+ "旅行者_EN": 563,
607
+ "佐西摩斯_EN": 564,
608
+ "鹿野奈奈_EN": 565,
609
+ "舒伯特_EN": 566,
610
+ "天叔_EN": 567,
611
+ "艾莉丝_EN": 568,
612
+ "龙二_EN": 569,
613
+ "莺儿_EN": 570,
614
+ "嘉良_EN": 571,
615
+ "珊瑚_EN": 572,
616
+ "费迪南德_EN": 573,
617
+ "言笑_EN": 574,
618
+ "一心传名刀_EN": 575,
619
+ "久利须_EN": 576,
620
+ "嘉玛_EN": 577,
621
+ "艾文_EN": 578,
622
+ "克洛琳德_EN": 579,
623
+ "丹吉尔_EN": 580,
624
+ "女士_EN": 581,
625
+ "天目十五_EN": 582,
626
+ "老孟_EN": 583,
627
+ "白老先生_EN": 584,
628
+ "舍利夫_EN": 585,
629
+ "巴达维_EN": 586,
630
+ "拉齐_EN": 587,
631
+ "长生_EN": 588,
632
+ "吴船长_EN": 589,
633
+ "艾伯特_EN": 590,
634
+ "松浦_EN": 591,
635
+ "埃泽_EN": 592,
636
+ "阿圆_EN": 593,
637
+ "阿拉夫_EN": 594,
638
+ "莫塞伊思_EN": 595,
639
+ "石头_EN": 596,
640
+ "百闻_EN": 597,
641
+ "杜吉耶_EN": 598,
642
+ "波洛_EN": 599,
643
+ "斯坦利_EN": 600,
644
+ "掇星攫辰天君_EN": 601,
645
+ "迈蒙_EN": 602,
646
+ "博易_EN": 603,
647
+ "诗筠_EN": 604,
648
+ "毗伽尔_EN": 605,
649
+ "慧心_EN": 606,
650
+ "芙卡洛斯_EN": 607,
651
+ "恶龙_EN": 608,
652
+ "小仓澪_EN": 609,
653
+ "恕筠_EN": 610,
654
+ "知易_EN": 611,
655
+ "克列门特_EN": 612,
656
+ "大慈树王_EN": 613,
657
+ "维多利亚_EN": 614,
658
+ "黑田_EN": 615,
659
+ "马姆杜_EN": 616,
660
+ "科林斯_EN": 617,
661
+ "上杉_EN": 618,
662
+ "西拉杰_EN": 619,
663
+ "宁禄_EN": 620,
664
+ "纯水精灵_EN": 621,
665
+ "常九爷_EN": 622,
666
+ "阿尔卡米_EN": 623,
667
+ "沙扎曼_EN": 624,
668
+ "田铁嘴_EN": 625,
669
+ "加萨尼_EN": 626,
670
+ "克罗索_EN": 627,
671
+ "星稀_EN": 628,
672
+ "莱斯格_EN": 629,
673
+ "阿巴图伊_EN": 630,
674
+ "埃尔欣根_EN": 631,
675
+ "阿佩普_EN": 632,
676
+ "萨赫哈蒂_EN": 633,
677
+ "洛伦佐_EN": 634,
678
+ "塔杰·拉德卡尼_EN": 635,
679
+ "泽田_EN": 636,
680
+ "安西_EN": 637,
681
  "埃舍尔_EN": 638,
682
+ "三月七_ZH": 639,
683
+ "丹恒_ZH": 640,
684
+ "希儿_ZH": 641,
685
+ "娜塔莎_ZH": 642,
686
+ "希露瓦_ZH": 643,
687
+ "瓦尔特_ZH": 644,
688
+ "佩拉_ZH": 645,
689
+ "布洛妮娅_ZH": 646,
690
+ "虎克_ZH": 647,
691
+ "素裳_ZH": 648,
692
+ "克拉拉_ZH": 649,
693
+ "符玄_ZH": 650,
694
+ "白露_ZH": 651,
695
+ "杰帕德_ZH": 652,
696
+ "景元_ZH": 653,
697
+ "藿藿_ZH": 654,
698
+ "姬子_ZH": 655,
699
+ "_ZH": 656,
700
+ "_ZH": 657,
701
+ "卡芙卡_ZH": 658,
702
+ "桂乃芬_ZH": 659,
703
+ "艾丝妲_ZH": 660,
704
+ "玲可_ZH": 661,
705
+ "彦卿_ZH": 662,
706
+ "托帕_ZH": 663,
707
+ "驭空_ZH": 664,
708
+ "浮烟_ZH": 665,
709
+ "停云_ZH": 666,
710
+ "镜流_ZH": 667,
711
+ "罗刹_ZH": 668,
712
+ "卢卡_ZH": 669,
713
+ "史瓦罗_ZH": 670,
714
+ "黑塔_ZH": 671,
715
+ "桑博_ZH": 672,
716
+ "伦纳德_ZH": 673,
717
+ "明曦_ZH": 674,
718
+ "银狼_ZH": 675,
719
+ "帕姆_ZH": 676,
720
+ "青雀_ZH": 677,
721
+ "乔瓦尼_ZH": 678,
722
+ "公输师傅_ZH": 679,
723
+ "晴霓_ZH": 680,
724
+ "螺丝咕姆_ZH": 681,
725
+ "阿兰_ZH": 682,
726
+ "奥列格_ZH": 683,
727
+ "丹枢_ZH": 684,
728
+ "尾巴_ZH": 685,
729
+ "寒鸦_ZH": 686,
730
+ "雪衣_ZH": 687,
731
+ "可可利亚_ZH": 688,
732
+ "青镞_ZH": 689,
733
+ "半夏_ZH": 690,
734
+ "银枝_ZH": 691,
735
+ "大毫_ZH": 692,
736
+ "霄翰_ZH": 693,
737
+ "信使_ZH": 694,
738
+ "费斯曼_ZH": 695,
739
+ "绿芙蓉_ZH": 696,
740
+ "金人会长_ZH": 697,
741
+ "维利特_ZH": 698,
742
+ "维尔德_ZH": 699,
743
+ "斯科特_ZH": 700,
744
+ "卡波特_ZH": 701,
745
+ "_ZH": 702,
746
+ "岩明_ZH": 703,
747
+ "浣溪_ZH": 704,
748
+ "三月七_JP": 705,
749
+ "丹恒_JP": 706,
750
+ "希儿_JP": 707,
751
+ "娜塔莎_JP": 708,
752
+ "希露瓦_JP": 709,
753
+ "瓦尔特_JP": 710,
754
+ "佩拉_JP": 711,
755
+ "布洛妮娅_JP": 712,
756
+ "虎克_JP": 713,
757
+ "素裳_JP": 714,
758
+ "克拉拉_JP": 715,
759
+ "符玄_JP": 716,
760
+ "白露_JP": 717,
761
+ "杰帕德_JP": 718,
762
+ "景元_JP": 719,
763
+ "藿藿_JP": 720,
764
+ "姬子_JP": 721,
765
+ "卡芙卡_JP": 722,
766
+ "_JP": 723,
767
+ "_JP": 724,
768
+ "桂乃芬_JP": 725,
769
+ "艾丝妲_JP": 726,
770
+ "彦卿_JP": 727,
771
+ "玲可_JP": 728,
772
+ "托帕_JP": 729,
773
+ "驭空_JP": 730,
774
+ "浮烟_JP": 731,
775
+ "停云_JP": 732,
776
+ "镜流_JP": 733,
777
+ "罗刹_JP": 734,
778
+ "卢卡_JP": 735,
779
+ "史瓦罗_JP": 736,
780
+ "黑塔_JP": 737,
781
+ "桑博_JP": 738,
782
+ "伦纳德_JP": 739,
783
+ "明曦_JP": 740,
784
+ "银狼_JP": 741,
785
+ "帕姆_JP": 742,
786
+ "青雀_JP": 743,
787
+ "乔瓦尼_JP": 744,
788
+ "公输师傅_JP": 745,
789
+ "晴霓_JP": 746,
790
+ "螺丝咕姆_JP": 747,
791
+ "阿兰_JP": 748,
792
+ "奥列格_JP": 749,
793
+ "丹枢_JP": 750,
794
+ "尾巴_JP": 751,
795
+ "寒鸦_JP": 752,
796
+ "雪衣_JP": 753,
797
+ "可可利亚_JP": 754,
798
+ "青镞_JP": 755,
799
+ "半夏_JP": 756,
800
+ "银枝_JP": 757,
801
+ "大毫_JP": 758,
802
+ "霄翰_JP": 759,
803
+ "信使_JP": 760,
804
+ "费斯曼_JP": 761,
805
+ "绿芙蓉_JP": 762,
806
+ "金人会长_JP": 763,
807
+ "维利特_JP": 764,
808
+ "维尔德_JP": 765,
809
+ "斯科特_JP": 766,
810
+ "_JP": 767,
811
+ "卡波特_JP": 768,
812
+ "岩明_JP": 769,
813
+ "浣溪_JP": 770,
814
+ "净砚_JP": 771,
815
+ "紫月季_JP": 772,
816
+ "歌蒂_JP": 773,
817
+ "奇怪的云骑_JP": 774,
818
+ "幻胧_JP": 775,
819
+ "斯薇塔_JP": 776,
820
+ "隐书_JP": 777,
821
+ "三月七_EN": 778,
822
+ "丹恒_EN": 779,
823
+ "希儿_EN": 780,
824
+ "娜塔莎_EN": 781,
825
+ "希露瓦_EN": 782,
826
+ "瓦尔特_EN": 783,
827
+ "佩拉_EN": 784,
828
+ "布洛妮娅_EN": 785,
829
+ "虎克_EN": 786,
830
+ "素裳_EN": 787,
831
+ "克拉拉_EN": 788,
832
+ "符玄_EN": 789,
833
+ "白露_EN": 790,
834
+ "杰帕德_EN": 791,
835
+ "景元_EN": 792,
836
+ "藿藿_EN": 793,
837
+ "姬子_EN": 794,
838
+ "卡芙卡_EN": 795,
839
+ "_EN": 796,
840
+ "_EN": 797,
841
+ "桂乃芬_EN": 798,
842
+ "艾丝妲_EN": 799,
843
+ "彦卿_EN": 800,
844
+ "玲可_EN": 801,
845
+ "托帕_EN": 802,
846
+ "驭空_EN": 803,
847
+ "浮烟_EN": 804,
848
+ "停云_EN": 805,
849
+ "镜流_EN": 806,
850
+ "罗刹_EN": 807,
851
+ "卢卡_EN": 808,
852
+ "史瓦罗_EN": 809,
853
+ "黑塔_EN": 810,
854
+ "桑博_EN": 811,
855
+ "伦纳德_EN": 812,
856
+ "明曦_EN": 813,
857
+ "银狼_EN": 814,
858
+ "帕姆_EN": 815,
859
+ "青雀_EN": 816,
860
+ "乔瓦尼_EN": 817,
861
+ "公输师傅_EN": 818,
862
+ "晴霓_EN": 819,
863
+ "螺丝咕姆_EN": 820,
864
+ "阿兰_EN": 821,
865
+ "奥列格_EN": 822,
866
+ "丹枢_EN": 823,
867
+ "尾巴_EN": 824,
868
+ "寒鸦_EN": 825,
869
+ "雪衣_EN": 826,
870
+ "可可利亚_EN": 827,
871
+ "青镞_EN": 828,
872
+ "半夏_EN": 829,
873
+ "银枝_EN": 830,
874
+ "大毫_EN": 831,
875
+ "霄翰_EN": 832,
876
+ "信使_EN": 833,
877
+ "费斯曼_EN": 834,
878
+ "绿芙蓉_EN": 835,
879
+ "金人会长_EN": 836,
880
+ "维利特_EN": 837,
881
+ "维尔德_EN": 838,
882
+ "_EN": 839,
883
+ "卡波特_EN": 840,
884
+ "岩明_EN": 841,
885
+ "浣溪_EN": 842,
886
+ "紫月季_EN": 843,
887
+ "幻胧_EN": 844,
888
+ "女声_EN": 845,
889
+ "陆景和": 846,
890
+ "莫弈": 847,
891
+ "左然": 848,
892
+ "夏彦": 849
 
 
 
 
 
893
  }
894
  },
895
  "model": {
 
944
  ],
945
  "n_layers_q": 3,
946
  "use_spectral_norm": false,
947
+ "gin_channels": 512,
948
+ "slm": {
949
+ "model": "./slm/wavlm-base-plus",
950
+ "sr": 16000,
951
+ "hidden": 768,
952
+ "nlayers": 13,
953
+ "initial_channel": 64
954
+ }
955
  },
956
+ "version": "2.3"
957
  }
data_utils.py CHANGED
@@ -44,10 +44,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
44
  self.min_text_len = getattr(hparams, "min_text_len", 1)
45
  self.max_text_len = getattr(hparams, "max_text_len", 384)
46
 
47
- self.empty_emo = torch.squeeze(
48
- torch.load("empty_emo.npy", map_location="cpu"), dim=1
49
- )
50
-
51
  random.seed(1234)
52
  random.shuffle(self.audiopaths_sid_text)
53
  self._filter()
@@ -98,14 +94,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
98
  spec, wav = self.get_audio(audiopath)
99
  sid = torch.LongTensor([int(self.spk_map[sid])])
100
 
101
- if np.random.rand() > 0.1:
102
- emo = torch.squeeze(
103
- torch.load(audiopath.replace(".wav", ".emo.npy"), map_location="cpu"),
104
- dim=1,
105
- )
106
- else:
107
- emo = self.empty_emo
108
- return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
109
 
110
  def get_audio(self, filename):
111
  audio, sampling_rate = load_wav_to_torch(filename)
@@ -168,15 +157,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
168
 
169
  if language_str == "ZH":
170
  bert = bert_ori
171
- ja_bert = torch.rand(1024, len(phone))
172
- en_bert = torch.rand(1024, len(phone))
173
  elif language_str == "JP":
174
- bert = torch.rand(1024, len(phone))
175
  ja_bert = bert_ori
176
- en_bert = torch.rand(1024, len(phone))
177
  elif language_str == "EN":
178
- bert = torch.rand(1024, len(phone))
179
- ja_bert = torch.rand(1024, len(phone))
180
  en_bert = bert_ori
181
  phone = torch.LongTensor(phone)
182
  tone = torch.LongTensor(tone)
@@ -226,7 +215,6 @@ class TextAudioSpeakerCollate:
226
  bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
227
  ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
228
  en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
229
- emo = torch.FloatTensor(len(batch), 512)
230
 
231
  spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
232
  wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
@@ -238,7 +226,6 @@ class TextAudioSpeakerCollate:
238
  bert_padded.zero_()
239
  ja_bert_padded.zero_()
240
  en_bert_padded.zero_()
241
- emo.zero_()
242
 
243
  for i in range(len(ids_sorted_decreasing)):
244
  row = batch[ids_sorted_decreasing[i]]
@@ -272,8 +259,6 @@ class TextAudioSpeakerCollate:
272
  en_bert = row[8]
273
  en_bert_padded[i, :, : en_bert.size(1)] = en_bert
274
 
275
- emo[i, :] = row[9]
276
-
277
  return (
278
  text_padded,
279
  text_lengths,
@@ -287,7 +272,6 @@ class TextAudioSpeakerCollate:
287
  bert_padded,
288
  ja_bert_padded,
289
  en_bert_padded,
290
- emo,
291
  )
292
 
293
 
 
44
  self.min_text_len = getattr(hparams, "min_text_len", 1)
45
  self.max_text_len = getattr(hparams, "max_text_len", 384)
46
 
 
 
 
 
47
  random.seed(1234)
48
  random.shuffle(self.audiopaths_sid_text)
49
  self._filter()
 
94
  spec, wav = self.get_audio(audiopath)
95
  sid = torch.LongTensor([int(self.spk_map[sid])])
96
 
97
+ return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert)
 
 
 
 
 
 
 
98
 
99
  def get_audio(self, filename):
100
  audio, sampling_rate = load_wav_to_torch(filename)
 
157
 
158
  if language_str == "ZH":
159
  bert = bert_ori
160
+ ja_bert = torch.randn(1024, len(phone))
161
+ en_bert = torch.randn(1024, len(phone))
162
  elif language_str == "JP":
163
+ bert = torch.randn(1024, len(phone))
164
  ja_bert = bert_ori
165
+ en_bert = torch.randn(1024, len(phone))
166
  elif language_str == "EN":
167
+ bert = torch.randn(1024, len(phone))
168
+ ja_bert = torch.randn(1024, len(phone))
169
  en_bert = bert_ori
170
  phone = torch.LongTensor(phone)
171
  tone = torch.LongTensor(tone)
 
215
  bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
216
  ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
217
  en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
 
218
 
219
  spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
220
  wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
 
226
  bert_padded.zero_()
227
  ja_bert_padded.zero_()
228
  en_bert_padded.zero_()
 
229
 
230
  for i in range(len(ids_sorted_decreasing)):
231
  row = batch[ids_sorted_decreasing[i]]
 
259
  en_bert = row[8]
260
  en_bert_padded[i, :, : en_bert.size(1)] = en_bert
261
 
 
 
262
  return (
263
  text_padded,
264
  text_lengths,
 
272
  bert_padded,
273
  ja_bert_padded,
274
  en_bert_padded,
 
275
  )
276
 
277
 
default_config.yml CHANGED
@@ -83,11 +83,11 @@ train_ms:
83
  base:
84
  use_base_model: false
85
  repo_id: "Stardust_minus/Bert-VITS2"
86
- model_image: "Bert-VITS2_2.1-Emo底模" # openi网页的模型名
87
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
88
  model: "models"
89
  # 配置文件路径
90
- config_path: "configs/config.json"
91
  # 训练使用的worker,不建议超过CPU核心数
92
  num_workers: 16
93
  # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
@@ -104,7 +104,7 @@ webui:
104
  # 模型路径
105
  model: "models/G_8000.pth"
106
  # 配置文件路径
107
- config_path: "configs/config.json"
108
  # 端口号
109
  port: 7860
110
  # 是否公开部署,对外网开放
 
83
  base:
84
  use_base_model: false
85
  repo_id: "Stardust_minus/Bert-VITS2"
86
+ model_image: "Bert-VITS2_2.3底模" # openi网页的模型名
87
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
88
  model: "models"
89
  # 配置文件路径
90
+ config_path: "config.json"
91
  # 训练使用的worker,不建议超过CPU核心数
92
  num_workers: 16
93
  # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
 
104
  # 模型路径
105
  model: "models/G_8000.pth"
106
  # 配置文件路径
107
+ config_path: "config.json"
108
  # 端口号
109
  port: 7860
110
  # 是否公开部署,对外网开放
export_onnx.py CHANGED
@@ -2,11 +2,13 @@ from onnx_modules import export_onnx
2
  import os
3
 
4
  if __name__ == "__main__":
5
- export_path = "MyModel"
6
- model_path = "S:\\VSGIT\\bert-vits2\\G_178000.pth"
7
- config_path = "S:\\VSGIT\\bert-vits2\\config.json"
 
 
8
  if not os.path.exists("onnx"):
9
  os.makedirs("onnx")
10
  if not os.path.exists(f"onnx/{export_path}"):
11
  os.makedirs(f"onnx/{export_path}")
12
- export_onnx(export_path, model_path, config_path)
 
2
  import os
3
 
4
  if __name__ == "__main__":
5
+ export_path = "BertVits2.2PT"
6
+ model_path = "model\\G_0.pth"
7
+ config_path = "model\\config.json"
8
+ novq = False
9
+ dev = False
10
  if not os.path.exists("onnx"):
11
  os.makedirs("onnx")
12
  if not os.path.exists(f"onnx/{export_path}"):
13
  os.makedirs(f"onnx/{export_path}")
14
+ export_onnx(export_path, model_path, config_path, novq, dev)
for_deploy/infer.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 版本管理、兼容推理及模型加载实现。
3
+ 版本说明:
4
+ 1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
5
+ 2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
+ 特殊版本说明:
7
+ 1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
+ 2.2:当前版本
9
+ """
10
+ import torch
11
+ import commons
12
+ from text import cleaned_text_to_sequence, get_bert
13
+ from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
14
+ from text.cleaner import clean_text
15
+ import utils
16
+ import numpy as np
17
+
18
+ from models import SynthesizerTrn
19
+ from text.symbols import symbols
20
+
21
+ from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
22
+ from oldVersion.V210.text import symbols as V210symbols
23
+ from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
24
+ from oldVersion.V200.text import symbols as V200symbols
25
+ from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
26
+ from oldVersion.V111.text import symbols as V111symbols
27
+ from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
28
+ from oldVersion.V110.text import symbols as V110symbols
29
+ from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
30
+ from oldVersion.V101.text import symbols as V101symbols
31
+
32
+ from oldVersion import V111, V110, V101, V200, V210
33
+
34
+ # 当前版本信息
35
+ latest_version = "2.2"
36
+
37
+ # 版本兼容
38
+ SynthesizerTrnMap = {
39
+ "2.1": V210SynthesizerTrn,
40
+ "2.0.2-fix": V200SynthesizerTrn,
41
+ "2.0.1": V200SynthesizerTrn,
42
+ "2.0": V200SynthesizerTrn,
43
+ "1.1.1-fix": V111SynthesizerTrn,
44
+ "1.1.1": V111SynthesizerTrn,
45
+ "1.1": V110SynthesizerTrn,
46
+ "1.1.0": V110SynthesizerTrn,
47
+ "1.0.1": V101SynthesizerTrn,
48
+ "1.0": V101SynthesizerTrn,
49
+ "1.0.0": V101SynthesizerTrn,
50
+ }
51
+
52
+ symbolsMap = {
53
+ "2.1": V210symbols,
54
+ "2.0.2-fix": V200symbols,
55
+ "2.0.1": V200symbols,
56
+ "2.0": V200symbols,
57
+ "1.1.1-fix": V111symbols,
58
+ "1.1.1": V111symbols,
59
+ "1.1": V110symbols,
60
+ "1.1.0": V110symbols,
61
+ "1.0.1": V101symbols,
62
+ "1.0": V101symbols,
63
+ "1.0.0": V101symbols,
64
+ }
65
+
66
+
67
+ # def get_emo_(reference_audio, emotion, sid):
68
+ # emo = (
69
+ # torch.from_numpy(get_emo(reference_audio))
70
+ # if reference_audio and emotion == -1
71
+ # else torch.FloatTensor(
72
+ # np.load(f"emo_clustering/{sid}/cluster_center_{emotion}.npy")
73
+ # )
74
+ # )
75
+ # return emo
76
+
77
+
78
+ def get_net_g(model_path: str, version: str, device: str, hps):
79
+ if version != latest_version:
80
+ net_g = SynthesizerTrnMap[version](
81
+ len(symbolsMap[version]),
82
+ hps.data.filter_length // 2 + 1,
83
+ hps.train.segment_size // hps.data.hop_length,
84
+ n_speakers=hps.data.n_speakers,
85
+ **hps.model,
86
+ ).to(device)
87
+ else:
88
+ # 当前版本模型 net_g
89
+ net_g = SynthesizerTrn(
90
+ len(symbols),
91
+ hps.data.filter_length // 2 + 1,
92
+ hps.train.segment_size // hps.data.hop_length,
93
+ n_speakers=hps.data.n_speakers,
94
+ **hps.model,
95
+ ).to(device)
96
+ _ = net_g.eval()
97
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
98
+ return net_g
99
+
100
+
101
+ def get_text(text, language_str, bert, hps, device):
102
+ # 在此处实现当前版本的get_text
103
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
104
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
105
+
106
+ if hps.data.add_blank:
107
+ phone = commons.intersperse(phone, 0)
108
+ tone = commons.intersperse(tone, 0)
109
+ language = commons.intersperse(language, 0)
110
+ for i in range(len(word2ph)):
111
+ word2ph[i] = word2ph[i] * 2
112
+ word2ph[0] += 1
113
+ # bert_ori = get_bert(norm_text, word2ph, language_str, device)
114
+ bert_ori = bert[language_str].get_bert_feature(norm_text, word2ph, device)
115
+ del word2ph
116
+ assert bert_ori.shape[-1] == len(phone), phone
117
+
118
+ if language_str == "ZH":
119
+ bert = bert_ori
120
+ ja_bert = torch.randn(1024, len(phone))
121
+ en_bert = torch.randn(1024, len(phone))
122
+ elif language_str == "JP":
123
+ bert = torch.randn(1024, len(phone))
124
+ ja_bert = bert_ori
125
+ en_bert = torch.randn(1024, len(phone))
126
+ elif language_str == "EN":
127
+ bert = torch.randn(1024, len(phone))
128
+ ja_bert = torch.randn(1024, len(phone))
129
+ en_bert = bert_ori
130
+ else:
131
+ raise ValueError("language_str should be ZH, JP or EN")
132
+
133
+ assert bert.shape[-1] == len(
134
+ phone
135
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
136
+
137
+ phone = torch.LongTensor(phone)
138
+ tone = torch.LongTensor(tone)
139
+ language = torch.LongTensor(language)
140
+ return bert, ja_bert, en_bert, phone, tone, language
141
+
142
+
143
+ def infer(
144
+ text,
145
+ emotion,
146
+ sdp_ratio,
147
+ noise_scale,
148
+ noise_scale_w,
149
+ length_scale,
150
+ sid,
151
+ language,
152
+ hps,
153
+ net_g,
154
+ device,
155
+ bert=None,
156
+ clap=None,
157
+ reference_audio=None,
158
+ skip_start=False,
159
+ skip_end=False,
160
+ ):
161
+ # 2.2版本参数位置变了
162
+ # 2.1 参数新增 emotion reference_audio skip_start skip_end
163
+ inferMap_V3 = {
164
+ "2.1": V210.infer,
165
+ }
166
+ # 支持中日英三语版本
167
+ inferMap_V2 = {
168
+ "2.0.2-fix": V200.infer,
169
+ "2.0.1": V200.infer,
170
+ "2.0": V200.infer,
171
+ "1.1.1-fix": V111.infer_fix,
172
+ "1.1.1": V111.infer,
173
+ "1.1": V110.infer,
174
+ "1.1.0": V110.infer,
175
+ }
176
+ # 仅支持中文版本
177
+ # 在测试中,并未发现两个版本的模型不能互相通用
178
+ inferMap_V1 = {
179
+ "1.0.1": V101.infer,
180
+ "1.0": V101.infer,
181
+ "1.0.0": V101.infer,
182
+ }
183
+ version = hps.version if hasattr(hps, "version") else latest_version
184
+ # 非当前版本,根据版本号选择合适的infer
185
+ if version != latest_version:
186
+ if version in inferMap_V3.keys():
187
+ return inferMap_V3[version](
188
+ text,
189
+ sdp_ratio,
190
+ noise_scale,
191
+ noise_scale_w,
192
+ length_scale,
193
+ sid,
194
+ language,
195
+ hps,
196
+ net_g,
197
+ device,
198
+ reference_audio,
199
+ emotion,
200
+ skip_start,
201
+ skip_end,
202
+ )
203
+ if version in inferMap_V2.keys():
204
+ return inferMap_V2[version](
205
+ text,
206
+ sdp_ratio,
207
+ noise_scale,
208
+ noise_scale_w,
209
+ length_scale,
210
+ sid,
211
+ language,
212
+ hps,
213
+ net_g,
214
+ device,
215
+ )
216
+ if version in inferMap_V1.keys():
217
+ return inferMap_V1[version](
218
+ text,
219
+ sdp_ratio,
220
+ noise_scale,
221
+ noise_scale_w,
222
+ length_scale,
223
+ sid,
224
+ hps,
225
+ net_g,
226
+ device,
227
+ )
228
+ # 在此处实现当前版本的推理
229
+ # emo = get_emo_(reference_audio, emotion, sid)
230
+ if isinstance(reference_audio, np.ndarray):
231
+ emo = clap.get_clap_audio_feature(reference_audio, device)
232
+ else:
233
+ emo = clap.get_clap_text_feature(emotion, device)
234
+ emo = torch.squeeze(emo, dim=1)
235
+
236
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
237
+ text, language, bert, hps, device
238
+ )
239
+ if skip_start:
240
+ phones = phones[3:]
241
+ tones = tones[3:]
242
+ lang_ids = lang_ids[3:]
243
+ bert = bert[:, 3:]
244
+ ja_bert = ja_bert[:, 3:]
245
+ en_bert = en_bert[:, 3:]
246
+ if skip_end:
247
+ phones = phones[:-2]
248
+ tones = tones[:-2]
249
+ lang_ids = lang_ids[:-2]
250
+ bert = bert[:, :-2]
251
+ ja_bert = ja_bert[:, :-2]
252
+ en_bert = en_bert[:, :-2]
253
+ with torch.no_grad():
254
+ x_tst = phones.to(device).unsqueeze(0)
255
+ tones = tones.to(device).unsqueeze(0)
256
+ lang_ids = lang_ids.to(device).unsqueeze(0)
257
+ bert = bert.to(device).unsqueeze(0)
258
+ ja_bert = ja_bert.to(device).unsqueeze(0)
259
+ en_bert = en_bert.to(device).unsqueeze(0)
260
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
261
+ emo = emo.to(device).unsqueeze(0)
262
+ del phones
263
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
264
+ audio = (
265
+ net_g.infer(
266
+ x_tst,
267
+ x_tst_lengths,
268
+ speakers,
269
+ tones,
270
+ lang_ids,
271
+ bert,
272
+ ja_bert,
273
+ en_bert,
274
+ emo,
275
+ sdp_ratio=sdp_ratio,
276
+ noise_scale=noise_scale,
277
+ noise_scale_w=noise_scale_w,
278
+ length_scale=length_scale,
279
+ )[0][0, 0]
280
+ .data.cpu()
281
+ .float()
282
+ .numpy()
283
+ )
284
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
285
+ if torch.cuda.is_available():
286
+ torch.cuda.empty_cache()
287
+ return audio
288
+
289
+
290
+ def infer_multilang(
291
+ text,
292
+ sdp_ratio,
293
+ noise_scale,
294
+ noise_scale_w,
295
+ length_scale,
296
+ sid,
297
+ language,
298
+ hps,
299
+ net_g,
300
+ device,
301
+ bert=None,
302
+ clap=None,
303
+ reference_audio=None,
304
+ emotion=None,
305
+ skip_start=False,
306
+ skip_end=False,
307
+ ):
308
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
309
+ # emo = get_emo_(reference_audio, emotion, sid)
310
+ if isinstance(reference_audio, np.ndarray):
311
+ emo = clap.get_clap_audio_feature(reference_audio, device)
312
+ else:
313
+ emo = clap.get_clap_text_feature(emotion, device)
314
+ emo = torch.squeeze(emo, dim=1)
315
+ for idx, (txt, lang) in enumerate(zip(text, language)):
316
+ skip_start = (idx != 0) or (skip_start and idx == 0)
317
+ skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
318
+ (
319
+ temp_bert,
320
+ temp_ja_bert,
321
+ temp_en_bert,
322
+ temp_phones,
323
+ temp_tones,
324
+ temp_lang_ids,
325
+ ) = get_text(txt, lang, bert, hps, device)
326
+ if skip_start:
327
+ temp_bert = temp_bert[:, 3:]
328
+ temp_ja_bert = temp_ja_bert[:, 3:]
329
+ temp_en_bert = temp_en_bert[:, 3:]
330
+ temp_phones = temp_phones[3:]
331
+ temp_tones = temp_tones[3:]
332
+ temp_lang_ids = temp_lang_ids[3:]
333
+ if skip_end:
334
+ temp_bert = temp_bert[:, :-2]
335
+ temp_ja_bert = temp_ja_bert[:, :-2]
336
+ temp_en_bert = temp_en_bert[:, :-2]
337
+ temp_phones = temp_phones[:-2]
338
+ temp_tones = temp_tones[:-2]
339
+ temp_lang_ids = temp_lang_ids[:-2]
340
+ bert.append(temp_bert)
341
+ ja_bert.append(temp_ja_bert)
342
+ en_bert.append(temp_en_bert)
343
+ phones.append(temp_phones)
344
+ tones.append(temp_tones)
345
+ lang_ids.append(temp_lang_ids)
346
+ bert = torch.concatenate(bert, dim=1)
347
+ ja_bert = torch.concatenate(ja_bert, dim=1)
348
+ en_bert = torch.concatenate(en_bert, dim=1)
349
+ phones = torch.concatenate(phones, dim=0)
350
+ tones = torch.concatenate(tones, dim=0)
351
+ lang_ids = torch.concatenate(lang_ids, dim=0)
352
+ with torch.no_grad():
353
+ x_tst = phones.to(device).unsqueeze(0)
354
+ tones = tones.to(device).unsqueeze(0)
355
+ lang_ids = lang_ids.to(device).unsqueeze(0)
356
+ bert = bert.to(device).unsqueeze(0)
357
+ ja_bert = ja_bert.to(device).unsqueeze(0)
358
+ en_bert = en_bert.to(device).unsqueeze(0)
359
+ emo = emo.to(device).unsqueeze(0)
360
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
361
+ del phones
362
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
363
+ audio = (
364
+ net_g.infer(
365
+ x_tst,
366
+ x_tst_lengths,
367
+ speakers,
368
+ tones,
369
+ lang_ids,
370
+ bert,
371
+ ja_bert,
372
+ en_bert,
373
+ emo,
374
+ sdp_ratio=sdp_ratio,
375
+ noise_scale=noise_scale,
376
+ noise_scale_w=noise_scale_w,
377
+ length_scale=length_scale,
378
+ )[0][0, 0]
379
+ .data.cpu()
380
+ .float()
381
+ .numpy()
382
+ )
383
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
384
+ if torch.cuda.is_available():
385
+ torch.cuda.empty_cache()
386
+ return audio
for_deploy/infer_utils.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import (
5
+ AutoModelForMaskedLM,
6
+ AutoTokenizer,
7
+ DebertaV2Model,
8
+ DebertaV2Tokenizer,
9
+ ClapModel,
10
+ ClapProcessor,
11
+ )
12
+
13
+ from config import config
14
+ from text.japanese import text2sep_kata
15
+
16
+
17
+ class BertFeature:
18
+ def __init__(self, model_path, language="ZH"):
19
+ self.model_path = model_path
20
+ self.language = language
21
+ self.tokenizer = None
22
+ self.model = None
23
+ self.device = None
24
+
25
+ self._prepare()
26
+
27
+ def _get_device(self, device=config.bert_gen_config.device):
28
+ if (
29
+ sys.platform == "darwin"
30
+ and torch.backends.mps.is_available()
31
+ and device == "cpu"
32
+ ):
33
+ device = "mps"
34
+ if not device:
35
+ device = "cuda"
36
+ return device
37
+
38
+ def _prepare(self):
39
+ self.device = self._get_device()
40
+
41
+ if self.language == "EN":
42
+ self.tokenizer = DebertaV2Tokenizer.from_pretrained(self.model_path)
43
+ self.model = DebertaV2Model.from_pretrained(self.model_path).to(self.device)
44
+ else:
45
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
46
+ self.model = AutoModelForMaskedLM.from_pretrained(self.model_path).to(
47
+ self.device
48
+ )
49
+ self.model.eval()
50
+
51
+ def get_bert_feature(self, text, word2ph):
52
+ if self.language == "JP":
53
+ text = "".join(text2sep_kata(text)[0])
54
+ with torch.no_grad():
55
+ inputs = self.tokenizer(text, return_tensors="pt")
56
+ for i in inputs:
57
+ inputs[i] = inputs[i].to(self.device)
58
+ res = self.model(**inputs, output_hidden_states=True)
59
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
60
+
61
+ word2phone = word2ph
62
+ phone_level_feature = []
63
+ for i in range(len(word2phone)):
64
+ repeat_feature = res[i].repeat(word2phone[i], 1)
65
+ phone_level_feature.append(repeat_feature)
66
+
67
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
68
+
69
+ return phone_level_feature.T
70
+
71
+
72
+ class ClapFeature:
73
+ def __init__(self, model_path):
74
+ self.model_path = model_path
75
+ self.processor = None
76
+ self.model = None
77
+ self.device = None
78
+
79
+ self._prepare()
80
+
81
+ def _get_device(self, device=config.bert_gen_config.device):
82
+ if (
83
+ sys.platform == "darwin"
84
+ and torch.backends.mps.is_available()
85
+ and device == "cpu"
86
+ ):
87
+ device = "mps"
88
+ if not device:
89
+ device = "cuda"
90
+ return device
91
+
92
+ def _prepare(self):
93
+ self.device = self._get_device()
94
+
95
+ self.processor = ClapProcessor.from_pretrained(self.model_path)
96
+ self.model = ClapModel.from_pretrained(self.model_path).to(self.device)
97
+ self.model.eval()
98
+
99
+ def get_clap_audio_feature(self, audio_data):
100
+ with torch.no_grad():
101
+ inputs = self.processor(
102
+ audios=audio_data, return_tensors="pt", sampling_rate=48000
103
+ ).to(self.device)
104
+ emb = self.model.get_audio_features(**inputs)
105
+ return emb.T
106
+
107
+ def get_clap_text_feature(self, text):
108
+ with torch.no_grad():
109
+ inputs = self.processor(text=text, return_tensors="pt").to(self.device)
110
+ emb = self.model.get_text_features(**inputs)
111
+ return emb.T
for_deploy/webui.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
2
+ import os
3
+ import logging
4
+ import re_matching
5
+ from tools.sentence import split_by_language
6
+
7
+ logging.getLogger("numba").setLevel(logging.WARNING)
8
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
9
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
10
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
11
+
12
+ logging.basicConfig(
13
+ level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ import torch
19
+ import utils
20
+ from infer import infer, latest_version, get_net_g, infer_multilang
21
+ import gradio as gr
22
+ import webbrowser
23
+ import numpy as np
24
+ from config import config
25
+ from tools.translate import translate
26
+ import librosa
27
+ from infer_utils import BertFeature, ClapFeature
28
+
29
+
30
+ net_g = None
31
+
32
+ device = config.webui_config.device
33
+ if device == "mps":
34
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
35
+
36
+ os.environ["OMP_NUM_THREADS"] = "1"
37
+ os.environ["MKL_NUM_THREADS"] = "1"
38
+
39
+ bert_feature_map = {
40
+ "ZH": BertFeature(
41
+ "./bert/chinese-roberta-wwm-ext-large",
42
+ language="ZH",
43
+ ),
44
+ "JP": BertFeature(
45
+ "./bert/deberta-v2-large-japanese-char-wwm",
46
+ language="JP",
47
+ ),
48
+ "EN": BertFeature(
49
+ "./bert/deberta-v3-large",
50
+ language="EN",
51
+ ),
52
+ }
53
+
54
+ clap_feature = ClapFeature("./emotional/clap-htsat-fused")
55
+
56
+
57
+ def generate_audio(
58
+ slices,
59
+ sdp_ratio,
60
+ noise_scale,
61
+ noise_scale_w,
62
+ length_scale,
63
+ speaker,
64
+ language,
65
+ reference_audio,
66
+ emotion,
67
+ skip_start=False,
68
+ skip_end=False,
69
+ ):
70
+ audio_list = []
71
+ # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
72
+ with torch.no_grad():
73
+ for idx, piece in enumerate(slices):
74
+ skip_start = (idx != 0) and skip_start
75
+ skip_end = (idx != len(slices) - 1) and skip_end
76
+ audio = infer(
77
+ piece,
78
+ reference_audio=reference_audio,
79
+ emotion=emotion,
80
+ sdp_ratio=sdp_ratio,
81
+ noise_scale=noise_scale,
82
+ noise_scale_w=noise_scale_w,
83
+ length_scale=length_scale,
84
+ sid=speaker,
85
+ language=language,
86
+ hps=hps,
87
+ net_g=net_g,
88
+ device=device,
89
+ skip_start=skip_start,
90
+ skip_end=skip_end,
91
+ bert=bert_feature_map,
92
+ clap=clap_feature,
93
+ )
94
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
95
+ audio_list.append(audio16bit)
96
+ # audio_list.append(silence) # 将静音添加到列表中
97
+ return audio_list
98
+
99
+
100
+ def generate_audio_multilang(
101
+ slices,
102
+ sdp_ratio,
103
+ noise_scale,
104
+ noise_scale_w,
105
+ length_scale,
106
+ speaker,
107
+ language,
108
+ reference_audio,
109
+ emotion,
110
+ skip_start=False,
111
+ skip_end=False,
112
+ ):
113
+ audio_list = []
114
+ # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
115
+ with torch.no_grad():
116
+ for idx, piece in enumerate(slices):
117
+ skip_start = (idx != 0) and skip_start
118
+ skip_end = (idx != len(slices) - 1) and skip_end
119
+ audio = infer_multilang(
120
+ piece,
121
+ reference_audio=reference_audio,
122
+ emotion=emotion,
123
+ sdp_ratio=sdp_ratio,
124
+ noise_scale=noise_scale,
125
+ noise_scale_w=noise_scale_w,
126
+ length_scale=length_scale,
127
+ sid=speaker,
128
+ language=language[idx],
129
+ hps=hps,
130
+ net_g=net_g,
131
+ device=device,
132
+ skip_start=skip_start,
133
+ skip_end=skip_end,
134
+ )
135
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
136
+ audio_list.append(audio16bit)
137
+ # audio_list.append(silence) # 将静音添加到列表中
138
+ return audio_list
139
+
140
+
141
+ def tts_split(
142
+ text: str,
143
+ speaker,
144
+ sdp_ratio,
145
+ noise_scale,
146
+ noise_scale_w,
147
+ length_scale,
148
+ language,
149
+ cut_by_sent,
150
+ interval_between_para,
151
+ interval_between_sent,
152
+ reference_audio,
153
+ emotion,
154
+ ):
155
+ if language == "mix":
156
+ return ("invalid", None)
157
+ while text.find("\n\n") != -1:
158
+ text = text.replace("\n\n", "\n")
159
+ para_list = re_matching.cut_para(text)
160
+ audio_list = []
161
+ if not cut_by_sent:
162
+ for idx, p in enumerate(para_list):
163
+ skip_start = idx != 0
164
+ skip_end = idx != len(para_list) - 1
165
+ audio = infer(
166
+ p,
167
+ reference_audio=reference_audio,
168
+ emotion=emotion,
169
+ sdp_ratio=sdp_ratio,
170
+ noise_scale=noise_scale,
171
+ noise_scale_w=noise_scale_w,
172
+ length_scale=length_scale,
173
+ sid=speaker,
174
+ language=language,
175
+ hps=hps,
176
+ net_g=net_g,
177
+ device=device,
178
+ skip_start=skip_start,
179
+ skip_end=skip_end,
180
+ )
181
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
182
+ audio_list.append(audio16bit)
183
+ silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
184
+ audio_list.append(silence)
185
+ else:
186
+ for idx, p in enumerate(para_list):
187
+ skip_start = idx != 0
188
+ skip_end = idx != len(para_list) - 1
189
+ audio_list_sent = []
190
+ sent_list = re_matching.cut_sent(p)
191
+ for idx, s in enumerate(sent_list):
192
+ skip_start = (idx != 0) and skip_start
193
+ skip_end = (idx != len(sent_list) - 1) and skip_end
194
+ audio = infer(
195
+ s,
196
+ reference_audio=reference_audio,
197
+ emotion=emotion,
198
+ sdp_ratio=sdp_ratio,
199
+ noise_scale=noise_scale,
200
+ noise_scale_w=noise_scale_w,
201
+ length_scale=length_scale,
202
+ sid=speaker,
203
+ language=language,
204
+ hps=hps,
205
+ net_g=net_g,
206
+ device=device,
207
+ skip_start=skip_start,
208
+ skip_end=skip_end,
209
+ )
210
+ audio_list_sent.append(audio)
211
+ silence = np.zeros((int)(44100 * interval_between_sent))
212
+ audio_list_sent.append(silence)
213
+ if (interval_between_para - interval_between_sent) > 0:
214
+ silence = np.zeros(
215
+ (int)(44100 * (interval_between_para - interval_between_sent))
216
+ )
217
+ audio_list_sent.append(silence)
218
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(
219
+ np.concatenate(audio_list_sent)
220
+ ) # 对完整句子做音量归一
221
+ audio_list.append(audio16bit)
222
+ audio_concat = np.concatenate(audio_list)
223
+ return ("Success", (44100, audio_concat))
224
+
225
+
226
+ def tts_fn(
227
+ text: str,
228
+ speaker,
229
+ sdp_ratio,
230
+ noise_scale,
231
+ noise_scale_w,
232
+ length_scale,
233
+ language,
234
+ reference_audio,
235
+ emotion,
236
+ prompt_mode,
237
+ ):
238
+ if prompt_mode == "Audio prompt":
239
+ if reference_audio == None:
240
+ return ("Invalid audio prompt", None)
241
+ else:
242
+ reference_audio = load_audio(reference_audio)[1]
243
+ else:
244
+ reference_audio = None
245
+ audio_list = []
246
+ if language == "mix":
247
+ bool_valid, str_valid = re_matching.validate_text(text)
248
+ if not bool_valid:
249
+ return str_valid, (
250
+ hps.data.sampling_rate,
251
+ np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
252
+ )
253
+ result = []
254
+ for slice in re_matching.text_matching(text):
255
+ _speaker = slice.pop()
256
+ temp_contant = []
257
+ temp_lang = []
258
+ for lang, content in slice:
259
+ if "|" in content:
260
+ temp = []
261
+ temp_ = []
262
+ for i in content.split("|"):
263
+ if i != "":
264
+ temp.append([i])
265
+ temp_.append([lang])
266
+ else:
267
+ temp.append([])
268
+ temp_.append([])
269
+ temp_contant += temp
270
+ temp_lang += temp_
271
+ else:
272
+ if len(temp_contant) == 0:
273
+ temp_contant.append([])
274
+ temp_lang.append([])
275
+ temp_contant[-1].append(content)
276
+ temp_lang[-1].append(lang)
277
+ for i, j in zip(temp_lang, temp_contant):
278
+ result.append([*zip(i, j), _speaker])
279
+ for i, one in enumerate(result):
280
+ skip_start = i != 0
281
+ skip_end = i != len(result) - 1
282
+ _speaker = one.pop()
283
+ idx = 0
284
+ while idx < len(one):
285
+ text_to_generate = []
286
+ lang_to_generate = []
287
+ while True:
288
+ lang, content = one[idx]
289
+ temp_text = [content]
290
+ if len(text_to_generate) > 0:
291
+ text_to_generate[-1] += [temp_text.pop(0)]
292
+ lang_to_generate[-1] += [lang]
293
+ if len(temp_text) > 0:
294
+ text_to_generate += [[i] for i in temp_text]
295
+ lang_to_generate += [[lang]] * len(temp_text)
296
+ if idx + 1 < len(one):
297
+ idx += 1
298
+ else:
299
+ break
300
+ skip_start = (idx != 0) and skip_start
301
+ skip_end = (idx != len(one) - 1) and skip_end
302
+ print(text_to_generate, lang_to_generate)
303
+ audio_list.extend(
304
+ generate_audio_multilang(
305
+ text_to_generate,
306
+ sdp_ratio,
307
+ noise_scale,
308
+ noise_scale_w,
309
+ length_scale,
310
+ _speaker,
311
+ lang_to_generate,
312
+ reference_audio,
313
+ emotion,
314
+ skip_start,
315
+ skip_end,
316
+ )
317
+ )
318
+ idx += 1
319
+ elif language.lower() == "auto":
320
+ for idx, slice in enumerate(text.split("|")):
321
+ if slice == "":
322
+ continue
323
+ skip_start = idx != 0
324
+ skip_end = idx != len(text.split("|")) - 1
325
+ sentences_list = split_by_language(
326
+ slice, target_languages=["zh", "ja", "en"]
327
+ )
328
+ idx = 0
329
+ while idx < len(sentences_list):
330
+ text_to_generate = []
331
+ lang_to_generate = []
332
+ while True:
333
+ content, lang = sentences_list[idx]
334
+ temp_text = [content]
335
+ lang = lang.upper()
336
+ if lang == "JA":
337
+ lang = "JP"
338
+ if len(text_to_generate) > 0:
339
+ text_to_generate[-1] += [temp_text.pop(0)]
340
+ lang_to_generate[-1] += [lang]
341
+ if len(temp_text) > 0:
342
+ text_to_generate += [[i] for i in temp_text]
343
+ lang_to_generate += [[lang]] * len(temp_text)
344
+ if idx + 1 < len(sentences_list):
345
+ idx += 1
346
+ else:
347
+ break
348
+ skip_start = (idx != 0) and skip_start
349
+ skip_end = (idx != len(sentences_list) - 1) and skip_end
350
+ print(text_to_generate, lang_to_generate)
351
+ audio_list.extend(
352
+ generate_audio_multilang(
353
+ text_to_generate,
354
+ sdp_ratio,
355
+ noise_scale,
356
+ noise_scale_w,
357
+ length_scale,
358
+ speaker,
359
+ lang_to_generate,
360
+ reference_audio,
361
+ emotion,
362
+ skip_start,
363
+ skip_end,
364
+ )
365
+ )
366
+ idx += 1
367
+ else:
368
+ audio_list.extend(
369
+ generate_audio(
370
+ text.split("|"),
371
+ sdp_ratio,
372
+ noise_scale,
373
+ noise_scale_w,
374
+ length_scale,
375
+ speaker,
376
+ language,
377
+ reference_audio,
378
+ emotion,
379
+ )
380
+ )
381
+
382
+ audio_concat = np.concatenate(audio_list)
383
+ return "Success", (hps.data.sampling_rate, audio_concat)
384
+
385
+
386
+ def load_audio(path):
387
+ audio, sr = librosa.load(path, 48000)
388
+ # audio = librosa.resample(audio, 44100, 48000)
389
+ return sr, audio
390
+
391
+
392
+ def gr_util(item):
393
+ if item == "Text prompt":
394
+ return {"visible": True, "__type__": "update"}, {
395
+ "visible": False,
396
+ "__type__": "update",
397
+ }
398
+ else:
399
+ return {"visible": False, "__type__": "update"}, {
400
+ "visible": True,
401
+ "__type__": "update",
402
+ }
403
+
404
+
405
+ if __name__ == "__main__":
406
+ if config.webui_config.debug:
407
+ logger.info("Enable DEBUG-LEVEL log")
408
+ logging.basicConfig(level=logging.DEBUG)
409
+ hps = utils.get_hparams_from_file(config.webui_config.config_path)
410
+ # 若config.json中未指定版本则默认为最新版本
411
+ version = hps.version if hasattr(hps, "version") else latest_version
412
+ net_g = get_net_g(
413
+ model_path=config.webui_config.model, version=version, device=device, hps=hps
414
+ )
415
+ speaker_ids = hps.data.spk2id
416
+ speakers = list(speaker_ids.keys())
417
+ languages = ["ZH", "JP", "EN", "mix", "auto"]
418
+ with gr.Blocks() as app:
419
+ with gr.Row():
420
+ with gr.Column():
421
+ text = gr.TextArea(
422
+ label="输入文本内容",
423
+ placeholder="""
424
+ 如果你选择语言为\'mix\',必须按照格式输入,否则报错:
425
+ 格式举例(zh是中文,jp是日语,不区分大小写;说话人举例:gongzi):
426
+ [说话人1]<zh>你好,こんにちは! <jp>こんにちは,世界。
427
+ [说话人2]<zh>你好吗?<jp>元気ですか?
428
+ [说话人3]<zh>谢谢。<jp>どういたしまして。
429
+ ...
430
+ 另外,所有的语言选项都可以用'|'分割长段实现分句生成。
431
+ """,
432
+ )
433
+ trans = gr.Button("中翻日", variant="primary")
434
+ slicer = gr.Button("快速切分", variant="primary")
435
+ speaker = gr.Dropdown(
436
+ choices=speakers, value=speakers[0], label="Speaker"
437
+ )
438
+ _ = gr.Markdown(
439
+ value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n"
440
+ )
441
+ prompt_mode = gr.Radio(
442
+ ["Text prompt", "Audio prompt"],
443
+ label="Prompt Mode",
444
+ value="Text prompt",
445
+ )
446
+ text_prompt = gr.Textbox(
447
+ label="Text prompt",
448
+ placeholder="用文字描述生成风格。如:Happy",
449
+ value="Happy",
450
+ visible=True,
451
+ )
452
+ audio_prompt = gr.Audio(
453
+ label="Audio prompt", type="filepath", visible=False
454
+ )
455
+ sdp_ratio = gr.Slider(
456
+ minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
457
+ )
458
+ noise_scale = gr.Slider(
459
+ minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
460
+ )
461
+ noise_scale_w = gr.Slider(
462
+ minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise_W"
463
+ )
464
+ length_scale = gr.Slider(
465
+ minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
466
+ )
467
+ language = gr.Dropdown(
468
+ choices=languages, value=languages[0], label="Language"
469
+ )
470
+ btn = gr.Button("生成音频!", variant="primary")
471
+ with gr.Column():
472
+ with gr.Row():
473
+ with gr.Column():
474
+ interval_between_sent = gr.Slider(
475
+ minimum=0,
476
+ maximum=5,
477
+ value=0.2,
478
+ step=0.1,
479
+ label="句间停顿(秒),勾选按句切分才生效",
480
+ )
481
+ interval_between_para = gr.Slider(
482
+ minimum=0,
483
+ maximum=10,
484
+ value=1,
485
+ step=0.1,
486
+ label="段间停顿(秒),需要大于句间停顿才有效",
487
+ )
488
+ opt_cut_by_sent = gr.Checkbox(
489
+ label="按句切分 在按段落切分的基础上再按句子切分文本"
490
+ )
491
+ slicer = gr.Button("切分生成", variant="primary")
492
+ text_output = gr.Textbox(label="状态信息")
493
+ audio_output = gr.Audio(label="输出音频")
494
+ # explain_image = gr.Image(
495
+ # label="参数解释信息",
496
+ # show_label=True,
497
+ # show_share_button=False,
498
+ # show_download_button=False,
499
+ # value=os.path.abspath("./img/参数说明.png"),
500
+ # )
501
+ btn.click(
502
+ tts_fn,
503
+ inputs=[
504
+ text,
505
+ speaker,
506
+ sdp_ratio,
507
+ noise_scale,
508
+ noise_scale_w,
509
+ length_scale,
510
+ language,
511
+ audio_prompt,
512
+ text_prompt,
513
+ prompt_mode,
514
+ ],
515
+ outputs=[text_output, audio_output],
516
+ )
517
+
518
+ trans.click(
519
+ translate,
520
+ inputs=[text],
521
+ outputs=[text],
522
+ )
523
+ slicer.click(
524
+ tts_split,
525
+ inputs=[
526
+ text,
527
+ speaker,
528
+ sdp_ratio,
529
+ noise_scale,
530
+ noise_scale_w,
531
+ length_scale,
532
+ language,
533
+ opt_cut_by_sent,
534
+ interval_between_para,
535
+ interval_between_sent,
536
+ audio_prompt,
537
+ text_prompt,
538
+ ],
539
+ outputs=[text_output, audio_output],
540
+ )
541
+
542
+ prompt_mode.change(
543
+ lambda x: gr_util(x),
544
+ inputs=[prompt_mode],
545
+ outputs=[text_prompt, audio_prompt],
546
+ )
547
+
548
+ audio_prompt.upload(
549
+ lambda x: load_audio(x),
550
+ inputs=[audio_prompt],
551
+ outputs=[audio_prompt],
552
+ )
553
+
554
+ print("推理页面已开启!")
555
+ webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
556
+ app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
infer.py CHANGED
@@ -10,7 +10,8 @@
10
  import torch
11
  import commons
12
  from text import cleaned_text_to_sequence, get_bert
13
- from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
 
14
  from text.cleaner import clean_text
15
  import utils
16
  import numpy as np
@@ -20,47 +21,47 @@ from text.symbols import symbols
20
 
21
  # from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
22
  # from oldVersion.V210.text import symbols as V210symbols
23
- from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
24
- from oldVersion.V200.text import symbols as V200symbols
25
- from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
26
- from oldVersion.V111.text import symbols as V111symbols
27
- from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
28
- from oldVersion.V110.text import symbols as V110symbols
29
- from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
30
- from oldVersion.V101.text import symbols as V101symbols
31
 
32
- from oldVersion import V111, V110, V101, V200 # , V210
33
 
34
  # 当前版本信息
35
- latest_version = "2.2"
36
 
37
  # 版本兼容
38
  SynthesizerTrnMap = {
39
  # "2.1": V210SynthesizerTrn,
40
- "2.0.2-fix": V200SynthesizerTrn,
41
- "2.0.1": V200SynthesizerTrn,
42
- "2.0": V200SynthesizerTrn,
43
- "1.1.1-fix": V111SynthesizerTrn,
44
- "1.1.1": V111SynthesizerTrn,
45
- "1.1": V110SynthesizerTrn,
46
- "1.1.0": V110SynthesizerTrn,
47
- "1.0.1": V101SynthesizerTrn,
48
- "1.0": V101SynthesizerTrn,
49
- "1.0.0": V101SynthesizerTrn,
50
  }
51
 
52
  symbolsMap = {
53
  # "2.1": V210symbols,
54
- "2.0.2-fix": V200symbols,
55
- "2.0.1": V200symbols,
56
- "2.0": V200symbols,
57
- "1.1.1-fix": V111symbols,
58
- "1.1.1": V111symbols,
59
- "1.1": V110symbols,
60
- "1.1.0": V110symbols,
61
- "1.0.1": V101symbols,
62
- "1.0": V101symbols,
63
- "1.0.0": V101symbols,
64
  }
65
 
66
 
@@ -98,7 +99,8 @@ def get_net_g(model_path: str, version: str, device: str, hps):
98
  return net_g
99
 
100
 
101
- def get_text(text, language_str, hps, device):
 
102
  # 在此处实现当前版本的get_text
103
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
104
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
@@ -110,21 +112,23 @@ def get_text(text, language_str, hps, device):
110
  for i in range(len(word2ph)):
111
  word2ph[i] = word2ph[i] * 2
112
  word2ph[0] += 1
113
- bert_ori = get_bert(norm_text, word2ph, language_str, device)
 
 
114
  del word2ph
115
  assert bert_ori.shape[-1] == len(phone), phone
116
 
117
  if language_str == "ZH":
118
  bert = bert_ori
119
- ja_bert = torch.zeros(1024, len(phone))
120
- en_bert = torch.zeros(1024, len(phone))
121
  elif language_str == "JP":
122
- bert = torch.zeros(1024, len(phone))
123
  ja_bert = bert_ori
124
- en_bert = torch.zeros(1024, len(phone))
125
  elif language_str == "EN":
126
- bert = torch.zeros(1024, len(phone))
127
- ja_bert = torch.zeros(1024, len(phone))
128
  en_bert = bert_ori
129
  else:
130
  raise ValueError("language_str should be ZH, JP or EN")
@@ -154,49 +158,54 @@ def infer(
154
  reference_audio=None,
155
  skip_start=False,
156
  skip_end=False,
 
 
157
  ):
158
  # 2.2版本参数位置变了
159
  # 2.1 参数新增 emotion reference_audio skip_start skip_end
160
  # inferMap_V3 = {
161
  # "2.1": V210.infer,
162
- # }
163
  # 支持中日英三语版本
164
  inferMap_V2 = {
165
- "2.0.2-fix": V200.infer,
166
- "2.0.1": V200.infer,
167
- "2.0": V200.infer,
168
- "1.1.1-fix": V111.infer_fix,
169
- "1.1.1": V111.infer,
170
- "1.1": V110.infer,
171
- "1.1.0": V110.infer,
172
  }
173
  # 仅支持中文版本
174
  # 在测试中,并未发现两个版本的模型不能互相通用
175
  inferMap_V1 = {
176
- "1.0.1": V101.infer,
177
- "1.0": V101.infer,
178
- "1.0.0": V101.infer,
179
  }
180
  version = hps.version if hasattr(hps, "version") else latest_version
181
  # 非当前版本,根据版本号选择合适的infer
182
  if version != latest_version:
183
- # if version in inferMap_V3.keys():
184
- # return inferMap_V3[version](
185
- # text,
186
- # sdp_ratio,
187
- # noise_scale,
188
- # noise_scale_w,
189
- # length_scale,
190
- # sid,
191
- # language,
192
- # hps,
193
- # net_g,
194
- # device,
195
- # reference_audio,
196
- # emotion,
197
- # skip_start,
198
- # skip_end,
199
- # )
 
 
 
200
  if version in inferMap_V2.keys():
201
  return inferMap_V2[version](
202
  text,
@@ -224,14 +233,19 @@ def infer(
224
  )
225
  # 在此处实现当前版本的推理
226
  # emo = get_emo_(reference_audio, emotion, sid)
227
- if isinstance(reference_audio, np.ndarray):
228
- emo = get_clap_audio_feature(reference_audio, device)
229
- else:
230
- emo = get_clap_text_feature(emotion, device)
231
- emo = torch.squeeze(emo, dim=1)
232
 
233
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
234
- text, language, hps, device
 
 
 
 
 
235
  )
236
  if skip_start:
237
  phones = phones[3:]
@@ -255,7 +269,7 @@ def infer(
255
  ja_bert = ja_bert.to(device).unsqueeze(0)
256
  en_bert = en_bert.to(device).unsqueeze(0)
257
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
258
- emo = emo.to(device).unsqueeze(0)
259
  del phones
260
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
261
  audio = (
@@ -268,7 +282,6 @@ def infer(
268
  bert,
269
  ja_bert,
270
  en_bert,
271
- emo,
272
  sdp_ratio=sdp_ratio,
273
  noise_scale=noise_scale,
274
  noise_scale_w=noise_scale_w,
@@ -278,7 +291,16 @@ def infer(
278
  .float()
279
  .numpy()
280
  )
281
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
282
  if torch.cuda.is_available():
283
  torch.cuda.empty_cache()
284
  return audio
@@ -302,14 +324,14 @@ def infer_multilang(
302
  ):
303
  bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
304
  # emo = get_emo_(reference_audio, emotion, sid)
305
- if isinstance(reference_audio, np.ndarray):
306
- emo = get_clap_audio_feature(reference_audio, device)
307
- else:
308
- emo = get_clap_text_feature(emotion, device)
309
- emo = torch.squeeze(emo, dim=1)
310
  for idx, (txt, lang) in enumerate(zip(text, language)):
311
- skip_start = (idx != 0) or (skip_start and idx == 0)
312
- skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
313
  (
314
  temp_bert,
315
  temp_ja_bert,
@@ -318,14 +340,14 @@ def infer_multilang(
318
  temp_tones,
319
  temp_lang_ids,
320
  ) = get_text(txt, lang, hps, device)
321
- if skip_start:
322
  temp_bert = temp_bert[:, 3:]
323
  temp_ja_bert = temp_ja_bert[:, 3:]
324
  temp_en_bert = temp_en_bert[:, 3:]
325
  temp_phones = temp_phones[3:]
326
  temp_tones = temp_tones[3:]
327
  temp_lang_ids = temp_lang_ids[3:]
328
- if skip_end:
329
  temp_bert = temp_bert[:, :-2]
330
  temp_ja_bert = temp_ja_bert[:, :-2]
331
  temp_en_bert = temp_en_bert[:, :-2]
@@ -351,7 +373,7 @@ def infer_multilang(
351
  bert = bert.to(device).unsqueeze(0)
352
  ja_bert = ja_bert.to(device).unsqueeze(0)
353
  en_bert = en_bert.to(device).unsqueeze(0)
354
- emo = emo.to(device).unsqueeze(0)
355
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
356
  del phones
357
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
@@ -365,7 +387,6 @@ def infer_multilang(
365
  bert,
366
  ja_bert,
367
  en_bert,
368
- emo,
369
  sdp_ratio=sdp_ratio,
370
  noise_scale=noise_scale,
371
  noise_scale_w=noise_scale_w,
@@ -375,7 +396,16 @@ def infer_multilang(
375
  .float()
376
  .numpy()
377
  )
378
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
379
  if torch.cuda.is_available():
380
  torch.cuda.empty_cache()
381
  return audio
 
10
  import torch
11
  import commons
12
  from text import cleaned_text_to_sequence, get_bert
13
+
14
+ # from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
15
  from text.cleaner import clean_text
16
  import utils
17
  import numpy as np
 
21
 
22
  # from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
23
  # from oldVersion.V210.text import symbols as V210symbols
24
+ # from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
25
+ # from oldVersion.V200.text import symbols as V200symbols
26
+ # from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
27
+ # from oldVersion.V111.text import symbols as V111symbols
28
+ # from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
29
+ # from oldVersion.V110.text import symbols as V110symbols
30
+ # from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
31
+ # from oldVersion.V101.text import symbols as V101symbols
32
 
33
+ # from oldVersion import V111, V110, V101, V200, V210
34
 
35
  # 当前版本信息
36
+ latest_version = "2.3"
37
 
38
  # 版本兼容
39
  SynthesizerTrnMap = {
40
  # "2.1": V210SynthesizerTrn,
41
+ # "2.0.2-fix": V200SynthesizerTrn,
42
+ # "2.0.1": V200SynthesizerTrn,
43
+ # "2.0": V200SynthesizerTrn,
44
+ # "1.1.1-fix": V111SynthesizerTrn,
45
+ # "1.1.1": V111SynthesizerTrn,
46
+ # "1.1": V110SynthesizerTrn,
47
+ # "1.1.0": V110SynthesizerTrn,
48
+ # "1.0.1": V101SynthesizerTrn,
49
+ # "1.0": V101SynthesizerTrn,
50
+ # "1.0.0": V101SynthesizerTrn,
51
  }
52
 
53
  symbolsMap = {
54
  # "2.1": V210symbols,
55
+ # "2.0.2-fix": V200symbols,
56
+ # "2.0.1": V200symbols,
57
+ # "2.0": V200symbols,
58
+ # "1.1.1-fix": V111symbols,
59
+ # "1.1.1": V111symbols,
60
+ # "1.1": V110symbols,
61
+ # "1.1.0": V110symbols,
62
+ # "1.0.1": V101symbols,
63
+ # "1.0": V101symbols,
64
+ # "1.0.0": V101symbols,
65
  }
66
 
67
 
 
99
  return net_g
100
 
101
 
102
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
103
+ style_text = None if style_text == "" else style_text
104
  # 在此处实现当前版本的get_text
105
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
106
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
 
112
  for i in range(len(word2ph)):
113
  word2ph[i] = word2ph[i] * 2
114
  word2ph[0] += 1
115
+ bert_ori = get_bert(
116
+ norm_text, word2ph, language_str, device, style_text, style_weight
117
+ )
118
  del word2ph
119
  assert bert_ori.shape[-1] == len(phone), phone
120
 
121
  if language_str == "ZH":
122
  bert = bert_ori
123
+ ja_bert = torch.randn(1024, len(phone))
124
+ en_bert = torch.randn(1024, len(phone))
125
  elif language_str == "JP":
126
+ bert = torch.randn(1024, len(phone))
127
  ja_bert = bert_ori
128
+ en_bert = torch.randn(1024, len(phone))
129
  elif language_str == "EN":
130
+ bert = torch.randn(1024, len(phone))
131
+ ja_bert = torch.randn(1024, len(phone))
132
  en_bert = bert_ori
133
  else:
134
  raise ValueError("language_str should be ZH, JP or EN")
 
158
  reference_audio=None,
159
  skip_start=False,
160
  skip_end=False,
161
+ style_text=None,
162
+ style_weight=0.7,
163
  ):
164
  # 2.2版本参数位置变了
165
  # 2.1 参数新增 emotion reference_audio skip_start skip_end
166
  # inferMap_V3 = {
167
  # "2.1": V210.infer,
168
+ }
169
  # 支持中日英三语版本
170
  inferMap_V2 = {
171
+ # "2.0.2-fix": V200.infer,
172
+ # "2.0.1": V200.infer,
173
+ # "2.0": V200.infer,
174
+ # "1.1.1-fix": V111.infer_fix,
175
+ # "1.1.1": V111.infer,
176
+ # "1.1": V110.infer,
177
+ # "1.1.0": V110.infer,
178
  }
179
  # 仅支持中文版本
180
  # 在测试中,并未发现两个版本的模型不能互相通用
181
  inferMap_V1 = {
182
+ # "1.0.1": V101.infer,
183
+ # "1.0": V101.infer,
184
+ # "1.0.0": V101.infer,
185
  }
186
  version = hps.version if hasattr(hps, "version") else latest_version
187
  # 非当前版本,根据版本号选择合适的infer
188
  if version != latest_version:
189
+ if version in inferMap_V3.keys():
190
+ emotion = 0
191
+ return inferMap_V3[version](
192
+ text,
193
+ sdp_ratio,
194
+ noise_scale,
195
+ noise_scale_w,
196
+ length_scale,
197
+ sid,
198
+ language,
199
+ hps,
200
+ net_g,
201
+ device,
202
+ reference_audio,
203
+ emotion,
204
+ skip_start,
205
+ skip_end,
206
+ style_text,
207
+ style_weight,
208
+ )
209
  if version in inferMap_V2.keys():
210
  return inferMap_V2[version](
211
  text,
 
233
  )
234
  # 在此处实现当前版本的推理
235
  # emo = get_emo_(reference_audio, emotion, sid)
236
+ # if isinstance(reference_audio, np.ndarray):
237
+ # emo = get_clap_audio_feature(reference_audio, device)
238
+ # else:
239
+ # emo = get_clap_text_feature(emotion, device)
240
+ # emo = torch.squeeze(emo, dim=1)
241
 
242
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
243
+ text,
244
+ language,
245
+ hps,
246
+ device,
247
+ style_text=style_text,
248
+ style_weight=style_weight,
249
  )
250
  if skip_start:
251
  phones = phones[3:]
 
269
  ja_bert = ja_bert.to(device).unsqueeze(0)
270
  en_bert = en_bert.to(device).unsqueeze(0)
271
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
272
+ # emo = emo.to(device).unsqueeze(0)
273
  del phones
274
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
275
  audio = (
 
282
  bert,
283
  ja_bert,
284
  en_bert,
 
285
  sdp_ratio=sdp_ratio,
286
  noise_scale=noise_scale,
287
  noise_scale_w=noise_scale_w,
 
291
  .float()
292
  .numpy()
293
  )
294
+ del (
295
+ x_tst,
296
+ tones,
297
+ lang_ids,
298
+ bert,
299
+ x_tst_lengths,
300
+ speakers,
301
+ ja_bert,
302
+ en_bert,
303
+ ) # , emo
304
  if torch.cuda.is_available():
305
  torch.cuda.empty_cache()
306
  return audio
 
324
  ):
325
  bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
326
  # emo = get_emo_(reference_audio, emotion, sid)
327
+ # if isinstance(reference_audio, np.ndarray):
328
+ # emo = get_clap_audio_feature(reference_audio, device)
329
+ # else:
330
+ # emo = get_clap_text_feature(emotion, device)
331
+ # emo = torch.squeeze(emo, dim=1)
332
  for idx, (txt, lang) in enumerate(zip(text, language)):
333
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
334
+ _skip_end = (idx != len(language) - 1) or skip_end
335
  (
336
  temp_bert,
337
  temp_ja_bert,
 
340
  temp_tones,
341
  temp_lang_ids,
342
  ) = get_text(txt, lang, hps, device)
343
+ if _skip_start:
344
  temp_bert = temp_bert[:, 3:]
345
  temp_ja_bert = temp_ja_bert[:, 3:]
346
  temp_en_bert = temp_en_bert[:, 3:]
347
  temp_phones = temp_phones[3:]
348
  temp_tones = temp_tones[3:]
349
  temp_lang_ids = temp_lang_ids[3:]
350
+ if _skip_end:
351
  temp_bert = temp_bert[:, :-2]
352
  temp_ja_bert = temp_ja_bert[:, :-2]
353
  temp_en_bert = temp_en_bert[:, :-2]
 
373
  bert = bert.to(device).unsqueeze(0)
374
  ja_bert = ja_bert.to(device).unsqueeze(0)
375
  en_bert = en_bert.to(device).unsqueeze(0)
376
+ # emo = emo.to(device).unsqueeze(0)
377
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
378
  del phones
379
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
 
387
  bert,
388
  ja_bert,
389
  en_bert,
 
390
  sdp_ratio=sdp_ratio,
391
  noise_scale=noise_scale,
392
  noise_scale_w=noise_scale_w,
 
396
  .float()
397
  .numpy()
398
  )
399
+ del (
400
+ x_tst,
401
+ tones,
402
+ lang_ids,
403
+ bert,
404
+ x_tst_lengths,
405
+ speakers,
406
+ ja_bert,
407
+ en_bert,
408
+ ) # , emo
409
  if torch.cuda.is_available():
410
  torch.cuda.empty_cache()
411
  return audio
losses.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
 
3
 
4
  def feature_loss(fmap_r, fmap_g):
@@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
56
  kl = torch.sum(kl * z_mask)
57
  l = kl / torch.sum(z_mask)
58
  return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torchaudio
3
+ from transformers import AutoModel
4
 
5
 
6
  def feature_loss(fmap_r, fmap_g):
 
58
  kl = torch.sum(kl * z_mask)
59
  l = kl / torch.sum(z_mask)
60
  return l
61
+
62
+
63
+ class WavLMLoss(torch.nn.Module):
64
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
65
+ super(WavLMLoss, self).__init__()
66
+ self.wavlm = AutoModel.from_pretrained(model)
67
+ self.wd = wd
68
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
69
+ self.wavlm.eval()
70
+ for param in self.wavlm.parameters():
71
+ param.requires_grad = False
72
+
73
+ def forward(self, wav, y_rec):
74
+ with torch.no_grad():
75
+ wav_16 = self.resample(wav)
76
+ wav_embeddings = self.wavlm(
77
+ input_values=wav_16, output_hidden_states=True
78
+ ).hidden_states
79
+ y_rec_16 = self.resample(y_rec)
80
+ y_rec_embeddings = self.wavlm(
81
+ input_values=y_rec_16.squeeze(), output_hidden_states=True
82
+ ).hidden_states
83
+
84
+ floss = 0
85
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
86
+ floss += torch.mean(torch.abs(er - eg))
87
+
88
+ return floss.mean()
89
+
90
+ def generator(self, y_rec):
91
+ y_rec_16 = self.resample(y_rec)
92
+ y_rec_embeddings = self.wavlm(
93
+ input_values=y_rec_16, output_hidden_states=True
94
+ ).hidden_states
95
+ y_rec_embeddings = (
96
+ torch.stack(y_rec_embeddings, dim=1)
97
+ .transpose(-1, -2)
98
+ .flatten(start_dim=1, end_dim=2)
99
+ )
100
+ y_df_hat_g = self.wd(y_rec_embeddings)
101
+ loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
102
+
103
+ return loss_gen
104
+
105
+ def discriminator(self, wav, y_rec):
106
+ with torch.no_grad():
107
+ wav_16 = self.resample(wav)
108
+ wav_embeddings = self.wavlm(
109
+ input_values=wav_16, output_hidden_states=True
110
+ ).hidden_states
111
+ y_rec_16 = self.resample(y_rec)
112
+ y_rec_embeddings = self.wavlm(
113
+ input_values=y_rec_16, output_hidden_states=True
114
+ ).hidden_states
115
+
116
+ y_embeddings = (
117
+ torch.stack(wav_embeddings, dim=1)
118
+ .transpose(-1, -2)
119
+ .flatten(start_dim=1, end_dim=2)
120
+ )
121
+ y_rec_embeddings = (
122
+ torch.stack(y_rec_embeddings, dim=1)
123
+ .transpose(-1, -2)
124
+ .flatten(start_dim=1, end_dim=2)
125
+ )
126
+
127
+ y_d_rs = self.wd(y_embeddings)
128
+ y_d_gs = self.wd(y_rec_embeddings)
129
+
130
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
131
+
132
+ r_loss = torch.mean((1 - y_df_hat_r) ** 2)
133
+ g_loss = torch.mean((y_df_hat_g) ** 2)
134
+
135
+ loss_disc_f = r_loss + g_loss
136
+
137
+ return loss_disc_f.mean()
138
+
139
+ def discriminator_forward(self, wav):
140
+ with torch.no_grad():
141
+ wav_16 = self.resample(wav)
142
+ wav_embeddings = self.wavlm(
143
+ input_values=wav_16, output_hidden_states=True
144
+ ).hidden_states
145
+ y_embeddings = (
146
+ torch.stack(wav_embeddings, dim=1)
147
+ .transpose(-1, -2)
148
+ .flatten(start_dim=1, end_dim=2)
149
+ )
150
+
151
+ y_d_rs = self.wd(y_embeddings)
152
+
153
+ return y_d_rs
models.py CHANGED
@@ -40,33 +40,22 @@ class DurationDiscriminator(nn.Module): # vits2
40
  self.norm_2 = modules.LayerNorm(filter_channels)
41
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
42
 
43
- self.pre_out_conv_1 = nn.Conv1d(
44
- 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
  )
46
- self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
47
- self.pre_out_conv_2 = nn.Conv1d(
48
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
49
- )
50
- self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
51
 
52
  if gin_channels != 0:
53
  self.cond = nn.Conv1d(gin_channels, in_channels, 1)
54
 
55
- self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
 
 
56
 
57
- def forward_probability(self, x, x_mask, dur, g=None):
58
  dur = self.dur_proj(dur)
59
  x = torch.cat([x, dur], dim=1)
60
- x = self.pre_out_conv_1(x * x_mask)
61
- x = torch.relu(x)
62
- x = self.pre_out_norm_1(x)
63
- x = self.drop(x)
64
- x = self.pre_out_conv_2(x * x_mask)
65
- x = torch.relu(x)
66
- x = self.pre_out_norm_2(x)
67
- x = self.drop(x)
68
- x = x * x_mask
69
  x = x.transpose(1, 2)
 
70
  output_prob = self.output_layer(x)
71
  return output_prob
72
 
@@ -86,7 +75,7 @@ class DurationDiscriminator(nn.Module): # vits2
86
 
87
  output_probs = []
88
  for dur in [dur_r, dur_hat]:
89
- output_prob = self.forward_probability(x, x_mask, dur, g)
90
  output_probs.append(output_prob)
91
 
92
  return output_probs
@@ -354,7 +343,6 @@ class TextEncoder(nn.Module):
354
  n_layers,
355
  kernel_size,
356
  p_dropout,
357
- n_speakers,
358
  gin_channels=0,
359
  ):
360
  super().__init__()
@@ -376,31 +364,6 @@ class TextEncoder(nn.Module):
376
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
377
  self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
378
  self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
379
- # self.emo_proj = nn.Linear(512, hidden_channels)
380
- self.in_feature_net = nn.Sequential(
381
- # input is assumed to an already normalized embedding
382
- nn.Linear(512, 1028, bias=False),
383
- nn.GELU(),
384
- nn.LayerNorm(1028),
385
- *[Block(1028, 512) for _ in range(1)],
386
- nn.Linear(1028, 512, bias=False),
387
- # normalize before passing to VQ?
388
- # nn.GELU(),
389
- # nn.LayerNorm(512),
390
- )
391
- self.emo_vq = VectorQuantize(
392
- dim=512,
393
- codebook_size=64,
394
- codebook_dim=32,
395
- commitment_weight=0.1,
396
- decay=0.85,
397
- heads=32,
398
- kmeans_iters=20,
399
- separate_codebook_per_head=True,
400
- stochastic_sample_codes=True,
401
- threshold_ema_dead_code=2,
402
- )
403
- self.out_feature_net = nn.Linear(512, hidden_channels)
404
 
405
  self.encoder = attentions.Encoder(
406
  hidden_channels,
@@ -413,18 +376,10 @@ class TextEncoder(nn.Module):
413
  )
414
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
415
 
416
- def forward(
417
- self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
418
- ):
419
- sid = sid.cpu()
420
  bert_emb = self.bert_proj(bert).transpose(1, 2)
421
  ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
422
  en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
423
- emo_emb = self.in_feature_net(emo)
424
- emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
425
- loss_commit = loss_commit.mean()
426
- emo_emb = self.out_feature_net(emo_emb)
427
- # emo_emb = self.emo_proj(emo.unsqueeze(1))
428
  x = (
429
  self.emb(x)
430
  + self.tone_emb(tone)
@@ -432,7 +387,6 @@ class TextEncoder(nn.Module):
432
  + bert_emb
433
  + ja_bert_emb
434
  + en_bert_emb
435
- + emo_emb
436
  ) * math.sqrt(
437
  self.hidden_channels
438
  ) # [b, t, h]
@@ -445,7 +399,7 @@ class TextEncoder(nn.Module):
445
  stats = self.proj(x) * x_mask
446
 
447
  m, logs = torch.split(stats, self.out_channels, dim=1)
448
- return x, m, logs, x_mask, loss_commit
449
 
450
 
451
  class ResidualCouplingBlock(nn.Module):
@@ -748,6 +702,55 @@ class MultiPeriodDiscriminator(torch.nn.Module):
748
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
749
 
750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  class ReferenceEncoder(nn.Module):
752
  """
753
  inputs --- [N, Ty/r, n_mels*r] mels
@@ -878,7 +881,6 @@ class SynthesizerTrn(nn.Module):
878
  n_layers,
879
  kernel_size,
880
  p_dropout,
881
- self.n_speakers,
882
  gin_channels=self.enc_gin_channels,
883
  )
884
  self.dec = Generator(
@@ -946,14 +948,13 @@ class SynthesizerTrn(nn.Module):
946
  bert,
947
  ja_bert,
948
  en_bert,
949
- emo=None,
950
  ):
951
  if self.n_speakers > 0:
952
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
953
  else:
954
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
955
- x, m_p, logs_p, x_mask, loss_commit = self.enc_p(
956
- x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
957
  )
958
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
959
  z_p = self.flow(z, y_mask, g=g)
@@ -996,9 +997,11 @@ class SynthesizerTrn(nn.Module):
996
 
997
  logw_ = torch.log(w + 1e-6) * x_mask
998
  logw = self.dp(x, x_mask, g=g)
 
999
  l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1000
  x_mask
1001
  ) # for averaging
 
1002
 
1003
  l_length = l_length_dp + l_length_sdp
1004
 
@@ -1018,9 +1021,8 @@ class SynthesizerTrn(nn.Module):
1018
  x_mask,
1019
  y_mask,
1020
  (z, z_p, m_p, logs_p, m_q, logs_q),
1021
- (x, logw, logw_),
1022
  g,
1023
- loss_commit,
1024
  )
1025
 
1026
  def infer(
@@ -1033,7 +1035,6 @@ class SynthesizerTrn(nn.Module):
1033
  bert,
1034
  ja_bert,
1035
  en_bert,
1036
- emo=None,
1037
  noise_scale=0.667,
1038
  length_scale=1,
1039
  noise_scale_w=0.8,
@@ -1047,8 +1048,8 @@ class SynthesizerTrn(nn.Module):
1047
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1048
  else:
1049
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1050
- x, m_p, logs_p, x_mask, _ = self.enc_p(
1051
- x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
1052
  )
1053
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1054
  sdp_ratio
 
40
  self.norm_2 = modules.LayerNorm(filter_channels)
41
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
42
 
43
+ self.LSTM = nn.LSTM(
44
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
45
  )
 
 
 
 
 
46
 
47
  if gin_channels != 0:
48
  self.cond = nn.Conv1d(gin_channels, in_channels, 1)
49
 
50
+ self.output_layer = nn.Sequential(
51
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
52
+ )
53
 
54
+ def forward_probability(self, x, dur):
55
  dur = self.dur_proj(dur)
56
  x = torch.cat([x, dur], dim=1)
 
 
 
 
 
 
 
 
 
57
  x = x.transpose(1, 2)
58
+ x, _ = self.LSTM(x)
59
  output_prob = self.output_layer(x)
60
  return output_prob
61
 
 
75
 
76
  output_probs = []
77
  for dur in [dur_r, dur_hat]:
78
+ output_prob = self.forward_probability(x, dur)
79
  output_probs.append(output_prob)
80
 
81
  return output_probs
 
343
  n_layers,
344
  kernel_size,
345
  p_dropout,
 
346
  gin_channels=0,
347
  ):
348
  super().__init__()
 
364
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
365
  self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
366
  self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  self.encoder = attentions.Encoder(
369
  hidden_channels,
 
376
  )
377
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
378
 
379
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
 
 
 
380
  bert_emb = self.bert_proj(bert).transpose(1, 2)
381
  ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
382
  en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
 
 
 
 
 
383
  x = (
384
  self.emb(x)
385
  + self.tone_emb(tone)
 
387
  + bert_emb
388
  + ja_bert_emb
389
  + en_bert_emb
 
390
  ) * math.sqrt(
391
  self.hidden_channels
392
  ) # [b, t, h]
 
399
  stats = self.proj(x) * x_mask
400
 
401
  m, logs = torch.split(stats, self.out_channels, dim=1)
402
+ return x, m, logs, x_mask
403
 
404
 
405
  class ResidualCouplingBlock(nn.Module):
 
702
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
703
 
704
 
705
+ class WavLMDiscriminator(nn.Module):
706
+ """docstring for Discriminator."""
707
+
708
+ def __init__(
709
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
710
+ ):
711
+ super(WavLMDiscriminator, self).__init__()
712
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
713
+ self.pre = norm_f(
714
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
715
+ )
716
+
717
+ self.convs = nn.ModuleList(
718
+ [
719
+ norm_f(
720
+ nn.Conv1d(
721
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
722
+ )
723
+ ),
724
+ norm_f(
725
+ nn.Conv1d(
726
+ initial_channel * 2,
727
+ initial_channel * 4,
728
+ kernel_size=5,
729
+ padding=2,
730
+ )
731
+ ),
732
+ norm_f(
733
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
734
+ ),
735
+ ]
736
+ )
737
+
738
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
739
+
740
+ def forward(self, x):
741
+ x = self.pre(x)
742
+
743
+ fmap = []
744
+ for l in self.convs:
745
+ x = l(x)
746
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
747
+ fmap.append(x)
748
+ x = self.conv_post(x)
749
+ x = torch.flatten(x, 1, -1)
750
+
751
+ return x
752
+
753
+
754
  class ReferenceEncoder(nn.Module):
755
  """
756
  inputs --- [N, Ty/r, n_mels*r] mels
 
881
  n_layers,
882
  kernel_size,
883
  p_dropout,
 
884
  gin_channels=self.enc_gin_channels,
885
  )
886
  self.dec = Generator(
 
948
  bert,
949
  ja_bert,
950
  en_bert,
 
951
  ):
952
  if self.n_speakers > 0:
953
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
954
  else:
955
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
956
+ x, m_p, logs_p, x_mask = self.enc_p(
957
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
958
  )
959
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
960
  z_p = self.flow(z, y_mask, g=g)
 
997
 
998
  logw_ = torch.log(w + 1e-6) * x_mask
999
  logw = self.dp(x, x_mask, g=g)
1000
+ logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
1001
  l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1002
  x_mask
1003
  ) # for averaging
1004
+ l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1005
 
1006
  l_length = l_length_dp + l_length_sdp
1007
 
 
1021
  x_mask,
1022
  y_mask,
1023
  (z, z_p, m_p, logs_p, m_q, logs_q),
1024
+ (x, logw, logw_, logw_sdp),
1025
  g,
 
1026
  )
1027
 
1028
  def infer(
 
1035
  bert,
1036
  ja_bert,
1037
  en_bert,
 
1038
  noise_scale=0.667,
1039
  length_scale=1,
1040
  noise_scale_w=0.8,
 
1048
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1049
  else:
1050
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1051
+ x, m_p, logs_p, x_mask = self.enc_p(
1052
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
1053
  )
1054
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1055
  sdp_ratio
oldVersion/V210/__init__.py CHANGED
@@ -5,10 +5,9 @@ import torch
5
  import commons
6
  from .text import cleaned_text_to_sequence, get_bert
7
  from .text.cleaner import clean_text
8
- from .emo_gen import get_emo
9
 
10
 
11
- def get_text(text, language_str, hps, device):
12
  # 在此处实现当前版本的get_text
13
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
14
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
@@ -20,7 +19,9 @@ def get_text(text, language_str, hps, device):
20
  for i in range(len(word2ph)):
21
  word2ph[i] = word2ph[i] * 2
22
  word2ph[0] += 1
23
- bert_ori = get_bert(norm_text, word2ph, language_str, device)
 
 
24
  del word2ph
25
  assert bert_ori.shape[-1] == len(phone), phone
26
 
@@ -50,6 +51,8 @@ def get_text(text, language_str, hps, device):
50
 
51
 
52
  def get_emo_(reference_audio, emotion):
 
 
53
  emo = (
54
  torch.from_numpy(get_emo(reference_audio))
55
  if reference_audio
@@ -73,9 +76,11 @@ def infer(
73
  emotion=None,
74
  skip_start=False,
75
  skip_end=False,
 
 
76
  ):
77
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
78
- text, language, hps, device
79
  )
80
  emo = get_emo_(reference_audio, emotion)
81
  if skip_start:
 
5
  import commons
6
  from .text import cleaned_text_to_sequence, get_bert
7
  from .text.cleaner import clean_text
 
8
 
9
 
10
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
11
  # 在此处实现当前版本的get_text
12
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
13
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
 
19
  for i in range(len(word2ph)):
20
  word2ph[i] = word2ph[i] * 2
21
  word2ph[0] += 1
22
+ bert_ori = get_bert(
23
+ norm_text, word2ph, language_str, device, style_text, style_weight
24
+ )
25
  del word2ph
26
  assert bert_ori.shape[-1] == len(phone), phone
27
 
 
51
 
52
 
53
  def get_emo_(reference_audio, emotion):
54
+ from .emo_gen import get_emo
55
+
56
  emo = (
57
  torch.from_numpy(get_emo(reference_audio))
58
  if reference_audio
 
76
  emotion=None,
77
  skip_start=False,
78
  skip_end=False,
79
+ style_text=None,
80
+ style_weight=0.7,
81
  ):
82
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
83
+ text, language, hps, device, style_text, style_weight
84
  )
85
  emo = get_emo_(reference_audio, emotion)
86
  if skip_start:
oldVersion/V210/models.py CHANGED
@@ -13,7 +13,7 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
  from vector_quantize_pytorch import VectorQuantize
14
 
15
  from commons import init_weights, get_padding
16
- from text import symbols, num_tones, num_languages
17
 
18
 
19
  class DurationDiscriminator(nn.Module): # vits2
 
13
  from vector_quantize_pytorch import VectorQuantize
14
 
15
  from commons import init_weights, get_padding
16
+ from .text import symbols, num_tones, num_languages
17
 
18
 
19
  class DurationDiscriminator(nn.Module): # vits2
oldVersion/V210/text/__init__.py CHANGED
@@ -18,13 +18,15 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
18
  return phones, tones, lang_ids
19
 
20
 
21
- def get_bert(norm_text, word2ph, language, device):
22
  from .chinese_bert import get_bert_feature as zh_bert
23
  from .english_bert_mock import get_bert_feature as en_bert
24
  from .japanese_bert import get_bert_feature as jp_bert
25
 
26
  lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
27
- bert = lang_bert_func_map[language](norm_text, word2ph, device)
 
 
28
  return bert
29
 
30
 
 
18
  return phones, tones, lang_ids
19
 
20
 
21
+ def get_bert(norm_text, word2ph, language, device, style_text, style_weight):
22
  from .chinese_bert import get_bert_feature as zh_bert
23
  from .english_bert_mock import get_bert_feature as en_bert
24
  from .japanese_bert import get_bert_feature as jp_bert
25
 
26
  lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
27
+ bert = lang_bert_func_map[language](
28
+ norm_text, word2ph, device, style_text, style_weight
29
+ )
30
  return bert
31
 
32
 
oldVersion/V210/text/chinese_bert.py CHANGED
@@ -12,7 +12,13 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
12
  models = dict()
13
 
14
 
15
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
 
 
 
 
 
 
16
  if (
17
  sys.platform == "darwin"
18
  and torch.backends.mps.is_available()
@@ -29,12 +35,25 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
29
  inputs[i] = inputs[i].to(device)
30
  res = models[device](**inputs, output_hidden_states=True)
31
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
 
 
 
 
 
 
 
32
 
33
  assert len(word2ph) == len(text) + 2
34
  word2phone = word2ph
35
  phone_level_feature = []
36
  for i in range(len(word2phone)):
37
- repeat_feature = res[i].repeat(word2phone[i], 1)
 
 
 
 
 
 
38
  phone_level_feature.append(repeat_feature)
39
 
40
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
12
  models = dict()
13
 
14
 
15
+ def get_bert_feature(
16
+ text,
17
+ word2ph,
18
+ device=config.bert_gen_config.device,
19
+ style_text=None,
20
+ style_weight=0.7,
21
+ ):
22
  if (
23
  sys.platform == "darwin"
24
  and torch.backends.mps.is_available()
 
35
  inputs[i] = inputs[i].to(device)
36
  res = models[device](**inputs, output_hidden_states=True)
37
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
38
+ if style_text:
39
+ style_inputs = tokenizer(style_text, return_tensors="pt")
40
+ for i in style_inputs:
41
+ style_inputs[i] = style_inputs[i].to(device)
42
+ style_res = models[device](**style_inputs, output_hidden_states=True)
43
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
44
+ style_res_mean = style_res.mean(0)
45
 
46
  assert len(word2ph) == len(text) + 2
47
  word2phone = word2ph
48
  phone_level_feature = []
49
  for i in range(len(word2phone)):
50
+ if style_text:
51
+ repeat_feature = (
52
+ res[i].repeat(word2phone[i], 1) * (1 - style_weight)
53
+ + style_res_mean.repeat(word2phone[i], 1) * style_weight
54
+ )
55
+ else:
56
+ repeat_feature = res[i].repeat(word2phone[i], 1)
57
  phone_level_feature.append(repeat_feature)
58
 
59
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
oldVersion/V210/text/english_bert_mock.py CHANGED
@@ -13,7 +13,13 @@ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
13
  models = dict()
14
 
15
 
16
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
 
 
 
 
 
 
17
  if (
18
  sys.platform == "darwin"
19
  and torch.backends.mps.is_available()
@@ -30,11 +36,24 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
30
  inputs[i] = inputs[i].to(device)
31
  res = models[device](**inputs, output_hidden_states=True)
32
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
 
 
 
 
 
 
 
33
  assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
34
  word2phone = word2ph
35
  phone_level_feature = []
36
  for i in range(len(word2phone)):
37
- repeat_feature = res[i].repeat(word2phone[i], 1)
 
 
 
 
 
 
38
  phone_level_feature.append(repeat_feature)
39
 
40
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
13
  models = dict()
14
 
15
 
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ style_text=None,
21
+ style_weight=0.7,
22
+ ):
23
  if (
24
  sys.platform == "darwin"
25
  and torch.backends.mps.is_available()
 
36
  inputs[i] = inputs[i].to(device)
37
  res = models[device](**inputs, output_hidden_states=True)
38
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
39
+ if style_text:
40
+ style_inputs = tokenizer(style_text, return_tensors="pt")
41
+ for i in style_inputs:
42
+ style_inputs[i] = style_inputs[i].to(device)
43
+ style_res = models[device](**style_inputs, output_hidden_states=True)
44
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
45
+ style_res_mean = style_res.mean(0)
46
  assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
47
  word2phone = word2ph
48
  phone_level_feature = []
49
  for i in range(len(word2phone)):
50
+ if style_text:
51
+ repeat_feature = (
52
+ res[i].repeat(word2phone[i], 1) * (1 - style_weight)
53
+ + style_res_mean.repeat(word2phone[i], 1) * style_weight
54
+ )
55
+ else:
56
+ repeat_feature = res[i].repeat(word2phone[i], 1)
57
  phone_level_feature.append(repeat_feature)
58
 
59
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
oldVersion/V210/text/japanese_bert.py CHANGED
@@ -13,8 +13,16 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
13
  models = dict()
14
 
15
 
16
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
 
 
 
 
 
 
17
  text = "".join(text2sep_kata(text)[0])
 
 
18
  if (
19
  sys.platform == "darwin"
20
  and torch.backends.mps.is_available()
@@ -31,12 +39,25 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
31
  inputs[i] = inputs[i].to(device)
32
  res = models[device](**inputs, output_hidden_states=True)
33
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
 
 
 
 
 
 
 
34
 
35
  assert len(word2ph) == len(text) + 2
36
  word2phone = word2ph
37
  phone_level_feature = []
38
  for i in range(len(word2phone)):
39
- repeat_feature = res[i].repeat(word2phone[i], 1)
 
 
 
 
 
 
40
  phone_level_feature.append(repeat_feature)
41
 
42
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
13
  models = dict()
14
 
15
 
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ style_text=None,
21
+ style_weight=0.7,
22
+ ):
23
  text = "".join(text2sep_kata(text)[0])
24
+ if style_text:
25
+ style_text = "".join(text2sep_kata(style_text)[0])
26
  if (
27
  sys.platform == "darwin"
28
  and torch.backends.mps.is_available()
 
39
  inputs[i] = inputs[i].to(device)
40
  res = models[device](**inputs, output_hidden_states=True)
41
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
42
+ if style_text:
43
+ style_inputs = tokenizer(style_text, return_tensors="pt")
44
+ for i in style_inputs:
45
+ style_inputs[i] = style_inputs[i].to(device)
46
+ style_res = models[device](**style_inputs, output_hidden_states=True)
47
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
48
+ style_res_mean = style_res.mean(0)
49
 
50
  assert len(word2ph) == len(text) + 2
51
  word2phone = word2ph
52
  phone_level_feature = []
53
  for i in range(len(word2phone)):
54
+ if style_text:
55
+ repeat_feature = (
56
+ res[i].repeat(word2phone[i], 1) * (1 - style_weight)
57
+ + style_res_mean.repeat(word2phone[i], 1) * style_weight
58
+ )
59
+ else:
60
+ repeat_feature = res[i].repeat(word2phone[i], 1)
61
  phone_level_feature.append(repeat_feature)
62
 
63
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
onnx_infer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from onnx_modules.V220_OnnxInference import OnnxInferenceSession
2
+ import numpy as np
3
+ Session = OnnxInferenceSession(
4
+ {
5
+ "enc" : "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
6
+ "emb_g" : "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
7
+ "dp" : "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
8
+ "sdp" : "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
9
+ "flow" : "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
10
+ "dec" : "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx"
11
+ },
12
+ Providers = ["CPUExecutionProvider"]
13
+ )
14
+
15
+ #这里的输入和原版是一样的,只需要在原版预处理结果出来之后加上.numpy()即可
16
+ x = np.array(
17
+ [
18
+ 0,
19
+ 97,
20
+ 0,
21
+ 8,
22
+ 0,
23
+ 78,
24
+ 0,
25
+ 8,
26
+ 0,
27
+ 76,
28
+ 0,
29
+ 37,
30
+ 0,
31
+ 40,
32
+ 0,
33
+ 97,
34
+ 0,
35
+ 8,
36
+ 0,
37
+ 23,
38
+ 0,
39
+ 8,
40
+ 0,
41
+ 74,
42
+ 0,
43
+ 26,
44
+ 0,
45
+ 104,
46
+ 0,
47
+ ]
48
+ )
49
+ tone = np.zeros_like(x)
50
+ language = np.zeros_like(x)
51
+ sid = np.array([0])
52
+ bert = np.random.randn(x.shape[0], 1024)
53
+ ja_bert = np.random.randn(x.shape[0], 1024)
54
+ en_bert = np.random.randn(x.shape[0], 1024)
55
+ emo = np.random.randn(512, 1)
56
+
57
+ audio = Session(
58
+ x,
59
+ tone,
60
+ language,
61
+ bert,
62
+ ja_bert,
63
+ en_bert,
64
+ emo,
65
+ sid
66
+ )
67
+
68
+ print(audio)
onnx_modules/V200/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .text.symbols import symbols
2
+ from .models_onnx import SynthesizerTrn
3
+
4
+ __all__ = ["symbols", "SynthesizerTrn"]
onnx_modules/V200_OnnxInference/__init__.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+
4
+
5
+ def convert_pad_shape(pad_shape):
6
+ layer = pad_shape[::-1]
7
+ pad_shape = [item for sublist in layer for item in sublist]
8
+ return pad_shape
9
+
10
+
11
+ def sequence_mask(length, max_length=None):
12
+ if max_length is None:
13
+ max_length = length.max()
14
+ x = np.arange(max_length, dtype=length.dtype)
15
+ return np.expand_dims(x, 0) < np.expand_dims(length, 1)
16
+
17
+
18
+ def generate_path(duration, mask):
19
+ """
20
+ duration: [b, 1, t_x]
21
+ mask: [b, 1, t_y, t_x]
22
+ """
23
+
24
+ b, _, t_y, t_x = mask.shape
25
+ cum_duration = np.cumsum(duration, -1)
26
+
27
+ cum_duration_flat = cum_duration.reshape(b * t_x)
28
+ path = sequence_mask(cum_duration_flat, t_y)
29
+ path = path.reshape(b, t_x, t_y)
30
+ path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
31
+ path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
32
+ return path
33
+
34
+
35
+ class OnnxInferenceSession:
36
+ def __init__(self, path, Providers=["CPUExecutionProvider"]):
37
+ self.enc = ort.InferenceSession(path["enc"], providers=Providers)
38
+ self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
39
+ self.dp = ort.InferenceSession(path["dp"], providers=Providers)
40
+ self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
41
+ self.flow = ort.InferenceSession(path["flow"], providers=Providers)
42
+ self.dec = ort.InferenceSession(path["dec"], providers=Providers)
43
+
44
+ def __call__(
45
+ self,
46
+ seq,
47
+ tone,
48
+ language,
49
+ bert_zh,
50
+ bert_jp,
51
+ bert_en,
52
+ sid,
53
+ seed=114514,
54
+ seq_noise_scale=0.8,
55
+ sdp_noise_scale=0.6,
56
+ length_scale=1.0,
57
+ sdp_ratio=0.0,
58
+ ):
59
+ if seq.ndim == 1:
60
+ seq = np.expand_dims(seq, 0)
61
+ if tone.ndim == 1:
62
+ tone = np.expand_dims(tone, 0)
63
+ if language.ndim == 1:
64
+ language = np.expand_dims(language, 0)
65
+ assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
66
+ g = self.emb_g.run(
67
+ None,
68
+ {
69
+ "sid": sid.astype(np.int64),
70
+ },
71
+ )[0]
72
+ g = np.expand_dims(g, -1)
73
+ enc_rtn = self.enc.run(
74
+ None,
75
+ {
76
+ "x": seq.astype(np.int64),
77
+ "t": tone.astype(np.int64),
78
+ "language": language.astype(np.int64),
79
+ "bert_0": bert_zh.astype(np.float32),
80
+ "bert_1": bert_jp.astype(np.float32),
81
+ "bert_2": bert_en.astype(np.float32),
82
+ "g": g.astype(np.float32),
83
+ },
84
+ )
85
+ x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
86
+ np.random.seed(seed)
87
+ zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
88
+ logw = self.sdp.run(
89
+ None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
90
+ )[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
91
+ 0
92
+ ] * (
93
+ 1 - sdp_ratio
94
+ )
95
+ w = np.exp(logw) * x_mask * length_scale
96
+ w_ceil = np.ceil(w)
97
+ y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
98
+ np.int64
99
+ )
100
+ y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
101
+ attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
102
+ attn = generate_path(w_ceil, attn_mask)
103
+ m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
104
+ 0, 2, 1
105
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
106
+ logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
107
+ 0, 2, 1
108
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
109
+
110
+ z_p = (
111
+ m_p
112
+ + np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
113
+ * np.exp(logs_p)
114
+ * seq_noise_scale
115
+ )
116
+
117
+ z = self.flow.run(
118
+ None,
119
+ {
120
+ "z_p": z_p.astype(np.float32),
121
+ "y_mask": y_mask.astype(np.float32),
122
+ "g": g,
123
+ },
124
+ )[0]
125
+
126
+ return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
onnx_modules/V210/__init__.py CHANGED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .text.symbols import symbols
2
+ from .models_onnx import SynthesizerTrn
3
+
4
+ __all__ = ["symbols", "SynthesizerTrn"]
onnx_modules/V210/models_onnx.py CHANGED
@@ -942,7 +942,7 @@ class SynthesizerTrn(nn.Module):
942
 
943
  torch.onnx.export(
944
  self.enc_p,
945
- (x, x_lengths, tone, language, bert, ja_bert, en_bert, g, sid + 1, sid + 2),
946
  f"onnx/{path}/{path}_enc_p.onnx",
947
  input_names=[
948
  "x",
 
942
 
943
  torch.onnx.export(
944
  self.enc_p,
945
+ (x, x_lengths, tone, language, bert, ja_bert, en_bert, g, sid, sid),
946
  f"onnx/{path}/{path}_enc_p.onnx",
947
  input_names=[
948
  "x",
onnx_modules/V210_OnnxInference/__init__.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+
4
+
5
+ def convert_pad_shape(pad_shape):
6
+ layer = pad_shape[::-1]
7
+ pad_shape = [item for sublist in layer for item in sublist]
8
+ return pad_shape
9
+
10
+
11
+ def sequence_mask(length, max_length=None):
12
+ if max_length is None:
13
+ max_length = length.max()
14
+ x = np.arange(max_length, dtype=length.dtype)
15
+ return np.expand_dims(x, 0) < np.expand_dims(length, 1)
16
+
17
+
18
+ def generate_path(duration, mask):
19
+ """
20
+ duration: [b, 1, t_x]
21
+ mask: [b, 1, t_y, t_x]
22
+ """
23
+
24
+ b, _, t_y, t_x = mask.shape
25
+ cum_duration = np.cumsum(duration, -1)
26
+
27
+ cum_duration_flat = cum_duration.reshape(b * t_x)
28
+ path = sequence_mask(cum_duration_flat, t_y)
29
+ path = path.reshape(b, t_x, t_y)
30
+ path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
31
+ path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
32
+ return path
33
+
34
+
35
+ class OnnxInferenceSession:
36
+ def __init__(self, path, Providers=["CPUExecutionProvider"]):
37
+ self.enc = ort.InferenceSession(path["enc"], providers=Providers)
38
+ self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
39
+ self.dp = ort.InferenceSession(path["dp"], providers=Providers)
40
+ self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
41
+ self.flow = ort.InferenceSession(path["flow"], providers=Providers)
42
+ self.dec = ort.InferenceSession(path["dec"], providers=Providers)
43
+
44
+ def __call__(
45
+ self,
46
+ seq,
47
+ tone,
48
+ language,
49
+ bert_zh,
50
+ bert_jp,
51
+ bert_en,
52
+ vqidx,
53
+ sid,
54
+ seed=114514,
55
+ seq_noise_scale=0.8,
56
+ sdp_noise_scale=0.6,
57
+ length_scale=1.0,
58
+ sdp_ratio=0.0,
59
+ ):
60
+ if seq.ndim == 1:
61
+ seq = np.expand_dims(seq, 0)
62
+ if tone.ndim == 1:
63
+ tone = np.expand_dims(tone, 0)
64
+ if language.ndim == 1:
65
+ language = np.expand_dims(language, 0)
66
+ assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
67
+ g = self.emb_g.run(
68
+ None,
69
+ {
70
+ "sid": sid.astype(np.int64),
71
+ },
72
+ )[0]
73
+ g = np.expand_dims(g, -1)
74
+ enc_rtn = self.enc.run(
75
+ None,
76
+ {
77
+ "x": seq.astype(np.int64),
78
+ "t": tone.astype(np.int64),
79
+ "language": language.astype(np.int64),
80
+ "bert_0": bert_zh.astype(np.float32),
81
+ "bert_1": bert_jp.astype(np.float32),
82
+ "bert_2": bert_en.astype(np.float32),
83
+ "g": g.astype(np.float32),
84
+ "vqidx": vqidx.astype(np.int64),
85
+ "sid": sid.astype(np.int64)
86
+ },
87
+ )
88
+ x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
89
+ np.random.seed(seed)
90
+ zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
91
+ logw = self.sdp.run(
92
+ None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
93
+ )[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
94
+ 0
95
+ ] * (
96
+ 1 - sdp_ratio
97
+ )
98
+ w = np.exp(logw) * x_mask * length_scale
99
+ w_ceil = np.ceil(w)
100
+ y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
101
+ np.int64
102
+ )
103
+ y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
104
+ attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
105
+ attn = generate_path(w_ceil, attn_mask)
106
+ m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
107
+ 0, 2, 1
108
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
109
+ logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
110
+ 0, 2, 1
111
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
112
+
113
+ z_p = (
114
+ m_p
115
+ + np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
116
+ * np.exp(logs_p)
117
+ * seq_noise_scale
118
+ )
119
+
120
+ z = self.flow.run(
121
+ None,
122
+ {
123
+ "z_p": z_p.astype(np.float32),
124
+ "y_mask": y_mask.astype(np.float32),
125
+ "g": g,
126
+ },
127
+ )[0]
128
+
129
+ return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
onnx_modules/V220/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .text.symbols import symbols
2
+ from .models_onnx import SynthesizerTrn
3
+
4
+ __all__ = ["symbols", "SynthesizerTrn"]
onnx_modules/V220/attentions_onnx.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+ # if isflow:
59
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
62
+ # self.gin_channels = 256
63
+ self.cond_layer_idx = self.n_layers
64
+ if "gin_channels" in kwargs:
65
+ self.gin_channels = kwargs["gin_channels"]
66
+ if self.gin_channels != 0:
67
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
+ # vits2 says 3rd block, so idx is 2 by default
69
+ self.cond_layer_idx = (
70
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
+ )
72
+ logging.debug(self.gin_channels, self.cond_layer_idx)
73
+ assert (
74
+ self.cond_layer_idx < self.n_layers
75
+ ), "cond_layer_idx should be less than n_layers"
76
+ self.drop = nn.Dropout(p_dropout)
77
+ self.attn_layers = nn.ModuleList()
78
+ self.norm_layers_1 = nn.ModuleList()
79
+ self.ffn_layers = nn.ModuleList()
80
+ self.norm_layers_2 = nn.ModuleList()
81
+ for i in range(self.n_layers):
82
+ self.attn_layers.append(
83
+ MultiHeadAttention(
84
+ hidden_channels,
85
+ hidden_channels,
86
+ n_heads,
87
+ p_dropout=p_dropout,
88
+ window_size=window_size,
89
+ )
90
+ )
91
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
92
+ self.ffn_layers.append(
93
+ FFN(
94
+ hidden_channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ kernel_size,
98
+ p_dropout=p_dropout,
99
+ )
100
+ )
101
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
102
+
103
+ def forward(self, x, x_mask, g=None):
104
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
+ x = x * x_mask
106
+ for i in range(self.n_layers):
107
+ if i == self.cond_layer_idx and g is not None:
108
+ g = self.spk_emb_linear(g.transpose(1, 2))
109
+ g = g.transpose(1, 2)
110
+ x = x + g
111
+ x = x * x_mask
112
+ y = self.attn_layers[i](x, x, attn_mask)
113
+ y = self.drop(y)
114
+ x = self.norm_layers_1[i](x + y)
115
+
116
+ y = self.ffn_layers[i](x, x_mask)
117
+ y = self.drop(y)
118
+ x = self.norm_layers_2[i](x + y)
119
+ x = x * x_mask
120
+ return x
121
+
122
+
123
+ class MultiHeadAttention(nn.Module):
124
+ def __init__(
125
+ self,
126
+ channels,
127
+ out_channels,
128
+ n_heads,
129
+ p_dropout=0.0,
130
+ window_size=None,
131
+ heads_share=True,
132
+ block_length=None,
133
+ proximal_bias=False,
134
+ proximal_init=False,
135
+ ):
136
+ super().__init__()
137
+ assert channels % n_heads == 0
138
+
139
+ self.channels = channels
140
+ self.out_channels = out_channels
141
+ self.n_heads = n_heads
142
+ self.p_dropout = p_dropout
143
+ self.window_size = window_size
144
+ self.heads_share = heads_share
145
+ self.block_length = block_length
146
+ self.proximal_bias = proximal_bias
147
+ self.proximal_init = proximal_init
148
+ self.attn = None
149
+
150
+ self.k_channels = channels // n_heads
151
+ self.conv_q = nn.Conv1d(channels, channels, 1)
152
+ self.conv_k = nn.Conv1d(channels, channels, 1)
153
+ self.conv_v = nn.Conv1d(channels, channels, 1)
154
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
155
+ self.drop = nn.Dropout(p_dropout)
156
+
157
+ if window_size is not None:
158
+ n_heads_rel = 1 if heads_share else n_heads
159
+ rel_stddev = self.k_channels**-0.5
160
+ self.emb_rel_k = nn.Parameter(
161
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
162
+ * rel_stddev
163
+ )
164
+ self.emb_rel_v = nn.Parameter(
165
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
166
+ * rel_stddev
167
+ )
168
+
169
+ nn.init.xavier_uniform_(self.conv_q.weight)
170
+ nn.init.xavier_uniform_(self.conv_k.weight)
171
+ nn.init.xavier_uniform_(self.conv_v.weight)
172
+ if proximal_init:
173
+ with torch.no_grad():
174
+ self.conv_k.weight.copy_(self.conv_q.weight)
175
+ self.conv_k.bias.copy_(self.conv_q.bias)
176
+
177
+ def forward(self, x, c, attn_mask=None):
178
+ q = self.conv_q(x)
179
+ k = self.conv_k(c)
180
+ v = self.conv_v(c)
181
+
182
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
+
184
+ x = self.conv_o(x)
185
+ return x
186
+
187
+ def attention(self, query, key, value, mask=None):
188
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
189
+ b, d, t_s, t_t = (*key.size(), query.size(2))
190
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
191
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
193
+
194
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
195
+ if self.window_size is not None:
196
+ assert (
197
+ t_s == t_t
198
+ ), "Relative attention is only available for self-attention."
199
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
200
+ rel_logits = self._matmul_with_relative_keys(
201
+ query / math.sqrt(self.k_channels), key_relative_embeddings
202
+ )
203
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
204
+ scores = scores + scores_local
205
+ if self.proximal_bias:
206
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
207
+ scores = scores + self._attention_bias_proximal(t_s).to(
208
+ device=scores.device, dtype=scores.dtype
209
+ )
210
+ if mask is not None:
211
+ scores = scores.masked_fill(mask == 0, -1e4)
212
+ if self.block_length is not None:
213
+ assert (
214
+ t_s == t_t
215
+ ), "Local attention is only available for self-attention."
216
+ block_mask = (
217
+ torch.ones_like(scores)
218
+ .triu(-self.block_length)
219
+ .tril(self.block_length)
220
+ )
221
+ scores = scores.masked_fill(block_mask == 0, -1e4)
222
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
223
+ p_attn = self.drop(p_attn)
224
+ output = torch.matmul(p_attn, value)
225
+ if self.window_size is not None:
226
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
227
+ value_relative_embeddings = self._get_relative_embeddings(
228
+ self.emb_rel_v, t_s
229
+ )
230
+ output = output + self._matmul_with_relative_values(
231
+ relative_weights, value_relative_embeddings
232
+ )
233
+ output = (
234
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
235
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
236
+ return output, p_attn
237
+
238
+ def _matmul_with_relative_values(self, x, y):
239
+ """
240
+ x: [b, h, l, m]
241
+ y: [h or 1, m, d]
242
+ ret: [b, h, l, d]
243
+ """
244
+ ret = torch.matmul(x, y.unsqueeze(0))
245
+ return ret
246
+
247
+ def _matmul_with_relative_keys(self, x, y):
248
+ """
249
+ x: [b, h, l, d]
250
+ y: [h or 1, m, d]
251
+ ret: [b, h, l, m]
252
+ """
253
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
254
+ return ret
255
+
256
+ def _get_relative_embeddings(self, relative_embeddings, length):
257
+ max_relative_position = 2 * self.window_size + 1
258
+ # Pad first before slice to avoid using cond ops.
259
+ pad_length = max(length - (self.window_size + 1), 0)
260
+ slice_start_position = max((self.window_size + 1) - length, 0)
261
+ slice_end_position = slice_start_position + 2 * length - 1
262
+ if pad_length > 0:
263
+ padded_relative_embeddings = F.pad(
264
+ relative_embeddings,
265
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
266
+ )
267
+ else:
268
+ padded_relative_embeddings = relative_embeddings
269
+ used_relative_embeddings = padded_relative_embeddings[
270
+ :, slice_start_position:slice_end_position
271
+ ]
272
+ return used_relative_embeddings
273
+
274
+ def _relative_position_to_absolute_position(self, x):
275
+ """
276
+ x: [b, h, l, 2*l-1]
277
+ ret: [b, h, l, l]
278
+ """
279
+ batch, heads, length, _ = x.size()
280
+ # Concat columns of pad to shift from relative to absolute indexing.
281
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
282
+
283
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
284
+ x_flat = x.view([batch, heads, length * 2 * length])
285
+ x_flat = F.pad(
286
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
287
+ )
288
+
289
+ # Reshape and slice out the padded elements.
290
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
291
+ :, :, :length, length - 1 :
292
+ ]
293
+ return x_final
294
+
295
+ def _absolute_position_to_relative_position(self, x):
296
+ """
297
+ x: [b, h, l, l]
298
+ ret: [b, h, l, 2*l-1]
299
+ """
300
+ batch, heads, length, _ = x.size()
301
+ # padd along column
302
+ x = F.pad(
303
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
304
+ )
305
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
306
+ # add 0's in the beginning that will skew the elements after reshape
307
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
308
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
309
+ return x_final
310
+
311
+ def _attention_bias_proximal(self, length):
312
+ """Bias for self-attention to encourage attention to close positions.
313
+ Args:
314
+ length: an integer scalar.
315
+ Returns:
316
+ a Tensor with shape [1, 1, length, length]
317
+ """
318
+ r = torch.arange(length, dtype=torch.float32)
319
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
320
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
321
+
322
+
323
+ class FFN(nn.Module):
324
+ def __init__(
325
+ self,
326
+ in_channels,
327
+ out_channels,
328
+ filter_channels,
329
+ kernel_size,
330
+ p_dropout=0.0,
331
+ activation=None,
332
+ causal=False,
333
+ ):
334
+ super().__init__()
335
+ self.in_channels = in_channels
336
+ self.out_channels = out_channels
337
+ self.filter_channels = filter_channels
338
+ self.kernel_size = kernel_size
339
+ self.p_dropout = p_dropout
340
+ self.activation = activation
341
+ self.causal = causal
342
+
343
+ if causal:
344
+ self.padding = self._causal_padding
345
+ else:
346
+ self.padding = self._same_padding
347
+
348
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
349
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
350
+ self.drop = nn.Dropout(p_dropout)
351
+
352
+ def forward(self, x, x_mask):
353
+ x = self.conv_1(self.padding(x * x_mask))
354
+ if self.activation == "gelu":
355
+ x = x * torch.sigmoid(1.702 * x)
356
+ else:
357
+ x = torch.relu(x)
358
+ x = self.drop(x)
359
+ x = self.conv_2(self.padding(x * x_mask))
360
+ return x * x_mask
361
+
362
+ def _causal_padding(self, x):
363
+ if self.kernel_size == 1:
364
+ return x
365
+ pad_l = self.kernel_size - 1
366
+ pad_r = 0
367
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
368
+ x = F.pad(x, commons.convert_pad_shape(padding))
369
+ return x
370
+
371
+ def _same_padding(self, x):
372
+ if self.kernel_size == 1:
373
+ return x
374
+ pad_l = (self.kernel_size - 1) // 2
375
+ pad_r = self.kernel_size // 2
376
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
377
+ x = F.pad(x, commons.convert_pad_shape(padding))
378
+ return x
onnx_modules/V220/models_onnx.py ADDED
@@ -0,0 +1,1076 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ from . import attentions_onnx
9
+ from vector_quantize_pytorch import VectorQuantize
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from commons import init_weights, get_padding
14
+ from .text import symbols, num_tones, num_languages
15
+
16
+
17
+ class DurationDiscriminator(nn.Module): # vits2
18
+ def __init__(
19
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
20
+ ):
21
+ super().__init__()
22
+
23
+ self.in_channels = in_channels
24
+ self.filter_channels = filter_channels
25
+ self.kernel_size = kernel_size
26
+ self.p_dropout = p_dropout
27
+ self.gin_channels = gin_channels
28
+
29
+ self.drop = nn.Dropout(p_dropout)
30
+ self.conv_1 = nn.Conv1d(
31
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
32
+ )
33
+ self.norm_1 = modules.LayerNorm(filter_channels)
34
+ self.conv_2 = nn.Conv1d(
35
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
36
+ )
37
+ self.norm_2 = modules.LayerNorm(filter_channels)
38
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
39
+
40
+ self.pre_out_conv_1 = nn.Conv1d(
41
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
42
+ )
43
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
44
+ self.pre_out_conv_2 = nn.Conv1d(
45
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
46
+ )
47
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
48
+
49
+ if gin_channels != 0:
50
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
51
+
52
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
53
+
54
+ def forward_probability(self, x, x_mask, dur, g=None):
55
+ dur = self.dur_proj(dur)
56
+ x = torch.cat([x, dur], dim=1)
57
+ x = self.pre_out_conv_1(x * x_mask)
58
+ x = torch.relu(x)
59
+ x = self.pre_out_norm_1(x)
60
+ x = self.drop(x)
61
+ x = self.pre_out_conv_2(x * x_mask)
62
+ x = torch.relu(x)
63
+ x = self.pre_out_norm_2(x)
64
+ x = self.drop(x)
65
+ x = x * x_mask
66
+ x = x.transpose(1, 2)
67
+ output_prob = self.output_layer(x)
68
+ return output_prob
69
+
70
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
71
+ x = torch.detach(x)
72
+ if g is not None:
73
+ g = torch.detach(g)
74
+ x = x + self.cond(g)
75
+ x = self.conv_1(x * x_mask)
76
+ x = torch.relu(x)
77
+ x = self.norm_1(x)
78
+ x = self.drop(x)
79
+ x = self.conv_2(x * x_mask)
80
+ x = torch.relu(x)
81
+ x = self.norm_2(x)
82
+ x = self.drop(x)
83
+
84
+ output_probs = []
85
+ for dur in [dur_r, dur_hat]:
86
+ output_prob = self.forward_probability(x, x_mask, dur, g)
87
+ output_probs.append(output_prob)
88
+
89
+ return output_probs
90
+
91
+
92
+ class TransformerCouplingBlock(nn.Module):
93
+ def __init__(
94
+ self,
95
+ channels,
96
+ hidden_channels,
97
+ filter_channels,
98
+ n_heads,
99
+ n_layers,
100
+ kernel_size,
101
+ p_dropout,
102
+ n_flows=4,
103
+ gin_channels=0,
104
+ share_parameter=False,
105
+ ):
106
+ super().__init__()
107
+ self.channels = channels
108
+ self.hidden_channels = hidden_channels
109
+ self.kernel_size = kernel_size
110
+ self.n_layers = n_layers
111
+ self.n_flows = n_flows
112
+ self.gin_channels = gin_channels
113
+
114
+ self.flows = nn.ModuleList()
115
+
116
+ self.wn = (
117
+ attentions_onnx.FFT(
118
+ hidden_channels,
119
+ filter_channels,
120
+ n_heads,
121
+ n_layers,
122
+ kernel_size,
123
+ p_dropout,
124
+ isflow=True,
125
+ gin_channels=self.gin_channels,
126
+ )
127
+ if share_parameter
128
+ else None
129
+ )
130
+
131
+ for i in range(n_flows):
132
+ self.flows.append(
133
+ modules.TransformerCouplingLayer(
134
+ channels,
135
+ hidden_channels,
136
+ kernel_size,
137
+ n_layers,
138
+ n_heads,
139
+ p_dropout,
140
+ filter_channels,
141
+ mean_only=True,
142
+ wn_sharing_parameter=self.wn,
143
+ gin_channels=self.gin_channels,
144
+ )
145
+ )
146
+ self.flows.append(modules.Flip())
147
+
148
+ def forward(self, x, x_mask, g=None, reverse=True):
149
+ if not reverse:
150
+ for flow in self.flows:
151
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
152
+ else:
153
+ for flow in reversed(self.flows):
154
+ x = flow(x, x_mask, g=g, reverse=reverse)
155
+ return x
156
+
157
+
158
+ class StochasticDurationPredictor(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ filter_channels,
163
+ kernel_size,
164
+ p_dropout,
165
+ n_flows=4,
166
+ gin_channels=0,
167
+ ):
168
+ super().__init__()
169
+ filter_channels = in_channels # it needs to be removed from future version.
170
+ self.in_channels = in_channels
171
+ self.filter_channels = filter_channels
172
+ self.kernel_size = kernel_size
173
+ self.p_dropout = p_dropout
174
+ self.n_flows = n_flows
175
+ self.gin_channels = gin_channels
176
+
177
+ self.log_flow = modules.Log()
178
+ self.flows = nn.ModuleList()
179
+ self.flows.append(modules.ElementwiseAffine(2))
180
+ for i in range(n_flows):
181
+ self.flows.append(
182
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
183
+ )
184
+ self.flows.append(modules.Flip())
185
+
186
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
187
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
188
+ self.post_convs = modules.DDSConv(
189
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
190
+ )
191
+ self.post_flows = nn.ModuleList()
192
+ self.post_flows.append(modules.ElementwiseAffine(2))
193
+ for i in range(4):
194
+ self.post_flows.append(
195
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
196
+ )
197
+ self.post_flows.append(modules.Flip())
198
+
199
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
200
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
201
+ self.convs = modules.DDSConv(
202
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
203
+ )
204
+ if gin_channels != 0:
205
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
206
+
207
+ def forward(self, x, x_mask, z, g=None):
208
+ x = torch.detach(x)
209
+ x = self.pre(x)
210
+ if g is not None:
211
+ g = torch.detach(g)
212
+ x = x + self.cond(g)
213
+ x = self.convs(x, x_mask)
214
+ x = self.proj(x) * x_mask
215
+
216
+ flows = list(reversed(self.flows))
217
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
218
+ for flow in flows:
219
+ z = flow(z, x_mask, g=x, reverse=True)
220
+ z0, z1 = torch.split(z, [1, 1], 1)
221
+ logw = z0
222
+ return logw
223
+
224
+
225
+ class DurationPredictor(nn.Module):
226
+ def __init__(
227
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
228
+ ):
229
+ super().__init__()
230
+
231
+ self.in_channels = in_channels
232
+ self.filter_channels = filter_channels
233
+ self.kernel_size = kernel_size
234
+ self.p_dropout = p_dropout
235
+ self.gin_channels = gin_channels
236
+
237
+ self.drop = nn.Dropout(p_dropout)
238
+ self.conv_1 = nn.Conv1d(
239
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
240
+ )
241
+ self.norm_1 = modules.LayerNorm(filter_channels)
242
+ self.conv_2 = nn.Conv1d(
243
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
244
+ )
245
+ self.norm_2 = modules.LayerNorm(filter_channels)
246
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
247
+
248
+ if gin_channels != 0:
249
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
250
+
251
+ def forward(self, x, x_mask, g=None):
252
+ x = torch.detach(x)
253
+ if g is not None:
254
+ g = torch.detach(g)
255
+ x = x + self.cond(g)
256
+ x = self.conv_1(x * x_mask)
257
+ x = torch.relu(x)
258
+ x = self.norm_1(x)
259
+ x = self.drop(x)
260
+ x = self.conv_2(x * x_mask)
261
+ x = torch.relu(x)
262
+ x = self.norm_2(x)
263
+ x = self.drop(x)
264
+ x = self.proj(x * x_mask)
265
+ return x * x_mask
266
+
267
+
268
+ class Bottleneck(nn.Sequential):
269
+ def __init__(self, in_dim, hidden_dim):
270
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
271
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
272
+ super().__init__(*[c_fc1, c_fc2])
273
+
274
+
275
+ class Block(nn.Module):
276
+ def __init__(self, in_dim, hidden_dim) -> None:
277
+ super().__init__()
278
+ self.norm = nn.LayerNorm(in_dim)
279
+ self.mlp = MLP(in_dim, hidden_dim)
280
+
281
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
282
+ x = x + self.mlp(self.norm(x))
283
+ return x
284
+
285
+
286
+ class MLP(nn.Module):
287
+ def __init__(self, in_dim, hidden_dim):
288
+ super().__init__()
289
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
290
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
291
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
292
+
293
+ def forward(self, x: torch.Tensor):
294
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
295
+ x = self.c_proj(x)
296
+ return x
297
+
298
+
299
+ class TextEncoder(nn.Module):
300
+ def __init__(
301
+ self,
302
+ n_vocab,
303
+ out_channels,
304
+ hidden_channels,
305
+ filter_channels,
306
+ n_heads,
307
+ n_layers,
308
+ kernel_size,
309
+ p_dropout,
310
+ n_speakers,
311
+ gin_channels=0,
312
+ ):
313
+ super().__init__()
314
+ self.n_vocab = n_vocab
315
+ self.out_channels = out_channels
316
+ self.hidden_channels = hidden_channels
317
+ self.filter_channels = filter_channels
318
+ self.n_heads = n_heads
319
+ self.n_layers = n_layers
320
+ self.kernel_size = kernel_size
321
+ self.p_dropout = p_dropout
322
+ self.gin_channels = gin_channels
323
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
324
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
325
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
326
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
327
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
328
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
329
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
330
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
331
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
332
+ # self.emo_proj = nn.Linear(1024, 1024)
333
+ # self.emo_quantizer = nn.ModuleList()
334
+ # for i in range(0, n_speakers):
335
+ # self.emo_quantizer.append(
336
+ # VectorQuantize(
337
+ # dim=1024,
338
+ # codebook_size=10,
339
+ # decay=0.8,
340
+ # commitment_weight=1.0,
341
+ # learnable_codebook=True,
342
+ # ema_update=False,
343
+ # )
344
+ # )
345
+ # self.emo_q_proj = nn.Linear(1024, hidden_channels)
346
+ self.n_speakers = n_speakers
347
+ self.in_feature_net = nn.Sequential(
348
+ # input is assumed to an already normalized embedding
349
+ nn.Linear(512, 1028, bias=False),
350
+ nn.GELU(),
351
+ nn.LayerNorm(1028),
352
+ *[Block(1028, 512) for _ in range(1)],
353
+ nn.Linear(1028, 512, bias=False),
354
+ # normalize before passing to VQ?
355
+ # nn.GELU(),
356
+ # nn.LayerNorm(512),
357
+ )
358
+ self.emo_vq = VectorQuantize(
359
+ dim=512,
360
+ codebook_size=64,
361
+ codebook_dim=32,
362
+ commitment_weight=0.1,
363
+ decay=0.85,
364
+ heads=32,
365
+ kmeans_iters=20,
366
+ separate_codebook_per_head=True,
367
+ stochastic_sample_codes=True,
368
+ threshold_ema_dead_code=2,
369
+ )
370
+ self.out_feature_net = nn.Linear(512, hidden_channels)
371
+
372
+ self.encoder = attentions_onnx.Encoder(
373
+ hidden_channels,
374
+ filter_channels,
375
+ n_heads,
376
+ n_layers,
377
+ kernel_size,
378
+ p_dropout,
379
+ gin_channels=self.gin_channels,
380
+ )
381
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
382
+
383
+ def forward(
384
+ self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g=None
385
+ ):
386
+ x_mask = torch.ones_like(x).unsqueeze(0)
387
+ bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
388
+ ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
389
+ 1, 2
390
+ )
391
+ en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
392
+ 1, 2
393
+ )
394
+ emo_emb = self.in_feature_net(emo.transpose(0, 1))
395
+ emo_emb, _, _ = self.emo_vq(emo_emb.unsqueeze(1))
396
+
397
+ emo_emb = self.out_feature_net(emo_emb)
398
+
399
+ x = (
400
+ self.emb(x)
401
+ + self.tone_emb(tone)
402
+ + self.language_emb(language)
403
+ + bert_emb
404
+ + ja_bert_emb
405
+ + en_bert_emb
406
+ + emo_emb
407
+ ) * math.sqrt(
408
+ self.hidden_channels
409
+ ) # [b, t, h]
410
+ x = torch.transpose(x, 1, -1) # [b, h, t]
411
+ x_mask = x_mask.to(x.dtype)
412
+
413
+ x = self.encoder(x * x_mask, x_mask, g=g)
414
+ stats = self.proj(x) * x_mask
415
+
416
+ m, logs = torch.split(stats, self.out_channels, dim=1)
417
+ return x, m, logs, x_mask
418
+
419
+
420
+ class ResidualCouplingBlock(nn.Module):
421
+ def __init__(
422
+ self,
423
+ channels,
424
+ hidden_channels,
425
+ kernel_size,
426
+ dilation_rate,
427
+ n_layers,
428
+ n_flows=4,
429
+ gin_channels=0,
430
+ ):
431
+ super().__init__()
432
+ self.channels = channels
433
+ self.hidden_channels = hidden_channels
434
+ self.kernel_size = kernel_size
435
+ self.dilation_rate = dilation_rate
436
+ self.n_layers = n_layers
437
+ self.n_flows = n_flows
438
+ self.gin_channels = gin_channels
439
+
440
+ self.flows = nn.ModuleList()
441
+ for i in range(n_flows):
442
+ self.flows.append(
443
+ modules.ResidualCouplingLayer(
444
+ channels,
445
+ hidden_channels,
446
+ kernel_size,
447
+ dilation_rate,
448
+ n_layers,
449
+ gin_channels=gin_channels,
450
+ mean_only=True,
451
+ )
452
+ )
453
+ self.flows.append(modules.Flip())
454
+
455
+ def forward(self, x, x_mask, g=None, reverse=True):
456
+ if not reverse:
457
+ for flow in self.flows:
458
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
459
+ else:
460
+ for flow in reversed(self.flows):
461
+ x = flow(x, x_mask, g=g, reverse=reverse)
462
+ return x
463
+
464
+
465
+ class PosteriorEncoder(nn.Module):
466
+ def __init__(
467
+ self,
468
+ in_channels,
469
+ out_channels,
470
+ hidden_channels,
471
+ kernel_size,
472
+ dilation_rate,
473
+ n_layers,
474
+ gin_channels=0,
475
+ ):
476
+ super().__init__()
477
+ self.in_channels = in_channels
478
+ self.out_channels = out_channels
479
+ self.hidden_channels = hidden_channels
480
+ self.kernel_size = kernel_size
481
+ self.dilation_rate = dilation_rate
482
+ self.n_layers = n_layers
483
+ self.gin_channels = gin_channels
484
+
485
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
486
+ self.enc = modules.WN(
487
+ hidden_channels,
488
+ kernel_size,
489
+ dilation_rate,
490
+ n_layers,
491
+ gin_channels=gin_channels,
492
+ )
493
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
494
+
495
+ def forward(self, x, x_lengths, g=None):
496
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
497
+ x.dtype
498
+ )
499
+ x = self.pre(x) * x_mask
500
+ x = self.enc(x, x_mask, g=g)
501
+ stats = self.proj(x) * x_mask
502
+ m, logs = torch.split(stats, self.out_channels, dim=1)
503
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
504
+ return z, m, logs, x_mask
505
+
506
+
507
+ class Generator(torch.nn.Module):
508
+ def __init__(
509
+ self,
510
+ initial_channel,
511
+ resblock,
512
+ resblock_kernel_sizes,
513
+ resblock_dilation_sizes,
514
+ upsample_rates,
515
+ upsample_initial_channel,
516
+ upsample_kernel_sizes,
517
+ gin_channels=0,
518
+ ):
519
+ super(Generator, self).__init__()
520
+ self.num_kernels = len(resblock_kernel_sizes)
521
+ self.num_upsamples = len(upsample_rates)
522
+ self.conv_pre = Conv1d(
523
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
524
+ )
525
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
526
+
527
+ self.ups = nn.ModuleList()
528
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
529
+ self.ups.append(
530
+ weight_norm(
531
+ ConvTranspose1d(
532
+ upsample_initial_channel // (2**i),
533
+ upsample_initial_channel // (2 ** (i + 1)),
534
+ k,
535
+ u,
536
+ padding=(k - u) // 2,
537
+ )
538
+ )
539
+ )
540
+
541
+ self.resblocks = nn.ModuleList()
542
+ for i in range(len(self.ups)):
543
+ ch = upsample_initial_channel // (2 ** (i + 1))
544
+ for j, (k, d) in enumerate(
545
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
546
+ ):
547
+ self.resblocks.append(resblock(ch, k, d))
548
+
549
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
550
+ self.ups.apply(init_weights)
551
+
552
+ if gin_channels != 0:
553
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
554
+
555
+ def forward(self, x, g=None):
556
+ x = self.conv_pre(x)
557
+ if g is not None:
558
+ x = x + self.cond(g)
559
+
560
+ for i in range(self.num_upsamples):
561
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
562
+ x = self.ups[i](x)
563
+ xs = None
564
+ for j in range(self.num_kernels):
565
+ if xs is None:
566
+ xs = self.resblocks[i * self.num_kernels + j](x)
567
+ else:
568
+ xs += self.resblocks[i * self.num_kernels + j](x)
569
+ x = xs / self.num_kernels
570
+ x = F.leaky_relu(x)
571
+ x = self.conv_post(x)
572
+ x = torch.tanh(x)
573
+
574
+ return x
575
+
576
+ def remove_weight_norm(self):
577
+ print("Removing weight norm...")
578
+ for layer in self.ups:
579
+ remove_weight_norm(layer)
580
+ for layer in self.resblocks:
581
+ layer.remove_weight_norm()
582
+
583
+
584
+ class DiscriminatorP(torch.nn.Module):
585
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
586
+ super(DiscriminatorP, self).__init__()
587
+ self.period = period
588
+ self.use_spectral_norm = use_spectral_norm
589
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
590
+ self.convs = nn.ModuleList(
591
+ [
592
+ norm_f(
593
+ Conv2d(
594
+ 1,
595
+ 32,
596
+ (kernel_size, 1),
597
+ (stride, 1),
598
+ padding=(get_padding(kernel_size, 1), 0),
599
+ )
600
+ ),
601
+ norm_f(
602
+ Conv2d(
603
+ 32,
604
+ 128,
605
+ (kernel_size, 1),
606
+ (stride, 1),
607
+ padding=(get_padding(kernel_size, 1), 0),
608
+ )
609
+ ),
610
+ norm_f(
611
+ Conv2d(
612
+ 128,
613
+ 512,
614
+ (kernel_size, 1),
615
+ (stride, 1),
616
+ padding=(get_padding(kernel_size, 1), 0),
617
+ )
618
+ ),
619
+ norm_f(
620
+ Conv2d(
621
+ 512,
622
+ 1024,
623
+ (kernel_size, 1),
624
+ (stride, 1),
625
+ padding=(get_padding(kernel_size, 1), 0),
626
+ )
627
+ ),
628
+ norm_f(
629
+ Conv2d(
630
+ 1024,
631
+ 1024,
632
+ (kernel_size, 1),
633
+ 1,
634
+ padding=(get_padding(kernel_size, 1), 0),
635
+ )
636
+ ),
637
+ ]
638
+ )
639
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
640
+
641
+ def forward(self, x):
642
+ fmap = []
643
+
644
+ # 1d to 2d
645
+ b, c, t = x.shape
646
+ if t % self.period != 0: # pad first
647
+ n_pad = self.period - (t % self.period)
648
+ x = F.pad(x, (0, n_pad), "reflect")
649
+ t = t + n_pad
650
+ x = x.view(b, c, t // self.period, self.period)
651
+
652
+ for layer in self.convs:
653
+ x = layer(x)
654
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
655
+ fmap.append(x)
656
+ x = self.conv_post(x)
657
+ fmap.append(x)
658
+ x = torch.flatten(x, 1, -1)
659
+
660
+ return x, fmap
661
+
662
+
663
+ class DiscriminatorS(torch.nn.Module):
664
+ def __init__(self, use_spectral_norm=False):
665
+ super(DiscriminatorS, self).__init__()
666
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
667
+ self.convs = nn.ModuleList(
668
+ [
669
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
670
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
671
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
672
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
673
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
674
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
675
+ ]
676
+ )
677
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
678
+
679
+ def forward(self, x):
680
+ fmap = []
681
+
682
+ for layer in self.convs:
683
+ x = layer(x)
684
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
685
+ fmap.append(x)
686
+ x = self.conv_post(x)
687
+ fmap.append(x)
688
+ x = torch.flatten(x, 1, -1)
689
+
690
+ return x, fmap
691
+
692
+
693
+ class MultiPeriodDiscriminator(torch.nn.Module):
694
+ def __init__(self, use_spectral_norm=False):
695
+ super(MultiPeriodDiscriminator, self).__init__()
696
+ periods = [2, 3, 5, 7, 11]
697
+
698
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
699
+ discs = discs + [
700
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
701
+ ]
702
+ self.discriminators = nn.ModuleList(discs)
703
+
704
+ def forward(self, y, y_hat):
705
+ y_d_rs = []
706
+ y_d_gs = []
707
+ fmap_rs = []
708
+ fmap_gs = []
709
+ for i, d in enumerate(self.discriminators):
710
+ y_d_r, fmap_r = d(y)
711
+ y_d_g, fmap_g = d(y_hat)
712
+ y_d_rs.append(y_d_r)
713
+ y_d_gs.append(y_d_g)
714
+ fmap_rs.append(fmap_r)
715
+ fmap_gs.append(fmap_g)
716
+
717
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
718
+
719
+
720
+ class ReferenceEncoder(nn.Module):
721
+ """
722
+ inputs --- [N, Ty/r, n_mels*r] mels
723
+ outputs --- [N, ref_enc_gru_size]
724
+ """
725
+
726
+ def __init__(self, spec_channels, gin_channels=0):
727
+ super().__init__()
728
+ self.spec_channels = spec_channels
729
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
730
+ K = len(ref_enc_filters)
731
+ filters = [1] + ref_enc_filters
732
+ convs = [
733
+ weight_norm(
734
+ nn.Conv2d(
735
+ in_channels=filters[i],
736
+ out_channels=filters[i + 1],
737
+ kernel_size=(3, 3),
738
+ stride=(2, 2),
739
+ padding=(1, 1),
740
+ )
741
+ )
742
+ for i in range(K)
743
+ ]
744
+ self.convs = nn.ModuleList(convs)
745
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
746
+
747
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
748
+ self.gru = nn.GRU(
749
+ input_size=ref_enc_filters[-1] * out_channels,
750
+ hidden_size=256 // 2,
751
+ batch_first=True,
752
+ )
753
+ self.proj = nn.Linear(128, gin_channels)
754
+
755
+ def forward(self, inputs, mask=None):
756
+ N = inputs.size(0)
757
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
758
+ for conv in self.convs:
759
+ out = conv(out)
760
+ # out = wn(out)
761
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
762
+
763
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
764
+ T = out.size(1)
765
+ N = out.size(0)
766
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
767
+
768
+ self.gru.flatten_parameters()
769
+ memory, out = self.gru(out) # out --- [1, N, 128]
770
+
771
+ return self.proj(out.squeeze(0))
772
+
773
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
774
+ for i in range(n_convs):
775
+ L = (L - kernel_size + 2 * pad) // stride + 1
776
+ return L
777
+
778
+
779
+ class SynthesizerTrn(nn.Module):
780
+ """
781
+ Synthesizer for Training
782
+ """
783
+
784
+ def __init__(
785
+ self,
786
+ n_vocab,
787
+ spec_channels,
788
+ segment_size,
789
+ inter_channels,
790
+ hidden_channels,
791
+ filter_channels,
792
+ n_heads,
793
+ n_layers,
794
+ kernel_size,
795
+ p_dropout,
796
+ resblock,
797
+ resblock_kernel_sizes,
798
+ resblock_dilation_sizes,
799
+ upsample_rates,
800
+ upsample_initial_channel,
801
+ upsample_kernel_sizes,
802
+ n_speakers=256,
803
+ gin_channels=256,
804
+ use_sdp=True,
805
+ n_flow_layer=4,
806
+ n_layers_trans_flow=4,
807
+ flow_share_parameter=False,
808
+ use_transformer_flow=True,
809
+ **kwargs,
810
+ ):
811
+ super().__init__()
812
+ self.n_vocab = n_vocab
813
+ self.spec_channels = spec_channels
814
+ self.inter_channels = inter_channels
815
+ self.hidden_channels = hidden_channels
816
+ self.filter_channels = filter_channels
817
+ self.n_heads = n_heads
818
+ self.n_layers = n_layers
819
+ self.kernel_size = kernel_size
820
+ self.p_dropout = p_dropout
821
+ self.resblock = resblock
822
+ self.resblock_kernel_sizes = resblock_kernel_sizes
823
+ self.resblock_dilation_sizes = resblock_dilation_sizes
824
+ self.upsample_rates = upsample_rates
825
+ self.upsample_initial_channel = upsample_initial_channel
826
+ self.upsample_kernel_sizes = upsample_kernel_sizes
827
+ self.segment_size = segment_size
828
+ self.n_speakers = n_speakers
829
+ self.gin_channels = gin_channels
830
+ self.n_layers_trans_flow = n_layers_trans_flow
831
+ self.use_spk_conditioned_encoder = kwargs.get(
832
+ "use_spk_conditioned_encoder", True
833
+ )
834
+ self.use_sdp = use_sdp
835
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
836
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
837
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
838
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
839
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
840
+ self.enc_gin_channels = gin_channels
841
+ self.enc_p = TextEncoder(
842
+ n_vocab,
843
+ inter_channels,
844
+ hidden_channels,
845
+ filter_channels,
846
+ n_heads,
847
+ n_layers,
848
+ kernel_size,
849
+ p_dropout,
850
+ self.n_speakers,
851
+ gin_channels=self.enc_gin_channels,
852
+ )
853
+ self.dec = Generator(
854
+ inter_channels,
855
+ resblock,
856
+ resblock_kernel_sizes,
857
+ resblock_dilation_sizes,
858
+ upsample_rates,
859
+ upsample_initial_channel,
860
+ upsample_kernel_sizes,
861
+ gin_channels=gin_channels,
862
+ )
863
+ self.enc_q = PosteriorEncoder(
864
+ spec_channels,
865
+ inter_channels,
866
+ hidden_channels,
867
+ 5,
868
+ 1,
869
+ 16,
870
+ gin_channels=gin_channels,
871
+ )
872
+ if use_transformer_flow:
873
+ self.flow = TransformerCouplingBlock(
874
+ inter_channels,
875
+ hidden_channels,
876
+ filter_channels,
877
+ n_heads,
878
+ n_layers_trans_flow,
879
+ 5,
880
+ p_dropout,
881
+ n_flow_layer,
882
+ gin_channels=gin_channels,
883
+ share_parameter=flow_share_parameter,
884
+ )
885
+ else:
886
+ self.flow = ResidualCouplingBlock(
887
+ inter_channels,
888
+ hidden_channels,
889
+ 5,
890
+ 1,
891
+ n_flow_layer,
892
+ gin_channels=gin_channels,
893
+ )
894
+ self.sdp = StochasticDurationPredictor(
895
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
896
+ )
897
+ self.dp = DurationPredictor(
898
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
899
+ )
900
+
901
+ if n_speakers >= 1:
902
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
903
+ else:
904
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
905
+
906
+ def export_onnx(
907
+ self,
908
+ path,
909
+ max_len=None,
910
+ sdp_ratio=0,
911
+ y=None,
912
+ ):
913
+ noise_scale = 0.667
914
+ length_scale = 1
915
+ noise_scale_w = 0.8
916
+ x = (
917
+ torch.LongTensor(
918
+ [
919
+ 0,
920
+ 97,
921
+ 0,
922
+ 8,
923
+ 0,
924
+ 78,
925
+ 0,
926
+ 8,
927
+ 0,
928
+ 76,
929
+ 0,
930
+ 37,
931
+ 0,
932
+ 40,
933
+ 0,
934
+ 97,
935
+ 0,
936
+ 8,
937
+ 0,
938
+ 23,
939
+ 0,
940
+ 8,
941
+ 0,
942
+ 74,
943
+ 0,
944
+ 26,
945
+ 0,
946
+ 104,
947
+ 0,
948
+ ]
949
+ )
950
+ .unsqueeze(0)
951
+ .cpu()
952
+ )
953
+ tone = torch.zeros_like(x).cpu()
954
+ language = torch.zeros_like(x).cpu()
955
+ x_lengths = torch.LongTensor([x.shape[1]]).cpu()
956
+ sid = torch.LongTensor([0]).cpu()
957
+ bert = torch.randn(size=(x.shape[1], 1024)).cpu()
958
+ ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
959
+ en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
960
+
961
+ if self.n_speakers > 0:
962
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
963
+ torch.onnx.export(
964
+ self.emb_g,
965
+ (sid),
966
+ f"onnx/{path}/{path}_emb.onnx",
967
+ input_names=["sid"],
968
+ output_names=["g"],
969
+ verbose=True,
970
+ )
971
+ else:
972
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
973
+
974
+ emo = torch.randn(512, 1)
975
+
976
+ torch.onnx.export(
977
+ self.enc_p,
978
+ (x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g),
979
+ f"onnx/{path}/{path}_enc_p.onnx",
980
+ input_names=[
981
+ "x",
982
+ "x_lengths",
983
+ "t",
984
+ "language",
985
+ "bert_0",
986
+ "bert_1",
987
+ "bert_2",
988
+ "emo",
989
+ "g",
990
+ ],
991
+ output_names=["xout", "m_p", "logs_p", "x_mask"],
992
+ dynamic_axes={
993
+ "x": [0, 1],
994
+ "t": [0, 1],
995
+ "language": [0, 1],
996
+ "bert_0": [0],
997
+ "bert_1": [0],
998
+ "bert_2": [0],
999
+ "xout": [0, 2],
1000
+ "m_p": [0, 2],
1001
+ "logs_p": [0, 2],
1002
+ "x_mask": [0, 2],
1003
+ },
1004
+ verbose=True,
1005
+ opset_version=16,
1006
+ )
1007
+
1008
+ x, m_p, logs_p, x_mask = self.enc_p(
1009
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g
1010
+ )
1011
+
1012
+ zinput = (
1013
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
1014
+ * noise_scale_w
1015
+ )
1016
+ torch.onnx.export(
1017
+ self.sdp,
1018
+ (x, x_mask, zinput, g),
1019
+ f"onnx/{path}/{path}_sdp.onnx",
1020
+ input_names=["x", "x_mask", "zin", "g"],
1021
+ output_names=["logw"],
1022
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
1023
+ verbose=True,
1024
+ )
1025
+ torch.onnx.export(
1026
+ self.dp,
1027
+ (x, x_mask, g),
1028
+ f"onnx/{path}/{path}_dp.onnx",
1029
+ input_names=["x", "x_mask", "g"],
1030
+ output_names=["logw"],
1031
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
1032
+ verbose=True,
1033
+ )
1034
+ logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
1035
+ x, x_mask, g=g
1036
+ ) * (1 - sdp_ratio)
1037
+ w = torch.exp(logw) * x_mask * length_scale
1038
+ w_ceil = torch.ceil(w)
1039
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1040
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1041
+ x_mask.dtype
1042
+ )
1043
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1044
+ attn = commons.generate_path(w_ceil, attn_mask)
1045
+
1046
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1047
+ 1, 2
1048
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1049
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1050
+ 1, 2
1051
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1052
+
1053
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1054
+ torch.onnx.export(
1055
+ self.flow,
1056
+ (z_p, y_mask, g),
1057
+ f"onnx/{path}/{path}_flow.onnx",
1058
+ input_names=["z_p", "y_mask", "g"],
1059
+ output_names=["z"],
1060
+ dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
1061
+ verbose=True,
1062
+ )
1063
+
1064
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1065
+ z_in = (z * y_mask)[:, :, :max_len]
1066
+
1067
+ torch.onnx.export(
1068
+ self.dec,
1069
+ (z_in, g),
1070
+ f"onnx/{path}/{path}_dec.onnx",
1071
+ input_names=["z_in", "g"],
1072
+ output_names=["o"],
1073
+ dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
1074
+ verbose=True,
1075
+ )
1076
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
onnx_modules/V220/text/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .symbols import *
onnx_modules/V220/text/symbols.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
+
5
+ # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
+ num_zh_tones = 6
74
+
75
+ # japanese
76
+ ja_symbols = [
77
+ "N",
78
+ "a",
79
+ "a:",
80
+ "b",
81
+ "by",
82
+ "ch",
83
+ "d",
84
+ "dy",
85
+ "e",
86
+ "e:",
87
+ "f",
88
+ "g",
89
+ "gy",
90
+ "h",
91
+ "hy",
92
+ "i",
93
+ "i:",
94
+ "j",
95
+ "k",
96
+ "ky",
97
+ "m",
98
+ "my",
99
+ "n",
100
+ "ny",
101
+ "o",
102
+ "o:",
103
+ "p",
104
+ "py",
105
+ "q",
106
+ "r",
107
+ "ry",
108
+ "s",
109
+ "sh",
110
+ "t",
111
+ "ts",
112
+ "ty",
113
+ "u",
114
+ "u:",
115
+ "w",
116
+ "y",
117
+ "z",
118
+ "zy",
119
+ ]
120
+ num_ja_tones = 2
121
+
122
+ # English
123
+ en_symbols = [
124
+ "aa",
125
+ "ae",
126
+ "ah",
127
+ "ao",
128
+ "aw",
129
+ "ay",
130
+ "b",
131
+ "ch",
132
+ "d",
133
+ "dh",
134
+ "eh",
135
+ "er",
136
+ "ey",
137
+ "f",
138
+ "g",
139
+ "hh",
140
+ "ih",
141
+ "iy",
142
+ "jh",
143
+ "k",
144
+ "l",
145
+ "m",
146
+ "n",
147
+ "ng",
148
+ "ow",
149
+ "oy",
150
+ "p",
151
+ "r",
152
+ "s",
153
+ "sh",
154
+ "t",
155
+ "th",
156
+ "uh",
157
+ "uw",
158
+ "V",
159
+ "w",
160
+ "y",
161
+ "z",
162
+ "zh",
163
+ ]
164
+ num_en_tones = 4
165
+
166
+ # combine all symbols
167
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
+ symbols = [pad] + normal_symbols + pu_symbols
169
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
+
171
+ # combine all tones
172
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
+
174
+ # language maps
175
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
+ num_languages = len(language_id_map.keys())
177
+
178
+ language_tone_start_map = {
179
+ "ZH": 0,
180
+ "JP": num_zh_tones,
181
+ "EN": num_zh_tones + num_ja_tones,
182
+ }
183
+
184
+ if __name__ == "__main__":
185
+ a = set(zh_symbols)
186
+ b = set(en_symbols)
187
+ print(sorted(a & b))
onnx_modules/V220_OnnxInference/__init__.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+
4
+
5
+ def convert_pad_shape(pad_shape):
6
+ layer = pad_shape[::-1]
7
+ pad_shape = [item for sublist in layer for item in sublist]
8
+ return pad_shape
9
+
10
+
11
+ def sequence_mask(length, max_length=None):
12
+ if max_length is None:
13
+ max_length = length.max()
14
+ x = np.arange(max_length, dtype=length.dtype)
15
+ return np.expand_dims(x, 0) < np.expand_dims(length, 1)
16
+
17
+
18
+ def generate_path(duration, mask):
19
+ """
20
+ duration: [b, 1, t_x]
21
+ mask: [b, 1, t_y, t_x]
22
+ """
23
+
24
+ b, _, t_y, t_x = mask.shape
25
+ cum_duration = np.cumsum(duration, -1)
26
+
27
+ cum_duration_flat = cum_duration.reshape(b * t_x)
28
+ path = sequence_mask(cum_duration_flat, t_y)
29
+ path = path.reshape(b, t_x, t_y)
30
+ path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
31
+ path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
32
+ return path
33
+
34
+
35
+ class OnnxInferenceSession:
36
+ def __init__(self, path, Providers=["CPUExecutionProvider"]):
37
+ self.enc = ort.InferenceSession(path["enc"], providers=Providers)
38
+ self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
39
+ self.dp = ort.InferenceSession(path["dp"], providers=Providers)
40
+ self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
41
+ self.flow = ort.InferenceSession(path["flow"], providers=Providers)
42
+ self.dec = ort.InferenceSession(path["dec"], providers=Providers)
43
+
44
+ def __call__(
45
+ self,
46
+ seq,
47
+ tone,
48
+ language,
49
+ bert_zh,
50
+ bert_jp,
51
+ bert_en,
52
+ emo,
53
+ sid,
54
+ seed=114514,
55
+ seq_noise_scale=0.8,
56
+ sdp_noise_scale=0.6,
57
+ length_scale=1.0,
58
+ sdp_ratio=0.0,
59
+ ):
60
+ if seq.ndim == 1:
61
+ seq = np.expand_dims(seq, 0)
62
+ if tone.ndim == 1:
63
+ tone = np.expand_dims(tone, 0)
64
+ if language.ndim == 1:
65
+ language = np.expand_dims(language, 0)
66
+ assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
67
+ g = self.emb_g.run(
68
+ None,
69
+ {
70
+ "sid": sid.astype(np.int64),
71
+ },
72
+ )[0]
73
+ g = np.expand_dims(g, -1)
74
+ enc_rtn = self.enc.run(
75
+ None,
76
+ {
77
+ "x": seq.astype(np.int64),
78
+ "t": tone.astype(np.int64),
79
+ "language": language.astype(np.int64),
80
+ "bert_0": bert_zh.astype(np.float32),
81
+ "bert_1": bert_jp.astype(np.float32),
82
+ "bert_2": bert_en.astype(np.float32),
83
+ "emo": emo.astype(np.float32),
84
+ "g": g.astype(np.float32),
85
+ },
86
+ )
87
+ x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
88
+ np.random.seed(seed)
89
+ zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
90
+ logw = self.sdp.run(
91
+ None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
92
+ )[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
93
+ 0
94
+ ] * (
95
+ 1 - sdp_ratio
96
+ )
97
+ w = np.exp(logw) * x_mask * length_scale
98
+ w_ceil = np.ceil(w)
99
+ y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
100
+ np.int64
101
+ )
102
+ y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
103
+ attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
104
+ attn = generate_path(w_ceil, attn_mask)
105
+ m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
106
+ 0, 2, 1
107
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
108
+ logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
109
+ 0, 2, 1
110
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
111
+
112
+ z_p = (
113
+ m_p
114
+ + np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
115
+ * np.exp(logs_p)
116
+ * seq_noise_scale
117
+ )
118
+
119
+ z = self.flow.run(
120
+ None,
121
+ {
122
+ "z_p": z_p.astype(np.float32),
123
+ "y_mask": y_mask.astype(np.float32),
124
+ "g": g,
125
+ },
126
+ )[0]
127
+
128
+ return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
onnx_modules/V220_novq_dev/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .text.symbols import symbols
2
+ from .models_onnx import SynthesizerTrn
3
+
4
+ __all__ = ["symbols", "SynthesizerTrn"]
onnx_modules/V220_novq_dev/attentions_onnx.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+ # if isflow:
59
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
62
+ # self.gin_channels = 256
63
+ self.cond_layer_idx = self.n_layers
64
+ if "gin_channels" in kwargs:
65
+ self.gin_channels = kwargs["gin_channels"]
66
+ if self.gin_channels != 0:
67
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
+ # vits2 says 3rd block, so idx is 2 by default
69
+ self.cond_layer_idx = (
70
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
+ )
72
+ logging.debug(self.gin_channels, self.cond_layer_idx)
73
+ assert (
74
+ self.cond_layer_idx < self.n_layers
75
+ ), "cond_layer_idx should be less than n_layers"
76
+ self.drop = nn.Dropout(p_dropout)
77
+ self.attn_layers = nn.ModuleList()
78
+ self.norm_layers_1 = nn.ModuleList()
79
+ self.ffn_layers = nn.ModuleList()
80
+ self.norm_layers_2 = nn.ModuleList()
81
+ for i in range(self.n_layers):
82
+ self.attn_layers.append(
83
+ MultiHeadAttention(
84
+ hidden_channels,
85
+ hidden_channels,
86
+ n_heads,
87
+ p_dropout=p_dropout,
88
+ window_size=window_size,
89
+ )
90
+ )
91
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
92
+ self.ffn_layers.append(
93
+ FFN(
94
+ hidden_channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ kernel_size,
98
+ p_dropout=p_dropout,
99
+ )
100
+ )
101
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
102
+
103
+ def forward(self, x, x_mask, g=None):
104
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
+ x = x * x_mask
106
+ for i in range(self.n_layers):
107
+ if i == self.cond_layer_idx and g is not None:
108
+ g = self.spk_emb_linear(g.transpose(1, 2))
109
+ g = g.transpose(1, 2)
110
+ x = x + g
111
+ x = x * x_mask
112
+ y = self.attn_layers[i](x, x, attn_mask)
113
+ y = self.drop(y)
114
+ x = self.norm_layers_1[i](x + y)
115
+
116
+ y = self.ffn_layers[i](x, x_mask)
117
+ y = self.drop(y)
118
+ x = self.norm_layers_2[i](x + y)
119
+ x = x * x_mask
120
+ return x
121
+
122
+
123
+ class MultiHeadAttention(nn.Module):
124
+ def __init__(
125
+ self,
126
+ channels,
127
+ out_channels,
128
+ n_heads,
129
+ p_dropout=0.0,
130
+ window_size=None,
131
+ heads_share=True,
132
+ block_length=None,
133
+ proximal_bias=False,
134
+ proximal_init=False,
135
+ ):
136
+ super().__init__()
137
+ assert channels % n_heads == 0
138
+
139
+ self.channels = channels
140
+ self.out_channels = out_channels
141
+ self.n_heads = n_heads
142
+ self.p_dropout = p_dropout
143
+ self.window_size = window_size
144
+ self.heads_share = heads_share
145
+ self.block_length = block_length
146
+ self.proximal_bias = proximal_bias
147
+ self.proximal_init = proximal_init
148
+ self.attn = None
149
+
150
+ self.k_channels = channels // n_heads
151
+ self.conv_q = nn.Conv1d(channels, channels, 1)
152
+ self.conv_k = nn.Conv1d(channels, channels, 1)
153
+ self.conv_v = nn.Conv1d(channels, channels, 1)
154
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
155
+ self.drop = nn.Dropout(p_dropout)
156
+
157
+ if window_size is not None:
158
+ n_heads_rel = 1 if heads_share else n_heads
159
+ rel_stddev = self.k_channels**-0.5
160
+ self.emb_rel_k = nn.Parameter(
161
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
162
+ * rel_stddev
163
+ )
164
+ self.emb_rel_v = nn.Parameter(
165
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
166
+ * rel_stddev
167
+ )
168
+
169
+ nn.init.xavier_uniform_(self.conv_q.weight)
170
+ nn.init.xavier_uniform_(self.conv_k.weight)
171
+ nn.init.xavier_uniform_(self.conv_v.weight)
172
+ if proximal_init:
173
+ with torch.no_grad():
174
+ self.conv_k.weight.copy_(self.conv_q.weight)
175
+ self.conv_k.bias.copy_(self.conv_q.bias)
176
+
177
+ def forward(self, x, c, attn_mask=None):
178
+ q = self.conv_q(x)
179
+ k = self.conv_k(c)
180
+ v = self.conv_v(c)
181
+
182
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
+
184
+ x = self.conv_o(x)
185
+ return x
186
+
187
+ def attention(self, query, key, value, mask=None):
188
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
189
+ b, d, t_s, t_t = (*key.size(), query.size(2))
190
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
191
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
193
+
194
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
195
+ if self.window_size is not None:
196
+ assert (
197
+ t_s == t_t
198
+ ), "Relative attention is only available for self-attention."
199
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
200
+ rel_logits = self._matmul_with_relative_keys(
201
+ query / math.sqrt(self.k_channels), key_relative_embeddings
202
+ )
203
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
204
+ scores = scores + scores_local
205
+ if self.proximal_bias:
206
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
207
+ scores = scores + self._attention_bias_proximal(t_s).to(
208
+ device=scores.device, dtype=scores.dtype
209
+ )
210
+ if mask is not None:
211
+ scores = scores.masked_fill(mask == 0, -1e4)
212
+ if self.block_length is not None:
213
+ assert (
214
+ t_s == t_t
215
+ ), "Local attention is only available for self-attention."
216
+ block_mask = (
217
+ torch.ones_like(scores)
218
+ .triu(-self.block_length)
219
+ .tril(self.block_length)
220
+ )
221
+ scores = scores.masked_fill(block_mask == 0, -1e4)
222
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
223
+ p_attn = self.drop(p_attn)
224
+ output = torch.matmul(p_attn, value)
225
+ if self.window_size is not None:
226
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
227
+ value_relative_embeddings = self._get_relative_embeddings(
228
+ self.emb_rel_v, t_s
229
+ )
230
+ output = output + self._matmul_with_relative_values(
231
+ relative_weights, value_relative_embeddings
232
+ )
233
+ output = (
234
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
235
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
236
+ return output, p_attn
237
+
238
+ def _matmul_with_relative_values(self, x, y):
239
+ """
240
+ x: [b, h, l, m]
241
+ y: [h or 1, m, d]
242
+ ret: [b, h, l, d]
243
+ """
244
+ ret = torch.matmul(x, y.unsqueeze(0))
245
+ return ret
246
+
247
+ def _matmul_with_relative_keys(self, x, y):
248
+ """
249
+ x: [b, h, l, d]
250
+ y: [h or 1, m, d]
251
+ ret: [b, h, l, m]
252
+ """
253
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
254
+ return ret
255
+
256
+ def _get_relative_embeddings(self, relative_embeddings, length):
257
+ max_relative_position = 2 * self.window_size + 1
258
+ # Pad first before slice to avoid using cond ops.
259
+ pad_length = max(length - (self.window_size + 1), 0)
260
+ slice_start_position = max((self.window_size + 1) - length, 0)
261
+ slice_end_position = slice_start_position + 2 * length - 1
262
+ if pad_length > 0:
263
+ padded_relative_embeddings = F.pad(
264
+ relative_embeddings,
265
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
266
+ )
267
+ else:
268
+ padded_relative_embeddings = relative_embeddings
269
+ used_relative_embeddings = padded_relative_embeddings[
270
+ :, slice_start_position:slice_end_position
271
+ ]
272
+ return used_relative_embeddings
273
+
274
+ def _relative_position_to_absolute_position(self, x):
275
+ """
276
+ x: [b, h, l, 2*l-1]
277
+ ret: [b, h, l, l]
278
+ """
279
+ batch, heads, length, _ = x.size()
280
+ # Concat columns of pad to shift from relative to absolute indexing.
281
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
282
+
283
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
284
+ x_flat = x.view([batch, heads, length * 2 * length])
285
+ x_flat = F.pad(
286
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
287
+ )
288
+
289
+ # Reshape and slice out the padded elements.
290
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
291
+ :, :, :length, length - 1 :
292
+ ]
293
+ return x_final
294
+
295
+ def _absolute_position_to_relative_position(self, x):
296
+ """
297
+ x: [b, h, l, l]
298
+ ret: [b, h, l, 2*l-1]
299
+ """
300
+ batch, heads, length, _ = x.size()
301
+ # padd along column
302
+ x = F.pad(
303
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
304
+ )
305
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
306
+ # add 0's in the beginning that will skew the elements after reshape
307
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
308
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
309
+ return x_final
310
+
311
+ def _attention_bias_proximal(self, length):
312
+ """Bias for self-attention to encourage attention to close positions.
313
+ Args:
314
+ length: an integer scalar.
315
+ Returns:
316
+ a Tensor with shape [1, 1, length, length]
317
+ """
318
+ r = torch.arange(length, dtype=torch.float32)
319
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
320
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
321
+
322
+
323
+ class FFN(nn.Module):
324
+ def __init__(
325
+ self,
326
+ in_channels,
327
+ out_channels,
328
+ filter_channels,
329
+ kernel_size,
330
+ p_dropout=0.0,
331
+ activation=None,
332
+ causal=False,
333
+ ):
334
+ super().__init__()
335
+ self.in_channels = in_channels
336
+ self.out_channels = out_channels
337
+ self.filter_channels = filter_channels
338
+ self.kernel_size = kernel_size
339
+ self.p_dropout = p_dropout
340
+ self.activation = activation
341
+ self.causal = causal
342
+
343
+ if causal:
344
+ self.padding = self._causal_padding
345
+ else:
346
+ self.padding = self._same_padding
347
+
348
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
349
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
350
+ self.drop = nn.Dropout(p_dropout)
351
+
352
+ def forward(self, x, x_mask):
353
+ x = self.conv_1(self.padding(x * x_mask))
354
+ if self.activation == "gelu":
355
+ x = x * torch.sigmoid(1.702 * x)
356
+ else:
357
+ x = torch.relu(x)
358
+ x = self.drop(x)
359
+ x = self.conv_2(self.padding(x * x_mask))
360
+ return x * x_mask
361
+
362
+ def _causal_padding(self, x):
363
+ if self.kernel_size == 1:
364
+ return x
365
+ pad_l = self.kernel_size - 1
366
+ pad_r = 0
367
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
368
+ x = F.pad(x, commons.convert_pad_shape(padding))
369
+ return x
370
+
371
+ def _same_padding(self, x):
372
+ if self.kernel_size == 1:
373
+ return x
374
+ pad_l = (self.kernel_size - 1) // 2
375
+ pad_r = self.kernel_size // 2
376
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
377
+ x = F.pad(x, commons.convert_pad_shape(padding))
378
+ return x
onnx_modules/V220_novq_dev/models_onnx.py ADDED
@@ -0,0 +1,1048 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ from . import attentions_onnx
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from commons import init_weights, get_padding
13
+ from .text import symbols, num_tones, num_languages
14
+
15
+
16
+ class DurationDiscriminator(nn.Module): # vits2
17
+ def __init__(
18
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
19
+ ):
20
+ super().__init__()
21
+
22
+ self.in_channels = in_channels
23
+ self.filter_channels = filter_channels
24
+ self.kernel_size = kernel_size
25
+ self.p_dropout = p_dropout
26
+ self.gin_channels = gin_channels
27
+
28
+ self.drop = nn.Dropout(p_dropout)
29
+ self.conv_1 = nn.Conv1d(
30
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
31
+ )
32
+ self.norm_1 = modules.LayerNorm(filter_channels)
33
+ self.conv_2 = nn.Conv1d(
34
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_2 = modules.LayerNorm(filter_channels)
37
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
38
+
39
+ self.pre_out_conv_1 = nn.Conv1d(
40
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
41
+ )
42
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
43
+ self.pre_out_conv_2 = nn.Conv1d(
44
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
+ )
46
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
47
+
48
+ if gin_channels != 0:
49
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
50
+
51
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
52
+
53
+ def forward_probability(self, x, x_mask, dur, g=None):
54
+ dur = self.dur_proj(dur)
55
+ x = torch.cat([x, dur], dim=1)
56
+ x = self.pre_out_conv_1(x * x_mask)
57
+ x = torch.relu(x)
58
+ x = self.pre_out_norm_1(x)
59
+ x = self.drop(x)
60
+ x = self.pre_out_conv_2(x * x_mask)
61
+ x = torch.relu(x)
62
+ x = self.pre_out_norm_2(x)
63
+ x = self.drop(x)
64
+ x = x * x_mask
65
+ x = x.transpose(1, 2)
66
+ output_prob = self.output_layer(x)
67
+ return output_prob
68
+
69
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
70
+ x = torch.detach(x)
71
+ if g is not None:
72
+ g = torch.detach(g)
73
+ x = x + self.cond(g)
74
+ x = self.conv_1(x * x_mask)
75
+ x = torch.relu(x)
76
+ x = self.norm_1(x)
77
+ x = self.drop(x)
78
+ x = self.conv_2(x * x_mask)
79
+ x = torch.relu(x)
80
+ x = self.norm_2(x)
81
+ x = self.drop(x)
82
+
83
+ output_probs = []
84
+ for dur in [dur_r, dur_hat]:
85
+ output_prob = self.forward_probability(x, x_mask, dur, g)
86
+ output_probs.append(output_prob)
87
+
88
+ return output_probs
89
+
90
+
91
+ class TransformerCouplingBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ n_heads,
98
+ n_layers,
99
+ kernel_size,
100
+ p_dropout,
101
+ n_flows=4,
102
+ gin_channels=0,
103
+ share_parameter=False,
104
+ ):
105
+ super().__init__()
106
+ self.channels = channels
107
+ self.hidden_channels = hidden_channels
108
+ self.kernel_size = kernel_size
109
+ self.n_layers = n_layers
110
+ self.n_flows = n_flows
111
+ self.gin_channels = gin_channels
112
+
113
+ self.flows = nn.ModuleList()
114
+
115
+ self.wn = (
116
+ attentions_onnx.FFT(
117
+ hidden_channels,
118
+ filter_channels,
119
+ n_heads,
120
+ n_layers,
121
+ kernel_size,
122
+ p_dropout,
123
+ isflow=True,
124
+ gin_channels=self.gin_channels,
125
+ )
126
+ if share_parameter
127
+ else None
128
+ )
129
+
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.TransformerCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ n_layers,
137
+ n_heads,
138
+ p_dropout,
139
+ filter_channels,
140
+ mean_only=True,
141
+ wn_sharing_parameter=self.wn,
142
+ gin_channels=self.gin_channels,
143
+ )
144
+ )
145
+ self.flows.append(modules.Flip())
146
+
147
+ def forward(self, x, x_mask, g=None, reverse=True):
148
+ if not reverse:
149
+ for flow in self.flows:
150
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
151
+ else:
152
+ for flow in reversed(self.flows):
153
+ x = flow(x, x_mask, g=g, reverse=reverse)
154
+ return x
155
+
156
+
157
+ class StochasticDurationPredictor(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ filter_channels,
162
+ kernel_size,
163
+ p_dropout,
164
+ n_flows=4,
165
+ gin_channels=0,
166
+ ):
167
+ super().__init__()
168
+ filter_channels = in_channels # it needs to be removed from future version.
169
+ self.in_channels = in_channels
170
+ self.filter_channels = filter_channels
171
+ self.kernel_size = kernel_size
172
+ self.p_dropout = p_dropout
173
+ self.n_flows = n_flows
174
+ self.gin_channels = gin_channels
175
+
176
+ self.log_flow = modules.Log()
177
+ self.flows = nn.ModuleList()
178
+ self.flows.append(modules.ElementwiseAffine(2))
179
+ for i in range(n_flows):
180
+ self.flows.append(
181
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
182
+ )
183
+ self.flows.append(modules.Flip())
184
+
185
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
186
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
187
+ self.post_convs = modules.DDSConv(
188
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
189
+ )
190
+ self.post_flows = nn.ModuleList()
191
+ self.post_flows.append(modules.ElementwiseAffine(2))
192
+ for i in range(4):
193
+ self.post_flows.append(
194
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
195
+ )
196
+ self.post_flows.append(modules.Flip())
197
+
198
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
199
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
200
+ self.convs = modules.DDSConv(
201
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
202
+ )
203
+ if gin_channels != 0:
204
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
205
+
206
+ def forward(self, x, x_mask, z, g=None):
207
+ x = torch.detach(x)
208
+ x = self.pre(x)
209
+ if g is not None:
210
+ g = torch.detach(g)
211
+ x = x + self.cond(g)
212
+ x = self.convs(x, x_mask)
213
+ x = self.proj(x) * x_mask
214
+
215
+ flows = list(reversed(self.flows))
216
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
217
+ for flow in flows:
218
+ z = flow(z, x_mask, g=x, reverse=True)
219
+ z0, z1 = torch.split(z, [1, 1], 1)
220
+ logw = z0
221
+ return logw
222
+
223
+
224
+ class DurationPredictor(nn.Module):
225
+ def __init__(
226
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
227
+ ):
228
+ super().__init__()
229
+
230
+ self.in_channels = in_channels
231
+ self.filter_channels = filter_channels
232
+ self.kernel_size = kernel_size
233
+ self.p_dropout = p_dropout
234
+ self.gin_channels = gin_channels
235
+
236
+ self.drop = nn.Dropout(p_dropout)
237
+ self.conv_1 = nn.Conv1d(
238
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
239
+ )
240
+ self.norm_1 = modules.LayerNorm(filter_channels)
241
+ self.conv_2 = nn.Conv1d(
242
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
243
+ )
244
+ self.norm_2 = modules.LayerNorm(filter_channels)
245
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
246
+
247
+ if gin_channels != 0:
248
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
249
+
250
+ def forward(self, x, x_mask, g=None):
251
+ x = torch.detach(x)
252
+ if g is not None:
253
+ g = torch.detach(g)
254
+ x = x + self.cond(g)
255
+ x = self.conv_1(x * x_mask)
256
+ x = torch.relu(x)
257
+ x = self.norm_1(x)
258
+ x = self.drop(x)
259
+ x = self.conv_2(x * x_mask)
260
+ x = torch.relu(x)
261
+ x = self.norm_2(x)
262
+ x = self.drop(x)
263
+ x = self.proj(x * x_mask)
264
+ return x * x_mask
265
+
266
+
267
+ class Bottleneck(nn.Sequential):
268
+ def __init__(self, in_dim, hidden_dim):
269
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
270
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
271
+ super().__init__(*[c_fc1, c_fc2])
272
+
273
+
274
+ class Block(nn.Module):
275
+ def __init__(self, in_dim, hidden_dim) -> None:
276
+ super().__init__()
277
+ self.norm = nn.LayerNorm(in_dim)
278
+ self.mlp = MLP(in_dim, hidden_dim)
279
+
280
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
281
+ x = x + self.mlp(self.norm(x))
282
+ return x
283
+
284
+
285
+ class MLP(nn.Module):
286
+ def __init__(self, in_dim, hidden_dim):
287
+ super().__init__()
288
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
289
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
290
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
291
+
292
+ def forward(self, x: torch.Tensor):
293
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
294
+ x = self.c_proj(x)
295
+ return x
296
+
297
+
298
+ class TextEncoder(nn.Module):
299
+ def __init__(
300
+ self,
301
+ n_vocab,
302
+ out_channels,
303
+ hidden_channels,
304
+ filter_channels,
305
+ n_heads,
306
+ n_layers,
307
+ kernel_size,
308
+ p_dropout,
309
+ n_speakers,
310
+ gin_channels=0,
311
+ ):
312
+ super().__init__()
313
+ self.n_vocab = n_vocab
314
+ self.out_channels = out_channels
315
+ self.hidden_channels = hidden_channels
316
+ self.filter_channels = filter_channels
317
+ self.n_heads = n_heads
318
+ self.n_layers = n_layers
319
+ self.kernel_size = kernel_size
320
+ self.p_dropout = p_dropout
321
+ self.gin_channels = gin_channels
322
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
323
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
324
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
325
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
326
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
327
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
328
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
329
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
330
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
331
+ # self.emo_proj = nn.Linear(1024, 1024)
332
+ # self.emo_quantizer = nn.ModuleList()
333
+ # for i in range(0, n_speakers):
334
+ # self.emo_quantizer.append(
335
+ # VectorQuantize(
336
+ # dim=1024,
337
+ # codebook_size=10,
338
+ # decay=0.8,
339
+ # commitment_weight=1.0,
340
+ # learnable_codebook=True,
341
+ # ema_update=False,
342
+ # )
343
+ # )
344
+ # self.emo_q_proj = nn.Linear(1024, hidden_channels)
345
+ self.n_speakers = n_speakers
346
+ self.emo_proj = nn.Linear(512, hidden_channels)
347
+
348
+ self.encoder = attentions_onnx.Encoder(
349
+ hidden_channels,
350
+ filter_channels,
351
+ n_heads,
352
+ n_layers,
353
+ kernel_size,
354
+ p_dropout,
355
+ gin_channels=self.gin_channels,
356
+ )
357
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
358
+
359
+ def forward(
360
+ self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g=None
361
+ ):
362
+ x_mask = torch.ones_like(x).unsqueeze(0)
363
+ bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
364
+ ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
365
+ 1, 2
366
+ )
367
+ en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
368
+ 1, 2
369
+ )
370
+
371
+ x = (
372
+ self.emb(x)
373
+ + self.tone_emb(tone)
374
+ + self.language_emb(language)
375
+ + bert_emb
376
+ + ja_bert_emb
377
+ + en_bert_emb
378
+ + self.emo_proj(emo)
379
+ ) * math.sqrt(
380
+ self.hidden_channels
381
+ ) # [b, t, h]
382
+ x = torch.transpose(x, 1, -1) # [b, h, t]
383
+ x_mask = x_mask.to(x.dtype)
384
+
385
+ x = self.encoder(x * x_mask, x_mask, g=g)
386
+ stats = self.proj(x) * x_mask
387
+
388
+ m, logs = torch.split(stats, self.out_channels, dim=1)
389
+ return x, m, logs, x_mask
390
+
391
+
392
+ class ResidualCouplingBlock(nn.Module):
393
+ def __init__(
394
+ self,
395
+ channels,
396
+ hidden_channels,
397
+ kernel_size,
398
+ dilation_rate,
399
+ n_layers,
400
+ n_flows=4,
401
+ gin_channels=0,
402
+ ):
403
+ super().__init__()
404
+ self.channels = channels
405
+ self.hidden_channels = hidden_channels
406
+ self.kernel_size = kernel_size
407
+ self.dilation_rate = dilation_rate
408
+ self.n_layers = n_layers
409
+ self.n_flows = n_flows
410
+ self.gin_channels = gin_channels
411
+
412
+ self.flows = nn.ModuleList()
413
+ for i in range(n_flows):
414
+ self.flows.append(
415
+ modules.ResidualCouplingLayer(
416
+ channels,
417
+ hidden_channels,
418
+ kernel_size,
419
+ dilation_rate,
420
+ n_layers,
421
+ gin_channels=gin_channels,
422
+ mean_only=True,
423
+ )
424
+ )
425
+ self.flows.append(modules.Flip())
426
+
427
+ def forward(self, x, x_mask, g=None, reverse=True):
428
+ if not reverse:
429
+ for flow in self.flows:
430
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
431
+ else:
432
+ for flow in reversed(self.flows):
433
+ x = flow(x, x_mask, g=g, reverse=reverse)
434
+ return x
435
+
436
+
437
+ class PosteriorEncoder(nn.Module):
438
+ def __init__(
439
+ self,
440
+ in_channels,
441
+ out_channels,
442
+ hidden_channels,
443
+ kernel_size,
444
+ dilation_rate,
445
+ n_layers,
446
+ gin_channels=0,
447
+ ):
448
+ super().__init__()
449
+ self.in_channels = in_channels
450
+ self.out_channels = out_channels
451
+ self.hidden_channels = hidden_channels
452
+ self.kernel_size = kernel_size
453
+ self.dilation_rate = dilation_rate
454
+ self.n_layers = n_layers
455
+ self.gin_channels = gin_channels
456
+
457
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
458
+ self.enc = modules.WN(
459
+ hidden_channels,
460
+ kernel_size,
461
+ dilation_rate,
462
+ n_layers,
463
+ gin_channels=gin_channels,
464
+ )
465
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
466
+
467
+ def forward(self, x, x_lengths, g=None):
468
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
469
+ x.dtype
470
+ )
471
+ x = self.pre(x) * x_mask
472
+ x = self.enc(x, x_mask, g=g)
473
+ stats = self.proj(x) * x_mask
474
+ m, logs = torch.split(stats, self.out_channels, dim=1)
475
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
476
+ return z, m, logs, x_mask
477
+
478
+
479
+ class Generator(torch.nn.Module):
480
+ def __init__(
481
+ self,
482
+ initial_channel,
483
+ resblock,
484
+ resblock_kernel_sizes,
485
+ resblock_dilation_sizes,
486
+ upsample_rates,
487
+ upsample_initial_channel,
488
+ upsample_kernel_sizes,
489
+ gin_channels=0,
490
+ ):
491
+ super(Generator, self).__init__()
492
+ self.num_kernels = len(resblock_kernel_sizes)
493
+ self.num_upsamples = len(upsample_rates)
494
+ self.conv_pre = Conv1d(
495
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
496
+ )
497
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
498
+
499
+ self.ups = nn.ModuleList()
500
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
501
+ self.ups.append(
502
+ weight_norm(
503
+ ConvTranspose1d(
504
+ upsample_initial_channel // (2**i),
505
+ upsample_initial_channel // (2 ** (i + 1)),
506
+ k,
507
+ u,
508
+ padding=(k - u) // 2,
509
+ )
510
+ )
511
+ )
512
+
513
+ self.resblocks = nn.ModuleList()
514
+ for i in range(len(self.ups)):
515
+ ch = upsample_initial_channel // (2 ** (i + 1))
516
+ for j, (k, d) in enumerate(
517
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
518
+ ):
519
+ self.resblocks.append(resblock(ch, k, d))
520
+
521
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
522
+ self.ups.apply(init_weights)
523
+
524
+ if gin_channels != 0:
525
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
526
+
527
+ def forward(self, x, g=None):
528
+ x = self.conv_pre(x)
529
+ if g is not None:
530
+ x = x + self.cond(g)
531
+
532
+ for i in range(self.num_upsamples):
533
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
534
+ x = self.ups[i](x)
535
+ xs = None
536
+ for j in range(self.num_kernels):
537
+ if xs is None:
538
+ xs = self.resblocks[i * self.num_kernels + j](x)
539
+ else:
540
+ xs += self.resblocks[i * self.num_kernels + j](x)
541
+ x = xs / self.num_kernels
542
+ x = F.leaky_relu(x)
543
+ x = self.conv_post(x)
544
+ x = torch.tanh(x)
545
+
546
+ return x
547
+
548
+ def remove_weight_norm(self):
549
+ print("Removing weight norm...")
550
+ for layer in self.ups:
551
+ remove_weight_norm(layer)
552
+ for layer in self.resblocks:
553
+ layer.remove_weight_norm()
554
+
555
+
556
+ class DiscriminatorP(torch.nn.Module):
557
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
558
+ super(DiscriminatorP, self).__init__()
559
+ self.period = period
560
+ self.use_spectral_norm = use_spectral_norm
561
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
562
+ self.convs = nn.ModuleList(
563
+ [
564
+ norm_f(
565
+ Conv2d(
566
+ 1,
567
+ 32,
568
+ (kernel_size, 1),
569
+ (stride, 1),
570
+ padding=(get_padding(kernel_size, 1), 0),
571
+ )
572
+ ),
573
+ norm_f(
574
+ Conv2d(
575
+ 32,
576
+ 128,
577
+ (kernel_size, 1),
578
+ (stride, 1),
579
+ padding=(get_padding(kernel_size, 1), 0),
580
+ )
581
+ ),
582
+ norm_f(
583
+ Conv2d(
584
+ 128,
585
+ 512,
586
+ (kernel_size, 1),
587
+ (stride, 1),
588
+ padding=(get_padding(kernel_size, 1), 0),
589
+ )
590
+ ),
591
+ norm_f(
592
+ Conv2d(
593
+ 512,
594
+ 1024,
595
+ (kernel_size, 1),
596
+ (stride, 1),
597
+ padding=(get_padding(kernel_size, 1), 0),
598
+ )
599
+ ),
600
+ norm_f(
601
+ Conv2d(
602
+ 1024,
603
+ 1024,
604
+ (kernel_size, 1),
605
+ 1,
606
+ padding=(get_padding(kernel_size, 1), 0),
607
+ )
608
+ ),
609
+ ]
610
+ )
611
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
612
+
613
+ def forward(self, x):
614
+ fmap = []
615
+
616
+ # 1d to 2d
617
+ b, c, t = x.shape
618
+ if t % self.period != 0: # pad first
619
+ n_pad = self.period - (t % self.period)
620
+ x = F.pad(x, (0, n_pad), "reflect")
621
+ t = t + n_pad
622
+ x = x.view(b, c, t // self.period, self.period)
623
+
624
+ for layer in self.convs:
625
+ x = layer(x)
626
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
627
+ fmap.append(x)
628
+ x = self.conv_post(x)
629
+ fmap.append(x)
630
+ x = torch.flatten(x, 1, -1)
631
+
632
+ return x, fmap
633
+
634
+
635
+ class DiscriminatorS(torch.nn.Module):
636
+ def __init__(self, use_spectral_norm=False):
637
+ super(DiscriminatorS, self).__init__()
638
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
639
+ self.convs = nn.ModuleList(
640
+ [
641
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
642
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
643
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
644
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
645
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
646
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
647
+ ]
648
+ )
649
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
650
+
651
+ def forward(self, x):
652
+ fmap = []
653
+
654
+ for layer in self.convs:
655
+ x = layer(x)
656
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
657
+ fmap.append(x)
658
+ x = self.conv_post(x)
659
+ fmap.append(x)
660
+ x = torch.flatten(x, 1, -1)
661
+
662
+ return x, fmap
663
+
664
+
665
+ class MultiPeriodDiscriminator(torch.nn.Module):
666
+ def __init__(self, use_spectral_norm=False):
667
+ super(MultiPeriodDiscriminator, self).__init__()
668
+ periods = [2, 3, 5, 7, 11]
669
+
670
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
671
+ discs = discs + [
672
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
673
+ ]
674
+ self.discriminators = nn.ModuleList(discs)
675
+
676
+ def forward(self, y, y_hat):
677
+ y_d_rs = []
678
+ y_d_gs = []
679
+ fmap_rs = []
680
+ fmap_gs = []
681
+ for i, d in enumerate(self.discriminators):
682
+ y_d_r, fmap_r = d(y)
683
+ y_d_g, fmap_g = d(y_hat)
684
+ y_d_rs.append(y_d_r)
685
+ y_d_gs.append(y_d_g)
686
+ fmap_rs.append(fmap_r)
687
+ fmap_gs.append(fmap_g)
688
+
689
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
690
+
691
+
692
+ class ReferenceEncoder(nn.Module):
693
+ """
694
+ inputs --- [N, Ty/r, n_mels*r] mels
695
+ outputs --- [N, ref_enc_gru_size]
696
+ """
697
+
698
+ def __init__(self, spec_channels, gin_channels=0):
699
+ super().__init__()
700
+ self.spec_channels = spec_channels
701
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
702
+ K = len(ref_enc_filters)
703
+ filters = [1] + ref_enc_filters
704
+ convs = [
705
+ weight_norm(
706
+ nn.Conv2d(
707
+ in_channels=filters[i],
708
+ out_channels=filters[i + 1],
709
+ kernel_size=(3, 3),
710
+ stride=(2, 2),
711
+ padding=(1, 1),
712
+ )
713
+ )
714
+ for i in range(K)
715
+ ]
716
+ self.convs = nn.ModuleList(convs)
717
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
718
+
719
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
720
+ self.gru = nn.GRU(
721
+ input_size=ref_enc_filters[-1] * out_channels,
722
+ hidden_size=256 // 2,
723
+ batch_first=True,
724
+ )
725
+ self.proj = nn.Linear(128, gin_channels)
726
+
727
+ def forward(self, inputs, mask=None):
728
+ N = inputs.size(0)
729
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
730
+ for conv in self.convs:
731
+ out = conv(out)
732
+ # out = wn(out)
733
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
734
+
735
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
736
+ T = out.size(1)
737
+ N = out.size(0)
738
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
739
+
740
+ self.gru.flatten_parameters()
741
+ memory, out = self.gru(out) # out --- [1, N, 128]
742
+
743
+ return self.proj(out.squeeze(0))
744
+
745
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
746
+ for i in range(n_convs):
747
+ L = (L - kernel_size + 2 * pad) // stride + 1
748
+ return L
749
+
750
+
751
+ class SynthesizerTrn(nn.Module):
752
+ """
753
+ Synthesizer for Training
754
+ """
755
+
756
+ def __init__(
757
+ self,
758
+ n_vocab,
759
+ spec_channels,
760
+ segment_size,
761
+ inter_channels,
762
+ hidden_channels,
763
+ filter_channels,
764
+ n_heads,
765
+ n_layers,
766
+ kernel_size,
767
+ p_dropout,
768
+ resblock,
769
+ resblock_kernel_sizes,
770
+ resblock_dilation_sizes,
771
+ upsample_rates,
772
+ upsample_initial_channel,
773
+ upsample_kernel_sizes,
774
+ n_speakers=256,
775
+ gin_channels=256,
776
+ use_sdp=True,
777
+ n_flow_layer=4,
778
+ n_layers_trans_flow=4,
779
+ flow_share_parameter=False,
780
+ use_transformer_flow=True,
781
+ **kwargs,
782
+ ):
783
+ super().__init__()
784
+ self.n_vocab = n_vocab
785
+ self.spec_channels = spec_channels
786
+ self.inter_channels = inter_channels
787
+ self.hidden_channels = hidden_channels
788
+ self.filter_channels = filter_channels
789
+ self.n_heads = n_heads
790
+ self.n_layers = n_layers
791
+ self.kernel_size = kernel_size
792
+ self.p_dropout = p_dropout
793
+ self.resblock = resblock
794
+ self.resblock_kernel_sizes = resblock_kernel_sizes
795
+ self.resblock_dilation_sizes = resblock_dilation_sizes
796
+ self.upsample_rates = upsample_rates
797
+ self.upsample_initial_channel = upsample_initial_channel
798
+ self.upsample_kernel_sizes = upsample_kernel_sizes
799
+ self.segment_size = segment_size
800
+ self.n_speakers = n_speakers
801
+ self.gin_channels = gin_channels
802
+ self.n_layers_trans_flow = n_layers_trans_flow
803
+ self.use_spk_conditioned_encoder = kwargs.get(
804
+ "use_spk_conditioned_encoder", True
805
+ )
806
+ self.use_sdp = use_sdp
807
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
808
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
809
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
810
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
811
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
812
+ self.enc_gin_channels = gin_channels
813
+ self.enc_p = TextEncoder(
814
+ n_vocab,
815
+ inter_channels,
816
+ hidden_channels,
817
+ filter_channels,
818
+ n_heads,
819
+ n_layers,
820
+ kernel_size,
821
+ p_dropout,
822
+ self.n_speakers,
823
+ gin_channels=self.enc_gin_channels,
824
+ )
825
+ self.dec = Generator(
826
+ inter_channels,
827
+ resblock,
828
+ resblock_kernel_sizes,
829
+ resblock_dilation_sizes,
830
+ upsample_rates,
831
+ upsample_initial_channel,
832
+ upsample_kernel_sizes,
833
+ gin_channels=gin_channels,
834
+ )
835
+ self.enc_q = PosteriorEncoder(
836
+ spec_channels,
837
+ inter_channels,
838
+ hidden_channels,
839
+ 5,
840
+ 1,
841
+ 16,
842
+ gin_channels=gin_channels,
843
+ )
844
+ if use_transformer_flow:
845
+ self.flow = TransformerCouplingBlock(
846
+ inter_channels,
847
+ hidden_channels,
848
+ filter_channels,
849
+ n_heads,
850
+ n_layers_trans_flow,
851
+ 5,
852
+ p_dropout,
853
+ n_flow_layer,
854
+ gin_channels=gin_channels,
855
+ share_parameter=flow_share_parameter,
856
+ )
857
+ else:
858
+ self.flow = ResidualCouplingBlock(
859
+ inter_channels,
860
+ hidden_channels,
861
+ 5,
862
+ 1,
863
+ n_flow_layer,
864
+ gin_channels=gin_channels,
865
+ )
866
+ self.sdp = StochasticDurationPredictor(
867
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
868
+ )
869
+ self.dp = DurationPredictor(
870
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
871
+ )
872
+
873
+ if n_speakers >= 1:
874
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
875
+ else:
876
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
877
+
878
+ def export_onnx(
879
+ self,
880
+ path,
881
+ max_len=None,
882
+ sdp_ratio=0,
883
+ y=None,
884
+ ):
885
+ noise_scale = 0.667
886
+ length_scale = 1
887
+ noise_scale_w = 0.8
888
+ x = (
889
+ torch.LongTensor(
890
+ [
891
+ 0,
892
+ 97,
893
+ 0,
894
+ 8,
895
+ 0,
896
+ 78,
897
+ 0,
898
+ 8,
899
+ 0,
900
+ 76,
901
+ 0,
902
+ 37,
903
+ 0,
904
+ 40,
905
+ 0,
906
+ 97,
907
+ 0,
908
+ 8,
909
+ 0,
910
+ 23,
911
+ 0,
912
+ 8,
913
+ 0,
914
+ 74,
915
+ 0,
916
+ 26,
917
+ 0,
918
+ 104,
919
+ 0,
920
+ ]
921
+ )
922
+ .unsqueeze(0)
923
+ .cpu()
924
+ )
925
+ tone = torch.zeros_like(x).cpu()
926
+ language = torch.zeros_like(x).cpu()
927
+ x_lengths = torch.LongTensor([x.shape[1]]).cpu()
928
+ sid = torch.LongTensor([0]).cpu()
929
+ bert = torch.randn(size=(x.shape[1], 1024)).cpu()
930
+ ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
931
+ en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
932
+
933
+ if self.n_speakers > 0:
934
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
935
+ torch.onnx.export(
936
+ self.emb_g,
937
+ (sid),
938
+ f"onnx/{path}/{path}_emb.onnx",
939
+ input_names=["sid"],
940
+ output_names=["g"],
941
+ verbose=True,
942
+ )
943
+ else:
944
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
945
+
946
+ emo = torch.randn(512, 1)
947
+
948
+ torch.onnx.export(
949
+ self.enc_p,
950
+ (x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g),
951
+ f"onnx/{path}/{path}_enc_p.onnx",
952
+ input_names=[
953
+ "x",
954
+ "x_lengths",
955
+ "t",
956
+ "language",
957
+ "bert_0",
958
+ "bert_1",
959
+ "bert_2",
960
+ "emo",
961
+ "g",
962
+ ],
963
+ output_names=["xout", "m_p", "logs_p", "x_mask"],
964
+ dynamic_axes={
965
+ "x": [0, 1],
966
+ "t": [0, 1],
967
+ "language": [0, 1],
968
+ "bert_0": [0],
969
+ "bert_1": [0],
970
+ "bert_2": [0],
971
+ "xout": [0, 2],
972
+ "m_p": [0, 2],
973
+ "logs_p": [0, 2],
974
+ "x_mask": [0, 2],
975
+ },
976
+ verbose=True,
977
+ opset_version=16,
978
+ )
979
+
980
+ x, m_p, logs_p, x_mask = self.enc_p(
981
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, g
982
+ )
983
+
984
+ zinput = (
985
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
986
+ * noise_scale_w
987
+ )
988
+ torch.onnx.export(
989
+ self.sdp,
990
+ (x, x_mask, zinput, g),
991
+ f"onnx/{path}/{path}_sdp.onnx",
992
+ input_names=["x", "x_mask", "zin", "g"],
993
+ output_names=["logw"],
994
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
995
+ verbose=True,
996
+ )
997
+ torch.onnx.export(
998
+ self.dp,
999
+ (x, x_mask, g),
1000
+ f"onnx/{path}/{path}_dp.onnx",
1001
+ input_names=["x", "x_mask", "g"],
1002
+ output_names=["logw"],
1003
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
1004
+ verbose=True,
1005
+ )
1006
+ logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
1007
+ x, x_mask, g=g
1008
+ ) * (1 - sdp_ratio)
1009
+ w = torch.exp(logw) * x_mask * length_scale
1010
+ w_ceil = torch.ceil(w)
1011
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1012
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1013
+ x_mask.dtype
1014
+ )
1015
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1016
+ attn = commons.generate_path(w_ceil, attn_mask)
1017
+
1018
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1019
+ 1, 2
1020
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1021
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1022
+ 1, 2
1023
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1024
+
1025
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1026
+ torch.onnx.export(
1027
+ self.flow,
1028
+ (z_p, y_mask, g),
1029
+ f"onnx/{path}/{path}_flow.onnx",
1030
+ input_names=["z_p", "y_mask", "g"],
1031
+ output_names=["z"],
1032
+ dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
1033
+ verbose=True,
1034
+ )
1035
+
1036
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1037
+ z_in = (z * y_mask)[:, :, :max_len]
1038
+
1039
+ torch.onnx.export(
1040
+ self.dec,
1041
+ (z_in, g),
1042
+ f"onnx/{path}/{path}_dec.onnx",
1043
+ input_names=["z_in", "g"],
1044
+ output_names=["o"],
1045
+ dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
1046
+ verbose=True,
1047
+ )
1048
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
onnx_modules/V220_novq_dev/text/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .symbols import *
onnx_modules/V220_novq_dev/text/symbols.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
+
5
+ # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
+ num_zh_tones = 6
74
+
75
+ # japanese
76
+ ja_symbols = [
77
+ "N",
78
+ "a",
79
+ "a:",
80
+ "b",
81
+ "by",
82
+ "ch",
83
+ "d",
84
+ "dy",
85
+ "e",
86
+ "e:",
87
+ "f",
88
+ "g",
89
+ "gy",
90
+ "h",
91
+ "hy",
92
+ "i",
93
+ "i:",
94
+ "j",
95
+ "k",
96
+ "ky",
97
+ "m",
98
+ "my",
99
+ "n",
100
+ "ny",
101
+ "o",
102
+ "o:",
103
+ "p",
104
+ "py",
105
+ "q",
106
+ "r",
107
+ "ry",
108
+ "s",
109
+ "sh",
110
+ "t",
111
+ "ts",
112
+ "ty",
113
+ "u",
114
+ "u:",
115
+ "w",
116
+ "y",
117
+ "z",
118
+ "zy",
119
+ ]
120
+ num_ja_tones = 2
121
+
122
+ # English
123
+ en_symbols = [
124
+ "aa",
125
+ "ae",
126
+ "ah",
127
+ "ao",
128
+ "aw",
129
+ "ay",
130
+ "b",
131
+ "ch",
132
+ "d",
133
+ "dh",
134
+ "eh",
135
+ "er",
136
+ "ey",
137
+ "f",
138
+ "g",
139
+ "hh",
140
+ "ih",
141
+ "iy",
142
+ "jh",
143
+ "k",
144
+ "l",
145
+ "m",
146
+ "n",
147
+ "ng",
148
+ "ow",
149
+ "oy",
150
+ "p",
151
+ "r",
152
+ "s",
153
+ "sh",
154
+ "t",
155
+ "th",
156
+ "uh",
157
+ "uw",
158
+ "V",
159
+ "w",
160
+ "y",
161
+ "z",
162
+ "zh",
163
+ ]
164
+ num_en_tones = 4
165
+
166
+ # combine all symbols
167
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
+ symbols = [pad] + normal_symbols + pu_symbols
169
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
+
171
+ # combine all tones
172
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
+
174
+ # language maps
175
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
+ num_languages = len(language_id_map.keys())
177
+
178
+ language_tone_start_map = {
179
+ "ZH": 0,
180
+ "JP": num_zh_tones,
181
+ "EN": num_zh_tones + num_ja_tones,
182
+ }
183
+
184
+ if __name__ == "__main__":
185
+ a = set(zh_symbols)
186
+ b = set(en_symbols)
187
+ print(sorted(a & b))
onnx_modules/V230/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .text.symbols import symbols
2
+ from .models_onnx import SynthesizerTrn
3
+
4
+ __all__ = ["symbols", "SynthesizerTrn"]
onnx_modules/V230/attentions_onnx.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+ # if isflow:
59
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
62
+ # self.gin_channels = 256
63
+ self.cond_layer_idx = self.n_layers
64
+ if "gin_channels" in kwargs:
65
+ self.gin_channels = kwargs["gin_channels"]
66
+ if self.gin_channels != 0:
67
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68
+ # vits2 says 3rd block, so idx is 2 by default
69
+ self.cond_layer_idx = (
70
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71
+ )
72
+ logging.debug(self.gin_channels, self.cond_layer_idx)
73
+ assert (
74
+ self.cond_layer_idx < self.n_layers
75
+ ), "cond_layer_idx should be less than n_layers"
76
+ self.drop = nn.Dropout(p_dropout)
77
+ self.attn_layers = nn.ModuleList()
78
+ self.norm_layers_1 = nn.ModuleList()
79
+ self.ffn_layers = nn.ModuleList()
80
+ self.norm_layers_2 = nn.ModuleList()
81
+ for i in range(self.n_layers):
82
+ self.attn_layers.append(
83
+ MultiHeadAttention(
84
+ hidden_channels,
85
+ hidden_channels,
86
+ n_heads,
87
+ p_dropout=p_dropout,
88
+ window_size=window_size,
89
+ )
90
+ )
91
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
92
+ self.ffn_layers.append(
93
+ FFN(
94
+ hidden_channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ kernel_size,
98
+ p_dropout=p_dropout,
99
+ )
100
+ )
101
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
102
+
103
+ def forward(self, x, x_mask, g=None):
104
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
105
+ x = x * x_mask
106
+ for i in range(self.n_layers):
107
+ if i == self.cond_layer_idx and g is not None:
108
+ g = self.spk_emb_linear(g.transpose(1, 2))
109
+ g = g.transpose(1, 2)
110
+ x = x + g
111
+ x = x * x_mask
112
+ y = self.attn_layers[i](x, x, attn_mask)
113
+ y = self.drop(y)
114
+ x = self.norm_layers_1[i](x + y)
115
+
116
+ y = self.ffn_layers[i](x, x_mask)
117
+ y = self.drop(y)
118
+ x = self.norm_layers_2[i](x + y)
119
+ x = x * x_mask
120
+ return x
121
+
122
+
123
+ class MultiHeadAttention(nn.Module):
124
+ def __init__(
125
+ self,
126
+ channels,
127
+ out_channels,
128
+ n_heads,
129
+ p_dropout=0.0,
130
+ window_size=None,
131
+ heads_share=True,
132
+ block_length=None,
133
+ proximal_bias=False,
134
+ proximal_init=False,
135
+ ):
136
+ super().__init__()
137
+ assert channels % n_heads == 0
138
+
139
+ self.channels = channels
140
+ self.out_channels = out_channels
141
+ self.n_heads = n_heads
142
+ self.p_dropout = p_dropout
143
+ self.window_size = window_size
144
+ self.heads_share = heads_share
145
+ self.block_length = block_length
146
+ self.proximal_bias = proximal_bias
147
+ self.proximal_init = proximal_init
148
+ self.attn = None
149
+
150
+ self.k_channels = channels // n_heads
151
+ self.conv_q = nn.Conv1d(channels, channels, 1)
152
+ self.conv_k = nn.Conv1d(channels, channels, 1)
153
+ self.conv_v = nn.Conv1d(channels, channels, 1)
154
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
155
+ self.drop = nn.Dropout(p_dropout)
156
+
157
+ if window_size is not None:
158
+ n_heads_rel = 1 if heads_share else n_heads
159
+ rel_stddev = self.k_channels**-0.5
160
+ self.emb_rel_k = nn.Parameter(
161
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
162
+ * rel_stddev
163
+ )
164
+ self.emb_rel_v = nn.Parameter(
165
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
166
+ * rel_stddev
167
+ )
168
+
169
+ nn.init.xavier_uniform_(self.conv_q.weight)
170
+ nn.init.xavier_uniform_(self.conv_k.weight)
171
+ nn.init.xavier_uniform_(self.conv_v.weight)
172
+ if proximal_init:
173
+ with torch.no_grad():
174
+ self.conv_k.weight.copy_(self.conv_q.weight)
175
+ self.conv_k.bias.copy_(self.conv_q.bias)
176
+
177
+ def forward(self, x, c, attn_mask=None):
178
+ q = self.conv_q(x)
179
+ k = self.conv_k(c)
180
+ v = self.conv_v(c)
181
+
182
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
+
184
+ x = self.conv_o(x)
185
+ return x
186
+
187
+ def attention(self, query, key, value, mask=None):
188
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
189
+ b, d, t_s, t_t = (*key.size(), query.size(2))
190
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
191
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
193
+
194
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
195
+ if self.window_size is not None:
196
+ assert (
197
+ t_s == t_t
198
+ ), "Relative attention is only available for self-attention."
199
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
200
+ rel_logits = self._matmul_with_relative_keys(
201
+ query / math.sqrt(self.k_channels), key_relative_embeddings
202
+ )
203
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
204
+ scores = scores + scores_local
205
+ if self.proximal_bias:
206
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
207
+ scores = scores + self._attention_bias_proximal(t_s).to(
208
+ device=scores.device, dtype=scores.dtype
209
+ )
210
+ if mask is not None:
211
+ scores = scores.masked_fill(mask == 0, -1e4)
212
+ if self.block_length is not None:
213
+ assert (
214
+ t_s == t_t
215
+ ), "Local attention is only available for self-attention."
216
+ block_mask = (
217
+ torch.ones_like(scores)
218
+ .triu(-self.block_length)
219
+ .tril(self.block_length)
220
+ )
221
+ scores = scores.masked_fill(block_mask == 0, -1e4)
222
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
223
+ p_attn = self.drop(p_attn)
224
+ output = torch.matmul(p_attn, value)
225
+ if self.window_size is not None:
226
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
227
+ value_relative_embeddings = self._get_relative_embeddings(
228
+ self.emb_rel_v, t_s
229
+ )
230
+ output = output + self._matmul_with_relative_values(
231
+ relative_weights, value_relative_embeddings
232
+ )
233
+ output = (
234
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
235
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
236
+ return output, p_attn
237
+
238
+ def _matmul_with_relative_values(self, x, y):
239
+ """
240
+ x: [b, h, l, m]
241
+ y: [h or 1, m, d]
242
+ ret: [b, h, l, d]
243
+ """
244
+ ret = torch.matmul(x, y.unsqueeze(0))
245
+ return ret
246
+
247
+ def _matmul_with_relative_keys(self, x, y):
248
+ """
249
+ x: [b, h, l, d]
250
+ y: [h or 1, m, d]
251
+ ret: [b, h, l, m]
252
+ """
253
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
254
+ return ret
255
+
256
+ def _get_relative_embeddings(self, relative_embeddings, length):
257
+ max_relative_position = 2 * self.window_size + 1
258
+ # Pad first before slice to avoid using cond ops.
259
+ pad_length = max(length - (self.window_size + 1), 0)
260
+ slice_start_position = max((self.window_size + 1) - length, 0)
261
+ slice_end_position = slice_start_position + 2 * length - 1
262
+ if pad_length > 0:
263
+ padded_relative_embeddings = F.pad(
264
+ relative_embeddings,
265
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
266
+ )
267
+ else:
268
+ padded_relative_embeddings = relative_embeddings
269
+ used_relative_embeddings = padded_relative_embeddings[
270
+ :, slice_start_position:slice_end_position
271
+ ]
272
+ return used_relative_embeddings
273
+
274
+ def _relative_position_to_absolute_position(self, x):
275
+ """
276
+ x: [b, h, l, 2*l-1]
277
+ ret: [b, h, l, l]
278
+ """
279
+ batch, heads, length, _ = x.size()
280
+ # Concat columns of pad to shift from relative to absolute indexing.
281
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
282
+
283
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
284
+ x_flat = x.view([batch, heads, length * 2 * length])
285
+ x_flat = F.pad(
286
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
287
+ )
288
+
289
+ # Reshape and slice out the padded elements.
290
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
291
+ :, :, :length, length - 1 :
292
+ ]
293
+ return x_final
294
+
295
+ def _absolute_position_to_relative_position(self, x):
296
+ """
297
+ x: [b, h, l, l]
298
+ ret: [b, h, l, 2*l-1]
299
+ """
300
+ batch, heads, length, _ = x.size()
301
+ # padd along column
302
+ x = F.pad(
303
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
304
+ )
305
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
306
+ # add 0's in the beginning that will skew the elements after reshape
307
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
308
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
309
+ return x_final
310
+
311
+ def _attention_bias_proximal(self, length):
312
+ """Bias for self-attention to encourage attention to close positions.
313
+ Args:
314
+ length: an integer scalar.
315
+ Returns:
316
+ a Tensor with shape [1, 1, length, length]
317
+ """
318
+ r = torch.arange(length, dtype=torch.float32)
319
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
320
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
321
+
322
+
323
+ class FFN(nn.Module):
324
+ def __init__(
325
+ self,
326
+ in_channels,
327
+ out_channels,
328
+ filter_channels,
329
+ kernel_size,
330
+ p_dropout=0.0,
331
+ activation=None,
332
+ causal=False,
333
+ ):
334
+ super().__init__()
335
+ self.in_channels = in_channels
336
+ self.out_channels = out_channels
337
+ self.filter_channels = filter_channels
338
+ self.kernel_size = kernel_size
339
+ self.p_dropout = p_dropout
340
+ self.activation = activation
341
+ self.causal = causal
342
+
343
+ if causal:
344
+ self.padding = self._causal_padding
345
+ else:
346
+ self.padding = self._same_padding
347
+
348
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
349
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
350
+ self.drop = nn.Dropout(p_dropout)
351
+
352
+ def forward(self, x, x_mask):
353
+ x = self.conv_1(self.padding(x * x_mask))
354
+ if self.activation == "gelu":
355
+ x = x * torch.sigmoid(1.702 * x)
356
+ else:
357
+ x = torch.relu(x)
358
+ x = self.drop(x)
359
+ x = self.conv_2(self.padding(x * x_mask))
360
+ return x * x_mask
361
+
362
+ def _causal_padding(self, x):
363
+ if self.kernel_size == 1:
364
+ return x
365
+ pad_l = self.kernel_size - 1
366
+ pad_r = 0
367
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
368
+ x = F.pad(x, commons.convert_pad_shape(padding))
369
+ return x
370
+
371
+ def _same_padding(self, x):
372
+ if self.kernel_size == 1:
373
+ return x
374
+ pad_l = (self.kernel_size - 1) // 2
375
+ pad_r = self.kernel_size // 2
376
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
377
+ x = F.pad(x, commons.convert_pad_shape(padding))
378
+ return x
onnx_modules/V230/models_onnx.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ from . import attentions_onnx
9
+
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ from commons import init_weights, get_padding
15
+ from .text import symbols, num_tones, num_languages
16
+
17
+
18
+
19
+
20
+ class DurationDiscriminator(nn.Module): # vits2
21
+ def __init__(
22
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
23
+ ):
24
+ super().__init__()
25
+
26
+ self.in_channels = in_channels
27
+ self.filter_channels = filter_channels
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.gin_channels = gin_channels
31
+
32
+ self.drop = nn.Dropout(p_dropout)
33
+ self.conv_1 = nn.Conv1d(
34
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_1 = modules.LayerNorm(filter_channels)
37
+ self.conv_2 = nn.Conv1d(
38
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
39
+ )
40
+ self.norm_2 = modules.LayerNorm(filter_channels)
41
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
42
+
43
+ self.LSTM = nn.LSTM(
44
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
45
+ )
46
+
47
+ if gin_channels != 0:
48
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
49
+
50
+ self.output_layer = nn.Sequential(
51
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
52
+ )
53
+
54
+ def forward_probability(self, x, dur):
55
+ dur = self.dur_proj(dur)
56
+ x = torch.cat([x, dur], dim=1)
57
+ x = x.transpose(1, 2)
58
+ x, _ = self.LSTM(x)
59
+ output_prob = self.output_layer(x)
60
+ return output_prob
61
+
62
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
63
+ x = torch.detach(x)
64
+ if g is not None:
65
+ g = torch.detach(g)
66
+ x = x + self.cond(g)
67
+ x = self.conv_1(x * x_mask)
68
+ x = torch.relu(x)
69
+ x = self.norm_1(x)
70
+ x = self.drop(x)
71
+ x = self.conv_2(x * x_mask)
72
+ x = torch.relu(x)
73
+ x = self.norm_2(x)
74
+ x = self.drop(x)
75
+
76
+ output_probs = []
77
+ for dur in [dur_r, dur_hat]:
78
+ output_prob = self.forward_probability(x, dur)
79
+ output_probs.append(output_prob)
80
+
81
+ return output_probs
82
+
83
+
84
+ class TransformerCouplingBlock(nn.Module):
85
+ def __init__(
86
+ self,
87
+ channels,
88
+ hidden_channels,
89
+ filter_channels,
90
+ n_heads,
91
+ n_layers,
92
+ kernel_size,
93
+ p_dropout,
94
+ n_flows=4,
95
+ gin_channels=0,
96
+ share_parameter=False,
97
+ ):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.hidden_channels = hidden_channels
101
+ self.kernel_size = kernel_size
102
+ self.n_layers = n_layers
103
+ self.n_flows = n_flows
104
+ self.gin_channels = gin_channels
105
+
106
+ self.flows = nn.ModuleList()
107
+
108
+ self.wn = (
109
+ attentions_onnx.FFT(
110
+ hidden_channels,
111
+ filter_channels,
112
+ n_heads,
113
+ n_layers,
114
+ kernel_size,
115
+ p_dropout,
116
+ isflow=True,
117
+ gin_channels=self.gin_channels,
118
+ )
119
+ if share_parameter
120
+ else None
121
+ )
122
+
123
+ for i in range(n_flows):
124
+ self.flows.append(
125
+ modules.TransformerCouplingLayer(
126
+ channels,
127
+ hidden_channels,
128
+ kernel_size,
129
+ n_layers,
130
+ n_heads,
131
+ p_dropout,
132
+ filter_channels,
133
+ mean_only=True,
134
+ wn_sharing_parameter=self.wn,
135
+ gin_channels=self.gin_channels,
136
+ )
137
+ )
138
+ self.flows.append(modules.Flip())
139
+
140
+ def forward(self, x, x_mask, g=None, reverse=True):
141
+ if not reverse:
142
+ for flow in self.flows:
143
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
144
+ else:
145
+ for flow in reversed(self.flows):
146
+ x = flow(x, x_mask, g=g, reverse=reverse)
147
+ return x
148
+
149
+
150
+ class StochasticDurationPredictor(nn.Module):
151
+ def __init__(
152
+ self,
153
+ in_channels,
154
+ filter_channels,
155
+ kernel_size,
156
+ p_dropout,
157
+ n_flows=4,
158
+ gin_channels=0,
159
+ ):
160
+ super().__init__()
161
+ filter_channels = in_channels # it needs to be removed from future version.
162
+ self.in_channels = in_channels
163
+ self.filter_channels = filter_channels
164
+ self.kernel_size = kernel_size
165
+ self.p_dropout = p_dropout
166
+ self.n_flows = n_flows
167
+ self.gin_channels = gin_channels
168
+
169
+ self.log_flow = modules.Log()
170
+ self.flows = nn.ModuleList()
171
+ self.flows.append(modules.ElementwiseAffine(2))
172
+ for i in range(n_flows):
173
+ self.flows.append(
174
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
175
+ )
176
+ self.flows.append(modules.Flip())
177
+
178
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
179
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
180
+ self.post_convs = modules.DDSConv(
181
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
182
+ )
183
+ self.post_flows = nn.ModuleList()
184
+ self.post_flows.append(modules.ElementwiseAffine(2))
185
+ for i in range(4):
186
+ self.post_flows.append(
187
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
188
+ )
189
+ self.post_flows.append(modules.Flip())
190
+
191
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
192
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
193
+ self.convs = modules.DDSConv(
194
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
195
+ )
196
+ if gin_channels != 0:
197
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
198
+
199
+ def forward(self, x, x_mask, z, g=None):
200
+ x = torch.detach(x)
201
+ x = self.pre(x)
202
+ if g is not None:
203
+ g = torch.detach(g)
204
+ x = x + self.cond(g)
205
+ x = self.convs(x, x_mask)
206
+ x = self.proj(x) * x_mask
207
+
208
+ flows = list(reversed(self.flows))
209
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
210
+ for flow in flows:
211
+ z = flow(z, x_mask, g=x, reverse=True)
212
+ z0, z1 = torch.split(z, [1, 1], 1)
213
+ logw = z0
214
+ return logw
215
+
216
+
217
+ class DurationPredictor(nn.Module):
218
+ def __init__(
219
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
220
+ ):
221
+ super().__init__()
222
+
223
+ self.in_channels = in_channels
224
+ self.filter_channels = filter_channels
225
+ self.kernel_size = kernel_size
226
+ self.p_dropout = p_dropout
227
+ self.gin_channels = gin_channels
228
+
229
+ self.drop = nn.Dropout(p_dropout)
230
+ self.conv_1 = nn.Conv1d(
231
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
232
+ )
233
+ self.norm_1 = modules.LayerNorm(filter_channels)
234
+ self.conv_2 = nn.Conv1d(
235
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
236
+ )
237
+ self.norm_2 = modules.LayerNorm(filter_channels)
238
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
239
+
240
+ if gin_channels != 0:
241
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
242
+
243
+ def forward(self, x, x_mask, g=None):
244
+ x = torch.detach(x)
245
+ if g is not None:
246
+ g = torch.detach(g)
247
+ x = x + self.cond(g)
248
+ x = self.conv_1(x * x_mask)
249
+ x = torch.relu(x)
250
+ x = self.norm_1(x)
251
+ x = self.drop(x)
252
+ x = self.conv_2(x * x_mask)
253
+ x = torch.relu(x)
254
+ x = self.norm_2(x)
255
+ x = self.drop(x)
256
+ x = self.proj(x * x_mask)
257
+ return x * x_mask
258
+
259
+
260
+ class Bottleneck(nn.Sequential):
261
+ def __init__(self, in_dim, hidden_dim):
262
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
263
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
264
+ super().__init__(*[c_fc1, c_fc2])
265
+
266
+
267
+ class Block(nn.Module):
268
+ def __init__(self, in_dim, hidden_dim) -> None:
269
+ super().__init__()
270
+ self.norm = nn.LayerNorm(in_dim)
271
+ self.mlp = MLP(in_dim, hidden_dim)
272
+
273
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
274
+ x = x + self.mlp(self.norm(x))
275
+ return x
276
+
277
+
278
+ class MLP(nn.Module):
279
+ def __init__(self, in_dim, hidden_dim):
280
+ super().__init__()
281
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
282
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
283
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
284
+
285
+ def forward(self, x: torch.Tensor):
286
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
287
+ x = self.c_proj(x)
288
+ return x
289
+
290
+
291
+ class TextEncoder(nn.Module):
292
+ def __init__(
293
+ self,
294
+ n_vocab,
295
+ out_channels,
296
+ hidden_channels,
297
+ filter_channels,
298
+ n_heads,
299
+ n_layers,
300
+ kernel_size,
301
+ p_dropout,
302
+ gin_channels=0,
303
+ ):
304
+ super().__init__()
305
+ self.n_vocab = n_vocab
306
+ self.out_channels = out_channels
307
+ self.hidden_channels = hidden_channels
308
+ self.filter_channels = filter_channels
309
+ self.n_heads = n_heads
310
+ self.n_layers = n_layers
311
+ self.kernel_size = kernel_size
312
+ self.p_dropout = p_dropout
313
+ self.gin_channels = gin_channels
314
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
315
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
316
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
317
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
318
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
319
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
320
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
321
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
322
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
323
+
324
+ self.encoder = attentions_onnx.Encoder(
325
+ hidden_channels,
326
+ filter_channels,
327
+ n_heads,
328
+ n_layers,
329
+ kernel_size,
330
+ p_dropout,
331
+ gin_channels=self.gin_channels,
332
+ )
333
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
334
+
335
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
336
+ x_mask = torch.ones_like(x).unsqueeze(0)
337
+ bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
338
+ ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
339
+ en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
340
+ x = (
341
+ self.emb(x)
342
+ + self.tone_emb(tone)
343
+ + self.language_emb(language)
344
+ + bert_emb
345
+ + ja_bert_emb
346
+ + en_bert_emb
347
+ ) * math.sqrt(
348
+ self.hidden_channels
349
+ ) # [b, t, h]
350
+ x = torch.transpose(x, 1, -1) # [b, h, t]
351
+ x_mask = x_mask.to(x.dtype)
352
+
353
+ x = self.encoder(x * x_mask, x_mask, g=g)
354
+ stats = self.proj(x) * x_mask
355
+
356
+ m, logs = torch.split(stats, self.out_channels, dim=1)
357
+ return x, m, logs, x_mask
358
+
359
+
360
+ class ResidualCouplingBlock(nn.Module):
361
+ def __init__(
362
+ self,
363
+ channels,
364
+ hidden_channels,
365
+ kernel_size,
366
+ dilation_rate,
367
+ n_layers,
368
+ n_flows=4,
369
+ gin_channels=0,
370
+ ):
371
+ super().__init__()
372
+ self.channels = channels
373
+ self.hidden_channels = hidden_channels
374
+ self.kernel_size = kernel_size
375
+ self.dilation_rate = dilation_rate
376
+ self.n_layers = n_layers
377
+ self.n_flows = n_flows
378
+ self.gin_channels = gin_channels
379
+
380
+ self.flows = nn.ModuleList()
381
+ for i in range(n_flows):
382
+ self.flows.append(
383
+ modules.ResidualCouplingLayer(
384
+ channels,
385
+ hidden_channels,
386
+ kernel_size,
387
+ dilation_rate,
388
+ n_layers,
389
+ gin_channels=gin_channels,
390
+ mean_only=True,
391
+ )
392
+ )
393
+ self.flows.append(modules.Flip())
394
+
395
+ def forward(self, x, x_mask, g=None, reverse=True):
396
+ if not reverse:
397
+ for flow in self.flows:
398
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
399
+ else:
400
+ for flow in reversed(self.flows):
401
+ x = flow(x, x_mask, g=g, reverse=reverse)
402
+ return x
403
+
404
+
405
+ class PosteriorEncoder(nn.Module):
406
+ def __init__(
407
+ self,
408
+ in_channels,
409
+ out_channels,
410
+ hidden_channels,
411
+ kernel_size,
412
+ dilation_rate,
413
+ n_layers,
414
+ gin_channels=0,
415
+ ):
416
+ super().__init__()
417
+ self.in_channels = in_channels
418
+ self.out_channels = out_channels
419
+ self.hidden_channels = hidden_channels
420
+ self.kernel_size = kernel_size
421
+ self.dilation_rate = dilation_rate
422
+ self.n_layers = n_layers
423
+ self.gin_channels = gin_channels
424
+
425
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
426
+ self.enc = modules.WN(
427
+ hidden_channels,
428
+ kernel_size,
429
+ dilation_rate,
430
+ n_layers,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
434
+
435
+ def forward(self, x, x_lengths, g=None):
436
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
437
+ x.dtype
438
+ )
439
+ x = self.pre(x) * x_mask
440
+ x = self.enc(x, x_mask, g=g)
441
+ stats = self.proj(x) * x_mask
442
+ m, logs = torch.split(stats, self.out_channels, dim=1)
443
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
444
+ return z, m, logs, x_mask
445
+
446
+
447
+ class Generator(torch.nn.Module):
448
+ def __init__(
449
+ self,
450
+ initial_channel,
451
+ resblock,
452
+ resblock_kernel_sizes,
453
+ resblock_dilation_sizes,
454
+ upsample_rates,
455
+ upsample_initial_channel,
456
+ upsample_kernel_sizes,
457
+ gin_channels=0,
458
+ ):
459
+ super(Generator, self).__init__()
460
+ self.num_kernels = len(resblock_kernel_sizes)
461
+ self.num_upsamples = len(upsample_rates)
462
+ self.conv_pre = Conv1d(
463
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
464
+ )
465
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
466
+
467
+ self.ups = nn.ModuleList()
468
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
469
+ self.ups.append(
470
+ weight_norm(
471
+ ConvTranspose1d(
472
+ upsample_initial_channel // (2**i),
473
+ upsample_initial_channel // (2 ** (i + 1)),
474
+ k,
475
+ u,
476
+ padding=(k - u) // 2,
477
+ )
478
+ )
479
+ )
480
+
481
+ self.resblocks = nn.ModuleList()
482
+ for i in range(len(self.ups)):
483
+ ch = upsample_initial_channel // (2 ** (i + 1))
484
+ for j, (k, d) in enumerate(
485
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
486
+ ):
487
+ self.resblocks.append(resblock(ch, k, d))
488
+
489
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
490
+ self.ups.apply(init_weights)
491
+
492
+ if gin_channels != 0:
493
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
494
+
495
+ def forward(self, x, g=None):
496
+ x = self.conv_pre(x)
497
+ if g is not None:
498
+ x = x + self.cond(g)
499
+
500
+ for i in range(self.num_upsamples):
501
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
502
+ x = self.ups[i](x)
503
+ xs = None
504
+ for j in range(self.num_kernels):
505
+ if xs is None:
506
+ xs = self.resblocks[i * self.num_kernels + j](x)
507
+ else:
508
+ xs += self.resblocks[i * self.num_kernels + j](x)
509
+ x = xs / self.num_kernels
510
+ x = F.leaky_relu(x)
511
+ x = self.conv_post(x)
512
+ x = torch.tanh(x)
513
+
514
+ return x
515
+
516
+ def remove_weight_norm(self):
517
+ print("Removing weight norm...")
518
+ for layer in self.ups:
519
+ remove_weight_norm(layer)
520
+ for layer in self.resblocks:
521
+ layer.remove_weight_norm()
522
+
523
+
524
+ class DiscriminatorP(torch.nn.Module):
525
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
526
+ super(DiscriminatorP, self).__init__()
527
+ self.period = period
528
+ self.use_spectral_norm = use_spectral_norm
529
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
530
+ self.convs = nn.ModuleList(
531
+ [
532
+ norm_f(
533
+ Conv2d(
534
+ 1,
535
+ 32,
536
+ (kernel_size, 1),
537
+ (stride, 1),
538
+ padding=(get_padding(kernel_size, 1), 0),
539
+ )
540
+ ),
541
+ norm_f(
542
+ Conv2d(
543
+ 32,
544
+ 128,
545
+ (kernel_size, 1),
546
+ (stride, 1),
547
+ padding=(get_padding(kernel_size, 1), 0),
548
+ )
549
+ ),
550
+ norm_f(
551
+ Conv2d(
552
+ 128,
553
+ 512,
554
+ (kernel_size, 1),
555
+ (stride, 1),
556
+ padding=(get_padding(kernel_size, 1), 0),
557
+ )
558
+ ),
559
+ norm_f(
560
+ Conv2d(
561
+ 512,
562
+ 1024,
563
+ (kernel_size, 1),
564
+ (stride, 1),
565
+ padding=(get_padding(kernel_size, 1), 0),
566
+ )
567
+ ),
568
+ norm_f(
569
+ Conv2d(
570
+ 1024,
571
+ 1024,
572
+ (kernel_size, 1),
573
+ 1,
574
+ padding=(get_padding(kernel_size, 1), 0),
575
+ )
576
+ ),
577
+ ]
578
+ )
579
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
580
+
581
+ def forward(self, x):
582
+ fmap = []
583
+
584
+ # 1d to 2d
585
+ b, c, t = x.shape
586
+ if t % self.period != 0: # pad first
587
+ n_pad = self.period - (t % self.period)
588
+ x = F.pad(x, (0, n_pad), "reflect")
589
+ t = t + n_pad
590
+ x = x.view(b, c, t // self.period, self.period)
591
+
592
+ for layer in self.convs:
593
+ x = layer(x)
594
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
595
+ fmap.append(x)
596
+ x = self.conv_post(x)
597
+ fmap.append(x)
598
+ x = torch.flatten(x, 1, -1)
599
+
600
+ return x, fmap
601
+
602
+
603
+ class DiscriminatorS(torch.nn.Module):
604
+ def __init__(self, use_spectral_norm=False):
605
+ super(DiscriminatorS, self).__init__()
606
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
607
+ self.convs = nn.ModuleList(
608
+ [
609
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
610
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
611
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
612
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
613
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
614
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
615
+ ]
616
+ )
617
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
618
+
619
+ def forward(self, x):
620
+ fmap = []
621
+
622
+ for layer in self.convs:
623
+ x = layer(x)
624
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
625
+ fmap.append(x)
626
+ x = self.conv_post(x)
627
+ fmap.append(x)
628
+ x = torch.flatten(x, 1, -1)
629
+
630
+ return x, fmap
631
+
632
+
633
+ class MultiPeriodDiscriminator(torch.nn.Module):
634
+ def __init__(self, use_spectral_norm=False):
635
+ super(MultiPeriodDiscriminator, self).__init__()
636
+ periods = [2, 3, 5, 7, 11]
637
+
638
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
639
+ discs = discs + [
640
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
641
+ ]
642
+ self.discriminators = nn.ModuleList(discs)
643
+
644
+ def forward(self, y, y_hat):
645
+ y_d_rs = []
646
+ y_d_gs = []
647
+ fmap_rs = []
648
+ fmap_gs = []
649
+ for i, d in enumerate(self.discriminators):
650
+ y_d_r, fmap_r = d(y)
651
+ y_d_g, fmap_g = d(y_hat)
652
+ y_d_rs.append(y_d_r)
653
+ y_d_gs.append(y_d_g)
654
+ fmap_rs.append(fmap_r)
655
+ fmap_gs.append(fmap_g)
656
+
657
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
658
+
659
+
660
+ class WavLMDiscriminator(nn.Module):
661
+ """docstring for Discriminator."""
662
+
663
+ def __init__(
664
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
665
+ ):
666
+ super(WavLMDiscriminator, self).__init__()
667
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
668
+ self.pre = norm_f(
669
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
670
+ )
671
+
672
+ self.convs = nn.ModuleList(
673
+ [
674
+ norm_f(
675
+ nn.Conv1d(
676
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
677
+ )
678
+ ),
679
+ norm_f(
680
+ nn.Conv1d(
681
+ initial_channel * 2,
682
+ initial_channel * 4,
683
+ kernel_size=5,
684
+ padding=2,
685
+ )
686
+ ),
687
+ norm_f(
688
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
689
+ ),
690
+ ]
691
+ )
692
+
693
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
694
+
695
+ def forward(self, x):
696
+ x = self.pre(x)
697
+
698
+ fmap = []
699
+ for l in self.convs:
700
+ x = l(x)
701
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
702
+ fmap.append(x)
703
+ x = self.conv_post(x)
704
+ x = torch.flatten(x, 1, -1)
705
+
706
+ return x
707
+
708
+
709
+ class ReferenceEncoder(nn.Module):
710
+ """
711
+ inputs --- [N, Ty/r, n_mels*r] mels
712
+ outputs --- [N, ref_enc_gru_size]
713
+ """
714
+
715
+ def __init__(self, spec_channels, gin_channels=0):
716
+ super().__init__()
717
+ self.spec_channels = spec_channels
718
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
719
+ K = len(ref_enc_filters)
720
+ filters = [1] + ref_enc_filters
721
+ convs = [
722
+ weight_norm(
723
+ nn.Conv2d(
724
+ in_channels=filters[i],
725
+ out_channels=filters[i + 1],
726
+ kernel_size=(3, 3),
727
+ stride=(2, 2),
728
+ padding=(1, 1),
729
+ )
730
+ )
731
+ for i in range(K)
732
+ ]
733
+ self.convs = nn.ModuleList(convs)
734
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
735
+
736
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
737
+ self.gru = nn.GRU(
738
+ input_size=ref_enc_filters[-1] * out_channels,
739
+ hidden_size=256 // 2,
740
+ batch_first=True,
741
+ )
742
+ self.proj = nn.Linear(128, gin_channels)
743
+
744
+ def forward(self, inputs, mask=None):
745
+ N = inputs.size(0)
746
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
747
+ for conv in self.convs:
748
+ out = conv(out)
749
+ # out = wn(out)
750
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
751
+
752
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
753
+ T = out.size(1)
754
+ N = out.size(0)
755
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
756
+
757
+ self.gru.flatten_parameters()
758
+ memory, out = self.gru(out) # out --- [1, N, 128]
759
+
760
+ return self.proj(out.squeeze(0))
761
+
762
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
763
+ for i in range(n_convs):
764
+ L = (L - kernel_size + 2 * pad) // stride + 1
765
+ return L
766
+
767
+
768
+ class SynthesizerTrn(nn.Module):
769
+ """
770
+ Synthesizer for Training
771
+ """
772
+
773
+ def __init__(
774
+ self,
775
+ n_vocab,
776
+ spec_channels,
777
+ segment_size,
778
+ inter_channels,
779
+ hidden_channels,
780
+ filter_channels,
781
+ n_heads,
782
+ n_layers,
783
+ kernel_size,
784
+ p_dropout,
785
+ resblock,
786
+ resblock_kernel_sizes,
787
+ resblock_dilation_sizes,
788
+ upsample_rates,
789
+ upsample_initial_channel,
790
+ upsample_kernel_sizes,
791
+ n_speakers=256,
792
+ gin_channels=256,
793
+ use_sdp=True,
794
+ n_flow_layer=4,
795
+ n_layers_trans_flow=4,
796
+ flow_share_parameter=False,
797
+ use_transformer_flow=True,
798
+ **kwargs
799
+ ):
800
+ super().__init__()
801
+ self.n_vocab = n_vocab
802
+ self.spec_channels = spec_channels
803
+ self.inter_channels = inter_channels
804
+ self.hidden_channels = hidden_channels
805
+ self.filter_channels = filter_channels
806
+ self.n_heads = n_heads
807
+ self.n_layers = n_layers
808
+ self.kernel_size = kernel_size
809
+ self.p_dropout = p_dropout
810
+ self.resblock = resblock
811
+ self.resblock_kernel_sizes = resblock_kernel_sizes
812
+ self.resblock_dilation_sizes = resblock_dilation_sizes
813
+ self.upsample_rates = upsample_rates
814
+ self.upsample_initial_channel = upsample_initial_channel
815
+ self.upsample_kernel_sizes = upsample_kernel_sizes
816
+ self.segment_size = segment_size
817
+ self.n_speakers = n_speakers
818
+ self.gin_channels = gin_channels
819
+ self.n_layers_trans_flow = n_layers_trans_flow
820
+ self.use_spk_conditioned_encoder = kwargs.get(
821
+ "use_spk_conditioned_encoder", True
822
+ )
823
+ self.use_sdp = use_sdp
824
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
825
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
826
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
827
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
828
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
829
+ self.enc_gin_channels = gin_channels
830
+ self.enc_p = TextEncoder(
831
+ n_vocab,
832
+ inter_channels,
833
+ hidden_channels,
834
+ filter_channels,
835
+ n_heads,
836
+ n_layers,
837
+ kernel_size,
838
+ p_dropout,
839
+ gin_channels=self.enc_gin_channels,
840
+ )
841
+ self.dec = Generator(
842
+ inter_channels,
843
+ resblock,
844
+ resblock_kernel_sizes,
845
+ resblock_dilation_sizes,
846
+ upsample_rates,
847
+ upsample_initial_channel,
848
+ upsample_kernel_sizes,
849
+ gin_channels=gin_channels,
850
+ )
851
+ self.enc_q = PosteriorEncoder(
852
+ spec_channels,
853
+ inter_channels,
854
+ hidden_channels,
855
+ 5,
856
+ 1,
857
+ 16,
858
+ gin_channels=gin_channels,
859
+ )
860
+ if use_transformer_flow:
861
+ self.flow = TransformerCouplingBlock(
862
+ inter_channels,
863
+ hidden_channels,
864
+ filter_channels,
865
+ n_heads,
866
+ n_layers_trans_flow,
867
+ 5,
868
+ p_dropout,
869
+ n_flow_layer,
870
+ gin_channels=gin_channels,
871
+ share_parameter=flow_share_parameter,
872
+ )
873
+ else:
874
+ self.flow = ResidualCouplingBlock(
875
+ inter_channels,
876
+ hidden_channels,
877
+ 5,
878
+ 1,
879
+ n_flow_layer,
880
+ gin_channels=gin_channels,
881
+ )
882
+ self.sdp = StochasticDurationPredictor(
883
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
884
+ )
885
+ self.dp = DurationPredictor(
886
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
887
+ )
888
+
889
+ if n_speakers >= 1:
890
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
891
+ else:
892
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
893
+
894
+ def export_onnx(
895
+ self,
896
+ path,
897
+ max_len=None,
898
+ sdp_ratio=0,
899
+ y=None,
900
+ ):
901
+ noise_scale = 0.667
902
+ length_scale = 1
903
+ noise_scale_w = 0.8
904
+ x = (
905
+ torch.LongTensor(
906
+ [
907
+ 0,
908
+ 97,
909
+ 0,
910
+ 8,
911
+ 0,
912
+ 78,
913
+ 0,
914
+ 8,
915
+ 0,
916
+ 76,
917
+ 0,
918
+ 37,
919
+ 0,
920
+ 40,
921
+ 0,
922
+ 97,
923
+ 0,
924
+ 8,
925
+ 0,
926
+ 23,
927
+ 0,
928
+ 8,
929
+ 0,
930
+ 74,
931
+ 0,
932
+ 26,
933
+ 0,
934
+ 104,
935
+ 0,
936
+ ]
937
+ )
938
+ .unsqueeze(0)
939
+ .cpu()
940
+ )
941
+ tone = torch.zeros_like(x).cpu()
942
+ language = torch.zeros_like(x).cpu()
943
+ x_lengths = torch.LongTensor([x.shape[1]]).cpu()
944
+ sid = torch.LongTensor([0]).cpu()
945
+ bert = torch.randn(size=(x.shape[1], 1024)).cpu()
946
+ ja_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
947
+ en_bert = torch.randn(size=(x.shape[1], 1024)).cpu()
948
+
949
+ if self.n_speakers > 0:
950
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
951
+ torch.onnx.export(
952
+ self.emb_g,
953
+ (sid),
954
+ f"onnx/{path}/{path}_emb.onnx",
955
+ input_names=["sid"],
956
+ output_names=["g"],
957
+ verbose=True,
958
+ )
959
+ else:
960
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
961
+
962
+ torch.onnx.export(
963
+ self.enc_p,
964
+ (x, x_lengths, tone, language, bert, ja_bert, en_bert, g),
965
+ f"onnx/{path}/{path}_enc_p.onnx",
966
+ input_names=[
967
+ "x",
968
+ "x_lengths",
969
+ "t",
970
+ "language",
971
+ "bert_0",
972
+ "bert_1",
973
+ "bert_2",
974
+ "g",
975
+ ],
976
+ output_names=["xout", "m_p", "logs_p", "x_mask"],
977
+ dynamic_axes={
978
+ "x": [0, 1],
979
+ "t": [0, 1],
980
+ "language": [0, 1],
981
+ "bert_0": [0],
982
+ "bert_1": [0],
983
+ "bert_2": [0],
984
+ "xout": [0, 2],
985
+ "m_p": [0, 2],
986
+ "logs_p": [0, 2],
987
+ "x_mask": [0, 2],
988
+ },
989
+ verbose=True,
990
+ opset_version=16,
991
+ )
992
+
993
+ x, m_p, logs_p, x_mask = self.enc_p(
994
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g
995
+ )
996
+
997
+ zinput = (
998
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
999
+ * noise_scale_w
1000
+ )
1001
+ torch.onnx.export(
1002
+ self.sdp,
1003
+ (x, x_mask, zinput, g),
1004
+ f"onnx/{path}/{path}_sdp.onnx",
1005
+ input_names=["x", "x_mask", "zin", "g"],
1006
+ output_names=["logw"],
1007
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
1008
+ verbose=True,
1009
+ )
1010
+ torch.onnx.export(
1011
+ self.dp,
1012
+ (x, x_mask, g),
1013
+ f"onnx/{path}/{path}_dp.onnx",
1014
+ input_names=["x", "x_mask", "g"],
1015
+ output_names=["logw"],
1016
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
1017
+ verbose=True,
1018
+ )
1019
+ logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
1020
+ x, x_mask, g=g
1021
+ ) * (1 - sdp_ratio)
1022
+ w = torch.exp(logw) * x_mask * length_scale
1023
+ w_ceil = torch.ceil(w)
1024
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1025
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1026
+ x_mask.dtype
1027
+ )
1028
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1029
+ attn = commons.generate_path(w_ceil, attn_mask)
1030
+
1031
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1032
+ 1, 2
1033
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1034
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1035
+ 1, 2
1036
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1037
+
1038
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1039
+ torch.onnx.export(
1040
+ self.flow,
1041
+ (z_p, y_mask, g),
1042
+ f"onnx/{path}/{path}_flow.onnx",
1043
+ input_names=["z_p", "y_mask", "g"],
1044
+ output_names=["z"],
1045
+ dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
1046
+ verbose=True,
1047
+ )
1048
+
1049
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1050
+ z_in = (z * y_mask)[:, :, :max_len]
1051
+
1052
+ torch.onnx.export(
1053
+ self.dec,
1054
+ (z_in, g),
1055
+ f"onnx/{path}/{path}_dec.onnx",
1056
+ input_names=["z_in", "g"],
1057
+ output_names=["o"],
1058
+ dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
1059
+ verbose=True,
1060
+ )
1061
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
onnx_modules/V230/text/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .symbols import *
onnx_modules/V230/text/symbols.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
+
5
+ # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
+ num_zh_tones = 6
74
+
75
+ # japanese
76
+ ja_symbols = [
77
+ "N",
78
+ "a",
79
+ "a:",
80
+ "b",
81
+ "by",
82
+ "ch",
83
+ "d",
84
+ "dy",
85
+ "e",
86
+ "e:",
87
+ "f",
88
+ "g",
89
+ "gy",
90
+ "h",
91
+ "hy",
92
+ "i",
93
+ "i:",
94
+ "j",
95
+ "k",
96
+ "ky",
97
+ "m",
98
+ "my",
99
+ "n",
100
+ "ny",
101
+ "o",
102
+ "o:",
103
+ "p",
104
+ "py",
105
+ "q",
106
+ "r",
107
+ "ry",
108
+ "s",
109
+ "sh",
110
+ "t",
111
+ "ts",
112
+ "ty",
113
+ "u",
114
+ "u:",
115
+ "w",
116
+ "y",
117
+ "z",
118
+ "zy",
119
+ ]
120
+ num_ja_tones = 2
121
+
122
+ # English
123
+ en_symbols = [
124
+ "aa",
125
+ "ae",
126
+ "ah",
127
+ "ao",
128
+ "aw",
129
+ "ay",
130
+ "b",
131
+ "ch",
132
+ "d",
133
+ "dh",
134
+ "eh",
135
+ "er",
136
+ "ey",
137
+ "f",
138
+ "g",
139
+ "hh",
140
+ "ih",
141
+ "iy",
142
+ "jh",
143
+ "k",
144
+ "l",
145
+ "m",
146
+ "n",
147
+ "ng",
148
+ "ow",
149
+ "oy",
150
+ "p",
151
+ "r",
152
+ "s",
153
+ "sh",
154
+ "t",
155
+ "th",
156
+ "uh",
157
+ "uw",
158
+ "V",
159
+ "w",
160
+ "y",
161
+ "z",
162
+ "zh",
163
+ ]
164
+ num_en_tones = 4
165
+
166
+ # combine all symbols
167
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
+ symbols = [pad] + normal_symbols + pu_symbols
169
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
+
171
+ # combine all tones
172
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
+
174
+ # language maps
175
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
+ num_languages = len(language_id_map.keys())
177
+
178
+ language_tone_start_map = {
179
+ "ZH": 0,
180
+ "JP": num_zh_tones,
181
+ "EN": num_zh_tones + num_ja_tones,
182
+ }
183
+
184
+ if __name__ == "__main__":
185
+ a = set(zh_symbols)
186
+ b = set(en_symbols)
187
+ print(sorted(a & b))
onnx_modules/V230_OnnxInference/__init__.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime as ort
3
+
4
+
5
+ def convert_pad_shape(pad_shape):
6
+ layer = pad_shape[::-1]
7
+ pad_shape = [item for sublist in layer for item in sublist]
8
+ return pad_shape
9
+
10
+
11
+ def sequence_mask(length, max_length=None):
12
+ if max_length is None:
13
+ max_length = length.max()
14
+ x = np.arange(max_length, dtype=length.dtype)
15
+ return np.expand_dims(x, 0) < np.expand_dims(length, 1)
16
+
17
+
18
+ def generate_path(duration, mask):
19
+ """
20
+ duration: [b, 1, t_x]
21
+ mask: [b, 1, t_y, t_x]
22
+ """
23
+
24
+ b, _, t_y, t_x = mask.shape
25
+ cum_duration = np.cumsum(duration, -1)
26
+
27
+ cum_duration_flat = cum_duration.reshape(b * t_x)
28
+ path = sequence_mask(cum_duration_flat, t_y)
29
+ path = path.reshape(b, t_x, t_y)
30
+ path = path ^ np.pad(path, ((0, 0), (1, 0), (0, 0)))[:, :-1]
31
+ path = np.expand_dims(path, 1).transpose(0, 1, 3, 2)
32
+ return path
33
+
34
+
35
+ class OnnxInferenceSession:
36
+ def __init__(self, path, Providers=["CPUExecutionProvider"]):
37
+ self.enc = ort.InferenceSession(path["enc"], providers=Providers)
38
+ self.emb_g = ort.InferenceSession(path["emb_g"], providers=Providers)
39
+ self.dp = ort.InferenceSession(path["dp"], providers=Providers)
40
+ self.sdp = ort.InferenceSession(path["sdp"], providers=Providers)
41
+ self.flow = ort.InferenceSession(path["flow"], providers=Providers)
42
+ self.dec = ort.InferenceSession(path["dec"], providers=Providers)
43
+
44
+ def __call__(
45
+ self,
46
+ seq,
47
+ tone,
48
+ language,
49
+ bert_zh,
50
+ bert_jp,
51
+ bert_en,
52
+ sid,
53
+ seed=114514,
54
+ seq_noise_scale=0.8,
55
+ sdp_noise_scale=0.6,
56
+ length_scale=1.0,
57
+ sdp_ratio=0.0,
58
+ ):
59
+ if seq.ndim == 1:
60
+ seq = np.expand_dims(seq, 0)
61
+ if tone.ndim == 1:
62
+ tone = np.expand_dims(tone, 0)
63
+ if language.ndim == 1:
64
+ language = np.expand_dims(language, 0)
65
+ assert(seq.ndim == 2,tone.ndim == 2,language.ndim == 2)
66
+ g = self.emb_g.run(
67
+ None,
68
+ {
69
+ "sid": sid.astype(np.int64),
70
+ },
71
+ )[0]
72
+ g = np.expand_dims(g, -1)
73
+ enc_rtn = self.enc.run(
74
+ None,
75
+ {
76
+ "x": seq.astype(np.int64),
77
+ "t": tone.astype(np.int64),
78
+ "language": language.astype(np.int64),
79
+ "bert_0": bert_zh.astype(np.float32),
80
+ "bert_1": bert_jp.astype(np.float32),
81
+ "bert_2": bert_en.astype(np.float32),
82
+ "g": g.astype(np.float32),
83
+ },
84
+ )
85
+ x, m_p, logs_p, x_mask = enc_rtn[0], enc_rtn[1], enc_rtn[2], enc_rtn[3]
86
+ np.random.seed(seed)
87
+ zinput = np.random.randn(x.shape[0], 2, x.shape[2]) * sdp_noise_scale
88
+ logw = self.sdp.run(
89
+ None, {"x": x, "x_mask": x_mask, "zin": zinput.astype(np.float32), "g": g}
90
+ )[0] * (sdp_ratio) + self.dp.run(None, {"x": x, "x_mask": x_mask, "g": g})[
91
+ 0
92
+ ] * (
93
+ 1 - sdp_ratio
94
+ )
95
+ w = np.exp(logw) * x_mask * length_scale
96
+ w_ceil = np.ceil(w)
97
+ y_lengths = np.clip(np.sum(w_ceil, (1, 2)), a_min=1.0, a_max=100000).astype(
98
+ np.int64
99
+ )
100
+ y_mask = np.expand_dims(sequence_mask(y_lengths, None), 1)
101
+ attn_mask = np.expand_dims(x_mask, 2) * np.expand_dims(y_mask, -1)
102
+ attn = generate_path(w_ceil, attn_mask)
103
+ m_p = np.matmul(attn.squeeze(1), m_p.transpose(0, 2, 1)).transpose(
104
+ 0, 2, 1
105
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
106
+ logs_p = np.matmul(attn.squeeze(1), logs_p.transpose(0, 2, 1)).transpose(
107
+ 0, 2, 1
108
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
109
+
110
+ z_p = (
111
+ m_p
112
+ + np.random.randn(m_p.shape[0], m_p.shape[1], m_p.shape[2])
113
+ * np.exp(logs_p)
114
+ * seq_noise_scale
115
+ )
116
+
117
+ z = self.flow.run(
118
+ None,
119
+ {
120
+ "z_p": z_p.astype(np.float32),
121
+ "y_mask": y_mask.astype(np.float32),
122
+ "g": g,
123
+ },
124
+ )[0]
125
+
126
+ return self.dec.run(None, {"z_in": z.astype(np.float32), "g": g})[0]
onnx_modules/__init__.py CHANGED
@@ -1,14 +1,21 @@
1
- from utils import get_hparams_from_file, load_checkpoint
2
  import json
3
 
4
 
5
- def export_onnx(export_path, model_path, config_path):
6
  hps = get_hparams_from_file(config_path)
7
  version = hps.version[0:3]
8
- if version == "2.0":
9
  from .V200 import SynthesizerTrn, symbols
10
- elif version == "2.1":
11
  from .V210 import SynthesizerTrn, symbols
 
 
 
 
 
 
 
12
  net_g = SynthesizerTrn(
13
  len(symbols),
14
  hps.data.filter_length // 2 + 1,
@@ -41,6 +48,7 @@ def export_onnx(export_path, model_path, config_path):
41
  "deberta-v2-large-japanese",
42
  "bert-base-japanese-v3",
43
  ],
 
44
  }
45
 
46
  with open(f"onnx/{export_path}.json", "w") as MoeVsConfFile:
 
1
+ from utils import get_hparams_from_file, load_checkpoint
2
  import json
3
 
4
 
5
+ def export_onnx(export_path, model_path, config_path, novq, dev):
6
  hps = get_hparams_from_file(config_path)
7
  version = hps.version[0:3]
8
+ if version == "2.0" or (version == "2.1" and novq):
9
  from .V200 import SynthesizerTrn, symbols
10
+ elif version == "2.1" and (not novq):
11
  from .V210 import SynthesizerTrn, symbols
12
+ elif version == "2.2":
13
+ if novq and dev:
14
+ from .V220_novq_dev import SynthesizerTrn, symbols
15
+ else:
16
+ from .V220 import SynthesizerTrn, symbols
17
+ elif version == "2.3":
18
+ from .V230 import SynthesizerTrn, symbols
19
  net_g = SynthesizerTrn(
20
  len(symbols),
21
  hps.data.filter_length // 2 + 1,
 
48
  "deberta-v2-large-japanese",
49
  "bert-base-japanese-v3",
50
  ],
51
+ "Clap": "clap-htsat-fused",
52
  }
53
 
54
  with open(f"onnx/{export_path}.json", "w") as MoeVsConfFile:
re_matching.py CHANGED
@@ -44,7 +44,6 @@ def text_matching(text: str) -> list:
44
  result = []
45
  for speaker, dialogue in matches:
46
  result.append(extract_language_and_text_updated(speaker, dialogue))
47
- print(result)
48
  return result
49
 
50
 
 
44
  result = []
45
  for speaker, dialogue in matches:
46
  result.append(extract_language_and_text_updated(speaker, dialogue))
 
47
  return result
48
 
49
 
requirements.txt CHANGED
@@ -11,7 +11,7 @@ jieba
11
  transformers
12
  pypinyin
13
  cn2an
14
- gradio==3.38.0
15
  av
16
  mecab-python3
17
  loguru
@@ -21,8 +21,7 @@ fugashi
21
  num2words
22
  PyYAML
23
  requests
24
- pyopenjtalk; sys_platform == 'linux'
25
- openjtalk; sys_platform != 'linux'
26
  jaconv
27
  psutil
28
  GPUtil
 
11
  transformers
12
  pypinyin
13
  cn2an
14
+ gradio==3.50.2
15
  av
16
  mecab-python3
17
  loguru
 
21
  num2words
22
  PyYAML
23
  requests
24
+ pyopenjtalk-prebuilt
 
25
  jaconv
26
  psutil
27
  GPUtil