ElesisSiegherts commited on
Commit
1b9cb8c
1 Parent(s): 8900345

Upload 7 files

Browse files
Files changed (7) hide show
  1. config.yml +176 -0
  2. data_utils.py +410 -0
  3. default_config.yml +176 -0
  4. emo_gen.py +155 -0
  5. export_onnx.py +56 -0
  6. get_emo.py +26 -0
  7. infer.py +341 -0
config.yml ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 全局配置
2
+ # 对于希望在同一时间使用多个配置文件的情况,例如两个GPU同时跑两个训练集:通过环境变量指定配置文件,不指定则默认为./config.yml
3
+
4
+ # 拟提供通用路径配置,统一存放数据,避免数据放得很乱
5
+ # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
6
+ # 不填或者填空则路径为相对于项目根目录的路径
7
+ dataset_path: "Data/tamura"
8
+
9
+ # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
10
+ mirror: ""
11
+ openi_token: "" # openi token
12
+
13
+ # resample 音频重采样配置
14
+ # 注意, “:” 后需要加空格
15
+ resample:
16
+ # 目标重采样率
17
+ sampling_rate: 44100
18
+ # 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
19
+ # 请填入相对于datasetPath的相对路径
20
+ in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir
21
+ # 音频文件重采样后输出路径
22
+ out_dir: "audios/wavs"
23
+
24
+
25
+ # preprocess_text 数据集预处理相关配置
26
+ # 注意, “:” 后需要加空格
27
+ preprocess_text:
28
+ # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
29
+ transcription_path: "filelists/text.list"
30
+ # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
31
+ cleaned_path: ""
32
+ # 训练集路径
33
+ train_path: "filelists/train.list"
34
+ # 验证集路径
35
+ val_path: "filelists/val.list"
36
+ # 配置文件路径
37
+ config_path: "config.json"
38
+ # 每个speaker的验证集条数
39
+ val_per_spk: 4
40
+ # 验证集最大条数,多于的会被截断并放到训练集中
41
+ max_val_total: 8
42
+ # 是否进行数据清洗
43
+ clean: true
44
+
45
+
46
+ # bert_gen 相关配置
47
+ # 注意, “:” 后需要加空格
48
+ bert_gen:
49
+ # 训练数据集配置文件路径
50
+ config_path: "config.json"
51
+ # 并行数
52
+ num_processes: 2
53
+ # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
54
+ # 该选项同时决定了get_bert_feature的默认设备
55
+ device: "cuda"
56
+ # 使用多卡推理
57
+ use_multi_device: false
58
+
59
+ # emo_gen 相关配置
60
+ # 注意, “:” 后需要加空格
61
+ emo_gen:
62
+ # 训练数据集配置文件路径
63
+ config_path: "config.json"
64
+ # 并行数
65
+ num_processes: 2
66
+ # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
67
+ device: "cuda"
68
+
69
+ # train 训练配置
70
+ # 注意, “:” 后需要加空格
71
+ train_ms:
72
+ env:
73
+ MASTER_ADDR: "localhost"
74
+ MASTER_PORT: 10086
75
+ WORLD_SIZE: 1
76
+ LOCAL_RANK: 0
77
+ RANK: 0
78
+ # 可以填写任意名的环境变量
79
+ # THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
80
+ # 底模设置
81
+ base:
82
+ use_base_model: false
83
+ repo_id: "Stardust_minus/Bert-VITS2"
84
+ model_image: "Bert-VITS2_2.1-Emo底模" # openi网页的模型名
85
+ # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
86
+ model: "models"
87
+ # 配置文件路径
88
+ config_path: "config.json"
89
+ # 训练使用的worker,不建议超过CPU核心数
90
+ num_workers: 16
91
+ # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
92
+ spec_cache: True
93
+ # 保存的检查点数量,多于此数目的权重会被删除来节省空间。
94
+ keep_ckpts: 8
95
+
96
+
97
+ # webui webui配置
98
+ # 注意, “:” 后需要加空格
99
+ webui:
100
+ # 推理设备
101
+ device: "cuda"
102
+ # 模型路径
103
+ model: "models/G_2750.pth"
104
+ # 配置文件路径
105
+ config_path: "config.json"
106
+ # 端口号
107
+ port: 7860
108
+ # 是否公开部署,对外网开放
109
+ share: false
110
+ # 是否开启debug模式
111
+ debug: false
112
+ # 语种识别库,可选langid, fastlid
113
+ language_identification_library: "langid"
114
+
115
+
116
+ # server api配置
117
+ # 注意, “:” 后需要加空格
118
+ # 注意,本配置下的所有配置均为相对于根目录的路径
119
+ server:
120
+ # 端口号
121
+ port: 5000
122
+ # 模型默认使用设备:但是当前并没有实现这个配置。
123
+ device: "cuda"
124
+ # 需要加载的所有模型的配置,可以填多个模型,也可以不填模型,等网页成功后手动加载模型
125
+ # 不加载模型的配置格式:删除默认给的两个模型配置,给models赋值 [ ],也就是空列表。参考模型2的speakers 即 models: [ ]
126
+ # 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
127
+ # 也可以不填模型,等网页加载成功后手动填写models。
128
+ models:
129
+ - # 模型的路径
130
+ model: ""
131
+ # 模型config.json的路径
132
+ config: ""
133
+ # 模型使用设备,若填写则会覆盖默认配置
134
+ device: "cuda"
135
+ # 模型默认使用的语言
136
+ language: "ZH"
137
+ # 模型人物默认参数
138
+ # 不必填写所有人物,不填的使用默认值
139
+ # 暂时不用填写,当前尚未实现按人区分配置
140
+ speakers:
141
+ - speaker: "科比"
142
+ sdp_ratio: 0.2
143
+ noise_scale: 0.6
144
+ noise_scale_w: 0.8
145
+ length_scale: 1
146
+ - speaker: "五条悟"
147
+ sdp_ratio: 0.3
148
+ noise_scale: 0.7
149
+ noise_scale_w: 0.8
150
+ length_scale: 0.5
151
+ - speaker: "安倍晋三"
152
+ sdp_ratio: 0.2
153
+ noise_scale: 0.6
154
+ noise_scale_w: 0.8
155
+ length_scale: 1.2
156
+ - # 模型的路径
157
+ model: ""
158
+ # 模型config.json的路径
159
+ config: ""
160
+ # 模型使用设备,若填写则会覆盖默认配置
161
+ device: "cpu"
162
+ # 模型默认使用的语言
163
+ language: "JP"
164
+ # 模型人物默认参数
165
+ # 不必填写所有人物,不填的使用默认值
166
+ speakers: [ ] # 也可以不填
167
+
168
+
169
+ # 百度翻译开放平台 api配置
170
+ # api接入文档 https://api.fanyi.baidu.com/doc/21
171
+ # 请不要在github等网站公开分享你的app id 与 key
172
+ translate:
173
+ # 你的APPID
174
+ "app_key": ""
175
+ # 你的密钥
176
+ "secret_key": ""
data_utils.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torch.utils.data
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from tools.log import logger
8
+ import commons
9
+ from mel_processing import spectrogram_torch, mel_spectrogram_torch
10
+ from utils import load_wav_to_torch, load_filepaths_and_text
11
+ from text import cleaned_text_to_sequence
12
+ from config import config
13
+
14
+ """Multi speaker version"""
15
+
16
+
17
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
18
+ """
19
+ 1) loads audio, speaker_id, text pairs
20
+ 2) normalizes text and converts them to sequences of integers
21
+ 3) computes spectrograms from audio files.
22
+ """
23
+
24
+ def __init__(self, audiopaths_sid_text, hparams):
25
+ self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
26
+ self.max_wav_value = hparams.max_wav_value
27
+ self.sampling_rate = hparams.sampling_rate
28
+ self.filter_length = hparams.filter_length
29
+ self.hop_length = hparams.hop_length
30
+ self.win_length = hparams.win_length
31
+ self.sampling_rate = hparams.sampling_rate
32
+ self.spk_map = hparams.spk2id
33
+ self.hparams = hparams
34
+
35
+ self.use_mel_spec_posterior = getattr(
36
+ hparams, "use_mel_posterior_encoder", False
37
+ )
38
+ if self.use_mel_spec_posterior:
39
+ self.n_mel_channels = getattr(hparams, "n_mel_channels", 80)
40
+
41
+ self.cleaned_text = getattr(hparams, "cleaned_text", False)
42
+
43
+ self.add_blank = hparams.add_blank
44
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
45
+ self.max_text_len = getattr(hparams, "max_text_len", 384)
46
+
47
+ random.seed(1234)
48
+ random.shuffle(self.audiopaths_sid_text)
49
+ self._filter()
50
+
51
+ def _filter(self):
52
+ """
53
+ Filter text & store spec lengths
54
+ """
55
+ # Store spectrogram lengths for Bucketing
56
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
57
+ # spec_length = wav_length // hop_length
58
+
59
+ audiopaths_sid_text_new = []
60
+ lengths = []
61
+ skipped = 0
62
+ logger.info("Init dataset...")
63
+ for _id, spk, language, text, phones, tone, word2ph in tqdm(
64
+ self.audiopaths_sid_text
65
+ ):
66
+ audiopath = f"{_id}"
67
+ if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len:
68
+ phones = phones.split(" ")
69
+ tone = [int(i) for i in tone.split(" ")]
70
+ word2ph = [int(i) for i in word2ph.split(" ")]
71
+ audiopaths_sid_text_new.append(
72
+ [audiopath, spk, language, text, phones, tone, word2ph]
73
+ )
74
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
75
+ else:
76
+ skipped += 1
77
+ logger.info(
78
+ "skipped: "
79
+ + str(skipped)
80
+ + ", total: "
81
+ + str(len(self.audiopaths_sid_text))
82
+ )
83
+ self.audiopaths_sid_text = audiopaths_sid_text_new
84
+ self.lengths = lengths
85
+
86
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
87
+ # separate filename, speaker_id and text
88
+ audiopath, sid, language, text, phones, tone, word2ph = audiopath_sid_text
89
+
90
+ bert, ja_bert, en_bert, phones, tone, language = self.get_text(
91
+ text, word2ph, phones, tone, language, audiopath
92
+ )
93
+
94
+ spec, wav = self.get_audio(audiopath)
95
+ sid = torch.LongTensor([int(self.spk_map[sid])])
96
+ emo = torch.FloatTensor(np.load(audiopath.replace(".wav", ".emo.npy")))
97
+ return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
98
+
99
+ def get_audio(self, filename):
100
+ audio, sampling_rate = load_wav_to_torch(filename)
101
+ if sampling_rate != self.sampling_rate:
102
+ raise ValueError(
103
+ "{} {} SR doesn't match target {} SR".format(
104
+ filename, sampling_rate, self.sampling_rate
105
+ )
106
+ )
107
+ audio_norm = audio / self.max_wav_value
108
+ audio_norm = audio_norm.unsqueeze(0)
109
+ spec_filename = filename.replace(".wav", ".spec.pt")
110
+ if self.use_mel_spec_posterior:
111
+ spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
112
+ try:
113
+ spec = torch.load(spec_filename)
114
+ except:
115
+ if self.use_mel_spec_posterior:
116
+ spec = mel_spectrogram_torch(
117
+ audio_norm,
118
+ self.filter_length,
119
+ self.n_mel_channels,
120
+ self.sampling_rate,
121
+ self.hop_length,
122
+ self.win_length,
123
+ self.hparams.mel_fmin,
124
+ self.hparams.mel_fmax,
125
+ center=False,
126
+ )
127
+ else:
128
+ spec = spectrogram_torch(
129
+ audio_norm,
130
+ self.filter_length,
131
+ self.sampling_rate,
132
+ self.hop_length,
133
+ self.win_length,
134
+ center=False,
135
+ )
136
+ spec = torch.squeeze(spec, 0)
137
+ if config.train_ms_config.spec_cache:
138
+ torch.save(spec, spec_filename)
139
+ return spec, audio_norm
140
+
141
+ def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
142
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
143
+ if self.add_blank:
144
+ phone = commons.intersperse(phone, 0)
145
+ tone = commons.intersperse(tone, 0)
146
+ language = commons.intersperse(language, 0)
147
+ for i in range(len(word2ph)):
148
+ word2ph[i] = word2ph[i] * 2
149
+ word2ph[0] += 1
150
+ bert_path = wav_path.replace(".wav", ".bert.pt")
151
+ try:
152
+ bert_ori = torch.load(bert_path)
153
+ assert bert_ori.shape[-1] == len(phone)
154
+ except Exception as e:
155
+ logger.warning("Bert load Failed")
156
+ logger.warning(e)
157
+
158
+ if language_str == "ZH":
159
+ bert = bert_ori
160
+ ja_bert = torch.zeros(1024, len(phone))
161
+ en_bert = torch.zeros(1024, len(phone))
162
+ elif language_str == "JP":
163
+ bert = torch.zeros(1024, len(phone))
164
+ ja_bert = bert_ori
165
+ en_bert = torch.zeros(1024, len(phone))
166
+ elif language_str == "EN":
167
+ bert = torch.zeros(1024, len(phone))
168
+ ja_bert = torch.zeros(1024, len(phone))
169
+ en_bert = bert_ori
170
+ phone = torch.LongTensor(phone)
171
+ tone = torch.LongTensor(tone)
172
+ language = torch.LongTensor(language)
173
+ return bert, ja_bert, en_bert, phone, tone, language
174
+
175
+ def get_sid(self, sid):
176
+ sid = torch.LongTensor([int(sid)])
177
+ return sid
178
+
179
+ def __getitem__(self, index):
180
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
181
+
182
+ def __len__(self):
183
+ return len(self.audiopaths_sid_text)
184
+
185
+
186
+ class TextAudioSpeakerCollate:
187
+ """Zero-pads model inputs and targets"""
188
+
189
+ def __init__(self, return_ids=False):
190
+ self.return_ids = return_ids
191
+
192
+ def __call__(self, batch):
193
+ """Collate's training batch from normalized text, audio and speaker identities
194
+ PARAMS
195
+ ------
196
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
197
+ """
198
+ # Right zero-pad all one-hot text sequences to max input length
199
+ _, ids_sorted_decreasing = torch.sort(
200
+ torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
201
+ )
202
+
203
+ max_text_len = max([len(x[0]) for x in batch])
204
+ max_spec_len = max([x[1].size(1) for x in batch])
205
+ max_wav_len = max([x[2].size(1) for x in batch])
206
+
207
+ text_lengths = torch.LongTensor(len(batch))
208
+ spec_lengths = torch.LongTensor(len(batch))
209
+ wav_lengths = torch.LongTensor(len(batch))
210
+ sid = torch.LongTensor(len(batch))
211
+
212
+ text_padded = torch.LongTensor(len(batch), max_text_len)
213
+ tone_padded = torch.LongTensor(len(batch), max_text_len)
214
+ language_padded = torch.LongTensor(len(batch), max_text_len)
215
+ bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
216
+ ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
217
+ en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
218
+ emo = torch.FloatTensor(len(batch), 1024)
219
+
220
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
221
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
222
+ text_padded.zero_()
223
+ tone_padded.zero_()
224
+ language_padded.zero_()
225
+ spec_padded.zero_()
226
+ wav_padded.zero_()
227
+ bert_padded.zero_()
228
+ ja_bert_padded.zero_()
229
+ en_bert_padded.zero_()
230
+ emo.zero_()
231
+
232
+ for i in range(len(ids_sorted_decreasing)):
233
+ row = batch[ids_sorted_decreasing[i]]
234
+
235
+ text = row[0]
236
+ text_padded[i, : text.size(0)] = text
237
+ text_lengths[i] = text.size(0)
238
+
239
+ spec = row[1]
240
+ spec_padded[i, :, : spec.size(1)] = spec
241
+ spec_lengths[i] = spec.size(1)
242
+
243
+ wav = row[2]
244
+ wav_padded[i, :, : wav.size(1)] = wav
245
+ wav_lengths[i] = wav.size(1)
246
+
247
+ sid[i] = row[3]
248
+
249
+ tone = row[4]
250
+ tone_padded[i, : tone.size(0)] = tone
251
+
252
+ language = row[5]
253
+ language_padded[i, : language.size(0)] = language
254
+
255
+ bert = row[6]
256
+ bert_padded[i, :, : bert.size(1)] = bert
257
+
258
+ ja_bert = row[7]
259
+ ja_bert_padded[i, :, : ja_bert.size(1)] = ja_bert
260
+
261
+ en_bert = row[8]
262
+ en_bert_padded[i, :, : en_bert.size(1)] = en_bert
263
+
264
+ emo[i, :] = row[9]
265
+
266
+ return (
267
+ text_padded,
268
+ text_lengths,
269
+ spec_padded,
270
+ spec_lengths,
271
+ wav_padded,
272
+ wav_lengths,
273
+ sid,
274
+ tone_padded,
275
+ language_padded,
276
+ bert_padded,
277
+ ja_bert_padded,
278
+ en_bert_padded,
279
+ emo,
280
+ )
281
+
282
+
283
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
284
+ """
285
+ Maintain similar input lengths in a batch.
286
+ Length groups are specified by boundaries.
287
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
288
+
289
+ It removes samples which are not included in the boundaries.
290
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ dataset,
296
+ batch_size,
297
+ boundaries,
298
+ num_replicas=None,
299
+ rank=None,
300
+ shuffle=True,
301
+ ):
302
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
303
+ self.lengths = dataset.lengths
304
+ self.batch_size = batch_size
305
+ self.boundaries = boundaries
306
+
307
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
308
+ self.total_size = sum(self.num_samples_per_bucket)
309
+ self.num_samples = self.total_size // self.num_replicas
310
+
311
+ def _create_buckets(self):
312
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
313
+ for i in range(len(self.lengths)):
314
+ length = self.lengths[i]
315
+ idx_bucket = self._bisect(length)
316
+ if idx_bucket != -1:
317
+ buckets[idx_bucket].append(i)
318
+
319
+ try:
320
+ for i in range(len(buckets) - 1, 0, -1):
321
+ if len(buckets[i]) == 0:
322
+ buckets.pop(i)
323
+ self.boundaries.pop(i + 1)
324
+ assert all(len(bucket) > 0 for bucket in buckets)
325
+ # When one bucket is not traversed
326
+ except Exception as e:
327
+ print("Bucket warning ", e)
328
+ for i in range(len(buckets) - 1, -1, -1):
329
+ if len(buckets[i]) == 0:
330
+ buckets.pop(i)
331
+ self.boundaries.pop(i + 1)
332
+
333
+ num_samples_per_bucket = []
334
+ for i in range(len(buckets)):
335
+ len_bucket = len(buckets[i])
336
+ total_batch_size = self.num_replicas * self.batch_size
337
+ rem = (
338
+ total_batch_size - (len_bucket % total_batch_size)
339
+ ) % total_batch_size
340
+ num_samples_per_bucket.append(len_bucket + rem)
341
+ return buckets, num_samples_per_bucket
342
+
343
+ def __iter__(self):
344
+ # deterministically shuffle based on epoch
345
+ g = torch.Generator()
346
+ g.manual_seed(self.epoch)
347
+
348
+ indices = []
349
+ if self.shuffle:
350
+ for bucket in self.buckets:
351
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
352
+ else:
353
+ for bucket in self.buckets:
354
+ indices.append(list(range(len(bucket))))
355
+
356
+ batches = []
357
+ for i in range(len(self.buckets)):
358
+ bucket = self.buckets[i]
359
+ len_bucket = len(bucket)
360
+ if len_bucket == 0:
361
+ continue
362
+ ids_bucket = indices[i]
363
+ num_samples_bucket = self.num_samples_per_bucket[i]
364
+
365
+ # add extra samples to make it evenly divisible
366
+ rem = num_samples_bucket - len_bucket
367
+ ids_bucket = (
368
+ ids_bucket
369
+ + ids_bucket * (rem // len_bucket)
370
+ + ids_bucket[: (rem % len_bucket)]
371
+ )
372
+
373
+ # subsample
374
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
375
+
376
+ # batching
377
+ for j in range(len(ids_bucket) // self.batch_size):
378
+ batch = [
379
+ bucket[idx]
380
+ for idx in ids_bucket[
381
+ j * self.batch_size : (j + 1) * self.batch_size
382
+ ]
383
+ ]
384
+ batches.append(batch)
385
+
386
+ if self.shuffle:
387
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
388
+ batches = [batches[i] for i in batch_ids]
389
+ self.batches = batches
390
+
391
+ assert len(self.batches) * self.batch_size == self.num_samples
392
+ return iter(self.batches)
393
+
394
+ def _bisect(self, x, lo=0, hi=None):
395
+ if hi is None:
396
+ hi = len(self.boundaries) - 1
397
+
398
+ if hi > lo:
399
+ mid = (hi + lo) // 2
400
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
401
+ return mid
402
+ elif x <= self.boundaries[mid]:
403
+ return self._bisect(x, lo, mid)
404
+ else:
405
+ return self._bisect(x, mid + 1, hi)
406
+ else:
407
+ return -1
408
+
409
+ def __len__(self):
410
+ return self.num_samples // self.batch_size
default_config.yml ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 全局配置
2
+ # 对于希望在同一时间使用多个配置文件的情况,例如两个GPU同时跑两个训练集:通过环境变量指定配置文件,不指定则默认为./config.yml
3
+
4
+ # 拟提供通用路径配置,统一存放数据,避免数据放得很乱
5
+ # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
6
+ # 不填或者填空则路径为相对于项目根目录的路径
7
+ dataset_path: "Data/"
8
+
9
+ # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
10
+ mirror: ""
11
+ openi_token: "" # openi token
12
+
13
+ # resample 音频重采样配置
14
+ # 注意, “:” 后需要加空格
15
+ resample:
16
+ # 目标重采样率
17
+ sampling_rate: 44100
18
+ # 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
19
+ # 请填入相对于datasetPath的相对路径
20
+ in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir
21
+ # 音频文件重采样后输出路径
22
+ out_dir: "audios/wavs"
23
+
24
+
25
+ # preprocess_text 数据集预处理相关配置
26
+ # 注意, “:” 后需要加空格
27
+ preprocess_text:
28
+ # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
29
+ transcription_path: "filelists/你的数据集文本.list"
30
+ # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
31
+ cleaned_path: ""
32
+ # 训练集路径
33
+ train_path: "filelists/train.list"
34
+ # 验证集路径
35
+ val_path: "filelists/val.list"
36
+ # 配置文件路径
37
+ config_path: "config.json"
38
+ # 每个speaker的验证集条数
39
+ val_per_spk: 4
40
+ # 验证集最大条数,多于的会被截断并放到训练集中
41
+ max_val_total: 8
42
+ # 是否进行数据清洗
43
+ clean: true
44
+
45
+
46
+ # bert_gen 相关配置
47
+ # 注意, “:” 后需要加空格
48
+ bert_gen:
49
+ # 训练数据集配置文件路径
50
+ config_path: "config.json"
51
+ # 并行数
52
+ num_processes: 2
53
+ # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
54
+ # 该选项同时决定了get_bert_feature的默认设备
55
+ device: "cuda"
56
+ # 使用多卡推理
57
+ use_multi_device: false
58
+
59
+ # emo_gen 相关配置
60
+ # 注意, “:” 后需要加空格
61
+ emo_gen:
62
+ # 训练数据集配置文件路径
63
+ config_path: "config.json"
64
+ # 并行数
65
+ num_processes: 2
66
+ # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
67
+ device: "cuda"
68
+
69
+ # train 训练配置
70
+ # 注意, “:” 后需要加空格
71
+ train_ms:
72
+ env:
73
+ MASTER_ADDR: "localhost"
74
+ MASTER_PORT: 10086
75
+ WORLD_SIZE: 1
76
+ LOCAL_RANK: 0
77
+ RANK: 0
78
+ # 可以填写任意名的环境变量
79
+ # THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
80
+ # 底模设置
81
+ base:
82
+ use_base_model: false
83
+ repo_id: "Stardust_minus/Bert-VITS2"
84
+ model_image: "Bert-VITS2_2.1-Emo底模" # openi网页的模型名
85
+ # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
86
+ model: "models"
87
+ # 配置文件路径
88
+ config_path: "config.json"
89
+ # 训练使用的worker,不建议超过CPU核心数
90
+ num_workers: 16
91
+ # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
92
+ spec_cache: True
93
+ # 保存的检查点数量,多于此数目的权重会被删除来节省空间。
94
+ keep_ckpts: 8
95
+
96
+
97
+ # webui webui配置
98
+ # 注意, “:” 后需要加空格
99
+ webui:
100
+ # 推理设备
101
+ device: "cuda"
102
+ # 模型路径
103
+ model: "genshin/models/G_8000.pth"
104
+ # 配置文件路径
105
+ config_path: "config.json"
106
+ # 端口号
107
+ port: 7860
108
+ # 是否公开部署,对外网开放
109
+ share: false
110
+ # 是否开启debug模式
111
+ debug: false
112
+ # 语种识别库,可选langid, fastlid
113
+ language_identification_library: "langid"
114
+
115
+
116
+ # server api配置
117
+ # 注意, “:” 后需要加空格
118
+ # 注意,本配置下的所有配置均为相对于根目录的路径
119
+ server:
120
+ # 端口号
121
+ port: 5000
122
+ # 模型默认使用设备:但是当前并没有实现这个配置。
123
+ device: "cuda"
124
+ # 需要加载的所有模型的配置,可以填多个模型,也可以不填模型,等网页成功后手动加载模型
125
+ # 不加载模型的配置格式:删除默认给的两个模型配置,给models赋值 [ ],也就是空列表。参考模型2的speakers 即 models: [ ]
126
+ # 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
127
+ # 也可以不填模型,等网页加载成功后手动填写models。
128
+ models:
129
+ - # 模型的路径
130
+ model: ""
131
+ # 模型config.json的路径
132
+ config: ""
133
+ # 模型使用设备,若填写则会覆盖默认配置
134
+ device: "cuda"
135
+ # 模型默认使用的语言
136
+ language: "ZH"
137
+ # 模型人物默认参数
138
+ # 不必填写所有人物,不填的使用默认值
139
+ # 暂时不用填写,当前尚未实现按人区分配置
140
+ speakers:
141
+ - speaker: "科比"
142
+ sdp_ratio: 0.2
143
+ noise_scale: 0.6
144
+ noise_scale_w: 0.8
145
+ length_scale: 1
146
+ - speaker: "五条悟"
147
+ sdp_ratio: 0.3
148
+ noise_scale: 0.7
149
+ noise_scale_w: 0.8
150
+ length_scale: 0.5
151
+ - speaker: "安倍晋三"
152
+ sdp_ratio: 0.2
153
+ noise_scale: 0.6
154
+ noise_scale_w: 0.8
155
+ length_scale: 1.2
156
+ - # 模型的路径
157
+ model: ""
158
+ # 模型config.json的路径
159
+ config: ""
160
+ # 模型使用设备,若填写则会覆盖默认配置
161
+ device: "cpu"
162
+ # 模型默认使用的语言
163
+ language: "JP"
164
+ # 模型人物默认参数
165
+ # 不必填写所有人物,不填的使用默认值
166
+ speakers: [ ] # 也可以不填
167
+
168
+
169
+ # 百度翻译开放平台 api配置
170
+ # api接入文档 https://api.fanyi.baidu.com/doc/21
171
+ # 请不要在github等网站公开分享你的app id 与 key
172
+ translate:
173
+ # 你的APPID
174
+ "app_key": ""
175
+ # 你的密钥
176
+ "secret_key": ""
emo_gen.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import Dataset
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from tqdm import tqdm
12
+ from transformers import Wav2Vec2Processor
13
+ from transformers.models.wav2vec2.modeling_wav2vec2 import (
14
+ Wav2Vec2Model,
15
+ Wav2Vec2PreTrainedModel,
16
+ )
17
+
18
+ import utils
19
+ from config import config
20
+
21
+
22
+ class RegressionHead(nn.Module):
23
+ r"""Classification head."""
24
+
25
+ def __init__(self, config):
26
+ super().__init__()
27
+
28
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
29
+ self.dropout = nn.Dropout(config.final_dropout)
30
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
31
+
32
+ def forward(self, features, **kwargs):
33
+ x = features
34
+ x = self.dropout(x)
35
+ x = self.dense(x)
36
+ x = torch.tanh(x)
37
+ x = self.dropout(x)
38
+ x = self.out_proj(x)
39
+
40
+ return x
41
+
42
+
43
+ class EmotionModel(Wav2Vec2PreTrainedModel):
44
+ r"""Speech emotion classifier."""
45
+
46
+ def __init__(self, config):
47
+ super().__init__(config)
48
+
49
+ self.config = config
50
+ self.wav2vec2 = Wav2Vec2Model(config)
51
+ self.classifier = RegressionHead(config)
52
+ self.init_weights()
53
+
54
+ def forward(
55
+ self,
56
+ input_values,
57
+ ):
58
+ outputs = self.wav2vec2(input_values)
59
+ hidden_states = outputs[0]
60
+ hidden_states = torch.mean(hidden_states, dim=1)
61
+ logits = self.classifier(hidden_states)
62
+
63
+ return hidden_states, logits
64
+
65
+
66
+ class AudioDataset(Dataset):
67
+ def __init__(self, list_of_wav_files, sr, processor):
68
+ self.list_of_wav_files = list_of_wav_files
69
+ self.processor = processor
70
+ self.sr = sr
71
+
72
+ def __len__(self):
73
+ return len(self.list_of_wav_files)
74
+
75
+ def __getitem__(self, idx):
76
+ wav_file = self.list_of_wav_files[idx]
77
+ audio_data, _ = librosa.load(wav_file, sr=self.sr)
78
+ processed_data = self.processor(audio_data, sampling_rate=self.sr)[
79
+ "input_values"
80
+ ][0]
81
+ return torch.from_numpy(processed_data)
82
+
83
+
84
+ def process_func(
85
+ x: np.ndarray,
86
+ sampling_rate: int,
87
+ model: EmotionModel,
88
+ processor: Wav2Vec2Processor,
89
+ device: str,
90
+ embeddings: bool = False,
91
+ ) -> np.ndarray:
92
+ r"""Predict emotions or extract embeddings from raw audio signal."""
93
+ model = model.to(device)
94
+ y = processor(x, sampling_rate=sampling_rate)
95
+ y = y["input_values"][0]
96
+ y = torch.from_numpy(y).unsqueeze(0).to(device)
97
+
98
+ # run through model
99
+ with torch.no_grad():
100
+ y = model(y)[0 if embeddings else 1]
101
+
102
+ # convert to numpy
103
+ y = y.detach().cpu().numpy()
104
+
105
+ return y
106
+
107
+
108
+ if __name__ == "__main__":
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument(
111
+ "-c", "--config", type=str, default=config.bert_gen_config.config_path
112
+ )
113
+ parser.add_argument(
114
+ "--num_processes", type=int, default=config.bert_gen_config.num_processes
115
+ )
116
+ args, _ = parser.parse_known_args()
117
+ config_path = args.config
118
+ hps = utils.get_hparams_from_file(config_path)
119
+
120
+ device = config.bert_gen_config.device
121
+
122
+ model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
123
+ REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
124
+ if not Path(model_name).joinpath("pytorch_model.bin").exists():
125
+ utils.download_emo_models(config.mirror, REPO_ID, model_name)
126
+
127
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
128
+ model = EmotionModel.from_pretrained(model_name).to(device)
129
+
130
+ lines = []
131
+ with open(hps.data.training_files, encoding="utf-8") as f:
132
+ lines.extend(f.readlines())
133
+
134
+ with open(hps.data.validation_files, encoding="utf-8") as f:
135
+ lines.extend(f.readlines())
136
+
137
+ wavnames = [line.split("|")[0] for line in lines]
138
+ dataset = AudioDataset(wavnames, 16000, processor)
139
+ data_loader = DataLoader(
140
+ dataset,
141
+ batch_size=1,
142
+ shuffle=False,
143
+ num_workers=min(args.num_processes, os.cpu_count() - 1),
144
+ )
145
+
146
+ with torch.no_grad():
147
+ for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
148
+ wavname = wavnames[i]
149
+ emo_path = wavname.replace(".wav", ".emo.npy")
150
+ if os.path.exists(emo_path):
151
+ continue
152
+ emb = model(data.to(device))[0].detach().cpu().numpy()
153
+ np.save(emo_path, emb)
154
+
155
+ print("Emo vec 生成完毕!")
export_onnx.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models_onnx import SynthesizerTrn
2
+ import utils
3
+ from text.symbols import symbols
4
+ import os
5
+ import json
6
+
7
+
8
+ def export_onnx(export_path, model_path, config_path):
9
+ hps = utils.get_hparams_from_file(config_path)
10
+ net_g = SynthesizerTrn(
11
+ len(symbols),
12
+ hps.data.filter_length // 2 + 1,
13
+ hps.train.segment_size // hps.data.hop_length,
14
+ n_speakers=hps.data.n_speakers,
15
+ **hps.model,
16
+ )
17
+ _ = net_g.eval()
18
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
19
+ net_g.export_onnx(export_path)
20
+
21
+ spklist = []
22
+ for key in hps.data.spk2id.keys():
23
+ spklist.append(key)
24
+
25
+ MoeVSConf = {
26
+ "Folder": f"{export_path}",
27
+ "Name": f"{export_path}",
28
+ "Type": "BertVits",
29
+ "Symbol": symbols,
30
+ "Cleaner": "",
31
+ "Rate": hps.data.sampling_rate,
32
+ "CharaMix": True,
33
+ "Characters": spklist,
34
+ "LanguageMap": {"ZH": [0, 0], "JP": [1, 6], "EN": [2, 8]},
35
+ "Dict": "BasicDict",
36
+ "BertPath": [
37
+ "chinese-roberta-wwm-ext-large",
38
+ "deberta-v2-large-japanese",
39
+ "bert-base-japanese-v3",
40
+ ],
41
+ }
42
+
43
+ with open(f"onnx/{export_path}.json", "w") as MoeVsConfFile:
44
+ json.dump(MoeVSConf, MoeVsConfFile, indent=4)
45
+
46
+
47
+ if __name__ == "__main__":
48
+ print(symbols)
49
+ export_path = "HimenoSena"
50
+ model_path = "G_53000.pth"
51
+ config_path = "config.json"
52
+ if not os.path.exists("onnx"):
53
+ os.makedirs("onnx")
54
+ if not os.path.exists(f"onnx/{export_path}"):
55
+ os.makedirs(f"onnx/{export_path}")
56
+ export_onnx(export_path, model_path, config_path)
get_emo.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from emo_gen import EmotionModel, process_func
2
+
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ from transformers import Wav2Vec2Processor
7
+
8
+ from config import config
9
+
10
+ model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
13
+ model = EmotionModel.from_pretrained(model_name).to(device)
14
+
15
+
16
+ def get_emo(path):
17
+ wav, sr = librosa.load(path, 16000)
18
+ device = config.bert_gen_config.device
19
+ return process_func(
20
+ np.expand_dims(wav, 0).astype(np.float64),
21
+ sr,
22
+ model,
23
+ processor,
24
+ device,
25
+ embeddings=True,
26
+ ).squeeze(0)
infer.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 版本管理、兼容推理及模型加载实现。
3
+ 版本说明:
4
+ 1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
5
+ 2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
6
+ 特殊版本说明:
7
+ 1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
8
+ 1.1.1-dev: dev开发
9
+ 2.1:当前版本
10
+ """
11
+ import torch
12
+ import commons
13
+ from text import cleaned_text_to_sequence, get_bert
14
+ from get_emo import get_emo
15
+ from text.cleaner import clean_text
16
+ import utils
17
+
18
+ from models import SynthesizerTrn
19
+ from text.symbols import symbols
20
+ from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
21
+ from oldVersion.V200.text import symbols as V200symbols
22
+ from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
23
+ from oldVersion.V111.text import symbols as V111symbols
24
+ from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
25
+ from oldVersion.V110.text import symbols as V110symbols
26
+ from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
27
+ from oldVersion.V101.text import symbols as V101symbols
28
+
29
+ from oldVersion import V111, V110, V101, V200
30
+
31
+ # 当前版本信息
32
+ latest_version = "2.1"
33
+
34
+ # 版本兼容
35
+ SynthesizerTrnMap = {
36
+ "2.0.2-fix": V200SynthesizerTrn,
37
+ "2.0.1": V200SynthesizerTrn,
38
+ "2.0": V200SynthesizerTrn,
39
+ "1.1.1-fix": V111SynthesizerTrn,
40
+ "1.1.1": V111SynthesizerTrn,
41
+ "1.1": V110SynthesizerTrn,
42
+ "1.1.0": V110SynthesizerTrn,
43
+ "1.0.1": V101SynthesizerTrn,
44
+ "1.0": V101SynthesizerTrn,
45
+ "1.0.0": V101SynthesizerTrn,
46
+ }
47
+
48
+ symbolsMap = {
49
+ "2.0.2-fix": V200symbols,
50
+ "2.0.1": V200symbols,
51
+ "2.0": V200symbols,
52
+ "1.1.1-fix": V111symbols,
53
+ "1.1.1": V111symbols,
54
+ "1.1": V110symbols,
55
+ "1.1.0": V110symbols,
56
+ "1.0.1": V101symbols,
57
+ "1.0": V101symbols,
58
+ "1.0.0": V101symbols,
59
+ }
60
+
61
+
62
+ def get_net_g(model_path: str, version: str, device: str, hps):
63
+ if version != latest_version:
64
+ net_g = SynthesizerTrnMap[version](
65
+ len(symbolsMap[version]),
66
+ hps.data.filter_length // 2 + 1,
67
+ hps.train.segment_size // hps.data.hop_length,
68
+ n_speakers=hps.data.n_speakers,
69
+ **hps.model,
70
+ ).to(device)
71
+ else:
72
+ # 当前版本模型 net_g
73
+ net_g = SynthesizerTrn(
74
+ len(symbols),
75
+ hps.data.filter_length // 2 + 1,
76
+ hps.train.segment_size // hps.data.hop_length,
77
+ n_speakers=hps.data.n_speakers,
78
+ **hps.model,
79
+ ).to(device)
80
+ _ = net_g.eval()
81
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
82
+ return net_g
83
+
84
+
85
+ def get_text(text, language_str, hps, device):
86
+ # 在此处实现当前版本的get_text
87
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
88
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
89
+
90
+ if hps.data.add_blank:
91
+ phone = commons.intersperse(phone, 0)
92
+ tone = commons.intersperse(tone, 0)
93
+ language = commons.intersperse(language, 0)
94
+ for i in range(len(word2ph)):
95
+ word2ph[i] = word2ph[i] * 2
96
+ word2ph[0] += 1
97
+ bert_ori = get_bert(norm_text, word2ph, language_str, device)
98
+ del word2ph
99
+ assert bert_ori.shape[-1] == len(phone), phone
100
+
101
+ if language_str == "ZH":
102
+ bert = bert_ori
103
+ ja_bert = torch.zeros(1024, len(phone))
104
+ en_bert = torch.zeros(1024, len(phone))
105
+ elif language_str == "JP":
106
+ bert = torch.zeros(1024, len(phone))
107
+ ja_bert = bert_ori
108
+ en_bert = torch.zeros(1024, len(phone))
109
+ elif language_str == "EN":
110
+ bert = torch.zeros(1024, len(phone))
111
+ ja_bert = torch.zeros(1024, len(phone))
112
+ en_bert = bert_ori
113
+ else:
114
+ raise ValueError("language_str should be ZH, JP or EN")
115
+
116
+ assert bert.shape[-1] == len(
117
+ phone
118
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
119
+
120
+ phone = torch.LongTensor(phone)
121
+ tone = torch.LongTensor(tone)
122
+ language = torch.LongTensor(language)
123
+ return bert, ja_bert, en_bert, phone, tone, language
124
+
125
+
126
+ def get_emo_(reference_audio, emotion):
127
+ emo = (
128
+ torch.from_numpy(get_emo(reference_audio))
129
+ if reference_audio
130
+ else torch.Tensor([emotion])
131
+ )
132
+ return emo
133
+
134
+
135
+ def infer(
136
+ text,
137
+ sdp_ratio,
138
+ noise_scale,
139
+ noise_scale_w,
140
+ length_scale,
141
+ sid,
142
+ language,
143
+ hps,
144
+ net_g,
145
+ device,
146
+ reference_audio=None,
147
+ emotion=None,
148
+ skip_start=False,
149
+ skip_end=False,
150
+ ):
151
+ # 支持中日英三语版本
152
+ inferMap_V2 = {
153
+ "2.0.2-fix": V200.infer,
154
+ "2.0.1": V200.infer,
155
+ "2.0": V200.infer,
156
+ "1.1.1-fix": V111.infer_fix,
157
+ "1.1.1": V111.infer,
158
+ "1.1": V110.infer,
159
+ "1.1.0": V110.infer,
160
+ }
161
+ # 仅支持中文版本
162
+ # 在测试中,并未发现两��版本的模型不能互相通用
163
+ inferMap_V1 = {
164
+ "1.0.1": V101.infer,
165
+ "1.0": V101.infer,
166
+ "1.0.0": V101.infer,
167
+ }
168
+ version = hps.version if hasattr(hps, "version") else latest_version
169
+ # 非当前版本,根据版本号选择合适的infer
170
+ if version != latest_version:
171
+ if version in inferMap_V2.keys():
172
+ return inferMap_V2[version](
173
+ text,
174
+ sdp_ratio,
175
+ noise_scale,
176
+ noise_scale_w,
177
+ length_scale,
178
+ sid,
179
+ language,
180
+ hps,
181
+ net_g,
182
+ device,
183
+ )
184
+ if version in inferMap_V1.keys():
185
+ return inferMap_V1[version](
186
+ text,
187
+ sdp_ratio,
188
+ noise_scale,
189
+ noise_scale_w,
190
+ length_scale,
191
+ sid,
192
+ hps,
193
+ net_g,
194
+ device,
195
+ )
196
+ # 在此处实现当前版本的推理
197
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
198
+ text, language, hps, device
199
+ )
200
+ emo = get_emo_(reference_audio, emotion)
201
+ if skip_start:
202
+ phones = phones[1:]
203
+ tones = tones[1:]
204
+ lang_ids = lang_ids[1:]
205
+ bert = bert[:, 1:]
206
+ ja_bert = ja_bert[:, 1:]
207
+ en_bert = en_bert[:, 1:]
208
+ if skip_end:
209
+ phones = phones[:-1]
210
+ tones = tones[:-1]
211
+ lang_ids = lang_ids[:-1]
212
+ bert = bert[:, :-1]
213
+ ja_bert = ja_bert[:, :-1]
214
+ en_bert = en_bert[:, :-1]
215
+ with torch.no_grad():
216
+ x_tst = phones.to(device).unsqueeze(0)
217
+ tones = tones.to(device).unsqueeze(0)
218
+ lang_ids = lang_ids.to(device).unsqueeze(0)
219
+ bert = bert.to(device).unsqueeze(0)
220
+ ja_bert = ja_bert.to(device).unsqueeze(0)
221
+ en_bert = en_bert.to(device).unsqueeze(0)
222
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
223
+ emo = emo.to(device).unsqueeze(0)
224
+ del phones
225
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
226
+ audio = (
227
+ net_g.infer(
228
+ x_tst,
229
+ x_tst_lengths,
230
+ speakers,
231
+ tones,
232
+ lang_ids,
233
+ bert,
234
+ ja_bert,
235
+ en_bert,
236
+ emo,
237
+ sdp_ratio=sdp_ratio,
238
+ noise_scale=noise_scale,
239
+ noise_scale_w=noise_scale_w,
240
+ length_scale=length_scale,
241
+ )[0][0, 0]
242
+ .data.cpu()
243
+ .float()
244
+ .numpy()
245
+ )
246
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
247
+ if torch.cuda.is_available():
248
+ torch.cuda.empty_cache()
249
+ return audio
250
+
251
+
252
+ def infer_multilang(
253
+ text,
254
+ sdp_ratio,
255
+ noise_scale,
256
+ noise_scale_w,
257
+ length_scale,
258
+ sid,
259
+ language,
260
+ hps,
261
+ net_g,
262
+ device,
263
+ reference_audio=None,
264
+ emotion=None,
265
+ skip_start=False,
266
+ skip_end=False,
267
+ ):
268
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
269
+ emo = get_emo_(reference_audio, emotion)
270
+ for idx, (txt, lang) in enumerate(zip(text, language)):
271
+ skip_start = (idx != 0) or (skip_start and idx == 0)
272
+ skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
273
+ (
274
+ temp_bert,
275
+ temp_ja_bert,
276
+ temp_en_bert,
277
+ temp_phones,
278
+ temp_tones,
279
+ temp_lang_ids,
280
+ ) = get_text(txt, lang, hps, device)
281
+ if skip_start:
282
+ temp_bert = temp_bert[:, 1:]
283
+ temp_ja_bert = temp_ja_bert[:, 1:]
284
+ temp_en_bert = temp_en_bert[:, 1:]
285
+ temp_phones = temp_phones[1:]
286
+ temp_tones = temp_tones[1:]
287
+ temp_lang_ids = temp_lang_ids[1:]
288
+ if skip_end:
289
+ temp_bert = temp_bert[:, :-1]
290
+ temp_ja_bert = temp_ja_bert[:, :-1]
291
+ temp_en_bert = temp_en_bert[:, :-1]
292
+ temp_phones = temp_phones[:-1]
293
+ temp_tones = temp_tones[:-1]
294
+ temp_lang_ids = temp_lang_ids[:-1]
295
+ bert.append(temp_bert)
296
+ ja_bert.append(temp_ja_bert)
297
+ en_bert.append(temp_en_bert)
298
+ phones.append(temp_phones)
299
+ tones.append(temp_tones)
300
+ lang_ids.append(temp_lang_ids)
301
+ bert = torch.concatenate(bert, dim=1)
302
+ ja_bert = torch.concatenate(ja_bert, dim=1)
303
+ en_bert = torch.concatenate(en_bert, dim=1)
304
+ phones = torch.concatenate(phones, dim=0)
305
+ tones = torch.concatenate(tones, dim=0)
306
+ lang_ids = torch.concatenate(lang_ids, dim=0)
307
+ with torch.no_grad():
308
+ x_tst = phones.to(device).unsqueeze(0)
309
+ tones = tones.to(device).unsqueeze(0)
310
+ lang_ids = lang_ids.to(device).unsqueeze(0)
311
+ bert = bert.to(device).unsqueeze(0)
312
+ ja_bert = ja_bert.to(device).unsqueeze(0)
313
+ en_bert = en_bert.to(device).unsqueeze(0)
314
+ emo = emo.to(device).unsqueeze(0)
315
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
316
+ del phones
317
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
318
+ audio = (
319
+ net_g.infer(
320
+ x_tst,
321
+ x_tst_lengths,
322
+ speakers,
323
+ tones,
324
+ lang_ids,
325
+ bert,
326
+ ja_bert,
327
+ en_bert,
328
+ emo,
329
+ sdp_ratio=sdp_ratio,
330
+ noise_scale=noise_scale,
331
+ noise_scale_w=noise_scale_w,
332
+ length_scale=length_scale,
333
+ )[0][0, 0]
334
+ .data.cpu()
335
+ .float()
336
+ .numpy()
337
+ )
338
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
339
+ if torch.cuda.is_available():
340
+ torch.cuda.empty_cache()
341
+ return audio