Mahiruoshi commited on
Commit
5220ea7
1 Parent(s): f4cadb2

Upload 86 files

Browse files
Files changed (47) hide show
  1. bert_gen.py +24 -17
  2. clap_gen.py +1 -1
  3. config.yml +17 -17
  4. data_utils.py +7 -23
  5. export_onnx.py +6 -4
  6. infer.py +53 -100
  7. losses.py +95 -0
  8. models.py +66 -65
  9. onnx_infer.py +68 -0
  10. re_matching.py +0 -1
  11. requirements.txt +2 -3
  12. resample.py +10 -6
  13. resample_legacy.py +71 -0
  14. server.py +214 -47
  15. server_fastapi.py +39 -1
  16. slm/wavlm-base-plus/.gitattributes +27 -0
  17. slm/wavlm-base-plus/README.md +65 -0
  18. slm/wavlm-base-plus/config.json +99 -0
  19. slm/wavlm-base-plus/preprocessor_config.json +9 -0
  20. slm/wavlm-base-plus/pytorch_model.bin +3 -0
  21. text/__init__.py +4 -2
  22. text/__pycache__/__init__.cpython-311.pyc +0 -0
  23. text/__pycache__/bert_utils.cpython-311.pyc +0 -0
  24. text/__pycache__/chinese.cpython-311.pyc +0 -0
  25. text/__pycache__/chinese_bert.cpython-311.pyc +0 -0
  26. text/__pycache__/cleaner.cpython-311.pyc +0 -0
  27. text/__pycache__/english.cpython-311.pyc +0 -0
  28. text/__pycache__/english_bert_mock.cpython-311.pyc +0 -0
  29. text/__pycache__/japanese.cpython-311.pyc +0 -0
  30. text/__pycache__/japanese_bert.cpython-311.pyc +0 -0
  31. text/__pycache__/symbols.cpython-311.pyc +0 -0
  32. text/__pycache__/tone_sandhi.cpython-311.pyc +0 -0
  33. text/chinese_bert.py +21 -3
  34. text/cleaner.py +2 -2
  35. text/english.py +71 -29
  36. text/english_bert_mock.py +21 -2
  37. text/japanese_bert.py +23 -2
  38. text/tone_sandhi.py +7 -3
  39. tools/__pycache__/__init__.cpython-311.pyc +0 -0
  40. tools/__pycache__/classify_language.cpython-311.pyc +0 -0
  41. tools/__pycache__/log.cpython-311.pyc +0 -0
  42. tools/__pycache__/sentence.cpython-311.pyc +0 -0
  43. tools/__pycache__/translate.cpython-311.pyc +0 -0
  44. train_ms.py +172 -58
  45. utils.py +5 -1
  46. webui.py +194 -174
  47. webui_preprocess.py +10 -21
bert_gen.py CHANGED
@@ -1,17 +1,16 @@
1
- import argparse
2
- from multiprocessing import Pool, cpu_count
3
-
4
  import torch
5
- import torch.multiprocessing as mp
6
- from tqdm import tqdm
7
-
8
  import commons
9
  import utils
 
 
 
 
10
  from config import config
11
- from text import cleaned_text_to_sequence, get_bert
12
 
13
 
14
- def process_line(line):
 
15
  device = config.bert_gen_config.device
16
  if config.bert_gen_config.use_multi_device:
17
  rank = mp.current_process()._identity
@@ -28,12 +27,13 @@ def process_line(line):
28
  word2ph = [i for i in word2ph]
29
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
30
 
31
- phone = commons.intersperse(phone, 0)
32
- tone = commons.intersperse(tone, 0)
33
- language = commons.intersperse(language, 0)
34
- for i in range(len(word2ph)):
35
- word2ph[i] = word2ph[i] * 2
36
- word2ph[0] += 1
 
37
 
38
  bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
39
 
@@ -59,16 +59,23 @@ if __name__ == "__main__":
59
  args, _ = parser.parse_known_args()
60
  config_path = args.config
61
  hps = utils.get_hparams_from_file(config_path)
 
62
  lines = []
63
  with open(hps.data.training_files, encoding="utf-8") as f:
64
  lines.extend(f.readlines())
65
 
66
  with open(hps.data.validation_files, encoding="utf-8") as f:
67
  lines.extend(f.readlines())
 
 
68
  if len(lines) != 0:
69
- num_processes = min(args.num_processes, cpu_count())
70
  with Pool(processes=num_processes) as pool:
71
- for _ in tqdm(pool.imap_unordered(process_line, lines), total=len(lines)):
72
- pass
 
 
 
 
73
 
74
  print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
 
 
 
 
1
  import torch
2
+ from multiprocessing import Pool
 
 
3
  import commons
4
  import utils
5
+ from tqdm import tqdm
6
+ from text import check_bert_models, cleaned_text_to_sequence, get_bert
7
+ import argparse
8
+ import torch.multiprocessing as mp
9
  from config import config
 
10
 
11
 
12
+ def process_line(x):
13
+ line, add_blank = x
14
  device = config.bert_gen_config.device
15
  if config.bert_gen_config.use_multi_device:
16
  rank = mp.current_process()._identity
 
27
  word2ph = [i for i in word2ph]
28
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
29
 
30
+ if add_blank:
31
+ phone = commons.intersperse(phone, 0)
32
+ tone = commons.intersperse(tone, 0)
33
+ language = commons.intersperse(language, 0)
34
+ for i in range(len(word2ph)):
35
+ word2ph[i] = word2ph[i] * 2
36
+ word2ph[0] += 1
37
 
38
  bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
39
 
 
59
  args, _ = parser.parse_known_args()
60
  config_path = args.config
61
  hps = utils.get_hparams_from_file(config_path)
62
+ check_bert_models()
63
  lines = []
64
  with open(hps.data.training_files, encoding="utf-8") as f:
65
  lines.extend(f.readlines())
66
 
67
  with open(hps.data.validation_files, encoding="utf-8") as f:
68
  lines.extend(f.readlines())
69
+ add_blank = [hps.data.add_blank] * len(lines)
70
+
71
  if len(lines) != 0:
72
+ num_processes = args.num_processes
73
  with Pool(processes=num_processes) as pool:
74
+ for _ in tqdm(
75
+ pool.imap_unordered(process_line, zip(lines, add_blank)),
76
+ total=len(lines),
77
+ ):
78
+ # 这里是缩进的代码块,表示循环体
79
+ pass # 使用pass语句作为占位符
80
 
81
  print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
clap_gen.py CHANGED
@@ -27,7 +27,7 @@ def process_line(line):
27
  device = torch.device("cpu")
28
  wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
29
 
30
- clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.npy")
31
  if os.path.isfile(clap_path):
32
  return
33
 
 
27
  device = torch.device("cpu")
28
  wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
29
 
30
+ clap_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".emo.pt")
31
  if os.path.isfile(clap_path):
32
  return
33
 
config.yml CHANGED
@@ -4,7 +4,7 @@
4
  # 拟提供通用路径配置,统一存放数据,避免数据放得很乱
5
  # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
6
  # 不填或者填空则路径为相对于项目根目录的路径
7
- dataset_path: "Data/"
8
 
9
  # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
10
  mirror: ""
@@ -17,16 +17,16 @@ resample:
17
  sampling_rate: 44100
18
  # 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
19
  # 请填入相对于datasetPath的相对路径
20
- in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir
21
  # 音频文件重采样后输出路径
22
- out_dir: "audios/wavs"
23
 
24
 
25
  # preprocess_text 数据集预处理相关配置
26
  # 注意, “:” 后需要加空格
27
  preprocess_text:
28
  # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
29
- transcription_path: "filelists/你的数据集文本.list"
30
  # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
31
  cleaned_path: ""
32
  # 训练集路径
@@ -34,11 +34,11 @@ preprocess_text:
34
  # 验证集路径
35
  val_path: "filelists/val.list"
36
  # 配置文件路径
37
- config_path: "config.json"
38
  # 每个语言的验证集条数
39
  val_per_lang: 4
40
  # 验证集最大条数,多于的会被截断并放到训练集中
41
- max_val_total: 12
42
  # 是否进行数据清洗
43
  clean: true
44
 
@@ -47,7 +47,7 @@ preprocess_text:
47
  # 注意, “:” 后需要加空格
48
  bert_gen:
49
  # 训练数据集配置文件路径
50
- config_path: "config.json"
51
  # 并行数
52
  num_processes: 4
53
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
@@ -60,9 +60,9 @@ bert_gen:
60
  # 注意, “:” 后需要加空格
61
  emo_gen:
62
  # 训练数据集配置文件路径
63
- config_path: "config.json"
64
  # 并行数
65
- num_processes: 4
66
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
67
  device: "cuda"
68
  # 使用多卡推理
@@ -81,15 +81,15 @@ train_ms:
81
  # THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
82
  # 底模设置
83
  base:
84
- use_base_model: false
85
  repo_id: "Stardust_minus/Bert-VITS2"
86
- model_image: "Bert-VITS2_2.2-Clap底模" # openi网页的模型名
87
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
88
  model: "models"
89
  # 配置文件路径
90
- config_path: "config.json"
91
  # 训练使用的worker,不建议超过CPU核心数
92
- num_workers: 16
93
  # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
94
  spec_cache: True
95
  # 保存的检查点数量,多于此数目的权重会被删除来节省空间。
@@ -102,9 +102,9 @@ webui:
102
  # 推理设备
103
  device: "cuda"
104
  # 模型路径
105
- model: "models/G_8000.pth"
106
  # 配置文件路径
107
- config_path: "config.json"
108
  # 端口号
109
  port: 7860
110
  # 是否公开部署,对外网开放
@@ -172,6 +172,6 @@ server:
172
  # 请不要在github等网站公开分享你的app id 与 key
173
  translate:
174
  # 你的APPID
175
- "app_key": ""
176
  # 你的密钥
177
- "secret_key": ""
 
4
  # 拟提供通用路径配置,统一存放数据,避免数据放得很乱
5
  # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
6
  # 不填或者填空则路径为相对于项目根目录的路径
7
+ dataset_path: "Data/V23"
8
 
9
  # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
10
  mirror: ""
 
17
  sampling_rate: 44100
18
  # 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
19
  # 请填入相对于datasetPath的相对路径
20
+ in_dir: "" # 相对于根目录的路径为 /datasetPath/in_dir
21
  # 音频文件重采样后输出路径
22
+ out_dir: ""
23
 
24
 
25
  # preprocess_text 数据集预处理相关配置
26
  # 注意, “:” 后需要加空格
27
  preprocess_text:
28
  # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
29
+ transcription_path: "filelists/whole.list"
30
  # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
31
  cleaned_path: ""
32
  # 训练集路径
 
34
  # 验证集路径
35
  val_path: "filelists/val.list"
36
  # 配置文件路径
37
+ config_path: "configs/config.json"
38
  # 每个语言的验证集条数
39
  val_per_lang: 4
40
  # 验证集最大条数,多于的会被截断并放到训练集中
41
+ max_val_total: 800
42
  # 是否进行数据清洗
43
  clean: true
44
 
 
47
  # 注意, “:” 后需要加空格
48
  bert_gen:
49
  # 训练数据集配置文件路径
50
+ config_path: "configs/config.json"
51
  # 并行数
52
  num_processes: 4
53
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
 
60
  # 注意, “:” 后需要加空格
61
  emo_gen:
62
  # 训练数据集配置文件路径
63
+ config_path: "configs/config.json"
64
  # 并行数
65
+ num_processes: 16
66
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
67
  device: "cuda"
68
  # 使用多卡推理
 
81
  # THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
82
  # 底模设置
83
  base:
84
+ use_base_model: True
85
  repo_id: "Stardust_minus/Bert-VITS2"
86
+ model_image: "Bert-VITS2_2.3底模" # openi网页的模型名
87
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
88
  model: "models"
89
  # 配置文件路径
90
+ config_path: "configs/config.json"
91
  # 训练使用的worker,不建议超过CPU核心数
92
+ num_workers: 22
93
  # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
94
  spec_cache: True
95
  # 保存的检查点数量,多于此数目的权重会被删除来节省空间。
 
102
  # 推理设备
103
  device: "cuda"
104
  # 模型路径
105
+ model: "models/G_408000.pth"
106
  # 配置文件路径
107
+ config_path: "configs/config.json"
108
  # 端口号
109
  port: 7860
110
  # 是否公开部署,对外网开放
 
172
  # 请不要在github等网站公开分享你的app id 与 key
173
  translate:
174
  # 你的APPID
175
+ "app_key": "20231117001883321"
176
  # 你的密钥
177
+ "secret_key": "lMQbvZHeJveDceLof2wf"
data_utils.py CHANGED
@@ -44,10 +44,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
44
  self.min_text_len = getattr(hparams, "min_text_len", 1)
45
  self.max_text_len = getattr(hparams, "max_text_len", 384)
46
 
47
- self.empty_emo = torch.squeeze(
48
- torch.load("empty_emo.npy", map_location="cpu"), dim=1
49
- )
50
-
51
  random.seed(1234)
52
  random.shuffle(self.audiopaths_sid_text)
53
  self._filter()
@@ -98,14 +94,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
98
  spec, wav = self.get_audio(audiopath)
99
  sid = torch.LongTensor([int(self.spk_map[sid])])
100
 
101
- if np.random.rand() > 0.1:
102
- emo = torch.squeeze(
103
- torch.load(audiopath.replace(".wav", ".emo.npy"), map_location="cpu"),
104
- dim=1,
105
- )
106
- else:
107
- emo = self.empty_emo
108
- return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
109
 
110
  def get_audio(self, filename):
111
  audio, sampling_rate = load_wav_to_torch(filename)
@@ -168,15 +157,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
168
 
169
  if language_str == "ZH":
170
  bert = bert_ori
171
- ja_bert = torch.rand(1024, len(phone))
172
- en_bert = torch.rand(1024, len(phone))
173
  elif language_str == "JP":
174
- bert = torch.rand(1024, len(phone))
175
  ja_bert = bert_ori
176
- en_bert = torch.rand(1024, len(phone))
177
  elif language_str == "EN":
178
- bert = torch.rand(1024, len(phone))
179
- ja_bert = torch.rand(1024, len(phone))
180
  en_bert = bert_ori
181
  phone = torch.LongTensor(phone)
182
  tone = torch.LongTensor(tone)
@@ -226,7 +215,6 @@ class TextAudioSpeakerCollate:
226
  bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
227
  ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
228
  en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
229
- emo = torch.FloatTensor(len(batch), 512)
230
 
231
  spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
232
  wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
@@ -238,7 +226,6 @@ class TextAudioSpeakerCollate:
238
  bert_padded.zero_()
239
  ja_bert_padded.zero_()
240
  en_bert_padded.zero_()
241
- emo.zero_()
242
 
243
  for i in range(len(ids_sorted_decreasing)):
244
  row = batch[ids_sorted_decreasing[i]]
@@ -272,8 +259,6 @@ class TextAudioSpeakerCollate:
272
  en_bert = row[8]
273
  en_bert_padded[i, :, : en_bert.size(1)] = en_bert
274
 
275
- emo[i, :] = row[9]
276
-
277
  return (
278
  text_padded,
279
  text_lengths,
@@ -287,7 +272,6 @@ class TextAudioSpeakerCollate:
287
  bert_padded,
288
  ja_bert_padded,
289
  en_bert_padded,
290
- emo,
291
  )
292
 
293
 
 
44
  self.min_text_len = getattr(hparams, "min_text_len", 1)
45
  self.max_text_len = getattr(hparams, "max_text_len", 384)
46
 
 
 
 
 
47
  random.seed(1234)
48
  random.shuffle(self.audiopaths_sid_text)
49
  self._filter()
 
94
  spec, wav = self.get_audio(audiopath)
95
  sid = torch.LongTensor([int(self.spk_map[sid])])
96
 
97
+ return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert)
 
 
 
 
 
 
 
98
 
99
  def get_audio(self, filename):
100
  audio, sampling_rate = load_wav_to_torch(filename)
 
157
 
158
  if language_str == "ZH":
159
  bert = bert_ori
160
+ ja_bert = torch.randn(1024, len(phone))
161
+ en_bert = torch.randn(1024, len(phone))
162
  elif language_str == "JP":
163
+ bert = torch.randn(1024, len(phone))
164
  ja_bert = bert_ori
165
+ en_bert = torch.randn(1024, len(phone))
166
  elif language_str == "EN":
167
+ bert = torch.randn(1024, len(phone))
168
+ ja_bert = torch.randn(1024, len(phone))
169
  en_bert = bert_ori
170
  phone = torch.LongTensor(phone)
171
  tone = torch.LongTensor(tone)
 
215
  bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
216
  ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
217
  en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
 
218
 
219
  spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
220
  wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
 
226
  bert_padded.zero_()
227
  ja_bert_padded.zero_()
228
  en_bert_padded.zero_()
 
229
 
230
  for i in range(len(ids_sorted_decreasing)):
231
  row = batch[ids_sorted_decreasing[i]]
 
259
  en_bert = row[8]
260
  en_bert_padded[i, :, : en_bert.size(1)] = en_bert
261
 
 
 
262
  return (
263
  text_padded,
264
  text_lengths,
 
272
  bert_padded,
273
  ja_bert_padded,
274
  en_bert_padded,
 
275
  )
276
 
277
 
export_onnx.py CHANGED
@@ -2,11 +2,13 @@ from onnx_modules import export_onnx
2
  import os
3
 
4
  if __name__ == "__main__":
5
- export_path = "BertVits2.2PT"
6
- model_path = "model\\G_0.pth"
7
- config_path = "model\\config.json"
 
 
8
  if not os.path.exists("onnx"):
9
  os.makedirs("onnx")
10
  if not os.path.exists(f"onnx/{export_path}"):
11
  os.makedirs(f"onnx/{export_path}")
12
- export_onnx(export_path, model_path, config_path)
 
2
  import os
3
 
4
  if __name__ == "__main__":
5
+ export_path = "BangDreamApi"
6
+ model_path = "Data/V23/models/G_621000.pth"
7
+ config_path = "Data/V23/configs/config.json"
8
+ novq = False
9
+ dev = False
10
  if not os.path.exists("onnx"):
11
  os.makedirs("onnx")
12
  if not os.path.exists(f"onnx/{export_path}"):
13
  os.makedirs(f"onnx/{export_path}")
14
+ export_onnx(export_path, model_path, config_path, novq, dev)
infer.py CHANGED
@@ -10,7 +10,8 @@
10
  import torch
11
  import commons
12
  from text import cleaned_text_to_sequence, get_bert
13
- from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
 
14
  from text.cleaner import clean_text
15
  import utils
16
  import numpy as np
@@ -32,7 +33,7 @@ from oldVersion.V101.text import symbols as V101symbols
32
  from oldVersion import V111, V110, V101, V200, V210
33
 
34
  # 当前版本信息
35
- latest_version = "2.2"
36
 
37
  # 版本兼容
