litagin commited on
Commit
cda1b98
1 Parent(s): a89499b

Upload 78 files

Browse files
Files changed (49) hide show
  1. README.md +2 -2
  2. app.py +216 -330
  3. attentions.py +1 -1
  4. common/constants.py +20 -0
  5. common/log.py +16 -0
  6. common/stdout_wrapper.py +34 -0
  7. common/subprocess_utils.py +32 -0
  8. common/tts_model.py +238 -0
  9. config.py +269 -254
  10. config.yml +51 -58
  11. configs/config.json +71 -0
  12. configs/configs_jp_extra.json +78 -0
  13. configs/paths.yml +8 -0
  14. infer.py +306 -263
  15. model_assets/jvnv-F1-jp/config.json +92 -0
  16. model_assets/jvnv-F1-jp/jvnv-F1-jp_e182_s16000.safetensors +3 -0
  17. model_assets/jvnv-F1-jp/style_vectors.npy +3 -0
  18. model_assets/jvnv-F2-jp/config.json +92 -0
  19. model_assets/jvnv-F2-jp/jvnv-F2_e166_s20000.safetensors +3 -0
  20. model_assets/jvnv-F2-jp/style_vectors.npy +3 -0
  21. model_assets/jvnv-M1-jp/config.json +92 -0
  22. model_assets/jvnv-M1-jp/jvnv-M1-jp_e158_s14000.safetensors +3 -0
  23. model_assets/jvnv-M1-jp/style_vectors.npy +3 -0
  24. model_assets/jvnv-M2-jp/config.json +92 -0
  25. model_assets/jvnv-M2-jp/jvnv-M2-jp_e159_s17000.safetensors +3 -0
  26. model_assets/jvnv-M2-jp/style_vectors.npy +3 -0
  27. models_jp_extra.py +1071 -0
  28. monotonic_align/__init__.py +16 -16
  29. monotonic_align/core.py +46 -46
  30. requirements.txt +1 -4
  31. text/__init__.py +32 -0
  32. text/chinese.py +199 -0
  33. text/chinese_bert.py +121 -0
  34. text/cleaner.py +31 -0
  35. text/cmudict.rep +0 -0
  36. text/cmudict_cache.pickle +3 -0
  37. text/english.py +495 -0
  38. text/english_bert_mock.py +63 -0
  39. text/japanese.py +585 -0
  40. text/japanese_bert.py +67 -0
  41. text/japanese_mora_list.py +232 -0
  42. text/opencpop-strict.txt +429 -0
  43. text/symbols.py +187 -0
  44. text/tone_sandhi.py +773 -0
  45. tools/__init__.py +3 -0
  46. tools/classify_language.py +197 -0
  47. tools/sentence.py +173 -0
  48. tools/translate.py +61 -0
  49. utils.py +6 -5
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Style Bert VITS2 JVNV
3
  emoji: 😡😊😱😫
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: agpl-3.0
 
1
  ---
2
+ title: Style-Bert-VITS2 JVNV
3
  emoji: 😡😊😱😫
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  license: agpl-3.0
app.py CHANGED
@@ -1,209 +1,41 @@
1
  import argparse
2
  import datetime
 
3
  import os
4
  import sys
5
- import warnings
6
 
7
  import gradio as gr
8
- import numpy as np
9
  import torch
10
- from gradio.processing_utils import convert_to_16_bit_wav
11
-
12
- import utils
13
- from config import config
14
- from infer import get_net_g, infer
15
- from tools.log import logger
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  is_hf_spaces = os.getenv("SYSTEM") == "spaces"
18
  limit = 100
19
 
 
 
 
 
 
20
 
21
- class Model:
22
- def __init__(self, model_path, config_path, style_vec_path, device):
23
- self.model_path = model_path
24
- self.config_path = config_path
25
- self.device = device
26
- self.style_vec_path = style_vec_path
27
- self.load()
28
-
29
- def load(self):
30
- self.hps = utils.get_hparams_from_file(self.config_path)
31
- self.spk2id = self.hps.data.spk2id
32
- self.num_styles = self.hps.data.num_styles
33
- if hasattr(self.hps.data, "style2id"):
34
- self.style2id = self.hps.data.style2id
35
- else:
36
- self.style2id = {str(i): i for i in range(self.num_styles)}
37
-
38
- self.style_vectors = np.load(self.style_vec_path)
39
- self.net_g = None
40
-
41
- def load_net_g(self):
42
- self.net_g = get_net_g(
43
- model_path=self.model_path,
44
- version=self.hps.version,
45
- device=self.device,
46
- hps=self.hps,
47
- )
48
-
49
- def get_style_vector(self, style_id, weight=1.0):
50
- mean = self.style_vectors[0]
51
- style_vec = self.style_vectors[style_id]
52
- style_vec = mean + (style_vec - mean) * weight
53
- return style_vec
54
-
55
- def get_style_vector_from_audio(self, audio_path, weight=1.0):
56
- from style_gen import extract_style_vector
57
-
58
- xvec = extract_style_vector(audio_path)
59
- mean = self.style_vectors[0]
60
- xvec = mean + (xvec - mean) * weight
61
- return xvec
62
-
63
- def infer(
64
- self,
65
- text,
66
- language="JP",
67
- sid=0,
68
- reference_audio_path=None,
69
- sdp_ratio=0.2,
70
- noise=0.6,
71
- noisew=0.8,
72
- length=1.0,
73
- line_split=True,
74
- split_interval=0.2,
75
- style_text="",
76
- style_weight=0.7,
77
- use_style_text=False,
78
- style="0",
79
- emotion_weight=1.0,
80
- ):
81
- if reference_audio_path == "":
82
- reference_audio_path = None
83
- if style_text == "" or not use_style_text:
84
- style_text = None
85
-
86
- if self.net_g is None:
87
- self.load_net_g()
88
- if reference_audio_path is None:
89
- style_id = self.style2id[style]
90
- style_vector = self.get_style_vector(style_id, emotion_weight)
91
- else:
92
- style_vector = self.get_style_vector_from_audio(
93
- reference_audio_path, emotion_weight
94
- )
95
- if not line_split:
96
- with torch.no_grad():
97
- audio = infer(
98
- text=text,
99
- sdp_ratio=sdp_ratio,
100
- noise_scale=noise,
101
- noise_scale_w=noisew,
102
- length_scale=length,
103
- sid=sid,
104
- language=language,
105
- hps=self.hps,
106
- net_g=self.net_g,
107
- device=self.device,
108
- style_text=style_text,
109
- style_weight=style_weight,
110
- style_vec=style_vector,
111
- )
112
- else:
113
- texts = text.split("\n")
114
- texts = [t for t in texts if t != ""]
115
- audios = []
116
- with torch.no_grad():
117
- for i, t in enumerate(texts):
118
- audios.append(
119
- infer(
120
- text=t,
121
- sdp_ratio=sdp_ratio,
122
- noise_scale=noise,
123
- noise_scale_w=noisew,
124
- length_scale=length,
125
- sid=sid,
126
- language=language,
127
- hps=self.hps,
128
- net_g=self.net_g,
129
- device=self.device,
130
- style_text=style_text,
131
- style_weight=style_weight,
132
- style_vec=style_vector,
133
- )
134
- )
135
- if i != len(texts) - 1:
136
- audios.append(np.zeros(int(44100 * split_interval)))
137
- audio = np.concatenate(audios)
138
- with warnings.catch_warnings():
139
- warnings.simplefilter("ignore")
140
- audio = convert_to_16_bit_wav(audio)
141
- return (self.hps.data.sampling_rate, audio)
142
-
143
-
144
- class ModelHolder:
145
- def __init__(self, root_dir, device):
146
- self.root_dir = root_dir
147
- self.device = device
148
- self.model_files_dict = {}
149
- self.current_model = None
150
- self.model_names = []
151
- self.models = []
152
- self.refresh()
153
-
154
- def refresh(self):
155
- self.model_files_dict = {}
156
- self.model_names = []
157
- self.current_model = None
158
- model_dirs = [
159
- d
160
- for d in os.listdir(self.root_dir)
161
- if os.path.isdir(os.path.join(self.root_dir, d))
162
- ]
163
- for model_name in model_dirs:
164
- model_dir = os.path.join(self.root_dir, model_name)
165
- model_files = [
166
- os.path.join(model_dir, f)
167
- for f in os.listdir(model_dir)
168
- if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
169
- ]
170
- if len(model_files) == 0:
171
- logger.info(
172
- f"No model files found in {self.root_dir}/{model_name}, so skip it"
173
- )
174
- self.model_files_dict[model_name] = model_files
175
- self.model_names.append(model_name)
176
-
177
- def load_model(self, model_name, model_path):
178
- if model_name not in self.model_files_dict:
179
- raise Exception(f"モデル名{model_name}は存在しません")
180
- if model_path not in self.model_files_dict[model_name]:
181
- raise Exception(f"pthファイル{model_path}は存在しません")
182
- self.current_model = Model(
183
- model_path=model_path,
184
- config_path=os.path.join(self.root_dir, model_name, "config.json"),
185
- style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
186
- device=self.device,
187
- )
188
- styles = list(self.current_model.style2id.keys())
189
- return (
190
- gr.Dropdown(choices=styles, value=styles[0]),
191
- gr.update(interactive=True, value="音声合成"),
192
- )
193
-
194
- def update_model_files_dropdown(self, model_name):
195
- model_files = self.model_files_dict[model_name]
196
- return gr.Dropdown(choices=model_files, value=model_files[0])
197
-
198
- def update_model_names_dropdown(self):
199
- self.refresh()
200
- initial_model_name = self.model_names[0]
201
- initial_model_files = self.model_files_dict[initial_model_name]
202
- return (
203
- gr.Dropdown(choices=self.model_names, value=initial_model_name),
204
- gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
205
- gr.update(interactive=False), # For tts_button
206
- )
207
 
208
 
209
  def tts_fn(
@@ -216,105 +48,100 @@ def tts_fn(
216
  length_scale,
217
  line_split,
218
  split_interval,
219
- style_text,
 
 
 
220
  style_weight,
221
- use_style_text,
222
- emotion,
223
- emotion_weight,
224
  ):
225
- logger.info(f"Start TTS with {language}:\n{text}")
226
- logger.info(f"Model: {model_holder.current_model.model_path}")
227
- logger.info(f"SDP: {sdp_ratio}, Noise: {noise_scale}, Noise_W: {noise_scale_w}, Length: {length_scale}")
228
- logger.info(f"Style text enabled: {use_style_text}, Style text: {style_text}, Style weight: {style_weight}")
229
- logger.info(f"Style: {emotion}, Style weight: {emotion_weight}")
230
-
231
  if is_hf_spaces and len(text) > limit:
232
- logger.error(f"文字数が{limit}文字を超えています")
233
- # raise Exception(f"文字数が{limit}文字を超えています")
234
- return f"文字数が{limit}文字を超えています", (44100, "")
235
-
 
 
236
  assert model_holder.current_model is not None
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  start_time = datetime.datetime.now()
239
 
240
- sr, audio = model_holder.current_model.infer(
241
- text=text,
242
- language=language,
243
- reference_audio_path=reference_audio_path,
244
- sdp_ratio=sdp_ratio,
245
- noise=noise_scale,
246
- noisew=noise_scale_w,
247
- length=length_scale,
248
- line_split=line_split,
249
- split_interval=split_interval,
250
- style_text=style_text,
251
- style_weight=style_weight,
252
- use_style_text=use_style_text,
253
- style=emotion,
254
- emotion_weight=emotion_weight,
255
- )
 
 
 
 
 
 
 
 
 
256
 
257
  end_time = datetime.datetime.now()
258
  duration = (end_time - start_time).total_seconds()
259
- logger.info(f"End TTS, duration: {duration} seconds")
260
- return f"Success, time: {duration} seconds.", (sr, audio)
261
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?"
264
 
265
- example_local = [
266
- [initial_text, "JP"],
267
- [
268
- """あなたがそんなこと言うなんて、私はとっても嬉しい。
269
- あなたがそんなこと言うなんて、私はとっても怒ってる。
270
- あなたがそんなこと言うなんて、私はとっても驚いてる。
271
- あなたがそんなこと言うなんて、私はとっても辛い。""",
272
- "JP",
273
- ],
274
- [ # ChatGPTに考えてもらった告白セリフ
275
- """私、ずっと前からあなたのことを見てきました。あなたの笑顔、優しさ、強さに、心惹かれていたんです。
276
- 友達として過ごす中で、あなたのことがだんだんと特別な存在になっていくのがわかりました。
277
- えっと、私、あなたのことが好きです!もしよければ、私と付き合ってくれませんか?""",
278
- "JP",
279
- ],
280
- [ # 夏目漱石『吾輩は猫である』
281
- """吾輩は猫である。名前はまだ無い。
282
- どこで生れたかとんと見当がつかぬ。なんでも薄暗いじめじめした所でニャーニャー泣いていた事だけは記憶している。
283
- 吾輩はここで始めて人間というものを見た。しかもあとで聞くと、それは書生という、人間中で一番獰悪な種族であったそうだ。
284
- この書生というのは時々我々を捕まえて煮て食うという話である。""",
285
- "JP",
286
- ],
287
- [ # 梶井基次郎『桜の樹の下には』
288
- """桜の樹の下には屍体が埋まっている!これは信じていいことなんだよ。
289
- 何故って、桜の花があんなにも見事に咲くなんて信じられないことじゃないか。俺はあの美しさが信じられないので、このにさんにち不安だった。
290
- しかしいま、やっとわかるときが来た。桜の樹の下には屍体が埋まっている。これは信じていいことだ。""",
291
- "JP",
292
- ],
293
- [ # ChatGPTと考えた、感情を表すセリフ
294
- """やったー!テストで満点取れた!私とっても嬉しいな!
295
- どうして私の意見を無視するの?許せない!ムカつく!あんたなんか死ねばいいのに。
296
- あはははっ!この漫画めっちゃ笑える、見てよこれ、ふふふ、あはは。
297
- あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しい。""",
298
- "JP",
299
- ],
300
- [ # 上の丁寧語バージョン
301
- """やりました!テストで満点取れましたよ!私とっても嬉しいです!
302
- どうして私の意見を無視するんですか?許せません!ムカつきます!あんたなんか死んでください。
303
- あはははっ!この漫画めっちゃ笑えます、見てくださいこれ、ふふふ、あはは。
304
- あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しいです。""",
305
- "JP",
306
- ],
307
- [ # ChatGPTに考えてもらった音声合成の説明文章
308
- """音声合成は、機械学習を活用して、テキストから人の声を再現する技術です。この技術は、言語の構造を解析し、それに基づいて音声を生成します。
309
- この分野の最新の研究成果を使うと、より自然で表現豊かな音声の生成が可能である。深層学習の応用により、感情やアクセントを含む声質の微妙な変化も再現することが出来る。""",
310
- "JP",
311
- ],
312
- [
313
- "Speech synthesis is the artificial production of human speech. A computer system used for this purpose is called a speech synthesizer, and can be implemented in software or hardware products.",
314
- "EN",
315
- ],
316
- ["语音合成是人工制造人类语音。用于此目的的计算机系统称为语音合成器,可以通过软件或硬件产品实现。", "ZH"],
317
- ]
318
 
319
  example_hf_spaces = [
320
  [initial_text, "JP"],
@@ -322,30 +149,30 @@ example_hf_spaces = [
322
  ["吾輩は猫である。名前はまだ無い。", "JP"],
323
  ["桜の樹の下には屍体が埋まっている!これは信じていいことなんだよ。", "JP"],
324
  ["やったー!テストで満点取れたよ!私とっても嬉しいな!", "JP"],
325
- ["どうして私の意見を無視するの?許せない!ムカつく!あんたなんか死ねばいいのに。", "JP"],
 
 
 
326
  ["あはははっ!この漫画めっちゃ笑える、見てよこれ、ふふふ、あはは。", "JP"],
327
- ["あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しい。", "JP"],
328
- ["深層学習の応用により、感情やアクセントを含む声質の微妙な変化も再現されている。", "JP"],
329
  [
330
- "Speech synthesis is the artificial production of human speech.",
331
- "EN",
 
 
 
 
332
  ],
333
- ["语音合成是人工制造人类语音。用于此目的的计算机系统称为语音合成器,可以通过软件或硬件产品实现。", "ZH"],
334
  ]
335
-
336
  initial_md = """
337
  # Style-Bert-VITS2 JVNVコーパスデモ
338
-
339
  怒り・悲しみ・喜び等の感情スタイルを強弱付きで制御できる、[Style-Bert-VITS2](https://github.com/litagin02/Style-Bert-VITS2)のデモです。
340
-
341
  入力上限文字数は100文字までにしています。
342
-
343
  このデモでは[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)を使っており、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。
344
  """
345
 
346
- style_md = """
347
  - プリセットまたは音声ファイルから読み上げの声音・感情・スタイルのようなものを制御できます。
348
- - デフォルトのNeutralでも、十分に読み上げる文に応じた感情で感情豊かに読み上げられます。このスタイル制御は、それを重み付きで上書きするような感じです。
349
  - 強さを大きくしすぎると発音が変になったり声にならなかったりと崩壊することがあります。
350
  - どのくらいに強さがいい���はモデルやスタイルによって異なるようです。
351
  - 音声ファイルを入力する場合は、学習データと似た声音の話者(特に同じ性別)でないとよい効果が出ないかもしれません。
@@ -371,7 +198,22 @@ if __name__ == "__main__":
371
  parser = argparse.ArgumentParser()
372
  parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
373
  parser.add_argument(
374
- "--dir", "-d", type=str, help="Model directory", default=config.out_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  )
376
  args = parser.parse_args()
377
  model_dir = args.dir
@@ -383,12 +225,11 @@ if __name__ == "__main__":
383
 
384
  model_holder = ModelHolder(model_dir, device)
385
 
386
- languages = ["JP", "EN", "ZH"]
387
- examples = example_hf_spaces if is_hf_spaces else example_local
388
-
389
  model_names = model_holder.model_names
390
  if len(model_names) == 0:
391
- logger.error(f"モデルが見つかりませんでした。{model_dir}にモデルを置いてください。")
 
 
392
  sys.exit(1)
393
  initial_id = 0
394
  initial_pth_files = model_holder.model_files_dict[model_names[initial_id]]
@@ -409,51 +250,87 @@ if __name__ == "__main__":
409
  choices=initial_pth_files,
410
  value=initial_pth_files[0],
411
  )
412
- refresh_button = gr.Button("更新", scale=1, visible=not is_hf_spaces)
413
  load_button = gr.Button("ロード", scale=1, variant="primary")
414
  text_input = gr.TextArea(label="テキスト", value=initial_text)
415
 
416
- line_split = gr.Checkbox(label="改行で分けて生成", value=True)
 
 
417
  split_interval = gr.Slider(
418
  minimum=0.0,
419
  maximum=2,
420
- value=0.5,
421
  step=0.1,
422
- label="分けた場合に挟む無音の長さ(秒)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  )
424
- language = gr.Dropdown(choices=languages, value="JP", label="Language")
 
425
  with gr.Accordion(label="詳細設定", open=False):
426
  sdp_ratio = gr.Slider(
427
- minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
 
 
 
 
428
  )
429
  noise_scale = gr.Slider(
430
- minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
 
 
 
 
431
  )
432
  noise_scale_w = gr.Slider(
433
- minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise_W"
 
 
 
 
434
  )
435
  length_scale = gr.Slider(
436
- minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
 
 
 
 
 
 
 
437
  )
438
- use_style_text = gr.Checkbox(label="Style textを使う", value=False)
439
- style_text = gr.Textbox(
440
- label="Style text",
441
  placeholder="どうして私の意見を無視するの?許せない、ムカつく!死ねばいいのに。",
442
  info="このテキストの読み上げと似た声音・感情になりやすくなります。ただ抑揚やテンポ等が犠牲になる傾向があります。",
443
  visible=False,
444
  )
445
- style_text_weight = gr.Slider(
446
  minimum=0,
447
  maximum=1,
448
- value=0.7,
449
  step=0.1,
450
- label="Style textの強さ",
451
  visible=False,
452
  )
453
- use_style_text.change(
454
  lambda x: (gr.Textbox(visible=x), gr.Slider(visible=x)),
455
- inputs=[use_style_text],
456
- outputs=[style_text, style_text_weight],
457
  )
458
  with gr.Column():
459
  with gr.Accordion("スタイルについて詳細", open=False):
@@ -464,25 +341,29 @@ if __name__ == "__main__":
464
  value="プリセットから選ぶ",
465
  )
466
  style = gr.Dropdown(
467
- label="スタイル(Neutralが平均スタイル)",
468
  choices=["モデルをロードしてください"],
469
  value="モデルをロードしてください",
470
  )
471
  style_weight = gr.Slider(
472
  minimum=0,
473
  maximum=50,
474
- value=5,
475
  step=0.1,
476
  label="スタイルの強さ",
477
  )
478
- ref_audio_path = gr.Audio(label="参照音声", type="filepath", visible=False)
 
 
479
  tts_button = gr.Button(
480
- "音声合成(モデルをロードしてください)", variant="primary", interactive=False
 
 
481
  )
482
  text_output = gr.Textbox(label="情報")
483
  audio_output = gr.Audio(label="結果")
484
- with gr.Accordion("テキスト例", open=False):
485
- gr.Examples(examples, inputs=[text_input, language])
486
 
487
  tts_button.click(
488
  tts_fn,
@@ -496,17 +377,20 @@ if __name__ == "__main__":
496
  length_scale,
497
  line_split,
498
  split_interval,
499
- style_text,
500
- style_text_weight,
501
- use_style_text,
502
  style,
503
  style_weight,
 
 
 
504
  ],
505
- outputs=[text_output, audio_output],
506
  )
507
 
508
  model_name.change(
509
- model_holder.update_model_files_dropdown,
510
  inputs=[model_name],
511
  outputs=[model_path],
512
  )
@@ -514,14 +398,14 @@ if __name__ == "__main__":
514
  model_path.change(make_non_interactive, outputs=[tts_button])
515
 
516
  refresh_button.click(
517
- model_holder.update_model_names_dropdown,
518
  outputs=[model_name, model_path, tts_button],
519
  )
520
 
521
  load_button.click(
522
- model_holder.load_model,
523
  inputs=[model_name, model_path],
524
- outputs=[style, tts_button],
525
  )
526
 
527
  style_mode.change(
@@ -530,4 +414,6 @@ if __name__ == "__main__":
530
  outputs=[style, ref_audio_path],
531
  )
532
 
533
- app.launch(inbrowser=True)
 
 
 
1
  import argparse
2
  import datetime
3
+ import json
4
  import os
5
  import sys
6
+ from typing import Optional
7
 
8
  import gradio as gr
 
9
  import torch
10
+ import yaml
11
+
12
+ from common.constants import (
13
+ DEFAULT_ASSIST_TEXT_WEIGHT,
14
+ DEFAULT_LENGTH,
15
+ DEFAULT_LINE_SPLIT,
16
+ DEFAULT_NOISE,
17
+ DEFAULT_NOISEW,
18
+ DEFAULT_SDP_RATIO,
19
+ DEFAULT_SPLIT_INTERVAL,
20
+ DEFAULT_STYLE,
21
+ DEFAULT_STYLE_WEIGHT,
22
+ Languages,
23
+ )
24
+ from common.log import logger
25
+ from common.tts_model import ModelHolder
26
+ from infer import InvalidToneError
27
+ from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize
28
 
29
  is_hf_spaces = os.getenv("SYSTEM") == "spaces"
30
  limit = 100
31
 
32
+ # Get path settings
33
+ with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
34
+ path_config: dict[str, str] = yaml.safe_load(f.read())
35
+ # dataset_root = path_config["dataset_root"]
36
+ assets_root = path_config["assets_root"]
37
 
38
+ languages = [l.value for l in Languages]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def tts_fn(
 
48
  length_scale,
49
  line_split,
50
  split_interval,
51
+ assist_text,
52
+ assist_text_weight,
53
+ use_assist_text,
54
+ style,
55
  style_weight,
56
+ kata_tone_json_str,
57
+ use_tone,
58
+ speaker,
59
  ):
 
 
 
 
 
 
60
  if is_hf_spaces and len(text) > limit:
61
+ logger.error(f"Text is too long: {len(text)}")
62
+ return (
63
+ f"Error: 文字数が多すぎます({limit}文字以下にしてください)",
64
+ None,
65
+ kata_tone_json_str,
66
+ )
67
  assert model_holder.current_model is not None
68
 
69
+ wrong_tone_message = ""
70
+ kata_tone: Optional[list[tuple[str, int]]] = None
71
+ if use_tone and kata_tone_json_str != "":
72
+ if language != "JP":
73
+ logger.warning("Only Japanese is supported for tone generation.")
74
+ wrong_tone_message = "アクセント指定は現在日本語のみ対応しています。"
75
+ if line_split:
76
+ logger.warning("Tone generation is not supported for line split.")
77
+ wrong_tone_message = (
78
+ "アクセント指定は改行で分けて生成を使わない場合のみ対応しています。"
79
+ )
80
+ try:
81
+ kata_tone = []
82
+ json_data = json.loads(kata_tone_json_str)
83
+ # tupleを使うように変換
84
+ for kana, tone in json_data:
85
+ assert isinstance(kana, str) and tone in (0, 1), f"{kana}, {tone}"
86
+ kata_tone.append((kana, tone))
87
+ except Exception as e:
88
+ logger.warning(f"Error occurred when parsing kana_tone_json: {e}")
89
+ wrong_tone_message = f"アクセント指定が不正です: {e}"
90
+ kata_tone = None
91
+
92
+ # toneは実際に音声合成に代入される際のみnot Noneになる
93
+ tone: Optional[list[int]] = None
94
+ if kata_tone is not None:
95
+ phone_tone = kata_tone2phone_tone(kata_tone)
96
+ tone = [t for _, t in phone_tone]
97
+
98
+ speaker_id = model_holder.current_model.spk2id[speaker]
99
+
100
  start_time = datetime.datetime.now()
101
 
102
+ try:
103
+ sr, audio = model_holder.current_model.infer(
104
+ text=text,
105
+ language=language,
106
+ reference_audio_path=reference_audio_path,
107
+ sdp_ratio=sdp_ratio,
108
+ noise=noise_scale,
109
+ noisew=noise_scale_w,
110
+ length=length_scale,
111
+ line_split=line_split,
112
+ split_interval=split_interval,
113
+ assist_text=assist_text,
114
+ assist_text_weight=assist_text_weight,
115
+ use_assist_text=use_assist_text,
116
+ style=style,
117
+ style_weight=style_weight,
118
+ given_tone=tone,
119
+ sid=speaker_id,
120
+ )
121
+ except InvalidToneError as e:
122
+ logger.error(f"Tone error: {e}")
123
+ return f"Error: アクセント指定が不正です:\n{e}", None, kata_tone_json_str
124
+ except ValueError as e:
125
+ logger.error(f"Value error: {e}")
126
+ return f"Error: {e}", None, kata_tone_json_str
127
 
128
  end_time = datetime.datetime.now()
129
  duration = (end_time - start_time).total_seconds()
 
 
130
 
131
+ if tone is None and language == "JP":
132
+ # アクセント指定に使えるようにアクセント情報を返す
133
+ norm_text = text_normalize(text)
134
+ kata_tone = g2kata_tone(norm_text)
135
+ kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False)
136
+ elif tone is None:
137
+ kata_tone_json_str = ""
138
+ message = f"Success, time: {duration} seconds."
139
+ if wrong_tone_message != "":
140
+ message = wrong_tone_message + "\n" + message
141
+ return message, (sr, audio), kata_tone_json_str
142
 
 
143
 
144
+ initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  example_hf_spaces = [
147
  [initial_text, "JP"],
 
149
  ["吾輩は猫である。名前はまだ無い。", "JP"],
150
  ["桜の樹の下には屍体が埋まっている!これは信じていいことなんだよ。", "JP"],
151
  ["やったー!テストで満点取れたよ!私とっても嬉しいな!", "JP"],
152
+ [
153
+ "どうして私の意見を無視するの?許せない!ムカつく!あんたなんか死ねばいいのに。",
154
+ "JP",
155
+ ],
156
  ["あはははっ!この漫画めっちゃ笑える、見てよこれ、ふふふ、あはは。", "JP"],
 
 
157
  [
158
+ "あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しい。",
159
+ "JP",
160
+ ],
161
+ [
162
+ "深層学習の応用により、感情やアクセントを含む声質の微妙な変化も再現されている。",
163
+ "JP",
164
  ],
 
165
  ]
 
166
  initial_md = """
167
  # Style-Bert-VITS2 JVNVコーパスデモ
 
168
  怒り・悲しみ・喜び等の感情スタイルを強弱付きで制御できる、[Style-Bert-VITS2](https://github.com/litagin02/Style-Bert-VITS2)のデモです。
 
169
  入力上限文字数は100文字までにしています。
 
170
  このデモでは[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)を使っており、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。
171
  """
172
 
173
+ style_md = f"""
174
  - プリセットまたは音声ファイルから読み上げの声音・感情・スタイルのようなものを制御できます。
175
+ - デフォルトの{DEFAULT_STYLE}でも、十分に読み上げる文に応じた感情で感情豊かに読み上げられます。このスタイル制御は、それを重み付きで上書きするような感じです。
176
  - 強さを大きくしすぎると発音が変になったり声にならなかったりと崩壊することがあります。
177
  - どのくらいに強さがいい���はモデルやスタイルによって異なるようです。
178
  - 音声ファイルを入力する場合は、学習データと似た声音の話者(特に同じ性別)でないとよい効果が出ないかもしれません。
 
198
  parser = argparse.ArgumentParser()
199
  parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
200
  parser.add_argument(
201
+ "--dir", "-d", type=str, help="Model directory", default=assets_root
202
+ )
203
+ parser.add_argument(
204
+ "--share", action="store_true", help="Share this app publicly", default=False
205
+ )
206
+ parser.add_argument(
207
+ "--server-name",
208
+ type=str,
209
+ default=None,
210
+ help="Server name for Gradio app",
211
+ )
212
+ parser.add_argument(
213
+ "--no-autolaunch",
214
+ action="store_true",
215
+ default=False,
216
+ help="Do not launch app automatically",
217
  )
218
  args = parser.parse_args()
219
  model_dir = args.dir
 
225
 
226
  model_holder = ModelHolder(model_dir, device)
227
 
 
 
 
228
  model_names = model_holder.model_names
229
  if len(model_names) == 0:
230
+ logger.error(
231
+ f"モデルが見つかりませんでした。{model_dir}にモデルを置いてください。"
232
+ )
233
  sys.exit(1)
234
  initial_id = 0
235
  initial_pth_files = model_holder.model_files_dict[model_names[initial_id]]
 
250
  choices=initial_pth_files,
251
  value=initial_pth_files[0],
252
  )
253
+ refresh_button = gr.Button("更新", scale=1, visible=True)
254
  load_button = gr.Button("ロード", scale=1, variant="primary")
255
  text_input = gr.TextArea(label="テキスト", value=initial_text)
256
 
257
+ line_split = gr.Checkbox(
258
+ label="改行で分けて生成", value=DEFAULT_LINE_SPLIT
259
+ )
260
  split_interval = gr.Slider(
261
  minimum=0.0,
262
  maximum=2,
263
+ value=DEFAULT_SPLIT_INTERVAL,
264
  step=0.1,
265
+ label="改行ごとに挟む無音の長さ(秒)",
266
+ )
267
+ line_split.change(
268
+ lambda x: (gr.Slider(visible=x)),
269
+ inputs=[line_split],
270
+ outputs=[split_interval],
271
+ )
272
+ tone = gr.Textbox(
273
+ label="アクセント調整(数値は 0=低 か1=高 のみ)",
274
+ info="改行で分けない場合のみ使えます。万能ではありません。",
275
+ )
276
+ use_tone = gr.Checkbox(label="アクセント調整を使う", value=False)
277
+ use_tone.change(
278
+ lambda x: (gr.Checkbox(value=False) if x else gr.Checkbox()),
279
+ inputs=[use_tone],
280
+ outputs=[line_split],
281
  )
282
+ language = gr.Dropdown(choices=["JP"], value="JP", label="Language")
283
+ speaker = gr.Dropdown(label="話者")
284
  with gr.Accordion(label="詳細設定", open=False):
285
  sdp_ratio = gr.Slider(
286
+ minimum=0,
287
+ maximum=1,
288
+ value=DEFAULT_SDP_RATIO,
289
+ step=0.1,
290
+ label="SDP Ratio",
291
  )
292
  noise_scale = gr.Slider(
293
+ minimum=0.1,
294
+ maximum=2,
295
+ value=DEFAULT_NOISE,
296
+ step=0.1,
297
+ label="Noise",
298
  )
299
  noise_scale_w = gr.Slider(
300
+ minimum=0.1,
301
+ maximum=2,
302
+ value=DEFAULT_NOISEW,
303
+ step=0.1,
304
+ label="Noise_W",
305
  )
306
  length_scale = gr.Slider(
307
+ minimum=0.1,
308
+ maximum=2,
309
+ value=DEFAULT_LENGTH,
310
+ step=0.1,
311
+ label="Length",
312
+ )
313
+ use_assist_text = gr.Checkbox(
314
+ label="Assist textを使う", value=False
315
  )
316
+ assist_text = gr.Textbox(
317
+ label="Assist text",
 
318
  placeholder="どうして私の意見を無視するの?許せない、ムカつく!死ねばいいのに。",
319
  info="このテキストの読み上げと似た声音・感情になりやすくなります。ただ抑揚やテンポ等が犠牲になる傾向があります。",
320
  visible=False,
321
  )
322
+ assist_text_weight = gr.Slider(
323
  minimum=0,
324
  maximum=1,
325
+ value=DEFAULT_ASSIST_TEXT_WEIGHT,
326
  step=0.1,
327
+ label="Assist textの強さ",
328
  visible=False,
329
  )
330
+ use_assist_text.change(
331
  lambda x: (gr.Textbox(visible=x), gr.Slider(visible=x)),
332
+ inputs=[use_assist_text],
333
+ outputs=[assist_text, assist_text_weight],
334
  )
335
  with gr.Column():
336
  with gr.Accordion("スタイルについて詳細", open=False):
 
341
  value="プリセットから選ぶ",
342
  )
343
  style = gr.Dropdown(
344
+ label=f"スタイル({DEFAULT_STYLE}が平均スタイル)",
345
  choices=["モデルをロードしてください"],
346
  value="モデルをロードしてください",
347
  )
348
  style_weight = gr.Slider(
349
  minimum=0,
350
  maximum=50,
351
+ value=DEFAULT_STYLE_WEIGHT,
352
  step=0.1,
353
  label="スタイルの強さ",
354
  )
355
+ ref_audio_path = gr.Audio(
356
+ label="参照音声", type="filepath", visible=False
357
+ )
358
  tts_button = gr.Button(
359
+ "音声合成(モデルをロードしてください)",
360
+ variant="primary",
361
+ interactive=False,
362
  )
363
  text_output = gr.Textbox(label="情報")
364
  audio_output = gr.Audio(label="結果")
365
+ with gr.Accordion("テキスト例", open=True):
366
+ gr.Examples(example_hf_spaces, inputs=[text_input, language])
367
 
368
  tts_button.click(
369
  tts_fn,
 
377
  length_scale,
378
  line_split,
379
  split_interval,
380
+ assist_text,
381
+ assist_text_weight,
382
+ use_assist_text,
383
  style,
384
  style_weight,
385
+ tone,
386
+ use_tone,
387
+ speaker,
388
  ],
389
+ outputs=[text_output, audio_output, tone],
390
  )
