Mahiruoshi commited on
Commit
db74c3c
1 Parent(s): 078ba7e

Upload 30 files

Browse files
Files changed (19) hide show
  1. bert_gen.py +25 -18
  2. config.yml +3 -3
  3. data_utils.py +7 -24
  4. export_onnx.py +3 -1
  5. hiyoriUI.py +725 -0
  6. infer.py +90 -35
  7. losses.py +95 -0
  8. models.py +66 -67
  9. modules.py +1 -1
  10. onnx_infer.py +60 -0
  11. re_matching.py +0 -1
  12. resample.py +10 -6
  13. resample_legacy.py +71 -0
  14. server.py +733 -103
  15. test.py +36 -0
  16. train_ms.py +176 -63
  17. utils.py +5 -1
  18. webui.py +211 -174
  19. webui_preprocess.py +10 -21
bert_gen.py CHANGED
@@ -1,17 +1,16 @@
1
- import argparse
2
- from multiprocessing import Pool, cpu_count
3
-
4
  import torch
5
- import torch.multiprocessing as mp
6
- from tqdm import tqdm
7
-
8
  import commons
9
  import utils
 
 
 
 
10
  from config import config
11
- from text import cleaned_text_to_sequence, get_bert
12
 
13
 
14
- def process_line(line):
 
15
  device = config.bert_gen_config.device
16
  if config.bert_gen_config.use_multi_device:
17
  rank = mp.current_process()._identity
@@ -28,18 +27,19 @@ def process_line(line):
28
  word2ph = [i for i in word2ph]
29
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
30
 
31
- phone = commons.intersperse(phone, 0)
32
- tone = commons.intersperse(tone, 0)
33
- language = commons.intersperse(language, 0)
34
- for i in range(len(word2ph)):
35
- word2ph[i] = word2ph[i] * 2
36
- word2ph[0] += 1
 
37
 
38
  bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
39
 
40
  try:
41
  bert = torch.load(bert_path)
42
- assert bert.shape[-1] == len(phone)
43
  except Exception:
44
  bert = get_bert(text, word2ph, language_str, device)
45
  assert bert.shape[-1] == len(phone)
@@ -59,16 +59,23 @@ if __name__ == "__main__":
59
  args, _ = parser.parse_known_args()
60
  config_path = args.config
61
  hps = utils.get_hparams_from_file(config_path)
 
62
  lines = []
63
  with open(hps.data.training_files, encoding="utf-8") as f:
64
  lines.extend(f.readlines())
65
 
66
  with open(hps.data.validation_files, encoding="utf-8") as f:
67
  lines.extend(f.readlines())
 
 
68
  if len(lines) != 0:
69
- num_processes = min(args.num_processes, cpu_count())
70
  with Pool(processes=num_processes) as pool:
71
- for _ in tqdm(pool.imap_unordered(process_line, lines), total=len(lines)):
72
- pass
 
 
 
 
73
 
74
  print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
 
 
 
 
1
  import torch
2
+ from multiprocessing import Pool
 
 
3
  import commons
4
  import utils
5
+ from tqdm import tqdm
6
+ from text import check_bert_models, cleaned_text_to_sequence, get_bert
7
+ import argparse
8
+ import torch.multiprocessing as mp
9
  from config import config
 
10
 
11
 
12
+ def process_line(x):
13
+ line, add_blank = x
14
  device = config.bert_gen_config.device
15
  if config.bert_gen_config.use_multi_device:
16
  rank = mp.current_process()._identity
 
27
  word2ph = [i for i in word2ph]
28
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
29
 
30
+ if add_blank:
31
+ phone = commons.intersperse(phone, 0)
32
+ tone = commons.intersperse(tone, 0)
33
+ language = commons.intersperse(language, 0)
34
+ for i in range(len(word2ph)):
35
+ word2ph[i] = word2ph[i] * 2
36
+ word2ph[0] += 1
37
 
38
  bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
39
 
40
  try:
41
  bert = torch.load(bert_path)
42
+ assert bert.shape[0] == 2048
43
  except Exception:
44
  bert = get_bert(text, word2ph, language_str, device)
45
  assert bert.shape[-1] == len(phone)
 
59
  args, _ = parser.parse_known_args()
60
  config_path = args.config
61
  hps = utils.get_hparams_from_file(config_path)
62
+ check_bert_models()
63
  lines = []
64
  with open(hps.data.training_files, encoding="utf-8") as f:
65
  lines.extend(f.readlines())
66
 
67
  with open(hps.data.validation_files, encoding="utf-8") as f:
68
  lines.extend(f.readlines())
69
+ add_blank = [hps.data.add_blank] * len(lines)
70
+
71
  if len(lines) != 0:
72
+ num_processes = args.num_processes
73
  with Pool(processes=num_processes) as pool:
74
+ for _ in tqdm(
75
+ pool.imap_unordered(process_line, zip(lines, add_blank)),
76
+ total=len(lines),
77
+ ):
78
+ # 这里是缩进的代码块,表示循环体
79
+ pass # 使用pass语句作为占位符
80
 
81
  print(f"bert生成完毕!, 共有{len(lines)}个bert.pt生成!")
config.yml CHANGED
@@ -83,7 +83,7 @@ train_ms:
83
  base:
84
  use_base_model: false
85
  repo_id: "Stardust_minus/Bert-VITS2"
86
- model_image: "Bert-VITS2_2.2-Clap底模" # openi网页的模型名
87
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
88
  model: "models"
89
  # 配置文件路径
@@ -172,6 +172,6 @@ server:
172
  # 请不要在github等网站公开分享你的app id 与 key
173
  translate:
174
  # 你的APPID
175
- "app_key": ""
176
  # 你的密钥
177
- "secret_key": ""
 
83
  base:
84
  use_base_model: false
85
  repo_id: "Stardust_minus/Bert-VITS2"
86
+ model_image: "Bert-VITS2_2.3底模" # openi网页的模型名
87
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
88
  model: "models"
89
  # 配置文件路径
 
172
  # 请不要在github等网站公开分享你的app id 与 key
173
  translate:
174
  # 你的APPID
175
+ "app_key": "20231117001883321"
176
  # 你的密钥
177
+ "secret_key": "lMQbvZHeJveDceLof2wf"
data_utils.py CHANGED
@@ -3,7 +3,6 @@ import random
3
  import torch
4
  import torch.utils.data
5
  from tqdm import tqdm
6
- import numpy as np
7
  from tools.log import logger
8
  import commons
9
  from mel_processing import spectrogram_torch, mel_spectrogram_torch
@@ -44,10 +43,6 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
44
  self.min_text_len = getattr(hparams, "min_text_len", 1)
45
  self.max_text_len = getattr(hparams, "max_text_len", 384)
46
 
47
- self.empty_emo = torch.squeeze(
48
- torch.load("empty_emo.npy", map_location="cpu"), dim=1
49
- )
50
-
51
  random.seed(1234)
52
  random.shuffle(self.audiopaths_sid_text)
53
  self._filter()
@@ -98,14 +93,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
98
  spec, wav = self.get_audio(audiopath)
99
  sid = torch.LongTensor([int(self.spk_map[sid])])
100
 
101
- if np.random.rand() > 0.1:
102
- emo = torch.squeeze(
103
- torch.load(audiopath.replace(".wav", ".emo.npy"), map_location="cpu"),
104
- dim=1,
105
- )
106
- else:
107
- emo = self.empty_emo
108
- return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
109
 
110
  def get_audio(self, filename):
111
  audio, sampling_rate = load_wav_to_torch(filename)
@@ -168,15 +156,15 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
168
 
169
  if language_str == "ZH":
170
  bert = bert_ori
171
- ja_bert = torch.rand(1024, len(phone))
172
- en_bert = torch.rand(1024, len(phone))
173
  elif language_str == "JP":
174
- bert = torch.rand(1024, len(phone))
175
  ja_bert = bert_ori
176
- en_bert = torch.rand(1024, len(phone))
177
  elif language_str == "EN":
178
- bert = torch.rand(1024, len(phone))
179
- ja_bert = torch.rand(1024, len(phone))
180
  en_bert = bert_ori
181
  phone = torch.LongTensor(phone)
182
  tone = torch.LongTensor(tone)
@@ -226,7 +214,6 @@ class TextAudioSpeakerCollate:
226
  bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
227
  ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
228
  en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
229
- emo = torch.FloatTensor(len(batch), 512)
230
 
231
  spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
232
  wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
@@ -238,7 +225,6 @@ class TextAudioSpeakerCollate:
238
  bert_padded.zero_()
239
  ja_bert_padded.zero_()
240
  en_bert_padded.zero_()
241
- emo.zero_()
242
 
243
  for i in range(len(ids_sorted_decreasing)):
244
  row = batch[ids_sorted_decreasing[i]]
@@ -272,8 +258,6 @@ class TextAudioSpeakerCollate:
272
  en_bert = row[8]
273
  en_bert_padded[i, :, : en_bert.size(1)] = en_bert
274
 
275
- emo[i, :] = row[9]
276
-
277
  return (
278
  text_padded,
279
  text_lengths,
@@ -287,7 +271,6 @@ class TextAudioSpeakerCollate:
287
  bert_padded,
288
  ja_bert_padded,
289
  en_bert_padded,
290
- emo,
291
  )
292
 
293
 
 
3
  import torch
4
  import torch.utils.data
5
  from tqdm import tqdm
 
6
  from tools.log import logger
7
  import commons
8
  from mel_processing import spectrogram_torch, mel_spectrogram_torch
 
43
  self.min_text_len = getattr(hparams, "min_text_len", 1)
44
  self.max_text_len = getattr(hparams, "max_text_len", 384)
45
 
 
 
 
 
46
  random.seed(1234)
47
  random.shuffle(self.audiopaths_sid_text)
48
  self._filter()
 
93
  spec, wav = self.get_audio(audiopath)
94
  sid = torch.LongTensor([int(self.spk_map[sid])])
95
 
96
+ return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert)
 
 
 
 
 
 
 
97
 
98
  def get_audio(self, filename):
99
  audio, sampling_rate = load_wav_to_torch(filename)
 
156
 
157
  if language_str == "ZH":
158
  bert = bert_ori
159
+ ja_bert = torch.randn(1024, len(phone))
160
+ en_bert = torch.randn(1024, len(phone))
161
  elif language_str == "JP":
162
+ bert = torch.randn(1024, len(phone))
163
  ja_bert = bert_ori
164
+ en_bert = torch.randn(1024, len(phone))
165
  elif language_str == "EN":
166
+ bert = torch.randn(1024, len(phone))
167
+ ja_bert = torch.randn(1024, len(phone))
168
  en_bert = bert_ori
169
  phone = torch.LongTensor(phone)
170
  tone = torch.LongTensor(tone)
 
214
  bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
215
  ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
216
  en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
 
217
 
218
  spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
219
  wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
 
225
  bert_padded.zero_()
226
  ja_bert_padded.zero_()
227
  en_bert_padded.zero_()
 
228
 
229
  for i in range(len(ids_sorted_decreasing)):
230
  row = batch[ids_sorted_decreasing[i]]
 
258
  en_bert = row[8]
259
  en_bert_padded[i, :, : en_bert.size(1)] = en_bert
260
 
 
 
261
  return (
262
  text_padded,
263
  text_lengths,
 
271
  bert_padded,
272
  ja_bert_padded,
273
  en_bert_padded,
 
274
  )
275
 
276
 
export_onnx.py CHANGED
@@ -5,8 +5,10 @@ if __name__ == "__main__":
5
  export_path = "BertVits2.2PT"
6
  model_path = "model\\G_0.pth"
7
  config_path = "model\\config.json"
 
 
8
  if not os.path.exists("onnx"):
9
  os.makedirs("onnx")
10
  if not os.path.exists(f"onnx/{export_path}"):
11
  os.makedirs(f"onnx/{export_path}")
12
- export_onnx(export_path, model_path, config_path)
 
5
  export_path = "BertVits2.2PT"
6
  model_path = "model\\G_0.pth"
7
  config_path = "model\\config.json"
8
+ novq = False
9
+ dev = False
10
  if not os.path.exists("onnx"):
11
  os.makedirs("onnx")
12
  if not os.path.exists(f"onnx/{export_path}"):
13
  os.makedirs(f"onnx/{export_path}")