38
  SynthesizerTrnMap = {
@@ -98,7 +99,8 @@ def get_net_g(model_path: str, version: str, device: str, hps):
98
  return net_g
99
 
100
 
101
- def get_text(text, language_str, hps, device):
 
102
  # 在此处实现当前版本的get_text
103
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
104
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
@@ -110,21 +112,23 @@ def get_text(text, language_str, hps, device):
110
  for i in range(len(word2ph)):
111
  word2ph[i] = word2ph[i] * 2
112
  word2ph[0] += 1
113
- bert_ori = get_bert(norm_text, word2ph, language_str, device)
 
 
114
  del word2ph
115
  assert bert_ori.shape[-1] == len(phone), phone
116
 
117
  if language_str == "ZH":
118
  bert = bert_ori
119
- ja_bert = torch.rand(1024, len(phone))
120
- en_bert = torch.rand(1024, len(phone))
121
  elif language_str == "JP":
122
- bert = torch.rand(1024, len(phone))
123
  ja_bert = bert_ori
124
- en_bert = torch.rand(1024, len(phone))
125
  elif language_str == "EN":
126
- bert = torch.rand(1024, len(phone))
127
- ja_bert = torch.rand(1024, len(phone))
128
  en_bert = bert_ori
129
  else:
130
  raise ValueError("language_str should be ZH, JP or EN")
@@ -154,84 +158,17 @@ def infer(
154
  reference_audio=None,
155
  skip_start=False,
156
  skip_end=False,
 
 
157
  ):
158
- # 2.2版本参数位置变了
159
- # 2.1 参数新增 emotion reference_audio skip_start skip_end
160
- inferMap_V3 = {
161
- "2.1": V210.infer,
162
- }
163
- # 支持中日英三语版本
164
- inferMap_V2 = {
165
- "2.0.2-fix": V200.infer,
166
- "2.0.1": V200.infer,
167
- "2.0": V200.infer,
168
- "1.1.1-fix": V111.infer_fix,
169
- "1.1.1": V111.infer,
170
- "1.1": V110.infer,
171
- "1.1.0": V110.infer,
172
- }
173
- # 仅支持中文版本
174
- # 在测试中,并未发现两个版本的模型不能互相通用
175
- inferMap_V1 = {
176
- "1.0.1": V101.infer,
177
- "1.0": V101.infer,
178
- "1.0.0": V101.infer,
179
- }
180
- version = hps.version if hasattr(hps, "version") else latest_version
181
- # 非当前版本,根据版本号选择合适的infer
182
- if version != latest_version:
183
- if version in inferMap_V3.keys():
184
- return inferMap_V3[version](
185
- text,
186
- sdp_ratio,
187
- noise_scale,
188
- noise_scale_w,
189
- length_scale,
190
- sid,
191
- language,
192
- hps,
193
- net_g,
194
- device,
195
- reference_audio,
196
- emotion,
197
- skip_start,
198
- skip_end,
199
- )
200
- if version in inferMap_V2.keys():
201
- return inferMap_V2[version](
202
- text,
203
- sdp_ratio,
204
- noise_scale,
205
- noise_scale_w,
206
- length_scale,
207
- sid,
208
- language,
209
- hps,
210
- net_g,
211
- device,
212
- )
213
- if version in inferMap_V1.keys():
214
- return inferMap_V1[version](
215
- text,
216
- sdp_ratio,
217
- noise_scale,
218
- noise_scale_w,
219
- length_scale,
220
- sid,
221
- hps,
222
- net_g,
223
- device,
224
- )
225
- # 在此处实现当前版本的推理
226
- # emo = get_emo_(reference_audio, emotion, sid)
227
- if isinstance(reference_audio, np.ndarray):
228
- emo = get_clap_audio_feature(reference_audio, device)
229
- else:
230
- emo = get_clap_text_feature(emotion, device)
231
- emo = torch.squeeze(emo, dim=1)
232
 
233
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
234
- text, language, hps, device
 
 
 
 
 
235
  )
236
  if skip_start:
237
  phones = phones[3:]
@@ -255,7 +192,7 @@ def infer(
255
  ja_bert = ja_bert.to(device).unsqueeze(0)
256
  en_bert = en_bert.to(device).unsqueeze(0)
257
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
258
- emo = emo.to(device).unsqueeze(0)
259
  del phones
260
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
261
  audio = (
@@ -268,7 +205,6 @@ def infer(
268
  bert,
269
  ja_bert,
270
  en_bert,
271
- emo,
272
  sdp_ratio=sdp_ratio,
273
  noise_scale=noise_scale,
274
  noise_scale_w=noise_scale_w,
@@ -278,7 +214,16 @@ def infer(
278
  .float()
279
  .numpy()
280
  )
281
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
282
  if torch.cuda.is_available():
283
  torch.cuda.empty_cache()
284
  return audio
@@ -302,14 +247,14 @@ def infer_multilang(
302
  ):
303
  bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
304
  # emo = get_emo_(reference_audio, emotion, sid)
305
- if isinstance(reference_audio, np.ndarray):
306
- emo = get_clap_audio_feature(reference_audio, device)
307
- else:
308
- emo = get_clap_text_feature(emotion, device)
309
- emo = torch.squeeze(emo, dim=1)
310
  for idx, (txt, lang) in enumerate(zip(text, language)):
311
- skip_start = (idx != 0) or (skip_start and idx == 0)
312
- skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
313
  (
314
  temp_bert,
315
  temp_ja_bert,
@@ -318,14 +263,14 @@ def infer_multilang(
318
  temp_tones,
319
  temp_lang_ids,
320
  ) = get_text(txt, lang, hps, device)
321
- if skip_start:
322
  temp_bert = temp_bert[:, 3:]
323
  temp_ja_bert = temp_ja_bert[:, 3:]
324
  temp_en_bert = temp_en_bert[:, 3:]
325
  temp_phones = temp_phones[3:]
326
  temp_tones = temp_tones[3:]
327
  temp_lang_ids = temp_lang_ids[3:]
328
- if skip_end:
329
  temp_bert = temp_bert[:, :-2]
330
  temp_ja_bert = temp_ja_bert[:, :-2]
331
  temp_en_bert = temp_en_bert[:, :-2]
@@ -351,7 +296,7 @@ def infer_multilang(
351
  bert = bert.to(device).unsqueeze(0)
352
  ja_bert = ja_bert.to(device).unsqueeze(0)
353
  en_bert = en_bert.to(device).unsqueeze(0)
354
- emo = emo.to(device).unsqueeze(0)
355
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
356
  del phones
357
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
@@ -365,7 +310,6 @@ def infer_multilang(
365
  bert,
366
  ja_bert,
367
  en_bert,
368
- emo,
369
  sdp_ratio=sdp_ratio,
370
  noise_scale=noise_scale,
371
  noise_scale_w=noise_scale_w,
@@ -375,7 +319,16 @@ def infer_multilang(
375
  .float()
376
  .numpy()
377
  )
378
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
379
  if torch.cuda.is_available():
380
  torch.cuda.empty_cache()
381
  return audio
 
10
  import torch
11
  import commons
12
  from text import cleaned_text_to_sequence, get_bert
13
+
14
+ # from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
15
  from text.cleaner import clean_text
16
  import utils
17
  import numpy as np
 
33
  from oldVersion import V111, V110, V101, V200, V210
34
 
35
  # 当前版本信息
36
+ latest_version = "2.3"
37
 
38
  # 版本兼容
39
  SynthesizerTrnMap = {
 
99
  return net_g
100
 
101
 
102
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
103
+ style_text = None if style_text == "" else style_text
104
  # 在此处实现当前版本的get_text
105
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
106
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
 
112
  for i in range(len(word2ph)):
113
  word2ph[i] = word2ph[i] * 2
114
  word2ph[0] += 1
115
+ bert_ori = get_bert(
116
+ norm_text, word2ph, language_str, device, style_text, style_weight
117
+ )
118
  del word2ph
119
  assert bert_ori.shape[-1] == len(phone), phone
120
 
121
  if language_str == "ZH":
122
  bert = bert_ori
123
+ ja_bert = torch.randn(1024, len(phone))
124
+ en_bert = torch.randn(1024, len(phone))
125
  elif language_str == "JP":
126
+ bert = torch.randn(1024, len(phone))
127
  ja_bert = bert_ori
128
+ en_bert = torch.randn(1024, len(phone))
129
  elif language_str == "EN":
130
+ bert = torch.randn(1024, len(phone))
131
+ ja_bert = torch.randn(1024, len(phone))
132
  en_bert = bert_ori
133
  else:
134
  raise ValueError("language_str should be ZH, JP or EN")
 
158
  reference_audio=None,
159
  skip_start=False,
160
  skip_end=False,
161
+ style_text=None,
162
+ style_weight=0.7,
163
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
166
+ text,
167
+ language,
168
+ hps,
169
+ device,
170
+ style_text=style_text,
171
+ style_weight=style_weight,
172
  )
173
  if skip_start:
174
  phones = phones[3:]
 
192
  ja_bert = ja_bert.to(device).unsqueeze(0)
193
  en_bert = en_bert.to(device).unsqueeze(0)
194
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
195
+ # emo = emo.to(device).unsqueeze(0)
196
  del phones
197
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
198
  audio = (
 
205
  bert,
206
  ja_bert,
207
  en_bert,
 
208
  sdp_ratio=sdp_ratio,
209
  noise_scale=noise_scale,
210
  noise_scale_w=noise_scale_w,
 
214
  .float()
215
  .numpy()
216
  )
217
+ del (
218
+ x_tst,
219
+ tones,
220
+ lang_ids,
221
+ bert,
222
+ x_tst_lengths,
223
+ speakers,
224
+ ja_bert,
225
+ en_bert,
226
+ ) # , emo
227
  if torch.cuda.is_available():
228
  torch.cuda.empty_cache()
229
  return audio
 
247
  ):
248
  bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
249
  # emo = get_emo_(reference_audio, emotion, sid)
250
+ # if isinstance(reference_audio, np.ndarray):
251
+ # emo = get_clap_audio_feature(reference_audio, device)
252
+ # else:
253
+ # emo = get_clap_text_feature(emotion, device)
254
+ # emo = torch.squeeze(emo, dim=1)
255
  for idx, (txt, lang) in enumerate(zip(text, language)):
256
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
257
+ _skip_end = (idx != len(language) - 1) or skip_end
258
  (
259
  temp_bert,
260
  temp_ja_bert,
 
263
  temp_tones,
264
  temp_lang_ids,
265
  ) = get_text(txt, lang, hps, device)
266
+ if _skip_start:
267
  temp_bert = temp_bert[:, 3:]
268
  temp_ja_bert = temp_ja_bert[:, 3:]
269
  temp_en_bert = temp_en_bert[:, 3:]
270
  temp_phones = temp_phones[3:]
271
  temp_tones = temp_tones[3:]
272
  temp_lang_ids = temp_lang_ids[3:]
273
+ if _skip_end:
274
  temp_bert = temp_bert[:, :-2]
275
  temp_ja_bert = temp_ja_bert[:, :-2]
276
  temp_en_bert = temp_en_bert[:, :-2]
 
296
  bert = bert.to(device).unsqueeze(0)
297
  ja_bert = ja_bert.to(device).unsqueeze(0)
298
  en_bert = en_bert.to(device).unsqueeze(0)
299
+ # emo = emo.to(device).unsqueeze(0)
300
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
301
  del phones
302
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
 
310
  bert,
311
  ja_bert,
312
  en_bert,
 
313
  sdp_ratio=sdp_ratio,
314
  noise_scale=noise_scale,
315
  noise_scale_w=noise_scale_w,
 
319
  .float()
320
  .numpy()
321
  )
322
+ del (
323
+ x_tst,
324
+ tones,
325
+ lang_ids,
326
+ bert,
327
+ x_tst_lengths,
328
+ speakers,
329
+ ja_bert,
330
+ en_bert,
331
+ ) # , emo
332
  if torch.cuda.is_available():
333
  torch.cuda.empty_cache()
334
  return audio
losses.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
 
3
 
4
  def feature_loss(fmap_r, fmap_g):
@@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
56
  kl = torch.sum(kl * z_mask)
57
  l = kl / torch.sum(z_mask)
58
  return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torchaudio
3
+ from transformers import AutoModel
4
 
5
 
6
  def feature_loss(fmap_r, fmap_g):
 
58
  kl = torch.sum(kl * z_mask)
59
  l = kl / torch.sum(z_mask)
60
  return l
61
+
62
+
63
+ class WavLMLoss(torch.nn.Module):
64
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
65
+ super(WavLMLoss, self).__init__()
66
+ self.wavlm = AutoModel.from_pretrained(model)
67
+ self.wd = wd
68
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
69
+ self.wavlm.eval()
70
+ for param in self.wavlm.parameters():
71
+ param.requires_grad = False
72
+
73
+ def forward(self, wav, y_rec):
74
+ with torch.no_grad():
75
+ wav_16 = self.resample(wav)
76
+ wav_embeddings = self.wavlm(
77
+ input_values=wav_16, output_hidden_states=True
78
+ ).hidden_states
79
+ y_rec_16 = self.resample(y_rec)
80
+ y_rec_embeddings = self.wavlm(
81
+ input_values=y_rec_16.squeeze(), output_hidden_states=True
82
+ ).hidden_states
83
+
84
+ floss = 0
85
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
86
+ floss += torch.mean(torch.abs(er - eg))
87
+
88
+ return floss.mean()
89
+
90
+ def generator(self, y_rec):
91
+ y_rec_16 = self.resample(y_rec)
92
+ y_rec_embeddings = self.wavlm(
93
+ input_values=y_rec_16, output_hidden_states=True
94
+ ).hidden_states
95
+ y_rec_embeddings = (
96
+ torch.stack(y_rec_embeddings, dim=1)
97
+ .transpose(-1, -2)
98
+ .flatten(start_dim=1, end_dim=2)
99
+ )
100
+ y_df_hat_g = self.wd(y_rec_embeddings)
101
+ loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
102
+
103
+ return loss_gen
104
+
105
+ def discriminator(self, wav, y_rec):
106
+ with torch.no_grad():
107
+ wav_16 = self.resample(wav)
108
+ wav_embeddings = self.wavlm(
109
+ input_values=wav_16, output_hidden_states=True
110
+ ).hidden_states
111
+ y_rec_16 = self.resample(y_rec)
112
+ y_rec_embeddings = self.wavlm(
113
+ input_values=y_rec_16, output_hidden_states=True
114
+ ).hidden_states
115
+
116
+ y_embeddings = (
117
+ torch.stack(wav_embeddings, dim=1)
118
+ .transpose(-1, -2)
119
+ .flatten(start_dim=1, end_dim=2)
120
+ )
121
+ y_rec_embeddings = (
122
+ torch.stack(y_rec_embeddings, dim=1)
123
+ .transpose(-1, -2)
124
+ .flatten(start_dim=1, end_dim=2)
125
+ )
126
+
127
+ y_d_rs = self.wd(y_embeddings)
128
+ y_d_gs = self.wd(y_rec_embeddings)
129
+
130
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
131
+
132
+ r_loss = torch.mean((1 - y_df_hat_r) ** 2)
133
+ g_loss = torch.mean((y_df_hat_g) ** 2)
134
+
135
+ loss_disc_f = r_loss + g_loss
136
+
137
+ return loss_disc_f.mean()
138
+
139
+ def discriminator_forward(self, wav):
140
+ with torch.no_grad():
141
+ wav_16 = self.resample(wav)
142
+ wav_embeddings = self.wavlm(
143
+ input_values=wav_16, output_hidden_states=True
144
+ ).hidden_states
145
+ y_embeddings = (
146
+ torch.stack(wav_embeddings, dim=1)
147
+ .transpose(-1, -2)
148
+ .flatten(start_dim=1, end_dim=2)
149
+ )
150
+
151
+ y_d_rs = self.wd(y_embeddings)
152
+
153
+ return y_d_rs
models.py CHANGED
@@ -40,33 +40,22 @@ class DurationDiscriminator(nn.Module): # vits2
40
  self.norm_2 = modules.LayerNorm(filter_channels)
41
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
42
 
43
- self.pre_out_conv_1 = nn.Conv1d(
44
- 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
  )
46
- self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
47
- self.pre_out_conv_2 = nn.Conv1d(
48
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
49
- )
50
- self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
51
 
52
  if gin_channels != 0:
53
  self.cond = nn.Conv1d(gin_channels, in_channels, 1)
54
 
55
- self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
 
 
56
 
57
- def forward_probability(self, x, x_mask, dur, g=None):
58
  dur = self.dur_proj(dur)
59
  x = torch.cat([x, dur], dim=1)
60
- x = self.pre_out_conv_1(x * x_mask)
61
- x = torch.relu(x)
62
- x = self.pre_out_norm_1(x)
63
- x = self.drop(x)
64
- x = self.pre_out_conv_2(x * x_mask)
65
- x = torch.relu(x)
66
- x = self.pre_out_norm_2(x)
67
- x = self.drop(x)
68
- x = x * x_mask
69
  x = x.transpose(1, 2)
 
70
  output_prob = self.output_layer(x)
71
  return output_prob
72
 
@@ -86,7 +75,7 @@ class DurationDiscriminator(nn.Module): # vits2
86
 
87
  output_probs = []
88
  for dur in [dur_r, dur_hat]:
89
- output_prob = self.forward_probability(x, x_mask, dur, g)
90
  output_probs.append(output_prob)
91
 
92
  return output_probs
@@ -354,7 +343,6 @@ class TextEncoder(nn.Module):
354
  n_layers,
355
  kernel_size,
356
  p_dropout,
357
- n_speakers,
358
  gin_channels=0,
359
  ):
360
  super().__init__()
@@ -376,31 +364,6 @@ class TextEncoder(nn.Module):
376
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
377
  self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
378
  self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
379
- # self.emo_proj = nn.Linear(512, hidden_channels)
380
- self.in_feature_net = nn.Sequential(
381
- # input is assumed to an already normalized embedding
382
- nn.Linear(512, 1028, bias=False),
383
- nn.GELU(),
384
- nn.LayerNorm(1028),
385
- *[Block(1028, 512) for _ in range(1)],
386
- nn.Linear(1028, 512, bias=False),
387
- # normalize before passing to VQ?
388
- # nn.GELU(),
389
- # nn.LayerNorm(512),
390
- )
391
- self.emo_vq = VectorQuantize(
392
- dim=512,
393
- codebook_size=64,
394
- codebook_dim=32,
395
- commitment_weight=0.1,
396
- decay=0.85,
397
- heads=32,
398
- kmeans_iters=20,
399
- separate_codebook_per_head=True,
400
- stochastic_sample_codes=True,
401
- threshold_ema_dead_code=2,
402
- )
403
- self.out_feature_net = nn.Linear(512, hidden_channels)
404
 
405
  self.encoder = attentions.Encoder(
406
  hidden_channels,
@@ -413,18 +376,10 @@ class TextEncoder(nn.Module):
413
  )
414
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
415
 
416
- def forward(
417
- self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
418
- ):
419
- sid = sid.cpu()
420
  bert_emb = self.bert_proj(bert).transpose(1, 2)
421
  ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
422
  en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
423
- emo_emb = self.in_feature_net(emo)
424
- emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
425
- loss_commit = loss_commit.mean()
426
- emo_emb = self.out_feature_net(emo_emb)
427
- # emo_emb = self.emo_proj(emo.unsqueeze(1))
428
  x = (
429
  self.emb(x)
430
  + self.tone_emb(tone)
@@ -432,7 +387,6 @@ class TextEncoder(nn.Module):
432
  + bert_emb
433
  + ja_bert_emb
434
  + en_bert_emb
435
- + emo_emb
436
  ) * math.sqrt(
437
  self.hidden_channels
438
  ) # [b, t, h]
@@ -445,7 +399,7 @@ class TextEncoder(nn.Module):
445
  stats = self.proj(x) * x_mask
446
 
447
  m, logs = torch.split(stats, self.out_channels, dim=1)
448
- return x, m, logs, x_mask, loss_commit
449
 
450
 
451
  class ResidualCouplingBlock(nn.Module):
@@ -748,6 +702,55 @@ class MultiPeriodDiscriminator(torch.nn.Module):
748
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
749
 
750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  class ReferenceEncoder(nn.Module):
752
  """
753
  inputs --- [N, Ty/r, n_mels*r] mels
@@ -878,7 +881,6 @@ class SynthesizerTrn(nn.Module):
878
  n_layers,
879
  kernel_size,
880
  p_dropout,
881
- self.n_speakers,
882
  gin_channels=self.enc_gin_channels,
883
  )
884
  self.dec = Generator(
@@ -946,14 +948,13 @@ class SynthesizerTrn(nn.Module):
946
  bert,
947
  ja_bert,
948
  en_bert,
949
- emo=None,
950
  ):
951
  if self.n_speakers > 0:
952
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
953
  else:
954
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
955
- x, m_p, logs_p, x_mask, loss_commit = self.enc_p(
956
- x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
957
  )
958
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
959
  z_p = self.flow(z, y_mask, g=g)
@@ -996,9 +997,11 @@ class SynthesizerTrn(nn.Module):
996
 
997
  logw_ = torch.log(w + 1e-6) * x_mask
998
  logw = self.dp(x, x_mask, g=g)
 
999
  l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1000
  x_mask
1001
  ) # for averaging
 
1002
 
1003
  l_length = l_length_dp + l_length_sdp
1004
 
@@ -1018,9 +1021,8 @@ class SynthesizerTrn(nn.Module):
1018
  x_mask,
1019
  y_mask,
1020
  (z, z_p, m_p, logs_p, m_q, logs_q),
1021
- (x, logw, logw_),
1022
  g,
1023
- loss_commit,
1024
  )
1025
 
1026
  def infer(
@@ -1033,7 +1035,6 @@ class SynthesizerTrn(nn.Module):
1033
  bert,
1034
  ja_bert,
1035
  en_bert,
1036
- emo=None,
1037
  noise_scale=0.667,
1038
  length_scale=1,
1039
  noise_scale_w=0.8,
@@ -1047,8 +1048,8 @@ class SynthesizerTrn(nn.Module):
1047
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1048
  else:
1049
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1050
- x, m_p, logs_p, x_mask, _ = self.enc_p(
1051
- x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
1052
  )
1053
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1054
  sdp_ratio
 
40
  self.norm_2 = modules.LayerNorm(filter_channels)
41
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
42
 
43
+ self.LSTM = nn.LSTM(
44
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
45
  )
 
 
 
 
 
46
 
47
  if gin_channels != 0:
48
  self.cond = nn.Conv1d(gin_channels, in_channels, 1)
49
 
50
+ self.output_layer = nn.Sequential(
51
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
52
+ )
53
 
54
+ def forward_probability(self, x, dur):
55
  dur = self.dur_proj(dur)
56
  x = torch.cat([x, dur], dim=1)
 
 
 
 
 
 
 
 
 
57
  x = x.transpose(1, 2)
58
+ x, _ = self.LSTM(x)
59
  output_prob = self.output_layer(x)
60
  return output_prob
61
 
 
75
 
76
  output_probs = []
77
  for dur in [dur_r, dur_hat]:
78
+ output_prob = self.forward_probability(x, dur)
79
  output_probs.append(output_prob)
80
 
81
  return output_probs
 
343
  n_layers,
344
  kernel_size,
345
  p_dropout,
 
346
  gin_channels=0,
347
  ):
348
  super().__init__()
 
364
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
365
  self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
366
  self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
  self.encoder = attentions.Encoder(
369
  hidden_channels,
 
376
  )
377
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
378
 
379
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
 
 
 
380
  bert_emb = self.bert_proj(bert).transpose(1, 2)
381
  ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
382
  en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
 
 
 
 
 
383
  x = (
384
  self.emb(x)
385
  + self.tone_emb(tone)
 
387
  + bert_emb
388
  + ja_bert_emb
389
  + en_bert_emb
 
390
  ) * math.sqrt(
391
  self.hidden_channels
392
  ) # [b, t, h]
 
399
  stats = self.proj(x) * x_mask
400
 
401
  m, logs = torch.split(stats, self.out_channels, dim=1)
402
+ return x, m, logs, x_mask
403
 
404
 
405
  class ResidualCouplingBlock(nn.Module):
 
702
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
703
 
704
 
705
+ class WavLMDiscriminator(nn.Module):
706
+ """docstring for Discriminator."""
707
+
708
+ def __init__(
709
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
710
+ ):
711
+ super(WavLMDiscriminator, self).__init__()
712
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
713
+ self.pre = norm_f(
714
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
715
+ )
716
+
717
+ self.convs = nn.ModuleList(
718
+ [
719
+ norm_f(
720
+ nn.Conv1d(
721
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
722
+ )
723
+ ),
724
+ norm_f(
725
+ nn.Conv1d(
726
+ initial_channel * 2,
727
+ initial_channel * 4,
728
+ kernel_size=5,
729
+ padding=2,
730
+ )
731
+ ),
732
+ norm_f(
733
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
734
+ ),
735
+ ]
736
+ )
737
+
738
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
739
+
740
+ def forward(self, x):
741
+ x = self.pre(x)
742
+
743
+ fmap = []
744
+ for l in self.convs:
745
+ x = l(x)
746
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
747
+ fmap.append(x)
748
+ x = self.conv_post(x)
749
+ x = torch.flatten(x, 1, -1)
750
+
751
+ return x
752
+
753
+
754
  class ReferenceEncoder(nn.Module):
755
  """
756
  inputs --- [N, Ty/r, n_mels*r] mels
 
881
  n_layers,
882
  kernel_size,
883
  p_dropout,
 
884
  gin_channels=self.enc_gin_channels,
885
  )
886
  self.dec = Generator(
 
948
  bert,
949
  ja_bert,
950
  en_bert,
 
951
  ):
952
  if self.n_speakers > 0:
953
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
954
  else:
955
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
956
+ x, m_p, logs_p, x_mask = self.enc_p(
957
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
958
  )
959
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
960
  z_p = self.flow(z, y_mask, g=g)
 
997
 
998
  logw_ = torch.log(w + 1e-6) * x_mask
999
  logw = self.dp(x, x_mask, g=g)
1000
+ logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
1001
  l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1002
  x_mask
1003
  ) # for averaging
1004
+ l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1005
 
1006
  l_length = l_length_dp + l_length_sdp
1007
 
 
1021
  x_mask,
1022
  y_mask,
1023
  (z, z_p, m_p, logs_p, m_q, logs_q),
1024
+ (x, logw, logw_, logw_sdp),
1025
  g,
 
1026
  )