391
 
392
  model_name.change(
393
+ model_holder.update_model_files_gr,
394
  inputs=[model_name],
395
  outputs=[model_path],
396
  )
 
398
  model_path.change(make_non_interactive, outputs=[tts_button])
399
 
400
  refresh_button.click(
401
+ model_holder.update_model_names_gr,
402
  outputs=[model_name, model_path, tts_button],
403
  )
404
 
405
  load_button.click(
406
+ model_holder.load_model_gr,
407
  inputs=[model_name, model_path],
408
+ outputs=[style, tts_button, speaker],
409
  )
410
 
411
  style_mode.change(
 
414
  outputs=[style, ref_audio_path],
415
  )
416
 
417
+ app.launch(
418
+ inbrowser=not args.no_autolaunch, share=args.share, server_name=args.server_name
419
+ )
attentions.py CHANGED
@@ -4,7 +4,7 @@ from torch import nn
4
  from torch.nn import functional as F
5
 
6
  import commons
7
- from tools.log import logger as logging
8
 
9
 
10
  class LayerNorm(nn.Module):
 
4
  from torch.nn import functional as F
5
 
6
  import commons
7
+ from common.log import logger as logging
8
 
9
 
10
  class LayerNorm(nn.Module):
common/constants.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+ DEFAULT_STYLE: str = "Neutral"
4
+ DEFAULT_STYLE_WEIGHT: float = 5.0
5
+
6
+
7
+ class Languages(str, enum.Enum):
8
+ JP = "JP"
9
+ EN = "EN"
10
+ ZH = "ZH"
11
+
12
+
13
+ DEFAULT_SDP_RATIO: float = 0.2
14
+ DEFAULT_NOISE: float = 0.6
15
+ DEFAULT_NOISEW: float = 0.8
16
+ DEFAULT_LENGTH: float = 1.0
17
+ DEFAULT_LINE_SPLIT: bool = True
18
+ DEFAULT_SPLIT_INTERVAL: float = 0.5
19
+ DEFAULT_ASSIST_TEXT_WEIGHT: float = 0.7
20
+ DEFAULT_ASSIST_TEXT_WEIGHT: float = 1.0
common/log.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ logger封装
3
+ """
4
+ from loguru import logger
5
+
6
+ from .stdout_wrapper import SAFE_STDOUT
7
+
8
+ # 移除所有默认的处理器
9
+ logger.remove()
10
+
11
+ # 自定义格式并添加到标准输出
12
+ log_format = (
13
+ "<g>{time:MM-DD HH:mm:ss}</g> |<lvl>{level:^8}</lvl>| {file}:{line} | {message}"
14
+ )
15
+
16
+ logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True)
common/stdout_wrapper.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import tempfile
3
+
4
+
5
+ class StdoutWrapper:
6
+ def __init__(self):
7
+ self.temp_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
8
+ self.original_stdout = sys.stdout
9
+
10
+ def write(self, message: str):
11
+ self.temp_file.write(message)
12
+ self.temp_file.flush()
13
+ print(message, end="", file=self.original_stdout)
14
+
15
+ def flush(self):
16
+ self.temp_file.flush()
17
+
18
+ def read(self):
19
+ self.temp_file.seek(0)
20
+ return self.temp_file.read()
21
+
22
+ def close(self):
23
+ self.temp_file.close()
24
+
25
+ def fileno(self):
26
+ return self.temp_file.fileno()
27
+
28
+
29
+ try:
30
+ import google.colab
31
+
32
+ SAFE_STDOUT = StdoutWrapper()
33
+ except ImportError:
34
+ SAFE_STDOUT = sys.stdout
common/subprocess_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ from .log import logger
5
+ from .stdout_wrapper import SAFE_STDOUT
6
+
7
+ python = sys.executable
8
+
9
+
10
+ def run_script_with_log(cmd: list[str], ignore_warning=False) -> tuple[bool, str]:
11
+ logger.info(f"Running: {' '.join(cmd)}")
12
+ result = subprocess.run(
13
+ [python] + cmd,
14
+ stdout=SAFE_STDOUT, # type: ignore
15
+ stderr=subprocess.PIPE,
16
+ text=True,
17
+ )
18
+ if result.returncode != 0:
19
+ logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}")
20
+ return False, result.stderr
21
+ elif result.stderr and not ignore_warning:
22
+ logger.warning(f"Warning: {' '.join(cmd)}\n{result.stderr}")
23
+ return True, result.stderr
24
+ logger.success(f"Success: {' '.join(cmd)}")
25
+ return True, ""
26
+
27
+
28
+ def second_elem_of(original_function):
29
+ def inner_function(*args, **kwargs):
30
+ return original_function(*args, **kwargs)[1]
31
+
32
+ return inner_function
common/tts_model.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import warnings
6
+ from gradio.processing_utils import convert_to_16_bit_wav
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import utils
10
+ from infer import get_net_g, infer
11
+ from models import SynthesizerTrn
12
+ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
13
+
14
+ from .log import logger
15
+ from .constants import (
16
+ DEFAULT_ASSIST_TEXT_WEIGHT,
17
+ DEFAULT_LENGTH,
18
+ DEFAULT_LINE_SPLIT,
19
+ DEFAULT_NOISE,
20
+ DEFAULT_NOISEW,
21
+ DEFAULT_SDP_RATIO,
22
+ DEFAULT_SPLIT_INTERVAL,
23
+ DEFAULT_STYLE,
24
+ DEFAULT_STYLE_WEIGHT,
25
+ )
26
+
27
+
28
+ class Model:
29
+ def __init__(
30
+ self, model_path: str, config_path: str, style_vec_path: str, device: str
31
+ ):
32
+ self.model_path: str = model_path
33
+ self.config_path: str = config_path
34
+ self.device: str = device
35
+ self.style_vec_path: str = style_vec_path
36
+ self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path)
37
+ self.spk2id: Dict[str, int] = self.hps.data.spk2id
38
+ self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()}
39
+
40
+ self.num_styles: int = self.hps.data.num_styles
41
+ if hasattr(self.hps.data, "style2id"):
42
+ self.style2id: Dict[str, int] = self.hps.data.style2id
43
+ else:
44
+ self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)}
45
+ if len(self.style2id) != self.num_styles:
46
+ raise ValueError(
47
+ f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})"
48
+ )
49
+
50
+ self.style_vectors: np.ndarray = np.load(self.style_vec_path)
51
+ if self.style_vectors.shape[0] != self.num_styles:
52
+ raise ValueError(
53
+ f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})"
54
+ )
55
+
56
+ self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None
57
+
58
+ def load_net_g(self):
59
+ self.net_g = get_net_g(
60
+ model_path=self.model_path,
61
+ version=self.hps.version,
62
+ device=self.device,
63
+ hps=self.hps,
64
+ )
65
+
66
+ def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray:
67
+ mean = self.style_vectors[0]
68
+ style_vec = self.style_vectors[style_id]
69
+ style_vec = mean + (style_vec - mean) * weight
70
+ return style_vec
71
+
72
+ def get_style_vector_from_audio(
73
+ self, audio_path: str, weight: float = 1.0
74
+ ) -> np.ndarray:
75
+ from style_gen import get_style_vector
76
+
77
+ xvec = get_style_vector(audio_path)
78
+ mean = self.style_vectors[0]
79
+ xvec = mean + (xvec - mean) * weight
80
+ return xvec
81
+
82
+ def infer(
83
+ self,
84
+ text: str,
85
+ language: str = "JP",
86
+ sid: int = 0,
87
+ reference_audio_path: Optional[str] = None,
88
+ sdp_ratio: float = DEFAULT_SDP_RATIO,
89
+ noise: float = DEFAULT_NOISE,
90
+ noisew: float = DEFAULT_NOISEW,
91
+ length: float = DEFAULT_LENGTH,
92
+ line_split: bool = DEFAULT_LINE_SPLIT,
93
+ split_interval: float = DEFAULT_SPLIT_INTERVAL,
94
+ assist_text: Optional[str] = None,
95
+ assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT,
96
+ use_assist_text: bool = False,
97
+ style: str = DEFAULT_STYLE,
98
+ style_weight: float = DEFAULT_STYLE_WEIGHT,
99
+ given_tone: Optional[list[int]] = None,
100
+ ) -> tuple[int, np.ndarray]:
101
+ logger.info(f"Start generating audio data from text:\n{text}")
102
+ if language != "JP" and self.hps.version.endswith("JP-Extra"):
103
+ raise ValueError(
104
+ "The model is trained with JP-Extra, but the language is not JP"
105
+ )
106
+ if reference_audio_path == "":
107
+ reference_audio_path = None
108
+ if assist_text == "" or not use_assist_text:
109
+ assist_text = None
110
+
111
+ if self.net_g is None:
112
+ self.load_net_g()
113
+ if reference_audio_path is None:
114
+ style_id = self.style2id[style]
115
+ style_vector = self.get_style_vector(style_id, style_weight)
116
+ else:
117
+ style_vector = self.get_style_vector_from_audio(
118
+ reference_audio_path, style_weight
119
+ )
120
+ if not line_split:
121
+ with torch.no_grad():
122
+ audio = infer(
123
+ text=text,
124
+ sdp_ratio=sdp_ratio,
125
+ noise_scale=noise,
126
+ noise_scale_w=noisew,
127
+ length_scale=length,
128
+ sid=sid,
129
+ language=language,
130
+ hps=self.hps,
131
+ net_g=self.net_g,
132
+ device=self.device,
133
+ assist_text=assist_text,
134
+ assist_text_weight=assist_text_weight,
135
+ style_vec=style_vector,
136
+ given_tone=given_tone,
137
+ )
138
+ else:
139
+ texts = text.split("\n")
140
+ texts = [t for t in texts if t != ""]
141
+ audios = []
142
+ with torch.no_grad():
143
+ for i, t in enumerate(texts):
144
+ audios.append(
145
+ infer(
146
+ text=t,
147
+ sdp_ratio=sdp_ratio,
148
+ noise_scale=noise,
149
+ noise_scale_w=noisew,
150
+ length_scale=length,
151
+ sid=sid,
152
+ language=language,
153
+ hps=self.hps,
154
+ net_g=self.net_g,
155
+ device=self.device,
156
+ assist_text=assist_text,
157
+ assist_text_weight=assist_text_weight,
158
+ style_vec=style_vector,
159
+ )
160
+ )
161
+ if i != len(texts) - 1:
162
+ audios.append(np.zeros(int(44100 * split_interval)))
163
+ audio = np.concatenate(audios)
164
+ with warnings.catch_warnings():
165
+ warnings.simplefilter("ignore")
166
+ audio = convert_to_16_bit_wav(audio)
167
+ logger.info("Audio data generated successfully")
168
+ return (self.hps.data.sampling_rate, audio)
169
+
170
+
171
+ class ModelHolder:
172
+ def __init__(self, root_dir: str, device: str):
173
+ self.root_dir: str = root_dir
174
+ self.device: str = device
175
+ self.model_files_dict: Dict[str, List[str]] = {}
176
+ self.current_model: Optional[Model] = None
177
+ self.model_names: List[str] = []
178
+ self.models: List[Model] = []
179
+ self.refresh()
180
+
181
+ def refresh(self):
182
+ self.model_files_dict = {}
183
+ self.model_names = []
184
+ self.current_model = None
185
+ model_dirs = [
186
+ d
187
+ for d in os.listdir(self.root_dir)
188
+ if os.path.isdir(os.path.join(self.root_dir, d))
189
+ ]
190
+ for model_name in model_dirs:
191
+ model_dir = os.path.join(self.root_dir, model_name)
192
+ model_files = [
193
+ os.path.join(model_dir, f)
194
+ for f in os.listdir(model_dir)
195
+ if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
196
+ ]
197
+ if len(model_files) == 0:
198
+ logger.warning(
199
+ f"No model files found in {self.root_dir}/{model_name}, so skip it"
200
+ )
201
+ continue
202
+ self.model_files_dict[model_name] = model_files
203
+ self.model_names.append(model_name)
204
+
205
+ def load_model_gr(
206
+ self, model_name: str, model_path: str
207
+ ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]:
208
+ if model_name not in self.model_files_dict:
209
+ raise ValueError(f"Model `{model_name}` is not found")
210
+ if model_path not in self.model_files_dict[model_name]:
211
+ raise ValueError(f"Model file `{model_path}` is not found")
212
+ self.current_model = Model(
213
+ model_path=model_path,
214
+ config_path=os.path.join(self.root_dir, model_name, "config.json"),
215
+ style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
216
+ device=self.device,
217
+ )
218
+ speakers = list(self.current_model.spk2id.keys())
219
+ styles = list(self.current_model.style2id.keys())
220
+ return (
221
+ gr.Dropdown(choices=styles, value=styles[0]),
222
+ gr.Button(interactive=True, value="音声合成"),
223
+ gr.Dropdown(choices=speakers, value=speakers[0]),
224
+ )
225
+
226
+ def update_model_files_gr(self, model_name: str) -> gr.Dropdown:
227
+ model_files = self.model_files_dict[model_name]
228
+ return gr.Dropdown(choices=model_files, value=model_files[0])
229
+
230
+ def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]:
231
+ self.refresh()
232
+ initial_model_name = self.model_names[0]
233
+ initial_model_files = self.model_files_dict[initial_model_name]
234
+ return (
235
+ gr.Dropdown(choices=self.model_names, value=initial_model_name),
236
+ gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
237
+ gr.Button(interactive=False), # For tts_button
238
+ )
config.py CHANGED
@@ -1,254 +1,269 @@
1
- """
2
- @Desc: 全局配置文件读取
3
- """
4
- import argparse
5
- import yaml
6
- from typing import Dict, List
7
- import os
8
- import shutil
9
- import sys
10
-
11
-
12
- class Resample_config:
13
- """重采样配置"""
14
-
15
- def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
16
- self.sampling_rate: int = sampling_rate # 目标采样率
17
- self.in_dir: str = in_dir # 待处理音频目录路径
18
- self.out_dir: str = out_dir # 重采样输出路径
19
-
20
- @classmethod
21
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
22
- """从字典中生成实例"""
23
-
24
- # 不检查路径是否有效,此逻辑在resample.py中处理
25
- data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
26
- data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
27
-
28
- return cls(**data)
29
-
30
-
31
- class Preprocess_text_config:
32
- """数据预处理配置"""
33
-
34
- def __init__(
35
- self,
36
- transcription_path: str,
37
- cleaned_path: str,
38
- train_path: str,
39
- val_path: str,
40
- config_path: str,
41
- val_per_lang: int = 5,
42
- max_val_total: int = 10000,
43
- clean: bool = True,
44
- ):
45
- self.transcription_path: str = transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
46
- self.cleaned_path: str = cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
47
- self.train_path: str = train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
48
- self.val_path: str = val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
49
- self.config_path: str = config_path # 配置文件路径
50
- self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数
51
- self.max_val_total: int = max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
52
- self.clean: bool = clean # 是否进行数据清洗
53
-
54
- @classmethod
55
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
56
- """从字典中生成实例"""
57
-
58
- data["transcription_path"] = os.path.join(
59
- dataset_path, data["transcription_path"]
60
- )
61
- if data["cleaned_path"] == "" or data["cleaned_path"] is None:
62
- data["cleaned_path"] = None
63
- else:
64
- data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
65
- data["train_path"] = os.path.join(dataset_path, data["train_path"])
66
- data["val_path"] = os.path.join(dataset_path, data["val_path"])
67
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
68
-
69
- return cls(**data)
70
-
71
-
72
- class Bert_gen_config:
73
- """bert_gen 配置"""
74
-
75
- def __init__(
76
- self,
77
- config_path: str,
78
- num_processes: int = 2,
79
- device: str = "cuda",
80
- use_multi_device: bool = False,
81
- ):
82
- self.config_path = config_path
83
- self.num_processes = num_processes
84
- self.device = device
85
- self.use_multi_device = use_multi_device
86
-
87
- @classmethod
88
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
89
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
90
-
91
- return cls(**data)
92
-
93
-
94
- class Style_gen_config:
95
- """style_gen 配置"""
96
-
97
- def __init__(
98
- self,
99
- config_path: str,
100
- num_processes: int = 2,
101
- device: str = "cuda",
102
- ):
103
- self.config_path = config_path
104
- self.num_processes = num_processes
105
- self.device = device
106
-
107
- @classmethod
108
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
109
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
110
-
111
- return cls(**data)
112
-
113
-
114
- class Train_ms_config:
115
- """训练配置"""
116
-
117
- def __init__(
118
- self,
119
- config_path: str,
120
- env: Dict[str, any],
121
- # base: Dict[str, any],
122
- model: str,
123
- num_workers: int,
124
- spec_cache: bool,
125
- keep_ckpts: int,
126
- ):
127
- self.env = env # 需要加载的环境变量
128
- # self.base = base # 底模配置
129
- self.model = model # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
130
- self.config_path = config_path # 配置文件路径
131
- self.num_workers = num_workers # worker数量
132
- self.spec_cache = spec_cache # 是否启用spec缓存
133
- self.keep_ckpts = keep_ckpts # ckpt数量
134
-
135
- @classmethod
136
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
137
- # data["model"] = os.path.join(dataset_path, data["model"])
138
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
139
-
140
- return cls(**data)
141
-
142
-
143
- class Webui_config:
144
- """webui 配置"""
145
-
146
- def __init__(
147
- self,
148
- device: str,
149
- model: str,
150
- config_path: str,
151
- language_identification_library: str,
152
- port: int = 7860,
153
- share: bool = False,
154
- debug: bool = False,
155
- ):
156
- self.device: str = device
157
- self.model: str = model # 端口号
158
- self.config_path: str = config_path # 是否公开部署,对外网开放
159
- self.port: int = port # 是否开启debug模式
160
- self.share: bool = share # 模型路径
161
- self.debug: bool = debug # 配置文件路径
162
- self.language_identification_library: str = (
163
- language_identification_library # 语种识别库
164
- )
165
-
166
- @classmethod
167
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
168
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
169
- data["model"] = os.path.join(dataset_path, data["model"])
170
- return cls(**data)
171
-
172
-
173
- class Server_config:
174
- def __init__(
175
- self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda"
176
- ):
177
- self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置
178
- self.port: int = port # 端口号
179
- self.device: str = device # 模型默认使用设备
180
-
181
- @classmethod
182
- def from_dict(cls, data: Dict[str, any]):
183
- return cls(**data)
184
-
185
-
186
- class Translate_config:
187
- """翻译api配置"""
188
-
189
- def __init__(self, app_key: str, secret_key: str):
190
- self.app_key = app_key
191
- self.secret_key = secret_key
192
-
193
- @classmethod
194
- def from_dict(cls, data: Dict[str, any]):
195
- return cls(**data)
196
-
197
-
198
- class Config:
199
- def __init__(self, config_path: str):
200
- if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
201
- shutil.copy(src="default_config.yml", dst=config_path)
202
- print(
203
- f"A configuration file {config_path} has been generated based on the default configuration file default_config.yml."
204
- )
205
- print(
206
- "If you have no special needs, please do not modify default_config.yml."
207
- )
208
- # sys.exit(0)
209
- with open(file=config_path, mode="r", encoding="utf-8") as file:
210
- yaml_config: Dict[str, any] = yaml.safe_load(file.read())
211
- model_name: str = yaml_config["model_name"]
212
- self.model_name: str = model_name
213
- if "dataset_path" in yaml_config:
214
- dataset_path = yaml_config["dataset_path"]
215
- else:
216
- dataset_path = f"Data/{model_name}"
217
- self.out_dir = yaml_config["out_dir"]
218
- # openi_token: str = yaml_config["openi_token"]
219
- self.dataset_path: str = dataset_path
220
- # self.mirror: str = yaml_config["mirror"]
221
- # self.openi_token: str = openi_token
222
- self.resample_config: Resample_config = Resample_config.from_dict(
223
- dataset_path, yaml_config["resample"]
224
- )
225
- self.preprocess_text_config: Preprocess_text_config = (
226
- Preprocess_text_config.from_dict(
227
- dataset_path, yaml_config["preprocess_text"]
228
- )
229
- )
230
- self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
231
- dataset_path, yaml_config["bert_gen"]
232
- )
233
- self.style_gen_config: Style_gen_config = Style_gen_config.from_dict(
234
- dataset_path, yaml_config["style_gen"]
235
- )
236
- self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
237
- dataset_path, yaml_config["train_ms"]
238
- )
239
- self.webui_config: Webui_config = Webui_config.from_dict(
240
- dataset_path, yaml_config["webui"]
241
- )
242
- self.server_config: Server_config = Server_config.from_dict(
243
- yaml_config["server"]
244
- )
245
- # self.translate_config: Translate_config = Translate_config.from_dict(
246
- # yaml_config["translate"]
247
- # )
248
-
249
-
250
- parser = argparse.ArgumentParser()
251
- # 为避免与以前的config.json起冲突,将其更名如下
252
- parser.add_argument("-y", "--yml_config", type=str, default="config.yml")
253
- args, _ = parser.parse_known_args()
254
- config = Config(args.yml_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Desc: 全局配置文件读取
3
+ """
4
+ import argparse
5
+ import os
6
+ import shutil
7
+ from typing import Dict, List
8
+
9
+ import yaml
10
+
11
+ from common.log import logger
12
+
13
+
14
+ class Resample_config:
15
+ """重采样配置"""
16
+
17
+ def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
18
+ self.sampling_rate: int = sampling_rate # 目标采样率
19
+ self.in_dir: str = in_dir # 待处理音频目录路径
20
+ self.out_dir: str = out_dir # 重采样输出路径
21
+
22
+ @classmethod
23
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
24
+ """从字典中生成实例"""
25
+
26
+ # 不检查路���是否有效,此逻辑在resample.py中处理
27
+ data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
28
+ data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
29
+
30
+ return cls(**data)
31
+
32
+
33
+ class Preprocess_text_config:
34
+ """数据预处理配置"""
35
+
36
+ def __init__(
37
+ self,
38
+ transcription_path: str,
39
+ cleaned_path: str,
40
+ train_path: str,
41
+ val_path: str,
42
+ config_path: str,
43
+ val_per_lang: int = 5,
44
+ max_val_total: int = 10000,
45
+ clean: bool = True,
46
+ ):
47
+ self.transcription_path: str = transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
48
+ self.cleaned_path: str = cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
49
+ self.train_path: str = train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
50
+ self.val_path: str = val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
51
+ self.config_path: str = config_path # 配置文件路径
52
+ self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数
53
+ self.max_val_total: int = max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
54
+ self.clean: bool = clean # 是否进行数据清洗
55
+
56
+ @classmethod
57
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
58
+ """从字典中生成实例"""
59
+
60
+ data["transcription_path"] = os.path.join(
61
+ dataset_path, data["transcription_path"]
62
+ )
63
+ if data["cleaned_path"] == "" or data["cleaned_path"] is None:
64
+ data["cleaned_path"] = None
65
+ else:
66
+ data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
67
+ data["train_path"] = os.path.join(dataset_path, data["train_path"])
68
+ data["val_path"] = os.path.join(dataset_path, data["val_path"])
69
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
70
+
71
+ return cls(**data)
72
+
73
+
74
+ class Bert_gen_config:
75
+ """bert_gen 配置"""
76
+
77
+ def __init__(
78
+ self,
79
+ config_path: str,
80
+ num_processes: int = 2,
81
+ device: str = "cuda",
82
+ use_multi_device: bool = False,
83
+ ):
84
+ self.config_path = config_path
85
+ self.num_processes = num_processes
86
+ self.device = device
87
+ self.use_multi_device = use_multi_device
88
+
89
+ @classmethod
90
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
91
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
92
+
93
+ return cls(**data)
94
+
95
+
96
+ class Style_gen_config:
97
+ """style_gen 配置"""
98
+
99
+ def __init__(
100
+ self,
101
+ config_path: str,
102
+ num_processes: int = 4,
103
+ device: str = "cuda",
104
+ ):
105
+ self.config_path = config_path
106
+ self.num_processes = num_processes
107
+ self.device = device
108
+
109
+ @classmethod
110
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
111
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
112
+
113
+ return cls(**data)
114
+
115
+
116
+ class Train_ms_config:
117
+ """训练配置"""
118
+
119
+ def __init__(
120
+ self,
121
+ config_path: str,
122
+ env: Dict[str, any],
123
+ # base: Dict[str, any],
124
+ model_dir: str,
125
+ num_workers: int,
126
+ spec_cache: bool,
127
+ keep_ckpts: int,
128
+ ):
129
+ self.env = env # 需要加载的环境变量
130
+ # self.base = base # 底模配置
131
+ self.model_dir = model_dir # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
132
+ self.config_path = config_path # 配置文件路径
133
+ self.num_workers = num_workers # worker数量
134
+ self.spec_cache = spec_cache # 是否启用spec缓存
135
+ self.keep_ckpts = keep_ckpts # ckpt数量
136
+
137
+ @classmethod
138
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
139
+ # data["model"] = os.path.join(dataset_path, data["model"])
140
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
141
+
142
+ return cls(**data)
143
+
144
+
145
+ class Webui_config:
146
+ """webui 配置"""
147
+
148
+ def __init__(
149
+ self,
150
+ device: str,
151
+ model: str,
152
+ config_path: str,
153
+ language_identification_library: str,
154
+ port: int = 7860,
155
+ share: bool = False,
156
+ debug: bool = False,
157
+ ):
158
+ self.device: str = device
159
+ self.model: str = model # 端口号
160
+ self.config_path: str = config_path # 是否公开部署,对外网开放
161
+ self.port: int = port # 是否开启debug模式
162
+ self.share: bool = share # 模型路径
163
+ self.debug: bool = debug # 配置文件路径
164
+ self.language_identification_library: str = (
165
+ language_identification_library # 语种识别库
166
+ )
167
+
168
+ @classmethod
169
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
170
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
171
+ data["model"] = os.path.join(dataset_path, data["model"])
172
+ return cls(**data)
173
+
174
+
175
+ class Server_config:
176
+ def __init__(
177
+ self,
178
+ port: int = 5000,
179
+ device: str = "cuda",
180
+ limit: int = 100,
181
+ language: str = "JP",
182
+ origins: List[str] = None,
183
+ ):
184
+ self.port: int = port
185
+ self.device: str = device
186
+ self.language: str = language
187
+ self.limit: int = limit
188
+ self.origins: List[str] = origins
189
+
190
+ @classmethod
191
+ def from_dict(cls, data: Dict[str, any]):
192
+ return cls(**data)
193
+
194
+
195
+ class Translate_config:
196
+ """翻译api配置"""
197
+
198
+ def __init__(self, app_key: str, secret_key: str):
199
+ self.app_key = app_key
200
+ self.secret_key = secret_key
201
+
202
+ @classmethod
203
+ def from_dict(cls, data: Dict[str, any]):
204
+ return cls(**data)
205
+
206
+
207
+ class Config:
208
+ def __init__(self, config_path: str, path_config: dict[str, str]):
209
+ if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
210
+ shutil.copy(src="default_config.yml", dst=config_path)
211
+ logger.info(
212
+ f"A configuration file {config_path} has been generated based on the default configuration file default_config.yml."
213
+ )
214
+ logger.info(
215
+ "If you have no special needs, please do not modify default_config.yml."
216
+ )
217
+ # sys.exit(0)
218
+ with open(file=config_path, mode="r", encoding="utf-8") as file:
219
+ yaml_config: Dict[str, any] = yaml.safe_load(file.read())
220
+ model_name: str = yaml_config["model_name"]
221
+ self.model_name: str = model_name
222
+ if "dataset_path" in yaml_config:
223
+ dataset_path = yaml_config["dataset_path"]
224
+ else:
225
+ dataset_path = os.path.join(path_config["dataset_root"], model_name)
226
+ self.dataset_path: str = dataset_path
227
+ self.assets_root: str = path_config["assets_root"]
228
+ self.out_dir = os.path.join(self.assets_root, model_name)
229
+ self.resample_config: Resample_config = Resample_config.from_dict(
230
+ dataset_path, yaml_config["resample"]
231
+ )
232
+ self.preprocess_text_config: Preprocess_text_config = (
233
+ Preprocess_text_config.from_dict(
234
+ dataset_path, yaml_config["preprocess_text"]
235
+ )
236
+ )
237
+ self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
238
+ dataset_path, yaml_config["bert_gen"]
239
+ )
240
+ self.style_gen_config: Style_gen_config = Style_gen_config.from_dict(
241
+ dataset_path, yaml_config["style_gen"]
242
+ )
243
+ self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
244
+ dataset_path, yaml_config["train_ms"]
245
+ )
246
+ self.webui_config: Webui_config = Webui_config.from_dict(
247
+ dataset_path, yaml_config["webui"]
248
+ )
249
+ self.server_config: Server_config = Server_config.from_dict(
250
+ yaml_config["server"]
251
+ )
252
+ # self.translate_config: Translate_config = Translate_config.from_dict(
253
+ # yaml_config["translate"]
254
+ # )
255
+
256
+
257
+ with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
258
+ path_config: dict[str, str] = yaml.safe_load(f.read())
259
+ # Should contain the following keys:
260
+ # - dataset_root: the root directory of the dataset, default to "Data"
261
+ # - assets_root: the root directory of the assets, default to "model_assets"
262
+
263
+
264
+ try:
265
+ config = Config("config.yml", path_config)
266
+ except (TypeError, KeyError):
267
+ logger.warning("Old config.yml found. Replace it with default_config.yml.")
268
+ shutil.copy(src="default_config.yml", dst="config.yml")
269
+ config = Config("config.yml", path_config)
config.yml CHANGED
@@ -1,58 +1,51 @@
1
- bert_gen:
2
- config_path: config.json
3
- device: cuda
4
- num_processes: 4
5
- use_multi_device: false
6
- dataset_path: Data\jvnv-M2
7
- model_name: jvnv-M2
8
- out_dir: model_assets
9
- preprocess_text:
10
- clean: true
11
- cleaned_path: ''
12
- config_path: config.json
13
- max_val_total: 12
14
- train_path: filelists/train.list
15
- transcription_path: filelists/text.list
16
- val_path: filelists/val.list
17
- val_per_lang: 4
18
- resample:
19
- in_dir: audios/raw
20
- out_dir: audios/wavs
21
- sampling_rate: 44100
22
- server:
23
- device: cuda
24
- models:
25
- - config: ''
26
- device: cuda
27
- language: ZH
28
- model: ''
29
- - config: ''
30
- device: cpu
31
- language: JP
32
- model: ''
33
- speakers: []
34
- port: 5000
35
- style_gen:
36
- config_path: config.json
37
- device: cuda
38
- num_processes: 4
39
- train_ms:
40
- config_path: config.json
41
- env:
42
- LOCAL_RANK: 0
43
- MASTER_ADDR: localhost
44
- MASTER_PORT: 10086
45
- RANK: 0
46
- WORLD_SIZE: 1
47
- keep_ckpts: 1
48
- model: models
49
- num_workers: 16
50
- spec_cache: true
51
- webui:
52
- config_path: config.json
53
- debug: false
54
- device: cuda
55
- language_identification_library: langid
56
- model: models/G_8000.pth
57
- port: 7860
58
- share: false
 