14
+ export_onnx(export_path, model_path, config_path, novq, dev)
hiyoriUI.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ api服务,网页后端 多版本多模型 fastapi实现
3
+ 原 server_fastapi
4
+ """
5
+ import logging
6
+ import gc
7
+ import random
8
+ import librosa
9
+ import gradio
10
+ import numpy as np
11
+ import utils
12
+ from fastapi import FastAPI, Query, Request, File, UploadFile, Form
13
+ from fastapi.responses import Response, FileResponse
14
+ from fastapi.staticfiles import StaticFiles
15
+ from io import BytesIO
16
+ from scipy.io import wavfile
17
+ import uvicorn
18
+ import torch
19
+ import webbrowser
20
+ import psutil
21
+ import GPUtil
22
+ from typing import Dict, Optional, List, Set, Union, Tuple
23
+ import os
24
+ from tools.log import logger
25
+ from urllib.parse import unquote
26
+
27
+ from infer import infer, get_net_g, latest_version
28
+ import tools.translate as trans
29
+ from tools.sentence import split_by_language
30
+ from re_matching import cut_sent
31
+
32
+
33
+ from config import config
34
+
35
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
36
+
37
+
38
+ class Model:
39
+ """模型封装类"""
40
+
41
+ def __init__(self, config_path: str, model_path: str, device: str, language: str):
42
+ self.config_path: str = os.path.normpath(config_path)
43
+ self.model_path: str = os.path.normpath(model_path)
44
+ self.device: str = device
45
+ self.language: str = language
46
+ self.hps = utils.get_hparams_from_file(config_path)
47
+ self.spk2id: Dict[str, int] = self.hps.data.spk2id # spk - id 映射字典
48
+ self.id2spk: Dict[int, str] = dict() # id - spk 映射字典
49
+ for speaker, speaker_id in self.hps.data.spk2id.items():
50
+ self.id2spk[speaker_id] = speaker
51
+ self.version: str = (
52
+ self.hps.version if hasattr(self.hps, "version") else latest_version
53
+ )
54
+ self.net_g = get_net_g(
55
+ model_path=model_path,
56
+ version=self.version,
57
+ device=device,
58
+ hps=self.hps,
59
+ )
60
+
61
+ def to_dict(self) -> Dict[str, any]:
62
+ return {
63
+ "config_path": self.config_path,
64
+ "model_path": self.model_path,
65
+ "device": self.device,
66
+ "language": self.language,
67
+ "spk2id": self.spk2id,
68
+ "id2spk": self.id2spk,
69
+ "version": self.version,
70
+ }
71
+
72
+
73
+ class Models:
74
+ def __init__(self):
75
+ self.models: Dict[int, Model] = dict()
76
+ self.num = 0
77
+ # spkInfo[角色名][模型id] = 角色id
78
+ self.spk_info: Dict[str, Dict[int, int]] = dict()
79
+ self.path2ids: Dict[str, Set[int]] = dict() # 路径指向的model的id
80
+
81
+ def init_model(
82
+ self, config_path: str, model_path: str, device: str, language: str
83
+ ) -> int:
84
+ """
85
+ 初始化并添加一个模型
86
+
87
+ :param config_path: 模型config.json路径
88
+ :param model_path: 模型路径
89
+ :param device: 模型推理使用设备
90
+ :param language: 模型推理默认语言
91
+ """
92
+ # 若文件不存在则不进行加载
93
+ if not os.path.isfile(model_path):
94
+ if model_path != "":
95
+ logger.warning(f"模型文件{model_path} 不存在,不进行初始化")
96
+ return self.num
97
+ if not os.path.isfile(config_path):
98
+ if config_path != "":
99
+ logger.warning(f"配置文件{config_path} 不存在,不进行初始化")
100
+ return self.num
101
+
102
+ # 若路径中的模型已存在,则不添加模型,若不存在,则进行初始化。
103
+ model_path = os.path.realpath(model_path)
104
+ if model_path not in self.path2ids.keys():
105
+ self.path2ids[model_path] = {self.num}
106
+ self.models[self.num] = Model(
107
+ config_path=config_path,
108
+ model_path=model_path,
109
+ device=device,
110
+ language=language,
111
+ )
112
+ logger.success(f"添加模型{model_path},使用配置文件{os.path.realpath(config_path)}")
113
+ else:
114
+ # 获取一个指向id
115
+ m_id = next(iter(self.path2ids[model_path]))
116
+ self.models[self.num] = self.models[m_id]
117
+ self.path2ids[model_path].add(self.num)
118
+ logger.success("模型已存在,添加模型引用。")
119
+ # 添加角色信息
120
+ for speaker, speaker_id in self.models[self.num].spk2id.items():
121
+ if speaker not in self.spk_info.keys():
122
+ self.spk_info[speaker] = {self.num: speaker_id}
123
+ else:
124
+ self.spk_info[speaker][self.num] = speaker_id
125
+ # 修改计数
126
+ self.num += 1
127
+ return self.num - 1
128
+
129
+ def del_model(self, index: int) -> Optional[int]:
130
+ """删除对应序号的模型,若不存在则返回None"""
131
+ if index not in self.models.keys():
132
+ return None
133
+ # 删除角色信息
134
+ for speaker, speaker_id in self.models[index].spk2id.items():
135
+ self.spk_info[speaker].pop(index)
136
+ if len(self.spk_info[speaker]) == 0:
137
+ # 若对应角色的所有模型都被删除,则清除该角色信息
138
+ self.spk_info.pop(speaker)
139
+ # 删除路径信息
140
+ model_path = os.path.realpath(self.models[index].model_path)
141
+ self.path2ids[model_path].remove(index)
142
+ if len(self.path2ids[model_path]) == 0:
143
+ self.path2ids.pop(model_path)
144
+ logger.success(f"删除模型{model_path}, id = {index}")
145
+ else:
146
+ logger.success(f"删除模型引用{model_path}, id = {index}")
147
+ # 删除模型
148
+ self.models.pop(index)
149
+ gc.collect()
150
+ if torch.cuda.is_available():
151
+ torch.cuda.empty_cache()
152
+ return index
153
+
154
+ def get_models(self):
155
+ """获取所有模型"""
156
+ return self.models
157
+
158
+
159
+ if __name__ == "__main__":
160
+ app = FastAPI()
161
+ app.logger = logger
162
+ # 挂载静态文件
163
+ logger.info("开始挂载网页页面")
164
+ StaticDir: str = "./Web"
165
+ if not os.path.isdir(StaticDir):
166
+ logger.warning(
167
+ "缺少网页资源,无法开启网页页面,如有需要请在 https://github.com/jiangyuxiaoxiao/Bert-VITS2-UI 或者Bert-VITS对应版本的release页面下载"
168
+ )
169
+ else:
170
+ dirs = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
171
+ files = [fir.name for fir in os.scandir(StaticDir) if fir.is_dir()]
172
+ for dirName in dirs:
173
+ app.mount(
174
+ f"/{dirName}",
175
+ StaticFiles(directory=f"./{StaticDir}/{dirName}"),
176
+ name=dirName,
177
+ )
178
+ loaded_models = Models()
179
+ # 加载模型
180
+ logger.info("开始加载模型")
181
+ models_info = config.server_config.models
182
+ for model_info in models_info:
183
+ loaded_models.init_model(
184
+ config_path=model_info["config"],
185
+ model_path=model_info["model"],
186
+ device=model_info["device"],
187
+ language=model_info["language"],
188
+ )
189
+
190
+ @app.get("/")
191
+ async def index():
192
+ return FileResponse("./Web/index.html")
193
+
194
+ async def _voice(
195
+ text: str,
196
+ model_id: int,
197
+ speaker_name: str,
198
+ speaker_id: int,
199
+ sdp_ratio: float,
200
+ noise: float,
201
+ noisew: float,
202
+ length: float,
203
+ language: str,
204
+ auto_translate: bool,
205
+ auto_split: bool,
206
+ emotion: Optional[Union[int, str]] = None,
207
+ reference_audio=None,
208
+ style_text: Optional[str] = None,
209
+ style_weight: float = 0.7,
210
+ ) -> Union[Response, Dict[str, any]]:
211
+ """TTS实现函数"""
212
+
213
+ # 检查
214
+ # 检查模型是否存在
215
+ if model_id not in loaded_models.models.keys():
216
+ logger.error(f"/voice 请求错误:模型model_id={model_id}未加载")
217
+ return {"status": 10, "detail": f"模型model_id={model_id}未加载"}
218
+ # 检查是否提供speaker
219
+ if speaker_name is None and speaker_id is None:
220
+ logger.error("/voice 请求错误:推理请求未提供speaker_name或speaker_id")
221
+ return {"status": 11, "detail": "请提供speaker_name或speaker_id"}
222
+ elif speaker_name is None:
223
+ # 检查speaker_id是否存在
224
+ if speaker_id not in loaded_models.models[model_id].id2spk.keys():
225
+ logger.error(f"/voice 请求错误:角色speaker_id={speaker_id}不存在")
226
+ return {"status": 12, "detail": f"角色speaker_id={speaker_id}不存在"}
227
+ speaker_name = loaded_models.models[model_id].id2spk[speaker_id]
228
+ # 检查speaker_name是否存在
229
+ if speaker_name not in loaded_models.models[model_id].spk2id.keys():
230
+ logger.error(f"/voice 请求错误:角色speaker_name={speaker_name}不存在")
231
+ return {"status": 13, "detail": f"角色speaker_name={speaker_name}不存在"}
232
+ # 未传入则使用默认语言
233
+ if language is None:
234
+ language = loaded_models.models[model_id].language
235
+ # 翻译会破坏mix结构,auto也会变得无意义。不要在这两个模式下使用
236
+ if auto_translate:
237
+ if language == "auto" or language == "mix":
238
+ logger.error(
239
+ f"/voice 请求错误:请勿同时使用language = {language}与auto_translate模式"
240
+ )
241
+ return {
242
+ "status": 20,
243
+ "detail": f"请勿同时使用language = {language}与auto_translate模式",
244
+ }
245
+ text = trans.translate(Sentence=text, to_Language=language.lower())
246
+ if reference_audio is not None:
247
+ ref_audio = BytesIO(await reference_audio.read())
248
+ # 2.2 适配
249
+ if loaded_models.models[model_id].version == "2.2":
250
+ ref_audio, _ = librosa.load(ref_audio, 48000)
251
+ else:
252
+ ref_audio = reference_audio
253
+
254
+ # 改动:增加使用 || 对文本进行主动切分
255
+ # 切分优先级: || → auto/mix → auto_split
256
+ text2 = text.replace("\n", "").lstrip()
257
+ texts: List[str] = text2.split("||")
258
+
259
+ # 对于mix和auto的说明:出于版本兼容性���考虑,暂时无法使用multilang的方式进行推理
260
+ if language == "MIX":
261
+ text_language_speakers: List[Tuple[str, str, str]] = []
262
+ for _text in texts:
263
+ speaker_pieces = _text.split("[") # 按说话人分割多块
264
+ for speaker_piece in speaker_pieces:
265
+ if speaker_piece == "":
266
+ continue
267
+ speaker_piece2 = speaker_piece.split("]")
268
+ if len(speaker_piece2) != 2:
269
+ return {
270
+ "status": 21,
271
+ "detail": "MIX语法错误",
272
+ }
273
+ speaker = speaker_piece2[0].strip()
274
+ lang_pieces = speaker_piece2[1].split("<")
275
+ for lang_piece in lang_pieces:
276
+ if lang_piece == "":
277
+ continue
278
+ lang_piece2 = lang_piece.split(">")
279
+ if len(lang_piece2) != 2:
280
+ return {
281
+ "status": 21,
282
+ "detail": "MIX语法错误",
283
+ }
284
+ lang = lang_piece2[0].strip()
285
+ if lang.upper() not in ["ZH", "EN", "JP"]:
286
+ return {
287
+ "status": 21,
288
+ "detail": "MIX语法错误",
289
+ }
290
+ t = lang_piece2[1]
291
+ text_language_speakers.append((t, lang.upper(), speaker))
292
+
293
+ elif language == "AUTO":
294
+ text_language_speakers: List[Tuple[str, str, str]] = [
295
+ (final_text, language.upper().replace("JA", "JP"), speaker_name)
296
+ for sub_list in [
297
+ split_by_language(_text, target_languages=["zh", "ja", "en"])
298
+ for _text in texts
299
+ if _text != ""
300
+ ]
301
+ for final_text, language in sub_list
302
+ if final_text != ""
303
+ ]
304
+ else:
305
+ text_language_speakers: List[Tuple[str, str, str]] = [
306
+ (_text, language, speaker_name) for _text in texts if _text != ""
307
+ ]
308
+
309
+ if auto_split:
310
+ text_language_speakers: List[Tuple[str, str, str]] = [
311
+ (final_text, lang, speaker)
312
+ for _text, lang, speaker in text_language_speakers
313
+ for final_text in cut_sent(_text)
314
+ ]
315
+
316
+ audios = []
317
+ with torch.no_grad():
318
+ for _text, lang, speaker in text_language_speakers:
319
+ audios.append(
320
+ infer(
321
+ text=_text,
322
+ sdp_ratio=sdp_ratio,
323
+ noise_scale=noise,
324
+ noise_scale_w=noisew,
325
+ length_scale=length,
326
+ sid=speaker,
327
+ language=lang,
328
+ hps=loaded_models.models[model_id].hps,
329
+ net_g=loaded_models.models[model_id].net_g,
330
+ device=loaded_models.models[model_id].device,
331
+ emotion=emotion,
332
+ reference_audio=ref_audio,
333
+ style_text=style_text,
334
+ style_weight=style_weight,
335
+ )
336
+ )
337
+ # audios.append(np.zeros(int(44100 * 0.2)))
338
+ # audios.pop()
339
+ audio = np.concatenate(audios)
340
+ audio = gradio.processing_utils.convert_to_16_bit_wav(audio)
341
+ with BytesIO() as wavContent:
342
+ wavfile.write(
343
+ wavContent, loaded_models.models[model_id].hps.data.sampling_rate, audio
344
+ )
345
+ response = Response(content=wavContent.getvalue(), media_type="audio/wav")
346
+ return response
347
+
348
+ @app.post("/voice")
349
+ async def voice(
350
+ request: Request, # fastapi自动注入
351
+ text: str = Form(...),
352
+ model_id: int = Query(..., description="模型ID"), # 模型序号
353
+ speaker_name: str = Query(
354
+ None, description="说话人名"
355
+ ), # speaker_name与 speaker_id二者选其一
356
+ speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
357
+ sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
358
+ noise: float = Query(0.2, description="感情"),
359
+ noisew: float = Query(0.9, description="音素长度"),
360
+ length: float = Query(1, description="语速"),
361
+ language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
362
+ auto_translate: bool = Query(False, description="自动翻译"),
363
+ auto_split: bool = Query(False, description="自动切分"),
364
+ emotion: Optional[Union[int, str]] = Query(None, description="emo"),
365
+ reference_audio: UploadFile = File(None),
366
+ style_text: Optional[str] = Form(None, description="风格文本"),
367
+ style_weight: float = Query(0.7, description="风格权重"),
368
+ ):
369
+ """语音接口,若需要上传参考音频请仅使用post请求"""
370
+ logger.info(
371
+ f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )} text={text}"
372
+ )
373
+ return await _voice(
374
+ text=text,
375
+ model_id=model_id,
376
+ speaker_name=speaker_name,
377
+ speaker_id=speaker_id,
378
+ sdp_ratio=sdp_ratio,
379
+ noise=noise,
380
+ noisew=noisew,
381
+ length=length,
382
+ language=language,
383
+ auto_translate=auto_translate,
384
+ auto_split=auto_split,
385
+ emotion=emotion,
386
+ reference_audio=reference_audio,
387
+ style_text=style_text,
388
+ style_weight=style_weight,
389
+ )
390
+
391
+ @app.get("/voice")
392
+ async def voice(
393
+ request: Request, # fastapi自动注入
394
+ text: str = Query(..., description="输入文字"),
395
+ model_id: int = Query(..., description="模型ID"), # 模型序号
396
+ speaker_name: str = Query(
397
+ None, description="说话人名"
398
+ ), # speaker_name与 speaker_id二者选其一
399
+ speaker_id: int = Query(None, description="说话人id,与speaker_name二选一"),
400
+ sdp_ratio: float = Query(0.2, description="SDP/DP混合比"),
401
+ noise: float = Query(0.2, description="感情"),
402
+ noisew: float = Query(0.9, description="音素长度"),
403
+ length: float = Query(1, description="语速"),
404
+ language: str = Query(None, description="语言"), # 若不指定使用语言则使用默认值
405
+ auto_translate: bool = Query(False, description="自动翻译"),
406
+ auto_split: bool = Query(False, description="自动切分"),
407
+ emotion: Optional[Union[int, str]] = Query(None, description="emo"),
408
+ style_text: Optional[str] = Query(None, description="风格文本"),
409
+ style_weight: float = Query(0.7, description="风格权重"),
410
+ ):
411
+ """语音接口,不建议使用"""
412
+ logger.info(
413
+ f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
414
+ )
415
+ return await _voice(
416
+ text=text,
417
+ model_id=model_id,
418
+ speaker_name=speaker_name,
419
+ speaker_id=speaker_id,
420
+ sdp_ratio=sdp_ratio,
421
+ noise=noise,
422
+ noisew=noisew,
423
+ length=length,
424
+ language=language,
425
+ auto_translate=auto_translate,
426
+ auto_split=auto_split,
427
+ emotion=emotion,
428
+ style_text=style_text,
429
+ style_weight=style_weight,
430
+ )
431
+
432
+ @app.get("/models/info")
433
+ def get_loaded_models_info(request: Request):
434
+ """获取已加载模型信息"""
435
+
436
+ result: Dict[str, Dict] = dict()
437
+ for key, model in loaded_models.models.items():
438
+ result[str(key)] = model.to_dict()
439
+ return result
440
+
441
+ @app.get("/models/delete")
442
+ def delete_model(
443
+ request: Request, model_id: int = Query(..., description="删除模型id")
444
+ ):
445
+ """删除指定模型"""
446
+ logger.info(
447
+ f"{request.client.host}:{request.client.port}/models/delete { unquote(str(request.query_params) )}"
448
+ )
449
+ result = loaded_models.del_model(model_id)
450
+ if result is None:
451
+ logger.error(f"/models/delete 模型删除错误:模型{model_id}不存在,删除失败")
452
+ return {"status": 14, "detail": f"模型{model_id}不存在,删除失败"}
453
+
454
+ return {"status": 0, "detail": "删除成功"}
455
+
456
+ @app.get("/models/add")
457
+ def add_model(
458
+ request: Request,
459
+ model_path: str = Query(..., description="添加模型路径"),
460
+ config_path: str = Query(
461
+ None, description="添加模型配置文件路径,不填则使用./config.json或../config.json"
462
+ ),
463
+ device: str = Query("cuda", description="推理使用设备"),
464
+ language: str = Query("ZH", description="模型默认语言"),
465
+ ):
466
+ """添加指定模型:允许重复添加相同路径模型,且不重复占用内存"""
467
+ logger.info(
468
+ f"{request.client.host}:{request.client.port}/models/add { unquote(str(request.query_params) )}"
469
+ )
470
+ if config_path is None:
471
+ model_dir = os.path.dirname(model_path)
472
+ if os.path.isfile(os.path.join(model_dir, "config.json")):
473
+ config_path = os.path.join(model_dir, "config.json")
474
+ elif os.path.isfile(os.path.join(model_dir, "../config.json")):
475
+ config_path = os.path.join(model_dir, "../config.json")
476
+ else:
477
+ logger.error("/models/add 模型添加失败:未在模型所在目录以及上级目录找到config.json文件")
478
+ return {
479
+ "status": 15,
480
+ "detail": "查询未传���配置文件路径,同时默认路径./与../中不存在配置文件config.json。",
481
+ }
482
+ try:
483
+ model_id = loaded_models.init_model(
484
+ config_path=config_path,
485
+ model_path=model_path,
486
+ device=device,
487
+ language=language,
488
+ )
489
+ except Exception:
490
+ logging.exception("模型加载出错")
491
+ return {
492
+ "status": 16,
493
+ "detail": "模型加载出错,详细查看日志",
494
+ }
495
+ return {
496
+ "status": 0,
497
+ "detail": "模型添加成功",
498
+ "Data": {
499
+ "model_id": model_id,
500
+ "model_info": loaded_models.models[model_id].to_dict(),
501
+ },
502
+ }
503
+
504
+ def _get_all_models(root_dir: str = "Data", only_unloaded: bool = False):
505
+ """从root_dir搜索获取所有可用模型"""
506
+ result: Dict[str, List[str]] = dict()
507
+ files = os.listdir(root_dir) + ["."]
508
+ for file in files:
509
+ if os.path.isdir(os.path.join(root_dir, file)):
510
+ sub_dir = os.path.join(root_dir, file)
511
+ # 搜索 "sub_dir" 、 "sub_dir/models" 两个路径
512
+ result[file] = list()
513
+ sub_files = os.listdir(sub_dir)
514
+ model_files = []
515
+ for sub_file in sub_files:
516
+ relpath = os.path.realpath(os.path.join(sub_dir, sub_file))
517
+ if only_unloaded and relpath in loaded_models.path2ids.keys():
518
+ continue
519
+ if sub_file.endswith(".pth") and sub_file.startswith("G_"):
520
+ if os.path.isfile(relpath):
521
+ model_files.append(sub_file)
522
+ # 对模型文件按步数排序
523
+ model_files = sorted(
524
+ model_files,
525
+ key=lambda pth: int(pth.lstrip("G_").rstrip(".pth"))
526
+ if pth.lstrip("G_").rstrip(".pth").isdigit()
527
+ else 10**10,
528
+ )
529
+ result[file] = model_files
530
+ models_dir = os.path.join(sub_dir, "models")
531
+ model_files = []
532
+ if os.path.isdir(models_dir):
533
+ sub_files = os.listdir(models_dir)
534
+ for sub_file in sub_files:
535
+ relpath = os.path.realpath(os.path.join(models_dir, sub_file))
536
+ if only_unloaded and relpath in loaded_models.path2ids.keys():
537
+ continue
538
+ if sub_file.endswith(".pth") and sub_file.startswith("G_"):
539
+ if os.path.isfile(os.path.join(models_dir, sub_file)):
540
+ model_files.append(f"models/{sub_file}")
541
+ # 对模型文件按步数排序
542
+ model_files = sorted(
543
+ model_files,
544
+ key=lambda pth: int(pth.lstrip("models/G_").rstrip(".pth"))
545
+ if pth.lstrip("models/G_").rstrip(".pth").isdigit()
546
+ else 10**10,
547
+ )
548
+ result[file] += model_files
549
+ if len(result[file]) == 0:
550
+ result.pop(file)
551
+
552
+ return result
553
+
554
+ @app.get("/models/get_unloaded")
555
+ def get_unloaded_models_info(
556
+ request: Request, root_dir: str = Query("Data", description="搜索根目录")
557
+ ):
558
+ """获取未加载模型"""
559
+ logger.info(
560
+ f"{request.client.host}:{request.client.port}/models/get_unloaded { unquote(str(request.query_params) )}"
561
+ )
562
+ return _get_all_models(root_dir, only_unloaded=True)
563
+
564
+ @app.get("/models/get_local")
565
+ def get_local_models_info(
566
+ request: Request, root_dir: str = Query("Data", description="搜索根目录")
567
+ ):
568
+ """获取全部本地模型"""
569
+ logger.info(
570
+ f"{request.client.host}:{request.client.port}/models/get_local { unquote(str(request.query_params) )}"
571
+ )
572
+ return _get_all_models(root_dir, only_unloaded=False)
573
+
574
+ @app.get("/status")
575
+ def get_status():
576
+ """获取电脑运行状态"""
577
+ cpu_percent = psutil.cpu_percent(interval=1)
578
+ memory_info = psutil.virtual_memory()
579
+ memory_total = memory_info.total
580
+ memory_available = memory_info.available
581
+ memory_used = memory_info.used
582
+ memory_percent = memory_info.percent
583
+ gpuInfo = []
584
+ devices = ["cpu"]
585
+ for i in range(torch.cuda.device_count()):
586
+ devices.append(f"cuda:{i}")
587
+ gpus = GPUtil.getGPUs()
588
+ for gpu in gpus:
589
+ gpuInfo.append(
590
+ {
591
+ "gpu_id": gpu.id,
592
+ "gpu_load": gpu.load,
593
+ "gpu_memory": {
594
+ "total": gpu.memoryTotal,
595
+ "used": gpu.memoryUsed,
596
+ "free": gpu.memoryFree,
597
+ },
598
+ }
599
+ )
600
+ return {
601
+ "devices": devices,
602
+ "cpu_percent": cpu_percent,
603
+ "memory_total": memory_total,
604
+ "memory_available": memory_available,
605
+ "memory_used": memory_used,
606
+ "memory_percent": memory_percent,
607
+ "gpu": gpuInfo,
608
+ }
609
+
610
+ @app.get("/tools/translate")
611
+ def translate(
612
+ request: Request,
613
+ texts: str = Query(..., description="待翻译文本"),
614
+ to_language: str = Query(..., description="翻译目标语言"),
615
+ ):
616
+ """翻译"""
617
+ logger.info(
618
+ f"{request.client.host}:{request.client.port}/tools/translate { unquote(str(request.query_params) )}"
619
+ )
620
+ return {"texts": trans.translate(Sentence=texts, to_Language=to_language)}
621
+
622
+ all_examples: Dict[str, Dict[str, List]] = dict() # 存放示例
623
+
624
+ @app.get("/tools/random_example")
625
+ def random_example(
626
+ request: Request,
627
+ language: str = Query(None, description="指定语言,未指定则随机返回"),
628
+ root_dir: str = Query("Data", description="搜索根目录"),
629
+ ):
630
+ """
631
+ 获取一个随机音频+文本,用于对比,音频会从本地目录随机选择。
632
+ """
633
+ logger.info(
634
+ f"{request.client.host}:{request.client.port}/tools/random_example { unquote(str(request.query_params) )}"
635
+ )
636
+ global all_examples
637
+ # 数据初始化
638
+ if root_dir not in all_examples.keys():
639
+ all_examples[root_dir] = {"ZH": [], "JP": [], "EN": []}
640
+
641
+ examples = all_examples[root_dir]
642
+
643
+ # 从项目Data目录中搜索train/val.list
644
+ for root, directories, _files in os.walk(root_dir):
645
+ for file in _files:
646
+ if file in ["train.list", "val.list"]:
647
+ with open(
648
+ os.path.join(root, file), mode="r", encoding="utf-8"
649
+ ) as f:
650
+ lines = f.readlines()
651
+ for line in lines:
652
+ data = line.split("|")
653
+ if len(data) != 7:
654
+ continue
655
+ # 音频存在 且语言为ZH/EN/JP
656
+ if os.path.isfile(data[0]) and data[2] in [
657
+ "ZH",
658
+ "JP",
659
+ "EN",
660
+ ]:
661
+ examples[data[2]].append(
662
+ {
663
+ "text": data[3],
664
+ "audio": data[0],
665
+ "speaker": data[1],
666
+ }
667
+ )
668
+
669
+ examples = all_examples[root_dir]
670
+ if language is None:
671
+ if len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) == 0:
672
+ return {"status": 17, "detail": "没有加载任何示例数据"}
673
+ else:
674
+ # 随机选一个
675
+ rand_num = random.randint(
676
+ 0,
677
+ len(examples["ZH"]) + len(examples["JP"]) + len(examples["EN"]) - 1,
678
+ )
679
+ # ZH
680
+ if rand_num < len(examples["ZH"]):
681
+ return {"status": 0, "Data": examples["ZH"][rand_num]}
682
+ # JP
683
+ if rand_num < len(examples["ZH"]) + len(examples["JP"]):
684
+ return {
685
+ "status": 0,
686
+ "Data": examples["JP"][rand_num - len(examples["ZH"])],
687
+ }
688
+ # EN
689
+ return {
690
+ "status": 0,
691
+ "Data": examples["EN"][
692
+ rand_num - len(examples["ZH"]) - len(examples["JP"])
693
+ ],
694
+ }
695
+
696
+ else:
697
+ if len(examples[language]) == 0:
698
+ return {"status": 17, "detail": f"没有加载任何{language}数据"}
699
+ return {
700
+ "status": 0,
701
+ "Data": examples[language][
702
+ random.randint(0, len(examples[language]) - 1)
703
+ ],
704
+ }
705
+
706
+ @app.get("/tools/get_audio")
707
+ def get_audio(request: Request, path: str = Query(..., description="本地音频路径")):
708
+ logger.info(
709
+ f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
710
+ )
711
+ if not os.path.isfile(path):
712
+ logger.error(f"/tools/get_audio 获取音频错误:指定音频{path}不存在")
713
+ return {"status": 18, "detail": "指定音频不存在"}
714
+ if not path.lower().endswith(".wav"):
715
+ logger.error(f"/tools/get_audio 获取音频错误:音频{path}非wav文件")
716
+ return {"status": 19, "detail": "非wav格式文件"}
717
+ return FileResponse(path=path)
718
+
719
+ logger.warning("本地服务,请勿将服务端口暴露于外网")
720
+ logger.info(f"api文档地址 http://127.0.0.1:{config.server_config.port}/docs")
721
+ if os.path.isdir(StaticDir):
722
+ webbrowser.open(f"http://127.0.0.1:{config.server_config.port}")
723
+ uvicorn.run(
724
+ app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
725
+ )
infer.py CHANGED
@@ -5,19 +5,22 @@
5
  2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
  特殊版本说明:
7
  1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
- 2.2:当前版本
9
  """
10
  import torch
11
  import commons
12
  from text import cleaned_text_to_sequence, get_bert
13
- from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
 
 
14
  from text.cleaner import clean_text
15
  import utils
16
- import numpy as np
17
 
18
  from models import SynthesizerTrn
19
  from text.symbols import symbols
20
 
 
 
21
  from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
22
  from oldVersion.V210.text import symbols as V210symbols
23
  from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
@@ -29,13 +32,14 @@ from oldVersion.V110.text import symbols as V110symbols
29
  from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
30
  from oldVersion.V101.text import symbols as V101symbols
31
 
32
- from oldVersion import V111, V110, V101, V200, V210
33
 
34
  # 当前版本信息
35
- latest_version = "2.2"
36
 
37
  # 版本兼容