1027
 
1028
  def infer(
 
1035
  bert,
1036
  ja_bert,
1037
  en_bert,
 
1038
  noise_scale=0.667,
1039
  length_scale=1,
1040
  noise_scale_w=0.8,
 
1048
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1049
  else:
1050
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1051
+ x, m_p, logs_p, x_mask = self.enc_p(
1052
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
1053
  )
1054
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1055
  sdp_ratio
onnx_infer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from onnx_modules.V220_OnnxInference import OnnxInferenceSession
2
+ import numpy as np
3
+ Session = OnnxInferenceSession(
4
+ {
5
+ "enc" : "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
6
+ "emb_g" : "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
7
+ "dp" : "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
8
+ "sdp" : "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
9
+ "flow" : "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
10
+ "dec" : "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx"
11
+ },
12
+ Providers = ["CPUExecutionProvider"]
13
+ )
14
+
15
+ #这里的输入和原版是一样的,只需要在原版预处理结果出来之后加上.numpy()即可
16
+ x = np.array(
17
+ [
18
+ 0,
19
+ 97,
20
+ 0,
21
+ 8,
22
+ 0,
23
+ 78,
24
+ 0,
25
+ 8,
26
+ 0,
27
+ 76,
28
+ 0,
29
+ 37,
30
+ 0,
31
+ 40,
32
+ 0,
33
+ 97,
34
+ 0,
35
+ 8,
36
+ 0,
37
+ 23,
38
+ 0,
39
+ 8,
40
+ 0,
41
+ 74,
42
+ 0,
43
+ 26,
44
+ 0,
45
+ 104,
46
+ 0,
47
+ ]
48
+ )
49
+ tone = np.zeros_like(x)
50
+ language = np.zeros_like(x)
51
+ sid = np.array([0])
52
+ bert = np.random.randn(x.shape[0], 1024)
53
+ ja_bert = np.random.randn(x.shape[0], 1024)
54
+ en_bert = np.random.randn(x.shape[0], 1024)
55
+ emo = np.random.randn(512, 1)
56
+
57
+ audio = Session(
58
+ x,
59
+ tone,
60
+ language,
61
+ bert,
62
+ ja_bert,
63
+ en_bert,
64
+ emo,
65
+ sid
66
+ )
67
+
68
+ print(audio)
re_matching.py CHANGED
@@ -44,7 +44,6 @@ def text_matching(text: str) -> list:
44
  result = []
45
  for speaker, dialogue in matches:
46
  result.append(extract_language_and_text_updated(speaker, dialogue))
47
- print(result)
48
  return result
49
 
50
 
 
44
  result = []
45
  for speaker, dialogue in matches:
46
  result.append(extract_language_and_text_updated(speaker, dialogue))
 
47
  return result
48
 
49
 
requirements.txt CHANGED
@@ -11,7 +11,7 @@ jieba
11
  transformers
12
  pypinyin
13
  cn2an
14
- gradio==3.38.0
15
  av
16
  mecab-python3
17
  loguru
@@ -21,8 +21,7 @@ fugashi
21
  num2words
22
  PyYAML
23
  requests
24
- pyopenjtalk; sys_platform == 'linux'
25
- openjtalk; sys_platform != 'linux'
26
  jaconv
27
  psutil
28
  GPUtil
 
11
  transformers
12
  pypinyin
13
  cn2an
14
+ gradio==3.50.2
15
  av
16
  mecab-python3
17
  loguru
 
21
  num2words
22
  PyYAML
23
  requests
24
+ pyopenjtalk-prebuilt
 
25
  jaconv
26
  psutil
27
  GPUtil
resample.py CHANGED
@@ -10,11 +10,11 @@ from config import config
10
 
11
 
12
  def process(item):
13
- wav_name, args = item
14
- wav_path = os.path.join(args.in_dir, wav_name)
15
  if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
16
  wav, sr = librosa.load(wav_path, sr=args.sr)
17
- soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
18
 
19
 
20
  if __name__ == "__main__":
@@ -54,11 +54,15 @@ if __name__ == "__main__":
54
  tasks = []
55
 
56
  for dirpath, _, filenames in os.walk(args.in_dir):
57
- if not os.path.isdir(args.out_dir):
58
- os.makedirs(args.out_dir, exist_ok=True)
 
 
 
59
  for filename in filenames:
60
  if filename.lower().endswith(".wav"):
61
- tasks.append((filename, args))
 
62
 
63
  for _ in tqdm(
64
  pool.imap_unordered(process, tasks),
 
10
 
11
 
12
  def process(item):
13
+ spkdir, wav_name, args = item
14
+ wav_path = os.path.join(args.in_dir, spkdir, wav_name)
15
  if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
16
  wav, sr = librosa.load(wav_path, sr=args.sr)
17
+ soundfile.write(os.path.join(args.out_dir, spkdir, wav_name), wav, sr)
18
 
19
 
20
  if __name__ == "__main__":
 
54
  tasks = []
55
 
56
  for dirpath, _, filenames in os.walk(args.in_dir):
57
+ # 子级目录
58
+ spk_dir = os.path.relpath(dirpath, args.in_dir)
59
+ spk_dir_out = os.path.join(args.out_dir, spk_dir)
60
+ if not os.path.isdir(spk_dir_out):
61
+ os.makedirs(spk_dir_out, exist_ok=True)
62
  for filename in filenames:
63
  if filename.lower().endswith(".wav"):
64
+ twople = (spk_dir, filename, args)
65
+ tasks.append(twople)
66
 
67
  for _ in tqdm(
68
  pool.imap_unordered(process, tasks),
resample_legacy.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import librosa
4
+ from multiprocessing import Pool, cpu_count
5
+
6
+ import soundfile
7
+ from tqdm import tqdm
8
+
9
+ from config import config
10
+
11
+
12
+ def process(item):
13
+ wav_name, args = item
14
+ wav_path = os.path.join(args.in_dir, wav_name)
15
+ if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
16
+ wav, sr = librosa.load(wav_path, sr=args.sr)
17
+ soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--sr",
24
+ type=int,
25
+ default=config.resample_config.sampling_rate,
26
+ help="sampling rate",
27
+ )
28
+ parser.add_argument(
29
+ "--in_dir",
30
+ type=str,
31
+ default=config.resample_config.in_dir,
32
+ help="path to source dir",
33
+ )
34
+ parser.add_argument(
35
+ "--out_dir",
36
+ type=str,
37
+ default=config.resample_config.out_dir,
38
+ help="path to target dir",
39
+ )
40
+ parser.add_argument(
41
+ "--processes",
42
+ type=int,
43
+ default=0,
44
+ help="cpu_processes",
45
+ )
46
+ args, _ = parser.parse_known_args()
47
+ # autodl 无卡模式会识别出46个cpu
48
+ if args.processes == 0:
49
+ processes = cpu_count() - 2 if cpu_count() > 4 else 1
50
+ else:
51
+ processes = args.processes
52
+ pool = Pool(processes=processes)
53
+
54
+ tasks = []
55
+
56
+ for dirpath, _, filenames in os.walk(args.in_dir):
57
+ if not os.path.isdir(args.out_dir):
58
+ os.makedirs(args.out_dir, exist_ok=True)
59
+ for filename in filenames:
60
+ if filename.lower().endswith(".wav"):
61
+ tasks.append((filename, args))
62
+
63
+ for _ in tqdm(
64
+ pool.imap_unordered(process, tasks),
65
+ ):
66
+ pass
67
+
68
+ pool.close()
69
+ pool.join()
70
+
71
+ print("音频重采样完毕!")
server.py CHANGED
@@ -4,9 +4,6 @@ from pathlib import Path
4
 
5
  import logging
6
  import re_matching
7
- import uuid
8
- from flask import Flask, request, jsonify, render_template_string
9
- from flask_cors import CORS
10
 
11
  logging.getLogger("numba").setLevel(logging.WARNING)
12
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
@@ -18,6 +15,7 @@ logging.basicConfig(
18
  )
19
 
20
  logger = logging.getLogger(__name__)
 
21
  import librosa
22
  import numpy as np
23
  import torch
@@ -25,25 +23,31 @@ import torch.nn as nn
25
  from torch.utils.data import Dataset
26
  from torch.utils.data import DataLoader, Dataset
27
  from tqdm import tqdm
 
 
 
 
 
 
 
28
 
29
  import utils
30
  from config import config
31
- import requests
32
  import torch
33
  import commons
34
  from text import cleaned_text_to_sequence, get_bert
35
- from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
36
-
37
  from text.cleaner import clean_text
38
  import utils
39
 
40
  from models import SynthesizerTrn
41
  from text.symbols import symbols
42
  import sys
43
-
44
  from scipy.io.wavfile import write
 
45
 
46
  net_g = None
 
47
  device = (
48
  "cuda:0"
49
  if torch.cuda.is_available()
@@ -54,7 +58,22 @@ device = (
54
  )
55
  )
56
 
57
- #device = 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def get_net_g(model_path: str, device: str, hps):
60
  net_g = SynthesizerTrn(
@@ -68,11 +87,11 @@ def get_net_g(model_path: str, device: str, hps):
68
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
69
  return net_g
70
 
71
-
72
- def get_text(text, language_str, hps, device):
73
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
74
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
75
- #print(text)
76
  if hps.data.add_blank:
77
  phone = commons.intersperse(phone, 0)
78
  tone = commons.intersperse(tone, 0)
@@ -80,18 +99,24 @@ def get_text(text, language_str, hps, device):
80
  for i in range(len(word2ph)):
81
  word2ph[i] = word2ph[i] * 2
82
  word2ph[0] += 1
83
- bert_ori = get_bert(norm_text, word2ph, language_str, device)
 
 
84
  del word2ph
85
  assert bert_ori.shape[-1] == len(phone), phone
86
 
87
  if language_str == "ZH":
88
  bert = bert_ori
89
- ja_bert = torch.zeros(1024, len(phone))
90
- en_bert = torch.zeros(1024, len(phone))
91
  elif language_str == "JP":
92
- bert = torch.zeros(1024, len(phone))
93
  ja_bert = bert_ori
94
- en_bert = torch.zeros(1024, len(phone))
 
 
 
 
95
  else:
96
  raise ValueError("language_str should be ZH, JP or EN")
97
 
@@ -104,6 +129,7 @@ def get_text(text, language_str, hps, device):
104
  language = torch.LongTensor(language)
105
  return bert, ja_bert, en_bert, phone, tone, language
106
 
 
107
  def infer(
108
  text,
109
  sdp_ratio,
@@ -111,18 +137,18 @@ def infer(
111
  noise_scale_w,
112
  length_scale,
113
  sid,
114
- reference_audio=None,
115
- emotion='Happy',
116
  ):
117
 
118
  language= 'JP' if is_japanese(text) else 'ZH'
119
- if isinstance(reference_audio, np.ndarray):
120
- emo = get_clap_audio_feature(reference_audio, device)
121
- else:
122
- emo = get_clap_text_feature(emotion, device)
123
- emo = torch.squeeze(emo, dim=1)
124
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
125
- text, language, hps, device
 
 
 
 
 
126
  )
127
  with torch.no_grad():
128
  x_tst = phones.to(device).unsqueeze(0)
@@ -132,7 +158,7 @@ def infer(
132
  ja_bert = ja_bert.to(device).unsqueeze(0)
133
  en_bert = en_bert.to(device).unsqueeze(0)
134
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
135
- emo = emo.to(device).unsqueeze(0)
136
  del phones
137
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
138
  audio = (
@@ -145,7 +171,6 @@ def infer(
145
  bert,
146
  ja_bert,
147
  en_bert,
148
- emo,
149
  sdp_ratio=sdp_ratio,
150
  noise_scale=noise_scale,
151
  noise_scale_w=noise_scale_w,
@@ -155,7 +180,80 @@ def infer(
155
  .float()
156
  .numpy()
157
  )
158
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  if torch.cuda.is_available():
160
  torch.cuda.empty_cache()
161
  unique_filename = f"temp{uuid.uuid4()}.wav"
@@ -176,19 +274,11 @@ def loadmodel(model):
176
  except:
177
  return "error"
178
 
179
- def send_audio_to_server(audio_path,text):
180
- url="http://127.0.0.1:3000/response"
181
- files = {'file': open(audio_path, 'rb')}
182
- data = {'text': text}
183
- try:
184
- response = requests.post(url, files=files,data=data)
185
- return response.status_code, response.text
186
- except Exception as e:
187
- return 500, str(e)
188
 
189
- app = Flask(__name__)
190
- CORS(app)
191
- @app.route('/')
192
 
193
  def tts():
194
  global last_text, last_model
@@ -197,7 +287,8 @@ def tts():
197
  noise_scale = float(request.args.get('noise_scale', 0.6))
198
  noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
199
  length_scale = float(request.args.get('length_scale', 1))
200
- emotion = request.args.get('emotion', 'happy')
 
201
  text = request.args.get('text')
202
  is_chat = request.args.get('is_chat', 'false').lower() == 'true'
203
  model = request.args.get('model',modelPaths[-1])
@@ -210,7 +301,7 @@ def tts():
210
  <title>TTS API Documentation</title>
211
  </head>
212
  <body>
213
- <iframe src="http://love.soyorin.top" style="width:100%; height:100vh; border:none;"></iframe>
214
  </body>
215
  </html>
216
  """)
@@ -225,9 +316,7 @@ def tts():
225
  write(unique_filename , 44100, silence)
226
  else:
227
  last_text = text
228
- unique_filename = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, reference_audio=None, emotion=emotion)
229
- status_code, response_text = send_audio_to_server(unique_filename,text)
230
- print(f"Response from server: {response_text} (Status code: {status_code})")
231
  with open(unique_filename ,'rb') as bit:
232
  wav_bytes = bit.read()
233
  os.remove(unique_filename)
@@ -236,14 +325,16 @@ def tts():
236
  'Text': unique_filename .encode('utf-8')}
237
  return wav_bytes, 200, headers
238
 
 
 
239
 
240
  if __name__ == "__main__":
241
  languages = [ "Auto", "ZH", "JP"]
242
  modelPaths = []
243
- for dirpath, dirnames, filenames in os.walk("Data/BangDreamV22/models/"):
244
  for filename in filenames:
245
  modelPaths.append(os.path.join(dirpath, filename))
246
- hps = utils.get_hparams_from_file('Data/BangDreamV22/configs/config.json')
247
  net_g = get_net_g(
248
  model_path=modelPaths[-1], device=device, hps=hps
249
  )
@@ -251,4 +342,80 @@ if __name__ == "__main__":
251
  speakers = list(speaker_ids.keys())
252
  last_text = ""
253
  last_model = modelPaths[-1]
254
- app.run(host="0.0.0.0", port=5000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import logging
6
  import re_matching
 
 
 
7
 
8
  logging.getLogger("numba").setLevel(logging.WARNING)
9
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
 
15
  )
16
 
17
  logger = logging.getLogger(__name__)
18
+
19
  import librosa
20
  import numpy as np
21
  import torch
 
23
  from torch.utils.data import Dataset
24
  from torch.utils.data import DataLoader, Dataset
25
  from tqdm import tqdm
26
+ from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
27
+
28
+ import uuid
29
+ from flask import Flask, request, jsonify, render_template_string
30
+ from flask_cors import CORS
31
+
32
+ import gradio as gr
33
 
34
  import utils
35
  from config import config
36
+
37
  import torch
38
  import commons
39
  from text import cleaned_text_to_sequence, get_bert
 
 
40
  from text.cleaner import clean_text
41
  import utils
42
 
43
  from models import SynthesizerTrn
44
  from text.symbols import symbols
45
  import sys
 
46
  from scipy.io.wavfile import write
47
+ from threading import Thread
48
 
49
  net_g = None
50
+
51
  device = (
52
  "cuda:0"
53
  if torch.cuda.is_available()
 
58
  )
59
  )
60
 
61
+ #device = "cpu"
62
+ BandList = {
63
+ "PoppinParty":["香澄","有咲","たえ","りみ","沙綾"],
64
+ "Afterglow":["蘭","モカ","ひまり","巴","つぐみ"],
65
+ "HelloHappyWorld":["こころ","美咲","薫","花音","はぐみ"],
66
+ "PastelPalettes":["彩","日菜","千聖","イヴ","麻弥"],
67
+ "Roselia":["友希那","紗夜","リサ","燐子","あこ"],
68
+ "RaiseASuilen":["レイヤ","ロック","ますき","チュチュ","パレオ"],
69
+ "Morfonica":["ましろ","瑠唯","つくし","七深","透子"],
70
+ "MyGo":["燈","愛音","そよ","立希","楽奈"],
71
+ "AveMujica":["祥子","睦","海鈴","にゃむ","初華"],
72
+ "圣翔音乐学园":["華戀","光","香子","雙葉","真晝","純那","克洛迪娜","真矢","奈奈"],
73
+ "凛明馆女子学校":["珠緒","壘","文","悠悠子","一愛"],
74
+ "弗隆提亚艺术学校":["艾露","艾露露","菈樂菲","司","靜羽"],
75
+ "西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
76
+ }
77
 
78
  def get_net_g(model_path: str, device: str, hps):
79
  net_g = SynthesizerTrn(
 
87
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
88
  return net_g
89
 
90
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
91
+ style_text = None if style_text == "" else style_text
92
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
93
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
94
+
95
  if hps.data.add_blank:
96
  phone = commons.intersperse(phone, 0)
97
  tone = commons.intersperse(tone, 0)
 
99
  for i in range(len(word2ph)):
100
  word2ph[i] = word2ph[i] * 2
101
  word2ph[0] += 1
102
+ bert_ori = get_bert(
103
+ norm_text, word2ph, language_str, device, style_text, style_weight
104
+ )
105
  del word2ph
106
  assert bert_ori.shape[-1] == len(phone), phone
107
 
108
  if language_str == "ZH":
109
  bert = bert_ori
110
+ ja_bert = torch.randn(1024, len(phone))
111
+ en_bert = torch.randn(1024, len(phone))
112
  elif language_str == "JP":
113
+ bert = torch.randn(1024, len(phone))
114
  ja_bert = bert_ori
115
+ en_bert = torch.randn(1024, len(phone))
116
+ elif language_str == "EN":
117
+ bert = torch.randn(1024, len(phone))
118
+ ja_bert = torch.randn(1024, len(phone))
119
+ en_bert = bert_ori
120
  else:
121
  raise ValueError("language_str should be ZH, JP or EN")
122
 
 
129
  language = torch.LongTensor(language)
130
  return bert, ja_bert, en_bert, phone, tone, language
131
 
132
+
133
  def infer(
134
  text,
135
  sdp_ratio,
 
137
  noise_scale_w,
138
  length_scale,
139
  sid,
140
+ style_text=None,
141
+ style_weight=0.7,
142
  ):
143
 
144
  language= 'JP' if is_japanese(text) else 'ZH'
 
 
 
 
 
145
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
146
+ text,
147
+ language,
148
+ hps,
149
+ device,
150
+ style_text=style_text,
151
+ style_weight=style_weight,
152
  )
153
  with torch.no_grad():
154
  x_tst = phones.to(device).unsqueeze(0)
 
158
  ja_bert = ja_bert.to(device).unsqueeze(0)
159
  en_bert = en_bert.to(device).unsqueeze(0)
160
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
161
+ # emo = emo.to(device).unsqueeze(0)
162
  del phones
163
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
164
  audio = (
 
171
  bert,
172
  ja_bert,
173
  en_bert,
 
174
  sdp_ratio=sdp_ratio,
175
  noise_scale=noise_scale,
176
  noise_scale_w=noise_scale_w,
 
180
  .float()
181
  .numpy()
182
  )
183
+ del (
184
+ x_tst,
185
+ tones,
186
+ lang_ids,
187
+ bert,
188
+ x_tst_lengths,
189
+ speakers,
190
+ ja_bert,
191
+ en_bert,
192
+ ) # , emo
193
+ if torch.cuda.is_available():
194
+ torch.cuda.empty_cache()
195
+ return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
196
+
197
+ def inferAPI(
198
+ text,
199
+ sdp_ratio,
200
+ noise_scale,
201
+ noise_scale_w,
202
+ length_scale,
203
+ sid,
204
+ style_text=None,
205
+ style_weight=0.7,
206
+ ):
207
+
208
+ language= 'JP' if is_japanese(text) else 'ZH'
209
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
210
+ text,
211
+ language,
212
+ hps,
213
+ device,
214
+ style_text=style_text,
215
+ style_weight=style_weight,
216
+ )
217
+ with torch.no_grad():
218
+ x_tst = phones.to(device).unsqueeze(0)
219
+ tones = tones.to(device).unsqueeze(0)
220
+ lang_ids = lang_ids.to(device).unsqueeze(0)
221
+ bert = bert.to(device).unsqueeze(0)
222
+ ja_bert = ja_bert.to(device).unsqueeze(0)
223
+ en_bert = en_bert.to(device).unsqueeze(0)
224
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
225
+ # emo = emo.to(device).unsqueeze(0)
226
+ del phones
227
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
228
+ audio = (
229
+ net_g.infer(
230
+ x_tst,
231
+ x_tst_lengths,
232
+ speakers,
233
+ tones,
234
+ lang_ids,
235
+ bert,
236
+ ja_bert,
237
+ en_bert,
238
+ sdp_ratio=sdp_ratio,
239
+ noise_scale=noise_scale,
240
+ noise_scale_w=noise_scale_w,
241
+ length_scale=length_scale,
242
+ )[0][0, 0]
243
+ .data.cpu()
244
+ .float()
245
+ .numpy()
246
+ )
247
+ del (
248
+ x_tst,
249
+ tones,
250
+ lang_ids,
251
+ bert,
252
+ x_tst_lengths,
253
+ speakers,
254
+ ja_bert,
255
+ en_bert,
256
+ ) # , emo
257
  if torch.cuda.is_available():
258
  torch.cuda.empty_cache()
259
  unique_filename = f"temp{uuid.uuid4()}.wav"
 
274
  except:
275
  return "error"
276
 
277
+ Flaskapp = Flask(__name__)
278
+ CORS(Flaskapp)
279
+ @Flaskapp.route('/')
 
 
 
 
 
 
280
 
281
+ @Flaskapp.route('/')
 
 
282
 
283
  def tts():
284
  global last_text, last_model
 
287
  noise_scale = float(request.args.get('noise_scale', 0.6))
288
  noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
289
  length_scale = float(request.args.get('length_scale', 1))
290
+ style_weight = float(request.args.get('style_weight', 0.7))
291
+ style_text = request.args.get('style_text', 'happy')
292
  text = request.args.get('text')
293
  is_chat = request.args.get('is_chat', 'false').lower() == 'true'
294
  model = request.args.get('model',modelPaths[-1])
 
301
  <title>TTS API Documentation</title>
302
  </head>
303
  <body>
304
+ <iframe src="http://127.0.0.1:7860" style="width:100%; height:100vh; border:none;"></iframe>
305
  </body>
306
  </html>
307
  """)
 
316
  write(unique_filename , 44100, silence)
317
  else:
318
  last_text = text
319
+ unique_filename = inferAPI(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, style_text=style_text, style_weight=style_weight)
 
 
320
  with open(unique_filename ,'rb') as bit:
321
  wav_bytes = bit.read()
322
  os.remove(unique_filename)
 
325
  'Text': unique_filename .encode('utf-8')}
326
  return wav_bytes, 200, headers
327
 
328
+ def gradio_interface():
329
+ return app.launch(share=True)
330
 
331
  if __name__ == "__main__":
332
  languages = [ "Auto", "ZH", "JP"]
333
  modelPaths = []
334
+ for dirpath, dirnames, filenames in os.walk('Data/V23/models/'):
335
  for filename in filenames:
336
  modelPaths.append(os.path.join(dirpath, filename))
337
+ hps = utils.get_hparams_from_file('Data/V23/configs/config.json')
338
  net_g = get_net_g(
339
  model_path=modelPaths[-1], device=device, hps=hps
340
  )
 
342
  speakers = list(speaker_ids.keys())
343
  last_text = ""
344
  last_model = modelPaths[-1]
345
+ with gr.Blocks() as app:
346
+ for band in BandList:
347
+ with gr.TabItem(band):
348
+ for name in BandList[band]:
349
+ with gr.TabItem(name):
350
+ with gr.Row():
351
+ with gr.Column():
352
+ with gr.Row():
353
+ gr.Markdown(
354
+ '<div align="center">'
355
+ f'<img style="width:auto;height:400px;" src="https://mahiruoshi-bangdream-bert-vits2.hf.space/file/image/{name}.png">'
356
+ '</div>'
357
+ )
358
+ length_scale = gr.Slider(
359
+ minimum=0.1, maximum=2, value=1, step=0.01, label="语速调节"
360
+ )
361
+ with gr.Accordion(label="参数设定", open=False):
362
+ sdp_ratio = gr.Slider(
363
+ minimum=0, maximum=1, value=0.5, step=0.01, label="SDP/DP混合比"
364
+ )
365
+ noise_scale = gr.Slider(
366
+ minimum=0.1, maximum=2, value=0.6, step=0.01, label="感情调节"
367
+ )
368
+ noise_scale_w = gr.Slider(
369
+ minimum=0.1, maximum=2, value=0.667, step=0.01, label="音素长度"
370
+ )
371
+ speaker = gr.Dropdown(
372
+ choices=speakers, value=name, label="说话人"
373
+ )
374
+ with gr.Accordion(label="切换模型", open=False):
375
+ modelstrs = gr.Dropdown(label = "模型", choices = modelPaths, value = modelPaths[0], type = "value")
376
+ btnMod = gr.Button("载入模型")
377
+ statusa = gr.TextArea()
378
+ btnMod.click(loadmodel, inputs=[modelstrs], outputs = [statusa])
379
+ with gr.Column():
380
+ text = gr.TextArea(
381
+ label="输入纯日语或者中文",
382
+ placeholder="输入纯日语或者中文",
383
+ value="为什么要演奏春日影!",
384
+ )
385
+ style_text = gr.Textbox(label="辅助文本")
386
+ style_weight = gr.Slider(
387
+ minimum=0,
388
+ maximum=1,
389
+ value=0.7,
390
+ step=0.1,
391
+ label="Weight",
392
+ info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
393
+ )
394
+ btn = gr.Button("点击生成", variant="primary")
395
+ audio_output = gr.Audio(label="Output Audio")
396
+ '''
397
+ btntran = gr.Button("快速中翻日")
398
+ translateResult = gr.TextArea("从这复制翻译后的文本")
399
+ btntran.click(translate, inputs=[text], outputs = [translateResult])
400
+ '''
401
+ btn.click(
402
+ infer,
403
+ inputs=[
404
+ text,
405
+ sdp_ratio,
406
+ noise_scale,
407
+ noise_scale_w,
408
+ length_scale,
409
+ speaker,
410
+ style_text,
411
+ style_weight,
412
+ ],
413
+ outputs=[audio_output],
414
+ )
415
+
416
+ api_thread = Thread(target=Flaskapp.run, args=("0.0.0.0", 5000))
417
+ gradio_thread = Thread(target=gradio_interface)
418
+ gradio_thread.start()
419
+ print("推理页面已开启!")
420
+ api_thread.start()
421
+ print("api页面已开启!运行在5000端口")
server_fastapi.py CHANGED
@@ -5,6 +5,7 @@ import logging
5
  import gc
6
  import random
7
 
 
8
  import gradio
9
  import numpy as np
10
  import utils
@@ -203,28 +204,48 @@ if __name__ == "__main__":
203
  auto_split: bool,
204
  emotion: Optional[Union[int, str]] = None,
205
  reference_audio=None,
 
 
206
  ) -> Union[Response, Dict[str, any]]:
207
  """TTS实现函数"""
208
  # 检查模型是否存在
209
  if model_id not in loaded_models.models.keys():
 
210
  return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
211
  # 检查是否提供speaker
212
  if speaker_name is None and speaker_id is None:
 
213
  return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
214
  elif speaker_name is None:
215
  # 检查speaker_id是否存在
216
  if speaker_id not in loaded_models.models[model_id].id2spk.keys():
 
217
  return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
218
  speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
219
  # 检查speaker_name是否存在
220
  if speaker_name not in loaded_models.models[model_id].spk2id.keys():
 
221
  return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
 
222
  if language is None:
223
  language = loaded_models.models[model_id].language
 
224
  if auto_translate:
 
 
 
 
 
 
 
 
225
  text = trans.translate(Sentence=text, to_Language=language.lower())
226
  if reference_audio is not None:
227
  ref_audio = BytesIO(await reference_audio.read())
 
 
 
 
228
  else:
229
  ref_audio = reference_audio
230
  if not auto_split:
@@ -242,6 +263,8 @@ if __name__ == "__main__":
242
  device=loaded_models.models[model_id].device,
243
  emotion=emotion,
244
  reference_audio=ref_audio,
 
 
245
  )
246
  audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
247
  else:
@@ -263,6 +286,8 @@ if __name__ == "__main__":
263
  device=loaded_models.models[model_id].device,
264
  emotion=emotion,
265
  reference_audio=ref_audio,
 
 
266
  )
267
  )
