Artrajz commited on
Commit
14e19a5
1 Parent(s): f86cbba
.gitignore CHANGED
@@ -3,3 +3,7 @@
3
  /logs/
4
  /cache/
5
  /upload/
 
 
 
 
 
3
  /logs/
4
  /cache/
5
  /upload/
6
+ /vits/text/chinese_dialect_lexicons/
7
+ /bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin
8
+ /bert_vits2/bert/bert-base-japanese-v3/pytorch_model.bin
9
+ /vits/bert/prosody_model.pt
README_zh.md CHANGED
@@ -5,7 +5,7 @@
5
  <br/>
6
  <p>
7
  <img src="https://img.shields.io/github/license/Artrajz/vits-simple-api">
8
- <img src="https://img.shields.io/badge/python-3.9%7C3.10-green">
9
  <a href="https://hub.docker.com/r/artrajz/vits-simple-api">
10
  <img src="https://img.shields.io/docker/pulls/artrajz/vits-simple-api"></a>
11
  </p>
@@ -16,6 +16,7 @@
16
 
17
 
18
 
 
19
  # Feature
20
 
21
  - [x] VITS语音合成,语音转换
@@ -48,7 +49,7 @@ https://user-images.githubusercontent.com/73542220/237995061-c1f25b4e-dd86-438a-
48
 
49
  # 部署
50
 
51
- ## Docker部署
52
 
53
  ### 镜像拉取脚本
54
 
@@ -56,7 +57,7 @@ https://user-images.githubusercontent.com/73542220/237995061-c1f25b4e-dd86-438a-
56
  bash -c "$(wget -O- https://raw.githubusercontent.com/Artrajz/vits-simple-api/main/vits-simple-api-installer-latest.sh)"
57
  ```
58
 
59
- - 目前docker镜像支持的平台`linux/amd64,linux/arm64`
60
  - 在拉取完成后,需要导入VITS模型才能使用,请根据以下步骤导入模型。
61
 
62
  ### 下载VITS模型
@@ -137,17 +138,15 @@ DIMENSIONAL_EMOTION_MODEL = ABS_PATH + "/Model/model.yaml"
137
 
138
  ### 下载python依赖
139
 
140
- 推荐使用python的虚拟环境,python版本 >= 3.9
141
 
142
  `pip install -r requirements.txt`
143
 
144
  windows下可能安装不了fasttext,可以用以下命令安装,附[wheels下载地址](https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext)
145
 
146
  ```
147
- #python3.10 win_amd64
148
  pip install https://github.com/Artrajz/archived/raw/main/fasttext/fasttext-0.9.2-cp310-cp310-win_amd64.whl
149
- #python3.9 win_amd64
150
- pip install https://github.com/Artrajz/archived/raw/main/fasttext/fasttext-0.9.2-cp39-cp39-win_amd64.whl
151
  ```
152
 
153
  ### 下载VITS模型
@@ -222,28 +221,24 @@ nvidia-smi
222
 
223
  ### 安装GPU版pytorch
224
 
225
- CUDA11.7对应的pytorch是用这个命令安装
226
 
227
  ```
228
- pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
229
  ```
230
 
231
- 对应版本的命令可以在[官网](https://pytorch.org/get-started/locally/)找到
232
-
233
  ## Linux
234
 
235
  安装过程类似,但我没有相应的环境所以没办法测试
236
 
237
- # Openjtalk安装问题
238
 
239
- 如果你是arm64架构的平台,由于pypi官网上没有arm64对应的whl,可能安装会出现一些问题,你可以使用我构建的whl来安装
240
 
241
  ```
242
- pip install openjtalk==0.3.0.dev2 --index-url https://pypi.artrajz.cn/simple
243
  ```
244
 
245
- 或者是自己手动构建一个whl,可以根据[教程](https://artrajz.cn/index.php/archives/167/)来构建
246
-
247
  # API
248
 
249
  ## GET
@@ -274,190 +269,9 @@ pip install openjtalk==0.3.0.dev2 --index-url https://pypi.artrajz.cn/simple
274
 
275
  ## POST
276
 
277
- - python
278
-
279
- ```python
280
- import re
281
- import requests
282
- import os
283
- import random
284
- import string
285
- from requests_toolbelt.multipart.encoder import MultipartEncoder
286
-
287
- abs_path = os.path.dirname(__file__)
288
- base = "http://127.0.0.1:23456"
289
-
290
-
291
- # 映射表
292
- def voice_speakers():
293
- url = f"{base}/voice/speakers"
294
-
295
- res = requests.post(url=url)
296
- json = res.json()
297
- for i in json:
298
- print(i)
299
- for j in json[i]:
300
- print(j)
301
- return json
302
-
303
-
304
- # 语音合成 voice vits
305
- def voice_vits(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50):
306
- fields = {
307
- "text": text,
308
- "id": str(id),
309
- "format": format,
310
- "lang": lang,
311
- "length": str(length),
312
- "noise": str(noise),
313
- "noisew": str(noisew),
314
- "max": str(max)
315
- }
316
- boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
317
-
318
- m = MultipartEncoder(fields=fields, boundary=boundary)
319
- headers = {"Content-Type": m.content_type}
320
- url = f"{base}/voice"
321
-
322
- res = requests.post(url=url, data=m, headers=headers)
323
- fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
324
- path = f"{abs_path}/{fname}"
325
-
326
- with open(path, "wb") as f:
327
- f.write(res.content)
328
- print(path)
329
- return path
330
-
331
-
332
- # 语音转换 hubert-vits
333
- def voice_hubert_vits(upload_path, id, format="wav", length=1, noise=0.667, noisew=0.8):
334
- upload_name = os.path.basename(upload_path)
335
- upload_type = f'audio/{upload_name.split(".")[1]}' # wav,ogg
336
-
337
- with open(upload_path, 'rb') as upload_file:
338
- fields = {
339
- "upload": (upload_name, upload_file, upload_type),
340
- "id": str(id),
341
- "format": format,
342
- "length": str(length),
343
- "noise": str(noise),
344
- "noisew": str(noisew),
345
- }
346
- boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
347
-
348
- m = MultipartEncoder(fields=fields, boundary=boundary)
349
- headers = {"Content-Type": m.content_type}
350
- url = f"{base}/voice/hubert-vits"
351
-
352
- res = requests.post(url=url, data=m, headers=headers)
353
- fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
354
- path = f"{abs_path}/{fname}"
355
-
356
- with open(path, "wb") as f:
357
- f.write(res.content)
358
- print(path)
359
- return path
360
-
361
-
362
- # 维度情感模型 w2v2-vits
363
- def voice_w2v2_vits(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50, emotion=0):
364
- fields = {
365
- "text": text,
366
- "id": str(id),
367
- "format": format,
368
- "lang": lang,
369
- "length": str(length),
370
- "noise": str(noise),
371
- "noisew": str(noisew),
372
- "max": str(max),
373
- "emotion": str(emotion)
374
- }
375
- boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
376
-
377
- m = MultipartEncoder(fields=fields, boundary=boundary)
378
- headers = {"Content-Type": m.content_type}
379
- url = f"{base}/voice/w2v2-vits"
380
-
381
- res = requests.post(url=url, data=m, headers=headers)
382
- fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
383
- path = f"{abs_path}/{fname}"
384
-
385
- with open(path, "wb") as f:
386
- f.write(res.content)
387
- print(path)
388
- return path
389
-
390
-
391
- # 语音转换 同VITS模型内角色之间的音色转换
392
- def voice_conversion(upload_path, original_id, target_id):
393
- upload_name = os.path.basename(upload_path)
394
- upload_type = f'audio/{upload_name.split(".")[1]}' # wav,ogg
395
-
396
- with open(upload_path, 'rb') as upload_file:
397
- fields = {
398
- "upload": (upload_name, upload_file, upload_type),
399
- "original_id": str(original_id),
400
- "target_id": str(target_id),
401
- }
402
- boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
403
- m = MultipartEncoder(fields=fields, boundary=boundary)
404
-
405
- headers = {"Content-Type": m.content_type}
406
- url = f"{base}/voice/conversion"
407
-
408
- res = requests.post(url=url, data=m, headers=headers)
409
-
410
- fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
411
- path = f"{abs_path}/{fname}"
412
-
413
- with open(path, "wb") as f:
414
- f.write(res.content)
415
- print(path)
416
- return path
417
-
418
-
419
- def voice_ssml(ssml):
420
- fields = {
421
- "ssml": ssml,
422
- }
423
- boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
424
-
425
- m = MultipartEncoder(fields=fields, boundary=boundary)
426
- headers = {"Content-Type": m.content_type}
427
- url = f"{base}/voice/ssml"
428
-
429
- res = requests.post(url=url, data=m, headers=headers)
430
- fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
431
- path = f"{abs_path}/{fname}"
432
-
433
- with open(path, "wb") as f:
434
- f.write(res.content)
435
- print(path)
436
- return path
437
-
438
- def voice_dimensional_emotion(upload_path):
439
- upload_name = os.path.basename(upload_path)
440
- upload_type = f'audio/{upload_name.split(".")[1]}' # wav,ogg
441
-
442
- with open(upload_path, 'rb') as upload_file:
443
- fields = {
444
- "upload": (upload_name, upload_file, upload_type),
445
- }
446
- boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
447
-
448
- m = MultipartEncoder(fields=fields, boundary=boundary)
449
- headers = {"Content-Type": m.content_type}
450
- url = f"{base}/voice/dimension-emotion"
451
-
452
- res = requests.post(url=url, data=m, headers=headers)
453
- fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
454
- path = f"{abs_path}/{fname}"
455
-
456
- with open(path, "wb") as f:
457
- f.write(res.content)
458
- print(path)
459
- return path
460
- ```
461
 
462
  ## API KEY
463
 
@@ -615,4 +429,5 @@ def voice_dimensional_emotion(upload_path):
615
  - emotional-vits:https://github.com/innnky/emotional-vits
616
  - vits-uma-genshin-honkai:https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai
617
  - vits_chinese:https://github.com/PlayVoice/vits_chinese
 
618
 
 
5
  <br/>
6
  <p>
7
  <img src="https://img.shields.io/github/license/Artrajz/vits-simple-api">
8
+ <img src="https://img.shields.io/badge/python-3.10-green">
9
  <a href="https://hub.docker.com/r/artrajz/vits-simple-api">
10
  <img src="https://img.shields.io/docker/pulls/artrajz/vits-simple-api"></a>
11
  </p>
 
16
 
17
 
18
 
19
+
20
  # Feature
21
 
22
  - [x] VITS语音合成,语音转换
 
49
 
50
  # 部署
51
 
52
+ ## Docker部署(Linux推荐)
53
 
54
  ### 镜像拉取脚本
55
 
 
57
  bash -c "$(wget -O- https://raw.githubusercontent.com/Artrajz/vits-simple-api/main/vits-simple-api-installer-latest.sh)"
58
  ```
59
 
60
+ - 目前docker镜像支持的平台`linux/amd64,linux/arm64`(arm64仅有CPU版本)
61
  - 在拉取完成后,需要导入VITS模型才能使用,请根据以下步骤导入模型。
62
 
63
  ### 下载VITS模型
 
138
 
139
  ### 下载python依赖
140
 
141
+ 推荐使用python的虚拟环境
142
 
143
  `pip install -r requirements.txt`
144
 
145
  windows下可能安装不了fasttext,可以用以下命令安装,附[wheels下载地址](https://www.lfd.uci.edu/~gohlke/pythonlibs/#fasttext)
146
 
147
  ```
148
+ # python3.10 win_amd64
149
  pip install https://github.com/Artrajz/archived/raw/main/fasttext/fasttext-0.9.2-cp310-cp310-win_amd64.whl
 
 
150
  ```
151
 
152
  ### 下载VITS模型
 
221
 
222
  ### 安装GPU版pytorch
223
 
224
+ CUDA11.7对应的pytorch是用这个命令安装,推荐使用1.13.1+cu117,其他版本可能存在内存不稳定的问题。
225
 
226
  ```
227
+ pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
228
  ```
229
 
 
 
230
  ## Linux
231
 
232
  安装过程类似,但我没有相应的环境所以没办法测试
233
 
234
+ # 依赖安装问题
235
 
236
+ 由于pypi.org没有pyopenjtalk的whl文件,通常需要从源代码来安装,这一过程对于一些人来说可能比较麻烦,所以你也可以使用我构建的whl来安装。
237
 
238
  ```
239
+ pip install pyopenjtalk -i https://pypi.artrajz.cn/simple
240
  ```
241
 
 
 
242
  # API
243
 
244
  ## GET
 
269
 
270
  ## POST
271
 
272
+ - 见`api_test.py`
273
+
274
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  ## API KEY
277
 
 
429
  - emotional-vits:https://github.com/innnky/emotional-vits
430
  - vits-uma-genshin-honkai:https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai
431
  - vits_chinese:https://github.com/PlayVoice/vits_chinese
432
+ - Bert_VITS2:https://github.com/fishaudio/Bert-VITS2
433
 
api_test.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import requests
4
+ import os
5
+ import time
6
+ import random
7
+ import string
8
+ from requests_toolbelt.multipart.encoder import MultipartEncoder
9
+
10
+ absolute_path = os.path.dirname(__file__)
11
+ base_url = "http://127.0.0.1:23456"
12
+
13
+
14
+ # 映射表
15
+ def voice_speakers():
16
+ url = f"{base_url}/voice/speakers"
17
+
18
+ res = requests.post(url=url)
19
+ json = res.json()
20
+ for i in json:
21
+ print(i)
22
+ for j in json[i]:
23
+ print(j)
24
+ return json
25
+
26
+
27
+ # 语音合成 voice vits
28
+ def voice_vits(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50, save_audio=True,
29
+ save_path=None):
30
+ fields = {
31
+ "text": text,
32
+ "id": str(id),
33
+ "format": format,
34
+ "lang": lang,
35
+ "length": str(length),
36
+ "noise": str(noise),
37
+ "noisew": str(noisew),
38
+ "max": str(max)
39
+ }
40
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
41
+
42
+ m = MultipartEncoder(fields=fields, boundary=boundary)
43
+ headers = {"Content-Type": m.content_type}
44
+ url = f"{base_url}/voice/vits"
45
+
46
+ res = requests.post(url=url, data=m, headers=headers)
47
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
48
+ if save_path is not None:
49
+ path = os.path.join(save_path, fname)
50
+ else:
51
+ path = os.path.join(absolute_path, fname)
52
+ if save_audio:
53
+ with open(path, "wb") as f:
54
+ f.write(res.content)
55
+ print(path)
56
+ return path
57
+ return None
58
+
59
+
60
+ def voice_vits_streaming(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50,
61
+ save_audio=True, save_path=None):
62
+ fields = {
63
+ "text": text,
64
+ "id": str(id),
65
+ "format": format,
66
+ "lang": lang,
67
+ "length": str(length),
68
+ "noise": str(noise),
69
+ "noisew": str(noisew),
70
+ "max": str(max),
71
+ "streaming": 'True'
72
+ }
73
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
74
+
75
+ m = MultipartEncoder(fields=fields, boundary=boundary)
76
+ headers = {"Content-Type": m.content_type}
77
+ url = f"{base_url}/voice"
78
+
79
+ res = requests.post(url=url, data=m, headers=headers)
80
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
81
+ if save_path is not None:
82
+ path = os.path.join(save_path, fname)
83
+ else:
84
+ path = os.path.join(absolute_path, fname)
85
+ if save_audio:
86
+ with open(path, "wb") as f:
87
+ f.write(res.content)
88
+ print(path)
89
+ return path
90
+ return None
91
+
92
+
93
+ def voice_vits_streaming(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50,
94
+ save_path=None):
95
+ fields = {
96
+ "text": text,
97
+ "id": str(id),
98
+ "format": format,
99
+ "lang": lang,
100
+ "length": str(length),
101
+ "noise": str(noise),
102
+ "noisew": str(noisew),
103
+ "max": str(max),
104
+ "streaming": 'True'
105
+ }
106
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
107
+
108
+ m = MultipartEncoder(fields=fields, boundary=boundary)
109
+ headers = {"Content-Type": m.content_type}
110
+ url = f"{base_url}/voice"
111
+
112
+ res = requests.post(url=url, data=m, headers=headers, stream=True)
113
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
114
+ if save_path is not None:
115
+ path = os.path.join(save_path, fname)
116
+ else:
117
+ path = os.path.join(absolute_path, fname)
118
+ audio = res.content
119
+
120
+ def get_file_size_from_bytes(byte_data):
121
+ file_size_offset = 4
122
+ file_size_length = 4
123
+
124
+ try:
125
+ file_size_bytes = byte_data[file_size_offset:file_size_offset + file_size_length]
126
+ file_size = int.from_bytes(file_size_bytes, byteorder='little')
127
+ return file_size + 8
128
+ except IndexError:
129
+ return None
130
+
131
+ audio = None
132
+ p = 0
133
+ audio_size = None
134
+ audios = []
135
+
136
+ for chunk in res.iter_content(chunk_size=1024):
137
+ if audio is None:
138
+ audio = chunk
139
+ else:
140
+ audio += chunk
141
+
142
+ p += len(chunk)
143
+ if audio_size is not None:
144
+ if p >= audio_size:
145
+ p = p - audio_size
146
+ audios.append(audio[:audio_size])
147
+ audio = audio[audio_size:]
148
+ audio_size = get_file_size_from_bytes(audio)
149
+ else:
150
+ audio_size = get_file_size_from_bytes(audio)
151
+ for i, audio in enumerate(audios):
152
+ with open(f"{path[:-4]}-{i}.wav", "wb") as f:
153
+ f.write(audio)
154
+
155
+ print(f"{path[:-4]}-{i}.wav")
156
+ return path
157
+
158
+
159
+ # 语音转换 hubert-vits
160
+ def voice_hubert_vits(upload_path, id, format="wav", length=1, noise=0.667, noisew=0.8, save_audio=True,
161
+ save_path=None):
162
+ upload_name = os.path.basename(upload_path)
163
+ upload_type = f'audio/{upload_name.split(".")[1]}' # wav,ogg
164
+
165
+ with open(upload_path, 'rb') as upload_file:
166
+ fields = {
167
+ "upload": (upload_name, upload_file, upload_type),
168
+ "id": str(id),
169
+ "format": format,
170
+ "length": str(length),
171
+ "noise": str(noise),
172
+ "noisew": str(noisew),
173
+ }
174
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
175
+
176
+ m = MultipartEncoder(fields=fields, boundary=boundary)
177
+ headers = {"Content-Type": m.content_type}
178
+ url = f"{base_url}/voice/hubert-vits"
179
+
180
+ res = requests.post(url=url, data=m, headers=headers)
181
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
182
+ if save_path is not None:
183
+ path = os.path.join(save_path, fname)
184
+ else:
185
+ path = os.path.join(absolute_path, fname)
186
+ if save_audio:
187
+ with open(path, "wb") as f:
188
+ f.write(res.content)
189
+ print(path)
190
+ return path
191
+ return None
192
+
193
+
194
+ # 维度情感模型 w2v2-vits
195
+ def voice_w2v2_vits(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50, emotion=0,
196
+ save_audio=True, save_path=None):
197
+ fields = {
198
+ "text": text,
199
+ "id": str(id),
200
+ "format": format,
201
+ "lang": lang,
202
+ "length": str(length),
203
+ "noise": str(noise),
204
+ "noisew": str(noisew),
205
+ "max": str(max),
206
+ "emotion": str(emotion)
207
+ }
208
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
209
+
210
+ m = MultipartEncoder(fields=fields, boundary=boundary)
211
+ headers = {"Content-Type": m.content_type}
212
+ url = f"{base_url}/voice/w2v2-vits"
213
+
214
+ res = requests.post(url=url, data=m, headers=headers)
215
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
216
+ if save_path is not None:
217
+ path = os.path.join(save_path, fname)
218
+ else:
219
+ path = os.path.join(absolute_path, fname)
220
+ if save_audio:
221
+ with open(path, "wb") as f:
222
+ f.write(res.content)
223
+ print(path)
224
+ return path
225
+ return None
226
+
227
+
228
+ # 语音转换 同VITS模型内角色之间的音色转换
229
+ def voice_conversion(upload_path, original_id, target_id, save_audio=True, save_path=None):
230
+ upload_name = os.path.basename(upload_path)
231
+ upload_type = f'audio/{upload_name.split(".")[1]}' # wav,ogg
232
+
233
+ with open(upload_path, 'rb') as upload_file:
234
+ fields = {
235
+ "upload": (upload_name, upload_file, upload_type),
236
+ "original_id": str(original_id),
237
+ "target_id": str(target_id),
238
+ }
239
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
240
+ m = MultipartEncoder(fields=fields, boundary=boundary)
241
+
242
+ headers = {"Content-Type": m.content_type}
243
+ url = f"{base_url}/voice/conversion"
244
+
245
+ res = requests.post(url=url, data=m, headers=headers)
246
+
247
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
248
+ if save_path is not None:
249
+ path = os.path.join(save_path, fname)
250
+ else:
251
+ path = os.path.join(absolute_path, fname)
252
+
253
+ if save_audio:
254
+ with open(path, "wb") as f:
255
+ f.write(res.content)
256
+ print(path)
257
+ return path
258
+ return None
259
+
260
+
261
+ def voice_ssml(ssml, save_audio=True, save_path=None):
262
+ fields = {
263
+ "ssml": ssml,
264
+ }
265
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
266
+
267
+ m = MultipartEncoder(fields=fields, boundary=boundary)
268
+ headers = {"Content-Type": m.content_type}
269
+ url = f"{base_url}/voice/ssml"
270
+
271
+ res = requests.post(url=url, data=m, headers=headers)
272
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
273
+ if save_path is not None:
274
+ path = os.path.join(save_path, fname)
275
+ else:
276
+ path = os.path.join(absolute_path, fname)
277
+
278
+ if save_audio:
279
+ with open(path, "wb") as f:
280
+ f.write(res.content)
281
+ print(path)
282
+ return path
283
+ return None
284
+
285
+
286
+ def voice_dimensional_emotion(upload_path, save_audio=True,
287
+ save_path=None):
288
+ upload_name = os.path.basename(upload_path)
289
+ upload_type = f'audio/{upload_name.split(".")[1]}' # wav,ogg
290
+
291
+ with open(upload_path, 'rb') as upload_file:
292
+ fields = {
293
+ "upload": (upload_name, upload_file, upload_type),
294
+ }
295
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
296
+
297
+ m = MultipartEncoder(fields=fields, boundary=boundary)
298
+ headers = {"Content-Type": m.content_type}
299
+ url = f"{base_url}/voice/dimension-emotion"
300
+
301
+ res = requests.post(url=url, data=m, headers=headers)
302
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
303
+ if save_path is not None:
304
+ path = os.path.join(save_path, fname)
305
+ else:
306
+ path = os.path.join(absolute_path, fname)
307
+ if save_audio:
308
+ with open(path, "wb") as f:
309
+ f.write(res.content)
310
+ print(path)
311
+ return path
312
+ return None
313
+
314
+
315
+ def vits_json(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50,
316
+ save_path=None):
317
+ fields = {
318
+ "text": text,
319
+ "id": str(id),
320
+ "format": format,
321
+ "lang": lang,
322
+ "length": str(length),
323
+ "noise": str(noise),
324
+ "noisew": str(noisew),
325
+ "max": str(max)
326
+ }
327
+ f = json.dumps(fields)
328
+ url = f"{base_url}/voice"
329
+ header = {"Content-Type": 'application/json'}
330
+ res = requests.post(url=url, data=f, headers=header)
331
+
332
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
333
+ if save_path is not None:
334
+ path = os.path.join(save_path, fname)
335
+ else:
336
+ path = os.path.join(absolute_path, fname)
337
+
338
+ with open(path, "wb") as f:
339
+ f.write(res.content)
340
+ print(path)
341
+ return path
342
+
343
+
344
+ # Bert_vits2
345
+ def voice_bert_vits2(text, id=0, format="wav", lang="auto", length=1, noise=0.667, noisew=0.8, max=50, sdp_ratio=0.2,
346
+ save_audio=True, save_path=None):
347
+ fields = {
348
+ "text": text,
349
+ "id": str(id),
350
+ "format": format,
351
+ "lang": lang,
352
+ "length": str(length),
353
+ "noise": str(noise),
354
+ "noisew": str(noisew),
355
+ "max": str(max),
356
+ "sdp_ratio": str(sdp_ratio)
357
+ }
358
+ boundary = '----VoiceConversionFormBoundary' + ''.join(random.sample(string.ascii_letters + string.digits, 16))
359
+
360
+ m = MultipartEncoder(fields=fields, boundary=boundary)
361
+ headers = {"Content-Type": m.content_type}
362
+ url = f"{base_url}/voice/bert-vits2"
363
+
364
+ res = requests.post(url=url, data=m, headers=headers)
365
+ fname = re.findall("filename=(.+)", res.headers["Content-Disposition"])[0]
366
+ if save_path is not None:
367
+ path = os.path.join(save_path, fname)
368
+ else:
369
+ path = os.path.join(absolute_path, fname)
370
+ if save_audio:
371
+ with open(path, "wb") as f:
372
+ f.write(res.content)
373
+ print(path)
374
+ return path
375
+ return None
376
+
377
+
378
+ def test_interface(text):
379
+ error_num = 0
380
+ for i in range(100):
381
+ try:
382
+ time.sleep(1)
383
+ t1 = time.time()
384
+ voice_vits(text, format="wav", lang="zh", save_audio=False)
385
+ t2 = time.time()
386
+ print(f"{i}:len:{len(text)}耗时:{t2 - t1}")
387
+ except Exception as e:
388
+ error_num += 1
389
+ print(e)
390
+ print(f"error_num={error_num}")
391
+
392
+
393
+ if __name__ == '__main__':
394
+ text = "你好,こんにちは"
395
+
396
+ ssml = """
397
+ <speak lang="zh" format="mp3" length="1.2">
398
+ <voice id="92" >这几天心里颇不宁静。</voice>
399
+ <voice id="125">今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。</voice>
400
+ <voice id="142">月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;</voice>
401
+ <voice id="98">妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。</voice>
402
+ <voice id="120">我悄悄地披了大衫,带上门出去。</voice><break time="2s"/>
403
+ <voice id="121">沿着荷塘,是一条曲折的小煤屑路。</voice>
404
+ <voice id="122">这是一条幽僻的路;白天也少人走,夜晚更加寂寞。</voice>
405
+ <voice id="123">荷塘四面,长着许多树,蓊蓊郁郁的。</voice>
406
+ <voice id="124">路的一旁,是些杨柳,和一些不知道名字的树。</voice>
407
+ <voice id="125">没有月光的晚上,这路上阴森森的,有些怕人。</voice>
408
+ <voice id="126">今晚却很好,虽然月光也还是淡淡的。</voice><break time="2s"/>
409
+ <voice id="127">路上只我一个人,背着手踱着。</voice>
410
+ <voice id="128">这一片天地好像是我的;我也像超出了平常的自己,到了另一个世界里。</voice>
411
+ <voice id="129">我爱热闹,也爱冷静;<break strength="x-weak"/>爱群居,也爱独处。</voice>
412
+ <voice id="130">像今晚上,一个人在这苍茫的月下,什么都可以想,什么都可以不想,便觉是个自由的人。</voice>
413
+ <voice id="131">白天里一定要做的事,一定要说的话,现在都可不理。</voice>
414
+ <voice id="132">这是独处的妙处,我且受用这无边的荷香月色好了。</voice>
415
+ </speak>
416
+ """
417
+
418
+ from config import CACHE_PATH
419
+
420
+ path = voice_vits(text, save_path=CACHE_PATH)
421
+ voice_vits_streaming(text, save_path=CACHE_PATH)
422
+ voice_w2v2_vits(text, save_path=CACHE_PATH)
423
+ voice_conversion(path, 1, 3, save_path=CACHE_PATH)
424
+ voice_hubert_vits(path, 0, save_path=CACHE_PATH)
425
+ voice_dimensional_emotion(path, save_path=CACHE_PATH)
426
+ voice_ssml(ssml, save_path=CACHE_PATH)
427
+ voice_bert_vits2("你好",lang="zh", save_path=CACHE_PATH)
428
+ voice_bert_vits2("こんにちは", lang="ja", save_path=CACHE_PATH)
429
+ # os.system(path)
app.py CHANGED
@@ -6,8 +6,8 @@ from flask import Flask, request, send_file, jsonify, make_response, render_temp
6
  from werkzeug.utils import secure_filename
7
  from flask_apscheduler import APScheduler
8
  from functools import wraps
9
- from utils.utils import clean_folder, check_is_none
10
- from utils.merge import merge_model
11
  from io import BytesIO
12
 
13
  app = Flask(__name__)
@@ -25,7 +25,7 @@ for path in (app.config['LOGS_PATH'], app.config['UPLOAD_FOLDER'], app.config['C
25
  logger.error(f"Unable to create directory {path}: {str(e)}")
26
 
27
  # load model
28
- tts = merge_model(app.config["MODEL_LIST"])
29
 
30
 
31
  def require_api_key(func):
@@ -67,30 +67,23 @@ def voice_speakers_api():
67
  def voice_vits_api():
68
  try:
69
  if request.method == "GET":
70
- text = request.args.get("text", "")
71
- id = int(request.args.get("id", app.config.get("ID", 0)))
72
- format = request.args.get("format", app.config.get("FORMAT", "wav"))
73
- lang = request.args.get("lang", app.config.get("LANG", "auto"))
74
- length = float(request.args.get("length", app.config.get("LENGTH", 1)))
75
- noise = float(request.args.get("noise", app.config.get("NOISE", 0.667)))
76
- noisew = float(request.args.get("noisew", app.config.get("NOISEW", 0.8)))
77
- max = int(request.args.get("max", app.config.get("MAX", 50)))
78
- use_streaming = request.args.get('streaming', False, type=bool)
79
  elif request.method == "POST":
80
  content_type = request.headers.get('Content-Type')
81
  if content_type == 'application/json':
82
- data = request.get_json()
83
  else:
84
- data = request.form
85
- text = data.get("text", "")
86
- id = int(data.get("id", app.config.get("ID", 0)))
87
- format = data.get("format", app.config.get("FORMAT", "wav"))
88
- lang = data.get("lang", app.config.get("LANG", "auto"))
89
- length = float(data.get("length", app.config.get("LENGTH", 1)))
90
- noise = float(data.get("noise", app.config.get("NOISE", 0.667)))
91
- noisew = float(data.get("noisew", app.config.get("NOISEW", 0.8)))
92
- max = int(data.get("max", app.config.get("MAX", 50)))
93
- use_streaming = request.form.get('streaming', False, type=bool)
 
94
  except Exception as e:
95
  logger.error(f"[VITS] {e}")
96
  return make_response("parameter error", 400)
@@ -112,7 +105,7 @@ def voice_vits_api():
112
 
113
  # 校验模型是否支持输入的语言
114
  speaker_lang = tts.voice_speakers["VITS"][id].get('lang')
115
- if lang.upper() != "AUTO" and lang.upper() != "MIX" and len(speaker_lang) != 1 and lang not in speaker_lang:
116
  logger.info(f"[VITS] lang \"{lang}\" is not in {speaker_lang}")
117
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
118
 
@@ -136,20 +129,23 @@ def voice_vits_api():
136
  "lang": lang,
137
  "speaker_lang": speaker_lang}
138
 
139
- if app.config.get("SAVE_AUDIO", False):
140
- logger.debug(f"[VITS] {fname}")
141
-
142
  if use_streaming:
143
- audio = tts.stream_vits_infer(task, fname)
144
  response = make_response(audio)
145
  response.headers['Content-Disposition'] = f'attachment; filename={fname}'
146
  response.headers['Content-Type'] = file_type
147
  return response
148
  else:
149
  t1 = time.time()
150
- audio = tts.vits_infer(task, fname)
151
  t2 = time.time()
152
  logger.info(f"[VITS] finish in {(t2 - t1):.2f}s")
 
 
 
 
 
 
153
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
154
 
155
 
@@ -191,11 +187,15 @@ def voice_hubert_api():
191
  "audio_path": os.path.join(app.config['UPLOAD_FOLDER'], fname)}
192
 
193
  t1 = time.time()
194
- audio = tts.hubert_vits_infer(task, fname)
195
  t2 = time.time()
 
 
196
  if app.config.get("SAVE_AUDIO", False):
197
  logger.debug(f"[hubert] {fname}")
198
- logger.info(f"[hubert] finish in {(t2 - t1):.2f}s")
 
 
199
  if use_streaming:
200
  audio = tts.generate_audio_chunks(audio)
201
  response = make_response(audio)
@@ -211,32 +211,24 @@ def voice_hubert_api():
211
  def voice_w2v2_api():
212
  try:
213
  if request.method == "GET":
214
- text = request.args.get("text", "")
215
- id = int(request.args.get("id", app.config.get("ID", 0)))
216
- format = request.args.get("format", app.config.get("FORMAT", "wav"))
217
- lang = request.args.get("lang", app.config.get("LANG", "auto"))
218
- length = float(request.args.get("length", app.config.get("LENGTH", 1)))
219
- noise = float(request.args.get("noise", app.config.get("NOISE", 0.667)))
220
- noisew = float(request.args.get("noisew", app.config.get("NOISEW", 0.8)))
221
- max = int(request.args.get("max", app.config.get("MAX", 50)))
222
- emotion = int(request.args.get("emotion", app.config.get("EMOTION", 0)))
223
- use_streaming = request.args.get('streaming', False, type=bool)
224
  elif request.method == "POST":
225
  content_type = request.headers.get('Content-Type')
226
  if content_type == 'application/json':
227
- data = request.get_json()
228
  else:
229
- data = request.form
230
- text = data.get("text", "")
231
- id = int(data.get("id", app.config.get("ID", 0)))
232
- format = data.get("format", app.config.get("FORMAT", "wav"))
233
- lang = data.get("lang", app.config.get("LANG", "auto"))
234
- length = float(data.get("length"))
235
- noise = float(data.get("noise", app.config.get("NOISE", 0.667)))
236
- noisew = float(data.get("noisew", app.config.get("NOISEW", 0.8)))
237
- max = int(data.get("max", app.config.get("MAX", 50)))
238
- emotion = int(data.get("emotion", app.config.get("EMOTION", 0)))
239
- use_streaming = request.form.get('streaming', False, type=bool)
 
240
  except Exception as e:
241
  logger.error(f"[w2v2] {e}")
242
  return make_response(f"parameter error", 400)
@@ -259,7 +251,7 @@ def voice_w2v2_api():
259
 
260
  # 校验模型是否支持输入的语言
261
  speaker_lang = tts.voice_speakers["W2V2-VITS"][id].get('lang')
262
- if lang.upper() != "AUTO" and lang.upper() != "MIX" and len(speaker_lang) != 1 and lang not in speaker_lang:
263
  logger.info(f"[w2v2] lang \"{lang}\" is not in {speaker_lang}")
264
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
265
 
@@ -285,10 +277,15 @@ def voice_w2v2_api():
285
  "speaker_lang": speaker_lang}