1
+ bert_gen:
2
+ config_path: config.json
3
+ device: cuda
4
+ num_processes: 2
5
+ use_multi_device: false
6
+ dataset_path: Data\model_name
7
+ model_name: model_name
8
+ preprocess_text:
9
+ clean: true
10
+ cleaned_path: ''
11
+ config_path: config.json
12
+ max_val_total: 12
13
+ train_path: train.list
14
+ transcription_path: esd.list
15
+ val_path: val.list
16
+ val_per_lang: 4
17
+ resample:
18
+ in_dir: raw
19
+ out_dir: wavs
20
+ sampling_rate: 44100
21
+ server:
22
+ device: cuda
23
+ language: JP
24
+ limit: 100
25
+ origins:
26
+ - '*'
27
+ port: 5000
28
+ style_gen:
29
+ config_path: config.json
30
+ device: cuda
31
+ num_processes: 4
32
+ train_ms:
33
+ config_path: config.json
34
+ env:
35
+ LOCAL_RANK: 0
36
+ MASTER_ADDR: localhost
37
+ MASTER_PORT: 10086
38
+ RANK: 0
39
+ WORLD_SIZE: 1
40
+ keep_ckpts: 1
41
+ model_dir: models
42
+ num_workers: 16
43
+ spec_cache: true
44
+ webui:
45
+ config_path: config.json
46
+ debug: false
47
+ device: cuda
48
+ language_identification_library: langid
49
+ model: models/G_8000.pth
50
+ port: 7860
51
+ share: false
 
 
 
 
 
 
 