268
  audios.append(np.zeros(int(44100 * 0.2)))
@@ -293,6 +318,8 @@ if __name__ == "__main__":
293
  auto_split: bool = Query(False, description="自动切分"),
294
  emotion: Optional[Union[int, str]] = Query(None, description="emo"),
295
  reference_audio: UploadFile = File(None),
 
 
296
  ):
297
  """语音接口,若需要上传参考音频请仅使用post请求"""
298
  logger.info(
@@ -312,6 +339,8 @@ if __name__ == "__main__":
312
  auto_split=auto_split,
313
  emotion=emotion,
314
  reference_audio=reference_audio,
 
 
315
  )
316
 
317
  @app.get("/voice")
@@ -331,6 +360,8 @@ if __name__ == "__main__":
331
  auto_translate: bool = Query(False, description="自动翻译"),
332
  auto_split: bool = Query(False, description="自动切分"),
333
  emotion: Optional[Union[int, str]] = Query(None, description="emo"),
 
 
334
  ):
335
  """语音接口"""
336
  logger.info(
@@ -349,6 +380,8 @@ if __name__ == "__main__":
349
  auto_translate=auto_translate,
350
  auto_split=auto_split,
351
  emotion=emotion,
 
 
352
  )
353
 
354
  @app.get("/models/info")
@@ -370,7 +403,9 @@ if __name__ == "__main__":
370
  )
371
  result = loaded_models.del_model(model_id)
372
  if result is None:
 
373
  return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
 
374
  return {"status": 0, "detail": "删除成功"}
375
 
376
  @app.get("/models/add")
@@ -394,6 +429,7 @@ if __name__ == "__main__":
394
  elif os.path.isfile(os.path.join(model_dir, "../config.json")):
395
  config_path = os.path.join(model_dir, "../config.json")
396
  else:
 
397
  return {
398
  "status": 15,
399
  "detail": "查询未传入配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
@@ -628,8 +664,10 @@ if __name__ == "__main__":
628
  f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
629
  )
630
  if not os.path.isfile(path):
 
631
  return {"status": 18, "detail": "指定音频不存在"}
632
- if not path.endswith(".wav"):
 
633
  return {"status": 19, "detail": "非wav格式文件"}
634
  return FileResponse(path=path)
635
 
 
5
  import gc
6
  import random
7
 
8
+ import librosa
9
  import gradio
10
  import numpy as np
11
  import utils
 
204
  auto_split: bool,
205
  emotion: Optional[Union[int, str]] = None,
206
  reference_audio=None,
207
+ style_text: Optional[str] = None,
208
+ style_weight: float = 0.7,
209
  ) -> Union[Response, Dict[str, any]]:
210
  """TTS实现函数"""
211
  # 检查模型是否存在
212
  if model_id not in loaded_models.models.keys():
213
+ logger.error(f"/voice 请求错误:模型model_id={model_id}未加载")
214
  return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
215
  # 检查是否提供speaker
216
  if speaker_name is None and speaker_id is None:
217
+ logger.error("/voice 请求错误:推理请求未提供speaker_name或speaker_id")
218
  return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
219
  elif speaker_name is None:
220
  # 检查speaker_id是否存在
221
  if speaker_id not in loaded_models.models[model_id].id2spk.keys():
222
+ logger.error(f"/voice 请求错误:角色speaker_id={speaker_id}不存在")
223
  return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
224
  speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
225
  # 检查speaker_name是否存在
226
  if speaker_name not in loaded_models.models[model_id].spk2id.keys():
227
+ logger.error(f"/voice 请求错误:角色speaker_name={speaker_name}不存在")
228
  return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
229
+ # 未传入则使用默认语言
230
  if language is None:
231
  language = loaded_models.models[model_id].language
232
+ # 翻译会破坏mix结构,auto也会变得无意义。不要在这两个模式下使用
233
  if auto_translate:
234
+ if language == "auto" or language == "mix":
235
+ logger.error(
236
+ f"/voice 请求错误:请勿同时使用language = {language}与auto_translate模式"
237
+ )
238
+ return {
239
+ "status": 20,
240
+ "detail": f"请勿同时使用language = {language}与auto_translate模式",
241
+ }
242
  text = trans.translate(Sentence=text, to_Language=language.lower())
243
  if reference_audio is not None:
244
  ref_audio = BytesIO(await reference_audio.read())
245
+ # 2.2 适配
246
+ if loaded_models.models[model_id].version == "2.2":
247
+ ref_audio, _ = librosa.load(ref_audio, 48000)
248
+
249
  else:
250
  ref_audio = reference_audio
251
  if not auto_split:
 
263
  device=loaded_models.models[model_id].device,
264
  emotion=emotion,
265
  reference_audio=ref_audio,
266
+ style_text=style_text,
267
+ style_weight=style_weight,
268
  )
269
  audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
270
  else:
 
286
  device=loaded_models.models[model_id].device,
287
  emotion=emotion,
288
  reference_audio=ref_audio,
289
+ style_text=style_text,
290
+ style_weight=style_weight,
291
  )