286
 
287
  t1 = time.time()
288
- audio = tts.w2v2_vits_infer(task, fname)
289
  t2 = time.time()
 
 
290
  if app.config.get("SAVE_AUDIO", False):
291
- logger.debug(f"[W2V2] {fname}")
 
 
 
292
  if use_streaming:
293
  audio = tts.generate_audio_chunks(audio)
294
  response = make_response(audio)
@@ -296,7 +293,6 @@ def voice_w2v2_api():
296
  response.headers['Content-Type'] = file_type
297
  return response
298
  else:
299
- logger.info(f"[w2v2] finish in {(t2 - t1):.2f}s")
300
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
301
 
302
 
@@ -326,11 +322,15 @@ def vits_voice_conversion_api():
326
  "format": format}
327
 
328
  t1 = time.time()
329
- audio = tts.vits_voice_conversion(task, fname)
330
  t2 = time.time()
 
 
331
  if app.config.get("SAVE_AUDIO", False):
332
  logger.debug(f"[Voice conversion] {fname}")
333
- logger.info(f"[Voice conversion] finish in {(t2 - t1):.2f}s")
 
 
334
  if use_streaming:
335
  audio = tts.generate_audio_chunks(audio)
336
  response = make_response(audio)
@@ -343,14 +343,15 @@ def vits_voice_conversion_api():
343
 
344
  @app.route('/voice/ssml', methods=["POST"])
345
  @require_api_key
346
- def ssml():
347
  try:
348
  content_type = request.headers.get('Content-Type')
349
  if content_type == 'application/json':
350
- data = request.get_json()
351
  else:
352
- data = request.form
353
- ssml = data.get("ssml")
 
354
  except Exception as e:
355
  logger.info(f"[ssml] {e}")
356
  return make_response(jsonify({"status": "error", "message": f"parameter error"}), 400)
@@ -361,11 +362,14 @@ def ssml():
361
  file_type = f"audio/{format}"
362
 
363
  t1 = time.time()
364
- audio = tts.create_ssml_infer_task(voice_tasks, format, fname)
365
  t2 = time.time()
 
 
366
  if app.config.get("SAVE_AUDIO", False):
367
  logger.debug(f"[ssml] {fname}")
368
- logger.info(f"[ssml] finish in {(t2 - t1):.2f}s")
 
369
 
370
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
371
 
@@ -385,15 +389,15 @@ def dimensional_emotion():
385
 
386
  file_type = "application/octet-stream; charset=ascii"
387
  fname = os.path.splitext(audio.filename)[0] + ".npy"
388
- audio = tts.get_dimensional_emotion_npy(content)
389
  if use_streaming:
390
- audio = tts.generate_audio_chunks(audio)
391
- response = make_response(audio)
392
  response.headers['Content-Disposition'] = f'attachment; filename={fname}'
393
  response.headers['Content-Type'] = file_type
394
  return response
395
  else:
396
- return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
397
 
398
 
399
  @app.route('/voice/bert-vits2', methods=["GET", "POST"])
@@ -401,37 +405,29 @@ def dimensional_emotion():
401
  def voice_bert_vits2_api():
402
  try:
403
  if request.method == "GET":
404
- text = request.args.get("text", "")
405
- id = int(request.args.get("id", app.config.get("ID", 0)))
406
- format = request.args.get("format", app.config.get("FORMAT", "wav"))
407
- # lang = request.args.get("lang", app.config.get("LANG", "auto"))
408
- lang = "ZH"
409
- length = float(request.args.get("length", app.config.get("LENGTH", 1)))
410
- noise = float(request.args.get("noise", app.config.get("NOISE", 0.5)))
411
- noisew = float(request.args.get("noisew", app.config.get("NOISEW", 0.6)))
412
- sdp_ratio = float(request.args.get("sdp_ratio", 0.2))
413
- max = int(request.args.get("max", app.config.get("MAX", 50)))
414
  elif request.method == "POST":
415
  content_type = request.headers.get('Content-Type')
416
  if content_type == 'application/json':
417
- data = request.get_json()
418
  else:
419
- data = request.form
420
- text = data.get("text", "")
421
- id = int(data.get("id", app.config.get("ID", 0)))
422
- format = data.get("format", app.config.get("FORMAT", "wav"))
423
- # lang = data.get("lang", app.config.get("LANG", "auto"))
424
- lang = "ZH"
425
- length = float(data.get("length", app.config.get("LENGTH", 1)))
426
- noise = float(data.get("noise", app.config.get("NOISE", 0.667)))
427
- noisew = float(data.get("noisew", app.config.get("NOISEW", 0.8)))
428
- sdp_ratio = float(data.get("noisew", app.config.get("SDP_RATIO", 0.2)))
429
- max = int(data.get("max", app.config.get("MAX", 50)))
430
  except Exception as e:
431
  logger.error(f"[Bert-VITS2] {e}")
432
  return make_response("parameter error", 400)
433
 
434
- logger.info(f"[Bert-VITS2] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew} sdp_ratio:{sdp_ratio}")
 
435
  logger.info(f"[Bert-VITS2] len:{len(text)} text:{text}")
436
 
437
  if check_is_none(text):
@@ -448,7 +444,7 @@ def voice_bert_vits2_api():
448
 
449
  # 校验模型是否支持输入的语言
450
  speaker_lang = tts.voice_speakers["BERT-VITS2"][id].get('lang')
451
- if lang.upper() != "AUTO" and lang.upper() != "MIX" and len(speaker_lang) != 1 and lang not in speaker_lang:
452
  logger.info(f"[Bert-VITS2] lang \"{lang}\" is not in {speaker_lang}")
453
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
454
 
@@ -469,14 +465,16 @@ def voice_bert_vits2_api():
469
  "lang": lang,
470
  "speaker_lang": speaker_lang}
471
 
472
- if app.config.get("SAVE_AUDIO", False):
473
- logger.debug(f"[Bert-VITS2] {fname}")
474
-
475
-
476
  t1 = time.time()
477
- audio = tts.bert_vits2_infer(task, fname)
478
  t2 = time.time()
479
  logger.info(f"[Bert-VITS2] finish in {(t2 - t1):.2f}s")
 
 
 
 
 
 
480
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
481
 
482
 
@@ -484,16 +482,16 @@ def voice_bert_vits2_api():
484
  def check():
485
  try:
486
  if request.method == "GET":
487
- model = request.args.get("model")
488
- id = int(request.args.get("id"))
489
  elif request.method == "POST":
490
  content_type = request.headers.get('Content-Type')
491
  if content_type == 'application/json':
492
- data = request.get_json()
493
  else:
494
- data = request.form
495
- model = data.get("model")
496
- id = int(data.get("id"))
 
497
  except Exception as e:
498
  logger.info(f"[check] {e}")
499
  return make_response(jsonify({"status": "error", "message": "parameter error"}), 400)
 
6
  from werkzeug.utils import secure_filename
7
  from flask_apscheduler import APScheduler
8
  from functools import wraps
9
+ from utils.data_utils import save_audio, clean_folder, check_is_none
10
+ from utils.load_model import load_model
11
  from io import BytesIO
12
 
13
  app = Flask(__name__)
 
25
  logger.error(f"Unable to create directory {path}: {str(e)}")
26
 
27
  # load model
28
+ tts = load_model(app.config["MODEL_LIST"])
29
 
30
 
31
  def require_api_key(func):
 
67
  def voice_vits_api():
68
  try:
69
  if request.method == "GET":
70
+ request_data = request.args
 
 
 
 
 
 
 
 
71
  elif request.method == "POST":
72
  content_type = request.headers.get('Content-Type')
73
  if content_type == 'application/json':
74
+ request_data = request.get_json()
75
  else:
76
+ request_data = request.form
77
+
78
+ text = request_data.get("text", "")
79
+ id = int(request_data.get("id", app.config.get("ID", 0)))
80
+ format = request_data.get("format", app.config.get("FORMAT", "wav"))
81
+ lang = request_data.get("lang", app.config.get("LANG", "auto")).lower()
82
+ length = float(request_data.get("length", app.config.get("LENGTH", 1)))
83
+ noise = float(request_data.get("noise", app.config.get("NOISE", 0.667)))
84
+ noisew = float(request_data.get("noisew", app.config.get("NOISEW", 0.8)))
85
+ max = int(request_data.get("max", app.config.get("MAX", 50)))
86
+ use_streaming = request_data.get('streaming', False, type=bool)
87
  except Exception as e:
88
  logger.error(f"[VITS] {e}")
89
  return make_response("parameter error", 400)
 
105
 
106
  # 校验模型是否支持输入的语言
107
  speaker_lang = tts.voice_speakers["VITS"][id].get('lang')
108
+ if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
109
  logger.info(f"[VITS] lang \"{lang}\" is not in {speaker_lang}")
110
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
111
 
 
129
  "lang": lang,
130
  "speaker_lang": speaker_lang}
131
 
 
 
 
132
  if use_streaming:
133
+ audio = tts.stream_vits_infer(task)
134
  response = make_response(audio)
135
  response.headers['Content-Disposition'] = f'attachment; filename={fname}'
136
  response.headers['Content-Type'] = file_type
137
  return response
138
  else:
139
  t1 = time.time()
140
+ audio = tts.vits_infer(task)
141
  t2 = time.time()
142
  logger.info(f"[VITS] finish in {(t2 - t1):.2f}s")
143
+
144
+ if app.config.get("SAVE_AUDIO", False):
145
+ logger.debug(f"[VITS] {fname}")
146
+ path = os.path.join(app.config.get('CACHE_PATH'), fname)
147
+ save_audio(audio.getvalue(), path)
148
+
149
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
150
 
151
 
 
187
  "audio_path": os.path.join(app.config['UPLOAD_FOLDER'], fname)}
188
 
189
  t1 = time.time()
190
+ audio = tts.hubert_vits_infer(task)
191
  t2 = time.time()
192
+ logger.info(f"[hubert] finish in {(t2 - t1):.2f}s")
193
+
194
  if app.config.get("SAVE_AUDIO", False):
195
  logger.debug(f"[hubert] {fname}")
196
+ path = os.path.join(app.config.get('CACHE_PATH'), fname)
197
+ save_audio(audio.getvalue(), path)
198
+
199
  if use_streaming:
200
  audio = tts.generate_audio_chunks(audio)
201
  response = make_response(audio)
 
211
  def voice_w2v2_api():
212
  try:
213
  if request.method == "GET":
214
+ request_data = request.args
 
 
 
 
 
 
 
 
 
215
  elif request.method == "POST":
216
  content_type = request.headers.get('Content-Type')
217
  if content_type == 'application/json':
218
+ request_data = request.get_json()
219
  else:
220
+ request_data = request.form
221
+
222
+ text = request_data.get("text", "")
223
+ id = int(request_data.get("id", app.config.get("ID", 0)))
224
+ format = request_data.get("format", app.config.get("FORMAT", "wav"))
225
+ lang = request_data.get("lang", app.config.get("LANG", "auto")).lower()
226
+ length = float(request_data.get("length", app.config.get("LENGTH", 1)))
227
+ noise = float(request_data.get("noise", app.config.get("NOISE", 0.667)))
228
+ noisew = float(request_data.get("noisew", app.config.get("NOISEW", 0.8)))
229
+ max = int(request_data.get("max", app.config.get("MAX", 50)))
230
+ emotion = int(request_data.get("emotion", app.config.get("EMOTION", 0)))
231
+ use_streaming = request_data.get('streaming', False, type=bool)
232
  except Exception as e:
233
  logger.error(f"[w2v2] {e}")
234
  return make_response(f"parameter error", 400)
 
251
 
252
  # 校验模型是否支持输入的语言
253
  speaker_lang = tts.voice_speakers["W2V2-VITS"][id].get('lang')
254
+ if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
255
  logger.info(f"[w2v2] lang \"{lang}\" is not in {speaker_lang}")
256
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
257
 
 
277
  "speaker_lang": speaker_lang}
278
 
279
  t1 = time.time()
280
+ audio = tts.w2v2_vits_infer(task)
281
  t2 = time.time()
282
+ logger.info(f"[w2v2] finish in {(t2 - t1):.2f}s")
283
+
284
  if app.config.get("SAVE_AUDIO", False):
285
+ logger.debug(f"[w2v2] {fname}")
286
+ path = os.path.join(app.config.get('CACHE_PATH'), fname)
287
+ save_audio(audio.getvalue(), path)
288
+
289
  if use_streaming:
290
  audio = tts.generate_audio_chunks(audio)
291
  response = make_response(audio)
 
293
  response.headers['Content-Type'] = file_type
294
  return response
295
  else:
 
296
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
297
 
298
 
 
322
  "format": format}
323
 
324
  t1 = time.time()
325
+ audio = tts.vits_voice_conversion(task)
326
  t2 = time.time()
327
+ logger.info(f"[Voice conversion] finish in {(t2 - t1):.2f}s")
328
+
329
  if app.config.get("SAVE_AUDIO", False):
330
  logger.debug(f"[Voice conversion] {fname}")
331
+ path = os.path.join(app.config.get('CACHE_PATH'), fname)
332
+ save_audio(audio.getvalue(), path)
333
+
334
  if use_streaming:
335
  audio = tts.generate_audio_chunks(audio)
336
  response = make_response(audio)
 
343
 
344
  @app.route('/voice/ssml', methods=["POST"])
345
  @require_api_key
346
+ def ssml_api():
347
  try:
348
  content_type = request.headers.get('Content-Type')
349
  if content_type == 'application/json':
350
+ request_data = request.get_json()
351
  else:
352
+ request_data = request.form
353
+
354
+ ssml = request_data.get("ssml")
355
  except Exception as e:
356
  logger.info(f"[ssml] {e}")
357
  return make_response(jsonify({"status": "error", "message": f"parameter error"}), 400)
 
362
  file_type = f"audio/{format}"
363
 
364
  t1 = time.time()
365
+ audio = tts.create_ssml_infer_task(voice_tasks, format)
366
  t2 = time.time()
367
+ logger.info(f"[ssml] finish in {(t2 - t1):.2f}s")
368
+
369
  if app.config.get("SAVE_AUDIO", False):
370
  logger.debug(f"[ssml] {fname}")
371
+ path = os.path.join(app.config.get('CACHE_PATH'), fname)
372
+ save_audio(audio.getvalue(), path)
373
 
374
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
375
 
 
389
 
390
  file_type = "application/octet-stream; charset=ascii"
391
  fname = os.path.splitext(audio.filename)[0] + ".npy"
392
+ emotion_npy = tts.get_dimensional_emotion_npy(content)
393
  if use_streaming:
394
+ emotion_npy = tts.generate_audio_chunks(emotion_npy)
395
+ response = make_response(emotion_npy)
396
  response.headers['Content-Disposition'] = f'attachment; filename={fname}'
397
  response.headers['Content-Type'] = file_type
398
  return response
399
  else:
400
+ return send_file(path_or_file=emotion_npy, mimetype=file_type, download_name=fname)
401
 
402
 
403
  @app.route('/voice/bert-vits2', methods=["GET", "POST"])
 
405
  def voice_bert_vits2_api():
406
  try:
407
  if request.method == "GET":
408
+ request_data = request.args
 
 
 
 
 
 
 
 
 
409
  elif request.method == "POST":
410
  content_type = request.headers.get('Content-Type')
411
  if content_type == 'application/json':
412
+ request_data = request.get_json()
413
  else:
414
+ request_data = request.form
415
+
416
+ text = request_data.get("text", "")
417
+ id = int(request_data.get("id", app.config.get("ID", 0)))
418
+ format = request_data.get("format", app.config.get("FORMAT", "wav"))
419
+ lang = request_data.get("lang", "auto").lower()
420
+ length = float(request_data.get("length", app.config.get("LENGTH", 1)))
421
+ noise = float(request_data.get("noise", app.config.get("NOISE", 0.667)))
422
+ noisew = float(request_data.get("noisew", app.config.get("NOISEW", 0.8)))
423
+ sdp_ratio = float(request_data.get("sdp_ratio", app.config.get("SDP_RATIO", 0.2)))
424
+ max = int(request_data.get("max", app.config.get("MAX", 50)))
425
  except Exception as e:
426
  logger.error(f"[Bert-VITS2] {e}")
427
  return make_response("parameter error", 400)
428
 
429
+ logger.info(
430
+ f"[Bert-VITS2] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew} sdp_ratio:{sdp_ratio}")
431
  logger.info(f"[Bert-VITS2] len:{len(text)} text:{text}")
432
 
433
  if check_is_none(text):
 
444
 
445
  # 校验模型是否支持输入的语言
446
  speaker_lang = tts.voice_speakers["BERT-VITS2"][id].get('lang')
447
+ if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
448
  logger.info(f"[Bert-VITS2] lang \"{lang}\" is not in {speaker_lang}")
449
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
450
 
 
465
  "lang": lang,
466
  "speaker_lang": speaker_lang}
467
 
 
 
 
 
468
  t1 = time.time()
469
+ audio = tts.bert_vits2_infer(task)
470
  t2 = time.time()
471
  logger.info(f"[Bert-VITS2] finish in {(t2 - t1):.2f}s")
472
+
473
+ if app.config.get("SAVE_AUDIO", False):
474
+ logger.debug(f"[Bert-VITS2] {fname}")
475
+ path = os.path.join(app.config.get('CACHE_PATH'), fname)
476
+ save_audio(audio.getvalue(), path)
477
+
478
  return send_file(path_or_file=audio, mimetype=file_type, download_name=fname)
479
 
480
 
 
482
  def check():
483
  try:
484
  if request.method == "GET":
485
+ request_data = request.args
 
486
  elif request.method == "POST":
487
  content_type = request.headers.get('Content-Type')
488
  if content_type == 'application/json':
489
+ request_data = request.get_json()
490
  else:
491
+ request_data = request.form
492
+
493
+ model = request_data.get("model")
494
+ id = int(request_data.get("id"))
495
  except Exception as e:
496
  logger.info(f"[check] {e}")
497
  return make_response(jsonify({"status": "error", "message": "parameter error"}), 400)
bert_vits2/attentions.py CHANGED
@@ -42,11 +42,11 @@ class Encoder(nn.Module):
42
  self.kernel_size = kernel_size
43
  self.p_dropout = p_dropout
44
  self.window_size = window_size
45
- if isflow:
46
- cond_layer = torch.nn.Conv1d(256, 2 * hidden_channels * n_layers, 1)
47
- self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
48
- self.cond_layer = weight_norm(cond_layer, name='weight')
49
- self.gin_channels = 256
50
  self.cond_layer_idx = self.n_layers
51
  if 'gin_channels' in kwargs:
52
  self.gin_channels = kwargs['gin_channels']
 
42
  self.kernel_size = kernel_size
43
  self.p_dropout = p_dropout
44
  self.window_size = window_size
45
+ # if isflow:
46
+ # cond_layer = torch.nn.Conv1d(256, 2 * hidden_channels * n_layers, 1)
47
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
48
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
49
+ # self.gin_channels = 256
50
  self.cond_layer_idx = self.n_layers
51
  if 'gin_channels' in kwargs:
52
  self.gin_channels = kwargs['gin_channels']