configs/config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "your_model_name",
3
+ "train": {
4
+ "log_interval": 200,
5
+ "eval_interval": 1000,
6
+ "seed": 42,
7
+ "epochs": 1000,
8
+ "learning_rate": 0.0002,
9
+ "betas": [0.8, 0.99],
10
+ "eps": 1e-9,
11
+ "batch_size": 4,
12
+ "bf16_run": true,
13
+ "lr_decay": 0.99995,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "skip_optimizer": false,
20
+ "freeze_ZH_bert": false,
21
+ "freeze_JP_bert": false,
22
+ "freeze_EN_bert": false,
23
+ "freeze_style": false
24
+ },
25
+ "data": {
26
+ "training_files": "Data/your_model_name/filelists/train.list",
27
+ "validation_files": "Data/your_model_name/filelists/val.list",
28
+ "max_wav_value": 32768.0,
29
+ "sampling_rate": 44100,
30
+ "filter_length": 2048,
31
+ "hop_length": 512,
32
+ "win_length": 2048,
33
+ "n_mel_channels": 128,
34
+ "mel_fmin": 0.0,
35
+ "mel_fmax": null,
36
+ "add_blank": true,
37
+ "n_speakers": 1,
38
+ "cleaned_text": true,
39
+ "num_styles": 1,
40
+ "style2id": {
41
+ "Neutral": 0
42
+ }
43
+ },
44
+ "model": {
45
+ "use_spk_conditioned_encoder": true,
46
+ "use_noise_scaled_mas": true,
47
+ "use_mel_posterior_encoder": false,
48
+ "use_duration_discriminator": true,
49
+ "inter_channels": 192,
50
+ "hidden_channels": 192,
51
+ "filter_channels": 768,
52
+ "n_heads": 2,
53
+ "n_layers": 6,
54
+ "kernel_size": 3,
55
+ "p_dropout": 0.1,
56
+ "resblock": "1",
57
+ "resblock_kernel_sizes": [3, 7, 11],
58
+ "resblock_dilation_sizes": [
59
+ [1, 3, 5],
60
+ [1, 3, 5],
61
+ [1, 3, 5]
62
+ ],
63
+ "upsample_rates": [8, 8, 2, 2, 2],
64
+ "upsample_initial_channel": 512,
65
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
66
+ "n_layers_q": 3,
67
+ "use_spectral_norm": false,
68
+ "gin_channels": 256
69
+ },
70
+ "version": "2.0.1"
71
+ }
configs/configs_jp_extra.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 42,
6
+ "epochs": 1000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 24,
11
+ "bf16_run": false,
12
+ "fp16_run": false,
13
+ "lr_decay": 0.99996,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "c_commit": 100,
20
+ "skip_optimizer": true,
21
+ "freeze_ZH_bert": false,
22
+ "freeze_JP_bert": false,
23
+ "freeze_EN_bert": false,
24
+ "freeze_emo": false,
25
+ "freeze_style": false
26
+ },
27
+ "data": {
28
+ "use_jp_extra": true,
29
+ "training_files": "filelists/train.list",
30
+ "validation_files": "filelists/val.list",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 128,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": null,
39
+ "add_blank": true,
40
+ "n_speakers": 512,
41
+ "cleaned_text": true
42
+ },
43
+ "model": {
44
+ "use_spk_conditioned_encoder": true,
45
+ "use_noise_scaled_mas": true,
46
+ "use_mel_posterior_encoder": false,
47
+ "use_duration_discriminator": false,
48
+ "use_wavlm_discriminator": true,
49
+ "inter_channels": 192,
50
+ "hidden_channels": 192,
51
+ "filter_channels": 768,
52
+ "n_heads": 2,
53
+ "n_layers": 6,
54
+ "kernel_size": 3,
55
+ "p_dropout": 0.1,
56
+ "resblock": "1",
57
+ "resblock_kernel_sizes": [3, 7, 11],
58
+ "resblock_dilation_sizes": [
59
+ [1, 3, 5],
60
+ [1, 3, 5],
61
+ [1, 3, 5]
62
+ ],
63
+ "upsample_rates": [8, 8, 2, 2, 2],
64
+ "upsample_initial_channel": 512,
65
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
66
+ "n_layers_q": 3,
67
+ "use_spectral_norm": false,
68
+ "gin_channels": 512,
69
+ "slm": {
70
+ "model": "./slm/wavlm-base-plus",
71
+ "sr": 16000,
72
+ "hidden": 768,
73
+ "nlayers": 13,
74
+ "initial_channel": 64
75
+ }
76
+ },
77
+ "version": "2.0.1-JP-Extra"
78
+ }
configs/paths.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Root directory of the training dataset.
2
+ # The training dataset of {model_name} should be placed in {dataset_root}/{model_name}.
3
+ dataset_root: Data
4
+
5
+ # Root directory of the model assets (for inference).
6
+ # In training, the model assets will be saved to {assets_root}/{model_name},
7
+ # and in inference, we load all the models from {assets_root}.
8
+ assets_root: model_assets
infer.py CHANGED
@@ -1,263 +1,306 @@
1
- import torch
2
-
3
- import commons
4
- import utils
5
- from models import SynthesizerTrn
6
- from text import cleaned_text_to_sequence, get_bert
7
- from text.cleaner import clean_text
8
- from text.symbols import symbols
9
-
10
- # latest_version = "1.0"
11
-
12
-
13
- def get_net_g(model_path: str, version: str, device: str, hps):
14
- net_g = SynthesizerTrn(
15
- len(symbols),
16
- hps.data.filter_length // 2 + 1,
17
- hps.train.segment_size // hps.data.hop_length,
18
- n_speakers=hps.data.n_speakers,
19
- **hps.model,
20
- ).to(device)
21
- net_g.state_dict()
22
- _ = net_g.eval()
23
- if model_path.endswith(".pth") or model_path.endswith(".pt"):
24
- _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
25
- elif model_path.endswith(".safetensors"):
26
- _ = utils.load_safetensors(model_path, net_g, device)
27
- else:
28
- raise ValueError(f"Unknown model format: {model_path}")
29
- return net_g
30
-
31
-
32
- def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
33
- # 在此处实现当前版本的get_text
34
- norm_text, phone, tone, word2ph = clean_text(text, language_str)
35
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
36
-
37
- if hps.data.add_blank:
38
- phone = commons.intersperse(phone, 0)
39
- tone = commons.intersperse(tone, 0)
40
- language = commons.intersperse(language, 0)
41
- for i in range(len(word2ph)):
42
- word2ph[i] = word2ph[i] * 2
43
- word2ph[0] += 1
44
- bert_ori = get_bert(
45
- norm_text, word2ph, language_str, device, style_text, style_weight
46
- )
47
- del word2ph
48
- assert bert_ori.shape[-1] == len(phone), phone
49
-
50
- if language_str == "ZH":
51
- bert = bert_ori
52
- ja_bert = torch.zeros(1024, len(phone))
53
- en_bert = torch.zeros(1024, len(phone))
54
- elif language_str == "JP":
55
- bert = torch.zeros(1024, len(phone))
56
- ja_bert = bert_ori
57
- en_bert = torch.zeros(1024, len(phone))
58
- elif language_str == "EN":
59
- bert = torch.zeros(1024, len(phone))
60
- ja_bert = torch.zeros(1024, len(phone))
61
- en_bert = bert_ori
62
- else:
63
- raise ValueError("language_str should be ZH, JP or EN")
64
-
65
- assert bert.shape[-1] == len(
66
- phone
67
- ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
68
-
69
- phone = torch.LongTensor(phone)
70
- tone = torch.LongTensor(tone)
71
- language = torch.LongTensor(language)
72
- return bert, ja_bert, en_bert, phone, tone, language
73
-
74
-
75
- def infer(
76
- text,
77
- style_vec,
78
- sdp_ratio,
79
- noise_scale,
80
- noise_scale_w,
81
- length_scale,
82
- sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id
83
- language,
84
- hps,
85
- net_g,
86
- device,
87
- skip_start=False,
88
- skip_end=False,
89
- style_text=None,
90
- style_weight=0.7,
91
- ):
92
- bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
93
- text,
94
- language,
95
- hps,
96
- device,
97
- style_text=style_text,
98
- style_weight=style_weight,
99
- )
100
- if skip_start:
101
- phones = phones[3:]
102
- tones = tones[3:]
103
- lang_ids = lang_ids[3:]
104
- bert = bert[:, 3:]
105
- ja_bert = ja_bert[:, 3:]
106
- en_bert = en_bert[:, 3:]
107
- if skip_end:
108
- phones = phones[:-2]
109
- tones = tones[:-2]
110
- lang_ids = lang_ids[:-2]
111
- bert = bert[:, :-2]
112
- ja_bert = ja_bert[:, :-2]
113
- en_bert = en_bert[:, :-2]
114
- with torch.no_grad():
115
- x_tst = phones.to(device).unsqueeze(0)
116
- tones = tones.to(device).unsqueeze(0)
117
- lang_ids = lang_ids.to(device).unsqueeze(0)
118
- bert = bert.to(device).unsqueeze(0)
119
- ja_bert = ja_bert.to(device).unsqueeze(0)
120
- en_bert = en_bert.to(device).unsqueeze(0)
121
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
122
- style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0)
123
- del phones
124
- sid_tensor = torch.LongTensor([sid]).to(device)
125
- audio = (
126
- net_g.infer(
127
- x_tst,
128
- x_tst_lengths,
129
- sid_tensor,
130
- tones,
131
- lang_ids,
132
- bert,
133
- ja_bert,
134
- en_bert,
135
- style_vec=style_vec,
136
- sdp_ratio=sdp_ratio,
137
- noise_scale=noise_scale,
138
- noise_scale_w=noise_scale_w,
139
- length_scale=length_scale,
140
- )[0][0, 0]
141
- .data.cpu()
142
- .float()
143
- .numpy()
144
- )
145
- del (
146
- x_tst,
147
- tones,
148
- lang_ids,
149
- bert,
150
- x_tst_lengths,
151
- sid_tensor,
152
- ja_bert,
153
- en_bert,
154
- style_vec,
155
- ) # , emo
156
- if torch.cuda.is_available():
157
- torch.cuda.empty_cache()
158
- return audio
159
-
160
-
161
- def infer_multilang(
162
- text,
163
- style_vec,
164
- sdp_ratio,
165
- noise_scale,
166
- noise_scale_w,
167
- length_scale,
168
- sid,
169
- language,
170
- hps,
171
- net_g,
172
- device,
173
- skip_start=False,
174
- skip_end=False,
175
- ):
176
- bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
177
- # emo = get_emo_(reference_audio, emotion, sid)
178
- # if isinstance(reference_audio, np.ndarray):
179
- # emo = get_clap_audio_feature(reference_audio, device)
180
- # else:
181
- # emo = get_clap_text_feature(emotion, device)
182
- # emo = torch.squeeze(emo, dim=1)
183
- for idx, (txt, lang) in enumerate(zip(text, language)):
184
- _skip_start = (idx != 0) or (skip_start and idx == 0)
185
- _skip_end = (idx != len(language) - 1) or skip_end
186
- (
187
- temp_bert,
188
- temp_ja_bert,
189
- temp_en_bert,
190
- temp_phones,
191
- temp_tones,
192
- temp_lang_ids,
193
- ) = get_text(txt, lang, hps, device)
194
- if _skip_start:
195
- temp_bert = temp_bert[:, 3:]
196
- temp_ja_bert = temp_ja_bert[:, 3:]
197
- temp_en_bert = temp_en_bert[:, 3:]
198
- temp_phones = temp_phones[3:]
199
- temp_tones = temp_tones[3:]
200
- temp_lang_ids = temp_lang_ids[3:]
201
- if _skip_end:
202
- temp_bert = temp_bert[:, :-2]
203
- temp_ja_bert = temp_ja_bert[:, :-2]
204
- temp_en_bert = temp_en_bert[:, :-2]
205
- temp_phones = temp_phones[:-2]
206
- temp_tones = temp_tones[:-2]
207
- temp_lang_ids = temp_lang_ids[:-2]
208
- bert.append(temp_bert)
209
- ja_bert.append(temp_ja_bert)
210
- en_bert.append(temp_en_bert)
211
- phones.append(temp_phones)
212
- tones.append(temp_tones)
213
- lang_ids.append(temp_lang_ids)
214
- bert = torch.concatenate(bert, dim=1)
215
- ja_bert = torch.concatenate(ja_bert, dim=1)
216
- en_bert = torch.concatenate(en_bert, dim=1)
217
- phones = torch.concatenate(phones, dim=0)
218
- tones = torch.concatenate(tones, dim=0)
219
- lang_ids = torch.concatenate(lang_ids, dim=0)
220
- with torch.no_grad():
221
- x_tst = phones.to(device).unsqueeze(0)
222
- tones = tones.to(device).unsqueeze(0)
223
- lang_ids = lang_ids.to(device).unsqueeze(0)
224
- bert = bert.to(device).unsqueeze(0)
225
- ja_bert = ja_bert.to(device).unsqueeze(0)
226
- en_bert = en_bert.to(device).unsqueeze(0)
227
- # emo = emo.to(device).unsqueeze(0)
228
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
229
- del phones
230
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
231
- audio = (
232
- net_g.infer(
233
- x_tst,
234
- x_tst_lengths,
235
- speakers,
236
- tones,
237
- lang_ids,
238
- bert,
239
- ja_bert,
240
- en_bert,
241
- style_vec=style_vec,
242
- sdp_ratio=sdp_ratio,
243
- noise_scale=noise_scale,
244
- noise_scale_w=noise_scale_w,
245
- length_scale=length_scale,
246
- )[0][0, 0]
247
- .data.cpu()
248
- .float()
249
- .numpy()
250
- )
251
- del (
252
- x_tst,
253
- tones,
254
- lang_ids,
255
- bert,
256
- x_tst_lengths,
257
- speakers,
258
- ja_bert,
259
- en_bert,
260
- ) # , emo
261
- if torch.cuda.is_available():
262
- torch.cuda.empty_cache()
263
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import commons
4
+ import utils
5
+ from models import SynthesizerTrn
6
+ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
7
+ from text import cleaned_text_to_sequence, get_bert
8
+ from text.cleaner import clean_text
9
+ from text.symbols import symbols
10
+ from common.log import logger
11
+
12
+
13
+ class InvalidToneError(ValueError):
14
+ pass
15
+
16
+
17
+ def get_net_g(model_path: str, version: str, device: str, hps):
18
+ if version.endswith("JP-Extra"):
19
+ logger.info("Using JP-Extra model")
20
+ net_g = SynthesizerTrnJPExtra(
21
+ len(symbols),
22
+ hps.data.filter_length // 2 + 1,
23
+ hps.train.segment_size // hps.data.hop_length,
24
+ n_speakers=hps.data.n_speakers,
25
+ **hps.model,
26
+ ).to(device)
27
+ else:
28
+ logger.info("Using normal model")
29
+ net_g = SynthesizerTrn(
30
+ len(symbols),
31
+ hps.data.filter_length // 2 + 1,
32
+ hps.train.segment_size // hps.data.hop_length,
33
+ n_speakers=hps.data.n_speakers,
34
+ **hps.model,
35
+ ).to(device)
36
+ net_g.state_dict()
37
+ _ = net_g.eval()
38
+ if model_path.endswith(".pth") or model_path.endswith(".pt"):
39
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
40
+ elif model_path.endswith(".safetensors"):
41
+ _ = utils.load_safetensors(model_path, net_g, True)
42
+ else:
43
+ raise ValueError(f"Unknown model format: {model_path}")
44
+ return net_g
45
+
46
+
47
+ def get_text(
48
+ text,
49
+ language_str,
50
+ hps,
51
+ device,
52
+ assist_text=None,
53
+ assist_text_weight=0.7,
54
+ given_tone=None,
55
+ ):
56
+ use_jp_extra = hps.version.endswith("JP-Extra")
57
+ norm_text, phone, tone, word2ph = clean_text(text, language_str, use_jp_extra)
58
+ if given_tone is not None:
59
+ if len(given_tone) != len(phone):
60
+ raise InvalidToneError(
61
+ f"Length of given_tone ({len(given_tone)}) != length of phone ({len(phone)})"
62
+ )
63
+ tone = given_tone
64
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
65
+
66
+ if hps.data.add_blank:
67
+ phone = commons.intersperse(phone, 0)
68
+ tone = commons.intersperse(tone, 0)
69
+ language = commons.intersperse(language, 0)
70
+ for i in range(len(word2ph)):
71
+ word2ph[i] = word2ph[i] * 2
72
+ word2ph[0] += 1
73
+ bert_ori = get_bert(
74
+ norm_text, word2ph, language_str, device, assist_text, assist_text_weight
75
+ )
76
+ del word2ph
77
+ assert bert_ori.shape[-1] == len(phone), phone
78
+
79
+ if language_str == "ZH":
80
+ bert = bert_ori
81
+ ja_bert = torch.zeros(1024, len(phone))
82
+ en_bert = torch.zeros(1024, len(phone))
83
+ elif language_str == "JP":
84
+ bert = torch.zeros(1024, len(phone))
85
+ ja_bert = bert_ori
86
+ en_bert = torch.zeros(1024, len(phone))
87
+ elif language_str == "EN":
88
+ bert = torch.zeros(1024, len(phone))
89
+ ja_bert = torch.zeros(1024, len(phone))
90
+ en_bert = bert_ori
91
+ else:
92
+ raise ValueError("language_str should be ZH, JP or EN")
93
+
94
+ assert bert.shape[-1] == len(
95
+ phone
96
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
97
+
98
+ phone = torch.LongTensor(phone)
99
+ tone = torch.LongTensor(tone)
100
+ language = torch.LongTensor(language)
101
+ return bert, ja_bert, en_bert, phone, tone, language
102
+
103
+
104
+ def infer(
105
+ text,
106
+ style_vec,
107
+ sdp_ratio,
108
+ noise_scale,
109
+ noise_scale_w,
110
+ length_scale,
111
+ sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id
112
+ language,
113
+ hps,
114
+ net_g,
115
+ device,
116
+ skip_start=False,
117
+ skip_end=False,
118
+ assist_text=None,
119
+ assist_text_weight=0.7,
120
+ given_tone=None,
121
+ ):
122
+ is_jp_extra = hps.version.endswith("JP-Extra")
123
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
124
+ text,
125
+ language,
126
+ hps,
127
+ device,
128
+ assist_text=assist_text,
129
+ assist_text_weight=assist_text_weight,
130
+ given_tone=given_tone,
131
+ )
132
+ if skip_start:
133
+ phones = phones[3:]
134
+ tones = tones[3:]
135
+ lang_ids = lang_ids[3:]
136
+ bert = bert[:, 3:]
137
+ ja_bert = ja_bert[:, 3:]
138
+ en_bert = en_bert[:, 3:]
139
+ if skip_end:
140
+ phones = phones[:-2]
141
+ tones = tones[:-2]
142
+ lang_ids = lang_ids[:-2]
143
+ bert = bert[:, :-2]
144
+ ja_bert = ja_bert[:, :-2]
145
+ en_bert = en_bert[:, :-2]
146
+ with torch.no_grad():
147
+ x_tst = phones.to(device).unsqueeze(0)
148
+ tones = tones.to(device).unsqueeze(0)
149
+ lang_ids = lang_ids.to(device).unsqueeze(0)
150
+ bert = bert.to(device).unsqueeze(0)
151
+ ja_bert = ja_bert.to(device).unsqueeze(0)
152
+ en_bert = en_bert.to(device).unsqueeze(0)
153
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
154
+ style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0)
155
+ del phones
156
+ sid_tensor = torch.LongTensor([sid]).to(device)
157
+ if is_jp_extra:
158
+ output = net_g.infer(
159
+ x_tst,
160
+ x_tst_lengths,
161
+ sid_tensor,
162
+ tones,
163
+ lang_ids,
164
+ ja_bert,
165
+ style_vec=style_vec,
166
+ sdp_ratio=sdp_ratio,
167
+ noise_scale=noise_scale,
168
+ noise_scale_w=noise_scale_w,
169
+ length_scale=length_scale,
170
+ )
171
+ else:
172
+ output = net_g.infer(
173
+ x_tst,
174
+ x_tst_lengths,
175
+ sid_tensor,
176
+ tones,
177
+ lang_ids,
178
+ bert,
179
+ ja_bert,
180
+ en_bert,
181
+ style_vec=style_vec,
182
+ sdp_ratio=sdp_ratio,
183
+ noise_scale=noise_scale,
184
+ noise_scale_w=noise_scale_w,
185
+ length_scale=length_scale,
186
+ )
187
+ audio = output[0][0, 0].data.cpu().float().numpy()
188
+ del (
189
+ x_tst,
190
+ tones,
191
+ lang_ids,
192
+ bert,
193
+ x_tst_lengths,
194
+ sid_tensor,
195
+ ja_bert,
196
+ en_bert,
197
+ style_vec,
198
+ ) # , emo
199
+ if torch.cuda.is_available():
200
+ torch.cuda.empty_cache()
201
+ return audio
202
+
203
+
204
+ def infer_multilang(
205
+ text,
206
+ style_vec,
207
+ sdp_ratio,
208
+ noise_scale,
209
+ noise_scale_w,
210
+ length_scale,
211
+ sid,
212
+ language,
213
+ hps,
214
+ net_g,
215
+ device,
216
+ skip_start=False,
217
+ skip_end=False,
218
+ ):
219
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
220
+ # emo = get_emo_(reference_audio, emotion, sid)
221
+ # if isinstance(reference_audio, np.ndarray):
222
+ # emo = get_clap_audio_feature(reference_audio, device)
223
+ # else:
224
+ # emo = get_clap_text_feature(emotion, device)
225
+ # emo = torch.squeeze(emo, dim=1)
226
+ for idx, (txt, lang) in enumerate(zip(text, language)):
227
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
228
+ _skip_end = (idx != len(language) - 1) or skip_end
229
+ (
230
+ temp_bert,
231
+ temp_ja_bert,
232
+ temp_en_bert,
233
+ temp_phones,
234
+ temp_tones,
235
+ temp_lang_ids,
236
+ ) = get_text(txt, lang, hps, device)
237
+ if _skip_start:
238
+ temp_bert = temp_bert[:, 3:]
239
+ temp_ja_bert = temp_ja_bert[:, 3:]
240
+ temp_en_bert = temp_en_bert[:, 3:]
241
+ temp_phones = temp_phones[3:]
242
+ temp_tones = temp_tones[3:]
243
+ temp_lang_ids = temp_lang_ids[3:]
244
+ if _skip_end:
245
+ temp_bert = temp_bert[:, :-2]
246
+ temp_ja_bert = temp_ja_bert[:, :-2]
247
+ temp_en_bert = temp_en_bert[:, :-2]
248
+ temp_phones = temp_phones[:-2]
249
+ temp_tones = temp_tones[:-2]
250
+ temp_lang_ids = temp_lang_ids[:-2]
251
+ bert.append(temp_bert)
252
+ ja_bert.append(temp_ja_bert)
253
+ en_bert.append(temp_en_bert)
254
+ phones.append(temp_phones)
255
+ tones.append(temp_tones)
256
+ lang_ids.append(temp_lang_ids)
257
+ bert = torch.concatenate(bert, dim=1)
258
+ ja_bert = torch.concatenate(ja_bert, dim=1)
259
+ en_bert = torch.concatenate(en_bert, dim=1)
260
+ phones = torch.concatenate(phones, dim=0)
261
+ tones = torch.concatenate(tones, dim=0)
262
+ lang_ids = torch.concatenate(lang_ids, dim=0)
263
+ with torch.no_grad():
264
+ x_tst = phones.to(device).unsqueeze(0)
265
+ tones = tones.to(device).unsqueeze(0)
266
+ lang_ids = lang_ids.to(device).unsqueeze(0)
267
+ bert = bert.to(device).unsqueeze(0)
268
+ ja_bert = ja_bert.to(device).unsqueeze(0)
269
+ en_bert = en_bert.to(device).unsqueeze(0)
270
+ # emo = emo.to(device).unsqueeze(0)
271
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
272
+ del phones
273
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
274
+ audio = (
275
+ net_g.infer(
276
+ x_tst,
277
+ x_tst_lengths,
278
+ speakers,
279
+ tones,
280
+ lang_ids,
281
+ bert,
282
+ ja_bert,
283
+ en_bert,
284
+ style_vec=style_vec,
285
+ sdp_ratio=sdp_ratio,
286
+ noise_scale=noise_scale,
287
+ noise_scale_w=noise_scale_w,
288
+ length_scale=length_scale,
289
+ )[0][0, 0]
290
+ .data.cpu()
291
+ .float()
292
+ .numpy()
293
+ )
294
+ del (
295
+ x_tst,
296
+ tones,
297
+ lang_ids,
298
+ bert,
299
+ x_tst_lengths,
300
+ speakers,
301
+ ja_bert,
302
+ en_bert,
303
+ ) # , emo
304
+ if torch.cuda.is_available():
305
+ torch.cuda.empty_cache()
306
+ return audio
model_assets/jvnv-F1-jp/config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 42,
6
+ "epochs": 300,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 4,
11
+ "bf16_run": true,
12
+ "fp16_run": false,
13
+ "lr_decay": 0.99996,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "c_commit": 100,
20
+ "skip_optimizer": true,
21
+ "freeze_ZH_bert": false,
22
+ "freeze_JP_bert": false,
23
+ "freeze_EN_bert": false,
24
+ "freeze_emo": false,
25
+ "freeze_style": false
26
+ },
27
+ "data": {
28
+ "use_jp_extra": true,
29
+ "training_files": "Data/jvnv-F1-jp/train.list",
30
+ "validation_files": "Data/jvnv-F1-jp/val.list",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 128,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": null,
39
+ "add_blank": true,
40
+ "n_speakers": 1,
41
+ "cleaned_text": true,
42
+ "spk2id": {
43
+ "jvnv-F1-jp": 0
44
+ },
45
+ "num_styles": 7,
46
+ "style2id": {
47
+ "Neutral": 0,
48
+ "Angry": 1,
49
+ "Disgust": 2,
50
+ "Fear": 3,
51
+ "Happy": 4,
52
+ "Sad": 5,
53
+ "Surprise": 6
54
+ }
55
+ },
56
+ "model": {
57
+ "use_spk_conditioned_encoder": true,
58
+ "use_noise_scaled_mas": true,
59
+ "use_mel_posterior_encoder": false,
60
+ "use_duration_discriminator": false,
61
+ "use_wavlm_discriminator": true,
62
+ "inter_channels": 192,
63
+ "hidden_channels": 192,
64
+ "filter_channels": 768,
65
+ "n_heads": 2,
66
+ "n_layers": 6,
67
+ "kernel_size": 3,
68
+ "p_dropout": 0.1,
69
+ "resblock": "1",
70
+ "resblock_kernel_sizes": [3, 7, 11],
71
+ "resblock_dilation_sizes": [
72
+ [1, 3, 5],
73
+ [1, 3, 5],
74
+ [1, 3, 5]
75
+ ],
76
+ "upsample_rates": [8, 8, 2, 2, 2],
77
+ "upsample_initial_channel": 512,
78
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
79
+ "n_layers_q": 3,
80
+ "use_spectral_norm": false,
81
+ "gin_channels": 512,
82
+ "slm": {
83
+ "model": "./slm/wavlm-base-plus",
84
+ "sr": 16000,
85
+ "hidden": 768,
86
+ "nlayers": 13,
87
+ "initial_channel": 64
88
+ }
89
+ },
90
+ "version": "2.0-JP-Extra",
91
+ "model_name": "jvnv-F1-jp"
92
+ }
model_assets/jvnv-F1-jp/jvnv-F1-jp_e182_s16000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd324b2dd04b4a3384e0dbf4a268fd8a9bbedcfe80608fbd2a3aaaa44e474abe
3
+ size 251150980
model_assets/jvnv-F1-jp/style_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f959bb45ed0922efc31ff24e9147253814f42cb1d2d1e2bb10391a9df368489
3
+ size 7296
model_assets/jvnv-F2-jp/config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 42,
6
+ "epochs": 300,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 4,
11
+ "bf16_run": false,
12
+ "fp16_run": false,
13
+ "lr_decay": 0.99996,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "c_commit": 100,
20
+ "skip_optimizer": true,
21
+ "freeze_ZH_bert": false,
22
+ "freeze_JP_bert": false,
23
+ "freeze_EN_bert": false,
24
+ "freeze_emo": false,
25
+ "freeze_style": false
26
+ },
27
+ "data": {
28
+ "use_jp_extra": true,
29
+ "training_files": "/content/drive/MyDrive/Style-Bert-VITS2/Data/jvnv-F2/train.list",
30
+ "validation_files": "/content/drive/MyDrive/Style-Bert-VITS2/Data/jvnv-F2/val.list",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 128,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": null,
39
+ "add_blank": true,
40
+ "n_speakers": 1,
41
+ "cleaned_text": true,
42
+ "spk2id": {
43
+ "jvnv-F2-jp": 0
44
+ },
45
+ "num_styles": 7,
46
+ "style2id": {
47
+ "Neutral": 0,
48
+ "Angry": 1,
49
+ "Disgust": 2,
50
+ "Fear": 3,
51
+ "Happy": 4,
52
+ "Sad": 5,
53
+ "Surprise": 6
54
+ }
55
+ },
56
+ "model": {
57
+ "use_spk_conditioned_encoder": true,
58
+ "use_noise_scaled_mas": true,
59
+ "use_mel_posterior_encoder": false,
60
+ "use_duration_discriminator": false,
61
+ "use_wavlm_discriminator": true,
62
+ "inter_channels": 192,
63
+ "hidden_channels": 192,
64
+ "filter_channels": 768,
65
+ "n_heads": 2,
66
+ "n_layers": 6,
67
+ "kernel_size": 3,
68
+ "p_dropout": 0.1,
69
+ "resblock": "1",
70
+ "resblock_kernel_sizes": [3, 7, 11],
71
+ "resblock_dilation_sizes": [
72
+ [1, 3, 5],
73
+ [1, 3, 5],
74
+ [1, 3, 5]
75
+ ],
76
+ "upsample_rates": [8, 8, 2, 2, 2],
77
+ "upsample_initial_channel": 512,
78
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
79
+ "n_layers_q": 3,
80
+ "use_spectral_norm": false,
81
+ "gin_channels": 512,
82
+ "slm": {
83
+ "model": "./slm/wavlm-base-plus",
84
+ "sr": 16000,
85
+ "hidden": 768,
86
+ "nlayers": 13,
87
+ "initial_channel": 64
88
+ }
89
+ },
90
+ "version": "2.0-JP-Extra",
91
+ "model_name": "jvnv-F2-jp"
92
+ }
model_assets/jvnv-F2-jp/jvnv-F2_e166_s20000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6289a6f30bb9795744815b9da764a3c8198b18652d9fddef82fff1e14f0e784
3
+ size 251150980
model_assets/jvnv-F2-jp/style_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:900f8cde3a336d12193fec7b7d8e6c5dc77b3a5d719a9be3f8598389cd88e643
3
+ size 7296
model_assets/jvnv-M1-jp/config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 42,
6
+ "epochs": 300,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 4,
11
+ "bf16_run": true,
12
+ "fp16_run": false,
13
+ "lr_decay": 0.99996,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "c_commit": 100,
20
+ "skip_optimizer": true,
21
+ "freeze_ZH_bert": false,
22
+ "freeze_JP_bert": false,
23
+ "freeze_EN_bert": false,
24
+ "freeze_emo": false,
25
+ "freeze_style": false
26
+ },
27
+ "data": {
28
+ "use_jp_extra": true,
29
+ "training_files": "Data/jvnv-M1-jp/train.list",
30
+ "validation_files": "Data/jvnv-M1-jp/val.list",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 128,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": null,
39
+ "add_blank": true,
40
+ "n_speakers": 1,
41
+ "cleaned_text": true,
42
+ "spk2id": {
43
+ "jvnv-M1-jp": 0
44
+ },
45
+ "num_styles": 7,
46
+ "style2id": {
47
+ "Neutral": 0,
48
+ "Angry": 1,
49
+ "Disgust": 2,
50
+ "Fear": 3,
51
+ "Happy": 4,
52
+ "Sad": 5,
53
+ "Surprise": 6
54
+ }
55
+ },
56
+ "model": {
57
+ "use_spk_conditioned_encoder": true,
58
+ "use_noise_scaled_mas": true,
59
+ "use_mel_posterior_encoder": false,
60
+ "use_duration_discriminator": false,
61
+ "use_wavlm_discriminator": true,
62
+ "inter_channels": 192,
63
+ "hidden_channels": 192,
64
+ "filter_channels": 768,
65
+ "n_heads": 2,
66
+ "n_layers": 6,
67
+ "kernel_size": 3,
68
+ "p_dropout": 0.1,
69
+ "resblock": "1",
70
+ "resblock_kernel_sizes": [3, 7, 11],
71
+ "resblock_dilation_sizes": [
72
+ [1, 3, 5],
73
+ [1, 3, 5],
74
+ [1, 3, 5]
75
+ ],
76
+ "upsample_rates": [8, 8, 2, 2, 2],
77
+ "upsample_initial_channel": 512,
78
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
79
+ "n_layers_q": 3,
80
+ "use_spectral_norm": false,
81
+ "gin_channels": 512,
82
+ "slm": {
83
+ "model": "./slm/wavlm-base-plus",
84
+ "sr": 16000,
85
+ "hidden": 768,
86
+ "nlayers": 13,
87
+ "initial_channel": 64
88
+ }
89
+ },
90
+ "version": "2.0-JP-Extra",
91
+ "model_name": "jvnv-M1-jp"
92
+ }
model_assets/jvnv-M1-jp/jvnv-M1-jp_e158_s14000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d86765f1fe08dbba74cd06283e96b6941b3f232329fabbba9c30e6edc27887a
3
+ size 251150980
model_assets/jvnv-M1-jp/style_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a925435e8c1c9efc8fc8e90e690655ab9a7bae00a790892e13e936510d04f05
3
+ size 7296
model_assets/jvnv-M2-jp/config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 42,
6
+ "epochs": 300,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 4,
11
+ "bf16_run": true,
12
+ "fp16_run": false,
13
+ "lr_decay": 0.99996,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "c_commit": 100,
20
+ "skip_optimizer": true,
21
+ "freeze_ZH_bert": false,
22
+ "freeze_JP_bert": false,
23
+ "freeze_EN_bert": false,
24
+ "freeze_emo": false,
25
+ "freeze_style": false
26
+ },
27
+ "data": {
28
+ "use_jp_extra": true,
29
+ "training_files": "Data/jvnv-M2-jp/train.list",
30
+ "validation_files": "Data/jvnv-M2-jp/val.list",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 128,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": null,
39
+ "add_blank": true,
40
+ "n_speakers": 1,
41
+ "cleaned_text": true,
42
+ "spk2id": {
43
+ "jvnv-M2-jp": 0
44
+ },
45
+ "num_styles": 7,
46
+ "style2id": {
47
+ "Neutral": 0,
48
+ "Angry": 1,
49
+ "Disgust": 2,
50
+ "Fear": 3,
51
+ "Happy": 4,
52
+ "Sad": 5,
53
+ "Surprise": 6
54
+ }
55
+ },
56
+ "model": {
57
+ "use_spk_conditioned_encoder": true,
58
+ "use_noise_scaled_mas": true,
59
+ "use_mel_posterior_encoder": false,
60
+ "use_duration_discriminator": false,
61
+ "use_wavlm_discriminator": true,
62
+ "inter_channels": 192,
63
+ "hidden_channels": 192,
64
+ "filter_channels": 768,
65
+ "n_heads": 2,
66
+ "n_layers": 6,
67
+ "kernel_size": 3,
68
+ "p_dropout": 0.1,
69
+ "resblock": "1",
70
+ "resblock_kernel_sizes": [3, 7, 11],
71
+ "resblock_dilation_sizes": [
72
+ [1, 3, 5],
73
+ [1, 3, 5],
74
+ [1, 3, 5]
75
+ ],
76
+ "upsample_rates": [8, 8, 2, 2, 2],
77
+ "upsample_initial_channel": 512,
78
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
79
+ "n_layers_q": 3,
80
+ "use_spectral_norm": false,
81
+ "gin_channels": 512,
82
+ "slm": {
83
+ "model": "./slm/wavlm-base-plus",
84
+ "sr": 16000,
85
+ "hidden": 768,
86
+ "nlayers": 13,
87
+ "initial_channel": 64
88
+ }
89
+ },
90
+ "version": "2.0-JP-Extra",
91
+ "model_name": "jvnv-M2-jp"
92
+ }
model_assets/jvnv-M2-jp/jvnv-M2-jp_e159_s17000.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8245f39438076d36a3befd8aefb15c38830cef326c1f7c9d9c8e64b647645402
3
+ size 251150980
model_assets/jvnv-M2-jp/style_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c965bb63fa4a759d41a8a4a3649333125d6497ae8a705d81b7d5c5bd2854797c
3
+ size 7296
models_jp_extra.py ADDED
@@ -0,0 +1,1071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ from commons import init_weights, get_padding
15
+ from text import symbols, num_tones, num_languages
16
+
17
+
18
+ class DurationDiscriminator(nn.Module): # vits2
19
+ def __init__(
20
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
21
+ ):
22
+ super().__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.filter_channels = filter_channels
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.gin_channels = gin_channels
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.conv_1 = nn.Conv1d(
32
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
33
+ )
34
+ self.norm_1 = modules.LayerNorm(filter_channels)
35
+ self.conv_2 = nn.Conv1d(
36
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
37
+ )
38
+ self.norm_2 = modules.LayerNorm(filter_channels)
39
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
40
+
41
+ self.LSTM = nn.LSTM(
42
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
43
+ )
44
+
45
+ if gin_channels != 0:
46
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
47
+
48
+ self.output_layer = nn.Sequential(
49
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
50
+ )
51
+
52
+ def forward_probability(self, x, dur):
53
+ dur = self.dur_proj(dur)
54
+ x = torch.cat([x, dur], dim=1)
55
+ x = x.transpose(1, 2)
56
+ x, _ = self.LSTM(x)
57
+ output_prob = self.output_layer(x)
58
+ return output_prob
59
+
60
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
61
+ x = torch.detach(x)
62
+ if g is not None:
63
+ g = torch.detach(g)
64
+ x = x + self.cond(g)
65
+ x = self.conv_1(x * x_mask)
66
+ x = torch.relu(x)
67
+ x = self.norm_1(x)
68
+ x = self.drop(x)
69
+ x = self.conv_2(x * x_mask)
70
+ x = torch.relu(x)
71
+ x = self.norm_2(x)
72
+ x = self.drop(x)
73
+
74
+ output_probs = []
75
+ for dur in [dur_r, dur_hat]:
76
+ output_prob = self.forward_probability(x, dur)
77
+ output_probs.append(output_prob)
78
+
79
+ return output_probs
80
+
81
+
82
+ class TransformerCouplingBlock(nn.Module):
83
+ def __init__(
84
+ self,
85
+ channels,
86
+ hidden_channels,
87
+ filter_channels,
88
+ n_heads,
89
+ n_layers,
90
+ kernel_size,
91
+ p_dropout,
92
+ n_flows=4,
93
+ gin_channels=0,
94
+ share_parameter=False,
95
+ ):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.hidden_channels = hidden_channels
99
+ self.kernel_size = kernel_size
100
+ self.n_layers = n_layers
101
+ self.n_flows = n_flows
102
+ self.gin_channels = gin_channels
103
+
104
+ self.flows = nn.ModuleList()
105
+
106
+ self.wn = (
107
+ attentions.FFT(
108
+ hidden_channels,
109
+ filter_channels,
110
+ n_heads,
111
+ n_layers,
112
+ kernel_size,
113
+ p_dropout,
114
+ isflow=True,
115
+ gin_channels=self.gin_channels,
116
+ )
117
+ if share_parameter
118
+ else None
119
+ )
120
+
121
+ for i in range(n_flows):
122
+ self.flows.append(
123
+ modules.TransformerCouplingLayer(
124
+ channels,
125
+ hidden_channels,
126
+ kernel_size,
127
+ n_layers,
128
+ n_heads,
129
+ p_dropout,
130
+ filter_channels,
131
+ mean_only=True,
132
+ wn_sharing_parameter=self.wn,
133
+ gin_channels=self.gin_channels,
134
+ )
135
+ )
136
+ self.flows.append(modules.Flip())
137
+
138
+ def forward(self, x, x_mask, g=None, reverse=False):
139
+ if not reverse:
140
+ for flow in self.flows:
141
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
142
+ else:
143
+ for flow in reversed(self.flows):
144
+ x = flow(x, x_mask, g=g, reverse=reverse)
145
+ return x
146
+
147
+
148
+ class StochasticDurationPredictor(nn.Module):
149
+ def __init__(
150
+ self,
151
+ in_channels,
152
+ filter_channels,
153
+ kernel_size,
154
+ p_dropout,
155
+ n_flows=4,
156
+ gin_channels=0,
157
+ ):
158
+ super().__init__()
159
+ filter_channels = in_channels # it needs to be removed from future version.
160
+ self.in_channels = in_channels
161
+ self.filter_channels = filter_channels
162
+ self.kernel_size = kernel_size
163
+ self.p_dropout = p_dropout
164
+ self.n_flows = n_flows
165
+ self.gin_channels = gin_channels
166
+
167
+ self.log_flow = modules.Log()
168
+ self.flows = nn.ModuleList()
169
+ self.flows.append(modules.ElementwiseAffine(2))
170
+ for i in range(n_flows):
171
+ self.flows.append(
172
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
173
+ )
174
+ self.flows.append(modules.Flip())
175
+
176
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
177
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
178
+ self.post_convs = modules.DDSConv(
179
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
180
+ )
181
+ self.post_flows = nn.ModuleList()
182
+ self.post_flows.append(modules.ElementwiseAffine(2))
183
+ for i in range(4):
184
+ self.post_flows.append(
185
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
186
+ )
187
+ self.post_flows.append(modules.Flip())
188
+
189
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
190
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
191
+ self.convs = modules.DDSConv(
192
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
193
+ )
194
+ if gin_channels != 0:
195
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
196
+
197
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
198
+ x = torch.detach(x)
199
+ x = self.pre(x)
200
+ if g is not None:
201
+ g = torch.detach(g)
202
+ x = x + self.cond(g)
203
+ x = self.convs(x, x_mask)
204
+ x = self.proj(x) * x_mask
205
+
206
+ if not reverse:
207
+ flows = self.flows
208
+ assert w is not None
209
+
210
+ logdet_tot_q = 0
211
+ h_w = self.post_pre(w)
212
+ h_w = self.post_convs(h_w, x_mask)
213
+ h_w = self.post_proj(h_w) * x_mask
214
+ e_q = (
215
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
216
+ * x_mask
217
+ )
218
+ z_q = e_q
219
+ for flow in self.post_flows:
220
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
221
+ logdet_tot_q += logdet_q
222
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
223
+ u = torch.sigmoid(z_u) * x_mask
224
+ z0 = (w - u) * x_mask
225
+ logdet_tot_q += torch.sum(
226
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
227
+ )
228
+ logq = (
229
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
230
+ - logdet_tot_q
231
+ )
232
+
233
+ logdet_tot = 0
234
+ z0, logdet = self.log_flow(z0, x_mask)
235
+ logdet_tot += logdet
236
+ z = torch.cat([z0, z1], 1)
237
+ for flow in flows:
238
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
239
+ logdet_tot = logdet_tot + logdet
240
+ nll = (
241
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
242
+ - logdet_tot
243
+ )
244
+ return nll + logq # [b]
245
+ else:
246
+ flows = list(reversed(self.flows))
247
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
248
+ z = (
249
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
250
+ * noise_scale
251
+ )
252
+ for flow in flows:
253
+ z = flow(z, x_mask, g=x, reverse=reverse)
254
+ z0, z1 = torch.split(z, [1, 1], 1)
255
+ logw = z0
256
+ return logw
257
+
258
+
259
+ class DurationPredictor(nn.Module):
260
+ def __init__(
261
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
262
+ ):
263
+ super().__init__()
264
+
265
+ self.in_channels = in_channels
266
+ self.filter_channels = filter_channels
267
+ self.kernel_size = kernel_size
268
+ self.p_dropout = p_dropout
269
+ self.gin_channels = gin_channels
270
+
271
+ self.drop = nn.Dropout(p_dropout)
272
+ self.conv_1 = nn.Conv1d(
273
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
274
+ )
275
+ self.norm_1 = modules.LayerNorm(filter_channels)
276
+ self.conv_2 = nn.Conv1d(
277
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
278
+ )
279
+ self.norm_2 = modules.LayerNorm(filter_channels)
280
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
281
+
282
+ if gin_channels != 0:
283
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
284
+
285
+ def forward(self, x, x_mask, g=None):
286
+ x = torch.detach(x)
287
+ if g is not None:
288
+ g = torch.detach(g)
289
+ x = x + self.cond(g)
290
+ x = self.conv_1(x * x_mask)
291
+ x = torch.relu(x)
292
+ x = self.norm_1(x)
293
+ x = self.drop(x)
294
+ x = self.conv_2(x * x_mask)
295
+ x = torch.relu(x)
296
+ x = self.norm_2(x)
297
+ x = self.drop(x)
298
+ x = self.proj(x * x_mask)
299
+ return x * x_mask
300
+
301
+
302
+ class Bottleneck(nn.Sequential):
303
+ def __init__(self, in_dim, hidden_dim):
304
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
305
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
306
+ super().__init__(*[c_fc1, c_fc2])
307
+
308
+
309
+ class Block(nn.Module):
310
+ def __init__(self, in_dim, hidden_dim) -> None:
311
+ super().__init__()
312
+ self.norm = nn.LayerNorm(in_dim)
313
+ self.mlp = MLP(in_dim, hidden_dim)
314
+
315
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
+ x = x + self.mlp(self.norm(x))
317
+ return x
318
+
319
+
320
+ class MLP(nn.Module):
321
+ def __init__(self, in_dim, hidden_dim):
322
+ super().__init__()
323
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
324
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
325
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
326
+
327
+ def forward(self, x: torch.Tensor):
328
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
329
+ x = self.c_proj(x)
330
+ return x
331
+
332
+
333
+ class TextEncoder(nn.Module):
334
+ def __init__(
335
+ self,
336
+ n_vocab,
337
+ out_channels,
338
+ hidden_channels,
339
+ filter_channels,
340
+ n_heads,
341
+ n_layers,
342
+ kernel_size,
343
+ p_dropout,
344
+ gin_channels=0,
345
+ ):
346
+ super().__init__()
347
+ self.n_vocab = n_vocab
348
+ self.out_channels = out_channels
349
+ self.hidden_channels = hidden_channels
350
+ self.filter_channels = filter_channels
351
+ self.n_heads = n_heads
352
+ self.n_layers = n_layers
353
+ self.kernel_size = kernel_size
354
+ self.p_dropout = p_dropout
355
+ self.gin_channels = gin_channels
356
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
357
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
358
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
359
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
360
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
361
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
362
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
363
+
364
+ # Remove emo_vq since it's not working well.
365
+ self.style_proj = nn.Linear(256, hidden_channels)
366
+
367
+ self.encoder = attentions.Encoder(
368
+ hidden_channels,
369
+ filter_channels,
370
+ n_heads,
371
+ n_layers,
372
+ kernel_size,
373
+ p_dropout,
374
+ gin_channels=self.gin_channels,
375
+ )
376
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
377
+
378
+ def forward(self, x, x_lengths, tone, language, bert, style_vec, g=None):
379
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
380
+ style_emb = self.style_proj(style_vec.unsqueeze(1))
381
+ x = (
382
+ self.emb(x)
383
+ + self.tone_emb(tone)
384
+ + self.language_emb(language)
385
+ + bert_emb
386
+ + style_emb
387
+ ) * math.sqrt(
388
+ self.hidden_channels
389
+ ) # [b, t, h]
390
+ x = torch.transpose(x, 1, -1) # [b, h, t]
391
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
392
+ x.dtype
393
+ )
394
+
395
+ x = self.encoder(x * x_mask, x_mask, g=g)
396
+ stats = self.proj(x) * x_mask
397
+
398
+ m, logs = torch.split(stats, self.out_channels, dim=1)
399
+ return x, m, logs, x_mask
400
+
401
+
402
+ class ResidualCouplingBlock(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ n_flows=4,
411
+ gin_channels=0,
412
+ ):
413
+ super().__init__()
414
+ self.channels = channels
415
+ self.hidden_channels = hidden_channels
416
+ self.kernel_size = kernel_size
417
+ self.dilation_rate = dilation_rate
418
+ self.n_layers = n_layers
419
+ self.n_flows = n_flows
420
+ self.gin_channels = gin_channels
421
+
422
+ self.flows = nn.ModuleList()
423
+ for i in range(n_flows):
424
+ self.flows.append(
425
+ modules.ResidualCouplingLayer(
426
+ channels,
427
+ hidden_channels,
428
+ kernel_size,
429
+ dilation_rate,
430
+ n_layers,
431
+ gin_channels=gin_channels,
432
+ mean_only=True,
433
+ )
434
+ )
435
+ self.flows.append(modules.Flip())
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ if not reverse:
439
+ for flow in self.flows:
440
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
441
+ else:
442
+ for flow in reversed(self.flows):
443
+ x = flow(x, x_mask, g=g, reverse=reverse)
444
+ return x
445
+
446
+
447
+ class PosteriorEncoder(nn.Module):
448
+ def __init__(
449
+ self,
450
+ in_channels,
451
+ out_channels,
452
+ hidden_channels,
453
+ kernel_size,
454
+ dilation_rate,
455
+ n_layers,
456
+ gin_channels=0,
457
+ ):
458
+ super().__init__()
459
+ self.in_channels = in_channels
460
+ self.out_channels = out_channels
461
+ self.hidden_channels = hidden_channels
462
+ self.kernel_size = kernel_size
463
+ self.dilation_rate = dilation_rate
464
+ self.n_layers = n_layers
465
+ self.gin_channels = gin_channels
466
+
467
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
468
+ self.enc = modules.WN(
469
+ hidden_channels,
470
+ kernel_size,
471
+ dilation_rate,
472
+ n_layers,
473
+ gin_channels=gin_channels,
474
+ )
475
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
476
+
477
+ def forward(self, x, x_lengths, g=None):
478
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
479
+ x.dtype
480
+ )
481
+ x = self.pre(x) * x_mask
482
+ x = self.enc(x, x_mask, g=g)
483
+ stats = self.proj(x) * x_mask
484
+ m, logs = torch.split(stats, self.out_channels, dim=1)
485
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
486
+ return z, m, logs, x_mask
487
+
488
+
489
+ class Generator(torch.nn.Module):
490
+ def __init__(
491
+ self,
492
+ initial_channel,
493
+ resblock,
494
+ resblock_kernel_sizes,
495
+ resblock_dilation_sizes,
496
+ upsample_rates,
497
+ upsample_initial_channel,
498
+ upsample_kernel_sizes,
499
+ gin_channels=0,
500
+ ):
501
+ super(Generator, self).__init__()
502
+ self.num_kernels = len(resblock_kernel_sizes)
503
+ self.num_upsamples = len(upsample_rates)
504
+ self.conv_pre = Conv1d(
505
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
506
+ )
507
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
508
+
509
+ self.ups = nn.ModuleList()
510
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
511
+ self.ups.append(
512
+ weight_norm(
513
+ ConvTranspose1d(
514
+ upsample_initial_channel // (2**i),
515
+ upsample_initial_channel // (2 ** (i + 1)),
516
+ k,
517
+ u,
518
+ padding=(k - u) // 2,
519
+ )
520
+ )
521
+ )
522
+
523
+ self.resblocks = nn.ModuleList()
524
+ for i in range(len(self.ups)):
525
+ ch = upsample_initial_channel // (2 ** (i + 1))
526
+ for j, (k, d) in enumerate(
527
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
528
+ ):
529
+ self.resblocks.append(resblock(ch, k, d))
530
+
531
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
532
+ self.ups.apply(init_weights)
533
+
534
+ if gin_channels != 0:
535
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
536
+
537
+ def forward(self, x, g=None):
538
+ x = self.conv_pre(x)
539
+ if g is not None:
540
+ x = x + self.cond(g)
541
+
542
+ for i in range(self.num_upsamples):
543
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
544
+ x = self.ups[i](x)
545
+ xs = None
546
+ for j in range(self.num_kernels):
547
+ if xs is None:
548
+ xs = self.resblocks[i * self.num_kernels + j](x)
549
+ else:
550
+ xs += self.resblocks[i * self.num_kernels + j](x)
551
+ x = xs / self.num_kernels
552
+ x = F.leaky_relu(x)
553
+ x = self.conv_post(x)
554
+ x = torch.tanh(x)
555
+
556
+ return x
557
+
558
+ def remove_weight_norm(self):
559
+ print("Removing weight norm...")
560
+ for layer in self.ups:
561
+ remove_weight_norm(layer)
562
+ for layer in self.resblocks:
563
+ layer.remove_weight_norm()
564
+
565
+
566
+ class DiscriminatorP(torch.nn.Module):
567
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
568
+ super(DiscriminatorP, self).__init__()
569
+ self.period = period
570
+ self.use_spectral_norm = use_spectral_norm
571
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
572
+ self.convs = nn.ModuleList(
573
+ [
574
+ norm_f(
575
+ Conv2d(
576
+ 1,
577
+ 32,
578
+ (kernel_size, 1),
579
+ (stride, 1),
580
+ padding=(get_padding(kernel_size, 1), 0),
581
+ )
582
+ ),
583
+ norm_f(
584
+ Conv2d(
585
+ 32,
586
+ 128,
587
+ (kernel_size, 1),
588
+ (stride, 1),
589
+ padding=(get_padding(kernel_size, 1), 0),
590
+ )
591
+ ),
592
+ norm_f(
593
+ Conv2d(
594
+ 128,
595
+ 512,
596
+ (kernel_size, 1),
597
+ (stride, 1),
598
+ padding=(get_padding(kernel_size, 1), 0),
599
+ )
600
+ ),
601
+ norm_f(
602
+ Conv2d(
603
+ 512,
604
+ 1024,
605
+ (kernel_size, 1),
606
+ (stride, 1),
607
+ padding=(get_padding(kernel_size, 1), 0),
608
+ )
609
+ ),
610
+ norm_f(
611
+ Conv2d(
612
+ 1024,
613
+ 1024,
614
+ (kernel_size, 1),
615
+ 1,
616
+ padding=(get_padding(kernel_size, 1), 0),
617
+ )
618
+ ),
619
+ ]
620
+ )
621
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
622
+
623
+ def forward(self, x):
624
+ fmap = []
625
+
626
+ # 1d to 2d
627
+ b, c, t = x.shape
628
+ if t % self.period != 0: # pad first
629
+ n_pad = self.period - (t % self.period)
630
+ x = F.pad(x, (0, n_pad), "reflect")
631
+ t = t + n_pad
632
+ x = x.view(b, c, t // self.period, self.period)
633
+
634
+ for layer in self.convs:
635
+ x = layer(x)
636
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
637
+ fmap.append(x)
638
+ x = self.conv_post(x)
639
+ fmap.append(x)
640
+ x = torch.flatten(x, 1, -1)
641
+
642
+ return x, fmap
643
+
644
+
645
+ class DiscriminatorS(torch.nn.Module):
646
+ def __init__(self, use_spectral_norm=False):
647
+ super(DiscriminatorS, self).__init__()
648
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
649
+ self.convs = nn.ModuleList(
650
+ [
651
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
652
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
653
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
654
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
655
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
656
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
657
+ ]
658
+ )
659
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
660
+
661
+ def forward(self, x):
662
+ fmap = []
663
+
664
+ for layer in self.convs:
665
+ x = layer(x)
666
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
667
+ fmap.append(x)
668
+ x = self.conv_post(x)
669
+ fmap.append(x)
670
+ x = torch.flatten(x, 1, -1)
671
+
672
+ return x, fmap
673
+
674
+
675
+ class MultiPeriodDiscriminator(torch.nn.Module):
676
+ def __init__(self, use_spectral_norm=False):
677
+ super(MultiPeriodDiscriminator, self).__init__()
678
+ periods = [2, 3, 5, 7, 11]
679
+
680
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
681
+ discs = discs + [
682
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
683
+ ]
684
+ self.discriminators = nn.ModuleList(discs)
685
+
686
+ def forward(self, y, y_hat):
687
+ y_d_rs = []
688
+ y_d_gs = []
689
+ fmap_rs = []
690
+ fmap_gs = []
691
+ for i, d in enumerate(self.discriminators):
692
+ y_d_r, fmap_r = d(y)
693
+ y_d_g, fmap_g = d(y_hat)
694
+ y_d_rs.append(y_d_r)
695
+ y_d_gs.append(y_d_g)
696
+ fmap_rs.append(fmap_r)
697
+ fmap_gs.append(fmap_g)
698
+
699
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
700
+
701
+
702
+ class WavLMDiscriminator(nn.Module):
703
+ """docstring for Discriminator."""
704
+
705
+ def __init__(
706
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
707
+ ):
708
+ super(WavLMDiscriminator, self).__init__()
709
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
710
+ self.pre = norm_f(
711
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
712
+ )
713
+
714
+ self.convs = nn.ModuleList(
715
+ [
716
+ norm_f(
717
+ nn.Conv1d(
718
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
719
+ )
720
+ ),
721
+ norm_f(
722
+ nn.Conv1d(
723
+ initial_channel * 2,
724
+ initial_channel * 4,
725
+ kernel_size=5,
726
+ padding=2,
727
+ )
728
+ ),
729
+ norm_f(
730
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
731
+ ),
732
+ ]
733
+ )
734
+
735
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
736
+
737
+ def forward(self, x):
738
+ x = self.pre(x)
739
+
740
+ fmap = []
741
+ for l in self.convs:
742
+ x = l(x)
743
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
744
+ fmap.append(x)
745
+ x = self.conv_post(x)
746
+ x = torch.flatten(x, 1, -1)
747
+
748
+ return x
749
+
750
+
751
+ class ReferenceEncoder(nn.Module):
752
+ """
753
+ inputs --- [N, Ty/r, n_mels*r] mels
754
+ outputs --- [N, ref_enc_gru_size]
755
+ """
756
+
757
+ def __init__(self, spec_channels, gin_channels=0):
758
+ super().__init__()
759
+ self.spec_channels = spec_channels
760
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
761
+ K = len(ref_enc_filters)
762
+ filters = [1] + ref_enc_filters
763
+ convs = [
764
+ weight_norm(
765
+ nn.Conv2d(
766
+ in_channels=filters[i],
767
+ out_channels=filters[i + 1],
768
+ kernel_size=(3, 3),
769
+ stride=(2, 2),
770
+ padding=(1, 1),
771
+ )
772
+ )
773
+ for i in range(K)
774
+ ]
775
+ self.convs = nn.ModuleList(convs)
776
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
777
+
778
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
779
+ self.gru = nn.GRU(
780
+ input_size=ref_enc_filters[-1] * out_channels,
781
+ hidden_size=256 // 2,
782
+ batch_first=True,
783
+ )
784
+ self.proj = nn.Linear(128, gin_channels)
785
+
786
+ def forward(self, inputs, mask=None):
787
+ N = inputs.size(0)
788
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
789
+ for conv in self.convs:
790
+ out = conv(out)
791
+ # out = wn(out)
792
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
793
+
794
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
795
+ T = out.size(1)
796
+ N = out.size(0)
797
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
798
+
799
+ self.gru.flatten_parameters()
800
+ memory, out = self.gru(out) # out --- [1, N, 128]
801
+
802
+ return self.proj(out.squeeze(0))
803
+
804
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
805
+ for i in range(n_convs):
806
+ L = (L - kernel_size + 2 * pad) // stride + 1
807
+ return L
808
+
809
+
810
+ class SynthesizerTrn(nn.Module):
811
+ """
812
+ Synthesizer for Training
813
+ """
814
+
815
+ def __init__(
816
+ self,
817
+ n_vocab,
818
+ spec_channels,
819
+ segment_size,
820
+ inter_channels,
821
+ hidden_channels,
822
+ filter_channels,
823
+ n_heads,
824
+ n_layers,
825
+ kernel_size,
826
+ p_dropout,
827
+ resblock,
828
+ resblock_kernel_sizes,
829
+ resblock_dilation_sizes,
830
+ upsample_rates,
831
+ upsample_initial_channel,
832
+ upsample_kernel_sizes,
833
+ n_speakers=256,
834
+ gin_channels=256,
835
+ use_sdp=True,
836
+ n_flow_layer=4,
837
+ n_layers_trans_flow=6,
838
+ flow_share_parameter=False,
839
+ use_transformer_flow=True,
840
+ **kwargs
841
+ ):
842
+ super().__init__()
843
+ self.n_vocab = n_vocab
844
+ self.spec_channels = spec_channels
845
+ self.inter_channels = inter_channels
846
+ self.hidden_channels = hidden_channels
847
+ self.filter_channels = filter_channels
848
+ self.n_heads = n_heads
849
+ self.n_layers = n_layers
850
+ self.kernel_size = kernel_size
851
+ self.p_dropout = p_dropout
852
+ self.resblock = resblock
853
+ self.resblock_kernel_sizes = resblock_kernel_sizes
854
+ self.resblock_dilation_sizes = resblock_dilation_sizes
855
+ self.upsample_rates = upsample_rates
856
+ self.upsample_initial_channel = upsample_initial_channel
857
+ self.upsample_kernel_sizes = upsample_kernel_sizes
858
+ self.segment_size = segment_size
859
+ self.n_speakers = n_speakers
860
+ self.gin_channels = gin_channels
861
+ self.n_layers_trans_flow = n_layers_trans_flow
862
+ self.use_spk_conditioned_encoder = kwargs.get(
863
+ "use_spk_conditioned_encoder", True
864
+ )
865
+ self.use_sdp = use_sdp
866
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
867
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
868
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
869
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
870
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
871
+ self.enc_gin_channels = gin_channels
872
+ self.enc_p = TextEncoder(
873
+ n_vocab,
874
+ inter_channels,
875
+ hidden_channels,
876
+ filter_channels,
877
+ n_heads,
878
+ n_layers,
879
+ kernel_size,
880
+ p_dropout,
881
+ gin_channels=self.enc_gin_channels,
882
+ )
883
+ self.dec = Generator(
884
+ inter_channels,
885
+ resblock,
886
+ resblock_kernel_sizes,
887
+ resblock_dilation_sizes,
888
+ upsample_rates,
889
+ upsample_initial_channel,
890
+ upsample_kernel_sizes,
891
+ gin_channels=gin_channels,
892
+ )
893
+ self.enc_q = PosteriorEncoder(
894
+ spec_channels,
895
+ inter_channels,
896
+ hidden_channels,
897
+ 5,
898
+ 1,
899
+ 16,
900
+ gin_channels=gin_channels,
901
+ )
902
+ if use_transformer_flow:
903
+ self.flow = TransformerCouplingBlock(
904
+ inter_channels,
905
+ hidden_channels,
906
+ filter_channels,
907
+ n_heads,
908
+ n_layers_trans_flow,
909
+ 5,
910
+ p_dropout,
911
+ n_flow_layer,
912
+ gin_channels=gin_channels,
913
+ share_parameter=flow_share_parameter,
914
+ )
915
+ else:
916
+ self.flow = ResidualCouplingBlock(
917
+ inter_channels,
918
+ hidden_channels,
919
+ 5,
920
+ 1,
921
+ n_flow_layer,
922
+ gin_channels=gin_channels,
923
+ )
924
+ self.sdp = StochasticDurationPredictor(
925
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
926
+ )
927
+ self.dp = DurationPredictor(
928
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
929
+ )
930
+
931
+ if n_speakers >= 1:
932
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
933
+ else:
934
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
935
+
936
+ def forward(
937
+ self,
938
+ x,
939
+ x_lengths,
940
+ y,
941
+ y_lengths,
942
+ sid,
943
+ tone,
944
+ language,
945
+ bert,
946
+ style_vec,
947
+ ):
948
+ if self.n_speakers > 0:
949
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
950
+ else:
951
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
952
+ x, m_p, logs_p, x_mask = self.enc_p(
953
+ x, x_lengths, tone, language, bert, style_vec, g=g
954
+ )
955
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
956
+ z_p = self.flow(z, y_mask, g=g)
957
+
958
+ with torch.no_grad():
959
+ # negative cross-entropy
960
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
961
+ neg_cent1 = torch.sum(
962
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
963
+ ) # [b, 1, t_s]
964
+ neg_cent2 = torch.matmul(
965
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
966
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
967
+ neg_cent3 = torch.matmul(
968
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
969
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
970
+ neg_cent4 = torch.sum(
971
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
972
+ ) # [b, 1, t_s]
973
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
974
+ if self.use_noise_scaled_mas:
975
+ epsilon = (
976
+ torch.std(neg_cent)
977
+ * torch.randn_like(neg_cent)
978
+ * self.current_mas_noise_scale
979
+ )
980
+ neg_cent = neg_cent + epsilon
981
+
982
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
983
+ attn = (
984
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
985
+ .unsqueeze(1)
986
+ .detach()
987
+ )
988
+
989
+ w = attn.sum(2)
990
+
991
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
992
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
993
+
994
+ logw_ = torch.log(w + 1e-6) * x_mask
995
+ logw = self.dp(x, x_mask, g=g)
996
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
997
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
998
+ x_mask
999
+ ) # for averaging
1000
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1001
+
1002
+ l_length = l_length_dp + l_length_sdp
1003
+
1004
+ # expand prior
1005
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1006
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1007
+
1008
+ z_slice, ids_slice = commons.rand_slice_segments(
1009
+ z, y_lengths, self.segment_size
1010
+ )
1011
+ o = self.dec(z_slice, g=g)
1012
+ return (
1013
+ o,
1014
+ l_length,
1015
+ attn,
1016
+ ids_slice,
1017
+ x_mask,
1018
+ y_mask,
1019
+ (z, z_p, m_p, logs_p, m_q, logs_q),
1020
+ (x, logw, logw_), # , logw_sdp),
1021
+ g,
1022
+ )
1023
+
1024
+ def infer(
1025
+ self,
1026
+ x,
1027
+ x_lengths,
1028
+ sid,
1029
+ tone,
1030
+ language,
1031
+ bert,
1032
+ style_vec,
1033
+ noise_scale=0.667,
1034
+ length_scale=1,
1035
+ noise_scale_w=0.8,
1036
+ max_len=None,
1037
+ sdp_ratio=0,
1038
+ y=None,
1039
+ ):
1040
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1041
+ # g = self.gst(y)
1042
+ if self.n_speakers > 0:
1043
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1044
+ else:
1045
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1046
+ x, m_p, logs_p, x_mask = self.enc_p(
1047
+ x, x_lengths, tone, language, bert, style_vec, g=g
1048
+ )
1049
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1050
+ sdp_ratio
1051
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1052
+ w = torch.exp(logw) * x_mask * length_scale
1053
+ w_ceil = torch.ceil(w)
1054
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1055
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1056
+ x_mask.dtype
1057
+ )
1058
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1059
+ attn = commons.generate_path(w_ceil, attn_mask)
1060
+
1061
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1062
+ 1, 2
1063
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1064
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1065
+ 1, 2
1066
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1067
+
1068
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1069
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1070
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1071
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
monotonic_align/__init__.py CHANGED
@@ -1,16 +1,16 @@
1
- from numpy import zeros, int32, float32
2
- from torch import from_numpy
3
-
4
- from .core import maximum_path_jit
5
-
6
-
7
- def maximum_path(neg_cent, mask):
8
- device = neg_cent.device
9
- dtype = neg_cent.dtype
10
- neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
- path = zeros(neg_cent.shape, dtype=int32)
12
-
13
- t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
- t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
- maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
- return from_numpy(path).to(device=device, dtype=dtype)
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/core.py CHANGED
@@ -1,46 +1,46 @@
1
- import numba
2
-
3
-
4
- @numba.jit(
5
- numba.void(
6
- numba.int32[:, :, ::1],
7
- numba.float32[:, :, ::1],
8
- numba.int32[::1],
9
- numba.int32[::1],
10
- ),
11
- nopython=True,
12
- nogil=True,
13
- )
14
- def maximum_path_jit(paths, values, t_ys, t_xs):
15
- b = paths.shape[0]
16
- max_neg_val = -1e9
17
- for i in range(int(b)):
18
- path = paths[i]
19
- value = values[i]
20
- t_y = t_ys[i]
21
- t_x = t_xs[i]
22
-
23
- v_prev = v_cur = 0.0
24
- index = t_x - 1
25
-
26
- for y in range(t_y):
27
- for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
- if x == y:
29
- v_cur = max_neg_val
30
- else:
31
- v_cur = value[y - 1, x]
32
- if x == 0:
33
- if y == 0:
34
- v_prev = 0.0
35
- else:
36
- v_prev = max_neg_val
37
- else:
38
- v_prev = value[y - 1, x - 1]
39
- value[y, x] += max(v_prev, v_cur)
40
-
41
- for y in range(t_y - 1, -1, -1):
42
- path[y, index] = 1
43
- if index != 0 and (
44
- index == y or value[y - 1, index] < value[y - 1, index - 1]
45
- ):
46
- index = index - 1
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  cmudict
2
  cn2an
3
- faster-whisper>=0.10.0
4
  g2p_en
5
  GPUtil
6
  gradio
@@ -15,14 +14,12 @@ num2words
15
  numba
16
  numpy
17
  psutil
18
- pyannote.audio>=3.1.0
19
  pyopenjtalk-prebuilt
20
  pypinyin
21
  PyYAML
22
  requests
23
- sentencepiece
24
  safetensors
25
  scipy
26
  tensorboard
27
- torch
28
  transformers
 
1
  cmudict
2
  cn2an
 
3
  g2p_en
4
  GPUtil
5
  gradio
 
14
  numba
15
  numpy
16
  psutil
 
17
  pyopenjtalk-prebuilt
18
  pypinyin
19
  PyYAML
20
  requests
 
21
  safetensors
22
  scipy
23
  tensorboard
24
+ torch>=2.1,<2.2 # For users without GPU or colab
25
  transformers
text/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text.symbols import *
2
+
3
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
4
+
5
+
6
+ def cleaned_text_to_sequence(cleaned_text, tones, language):
7
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
8
+ Args:
9
+ text: string to convert to a sequence
10
+ Returns:
11
+ List of integers corresponding to the symbols in the text
12
+ """
13
+ phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
14
+ tone_start = language_tone_start_map[language]
15
+ tones = [i + tone_start for i in tones]
16
+ lang_id = language_id_map[language]
17
+ lang_ids = [lang_id for i in phones]
18
+ return phones, tones, lang_ids
19
+
20
+
21
+ def get_bert(
22
+ norm_text, word2ph, language, device, assist_text=None, assist_text_weight=0.7
23
+ ):
24
+ from .chinese_bert import get_bert_feature as zh_bert
25
+ from .english_bert_mock import get_bert_feature as en_bert
26
+ from .japanese_bert import get_bert_feature as jp_bert
27
+
28
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
29
+ bert = lang_bert_func_map[language](
30
+ norm_text, word2ph, device, assist_text, assist_text_weight
31
+ )
32
+ return bert
text/chinese.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ from text.symbols import punctuation
8
+ from text.tone_sandhi import ToneSandhi
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ pinyin_to_symbol_map = {
12
+ line.split("\t")[0]: line.strip().split("\t")[1]
13
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
+ }
15
+
16
+ import jieba.posseg as psg
17
+
18
+
19
+ rep_map = {
20
+ ":": ",",
21
+ ";": ",",
22
+ ",": ",",
23
+ "。": ".",
24
+ "!": "!",
25
+ "?": "?",
26
+ "\n": ".",
27
+ "·": ",",
28
+ "、": ",",
29
+ "...": "…",
30
+ "$": ".",
31
+ "“": "'",
32
+ "”": "'",
33
+ '"': "'",
34
+ "‘": "'",
35
+ "’": "'",
36
+ "(": "'",
37
+ ")": "'",
38
+ "(": "'",
39
+ ")": "'",
40
+ "《": "'",
41
+ "》": "'",
42
+ "【": "'",
43
+ "】": "'",
44
+ "[": "'",
45
+ "]": "'",
46
+ "—": "-",
47
+ "~": "-",
48
+ "~": "-",
49
+ "「": "'",
50
+ "」": "'",
51
+ }
52
+
53
+ tone_modifier = ToneSandhi()
54
+
55
+
56
+ def replace_punctuation(text):
57
+ text = text.replace("嗯", "恩").replace("呣", "母")
58
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
59
+
60
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
61
+
62
+ replaced_text = re.sub(
63
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
64
+ )
65
+
66
+ return replaced_text
67
+
68
+
69
+ def g2p(text):
70
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
71
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
72
+ phones, tones, word2ph = _g2p(sentences)
73
+ assert sum(word2ph) == len(phones)
74
+ assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
75
+ phones = ["_"] + phones + ["_"]
76
+ tones = [0] + tones + [0]
77
+ word2ph = [1] + word2ph + [1]
78
+ return phones, tones, word2ph
79
+
80
+
81
+ def _get_initials_finals(word):
82
+ initials = []
83
+ finals = []
84
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
85
+ orig_finals = lazy_pinyin(
86
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
87
+ )
88
+ for c, v in zip(orig_initials, orig_finals):
89
+ initials.append(c)
90
+ finals.append(v)
91
+ return initials, finals
92
+
93
+
94
+ def _g2p(segments):
95
+ phones_list = []
96
+ tones_list = []
97
+ word2ph = []
98
+ for seg in segments:
99
+ # Replace all English words in the sentence
100
+ seg = re.sub("[a-zA-Z]+", "", seg)
101
+ seg_cut = psg.lcut(seg)
102
+ initials = []
103
+ finals = []
104
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
105
+ for word, pos in seg_cut:
106
+ if pos == "eng":
107
+ continue
108
+ sub_initials, sub_finals = _get_initials_finals(word)
109
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
110
+ initials.append(sub_initials)
111
+ finals.append(sub_finals)
112
+
113
+ # assert len(sub_initials) == len(sub_finals) == len(word)
114
+ initials = sum(initials, [])
115
+ finals = sum(finals, [])
116
+ #
117
+ for c, v in zip(initials, finals):
118
+ raw_pinyin = c + v
119
+ # NOTE: post process for pypinyin outputs
120
+ # we discriminate i, ii and iii
121
+ if c == v:
122
+ assert c in punctuation
123
+ phone = [c]
124
+ tone = "0"
125
+ word2ph.append(1)
126
+ else:
127
+ v_without_tone = v[:-1]
128
+ tone = v[-1]
129
+
130
+ pinyin = c + v_without_tone
131
+ assert tone in "12345"
132
+
133
+ if c:
134
+ # 多音节
135
+ v_rep_map = {
136
+ "uei": "ui",
137
+ "iou": "iu",
138
+ "uen": "un",
139
+ }
140
+ if v_without_tone in v_rep_map.keys():
141
+ pinyin = c + v_rep_map[v_without_tone]
142
+ else:
143
+ # 单音节
144
+ pinyin_rep_map = {
145
+ "ing": "ying",
146
+ "i": "yi",
147
+ "in": "yin",
148
+ "u": "wu",
149
+ }
150
+ if pinyin in pinyin_rep_map.keys():
151
+ pinyin = pinyin_rep_map[pinyin]
152
+ else:
153
+ single_rep_map = {
154
+ "v": "yu",
155
+ "e": "e",
156
+ "i": "y",
157
+ "u": "w",
158
+ }
159
+ if pinyin[0] in single_rep_map.keys():
160
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
161
+
162
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
163
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
164
+ word2ph.append(len(phone))
165
+
166
+ phones_list += phone
167
+ tones_list += [int(tone)] * len(phone)
168
+ return phones_list, tones_list, word2ph
169
+
170
+
171
+ def text_normalize(text):
172
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
173
+ for number in numbers:
174
+ text = text.replace(number, cn2an.an2cn(number), 1)
175
+ text = replace_punctuation(text)
176
+ return text
177
+
178
+
179
+ def get_bert_feature(text, word2ph):
180
+ from text import chinese_bert
181
+
182
+ return chinese_bert.get_bert_feature(text, word2ph)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from text.chinese_bert import get_bert_feature
187
+
188
+ text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
189
+ text = text_normalize(text)
190
+ print(text)
191
+ phones, tones, word2ph = g2p(text)
192
+ bert = get_bert_feature(text, word2ph)
193
+
194
+ print(phones, tones, word2ph, bert.shape)
195
+
196
+
197
+ # # 示例用法
198
+ # text = "这是一个示例文本:,你好!这是一个测试...."
199
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
text/chinese_bert.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
5
+
6
+ from config import config
7
+
8
+ LOCAL_PATH = "./bert/chinese-roberta-wwm-ext-large"
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
11
+
12
+ models = dict()
13
+
14
+
15
+ def get_bert_feature(
16
+ text,
17
+ word2ph,
18
+ device=config.bert_gen_config.device,
19
+ assist_text=None,
20
+ assist_text_weight=0.7,
21
+ ):
22
+ if (
23
+ sys.platform == "darwin"
24
+ and torch.backends.mps.is_available()
25
+ and device == "cpu"
26
+ ):
27
+ device = "mps"
28
+ if not device:
29
+ device = "cuda"
30
+ if device == "cuda" and not torch.cuda.is_available():
31
+ device = "cpu"
32
+ if device not in models.keys():
33
+ models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
34
+ with torch.no_grad():
35
+ inputs = tokenizer(text, return_tensors="pt")
36
+ for i in inputs:
37
+ inputs[i] = inputs[i].to(device)
38
+ res = models[device](**inputs, output_hidden_states=True)
39
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
40
+ if assist_text:
41
+ style_inputs = tokenizer(assist_text, return_tensors="pt")
42
+ for i in style_inputs:
43
+ style_inputs[i] = style_inputs[i].to(device)
44
+ style_res = models[device](**style_inputs, output_hidden_states=True)
45
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
46
+ style_res_mean = style_res.mean(0)
47
+ assert len(word2ph) == len(text) + 2
48
+ word2phone = word2ph
49
+ phone_level_feature = []
50
+ for i in range(len(word2phone)):
51
+ if assist_text:
52
+ repeat_feature = (
53
+ res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight)
54
+ + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight
55
+ )
56
+ else:
57
+ repeat_feature = res[i].repeat(word2phone[i], 1)
58
+ phone_level_feature.append(repeat_feature)
59
+
60
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
61
+
62
+ return phone_level_feature.T
63
+
64
+
65
+ if __name__ == "__main__":
66
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
67
+ word2phone = [
68
+ 1,
69
+ 2,
70
+ 1,
71
+ 2,
72
+ 2,
73
+ 1,
74
+ 2,
75
+ 2,
76
+ 1,
77
+ 2,
78
+ 2,
79
+ 1,
80
+ 2,
81
+ 2,
82
+ 2,
83
+ 2,
84
+ 2,
85
+ 1,
86
+ 1,
87
+ 2,
88
+ 2,
89
+ 1,
90
+ 2,
91
+ 2,
92
+ 2,
93
+ 2,
94
+ 1,
95
+ 2,
96
+ 2,
97
+ 2,
98
+ 2,
99
+ 2,
100
+ 1,
101
+ 2,
102
+ 2,
103
+ 2,
104
+ 2,
105
+ 1,
106
+ ]
107
+
108
+ # 计算总帧数
109
+ total_frames = sum(word2phone)
110
+ print(word_level_feature.shape)
111
+ print(word2phone)
112
+ phone_level_feature = []
113
+ for i in range(len(word2phone)):
114
+ print(word_level_feature[i].shape)
115
+
116
+ # 对每个词重复word2phone[i]次
117
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
118
+ phone_level_feature.append(repeat_feature)
119
+
120
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
121
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
text/cleaner.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text import chinese, japanese, english, cleaned_text_to_sequence
2
+
3
+
4
+ language_module_map = {"ZH": chinese, "JP": japanese, "EN": english}
5
+
6
+
7
+ def clean_text(text, language, use_jp_extra=True):
8
+ language_module = language_module_map[language]
9
+ norm_text = language_module.text_normalize(text)
10
+ if language == "JP":
11
+ phones, tones, word2ph = language_module.g2p(norm_text, use_jp_extra)
12
+ else:
13
+ phones, tones, word2ph = language_module.g2p(norm_text)
14
+ return norm_text, phones, tones, word2ph
15
+
16
+
17
+ def clean_text_bert(text, language):
18
+ language_module = language_module_map[language]
19
+ norm_text = language_module.text_normalize(text)
20
+ phones, tones, word2ph = language_module.g2p(norm_text)
21
+ bert = language_module.get_bert_feature(norm_text, word2ph)
22
+ return phones, tones, bert
23
+
24
+
25
+ def text_to_sequence(text, language):
26
+ norm_text, phones, tones, word2ph = clean_text(text, language)
27
+ return cleaned_text_to_sequence(phones, tones, language)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ pass
text/cmudict.rep ADDED
The diff for this file is too large to render. See raw diff
 
text/cmudict_cache.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
3
+ size 6212655
text/english.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+ from transformers import DebertaV2Tokenizer
6
+
7
+ from text import symbols
8
+ from text.symbols import punctuation
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
12
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
13
+ _g2p = G2p()
14
+ LOCAL_PATH = "./bert/deberta-v3-large"
15
+ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
16
+
17
+ arpa = {
18
+ "AH0",
19
+ "S",
20
+ "AH1",
21
+ "EY2",
22
+ "AE2",
23
+ "EH0",
24
+ "OW2",
25
+ "UH0",
26
+ "NG",
27
+ "B",
28
+ "G",
29
+ "AY0",
30
+ "M",
31
+ "AA0",
32
+ "F",
33
+ "AO0",
34
+ "ER2",
35
+ "UH1",
36
+ "IY1",
37
+ "AH2",
38
+ "DH",
39
+ "IY0",
40
+ "EY1",
41
+ "IH0",
42
+ "K",
43
+ "N",
44
+ "W",
45
+ "IY2",
46
+ "T",
47
+ "AA1",
48
+ "ER1",
49
+ "EH2",
50
+ "OY0",
51
+ "UH2",
52
+ "UW1",
53
+ "Z",
54
+ "AW2",
55
+ "AW1",
56
+ "V",
57
+ "UW2",
58
+ "AA2",
59
+ "ER",
60
+ "AW0",
61
+ "UW0",
62
+ "R",
63
+ "OW1",
64
+ "EH1",
65
+ "ZH",
66
+ "AE0",
67
+ "IH2",
68
+ "IH",
69
+ "Y",
70
+ "JH",
71
+ "P",
72
+ "AY1",
73
+ "EY0",
74
+ "OY2",
75
+ "TH",
76
+ "HH",
77
+ "D",
78
+ "ER0",
79
+ "CH",
80
+ "AO1",
81
+ "AE1",
82
+ "AO2",
83
+ "OY1",
84
+ "AY2",
85
+ "IH1",
86
+ "OW0",
87
+ "L",
88
+ "SH",
89
+ }
90
+
91
+
92
+ def post_replace_ph(ph):
93
+ rep_map = {
94
+ ":": ",",
95
+ ";": ",",
96
+ ",": ",",
97
+ "。": ".",
98
+ "!": "!",
99
+ "?": "?",
100
+ "\n": ".",
101
+ "·": ",",
102
+ "、": ",",
103
+ "…": "...",
104
+ "···": "...",
105
+ "・・・": "...",
106
+ "v": "V",
107
+ }
108
+ if ph in rep_map.keys():
109
+ ph = rep_map[ph]
110
+ if ph in symbols:
111
+ return ph
112
+ if ph not in symbols:
113
+ ph = "UNK"
114
+ return ph
115
+
116
+
117
+ rep_map = {
118
+ ":": ",",
119
+ ";": ",",
120
+ ",": ",",
121
+ "。": ".",
122
+ "!": "!",
123
+ "?": "?",
124
+ "\n": ".",
125
+ ".": ".",
126
+ "…": "...",
127
+ "···": "...",
128
+ "・・・": "...",
129
+ "·": ",",
130
+ "・": ",",
131
+ "、": ",",
132
+ "$": ".",
133
+ "“": "'",
134
+ "”": "'",
135
+ '"': "'",
136
+ "‘": "'",
137
+ "’": "'",
138
+ "(": "'",
139
+ ")": "'",
140
+ "(": "'",
141
+ ")": "'",
142
+ "《": "'",
143
+ "》": "'",
144
+ "【": "'",
145
+ "】": "'",
146
+ "[": "'",
147
+ "]": "'",
148
+ "—": "-",
149
+ "−": "-",
150
+ "~": "-",
151
+ "~": "-",
152
+ "「": "'",
153
+ "」": "'",
154
+ }
155
+
156
+
157
+ def replace_punctuation(text):
158
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
159
+
160
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
161
+
162
+ # replaced_text = re.sub(
163
+ # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
164
+ # + "".join(punctuation)
165
+ # + r"]+",
166
+ # "",
167
+ # replaced_text,
168
+ # )
169
+
170
+ return replaced_text
171
+
172
+
173
+ def read_dict():
174
+ g2p_dict = {}
175
+ start_line = 49
176
+ with open(CMU_DICT_PATH) as f:
177
+ line = f.readline()
178
+ line_index = 1
179
+ while line:
180
+ if line_index >= start_line:
181
+ line = line.strip()
182
+ word_split = line.split(" ")
183
+ word = word_split[0]
184
+
185
+ syllable_split = word_split[1].split(" - ")
186
+ g2p_dict[word] = []
187
+ for syllable in syllable_split:
188
+ phone_split = syllable.split(" ")
189
+ g2p_dict[word].append(phone_split)
190
+
191
+ line_index = line_index + 1
192
+ line = f.readline()
193
+
194
+ return g2p_dict
195
+
196
+
197
+ def cache_dict(g2p_dict, file_path):
198
+ with open(file_path, "wb") as pickle_file:
199
+ pickle.dump(g2p_dict, pickle_file)
200
+
201
+
202
+ def get_dict():
203
+ if os.path.exists(CACHE_PATH):
204
+ with open(CACHE_PATH, "rb") as pickle_file:
205
+ g2p_dict = pickle.load(pickle_file)
206
+ else:
207
+ g2p_dict = read_dict()
208
+ cache_dict(g2p_dict, CACHE_PATH)
209
+
210
+ return g2p_dict
211
+
212
+
213
+ eng_dict = get_dict()
214
+
215
+
216
+ def refine_ph(phn):
217
+ tone = 0
218
+ if re.search(r"\d$", phn):
219
+ tone = int(phn[-1]) + 1
220
+ phn = phn[:-1]
221
+ else:
222
+ tone = 3
223
+ return phn.lower(), tone
224
+
225
+
226
+ def refine_syllables(syllables):
227
+ tones = []
228
+ phonemes = []
229
+ for phn_list in syllables:
230
+ for i in range(len(phn_list)):
231
+ phn = phn_list[i]
232
+ phn, tone = refine_ph(phn)
233
+ phonemes.append(phn)
234
+ tones.append(tone)
235
+ return phonemes, tones
236
+
237
+
238
+ import re
239
+ import inflect
240
+
241
+ _inflect = inflect.engine()
242
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
243
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
244
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
245
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
246
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
247
+ _number_re = re.compile(r"[0-9]+")
248
+
249
+ # List of (regular expression, replacement) pairs for abbreviations:
250
+ _abbreviations = [
251
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
252
+ for x in [
253
+ ("mrs", "misess"),
254
+ ("mr", "mister"),
255
+ ("dr", "doctor"),
256
+ ("st", "saint"),
257
+ ("co", "company"),
258
+ ("jr", "junior"),
259
+ ("maj", "major"),
260
+ ("gen", "general"),
261
+ ("drs", "doctors"),
262
+ ("rev", "reverend"),
263
+ ("lt", "lieutenant"),
264
+ ("hon", "honorable"),
265
+ ("sgt", "sergeant"),
266
+ ("capt", "captain"),
267
+ ("esq", "esquire"),
268
+ ("ltd", "limited"),
269
+ ("col", "colonel"),
270
+ ("ft", "fort"),
271
+ ]
272
+ ]
273
+
274
+
275
+ # List of (ipa, lazy ipa) pairs:
276
+ _lazy_ipa = [
277
+ (re.compile("%s" % x[0]), x[1])
278
+ for x in [
279
+ ("r", "ɹ"),
280
+ ("æ", "e"),
281
+ ("ɑ", "a"),
282
+ ("ɔ", "o"),
283
+ ("ð", "z"),
284
+ ("θ", "s"),
285
+ ("ɛ", "e"),
286
+ ("ɪ", "i"),
287
+ ("ʊ", "u"),
288
+ ("ʒ", "ʥ"),
289
+ ("ʤ", "ʥ"),
290
+ ("ˈ", "↓"),
291
+ ]
292
+ ]
293
+
294
+ # List of (ipa, lazy ipa2) pairs:
295
+ _lazy_ipa2 = [
296
+ (re.compile("%s" % x[0]), x[1])
297
+ for x in [
298
+ ("r", "ɹ"),
299
+ ("ð", "z"),
300
+ ("θ", "s"),
301
+ ("ʒ", "ʑ"),
302
+ ("ʤ", "dʑ"),
303
+ ("ˈ", "↓"),
304
+ ]
305
+ ]
306
+
307
+ # List of (ipa, ipa2) pairs
308
+ _ipa_to_ipa2 = [
309
+ (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")]
310
+ ]
311
+
312
+
313
+ def _expand_dollars(m):
314
+ match = m.group(1)
315
+ parts = match.split(".")
316
+ if len(parts) > 2:
317
+ return match + " dollars" # Unexpected format
318
+ dollars = int(parts[0]) if parts[0] else 0
319
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
320
+ if dollars and cents:
321
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
322
+ cent_unit = "cent" if cents == 1 else "cents"
323
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
324
+ elif dollars:
325
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
326
+ return "%s %s" % (dollars, dollar_unit)
327
+ elif cents:
328
+ cent_unit = "cent" if cents == 1 else "cents"
329
+ return "%s %s" % (cents, cent_unit)
330
+ else:
331
+ return "zero dollars"
332
+
333
+
334
+ def _remove_commas(m):
335
+ return m.group(1).replace(",", "")
336
+
337
+
338
+ def _expand_ordinal(m):
339
+ return _inflect.number_to_words(m.group(0))
340
+
341
+
342
+ def _expand_number(m):
343
+ num = int(m.group(0))
344
+ if num > 1000 and num < 3000:
345
+ if num == 2000:
346
+ return "two thousand"
347
+ elif num > 2000 and num < 2010:
348
+ return "two thousand " + _inflect.number_to_words(num % 100)
349
+ elif num % 100 == 0:
350
+ return _inflect.number_to_words(num // 100) + " hundred"
351
+ else:
352
+ return _inflect.number_to_words(
353
+ num, andword="", zero="oh", group=2
354
+ ).replace(", ", " ")
355
+ else:
356
+ return _inflect.number_to_words(num, andword="")
357
+
358
+
359
+ def _expand_decimal_point(m):
360
+ return m.group(1).replace(".", " point ")
361
+
362
+
363
+ def normalize_numbers(text):
364
+ text = re.sub(_comma_number_re, _remove_commas, text)
365
+ text = re.sub(_pounds_re, r"\1 pounds", text)
366
+ text = re.sub(_dollars_re, _expand_dollars, text)
367
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
368
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
369
+ text = re.sub(_number_re, _expand_number, text)
370
+ return text
371
+
372
+
373
+ def text_normalize(text):
374
+ text = normalize_numbers(text)
375
+ text = replace_punctuation(text)
376
+ text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text)
377
+ return text
378
+
379
+
380
+ def distribute_phone(n_phone, n_word):
381
+ phones_per_word = [0] * n_word
382
+ for task in range(n_phone):
383
+ min_tasks = min(phones_per_word)
384
+ min_index = phones_per_word.index(min_tasks)
385
+ phones_per_word[min_index] += 1
386
+ return phones_per_word
387
+
388
+
389
+ def sep_text(text):
390
+ words = re.split(r"([,;.\?\!\s+])", text)
391
+ words = [word for word in words if word.strip() != ""]
392
+ return words
393
+
394
+
395
+ def text_to_words(text):
396
+ tokens = tokenizer.tokenize(text)
397
+ words = []
398
+ for idx, t in enumerate(tokens):
399
+ if t.startswith("▁"):
400
+ words.append([t[1:]])
401
+ else:
402
+ if t in punctuation:
403
+ if idx == len(tokens) - 1:
404
+ words.append([f"{t}"])
405
+ else:
406
+ if (
407
+ not tokens[idx + 1].startswith("▁")
408
+ and tokens[idx + 1] not in punctuation
409
+ ):
410
+ if idx == 0:
411
+ words.append([])
412
+ words[-1].append(f"{t}")
413
+ else:
414
+ words.append([f"{t}"])
415
+ else:
416
+ if idx == 0:
417
+ words.append([])
418
+ words[-1].append(f"{t}")
419
+ return words
420
+
421
+
422
+ def g2p(text):
423
+ phones = []
424
+ tones = []
425
+ phone_len = []
426
+ # words = sep_text(text)
427
+ # tokens = [tokenizer.tokenize(i) for i in words]
428
+ words = text_to_words(text)
429
+
430
+ for word in words:
431
+ temp_phones, temp_tones = [], []
432
+ if len(word) > 1:
433
+ if "'" in word:
434
+ word = ["".join(word)]
435
+ for w in word:
436
+ if w in punctuation:
437
+ temp_phones.append(w)
438
+ temp_tones.append(0)
439
+ continue
440
+ if w.upper() in eng_dict:
441
+ phns, tns = refine_syllables(eng_dict[w.upper()])
442
+ temp_phones += [post_replace_ph(i) for i in phns]
443
+ temp_tones += tns
444
+ # w2ph.append(len(phns))
445
+ else:
446
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
447
+ phns = []
448
+ tns = []
449
+ for ph in phone_list:
450
+ if ph in arpa:
451
+ ph, tn = refine_ph(ph)
452
+ phns.append(ph)
453
+ tns.append(tn)
454
+ else:
455
+ phns.append(ph)
456
+ tns.append(0)
457
+ temp_phones += [post_replace_ph(i) for i in phns]
458
+ temp_tones += tns
459
+ phones += temp_phones
460
+ tones += temp_tones
461
+ phone_len.append(len(temp_phones))
462
+ # phones = [post_replace_ph(i) for i in phones]
463
+
464
+ word2ph = []
465
+ for token, pl in zip(words, phone_len):
466
+ word_len = len(token)
467
+
468
+ aaa = distribute_phone(pl, word_len)
469
+ word2ph += aaa
470
+
471
+ phones = ["_"] + phones + ["_"]
472
+ tones = [0] + tones + [0]
473
+ word2ph = [1] + word2ph + [1]
474
+ assert len(phones) == len(tones), text
475
+ assert len(phones) == sum(word2ph), text
476
+
477
+ return phones, tones, word2ph
478
+
479
+
480
+ def get_bert_feature(text, word2ph):
481
+ from text import english_bert_mock
482
+
483
+ return english_bert_mock.get_bert_feature(text, word2ph)
484
+
485
+
486
+ if __name__ == "__main__":
487
+ # print(get_dict())
488
+ # print(eng_word_to_phoneme("hello"))
489
+ print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))
490
+ # all_phones = set()
491
+ # for k, syllables in eng_dict.items():
492
+ # for group in syllables:
493
+ # for ph in group:
494
+ # all_phones.add(ph)
495
+ # print(all_phones)
text/english_bert_mock.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import DebertaV2Model, DebertaV2Tokenizer
5
+
6
+ from config import config
7
+
8
+
9
+ LOCAL_PATH = "./bert/deberta-v3-large"
10
+
11
+ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
12
+
13
+ models = dict()
14
+
15
+
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ assist_text=None,
21
+ assist_text_weight=0.7,
22
+ ):
23
+ if (
24
+ sys.platform == "darwin"
25
+ and torch.backends.mps.is_available()
26
+ and device == "cpu"
27
+ ):
28
+ device = "mps"
29
+ if not device:
30
+ device = "cuda"
31
+ if device == "cuda" and not torch.cuda.is_available():
32
+ device = "cpu"
33
+ if device not in models.keys():
34
+ models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device)
35
+ with torch.no_grad():
36
+ inputs = tokenizer(text, return_tensors="pt")
37
+ for i in inputs:
38
+ inputs[i] = inputs[i].to(device)
39
+ res = models[device](**inputs, output_hidden_states=True)
40
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
41
+ if assist_text:
42
+ style_inputs = tokenizer(assist_text, return_tensors="pt")
43
+ for i in style_inputs:
44
+ style_inputs[i] = style_inputs[i].to(device)
45
+ style_res = models[device](**style_inputs, output_hidden_states=True)
46
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
47
+ style_res_mean = style_res.mean(0)
48
+ assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
49
+ word2phone = word2ph
50
+ phone_level_feature = []
51
+ for i in range(len(word2phone)):
52
+ if assist_text:
53
+ repeat_feature = (
54
+ res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight)
55
+ + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight
56
+ )
57
+ else:
58
+ repeat_feature = res[i].repeat(word2phone[i], 1)
59
+ phone_level_feature.append(repeat_feature)
60
+
61
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
62
+
63
+ return phone_level_feature.T
text/japanese.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert Japanese text to phonemes which is
2
+ # compatible with Julius https://github.com/julius-speech/segmentation-kit
3
+ import re
4
+ import unicodedata
5
+
6
+ import pyopenjtalk
7
+ from num2words import num2words
8
+ from transformers import AutoTokenizer
9
+
10
+ from common.log import logger
11
+ from text import punctuation
12
+ from text.japanese_mora_list import (
13
+ mora_kata_to_mora_phonemes,
14
+ mora_phonemes_to_mora_kata,
15
+ )
16
+
17
+ # 子音の集合
18
+ COSONANTS = set(
19
+ [
20
+ cosonant
21
+ for cosonant, _ in mora_kata_to_mora_phonemes.values()
22
+ if cosonant is not None
23
+ ]
24
+ )
25
+
26
+ # 母音の集合、便宜上「ん」を含める
27
+ VOWELS = {"a", "i", "u", "e", "o", "N"}
28
+
29
+
30
+ # 正規化で記号を変換するための辞書
31
+ rep_map = {
32
+ ":": ",",
33
+ ";": ",",
34
+ ",": ",",
35
+ "。": ".",
36
+ "!": "!",
37
+ "?": "?",
38
+ "\n": ".",
39
+ ".": ".",
40
+ "…": "...",
41
+ "···": "...",
42
+ "・・・": "...",
43
+ "·": ",",
44
+ "・": ",",
45
+ "、": ",",
46
+ "$": ".",
47
+ "“": "'",
48
+ "”": "'",
49
+ '"': "'",
50
+ "‘": "'",
51
+ "’": "'",
52
+ "(": "'",
53
+ ")": "'",
54
+ "(": "'",
55
+ ")": "'",
56
+ "《": "'",
57
+ "》": "'",
58
+ "【": "'",
59
+ "】": "'",
60
+ "[": "'",
61
+ "]": "'",
62
+ "—": "-",
63
+ "−": "-",
64
+ # "~": "-", # これは長音記号「ー」として扱うよう変更
65
+ # "~": "-", # これも長音記号「ー」として扱うよう変更
66
+ "「": "'",
67
+ "」": "'",
68
+ }
69
+
70
+
71
+ def text_normalize(text):
72
+ """
73
+ 日本語のテキストを正規化する。
74
+ 結果は、ちょうど次の文字のみからなる:
75
+ - ひらがな
76
+ - カタカナ(全角長音記号「ー」が入る!)
77
+ - 漢字
78
+ - 半角アルファベット(大文字と小文字)
79
+ - ギリシャ文字
80
+ - `.` (句点`。`や`…`の一部や改行等)
81
+ - `,` (読点`、`や`:`等)
82
+ - `?` (疑問符`?`)
83
+ - `!` (感嘆符`!`)
84
+ - `'` (`「`や`」`等)
85
+ - `-` (`―`(ダッシュ、長音記号ではない)や`-`等)
86
+
87
+ 注意点:
88
+ - 三点リーダー`…`は`...`に変換される(`なるほど…。` → `なるほど....`)
89
+ - 数字は漢字に変換される(`1,100円` → `千百円`、`52.34` → `五十二点三四`)
90
+ - 読点や疑問符等の位置・個数等は保持される(`??あ、、!!!` → `??あ,,!!!`)
91
+ """
92
+ res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる
93
+ res = japanese_convert_numbers_to_words(res) # 「100円」→「百円」等
94
+ # 「~」と「~」も長音記号として扱う
95
+ res = res.replace("~", "ー")
96
+ res = res.replace("~", "ー")
97
+
98
+ res = replace_punctuation(res) # 句読点等正規化、読めない文字を削除
99
+
100
+ # 結合文字の濁点・半濁点を削除
101
+ # 通常の「ば」等はそのままのこされる、「あ゛」は上で「あ゙」になりここで「あ」になる
102
+ res = res.replace("\u3099", "") # 結合文字の濁点を削除、る゙ → る
103
+ res = res.replace("\u309A", "") # 結合文字の半濁点を削除、な゚ → な
104
+ return res
105
+
106
+
107
+ def replace_punctuation(text: str) -> str:
108
+ """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す:
109
+ 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字
110
+ """
111
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
112
+
113
+ # 句読点を辞書で置換
114
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
115
+
116
+ replaced_text = re.sub(
117
+ # ↓ ひらがな、カタカナ、漢字
118
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
119
+ # ↓ 半角アルファベット(大文字と小文字)
120
+ + r"\u0041-\u005A\u0061-\u007A"
121
+ # ↓ 全角アルファベット(大文字と小文字)
122
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
123
+ # ↓ ギリシャ文字
124
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
125
+ # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている
126
+ + "".join(punctuation) + r"]+",
127
+ # 上述以外の文字を削除
128
+ "",
129
+ replaced_text,
130
+ )
131
+
132
+ return replaced_text
133
+
134
+
135
+ _NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
136
+ _CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"}
137
+ _CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
138
+ _NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
139
+
140
+
141
+ def japanese_convert_numbers_to_words(text: str) -> str:
142
+ res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
143
+ res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
144
+ res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
145
+ return res
146
+
147
+
148
+ def g2p(
149
+ norm_text: str, use_jp_extra: bool = True
150
+ ) -> tuple[list[str], list[int], list[int]]:
151
+ """
152
+ 他で使われるメインの関数。`text_normalize()`で正規化された`norm_text`を受け取り、
153
+ - phones: 音素のリスト(ただし`!`や`,`や`.`等punctuationが含まれうる)
154
+ - tones: アクセントのリスト、0(低)と1(高)からなり、phonesと同じ長さ
155
+ - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト
156
+ のタプルを返す。
157
+ ただし`phones`と`tones`の最初と終わりに`_`が入り、応じて`word2ph`の最初と最後に1が追加される。
158
+ use_jp_extra: Falseの場合、「ん」の音素を「N」ではなく「n」とする。
159
+ """
160
+ # pyopenjtalkのフルコンテキストラベルを使ってアクセントを取り出すと、punctuationの位置が消えてしまい情報が失われてしまう:
161
+ # 「こんにちは、世界。」と「こんにちは!世界。」と「こんにちは!!!???世界……。」は全て同じになる。
162
+ # よって、まずpunctuation無しの音素とアクセントのリストを作り、
163
+ # それとは別にpyopenjtalk.run_frontend()で得られる音素リスト(こちらはpunctuationが保持される)を使い、
164
+ # アクセント割当をしなおすことによってpunctuationを含めた音素とアクセントのリストを作る。
165
+
166
+ # punctuationがすべて消えた、音素とアクセントのタプルのリスト(「ん」は「N」)
167
+ phone_tone_list_wo_punct = g2phone_tone_wo_punct(norm_text)
168
+
169
+ # sep_text: 単語単位の単語のリスト
170
+ # sep_kata: 単語単位の単語のカタカナ読みのリスト
171
+ sep_text, sep_kata = text2sep_kata(norm_text)
172
+
173
+ # sep_phonemes: 各単語ごとの音素のリストのリスト
174
+ sep_phonemes = handle_long([kata2phoneme_list(i) for i in sep_kata])
175
+
176
+ # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列
177
+ phone_w_punct: list[str] = []
178
+ for i in sep_phonemes:
179
+ phone_w_punct += i
180
+
181
+ # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る
182
+ phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct)
183
+ # logger.debug(f"phone_tone_list:\n{phone_tone_list}")
184
+ # word2phは厳密な解答は不可能なので(「今日」「眼鏡」等の熟字訓が存在)、
185
+ # Bert-VITS2では、単語単位の分割を使って、単語の文字ごとにだいたい均等に音素を分配する
186
+
187
+ # sep_textから、各単語を1文字1文字分割して、文字のリスト(のリスト)を作る
188
+ sep_tokenized: list[list[str]] = []
189
+ for i in sep_text:
190
+ if i not in punctuation:
191
+ sep_tokenized.append(
192
+ tokenizer.tokenize(i)
193
+ ) # ここでおそらく`i`が文字単位に分割される
194
+ else:
195
+ sep_tokenized.append([i])
196
+
197
+ # 各単語について、音素の数と文字の数を比較して、均等っぽく分配する
198
+ word2ph = []
199
+ for token, phoneme in zip(sep_tokenized, sep_phonemes):
200
+ phone_len = len(phoneme)
201
+ word_len = len(token)
202
+ word2ph += distribute_phone(phone_len, word_len)
203
+
204
+ # 最初と最後に`_`記号を追加、アクセントは0(低)、word2phもそれに合わせて追加
205
+ phone_tone_list = [("_", 0)] + phone_tone_list + [("_", 0)]
206
+ word2ph = [1] + word2ph + [1]
207
+
208
+ phones = [phone for phone, _ in phone_tone_list]
209
+ tones = [tone for _, tone in phone_tone_list]
210
+
211
+ assert len(phones) == sum(word2ph), f"{len(phones)} != {sum(word2ph)}"
212
+
213
+ # use_jp_extraでない場合は「N」を「n」に変換
214
+ if not use_jp_extra:
215
+ phones = [phone if phone != "N" else "n" for phone in phones]
216
+
217
+ return phones, tones, word2ph
218
+
219
+
220
+ def g2kata_tone(norm_text: str) -> list[tuple[str, int]]:
221
+ phones, tones, _ = g2p(norm_text, use_jp_extra=True)
222
+ return phone_tone2kata_tone(list(zip(phones, tones)))
223
+
224
+
225
+ def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, int]]:
226
+ """phone_toneをのphone部分をカタカナに変換する。ただし最初と最後の("_", 0)は無視"""
227
+ phone_tone = phone_tone[1:] # 最初の("_", 0)を無視
228
+ phones = [phone for phone, _ in phone_tone]
229
+ tones = [tone for _, tone in phone_tone]
230
+ result: list[tuple[str, int]] = []
231
+ current_mora = ""
232
+ for phone, next_phone, tone, next_tone in zip(phones, phones[1:], tones, tones[1:]):
233
+ # zipの関係で最後の("_", 0)は無視されている
234
+ if phone in punctuation:
235
+ result.append((phone, tone))
236
+ continue
237
+ if phone in COSONANTS: # n以外の子音の場合
238
+ assert current_mora == "", f"Unexpected {phone} after {current_mora}"
239
+ assert tone == next_tone, f"Unexpected {phone} tone {tone} != {next_tone}"
240
+ current_mora = phone
241
+ else:
242
+ # phoneが母音もしくは「N」
243
+ current_mora += phone
244
+ result.append((mora_phonemes_to_mora_kata[current_mora], tone))
245
+ current_mora = ""
246
+ return result
247
+
248
+
249
+ def kata_tone2phone_tone(kata_tone: list[tuple[str, int]]) -> list[tuple[str, int]]:
250
+ """`phone_tone2kata_tone()`の逆。"""
251
+ result: list[tuple[str, int]] = [("_", 0)]
252
+ for mora, tone in kata_tone:
253
+ if mora in punctuation:
254
+ result.append((mora, tone))
255
+ else:
256
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
257
+ if cosonant is None:
258
+ result.append((vowel, tone))
259
+ else:
260
+ result.append((cosonant, tone))
261
+ result.append((vowel, tone))
262
+ result.append(("_", 0))
263
+ return result
264
+
265
+
266
+ def g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]:
267
+ """
268
+ テキストに対して、音素とアクセント(0か1)のペアのリストを返す。
269
+ ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。
270
+ 非音素記号を含める処理は`align_tones()`で行われる。
271
+ また「っ」は「q」に、「ん」は「N」に変換される。
272
+ 例: "こんにちは、世界ー。。元気?!" →
273
+ [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)]
274
+ """
275
+ prosodies = pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True)
276
+ # logger.debug(f"prosodies: {prosodies}")
277
+ result: list[tuple[str, int]] = []
278
+ current_phrase: list[tuple[str, int]] = []
279
+ current_tone = 0
280
+ for i, letter in enumerate(prosodies):
281
+ # 特殊記号の処理
282
+
283
+ # 文頭記号、無視する
284
+ if letter == "^":
285
+ assert i == 0, "Unexpected ^"
286
+ # アクセント句の終わりに来る記号
287
+ elif letter in ("$", "?", "_", "#"):
288
+ # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加
289
+ result.extend(fix_phone_tone(current_phrase))
290
+ # 末尾に来る終了記号、無視(文中の疑問文は`_`になる)
291
+ if letter in ("$", "?"):
292
+ assert i == len(prosodies) - 1, f"Unexpected {letter}"
293
+ # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ
294
+ # これらは残さず、次のアクセント句に備える。
295
+ current_phrase = []
296
+ # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る)
297
+ current_tone = 0
298
+ # アクセント上昇記号
299
+ elif letter == "[":
300
+ current_tone = current_tone + 1
301
+ # アクセント下降記号
302
+ elif letter == "]":
303
+ current_tone = current_tone - 1
304
+ # それ以外は通常の音素
305
+ else:
306
+ if letter == "cl": # 「っ」の処理
307
+ letter = "q"
308
+ # elif letter == "N": # 「ん」の処理
309
+ # letter = "n"
310
+ current_phrase.append((letter, current_tone))
311
+ return result
312
+
313
+
314
+ def text2sep_kata(norm_text: str) -> tuple[list[str], list[str]]:
315
+ """
316
+ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、
317
+ 分割された単語リストとその読み(カタカナor記号1文字)のリストのタプルを返す。
318
+ 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。
319
+ 例:
320
+ `私はそう思う!って感じ?` →
321
+ ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"]
322
+ """
323
+ # parsed: OpenJTalkの解析結果
324
+ parsed = pyopenjtalk.run_frontend(norm_text)
325
+ sep_text: list[str] = []
326
+ sep_kata: list[str] = []
327
+ for parts in parsed:
328
+ # word: 実際の単語の文字列
329
+ # yomi: その読み、但し無声化サインの`’`は除去
330
+ word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace(
331
+ "’", ""
332
+ )
333
+ """
334
+ ここで`yomi`の取りうる値は以下の通りのはず。
335
+ - `word`が通常単語 → 通常の読み(カタカナ)
336
+ (カタカナからなり、長音記号も含みうる、`アー` 等)
337
+ - `word`が`ー` から始まる → `ーラー` や `ーーー` など
338
+ - `word`が句読点や空白等 → `、`
339
+ - `word`が`?` → `?`(全角になる)
340
+ 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。
341
+ また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。
342
+ 処理すべきは`yomi`が`、`の場合のみのはず。
343
+ """
344
+ assert yomi != "", f"Empty yomi: {word}"
345
+ if yomi == "、":
346
+ # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか
347
+ if word not in (
348
+ ".",
349
+ ",",
350
+ "!",
351
+ "'",
352
+ "-",
353
+ "--",
354
+ ):
355
+ # ここはpyopenjtalkが読めない文字等のときに起こる
356
+ raise ValueError(f"Cannot read: {word} in:\n{norm_text}")
357
+ # yomiは元の記号のままに変更
358
+ yomi = word
359
+ elif yomi == "?":
360
+ assert word == "?", f"yomi `?` comes from: {word}"
361
+ yomi = "?"
362
+ sep_text.append(word)
363
+ sep_kata.append(yomi)
364
+ return sep_text, sep_kata
365
+
366
+
367
+ # ESPnetの実装から引用、変更点無し。「ん」は「N」なことに注意。
368
+ # https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
369
+ def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]:
370
+ """Extract phoneme + prosoody symbol sequence from input full-context labels.
371
+
372
+ The algorithm is based on `Prosodic features control by symbols as input of
373
+ sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.
374
+
375
+ Args:
376
+ text (str): Input text.
377
+ drop_unvoiced_vowels (bool): whether to drop unvoiced vowels.
378
+
379
+ Returns:
380
+ List[str]: List of phoneme + prosody symbols.
381
+
382
+ Examples:
383
+ >>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
384
+ >>> pyopenjtalk_g2p_prosody("こんにちは。")
385
+ ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
386
+
387
+ .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic
388
+ modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104
389
+
390
+ """
391
+ labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text))
392
+ N = len(labels)
393
+
394
+ phones = []
395
+ for n in range(N):
396
+ lab_curr = labels[n]
397
+
398
+ # current phoneme
399
+ p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
400
+ # deal unvoiced vowels as normal vowels
401
+ if drop_unvoiced_vowels and p3 in "AEIOU":
402
+ p3 = p3.lower()
403
+
404
+ # deal with sil at the beginning and the end of text
405
+ if p3 == "sil":
406
+ assert n == 0 or n == N - 1
407
+ if n == 0:
408
+ phones.append("^")
409
+ elif n == N - 1:
410
+ # check question form or not
411
+ e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
412
+ if e3 == 0:
413
+ phones.append("$")
414
+ elif e3 == 1:
415
+ phones.append("?")
416
+ continue
417
+ elif p3 == "pau":
418
+ phones.append("_")
419
+ continue
420
+ else:
421
+ phones.append(p3)
422
+
423
+ # accent type and position info (forward or backward)
424
+ a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
425
+ a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
426
+ a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
427
+
428
+ # number of mora in accent phrase
429
+ f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
430
+
431
+ a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
432
+ # accent phrase border
433
+ if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
434
+ phones.append("#")
435
+ # pitch falling
436
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
437
+ phones.append("]")
438
+ # pitch rising
439
+ elif a2 == 1 and a2_next == 2:
440
+ phones.append("[")
441
+
442
+ return phones
443
+
444
+
445
+ def _numeric_feature_by_regex(regex, s):
446
+ match = re.search(regex, s)
447
+ if match is None:
448
+ return -50
449
+ return int(match.group(1))
450
+
451
+
452
+ def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]:
453
+ """
454
+ `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。
455
+ 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)]
456
+ """
457
+ tone_values = set(tone for _, tone in phone_tone_list)
458
+ if len(tone_values) == 1:
459
+ assert tone_values == {0}, tone_values
460
+ return phone_tone_list
461
+ elif len(tone_values) == 2:
462
+ if tone_values == {0, 1}:
463
+ return phone_tone_list
464
+ elif tone_values == {-1, 0}:
465
+ return [
466
+ (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list
467
+ ]
468
+ else:
469
+ raise ValueError(f"Unexpected tone values: {tone_values}")
470
+ else:
471
+ raise ValueError(f"Unexpected tone values: {tone_values}")
472
+
473
+
474
+ def distribute_phone(n_phone: int, n_word: int) -> list[int]:
475
+ """
476
+ 左から右に1ずつ振り分け、次にまた左から右に1ずつ増やし、というふうに、
477
+ 音素の数`n_phone`を単語の数`n_word`に分配する。
478
+ """
479
+ phones_per_word = [0] * n_word
480
+ for _ in range(n_phone):
481
+ min_tasks = min(phones_per_word)
482
+ min_index = phones_per_word.index(min_tasks)
483
+ phones_per_word[min_index] += 1
484
+ return phones_per_word
485
+
486
+
487
+ def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]:
488
+ for i in range(len(sep_phonemes)):
489
+ if sep_phonemes[i][0] == "ー":
490
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
491
+ if "ー" in sep_phonemes[i]:
492
+ for j in range(len(sep_phonemes[i])):
493
+ if sep_phonemes[i][j] == "ー":
494
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
495
+ return sep_phonemes
496
+
497
+
498
+ tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm")
499
+
500
+
501
+ def align_tones(
502
+ phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]]
503
+ ) -> list[tuple[str, int]]:
504
+ """
505
+ 例:
506
+ …私は、、そう思う。
507
+ phones_with_punct:
508
+ [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."]
509
+ phone_tone_list:
510
+ [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("_", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))]
511
+ Return:
512
+ [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)]
513
+ """
514
+ result: list[tuple[str, int]] = []
515
+ tone_index = 0
516
+ for phone in phones_with_punct:
517
+ if tone_index >= len(phone_tone_list):
518
+ # 余ったpunctuationがある場合 → (punctuation, 0)を追加
519
+ result.append((phone, 0))
520
+ elif phone == phone_tone_list[tone_index][0]:
521
+ # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加
522
+ result.append((phone, phone_tone_list[tone_index][1]))
523
+ # 探すindexを1つ進める
524
+ tone_index += 1
525
+ elif phone in punctuation:
526
+ # phoneがpunctuationの場合 → (phone, 0)を追加
527
+ result.append((phone, 0))
528
+ else:
529
+ logger.debug(f"phones: {phones_with_punct}")
530
+ logger.debug(f"phone_tone_list: {phone_tone_list}")
531
+ logger.debug(f"result: {result}")
532
+ logger.debug(f"tone_index: {tone_index}")
533
+ logger.debug(f"phone: {phone}")
534
+ raise ValueError(f"Unexpected phone: {phone}")
535
+ return result
536
+
537
+
538
+ def kata2phoneme_list(text: str) -> list[str]:
539
+ """
540
+ 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。
541
+ 注意点:
542
+ - punctuationが来た場合(punctuationが1文字の場合がありうる)、処理せず1文字のリストを返す
543
+ - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される)
544
+ - 文中の「ー」は前の音素記号の最後の音素記号に変換される。
545
+ 例:
546
+ `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"]
547
+ `?` → ["?"]
548
+ """
549
+ if text in punctuation:
550
+ return [text]
551
+ elif text == "--":
552
+ return ["-", "-"]
553
+ # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック
554
+ if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None:
555
+ raise ValueError(f"Input must be katakana only: {text}")
556
+ sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True)
557
+ pattern = "|".join(map(re.escape, sorted_keys))
558
+
559
+ def mora2phonemes(mora: str) -> str:
560
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
561
+ if cosonant is None:
562
+ return f" {vowel}"
563
+ return f" {cosonant} {vowel}"
564
+
565
+ spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text)
566
+
567
+ # 長音記号「ー」の処理
568
+ long_pattern = r"(\w)(ー*)"
569
+ long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
570
+ spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes)
571
+ return spaced_phonemes.strip().split(" ")
572
+
573
+
574
+ if __name__ == "__main__":
575
+ tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese")
576
+ text = "hello,こんにちは、世界ー!……"
577
+ from text.japanese_bert import get_bert_feature
578
+
579
+ text = text_normalize(text)
580
+ print(text)
581
+
582
+ phones, tones, word2ph = g2p(text)
583
+ bert = get_bert_feature(text, word2ph)
584
+
585
+ print(phones, tones, word2ph, bert.shape)
text/japanese_bert.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
5
+
6
+ from config import config
7
+ from text.japanese import text2sep_kata
8
+
9
+ LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
12
+
13
+ models = dict()
14
+
15
+
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ assist_text=None,
21
+ assist_text_weight=0.7,
22
+ ):
23
+ text = "".join(text2sep_kata(text)[0])
24
+ if assist_text:
25
+ assist_text = "".join(text2sep_kata(assist_text)[0])
26
+ if (
27
+ sys.platform == "darwin"
28
+ and torch.backends.mps.is_available()
29
+ and device == "cpu"
30
+ ):
31
+ device = "mps"
32
+ if not device:
33
+ device = "cuda"
34
+ if device == "cuda" and not torch.cuda.is_available():
35
+ device = "cpu"
36
+ if device not in models.keys():
37
+ models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
38
+ with torch.no_grad():
39
+ inputs = tokenizer(text, return_tensors="pt")
40
+ for i in inputs:
41
+ inputs[i] = inputs[i].to(device)
42
+ res = models[device](**inputs, output_hidden_states=True)
43
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
44
+ if assist_text:
45
+ style_inputs = tokenizer(assist_text, return_tensors="pt")
46
+ for i in style_inputs:
47
+ style_inputs[i] = style_inputs[i].to(device)
48
+ style_res = models[device](**style_inputs, output_hidden_states=True)
49
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
50
+ style_res_mean = style_res.mean(0)
51
+
52
+ assert len(word2ph) == len(text) + 2, text
53
+ word2phone = word2ph
54
+ phone_level_feature = []
55
+ for i in range(len(word2phone)):
56
+ if assist_text:
57
+ repeat_feature = (
58
+ res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight)
59
+ + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight
60
+ )
61
+ else:
62
+ repeat_feature = res[i].repeat(word2phone[i], 1)
63
+ phone_level_feature.append(repeat_feature)
64
+
65
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
66
+
67
+ return phone_level_feature.T
text/japanese_mora_list.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VOICEVOXのソースコードからお借りして最低限に改造したコード。
3
+ https://github.com/VOICEVOX/voicevox_engine/blob/master/voicevox_engine/tts_pipeline/mora_list.py
4
+ """
5
+
6
+ """
7
+ 以下のモーラ対応表はOpenJTalkのソースコードから取得し、
8
+ カタカナ表記とモーラが一対一対応するように改造した。
9
+ ライセンス表記:
10
+ -----------------------------------------------------------------
11
+ The Japanese TTS System "Open JTalk"
12
+ developed by HTS Working Group
13
+ http://open-jtalk.sourceforge.net/
14
+ -----------------------------------------------------------------
15
+
16
+ Copyright (c) 2008-2014 Nagoya Institute of Technology
17
+ Department of Computer Science
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or
22
+ without modification, are permitted provided that the following
23
+ conditions are met:
24
+
25
+ - Redistributions of source code must retain the above copyright
26
+ notice, this list of conditions and the following disclaimer.
27
+ - Redistributions in binary form must reproduce the above
28
+ copyright notice, this list of conditions and the following
29
+ disclaimer in the documentation and/or other materials provided
30
+ with the distribution.
31
+ - Neither the name of the HTS working group nor the names of its
32
+ contributors may be used to endorse or promote products derived
33
+ from this software without specific prior written permission.
34
+
35
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
36
+ CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
37
+ INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
38
+ MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
39
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
40
+ BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
41
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
42
+ TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
43
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
44
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
45
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
46
+ OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
47
+ POSSIBILITY OF SUCH DAMAGE.
48
+ """
49
+ from typing import Optional
50
+
51
+ # (カタカナ, 子音, 母音)の順。子音がない場合はNoneを入れる。
52
+ # 但し「ン」と「ッ」は母音のみという扱いで、「ン」は「N」、「ッ」は「q」とする。
53
+ # (元々「ッ」は「cl」)
54
+ # また「デェ = dy e」はpyopenjtalkの出力(de e)と合わないため削除
55
+ _mora_list_minimum: list[tuple[str, Optional[str], str]] = [
56
+ ("ヴォ", "v", "o"),
57
+ ("ヴェ", "v", "e"),
58
+ ("ヴィ", "v", "i"),
59
+ ("ヴァ", "v", "a"),
60
+ ("ヴ", "v", "u"),
61
+ ("ン", None, "N"),
62
+ ("ワ", "w", "a"),
63
+ ("ロ", "r", "o"),
64
+ ("レ", "r", "e"),
65
+ ("ル", "r", "u"),
66
+ ("リョ", "ry", "o"),
67
+ ("リュ", "ry", "u"),
68
+ ("リャ", "ry", "a"),
69
+ ("リェ", "ry", "e"),
70
+ ("リ", "r", "i"),
71
+ ("ラ", "r", "a"),
72
+ ("ヨ", "y", "o"),
73
+ ("ユ", "y", "u"),
74
+ ("ヤ", "y", "a"),
75
+ ("モ", "m", "o"),
76
+ ("メ", "m", "e"),
77
+ ("ム", "m", "u"),
78
+ ("ミョ", "my", "o"),
79
+ ("ミュ", "my", "u"),
80
+ ("ミャ", "my", "a"),
81
+ ("ミェ", "my", "e"),
82
+ ("ミ", "m", "i"),
83
+ ("マ", "m", "a"),
84
+ ("ポ", "p", "o"),
85
+ ("ボ", "b", "o"),
86
+ ("ホ", "h", "o"),
87
+ ("ペ", "p", "e"),
88
+ ("ベ", "b", "e"),
89
+ ("ヘ", "h", "e"),
90
+ ("プ", "p", "u"),
91
+ ("ブ", "b", "u"),
92
+ ("フォ", "f", "o"),
93
+ ("フェ", "f", "e"),
94
+ ("フィ", "f", "i"),
95
+ ("ファ", "f", "a"),
96
+ ("フ", "f", "u"),
97
+ ("ピョ", "py", "o"),
98
+ ("ピュ", "py", "u"),
99
+ ("ピャ", "py", "a"),
100
+ ("ピェ", "py", "e"),
101
+ ("ピ", "p", "i"),
102
+ ("ビョ", "by", "o"),
103
+ ("ビュ", "by", "u"),
104
+ ("ビャ", "by", "a"),
105
+ ("ビェ", "by", "e"),
106
+ ("ビ", "b", "i"),
107
+ ("ヒョ", "hy", "o"),
108
+ ("ヒュ", "hy", "u"),
109
+ ("ヒャ", "hy", "a"),
110
+ ("ヒェ", "hy", "e"),
111
+ ("ヒ", "h", "i"),
112
+ ("パ", "p", "a"),
113
+ ("バ", "b", "a"),
114
+ ("ハ", "h", "a"),
115
+ ("ノ", "n", "o"),
116
+ ("ネ", "n", "e"),
117
+ ("ヌ", "n", "u"),
118
+ ("ニョ", "ny", "o"),
119
+ ("ニュ", "ny", "u"),
120
+ ("ニャ", "ny", "a"),
121
+ ("ニェ", "ny", "e"),
122
+ ("ニ", "n", "i"),
123
+ ("ナ", "n", "a"),
124
+ ("ドゥ", "d", "u"),
125
+ ("ド", "d", "o"),
126
+ ("トゥ", "t", "u"),
127
+ ("ト", "t", "o"),
128
+ ("デョ", "dy", "o"),
129
+ ("デュ", "dy", "u"),
130
+ ("デャ", "dy", "a"),
131
+ # ("デェ", "dy", "e"),
132
+ ("ディ", "d", "i"),
133
+ ("デ", "d", "e"),
134
+ ("テョ", "ty", "o"),
135
+ ("テュ", "ty", "u"),
136
+ ("テャ", "ty", "a"),
137
+ ("ティ", "t", "i"),
138
+ ("テ", "t", "e"),
139
+ ("ツォ", "ts", "o"),
140
+ ("ツェ", "ts", "e"),
141
+ ("ツィ", "ts", "i"),
142
+ ("ツァ", "ts", "a"),
143
+ ("ツ", "ts", "u"),
144
+ ("ッ", None, "q"), # 「cl」から「q」に変更
145
+ ("チョ", "ch", "o"),
146
+ ("チュ", "ch", "u"),
147
+ ("チャ", "ch", "a"),
148
+ ("チェ", "ch", "e"),
149
+ ("チ", "ch", "i"),
150
+ ("ダ", "d", "a"),
151
+ ("タ", "t", "a"),
152
+ ("ゾ", "z", "o"),
153
+ ("ソ", "s", "o"),
154
+ ("ゼ", "z", "e"),
155
+ ("セ", "s", "e"),
156
+ ("ズィ", "z", "i"),
157
+ ("ズ", "z", "u"),
158
+ ("スィ", "s", "i"),
159
+ ("ス", "s", "u"),
160
+ ("ジョ", "j", "o"),
161
+ ("ジュ", "j", "u"),
162
+ ("ジャ", "j", "a"),
163
+ ("ジェ", "j", "e"),
164
+ ("ジ", "j", "i"),
165
+ ("ショ", "sh", "o"),
166
+ ("シュ", "sh", "u"),
167
+ ("シャ", "sh", "a"),
168
+ ("シェ", "sh", "e"),
169
+ ("シ", "sh", "i"),
170
+ ("ザ", "z", "a"),
171
+ ("サ", "s", "a"),
172
+ ("ゴ", "g", "o"),
173
+ ("コ", "k", "o"),
174
+ ("ゲ", "g", "e"),
175
+ ("ケ", "k", "e"),
176
+ ("グヮ", "gw", "a"),
177
+ ("グ", "g", "u"),
178
+ ("クヮ", "kw", "a"),
179
+ ("ク", "k", "u"),
180
+ ("ギョ", "gy", "o"),
181
+ ("ギュ", "gy", "u"),
182
+ ("ギャ", "gy", "a"),
183
+ ("ギェ", "gy", "e"),
184
+ ("ギ", "g", "i"),
185
+ ("キョ", "ky", "o"),
186
+ ("キュ", "ky", "u"),
187
+ ("キャ", "ky", "a"),
188
+ ("キェ", "ky", "e"),
189
+ ("キ", "k", "i"),
190
+ ("ガ", "g", "a"),
191
+ ("カ", "k", "a"),
192
+ ("オ", None, "o"),
193
+ ("エ", None, "e"),
194
+ ("ウォ", "w", "o"),
195
+ ("ウェ", "w", "e"),
196
+ ("ウィ", "w", "i"),
197
+ ("ウ", None, "u"),
198
+ ("イェ", "y", "e"),
199
+ ("イ", None, "i"),
200
+ ("ア", None, "a"),
201
+ ]
202
+ _mora_list_additional: list[tuple[str, Optional[str], str]] = [
203
+ ("ヴョ", "by", "o"),
204
+ ("ヴュ", "by", "u"),
205
+ ("ヴャ", "by", "a"),
206
+ ("ヲ", None, "o"),
207
+ ("ヱ", None, "e"),
208
+ ("ヰ", None, "i"),
209
+ ("ヮ", "w", "a"),
210
+ ("ョ", "y", "o"),
211
+ ("ュ", "y", "u"),
212
+ ("ヅ", "z", "u"),
213
+ ("ヂ", "j", "i"),
214
+ ("ヶ", "k", "e"),
215
+ ("ャ", "y", "a"),
216
+ ("ォ", None, "o"),
217
+ ("ェ", None, "e"),
218
+ ("ゥ", None, "u"),
219
+ ("ィ", None, "i"),
220
+ ("ァ", None, "a"),
221
+ ]
222
+
223
+ # 例: "vo" -> "ヴォ", "a" -> "ア"
224
+ mora_phonemes_to_mora_kata: dict[str, str] = {
225
+ (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum
226
+ }
227
+
228
+ # 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a")
229
+ mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = {
230
+ kana: (consonant, vowel)
231
+ for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional
232
+ }
text/opencpop-strict.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a AA a
2
+ ai AA ai
3
+ an AA an
4
+ ang AA ang
5
+ ao AA ao
6
+ ba b a
7
+ bai b ai
8
+ ban b an
9
+ bang b ang
10
+ bao b ao
11
+ bei b ei
12
+ ben b en
13
+ beng b eng
14
+ bi b i
15
+ bian b ian
16
+ biao b iao
17
+ bie b ie
18
+ bin b in
19
+ bing b ing
20
+ bo b o
21
+ bu b u
22
+ ca c a
23
+ cai c ai
24
+ can c an
25
+ cang c ang
26
+ cao c ao
27
+ ce c e
28
+ cei c ei
29
+ cen c en
30
+ ceng c eng
31
+ cha ch a
32
+ chai ch ai
33
+ chan ch an
34
+ chang ch ang
35
+ chao ch ao
36
+ che ch e
37
+ chen ch en
38
+ cheng ch eng
39
+ chi ch ir
40
+ chong ch ong
41
+ chou ch ou
42
+ chu ch u
43
+ chua ch ua
44
+ chuai ch uai
45
+ chuan ch uan
46
+ chuang ch uang
47
+ chui ch ui
48
+ chun ch un
49
+ chuo ch uo
50
+ ci c i0
51
+ cong c ong
52
+ cou c ou
53
+ cu c u
54
+ cuan c uan
55
+ cui c ui
56
+ cun c un
57
+ cuo c uo
58
+ da d a
59
+ dai d ai
60
+ dan d an
61
+ dang d ang
62
+ dao d ao
63
+ de d e
64
+ dei d ei
65
+ den d en
66
+ deng d eng
67
+ di d i
68
+ dia d ia
69
+ dian d ian
70
+ diao d iao
71
+ die d ie
72
+ ding d ing
73
+ diu d iu
74
+ dong d ong
75
+ dou d ou
76
+ du d u
77
+ duan d uan
78
+ dui d ui
79
+ dun d un
80
+ duo d uo
81
+ e EE e
82
+ ei EE ei
83
+ en EE en
84
+ eng EE eng
85
+ er EE er
86
+ fa f a
87
+ fan f an
88
+ fang f ang
89
+ fei f ei
90
+ fen f en
91
+ feng f eng
92
+ fo f o
93
+ fou f ou
94
+ fu f u
95
+ ga g a
96
+ gai g ai
97
+ gan g an
98
+ gang g ang
99
+ gao g ao
100
+ ge g e
101
+ gei g ei
102
+ gen g en
103
+ geng g eng
104
+ gong g ong
105
+ gou g ou
106
+ gu g u
107
+ gua g ua
108
+ guai g uai
109
+ guan g uan
110
+ guang g uang
111
+ gui g ui
112
+ gun g un
113
+ guo g uo
114
+ ha h a
115
+ hai h ai
116
+ han h an
117
+ hang h ang
118
+ hao h ao
119
+ he h e
120
+ hei h ei
121
+ hen h en
122
+ heng h eng
123
+ hong h ong
124
+ hou h ou
125
+ hu h u
126
+ hua h ua
127
+ huai h uai
128
+ huan h uan
129
+ huang h uang
130
+ hui h ui
131
+ hun h un
132
+ huo h uo
133
+ ji j i
134
+ jia j ia
135
+ jian j ian
136
+ jiang j iang
137
+ jiao j iao
138
+ jie j ie
139
+ jin j in
140
+ jing j ing
141
+ jiong j iong
142
+ jiu j iu
143
+ ju j v
144
+ jv j v
145
+ juan j van
146
+ jvan j van
147
+ jue j ve
148
+ jve j ve
149
+ jun j vn
150
+ jvn j vn
151
+ ka k a
152
+ kai k ai
153
+ kan k an
154
+ kang k ang
155
+ kao k ao
156
+ ke k e
157
+ kei k ei
158
+ ken k en
159
+ keng k eng
160
+ kong k ong
161
+ kou k ou
162
+ ku k u
163
+ kua k ua
164
+ kuai k uai
165
+ kuan k uan
166
+ kuang k uang
167
+ kui k ui
168
+ kun k un
169
+ kuo k uo
170
+ la l a
171
+ lai l ai
172
+ lan l an
173
+ lang l ang
174
+ lao l ao
175
+ le l e
176
+ lei l ei
177
+ leng l eng
178
+ li l i
179
+ lia l ia
180
+ lian l ian
181
+ liang l iang
182
+ liao l iao
183
+ lie l ie
184
+ lin l in
185
+ ling l ing
186
+ liu l iu
187
+ lo l o
188
+ long l ong
189
+ lou l ou
190
+ lu l u
191
+ luan l uan
192
+ lun l un
193
+ luo l uo
194
+ lv l v
195
+ lve l ve
196
+ ma m a
197
+ mai m ai
198
+ man m an
199
+ mang m ang
200
+ mao m ao
201
+ me m e
202
+ mei m ei
203
+ men m en
204
+ meng m eng
205
+ mi m i
206
+ mian m ian
207
+ miao m iao
208
+ mie m ie
209
+ min m in
210
+ ming m ing
211
+ miu m iu
212
+ mo m o
213
+ mou m ou
214
+ mu m u
215
+ na n a
216
+ nai n ai
217
+ nan n an
218
+ nang n ang
219
+ nao n ao
220
+ ne n e
221
+ nei n ei
222
+ nen n en
223
+ neng n eng
224
+ ni n i
225
+ nian n ian
226
+ niang n iang
227
+ niao n iao
228
+ nie n ie
229
+ nin n in
230
+ ning n ing
231
+ niu n iu
232
+ nong n ong
233
+ nou n ou
234
+ nu n u
235
+ nuan n uan
236
+ nun n un
237
+ nuo n uo
238
+ nv n v
239
+ nve n ve
240
+ o OO o
241
+ ou OO ou
242
+ pa p a
243
+ pai p ai
244
+ pan p an
245
+ pang p ang
246
+ pao p ao
247
+ pei p ei
248
+ pen p en
249
+ peng p eng
250
+ pi p i
251
+ pian p ian
252
+ piao p iao
253
+ pie p ie
254
+ pin p in
255
+ ping p ing
256
+ po p o
257
+ pou p ou
258
+ pu p u
259
+ qi q i
260
+ qia q ia
261
+ qian q ian
262
+ qiang q iang
263
+ qiao q iao
264
+ qie q ie
265
+ qin q in
266
+ qing q ing
267
+ qiong q iong
268
+ qiu q iu
269
+ qu q v
270
+ qv q v
271
+ quan q van
272
+ qvan q van
273
+ que q ve
274
+ qve q ve
275
+ qun q vn
276
+ qvn q vn
277
+ ran r an
278
+ rang r ang
279
+ rao r ao
280
+ re r e
281
+ ren r en
282
+ reng r eng
283
+ ri r ir
284
+ rong r ong
285
+ rou r ou
286
+ ru r u
287
+ rua r ua
288
+ ruan r uan
289
+ rui r ui
290
+ run r un
291
+ ruo r uo
292
+ sa s a
293
+ sai s ai
294
+ san s an
295
+ sang s ang
296
+ sao s ao
297
+ se s e
298
+ sen s en
299
+ seng s eng
300
+ sha sh a
301
+ shai sh ai
302
+ shan sh an
303
+ shang sh ang
304
+ shao sh ao
305
+ she sh e
306
+ shei sh ei
307
+ shen sh en
308
+ sheng sh eng
309
+ shi sh ir
310
+ shou sh ou
311
+ shu sh u
312
+ shua sh ua
313
+ shuai sh uai
314
+ shuan sh uan
315
+ shuang sh uang
316
+ shui sh ui
317
+ shun sh un
318
+ shuo sh uo
319
+ si s i0
320
+ song s ong
321
+ sou s ou
322
+ su s u
323
+ suan s uan
324
+ sui s ui
325
+ sun s un
326
+ suo s uo
327
+ ta t a
328
+ tai t ai
329
+ tan t an
330
+ tang t ang
331
+ tao t ao
332
+ te t e
333
+ tei t ei
334
+ teng t eng
335
+ ti t i
336
+ tian t ian
337
+ tiao t iao
338
+ tie t ie
339
+ ting t ing
340
+ tong t ong
341
+ tou t ou
342
+ tu t u
343
+ tuan t uan
344
+ tui t ui
345
+ tun t un
346
+ tuo t uo
347
+ wa w a
348
+ wai w ai
349
+ wan w an
350
+ wang w ang
351
+ wei w ei
352
+ wen w en
353
+ weng w eng
354
+ wo w o
355
+ wu w u
356
+ xi x i
357
+ xia x ia
358
+ xian x ian
359
+ xiang x iang
360
+ xiao x iao
361
+ xie x ie
362
+ xin x in
363
+ xing x ing
364
+ xiong x iong
365
+ xiu x iu
366
+ xu x v
367
+ xv x v
368
+ xuan x van
369
+ xvan x van
370
+ xue x ve
371
+ xve x ve
372
+ xun x vn
373
+ xvn x vn
374
+ ya y a
375
+ yan y En
376
+ yang y ang
377
+ yao y ao
378
+ ye y E
379
+ yi y i
380
+ yin y in
381
+ ying y ing
382
+ yo y o
383
+ yong y ong
384
+ you y ou
385
+ yu y v
386
+ yv y v
387
+ yuan y van
388
+ yvan y van
389
+ yue y ve
390
+ yve y ve
391
+ yun y vn
392
+ yvn y vn
393
+ za z a
394
+ zai z ai
395
+ zan z an
396
+ zang z ang
397
+ zao z ao
398
+ ze z e
399
+ zei z ei
400
+ zen z en
401
+ zeng z eng
402
+ zha zh a
403
+ zhai zh ai
404
+ zhan zh an
405
+ zhang zh ang
406
+ zhao zh ao
407
+ zhe zh e
408
+ zhei zh ei
409
+ zhen zh en
410
+ zheng zh eng
411
+ zhi zh ir
412
+ zhong zh ong
413
+ zhou zh ou
414
+ zhu zh u
415
+ zhua zh ua
416
+ zhuai zh uai
417
+ zhuan zh uan
418
+ zhuang zh uang
419
+ zhui zh ui
420
+ zhun zh un
421
+ zhuo zh uo
422
+ zi z i0
423
+ zong z ong
424
+ zou z ou
425
+ zu z u
426
+ zuan z uan
427
+ zui z ui
428
+ zun z un
429
+ zuo z uo
text/symbols.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
+
5
+ # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
+ num_zh_tones = 6
74
+
75
+ # japanese
76
+ ja_symbols = [
77
+ "N",
78
+ "a",
79
+ "a:",
80
+ "b",
81
+ "by",
82
+ "ch",
83
+ "d",
84
+ "dy",
85
+ "e",
86
+ "e:",
87
+ "f",
88
+ "g",
89
+ "gy",
90
+ "h",
91
+ "hy",
92
+ "i",
93
+ "i:",
94
+ "j",
95
+ "k",
96
+ "ky",
97
+ "m",
98
+ "my",
99
+ "n",
100
+ "ny",
101
+ "o",
102
+ "o:",
103
+ "p",
104
+ "py",
105
+ "q",
106
+ "r",
107
+ "ry",
108
+ "s",
109
+ "sh",
110
+ "t",
111
+ "ts",
112
+ "ty",
113
+ "u",
114
+ "u:",
115
+ "w",
116
+ "y",
117
+ "z",
118
+ "zy",
119
+ ]
120
+ num_ja_tones = 2
121
+
122
+ # English
123
+ en_symbols = [
124
+ "aa",
125
+ "ae",
126
+ "ah",
127
+ "ao",
128
+ "aw",
129
+ "ay",
130
+ "b",
131
+ "ch",
132
+ "d",
133
+ "dh",
134
+ "eh",
135
+ "er",
136
+ "ey",
137
+ "f",
138
+ "g",
139
+ "hh",
140
+ "ih",
141
+ "iy",
142
+ "jh",
143
+ "k",
144
+ "l",
145
+ "m",
146
+ "n",
147
+ "ng",
148
+ "ow",
149
+ "oy",
150
+ "p",
151
+ "r",
152
+ "s",
153
+ "sh",
154
+ "t",
155
+ "th",
156
+ "uh",
157
+ "uw",
158
+ "V",
159
+ "w",
160
+ "y",
161
+ "z",
162
+ "zh",
163
+ ]
164
+ num_en_tones = 4
165
+
166
+ # combine all symbols
167
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
+ symbols = [pad] + normal_symbols + pu_symbols
169
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
+
171
+ # combine all tones
172
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
+
174
+ # language maps
175
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
+ num_languages = len(language_id_map.keys())
177
+
178
+ language_tone_start_map = {
179
+ "ZH": 0,
180
+ "JP": num_zh_tones,
181
+ "EN": num_zh_tones + num_ja_tones,
182
+ }
183
+
184
+ if __name__ == "__main__":
185
+ a = set(zh_symbols)
186
+ b = set(en_symbols)
187
+ print(sorted(a & b))
text/tone_sandhi.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ from typing import Tuple
16
+
17
+ import jieba
18
+ from pypinyin import lazy_pinyin
19
+ from pypinyin import Style
20
+
21
+
22
+ class ToneSandhi:
23
+ def __init__(self):
24
+ self.must_neural_tone_words = {
25
+ "麻烦",
26
+ "麻利",
27
+ "鸳鸯",
28
+ "高粱",
29
+ "骨头",
30
+ "骆驼",
31
+ "马虎",
32
+ "首饰",
33
+ "馒头",
34
+ "馄饨",
35
+ "风筝",
36
+ "难为",
37
+ "队伍",
38
+ "阔气",
39
+ "闺女",
40
+ "门道",
41
+ "锄头",
42
+ "铺盖",
43
+ "铃铛",
44
+ "铁匠",
45
+ "钥匙",
46
+ "里脊",
47
+ "里头",
48
+ "部分",
49
+ "那么",
50
+ "道士",
51
+ "造化",
52
+ "迷糊",
53
+ "连累",
54
+ "这么",
55
+ "这个",
56
+ "运气",
57
+ "过去",
58
+ "软和",
59
+ "转悠",
60
+ "踏实",
61
+ "跳蚤",
62
+ "跟头",
63
+ "趔趄",
64
+ "财主",
65
+ "豆腐",
66
+ "讲究",
67
+ "记性",
68
+ "记号",
69
+ "认识",
70
+ "规矩",
71
+ "见识",
72
+ "裁缝",
73
+ "补丁",
74
+ "衣裳",
75
+ "衣服",
76
+ "衙门",
77
+ "街坊",
78
+ "行李",
79
+ "行当",
80
+ "蛤蟆",
81
+ "蘑菇",
82
+ "薄荷",
83
+ "葫芦",
84
+ "葡萄",
85
+ "萝卜",
86
+ "荸荠",
87
+ "苗条",
88
+ "苗头",
89
+ "苍蝇",
90
+ "芝麻",
91
+ "舒服",
92
+ "舒坦",
93
+ "舌头",
94
+ "自在",
95
+ "膏药",
96
+ "脾气",
97
+ "脑袋",
98
+ "脊梁",
99
+ "能耐",
100
+ "胳膊",
101
+ "胭脂",
102
+ "胡萝",
103
+ "胡琴",
104
+ "胡同",
105
+ "聪明",
106
+ "耽误",
107
+ "耽搁",
108
+ "耷拉",
109
+ "耳朵",
110
+ "老爷",
111
+ "老实",
112
+ "老婆",
113
+ "老头",
114
+ "老太",
115
+ "翻腾",
116
+ "罗嗦",
117
+ "罐头",
118
+ "编辑",
119
+ "结实",
120
+ "红火",
121
+ "累赘",
122
+ "糨糊",
123
+ "糊涂",
124
+ "精神",
125
+ "粮食",
126
+ "簸箕",
127
+ "篱笆",
128
+ "算计",
129
+ "算盘",
130
+ "答应",
131
+ "笤帚",
132
+ "笑语",
133
+ "笑话",
134
+ "窟窿",
135
+ "窝囊",
136
+ "窗户",
137
+ "稳当",
138
+ "稀罕",
139
+ "称呼",
140
+ "秧歌",
141
+ "秀气",
142
+ "秀才",
143
+ "福气",
144
+ "祖宗",
145
+ "砚台",
146
+ "码头",
147
+ "石榴",
148
+ "石头",
149
+ "石匠",
150
+ "知识",
151
+ "眼睛",
152
+ "眯缝",
153
+ "眨巴",
154
+ "眉毛",
155
+ "相声",
156
+ "盘算",
157
+ "白净",
158
+ "痢疾",
159
+ "痛快",
160
+ "疟疾",
161
+ "疙瘩",
162
+ "疏忽",
163
+ "畜生",
164
+ "生意",
165
+ "甘蔗",
166
+ "琵琶",
167
+ "琢磨",
168
+ "琉璃",
169
+ "玻璃",
170
+ "玫瑰",
171
+ "玄乎",
172
+ "狐狸",
173
+ "状元",
174
+ "特务",
175
+ "牲口",
176
+ "牙碜",
177
+ "牌楼",
178
+ "爽快",
179
+ "爱人",
180
+ "热闹",
181
+ "烧饼",
182
+ "烟筒",
183
+ "烂糊",
184
+ "点心",
185
+ "炊帚",
186
+ "灯笼",
187
+ "火候",
188
+ "漂亮",
189
+ "滑溜",
190
+ "溜达",
191
+ "温和",
192
+ "清楚",
193
+ "消息",
194
+ "浪头",
195
+ "活泼",
196
+ "比方",
197
+ "正经",
198
+ "欺负",
199
+ "模糊",
200
+ "槟榔",
201
+ "棺材",
202
+ "棒槌",
203
+ "棉花",
204
+ "核桃",
205
+ "栅栏",
206
+ "柴火",
207
+ "架势",
208
+ "枕头",
209
+ "枇杷",
210
+ "机灵",
211
+ "本事",
212
+ "木头",
213
+ "木匠",
214
+ "朋友",
215
+ "月饼",
216
+ "月亮",
217
+ "暖和",
218
+ "明白",
219
+ "时候",
220
+ "新鲜",
221
+ "故事",
222
+ "收拾",
223
+ "收成",
224
+ "提防",
225
+ "挖苦",
226
+ "挑剔",
227
+ "指甲",
228
+ "指头",
229
+ "拾掇",
230
+ "拳头",
231
+ "拨弄",
232
+ "招牌",
233
+ "招呼",
234
+ "抬举",
235
+ "护士",
236
+ "折腾",
237
+ "扫帚",
238
+ "打量",
239
+ "打算",
240
+ "打点",
241
+ "打扮",
242
+ "打听",
243
+ "打发",
244
+ "扎实",
245
+ "扁担",
246
+ "戒指",
247
+ "懒得",
248
+ "意识",
249
+ "意思",
250
+ "情形",
251
+ "悟性",
252
+ "怪物",
253
+ "思量",
254
+ "怎么",
255
+ "念头",
256
+ "念叨",
257
+ "快活",
258
+ "忙活",
259
+ "志气",
260
+ "心思",
261
+ "得罪",
262
+ "张罗",
263
+ "弟兄",
264
+ "开通",
265
+ "应酬",
266
+ "庄稼",
267
+ "干事",
268
+ "帮手",
269
+ "帐篷",
270
+ "希罕",
271
+ "师父",
272
+ "师傅",
273
+ "巴结",
274
+ "巴掌",
275
+ "差事",
276
+ "工夫",
277
+ "岁数",
278
+ "屁股",
279
+ "尾巴",
280
+ "少爷",
281
+ "小气",
282
+ "小伙",
283
+ "将就",
284
+ "对头",
285
+ "对付",
286
+ "寡妇",
287
+ "家伙",
288
+ "客气",
289
+ "实在",
290
+ "官司",
291
+ "学问",
292
+ "学生",
293
+ "字号",
294
+ "嫁妆",
295
+ "媳妇",
296
+ "媒人",
297
+ "婆家",
298
+ "娘家",
299
+ "委屈",
300
+ "姑娘",
301
+ "姐夫",
302
+ "妯娌",
303
+ "妥当",
304
+ "妖精",
305
+ "奴才",
306
+ "女婿",
307
+ "头发",
308
+ "太阳",
309
+ "大爷",
310
+ "大方",
311
+ "大意",
312
+ "大夫",
313
+ "多少",
314
+ "多么",
315
+ "外甥",
316
+ "壮实",
317
+ "地道",
318
+ "地方",
319
+ "在乎",
320
+ "困难",
321
+ "嘴巴",
322
+ "嘱咐",
323
+ "嘟囔",
324
+ "嘀咕",
325
+ "喜欢",
326
+ "喇嘛",
327
+ "喇叭",
328
+ "商量",
329
+ "唾沫",
330
+ "哑巴",
331
+ "哈欠",
332
+ "哆嗦",
333
+ "咳嗽",
334
+ "和尚",
335
+ "告诉",
336
+ "告示",
337
+ "含糊",
338
+ "吓唬",
339
+ "后头",
340
+ "名字",
341
+ "名堂",
342
+ "合同",
343
+ "吆喝",
344
+ "叫唤",
345
+ "口袋",
346
+ "厚道",
347
+ "厉害",
348
+ "千斤",
349
+ "包袱",
350
+ "包涵",
351
+ "匀称",
352
+ "勤快",
353
+ "动静",
354
+ "动弹",
355
+ "功夫",
356
+ "力气",
357
+ "前头",
358
+ "刺猬",
359
+ "刺激",
360
+ "别扭",
361
+ "利落",
362
+ "利索",
363
+ "利害",
364
+ "分析",
365
+ "出息",
366
+ "凑合",
367
+ "凉快",
368
+ "冷战",
369
+ "冤枉",
370
+ "冒失",
371
+ "养活",
372
+ "关系",
373
+ "先生",
374
+ "兄弟",
375
+ "便宜",
376
+ "使唤",
377
+ "佩服",
378
+ "作坊",
379
+ "体面",
380
+ "位置",
381
+ "似的",
382
+ "伙计",
383
+ "休息",
384
+ "什么",
385
+ "人家",
386
+ "亲戚",
387
+ "亲家",
388
+ "交情",
389
+ "云彩",
390
+ "事情",
391
+ "买卖",
392
+ "主意",
393
+ "丫头",
394
+ "丧气",
395
+ "两口",
396
+ "东西",
397
+ "东家",
398
+ "世故",
399
+ "不由",
400
+ "不在",
401
+ "下水",
402
+ "下巴",
403
+ "上头",
404
+ "上司",
405
+ "丈夫",
406
+ "丈人",
407
+ "一辈",
408
+ "那个",
409
+ "菩萨",
410
+ "父亲",
411
+ "母亲",
412
+ "咕噜",
413
+ "邋遢",
414
+ "费用",
415
+ "冤家",
416
+ "甜头",
417
+ "介绍",
418
+ "荒唐",
419
+ "大人",
420
+ "泥鳅",
421
+ "幸福",
422
+ "熟悉",
423
+ "计划",
424
+ "扑腾",
425
+ "蜡烛",
426
+ "姥爷",
427
+ "照顾",
428
+ "喉咙",
429
+ "吉他",
430
+ "弄堂",
431
+ "蚂蚱",
432
+ "凤凰",
433
+ "拖沓",
434
+ "寒碜",
435
+ "糟蹋",
436
+ "倒腾",
437
+ "报复",
438
+ "逻辑",
439
+ "盘缠",
440
+ "喽啰",
441
+ "牢骚",
442
+ "咖喱",
443
+ "扫把",
444
+ "惦记",
445
+ }
446
+ self.must_not_neural_tone_words = {
447
+ "男子",
448
+ "女子",
449
+ "分子",
450
+ "原子",
451
+ "量子",
452
+ "莲子",
453
+ "石子",
454
+ "瓜子",
455
+ "电子",
456
+ "人人",
457
+ "虎虎",
458
+ }
459
+ self.punc = ":,;。?!“”‘’':,;.?!"
460
+
461
+ # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
462
+ # e.g.
463
+ # word: "家里"
464
+ # pos: "s"
465
+ # finals: ['ia1', 'i3']
466
+ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
467
+ # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
468
+ for j, item in enumerate(word):
469
+ if (
470
+ j - 1 >= 0
471
+ and item == word[j - 1]
472
+ and pos[0] in {"n", "v", "a"}
473
+ and word not in self.must_not_neural_tone_words
474
+ ):
475
+ finals[j] = finals[j][:-1] + "5"
476
+ ge_idx = word.find("个")
477
+ if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
478
+ finals[-1] = finals[-1][:-1] + "5"
479
+ elif len(word) >= 1 and word[-1] in "的地得":
480
+ finals[-1] = finals[-1][:-1] + "5"
481
+ # e.g. 走了, 看着, 去过
482
+ # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
483
+ # finals[-1] = finals[-1][:-1] + "5"
484
+ elif (
485
+ len(word) > 1
486
+ and word[-1] in "们子"
487
+ and pos in {"r", "n"}
488
+ and word not in self.must_not_neural_tone_words
489
+ ):
490
+ finals[-1] = finals[-1][:-1] + "5"
491
+ # e.g. 桌上, 地下, 家里
492
+ elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
493
+ finals[-1] = finals[-1][:-1] + "5"
494
+ # e.g. 上来, 下去
495
+ elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
496
+ finals[-1] = finals[-1][:-1] + "5"
497
+ # 个做量词
498
+ elif (
499
+ ge_idx >= 1
500
+ and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
501
+ ) or word == "个":
502
+ finals[ge_idx] = finals[ge_idx][:-1] + "5"
503
+ else:
504
+ if (
505
+ word in self.must_neural_tone_words
506
+ or word[-2:] in self.must_neural_tone_words
507
+ ):
508
+ finals[-1] = finals[-1][:-1] + "5"
509
+
510
+ word_list = self._split_word(word)
511
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
512
+ for i, word in enumerate(word_list):
513
+ # conventional neural in Chinese
514
+ if (
515
+ word in self.must_neural_tone_words
516
+ or word[-2:] in self.must_neural_tone_words
517
+ ):
518
+ finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
519
+ finals = sum(finals_list, [])
520
+ return finals
521
+
522
+ def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
523
+ # e.g. 看不懂
524
+ if len(word) == 3 and word[1] == "不":
525
+ finals[1] = finals[1][:-1] + "5"
526
+ else:
527
+ for i, char in enumerate(word):
528
+ # "不" before tone4 should be bu2, e.g. 不怕
529
+ if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
530
+ finals[i] = finals[i][:-1] + "2"
531
+ return finals
532
+
533
+ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
534
+ # "一" in number sequences, e.g. 一零零, 二一零
535
+ if word.find("一") != -1 and all(
536
+ [item.isnumeric() for item in word if item != "一"]
537
+ ):
538
+ return finals
539
+ # "一" between reduplication words should be yi5, e.g. 看一看
540
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
541
+ finals[1] = finals[1][:-1] + "5"
542
+ # when "一" is ordinal word, it should be yi1
543
+ elif word.startswith("第一"):
544
+ finals[1] = finals[1][:-1] + "1"
545
+ else:
546
+ for i, char in enumerate(word):
547
+ if char == "一" and i + 1 < len(word):
548
+ # "一" before tone4 should be yi2, e.g. 一段
549
+ if finals[i + 1][-1] == "4":
550
+ finals[i] = finals[i][:-1] + "2"
551
+ # "一" before non-tone4 should be yi4, e.g. 一天
552
+ else:
553
+ # "一" 后面如果是标点,还读一声
554
+ if word[i + 1] not in self.punc:
555
+ finals[i] = finals[i][:-1] + "4"
556
+ return finals
557
+
558
+ def _split_word(self, word: str) -> List[str]:
559
+ word_list = jieba.cut_for_search(word)
560
+ word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
561
+ first_subword = word_list[0]
562
+ first_begin_idx = word.find(first_subword)
563
+ if first_begin_idx == 0:
564
+ second_subword = word[len(first_subword) :]
565
+ new_word_list = [first_subword, second_subword]
566
+ else:
567
+ second_subword = word[: -len(first_subword)]
568
+ new_word_list = [second_subword, first_subword]
569
+ return new_word_list
570
+
571
+ def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
572
+ if len(word) == 2 and self._all_tone_three(finals):
573
+ finals[0] = finals[0][:-1] + "2"
574
+ elif len(word) == 3:
575
+ word_list = self._split_word(word)
576
+ if self._all_tone_three(finals):
577
+ # disyllabic + monosyllabic, e.g. 蒙古/包
578
+ if len(word_list[0]) == 2:
579
+ finals[0] = finals[0][:-1] + "2"
580
+ finals[1] = finals[1][:-1] + "2"
581
+ # monosyllabic + disyllabic, e.g. 纸/老虎
582
+ elif len(word_list[0]) == 1:
583
+ finals[1] = finals[1][:-1] + "2"
584
+ else:
585
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
586
+ if len(finals_list) == 2:
587
+ for i, sub in enumerate(finals_list):
588
+ # e.g. 所有/人
589
+ if self._all_tone_three(sub) and len(sub) == 2:
590
+ finals_list[i][0] = finals_list[i][0][:-1] + "2"
591
+ # e.g. 好/喜欢
592
+ elif (
593
+ i == 1
594
+ and not self._all_tone_three(sub)
595
+ and finals_list[i][0][-1] == "3"
596
+ and finals_list[0][-1][-1] == "3"
597
+ ):
598
+ finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
599
+ finals = sum(finals_list, [])
600
+ # split idiom into two words who's length is 2
601
+ elif len(word) == 4:
602
+ finals_list = [finals[:2], finals[2:]]
603
+ finals = []
604
+ for sub in finals_list:
605
+ if self._all_tone_three(sub):
606
+ sub[0] = sub[0][:-1] + "2"
607
+ finals += sub
608
+
609
+ return finals
610
+
611
+ def _all_tone_three(self, finals: List[str]) -> bool:
612
+ return all(x[-1] == "3" for x in finals)
613
+
614
+ # merge "不" and the word behind it
615
+ # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
616
+ def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
617
+ new_seg = []
618
+ last_word = ""
619
+ for word, pos in seg:
620
+ if last_word == "不":
621
+ word = last_word + word
622
+ if word != "不":
623
+ new_seg.append((word, pos))
624
+ last_word = word[:]
625
+ if last_word == "不":
626
+ new_seg.append((last_word, "d"))
627
+ last_word = ""
628
+ return new_seg
629
+
630
+ # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
631
+ # function 2: merge single "一" and the word behind it
632
+ # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
633
+ # e.g.
634
+ # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
635
+ # output seg: [['听一听', 'v']]
636
+ def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
637
+ new_seg = [] * len(seg)
638
+ # function 1
639
+ i = 0
640
+ while i < len(seg):
641
+ word, pos = seg[i]
642
+ if (
643
+ i - 1 >= 0
644
+ and word == "一"
645
+ and i + 1 < len(seg)
646
+ and seg[i - 1][0] == seg[i + 1][0]
647
+ and seg[i - 1][1] == "v"
648
+ ):
649
+ new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
650
+ i += 2
651
+ else:
652
+ if (
653
+ i - 2 >= 0
654
+ and seg[i - 1][0] == "一"
655
+ and seg[i - 2][0] == word
656
+ and pos == "v"
657
+ ):
658
+ continue
659
+ else:
660
+ new_seg.append([word, pos])
661
+ i += 1
662
+ seg = [i for i in new_seg if len(i) > 0]
663
+ new_seg = []
664
+ # function 2
665
+ for i, (word, pos) in enumerate(seg):
666
+ if new_seg and new_seg[-1][0] == "一":
667
+ new_seg[-1][0] = new_seg[-1][0] + word
668
+ else:
669
+ new_seg.append([word, pos])
670
+ return new_seg
671
+
672
+ # the first and the second words are all_tone_three
673
+ def _merge_continuous_three_tones(
674
+ self, seg: List[Tuple[str, str]]
675
+ ) -> List[Tuple[str, str]]:
676
+ new_seg = []
677
+ sub_finals_list = [
678
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
679
+ for (word, pos) in seg
680
+ ]
681
+ assert len(sub_finals_list) == len(seg)
682
+ merge_last = [False] * len(seg)
683
+ for i, (word, pos) in enumerate(seg):
684
+ if (
685
+ i - 1 >= 0
686
+ and self._all_tone_three(sub_finals_list[i - 1])
687
+ and self._all_tone_three(sub_finals_list[i])
688
+ and not merge_last[i - 1]
689
+ ):
690
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
691
+ if (
692
+ not self._is_reduplication(seg[i - 1][0])
693
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
694
+ ):
695
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
696
+ merge_last[i] = True
697
+ else:
698
+ new_seg.append([word, pos])
699
+ else:
700
+ new_seg.append([word, pos])
701
+
702
+ return new_seg
703
+
704
+ def _is_reduplication(self, word: str) -> bool:
705
+ return len(word) == 2 and word[0] == word[1]
706
+
707
+ # the last char of first word and the first char of second word is tone_three
708
+ def _merge_continuous_three_tones_2(
709
+ self, seg: List[Tuple[str, str]]
710
+ ) -> List[Tuple[str, str]]:
711
+ new_seg = []
712
+ sub_finals_list = [
713
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
714
+ for (word, pos) in seg
715
+ ]
716
+ assert len(sub_finals_list) == len(seg)
717
+ merge_last = [False] * len(seg)
718
+ for i, (word, pos) in enumerate(seg):
719
+ if (
720
+ i - 1 >= 0
721
+ and sub_finals_list[i - 1][-1][-1] == "3"
722
+ and sub_finals_list[i][0][-1] == "3"
723
+ and not merge_last[i - 1]
724
+ ):
725
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
726
+ if (
727
+ not self._is_reduplication(seg[i - 1][0])
728
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
729
+ ):
730
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
731
+ merge_last[i] = True
732
+ else:
733
+ new_seg.append([word, pos])
734
+ else:
735
+ new_seg.append([word, pos])
736
+ return new_seg
737
+
738
+ def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
739
+ new_seg = []
740
+ for i, (word, pos) in enumerate(seg):
741
+ if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
742
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
743
+ else:
744
+ new_seg.append([word, pos])
745
+ return new_seg
746
+
747
+ def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
748
+ new_seg = []
749
+ for i, (word, pos) in enumerate(seg):
750
+ if new_seg and word == new_seg[-1][0]:
751
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
752
+ else:
753
+ new_seg.append([word, pos])
754
+ return new_seg
755
+
756
+ def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
757
+ seg = self._merge_bu(seg)
758
+ try:
759
+ seg = self._merge_yi(seg)
760
+ except:
761
+ print("_merge_yi failed")
762
+ seg = self._merge_reduplication(seg)
763
+ seg = self._merge_continuous_three_tones(seg)
764
+ seg = self._merge_continuous_three_tones_2(seg)
765
+ seg = self._merge_er(seg)
766
+ return seg
767
+
768
+ def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
769
+ finals = self._bu_sandhi(word, finals)
770
+ finals = self._yi_sandhi(word, finals)
771
+ finals = self._neural_sandhi(word, pos, finals)
772
+ finals = self._three_sandhi(word, finals)
773
+ return finals
tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ 工具包
3
+ """
tools/classify_language.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+
3
+ try:
4
+ from config import config
5
+
6
+ LANGUAGE_IDENTIFICATION_LIBRARY = (
7
+ config.webui_config.language_identification_library
8
+ )
9
+ except:
10
+ LANGUAGE_IDENTIFICATION_LIBRARY = "langid"
11
+
12
+ module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
13
+
14
+ langid_languages = [
15
+ "af",
16
+ "am",
17
+ "an",
18
+ "ar",
19
+ "as",
20
+ "az",
21
+ "be",
22
+ "bg",
23
+ "bn",
24
+ "br",
25
+ "bs",
26
+ "ca",
27
+ "cs",
28
+ "cy",
29
+ "da",
30
+ "de",
31
+ "dz",
32
+ "el",
33
+ "en",
34
+ "eo",
35
+ "es",
36
+ "et",
37
+ "eu",
38
+ "fa",
39
+ "fi",
40
+ "fo",
41
+ "fr",
42
+ "ga",
43
+ "gl",
44
+ "gu",
45
+ "he",
46
+ "hi",
47
+ "hr",
48
+ "ht",
49
+ "hu",
50
+ "hy",
51
+ "id",
52
+ "is",
53
+ "it",
54
+ "ja",
55
+ "jv",
56
+ "ka",
57
+ "kk",
58
+ "km",
59
+ "kn",
60
+ "ko",
61
+ "ku",
62
+ "ky",
63
+ "la",
64
+ "lb",
65
+ "lo",
66
+ "lt",
67
+ "lv",
68
+ "mg",
69
+ "mk",
70
+ "ml",
71
+ "mn",
72
+ "mr",
73
+ "ms",
74
+ "mt",
75
+ "nb",
76
+ "ne",
77
+ "nl",
78
+ "nn",
79
+ "no",
80
+ "oc",
81
+ "or",
82
+ "pa",
83
+ "pl",
84
+ "ps",
85
+ "pt",
86
+ "qu",
87
+ "ro",
88
+ "ru",
89
+ "rw",
90
+ "se",
91
+ "si",
92
+ "sk",
93
+ "sl",
94
+ "sq",
95
+ "sr",
96
+ "sv",
97
+ "sw",
98
+ "ta",
99
+ "te",
100
+ "th",
101
+ "tl",
102
+ "tr",
103
+ "ug",
104
+ "uk",
105
+ "ur",
106
+ "vi",
107
+ "vo",
108
+ "wa",
109
+ "xh",
110
+ "zh",
111
+ "zu",
112
+ ]
113
+
114
+
115
+ def classify_language(text: str, target_languages: list = None) -> str:
116
+ if module == "fastlid" or module == "fasttext":
117
+ from fastlid import fastlid, supported_langs
118
+
119
+ classifier = fastlid
120
+ if target_languages != None:
121
+ target_languages = [
122
+ lang for lang in target_languages if lang in supported_langs
123
+ ]
124
+ fastlid.set_languages = target_languages
125
+ elif module == "langid":
126
+ import langid
127
+
128
+ classifier = langid.classify
129
+ if target_languages != None:
130
+ target_languages = [
131
+ lang for lang in target_languages if lang in langid_languages
132
+ ]
133
+ langid.set_languages(target_languages)
134
+ else:
135
+ raise ValueError(f"Wrong module {module}")
136
+
137
+ lang = classifier(text)[0]
138
+
139
+ return lang
140
+
141
+
142
+ def classify_zh_ja(text: str) -> str:
143
+ for idx, char in enumerate(text):
144
+ unicode_val = ord(char)
145
+
146
+ # 检测日语字符
147
+ if 0x3040 <= unicode_val <= 0x309F or 0x30A0 <= unicode_val <= 0x30FF:
148
+ return "ja"
149
+
150
+ # 检测汉字字符
151
+ if 0x4E00 <= unicode_val <= 0x9FFF:
152
+ # 检查周围的字符
153
+ next_char = text[idx + 1] if idx + 1 < len(text) else None
154
+
155
+ if next_char and (
156
+ 0x3040 <= ord(next_char) <= 0x309F or 0x30A0 <= ord(next_char) <= 0x30FF
157
+ ):
158
+ return "ja"
159
+
160
+ return "zh"
161
+
162
+
163
+ def split_alpha_nonalpha(text, mode=1):
164
+ if mode == 1:
165
+ pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\d\s])(?=[\p{Latin}])|(?<=[\p{Latin}\s])(?=[\u4e00-\u9fff\u3040-\u30FF\d])"
166
+ elif mode == 2:
167
+ pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\s])(?=[\p{Latin}\d])|(?<=[\p{Latin}\d\s])(?=[\u4e00-\u9fff\u3040-\u30FF])"
168
+ else:
169
+ raise ValueError("Invalid mode. Supported modes are 1 and 2.")
170
+
171
+ return re.split(pattern, text)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ text = "这是一个测试文本"
176
+ print(classify_language(text))
177
+ print(classify_zh_ja(text)) # "zh"
178
+
179
+ text = "これはテストテキストです"
180
+ print(classify_language(text))
181
+ print(classify_zh_ja(text)) # "ja"
182
+
183
+ text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days"
184
+
185
+ print(split_alpha_nonalpha(text, mode=1))
186
+ # output: ['vits', '和', 'Bert-VITS', '2是', 'tts', '模型。花费3', 'days.花费3天。Take 3 days']
187
+
188
+ print(split_alpha_nonalpha(text, mode=2))
189
+ # output: ['vits', '和', 'Bert-VITS2', '是', 'tts', '模型。花费', '3days.花费', '3', '天。Take 3 days']
190
+
191
+ text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days"
192
+ print(split_alpha_nonalpha(text, mode=1))
193
+ # output: ['vits ', '和 ', 'Bert-VITS', '2 ', '是 ', 'tts ', '模型。花费3', 'days.花费3天。Take ', '3 ', 'days']
194
+
195
+ text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days"
196
+ print(split_alpha_nonalpha(text, mode=2))
197
+ # output: ['vits ', '和 ', 'Bert-VITS2 ', '是 ', 'tts ', '模型。花费', '3days.花费', '3', '天。Take ', '3 ', 'days']
tools/sentence.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import regex as re
4
+
5
+ from tools.classify_language import classify_language, split_alpha_nonalpha
6
+
7
+
8
+ def check_is_none(item) -> bool:
9
+ """none -> True, not none -> False"""
10
+ return (
11
+ item is None
12
+ or (isinstance(item, str) and str(item).isspace())
13
+ or str(item) == ""
14
+ )
15
+
16
+
17
+ def markup_language(text: str, target_languages: list = None) -> str:
18
+ pattern = (
19
+ r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
20
+ r"\!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
21
+ r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+"
22
+ )
23
+ sentences = re.split(pattern, text)
24
+
25
+ pre_lang = ""
26
+ p = 0
27
+
28
+ if target_languages is not None:
29
+ sorted_target_languages = sorted(target_languages)
30
+ if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
31
+ new_sentences = []
32
+ for sentence in sentences:
33
+ new_sentences.extend(split_alpha_nonalpha(sentence))
34
+ sentences = new_sentences
35
+
36
+ for sentence in sentences:
37
+ if check_is_none(sentence):
38
+ continue
39
+
40
+ lang = classify_language(sentence, target_languages)
41
+
42
+ if pre_lang == "":
43
+ text = text[:p] + text[p:].replace(
44
+ sentence, f"[{lang.upper()}]{sentence}", 1
45
+ )
46
+ p += len(f"[{lang.upper()}]")
47
+ elif pre_lang != lang:
48
+ text = text[:p] + text[p:].replace(
49
+ sentence, f"[{pre_lang.upper()}][{lang.upper()}]{sentence}", 1
50
+ )
51
+ p += len(f"[{pre_lang.upper()}][{lang.upper()}]")
52
+ pre_lang = lang
53
+ p += text[p:].index(sentence) + len(sentence)
54
+ text += f"[{pre_lang.upper()}]"
55
+
56
+ return text
57
+
58
+
59
+ def split_by_language(text: str, target_languages: list = None) -> list:
60
+ pattern = (
61
+ r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
62
+ r"\!?\。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
63
+ r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+"
64
+ )
65
+ sentences = re.split(pattern, text)
66
+
67
+ pre_lang = ""
68
+ start = 0
69
+ end = 0
70
+ sentences_list = []
71
+
72
+ if target_languages is not None:
73
+ sorted_target_languages = sorted(target_languages)
74
+ if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
75
+ new_sentences = []
76
+ for sentence in sentences:
77
+ new_sentences.extend(split_alpha_nonalpha(sentence))
78
+ sentences = new_sentences
79
+
80
+ for sentence in sentences:
81
+ if check_is_none(sentence):
82
+ continue
83
+
84
+ lang = classify_language(sentence, target_languages)
85
+
86
+ end += text[end:].index(sentence)
87
+ if pre_lang != "" and pre_lang != lang:
88
+ sentences_list.append((text[start:end], pre_lang))
89
+ start = end
90
+ end += len(sentence)
91
+ pre_lang = lang
92
+ sentences_list.append((text[start:], pre_lang))
93
+
94
+ return sentences_list
95
+
96
+
97
+ def sentence_split(text: str, max: int) -> list:
98
+ pattern = r"[!(),—+\-.:;??。,、;:]+"
99
+ sentences = re.split(pattern, text)
100
+ discarded_chars = re.findall(pattern, text)
101
+
102
+ sentences_list, count, p = [], 0, 0
103
+
104
+ # 按被分割的符号遍历
105
+ for i, discarded_chars in enumerate(discarded_chars):
106
+ count += len(sentences[i]) + len(discarded_chars)
107
+ if count >= max:
108
+ sentences_list.append(text[p : p + count].strip())
109
+ p += count
110
+ count = 0
111
+
112
+ # 加入最后剩余的文本
113
+ if p < len(text):
114
+ sentences_list.append(text[p:])
115
+
116
+ return sentences_list
117
+
118
+
119
+ def sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None):
120
+ # 如果该speaker只支持一种语言
121
+ if speaker_lang is not None and len(speaker_lang) == 1:
122
+ if lang.upper() not in ["AUTO", "MIX"] and lang.lower() != speaker_lang[0]:
123
+ logging.debug(
124
+ f'lang "{lang}" is not in speaker_lang {speaker_lang},automatically set lang={speaker_lang[0]}'
125
+ )
126
+ lang = speaker_lang[0]
127
+
128
+ sentences_list = []
129
+ if lang.upper() != "MIX":
130
+ if max <= 0:
131
+ sentences_list.append(
132
+ markup_language(text, speaker_lang)
133
+ if lang.upper() == "AUTO"
134
+ else f"[{lang.upper()}]{text}[{lang.upper()}]"
135
+ )
136
+ else:
137
+ for i in sentence_split(text, max):
138
+ if check_is_none(i):
139
+ continue
140
+ sentences_list.append(
141
+ markup_language(i, speaker_lang)
142
+ if lang.upper() == "AUTO"
143
+ else f"[{lang.upper()}]{i}[{lang.upper()}]"
144
+ )
145
+ else:
146
+ sentences_list.append(text)
147
+
148
+ for i in sentences_list:
149
+ logging.debug(i)
150
+
151
+ return sentences_list
152
+
153
+
154
+ if __name__ == "__main__":
155
+ text = "这几天心里颇不宁静。今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。我悄悄地披了大衫,带上门出去。"
156
+ print(markup_language(text, target_languages=None))
157
+ print(sentence_split(text, max=50))
158
+ print(sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None))
159
+
160
+ text = "你好,这是一段用来测试自动标注的文本。こんにちは,これは自動ラベリングのテスト用テキストです.Hello, this is a piece of text to test autotagging.你好!今天我们要介绍VITS项目,其重点是使用了GAN Duration predictor和transformer flow,并且接入了Bert模型来提升韵律。Bert embedding会在稍后介绍。"
161
+ print(split_by_language(text, ["zh", "ja", "en"]))
162
+
163
+ text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days"
164
+
165
+ print(split_by_language(text, ["zh", "ja", "en"]))
166
+ # output: [('vits', 'en'), ('和', 'ja'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')]
167
+
168
+ print(split_by_language(text, ["zh", "en"]))
169
+ # output: [('vits', 'en'), ('和', 'zh'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')]
170
+
171
+ text = "vits 和 Bert-VITS2 是 tts 模型。花费 3 days. 花费 3天。Take 3 days"
172
+ print(split_by_language(text, ["zh", "en"]))
173
+ # output: [('vits ', 'en'), ('和 ', 'zh'), ('Bert-VITS2 ', 'en'), ('是 ', 'zh'), ('tts ', 'en'), ('模型。花费 ', 'zh'), ('3 days. ', 'en'), ('花费 3天。', 'zh'), ('Take 3 days', 'en')]
tools/translate.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 翻译api
3
+ """
4
+ from config import config
5
+
6
+ import random
7
+ import hashlib
8
+ import requests
9
+
10
+
11
+ def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
12
+ """
13
+ :param Sentence: 待翻译语句
14
+ :param from_Language: 待翻译语句语言
15
+ :param to_Language: 目标语言
16
+ :return: 翻译后语句 出错时返回None
17
+
18
+ 常见语言代码:中文 zh 英语 en 日语 jp
19
+ """
20
+ appid = config.translate_config.app_key
21
+ key = config.translate_config.secret_key
22
+ if appid == "" or key == "":
23
+ return "请开发者在config.yml中配置app_key与secret_key"
24
+ url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
25
+ texts = Sentence.splitlines()
26
+ outTexts = []
27
+ for t in texts:
28
+ if t != "":
29
+ # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
30
+ salt = str(random.randint(1, 100000))
31
+ signString = appid + t + salt + key
32
+ hs = hashlib.md5()
33
+ hs.update(signString.encode("utf-8"))
34
+ signString = hs.hexdigest()
35
+ if from_Language == "":
36
+ from_Language = "auto"
37
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
38
+ payload = {
39
+ "q": t,
40
+ "from": from_Language,
41
+ "to": to_Language,
42
+ "appid": appid,
43
+ "salt": salt,
44
+ "sign": signString,
45
+ }
46
+ # 发送请求
47
+ try:
48
+ response = requests.post(
49
+ url=url, data=payload, headers=headers, timeout=3
50
+ )
51
+ response = response.json()
52
+ if "trans_result" in response.keys():
53
+ result = response["trans_result"][0]
54
+ if "dst" in result.keys():
55
+ dst = result["dst"]
56
+ outTexts.append(dst)
57
+ except Exception:
58
+ return Sentence
59
+ else:
60
+ outTexts.append(t)
61
+ return "\n".join(outTexts)
utils.py CHANGED
@@ -13,7 +13,7 @@ from safetensors import safe_open
13
  from safetensors.torch import save_file
14
  from scipy.io.wavfile import read
15
 
16
- from tools.log import logger
17
 
18
  MATPLOTLIB_FLAG = False
19
 
@@ -189,10 +189,11 @@ def summarize(
189
 
190
 
191
  def is_resuming(dir_path):
 
192
  g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
193
- d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
194
- dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
195
- return len(g_list) > 0 and len(d_list) > 0 and len(dur_list) > 0
196
 
197
 
198
  def latest_checkpoint_path(dir_path, regex="G_*.pth"):
@@ -348,7 +349,7 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
348
  ]
349
 
350
  def del_info(fn):
351
- return logger.info(f".. Free up space by deleting ckpt {fn}")
352
 
353
  def del_routine(x):
354
  return [os.remove(x), del_info(x)]
 
13
  from safetensors.torch import save_file
14
  from scipy.io.wavfile import read
15
 
16
+ from common.log import logger
17
 
18
  MATPLOTLIB_FLAG = False
19
 
 
189
 
190
 
191
  def is_resuming(dir_path):
192
+ # JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する
193
  g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
194
+ # d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
195
+ # dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
196
+ return len(g_list) > 0
197
 
198
 
199
  def latest_checkpoint_path(dir_path, regex="G_*.pth"):
 
349
  ]
350
 
351
  def del_info(fn):
352
+ return logger.info(f"Free up space by deleting ckpt {fn}")
353
 
354
  def del_routine(x):
355
  return [os.remove(x), del_info(x)]