292
  )
293
  audios.append(np.zeros(int(44100 * 0.2)))
 
318
  auto_split: bool = Query(False, description="自动切分"),
319
  emotion: Optional[Union[int, str]] = Query(None, description="emo"),
320
  reference_audio: UploadFile = File(None),
321
+ style_text: Optional[str] = Form(None, description="风格文本"),
322
+ style_weight: float = Query(0.7, description="风格权重"),
323
  ):
324
  """语音接口,若需要上传参考音频请仅使用post请求"""
325
  logger.info(
 
339
  auto_split=auto_split,
340
  emotion=emotion,
341
  reference_audio=reference_audio,
342
+ style_text=style_text,
343
+ style_weight=style_weight,
344
  )
345
 
346
  @app.get("/voice")
 
360
  auto_translate: bool = Query(False, description="自动翻译"),
361
  auto_split: bool = Query(False, description="自动切分"),
362
  emotion: Optional[Union[int, str]] = Query(None, description="emo"),
363
+ style_text: Optional[str] = Query(None, description="风格文本"),
364
+ style_weight: float = Query(0.7, description="风格权重"),
365
  ):
366
  """语音接口"""
367
  logger.info(
 
380
  auto_translate=auto_translate,
381
  auto_split=auto_split,
382
  emotion=emotion,
383
+ style_text=style_text,
384
+ style_weight=style_weight,
385
  )
386
 
387
  @app.get("/models/info")
 
403
  )
404
  result = loaded_models.del_model(model_id)
405
  if result is None:
406
+ logger.error(f"/models/delete 模型删除错误:模型{model_id}不存在,删除失败")
407
  return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
408
+
409
  return {"status": 0, "detail": "删除成功"}
410
 
411
  @app.get("/models/add")
 
429
  elif os.path.isfile(os.path.join(model_dir, "../config.json")):
430
  config_path = os.path.join(model_dir, "../config.json")
431
  else:
432
+ logger.error("/models/add 模型添加失败:未在模型所在目录以及上级目录找到config.json文件")
433
  return {
434
  "status": 15,
435
  "detail": "查询未传入配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
 
664
  f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
665
  )
666
  if not os.path.isfile(path):
667
+ logger.error(f"/tools/get_audio 获取音频错误:指定音频{path}不存在")
668
  return {"status": 18, "detail": "指定音频不存在"}
669
+ if not path.lower().endswith(".wav"):
670
+ logger.error(f"/tools/get_audio 获取音频错误:音频{path}非wav文件")
671
  return {"status": 19, "detail": "非wav格式文件"}
672
  return FileResponse(path=path)
673
 