bert_vits2/bert/bert-base-japanese-v3/README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - cc100
5
+ - wikipedia
6
+ language:
7
+ - ja
8
+ widget:
9
+ - text: 東北大学で[MASK]の研究をしています。
10
+ ---
11
+
12
+ # BERT base Japanese (unidic-lite with whole word masking, CC-100 and jawiki-20230102)
13
+
14
+ This is a [BERT](https://github.com/google-research/bert) model pretrained on texts in the Japanese language.
15
+
16
+ This version of the model processes input texts with word-level tokenization based on the Unidic 2.1.2 dictionary (available in [unidic-lite](https://pypi.org/project/unidic-lite/) package), followed by the WordPiece subword tokenization.
17
+ Additionally, the model is trained with the whole word masking enabled for the masked language modeling (MLM) objective.
18
+
19
+ The codes for the pretraining are available at [cl-tohoku/bert-japanese](https://github.com/cl-tohoku/bert-japanese/).
20
+
21
+ ## Model architecture
22
+
23
+ The model architecture is the same as the original BERT base model; 12 layers, 768 dimensions of hidden states, and 12 attention heads.
24
+
25
+ ## Training Data
26
+
27
+ The model is trained on the Japanese portion of [CC-100 dataset](https://data.statmt.org/cc-100/) and the Japanese version of Wikipedia.
28
+ For Wikipedia, we generated a text corpus from the [Wikipedia Cirrussearch dump file](https://dumps.wikimedia.org/other/cirrussearch/) as of January 2, 2023.
29
+ The corpus files generated from CC-100 and Wikipedia are 74.3GB and 4.9GB in size and consist of approximately 392M and 34M sentences, respectively.
30
+
31
+ For the purpose of splitting texts into sentences, we used [fugashi](https://github.com/polm/fugashi) with [mecab-ipadic-NEologd](https://github.com/neologd/mecab-ipadic-neologd) dictionary (v0.0.7).
32
+
33
+ ## Tokenization
34
+
35
+ The texts are first tokenized by MeCab with the Unidic 2.1.2 dictionary and then split into subwords by the WordPiece algorithm.
36
+ The vocabulary size is 32768.
37
+
38
+ We used [fugashi](https://github.com/polm/fugashi) and [unidic-lite](https://github.com/polm/unidic-lite) packages for the tokenization.
39
+
40
+ ## Training
41
+
42
+ We trained the model first on the CC-100 corpus for 1M steps and then on the Wikipedia corpus for another 1M steps.
43
+ For training of the MLM (masked language modeling) objective, we introduced whole word masking in which all of the subword tokens corresponding to a single word (tokenized by MeCab) are masked at once.
44
+
45
+ For training of each model, we used a v3-8 instance of Cloud TPUs provided by [TPU Research Cloud](https://sites.research.google/trc/about/).
46
+
47
+ ## Licenses
48
+
49
+ The pretrained models are distributed under the Apache License 2.0.
50
+
51
+ ## Acknowledgments
52
+
53
+ This model is trained with Cloud TPUs provided by [TPU Research Cloud](https://sites.research.google/trc/about/) program.
bert_vits2/bert/bert-base-japanese-v3/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForPreTraining"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 32768
19
+ }
bert_vits2/bert/bert-base-japanese-v3/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
bert_vits2/bert_vits2.py CHANGED
@@ -1,41 +1,46 @@
1
- import re
2
-
3
  import numpy as np
4
  import torch
5
 
6
- from bert_vits2 import utils, commons
 
7
  from bert_vits2.models import SynthesizerTrn
8
- from bert_vits2.text import symbols, cleaned_text_to_sequence, get_bert
9
  from bert_vits2.text.cleaner import clean_text
10
- from utils.nlp import sentence_split, cut
 
11
 
12
 
13
  class Bert_VITS2:
14
- def __init__(self, model, config, device=torch.device("cpu")):
15
- self.hps_ms = utils.get_hparams_from_file(config)
16
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
17
  self.speakers = [item[0] for item in
18
  sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
 
 
 
 
 
19
  self.net_g = SynthesizerTrn(
20
- len(symbols),
21
  self.hps_ms.data.filter_length // 2 + 1,
22
  self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
23
  n_speakers=self.hps_ms.data.n_speakers,
 
24
  **self.hps_ms.model).to(device)
25
  _ = self.net_g.eval()
26
  self.device = device
27
  self.load_model(model)
28
 
29
  def load_model(self, model):
30
- utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True)
31
 
32
  def get_speakers(self):
33
  return self.speakers
34
 
35
  def get_text(self, text, language_str, hps):
36
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
37
- # print([f"{p}{t}" for p, t in zip(phone, tone)])
38
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
39
 
40
  if hps.data.add_blank:
41
  phone = commons.intersperse(phone, 0)
@@ -45,25 +50,37 @@ class Bert_VITS2:
45
  word2ph[i] = word2ph[i] * 2
46
  word2ph[0] += 1
47
  bert = get_bert(norm_text, word2ph, language_str)
 
 
48
 
49
- assert bert.shape[-1] == len(phone)
50
-
 
 
 
 
 
 
 
 
 
 
51
  phone = torch.LongTensor(phone)
52
  tone = torch.LongTensor(tone)
53
  language = torch.LongTensor(language)
 
54
 
55
- return bert, phone, tone, language
56
-
57
- def infer(self, text, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
58
- bert, phones, tones, lang_ids = self.get_text(text, "ZH", self.hps_ms)
59
  with torch.no_grad():
60
  x_tst = phones.to(self.device).unsqueeze(0)
61
  tones = tones.to(self.device).unsqueeze(0)
62
  lang_ids = lang_ids.to(self.device).unsqueeze(0)
63
  bert = bert.to(self.device).unsqueeze(0)
 
64
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
65
  speakers = torch.LongTensor([int(sid)]).to(self.device)
66
- audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, sdp_ratio=sdp_ratio
67
  , noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[
68
  0][0, 0].data.cpu().float().numpy()
69
 
@@ -72,17 +89,20 @@ class Bert_VITS2:
72
 
73
  def get_audio(self, voice, auto_break=False):
74
  text = voice.get("text", None)
 
75
  sdp_ratio = voice.get("sdp_ratio", 0.2)
76
  noise_scale = voice.get("noise", 0.5)
77
  noise_scale_w = voice.get("noisew", 0.6)
78
  length_scale = voice.get("length", 1)
79
  sid = voice.get("id", 0)
80
  max = voice.get("max", 50)
81
- # sentence_list = sentence_split(text, max, "ZH", ["zh"])
 
 
82
  sentence_list = cut(text, max)
83
  audios = []
84
  for sentence in sentence_list:
85
- audio = self.infer(sentence, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid)
86
  audios.append(audio)
87
  audio = np.concatenate(audios)
88
  return audio
 
 
 
1
  import numpy as np
2
  import torch
3
 
4
+ from bert_vits2 import commons
5
+ from bert_vits2 import utils as bert_vits2_utils
6
  from bert_vits2.models import SynthesizerTrn
7
+ from bert_vits2.text import *
8
  from bert_vits2.text.cleaner import clean_text
9
+ from utils import classify_language, get_hparams_from_file, lang_dict
10
+ from utils.sentence import sentence_split_and_markup, cut
11
 
12
 
13
  class Bert_VITS2:
14
+ def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
15
+ self.hps_ms = get_hparams_from_file(config)
16
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
17
  self.speakers = [item[0] for item in
18
  sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
19
+
20
+ self.legacy = getattr(self.hps_ms.data, 'legacy', False)
21
+ self.symbols = symbols_legacy if self.legacy else symbols
22
+ self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
23
+
24
  self.net_g = SynthesizerTrn(
25
+ len(self.symbols),
26
  self.hps_ms.data.filter_length // 2 + 1,
27
  self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
28
  n_speakers=self.hps_ms.data.n_speakers,
29
+ symbols=self.symbols,
30
  **self.hps_ms.model).to(device)
31
  _ = self.net_g.eval()
32
  self.device = device
33
  self.load_model(model)
34
 
35
  def load_model(self, model):
36
+ bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True)
37
 
38
  def get_speakers(self):
39
  return self.speakers
40
 
41
  def get_text(self, text, language_str, hps):
42
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
43
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)
 
44
 
45
  if hps.data.add_blank:
46
  phone = commons.intersperse(phone, 0)
 
50
  word2ph[i] = word2ph[i] * 2
51
  word2ph[0] += 1
52
  bert = get_bert(norm_text, word2ph, language_str)
53
+ del word2ph
54
+ assert bert.shape[-1] == len(phone), phone
55
 
56
+ if language_str == "zh":
57
+ bert = bert
58
+ ja_bert = torch.zeros(768, len(phone))
59
+ elif language_str == "ja":
60
+ ja_bert = bert
61
+ bert = torch.zeros(1024, len(phone))
62
+ else:
63
+ bert = torch.zeros(1024, len(phone))
64
+ ja_bert = torch.zeros(768, len(phone))
65
+ assert bert.shape[-1] == len(
66
+ phone
67
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
68
  phone = torch.LongTensor(phone)
69
  tone = torch.LongTensor(tone)
70
  language = torch.LongTensor(language)
71
+ return bert, ja_bert, phone, tone, language
72
 
73
+ def infer(self, text, lang, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid):
74
+ bert, ja_bert, phones, tones, lang_ids = self.get_text(text, lang, self.hps_ms)
 
 
75
  with torch.no_grad():
76
  x_tst = phones.to(self.device).unsqueeze(0)
77
  tones = tones.to(self.device).unsqueeze(0)
78
  lang_ids = lang_ids.to(self.device).unsqueeze(0)
79
  bert = bert.to(self.device).unsqueeze(0)
80
+ ja_bert = ja_bert.to(self.device).unsqueeze(0)
81
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(self.device)
82
  speakers = torch.LongTensor([int(sid)]).to(self.device)
83
+ audio = self.net_g.infer(x_tst, x_tst_lengths, speakers, tones, lang_ids, bert, ja_bert, sdp_ratio=sdp_ratio
84
  , noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[
85
  0][0, 0].data.cpu().float().numpy()
86
 
 
89
 
90
  def get_audio(self, voice, auto_break=False):
91
  text = voice.get("text", None)
92
+ lang = voice.get("lang", "auto")
93
  sdp_ratio = voice.get("sdp_ratio", 0.2)
94
  noise_scale = voice.get("noise", 0.5)
95
  noise_scale_w = voice.get("noisew", 0.6)
96
  length_scale = voice.get("length", 1)
97
  sid = voice.get("id", 0)
98
  max = voice.get("max", 50)
99
+ # sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"])
100
+ if lang == "auto":
101
+ lang = classify_language(text, target_languages=lang_dict["bert_vits2"])
102
  sentence_list = cut(text, max)
103
  audios = []
104
  for sentence in sentence_list:
105
+ audio = self.infer(sentence, lang, sdp_ratio, noise_scale, noise_scale_w, length_scale, sid)
106
  audios.append(audio)
107
  audio = np.concatenate(audios)
108
  return audio
bert_vits2/models.py CHANGED
@@ -11,7 +11,7 @@ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
 
13
  from bert_vits2.commons import init_weights, get_padding
14
- from bert_vits2.text import symbols, num_tones, num_languages
15
 
16
 
17
  class DurationDiscriminator(nn.Module): # vits2
@@ -254,7 +254,8 @@ class TextEncoder(nn.Module):
254
  n_layers,
255
  kernel_size,
256
  p_dropout,
257
- gin_channels=0):
 
258
  super().__init__()
259
  self.n_vocab = n_vocab
260
  self.out_channels = out_channels
@@ -272,6 +273,7 @@ class TextEncoder(nn.Module):
272
  self.language_emb = nn.Embedding(num_languages, hidden_channels)
273
  nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels ** -0.5)
274
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
 
275
 
276
  self.encoder = attentions.Encoder(
277
  hidden_channels,
@@ -283,9 +285,10 @@ class TextEncoder(nn.Module):
283
  gin_channels=self.gin_channels)
284
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
285
 
286
- def forward(self, x, x_lengths, tone, language, bert, g=None):
287
- x = (self.emb(x) + self.tone_emb(tone) + self.language_emb(language) + self.bert_proj(bert).transpose(1,
288
- 2)) * math.sqrt(
 
289
  self.hidden_channels) # [b, t, h]
290
  x = torch.transpose(x, 1, -1) # [b, h, t]
291
  x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
@@ -620,6 +623,7 @@ class SynthesizerTrn(nn.Module):
620
  self.current_mas_noise_scale = self.mas_noise_scale_initial
621
  if self.use_spk_conditioned_encoder and gin_channels > 0:
622
  self.enc_gin_channels = gin_channels
 
623
  self.enc_p = TextEncoder(n_vocab,
624
  inter_channels,
625
  hidden_channels,
@@ -628,7 +632,9 @@ class SynthesizerTrn(nn.Module):
628
  n_layers,
629
  kernel_size,
630
  p_dropout,
631
- gin_channels=self.enc_gin_channels)
 
 
632
  self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
633
  upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
634
  self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
@@ -643,12 +649,13 @@ class SynthesizerTrn(nn.Module):
643
  self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
644
  self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
645
 
646
- if n_speakers >= 1:
647
- self.emb_g = nn.Embedding(n_speakers, gin_channels)
648
  else:
649
  self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
650
 
651
- def infer(self, x, x_lengths, sid, tone, language, bert, noise_scale=.667, length_scale=1, noise_scale_w=0.8,
 
652
  max_len=None, sdp_ratio=0, y=None):
653
  # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
654
  # g = self.gst(y)
@@ -656,7 +663,7 @@ class SynthesizerTrn(nn.Module):
656
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
657
  else:
658
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
659
- x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert, g=g)
660
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (sdp_ratio) + self.dp(x, x_mask,
661
  g=g) * (
662
  1 - sdp_ratio)
 
11
  from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
 
13
  from bert_vits2.commons import init_weights, get_padding
14
+ from bert_vits2.text import num_tones, num_languages
15
 
16
 
17
  class DurationDiscriminator(nn.Module): # vits2
 
254
  n_layers,
255
  kernel_size,
256
  p_dropout,
257
+ gin_channels=0,
258
+ symbols=None):
259
  super().__init__()
260
  self.n_vocab = n_vocab
261
  self.out_channels = out_channels
 
273
  self.language_emb = nn.Embedding(num_languages, hidden_channels)
274
  nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels ** -0.5)
275
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
276
+ self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
277
 
278
  self.encoder = attentions.Encoder(
279
  hidden_channels,
 
285
  gin_channels=self.gin_channels)
286
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
287
 
288
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
289
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
290
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
291
+ x = (self.emb(x) + self.tone_emb(tone) + self.language_emb(language) + bert_emb + ja_bert_emb) * math.sqrt(
292
  self.hidden_channels) # [b, t, h]
293
  x = torch.transpose(x, 1, -1) # [b, h, t]
294
  x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
 
623
  self.current_mas_noise_scale = self.mas_noise_scale_initial
624
  if self.use_spk_conditioned_encoder and gin_channels > 0:
625
  self.enc_gin_channels = gin_channels
626
+ symbols = kwargs.get("symbols")
627
  self.enc_p = TextEncoder(n_vocab,
628
  inter_channels,
629
  hidden_channels,
 
632
  n_layers,
633
  kernel_size,
634
  p_dropout,
635
+ gin_channels=self.enc_gin_channels,
636
+ symbols=symbols,
637
+ )
638
  self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
639
  upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
640
  self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
 
649
  self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
650
  self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
651
 
652
+ if self.n_speakers > 0:
653
+ self.emb_g = nn.Embedding(self.n_speakers, gin_channels)
654
  else:
655
  self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
656
 
657
+ def infer(self, x, x_lengths, sid, tone, language, bert, ja_bert, noise_scale=.667, length_scale=1,
658
+ noise_scale_w=0.8,
659
  max_len=None, sdp_ratio=0, y=None):
660
  # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
661
  # g = self.gst(y)
 
663
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
664
  else:
665
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
666
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert, ja_bert, g=g)
667
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (sdp_ratio) + self.dp(x, x_mask,
668
  g=g) * (
669
  1 - sdp_ratio)
bert_vits2/text/__init__.py CHANGED
@@ -1,17 +1,16 @@
1
  from bert_vits2.text.symbols import *
2
  from .chinese_bert import get_bert_feature as zh_bert
3
  from .english_bert_mock import get_bert_feature as en_bert
 
4
 
5
- _symbol_to_id = {s: i for i, s in enumerate(symbols)}
6
 
7
-
8
- def cleaned_text_to_sequence(cleaned_text, tones, language):
9
- '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
10
- Args:
11
- text: string to convert to a sequence
12
- Returns:
13
- List of integers corresponding to the symbols in the text
14
- '''
15
  phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
16
  tone_start = language_tone_start_map[language]
17
  tones = [i + tone_start for i in tones]
@@ -21,9 +20,6 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
21
 
22
 
23
  def get_bert(norm_text, word2ph, language):
24
- lang_bert_func_map = {
25
- 'ZH': zh_bert,
26
- 'EN': en_bert
27
- }
28
  bert = lang_bert_func_map[language](norm_text, word2ph)
29
  return bert
 
1
  from bert_vits2.text.symbols import *
2
  from .chinese_bert import get_bert_feature as zh_bert
3
  from .english_bert_mock import get_bert_feature as en_bert
4
+ from .japanese_bert import get_bert_feature as ja_bert
5
 
 
6
 
7
+ def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
8
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
+ Args:
10
+ text: string to convert to a sequence
11
+ Returns:
12
+ List of integers corresponding to the symbols in the text
13
+ """
 
14
  phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
15
  tone_start = language_tone_start_map[language]
16
  tones = [i + tone_start for i in tones]
 
20
 
21
 
22
  def get_bert(norm_text, word2ph, language):
23
+ lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert}
 
 
 
24
  bert = lang_bert_func_map[language](norm_text, word2ph)
25
  return bert
bert_vits2/text/chinese.py CHANGED
@@ -12,6 +12,8 @@ pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for lin
12
  open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()}
13
 
14
  import jieba.posseg as psg
 
 
15
 
16
  rep_map = {
17
  ':': ',',
 
12
  open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()}
13
 
14
  import jieba.posseg as psg
15
+ from jieba import lcut
16
+ lcut("预加载")
17
 
18
  rep_map = {
19
  ':': ',',
bert_vits2/text/chinese_bert.py CHANGED
@@ -1,9 +1,20 @@
 
 
1
  import config
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
4
  from logger import logger
 
 
 
 
 
 
 
 
5
 
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
7
 
8
  try:
9
  logger.info("Loading chinese-roberta-wwm-ext-large...")
@@ -13,10 +24,10 @@ try:
13
  logger.info("Loading finished.")
14
  except Exception as e:
15
  logger.error(e)
16
- logger.error(f"Please download model from hfl/chinese-roberta-wwm-ext-large.")
17
 
18
 
19
- def get_bert_feature(text, word2ph):
20
  with torch.no_grad():
21
  inputs = tokenizer(text, return_tensors='pt')
22
  for i in inputs:
@@ -37,7 +48,6 @@ def get_bert_feature(text, word2ph):
37
 
38
 
39
  if __name__ == '__main__':
40
- # feature = get_bert_feature('你好,我是说的道理。')
41
  import torch
42
 
43
  word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
 
1
+ import os
2
+
3
  import config
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForMaskedLM
6
  from logger import logger
7
+ from utils.download import download_and_verify
8
+ from config import DEVICE as device
9
+
10
+ URLS = [
11
+ "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin",
12
+ ]
13
+ TARGET_PATH = os.path.join(config.ABS_PATH, "bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin")
14
+ EXPECTED_MD5 = None
15
 
16
+ if not os.path.exists(TARGET_PATH):
17
+ success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5)
18
 
19
  try:
20
  logger.info("Loading chinese-roberta-wwm-ext-large...")
 
24
  logger.info("Loading finished.")
25
  except Exception as e:
26
  logger.error(e)
27
+ logger.error(f"Please download pytorch_model.bin from hfl/chinese-roberta-wwm-ext-large.")
28
 
29
 
30
+ def get_bert_feature(text, word2ph, device=config.DEVICE):
31
  with torch.no_grad():
32
  inputs = tokenizer(text, return_tensors='pt')
33
  for i in inputs:
 
48
 
49
 
50
  if __name__ == '__main__':
 
51
  import torch
52
 
53
  word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
bert_vits2/text/cleaner.py CHANGED
@@ -1,7 +1,8 @@
1
- from bert_vits2.text import chinese, cleaned_text_to_sequence
2
 
3
  language_module_map = {
4
- 'ZH': chinese
 
5
  }
6
 
7
 
 
1
+ from bert_vits2.text import chinese, japanese, cleaned_text_to_sequence
2
 
3
  language_module_map = {
4
+ 'zh': chinese,
5
+ 'ja': japanese
6
  }
7
 
8
 
bert_vits2/text/japanese.py CHANGED
@@ -1,104 +1,585 @@
1
- # modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py
 
2
  import re
3
- import sys
4
-
5
- import pyopenjtalk
6
-
7
- from bert_vits2.text import symbols
8
-
9
- # Regular expression matching Japanese without punctuation marks:
10
- _japanese_characters = re.compile(
11
- r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
12
-
13
- # Regular expression matching non-Japanese characters or punctuation marks:
14
- _japanese_marks = re.compile(
15
- r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
16
-
17
- # List of (symbol, Japanese) pairs for marks:
18
- _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
19
- ('%', 'パーセント')
20
- ]]
21
-
22
- # List of (consonant, sokuon) pairs:
23
- _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
24
- (r'Q([↑↓]*[kg])', r'k#\1'),
25
- (r'Q([↑↓]*[tdjʧ])', r't#\1'),
26
- (r'Q([↑↓]*[sʃ])', r's\1'),
27
- (r'Q([↑↓]*[pb])', r'p#\1')
28
- ]]
29
-
30
- # List of (consonant, hatsuon) pairs:
31
- _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
32
- (r'N([↑↓]*[pbm])', r'm\1'),
33
- (r'N([↑↓]*[ʧʥj])', r'n^\1'),
34
- (r'N([↑↓]*[tdn])', r'n\1'),
35
- (r'N([↑↓]*[kg])', r'ŋ\1')
36
- ]]
37
-
38
-
39
- def post_replace_ph(ph):
40
- rep_map = {
41
- ':': ',',
42
- ';': ',',
43
- ',': ',',
44
- '。': '.',
45
- '!': '!',
46
- '?': '?',
47
- '\n': '.',
48
- "·": ",",
49
- '、': ",",
50
- '...': '…',
51
- 'v': "V"
52
- }
53
- if ph in rep_map.keys():
54
- ph = rep_map[ph]
55
- if ph in symbols:
56
- return ph
57
- if ph not in symbols:
58
- ph = 'UNK'
59
- return ph
60
-
61
-
62
- def symbols_to_japanese(text):
63
- for regex, replacement in _symbols_to_japanese:
64
- text = re.sub(regex, replacement, text)
65
- return text
66
-
67
-
68
- def preprocess_jap(text):
69
- '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
70
- text = symbols_to_japanese(text)
71
- sentences = re.split(_japanese_marks, text)
72
- marks = re.findall(_japanese_marks, text)
73
- text = []
74
- for i, sentence in enumerate(sentences):
75
- if re.match(_japanese_characters, sentence):
76
- p = pyopenjtalk.g2p(sentence)
77
- text += p.split(" ")
78
-
79
- if i < len(marks):
80
- text += [marks[i].replace(' ', '')]
81
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  def text_normalize(text):
85
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  def g2p(norm_text):
89
- phones = preprocess_jap(norm_text)
90
- phones = [post_replace_ph(i) for i in phones]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
92
  tones = [0 for i in phones]
93
- word2ph = [1 for i in phones]
94
  return phones, tones, word2ph
95
 
96
 
97
- if __name__ == '__main__':
98
- for line in open("../../../Downloads/transcript_utf8.txt").readlines():
99
- text = line.split(":")[1]
100
- phones, tones, word2ph = g2p(text)
101
- for p in phones:
102
- if p == "z":
103
- print(text, phones)
104
- sys.exit(0)
 
 
 
 
 
 
1
+ # Convert Japanese text to phonemes which is
2
+ # compatible with Julius https://github.com/julius-speech/segmentation-kit
3
  import re
4
+ import unicodedata
5
+
6
+ from transformers import AutoTokenizer
7
+
8
+ from bert_vits2.text.symbols import *
9
+ from bert_vits2.text.japanese_bert import tokenizer
10
+
11
+ try:
12
+ import MeCab
13
+ except ImportError as e:
14
+ raise ImportError("Japanese requires mecab-python3 and unidic-lite.") from e
15
+ from num2words import num2words
16
+
17
+ _CONVRULES = [
18
+ # Conversion of 2 letters
19
+ "アァ/ a a",
20
+ "イィ/ i i",
21
+ "イェ/ i e",
22
+ "イャ/ y a",
23
+ "ウゥ/ u:",
24
+ "エェ/ e e",
25
+ "オォ/ o:",
26
+ "カァ/ k a:",
27
+ "キィ/ k i:",
28
+ "クゥ/ k u:",
29
+ "クャ/ ky a",
30
+ "クュ/ ky u",
31
+ "クョ/ ky o",
32
+ "ケェ/ k e:",
33
+ "コォ/ k o:",
34
+ "ガァ/ g a:",
35
+ "ギィ/ g i:",
36
+ "グゥ/ g u:",
37
+ "グャ/ gy a",
38
+ "グュ/ gy u",
39
+ "グョ/ gy o",
40
+ "ゲェ/ g e:",
41
+ "ゴォ/ g o:",
42
+ "サァ/ s a:",
43
+ "シィ/ sh i:",
44
+ "スゥ/ s u:",
45
+ "スャ/ sh a",
46
+ "スュ/ sh u",
47
+ "スョ/ sh o",
48
+ "セェ/ s e:",
49
+ "ソォ/ s o:",
50
+ "ザァ/ z a:",
51
+ "ジィ/ j i:",
52
+ "ズゥ/ z u:",
53
+ "ズャ/ zy a",
54
+ "ズュ/ zy u",
55
+ "ズョ/ zy o",
56
+ "ゼェ/ z e:",
57
+ "ゾォ/ z o:",
58
+ "タァ/ t a:",
59
+ "チィ/ ch i:",
60
+ "ツァ/ ts a",
61
+ "ツィ/ ts i",
62
+ "ツゥ/ ts u:",
63
+ "ツャ/ ch a",
64
+ "ツュ/ ch u",
65
+ "ツョ/ ch o",
66
+ "ツェ/ ts e",
67
+ "ツォ/ ts o",
68
+ "テェ/ t e:",
69
+ "トォ/ t o:",
70
+ "ダァ/ d a:",
71
+ "ヂィ/ j i:",
72
+ "ヅゥ/ d u:",
73
+ "ヅャ/ zy a",
74
+ "ヅュ/ zy u",
75
+ "ヅョ/ zy o",
76
+ "デェ/ d e:",
77
+ "ドォ/ d o:",
78
+ "ナァ/ n a:",
79
+ "ニィ/ n i:",
80
+ "ヌゥ/ n u:",
81
+ "ヌャ/ ny a",
82
+ "ヌュ/ ny u",
83
+ "ヌョ/ ny o",
84
+ "ネェ/ n e:",
85
+ "ノォ/ n o:",
86
+ "ハァ/ h a:",
87
+ "ヒィ/ h i:",
88
+ "フゥ/ f u:",
89
+ "フャ/ hy a",
90
+ "フュ/ hy u",
91
+ "フョ/ hy o",
92
+ "ヘェ/ h e:",
93
+ "ホォ/ h o:",
94
+ "バァ/ b a:",
95
+ "ビィ/ b i:",
96
+ "ブゥ/ b u:",
97
+ "フャ/ hy a",
98
+ "ブュ/ by u",
99
+ "フョ/ hy o",
100
+ "ベェ/ b e:",
101
+ "ボォ/ b o:",
102
+ "パァ/ p a:",
103
+ "ピィ/ p i:",
104
+ "プゥ/ p u:",
105
+ "プャ/ py a",
106
+ "プュ/ py u",
107
+ "プョ/ py o",
108
+ "ペェ/ p e:",
109
+ "ポォ/ p o:",
110
+ "マァ/ m a:",
111
+ "ミィ/ m i:",
112
+ "ムゥ/ m u:",
113
+ "ムャ/ my a",
114
+ "ムュ/ my u",
115
+ "ムョ/ my o",
116
+ "メェ/ m e:",
117
+ "モォ/ m o:",
118
+ "ヤァ/ y a:",
119
+ "ユゥ/ y u:",
120
+ "ユャ/ y a:",
121
+ "ユュ/ y u:",
122
+ "ユョ/ y o:",
123
+ "ヨォ/ y o:",
124
+ "ラ���/ r a:",
125
+ "リィ/ r i:",
126
+ "ルゥ/ r u:",
127
+ "ルャ/ ry a",
128
+ "ルュ/ ry u",
129
+ "ルョ/ ry o",
130
+ "レェ/ r e:",
131
+ "ロォ/ r o:",
132
+ "ワァ/ w a:",
133
+ "ヲォ/ o:",
134
+ "ディ/ d i",
135
+ "デェ/ d e:",
136
+ "デャ/ dy a",
137
+ "デュ/ dy u",
138
+ "デョ/ dy o",
139
+ "ティ/ t i",
140
+ "テェ/ t e:",
141
+ "テャ/ ty a",
142
+ "テュ/ ty u",
143
+ "テョ/ ty o",
144
+ "スィ/ s i",
145
+ "ズァ/ z u a",
146
+ "ズィ/ z i",
147
+ "ズゥ/ z u",
148
+ "ズャ/ zy a",
149
+ "ズュ/ zy u",
150
+ "ズョ/ zy o",
151
+ "ズェ/ z e",
152
+ "ズォ/ z o",
153
+ "キャ/ ky a",
154
+ "キュ/ ky u",
155
+ "キョ/ ky o",
156
+ "シャ/ sh a",
157
+ "シュ/ sh u",
158
+ "シェ/ sh e",
159
+ "ショ/ sh o",
160
+ "チャ/ ch a",
161
+ "チュ/ ch u",
162
+ "チェ/ ch e",
163
+ "チョ/ ch o",
164
+ "トゥ/ t u",
165
+ "トャ/ ty a",
166
+ "トュ/ ty u",
167
+ "トョ/ ty o",
168
+ "ドァ/ d o a",
169
+ "ドゥ/ d u",
170
+ "ドャ/ dy a",
171
+ "ドュ/ dy u",
172
+ "ドョ/ dy o",
173
+ "ドォ/ d o:",
174
+ "ニャ/ ny a",
175
+ "ニュ/ ny u",
176
+ "ニョ/ ny o",
177
+ "ヒャ/ hy a",
178
+ "ヒュ/ hy u",
179
+ "ヒョ/ hy o",
180
+ "ミャ/ my a",
181
+ "ミュ/ my u",
182
+ "ミョ/ my o",
183
+ "リャ/ ry a",
184
+ "リュ/ ry u",
185
+ "リョ/ ry o",
186
+ "ギャ/ gy a",
187
+ "ギュ/ gy u",
188
+ "ギョ/ gy o",
189
+ "ヂェ/ j e",
190
+ "ヂャ/ j a",
191
+ "ヂュ/ j u",
192
+ "ヂョ/ j o",
193
+ "ジェ/ j e",
194
+ "ジャ/ j a",
195
+ "ジュ/ j u",
196
+ "ジョ/ j o",
197
+ "ビャ/ by a",
198
+ "ビュ/ by u",
199
+ "ビョ/ by o",
200
+ "ピャ/ py a",
201
+ "ピュ/ py u",
202
+ "ピョ/ py o",
203
+ "ウァ/ u a",
204
+ "ウィ/ w i",
205
+ "ウェ/ w e",
206
+ "ウォ/ w o",
207
+ "ファ/ f a",
208
+ "フィ/ f i",
209
+ "フゥ/ f u",
210
+ "フャ/ hy a",
211
+ "フュ/ hy u",
212
+ "フョ/ hy o",
213
+ "フェ/ f e",
214
+ "フォ/ f o",
215
+ "ヴァ/ b a",
216
+ "ヴィ/ b i",
217
+ "ヴェ/ b e",
218
+ "ヴォ/ b o",
219
+ "ヴュ/ by u",
220
+ # Conversion of 1 letter
221
+ "ア/ a",
222
+ "イ/ i",
223
+ "ウ/ u",
224
+ "エ/ e",
225
+ "オ/ o",
226
+ "カ/ k a",
227
+ "キ/ k i",
228
+ "ク/ k u",
229
+ "ケ/ k e",
230
+ "コ/ k o",
231
+ "サ/ s a",
232
+ "シ/ sh i",
233
+ "ス/ s u",
234
+ "セ/ s e",
235
+ "ソ/ s o",
236
+ "タ/ t a",
237
+ "チ/ ch i",
238
+ "ツ/ ts u",
239
+ "テ/ t e",
240
+ "ト/ t o",
241
+ "ナ/ n a",
242
+ "ニ/ n i",
243
+ "ヌ/ n u",
244
+ "ネ/ n e",
245
+ "ノ/ n o",
246
+ "ハ/ h a",
247
+ "ヒ/ h i",
248
+ "フ/ f u",
249
+ "ヘ/ h e",
250
+ "ホ/ h o",
251
+ "マ/ m a",
252
+ "ミ/ m i",
253
+ "ム/ m u",
254
+ "メ/ m e",
255
+ "モ/ m o",
256
+ "ラ/ r a",
257
+ "リ/ r i",
258
+ "ル/ r u",
259
+ "レ/ r e",
260
+ "ロ/ r o",
261
+ "ガ/ g a",
262
+ "ギ/ g i",
263
+ "グ/ g u",
264
+ "ゲ/ g e",
265
+ "ゴ/ g o",
266
+ "ザ/ z a",
267
+ "ジ/ j i",
268
+ "ズ/ z u",
269
+ "ゼ/ z e",
270
+ "ゾ/ z o",
271
+ "ダ/ d a",
272
+ "ヂ/ j i",
273
+ "ヅ/ z u",
274
+ "デ/ d e",
275
+ "ド/ d o",
276
+ "バ/ b a",
277
+ "ビ/ b i",
278
+ "ブ/ b u",
279
+ "ベ/ b e",
280
+ "ボ/ b o",
281
+ "パ/ p a",
282
+ "ピ/ p i",
283
+ "プ/ p u",
284
+ "ペ/ p e",
285
+ "ポ/ p o",
286
+ "ヤ/ y a",
287
+ "ユ/ y u",
288
+ "ヨ/ y o",
289
+ "ワ/ w a",
290
+ "ヰ/ i",
291
+ "ヱ/ e",
292
+ "ヲ/ o",
293
+ "ン/ N",
294
+ "ッ/ q",
295
+ "ヴ/ b u",
296
+ "ー/:",
297
+ # Try converting broken text
298
+ "ァ/ a",
299
+ "ィ/ i",
300
+ "ゥ/ u",
301
+ "ェ/ e",
302
+ "ォ/ o",
303
+ "ヮ/ w a",
304
+ "ォ/ o",
305
+ # Symbols
306
+ "、/ ,",
307
+ "。/ .",
308
+ "!/ !",
309
+ "?/ ?",
310
+ "・/ ,",
311
+ ]
312
+
313
+ _COLON_RX = re.compile(":+")
314
+ _REJECT_RX = re.compile("[^ a-zA-Z:,.?]")
315
+
316
+
317
+ def _makerulemap():
318
+ l = [tuple(x.split("/")) for x in _CONVRULES]
319
+ return tuple({k: v for k, v in l if len(k) == i} for i in (1, 2))
320
+
321
+
322
+ _RULEMAP1, _RULEMAP2 = _makerulemap()
323
+
324
+
325
+ def kata2phoneme(text: str) -> str:
326
+ """Convert katakana text to phonemes."""
327
+ text = text.strip()
328
+ res = []
329
+ while text:
330
+ if len(text) >= 2:
331
+ x = _RULEMAP2.get(text[:2])
332
+ if x is not None:
333
+ text = text[2:]
334
+ res += x.split(" ")[1:]
335
+ continue
336
+ x = _RULEMAP1.get(text[0])
337
+ if x is not None:
338
+ text = text[1:]
339
+ res += x.split(" ")[1:]
340
+ continue
341
+ res.append(text[0])
342
+ text = text[1:]
343
+ # res = _COLON_RX.sub(":", res)
344
+ return res
345
+
346
+
347
+ _KATAKANA = "".join(chr(ch) for ch in range(ord("ァ"), ord("ン") + 1))
348
+ _HIRAGANA = "".join(chr(ch) for ch in range(ord("ぁ"), ord("ん") + 1))
349
+ _HIRA2KATATRANS = str.maketrans(_HIRAGANA, _KATAKANA)
350
+
351
+
352
+ def hira2kata(text: str) -> str:
353
+ text = text.translate(_HIRA2KATATRANS)
354
+ return text.replace("う゛", "ヴ")
355
+
356
+
357
+ _SYMBOL_TOKENS = set(list("・、。?!"))
358
+ _NO_YOMI_TOKENS = set(list("「」『』―()[][]"))
359
+ _TAGGER = MeCab.Tagger()
360
+
361
+
362
+ def text2kata(text: str) -> str:
363
+ parsed = _TAGGER.parse(text)
364
+ res = []
365
+ for line in parsed.split("\n"):
366
+ if line == "EOS":
367
+ break
368
+ parts = line.split("\t")
369
+
370
+ word, yomi = parts[0], parts[1]
371
+ if yomi:
372
+ res.append(yomi)
373
+ else:
374
+ if word in _SYMBOL_TOKENS:
375
+ res.append(word)
376
+ elif word in ("っ", "ッ"):
377
+ res.append("ッ")
378
+ elif word in _NO_YOMI_TOKENS:
379
+ pass
380
+ else:
381
+ res.append(word)
382
+ return hira2kata("".join(res))
383
+
384
+
385
+ _ALPHASYMBOL_YOMI = {
386
+ "#": "シャープ",
387
+ "%": "パーセント",
388
+ "&": "アンド",
389
+ "+": "プラス",
390
+ "-": "マイナス",
391
+ ":": "コロン",
392
+ ";": "セミコロン",
393
+ "<": "小なり",
394
+ "=": "イコール",
395
+ ">": "大なり",
396
+ "@": "アット",
397
+ "a": "エー",
398
+ "b": "ビー",
399
+ "c": "シー",
400
+ "d": "ディー",
401
+ "e": "イー",
402
+ "f": "エフ",
403
+ "g": "ジー",
404
+ "h": "エイチ",
405
+ "i": "アイ",
406
+ "j": "ジェー",
407
+ "k": "ケー",
408
+ "l": "エル",
409
+ "m": "エム",
410
+ "n": "エヌ",
411
+ "o": "オー",
412
+ "p": "ピー",
413
+ "q": "キュー",
414
+ "r": "アール",
415
+ "s": "エス",
416
+ "t": "ティー",
417
+ "u": "ユー",
418
+ "v": "ブイ",
419
+ "w": "ダブリュー",
420
+ "x": "エックス",
421
+ "y": "ワイ",
422
+ "z": "ゼット",
423
+ "α": "アルファ",
424
+ "β": "ベータ",
425
+ "γ": "ガンマ",
426
+ "δ": "デルタ",
427
+ "ε": "イプシロン",
428
+ "ζ": "ゼータ",
429
+ "η": "イータ",
430
+ "θ": "シータ",
431
+ "ι": "イオタ",
432
+ "κ": "カッパ",
433
+ "λ": "ラムダ",
434
+ "μ": "ミュー",
435
+ "ν": "ニュー",
436
+ "ξ": "クサイ",
437
+ "ο": "オミクロン",
438
+ "π": "パイ",
439
+ "ρ": "ロー",
440
+ "σ": "シグマ",
441
+ "τ": "タウ",
442
+ "υ": "ウプシロン",
443
+ "φ": "ファイ",
444
+ "χ": "カイ",
445
+ "ψ": "プサイ",
446
+ "ω": "オメガ",
447
+ }
448
+
449
+ _NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
450
+ _CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"}
451
+ _CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
452
+ _NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
453
+
454
+
455
+ def japanese_convert_numbers_to_words(text: str) -> str:
456
+ res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
457
+ res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
458
+ res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
459
+ return res
460
+
461
+
462
+ def japanese_convert_alpha_symbols_to_words(text: str) -> str:
463
+ return "".join([_ALPHASYMBOL_YOMI.get(ch, ch) for ch in text.lower()])
464
+
465
+
466
+ def japanese_text_to_phonemes(text: str) -> str:
467
+ """Convert Japanese text to phonemes."""
468
+ res = unicodedata.normalize("NFKC", text)
469
+ res = japanese_convert_numbers_to_words(res)
470
+ # res = japanese_convert_alpha_symbols_to_words(res)
471
+ res = text2kata(res)
472
+ res = kata2phoneme(res)
473
+ return res
474
+
475
+
476
+ def is_japanese_character(char):
477
+ # 定义日语文字系统的 Unicode 范围
478
+ japanese_ranges = [
479
+ (0x3040, 0x309F), # 平假名
480
+ (0x30A0, 0x30FF), # 片假名
481
+ (0x4E00, 0x9FFF), # 汉字 (CJK Unified Ideographs)
482
+ (0x3400, 0x4DBF), # 汉字扩展 A
483
+ (0x20000, 0x2A6DF), # 汉字扩展 B
484
+ # 可以根据需要添加其他汉字扩展范围
485
+ ]
486
+
487
+ # 将字符的 Unicode 编码转换为整数
488
+ char_code = ord(char)
489
+
490
+ # 检查字符是否在任何一个日语范围内
491
+ for start, end in japanese_ranges:
492
+ if start <= char_code <= end:
493
+ return True
494
+
495
+ return False
496
+
497
+
498
+ rep_map = {
499
+ ":": ",",
500
+ ";": ",",
501
+ ",": ",",
502
+ "。": ".",
503
+ "!": "!",
504
+ "?": "?",
505
+ "\n": ".",
506
+ "·": ",",
507
+ "、": ",",
508
+ "...": "…",
509
+ }
510
+
511
+
512
+ def replace_punctuation(text):
513
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
514
+
515
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
516
+
517
+ replaced_text = re.sub(
518
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF"
519
+ + "".join(punctuation)
520
+ + r"]+",
521
+ "",
522
+ replaced_text,
523
+ )
524
+
525
+ return replaced_text
526
 
527
 
528
  def text_normalize(text):
529
+ res = unicodedata.normalize("NFKC", text)
530
+ res = japanese_convert_numbers_to_words(res)
531
+ # res = "".join([i for i in res if is_japanese_character(i)])
532
+ res = replace_punctuation(res)
533
+ return res
534
+
535
+
536
+ def distribute_phone(n_phone, n_word):
537
+ phones_per_word = [0] * n_word
538
+ for task in range(n_phone):
539
+ min_tasks = min(phones_per_word)
540
+ min_index = phones_per_word.index(min_tasks)
541
+ phones_per_word[min_index] += 1
542
+ return phones_per_word
543
 
544
 
545
  def g2p(norm_text):
546
+ tokenized = tokenizer.tokenize(norm_text)
547
+ phs = []
548
+ ph_groups = []
549
+ for t in tokenized:
550
+ if not t.startswith("#"):
551
+ ph_groups.append([t])
552
+ else:
553
+ ph_groups[-1].append(t.replace("#", ""))
554
+ word2ph = []
555
+ for group in ph_groups:
556
+ phonemes = kata2phoneme(text2kata("".join(group)))
557
+ # phonemes = [i for i in phonemes if i in symbols]
558
+ for i in phonemes:
559
+ assert i in symbols, (i, group, norm_text, tokenized)
560
+ phone_len = len(phonemes)
561
+ word_len = len(group)
562
+
563
+ aaa = distribute_phone(phone_len, word_len)
564
+ word2ph += aaa
565
 
566
+ phs += phonemes
567
+ phones = ["_"] + phs + ["_"]
568
  tones = [0 for i in phones]
569
+ word2ph = [1] + word2ph + [1]
570
  return phones, tones, word2ph
571
 
572
 
573
+ if __name__ == "__main__":
574
+ from config import ABS_PATH
575
+
576
+ tokenizer = AutoTokenizer.from_pretrained(ABS_PATH + "/bert_vits2/bert/bert-base-japanese-v3")
577
+ text = "hello,こんにちは、世界!……"
578
+ from bert_vits2.text.japanese_bert import get_bert_feature
579
+
580
+ text = text_normalize(text)
581
+ print(text)
582
+ phones, tones, word2ph = g2p(text)
583
+ bert = get_bert_feature(text, word2ph)
584
+
585
+ print(phones, tones, word2ph, bert.shape)
bert_vits2/text/japanese_bert.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
5
+
6
+ import config
7
+ from logger import logger
8
+ from utils.download import download_and_verify
9
+ from config import DEVICE as device
10
+
11
+ URLS = [
12
+ "https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin",
13
+ ]
14
+ TARGET_PATH = os.path.join(config.ABS_PATH, "bert_vits2/bert/bert-base-japanese-v3/pytorch_model.bin")
15
+ EXPECTED_MD5 = None
16
+
17
+ if not os.path.exists(TARGET_PATH):
18
+ success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5)
19
+
20
+ try:
21
+ logger.info("Loading bert-base-japanese-v3...")
22
+ tokenizer = AutoTokenizer.from_pretrained(config.ABS_PATH + "/bert_vits2/bert/bert-base-japanese-v3")
23
+ model = AutoModelForMaskedLM.from_pretrained(config.ABS_PATH + "/bert_vits2/bert/bert-base-japanese-v3").to(
24
+ device)
25
+ logger.info("Loading finished.")
26
+ except Exception as e:
27
+ logger.error(e)
28
+ logger.error(f"Please download pytorch_model.bin from cl-tohoku/bert-base-japanese-v3.")
29
+
30
+
31
+ def get_bert_feature(text, word2ph, device=config.DEVICE):
32
+ with torch.no_grad():
33
+ inputs = tokenizer(text, return_tensors="pt")
34
+ for i in inputs:
35
+ inputs[i] = inputs[i].to(device)
36
+ res = model(**inputs, output_hidden_states=True)
37
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
38
+ assert inputs["input_ids"].shape[-1] == len(word2ph)
39
+ word2phone = word2ph
40
+ phone_level_feature = []
41
+ for i in range(len(word2phone)):
42
+ repeat_feature = res[i].repeat(word2phone[i], 1)
43
+ phone_level_feature.append(repeat_feature)
44
+
45
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
46
+
47
+ return phone_level_feature.T
bert_vits2/text/symbols.py CHANGED
@@ -1,52 +1,198 @@
1
- punctuation = ['!', '?', '', ",", ".", "'", '-']
2
  pu_symbols = punctuation + ["SP", "UNK"]
3
- pad = '_'
4
 
5
  # chinese
6
- zh_symbols = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'b', 'c', 'ch', 'd', 'e', 'ei', 'en', 'eng', 'er', 'f', 'g', 'h',
7
- 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'j', 'k', 'l', 'm', 'n',
8
- 'o',
9
- 'ong',
10
- 'ou', 'p', 'q', 'r', 's', 'sh', 't', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've',
11
- 'vn',
12
- 'w', 'x', 'y', 'z', 'zh',
13
- "AA", "EE", "OO"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  num_zh_tones = 6
15
 
16
  # japanese
17
- ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky',
18
- 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'V', 'w', 'y', 'z']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  num_ja_tones = 1
20
 
21
  # English
22
- en_symbols = ['aa', 'ae', 'ah', 'ao', 'aw', 'ay', 'b', 'ch', 'd', 'dh', 'eh', 'er', 'ey', 'f', 'g', 'hh', 'ih', 'iy',
23
- 'jh', 'k', 'l', 'm', 'n', 'ng', 'ow', 'oy', 'p', 'r', 's',
24
- 'sh', 't', 'th', 'uh', 'uw', 'V', 'w', 'y', 'z', 'zh']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  num_en_tones = 4
26
 
27
- # combine all symbols
28
  normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
29
  symbols = [pad] + normal_symbols + pu_symbols
30
  sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
31
 
 
 
 
 
 
32
  # combine all tones
33
  num_tones = num_zh_tones + num_ja_tones + num_en_tones
34
 
35
  # language maps
36
- language_id_map = {
37
- 'ZH': 0,
38
- "JA": 1,
39
- "EN": 2
40
- }
41
  num_languages = len(language_id_map.keys())
42
 
43
  language_tone_start_map = {
44
- 'ZH': 0,
45
- "JA": num_zh_tones,
46
- "EN": num_zh_tones + num_ja_tones
47
  }
48
 
49
- if __name__ == '__main__':
50
- a = set(zh_symbols)
51
- b = set(en_symbols)
52
- print(sorted(a & b))
 
 
 
 
 
1
+ punctuation = ["!", "?", "", ",", ".", "'", "-"]
2
  pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
 
5
  # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
  num_zh_tones = 6
74
 
75
  # japanese
76
+ ja_symbols_legacy = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j',
77
+ 'k', 'ky',
78
+ 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'V', 'w', 'y', 'z']
79
+ ja_symbols = [
80
+ "N",
81
+ "a",
82
+ "a:",
83
+ "b",
84
+ "by",
85
+ "ch",
86
+ "d",
87
+ "dy",
88
+ "e",
89
+ "e:",
90
+ "f",
91
+ "g",
92
+ "gy",
93
+ "h",
94
+ "hy",
95
+ "i",
96
+ "i:",
97
+ "j",
98
+ "k",
99
+ "ky",
100
+ "m",
101
+ "my",
102
+ "n",
103
+ "ny",
104
+ "o",
105
+ "o:",
106
+ "p",
107
+ "py",
108
+ "q",
109
+ "r",
110
+ "ry",
111
+ "s",
112
+ "sh",
113
+ "t",
114
+ "ts",
115
+ "ty",
116
+ "u",
117
+ "u:",
118
+ "w",
119
+ "y",
120
+ "z",
121
+ "zy",
122
+ ]
123
  num_ja_tones = 1
124
 
125
  # English
126
+ en_symbols = [
127
+ "aa",
128
+ "ae",
129
+ "ah",
130
+ "ao",
131
+ "aw",
132
+ "ay",
133
+ "b",
134
+ "ch",
135
+ "d",
136
+ "dh",
137
+ "eh",
138
+ "er",
139
+ "ey",
140
+ "f",
141
+ "g",
142
+ "hh",
143
+ "ih",
144
+ "iy",
145
+ "jh",
146
+ "k",
147
+ "l",
148
+ "m",
149
+ "n",
150
+ "ng",
151
+ "ow",
152
+ "oy",
153
+ "p",
154
+ "r",
155
+ "s",
156
+ "sh",
157
+ "t",
158
+ "th",
159
+ "uh",
160
+ "uw",
161
+ "V",
162
+ "w",
163
+ "y",
164
+ "z",
165
+ "zh",
166
+ ]
167
  num_en_tones = 4
168
 
 
169
  normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
170
  symbols = [pad] + normal_symbols + pu_symbols
171
  sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
172
 
173
+ # legacy
174
+ normal_symbols_legacy = sorted(set(zh_symbols + ja_symbols_legacy + en_symbols))
175
+ symbols_legacy = [pad] + normal_symbols_legacy + pu_symbols
176
+ sil_phonemes_ids_legacy = [symbols_legacy.index(i) for i in pu_symbols]
177
+
178
  # combine all tones
179
  num_tones = num_zh_tones + num_ja_tones + num_en_tones
180
 
181
  # language maps
182
+ language_id_map = {"zh": 0, "ja": 1, "en": 2}
 
 
 
 
183
  num_languages = len(language_id_map.keys())
184
 
185
  language_tone_start_map = {
186
+ "zh": 0,
187
+ "ja": num_zh_tones,
188
+ "en": num_zh_tones + num_ja_tones,
189
  }
190
 
191
+ if __name__ == "__main__":
192
+ zh = set(zh_symbols)
193
+ en = set(en_symbols)
194
+ ja = set(ja_symbols)
195
+ print(zh)
196
+ print(en)
197
+ print(ja)
198
+ print(sorted(zh & en))
bert_vits2/text/tone_sandhi.py CHANGED
@@ -19,51 +19,442 @@ from pypinyin import lazy_pinyin
19
  from pypinyin import Style
20
 
21
 
22
- class ToneSandhi():
23
  def __init__(self):
24
  self.must_neural_tone_words = {
25
- '麻烦', '麻利', '鸳鸯', '高粱', '骨头', '骆驼', '马虎', '首饰', '馒头', '馄饨', '风筝',
26
- '难为', '队伍', '阔气', '闺女', '门道', '锄头', '铺盖', '铃铛', '铁匠', '钥匙', '里脊',
27
- '里头', '部分', '那么', '道士', '造化', '迷糊', '连累', '这么', '这个', '运气', '过去',
28
- '软和', '转悠', '踏实', '跳蚤', '跟头', '趔趄', '财主', '豆腐', '讲究', '记性', '记号',
29
- '认识', '规矩', '见识', '裁缝', '补丁', '衣裳', '衣服', '衙门', '街坊', '行李', '行当',
30
- '蛤蟆', '蘑菇', '薄荷', '葫芦', '葡萄', '萝卜', '荸荠', '苗条', '苗头', '苍蝇', '芝麻',
31
- '舒服', '舒坦', '舌头', '自在', '膏药', '脾气', '脑袋', '脊梁', '能耐', '胳膊', '胭脂',
32
- '胡萝', '胡琴', '胡同', '聪明', '耽误', '耽搁', '耷拉', '耳朵', '老爷', '老实', '老婆',
33
- '老头', '老太', '翻腾', '罗嗦', '罐头', '编辑', '结实', '红火', '累赘', '糨糊', '糊涂',
34
- '精神', '粮食', '簸箕', '篱笆', '算计', '算盘', '答应', '笤帚', '笑语', '笑话', '窟窿',
35
- '窝囊', '窗户', '稳当', '稀罕', '称呼', '秧歌', '秀气', '秀才', '福气', '祖宗', '砚台',
36
- '码头', '石榴', '石头', '石匠', '知识', '眼睛', '眯缝', '眨巴', '眉毛', '相声', '盘算',
37
- '白净', '痢疾', '痛快', '疟疾', '疙瘩', '疏忽', '畜生', '生意', '甘蔗', '琵琶', '琢磨',
38
- '琉璃', '玻璃', '玫瑰', '玄乎', '狐狸', '状元', '特务', '牲口', '牙碜', '牌楼', '爽快',
39
- '爱人', '热闹', '烧饼', '烟筒', '烂糊', '点心', '炊帚', '灯笼', '火候', '漂亮', '滑溜',
40
- '溜达', '温和', '清楚', '消息', '浪头', '活泼', '比方', '正经', '欺负', '模糊', '槟榔',
41
- '棺材', '棒槌', '棉花', '核桃', '栅栏', '柴火', '架势', '枕头', '枇杷', '机灵', '本事',
42
- '木头', '木匠', '朋友', '月饼', '月亮', '暖和', '明白', '时候', '新鲜', '故事', '收拾',
43
- '收成', '提防', '挖苦', '挑剔', '指甲', '指头', '拾掇', '拳头', '拨弄', '招牌', '招呼',
44
- '抬举', '护士', '折腾', '扫帚', '打量', '打算', '打点', '打扮', '打听', '打发', '扎实',
45
- '扁担', '戒指', '懒得', '意识', '意思', '情形', '悟性', '怪物', '思量', '怎么', '念头',
46
- '念叨', '快活', '忙活', '志气', '心思', '得罪', '张罗', '弟兄', '开通', '应酬', '庄稼',
47
- '干事', '帮手', '帐篷', '希罕', '师父', '师傅', '巴结', '巴掌', '差事', '工夫', '岁数',
48
- '屁股', '尾巴', '少爷', '小气', '小伙', '将就', '对头', '对付', '寡妇', '家伙', '客气',
49
- '实在', '官司', '学问', '学生', '字号', '嫁妆', '媳妇', '媒人', '婆家', '娘家', '委屈',
50
- '姑娘', '姐夫', '妯娌', '妥当', '妖精', '奴才', '女婿', '头发', '太阳', '大爷', '大方',
51
- '大意', '大夫', '多少', '多么', '外甥', '壮实', '地道', '地方', '在乎', '困难', '嘴巴',
52
- '嘱咐', '嘟囔', '嘀咕', '喜欢', '喇嘛', '喇叭', '商量', '唾沫', '哑巴', '哈欠', '哆嗦',
53
- '咳嗽', '和尚', '告诉', '告示', '含糊', '吓唬', '后头', '名字', '名堂', '合同', '吆喝',
54
- '叫唤', '口袋', '厚道', '厉害', '千斤', '包袱', '包涵', '匀称', '勤快', '动静', '动弹',
55
- '功夫', '力气', '前头', '刺猬', '刺激', '别扭', '利落', '利索', '利害', '分析', '出息',
56
- '凑合', '凉快', '冷战', '冤枉', '冒失', '养活', '关系', '先生', '兄弟', '便宜', '使唤',
57
- '佩服', '作坊', '体面', '位置', '似的', '伙计', '休息', '什么', '人家', '亲戚', '亲家',
58
- '交情', '云彩', '事情', '买卖', '主意', '丫头', '丧气', '两口', '东西', '东家', '世故',
59
- '不由', '不在', '下水', '下巴', '上头', '上司', '丈夫', '丈人', '一辈', '那个', '菩萨',
60
- '父亲', '母亲', '咕噜', '邋遢', '费用', '冤家', '甜头', '介绍', '荒唐', '大人', '泥鳅',
61
- '幸福', '熟悉', '计划', '扑腾', '蜡烛', '姥爷', '照顾', '喉咙', '吉他', '弄堂', '蚂蚱',
62
- '凤凰', '拖沓', '寒碜', '糟蹋', '倒腾', '报复', '逻辑', '盘缠', '喽啰', '牢骚', '咖喱',
63
- '扫把', '惦记'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  }
65
  self.must_not_neural_tone_words = {
66
- "男子", "女子", "分子", "原子", "量子", "莲子", "石子", "瓜子", "电子", "人人", "虎虎"
 
 
 
 
 
 
 
 
 
 
67
  }
68
  self.punc = ":,;。?!“”‘’':,;.?!"
69
 
@@ -72,14 +463,15 @@ class ToneSandhi():
72
  # word: "家里"
73
  # pos: "s"
74
  # finals: ['ia1', 'i3']
75
- def _neural_sandhi(self, word: str, pos: str,
76
- finals: List[str]) -> List[str]:
77
-
78
  # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
79
  for j, item in enumerate(word):
80
- if j - 1 >= 0 and item == word[j - 1] and pos[0] in {
81
- "n", "v", "a"
82
- } and word not in self.must_not_neural_tone_words:
 
 
 
83
  finals[j] = finals[j][:-1] + "5"
84
  ge_idx = word.find("个")
85
  if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
@@ -89,9 +481,12 @@ class ToneSandhi():
89
  # e.g. 走了, 看着, 去过
90
  # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
91
  # finals[-1] = finals[-1][:-1] + "5"
92
- elif len(word) > 1 and word[-1] in "们子" and pos in {
93
- "r", "n"
94
- } and word not in self.must_not_neural_tone_words:
 
 
 
95
  finals[-1] = finals[-1][:-1] + "5"
96
  # e.g. 桌上, 地下, 家里
97
  elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
@@ -100,21 +495,26 @@ class ToneSandhi():
100
  elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
101
  finals[-1] = finals[-1][:-1] + "5"
102
  # 个做量词
103
- elif (ge_idx >= 1 and
104
- (word[ge_idx - 1].isnumeric() or
105
- word[ge_idx - 1] in "几有两半多各整每做是")) or word == '个':
 
106
  finals[ge_idx] = finals[ge_idx][:-1] + "5"
107
  else:
108
- if word in self.must_neural_tone_words or word[
109
- -2:] in self.must_neural_tone_words:
 
 
110
  finals[-1] = finals[-1][:-1] + "5"
111
 
112
  word_list = self._split_word(word)
113
- finals_list = [finals[:len(word_list[0])], finals[len(word_list[0]):]]
114
  for i, word in enumerate(word_list):
115
  # conventional neural in Chinese
116
- if word in self.must_neural_tone_words or word[
117
- -2:] in self.must_neural_tone_words:
 
 
118
  finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
119
  finals = sum(finals_list, [])
120
  return finals
@@ -126,17 +526,17 @@ class ToneSandhi():
126
  else:
127
  for i, char in enumerate(word):
128
  # "不" before tone4 should be bu2, e.g. 不怕
129
- if char == "不" and i + 1 < len(word) and finals[i +
130
- 1][-1] == "4":
131
  finals[i] = finals[i][:-1] + "2"
132
  return finals
133
 
134
  def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
135
  # "一" in number sequences, e.g. 一零零, 二一零
136
  if word.find("一") != -1 and all(
137
- [item.isnumeric() for item in word if item != "一"]):
 
138
  return finals
139
- # "一" between reduplication words shold be yi5, e.g. 看一看
140
  elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
141
  finals[1] = finals[1][:-1] + "5"
142
  # when "一" is ordinal word, it should be yi1
@@ -161,10 +561,10 @@ class ToneSandhi():
161
  first_subword = word_list[0]
162
  first_begin_idx = word.find(first_subword)
163
  if first_begin_idx == 0:
164
- second_subword = word[len(first_subword):]
165
  new_word_list = [first_subword, second_subword]
166
  else:
167
- second_subword = word[:-len(first_subword)]
168
  new_word_list = [second_subword, first_subword]
169
  return new_word_list
170
 
@@ -182,18 +582,19 @@ class ToneSandhi():
182
  elif len(word_list[0]) == 1:
183
  finals[1] = finals[1][:-1] + "2"
184
  else:
185
- finals_list = [
186
- finals[:len(word_list[0])], finals[len(word_list[0]):]
187
- ]
188
  if len(finals_list) == 2:
189
  for i, sub in enumerate(finals_list):
190
  # e.g. 所有/人
191
  if self._all_tone_three(sub) and len(sub) == 2:
192
  finals_list[i][0] = finals_list[i][0][:-1] + "2"
193
  # e.g. 好/喜欢
194
- elif i == 1 and not self._all_tone_three(sub) and finals_list[i][0][-1] == "3" and \
195
- finals_list[0][-1][-1] == "3":
196
-
 
 
 
197
  finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
198
  finals = sum(finals_list, [])
199
  # split idiom into two words who's length is 2
@@ -222,7 +623,7 @@ class ToneSandhi():
222
  new_seg.append((word, pos))
223
  last_word = word[:]
224
  if last_word == "不":
225
- new_seg.append((last_word, 'd'))
226
  last_word = ""
227
  return new_seg
228
 
@@ -236,12 +637,21 @@ class ToneSandhi():
236
  new_seg = []
237
  # function 1
238
  for i, (word, pos) in enumerate(seg):
239
- if i - 1 >= 0 and word == "一" and i + 1 < len(seg) and seg[i - 1][
240
- 0] == seg[i + 1][0] and seg[i - 1][1] == "v":
 
 
 
 
 
241
  new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
242
  else:
243
- if i - 2 >= 0 and seg[i - 1][0] == "一" and seg[i - 2][
244
- 0] == word and pos == "v":
 
 
 
 
245
  continue
246
  else:
247
  new_seg.append([word, pos])
@@ -257,22 +667,27 @@ class ToneSandhi():
257
 
258
  # the first and the second words are all_tone_three
259
  def _merge_continuous_three_tones(
260
- self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
 
261
  new_seg = []
262
  sub_finals_list = [
263
- lazy_pinyin(
264
- word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
265
  for (word, pos) in seg
266
  ]
267
  assert len(sub_finals_list) == len(seg)
268
  merge_last = [False] * len(seg)
269
  for i, (word, pos) in enumerate(seg):
270
- if i - 1 >= 0 and self._all_tone_three(
271
- sub_finals_list[i - 1]) and self._all_tone_three(
272
- sub_finals_list[i]) and not merge_last[i - 1]:
 
 
 
273
  # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
274
- if not self._is_reduplication(seg[i - 1][0]) and len(
275
- seg[i - 1][0]) + len(seg[i][0]) <= 3:
 
 
276
  new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
277
  merge_last[i] = True
278
  else:
@@ -287,21 +702,27 @@ class ToneSandhi():
287
 
288
  # the last char of first word and the first char of second word is tone_three
289
  def _merge_continuous_three_tones_2(
290
- self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
 
291
  new_seg = []
292
  sub_finals_list = [
293
- lazy_pinyin(
294
- word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
295
  for (word, pos) in seg
296
  ]
297
  assert len(sub_finals_list) == len(seg)
298
  merge_last = [False] * len(seg)
299
  for i, (word, pos) in enumerate(seg):
300
- if i - 1 >= 0 and sub_finals_list[i - 1][-1][-1] == "3" and sub_finals_list[i][0][-1] == "3" and not \
301
- merge_last[i - 1]:
 
 
 
 
302
  # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
303
- if not self._is_reduplication(seg[i - 1][0]) and len(
304
- seg[i - 1][0]) + len(seg[i][0]) <= 3:
 
 
305
  new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
306
  merge_last[i] = True
307
  else:
@@ -319,8 +740,7 @@ class ToneSandhi():
319
  new_seg.append([word, pos])
320
  return new_seg
321
 
322
- def _merge_reduplication(
323
- self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
324
  new_seg = []
325
  for i, (word, pos) in enumerate(seg):
326
  if new_seg and word == new_seg[-1][0]:
@@ -329,8 +749,7 @@ class ToneSandhi():
329
  new_seg.append([word, pos])
330
  return new_seg
331
 
332
- def pre_merge_for_modify(
333
- self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
334
  seg = self._merge_bu(seg)
335
  try:
336
  seg = self._merge_yi(seg)
@@ -342,8 +761,7 @@ class ToneSandhi():
342
  seg = self._merge_er(seg)
343
  return seg
344
 
345
- def modified_tone(self, word: str, pos: str,
346
- finals: List[str]) -> List[str]:
347
  finals = self._bu_sandhi(word, finals)
348
  finals = self._yi_sandhi(word, finals)
349
  finals = self._neural_sandhi(word, pos, finals)
 
19
  from pypinyin import Style
20
 
21
 
22
+ class ToneSandhi:
23
  def __init__(self):
24
  self.must_neural_tone_words = {
25
+ "麻烦",
26
+ "麻利",
27
+ "鸳鸯",
28
+ "高粱",
29
+ "骨头",
30
+ "骆驼",
31
+ "马虎",
32
+ "首饰",
33
+ "馒头",
34
+ "���饨",
35
+ "风筝",
36
+ "难为",
37
+ "队伍",
38
+ "阔气",
39
+ "闺女",
40
+ "门道",
41
+ "锄头",
42
+ "铺盖",
43
+ "铃铛",
44
+ "铁匠",
45
+ "钥匙",
46
+ "里脊",
47
+ "里头",
48
+ "部分",
49
+ "那么",
50
+ "道士",
51
+ "造化",
52
+ "迷糊",
53
+ "连累",
54
+ "这么",
55
+ "这个",
56
+ "运气",
57
+ "过去",
58
+ "软和",
59
+ "转悠",
60
+ "踏实",
61
+ "跳蚤",
62
+ "跟头",
63
+ "趔趄",
64
+ "财主",
65
+ "豆腐",
66
+ "讲究",
67
+ "记性",
68
+ "记号",
69
+ "认识",
70
+ "规矩",
71
+ "见识",
72
+ "裁缝",
73
+ "补丁",
74
+ "衣裳",
75
+ "衣服",
76
+ "衙门",
77
+ "街坊",
78
+ "行李",
79
+ "行当",
80
+ "蛤蟆",
81
+ "蘑菇",
82
+ "薄荷",
83
+ "葫芦",
84
+ "葡萄",
85
+ "萝卜",
86
+ "荸荠",
87
+ "苗条",
88
+ "苗头",
89
+ "苍蝇",
90
+ "芝麻",
91
+ "舒服",
92
+ "舒坦",
93
+ "舌头",
94
+ "自在",
95
+ "膏药",
96
+ "脾气",
97
+ "脑袋",
98
+ "脊梁",
99
+ "能耐",
100
+ "胳膊",
101
+ "胭脂",
102
+ "胡萝",
103
+ "胡琴",
104
+ "胡同",
105
+ "聪明",
106
+ "耽误",
107
+ "耽搁",
108
+ "耷拉",
109
+ "耳朵",
110
+ "老爷",
111
+ "老实",
112
+ "老婆",
113
+ "老头",
114
+ "老太",
115
+ "翻腾",
116
+ "罗嗦",
117
+ "罐头",
118
+ "编辑",
119
+ "结实",
120
+ "红火",
121
+ "累赘",
122
+ "糨糊",
123
+ "糊涂",
124
+ "精神",
125
+ "粮食",
126
+ "簸箕",
127
+ "篱笆",
128
+ "算计",
129
+ "算盘",
130
+ "答应",
131
+ "笤帚",
132
+ "笑语",
133
+ "笑话",
134
+ "窟窿",
135
+ "窝囊",
136
+ "窗户",
137
+ "稳当",
138
+ "稀罕",
139
+ "称呼",
140
+ "秧歌",
141
+ "秀气",
142
+ "秀才",
143
+ "福气",
144
+ "祖宗",
145
+ "砚台",
146
+ "码头",
147
+ "石榴",
148
+ "石头",
149
+ "石匠",
150
+ "知识",
151
+ "眼睛",
152
+ "眯缝",
153
+ "眨巴",
154
+ "眉毛",
155
+ "相声",
156
+ "盘算",
157
+ "白净",
158
+ "痢疾",
159
+ "痛快",
160
+ "疟疾",
161
+ "疙瘩",
162
+ "疏忽",
163
+ "畜生",
164
+ "生意",
165
+ "甘蔗",
166
+ "琵琶",
167
+ "琢磨",
168
+ "琉璃",
169
+ "玻璃",
170
+ "玫瑰",
171
+ "玄乎",
172
+ "狐狸",
173
+ "状元",
174
+ "特务",
175
+ "牲口",
176
+ "牙碜",
177
+ "牌楼",
178
+ "爽快",
179
+ "爱人",
180
+ "热闹",
181
+ "烧饼",
182
+ "烟筒",
183
+ "烂糊",
184
+ "点心",
185
+ "炊帚",
186
+ "灯笼",
187
+ "火候",
188
+ "漂亮",
189
+ "滑溜",
190
+ "溜达",
191
+ "温和",
192
+ "清楚",
193
+ "消息",
194
+ "浪头",
195
+ "活泼",
196
+ "比方",
197
+ "正经",
198
+ "欺负",
199
+ "模糊",
200
+ "槟榔",
201
+ "棺材",
202
+ "棒槌",
203
+ "棉花",
204
+ "核桃",
205
+ "栅栏",
206
+ "柴火",
207
+ "架势",
208
+ "枕头",
209
+ "枇杷",
210
+ "机灵",
211
+ "本事",
212
+ "木头",
213
+ "木匠",
214
+ "朋友",
215
+ "月饼",
216
+ "月亮",
217
+ "暖和",
218
+ "明白",
219
+ "时候",
220
+ "新鲜",
221
+ "故事",
222
+ "收拾",
223
+ "收成",
224
+ "提防",
225
+ "挖苦",
226
+ "挑剔",
227
+ "指甲",
228
+ "指头",
229
+ "拾掇",
230
+ "拳头",
231
+ "拨弄",
232
+ "招牌",
233
+ "招呼",
234
+ "抬举",
235
+ "护士",
236
+ "折腾",
237
+ "扫帚",
238
+ "打量",
239
+ "打算",
240
+ "打点",
241
+ "打扮",
242
+ "打听",
243
+ "打发",
244
+ "扎实",
245
+ "扁担",
246
+ "戒指",
247
+ "懒得",
248
+ "意识",
249
+ "意思",
250
+ "情形",
251
+ "悟性",
252
+ "怪物",
253
+ "思量",
254
+ "怎么",
255
+ "念头",
256
+ "念叨",
257
+ "快活",
258
+ "忙活",
259
+ "志气",
260
+ "心思",
261
+ "得罪",
262
+ "张罗",
263
+ "弟兄",
264
+ "开通",
265
+ "应酬",
266
+ "庄稼",
267
+ "干事",
268
+ "帮手",
269
+ "帐篷",
270
+ "希罕",
271
+ "师父",
272
+ "师傅",
273
+ "巴结",
274
+ "巴掌",
275
+ "差事",
276
+ "工夫",
277
+ "岁数",
278
+ "屁股",
279
+ "尾巴",
280
+ "少爷",
281
+ "小气",
282
+ "小伙",
283
+ "将就",
284
+ "对头",
285
+ "对付",
286
+ "寡妇",
287
+ "家伙",
288
+ "客气",
289
+ "实在",
290
+ "官司",
291
+ "学问",
292
+ "学生",
293
+ "字号",
294
+ "嫁妆",
295
+ "媳妇",
296
+ "媒人",
297
+ "婆家",
298
+ "娘家",
299
+ "委屈",
300
+ "姑娘",
301
+ "姐夫",
302
+ "妯娌",
303
+ "妥当",
304
+ "妖精",
305
+ "奴才",
306
+ "女婿",
307
+ "头发",
308
+ "太阳",
309
+ "大爷",
310
+ "大方",
311
+ "大意",
312
+ "大夫",
313
+ "多少",
314
+ "多么",
315
+ "外甥",
316
+ "壮实",
317
+ "地道",
318
+ "地方",
319
+ "在乎",
320
+ "困难",
321
+ "嘴巴",
322
+ "嘱咐",
323
+ "嘟囔",
324
+ "嘀咕",
325
+ "喜欢",
326
+ "喇嘛",
327
+ "喇叭",
328
+ "商量",
329
+ "唾沫",
330
+ "哑巴",
331
+ "哈欠",
332
+ "哆嗦",
333
+ "咳嗽",
334
+ "和尚",
335
+ "告诉",
336
+ "告示",
337
+ "含糊",
338
+ "吓唬",
339
+ "后头",
340
+ "名字",
341
+ "名堂",
342
+ "合同",
343
+ "吆喝",
344
+ "叫唤",
345
+ "口袋",
346
+ "厚道",
347
+ "厉害",
348
+ "千斤",
349
+ "包袱",
350
+ "包涵",
351
+ "匀称",
352
+ "勤快",
353
+ "动静",
354
+ "动弹",
355
+ "功夫",
356
+ "力气",
357
+ "前头",
358
+ "刺猬",
359
+ "刺激",
360
+ "别扭",
361
+ "利落",
362
+ "利索",
363
+ "利害",
364
+ "分析",
365
+ "出息",
366
+ "凑合",
367
+ "凉快",
368
+ "冷战",
369
+ "冤枉",
370
+ "冒失",
371
+ "养活",
372
+ "关系",
373
+ "先生",
374
+ "兄弟",
375
+ "便宜",
376
+ "使唤",
377
+ "佩服",
378
+ "作坊",
379
+ "体面",
380
+ "位置",
381
+ "似的",
382
+ "伙计",
383
+ "休息",
384
+ "什么",
385
+ "人家",
386
+ "亲戚",
387
+ "亲家",
388
+ "交情",
389
+ "云彩",
390
+ "事情",
391
+ "买卖",
392
+ "主意",
393
+ "丫头",
394
+ "丧气",
395
+ "两口",
396
+ "东西",
397
+ "东家",
398
+ "世故",
399
+ "不由",
400
+ "不在",
401
+ "下水",
402
+ "下巴",
403
+ "上头",
404
+ "上司",
405
+ "丈夫",
406
+ "丈人",
407
+ "一辈",
408
+ "那个",
409
+ "菩萨",
410
+ "父亲",
411
+ "母亲",
412
+ "咕噜",
413
+ "邋遢",
414
+ "费用",
415
+ "冤家",
416
+ "甜头",
417
+ "介绍",
418
+ "荒唐",
419
+ "大人",
420
+ "泥鳅",
421
+ "幸福",
422
+ "熟悉",
423
+ "计划",
424
+ "扑腾",
425
+ "蜡烛",
426
+ "姥爷",
427
+ "照顾",
428
+ "喉咙",
429
+ "吉他",
430
+ "弄堂",
431
+ "蚂蚱",
432
+ "凤凰",
433
+ "拖沓",
434
+ "寒碜",
435
+ "糟蹋",
436
+ "倒腾",
437
+ "报复",
438
+ "逻辑",
439
+ "盘缠",
440
+ "喽啰",
441
+ "牢骚",
442
+ "咖喱",
443
+ "扫把",
444
+ "惦记",
445
  }
446
  self.must_not_neural_tone_words = {
447
+ "男子",
448
+ "女子",
449
+ "分子",
450
+ "原子",
451
+ "量子",
452
+ "莲子",
453
+ "石子",
454
+ "瓜子",
455
+ "电子",
456
+ "人人",
457
+ "虎虎",
458
  }
459
  self.punc = ":,;。?!“”‘’':,;.?!"
460
 
 
463
  # word: "家里"
464
  # pos: "s"
465
  # finals: ['ia1', 'i3']
466
+ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
 
 
467
  # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
468
  for j, item in enumerate(word):
469
+ if (
470
+ j - 1 >= 0
471
+ and item == word[j - 1]
472
+ and pos[0] in {"n", "v", "a"}
473
+ and word not in self.must_not_neural_tone_words
474
+ ):
475
  finals[j] = finals[j][:-1] + "5"
476
  ge_idx = word.find("个")
477
  if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
 
481
  # e.g. 走了, 看着, 去过
482
  # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
483
  # finals[-1] = finals[-1][:-1] + "5"
484
+ elif (
485
+ len(word) > 1
486
+ and word[-1] in "们子"
487
+ and pos in {"r", "n"}
488
+ and word not in self.must_not_neural_tone_words
489
+ ):
490
  finals[-1] = finals[-1][:-1] + "5"
491
  # e.g. 桌上, 地下, 家里
492
  elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
 
495
  elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
496
  finals[-1] = finals[-1][:-1] + "5"
497
  # 个做量词
498
+ elif (
499
+ ge_idx >= 1
500
+ and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
501
+ ) or word == "个":
502
  finals[ge_idx] = finals[ge_idx][:-1] + "5"
503
  else:
504
+ if (
505
+ word in self.must_neural_tone_words
506
+ or word[-2:] in self.must_neural_tone_words
507
+ ):
508
  finals[-1] = finals[-1][:-1] + "5"
509
 
510
  word_list = self._split_word(word)
511
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
512
  for i, word in enumerate(word_list):
513
  # conventional neural in Chinese
514
+ if (
515
+ word in self.must_neural_tone_words
516
+ or word[-2:] in self.must_neural_tone_words
517
+ ):
518
  finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
519
  finals = sum(finals_list, [])
520
  return finals
 
526
  else:
527
  for i, char in enumerate(word):
528
  # "不" before tone4 should be bu2, e.g. 不怕
529
+ if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
 
530
  finals[i] = finals[i][:-1] + "2"
531
  return finals
532
 
533
  def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
534
  # "一" in number sequences, e.g. 一零零, 二一零
535
  if word.find("一") != -1 and all(
536
+ [item.isnumeric() for item in word if item != "一"]
537
+ ):
538
  return finals
539
+ # "一" between reduplication words should be yi5, e.g. 看一看
540
  elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
541
  finals[1] = finals[1][:-1] + "5"
542
  # when "一" is ordinal word, it should be yi1
 
561
  first_subword = word_list[0]
562
  first_begin_idx = word.find(first_subword)
563
  if first_begin_idx == 0:
564
+ second_subword = word[len(first_subword) :]
565
  new_word_list = [first_subword, second_subword]
566
  else:
567
+ second_subword = word[: -len(first_subword)]
568
  new_word_list = [second_subword, first_subword]
569
  return new_word_list
570
 
 
582
  elif len(word_list[0]) == 1:
583
  finals[1] = finals[1][:-1] + "2"
584
  else:
585
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
 
 
586
  if len(finals_list) == 2:
587
  for i, sub in enumerate(finals_list):
588
  # e.g. 所有/人
589
  if self._all_tone_three(sub) and len(sub) == 2:
590
  finals_list[i][0] = finals_list[i][0][:-1] + "2"
591
  # e.g. 好/喜欢
592
+ elif (
593
+ i == 1
594
+ and not self._all_tone_three(sub)
595
+ and finals_list[i][0][-1] == "3"
596
+ and finals_list[0][-1][-1] == "3"
597
+ ):
598
  finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
599
  finals = sum(finals_list, [])
600
  # split idiom into two words who's length is 2
 
623
  new_seg.append((word, pos))
624
  last_word = word[:]
625
  if last_word == "不":
626
+ new_seg.append((last_word, "d"))
627
  last_word = ""
628
  return new_seg
629
 
 
637
  new_seg = []
638
  # function 1
639
  for i, (word, pos) in enumerate(seg):
640
+ if (
641
+ i - 1 >= 0
642
+ and word == "一"
643
+ and i + 1 < len(seg)
644
+ and seg[i - 1][0] == seg[i + 1][0]
645
+ and seg[i - 1][1] == "v"
646
+ ):
647
  new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
648
  else:
649
+ if (
650
+ i - 2 >= 0
651
+ and seg[i - 1][0] == "一"
652
+ and seg[i - 2][0] == word
653
+ and pos == "v"
654
+ ):
655
  continue
656
  else:
657
  new_seg.append([word, pos])
 
667
 
668
  # the first and the second words are all_tone_three
669
  def _merge_continuous_three_tones(
670
+ self, seg: List[Tuple[str, str]]
671
+ ) -> List[Tuple[str, str]]:
672
  new_seg = []
673
  sub_finals_list = [
674
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
 
675
  for (word, pos) in seg
676
  ]
677
  assert len(sub_finals_list) == len(seg)
678
  merge_last = [False] * len(seg)
679
  for i, (word, pos) in enumerate(seg):
680
+ if (
681
+ i - 1 >= 0
682
+ and self._all_tone_three(sub_finals_list[i - 1])
683
+ and self._all_tone_three(sub_finals_list[i])
684
+ and not merge_last[i - 1]
685
+ ):
686
  # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
687
+ if (
688
+ not self._is_reduplication(seg[i - 1][0])
689
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
690
+ ):
691
  new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
692
  merge_last[i] = True
693
  else:
 
702
 
703
  # the last char of first word and the first char of second word is tone_three
704
  def _merge_continuous_three_tones_2(
705
+ self, seg: List[Tuple[str, str]]
706
+ ) -> List[Tuple[str, str]]:
707
  new_seg = []
708
  sub_finals_list = [
709
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
 
710
  for (word, pos) in seg
711
  ]
712
  assert len(sub_finals_list) == len(seg)
713
  merge_last = [False] * len(seg)
714
  for i, (word, pos) in enumerate(seg):
715
+ if (
716
+ i - 1 >= 0
717
+ and sub_finals_list[i - 1][-1][-1] == "3"
718
+ and sub_finals_list[i][0][-1] == "3"
719
+ and not merge_last[i - 1]
720
+ ):
721
  # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
722
+ if (
723
+ not self._is_reduplication(seg[i - 1][0])
724
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
725
+ ):
726
  new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
727
  merge_last[i] = True
728
  else:
 
740
  new_seg.append([word, pos])
741
  return new_seg
742
 
743
+ def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
 
744
  new_seg = []
745
  for i, (word, pos) in enumerate(seg):
746
  if new_seg and word == new_seg[-1][0]:
 
749
  new_seg.append([word, pos])
750
  return new_seg
751
 
752
+ def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
 
753
  seg = self._merge_bu(seg)
754
  try:
755
  seg = self._merge_yi(seg)
 
761
  seg = self._merge_er(seg)
762
  return seg
763
 
764
+ def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
 
765
  finals = self._bu_sandhi(word, finals)
766
  finals = self._yi_sandhi(word, finals)
767
  finals = self._neural_sandhi(word, pos, finals)
bert_vits2/utils.py CHANGED
@@ -1,12 +1,6 @@
1
  import os
2
- import glob
3
  import sys
4
- import argparse
5
  import logging
6
- import json
7
- import subprocess
8
- import numpy as np
9
- from scipy.io.wavfile import read
10
  import torch
11
 
12
  MATPLOTLIB_FLAG = False
@@ -42,7 +36,13 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
42
  new_state_dict[k] = saved_state_dict[k]
43
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
44
  except:
45
- print("error, %s is not in the checkpoint" % k)
 
 
 
 
 
 
46
  new_state_dict[k] = v
47
  if hasattr(model, 'module'):
48
  model.module.load_state_dict(new_state_dict, strict=False)
@@ -52,241 +52,3 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
52
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
53
  checkpoint_path, iteration))
54
  return model, optimizer, learning_rate, iteration
55
-
56
-
57
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
58
- logger.info("Saving model and optimizer state at iteration {} to {}".format(
59
- iteration, checkpoint_path))
60
- if hasattr(model, 'module'):
61
- state_dict = model.module.state_dict()
62
- else:
63
- state_dict = model.state_dict()
64
- torch.save({'model': state_dict,
65
- 'iteration': iteration,
66
- 'optimizer': optimizer.state_dict(),
67
- 'learning_rate': learning_rate}, checkpoint_path)
68
-
69
-
70
- def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
71
- for k, v in scalars.items():
72
- writer.add_scalar(k, v, global_step)
73
- for k, v in histograms.items():
74
- writer.add_histogram(k, v, global_step)
75
- for k, v in images.items():
76
- writer.add_image(k, v, global_step, dataformats='HWC')
77
- for k, v in audios.items():
78
- writer.add_audio(k, v, global_step, audio_sampling_rate)
79
-
80
-
81
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
82
- f_list = glob.glob(os.path.join(dir_path, regex))
83
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
84
- x = f_list[-1]
85
- print(x)
86
- return x
87
-
88
-
89
- def plot_spectrogram_to_numpy(spectrogram):
90
- global MATPLOTLIB_FLAG
91
- if not MATPLOTLIB_FLAG:
92
- import matplotlib
93
- matplotlib.use("Agg")
94
- MATPLOTLIB_FLAG = True
95
- mpl_logger = logging.getLogger('matplotlib')
96
- mpl_logger.setLevel(logging.WARNING)
97
- import matplotlib.pylab as plt
98
- import numpy as np
99
-
100
- fig, ax = plt.subplots(figsize=(10, 2))
101
- im = ax.imshow(spectrogram, aspect="auto", origin="lower",
102
- interpolation='none')
103
- plt.colorbar(im, ax=ax)
104
- plt.xlabel("Frames")
105
- plt.ylabel("Channels")
106
- plt.tight_layout()
107
-
108
- fig.canvas.draw()
109
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
110
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
111
- plt.close()
112
- return data
113
-
114
-
115
- def plot_alignment_to_numpy(alignment, info=None):
116
- global MATPLOTLIB_FLAG
117
- if not MATPLOTLIB_FLAG:
118
- import matplotlib
119
- matplotlib.use("Agg")
120
- MATPLOTLIB_FLAG = True
121
- mpl_logger = logging.getLogger('matplotlib')
122
- mpl_logger.setLevel(logging.WARNING)
123
- import matplotlib.pylab as plt
124
- import numpy as np
125
-
126
- fig, ax = plt.subplots(figsize=(6, 4))
127
- im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
128
- interpolation='none')
129
- fig.colorbar(im, ax=ax)
130
- xlabel = 'Decoder timestep'
131
- if info is not None:
132
- xlabel += '\n\n' + info
133
- plt.xlabel(xlabel)
134
- plt.ylabel('Encoder timestep')
135
- plt.tight_layout()
136
-
137
- fig.canvas.draw()
138
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
139
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
140
- plt.close()
141
- return data
142
-
143
-
144
- def load_wav_to_torch(full_path):
145
- sampling_rate, data = read(full_path)
146
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
147
-
148
-
149
- def load_filepaths_and_text(filename, split="|"):
150
- with open(filename, encoding='utf-8') as f:
151
- filepaths_and_text = [line.strip().split(split) for line in f]
152
- return filepaths_and_text
153
-
154
-
155
- def get_hparams(init=True):
156
- parser = argparse.ArgumentParser()
157
- parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
158
- help='JSON file for configuration')
159
- parser.add_argument('-m', '--model', type=str, required=True,
160
- help='Model name')
161
-
162
- args = parser.parse_args()
163
- model_dir = os.path.join("./logs", args.model)
164
-
165
- if not os.path.exists(model_dir):
166
- os.makedirs(model_dir)
167
-
168
- config_path = args.config
169
- config_save_path = os.path.join(model_dir, "config.json")
170
- if init:
171
- with open(config_path, "r") as f:
172
- data = f.read()
173
- with open(config_save_path, "w") as f:
174
- f.write(data)
175
- else:
176
- with open(config_save_path, "r") as f:
177
- data = f.read()
178
- config = json.loads(data)
179
-
180
- hparams = HParams(**config)
181
- hparams.model_dir = model_dir
182
- return hparams
183
-
184
-
185
- def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True):
186
- """Freeing up space by deleting saved ckpts
187
-
188
- Arguments:
189
- path_to_models -- Path to the model directory
190
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
191
- sort_by_time -- True -> chronologically delete ckpts
192
- False -> lexicographically delete ckpts
193
- """
194
- import re
195
- ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
196
- name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
197
- time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
198
- sort_key = time_key if sort_by_time else name_key
199
- x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')],
200
- key=sort_key)
201
- to_del = [os.path.join(path_to_models, fn) for fn in
202
- (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
203
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
204
- del_routine = lambda x: [os.remove(x), del_info(x)]
205
- rs = [del_routine(fn) for fn in to_del]
206
-
207
-
208
- def get_hparams_from_dir(model_dir):
209
- config_save_path = os.path.join(model_dir, "config.json")
210
- with open(config_save_path, "r", encoding='utf-8') as f:
211
- data = f.read()
212
- config = json.loads(data)
213
-
214
- hparams = HParams(**config)
215
- hparams.model_dir = model_dir
216
- return hparams
217
-
218
-
219
- def get_hparams_from_file(config_path):
220
- with open(config_path, "r", encoding='utf-8') as f:
221
- data = f.read()
222
- config = json.loads(data)
223
-
224
- hparams = HParams(**config)
225
- return hparams
226
-
227
-
228
- def check_git_hash(model_dir):
229
- source_dir = os.path.dirname(os.path.realpath(__file__))
230
- if not os.path.exists(os.path.join(source_dir, ".git")):
231
- logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
232
- source_dir
233
- ))
234
- return
235
-
236
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
237
-
238
- path = os.path.join(model_dir, "githash")
239
- if os.path.exists(path):
240
- saved_hash = open(path).read()
241
- if saved_hash != cur_hash:
242
- logger.warn("git hash values are different. {}(saved) != {}(current)".format(
243
- saved_hash[:8], cur_hash[:8]))
244
- else:
245
- open(path, "w").write(cur_hash)
246
-
247
-
248
- def get_logger(model_dir, filename="train.log"):
249
- global logger
250
- logger = logging.getLogger(os.path.basename(model_dir))
251
- logger.setLevel(logging.DEBUG)
252
-
253
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
254
- if not os.path.exists(model_dir):
255
- os.makedirs(model_dir)
256
- h = logging.FileHandler(os.path.join(model_dir, filename))
257
- h.setLevel(logging.DEBUG)
258
- h.setFormatter(formatter)
259
- logger.addHandler(h)
260
- return logger
261
-
262
-
263
- class HParams():
264
- def __init__(self, **kwargs):
265
- for k, v in kwargs.items():
266
- if type(v) == dict:
267
- v = HParams(**v)
268
- self[k] = v
269
-
270
- def keys(self):
271
- return self.__dict__.keys()
272
-
273
- def items(self):
274
- return self.__dict__.items()
275
-
276
- def values(self):
277
- return self.__dict__.values()
278
-
279
- def __len__(self):
280
- return len(self.__dict__)
281
-
282
- def __getitem__(self, key):
283
- return getattr(self, key)
284
-
285
- def __setitem__(self, key, value):
286
- return setattr(self, key, value)
287
-
288
- def __contains__(self, key):
289
- return key in self.__dict__
290
-
291
- def __repr__(self):
292
- return self.__dict__.__repr__()
 
1
  import os
 
2
  import sys
 
3
  import logging
 
 
 
 
4
  import torch
5
 
6
  MATPLOTLIB_FLAG = False
 
36
  new_state_dict[k] = saved_state_dict[k]
37
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
38
  except:
39
+ # For upgrading from the old version
40
+ if "ja_bert_proj" in k:
41
+ v = torch.zeros_like(v)
42
+ logger.warning(
43
+ f"If you are using an older version of the model, you should add the parameter \"legacy\":true to the data of the model's config.json")
44
+ logger.error(f"{k} is not in the checkpoint")
45
+
46
  new_state_dict[k] = v
47
  if hasattr(model, 'module'):
48
  model.module.load_state_dict(new_state_dict, strict=False)
 
52
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
53
  checkpoint_path, iteration))
54
  return model, optimizer, learning_rate, iteration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docker-compose-gpu.yaml CHANGED
@@ -8,7 +8,7 @@ services:
8
  - 23456:23456
9
  environment:
10
  LANG: 'C.UTF-8'
11
- TZ: Asia/Shanghai #timezone
12
  volumes:
13
  - ./Model:/app/Model # 挂载模型文件夹
14
  - ./config.py:/app/config.py # 挂载配置文件
@@ -17,5 +17,10 @@ services:
17
  - ./vits/bert:/app/vits/bert # vits_chinese
18
  - ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2
19
  - ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
20
- devices:
21
- - "/dev/nvidia0:/dev/nvidia0"
 
 
 
 
 
 
8
  - 23456:23456
9
  environment:
10
  LANG: 'C.UTF-8'
11
+ #TZ: Asia/Shanghai #timezone
12
  volumes:
13
  - ./Model:/app/Model # 挂载模型文件夹
14
  - ./config.py:/app/config.py # 挂载配置文件
 
17
  - ./vits/bert:/app/vits/bert # vits_chinese
18
  - ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2
19
  - ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
20
+ deploy:
21
+ resources:
22
+ reservations:
23
+ devices:
24
+ - driver: nvidia
25
+ #device_ids: ['0', '3']
26
+ capabilities: [gpu]
requirements.txt CHANGED
@@ -1,16 +1,19 @@
1
- numba
2
- librosa
3
  numpy==1.23.3
4
  scipy
5
- MarkupSafe==2.1.2
6
- torch
 
 
 
 
 
 
7
  unidecode
8
- six==1.16.0
9
- openjtalk==0.3.0.dev2
10
  jamo
11
  pypinyin
12
  jieba
13
- protobuf
14
  cn2an
15
  inflect
16
  eng_to_ipa
@@ -18,14 +21,30 @@ ko_pron
18
  indic_transliteration
19
  num_thai
20
  opencc
21
- audonnx
22
- flask==2.2.3
23
- soundfile==0.12.1
24
- graiax-silkcoder[libsndfile]
25
- flask_apscheduler
26
  fasttext
27
  fastlid
28
  langid
29
  phonemizer==3.2.1
30
  transformers
31
- pydantic==2.3.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core scientific and computation libraries
 
2
  numpy==1.23.3
3
  scipy
4
+
5
+ # Audio processing
6
+ librosa
7
+ soundfile==0.12.1
8
+ numba
9
+ graiax-silkcoder[libsndfile]
10
+
11
+ # Natural Language Processing and Text Conversion
12
  unidecode
13
+ pyopenjtalk==0.3.2
 
14
  jamo
15
  pypinyin
16
  jieba
 
17
  cn2an
18
  inflect
19
  eng_to_ipa
 
21
  indic_transliteration
22
  num_thai
23
  opencc
 
 
 
 
 
24
  fasttext
25
  fastlid
26
  langid
27
  phonemizer==3.2.1
28
  transformers
29
+ num2words
30
+ mecab-python3
31
+ unidic-lite
32
+
33
+ # Machine Learning and Deep Learning
34
+ torch
35
+ audonnx
36
+
37
+ # Web and API services
38
+ flask==2.2.3
39
+ flask_apscheduler
40
+ pydantic==2.3.0
41
+ werkzeug==2.3.6
42
+
43
+ # Compression and Decompression
44
+ py7zr
45
+
46
+ # Other utilities and dependencies
47
+ MarkupSafe==2.1.2
48
+ six==1.16.0
49
+ protobuf
50
+ tqdm
templates/index.html CHANGED
@@ -532,4 +532,4 @@
532
  });
533
  </script>
534
  </body>
535
- </html>
 
532
  });
533
  </script>
534
  </body>
535
+ </html>
utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from utils.classify_language import classify_language
2
+ from utils.data_utils import get_hparams_from_file, load_checkpoint, load_audio_to_torch
3
+ from utils.lang_dict import lang_dict
utils/classify_language.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import LANGUAGE_IDENTIFICATION_LIBRARY
2
+
3
+ module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
4
+
5
+
6
+ def classify_language(text: str, target_languages: list = None) -> str:
7
+ if module == "fastlid" or module == "fasttext":
8
+ from fastlid import fastlid
9
+ classifier = fastlid
10
+ if target_languages != None: fastlid.set_languages = target_languages
11
+ elif module == "langid":
12
+ import langid
13
+ classifier = langid.classify
14
+ if target_languages != None: langid.set_languages(target_languages)
15
+ else:
16
+ raise ValueError(f"Wrong LANGUAGE_IDENTIFICATION_LIBRARY in config.py")
17
+
18
+ lang = classifier(text)[0]
19
+
20
+ return lang
21
+
22
+
23
+ def classify_zh_ja(text: str) -> str:
24
+ for idx, char in enumerate(text):
25
+ unicode_val = ord(char)
26
+
27
+ # 检测日语字符
28
+ if 0x3040 <= unicode_val <= 0x309F or 0x30A0 <= unicode_val <= 0x30FF:
29
+ return "ja"
30
+
31
+ # 检测汉字字符
32
+ if 0x4E00 <= unicode_val <= 0x9FFF:
33
+ # 检查周围的字符
34
+ next_char = text[idx + 1] if idx + 1 < len(text) else None
35
+
36
+ if next_char and (0x3040 <= ord(next_char) <= 0x309F or 0x30A0 <= ord(next_char) <= 0x30FF):
37
+ return "ja"
38
+
39
+ return "zh"
40
+
41
+
42
+ if __name__ == "__main__":
43
+ text = "这是一个测试文本"
44
+ print(classify_language(text))
45
+ print(classify_zh_ja(text)) # "zh"
46
+
47
+ text = "これはテストテキストです"
48
+ print(classify_language(text))
49
+ print(classify_zh_ja(text)) # "ja"
utils/data_utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+
5
+ class HParams():
6
+ def __init__(self, **kwargs):
7
+ for k, v in kwargs.items():
8
+ if type(v) == dict:
9
+ v = HParams(**v)
10
+ self[k] = v
11
+
12
+ def keys(self):
13
+ return self.__dict__.keys()
14
+
15
+ def items(self):
16
+ return self.__dict__.items()
17
+
18
+ def values(self):
19
+ return self.__dict__.values()
20
+
21
+ def __len__(self):
22
+ return len(self.__dict__)
23
+
24
+ def __getitem__(self, key):
25
+ return getattr(self, key)
26
+
27
+ def __setitem__(self, key, value):
28
+ return setattr(self, key, value)
29
+
30
+ def __contains__(self, key):
31
+ return key in self.__dict__
32
+
33
+ def __repr__(self):
34
+ return self.__dict__.__repr__()
35
+
36
+
37
+ def load_checkpoint(checkpoint_path, model):
38
+ from torch import load
39
+ checkpoint_dict = load(checkpoint_path, map_location='cpu')
40
+ iteration = checkpoint_dict.get('iteration', None)
41
+ saved_state_dict = checkpoint_dict['model']
42
+ if hasattr(model, 'module'):
43
+ state_dict = model.module.state_dict()
44
+ else:
45
+ state_dict = model.state_dict()
46
+ new_state_dict = {}
47
+ for k, v in state_dict.items():
48
+ try:
49
+ new_state_dict[k] = saved_state_dict[k]
50
+ except:
51
+ logging.info(f"{k} is not in the checkpoint")
52
+ new_state_dict[k] = v
53
+ if hasattr(model, 'module'):
54
+ model.module.load_state_dict(new_state_dict)
55
+ else:
56
+ model.load_state_dict(new_state_dict)
57
+ if iteration:
58
+ logging.info(f"Loaded checkpoint '{checkpoint_path}' (iteration {iteration})")
59
+ else:
60
+ logging.info(f"Loaded checkpoint '{checkpoint_path}'")
61
+ return
62
+
63
+
64
+ def get_hparams_from_file(config_path):
65
+ from json import loads
66
+ with open(config_path, 'r', encoding='utf-8') as f:
67
+ data = f.read()
68
+ config = loads(data)
69
+
70
+ hparams = HParams(**config)
71
+ return hparams
72
+
73
+
74
+ def load_audio_to_torch(full_path, target_sampling_rate):
75
+ import librosa
76
+ from torch import FloatTensor
77
+ from numpy import float32
78
+ audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
79
+ return FloatTensor(audio.astype(float32))
80
+
81
+
82
+ def clean_folder(folder_path):
83
+ for filename in os.listdir(folder_path):
84
+ file_path = os.path.join(folder_path, filename)
85
+ # 如果是文件,则删除文件。如果是文件夹则跳过。
86
+ if os.path.isfile(file_path):
87
+ os.remove(file_path)
88
+
89
+
90
+ def check_is_none(item) -> bool:
91
+ # none -> True, not none -> False
92
+ return item is None or (isinstance(item, str) and str(item).isspace()) or str(item) == ""
93
+
94
+
95
+ def save_audio(audio, path):
96
+ with open(path, "wb") as f:
97
+ f.write(audio)
utils/download.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import tarfile
4
+ import urllib.request
5
+ import zipfile
6
+
7
+ from tqdm import tqdm
8
+ from pathlib import Path
9
+ from logger import logger
10
+ from py7zr import SevenZipFile
11
+
12
+
13
+ class TqdmUpTo(tqdm):
14
+ def update_to(self, b=1, bsize=1, tsize=None):
15
+ if tsize is not None:
16
+ self.total = tsize
17
+ self.update(b * bsize - self.n)
18
+
19
+
20
+ def download_file(url, dest_path):
21
+ with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=url.split('/')[-1]) as t:
22
+ urllib.request.urlretrieve(url, dest_path, reporthook=t.update_to)
23
+
24
+
25
+ def verify_md5(file_path, expected_md5):
26
+ md5 = hashlib.md5(file_path.read_bytes()).hexdigest()
27
+ if md5 != expected_md5:
28
+ return False, f"MD5 mismatch: {md5} != {expected_md5}"
29
+ return True, ""
30
+
31
+
32
+ def extract_file(file_path, destination=None):
33
+ """
34
+ Extract a compressed file based on its extension.
35
+ If destination is not specified, it will be extracted to its parent directory.
36
+ """
37
+ if destination is None:
38
+ destination = Path(file_path).parent
39
+
40
+ if file_path.endswith('.zip'):
41
+ with zipfile.ZipFile(file_path, 'r') as zip_ref:
42
+ zip_ref.extractall(destination)
43
+ elif file_path.endswith('.tar.gz'):
44
+ with tarfile.open(file_path, 'r:gz') as tar_ref:
45
+ tar_ref.extractall(destination)
46
+ elif file_path.endswith('.tar.bz2'):
47
+ with tarfile.open(file_path, 'r:bz2') as tar_ref:
48
+ tar_ref.extractall(destination)
49
+ elif file_path.endswith('.7z'):
50
+ with SevenZipFile(file_path, mode='r') as z:
51
+ z.extractall(destination)
52
+ else:
53
+ print(f"Unsupported compression format for file {file_path}")
54
+
55
+
56
+ def download_and_verify(urls, target_path, expected_md5=None, extract_destination=None):
57
+ for url in urls:
58
+ try:
59
+ download_file(url, target_path)
60
+ break
61
+ except Exception as error:
62
+ logger.error(f"downloading from URL {url}: {error}")
63
+
64
+ else: # This else is tied to the for loop, and executes if no download is successful
65
+ return False, "Error downloading from all provided URLs."
66
+
67
+ if expected_md5 is not None:
68
+ success, message = verify_md5(Path(target_path), expected_md5)
69
+ if not success:
70
+ os.remove(target_path)
71
+ return False, message
72
+
73
+ # If it's a compressed file, extract it
74
+ if target_path.endswith(('.zip', '.tar.gz', '.tar.bz2', '.7z')):
75
+ extract_file(target_path, extract_destination)
76
+ os.remove(target_path)
77
+
78
+ return True, "File downloaded, verified, and extracted successfully!"
79
+
80
+
81
+ if __name__ == "__main__":
82
+ URLS = [
83
+ "YOUR_PRIMARY_URL_HERE",
84
+ "YOUR_FIRST_BACKUP_URL_HERE",
85
+ # ... you can add more backup URLs as needed
86
+ ]
87
+ TARGET_PATH = ""
88
+ EXPECTED_MD5 = ""
89
+ EXTRACT_DESTINATION = ""
90
+
91
+ success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION)
92
+ print(message)
utils/lang_dict.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lang_dict = {
2
+ "english_cleaners": ["en"],
3
+ "english_cleaners2": ["en"],
4
+ "japanese_cleaners": ["ja"],
5
+ "japanese_cleaners2": ["ja"],
6
+ "korean_cleaners": ["ko"],
7
+ "chinese_cleaners": ["zh"],
8
+ "zh_ja_mixture_cleaners": ["zh", "ja"],
9
+ "sanskrit_cleaners": ["sa"],
10
+ "cjks_cleaners": ["zh", "ja", "ko", "sa"],
11
+ "cjke_cleaners": ["zh", "ja", "ko", "en"],
12
+ "cjke_cleaners2": ["zh", "ja", "ko", "en"],
13
+ "cje_cleaners": ["zh", "ja", "en"],
14
+ "cje_cleaners2": ["zh", "ja", "en"],
15
+ "thai_cleaners": ["th"],
16
+ "shanghainese_cleaners": ["sh"],
17
+ "chinese_dialect_cleaners": ["zh", "ja", "sh", "gd", "en", "SZ", "WX", "CZ", "HZ", "SX", "NB", "JJ", "YX", "JD",
18
+ "ZR", "PH", "TX", "JS", "HN", "LP", "XS", "FY", "RA", "CX", "SM", "TT", "WZ", "SC",
19
+ "YB"],
20
+ "bert_chinese_cleaners": ["zh"],
21
+ "bert_vits2": ["zh", "ja"]
22
+ }
utils/load_model.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ import config
5
+ import numpy as np
6
+ from utils.data_utils import check_is_none
7
+ from vits import VITS
8
+ from voice import TTS
9
+ from config import DEVICE as device
10
+ from utils.lang_dict import lang_dict
11
+
12
+
13
+
14
+ def recognition_model_type(model_config_json: json) -> str:
15
+ model_config = json.load(model_config_json)
16
+ symbols = model_config.get("symbols", None)
17
+ emotion_embedding = model_config.get("data").get("emotion_embedding", False)
18
+
19
+ if "use_spk_conditioned_encoder" in model_config.get("model"):
20
+ model_type = "bert_vits2"
21
+ return model_type
22
+
23
+ if symbols != None:
24
+ if not emotion_embedding:
25
+ mode_type = "vits"
26
+ else:
27
+ mode_type = "w2v2"
28
+ else:
29
+ mode_type = "hubert"
30
+
31
+ return mode_type
32
+
33
+
34
+ def load_npy(emotion_reference_npy):
35
+ if isinstance(emotion_reference_npy, list):
36
+ # check if emotion_reference_npy is endwith .npy
37
+ for i in emotion_reference_npy:
38
+ model_extention = os.path.splitext(i)[1]
39
+ if model_extention != ".npy":
40
+ raise ValueError(f"Unsupported model type: {model_extention}")
41
+
42
+ # merge npy files
43
+ emotion_reference = np.empty((0, 1024))
44
+ for i in emotion_reference_npy:
45
+ tmp = np.load(i).reshape(-1, 1024)
46
+ emotion_reference = np.append(emotion_reference, tmp, axis=0)
47
+
48
+ elif os.path.isdir(emotion_reference_npy):
49
+ emotion_reference = np.empty((0, 1024))
50
+ for root, dirs, files in os.walk(emotion_reference_npy):
51
+ for file_name in files:
52
+ # check if emotion_reference_npy is endwith .npy
53
+ model_extention = os.path.splitext(file_name)[1]
54
+ if model_extention != ".npy":
55
+ continue
56
+ file_path = os.path.join(root, file_name)
57
+
58
+ # merge npy files
59
+ tmp = np.load(file_path).reshape(-1, 1024)
60
+ emotion_reference = np.append(emotion_reference, tmp, axis=0)
61
+
62
+ elif os.path.isfile(emotion_reference_npy):
63
+ # check if emotion_reference_npy is endwith .npy
64
+ model_extention = os.path.splitext(emotion_reference_npy)[1]
65
+ if model_extention != ".npy":
66
+ raise ValueError(f"Unsupported model type: {model_extention}")
67
+
68
+ emotion_reference = np.load(emotion_reference_npy)
69
+ logging.info(f"Loaded emotional dimention npy range:{len(emotion_reference)}")
70
+ return emotion_reference
71
+
72
+
73
+ def parse_models(model_list):
74
+ categorized_models = {
75
+ "vits": [],
76
+ "hubert": [],
77
+ "w2v2": [],
78
+ "bert_vits2": []
79
+ }
80
+
81
+ for model_info in model_list:
82
+ config_path = model_info[1]
83
+ with open(config_path, 'r', encoding='utf-8') as model_config:
84
+ model_type = recognition_model_type(model_config)
85
+ if model_type in categorized_models:
86
+ categorized_models[model_type].append(model_info)
87
+
88
+ return categorized_models
89
+
90
+
91
+ def merge_models(model_list, model_class, model_type, additional_arg=None):
92
+ id_mapping_objs = []
93
+ speakers = []
94
+ new_id = 0
95
+
96
+ for obj_id, (model_path, config_path) in enumerate(model_list):
97
+ obj_args = {
98
+ "model": model_path,
99
+ "config": config_path,
100
+ "model_type": model_type,
101
+ "device": device
102
+ }
103
+ if additional_arg:
104
+ obj_args.update(additional_arg)
105
+
106
+ obj = model_class(**obj_args)
107
+
108
+ if model_type == "bert_vits2":
109
+ key = model_type
110
+ else:
111
+ key = obj.get_cleaner()
112
+
113
+ lang = lang_dict.get(key, ["unknown"])
114
+
115
+ for real_id, name in enumerate(obj.get_speakers()):
116
+ id_mapping_objs.append([real_id, obj, obj_id])
117
+ speakers.append({"id": new_id, "name": name, "lang": lang})
118
+ new_id += 1
119
+
120
+ return id_mapping_objs, speakers
121
+
122
+
123
+ def load_model(model_list) -> TTS:
124
+ categorized_models = parse_models(model_list)
125
+
126
+ # Handle VITS
127
+ vits_objs, vits_speakers = merge_models(categorized_models["vits"], VITS, "vits")
128
+
129
+ # Handle HUBERT-VITS
130
+ hubert_vits_objs, hubert_vits_speakers = [], []
131
+ if len(categorized_models["hubert"]) != 0:
132
+ if getattr(config, "HUBERT_SOFT_MODEL", None) is None or check_is_none(config.HUBERT_SOFT_MODEL):
133
+ raise ValueError(f"Please configure HUBERT_SOFT_MODEL path in config.py")
134
+ try:
135
+ from vits.hubert_model import hubert_soft
136
+ hubert = hubert_soft(config.HUBERT_SOFT_MODEL)
137
+ except Exception as e:
138
+ raise ValueError(f"Load HUBERT_SOFT_MODEL failed {e}")
139
+
140
+ hubert_vits_objs, hubert_vits_speakers = merge_models(categorized_models["hubert"], VITS, "hubert",
141
+ additional_arg={"additional_model": hubert})
142
+
143
+ # Handle W2V2-VITS
144
+ w2v2_vits_objs, w2v2_vits_speakers = [], []
145
+ w2v2_emotion_count = 0
146
+ if len(categorized_models["w2v2"]) != 0:
147
+ if getattr(config, "DIMENSIONAL_EMOTION_NPY", None) is None or check_is_none(
148
+ config.DIMENSIONAL_EMOTION_NPY):
149
+ raise ValueError(f"Please configure DIMENSIONAL_EMOTION_NPY path in config.py")
150
+ try:
151
+ emotion_reference = load_npy(config.DIMENSIONAL_EMOTION_NPY)
152
+ except Exception as e:
153
+ emotion_reference = None
154
+ raise ValueError(f"Load DIMENSIONAL_EMOTION_NPY failed {e}")
155
+
156
+ w2v2_vits_objs, w2v2_vits_speakers = merge_models(categorized_models["w2v2"], VITS, "w2v2",
157
+ additional_arg={"additional_model": emotion_reference})
158
+ w2v2_emotion_count = len(emotion_reference) if emotion_reference is not None else 0
159
+
160
+ # Handle BERT-VITS2
161
+ bert_vits2_objs, bert_vits2_speakers = [], []
162
+ if len(categorized_models["bert_vits2"]) != 0:
163
+ from bert_vits2 import Bert_VITS2
164
+ bert_vits2_objs, bert_vits2_speakers = merge_models(categorized_models["bert_vits2"], Bert_VITS2, "bert_vits2")
165
+
166
+ voice_obj = {"VITS": vits_objs, "HUBERT-VITS": hubert_vits_objs, "W2V2-VITS": w2v2_vits_objs,
167
+ "BERT-VITS2": bert_vits2_objs}
168
+ voice_speakers = {"VITS": vits_speakers, "HUBERT-VITS": hubert_vits_speakers, "W2V2-VITS": w2v2_vits_speakers,
169
+ "BERT-VITS2": bert_vits2_speakers}
170
+
171
+ tts = TTS(voice_obj, voice_speakers, device=device, w2v2_emotion_count=w2v2_emotion_count)
172
+ return tts
utils/sentence.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+
3
+ from logger import logger
4
+ from utils.data_utils import check_is_none
5
+ from utils.classify_language import classify_language
6
+
7
+
8
+ def markup_language_type(text: str, target_languages: list = None) -> str:
9
+ pattern = r'[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`' \
10
+ r'\!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」' \
11
+ r'『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+'
12
+ sentences = re.split(pattern, text)
13
+
14
+ pre_lang = ""
15
+ p = 0
16
+
17
+ for sentence in sentences:
18
+
19
+ if check_is_none(sentence): continue
20
+
21
+ lang = classify_language(sentence, target_languages)
22
+
23
+ if pre_lang == "":
24
+ text = text[:p] + text[p:].replace(sentence, f"[{lang.upper()}]{sentence}", 1)
25
+ p += len(f"[{lang.upper()}]")
26
+ elif pre_lang != lang:
27
+ text = text[:p] + text[p:].replace(sentence, f"[{pre_lang.upper()}][{lang.upper()}]{sentence}", 1)
28
+ p += len(f"[{pre_lang.upper()}][{lang.upper()}]")
29
+ pre_lang = lang
30
+ p += text[p:].index(sentence) + len(sentence)
31
+ text += f"[{pre_lang.upper()}]"
32
+
33
+ return text
34
+
35
+
36
+ def cut(text: str, max: int) -> list:
37
+ pattern = r'[!(),—+\-.:;??。,、;:]+'
38
+ sentences = re.split(pattern, text)
39
+ discarded_chars = re.findall(pattern, text)
40
+
41
+ sentence_list, count, p = [], 0, 0
42
+
43
+ # 按被分割的符号遍历
44
+ for i, discarded_chars in enumerate(discarded_chars):
45
+ count += len(sentences[i]) + len(discarded_chars)
46
+ if count >= max:
47
+ sentence_list.append(text[p:p + count].strip())
48
+ p += count
49
+ count = 0
50
+
51
+ # 加入最后剩余的文本
52
+ if p < len(text):
53
+ sentence_list.append(text[p:])
54
+
55
+ return sentence_list
56
+
57
+
58
+ def sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None):
59
+ # 如果该speaker只支持一种语言
60
+ if speaker_lang is not None and len(speaker_lang) == 1:
61
+ if lang.upper() not in ["AUTO", "MIX"] and lang.lower() != speaker_lang[0]:
62
+ logger.debug(
63
+ f"lang \"{lang}\" is not in speaker_lang {speaker_lang},automatically set lang={speaker_lang[0]}")
64
+ lang = speaker_lang[0]
65
+
66
+ sentence_list = []
67
+ if lang.upper() != "MIX":
68
+ if max <= 0:
69
+ sentence_list.append(
70
+ markup_language_type(text,
71
+ speaker_lang) if lang.upper() == "AUTO" else f"[{lang.upper()}]{text}[{lang.upper()}]")
72
+ else:
73
+ for i in cut(text, max):
74
+ if check_is_none(i): continue
75
+ sentence_list.append(
76
+ markup_language_type(i,
77
+ speaker_lang) if lang.upper() == "AUTO" else f"[{lang.upper()}]{i}[{lang.upper()}]")
78
+ else:
79
+ sentence_list.append(text)
80
+
81
+ for i in sentence_list:
82
+ logger.debug(i)
83
+
84
+ return sentence_list
85
+
86
+
87
+ if __name__ == '__main__':
88
+ text = "这几天心里颇不宁静。今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。我悄悄地披了大衫,带上门出去。"
89
+ print(markup_language_type(text, languages=None))
90
+ print(cut(text, max=50))
91
+ print(sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None))
vits-simple-api-installer-latest.sh CHANGED
@@ -228,19 +228,6 @@ read -p "${MESSAGES["ENTER_CHOICE"]}" choice_download_bert_vits2
228
  if [ "$choice_download_bert_vits2" -eq 1 ]; then
229
  mkdir -p bert_vits2/bert/chinese-roberta-wwm-ext-large
230
 
231
- EXPECTED_MD5="78ef42421495cb23372bef9d069a75f3"
232
- FILE_PATH="bert_vits2/bert/chinese-roberta-wwm-ext-large/flax_model.msgpack"
233
- echo -e "${MESSAGES["VERIFYING"]}$FILE_PATH"
234
- ACTUAL_MD5=$(md5sum $FILE_PATH | awk '{print $1}')
235
-
236
- if [ "$EXPECTED_MD5" == "$ACTUAL_MD5" ]; then
237
- echo "${MESSAGES["DOWNLOADED"]}"
238
- else
239
- echo "${MESSAGES["CORRUPTED"]}"
240
- download_with_fallback bert_vits2/bert/chinese-roberta-wwm-ext-large/flax_model.msgpack \
241
- "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/flax_model.msgpack"
242
- fi
243
-
244
  EXPECTED_MD5="15d7435868fef1bd4222ff7820149a2a"
245
  FILE_PATH="bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin"
246
  echo -e "${MESSAGES["VERIFYING"]}$FILE_PATH"
@@ -254,19 +241,6 @@ if [ "$choice_download_bert_vits2" -eq 1 ]; then
254
  "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin"
255
  fi
256
 
257
- EXPECTED_MD5="d15991416bd4a86fa127c70d3c0f4779"
258
- FILE_PATH="bert_vits2/bert/chinese-roberta-wwm-ext-large/tf_model.h5"
259
- echo -e "${MESSAGES["VERIFYING"]}$FILE_PATH"
260
- ACTUAL_MD5=$(md5sum $FILE_PATH | awk '{print $1}')
261
-
262
- if [ "$EXPECTED_MD5" == "$ACTUAL_MD5" ]; then
263
- echo "${MESSAGES["DOWNLOADED"]}"
264
- else
265
- echo "${MESSAGES["CORRUPTED"]}"
266
- download_with_fallback bert_vits2/bert/chinese-roberta-wwm-ext-large/tf_model.h5 \
267
- "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/tf_model.h5"
268
- fi
269
-
270
  fi
271
 
272
  if [ "$choice_gpu" -eq 2 ]; then
 
228
  if [ "$choice_download_bert_vits2" -eq 1 ]; then
229
  mkdir -p bert_vits2/bert/chinese-roberta-wwm-ext-large
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  EXPECTED_MD5="15d7435868fef1bd4222ff7820149a2a"
232
  FILE_PATH="bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin"
233
  echo -e "${MESSAGES["VERIFYING"]}$FILE_PATH"
 
241
  "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin"
242
  fi
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  fi
245
 
246
  if [ "$choice_gpu" -eq 2 ]; then
vits/bert/__init__.py CHANGED
@@ -1,2 +1,15 @@
1
  """ from https://github.com/PlayVoice/vits_chinese """
2
- from .ProsodyModel import TTSProsody
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """ from https://github.com/PlayVoice/vits_chinese """
2
+ import os
3
+
4
+ import config
5
+ from utils.download import download_and_verify
6
+ from .ProsodyModel import TTSProsody
7
+
8
+ URLS = [
9
+ "https://huggingface.co/spaces/maxmax20160403/vits_chinese/resolve/main/bert/prosody_model.pt",
10
+ ]
11
+ TARGET_PATH = os.path.join(config.ABS_PATH, "vits/bert/prosody_model.pt")
12
+ EXPECTED_MD5 = None
13
+
14
+ if not os.path.exists(TARGET_PATH):
15
+ success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5)
vits/text/cantonese.py CHANGED
@@ -1,9 +1,23 @@
 
1
  import re
2
  import cn2an
3
  import opencc
4
  import config
 
5
 
6
- converter = opencc.OpenCC(config.ABS_PATH + '/vits/text/chinese_dialect_lexicons/jyutjyu_2')
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # List of (Latin alphabet, ipa) pairs:
9
  _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
 
1
+ import os.path
2
  import re
3
  import cn2an
4
  import opencc
5
  import config
6
+ from utils.download import download_and_verify
7
 
8
+ URLS = [
9
+ "https://github.com/CjangCjengh/chinese-dialect-lexicons/releases/download/v1.0.3/chinese_dialects.7z",
10
+ "https://ghproxy.com/https://github.com/CjangCjengh/chinese-dialect-lexicons/releases/download/v1.0.3/chinese_dialects.7z",
11
+ ]
12
+ TARGET_PATH = os.path.join(config.ABS_PATH, "vits/text/chinese_dialects.7z")
13
+ EXTRACT_DESTINATION = os.path.join(config.ABS_PATH, "vits/text/chinese_dialect_lexicons/")
14
+ EXPECTED_MD5 = None
15
+ OPENCC_FILE_PATH = os.path.join(config.ABS_PATH, "vits/text/chinese_dialect_lexicons/jyutjyu.json")
16
+
17
+ if not os.path.exists(OPENCC_FILE_PATH):
18
+ success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION)
19
+
20
+ converter = opencc.OpenCC(OPENCC_FILE_PATH)
21
 
22
  # List of (Latin alphabet, ipa) pairs:
23
  _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
vits/text/shanghainese.py CHANGED
@@ -1,9 +1,23 @@
 
1
  import re
2
  import cn2an
3
  import opencc
4
  import config
 
5
 
6
- converter = opencc.OpenCC(config.ABS_PATH + '/vits/text/chinese_dialect_lexicons/zaonhe')
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # List of (Latin alphabet, ipa) pairs:
9
  _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
 
1
+ import os
2
  import re
3
  import cn2an
4
  import opencc
5
  import config
6
+ from utils.download import download_and_verify
7
 
8
+ URLS = [
9
+ "https://github.com/CjangCjengh/chinese-dialect-lexicons/releases/download/v1.0.3/chinese_dialects.7z",
10
+ "https://ghproxy.com/https://github.com/CjangCjengh/chinese-dialect-lexicons/releases/download/v1.0.3/chinese_dialects.7z",
11
+ ]
12
+ TARGET_PATH = os.path.join(config.ABS_PATH, "vits/text/chinese_dialects.7z")
13
+ EXTRACT_DESTINATION = os.path.join(config.ABS_PATH, "vits/text/chinese_dialect_lexicons/")
14
+ EXPECTED_MD5 = None
15
+ OPENCC_FILE_PATH = os.path.join(config.ABS_PATH, "vits/text/chinese_dialect_lexicons/zaonhe.json")
16
+
17
+ if not os.path.exists(OPENCC_FILE_PATH):
18
+ success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION)
19
+
20
+ converter = opencc.OpenCC(OPENCC_FILE_PATH)
21
 
22
  # List of (Latin alphabet, ipa) pairs:
23
  _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
vits/text/vits_pinyin.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pypinyin
2
  from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
3
  from pypinyin.converter import DefaultConverter
@@ -41,12 +42,12 @@ class VITS_PinYin:
41
  def chinese_to_phonemes(self, text):
42
  # 考虑使用g2pw的chinese bert替换原始的pypinyin,目前测试下来运行速度太慢。
43
  # 将标准中文文本符号替换成 bert 符号库中的单符号,以保证bert的效果.
44
- text = text.replace("——", "...")\
45
- .replace("—", "...")\
46
- .replace("……", "...")\
47
- .replace("…", "...")\
48
- .replace('“', '"')\
49
- .replace('”', '"')\
50
  .replace("\n", "")
51
  tokens = self.prosody.char_model.tokenizer.tokenize(text)
52
  text = ''.join(tokens)
 
1
+ """ from https://github.com/PlayVoice/vits_chinese """
2
  import pypinyin
3
  from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin
4
  from pypinyin.converter import DefaultConverter
 
42
  def chinese_to_phonemes(self, text):
43
  # 考虑使用g2pw的chinese bert替换原始的pypinyin,目前测试下来运行速度太慢。
44
  # 将标准中文文本符号替换成 bert 符号库中的单符号,以保证bert的效果.
45
+ text = text.replace("——", "...") \
46
+ .replace("—", "...") \
47
+ .replace("……", "...") \
48
+ .replace("…", "...") \
49
+ .replace('“', '"') \
50
+ .replace('”', '"') \
51
  .replace("\n", "")
52
  tokens = self.prosody.char_model.tokenizer.tokenize(text)
53
  text = ''.join(tokens)
vits/vits.py CHANGED
@@ -1,18 +1,19 @@
1
  import librosa
2
- from vits import commons
3
  import re
4
  import numpy as np
5
  import torch
6
  from torch import no_grad, LongTensor, inference_mode, FloatTensor
7
- from utils.nlp import sentence_split
 
 
8
  from vits.mel_processing import spectrogram_torch
9
  from vits.text import text_to_sequence
10
  from vits.models import SynthesizerTrn
11
- from utils import utils
12
 
13
 
14
  class VITS:
15
- def __init__(self, model, config, model_=None, model_type=None, device=torch.device("cpu")):
16
  self.model_type = model_type
17
  self.hps_ms = utils.get_hparams_from_file(config)
18
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
@@ -38,15 +39,15 @@ class VITS:
38
  self.device = device
39
 
40
  # load model
41
- self.load_model(model, model_)
42
 
43
- def load_model(self, model, model_=None):
44
  utils.load_checkpoint(model, self.net_g_ms)
45
  self.net_g_ms.to(self.device)
46
  if self.model_type == "hubert":
47
- self.hubert = model_
48
  elif self.model_type == "w2v2":
49
- self.emotion_reference = model_
50
 
51
  def get_cleaned_text(self, text, hps, cleaned=False):
52
  if cleaned:
@@ -165,7 +166,7 @@ class VITS:
165
 
166
  tasks = []
167
  if self.model_type == "vits":
168
- sentence_list = sentence_split(text, max, lang, speaker_lang)
169
  for sentence in sentence_list:
170
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
171
  noise_scale=noise, noise_scale_w=noisew)
@@ -177,7 +178,7 @@ class VITS:
177
  tasks.append(params)
178
 
179
  elif self.model_type == "w2v2":
180
- sentence_list = sentence_split(text, max, lang, speaker_lang)
181
  for sentence in sentence_list:
182
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
183
  noise_scale=noise, noise_scale_w=noisew, emotion=emotion)
 
1
  import librosa
 
2
  import re
3
  import numpy as np
4
  import torch
5
  from torch import no_grad, LongTensor, inference_mode, FloatTensor
6
+ import utils
7
+ from utils.sentence import sentence_split_and_markup
8
+ from vits import commons
9
  from vits.mel_processing import spectrogram_torch
10
  from vits.text import text_to_sequence
11
  from vits.models import SynthesizerTrn
12
+
13
 
14
 
15
  class VITS:
16
+ def __init__(self, model, config, additional_model=None, model_type=None, device=torch.device("cpu"),**kwargs):
17
  self.model_type = model_type
18
  self.hps_ms = utils.get_hparams_from_file(config)
19
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
 
39
  self.device = device
40
 
41
  # load model
42
+ self.load_model(model, additional_model)
43
 
44
+ def load_model(self, model, additional_model=None):
45
  utils.load_checkpoint(model, self.net_g_ms)
46
  self.net_g_ms.to(self.device)
47
  if self.model_type == "hubert":
48
+ self.hubert = additional_model
49
  elif self.model_type == "w2v2":
50
+ self.emotion_reference = additional_model
51
 
52
  def get_cleaned_text(self, text, hps, cleaned=False):
53
  if cleaned:
 
166
 
167
  tasks = []
168
  if self.model_type == "vits":
169
+ sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
170
  for sentence in sentence_list:
171
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
172
  noise_scale=noise, noise_scale_w=noisew)
 
178
  tasks.append(params)
179
 
180
  elif self.model_type == "w2v2":
181
+ sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
182
  for sentence in sentence_list:
183
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
184
  noise_scale=noise, noise_scale_w=noisew, emotion=emotion)
voice.py CHANGED
@@ -8,7 +8,7 @@ import config
8
  import soundfile as sf
9
  from io import BytesIO
10
  from graiax import silkcoder
11
- from utils import utils
12
  from logger import logger
13
 
14
 
@@ -16,7 +16,7 @@ from logger import logger
16
 
17
 
18
  class TTS:
19
- def __init__(self, voice_obj, voice_speakers, w2v2_emotion_count=0, device=torch.device("cpu")):
20
  self._voice_obj = voice_obj
21
  self._voice_speakers = voice_speakers
22
  self._strength_dict = {"x-weak": 0.25, "weak": 0.5, "Medium": 0.75, "Strong": 1, "x-strong": 1.25}
@@ -24,14 +24,14 @@ class TTS:
24
  self._vits_speakers_count = len(self._voice_speakers["VITS"])
25
  self._hubert_speakers_count = len(self._voice_speakers["HUBERT-VITS"])
26
  self._w2v2_speakers_count = len(self._voice_speakers["W2V2-VITS"])
27
- self._w2v2_emotion_count = w2v2_emotion_count
28
  self._bert_vits2_speakers_count = len(self._voice_speakers["BERT-VITS2"])
29
  self.dem = None
30
 
31
  # Initialization information
32
  self.logger = logger
33
  self.logger.info(f"torch:{torch.__version__} cuda_available:{torch.cuda.is_available()}")
34
- self.logger.info(f'device:{device} device.type:{device.type}')
35
 
36
  if getattr(config, "DIMENSIONAL_EMOTION_MODEL", None) != None:
37
  try:
@@ -45,7 +45,8 @@ class TTS:
45
  if self._vits_speakers_count != 0: self.logger.info(f"[VITS] {self._vits_speakers_count} speakers")
46
  if self._hubert_speakers_count != 0: self.logger.info(f"[hubert] {self._hubert_speakers_count} speakers")
47
  if self._w2v2_speakers_count != 0: self.logger.info(f"[w2v2] {self._w2v2_speakers_count} speakers")
48
- if self._bert_vits2_speakers_count != 0: self.logger.info(f"[Bert-VITS2] {self._bert_vits2_speakers_count} speakers")
 
49
  self.logger.info(f"{self._speakers_count} speakers in total.")
50
  if self._speakers_count == 0:
51
  self.logger.warning(f"No model was loaded.")
@@ -196,84 +197,73 @@ class TTS:
196
 
197
  return voice_tasks, format
198
 
199
- def create_ssml_infer_task(self, voice_tasks, format, fname):
200
  audios = []
201
- for voice in voice_tasks:
202
- if voice.get("break"):
203
- audios.append(np.zeros(int(voice.get("break") * 22050), dtype=np.int16))
204
  else:
205
- model = voice.get("model").upper()
206
  if model != "VITS" and model != "W2V2-VITS" and model != "EMOTION-VITS":
207
- raise ValueError(f"Unsupported model: {voice.get('model')}")
208
- voice_obj = self._voice_obj[model][voice.get("id")][1]
209
- voice["id"] = self._voice_obj[model][voice.get("id")][0]
210
- audio = voice_obj.get_audio(voice)
211
  audios.append(audio)
212
 
213
  audio = np.concatenate(audios, axis=0)
214
  encoded_audio = self.encode(voice_obj.hps_ms.data.sampling_rate, audio, format)
215
- if getattr(config, "SAVE_AUDIO", False):
216
- path = f"{config.CACHE_PATH}/{fname}"
217
- utils.save_audio(encoded_audio.getvalue(), path)
218
  return encoded_audio
219
 
220
- def vits_infer(self, voice, fname):
221
- format = voice.get("format", "wav")
222
- voice_obj = self._voice_obj["VITS"][voice.get("id")][1]
223
- voice["id"] = self._voice_obj["VITS"][voice.get("id")][0]
 
224
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
225
- audio = voice_obj.get_audio(voice, auto_break=True)
226
  encoded_audio = self.encode(sampling_rate, audio, format)
227
- if getattr(config, "SAVE_AUDIO", False):
228
- path = f"{config.CACHE_PATH}/{fname}"
229
- utils.save_audio(encoded_audio.getvalue(), path)
230
  return encoded_audio
231
 
232
- def stream_vits_infer(self, voice, fname):
233
- format = voice.get("format", "wav")
234
- voice_obj = self._voice_obj["VITS"][voice.get("id")][1]
235
- voice["id"] = self._voice_obj["VITS"][voice.get("id")][0]
236
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
237
- genertator = voice_obj.get_stream_audio(voice, auto_break=True)
238
- audio = BytesIO()
239
  for chunk in genertator:
240
  encoded_audio = self.encode(sampling_rate, chunk, format)
241
  for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio):
242
  yield encoded_audio_chunk
243
- if getattr(config, "SAVE_AUDIO", False):
244
- audio.write(encoded_audio.getvalue())
245
- if getattr(config, "SAVE_AUDIO", False):
246
- path = f"{config.CACHE_PATH}/{fname}"
247
- utils.save_audio(audio.getvalue(), path)
248
-
249
- def hubert_vits_infer(self, voice, fname):
250
- format = voice.get("format", "wav")
251
- voice_obj = self._voice_obj["HUBERT-VITS"][voice.get("id")][1]
252
- voice["id"] = self._voice_obj["HUBERT-VITS"][voice.get("id")][0]
253
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
254
- audio = voice_obj.get_audio(voice)
255
  encoded_audio = self.encode(sampling_rate, audio, format)
256
- if getattr(config, "SAVE_AUDIO", False):
257
- path = f"{config.CACHE_PATH}/{fname}"
258
- utils.save_audio(encoded_audio.getvalue(), path)
259
  return encoded_audio
260
 
261
- def w2v2_vits_infer(self, voice, fname):
262
- format = voice.get("format", "wav")
263
- voice_obj = self._voice_obj["W2V2-VITS"][voice.get("id")][1]
264
- voice["id"] = self._voice_obj["W2V2-VITS"][voice.get("id")][0]
265
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
266
- audio = voice_obj.get_audio(voice, auto_break=True)
267
  encoded_audio = self.encode(sampling_rate, audio, format)
268
- if getattr(config, "SAVE_AUDIO", False):
269
- path = f"{config.CACHE_PATH}/{fname}"
270
- utils.save_audio(encoded_audio.getvalue(), path)
271
  return encoded_audio
272
 
273
- def vits_voice_conversion(self, voice, fname):
274
- original_id = voice.get("original_id")
275
- target_id = voice.get("target_id")
276
- format = voice.get("format")
277
 
278
  original_id_obj = int(self._voice_obj["VITS"][original_id][2])
279
  target_id_obj = int(self._voice_obj["VITS"][target_id][2])
@@ -281,17 +271,14 @@ class TTS:
281
  if original_id_obj != target_id_obj:
282
  raise ValueError(f"speakers are in diffrent VITS Model")
283
 
284
- voice["original_id"] = int(self._voice_obj["VITS"][original_id][0])
285
- voice["target_id"] = int(self._voice_obj["VITS"][target_id][0])
286
 
287
  voice_obj = self._voice_obj["VITS"][original_id][1]
288
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
289
 
290
- audio = voice_obj.voice_conversion(voice)
291
  encoded_audio = self.encode(sampling_rate, audio, format)
292
- if getattr(config, "SAVE_AUDIO", False):
293
- path = f"{config.CACHE_PATH}/{fname}"
294
- utils.save_audio(encoded_audio.getvalue(), path)
295
  return encoded_audio
296
 
297
  def get_dimensional_emotion_npy(self, audio):
@@ -305,14 +292,11 @@ class TTS:
305
 
306
  return emotion_npy
307
 
308
- def bert_vits2_infer(self, voice, fname):
309
- format = voice.get("format", "wav")
310
- voice_obj = self._voice_obj["BERT-VITS2"][voice.get("id")][1]
311
- voice["id"] = self._voice_obj["BERT-VITS2"][voice.get("id")][0]
312
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
313
- audio = voice_obj.get_audio(voice, auto_break=True)
314
  encoded_audio = self.encode(sampling_rate, audio, format)
315
- if getattr(config, "SAVE_AUDIO", False):
316
- path = f"{config.CACHE_PATH}/{fname}"
317
- utils.save_audio(encoded_audio.getvalue(), path)
318
  return encoded_audio
 
8
  import soundfile as sf
9
  from io import BytesIO
10
  from graiax import silkcoder
11
+ import utils
12
  from logger import logger
13
 
14
 
 
16
 
17
 
18
  class TTS:
19
+ def __init__(self, voice_obj, voice_speakers, **kwargs):
20
  self._voice_obj = voice_obj
21
  self._voice_speakers = voice_speakers
22
  self._strength_dict = {"x-weak": 0.25, "weak": 0.5, "Medium": 0.75, "Strong": 1, "x-strong": 1.25}
 
24
  self._vits_speakers_count = len(self._voice_speakers["VITS"])
25
  self._hubert_speakers_count = len(self._voice_speakers["HUBERT-VITS"])
26
  self._w2v2_speakers_count = len(self._voice_speakers["W2V2-VITS"])
27
+ self._w2v2_emotion_count = kwargs.get("w2v2_emotion_count", 0)
28
  self._bert_vits2_speakers_count = len(self._voice_speakers["BERT-VITS2"])
29
  self.dem = None
30
 
31
  # Initialization information
32
  self.logger = logger
33
  self.logger.info(f"torch:{torch.__version__} cuda_available:{torch.cuda.is_available()}")
34
+ self.logger.info(f'device:{kwargs.get("device")} device.type:{kwargs.get("device").type}')
35
 
36
  if getattr(config, "DIMENSIONAL_EMOTION_MODEL", None) != None:
37
  try:
 
45
  if self._vits_speakers_count != 0: self.logger.info(f"[VITS] {self._vits_speakers_count} speakers")
46
  if self._hubert_speakers_count != 0: self.logger.info(f"[hubert] {self._hubert_speakers_count} speakers")
47
  if self._w2v2_speakers_count != 0: self.logger.info(f"[w2v2] {self._w2v2_speakers_count} speakers")
48
+ if self._bert_vits2_speakers_count != 0: self.logger.info(
49
+ f"[Bert-VITS2] {self._bert_vits2_speakers_count} speakers")
50
  self.logger.info(f"{self._speakers_count} speakers in total.")
51
  if self._speakers_count == 0:
52
  self.logger.warning(f"No model was loaded.")
 
197
 
198
  return voice_tasks, format
199
 
200
+ def create_ssml_infer_task(self, tasks, format):
201
  audios = []
202
+ for task in tasks:
203
+ if task.get("break"):
204
+ audios.append(np.zeros(int(task.get("break") * 22050), dtype=np.int16))
205
  else:
206
+ model = task.get("model").upper()
207
  if model != "VITS" and model != "W2V2-VITS" and model != "EMOTION-VITS":
208
+ raise ValueError(f"Unsupported model: {task.get('model')}")
209
+ voice_obj = self._voice_obj[model][task.get("id")][1]
210
+ task["id"] = self._voice_obj[model][task.get("id")][0]
211
+ audio = voice_obj.get_audio(task)
212
  audios.append(audio)
213
 
214
  audio = np.concatenate(audios, axis=0)
215
  encoded_audio = self.encode(voice_obj.hps_ms.data.sampling_rate, audio, format)
 
 
 
216
  return encoded_audio
217
 
218
+ def vits_infer(self, task):
219
+ format = task.get("format", "wav")
220
+ voice_obj = self._voice_obj["VITS"][task.get("id")][1]
221
+ real_id = self._voice_obj["VITS"][task.get("id")][0]
222
+ task["id"] = real_id # Change to real id
223
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
224
+ audio = voice_obj.get_audio(task, auto_break=True)
225
  encoded_audio = self.encode(sampling_rate, audio, format)
 
 
 
226
  return encoded_audio
227
 
228
+ def stream_vits_infer(self, task, fname=None):
229
+ format = task.get("format", "wav")
230
+ voice_obj = self._voice_obj["VITS"][task.get("id")][1]
231
+ task["id"] = self._voice_obj["VITS"][task.get("id")][0]
232
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
233
+ genertator = voice_obj.get_stream_audio(task, auto_break=True)
234
+ # audio = BytesIO()
235
  for chunk in genertator:
236
  encoded_audio = self.encode(sampling_rate, chunk, format)
237
  for encoded_audio_chunk in self.generate_audio_chunks(encoded_audio):
238
  yield encoded_audio_chunk
239
+ # if getattr(config, "SAVE_AUDIO", False):
240
+ # audio.write(encoded_audio.getvalue())
241
+ # if getattr(config, "SAVE_AUDIO", False):
242
+ # path = f"{config.CACHE_PATH}/{fname}"
243
+ # utils.save_audio(audio.getvalue(), path)
244
+
245
+ def hubert_vits_infer(self, task):
246
+ format = task.get("format", "wav")
247
+ voice_obj = self._voice_obj["HUBERT-VITS"][task.get("id")][1]
248
+ task["id"] = self._voice_obj["HUBERT-VITS"][task.get("id")][0]
249
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
250
+ audio = voice_obj.get_audio(task)
251
  encoded_audio = self.encode(sampling_rate, audio, format)
 
 
 
252
  return encoded_audio
253
 
254
+ def w2v2_vits_infer(self, task):
255
+ format = task.get("format", "wav")
256
+ voice_obj = self._voice_obj["W2V2-VITS"][task.get("id")][1]
257
+ task["id"] = self._voice_obj["W2V2-VITS"][task.get("id")][0]
258
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
259
+ audio = voice_obj.get_audio(task, auto_break=True)
260
  encoded_audio = self.encode(sampling_rate, audio, format)
 
 
 
261
  return encoded_audio
262
 
263
+ def vits_voice_conversion(self, task):
264
+ original_id = task.get("original_id")
265
+ target_id = task.get("target_id")
266
+ format = task.get("format")
267
 
268
  original_id_obj = int(self._voice_obj["VITS"][original_id][2])
269
  target_id_obj = int(self._voice_obj["VITS"][target_id][2])
 
271
  if original_id_obj != target_id_obj:
272
  raise ValueError(f"speakers are in diffrent VITS Model")
273
 
274
+ task["original_id"] = int(self._voice_obj["VITS"][original_id][0])
275
+ task["target_id"] = int(self._voice_obj["VITS"][target_id][0])
276
 
277
  voice_obj = self._voice_obj["VITS"][original_id][1]
278
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
279
 
280
+ audio = voice_obj.voice_conversion(task)
281
  encoded_audio = self.encode(sampling_rate, audio, format)
 
 
 
282
  return encoded_audio
283
 
284
  def get_dimensional_emotion_npy(self, audio):
 
292
 
293
  return emotion_npy
294
 
295
+ def bert_vits2_infer(self, task):
296
+ format = task.get("format", "wav")
297
+ voice_obj = self._voice_obj["BERT-VITS2"][task.get("id")][1]
298
+ task["id"] = self._voice_obj["BERT-VITS2"][task.get("id")][0]
299
  sampling_rate = voice_obj.hps_ms.data.sampling_rate
300
+ audio = voice_obj.get_audio(task, auto_break=True)
301
  encoded_audio = self.encode(sampling_rate, audio, format)
 
 
 
302
  return encoded_audio