38
  SynthesizerTrnMap = {
 
39
  "2.1": V210SynthesizerTrn,
40
  "2.0.2-fix": V200SynthesizerTrn,
41
  "2.0.1": V200SynthesizerTrn,
@@ -50,6 +54,7 @@ SynthesizerTrnMap = {
50
  }
51
 
52
  symbolsMap = {
 
53
  "2.1": V210symbols,
54
  "2.0.2-fix": V200symbols,
55
  "2.0.1": V200symbols,
@@ -98,7 +103,8 @@ def get_net_g(model_path: str, version: str, device: str, hps):
98
  return net_g
99
 
100
 
101
- def get_text(text, language_str, hps, device):
 
102
  # 在此处实现当前版本的get_text
103
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
104
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
@@ -110,21 +116,23 @@ def get_text(text, language_str, hps, device):
110
  for i in range(len(word2ph)):
111
  word2ph[i] = word2ph[i] * 2
112
  word2ph[0] += 1
113
- bert_ori = get_bert(norm_text, word2ph, language_str, device)
 
 
114
  del word2ph
115
  assert bert_ori.shape[-1] == len(phone), phone
116
 
117
  if language_str == "ZH":
118
  bert = bert_ori
119
- ja_bert = torch.rand(1024, len(phone))
120
- en_bert = torch.rand(1024, len(phone))
121
  elif language_str == "JP":
122
- bert = torch.rand(1024, len(phone))
123
  ja_bert = bert_ori
124
- en_bert = torch.rand(1024, len(phone))
125
  elif language_str == "EN":
126
- bert = torch.rand(1024, len(phone))
127
- ja_bert = torch.rand(1024, len(phone))
128
  en_bert = bert_ori
129
  else:
130
  raise ValueError("language_str should be ZH, JP or EN")
@@ -141,7 +149,7 @@ def get_text(text, language_str, hps, device):
141
 
142
  def infer(
143
  text,
144
- emotion,
145
  sdp_ratio,
146
  noise_scale,
147
  noise_scale_w,
@@ -154,8 +162,13 @@ def infer(
154
  reference_audio=None,
155
  skip_start=False,
156
  skip_end=False,
 
 
157
  ):
158
  # 2.2版本参数位置变了
 
 
 
159
  # 2.1 参数新增 emotion reference_audio skip_start skip_end
160
  inferMap_V3 = {
161
  "2.1": V210.infer,
@@ -180,6 +193,25 @@ def infer(
180
  version = hps.version if hasattr(hps, "version") else latest_version
181
  # 非当前版本,根据版本号选择合适的infer
182
  if version != latest_version:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  if version in inferMap_V3.keys():
184
  return inferMap_V3[version](
185
  text,
@@ -196,6 +228,8 @@ def infer(
196
  emotion,
197
  skip_start,
198
  skip_end,
 
 
199
  )
200
  if version in inferMap_V2.keys():
201
  return inferMap_V2[version](
@@ -224,14 +258,19 @@ def infer(
224
  )
225
  # 在此处实现当前版本的推理
226
  # emo = get_emo_(reference_audio, emotion, sid)
227
- if isinstance(reference_audio, np.ndarray):
228
- emo = get_clap_audio_feature(reference_audio, device)
229
- else:
230
- emo = get_clap_text_feature(emotion, device)
231
- emo = torch.squeeze(emo, dim=1)
232
 
233
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
234
- text, language, hps, device
 
 
 
 
 
235
  )
236
  if skip_start:
237
  phones = phones[3:]
@@ -255,7 +294,7 @@ def infer(
255
  ja_bert = ja_bert.to(device).unsqueeze(0)
256
  en_bert = en_bert.to(device).unsqueeze(0)
257
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
258
- emo = emo.to(device).unsqueeze(0)
259
  del phones
260
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
261
  audio = (
@@ -268,7 +307,6 @@ def infer(
268
  bert,
269
  ja_bert,
270
  en_bert,
271
- emo,
272
  sdp_ratio=sdp_ratio,
273
  noise_scale=noise_scale,
274
  noise_scale_w=noise_scale_w,
@@ -278,7 +316,16 @@ def infer(
278
  .float()
279
  .numpy()
280
  )
281
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
282
  if torch.cuda.is_available():
283
  torch.cuda.empty_cache()
284
  return audio
@@ -302,14 +349,14 @@ def infer_multilang(
302
  ):
303
  bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
304
  # emo = get_emo_(reference_audio, emotion, sid)
305
- if isinstance(reference_audio, np.ndarray):
306
- emo = get_clap_audio_feature(reference_audio, device)
307
- else:
308
- emo = get_clap_text_feature(emotion, device)
309
- emo = torch.squeeze(emo, dim=1)
310
  for idx, (txt, lang) in enumerate(zip(text, language)):
311
- skip_start = (idx != 0) or (skip_start and idx == 0)
312
- skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
313
  (
314
  temp_bert,
315
  temp_ja_bert,
@@ -318,14 +365,14 @@ def infer_multilang(
318
  temp_tones,
319
  temp_lang_ids,
320
  ) = get_text(txt, lang, hps, device)
321
- if skip_start:
322
  temp_bert = temp_bert[:, 3:]
323
  temp_ja_bert = temp_ja_bert[:, 3:]
324
  temp_en_bert = temp_en_bert[:, 3:]
325
  temp_phones = temp_phones[3:]
326
  temp_tones = temp_tones[3:]
327
  temp_lang_ids = temp_lang_ids[3:]
328
- if skip_end:
329
  temp_bert = temp_bert[:, :-2]
330
  temp_ja_bert = temp_ja_bert[:, :-2]
331
  temp_en_bert = temp_en_bert[:, :-2]
@@ -351,7 +398,7 @@ def infer_multilang(
351
  bert = bert.to(device).unsqueeze(0)
352
  ja_bert = ja_bert.to(device).unsqueeze(0)
353
  en_bert = en_bert.to(device).unsqueeze(0)
354
- emo = emo.to(device).unsqueeze(0)
355
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
356
  del phones
357
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
@@ -365,7 +412,6 @@ def infer_multilang(
365
  bert,
366
  ja_bert,
367
  en_bert,
368
- emo,
369
  sdp_ratio=sdp_ratio,
370
  noise_scale=noise_scale,
371
  noise_scale_w=noise_scale_w,
@@ -375,7 +421,16 @@ def infer_multilang(
375
  .float()
376
  .numpy()
377
  )
378
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
379
  if torch.cuda.is_available():
380
  torch.cuda.empty_cache()
381
  return audio
 
5
  2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
  特殊版本说明:
7
  1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
+ 2.3:当前版本
9
  """
10
  import torch
11
  import commons
12
  from text import cleaned_text_to_sequence, get_bert
13
+
14
+ # from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
15
+ from typing import Union
16
  from text.cleaner import clean_text
17
  import utils
 
18
 
19
  from models import SynthesizerTrn
20
  from text.symbols import symbols
21
 
22
+ from oldVersion.V220.models import SynthesizerTrn as V220SynthesizerTrn
23
+ from oldVersion.V220.text import symbols as V220symbols
24
  from oldVersion.V210.models import SynthesizerTrn as V210SynthesizerTrn
25
  from oldVersion.V210.text import symbols as V210symbols
26
  from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
 
32
  from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
33
  from oldVersion.V101.text import symbols as V101symbols
34
 
35
+ from oldVersion import V111, V110, V101, V200, V210, V220
36
 
37
  # 当前版本信息
38
+ latest_version = "2.3"
39
 
40
  # 版本兼容
41
  SynthesizerTrnMap = {
42
+ "2.2": V220SynthesizerTrn,
43
  "2.1": V210SynthesizerTrn,
44
  "2.0.2-fix": V200SynthesizerTrn,
45
  "2.0.1": V200SynthesizerTrn,
 
54
  }
55
 
56
  symbolsMap = {
57
+ "2.2": V220symbols,
58
  "2.1": V210symbols,
59
  "2.0.2-fix": V200symbols,
60
  "2.0.1": V200symbols,
 
103
  return net_g
104
 
105
 
106
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
107
+ style_text = None if style_text == "" else style_text
108
  # 在此处实现当前版本的get_text
109
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
110
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
 
116
  for i in range(len(word2ph)):
117
  word2ph[i] = word2ph[i] * 2
118
  word2ph[0] += 1
119
+ bert_ori = get_bert(
120
+ norm_text, word2ph, language_str, device, style_text, style_weight
121
+ )
122
  del word2ph
123
  assert bert_ori.shape[-1] == len(phone), phone
124
 
125
  if language_str == "ZH":
126
  bert = bert_ori
127
+ ja_bert = torch.randn(1024, len(phone))
128
+ en_bert = torch.randn(1024, len(phone))
129
  elif language_str == "JP":
130
+ bert = torch.randn(1024, len(phone))
131
  ja_bert = bert_ori
132
+ en_bert = torch.randn(1024, len(phone))
133
  elif language_str == "EN":
134
+ bert = torch.randn(1024, len(phone))
135
+ ja_bert = torch.randn(1024, len(phone))
136
  en_bert = bert_ori
137
  else:
138
  raise ValueError("language_str should be ZH, JP or EN")
 
149
 
150
  def infer(
151
  text,
152
+ emotion: Union[int, str],
153
  sdp_ratio,
154
  noise_scale,
155
  noise_scale_w,
 
162
  reference_audio=None,
163
  skip_start=False,
164
  skip_end=False,
165
+ style_text=None,
166
+ style_weight=0.7,
167
  ):
168
  # 2.2版本参数位置变了
169
+ inferMap_V4 = {
170
+ "2.2": V220.infer,
171
+ }
172
  # 2.1 参数新增 emotion reference_audio skip_start skip_end
173
  inferMap_V3 = {
174
  "2.1": V210.infer,
 
193
  version = hps.version if hasattr(hps, "version") else latest_version
194
  # 非当前版本,根据版本号选择合适的infer
195
  if version != latest_version:
196
+ if version in inferMap_V4.keys():
197
+ return inferMap_V4[version](
198
+ text,
199
+ emotion,
200
+ sdp_ratio,
201
+ noise_scale,
202
+ noise_scale_w,
203
+ length_scale,
204
+ sid,
205
+ language,
206
+ hps,
207
+ net_g,
208
+ device,
209
+ reference_audio,
210
+ skip_start,
211
+ skip_end,
212
+ style_text,
213
+ style_weight,
214
+ )
215
  if version in inferMap_V3.keys():
216
  return inferMap_V3[version](
217
  text,
 
228
  emotion,
229
  skip_start,
230
  skip_end,
231
+ style_text,
232
+ style_weight,
233
  )
234
  if version in inferMap_V2.keys():
235
  return inferMap_V2[version](
 
258
  )
259
  # 在此处实现当前版本的推理
260
  # emo = get_emo_(reference_audio, emotion, sid)
261
+ # if isinstance(reference_audio, np.ndarray):
262
+ # emo = get_clap_audio_feature(reference_audio, device)
263
+ # else:
264
+ # emo = get_clap_text_feature(emotion, device)
265
+ # emo = torch.squeeze(emo, dim=1)
266
 
267
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
268
+ text,
269
+ language,
270
+ hps,
271
+ device,
272
+ style_text=style_text,
273
+ style_weight=style_weight,
274
  )
275
  if skip_start:
276
  phones = phones[3:]
 
294
  ja_bert = ja_bert.to(device).unsqueeze(0)
295
  en_bert = en_bert.to(device).unsqueeze(0)
296
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
297
+ # emo = emo.to(device).unsqueeze(0)
298
  del phones
299
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
300
  audio = (
 
307
  bert,
308
  ja_bert,
309
  en_bert,
 
310
  sdp_ratio=sdp_ratio,
311
  noise_scale=noise_scale,
312
  noise_scale_w=noise_scale_w,
 
316
  .float()
317
  .numpy()
318
  )
319
+ del (
320
+ x_tst,
321
+ tones,
322
+ lang_ids,
323
+ bert,
324
+ x_tst_lengths,
325
+ speakers,
326
+ ja_bert,
327
+ en_bert,
328
+ ) # , emo
329
  if torch.cuda.is_available():
330
  torch.cuda.empty_cache()
331
  return audio
 
349
  ):
350
  bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
351
  # emo = get_emo_(reference_audio, emotion, sid)
352
+ # if isinstance(reference_audio, np.ndarray):
353
+ # emo = get_clap_audio_feature(reference_audio, device)
354
+ # else:
355
+ # emo = get_clap_text_feature(emotion, device)
356
+ # emo = torch.squeeze(emo, dim=1)
357
  for idx, (txt, lang) in enumerate(zip(text, language)):
358
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
359
+ _skip_end = (idx != len(language) - 1) or skip_end
360
  (
361
  temp_bert,
362
  temp_ja_bert,
 
365
  temp_tones,
366
  temp_lang_ids,
367
  ) = get_text(txt, lang, hps, device)
368
+ if _skip_start:
369
  temp_bert = temp_bert[:, 3:]
370
  temp_ja_bert = temp_ja_bert[:, 3:]
371
  temp_en_bert = temp_en_bert[:, 3:]
372
  temp_phones = temp_phones[3:]
373
  temp_tones = temp_tones[3:]
374
  temp_lang_ids = temp_lang_ids[3:]
375
+ if _skip_end:
376
  temp_bert = temp_bert[:, :-2]
377
  temp_ja_bert = temp_ja_bert[:, :-2]
378
  temp_en_bert = temp_en_bert[:, :-2]
 
398
  bert = bert.to(device).unsqueeze(0)
399
  ja_bert = ja_bert.to(device).unsqueeze(0)
400
  en_bert = en_bert.to(device).unsqueeze(0)
401
+ # emo = emo.to(device).unsqueeze(0)
402
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
403
  del phones
404
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
 
412
  bert,
413
  ja_bert,
414
  en_bert,
 
415
  sdp_ratio=sdp_ratio,
416
  noise_scale=noise_scale,
417
  noise_scale_w=noise_scale_w,
 
421
  .float()
422
  .numpy()
423
  )
424
+ del (
425
+ x_tst,
426
+ tones,
427
+ lang_ids,
428
+ bert,
429
+ x_tst_lengths,
430
+ speakers,
431
+ ja_bert,
432
+ en_bert,
433
+ ) # , emo
434
  if torch.cuda.is_available():
435
  torch.cuda.empty_cache()
436
  return audio
losses.py CHANGED
@@ -1,4 +1,6 @@
1
  import torch
 
 
2
 
3
 
4
  def feature_loss(fmap_r, fmap_g):
@@ -56,3 +58,96 @@ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
56
  kl = torch.sum(kl * z_mask)
57
  l = kl / torch.sum(z_mask)
58
  return l
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torchaudio
3
+ from transformers import AutoModel
4
 
5
 
6
  def feature_loss(fmap_r, fmap_g):
 
58
  kl = torch.sum(kl * z_mask)
59
  l = kl / torch.sum(z_mask)
60
  return l
61
+
62
+
63
+ class WavLMLoss(torch.nn.Module):
64
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
65
+ super(WavLMLoss, self).__init__()
66
+ self.wavlm = AutoModel.from_pretrained(model)
67
+ self.wd = wd
68
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
69
+ self.wavlm.eval()
70
+ for param in self.wavlm.parameters():
71
+ param.requires_grad = False
72
+
73
+ def forward(self, wav, y_rec):
74
+ with torch.no_grad():
75
+ wav_16 = self.resample(wav)
76
+ wav_embeddings = self.wavlm(
77
+ input_values=wav_16, output_hidden_states=True
78
+ ).hidden_states
79
+ y_rec_16 = self.resample(y_rec)
80
+ y_rec_embeddings = self.wavlm(
81
+ input_values=y_rec_16.squeeze(), output_hidden_states=True
82
+ ).hidden_states
83
+
84
+ floss = 0
85
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
86
+ floss += torch.mean(torch.abs(er - eg))
87
+
88
+ return floss.mean()
89
+
90
+ def generator(self, y_rec):
91
+ y_rec_16 = self.resample(y_rec)
92
+ y_rec_embeddings = self.wavlm(
93
+ input_values=y_rec_16, output_hidden_states=True
94
+ ).hidden_states
95
+ y_rec_embeddings = (
96
+ torch.stack(y_rec_embeddings, dim=1)
97
+ .transpose(-1, -2)
98
+ .flatten(start_dim=1, end_dim=2)
99
+ )
100
+ y_df_hat_g = self.wd(y_rec_embeddings)
101
+ loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
102
+
103
+ return loss_gen
104
+
105
+ def discriminator(self, wav, y_rec):
106
+ with torch.no_grad():
107
+ wav_16 = self.resample(wav)
108
+ wav_embeddings = self.wavlm(
109
+ input_values=wav_16, output_hidden_states=True
110
+ ).hidden_states
111
+ y_rec_16 = self.resample(y_rec)
112
+ y_rec_embeddings = self.wavlm(
113
+ input_values=y_rec_16, output_hidden_states=True
114
+ ).hidden_states
115
+
116
+ y_embeddings = (
117
+ torch.stack(wav_embeddings, dim=1)
118
+ .transpose(-1, -2)
119
+ .flatten(start_dim=1, end_dim=2)
120
+ )
121
+ y_rec_embeddings = (
122
+ torch.stack(y_rec_embeddings, dim=1)
123
+ .transpose(-1, -2)
124
+ .flatten(start_dim=1, end_dim=2)
125
+ )
126
+
127
+ y_d_rs = self.wd(y_embeddings)
128
+ y_d_gs = self.wd(y_rec_embeddings)
129
+
130
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
131
+
132
+ r_loss = torch.mean((1 - y_df_hat_r) ** 2)
133
+ g_loss = torch.mean((y_df_hat_g) ** 2)
134
+
135
+ loss_disc_f = r_loss + g_loss
136
+
137
+ return loss_disc_f.mean()
138
+
139
+ def discriminator_forward(self, wav):
140
+ with torch.no_grad():
141
+ wav_16 = self.resample(wav)
142
+ wav_embeddings = self.wavlm(
143
+ input_values=wav_16, output_hidden_states=True
144
+ ).hidden_states
145
+ y_embeddings = (
146
+ torch.stack(wav_embeddings, dim=1)
147
+ .transpose(-1, -2)
148
+ .flatten(start_dim=1, end_dim=2)
149
+ )
150
+
151
+ y_d_rs = self.wd(y_embeddings)
152
+
153
+ return y_d_rs
models.py CHANGED
@@ -14,8 +14,6 @@ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
  from commons import init_weights, get_padding
15
  from text import symbols, num_tones, num_languages
16
 
17
- from vector_quantize_pytorch import VectorQuantize
18
-
19
 
20
  class DurationDiscriminator(nn.Module): # vits2
21
  def __init__(
@@ -40,33 +38,22 @@ class DurationDiscriminator(nn.Module): # vits2
40
  self.norm_2 = modules.LayerNorm(filter_channels)
41
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
42
 
43
- self.pre_out_conv_1 = nn.Conv1d(
44
- 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
- )
46
- self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
47
- self.pre_out_conv_2 = nn.Conv1d(
48
- filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
49
  )
50
- self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
51
 
52
  if gin_channels != 0:
53
  self.cond = nn.Conv1d(gin_channels, in_channels, 1)
54
 
55
- self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
 
 
56
 
57
- def forward_probability(self, x, x_mask, dur, g=None):
58
  dur = self.dur_proj(dur)
59
  x = torch.cat([x, dur], dim=1)
60
- x = self.pre_out_conv_1(x * x_mask)
61
- x = torch.relu(x)
62
- x = self.pre_out_norm_1(x)
63
- x = self.drop(x)
64
- x = self.pre_out_conv_2(x * x_mask)
65
- x = torch.relu(x)
66
- x = self.pre_out_norm_2(x)
67
- x = self.drop(x)
68
- x = x * x_mask
69
  x = x.transpose(1, 2)
 
70
  output_prob = self.output_layer(x)
71
  return output_prob
72
 
@@ -86,7 +73,7 @@ class DurationDiscriminator(nn.Module): # vits2
86
 
87
  output_probs = []
88
  for dur in [dur_r, dur_hat]:
89
- output_prob = self.forward_probability(x, x_mask, dur, g)
90
  output_probs.append(output_prob)
91
 
92
  return output_probs
@@ -354,7 +341,6 @@ class TextEncoder(nn.Module):
354
  n_layers,
355
  kernel_size,
356
  p_dropout,
357
- n_speakers,
358
  gin_channels=0,
359
  ):
360
  super().__init__()
@@ -376,31 +362,6 @@ class TextEncoder(nn.Module):
376
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
377
  self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
378
  self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
379
- # self.emo_proj = nn.Linear(512, hidden_channels)
380
- self.in_feature_net = nn.Sequential(
381
- # input is assumed to an already normalized embedding
382
- nn.Linear(512, 1028, bias=False),
383
- nn.GELU(),
384
- nn.LayerNorm(1028),
385
- *[Block(1028, 512) for _ in range(1)],
386
- nn.Linear(1028, 512, bias=False),
387
- # normalize before passing to VQ?
388
- # nn.GELU(),
389
- # nn.LayerNorm(512),
390
- )
391
- self.emo_vq = VectorQuantize(
392
- dim=512,
393
- codebook_size=64,
394
- codebook_dim=32,
395
- commitment_weight=0.1,
396
- decay=0.85,
397
- heads=32,
398
- kmeans_iters=20,
399
- separate_codebook_per_head=True,
400
- stochastic_sample_codes=True,
401
- threshold_ema_dead_code=2,
402
- )
403
- self.out_feature_net = nn.Linear(512, hidden_channels)
404
 
405
  self.encoder = attentions.Encoder(
406
  hidden_channels,
@@ -413,18 +374,10 @@ class TextEncoder(nn.Module):
413
  )
414
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
415
 
416
- def forward(
417
- self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
418
- ):
419
- sid = sid.cpu()
420
  bert_emb = self.bert_proj(bert).transpose(1, 2)
421
  ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
422
  en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
423
- emo_emb = self.in_feature_net(emo)
424
- emo_emb, _, loss_commit = self.emo_vq(emo_emb.unsqueeze(1))
425
- loss_commit = loss_commit.mean()
426
- emo_emb = self.out_feature_net(emo_emb)
427
- # emo_emb = self.emo_proj(emo.unsqueeze(1))
428
  x = (
429
  self.emb(x)
430
  + self.tone_emb(tone)
@@ -432,7 +385,6 @@ class TextEncoder(nn.Module):
432
  + bert_emb
433
  + ja_bert_emb
434
  + en_bert_emb
435
- + emo_emb
436
  ) * math.sqrt(
437
  self.hidden_channels
438
  ) # [b, t, h]
@@ -445,7 +397,7 @@ class TextEncoder(nn.Module):
445
  stats = self.proj(x) * x_mask
446
 
447
  m, logs = torch.split(stats, self.out_channels, dim=1)
448
- return x, m, logs, x_mask, loss_commit
449
 
450
 
451
  class ResidualCouplingBlock(nn.Module):
@@ -748,6 +700,55 @@ class MultiPeriodDiscriminator(torch.nn.Module):
748
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
749
 
750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  class ReferenceEncoder(nn.Module):
752
  """
753
  inputs --- [N, Ty/r, n_mels*r] mels
@@ -878,7 +879,6 @@ class SynthesizerTrn(nn.Module):
878
  n_layers,
879
  kernel_size,
880
  p_dropout,
881
- self.n_speakers,
882
  gin_channels=self.enc_gin_channels,
883
  )
884
  self.dec = Generator(
@@ -946,14 +946,13 @@ class SynthesizerTrn(nn.Module):
946
  bert,
947
  ja_bert,
948
  en_bert,
949
- emo=None,
950
  ):
951
  if self.n_speakers > 0:
952
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
953
  else:
954
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
955
- x, m_p, logs_p, x_mask, loss_commit = self.enc_p(
956
- x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
957
  )
958
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
959
  z_p = self.flow(z, y_mask, g=g)
@@ -996,9 +995,11 @@ class SynthesizerTrn(nn.Module):
996
 
997
  logw_ = torch.log(w + 1e-6) * x_mask
998
  logw = self.dp(x, x_mask, g=g)
 
999
  l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1000
  x_mask
1001
  ) # for averaging
 
1002
 
1003
  l_length = l_length_dp + l_length_sdp
1004
 
@@ -1018,9 +1019,8 @@ class SynthesizerTrn(nn.Module):
1018
  x_mask,
1019
  y_mask,
1020
  (z, z_p, m_p, logs_p, m_q, logs_q),
1021
- (x, logw, logw_),
1022
  g,
1023
- loss_commit,
1024
  )
1025
 
1026
  def infer(
@@ -1033,7 +1033,6 @@ class SynthesizerTrn(nn.Module):
1033
  bert,
1034
  ja_bert,
1035
  en_bert,
1036
- emo=None,
1037
  noise_scale=0.667,
1038
  length_scale=1,
1039
  noise_scale_w=0.8,
@@ -1047,8 +1046,8 @@ class SynthesizerTrn(nn.Module):
1047
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1048
  else:
1049
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1050
- x, m_p, logs_p, x_mask, _ = self.enc_p(
1051
- x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
1052
  )
1053
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1054
  sdp_ratio
 
14
  from commons import init_weights, get_padding
15
  from text import symbols, num_tones, num_languages
16
 
 
 
17
 
18
  class DurationDiscriminator(nn.Module): # vits2
19
  def __init__(
 
38
  self.norm_2 = modules.LayerNorm(filter_channels)
39
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
40
 
41
+ self.LSTM = nn.LSTM(
42
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
 
 
 
 
43
  )
 
44
 
45
  if gin_channels != 0:
46
  self.cond = nn.Conv1d(gin_channels, in_channels, 1)
47
 
48
+ self.output_layer = nn.Sequential(
49
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
50
+ )
51
 
52
+ def forward_probability(self, x, dur):
53
  dur = self.dur_proj(dur)
54
  x = torch.cat([x, dur], dim=1)
 
 
 
 
 
 
 
 
 
55
  x = x.transpose(1, 2)
56
+ x, _ = self.LSTM(x)
57
  output_prob = self.output_layer(x)
58
  return output_prob
59
 
 
73
 
74
  output_probs = []
75
  for dur in [dur_r, dur_hat]:
76
+ output_prob = self.forward_probability(x, dur)
77
  output_probs.append(output_prob)
78
 
79
  return output_probs
 
341
  n_layers,
342
  kernel_size,
343
  p_dropout,
 
344
  gin_channels=0,
345
  ):
346
  super().__init__()
 
362
  self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
363
  self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
364
  self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  self.encoder = attentions.Encoder(
367
  hidden_channels,
 
374
  )
375
  self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
376
 
377
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
 
 
 
378
  bert_emb = self.bert_proj(bert).transpose(1, 2)
379
  ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
380
  en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
 
 
 
 
 
381
  x = (
382
  self.emb(x)
383
  + self.tone_emb(tone)
 
385
  + bert_emb
386
  + ja_bert_emb
387
  + en_bert_emb
 
388
  ) * math.sqrt(
389
  self.hidden_channels
390
  ) # [b, t, h]
 
397
  stats = self.proj(x) * x_mask
398
 
399
  m, logs = torch.split(stats, self.out_channels, dim=1)
400
+ return x, m, logs, x_mask
401
 
402
 
403
  class ResidualCouplingBlock(nn.Module):
 
700
  return y_d_rs, y_d_gs, fmap_rs, fmap_gs
701
 
702
 
703
+ class WavLMDiscriminator(nn.Module):
704
+ """docstring for Discriminator."""
705
+
706
+ def __init__(
707
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
708
+ ):
709
+ super(WavLMDiscriminator, self).__init__()
710
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
711
+ self.pre = norm_f(
712
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
713
+ )
714
+
715
+ self.convs = nn.ModuleList(
716
+ [
717
+ norm_f(
718
+ nn.Conv1d(
719
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
720
+ )
721
+ ),
722
+ norm_f(
723
+ nn.Conv1d(
724
+ initial_channel * 2,
725
+ initial_channel * 4,
726
+ kernel_size=5,
727
+ padding=2,
728
+ )
729
+ ),
730
+ norm_f(
731
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
732
+ ),
733
+ ]
734
+ )
735
+
736
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
737
+
738
+ def forward(self, x):
739
+ x = self.pre(x)
740
+
741
+ fmap = []
742
+ for l in self.convs:
743
+ x = l(x)
744
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
745
+ fmap.append(x)
746
+ x = self.conv_post(x)
747
+ x = torch.flatten(x, 1, -1)
748
+
749
+ return x
750
+
751
+
752
  class ReferenceEncoder(nn.Module):
753
  """
754
  inputs --- [N, Ty/r, n_mels*r] mels
 
879
  n_layers,
880
  kernel_size,
881
  p_dropout,
 
882
  gin_channels=self.enc_gin_channels,
883
  )
884
  self.dec = Generator(
 
946
  bert,
947
  ja_bert,
948
  en_bert,
 
949
  ):
950
  if self.n_speakers > 0:
951
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
952
  else:
953
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
954
+ x, m_p, logs_p, x_mask = self.enc_p(
955
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
956
  )
957
  z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
958
  z_p = self.flow(z, y_mask, g=g)
 
995
 
996
  logw_ = torch.log(w + 1e-6) * x_mask
997
  logw = self.dp(x, x_mask, g=g)
998
+ logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
999
  l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1000
  x_mask
1001
  ) # for averaging
1002
+ l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1003
 
1004
  l_length = l_length_dp + l_length_sdp
1005
 
 
1019
  x_mask,
1020
  y_mask,
1021
  (z, z_p, m_p, logs_p, m_q, logs_q),
1022
+ (x, logw, logw_, logw_sdp),
1023
  g,
 
1024
  )
1025
 
1026
  def infer(
 
1033
  bert,
1034
  ja_bert,
1035
  en_bert,
 
1036
  noise_scale=0.667,
1037
  length_scale=1,
1038
  noise_scale_w=0.8,
 
1046
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1047
  else:
1048
  g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1049
+ x, m_p, logs_p, x_mask = self.enc_p(
1050
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
1051
  )
1052
  logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1053
  sdp_ratio
modules.py CHANGED
@@ -83,7 +83,7 @@ class ConvReluNorm(nn.Module):
83
 
84
  class DDSConv(nn.Module):
85
  """
86
- Dialted and Depth-Separable Convolution
87
  """
88
 
89
  def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
 
83
 
84
  class DDSConv(nn.Module):
85
  """
86
+ Dilated and Depth-Separable Convolution
87
  """
88
 
89
  def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
onnx_infer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from onnx_modules.V220_OnnxInference import OnnxInferenceSession
2
+ import numpy as np
3
+
4
+ Session = OnnxInferenceSession(
5
+ {
6
+ "enc": "onnx/BertVits2.2PT/BertVits2.2PT_enc_p.onnx",
7
+ "emb_g": "onnx/BertVits2.2PT/BertVits2.2PT_emb.onnx",
8
+ "dp": "onnx/BertVits2.2PT/BertVits2.2PT_dp.onnx",
9
+ "sdp": "onnx/BertVits2.2PT/BertVits2.2PT_sdp.onnx",
10
+ "flow": "onnx/BertVits2.2PT/BertVits2.2PT_flow.onnx",
11
+ "dec": "onnx/BertVits2.2PT/BertVits2.2PT_dec.onnx",
12
+ },
13
+ Providers=["CPUExecutionProvider"],
14
+ )
15
+
16
+ # 这里的输入和原版是一样的,只需要在原版预处理结果出来之后加上.numpy()即可
17
+ x = np.array(
18
+ [
19
+ 0,
20
+ 97,
21
+ 0,
22
+ 8,
23
+ 0,
24
+ 78,
25
+ 0,
26
+ 8,
27
+ 0,
28
+ 76,
29
+ 0,
30
+ 37,
31
+ 0,
32
+ 40,
33
+ 0,
34
+ 97,
35
+ 0,
36
+ 8,
37
+ 0,
38
+ 23,
39
+ 0,
40
+ 8,
41
+ 0,
42
+ 74,
43
+ 0,
44
+ 26,
45
+ 0,
46
+ 104,
47
+ 0,
48
+ ]
49
+ )
50
+ tone = np.zeros_like(x)
51
+ language = np.zeros_like(x)
52
+ sid = np.array([0])
53
+ bert = np.random.randn(x.shape[0], 1024)
54
+ ja_bert = np.random.randn(x.shape[0], 1024)
55
+ en_bert = np.random.randn(x.shape[0], 1024)
56
+ emo = np.random.randn(512, 1)
57
+
58
+ audio = Session(x, tone, language, bert, ja_bert, en_bert, emo, sid)
59
+
60
+ print(audio)
re_matching.py CHANGED
@@ -44,7 +44,6 @@ def text_matching(text: str) -> list:
44
  result = []
45
  for speaker, dialogue in matches:
46
  result.append(extract_language_and_text_updated(speaker, dialogue))
47
- print(result)
48
  return result
49
 
50
 
 
44
  result = []
45
  for speaker, dialogue in matches:
46
  result.append(extract_language_and_text_updated(speaker, dialogue))
 
47
  return result
48
 
49
 
resample.py CHANGED
@@ -10,11 +10,11 @@ from config import config
10
 
11
 
12
  def process(item):
13
- wav_name, args = item
14
- wav_path = os.path.join(args.in_dir, wav_name)
15
  if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
16
  wav, sr = librosa.load(wav_path, sr=args.sr)
17
- soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
18
 
19
 
20
  if __name__ == "__main__":
@@ -54,11 +54,15 @@ if __name__ == "__main__":
54
  tasks = []
55
 
56
  for dirpath, _, filenames in os.walk(args.in_dir):
57
- if not os.path.isdir(args.out_dir):
58
- os.makedirs(args.out_dir, exist_ok=True)
 
 
 
59
  for filename in filenames:
60
  if filename.lower().endswith(".wav"):
61
- tasks.append((filename, args))
 
62
 
63
  for _ in tqdm(
64
  pool.imap_unordered(process, tasks),
 
10
 
11
 
12
  def process(item):
13
+ spkdir, wav_name, args = item
14
+ wav_path = os.path.join(args.in_dir, spkdir, wav_name)
15
  if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
16
  wav, sr = librosa.load(wav_path, sr=args.sr)
17
+ soundfile.write(os.path.join(args.out_dir, spkdir, wav_name), wav, sr)
18
 
19
 
20
  if __name__ == "__main__":
 
54
  tasks = []
55
 
56
  for dirpath, _, filenames in os.walk(args.in_dir):
57
+ # 子级目录
58
+ spk_dir = os.path.relpath(dirpath, args.in_dir)
59
+ spk_dir_out = os.path.join(args.out_dir, spk_dir)
60
+ if not os.path.isdir(spk_dir_out):
61
+ os.makedirs(spk_dir_out, exist_ok=True)
62
  for filename in filenames:
63
  if filename.lower().endswith(".wav"):
64
+ twople = (spk_dir, filename, args)
65
+ tasks.append(twople)
66
 
67
  for _ in tqdm(
68
  pool.imap_unordered(process, tasks),
resample_legacy.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import librosa
4
+ from multiprocessing import Pool, cpu_count
5
+
6
+ import soundfile
7
+ from tqdm import tqdm
8
+
9
+ from config import config
10
+
11
+
12
+ def process(item):
13
+ wav_name, args = item
14
+ wav_path = os.path.join(args.in_dir, wav_name)
15
+ if os.path.exists(wav_path) and wav_path.lower().endswith(".wav"):
16
+ wav, sr = librosa.load(wav_path, sr=args.sr)
17
+ soundfile.write(os.path.join(args.out_dir, wav_name), wav, sr)
18
+
19
+
20
+ if __name__ == "__main__":
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--sr",
24
+ type=int,
25
+ default=config.resample_config.sampling_rate,
26
+ help="sampling rate",
27
+ )
28
+ parser.add_argument(
29
+ "--in_dir",
30
+ type=str,
31
+ default=config.resample_config.in_dir,
32
+ help="path to source dir",
33
+ )
34
+ parser.add_argument(
35
+ "--out_dir",
36
+ type=str,
37
+ default=config.resample_config.out_dir,
38
+ help="path to target dir",
39
+ )
40
+ parser.add_argument(
41
+ "--processes",
42
+ type=int,
43
+ default=0,
44
+ help="cpu_processes",
45
+ )
46
+ args, _ = parser.parse_known_args()
47
+ # autodl 无卡模式会识别出46个cpu
48
+ if args.processes == 0:
49
+ processes = cpu_count() - 2 if cpu_count() > 4 else 1
50
+ else:
51
+ processes = args.processes
52
+ pool = Pool(processes=processes)
53
+
54
+ tasks = []
55
+
56
+ for dirpath, _, filenames in os.walk(args.in_dir):
57
+ if not os.path.isdir(args.out_dir):
58
+ os.makedirs(args.out_dir, exist_ok=True)
59
+ for filename in filenames:
60
+ if filename.lower().endswith(".wav"):
61
+ tasks.append((filename, args))
62
+
63
+ for _ in tqdm(
64
+ pool.imap_unordered(process, tasks),
65
+ ):
66
+ pass
67
+
68
+ pool.close()
69
+ pool.join()
70
+
71
+ print("音频重采样完毕!")
server.py CHANGED
@@ -3,10 +3,8 @@ import os
3
  from pathlib import Path
4
 
5
  import logging
6
- import re_matching
7
  import uuid
8
- from flask import Flask, request, jsonify, render_template_string
9
- from flask_cors import CORS
10
 
11
  logging.getLogger("numba").setLevel(logging.WARNING)
12
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
@@ -18,6 +16,8 @@ logging.basicConfig(
18
  )
19
 
20
  logger = logging.getLogger(__name__)
 
 
21
  import librosa
22
  import numpy as np
23
  import torch
@@ -26,24 +26,44 @@ from torch.utils.data import Dataset
26
  from torch.utils.data import DataLoader, Dataset
27
  from tqdm import tqdm
28
 
 
 
29
  import utils
30
  from config import config
31
- import requests
32
  import torch
33
  import commons
34
  from text import cleaned_text_to_sequence, get_bert
35
- from clap_wrapper import get_clap_audio_feature, get_clap_text_feature
36
-
37
  from text.cleaner import clean_text
38
  import utils
39
 
40
  from models import SynthesizerTrn
41
  from text.symbols import symbols
42
  import sys
 
43
 
44
- from scipy.io.wavfile import write
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  net_g = None
 
47
  device = (
48
  "cuda:0"
49
  if torch.cuda.is_available()
@@ -54,7 +74,375 @@ device = (
54
  )
55
  )
56
 
57
- #device = 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def get_net_g(model_path: str, device: str, hps):
60
  net_g = SynthesizerTrn(
@@ -68,11 +456,11 @@ def get_net_g(model_path: str, device: str, hps):
68
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
69
  return net_g
70
 
71
-
72
- def get_text(text, language_str, hps, device):
73
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
74
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
75
- #print(text)
76
  if hps.data.add_blank:
77
  phone = commons.intersperse(phone, 0)
78
  tone = commons.intersperse(tone, 0)
@@ -80,18 +468,24 @@ def get_text(text, language_str, hps, device):
80
  for i in range(len(word2ph)):
81
  word2ph[i] = word2ph[i] * 2
82
  word2ph[0] += 1
83
- bert_ori = get_bert(norm_text, word2ph, language_str, device)
 
 
84
  del word2ph
85
  assert bert_ori.shape[-1] == len(phone), phone
86
 
87
  if language_str == "ZH":
88
  bert = bert_ori
89
- ja_bert = torch.zeros(1024, len(phone))
90
- en_bert = torch.zeros(1024, len(phone))
91
  elif language_str == "JP":
92
- bert = torch.zeros(1024, len(phone))
93
  ja_bert = bert_ori
94
- en_bert = torch.zeros(1024, len(phone))
 
 
 
 
95
  else:
96
  raise ValueError("language_str should be ZH, JP or EN")
97
 
@@ -111,19 +505,47 @@ def infer(
111
  noise_scale_w,
112
  length_scale,
113
  sid,
114
- reference_audio=None,
115
- emotion='Happy',
 
 
 
 
116
  ):
117
-
118
- language= 'JP' if is_japanese(text) else 'ZH'
119
- if isinstance(reference_audio, np.ndarray):
120
- emo = get_clap_audio_feature(reference_audio, device)
121
- else:
122
- emo = get_clap_text_feature(emotion, device)
123
- emo = torch.squeeze(emo, dim=1)
 
 
 
 
 
124
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
125
- text, language, hps, device
 
 
 
 
 
126
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  with torch.no_grad():
128
  x_tst = phones.to(device).unsqueeze(0)
129
  tones = tones.to(device).unsqueeze(0)
@@ -132,7 +554,7 @@ def infer(
132
  ja_bert = ja_bert.to(device).unsqueeze(0)
133
  en_bert = en_bert.to(device).unsqueeze(0)
134
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
135
- emo = emo.to(device).unsqueeze(0)
136
  del phones
137
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
138
  audio = (
@@ -145,7 +567,6 @@ def infer(
145
  bert,
146
  ja_bert,
147
  en_bert,
148
- emo,
149
  sdp_ratio=sdp_ratio,
150
  noise_scale=noise_scale,
151
  noise_scale_w=noise_scale_w,
@@ -155,79 +576,292 @@ def infer(
155
  .float()
156
  .numpy()
157
  )
158
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
 
 
 
 
 
 
 
 
 
159
  if torch.cuda.is_available():
160
  torch.cuda.empty_cache()
161
- unique_filename = f"temp{uuid.uuid4()}.wav"
162
- write(unique_filename, 44100, audio)
163
- return unique_filename
164
-
165
- def is_japanese(string):
166
- for ch in string:
167
- if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
168
- return True
169
- return False
170
 
171
  def loadmodel(model):
172
- try:
173
- _ = net_g.eval()
174
- _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
175
- return "success"
176
- except:
177
- return "error"
178
-
179
- def send_audio_to_server(audio_path,text):
180
- url="http://127.0.0.1:3000/response"
181
- files = {'file': open(audio_path, 'rb')}
182
- data = {'text': text}
183
- try:
184
- response = requests.post(url, files=files,data=data)
185
- return response.status_code, response.text
186
- except Exception as e:
187
- return 500, str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
- app = Flask(__name__)
190
- CORS(app)
191
- @app.route('/')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  def tts():
194
- global last_text, last_model
195
- speaker = request.args.get('speaker')
196
- sdp_ratio = float(request.args.get('sdp_ratio', 0.2))
197
- noise_scale = float(request.args.get('noise_scale', 0.6))
198
- noise_scale_w = float(request.args.get('noise_scale_w', 0.8))
199
- length_scale = float(request.args.get('length_scale', 1))
200
- emotion = request.args.get('emotion', 'happy')
201
- text = request.args.get('text')
202
- is_chat = request.args.get('is_chat', 'false').lower() == 'true'
203
- model = request.args.get('model',modelPaths[-1])
204
-
205
- if not speaker or not text:
206
- return render_template_string("""
207
- <!DOCTYPE html>
208
- <html>
209
- <head>
210
- <title>TTS API Documentation</title>
211
- </head>
212
- <body>
213
- <iframe src="http://love.soyorin.top" style="width:100%; height:100vh; border:none;"></iframe>
214
- </body>
215
- </html>
216
- """)
217
-
218
- if model != last_model:
219
- unique_filename = loadmodel(model)
220
- last_model = model
221
- if is_chat and text == last_text:
222
- # Generate 1 second of silence and return
223
- unique_filename = 'blank.wav'
224
- silence = np.zeros(44100, dtype=np.int16)
225
- write(unique_filename , 44100, silence)
226
- else:
227
- last_text = text
228
- unique_filename = infer(text, sdp_ratio=sdp_ratio, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale,sid = speaker, reference_audio=None, emotion=emotion)
229
- status_code, response_text = send_audio_to_server(unique_filename,text)
230
- print(f"Response from server: {response_text} (Status code: {status_code})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  with open(unique_filename ,'rb') as bit:
232
  wav_bytes = bit.read()
233
  os.remove(unique_filename)
@@ -238,17 +872,13 @@ def tts():
238
 
239
 
240
  if __name__ == "__main__":
241
- languages = [ "Auto", "ZH", "JP"]
242
- modelPaths = []
243
- for dirpath, dirnames, filenames in os.walk("Data/BangDreamV22/models/"):
244
- for filename in filenames:
245
- modelPaths.append(os.path.join(dirpath, filename))
246
- hps = utils.get_hparams_from_file('Data/BangDreamV22/configs/config.json')
247
  net_g = get_net_g(
248
  model_path=modelPaths[-1], device=device, hps=hps
249
  )
250
  speaker_ids = hps.data.spk2id
251
  speakers = list(speaker_ids.keys())
252
- last_text = ""
253
- last_model = modelPaths[-1]
254
- app.run(host="0.0.0.0", port=5000)
 
3
  from pathlib import Path
4
 
5
  import logging
 
6
  import uuid
7
+ import re_matching
 
8
 
9
  logging.getLogger("numba").setLevel(logging.WARNING)
10
  logging.getLogger("markdown_it").setLevel(logging.WARNING)
 
16
  )
17
 
18
  logger = logging.getLogger(__name__)
19
+ import shutil
20
+ from scipy.io.wavfile import write
21
  import librosa
22
  import numpy as np
23
  import torch
 
26
  from torch.utils.data import DataLoader, Dataset
27
  from tqdm import tqdm
28
 
29
+ import gradio as gr
30
+
31
  import utils
32
  from config import config
33
+
34
  import torch
35
  import commons
36
  from text import cleaned_text_to_sequence, get_bert
 
 
37
  from text.cleaner import clean_text
38
  import utils
39
 
40
  from models import SynthesizerTrn
41
  from text.symbols import symbols
42
  import sys
43
+ import re
44
 
45
+ import random
46
+ import hashlib
47
 
48
+ from fugashi import Tagger
49
+ import jaconv
50
+ import unidic
51
+ import subprocess
52
+
53
+ import requests
54
+
55
+ from ebooklib import epub
56
+ import PyPDF2
57
+ from PyPDF2 import PdfReader
58
+ from bs4 import BeautifulSoup
59
+ import jieba
60
+ import romajitable
61
+
62
+ from flask import Flask, request, jsonify, render_template_string, send_file
63
+ from flask_cors import CORS
64
+ from scipy.io.wavfile import write
65
  net_g = None
66
+
67
  device = (
68
  "cuda:0"
69
  if torch.cuda.is_available()
 
74
  )
75
  )
76
 
77
+ #device = "cpu"
78
+ BandList = {
79
+ "PoppinParty":["香澄","有咲","たえ","りみ","沙綾"],
80
+ "Afterglow":["蘭","モカ","ひまり","巴","つぐみ"],
81
+ "HelloHappyWorld":["こころ","美咲","薫","花音","はぐみ"],
82
+ "PastelPalettes":["彩","日菜","千聖","イヴ","麻弥"],
83
+ "Roselia":["友希那","紗夜","リサ","燐子","あこ"],
84
+ "RaiseASuilen":["レイヤ","ロック","ますき","チュチュ","パレオ"],
85
+ "Morfonica":["ましろ","瑠唯","つくし","七深","透子"],
86
+ "MyGo":["燈","愛音","そよ","立希","楽奈"],
87
+ "AveMujica":["祥子","睦","海鈴","にゃむ","初華"],
88
+ "圣翔音乐学园":["華戀","光","香子","雙葉","真晝","純那","克洛迪娜","真矢","奈奈"],
89
+ "凛明馆女子学校":["珠緒","壘","文","悠悠子","一愛"],
90
+ "弗隆提亚艺术学校":["艾露","艾露露","菈樂菲","司","靜羽"],
91
+ "西克菲尔特音乐学院":["晶","未知留","八千代","栞","美帆"]
92
+ }
93
+
94
+ webBase = 'https://mahiruoshi-bangdream-bert-vits2.hf.space/'
95
+
96
+ port = 8080
97
+
98
+ languages = [ "Auto", "ZH", "JP"]
99
+ modelPaths = []
100
+ modes = ['pyopenjtalk-V2.3-Katakana','fugashi-V2.3-Katakana','pyopenjtalk-V2.3-Katakana-Katakana','fugashi-V2.3-Katakana-Katakana','onnx-V2.3']
101
+ sentence_modes = ['sentence','paragraph']
102
+ for dirpath, dirnames, filenames in os.walk('Data/BangDream/models/'):
103
+ for filename in filenames:
104
+ modelPaths.append(os.path.join(dirpath, filename))
105
+ hps = utils.get_hparams_from_file('Data/BangDream/config.json')
106
+
107
+ def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
108
+ """
109
+ :param Sentence: 待翻译语句
110
+ :param from_Language: 待翻译语句语言
111
+ :param to_Language: 目标语言
112
+ :return: 翻译后语句 出错时返回None
113
+
114
+ 常见语言代码:中文 zh 英语 en 日语 jp
115
+ """
116
+ appid = "20231117001883321"
117
+ key = "lMQbvZHeJveDceLof2wf"
118
+ if appid == "" or key == "":
119
+ return "请开发者在config.yml中配置app_key与secret_key"
120
+ url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
121
+ texts = Sentence.splitlines()
122
+ outTexts = []
123
+ for t in texts:
124
+ if t != "":
125
+ # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
126
+ salt = str(random.randint(1, 100000))
127
+ signString = appid + t + salt + key
128
+ hs = hashlib.md5()
129
+ hs.update(signString.encode("utf-8"))
130
+ signString = hs.hexdigest()
131
+ if from_Language == "":
132
+ from_Language = "auto"
133
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
134
+ payload = {
135
+ "q": t,
136
+ "from": from_Language,
137
+ "to": to_Language,
138
+ "appid": appid,
139
+ "salt": salt,
140
+ "sign": signString,
141
+ }
142
+ # 发送请求
143
+ try:
144
+ response = requests.post(
145
+ url=url, data=payload, headers=headers, timeout=3
146
+ )
147
+ response = response.json()
148
+ if "trans_result" in response.keys():
149
+ result = response["trans_result"][0]
150
+ if "dst" in result.keys():
151
+ dst = result["dst"]
152
+ outTexts.append(dst)
153
+ except Exception:
154
+ return Sentence
155
+ else:
156
+ outTexts.append(t)
157
+ return "\n".join(outTexts)
158
+
159
+ #文本清洗工具
160
+ def is_japanese(string):
161
+ for ch in string:
162
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
163
+ return True
164
+ return False
165
+
166
+ def is_chinese(string):
167
+ for ch in string:
168
+ if '\u4e00' <= ch <= '\u9fff':
169
+ return True
170
+ return False
171
+
172
+ def is_single_language(sentence):
173
+ # 检查句子是否为单一语言
174
+ contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
175
+ contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
176
+ contains_english = re.search(r'[a-zA-Z]', sentence) is not None
177
+ language_count = sum([contains_chinese, contains_japanese, contains_english])
178
+ return language_count == 1
179
+
180
+ def merge_scattered_parts(sentences):
181
+ """合并零散的部分到相邻的句子中,并确保单一语言性"""
182
+ merged_sentences = []
183
+ buffer_sentence = ""
184
+
185
+ for sentence in sentences:
186
+ # 检查是否是单一语言或者太短(可能是标点或单个词)
187
+ if is_single_language(sentence) and len(sentence) > 1:
188
+ # 如果缓冲区有内容,先将缓冲区的内容添加到列表
189
+ if buffer_sentence:
190
+ merged_sentences.append(buffer_sentence)
191
+ buffer_sentence = ""
192
+ merged_sentences.append(sentence)
193
+ else:
194
+ # 如果是零散的部分,将其添加到缓冲区
195
+ buffer_sentence += sentence
196
+
197
+ # 确保最后的缓冲区内容被添加
198
+ if buffer_sentence:
199
+ merged_sentences.append(buffer_sentence)
200
+
201
+ return merged_sentences
202
+
203
+ def is_only_punctuation(s):
204
+ """检查字符串是否只包含标点符号"""
205
+ # 此处列出中文、日文、英文常见标点符号
206
+ punctuation_pattern = re.compile(r'^[\s。*;,:“”()、!?《》\u3000\.,;:"\'?!()]+$')
207
+ return punctuation_pattern.match(s) is not None
208
+
209
+ def split_mixed_language(sentence):
210
+ # 分割混合语言句子
211
+ # 逐字符检查,分割不同语言部分
212
+ sub_sentences = []
213
+ current_language = None
214
+ current_part = ""
215
+
216
+ for char in sentence:
217
+ if re.match(r'[\u4e00-\u9fff]', char): # Chinese character
218
+ if current_language != 'chinese':
219
+ if current_part:
220
+ sub_sentences.append(current_part)
221
+ current_part = char
222
+ current_language = 'chinese'
223
+ else:
224
+ current_part += char
225
+ elif re.match(r'[\u3040-\u30ff\u31f0-\u31ff]', char): # Japanese character
226
+ if current_language != 'japanese':
227
+ if current_part:
228
+ sub_sentences.append(current_part)
229
+ current_part = char
230
+ current_language = 'japanese'
231
+ else:
232
+ current_part += char
233
+ elif re.match(r'[a-zA-Z]', char): # English character
234
+ if current_language != 'english':
235
+ if current_part:
236
+ sub_sentences.append(current_part)
237
+ current_part = char
238
+ current_language = 'english'
239
+ else:
240
+ current_part += char
241
+ else:
242
+ current_part += char # For punctuation and other characters
243
+
244
+ if current_part:
245
+ sub_sentences.append(current_part)
246
+
247
+ return sub_sentences
248
+
249
+ def replace_quotes(text):
250
+ # 替换中文、日文引号为英文引号
251
+ text = re.sub(r'[“”‘’『』「」()()]', '"', text)
252
+ return text
253
+
254
+ def remove_numeric_annotations(text):
255
+ # 定义用于匹配数字注释的正则表达式
256
+ # 包括 “”、【】和〔〕包裹的数字
257
+ pattern = r'“\d+”|【\d+】|〔\d+〕'
258
+ # 使用正则表达式替换掉这些注释
259
+ cleaned_text = re.sub(pattern, '', text)
260
+ return cleaned_text
261
+
262
+ def merge_adjacent_japanese(sentences):
263
+ """合并相邻且都只包含日语的句子"""
264
+ merged_sentences = []
265
+ i = 0
266
+ while i < len(sentences):
267
+ current_sentence = sentences[i]
268
+ if i + 1 < len(sentences) and is_japanese(current_sentence) and is_japanese(sentences[i + 1]):
269
+ # 当前句子和下一句都是日语,合并它们
270
+ while i + 1 < len(sentences) and is_japanese(sentences[i + 1]):
271
+ current_sentence += sentences[i + 1]
272
+ i += 1
273
+ merged_sentences.append(current_sentence)
274
+ i += 1
275
+ return merged_sentences
276
+
277
+ def extrac(text):
278
+ text = replace_quotes(remove_numeric_annotations(text)) # 替换引号
279
+ text = re.sub("<[^>]*>", "", text) # 移除 HTML 标签
280
+ # 使用换行符和标点符号进行初步分割
281
+ preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
282
+ final_sentences = []
283
+
284
+ preliminary_sentences = re.split(r'([\n。;!?\.\?!])', text)
285
+
286
+ for piece in preliminary_sentences:
287
+ if is_single_language(piece):
288
+ final_sentences.append(piece)
289
+ else:
290
+ sub_sentences = split_mixed_language(piece)
291
+ final_sentences.extend(sub_sentences)
292
+
293
+ # 处理长句子,使用jieba进行分词
294
+ split_sentences = []
295
+ for sentence in final_sentences:
296
+ split_sentences.extend(split_long_sentences(sentence))
297
+
298
+ # 合并相邻的日语句子
299
+ merged_japanese_sentences = merge_adjacent_japanese(split_sentences)
300
+
301
+ # 剔除只包含标点符号的元素
302
+ clean_sentences = [s for s in merged_japanese_sentences if not is_only_punctuation(s)]
303
+
304
+ # 移除空字符串并去除多余引号
305
+ return [s.replace('"','').strip() for s in clean_sentences if s]
306
+
307
+
308
+
309
+ # 移除空字符串
310
+
311
+ def is_mixed_language(sentence):
312
+ contains_chinese = re.search(r'[\u4e00-\u9fff]', sentence) is not None
313
+ contains_japanese = re.search(r'[\u3040-\u30ff\u31f0-\u31ff]', sentence) is not None
314
+ contains_english = re.search(r'[a-zA-Z]', sentence) is not None
315
+ languages_count = sum([contains_chinese, contains_japanese, contains_english])
316
+ return languages_count > 1
317
+
318
+ def split_mixed_language(sentence):
319
+ # 分割混合语言句子
320
+ sub_sentences = re.split(r'(?<=[。!?\.\?!])(?=")|(?<=")(?=[\u4e00-\u9fff\u3040-\u30ff\u31f0-\u31ff]|[a-zA-Z])', sentence)
321
+ return [s.strip() for s in sub_sentences if s.strip()]
322
+
323
+ def seconds_to_ass_time(seconds):
324
+ """将秒数转换为ASS时间格式"""
325
+ hours = int(seconds / 3600)
326
+ minutes = int((seconds % 3600) / 60)
327
+ seconds = int(seconds) % 60
328
+ milliseconds = int((seconds - int(seconds)) * 1000)
329
+ return "{:01d}:{:02d}:{:02d}.{:02d}".format(hours, minutes, seconds, int(milliseconds / 10))
330
+
331
+ def extract_text_from_epub(file_path):
332
+ book = epub.read_epub(file_path)
333
+ content = []
334
+ for item in book.items:
335
+ if isinstance(item, epub.EpubHtml):
336
+ soup = BeautifulSoup(item.content, 'html.parser')
337
+ content.append(soup.get_text())
338
+ return '\n'.join(content)
339
+
340
+ def extract_text_from_pdf(file_path):
341
+ with open(file_path, 'rb') as file:
342
+ reader = PdfReader(file)
343
+ content = [page.extract_text() for page in reader.pages]
344
+ return '\n'.join(content)
345
+
346
+ def remove_annotations(text):
347
+ # 移除方括号、尖括号和中文方括号中的内容
348
+ text = re.sub(r'\[.*?\]', '', text)
349
+ text = re.sub(r'\<.*?\>', '', text)
350
+ text = re.sub(r'&#8203;``【oaicite:1】``&#8203;', '', text)
351
+ return text
352
+
353
+ def extract_text_from_file(inputFile):
354
+ file_extension = os.path.splitext(inputFile)[1].lower()
355
+ if file_extension == ".epub":
356
+ return extract_text_from_epub(inputFile)
357
+ elif file_extension == ".pdf":
358
+ return extract_text_from_pdf(inputFile)
359
+ elif file_extension == ".txt":
360
+ with open(inputFile, 'r', encoding='utf-8') as f:
361
+ return f.read()
362
+ else:
363
+ raise ValueError(f"Unsupported file format: {file_extension}")
364
+
365
+ def split_by_punctuation(sentence):
366
+ """按照中文次级标点符号分割句子"""
367
+ # 常见的中文次级分隔符号:逗号、分号等
368
+ parts = re.split(r'([,,;;])', sentence)
369
+ # 将标点符号与前面的词语合并,避免单独标点符号成为一个部分
370
+ merged_parts = []
371
+ for part in parts:
372
+ if part and not part in ',,;;':
373
+ merged_parts.append(part)
374
+ elif merged_parts:
375
+ merged_parts[-1] += part
376
+ return merged_parts
377
+
378
+ def split_long_sentences(sentence, max_length=30):
379
+ """如果中文句子太长,先按标点分割,必要时使用jieba进行分词并分割"""
380
+ if len(sentence) > max_length and is_chinese(sentence):
381
+ # 首先尝试按照次级标点符号分割
382
+ preliminary_parts = split_by_punctuation(sentence)
383
+ new_sentences = []
384
+
385
+ for part in preliminary_parts:
386
+ # 如果部分仍然太长,使用jieba进行分词
387
+ if len(part) > max_length:
388
+ words = jieba.lcut(part)
389
+ current_sentence = ""
390
+ for word in words:
391
+ if len(current_sentence) + len(word) > max_length:
392
+ new_sentences.append(current_sentence)
393
+ current_sentence = word
394
+ else:
395
+ current_sentence += word
396
+ if current_sentence:
397
+ new_sentences.append(current_sentence)
398
+ else:
399
+ new_sentences.append(part)
400
+
401
+ return new_sentences
402
+ return [sentence] # 如果句子不长或不是中文,直接返回
403
+
404
+ def extract_and_convert(text):
405
+
406
+ # 使用正则表达式找出所有英文单词
407
+ english_parts = re.findall(r'\b[A-Za-z]+\b', text) # \b为单词边界标识
408
+
409
+ # 对每个英文单词进行片假名转换
410
+ kana_parts = ['\n{}\n'.format(romajitable.to_kana(word).katakana) for word in english_parts]
411
+
412
+ # 替换原文本中的英文部分
413
+ for eng, kana in zip(english_parts, kana_parts):
414
+ text = text.replace(eng, kana, 1) # 限制每次只替换一个实例
415
+
416
+ return text
417
+ # 推理工具
418
+ def download_unidic():
419
+ try:
420
+ Tagger()
421
+ print("Tagger launch successfully.")
422
+ except Exception as e:
423
+ print("UNIDIC dictionary not found, downloading...")
424
+ subprocess.run([sys.executable, "-m", "unidic", "download"])
425
+ print("Download completed.")
426
+
427
+ def kanji_to_hiragana(text):
428
+ global tagger
429
+ output = ""
430
+
431
+ # 更新正则表达式以更准确地区分文本和标点符号
432
+ segments = re.findall(r'[一-龥ぁ-んァ-ン\w]+|[^\一-龥ぁ-んァ-ン\w\s]', text, re.UNICODE)
433
+
434
+ for segment in segments:
435
+ if re.match(r'[一-龥ぁ-んァ-ン\w]+', segment):
436
+ # 如果是单词或汉字,转换为平假名
437
+ for word in tagger(segment):
438
+ kana = word.feature.kana or word.surface
439
+ hiragana = jaconv.kata2hira(kana) # 将片假名转换为平假名
440
+ output += hiragana
441
+ else:
442
+ # 如果是标点符号,保持不变
443
+ output += segment
444
+
445
+ return output
446
 
447
  def get_net_g(model_path: str, device: str, hps):
448
  net_g = SynthesizerTrn(
 
456
  _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
457
  return net_g
458
 
459
+ def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
460
+ style_text = None if style_text == "" else style_text
461
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
462
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
463
+
464
  if hps.data.add_blank:
465
  phone = commons.intersperse(phone, 0)
466
  tone = commons.intersperse(tone, 0)
 
468
  for i in range(len(word2ph)):
469
  word2ph[i] = word2ph[i] * 2
470
  word2ph[0] += 1
471
+ bert_ori = get_bert(
472
+ norm_text, word2ph, language_str, device, style_text, style_weight
473
+ )
474
  del word2ph
475
  assert bert_ori.shape[-1] == len(phone), phone
476
 
477
  if language_str == "ZH":
478
  bert = bert_ori
479
+ ja_bert = torch.randn(1024, len(phone))
480
+ en_bert = torch.randn(1024, len(phone))
481
  elif language_str == "JP":
482
+ bert = torch.randn(1024, len(phone))
483
  ja_bert = bert_ori
484
+ en_bert = torch.randn(1024, len(phone))
485
+ elif language_str == "EN":
486
+ bert = torch.randn(1024, len(phone))
487
+ ja_bert = torch.randn(1024, len(phone))
488
+ en_bert = bert_ori
489
  else:
490
  raise ValueError("language_str should be ZH, JP or EN")
491
 
 
505
  noise_scale_w,
506
  length_scale,
507
  sid,
508
+ style_text=None,
509
+ style_weight=0.7,
510
+ language = "Auto",
511
+ mode = 'pyopenjtalk-V2.3-Katakana',
512
+ skip_start=False,
513
+ skip_end=False,
514
  ):
515
+ if style_text == None:
516
+ style_text = ""
517
+ style_weight=0,
518
+ if mode == 'fugashi-V2.3-Katakana':
519
+ text = kanji_to_hiragana(text) if is_japanese(text) else text
520
+ if language == "JP":
521
+ text = translate(text,"jp")
522
+ if language == "ZH":
523
+ text = translate(text,"zh")
524
+ if language == "Auto":
525
+ language= 'JP' if is_japanese(text) else 'ZH'
526
+ #print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{sid}:{language}:{mode}:{skip_start}:{skip_end}')
527
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
528
+ text,
529
+ language,
530
+ hps,
531
+ device,
532
+ style_text=style_text,
533
+ style_weight=style_weight,
534
  )
535
+ if skip_start:
536
+ phones = phones[3:]
537
+ tones = tones[3:]
538
+ lang_ids = lang_ids[3:]
539
+ bert = bert[:, 3:]
540
+ ja_bert = ja_bert[:, 3:]
541
+ en_bert = en_bert[:, 3:]
542
+ if skip_end:
543
+ phones = phones[:-2]
544
+ tones = tones[:-2]
545
+ lang_ids = lang_ids[:-2]
546
+ bert = bert[:, :-2]
547
+ ja_bert = ja_bert[:, :-2]
548
+ en_bert = en_bert[:, :-2]
549
  with torch.no_grad():
550
  x_tst = phones.to(device).unsqueeze(0)
551
  tones = tones.to(device).unsqueeze(0)
 
554
  ja_bert = ja_bert.to(device).unsqueeze(0)
555
  en_bert = en_bert.to(device).unsqueeze(0)
556
  x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
557
+ # emo = emo.to(device).unsqueeze(0)
558
  del phones
559
  speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
560
  audio = (
 
567
  bert,
568
  ja_bert,
569
  en_bert,
 
570
  sdp_ratio=sdp_ratio,
571
  noise_scale=noise_scale,
572
  noise_scale_w=noise_scale_w,
 
576
  .float()
577
  .numpy()
578
  )
579
+ del (
580
+ x_tst,
581
+ tones,
582
+ lang_ids,
583
+ bert,
584
+ x_tst_lengths,
585
+ speakers,
586
+ ja_bert,
587
+ en_bert,
588
+ ) # , emo
589
  if torch.cuda.is_available():
590
  torch.cuda.empty_cache()
591
+ print("Success.")
592
+ return audio
 
 
 
 
 
 
 
593
 
594
  def loadmodel(model):
595
+ _ = net_g.eval()
596
+ _ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
597
+ return "success"
598
+
599
+ def generate_audio_and_srt_for_group(
600
+ group,
601
+ outputPath,
602
+ group_index,
603
+ sampling_rate,
604
+ speaker,
605
+ sdp_ratio,
606
+ noise_scale,
607
+ noise_scale_w,
608
+ length_scale,
609
+ speakerList,
610
+ silenceTime,
611
+ language,
612
+ mode,
613
+ skip_start,
614
+ skip_end,
615
+ style_text,
616
+ style_weight,
617
+ ):
618
+ audio_fin = []
619
+ ass_entries = []
620
+ start_time = 0
621
+ #speaker = random.choice(cara_list)
622
+ ass_header = """[Script Info]
623
+ ; 我没意见
624
+ Title: Audiobook
625
+ ScriptType: v4.00+
626
+ WrapStyle: 0
627
+ PlayResX: 640
628
+ PlayResY: 360
629
+ ScaledBorderAndShadow: yes
630
+ [V4+ Styles]
631
+ Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
632
+ Style: Default,Arial,20,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,1,1,2,10,10,10,1
633
+ [Events]
634
+ Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
635
+ """
636
+
637
+ for sentence in group:
638
+
639
+ if len(sentence) > 1:
640
+ FakeSpeaker = sentence.split("|")[0]
641
+ print(FakeSpeaker)
642
+ SpeakersList = re.split('\n', speakerList)
643
+ if FakeSpeaker in list(hps.data.spk2id.keys()):
644
+ speaker = FakeSpeaker
645
+ for i in SpeakersList:
646
+ if FakeSpeaker == i.split("|")[1]:
647
+ speaker = i.split("|")[0]
648
+ if sentence != '\n':
649
+ text = (remove_annotations(sentence.split("|")[-1]).replace(" ","")+"。").replace(",。","。")
650
+ if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
651
+ #print(f'{text}:{sdp_ratio}:{noise_scale}:{noise_scale_w}:{length_scale}:{length_scale}:{speaker}:{language}:{mode}:{skip_start}:{skip_end}')
652
+ audio = infer(
653
+ text,
654
+ sdp_ratio,
655
+ noise_scale,
656
+ noise_scale_w,
657
+ length_scale,
658
+ speaker,
659
+ style_text,
660
+ style_weight,
661
+ language,
662
+ mode,
663
+ skip_start,
664
+ skip_end,
665
+ )
666
+ silence_frames = int(silenceTime * 44010) if is_chinese(sentence) else int(silenceTime * 44010)
667
+ silence_data = np.zeros((silence_frames,), dtype=audio.dtype)
668
+ audio_fin.append(audio)
669
+ audio_fin.append(silence_data)
670
+ duration = len(audio) / sampling_rate
671
+ print(duration)
672
+ end_time = start_time + duration + silenceTime
673
+ ass_entries.append("Dialogue: 0,{},{},".format(seconds_to_ass_time(start_time), seconds_to_ass_time(end_time)) + "Default,,0,0,0,,{}".format(sentence.replace("|",":")))
674
+ start_time = end_time
675
+
676
+ wav_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.wav')
677
+ ass_filename = os.path.join(outputPath, f'audiobook_part_{group_index}.ass')
678
+ write(wav_filename, sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
679
 
680
+ with open(ass_filename, 'w', encoding='utf-8') as f:
681
+ f.write(ass_header + '\n'.join(ass_entries))
682
+ return (hps.data.sampling_rate, gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_fin)))
683
+
684
+ def generate_audio(
685
+ inputFile,
686
+ groupsize,
687
+ filepath,
688
+ silenceTime,
689
+ speakerList,
690
+ text,
691
+ sdp_ratio,
692
+ noise_scale,
693
+ noise_scale_w,
694
+ length_scale,
695
+ sid,
696
+ style_text=None,
697
+ style_weight=0.7,
698
+ language = "Auto",
699
+ mode = 'pyopenjtalk-V2.3-Katakana',
700
+ sentence_mode = 'sentence',
701
+ skip_start=False,
702
+ skip_end=False,
703
+ ):
704
+ if mode == 'pyopenjtalk-V2.3-Katakana' or mode == 'fugashi-V2.3-Katakana':
705
+ if sentence_mode == 'sentence':
706
+ audio = infer(
707
+ text,
708
+ sdp_ratio,
709
+ noise_scale,
710
+ noise_scale_w,
711
+ length_scale,
712
+ sid,
713
+ style_text,
714
+ style_weight,
715
+ language,
716
+ mode,
717
+ skip_start,
718
+ skip_end,
719
+ )
720
+ return (hps.data.sampling_rate,gr.processing_utils.convert_to_16_bit_wav(audio))
721
+ if sentence_mode == 'paragraph':
722
+ GROUP_SIZE = groupsize
723
+ directory_path = filepath if torch.cuda.is_available() else "books"
724
+ if os.path.exists(directory_path):
725
+ shutil.rmtree(directory_path)
726
+ os.makedirs(directory_path)
727
+ if inputFile:
728
+ text = extract_text_from_file(inputFile.name)
729
+ if language == 'Auto':
730
+ sentences = extrac(extract_and_convert(text))
731
+ else:
732
+ sentences = extrac(text)
733
+ for i in range(0, len(sentences), GROUP_SIZE):
734
+ group = sentences[i:i+GROUP_SIZE]
735
+ if speakerList == "":
736
+ speakerList = "无"
737
+ result = generate_audio_and_srt_for_group(
738
+ group,
739
+ directory_path,
740
+ i//GROUP_SIZE + 1,
741
+ 44100,
742
+ sid,
743
+ sdp_ratio,
744
+ noise_scale,
745
+ noise_scale_w,
746
+ length_scale,
747
+ speakerList,
748
+ silenceTime,
749
+ language,
750
+ mode,
751
+ skip_start,
752
+ skip_end,
753
+ style_text,
754
+ style_weight,
755
+ )
756
+ if not torch.cuda.is_available():
757
+ return result
758
+ return result
759
+
760
+ Flaskapp = Flask(__name__)
761
+ CORS(Flaskapp)
762
+ @Flaskapp.route('/', methods=['GET', 'POST'])
763
 
764
  def tts():
765
+ if request.method == 'POST':
766
+ input = request.json
767
+ inputFile = None
768
+ filepath = input['filepath']
769
+ groupSize = input['groupSize']
770
+ text = input['text']
771
+ sdp_ratio = input['sdp_ratio']
772
+ noise_scale = input['noise_scale']
773
+ noise_scale_w = input['noise_scale_w']
774
+ length_scale = input['length_scale']
775
+ sid = input['speaker']
776
+ style_text = input['style_text']
777
+ style_weight = input['style_weight']
778
+ language = input['language']
779
+ mode = input['mode']
780
+ sentence_mode = input['sentence_mode']
781
+ skip_start = input['skip_start']
782
+ skip_end = input['skip_end']
783
+ speakerList = input['speakerList']
784
+ silenceTime = input['silenceTime']
785
+ samplerate, audio = generate_audio(
786
+ inputFile,
787
+ groupSize,
788
+ filepath,
789
+ silenceTime,
790
+ speakerList,
791
+ text,
792
+ sdp_ratio,
793
+ noise_scale,
794
+ noise_scale_w,
795
+ length_scale,
796
+ sid,
797
+ style_text,
798
+ style_weight,
799
+ language,
800
+ mode,
801
+ sentence_mode,
802
+ skip_start,
803
+ skip_end,
804
+ )
805
+ unique_filename = f"temp{uuid.uuid4()}.wav"
806
+ write(unique_filename, samplerate, audio)
807
+ with open(unique_filename ,'rb') as bit:
808
+ wav_bytes = bit.read()
809
+ os.remove(unique_filename)
810
+ headers = {
811
+ 'Content-Type': 'audio/wav',
812
+ 'Text': unique_filename .encode('utf-8')}
813
+ return wav_bytes, 200, headers
814
+ groupSize = request.args.get('groupSize', default = 50, type = int)
815
+ text = request.args.get('text', default = '', type = str)
816
+ sdp_ratio = request.args.get('sdp_ratio', default = 0.5, type = float)
817
+ noise_scale = request.args.get('noise_scale', default = 0.6, type = float)
818
+ noise_scale_w = request.args.get('noise_scale_w', default = 0.667, type = float)
819
+ length_scale = request.args.get('length_scale', default = 1, type = float)
820
+ sid = request.args.get('speaker', default = '八千代', type = str)
821
+ style_text = request.args.get('style_text', default = '', type = str)
822
+ style_weight = request.args.get('style_weight', default = 0.7, type = float)
823
+ language = request.args.get('language', default = 'Auto', type = str)
824
+ mode = request.args.get('mode', default = 'pyopenjtalk-V2.3-Katakana', type = str)
825
+ sentence_mode = request.args.get('sentence_mode', default = 'sentence', type = str)
826
+ skip_start = request.args.get('skip_start', default = False, type = bool)
827
+ skip_end = request.args.get('skip_end', default = False, type = bool)
828
+ speakerList = request.args.get('speakerList', default = '', type = str)
829
+ silenceTime = request.args.get('silenceTime', default = 0.1, type = float)
830
+ inputFile = None
831
+ if not sid or not text:
832
+ return render_template_string(f"""
833
+ <!DOCTYPE html>
834
+ <html>
835
+ <head>
836
+ <title>TTS API Documentation</title>
837
+ </head>
838
+ <body>
839
+ <iframe src={webBase} style="width:100%; height:100vh; border:none;"></iframe>
840
+ </body>
841
+ </html>
842
+ """)
843
+ samplerate, audio = generate_audio(
844
+ inputFile,
845
+ groupSize,
846
+ None,
847
+ silenceTime,
848
+ speakerList,
849
+ text,
850
+ sdp_ratio,
851
+ noise_scale,
852
+ noise_scale_w,
853
+ length_scale,
854
+ sid,
855
+ style_text,
856
+ style_weight,
857
+ language,
858
+ mode,
859
+ sentence_mode,
860
+ skip_start,
861
+ skip_end,
862
+ )
863
+ unique_filename = f"temp{uuid.uuid4()}.wav"
864
+ write(unique_filename, samplerate, audio)
865
  with open(unique_filename ,'rb') as bit:
866
  wav_bytes = bit.read()
867
  os.remove(unique_filename)
 
872
 
873
 
874
  if __name__ == "__main__":
875
+ download_unidic()
876
+ tagger = Tagger()
 
 
 
 
877
  net_g = get_net_g(
878
  model_path=modelPaths[-1], device=device, hps=hps
879
  )
880
  speaker_ids = hps.data.spk2id
881
  speakers = list(speaker_ids.keys())
882
+
883
+ print("推理页面已开启!")
884
+ Flaskapp.run(host="0.0.0.0", port=8080,debug=True)
test.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from fugashi import Tagger
3
+ import jaconv
4
+
5
+ def kanji_to_hiragana(text):
6
+ tagger = Tagger()
7
+ output = ""
8
+
9
+ # 更新正则表达式以更准确地区分文本和标点符号
10
+ segments = re.findall(r'[一-龥ぁ-んァ-ン\w]+|[^\一-龥ぁ-んァ-ン\w\s]', text, re.UNICODE)
11
+
12
+ for segment in segments:
13
+ if re.match(r'[一-龥ぁ-んァ-ン\w]+', segment):
14
+ # 如果是单词或汉字,转换为平假名
15
+ for word in tagger(segment):
16
+ kana = word.feature.kana or word.surface
17
+ hiragana = jaconv.kata2hira(kana) # 将片假名转换为平假名
18
+ output += hiragana
19
+ else:
20
+ # 如果是标点符号,保持不变
21
+ output += segment
22
+
23
+ return output
24
+
25
+
26
+ text = "私は学生です。"
27
+ tagger = Tagger()
28
+
29
+ for word in tagger(text):
30
+ print(word.surface, word.feature.pos1)
31
+
32
+
33
+ # 示例文本
34
+ text = "業火とはね、どんな人でも彼女が築いた悪業は、いつの日か、彼女を少しも残さず焼き払うことになる……"
35
+ converted_text = kanji_to_hiragana(text)
36
+ print(converted_text)
train_ms.py CHANGED
@@ -13,7 +13,6 @@ import logging
13
  from config import config
14
  import argparse
15
  import datetime
16
- import gc
17
 
18
  logging.getLogger("numba").setLevel(logging.WARNING)
19
  import commons
@@ -27,14 +26,21 @@ from models import (
27
  SynthesizerTrn,
28
  MultiPeriodDiscriminator,
29
  DurationDiscriminator,
 
 
 
 
 
 
 
 
30
  )
31
- from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
32
  from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
33
  from text.symbols import symbols
34
 
35
  torch.backends.cuda.matmul.allow_tf32 = True
36
  torch.backends.cudnn.allow_tf32 = (
37
- True # If encontered training problem,please try to disable TF32.
38
  )
39
  torch.set_float32_matmul_precision("medium")
40
  torch.backends.cuda.sdp_kernel("flash")
@@ -42,7 +48,6 @@ torch.backends.cuda.enable_flash_sdp(True)
42
  torch.backends.cuda.enable_mem_efficient_sdp(
43
  True
44
  ) # Not available if torch version is lower than 2.0
45
- torch.backends.cuda.enable_math_sdp(True)
46
  global_step = 0
47
 
48
 
@@ -97,7 +102,7 @@ def run():
97
  args = parser.parse_args()
98
  model_dir = os.path.join(args.model, config.train_ms_config.model)
99
  if not os.path.exists(model_dir):
100
- os.makedirs(model_dir)
101
  hps = utils.get_hparams_from_file(args.config)
102
  hps.model_dir = model_dir
103
  # 比较路径是否相同
@@ -173,6 +178,8 @@ def run():
173
  0.1,
174
  gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
175
  ).cuda(local_rank)
 
 
176
  if (
177
  "use_spk_conditioned_encoder" in hps.model.keys()
178
  and hps.model.use_spk_conditioned_encoder is True
@@ -210,6 +217,9 @@ def run():
210
  param.requires_grad = False
211
 
212
  net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
 
 
 
213
  optim_g = torch.optim.AdamW(
214
  filter(lambda p: p.requires_grad, net_g.parameters()),
215
  hps.train.learning_rate,
@@ -222,6 +232,12 @@ def run():
222
  betas=hps.train.betas,
223
  eps=hps.train.eps,
224
  )
 
 
 
 
 
 
225
  if net_dur_disc is not None:
226
  optim_dur_disc = torch.optim.AdamW(
227
  net_dur_disc.parameters(),
@@ -233,12 +249,11 @@ def run():
233
  optim_dur_disc = None
234
  net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
235
  net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
236
- dur_resume_lr = None
237
  if net_dur_disc is not None:
238
  net_dur_disc = DDP(
239
  net_dur_disc,
240
  device_ids=[local_rank],
241
- find_unused_parameters=True,
242
  bucket_cap_mb=512,
243
  )
244
 
@@ -250,9 +265,10 @@ def run():
250
  token=config.openi_token,
251
  mirror=config.mirror,
252
  )
253
-
254
- try:
255
- if net_dur_disc is not None:
 
256
  _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
257
  utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
258
  net_dur_disc,
@@ -261,28 +277,32 @@ def run():
261
  if "skip_optimizer" in hps.train
262
  else True,
263
  )
264
- _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
265
- utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
266
- net_g,
267
- optim_g,
268
- skip_optimizer=hps.train.skip_optimizer
269
- if "skip_optimizer" in hps.train
270
- else True,
271
- )
272
- _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
273
- utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
274
- net_d,
275
- optim_d,
276
- skip_optimizer=hps.train.skip_optimizer
277
- if "skip_optimizer" in hps.train
278
- else True,
279
- )
280
- if not optim_g.param_groups[0].get("initial_lr"):
281
- optim_g.param_groups[0]["initial_lr"] = g_resume_lr
282
- if not optim_d.param_groups[0].get("initial_lr"):
283
- optim_d.param_groups[0]["initial_lr"] = d_resume_lr
284
  if not optim_dur_disc.param_groups[0].get("initial_lr"):
285
  optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
  epoch_str = max(epoch_str, 1)
288
  # global_step = (epoch_str - 1) * len(train_loader)
@@ -297,21 +317,43 @@ def run():
297
  epoch_str = 1
298
  global_step = 0
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
301
  optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
302
  )
303
  scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
304
  optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
305
  )
 
 
 
306
  if net_dur_disc is not None:
307
- if not optim_dur_disc.param_groups[0].get("initial_lr"):
308
- optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
309
  scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
310
  optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
311
  )
312
  else:
313
  scheduler_dur_disc = None
314
- scaler = GradScaler(enabled=hps.train.fp16_run)
 
 
 
 
 
 
 
315
 
316
  for epoch in range(epoch_str, hps.train.epochs + 1):
317
  if rank == 0:
@@ -320,9 +362,9 @@ def run():
320
  local_rank,
321
  epoch,
322
  hps,
323
- [net_g, net_d, net_dur_disc],
324
- [optim_g, optim_d, optim_dur_disc],
325
- [scheduler_g, scheduler_d, scheduler_dur_disc],
326
  scaler,
327
  [train_loader, eval_loader],
328
  logger,
@@ -334,9 +376,9 @@ def run():
334
  local_rank,
335
  epoch,
336
  hps,
337
- [net_g, net_d, net_dur_disc],
338
- [optim_g, optim_d, optim_dur_disc],
339
- [scheduler_g, scheduler_d, scheduler_dur_disc],
340
  scaler,
341
  [train_loader, None],
342
  None,
@@ -344,6 +386,7 @@ def run():
344
  )
345
  scheduler_g.step()
346
  scheduler_d.step()
 
347
  if net_dur_disc is not None:
348
  scheduler_dur_disc.step()
349
 
@@ -361,9 +404,9 @@ def train_and_evaluate(
361
  logger,
362
  writers,
363
  ):
364
- net_g, net_d, net_dur_disc = nets
365
- optim_g, optim_d, optim_dur_disc = optims
366
- scheduler_g, scheduler_d, scheduler_dur_disc = schedulers
367
  train_loader, eval_loader = loaders
368
  if writers is not None:
369
  writer, writer_eval = writers
@@ -373,6 +416,7 @@ def train_and_evaluate(
373
 
374
  net_g.train()
375
  net_d.train()
 
376
  if net_dur_disc is not None:
377
  net_dur_disc.train()
378
  for batch_idx, (
@@ -388,7 +432,6 @@ def train_and_evaluate(
388
  bert,
389
  ja_bert,
390
  en_bert,
391
- emo,
392
  ) in enumerate(tqdm(train_loader)):
393
  if net_g.module.use_noise_scaled_mas:
394
  current_mas_noise_scale = (
@@ -411,9 +454,8 @@ def train_and_evaluate(
411
  bert = bert.cuda(local_rank, non_blocking=True)
412
  ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
413
  en_bert = en_bert.cuda(local_rank, non_blocking=True)
414
- emo = emo.cuda(local_rank, non_blocking=True)
415
 
416
- with autocast(enabled=hps.train.fp16_run):
417
  (
418
  y_hat,
419
  l_length,
@@ -422,9 +464,8 @@ def train_and_evaluate(
422
  x_mask,
423
  z_mask,
424
  (z, z_p, m_p, logs_p, m_q, logs_q),
425
- (hidden_x, logw, logw_),
426
  g,
427
- loss_commit,
428
  ) = net_g(
429
  x,
430
  x_lengths,
@@ -436,7 +477,6 @@ def train_and_evaluate(
436
  bert,
437
  ja_bert,
438
  en_bert,
439
- emo,
440
  )
441
  mel = spec_to_mel_torch(
442
  spec,
@@ -450,7 +490,7 @@ def train_and_evaluate(
450
  mel, ids_slice, hps.train.segment_size // hps.data.hop_length
451
  )
452
  y_hat_mel = mel_spectrogram_torch(
453
- y_hat.squeeze(1),
454
  hps.data.filter_length,
455
  hps.data.n_mel_channels,
456
  hps.data.sampling_rate,
@@ -466,7 +506,7 @@ def train_and_evaluate(
466
 
467
  # Discriminator
468
  y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
469
- with autocast(enabled=False):
470
  loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
471
  y_d_hat_r, y_d_hat_g
472
  )
@@ -475,11 +515,20 @@ def train_and_evaluate(
475
  y_dur_hat_r, y_dur_hat_g = net_dur_disc(
476
  hidden_x.detach(),
477
  x_mask.detach(),
 
478
  logw.detach(),
 
 
 
 
 
479
  logw_.detach(),
 
480
  g.detach(),
481
  )
482
- with autocast(enabled=False):
 
 
483
  # TODO: I think need to mean using the mask, but for now, just mean all
484
  (
485
  loss_dur_disc,
@@ -490,31 +539,60 @@ def train_and_evaluate(
490
  optim_dur_disc.zero_grad()
491
  scaler.scale(loss_dur_disc_all).backward()
492
  scaler.unscale_(optim_dur_disc)
493
- commons.clip_grad_value_(net_dur_disc.parameters(), None)
 
 
 
 
 
494
  scaler.step(optim_dur_disc)
495
 
496
  optim_d.zero_grad()
497
  scaler.scale(loss_disc_all).backward()
498
  scaler.unscale_(optim_d)
 
 
499
  grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
500
  scaler.step(optim_d)
501
 
502
- with autocast(enabled=hps.train.fp16_run):
 
 
 
 
 
 
 
 
 
 
 
 
503
  # Generator
504
  y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
505
  if net_dur_disc is not None:
506
- y_dur_hat_r, y_dur_hat_g = net_dur_disc(
507
- hidden_x, x_mask, logw, logw_, g
508
- )
509
- with autocast(enabled=False):
510
  loss_dur = torch.sum(l_length.float())
511
  loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
512
  loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
513
 
514
  loss_fm = feature_loss(fmap_r, fmap_g)
515
  loss_gen, losses_gen = generator_loss(y_d_hat_g)
 
 
 
 
516
  loss_gen_all = (
517
- loss_gen + loss_fm + loss_mel + loss_dur + loss_kl + loss_commit
 
 
 
 
 
 
518
  )
519
  if net_dur_disc is not None:
520
  loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
@@ -522,6 +600,8 @@ def train_and_evaluate(
522
  optim_g.zero_grad()
523
  scaler.scale(loss_gen_all).backward()
524
  scaler.unscale_(optim_g)
 
 
525
  grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
526
  scaler.step(optim_g)
527
  scaler.update()
@@ -540,9 +620,12 @@ def train_and_evaluate(
540
  scalar_dict = {
541
  "loss/g/total": loss_gen_all,
542
  "loss/d/total": loss_disc_all,
 
543
  "learning_rate": lr,
544
  "grad_norm_d": grad_norm_d,
545
  "grad_norm_g": grad_norm_g,
 
 
546
  }
547
  scalar_dict.update(
548
  {
@@ -550,6 +633,8 @@ def train_and_evaluate(
550
  "loss/g/mel": loss_mel,
551
  "loss/g/dur": loss_dur,
552
  "loss/g/kl": loss_kl,
 
 
553
  }
554
  )
555
  scalar_dict.update(
@@ -562,6 +647,30 @@ def train_and_evaluate(
562
  {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
563
  )
564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  image_dict = {
566
  "slice/mel_org": utils.plot_spectrogram_to_numpy(
567
  y_mel[0].data.cpu().numpy()
@@ -599,6 +708,13 @@ def train_and_evaluate(
599
  epoch,
600
  os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
601
  )
 
 
 
 
 
 
 
602
  if net_dur_disc is not None:
603
  utils.save_checkpoint(
604
  net_dur_disc,
@@ -617,8 +733,8 @@ def train_and_evaluate(
617
 
618
  global_step += 1
619
 
620
- gc.collect()
621
- torch.cuda.empty_cache()
622
  if rank == 0:
623
  logger.info("====> Epoch: {}".format(epoch))
624
 
@@ -642,7 +758,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
642
  bert,
643
  ja_bert,
644
  en_bert,
645
- emo,
646
  ) in enumerate(eval_loader):
647
  x, x_lengths = x.cuda(), x_lengths.cuda()
648
  spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
@@ -653,7 +768,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
653
  en_bert = en_bert.cuda()
654
  tone = tone.cuda()
655
  language = language.cuda()
656
- emo = emo.cuda()
657
  for use_sdp in [True, False]:
658
  y_hat, attn, mask, *_ = generator.module.infer(
659
  x,
@@ -664,7 +778,6 @@ def evaluate(hps, generator, eval_loader, writer_eval):
664
  bert,
665
  ja_bert,
666
  en_bert,
667
- emo,
668
  y=spec,
669
  max_len=1000,
670
  sdp_ratio=0.0 if not use_sdp else 1.0,
 
13
  from config import config
14
  import argparse
15
  import datetime
 
16
 
17
  logging.getLogger("numba").setLevel(logging.WARNING)
18
  import commons
 
26
  SynthesizerTrn,
27
  MultiPeriodDiscriminator,
28
  DurationDiscriminator,
29
+ WavLMDiscriminator,
30
+ )
31
+ from losses import (
32
+ generator_loss,
33
+ discriminator_loss,
34
+ feature_loss,
35
+ kl_loss,
36
+ WavLMLoss,
37
  )
 
38
  from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
39
  from text.symbols import symbols
40
 
41
  torch.backends.cuda.matmul.allow_tf32 = True
42
  torch.backends.cudnn.allow_tf32 = (
43
+ True # If encountered training problem,please try to disable TF32.
44
  )
45
  torch.set_float32_matmul_precision("medium")
46
  torch.backends.cuda.sdp_kernel("flash")
 
48
  torch.backends.cuda.enable_mem_efficient_sdp(
49
  True
50
  ) # Not available if torch version is lower than 2.0
 
51
  global_step = 0
52
 
53
 
 
102
  args = parser.parse_args()
103
  model_dir = os.path.join(args.model, config.train_ms_config.model)
104
  if not os.path.exists(model_dir):
105
+ os.makedirs(model_dir, exist_ok=True)
106
  hps = utils.get_hparams_from_file(args.config)
107
  hps.model_dir = model_dir
108
  # 比较路径是否相同
 
178
  0.1,
179
  gin_channels=hps.model.gin_channels if hps.data.n_speakers != 0 else 0,
180
  ).cuda(local_rank)
181
+ else:
182
+ net_dur_disc = None
183
  if (
184
  "use_spk_conditioned_encoder" in hps.model.keys()
185
  and hps.model.use_spk_conditioned_encoder is True
 
217
  param.requires_grad = False
218
 
219
  net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
220
+ net_wd = WavLMDiscriminator(
221
+ hps.model.slm.hidden, hps.model.slm.nlayers, hps.model.slm.initial_channel
222
+ ).cuda(local_rank)
223
  optim_g = torch.optim.AdamW(
224
  filter(lambda p: p.requires_grad, net_g.parameters()),
225
  hps.train.learning_rate,
 
232
  betas=hps.train.betas,
233
  eps=hps.train.eps,
234
  )
235
+ optim_wd = torch.optim.AdamW(
236
+ net_wd.parameters(),
237
+ hps.train.learning_rate,
238
+ betas=hps.train.betas,
239
+ eps=hps.train.eps,
240
+ )
241
  if net_dur_disc is not None:
242
  optim_dur_disc = torch.optim.AdamW(
243
  net_dur_disc.parameters(),
 
249
  optim_dur_disc = None
250
  net_g = DDP(net_g, device_ids=[local_rank], bucket_cap_mb=512)
251
  net_d = DDP(net_d, device_ids=[local_rank], bucket_cap_mb=512)
252
+ net_wd = DDP(net_wd, device_ids=[local_rank], bucket_cap_mb=512)
253
  if net_dur_disc is not None:
254
  net_dur_disc = DDP(
255
  net_dur_disc,
256
  device_ids=[local_rank],
 
257
  bucket_cap_mb=512,
258
  )
259
 
 
265
  token=config.openi_token,
266
  mirror=config.mirror,
267
  )
268
+ dur_resume_lr = hps.train.learning_rate
269
+ wd_resume_lr = hps.train.learning_rate
270
+ if net_dur_disc is not None:
271
+ try:
272
  _, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
273
  utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
274
  net_dur_disc,
 
277
  if "skip_optimizer" in hps.train
278
  else True,
279
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  if not optim_dur_disc.param_groups[0].get("initial_lr"):
281
  optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
282
+ except:
283
+ print("Initialize dur_disc")
284
+
285
+ try:
286
+ _, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
287
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
288
+ net_g,
289
+ optim_g,
290
+ skip_optimizer=hps.train.skip_optimizer
291
+ if "skip_optimizer" in hps.train
292
+ else True,
293
+ )
294
+ _, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
295
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
296
+ net_d,
297
+ optim_d,
298
+ skip_optimizer=hps.train.skip_optimizer
299
+ if "skip_optimizer" in hps.train
300
+ else True,
301
+ )
302
+ if not optim_g.param_groups[0].get("initial_lr"):
303
+ optim_g.param_groups[0]["initial_lr"] = g_resume_lr
304
+ if not optim_d.param_groups[0].get("initial_lr"):
305
+ optim_d.param_groups[0]["initial_lr"] = d_resume_lr
306
 
307
  epoch_str = max(epoch_str, 1)
308
  # global_step = (epoch_str - 1) * len(train_loader)
 
317
  epoch_str = 1
318
  global_step = 0
319
 
320
+ try:
321
+ _, optim_wd, wd_resume_lr, epoch_str = utils.load_checkpoint(
322
+ utils.latest_checkpoint_path(hps.model_dir, "WD_*.pth"),
323
+ net_wd,
324
+ optim_wd,
325
+ skip_optimizer=hps.train.skip_optimizer
326
+ if "skip_optimizer" in hps.train
327
+ else True,
328
+ )
329
+ if not optim_wd.param_groups[0].get("initial_lr"):
330
+ optim_wd.param_groups[0]["initial_lr"] = wd_resume_lr
331
+ except Exception as e:
332
+ print(e)
333
+
334
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
335
  optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
336
  )
337
  scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
338
  optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
339
  )
340
+ scheduler_wd = torch.optim.lr_scheduler.ExponentialLR(
341
+ optim_wd, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
342
+ )
343
  if net_dur_disc is not None:
 
 
344
  scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
345
  optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
346
  )
347
  else:
348
  scheduler_dur_disc = None
349
+ scaler = GradScaler(enabled=hps.train.bf16_run)
350
+
351
+ wl = WavLMLoss(
352
+ hps.model.slm.model,
353
+ net_wd,
354
+ hps.data.sampling_rate,
355
+ hps.model.slm.sr,
356
+ ).to(local_rank)
357
 
358
  for epoch in range(epoch_str, hps.train.epochs + 1):
359
  if rank == 0:
 
362
  local_rank,
363
  epoch,
364
  hps,
365
+ [net_g, net_d, net_dur_disc, net_wd, wl],
366
+ [optim_g, optim_d, optim_dur_disc, optim_wd],
367
+ [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
368
  scaler,
369
  [train_loader, eval_loader],
370
  logger,
 
376
  local_rank,
377
  epoch,
378
  hps,
379
+ [net_g, net_d, net_dur_disc, net_wd, wl],
380
+ [optim_g, optim_d, optim_dur_disc, optim_wd],
381
+ [scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd],
382
  scaler,
383
  [train_loader, None],
384
  None,
 
386
  )
387
  scheduler_g.step()
388
  scheduler_d.step()
389
+ scheduler_wd.step()
390
  if net_dur_disc is not None:
391
  scheduler_dur_disc.step()
392
 
 
404
  logger,
405
  writers,
406
  ):
407
+ net_g, net_d, net_dur_disc, net_wd, wl = nets
408
+ optim_g, optim_d, optim_dur_disc, optim_wd = optims
409
+ scheduler_g, scheduler_d, scheduler_dur_disc, scheduler_wd = schedulers
410
  train_loader, eval_loader = loaders
411
  if writers is not None:
412
  writer, writer_eval = writers
 
416
 
417
  net_g.train()
418
  net_d.train()
419
+ net_wd.train()
420
  if net_dur_disc is not None:
421
  net_dur_disc.train()
422
  for batch_idx, (
 
432
  bert,
433
  ja_bert,
434
  en_bert,
 
435
  ) in enumerate(tqdm(train_loader)):
436
  if net_g.module.use_noise_scaled_mas:
437
  current_mas_noise_scale = (
 
454
  bert = bert.cuda(local_rank, non_blocking=True)
455
  ja_bert = ja_bert.cuda(local_rank, non_blocking=True)
456
  en_bert = en_bert.cuda(local_rank, non_blocking=True)
 
457
 
458
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
459
  (
460
  y_hat,
461
  l_length,
 
464
  x_mask,
465
  z_mask,
466
  (z, z_p, m_p, logs_p, m_q, logs_q),
467
+ (hidden_x, logw, logw_, logw_sdp),
468
  g,
 
469
  ) = net_g(
470
  x,
471
  x_lengths,
 
477
  bert,
478
  ja_bert,
479
  en_bert,
 
480
  )
481
  mel = spec_to_mel_torch(
482
  spec,
 
490
  mel, ids_slice, hps.train.segment_size // hps.data.hop_length
491
  )
492
  y_hat_mel = mel_spectrogram_torch(
493
+ y_hat.squeeze(1).float(),
494
  hps.data.filter_length,
495
  hps.data.n_mel_channels,
496
  hps.data.sampling_rate,
 
506
 
507
  # Discriminator
508
  y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach())
509
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
510
  loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
511
  y_d_hat_r, y_d_hat_g
512
  )
 
515
  y_dur_hat_r, y_dur_hat_g = net_dur_disc(
516
  hidden_x.detach(),
517
  x_mask.detach(),
518
+ logw_.detach(),
519
  logw.detach(),
520
+ g.detach(),
521
+ )
522
+ y_dur_hat_r_sdp, y_dur_hat_g_sdp = net_dur_disc(
523
+ hidden_x.detach(),
524
+ x_mask.detach(),
525
  logw_.detach(),
526
+ logw_sdp.detach(),
527
  g.detach(),
528
  )
529
+ y_dur_hat_r = y_dur_hat_r + y_dur_hat_r_sdp
530
+ y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
531
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
532
  # TODO: I think need to mean using the mask, but for now, just mean all
533
  (
534
  loss_dur_disc,
 
539
  optim_dur_disc.zero_grad()
540
  scaler.scale(loss_dur_disc_all).backward()
541
  scaler.unscale_(optim_dur_disc)
542
+ # torch.nn.utils.clip_grad_norm_(
543
+ # parameters=net_dur_disc.parameters(), max_norm=100
544
+ # )
545
+ grad_norm_dur = commons.clip_grad_value_(
546
+ net_dur_disc.parameters(), None
547
+ )
548
  scaler.step(optim_dur_disc)
549
 
550
  optim_d.zero_grad()
551
  scaler.scale(loss_disc_all).backward()
552
  scaler.unscale_(optim_d)
553
+ if getattr(hps.train, "bf16_run", False):
554
+ torch.nn.utils.clip_grad_norm_(parameters=net_d.parameters(), max_norm=200)
555
  grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
556
  scaler.step(optim_d)
557
 
558
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
559
+ loss_slm = wl.discriminator(
560
+ y.detach().squeeze(), y_hat.detach().squeeze()
561
+ ).mean()
562
+
563
+ optim_wd.zero_grad()
564
+ scaler.scale(loss_slm).backward()
565
+ scaler.unscale_(optim_wd)
566
+ # torch.nn.utils.clip_grad_norm_(parameters=net_wd.parameters(), max_norm=200)
567
+ grad_norm_wd = commons.clip_grad_value_(net_wd.parameters(), None)
568
+ scaler.step(optim_wd)
569
+
570
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
571
  # Generator
572
  y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat)
573
  if net_dur_disc is not None:
574
+ _, y_dur_hat_g = net_dur_disc(hidden_x, x_mask, logw_, logw, g)
575
+ _, y_dur_hat_g_sdp = net_dur_disc(hidden_x, x_mask, logw_, logw_sdp, g)
576
+ y_dur_hat_g = y_dur_hat_g + y_dur_hat_g_sdp
577
+ with autocast(enabled=hps.train.bf16_run, dtype=torch.bfloat16):
578
  loss_dur = torch.sum(l_length.float())
579
  loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
580
  loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
581
 
582
  loss_fm = feature_loss(fmap_r, fmap_g)
583
  loss_gen, losses_gen = generator_loss(y_d_hat_g)
584
+
585
+ loss_lm = wl(y.detach().squeeze(), y_hat.squeeze()).mean()
586
+ loss_lm_gen = wl.generator(y_hat.squeeze())
587
+
588
  loss_gen_all = (
589
+ loss_gen
590
+ + loss_fm
591
+ + loss_mel
592
+ + loss_dur
593
+ + loss_kl
594
+ + loss_lm
595
+ + loss_lm_gen
596
  )
597
  if net_dur_disc is not None:
598
  loss_dur_gen, losses_dur_gen = generator_loss(y_dur_hat_g)
 
600
  optim_g.zero_grad()
601
  scaler.scale(loss_gen_all).backward()
602
  scaler.unscale_(optim_g)
603
+ if getattr(hps.train, "bf16_run", False):
604
+ torch.nn.utils.clip_grad_norm_(parameters=net_g.parameters(), max_norm=500)
605
  grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
606
  scaler.step(optim_g)
607
  scaler.update()
 
620
  scalar_dict = {
621
  "loss/g/total": loss_gen_all,
622
  "loss/d/total": loss_disc_all,
623
+ "loss/wd/total": loss_slm,
624
  "learning_rate": lr,
625
  "grad_norm_d": grad_norm_d,
626
  "grad_norm_g": grad_norm_g,
627
+ "grad_norm_dur": grad_norm_dur,
628
+ "grad_norm_wd": grad_norm_wd,
629
  }
630
  scalar_dict.update(
631
  {
 
633
  "loss/g/mel": loss_mel,
634
  "loss/g/dur": loss_dur,
635
  "loss/g/kl": loss_kl,
636
+ "loss/g/lm": loss_lm,
637
+ "loss/g/lm_gen": loss_lm_gen,
638
  }
639
  )
640
  scalar_dict.update(
 
647
  {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
648
  )
649
 
650
+ if net_dur_disc is not None:
651
+ scalar_dict.update({"loss/dur_disc/total": loss_dur_disc_all})
652
+
653
+ scalar_dict.update(
654
+ {
655
+ "loss/dur_disc_g/{}".format(i): v
656
+ for i, v in enumerate(losses_dur_disc_g)
657
+ }
658
+ )
659
+ scalar_dict.update(
660
+ {
661
+ "loss/dur_disc_r/{}".format(i): v
662
+ for i, v in enumerate(losses_dur_disc_r)
663
+ }
664
+ )
665
+
666
+ scalar_dict.update({"loss/g/dur_gen": loss_dur_gen})
667
+ scalar_dict.update(
668
+ {
669
+ "loss/g/dur_gen_{}".format(i): v
670
+ for i, v in enumerate(losses_dur_gen)
671
+ }
672
+ )
673
+
674
  image_dict = {
675
  "slice/mel_org": utils.plot_spectrogram_to_numpy(
676
  y_mel[0].data.cpu().numpy()
 
708
  epoch,
709
  os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
710
  )
711
+ utils.save_checkpoint(
712
+ net_wd,
713
+ optim_wd,
714
+ hps.train.learning_rate,
715
+ epoch,
716
+ os.path.join(hps.model_dir, "WD_{}.pth".format(global_step)),
717
+ )
718
  if net_dur_disc is not None:
719
  utils.save_checkpoint(
720
  net_dur_disc,
 
733
 
734
  global_step += 1
735
 
736
+ # gc.collect()
737
+ # torch.cuda.empty_cache()
738
  if rank == 0:
739
  logger.info("====> Epoch: {}".format(epoch))
740
 
 
758
  bert,
759
  ja_bert,
760
  en_bert,
 
761
  ) in enumerate(eval_loader):
762
  x, x_lengths = x.cuda(), x_lengths.cuda()
763
  spec, spec_lengths = spec.cuda(), spec_lengths.cuda()
 
768
  en_bert = en_bert.cuda()
769
  tone = tone.cuda()
770
  language = language.cuda()
 
771
  for use_sdp in [True, False]:
772
  y_hat, attn, mask, *_ = generator.module.infer(
773
  x,
 
778
  bert,
779
  ja_bert,
780
  en_bert,
 
781
  y=spec,
782
  max_len=1000,
783
  sdp_ratio=0.0 if not use_sdp else 1.0,
utils.py CHANGED
@@ -301,7 +301,11 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
301
 
302
  to_del = [
303
  os.path.join(path_to_models, fn)
304
- for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
 
 
 
 
305
  ]
306
 
307
  def del_info(fn):
 
301
 
302
  to_del = [
303
  os.path.join(path_to_models, fn)
304
+ for fn in (
305
+ x_sorted("G")[:-n_ckpts_to_keep]
306
+ + x_sorted("D")[:-n_ckpts_to_keep]
307
+ + x_sorted("WD")[:-n_ckpts_to_keep]
308
+ )
309
  ]
310
 
311
  def del_info(fn):
webui.py CHANGED
@@ -1,4 +1,5 @@
1
  # flake8: noqa: E402
 
2
  import os
3
  import logging
4
  import re_matching
@@ -32,6 +33,14 @@ if device == "mps":
32
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
33
 
34
 
 
 
 
 
 
 
 
 
35
  def generate_audio(
36
  slices,
37
  sdp_ratio,
@@ -42,15 +51,20 @@ def generate_audio(
42
  language,
43
  reference_audio,
44
  emotion,
 
 
45
  skip_start=False,
46
  skip_end=False,
47
  ):
48
  audio_list = []
49
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
 
 
 
50
  with torch.no_grad():
51
  for idx, piece in enumerate(slices):
52
- skip_start = (idx != 0) and skip_start
53
- skip_end = (idx != len(slices) - 1) and skip_end
54
  audio = infer(
55
  piece,
56
  reference_audio=reference_audio,
@@ -66,10 +80,11 @@ def generate_audio(
66
  device=device,
67
  skip_start=skip_start,
68
  skip_end=skip_end,
 
 
69
  )
70
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
71
  audio_list.append(audio16bit)
72
- # audio_list.append(silence) # 将静音添加到列表中
73
  return audio_list
74
 
75
 
@@ -88,10 +103,13 @@ def generate_audio_multilang(
88
  ):
89
  audio_list = []
90
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
 
 
 
91
  with torch.no_grad():
92
  for idx, piece in enumerate(slices):
93
- skip_start = (idx != 0) and skip_start
94
- skip_end = (idx != len(slices) - 1) and skip_end
95
  audio = infer_multilang(
96
  piece,
97
  reference_audio=reference_audio,
@@ -110,7 +128,6 @@ def generate_audio_multilang(
110
  )
111
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
112
  audio_list.append(audio16bit)
113
- # audio_list.append(silence) # 将静音添加到列表中
114
  return audio_list
115
 
116
 
@@ -127,63 +144,50 @@ def tts_split(
127
  interval_between_sent,
128
  reference_audio,
129
  emotion,
 
 
130
  ):
131
- if language == "mix":
132
- return ("invalid", None)
133
  while text.find("\n\n") != -1:
134
  text = text.replace("\n\n", "\n")
 
135
  para_list = re_matching.cut_para(text)
 
136
  audio_list = []
137
- if not cut_by_sent:
138
- for idx, p in enumerate(para_list):
139
- skip_start = idx != 0
140
- skip_end = idx != len(para_list) - 1
141
- audio = infer(
142
  p,
143
- reference_audio=reference_audio,
144
- emotion=emotion,
145
- sdp_ratio=sdp_ratio,
146
- noise_scale=noise_scale,
147
- noise_scale_w=noise_scale_w,
148
- length_scale=length_scale,
149
- sid=speaker,
150
- language=language,
151
- hps=hps,
152
- net_g=net_g,
153
- device=device,
154
- skip_start=skip_start,
155
- skip_end=skip_end,
156
  )
157
- audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
158
- audio_list.append(audio16bit)
159
  silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
160
  audio_list.append(silence)
161
- else:
162
- for idx, p in enumerate(para_list):
163
- skip_start = idx != 0
164
- skip_end = idx != len(para_list) - 1
165
  audio_list_sent = []
166
  sent_list = re_matching.cut_sent(p)
167
- for idx, s in enumerate(sent_list):
168
- skip_start = (idx != 0) and skip_start
169
- skip_end = (idx != len(sent_list) - 1) and skip_end
170
- audio = infer(
171
  s,
172
- reference_audio=reference_audio,
173
- emotion=emotion,
174
- sdp_ratio=sdp_ratio,
175
- noise_scale=noise_scale,
176
- noise_scale_w=noise_scale_w,
177
- length_scale=length_scale,
178
- sid=speaker,
179
- language=language,
180
- hps=hps,
181
- net_g=net_g,
182
- device=device,
183
- skip_start=skip_start,
184
- skip_end=skip_end,
185
  )
186
- audio_list_sent.append(audio)
187
  silence = np.zeros((int)(44100 * interval_between_sent))
188
  audio_list_sent.append(silence)
189
  if (interval_between_para - interval_between_sent) > 0:
@@ -196,10 +200,49 @@ def tts_split(
196
  ) # 对完整句子做音量归一
197
  audio_list.append(audio16bit)
198
  audio_concat = np.concatenate(audio_list)
199
- return ("Success", (44100, audio_concat))
200
 
201
 
202
- def tts_fn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  text: str,
204
  speaker,
205
  sdp_ratio,
@@ -209,15 +252,9 @@ def tts_fn(
209
  language,
210
  reference_audio,
211
  emotion,
212
- prompt_mode,
 
213
  ):
214
- if prompt_mode == "Audio prompt":
215
- if reference_audio == None:
216
- return ("Invalid audio prompt", None)
217
- else:
218
- reference_audio = load_audio(reference_audio)[1]
219
- else:
220
- reference_audio = None
221
  audio_list = []
222
  if language == "mix":
223
  bool_valid, str_valid = re_matching.validate_text(text)
@@ -226,120 +263,40 @@ def tts_fn(
226
  hps.data.sampling_rate,
227
  np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
228
  )
229
- result = []
230
  for slice in re_matching.text_matching(text):
231
- _speaker = slice.pop()
232
- temp_contant = []
233
- temp_lang = []
234
- for lang, content in slice:
235
- if "|" in content:
236
- temp = []
237
- temp_ = []
238
- for i in content.split("|"):
239
- if i != "":
240
- temp.append([i])
241
- temp_.append([lang])
242
- else:
243
- temp.append([])
244
- temp_.append([])
245
- temp_contant += temp
246
- temp_lang += temp_
247
- else:
248
- if len(temp_contant) == 0:
249
- temp_contant.append([])
250
- temp_lang.append([])
251
- temp_contant[-1].append(content)
252
- temp_lang[-1].append(lang)
253
- for i, j in zip(temp_lang, temp_contant):
254
- result.append([*zip(i, j), _speaker])
255
- for i, one in enumerate(result):
256
- skip_start = i != 0
257
- skip_end = i != len(result) - 1
258
- _speaker = one.pop()
259
- idx = 0
260
- while idx < len(one):
261
- text_to_generate = []
262
- lang_to_generate = []
263
- while True:
264
- lang, content = one[idx]
265
- temp_text = [content]
266
- if len(text_to_generate) > 0:
267
- text_to_generate[-1] += [temp_text.pop(0)]
268
- lang_to_generate[-1] += [lang]
269
- if len(temp_text) > 0:
270
- text_to_generate += [[i] for i in temp_text]
271
- lang_to_generate += [[lang]] * len(temp_text)
272
- if idx + 1 < len(one):
273
- idx += 1
274
- else:
275
- break
276
- skip_start = (idx != 0) and skip_start
277
- skip_end = (idx != len(one) - 1) and skip_end
278
- print(text_to_generate, lang_to_generate)
279
- audio_list.extend(
280
- generate_audio_multilang(
281
- text_to_generate,
282
- sdp_ratio,
283
- noise_scale,
284
- noise_scale_w,
285
- length_scale,
286
- _speaker,
287
- lang_to_generate,
288
- reference_audio,
289
- emotion,
290
- skip_start,
291
- skip_end,
292
- )
293
  )
294
- idx += 1
295
  elif language.lower() == "auto":
296
- for idx, slice in enumerate(text.split("|")):
297
- if slice == "":
298
- continue
299
- skip_start = idx != 0
300
- skip_end = idx != len(text.split("|")) - 1
301
- sentences_list = split_by_language(
302
- slice, target_languages=["zh", "ja", "en"]
 
 
 
 
 
 
303
  )
304
- idx = 0
305
- while idx < len(sentences_list):
306
- text_to_generate = []
307
- lang_to_generate = []
308
- while True:
309
- content, lang = sentences_list[idx]
310
- temp_text = [content]
311
- lang = lang.upper()
312
- if lang == "JA":
313
- lang = "JP"
314
- if len(text_to_generate) > 0:
315
- text_to_generate[-1] += [temp_text.pop(0)]
316
- lang_to_generate[-1] += [lang]
317
- if len(temp_text) > 0:
318
- text_to_generate += [[i] for i in temp_text]
319
- lang_to_generate += [[lang]] * len(temp_text)
320
- if idx + 1 < len(sentences_list):
321
- idx += 1
322
- else:
323
- break
324
- skip_start = (idx != 0) and skip_start
325
- skip_end = (idx != len(sentences_list) - 1) and skip_end
326
- print(text_to_generate, lang_to_generate)
327
- audio_list.extend(
328
- generate_audio_multilang(
329
- text_to_generate,
330
- sdp_ratio,
331
- noise_scale,
332
- noise_scale_w,
333
- length_scale,
334
- speaker,
335
- lang_to_generate,
336
- reference_audio,
337
- emotion,
338
- skip_start,
339
- skip_end,
340
- )
341
- )
342
- idx += 1
343
  else:
344
  audio_list.extend(
345
  generate_audio(
@@ -352,13 +309,65 @@ def tts_fn(
352
  language,
353
  reference_audio,
354
  emotion,
 
 
355
  )
356
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  audio_concat = np.concatenate(audio_list)
359
  return "Success", (hps.data.sampling_rate, audio_concat)
360
 
361
 
 
 
 
 
 
 
 
 
 
 
362
  def load_audio(path):
363
  audio, sr = librosa.load(path, 48000)
364
  # audio = librosa.resample(audio, 44100, 48000)
@@ -408,34 +417,37 @@ if __name__ == "__main__":
408
  )
409
  trans = gr.Button("中翻日", variant="primary")
410
  slicer = gr.Button("快速切分", variant="primary")
 
411
  speaker = gr.Dropdown(
412
  choices=speakers, value=speakers[0], label="Speaker"
413
  )
414
  _ = gr.Markdown(
415
- value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n"
 
416
  )
417
  prompt_mode = gr.Radio(
418
  ["Text prompt", "Audio prompt"],
419
  label="Prompt Mode",
420
  value="Text prompt",
 
421
  )
422
  text_prompt = gr.Textbox(
423
  label="Text prompt",
424
  placeholder="用文字描述生成风格。如:Happy",
425
  value="Happy",
426
- visible=True,
427
  )
428
  audio_prompt = gr.Audio(
429
  label="Audio prompt", type="filepath", visible=False
430
  )
431
  sdp_ratio = gr.Slider(
432
- minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
433
  )
434
  noise_scale = gr.Slider(
435
  minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
436
  )
437
  noise_scale_w = gr.Slider(
438
- minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise_W"
439
  )
440
  length_scale = gr.Slider(
441
  minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
@@ -445,6 +457,21 @@ if __name__ == "__main__":
445
  )
446
  btn = gr.Button("生成音频!", variant="primary")
447
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  with gr.Row():
449
  with gr.Column():
450
  interval_between_sent = gr.Slider(
@@ -487,6 +514,8 @@ if __name__ == "__main__":
487
  audio_prompt,
488
  text_prompt,
489
  prompt_mode,
 
 
490
  ],
491
  outputs=[text_output, audio_output],
492
  )
@@ -511,6 +540,8 @@ if __name__ == "__main__":
511
  interval_between_sent,
512
  audio_prompt,
513
  text_prompt,
 
 
514
  ],
515
  outputs=[text_output, audio_output],
516
  )
@@ -527,6 +558,12 @@ if __name__ == "__main__":
527
  outputs=[audio_prompt],
528
  )
529
 
 
 
 
 
 
 
530
  print("推理页面已开启!")
531
  webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
532
  app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
 
1
  # flake8: noqa: E402
2
+ import gc
3
  import os
4
  import logging
5
  import re_matching
 
33
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
34
 
35
 
36
+ def free_up_memory():
37
+ # Prior inference run might have large variables not cleaned up due to exception during the run.
38
+ # Free up as much memory as possible to allow this run to be successful.
39
+ gc.collect()
40
+ if torch.cuda.is_available():
41
+ torch.cuda.empty_cache()
42
+
43
+
44
  def generate_audio(
45
  slices,
46
  sdp_ratio,
 
51
  language,
52
  reference_audio,
53
  emotion,
54
+ style_text,
55
+ style_weight,
56
  skip_start=False,
57
  skip_end=False,
58
  ):
59
  audio_list = []
60
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
61
+
62
+ free_up_memory()
63
+
64
  with torch.no_grad():
65
  for idx, piece in enumerate(slices):
66
+ skip_start = idx != 0
67
+ skip_end = idx != len(slices) - 1
68
  audio = infer(
69
  piece,
70
  reference_audio=reference_audio,
 
80
  device=device,
81
  skip_start=skip_start,
82
  skip_end=skip_end,
83
+ style_text=style_text,
84
+ style_weight=style_weight,
85
  )
86
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
87
  audio_list.append(audio16bit)
 
88
  return audio_list
89
 
90
 
 
103
  ):
104
  audio_list = []
105
  # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
106
+
107
+ free_up_memory()
108
+
109
  with torch.no_grad():
110
  for idx, piece in enumerate(slices):
111
+ skip_start = idx != 0
112
+ skip_end = idx != len(slices) - 1
113
  audio = infer_multilang(
114
  piece,
115
  reference_audio=reference_audio,
 
128
  )
129
  audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
130
  audio_list.append(audio16bit)
 
131
  return audio_list
132
 
133
 
 
144
  interval_between_sent,
145
  reference_audio,
146
  emotion,
147
+ style_text,
148
+ style_weight,
149
  ):
 
 
150
  while text.find("\n\n") != -1:
151
  text = text.replace("\n\n", "\n")
152
+ text = text.replace("|", "")
153
  para_list = re_matching.cut_para(text)
154
+ para_list = [p for p in para_list if p != ""]
155
  audio_list = []
156
+ for p in para_list:
157
+ if not cut_by_sent:
158
+ audio_list += process_text(
 
 
159
  p,
160
+ speaker,
161
+ sdp_ratio,
162
+ noise_scale,
163
+ noise_scale_w,
164
+ length_scale,
165
+ language,
166
+ reference_audio,
167
+ emotion,
168
+ style_text,
169
+ style_weight,
 
 
 
170
  )
 
 
171
  silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
172
  audio_list.append(silence)
173
+ else:
 
 
 
174
  audio_list_sent = []
175
  sent_list = re_matching.cut_sent(p)
176
+ sent_list = [s for s in sent_list if s != ""]
177
+ for s in sent_list:
178
+ audio_list_sent += process_text(
 
179
  s,
180
+ speaker,
181
+ sdp_ratio,
182
+ noise_scale,
183
+ noise_scale_w,
184
+ length_scale,
185
+ language,
186
+ reference_audio,
187
+ emotion,
188
+ style_text,
189
+ style_weight,
 
 
 
190
  )
 
191
  silence = np.zeros((int)(44100 * interval_between_sent))
192
  audio_list_sent.append(silence)
193
  if (interval_between_para - interval_between_sent) > 0:
 
200
  ) # 对完整句子做音量归一
201
  audio_list.append(audio16bit)
202
  audio_concat = np.concatenate(audio_list)
203
+ return ("Success", (hps.data.sampling_rate, audio_concat))
204
 
205
 
206
+ def process_mix(slice):
207
+ _speaker = slice.pop()
208
+ _text, _lang = [], []
209
+ for lang, content in slice:
210
+ content = content.split("|")
211
+ content = [part for part in content if part != ""]
212
+ if len(content) == 0:
213
+ continue
214
+ if len(_text) == 0:
215
+ _text = [[part] for part in content]
216
+ _lang = [[lang] for part in content]
217
+ else:
218
+ _text[-1].append(content[0])
219
+ _lang[-1].append(lang)
220
+ if len(content) > 1:
221
+ _text += [[part] for part in content[1:]]
222
+ _lang += [[lang] for part in content[1:]]
223
+ return _text, _lang, _speaker
224
+
225
+
226
+ def process_auto(text):
227
+ _text, _lang = [], []
228
+ for slice in text.split("|"):
229
+ if slice == "":
230
+ continue
231
+ temp_text, temp_lang = [], []
232
+ sentences_list = split_by_language(slice, target_languages=["zh", "ja", "en"])
233
+ for sentence, lang in sentences_list:
234
+ if sentence == "":
235
+ continue
236
+ temp_text.append(sentence)
237
+ if lang == "ja":
238
+ lang = "jp"
239
+ temp_lang.append(lang.upper())
240
+ _text.append(temp_text)
241
+ _lang.append(temp_lang)
242
+ return _text, _lang
243
+
244
+
245
+ def process_text(
246
  text: str,
247
  speaker,
248
  sdp_ratio,
 
252
  language,
253
  reference_audio,
254
  emotion,
255
+ style_text=None,
256
+ style_weight=0,
257
  ):
 
 
 
 
 
 
 
258
  audio_list = []
259
  if language == "mix":
260
  bool_valid, str_valid = re_matching.validate_text(text)
 
263
  hps.data.sampling_rate,
264
  np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
265
  )
 
266
  for slice in re_matching.text_matching(text):
267
+ _text, _lang, _speaker = process_mix(slice)
268
+ if _speaker is None:
269
+ continue
270
+ print(f"Text: {_text}\nLang: {_lang}")
271
+ audio_list.extend(
272
+ generate_audio_multilang(
273
+ _text,
274
+ sdp_ratio,
275
+ noise_scale,
276
+ noise_scale_w,
277
+ length_scale,
278
+ _speaker,
279
+ _lang,
280
+ reference_audio,
281
+ emotion,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  )
283
+ )
284
  elif language.lower() == "auto":
285
+ _text, _lang = process_auto(text)
286
+ print(f"Text: {_text}\nLang: {_lang}")
287
+ audio_list.extend(
288
+ generate_audio_multilang(
289
+ _text,
290
+ sdp_ratio,
291
+ noise_scale,
292
+ noise_scale_w,
293
+ length_scale,
294
+ speaker,
295
+ _lang,
296
+ reference_audio,
297
+ emotion,
298
  )
299
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  else:
301
  audio_list.extend(
302
  generate_audio(
 
309
  language,
310
  reference_audio,
311
  emotion,
312
+ style_text,
313
+ style_weight,
314
  )
315
  )
316
+ return audio_list
317
+
318
+
319
+ def tts_fn(
320
+ text: str,
321
+ speaker,
322
+ sdp_ratio,
323
+ noise_scale,
324
+ noise_scale_w,
325
+ length_scale,
326
+ language,
327
+ reference_audio,
328
+ emotion,
329
+ prompt_mode,
330
+ style_text=None,
331
+ style_weight=0,
332
+ ):
333
+ if style_text == "":
334
+ style_text = None
335
+ if prompt_mode == "Audio prompt":
336
+ if reference_audio == None:
337
+ return ("Invalid audio prompt", None)
338
+ else:
339
+ reference_audio = load_audio(reference_audio)[1]
340
+ else:
341
+ reference_audio = None
342
+
343
+ audio_list = process_text(
344
+ text,
345
+ speaker,
346
+ sdp_ratio,
347
+ noise_scale,
348
+ noise_scale_w,
349
+ length_scale,
350
+ language,
351
+ reference_audio,
352
+ emotion,
353
+ style_text,
354
+ style_weight,
355
+ )
356
 
357
  audio_concat = np.concatenate(audio_list)
358
  return "Success", (hps.data.sampling_rate, audio_concat)
359
 
360
 
361
+ def format_utils(text, speaker):
362
+ _text, _lang = process_auto(text)
363
+ res = f"[{speaker}]"
364
+ for lang_s, content_s in zip(_lang, _text):
365
+ for lang, content in zip(lang_s, content_s):
366
+ res += f"<{lang.lower()}>{content}"
367
+ res += "|"
368
+ return "mix", res[:-1]
369
+
370
+
371
  def load_audio(path):
372
  audio, sr = librosa.load(path, 48000)
373
  # audio = librosa.resample(audio, 44100, 48000)
 
417
  )
418
  trans = gr.Button("中翻日", variant="primary")
419
  slicer = gr.Button("快速切分", variant="primary")
420
+ formatter = gr.Button("检测语言,并整理为 MIX 格式", variant="primary")
421
  speaker = gr.Dropdown(
422
  choices=speakers, value=speakers[0], label="Speaker"
423
  )
424
  _ = gr.Markdown(
425
+ value="提示模式(Prompt mode):可选文字提示或音频提示,用于生成文字或音频指定风格的声音。\n",
426
+ visible=False,
427
  )
428
  prompt_mode = gr.Radio(
429
  ["Text prompt", "Audio prompt"],
430
  label="Prompt Mode",
431
  value="Text prompt",
432
+ visible=False,
433
  )
434
  text_prompt = gr.Textbox(
435
  label="Text prompt",
436
  placeholder="用文字描述生成风格。如:Happy",
437
  value="Happy",
438
+ visible=False,
439
  )
440
  audio_prompt = gr.Audio(
441
  label="Audio prompt", type="filepath", visible=False
442
  )
443
  sdp_ratio = gr.Slider(
444
+ minimum=0, maximum=1, value=0.5, step=0.1, label="SDP Ratio"
445
  )
446
  noise_scale = gr.Slider(
447
  minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
448
  )
449
  noise_scale_w = gr.Slider(
450
+ minimum=0.1, maximum=2, value=0.9, step=0.1, label="Noise_W"
451
  )
452
  length_scale = gr.Slider(
453
  minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
 
457
  )
458
  btn = gr.Button("生成音频!", variant="primary")
459
  with gr.Column():
460
+ with gr.Accordion("融合文本语义", open=False):
461
+ gr.Markdown(
462
+ value="使用辅助文本的语意来辅助生成对话(语言保持与主文本相同)\n\n"
463
+ "**注意**:不要使用**指令式文本**(如:开心),要使用**带有强烈情感的文本**(如:我好快乐!!!)\n\n"
464
+ "效果较不明确,留空即为不使用该功能"
465
+ )
466
+ style_text = gr.Textbox(label="辅助文本")
467
+ style_weight = gr.Slider(
468
+ minimum=0,
469
+ maximum=1,
470
+ value=0.7,
471
+ step=0.1,
472
+ label="Weight",
473
+ info="主文本和辅助文本的bert混合比率,0表示仅主文本,1表示仅辅助文本",
474
+ )
475
  with gr.Row():
476
  with gr.Column():
477
  interval_between_sent = gr.Slider(
 
514
  audio_prompt,
515
  text_prompt,
516
  prompt_mode,
517
+ style_text,
518
+ style_weight,
519
  ],
520
  outputs=[text_output, audio_output],
521
  )
 
540
  interval_between_sent,
541
  audio_prompt,
542
  text_prompt,
543
+ style_text,
544
+ style_weight,
545
  ],
546
  outputs=[text_output, audio_output],
547
  )
 
558
  outputs=[audio_prompt],
559
  )
560
 
561
+ formatter.click(
562
+ format_utils,
563
+ inputs=[text, speaker],
564
+ outputs=[language, text],
565
+ )
566
+
567
  print("推理页面已开启!")
568
  webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
569
  app.launch(share=config.webui_config.share, server_port=config.webui_config.port)
webui_preprocess.py CHANGED
@@ -19,9 +19,9 @@ def generate_config(data_dir, batch_size):
19
  assert data_dir != "", "数据集名称不能为空"
20
  start_path, _, train_path, val_path, config_path = get_path(data_dir)
21
  if os.path.isfile(config_path):
22
- config = json.load(open(config_path))
23
  else:
24
- config = json.load(open("configs/config.json"))
25
  config["data"]["training_files"] = train_path
26
  config["data"]["validation_files"] = val_path
27
  config["train"]["batch_size"] = batch_size
@@ -44,7 +44,7 @@ def resample(data_dir):
44
  in_dir = os.path.join(start_path, "raw")
45
  out_dir = os.path.join(start_path, "wavs")
46
  subprocess.run(
47
- f"python resample.py "
48
  f"--sr 44100 "
49
  f"--in_dir {in_dir} "
50
  f"--out_dir {out_dir} ",
@@ -60,7 +60,9 @@ def preprocess_text(data_dir):
60
  with open(lbl_path, "w", encoding="utf-8") as f:
61
  for line in lines:
62
  path, spk, language, text = line.strip().split("|")
63
- path = os.path.join(start_path, "wavs", os.path.basename(path))
 
 
64
  f.writelines(f"{path}|{spk}|{language}|{text}\n")
65
  subprocess.run(
66
  f"python preprocess_text.py "
@@ -83,16 +85,6 @@ def bert_gen(data_dir):
83
  return "BERT 特征文件生成完成"
84
 
85
 
86
- def clap_gen(data_dir):
87
- assert data_dir != "", "数据集名称不能为空"
88
- _, _, _, _, config_path = get_path(data_dir)
89
- subprocess.run(
90
- f"python clap_gen.py " f"--config {config_path}",
91
- shell=True,
92
- )
93
- return "CLAP 特征文件生成完成"
94
-
95
-
96
  if __name__ == "__main__":
97
  with gr.Blocks() as app:
98
  with gr.Row():
@@ -100,13 +92,13 @@ if __name__ == "__main__":
100
  _ = gr.Markdown(
101
  value="# Bert-VITS2 数据预处理\n"
102
  "## 预先准备:\n"
103
- "下载 BERT 和 CLAP 模型:\n"
104
  "- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
105
  "- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
106
  "- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
107
- "- [CLAP](https://huggingface.co/laion/clap-htsat-fused)\n"
108
  "\n"
109
- "将 BERT 模型放置到 `bert` 文件夹下,CLAP 模型放置到 `emotional` 文件夹下,覆盖同名文件夹。\n"
110
  "\n"
111
  "数据准备:\n"
112
  "将数据放置在 data 文件夹下,按照如下结构组织:\n"
@@ -156,12 +148,10 @@ if __name__ == "__main__":
156
  preprocess_text_btn = gr.Button(value="执行", variant="primary")
157
  _ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
158
  bert_gen_btn = gr.Button(value="执行", variant="primary")
159
- _ = gr.Markdown(value="## 第五步:生成 CLAP 特征文件")
160
- clap_gen_btn = gr.Button(value="执行", variant="primary")
161
  _ = gr.Markdown(
162
  value="## 训练模型及部署:\n"
163
  "修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
164
- "- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
165
  "- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
166
  )
167
 
@@ -171,7 +161,6 @@ if __name__ == "__main__":
171
  resample_btn.click(resample, inputs=[data_dir], outputs=[info])
172
  preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
173
  bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
174
- clap_gen_btn.click(clap_gen, inputs=[data_dir], outputs=[info])
175
 
176
  webbrowser.open("http://127.0.0.1:7860")
177
  app.launch(share=False, server_port=7860)
 
19
  assert data_dir != "", "数据集名称不能为空"
20
  start_path, _, train_path, val_path, config_path = get_path(data_dir)
21
  if os.path.isfile(config_path):
22
+ config = json.load(open(config_path, "r", encoding="utf-8"))
23
  else:
24
+ config = json.load(open("configs/config.json", "r", encoding="utf-8"))
25
  config["data"]["training_files"] = train_path
26
  config["data"]["validation_files"] = val_path
27
  config["train"]["batch_size"] = batch_size
 
44
  in_dir = os.path.join(start_path, "raw")
45
  out_dir = os.path.join(start_path, "wavs")
46
  subprocess.run(
47
+ f"python resample_legacy.py "
48
  f"--sr 44100 "
49
  f"--in_dir {in_dir} "
50
  f"--out_dir {out_dir} ",
 
60
  with open(lbl_path, "w", encoding="utf-8") as f:
61
  for line in lines:
62
  path, spk, language, text = line.strip().split("|")
63
+ path = os.path.join(start_path, "wavs", os.path.basename(path)).replace(
64
+ "\\", "/"
65
+ )
66
  f.writelines(f"{path}|{spk}|{language}|{text}\n")
67
  subprocess.run(
68
  f"python preprocess_text.py "
 
85
  return "BERT 特征文件生成完成"
86
 
87
 
 
 
 
 
 
 
 
 
 
 
88
  if __name__ == "__main__":
89
  with gr.Blocks() as app:
90
  with gr.Row():
 
92
  _ = gr.Markdown(
93
  value="# Bert-VITS2 数据预处理\n"
94
  "## 预先准备:\n"
95
+ "下载 BERT 和 WavLM 模型:\n"
96
  "- [中文 RoBERTa](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large)\n"
97
  "- [日文 DeBERTa](https://huggingface.co/ku-nlp/deberta-v2-large-japanese-char-wwm)\n"
98
  "- [英文 DeBERTa](https://huggingface.co/microsoft/deberta-v3-large)\n"
99
+ "- [WavLM](https://huggingface.co/microsoft/wavlm-base-plus)\n"
100
  "\n"
101
+ "将 BERT 模型放置到 `bert` 文件夹下,WavLM 模型放置到 `slm` 文件夹下,覆盖同名文件夹。\n"
102
  "\n"
103
  "数据准备:\n"
104
  "将数据放置在 data 文件夹下,按照如下结构组织:\n"
 
148
  preprocess_text_btn = gr.Button(value="执行", variant="primary")
149
  _ = gr.Markdown(value="## 第四步:生成 BERT 特征文件")
150
  bert_gen_btn = gr.Button(value="执行", variant="primary")
 
 
151
  _ = gr.Markdown(
152
  value="## 训练模型及部署:\n"
153
  "修改根目录下的 `config.yml` 中 `dataset_path` 一项为 `data/{你的数据集名称}`\n"
154
+ "- 训练:将[预训练模型文件](https://openi.pcl.ac.cn/Stardust_minus/Bert-VITS2/modelmanage/show_model)(`D_0.pth`、`DUR_0.pth`、`WD_0.pth` 和 `G_0.pth`)放到 `data/{你的数据集名称}/models` 文件夹下,执行 `torchrun --nproc_per_node=1 train_ms.py` 命令(多卡运行可参考 `run_MnodesAndMgpus.sh` 中的命令。\n"
155
  "- 部署:修改根目录下的 `config.yml` 中 `webui` 下 `model` 一项为 `models/{权重文件名}.pth` (如 G_10000.pth),然后执行 `python webui.py`"
156
  )
157
 
 
161
  resample_btn.click(resample, inputs=[data_dir], outputs=[info])
162
  preprocess_text_btn.click(preprocess_text, inputs=[data_dir], outputs=[info])
163
  bert_gen_btn.click(bert_gen, inputs=[data_dir], outputs=[info])
 
164
 
165
  webbrowser.open("http://127.0.0.1:7860")
166
  app.launch(share=False, server_port=7860)