slm/wavlm-base-plus/.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
slm/wavlm-base-plus/README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ datasets:
5
+ tags:
6
+ - speech
7
+ inference: false
8
+ ---
9
+
10
+ # WavLM-Base-Plus
11
+
12
+ [Microsoft's WavLM](https://github.com/microsoft/unilm/tree/master/wavlm)
13
+
14
+ The base model pretrained on 16kHz sampled speech audio. When using the model, make sure that your speech input is also sampled at 16kHz.
15
+
16
+ **Note**: This model does not have a tokenizer as it was pretrained on audio alone. In order to use this model **speech recognition**, a tokenizer should be created and the model should be fine-tuned on labeled text data. Check out [this blog](https://huggingface.co/blog/fine-tune-wav2vec2-english) for more in-detail explanation of how to fine-tune the model.
17
+
18
+ The model was pre-trained on:
19
+
20
+ - 60,000 hours of [Libri-Light](https://arxiv.org/abs/1912.07875)
21
+ - 10,000 hours of [GigaSpeech](https://arxiv.org/abs/2106.06909)
22
+ - 24,000 hours of [VoxPopuli](https://arxiv.org/abs/2101.00390)
23
+
24
+ [Paper: WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900)
25
+
26
+ Authors: Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei
27
+
28
+ **Abstract**
29
+ *Self-supervised learning (SSL) achieves great success in speech recognition, while limited exploration has been attempted for other speech processing tasks. As speech signal contains multi-faceted information including speaker identity, paralinguistics, spoken content, etc., learning universal representations for all speech tasks is challenging. In this paper, we propose a new pre-trained model, WavLM, to solve full-stack downstream speech tasks. WavLM is built based on the HuBERT framework, with an emphasis on both spoken content modeling and speaker identity preservation. We first equip the Transformer structure with gated relative position bias to improve its capability on recognition tasks. For better speaker discrimination, we propose an utterance mixing training strategy, where additional overlapped utterances are created unsupervisely and incorporated during model training. Lastly, we scale up the training dataset from 60k hours to 94k hours. WavLM Large achieves state-of-the-art performance on the SUPERB benchmark, and brings significant improvements for various speech processing tasks on their representative benchmarks.*
30
+
31
+ The original model can be found under https://github.com/microsoft/unilm/tree/master/wavlm.
32
+
33
+ # Usage
34
+
35
+ This is an English pre-trained speech model that has to be fine-tuned on a downstream task like speech recognition or audio classification before it can be
36
+ used in inference. The model was pre-trained in English and should therefore perform well only in English. The model has been shown to work well on the [SUPERB benchmark](https://superbbenchmark.org/).
37
+
38
+ **Note**: The model was pre-trained on phonemes rather than characters. This means that one should make sure that the input text is converted to a sequence
39
+ of phonemes before fine-tuning.
40
+
41
+ ## Speech Recognition
42
+
43
+ To fine-tune the model for speech recognition, see [the official speech recognition example](https://github.com/huggingface/transformers/tree/master/examples/pytorch/speech-recognition).
44
+
45
+ ## Speech Classification
46
+
47
+ To fine-tune the model for speech classification, see [the official audio classification example](https://github.com/huggingface/transformers/tree/master/examples/pytorch/audio-classification).
48
+
49
+ ## Speaker Verification
50
+
51
+ TODO
52
+
53
+ ## Speaker Diarization
54
+
55
+ TODO
56
+
57
+ # Contribution
58
+
59
+ The model was contributed by [cywang](https://huggingface.co/cywang) and [patrickvonplaten](https://huggingface.co/patrickvonplaten).
60
+
61
+ # License
62
+
63
+ The official license can be found [here](https://github.com/microsoft/UniSpeech/blob/main/LICENSE)
64
+
65
+ ![design](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/wavlm.png)
slm/wavlm-base-plus/config.json ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "wavlm-base-plus",
3
+ "activation_dropout": 0.0,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "WavLMModel"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 256,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": false,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_norm": "group",
51
+ "feat_proj_dropout": 0.1,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "freeze_feat_extract_train": true,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.1,
57
+ "hidden_size": 768,
58
+ "initializer_range": 0.02,
59
+ "intermediate_size": 3072,
60
+ "layer_norm_eps": 1e-05,
61
+ "layerdrop": 0.05,
62
+ "mask_channel_length": 10,
63
+ "mask_channel_min_space": 1,
64
+ "mask_channel_other": 0.0,
65
+ "mask_channel_prob": 0.0,
66
+ "mask_channel_selection": "static",
67
+ "mask_feature_length": 10,
68
+ "mask_feature_min_masks": 0,
69
+ "mask_feature_prob": 0.0,
70
+ "mask_time_length": 10,
71
+ "mask_time_min_masks": 2,
72
+ "mask_time_min_space": 1,
73
+ "mask_time_other": 0.0,
74
+ "mask_time_prob": 0.05,
75
+ "mask_time_selection": "static",
76
+ "model_type": "wavlm",
77
+ "no_mask_channel_overlap": false,
78
+ "no_mask_time_overlap": false,
79
+ "num_adapter_layers": 3,
80
+ "num_attention_heads": 12,
81
+ "num_buckets": 320,
82
+ "num_codevector_groups": 2,
83
+ "num_codevectors_per_group": 320,
84
+ "num_conv_pos_embedding_groups": 16,
85
+ "num_conv_pos_embeddings": 128,
86
+ "num_ctc_classes": 80,
87
+ "num_feat_extract_layers": 7,
88
+ "num_hidden_layers": 12,
89
+ "num_negatives": 100,
90
+ "output_hidden_size": 768,
91
+ "pad_token_id": 0,
92
+ "proj_codevector_dim": 256,
93
+ "replace_prob": 0.5,
94
+ "torch_dtype": "float32",
95
+ "transformers_version": "4.13.0.dev0",
96
+ "use_weighted_layer_sum": false,
97
+ "vocab_size": 32,
98
+ "tokenizer_class": "Wav2Vec2CTCTokenizer"
99
+ }
slm/wavlm-base-plus/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": false,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": true,
8
+ "sampling_rate": 16000
9
+ }
slm/wavlm-base-plus/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bb273a6ace99408b50cfc81afdbb7ef2de02da2eab0234e18db608ce692fe51
3
+ size 377617425
text/__init__.py CHANGED
@@ -18,13 +18,15 @@ def cleaned_text_to_sequence(cleaned_text, tones, language):
18
  return phones, tones, lang_ids
19
 
20
 
21
- def get_bert(norm_text, word2ph, language, device):
22
  from .chinese_bert import get_bert_feature as zh_bert
23
  from .english_bert_mock import get_bert_feature as en_bert
24
  from .japanese_bert import get_bert_feature as jp_bert
25
 
26
  lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
27
- bert = lang_bert_func_map[language](norm_text, word2ph, device)
 
 
28
  return bert
29
 
30
 
 
18
  return phones, tones, lang_ids
19
 
20
 
21
+ def get_bert(norm_text, word2ph, language, device, style_text=None, style_weight=0.7):
22
  from .chinese_bert import get_bert_feature as zh_bert
23
  from .english_bert_mock import get_bert_feature as en_bert
24
  from .japanese_bert import get_bert_feature as jp_bert
25
 
26
  lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
27
+ bert = lang_bert_func_map[language](
28
+ norm_text, word2ph, device, style_text, style_weight
29
+ )
30
  return bert
31
 
32
 
text/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/__init__.cpython-311.pyc and b/text/__pycache__/__init__.cpython-311.pyc differ
 
text/__pycache__/bert_utils.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/bert_utils.cpython-311.pyc and b/text/__pycache__/bert_utils.cpython-311.pyc differ
 
text/__pycache__/chinese.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/chinese.cpython-311.pyc and b/text/__pycache__/chinese.cpython-311.pyc differ
 
text/__pycache__/chinese_bert.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/chinese_bert.cpython-311.pyc and b/text/__pycache__/chinese_bert.cpython-311.pyc differ
 
text/__pycache__/cleaner.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/cleaner.cpython-311.pyc and b/text/__pycache__/cleaner.cpython-311.pyc differ
 
text/__pycache__/english.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/english.cpython-311.pyc and b/text/__pycache__/english.cpython-311.pyc differ
 
text/__pycache__/english_bert_mock.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/english_bert_mock.cpython-311.pyc and b/text/__pycache__/english_bert_mock.cpython-311.pyc differ
 
text/__pycache__/japanese.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/japanese.cpython-311.pyc and b/text/__pycache__/japanese.cpython-311.pyc differ
 
text/__pycache__/japanese_bert.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/japanese_bert.cpython-311.pyc and b/text/__pycache__/japanese_bert.cpython-311.pyc differ
 
text/__pycache__/symbols.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/symbols.cpython-311.pyc and b/text/__pycache__/symbols.cpython-311.pyc differ
 
text/__pycache__/tone_sandhi.cpython-311.pyc CHANGED
Binary files a/text/__pycache__/tone_sandhi.cpython-311.pyc and b/text/__pycache__/tone_sandhi.cpython-311.pyc differ
 
text/chinese_bert.py CHANGED
@@ -12,7 +12,13 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
12
  models = dict()
13
 
14
 
15
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
 
 
 
 
 
 
16
  if (
17
  sys.platform == "darwin"
18
  and torch.backends.mps.is_available()
@@ -29,12 +35,24 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
29
  inputs[i] = inputs[i].to(device)
30
  res = models[device](**inputs, output_hidden_states=True)
31
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
32
-
 
 
 
 
 
 
33
  assert len(word2ph) == len(text) + 2
34
  word2phone = word2ph
35
  phone_level_feature = []
36
  for i in range(len(word2phone)):
37
- repeat_feature = res[i].repeat(word2phone[i], 1)
 
 
 
 
 
 
38
  phone_level_feature.append(repeat_feature)
39
 
40
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
12
  models = dict()
13
 
14
 
15
+ def get_bert_feature(
16
+ text,
17
+ word2ph,
18
+ device=config.bert_gen_config.device,
19
+ style_text=None,
20
+ style_weight=0.7,
21
+ ):
22
  if (
23
  sys.platform == "darwin"
24
  and torch.backends.mps.is_available()
 
35
  inputs[i] = inputs[i].to(device)
36
  res = models[device](**inputs, output_hidden_states=True)
37
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
38
+ if style_text:
39
+ style_inputs = tokenizer(style_text, return_tensors="pt")
40
+ for i in style_inputs:
41
+ style_inputs[i] = style_inputs[i].to(device)
42
+ style_res = models[device](**style_inputs, output_hidden_states=True)
43
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
44
+ style_res_mean = style_res.mean(0)
45
  assert len(word2ph) == len(text) + 2
46
  word2phone = word2ph
47
  phone_level_feature = []
48
  for i in range(len(word2phone)):
49
+ if style_text:
50
+ repeat_feature = (
51
+ res[i].repeat(word2phone[i], 1) * (1 - style_weight)
52
+ + style_res_mean.repeat(word2phone[i], 1) * style_weight
53
+ )
54
+ else:
55
+ repeat_feature = res[i].repeat(word2phone[i], 1)
56
  phone_level_feature.append(repeat_feature)
57
 
58
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
text/cleaner.py CHANGED
@@ -1,7 +1,7 @@
1
- from text import chinese, japanese, cleaned_text_to_sequence
2
 
3
 
4
- language_module_map = {"ZH": chinese, "JP": japanese}
5
 
6
 
7
  def clean_text(text, language):
 
1
+ from text import chinese, japanese, english, cleaned_text_to_sequence
2
 
3
 
4
+ language_module_map = {"ZH": chinese, "JP": japanese, "EN": english}
5
 
6
 
7
  def clean_text(text, language):
text/english.py CHANGED
@@ -5,6 +5,7 @@ from g2p_en import G2p
5
  from transformers import DebertaV2Tokenizer
6
 
7
  from text import symbols
 
8
 
9
  current_file_path = os.path.dirname(__file__)
10
  CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
@@ -217,6 +218,8 @@ def refine_ph(phn):
217
  if re.search(r"\d$", phn):
218
  tone = int(phn[-1]) + 1
219
  phn = phn[:-1]
 
 
220
  return phn.lower(), tone
221
 
222
 
@@ -389,45 +392,84 @@ def sep_text(text):
389
  return words
390
 
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  def g2p(text):
393
  phones = []
394
  tones = []
395
- # word2ph = []
396
- words = sep_text(text)
397
- tokens = [tokenizer.tokenize(i) for i in words]
 
 
398
  for word in words:
399
- if word.upper() in eng_dict:
400
- phns, tns = refine_syllables(eng_dict[word.upper()])
401
- phones.append([post_replace_ph(i) for i in phns])
402
- tones.append(tns)
403
- # word2ph.append(len(phns))
404
- else:
405
- phone_list = list(filter(lambda p: p != " ", _g2p(word)))
406
- phns = []
407
- tns = []
408
- for ph in phone_list:
409
- if ph in arpa:
410
- ph, tn = refine_ph(ph)
411
- phns.append(ph)
412
- tns.append(tn)
413
- else:
414
- phns.append(ph)
415
- tns.append(0)
416
- phones.append([post_replace_ph(i) for i in phns])
417
- tones.append(tns)
418
- # word2ph.append(len(phns))
419
- # phones = [post_replace_ph(i) for i in phones]
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  word2ph = []
422
- for token, phoneme in zip(tokens, phones):
423
- phone_len = len(phoneme)
424
  word_len = len(token)
425
 
426
- aaa = distribute_phone(phone_len, word_len)
427
  word2ph += aaa
428
 
429
- phones = ["_"] + [j for i in phones for j in i] + ["_"]
430
- tones = [0] + [j for i in tones for j in i] + [0]
431
  word2ph = [1] + word2ph + [1]
432
  assert len(phones) == len(tones), text
433
  assert len(phones) == sum(word2ph), text
 
5
  from transformers import DebertaV2Tokenizer
6
 
7
  from text import symbols
8
+ from text.symbols import punctuation
9
 
10
  current_file_path = os.path.dirname(__file__)
11
  CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
 
218
  if re.search(r"\d$", phn):
219
  tone = int(phn[-1]) + 1
220
  phn = phn[:-1]
221
+ else:
222
+ tone = 3
223
  return phn.lower(), tone
224
 
225
 
 
392
  return words
393
 
394
 
395
+ def text_to_words(text):
396
+ tokens = tokenizer.tokenize(text)
397
+ words = []
398
+ for idx, t in enumerate(tokens):
399
+ if t.startswith("▁"):
400
+ words.append([t[1:]])
401
+ else:
402
+ if t in punctuation:
403
+ if idx == len(tokens) - 1:
404
+ words.append([f"{t}"])
405
+ else:
406
+ if (
407
+ not tokens[idx + 1].startswith("▁")
408
+ and tokens[idx + 1] not in punctuation
409
+ ):
410
+ if idx == 0:
411
+ words.append([])
412
+ words[-1].append(f"{t}")
413
+ else:
414
+ words.append([f"{t}"])
415
+ else:
416
+ if idx == 0:
417
+ words.append([])
418
+ words[-1].append(f"{t}")
419
+ return words
420
+
421
+
422
  def g2p(text):
423
  phones = []
424
  tones = []
425
+ phone_len = []
426
+ # words = sep_text(text)
427
+ # tokens = [tokenizer.tokenize(i) for i in words]
428
+ words = text_to_words(text)
429
+
430
  for word in words:
431
+ temp_phones, temp_tones = [], []
432
+ if len(word) > 1:
433
+ if "'" in word:
434
+ word = ["".join(word)]
435
+ for w in word:
436
+ if w in punctuation:
437
+ temp_phones.append(w)
438
+ temp_tones.append(0)
439
+ continue
440
+ if w.upper() in eng_dict:
441
+ phns, tns = refine_syllables(eng_dict[w.upper()])
442
+ temp_phones += [post_replace_ph(i) for i in phns]
443
+ temp_tones += tns
444
+ # w2ph.append(len(phns))
445
+ else:
446
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
447
+ phns = []
448
+ tns = []
449
+ for ph in phone_list:
450
+ if ph in arpa:
451
+ ph, tn = refine_ph(ph)
452
+ phns.append(ph)
453
+ tns.append(tn)
454
+ else:
455
+ phns.append(ph)
456
+ tns.append(0)
457
+ temp_phones += [post_replace_ph(i) for i in phns]
458
+ temp_tones += tns
459
+ phones += temp_phones
460
+ tones += temp_tones
461
+ phone_len.append(len(temp_phones))
462
+ # phones = [post_replace_ph(i) for i in phones]
463
 
464
  word2ph = []
465
+ for token, pl in zip(words, phone_len):
 
466
  word_len = len(token)
467
 
468
+ aaa = distribute_phone(pl, word_len)
469
  word2ph += aaa
470
 
471
+ phones = ["_"] + phones + ["_"]
472
+ tones = [0] + tones + [0]
473
  word2ph = [1] + word2ph + [1]
474
  assert len(phones) == len(tones), text
475
  assert len(phones) == sum(word2ph), text
text/english_bert_mock.py CHANGED
@@ -13,7 +13,13 @@ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
13
  models = dict()
14
 
15
 
16
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
 
 
 
 
 
 
17
  if (
18
  sys.platform == "darwin"
19
  and torch.backends.mps.is_available()
@@ -30,11 +36,24 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
30
  inputs[i] = inputs[i].to(device)
31
  res = models[device](**inputs, output_hidden_states=True)
32
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
 
 
 
 
 
 
 
33
  assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
34
  word2phone = word2ph
35
  phone_level_feature = []
36
  for i in range(len(word2phone)):
37
- repeat_feature = res[i].repeat(word2phone[i], 1)
 
 
 
 
 
 
38
  phone_level_feature.append(repeat_feature)
39
 
40
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
13
  models = dict()
14
 
15
 
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ style_text=None,
21
+ style_weight=0.7,
22
+ ):
23
  if (
24
  sys.platform == "darwin"
25
  and torch.backends.mps.is_available()
 
36
  inputs[i] = inputs[i].to(device)
37
  res = models[device](**inputs, output_hidden_states=True)
38
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
39
+ if style_text:
40
+ style_inputs = tokenizer(style_text, return_tensors="pt")
41
+ for i in style_inputs:
42
+ style_inputs[i] = style_inputs[i].to(device)
43
+ style_res = models[device](**style_inputs, output_hidden_states=True)
44
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
45
+ style_res_mean = style_res.mean(0)
46
  assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
47
  word2phone = word2ph
48
  phone_level_feature = []
49
  for i in range(len(word2phone)):
50
+ if style_text:
51
+ repeat_feature = (
52
+ res[i].repeat(word2phone[i], 1) * (1 - style_weight)
53
+ + style_res_mean.repeat(word2phone[i], 1) * style_weight
54
+ )
55
+ else:
56
+ repeat_feature = res[i].repeat(word2phone[i], 1)
57
  phone_level_feature.append(repeat_feature)
58
 
59
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
text/japanese_bert.py CHANGED
@@ -13,8 +13,16 @@ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
13
  models = dict()
14
 
15
 
16
- def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
 
 
 
 
 
 
17
  text = "".join(text2sep_kata(text)[0])
 
 
18
  if (
19
  sys.platform == "darwin"
20
  and torch.backends.mps.is_available()
@@ -31,12 +39,25 @@ def get_bert_feature(text, word2ph, device=config.bert_gen_config.device):
31
  inputs[i] = inputs[i].to(device)
32
  res = models[device](**inputs, output_hidden_states=True)
33
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
 
 
 
 
 
 
 
34
 
35
  assert len(word2ph) == len(text) + 2
36
  word2phone = word2ph
37
  phone_level_feature = []
38
  for i in range(len(word2phone)):
39
- repeat_feature = res[i].repeat(word2phone[i], 1)
 
 
 
 
 
 
40
  phone_level_feature.append(repeat_feature)
41
 
42
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
13
  models = dict()
14
 
15
 
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ style_text=None,
21
+ style_weight=0.7,
22
+ ):
23
  text = "".join(text2sep_kata(text)[0])
24
+ if style_text:
25
+ style_text = "".join(text2sep_kata(style_text)[0])
26
  if (
27
  sys.platform == "darwin"
28
  and torch.backends.mps.is_available()
 
39
  inputs[i] = inputs[i].to(device)
40
  res = models[device](**inputs, output_hidden_states=True)
41
  res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
42
+ if style_text:
43
+ style_inputs = tokenizer(style_text, return_tensors="pt")
44
+ for i in style_inputs:
45
+ style_inputs[i] = style_inputs[i].to(device)
46
+ style_res = models[device](**style_inputs, output_hidden_states=True)
47
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
48
+ style_res_mean = style_res.mean(0)
49
 
50
  assert len(word2ph) == len(text) + 2
51
  word2phone = word2ph
52
  phone_level_feature = []
53
  for i in range(len(word2phone)):
54
+ if style_text:
55
+ repeat_feature = (
56
+ res[i].repeat(word2phone[i], 1) * (1 - style_weight)
57
+ + style_res_mean.repeat(word2phone[i], 1) * style_weight
58
+ )
59
+ else:
60
+ repeat_feature = res[i].repeat(word2phone[i], 1)
61
  phone_level_feature.append(repeat_feature)
62
 
63
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
text/tone_sandhi.py CHANGED
@@ -634,9 +634,11 @@ class ToneSandhi:
634
  # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
635
  # output seg: [['听一听', 'v']]
636
  def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
637
- new_seg = []
638
  # function 1
639
- for i, (word, pos) in enumerate(seg):
 
 
640
  if (
641
  i - 1 >= 0
642
  and word == "一"
@@ -645,6 +647,7 @@ class ToneSandhi:
645
  and seg[i - 1][1] == "v"
646
  ):
647
  new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
 
648
  else:
649
  if (
650
  i - 2 >= 0
@@ -655,7 +658,8 @@ class ToneSandhi:
655
  continue
656
  else:
657
  new_seg.append([word, pos])
658
- seg = new_seg
 
659
  new_seg = []
660
  # function 2
661
  for i, (word, pos) in enumerate(seg):
 
634
  # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
635
  # output seg: [['听一听', 'v']]
636
  def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
637
+ new_seg = [] * len(seg)
638
  # function 1
639
+ i = 0
640
+ while i < len(seg):
641
+ word, pos = seg[i]
642
  if (
643
  i - 1 >= 0
644
  and word == "一"
 
647
  and seg[i - 1][1] == "v"
648
  ):
649
  new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
650
+ i += 2
651
  else:
652
  if (
653
  i - 2 >= 0
 
658
  continue
659
  else:
660
  new_seg.append([word, pos])
661
+ i += 1
662
+ seg = [i for i in new_seg if len(i) > 0]
663
  new_seg = []
664
  # function 2
665
  for i, (word, pos) in enumerate(seg):
tools/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/tools/__pycache__/__init__.cpython-311.pyc and b/tools/__pycache__/__init__.cpython-311.pyc differ
 
tools/__pycache__/classify_language.cpython-311.pyc CHANGED
Binary files a/tools/__pycache__/classify_language.cpython-311.pyc and b/tools/__pycache__/classify_language.cpython-311.pyc differ
 
tools/__pycache__/log.cpython-311.pyc ADDED
Binary file (547 Bytes). View file
 
tools/__pycache__/sentence.cpython-311.pyc CHANGED
Binary files a/tools/__pycache__/sentence.cpython-311.pyc and b/tools/__pycache__/sentence.cpython-311.pyc differ
 
tools/__pycache__/translate.cpython-311.pyc CHANGED
Binary files a/tools/__pycache__/translate.cpython-311.pyc and b/tools/__pycache__/translate.cpython-311.pyc differ
 
train_ms.py CHANGED
@@ -27,8 +27,15 @@ from models import (
27
  SynthesizerTrn,
28
  MultiPeriodDiscriminator,
29
  DurationDiscriminator,
 
 
 
 
 
 
 
 
30
  )
31
- from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
32
  from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
33
  from text.symbols import symbols
34
 
@@ -42,7 +49,6 @@ torch.backends.cuda.enable_flash_sdp(True)
42
  torch.backends.cuda.enable_mem_efficient_sdp(
43
  True
44
  ) # Not available if torch version is lower than 2.0
45
- torch.backends.cuda.enable_math_sdp(True)
46
  global_step = 0
47
 
48
 
@@ -173,6 +179,8 @@ def run():
173
  0.1,
174
  gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
175
  ).cuda(local_rank)
 
 
176
  if (
177
  "use_spk_conditioned_encoder" in hps.model.keys()
178
  and hps.model.use_spk_conditioned_encoder is True
@@ -210,6 +218,9 @@ def run():
210
  param.requires_grad = False
211
 
212
  net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
 
 
 
213
  optim_g = torch.optim.AdamW(
214
  filter(lambda p: p.requires_grad, net_g.parameters()),
215
  hps.train.learning_rate,
@@ -222,6 +233,12 @@ def run():
222
  betas=hps.train.betas,
223
  eps=hps.train.eps,
224
  )
 
 
 
 
 
 
225
  if net_dur_disc is not None:
226
  optim_dur_disc = torch.optim.AdamW(
227
  net_dur_disc.parameters(),
@@ -233,12 +250,11 @@ def run():
233
  optim_dur_disc = None
234
  net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
235
  net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
236
- dur_resume_lr = None
237
  if net_dur_disc is not None:
238
  net_dur_disc = DDP(
239
  net_dur_disc,
240
  device_ids=[local_rank],
241
- find_unused_parameters=True,
242
  bucket_cap_mb=512,
243
  )
244
 
@@ -250,9 +266,10 @@ def run():
250
  token=config.openi_token,
251
  mirror=config.mirror,
252
  )
253
-
254
- try:
255
- if net_dur_disc is not None:
 
256
  _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
257
  utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
258
  net_dur_disc,
@@ -261,28 +278,32 @@ def run():
261
  if "skip_optimizer" in hps.train
262
  else True,
263
  )
264
- _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
265
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
266
- net_g,
267
- optim_g,
268
- skip_optimizer=hps.train.skip_optimizer
269
- if "skip_optimizer" in hps.train
270
- else True,
271
- )
272
- _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
273
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
274
- net_d,
275
- optim_d,
276
- skip_optimizer=hps.train.skip_optimizer
277
- if "skip_optimizer" in hps.train
278
- else True,
279
- )
280
- if not optim_g.param_groups[0].get("initial_lr"):
281
- optim_g.param_groups[0]["initial_lr"] = g_resume_lr
282
- if not optim_d.param_groups[0].get("initial_lr"):
283
- optim_d.param_groups[0]["initial_lr"] = d_resume_lr
284
  if not optim_dur_disc.param_groups[0].get("initial_lr"):
285
  optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  epoch_str = max(epoch_str, 1)
288
  # global_step = (epoch_str - 1) * len(train_loader)
@@ -297,21 +318,43 @@ def run():
297
  epoch_str = 1
298
  global_step = 0
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
301
  optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
302
  )
303
  scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
304
  optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
305
  )
 
 
 
306
  if net_dur_disc is not None:
307
- if not optim_dur_disc.param_groups[0].get("initial_lr"):
308
- optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
309
  scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
310
  optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
311
  )
312
  else:
313
  scheduler_dur_disc = None
314
- scaler = GradScaler(enabled=hps.train.fp16_run)
 
 
 
 
 
 
 
315
 
316
  for epoch in range(epoch_str, hps.train.epochs + 1):
317
  if rank == 0:
@@ -320,9 +363,9 @@ def run():
320
  local_rank,
321
  epoch,
322
  hps,
323
- [net_g, net_d, net_dur_disc],
324
- [optim_g, optim_d, optim_dur_disc],
325
- [scheduler_g, scheduler_d, scheduler_dur_disc],
326
  scaler,
327
  [train_loader, eval_loader],
328
  logger,
@@ -334,9 +377,9 @@ def run():
334
  local_rank,
335
  epoch,
336
  hps,
337
- [net_g, net_d, net_dur_disc],
338
- [optim_g, optim_d, optim_dur_disc],
339
- [scheduler_g, scheduler_d, scheduler_dur_disc],
340
  scaler,
341
  [train_loader, None],
342
  None,
@@ -344,6 +387,7 @@ def run():
344
  )
345
  scheduler_g.step()
346
  scheduler_d.step()
 
347
  if net_dur_disc is not None:
348
  scheduler_dur_disc.step()
349
 
@@ -361,9 +405,9 @@ def train_and_evaluate(
361
  logger,
362
  writers,
363
  ):
364
- net_g, net_d, net_dur_disc = nets
365
- optim_g, optim_d, optim_dur_disc = optims
366
- scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
367
  train_loader, eval_loader = loaders
368
  if writers is not None:
369
  writer, writer_eval = writers
@@ -373,6 +417,7 @@ def train_and_evaluate(
373
 
374
  net_g.train()
375
  net_d.train()
 
376
  if net_dur_disc is not None:
377
  net_dur_disc.train()
378
  for batch_idx, (
@@ -388,7 +433,6 @@ def train_and_evaluate(
388
  bert,
389
  ja_bert,
390
  en_bert,
391
- emo,
392
  ) in enumerate(tqdm(train_loader)):
393
  if net_g.module.use_noise_scaled_mas:
394
  current_mas_noise_scale = (
@@ -411,9 +455,8 @@ def train_and_evaluate(
411
  bert = bert.cuda(local_rank, non_blocking=True)
412
  ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
413
  en_bert = en_bert.cuda(local_rank, non_blocking=True)
414
- emo = emo.cuda(local_rank, non_blocking=True)
415
 
416
- with autocast(enabled=hps.train.fp16_run):
417
  (
418
  y_hat,
419
  l_length,
@@ -422,9 +465,8 @@ def train_and_evaluate(
422
  x_mask,
423
  z_mask,
424
  (z, z_p, m_p, logs_p, m_q, logs_q),
425
- (hidden_x, logw, logw_),
426
  g,
427
- loss_commit,
428
  ) = net_g(
429
  x,
430
  x_lengths,
@@ -436,7 +478,6 @@ def train_and_evaluate(
436
  bert,
437
  ja_bert,
438
  en_bert,
439
- emo,
440
  )
441
  mel = spec_to_mel_torch(
442
  spec,
@@ -450,7 +491,7 @@ def train_and_evaluate(
450
  mel, ids_slice, hps.train.segment_size // hps.data.hop_length
451
  )
452
  y_hat_mel = mel_spectrogram_torch(
453
- y_hat.squeeze(1),
454
  hps.data.filter_length,
455
  hps.data.n_mel_channels,
456
  hps.data.sampling_rate,
@@ -466,7 +507,7 @@ def train_and_evaluate(
466
 
467
  # Discriminator
468
  y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
469
- with autocast(enabled=False):
470
  loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
471
  y_d_hat_r, y_d_hat_g
472
  )
@@ -475,11 +516,20 @@ def train_and_evaluate(
475
  y_dur_hat_r, y_dur_hat_g = net_dur_disc(
476
  hidden_x.detach(),
477
  x_mask.detach(),
 
478
  logw.detach(),
 
 
 
 
 
479
  logw_.detach(),
 
480
  g.detach(),
481
  )
482
- with autocast(enabled=False):
 
 
483
  # TODO: I think need to mean using the mask, but for now, just mean all
484
  (
485
  loss_dur_disc,
@@ -490,31 +540,60 @@ def train_and_evaluate(
490
  optim_dur_disc.zero_grad()
491
  scaler.scale(loss_dur_disc_all).backward()
492
  scaler.unscale_(optim_dur_disc)
493
- commons.clip_grad_value_(net_dur_disc.parameters(), None)
 
 
 
 
 
494
  scaler.step(optim_dur_disc)
495
 
496
  optim_d.zero_grad()
497
  scaler.scale(loss_disc_all).backward()
498
  scaler.unscale_(optim_d)
 
 
499
  grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
500
  scaler.step(optim_d)
501
 
502
- with autocast(enabled=hps.train.fp16_run):
 
 
 
 
 
 
 
 
 
 
 
 
503
  # Generator
504
  y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
505
  if net_dur_disc is not None:
506
- y_dur_hat_r, y_dur_hat_g = net_dur_disc(
507
- hidden_x, x_mask, logw, logw_, g
508
- )
509
- with autocast(enabled=False):
510
  loss_dur = torch.sum(l_length.float())
511
  loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
512
  loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
513
 
514
  loss_fm = feature_loss(fmap_r, fmap_g)
515
  loss_gen, losses_gen = generator_loss(y_d_hat_g)
 
 
 
 
516
  loss_gen_all = (
517
- loss_gen + loss_fm + loss_mel + loss_dur + loss_kl + loss_commit
 
 
 
 
 
 
518
  )
519
  if net_dur_disc is not None:
520
  loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
@@ -522,6 +601,8 @@ def train_and_evaluate(
522
  optim_g.zero_grad()
523
  scaler.scale(loss_gen_all).backward()
524
  scaler.unscale_(optim_g)
 
 
525
  grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
526
  scaler.step(optim_g)
527
  scaler.update()
@@ -540,9 +621,12 @@ def train_and_evaluate(
540
  scalar_dict = {
541
  "loss/g/total": loss_gen_all,
542
  "loss/d/total": loss_disc_all,
 
543
  "learning_rate": lr,
544
  "grad_norm_d": grad_norm_d,
545
  "grad_norm_g": grad_norm_g,
 
 
546
  }
547
  scalar_dict.update(
548
  {
@@ -550,6 +634,8 @@ def train_and_evaluate(
550
  "loss/g/mel": loss_mel,
551
  "loss/g/dur": loss_dur,
552
  "loss/g/kl": loss_kl,
 
 
553
  }
554
  )
555
  scalar_dict.update(
@@ -562,6 +648,30 @@ def train_and_evaluate(
562
  {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
563
  )
564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  image_dict = {
566
  "slice/mel_org": utils.plot_spectrogram_to_numpy(
567
  y_mel[0].data.cpu().numpy()
@@ -599,6 +709,13 @@ def train_and_evaluate(
599
  epoch,
600
  os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
601
  )
 
 
 
 
 
 
 
602
  if net_dur_disc is not None:
603
  utils.save_checkpoint(
604
  net_dur_disc,
@@ -642,7 +759,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
642
  bert,
643
  ja_bert,
644
  en_bert,
645
- emo,
646
  ) in enumerate(eval_loader):
647
  x, x_lengths = x.cuda(), x_lengths.cuda()
648
  spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
@@ -653,7 +769,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
653
  en_bert = en_bert.cuda()
654
  tone = tone.cuda()
655
  language = language.cuda()
656
- emo = emo.cuda()
657
  for use_sdp in [True, False]:
658
  y_hat, attn, mask, *_ = generator.module.infer(
659
  x,
@@ -664,7 +779,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
664
  bert,
665
  ja_bert,
666
  en_bert,
667
- emo,
668
  y=spec,
669
  max_len=1000,
670
  sdp_ratio=0.0 if not use_sdp else 1.0,
 
27
  SynthesizerTrn,
28
  MultiPeriodDiscriminator,
29
  DurationDiscriminator,
30
+ WavLMDiscriminator,
31
+ )
32
+ from losses import (
33
+ generator_loss,
34
+ discriminator_loss,
35
+ feature_loss,
36
+ kl_loss,
37
+ WavLMLoss,
38
  )
 
39
  from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
40
  from text.symbols import symbols
41
 
 
49
  torch.backends.cuda.enable_mem_efficient_sdp(
50
  True
51
  ) # Not available if torch version is lower than 2.0
 
52
  global_step = 0
53
 
54
 
 
179
  0.1,
180
  gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
181
  ).cuda(local_rank)
182
+ else:
183
+ net_dur_disc = None
184
  if (
185
  "use_spk_conditioned_encoder" in hps.model.keys()
186
  and hps.model.use_spk_conditioned_encoder is True
 
218
  param.requires_grad = False
219
 
220
  net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
221
+ net_wd = WavLMDiscriminator(
222
+ hps.model.slm.hidden, hps.model.slm.nlayers, hps.model.slm.initial_channel
223
+ ).cuda(local_rank)
224
  optim_g = torch.optim.AdamW(
225
  filter(lambda p: p.requires_grad, net_g.parameters()),
226
  hps.train.learning_rate,
 
233
  betas=hps.train.betas,
234
  eps=hps.train.eps,
235
  )
236
+ optim_wd = torch.optim.AdamW(
237
+ net_wd.parameters(),
238
+ hps.train.learning_rate,
239
+ betas=hps.train.betas,
240
+ eps=hps.train.eps,
241
+ )
242
  if net_dur_disc is not None:
243
  optim_dur_disc = torch.optim.AdamW(
244
  net_dur_disc.parameters(),
 
250
  optim_dur_disc = None
251
  net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
252
  net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
253
+ net_wd = DDP(net_wd, device_ids=[local_rank], bucket_cap_mb=512)
254
  if net_dur_disc is not None:
255
  net_dur_disc = DDP(
256
  net_dur_disc,
257
  device_ids=[local_rank],
 
258
  bucket_cap_mb=512,
259
  )
260
 
 
266
  token=config.openi_token,
267
  mirror=config.mirror,
268
  )
269
+ dur_resume_lr = hps.train.learning_rate
270
+ wd_resume_lr = hps.train.learning_rate
271
+ if net_dur_disc is not None:
272
+ try:
273
  _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
274
  utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
275
  net_dur_disc,
 
278
  if "skip_optimizer" in hps.train
279
  else True,
280
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  if not optim_dur_disc.param_groups[0].get("initial_lr"):
282
  optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
283
+ except:
284
+ print("Initialize dur_disc")
285
+
286
+ try:
287
+ _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
288
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
289
+ net_g,
290
+ optim_g,
291
+ skip_optimizer=hps.train.skip_optimizer
292
+ if "skip_optimizer" in hps.train
293
+ else True,
294
+ )
295
+ _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
296
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
297
+ net_d,
298
+ optim_d,
299
+ skip_optimizer=hps.train.skip_optimizer
300
+ if "skip_optimizer" in hps.train
301
+ else True,
302
+ )
303
+ if not optim_g.param_groups[0].get("initial_lr"):
304
+ optim_g.param_groups[0]["initial_lr"] = g_resume_lr
305
+ if not optim_d.param_groups[0].get("initial_lr"):
306
+ optim_d.param_groups[0]["initial_lr"] = d_resume_lr
307
 
308
  epoch_str = max(epoch_str, 1)
309
  # global_step = (epoch_str - 1) * len(train_loader)
 
318
  epoch_str = 1
319
  global_step = 0
320
 
321
+ try:
322
+ _, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint(
323
+ utils.latest_checkpoint_path(hps.model_dir, "WD_*.pth"),
324
+ net_wd,
325
+ optim_wd,
326
+ skip_optimizer=hps.train.skip_optimizer
327
+ if "skip_optimizer" in hps.train
328
+ else True,
329
+ )
330
+ if not optim_wd.param_groups[0].get("initial_lr"):
331
+ optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr
332
+ except Exception as e:
333
+ print(e)
334
+
335
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
336
  optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
337
  )
338
  scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
339
  optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
340
  )
341
+ scheduler_wd = torch.optim.lr_scheduler.ExponentialLR(
342
+ optim_wd, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
343
+ )
344
  if net_dur_disc is not None:
 
 
345
  scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
346
  optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
347
  )
348
  else:
349
  scheduler_dur_disc = None
350
+ scaler = GradScaler(enabled=hps.train.bf16_run)
351
+
352
+ wl = WavLMLoss(
353
+ hps.model.slm.model,
354
+ net_wd,
355
+ hps.data.sampling_rate,
356
+ hps.model.slm.sr,
357
+ ).to(local_rank)
358
 
359
  for epoch in range(epoch_str, hps.train.epochs + 1):
360
  if rank == 0:
 
363
  local_rank,
364
  epoch,
365
  hps,
366
+ [net_g, net_d, net_dur_disc, net_wd, wl],
367
+ [optim_g, optim_d, optim_dur_disc, optim_wd],
368
+ [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
369
  scaler,
370
  [train_loader, eval_loader],
371
  logger,
 
377
  local_rank,
378
  epoch,
379
  hps,
380
+ [net_g, net_d, net_dur_disc, net_wd, wl],
381
+ [optim_g, optim_d, optim_dur_disc, optim_wd],
382
+ [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
383
  scaler,
384
  [train_loader, None],
385
  None,
 
387
  )
388
  scheduler_g.step()
389
  scheduler_d.step()
390
+ scheduler_wd.step()
391
  if net_dur_disc is not None:
392
  scheduler_dur_disc.step()
393
 
 
405
  logger,
406
  writers,
407
  ):
408
+ net_g, net_d, net_dur_disc, net_wd, wl = nets
409
+ optim_g, optim_d, optim_dur_disc, optim_wd = optims
410
+ scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers
411
  train_loader, eval_loader = loaders
412
  if writers is not None:
413
  writer, writer_eval = writers
 
417
 
418
  net_g.train()
419
  net_d.train()
420
+ net_wd.train()
421
  if net_dur_disc is not None:
422
  net_dur_disc.train()
423
  for batch_idx, (
 
433
  bert,
434
  ja_bert,
435
  en_bert,
 
436
  ) in enumerate(tqdm(train_loader)):
437
  if net_g.module.use_noise_scaled_mas:
438
  current_mas_noise_scale = (
 
455
  bert = bert.cuda(local_rank, non_blocking=True)
456
  ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
457
  en_bert = en_bert.cuda(local_rank, non_blocking=True)
 
458
 
459
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
460
  (
461
  y_hat,
462
  l_length,
 
465
  x_mask,
466
  z_mask,
467
  (z, z_p, m_p, logs_p, m_q, logs_q),
468
+ (hidden_x, logw, logw_, logw_sdp),
469
  g,
 
470
  ) = net_g(
471
  x,
472
  x_lengths,
 
478
  bert,
479
  ja_bert,
480
  en_bert,
 
481
  )
482
  mel = spec_to_mel_torch(
483
  spec,
 
491
  mel, ids_slice, hps.train.segment_size // hps.data.hop_length
492
  )
493
  y_hat_mel = mel_spectrogram_torch(
494
+ y_hat.squeeze(1).float(),
495
  hps.data.filter_length,
496
  hps.data.n_mel_channels,
497
  hps.data.sampling_rate,
 
507
 
508
  # Discriminator
509
  y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
510
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
511
  loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
512
  y_d_hat_r, y_d_hat_g
513
  )
 
516
  y_dur_hat_r, y_dur_hat_g = net_dur_disc(
517
  hidden_x.detach(),
518
  x_mask.detach(),
519
+ logw_.detach(),
520
  logw.detach(),
521
+ g.detach(),
522
+ )
523
+ y_dur_hat_r_sdp, y_dur_hat_g_sdp = net_dur_disc(
524
+ hidden_x.detach(),
525
+ x_mask.detach(),
526
  logw_.detach(),
527
+ logw_sdp.detach(),
528
  g.detach(),
529
  )
530
+ y_dur_hat_r = y_dur_hat_r + y_dur_hat_r_sdp
531
+ y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
532
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
533
  # TODO: I think need to mean using the mask, but for now, just mean all
534
  (
535
  loss_dur_disc,
 
540
  optim_dur_disc.zero_grad()
541
  scaler.scale(loss_dur_disc_all).backward()
542
  scaler.unscale_(optim_dur_disc)
543
+ # torch.nn.utils.clip_grad_norm_(
544
+ # parameters=net_dur_disc.parameters(), max_norm=100
545
+ # )
546
+ grad_norm_dur = commons.clip_grad_value_(
547
+ net_dur_disc.parameters(), None
548
+ )
549
  scaler.step(optim_dur_disc)
550
 
551
  optim_d.zero_grad()
552
  scaler.scale(loss_disc_all).backward()
553
  scaler.unscale_(optim_d)
554
+ if getattr(hps.train, "bf16_run", False):
555
+ torch.nn.utils.clip_grad_norm_(parameters=net_d.parameters(), max_norm=200)
556
  grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
557
  scaler.step(optim_d)
558
 
559
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
560
+ loss_slm = wl.discriminator(
561
+ y.detach().squeeze(), y_hat.detach().squeeze()
562
+ ).mean()
563
+
564
+ optim_wd.zero_grad()
565
+ scaler.scale(loss_slm).backward()
566
+ scaler.unscale_(optim_wd)
567
+ # torch.nn.utils.clip_grad_norm_(parameters=net_wd.parameters(), max_norm=200)
568
+ grad_norm_wd = commons.clip_grad_value_(net_wd.parameters(), None)
569
+ scaler.step(optim_wd)
570
+
571
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
572
  # Generator
573
  y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
574
  if net_dur_disc is not None:
575
+ _, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw, g)
576
+ _, y_dur_hat_g_sdp = net_dur_disc(hidden_x, x_mask, logw_, logw_sdp, g)
577
+ y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
578
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
579
  loss_dur = torch.sum(l_length.float())
580
  loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
581
  loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
582
 
583
  loss_fm = feature_loss(fmap_r, fmap_g)
584
  loss_gen, losses_gen = generator_loss(y_d_hat_g)
585
+
586
+ loss_lm = wl(y.detach().squeeze(), y_hat.squeeze()).mean()
587
+ loss_lm_gen = wl.generator(y_hat.squeeze())
588
+
589
  loss_gen_all = (
590
+ loss_gen
591
+ + loss_fm
592
+ + loss_mel
593
+ + loss_dur
594
+ + loss_kl
595
+ + loss_lm
596
+ + loss_lm_gen
597
  )
598
  if net_dur_disc is not None:
599
  loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
 
601
  optim_g.zero_grad()
602
  scaler.scale(loss_gen_all).backward()
603
  scaler.unscale_(optim_g)
604
+ if getattr(hps.train, "bf16_run", False):
605
+ torch.nn.utils.clip_grad_norm_(parameters=net_g.parameters(), max_norm=500)
606
  grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
607
  scaler.step(optim_g)
608
  scaler.update()
 
621
  scalar_dict = {
622
  "loss/g/total": loss_gen_all,
623
  "loss/d/total": loss_disc_all,
624
+ "loss/wd/total": loss_slm,
625
  "learning_rate": lr,
626
  "grad_norm_d": grad_norm_d,
627
  "grad_norm_g": grad_norm_g,
628
+ "grad_norm_dur": grad_norm_dur,
629
+ "grad_norm_wd": grad_norm_wd,
630
  }
631
  scalar_dict.update(
632
  {
 
634
  "loss/g/mel": loss_mel,
635
  "loss/g/dur": loss_dur,
636
  "loss/g/kl": loss_kl,
637
+ "loss/g/lm": loss_lm,
638
+ "loss/g/lm_gen": loss_lm_gen,
639
  }
640
  )
641
  scalar_dict.update(
 
648
  {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
649
  )
650
 
651
+ if net_dur_disc is not None:
652
+ scalar_dict.update({"loss/dur_disc/total": loss_dur_disc_all})
653
+
654
+ scalar_dict.update(
655
+ {
656
+ "loss/dur_disc_g/{}".format(i): v
657
+ for i, v in enumerate(losses_dur_disc_g)
658
+ }
659
+ )
660
+ scalar_dict.update(
661
+ {
662
+ "loss/dur_disc_r/{}".format(i): v
663
+ for i, v in enumerate(losses_dur_disc_r)
664
+ }
665
+ )
666
+
667
+ scalar_dict.update({"loss/g/dur_gen": loss_dur_gen})
668
+ scalar_dict.update(
669
+ {
670
+ "loss/g/dur_gen_{}".format(i): v
671
+ for i, v in enumerate(losses_dur_gen)
672
+ }
673
+ )
674
+
675
  image_dict = {
676
  "slice/mel_org": utils.plot_spectrogram_to_numpy(
677
  y_mel[0].data.cpu().numpy()
 
709
  epoch,
710
  os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
711
  )
712
+ utils.save_checkpoint(
713
+ net_wd,
714
+ optim_wd,
715
+ hps.train.learning_rate,
716
+ epoch,
717
+ os.path.join(hps.model_dir, "WD_{}.pth".format(global_step)),
718
+ )
719
  if net_dur_disc is not None:
720
  utils.save_checkpoint(
721
  net_dur_disc,
 
759
  bert,
760
  ja_bert,
761
  en_bert,
 
762
  ) in enumerate(eval_loader):
763
  x, x_lengths = x.cuda(), x_lengths.cuda()
764
  spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
 
769
  en_bert = en_bert.cuda()
770
  tone = tone.cuda()
771
  language = language.cuda()
 
772
  for use_sdp in [True, False]:
773
  y_hat, attn, mask, *_ = generator.module.infer(
774
  x,
 
779
  bert,
780
  ja_bert,
781
  en_bert,
 
782
  y=spec,
783
  max_len=1000,
784
  sdp_ratio=0.0 if not use_sdp else 1.0,
utils.py CHANGED
@@ -301,7 +301,11 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
301
 
302
  to_del = [
303
  os.path.join(path_to_models, fn)
304
- for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
 
 
 
 
305
  ]
306
 
307
  def del_info(fn):
 
301
 
302
  to_del = [
303
  os.path.join(path_to_models, fn)
304
+ for fn in (
305
+ x_sorted("G")[:-n_ckpts_to_keep]
306
+ + x_sorted("D")[:-n_ckpts_to_keep]
307
+ + x_sorted("WD")[:-n_ckpts_to_keep]
308
+ )
309
  ]
310
 
311
  def del_info(fn):
webui.py CHANGED
@@ -42,6 +42,8 @@ def generate_audio(
42
  language,
43
  reference_audio,
44
  emotion,
 
 
45
  skip_start=False,
46
  skip_end=False,
47
  ):
@@ -49,8 +51,8 @@ def generate_audio(
49
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
50
  with torch.no_grad():
51
  for idx, piece in enumerate(slices):
52
- skip_start = (idx != 0) and skip_start
53
- skip_end = (idx != len(slices) - 1) and skip_end
54
  audio = infer(
55
  piece,
56
  reference_audio=reference_audio,
@@ -66,10 +68,11 @@ def generate_audio(
66
  device=device,
67
  skip_start=skip_start,
68
  skip_end=skip_end,
 
 
69
  )
70
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
71
  audio_list.append(audio16bit)
72
- # audio_list.append(silence) # 将静音添加到列表中
73
  return audio_list
74
 
75
 
@@ -90,8 +93,8 @@ def generate_audio_multilang(
90
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
91
  with torch.no_grad():
92
  for idx, piece in enumerate(slices):
93
- skip_start = (idx != 0) and skip_start
94
- skip_end = (idx != len(slices) - 1) and skip_end
95
  audio = infer_multilang(
96
  piece,
97
  reference_audio=reference_audio,
@@ -110,7 +113,6 @@ def generate_audio_multilang(
110
  )
111
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
112
  audio_list.append(audio16bit)
113
- # audio_list.append(silence) # 将静音添加到列表中
114
  return audio_list
115
 
116
 
@@ -127,63 +129,50 @@ def tts_split(
127
  interval_between_sent,
128
  reference_audio,
129
  emotion,
 
 
130
  ):
131
- if language == "mix":
132
- return ("invalid", None)
133
  while text.find("\n\n") != -1:
134
  text = text.replace("\n\n", "\n")
 
135
  para_list = re_matching.cut_para(text)
 
136
  audio_list = []
137
- if not cut_by_sent:
138
- for idx, p in enumerate(para_list):
139
- skip_start = idx != 0
140
- skip_end = idx != len(para_list) - 1
141
- audio = infer(
142
  p,
143
- reference_audio=reference_audio,
144
- emotion=emotion,
145
- sdp_ratio=sdp_ratio,
146
- noise_scale=noise_scale,
147
- noise_scale_w=noise_scale_w,
148
- length_scale=length_scale,
149
- sid=speaker,
150
- language=language,
151
- hps=hps,
152
- net_g=net_g,
153
- device=device,
154
- skip_start=skip_start,
155
- skip_end=skip_end,
156
  )
157
- audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
158
- audio_list.append(audio16bit)
159
  silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
160
  audio_list.append(silence)
161
- else:
162
- for idx, p in enumerate(para_list):
163
- skip_start = idx != 0
164
- skip_end = idx != len(para_list) - 1
165
  audio_list_sent = []
166
  sent_list = re_matching.cut_sent(p)
167
- for idx, s in enumerate(sent_list):
168
- skip_start = (idx != 0) and skip_start
169
- skip_end = (idx != len(sent_list) - 1) and skip_end
170
- audio = infer(
171
  s,
172
- reference_audio=reference_audio,
173
- emotion=emotion,
174
- sdp_ratio=sdp_ratio,
175
- noise_scale=noise_scale,
176
- noise_scale_w=noise_scale_w,
177
- length_scale=length_scale,
178
- sid=speaker,
179
- language=language,
180
- hps=hps,
181
- net_g=net_g,
182
- device=device,
183
- skip_start=skip_start,
184
- skip_end=skip_end,
185
  )
186
- audio_list_sent.append(audio)
187
  silence = np.zeros((int)(44100 * interval_between_sent))
188
  audio_list_sent.append(silence)
189
  if (interval_between_para - interval_between_sent) > 0:
@@ -196,10 +185,47 @@ def tts_split(
196
  ) # 对完整句子做音量归一
197
  audio_list.append(audio16bit)
198
  audio_concat = np.concatenate(audio_list)
199
- return ("Success", (44100, audio_concat))
200
 
201
 
202
- def tts_fn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  text: str,
204
  speaker,
205
  sdp_ratio,
@@ -209,15 +235,9 @@ def tts_fn(
209
  language,
210
  reference_audio,
211
  emotion,
212
- prompt_mode,
 
213
  ):
214
- if prompt_mode == "Audio prompt":
215
- if reference_audio == None:
216
- return ("Invalid audio prompt", None)
217
- else:
218
- reference_audio = load_audio(reference_audio)[1]
219
- else:
220
- reference_audio = None
221
  audio_list = []
222
  if language == "mix":
223
  bool_valid, str_valid = re_matching.validate_text(text)
@@ -226,120 +246,40 @@ def tts_fn(
226
  hps.data.sampling_rate,
227
  np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
228
  )
229
- result = []
230
  for slice in re_matching.text_matching(text):
231
- _speaker = slice.pop()
232
- temp_contant = []
233
- temp_lang = []
234
- for lang, content in slice:
235
- if "|" in content:
236
- temp = []
237
- temp_ = []
238
- for i in content.split("|"):
239
- if i != "":
240
- temp.append([i])
241
- temp_.append([lang])
242
- else:
243
- temp.append([])
244
- temp_.append([])
245
- temp_contant += temp
246
- temp_lang += temp_
247
- else:
248
- if len(temp_contant) == 0:
249
- temp_contant.append([])
250
- temp_lang.append([])
251
- temp_contant[-1].append(content)
252
- temp_lang[-1].append(lang)
253
- for i, j in zip(temp_lang, temp_contant):
254
- result.append([*zip(i, j), _speaker])
255
- for i, one in enumerate(result):
256
- skip_start = i != 0
257
- skip_end = i != len(result) - 1
258
- _speaker = one.pop()
259
- idx = 0
260
- while idx < len(one):
261
- text_to_generate = []
262
- lang_to_generate = []
263
- while True:
264
- lang, content = one[idx]
265
- temp_text = [content]
266
- if len(text_to_generate) > 0:
267
- text_to_generate[-1] += [temp_text.pop(0)]
268
- lang_to_generate[-1] += [lang]
269
- if len(temp_text) > 0:
270
- text_to_generate += [[i] for i in temp_text]
271
- lang_to_generate += [[lang]] * len(temp_text)
272
- if idx + 1 < len(one):
273
- idx += 1
274
- else:
275
- break
276
- skip_start = (idx != 0) and skip_start
277
- skip_end = (idx != len(one) - 1) and skip_end
278
- print(text_to_generate, lang_to_generate)
279
- audio_list.extend(
280
- generate_audio_multilang(
281
- text_to_generate,
282
- sdp_ratio,
283
- noise_scale,
284
- noise_scale_w,
285
- length_scale,
286
- _speaker,
287
- lang_to_generate,
288
- reference_audio,
289
- emotion,
290
- skip_start,
291
- skip_end,
292
- )
293
  )
294
- idx += 1
295
  elif language.lower() == "auto":
296
- for idx, slice in enumerate(text.split("|")):
297
- if slice == "":
298
- continue
299
- skip_start = idx != 0
300
- skip_end = idx != len(text.split("|")) - 1
301
- sentences_list = split_by_language(
302
- slice, target_languages=["zh", "ja", "en"]
 
 
 
 
 
 
303
  )
304
- idx = 0
305
- while idx < len(sentences_list):
306
- text_to_generate = []
307
- lang_to_generate = []
308
- while True:
309
- content, lang = sentences_list[idx]
310
- temp_text = [content]
311
- lang = lang.upper()
312
- if lang == "JA":
313
- lang = "JP"
314
- if len(text_to_generate) > 0:
315
- text_to_generate[-1] += [temp_text.pop(0)]
316
- lang_to_generate[-1] += [lang]
317
- if len(temp_text) > 0:
318
- text_to_generate += [[i] for i in temp_text]
319
- lang_to_generate += [[lang]] * len(temp_text)
320
- if idx + 1 < len(sentences_list):
321
- idx += 1
322
- else:
323
- break
324
- skip_start = (idx != 0) and skip_start
325
- skip_end = (idx != len(sentences_list) - 1) and skip_end
326
- print(text_to_generate, lang_to_generate)
327
- audio_list.extend(
328
- generate_audio_multilang(
329
- text_to_generate,
330
- sdp_ratio,
331
- noise_scale,
332
- noise_scale_w,
333
- length_scale,
334
- speaker,
335
- lang_to_generate,
336
- reference_audio,
337
- emotion,
338
- skip_start,
339
- skip_end,
340
- )
341
- )
342
- idx += 1
343
  else:
344
  audio_list.extend(
345
  generate_audio(
@@ -352,13 +292,65 @@ def tts_fn(
352
  language,
353
  reference_audio,
354
  emotion,
 
 
355
  )
356
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  audio_concat = np.concatenate(audio_list)
359
  return "Success", (hps.data.sampling_rate, audio_concat)
360
 
361
 
 
 
 
 
 
 
 
 
 
 
362
  def load_audio(path):
363
  audio, sr = librosa.load(path, 48000)
364
  # audio = librosa.resample(audio, 44100, 48000)
@@ -408,34 +400,37 @@ if __name__ == "__main__":
408
  )
409
  trans = gr.Button("中翻日", variant="primary")
410
  slicer = gr.Button("快速切分", variant="primary")
 
411
  speaker = gr.Dropdown(
412
  choices=speakers, value=speakers[0], label="Speaker"
413
  )
414
  _ = gr.Markdown(
415
- value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n"
 
416
  )
417
  prompt_mode = gr.Radio(
418
  ["Text prompt", "Audio prompt"],
419
  label="Prompt Mode",
420
  value="Text prompt",
 
421
  )
422
  text_prompt = gr.Textbox(
423
  label="Text prompt",
424
  placeholder="用文字描述生成风格。如:Happy",
425
  value="Happy",
426
- visible=True,
427
  )
428
  audio_prompt = gr.Audio(
429
  label="Audio prompt", type="filepath", visible=False
430
  )
431
  sdp_ratio = gr.Slider(
432
- minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
433
  )
434
  noise_scale = gr.Slider(
435
  minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
436
  )
437
  noise_scale_w = gr.Slider(
438
- minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise_W"
439
  )
440
  length_scale = gr.Slider(
441
  minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
@@ -445,6 +440,21 @@ if __name__ == "__main__":
445
  )
446
  btn = gr.Button("生成音频!", variant="primary")
447
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  with gr.Row():
449
  with gr.Column():
450
  interval_between_sent = gr.Slider(
@@ -487,6 +497,8 @@ if __name__ == "__main__":
487
  audio_prompt,
488
  text_prompt,
489
  prompt_mode,
 
 
490
  ],
491
  outputs=[text_output, audio_output],
492
  )
@@ -511,6 +523,8 @@ if __name__ == "__main__":
511
  interval_between_sent,
512
  audio_prompt,
513
  text_prompt,
 
 
514
  ],
515
  outputs=[text_output, audio_output],
516
  )
@@ -527,6 +541,12 @@ if __name__ == "__main__":
527
  outputs=[audio_prompt],
528
  )
529
 
 
 
 
 
 
 
530
  print("推理页面已开启!")
531
  webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
532
  app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
 
42
  language,
43
  reference_audio,
44
  emotion,
45
+ style_text,
46
+ style_weight,
47
  skip_start=False,
48
  skip_end=False,
49
  ):
 
51
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
52
  with torch.no_grad():
53
  for idx, piece in enumerate(slices):
54
+ skip_start = idx != 0
55
+ skip_end = idx != len(slices) - 1
56
  audio = infer(
57
  piece,
58
  reference_audio=reference_audio,
 
68
  device=device,
69
  skip_start=skip_start,
70
  skip_end=skip_end,
71
+ style_text=style_text,
72
+ style_weight=style_weight,
73
  )
74
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
75
  audio_list.append(audio16bit)
 
76
  return audio_list
77
 
78
 
 
93
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
94
  with torch.no_grad():
95
  for idx, piece in enumerate(slices):
96
+ skip_start = idx != 0
97
+ skip_end = idx != len(slices) - 1
98
  audio = infer_multilang(
99
  piece,
100
  reference_audio=reference_audio,
 
113
  )
114
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
115
  audio_list.append(audio16bit)
 
116
  return audio_list
117
 
118
 
 
129
  interval_between_sent,
130
  reference_audio,
131
  emotion,
132
+ style_text,
133
+ style_weight,
134
  ):
 
 
135
  while text.find("\n\n") != -1:
136
  text = text.replace("\n\n", "\n")
137
+ text = text.replace("|", "")
138
  para_list = re_matching.cut_para(text)
139
+ para_list = [p for p in para_list if p != ""]
140
  audio_list = []
141
+ for p in para_list:
142
+ if not cut_by_sent:
143
+ audio_list += process_text(
 
 
144
  p,
145
+ speaker,
146
+ sdp_ratio,
147
+ noise_scale,
148
+ noise_scale_w,
149
+ length_scale,
150
+ language,
151
+ reference_audio,
152
+ emotion,
153
+ style_text,
154
+ style_weight,
 
 
 
155
  )
 
 
156
  silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
157
  audio_list.append(silence)
158
+ else:
 
 
 
159
  audio_list_sent = []
160
  sent_list = re_matching.cut_sent(p)
161
+ sent_list = [s for s in sent_list if s != ""]
162
+ for s in sent_list:
163
+ audio_list_sent += process_text(
 
164
  s,
165
+ speaker,
166
+ sdp_ratio,
167
+ noise_scale,
168
+ noise_scale_w,
169
+ length_scale,
170
+ language,
171
+ reference_audio,
172
+ emotion,
173
+ style_text,
174
+ style_weight,
 
 
 
175
  )
 
176
  silence = np.zeros((int)(44100 * interval_between_sent))
177
  audio_list_sent.append(silence)
178
  if (interval_between_para - interval_between_sent) > 0:
 
185
  ) # 对完整句子做音量归一
186
  audio_list.append(audio16bit)
187
  audio_concat = np.concatenate(audio_list)
188
+ return ("Success", (hps.data.sampling_rate, audio_concat))
189
 
190
 
191
+ def process_mix(slice):
192
+ _speaker = slice.pop()
193
+ _text, _lang = [], []
194
+ for lang, content in slice:
195
+ content = content.split("|")
196
+ content = [part for part in content if part != ""]
197
+ if len(content) == 0:
198
+ continue
199
+ if len(_text) == 0:
200
+ _text = [[part] for part in content]
201
+ _lang = [[lang] for part in content]
202
+ else:
203
+ _text[-1].append(content[0])
204
+ _lang[-1].append(lang)
205
+ if len(content) > 1:
206
+ _text += [[part] for part in content[1:]]
207
+ _lang += [[lang] for part in content[1:]]
208
+ return _text, _lang, _speaker
209
+
210
+
211
+ def process_auto(text):
212
+ _text, _lang = [], []
213
+ for slice in text.split("|"):
214
+ if slice == "":
215
+ continue
216
+ temp_text, temp_lang = [], []
217
+ sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
218
+ for sentence, lang in sentences_list:
219
+ if sentence == "":
220
+ continue
221
+ temp_text.append(sentence)
222
+ temp_lang.append(lang.upper())
223
+ _text.append(temp_text)
224
+ _lang.append(temp_lang)
225
+ return _text, _lang
226
+
227
+
228
+ def process_text(
229
  text: str,
230
  speaker,
231
  sdp_ratio,
 
235
  language,
236
  reference_audio,
237
  emotion,
238
+ style_text=None,
239
+ style_weight=0,
240
  ):
 
 
 
 
 
 
 
241
  audio_list = []
242
  if language == "mix":
243
  bool_valid, str_valid = re_matching.validate_text(text)
 
246
  hps.data.sampling_rate,
247
  np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
248
  )
 
249
  for slice in re_matching.text_matching(text):
250
+ _text, _lang, _speaker = process_mix(slice)
251
+ if _speaker is None:
252
+ continue
253
+ print(f"Text: {_text}\nLang: {_lang}")
254
+ audio_list.extend(
255
+ generate_audio_multilang(
256
+ _text,
257
+ sdp_ratio,
258
+ noise_scale,
259
+ noise_scale_w,
260
+ length_scale,
261
+ _speaker,
262
+ _lang,
263
+ reference_audio,
264
+ emotion,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
+ )
267
  elif language.lower() == "auto":
268
+ _text, _lang = process_auto(text)
269
+ print(f"Text: {_text}\nLang: {_lang}")
270
+ audio_list.extend(
271
+ generate_audio_multilang(
272
+ _text,
273
+ sdp_ratio,
274
+ noise_scale,
275
+ noise_scale_w,
276
+ length_scale,
277
+ speaker,
278
+ _lang,
279
+ reference_audio,
280
+ emotion,
281
  )
282
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  else:
284
  audio_list.extend(
285
  generate_audio(
 
292
  language,
293
  reference_audio,
294
  emotion,
295
+ style_text,
296
+ style_weight,
297
  )
298
  )
299
+ return audio_list
300
+
301
+
302
+ def tts_fn(
303
+ text: str,
304
+ speaker,
305
+ sdp_ratio,
306
+ noise_scale,
307
+ noise_scale_w,
308
+ length_scale,
309
+ language,
310
+ reference_audio,
311
+ emotion,
312
+ prompt_mode,
313
+ style_text=None,
314
+ style_weight=0,
315
+ ):
316
+ if style_text == "":
317
+ style_text = None
318
+ if prompt_mode == "Audio prompt":
319
+ if reference_audio == None:
320
+ return ("Invalid audio prompt", None)
321
+ else:
322
+ reference_audio = load_audio(reference_audio)[1]
323
+ else:
324
+ reference_audio = None
325
+
326
+ audio_list = process_text(
327
+ text,
328
+ speaker,
329
+ sdp_ratio,
330
+ noise_scale,
331
+ noise_scale_w,
332
+ length_scale,
333
+ language,
334
+ reference_audio,
335
+ emotion,
336
+ style_text,
337
+ style_weight,
338
+ )
339
 
340
  audio_concat = np.concatenate(audio_list)
341
  return "Success", (hps.data.sampling_rate, audio_concat)
342
 
343
 
344
+ def format_utils(text, speaker):
345
+ _text, _lang = process_auto(text)
346
+ res = f"[{speaker}]"
347
+ for lang_s, content_s in zip(_lang, _text):
348
+ for lang, content in zip(lang_s, content_s):
349
+ res += f"<{lang.lower()}>{content}"
350
+ res += "|"
351
+ return "mix", res[:-1]
352
+
353
+
354
  def load_audio(path):
355
  audio, sr = librosa.load(path, 48000)
356
  # audio = librosa.resample(audio, 44100, 48000)
 
400
  )
401
  trans = gr.Button("中翻日", variant="primary")
402
  slicer = gr.Button("快速切分", variant="primary")
403
+ formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
404
  speaker = gr.Dropdown(
405
  choices=speakers, value=speakers[0], label="Speaker"
406
  )
407
  _ = gr.Markdown(
408
+ value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
409
+ visible=False,
410
  )
411
  prompt_mode = gr.Radio(
412
  ["Text prompt", "Audio prompt"],
413
  label="Prompt Mode",
414
  value="Text prompt",
415
+ visible=False,
416
  )
417
  text_prompt = gr.Textbox(
418
  label="Text prompt",
419
  placeholder="用文字描述生成风格。如:Happy",
420
  value="Happy",
421
+ visible=False,
422
  )
423
  audio_prompt = gr.Audio(
424
  label="Audio prompt", type="filepath", visible=False
425
  )
426
  sdp_ratio = gr.Slider(
427
+ minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio"
428
  )
429
  noise_scale = gr.Slider(
430
  minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
431
  )
432
  noise_scale_w = gr.Slider(
433
+ minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W"
434
  )
435
  length_scale = gr.Slider(
436
  minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
 
440
  )
441
  btn = gr.Button("生成音频!", variant="primary")
442
  with gr.Column():
443
+ with gr.Accordion("融合文本语义", open=False):
444
+ gr.Markdown(
445
+ value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
446
+ "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
447
+ "效果较不明确,留空即为不使用该功能"
448
+ )
449
+ style_text = gr.Textbox(label="辅助文本")
450
+ style_weight = gr.Slider(
451
+ minimum=0,
452
+ maximum=1,
453
+ value=0.7,
454
+ step=0.1,
455
+ label="Weight",
456
+ info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
457
+ )
458
  with gr.Row():
459
  with gr.Column():
460
  interval_between_sent = gr.Slider(
 
497
  audio_prompt,
498
  text_prompt,
499
  prompt_mode,
500
+ style_text,
501
+ style_weight,
502
  ],
503
  outputs=[text_output, audio_output],
504
  )
 
523
  interval_between_sent,
524
  audio_prompt,
525
  text_prompt,
526
+ style_text,
527
+ style_weight,
528
  ],
529
  outputs=[text_output, audio_output],
530
  )
 
541
  outputs=[audio_prompt],
542
  )
543
 
544
+ formatter.click(
545
+ format_utils,
546
+ inputs=[text, speaker],
547
+ outputs=[language, text],
548
+ )
549
+
550
  print("推理页面已开启!")
551
  webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
552
  app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
webui_preprocess.py CHANGED
@@ -19,9 +19,9 @@ def generate_config(data_dir, batch_size):
19
  assert data_dir != "", "数据集名称不能为空"
20
  start_path, _, train_path, val_path, config_path = get_path(data_dir)
21
  if os.path.isfile(config_path):
22
- config = json.load(open(config_path))
23
  else:
24
- config = json.load(open("configs/config.json"))
25
  config["data"]["training_files"] = train_path
26
  config["data"]["validation_files"] = val_path
27
  config["train"]["batch_size"] = batch_size
@@ -44,7 +44,7 @@ def resample(data_dir):
44
  in_dir = os.path.join(start_path, "raw")
45
  out_dir = os.path.join(start_path, "wavs")
46
  subprocess.run(
47
- f"python resample.py "
48
  f"--sr 44100 "
49
  f"--in_dir {in_dir} "
50
  f"--out_dir {out_dir} ",
@@ -60,7 +60,9 @@ def preprocess_text(data_dir):
60
  with open(lbl_path, "w", encoding="utf-8") as f:
61
  for line in lines:
62
  path, spk, language, text = line.strip().split("|")
63
- path = os.path.join(start_path, "wavs", os.path.basename(path))
 
 
64
  f.writelines(f"{path}|{spk}|{language}|{text}\n")
65
  subprocess.run(
66
  f"python preprocess_text.py "
@@ -83,16 +85,6 @@ def bert_gen(data_dir):
83
  return "BERT 特征文件生成完成"
84
 
85
 
86
- def clap_gen(data_dir):
87
- assert data_dir != "", "数据集名称不能为空"
88
- _, _, _, _, config_path = get_path(data_dir)
89
- subprocess.run(
90
- f"python clap_gen.py " f"--config {config_path}",
91
- shell=True,
92
- )
93
- return "CLAP 特征文件生成完成"
94
-
95
-
96
  if __name__ == "__main__":
97
  with gr.Blocks() as app:
98
  with gr.Row():
@@ -100,13 +92,13 @@ if __name__ == "__main__":
100
  _ = gr.Markdown(
101
  value="# Bert-VITS2 数据预处理\n"
102
  "## 预先准备:\n"
103
- "下载 BERT 和 CLAP 模型:\n"
104
  "- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
105
  "- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
106
  "- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
107
- "- [CLAP](https://huggingface.co/laion/clap-htsat-fused)\n"
108
  "\n"
109
- "将 BERT 模型放置到 `bert` 文件夹下,CLAP 模型放置到 `emotional` 文件夹下,覆盖同名文件夹。\n"
110
  "\n"
111
  "数据准备:\n"
112
  "将数据放置在 data 文件夹下,按照如下结构组织:\n"
@@ -156,12 +148,10 @@ if __name__ == "__main__":
156
  preprocess_text_btn = gr.Button(value="执行", variant="primary")
157
  _ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
158
  bert_gen_btn = gr.Button(value="执行", variant="primary")
159
- _ = gr.Markdown(value="## 第五步:生成 CLAP 特征文件")
160
- clap_gen_btn = gr.Button(value="执行", variant="primary")
161
  _ = gr.Markdown(
162
  value="## 训练模型及部署:\n"
163
  "修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
164
- "- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
165
  "- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
166
  )
167
 
@@ -171,7 +161,6 @@ if __name__ == "__main__":
171
  resample_btn.click(resample, inputs=[data_dir], outputs=[info])
172
  preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
173
  bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
174
- clap_gen_btn.click(clap_gen, inputs=[data_dir], outputs=[info])
175
 
176
  webbrowser.open("http://127.0.0.1:7860")
177
  app.launch(share=False, server_port=7860)
 
19
  assert data_dir != "", "数据集名称不能为空"
20
  start_path, _, train_path, val_path, config_path = get_path(data_dir)
21
  if os.path.isfile(config_path):
22
+ config = json.load(open(config_path, "r", encoding="utf-8"))
23
  else:
24
+ config = json.load(open("configs/config.json", "r", encoding="utf-8"))
25
  config["data"]["training_files"] = train_path
26
  config["data"]["validation_files"] = val_path
27
  config["train"]["batch_size"] = batch_size
 
44
  in_dir = os.path.join(start_path, "raw")
45
  out_dir = os.path.join(start_path, "wavs")
46
  subprocess.run(
47
+ f"python resample_legacy.py "
48
  f"--sr 44100 "
49
  f"--in_dir {in_dir} "
50
  f"--out_dir {out_dir} ",
 
60
  with open(lbl_path, "w", encoding="utf-8") as f:
61
  for line in lines:
62
  path, spk, language, text = line.strip().split("|")
63
+ path = os.path.join(start_path, "wavs", os.path.basename(path)).replace(
64
+ "\\", "/"
65
+ )
66
  f.writelines(f"{path}|{spk}|{language}|{text}\n")
67
  subprocess.run(
68
  f"python preprocess_text.py "
 
85
  return "BERT 特征文件生成完成"
86
 
87
 
 
 
 
 
 
 
 
 
 
 
88
  if __name__ == "__main__":
89
  with gr.Blocks() as app:
90
  with gr.Row():
 
92
  _ = gr.Markdown(
93
  value="# Bert-VITS2 数据预处理\n"
94
  "## 预先准备:\n"
95
+ "下载 BERT 和 WavLM 模型:\n"
96
  "- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
97
  "- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
98
  "- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
99
+ "- [WavLM](https://huggingface.co/microsoft/wavlm-base-plus)\n"
100
  "\n"
101
+ "将 BERT 模型放置到 `bert` 文件夹下,WavLM 模型放置到 `slm` 文件夹下,覆盖同名文件夹。\n"
102
  "\n"
103
  "数据准备:\n"
104
  "将数据放置在 data 文件夹下,按照如下结构组织:\n"
 
148
  preprocess_text_btn = gr.Button(value="执行", variant="primary")
149
  _ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
150
  bert_gen_btn = gr.Button(value="执行", variant="primary")
 
 
151
  _ = gr.Markdown(
152
  value="## 训练模型及部署:\n"
153
  "修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
154
+ "- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth`、`WD_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
155
  "- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
156
  )
157
 
 
161
  resample_btn.click(resample, inputs=[data_dir], outputs=[info])
162
  preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
163
  bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
 
164
 
165
  webbrowser.open("http://127.0.0.1:7860")
166
  app.launch(share=False, server_port=7860)