ElesisSiegherts
commited on
Commit
•
1b9cb8c
1
Parent(s):
8900345
Upload 7 files
Browse files- config.yml +176 -0
- data_utils.py +410 -0
- default_config.yml +176 -0
- emo_gen.py +155 -0
- export_onnx.py +56 -0
- get_emo.py +26 -0
- infer.py +341 -0
config.yml
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 全局配置
|
2 |
+
# 对于希望在同一时间使用多个配置文件的情况,例如两个GPU同时跑两个训练集:通过环境变量指定配置文件,不指定则默认为./config.yml
|
3 |
+
|
4 |
+
# 拟提供通用路径配置,统一存放数据,避免数据放得很乱
|
5 |
+
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
|
6 |
+
# 不填或者填空则路径为相对于项目根目录的路径
|
7 |
+
dataset_path: "Data/tamura"
|
8 |
+
|
9 |
+
# 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
|
10 |
+
mirror: ""
|
11 |
+
openi_token: "" # openi token
|
12 |
+
|
13 |
+
# resample 音频重采样配置
|
14 |
+
# 注意, “:” 后需要加空格
|
15 |
+
resample:
|
16 |
+
# 目标重采样率
|
17 |
+
sampling_rate: 44100
|
18 |
+
# 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
|
19 |
+
# 请填入相对于datasetPath的相对路径
|
20 |
+
in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir
|
21 |
+
# 音频文件重采样后输出路径
|
22 |
+
out_dir: "audios/wavs"
|
23 |
+
|
24 |
+
|
25 |
+
# preprocess_text 数据集预处理相关配置
|
26 |
+
# 注意, “:” 后需要加空格
|
27 |
+
preprocess_text:
|
28 |
+
# 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
|
29 |
+
transcription_path: "filelists/text.list"
|
30 |
+
# 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
|
31 |
+
cleaned_path: ""
|
32 |
+
# 训练集路径
|
33 |
+
train_path: "filelists/train.list"
|
34 |
+
# 验证集路径
|
35 |
+
val_path: "filelists/val.list"
|
36 |
+
# 配置文件路径
|
37 |
+
config_path: "config.json"
|
38 |
+
# 每个speaker的验证集条数
|
39 |
+
val_per_spk: 4
|
40 |
+
# 验证集最大条数,多于的会被截断并放到训练集中
|
41 |
+
max_val_total: 8
|
42 |
+
# 是否进行数据清洗
|
43 |
+
clean: true
|
44 |
+
|
45 |
+
|
46 |
+
# bert_gen 相关配置
|
47 |
+
# 注意, “:” 后需要加空格
|
48 |
+
bert_gen:
|
49 |
+
# 训练数据集配置文件路径
|
50 |
+
config_path: "config.json"
|
51 |
+
# 并行数
|
52 |
+
num_processes: 2
|
53 |
+
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
54 |
+
# 该选项同时决定了get_bert_feature的默认设备
|
55 |
+
device: "cuda"
|
56 |
+
# 使用多卡推理
|
57 |
+
use_multi_device: false
|
58 |
+
|
59 |
+
# emo_gen 相关配置
|
60 |
+
# 注意, “:” 后需要加空格
|
61 |
+
emo_gen:
|
62 |
+
# 训练数据集配置文件路径
|
63 |
+
config_path: "config.json"
|
64 |
+
# 并行数
|
65 |
+
num_processes: 2
|
66 |
+
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
67 |
+
device: "cuda"
|
68 |
+
|
69 |
+
# train 训练配置
|
70 |
+
# 注意, “:” 后需要加空格
|
71 |
+
train_ms:
|
72 |
+
env:
|
73 |
+
MASTER_ADDR: "localhost"
|
74 |
+
MASTER_PORT: 10086
|
75 |
+
WORLD_SIZE: 1
|
76 |
+
LOCAL_RANK: 0
|
77 |
+
RANK: 0
|
78 |
+
# 可以填写任意名的环境变量
|
79 |
+
# THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
|
80 |
+
# 底模设置
|
81 |
+
base:
|
82 |
+
use_base_model: false
|
83 |
+
repo_id: "Stardust_minus/Bert-VITS2"
|
84 |
+
model_image: "Bert-VITS2_2.1-Emo底模" # openi网页的模型名
|
85 |
+
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
86 |
+
model: "models"
|
87 |
+
# 配置文件路径
|
88 |
+
config_path: "config.json"
|
89 |
+
# 训练使用的worker,不建议超过CPU核心数
|
90 |
+
num_workers: 16
|
91 |
+
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
92 |
+
spec_cache: True
|
93 |
+
# 保存的检查点数量,多于此数目的权重会被删除来节省空间。
|
94 |
+
keep_ckpts: 8
|
95 |
+
|
96 |
+
|
97 |
+
# webui webui配置
|
98 |
+
# 注意, “:” 后需要加空格
|
99 |
+
webui:
|
100 |
+
# 推理设备
|
101 |
+
device: "cuda"
|
102 |
+
# 模型路径
|
103 |
+
model: "models/G_2750.pth"
|
104 |
+
# 配置文件路径
|
105 |
+
config_path: "config.json"
|
106 |
+
# 端口号
|
107 |
+
port: 7860
|
108 |
+
# 是否公开部署,对外网开放
|
109 |
+
share: false
|
110 |
+
# 是否开启debug模式
|
111 |
+
debug: false
|
112 |
+
# 语种识别库,可选langid, fastlid
|
113 |
+
language_identification_library: "langid"
|
114 |
+
|
115 |
+
|
116 |
+
# server api配置
|
117 |
+
# 注意, “:” 后需要加空格
|
118 |
+
# 注意,本配置下的所有配置均为相对于根目录的路径
|
119 |
+
server:
|
120 |
+
# 端口号
|
121 |
+
port: 5000
|
122 |
+
# 模型默认使用设备:但是当前并没有实现这个配置。
|
123 |
+
device: "cuda"
|
124 |
+
# 需要加载的所有模型的配置,可以填多个模型,也可以不填模型,等网页成功后手动加载模型
|
125 |
+
# 不加载模型的配置格式:删除默认给的两个模型配置,给models赋值 [ ],也就是空列表。参考模型2的speakers 即 models: [ ]
|
126 |
+
# 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
|
127 |
+
# 也可以不填模型,等网页加载成功后手动填写models。
|
128 |
+
models:
|
129 |
+
- # 模型的路径
|
130 |
+
model: ""
|
131 |
+
# 模型config.json的路径
|
132 |
+
config: ""
|
133 |
+
# 模型使用设备,若填写则会覆盖默认配置
|
134 |
+
device: "cuda"
|
135 |
+
# 模型默认使用的语言
|
136 |
+
language: "ZH"
|
137 |
+
# 模型人物默认参数
|
138 |
+
# 不必填写所有人物,不填的使用默认值
|
139 |
+
# 暂时不用填写,当前尚未实现按人区分配置
|
140 |
+
speakers:
|
141 |
+
- speaker: "科比"
|
142 |
+
sdp_ratio: 0.2
|
143 |
+
noise_scale: 0.6
|
144 |
+
noise_scale_w: 0.8
|
145 |
+
length_scale: 1
|
146 |
+
- speaker: "五条悟"
|
147 |
+
sdp_ratio: 0.3
|
148 |
+
noise_scale: 0.7
|
149 |
+
noise_scale_w: 0.8
|
150 |
+
length_scale: 0.5
|
151 |
+
- speaker: "安倍晋三"
|
152 |
+
sdp_ratio: 0.2
|
153 |
+
noise_scale: 0.6
|
154 |
+
noise_scale_w: 0.8
|
155 |
+
length_scale: 1.2
|
156 |
+
- # 模型的路径
|
157 |
+
model: ""
|
158 |
+
# 模型config.json的路径
|
159 |
+
config: ""
|
160 |
+
# 模型使用设备,若填写则会覆盖默认配置
|
161 |
+
device: "cpu"
|
162 |
+
# 模型默认使用的语言
|
163 |
+
language: "JP"
|
164 |
+
# 模型人物默认参数
|
165 |
+
# 不必填写所有人物,不填的使用默认值
|
166 |
+
speakers: [ ] # 也可以不填
|
167 |
+
|
168 |
+
|
169 |
+
# 百度翻译开放平台 api配置
|
170 |
+
# api接入文档 https://api.fanyi.baidu.com/doc/21
|
171 |
+
# 请不要在github等网站公开分享你的app id 与 key
|
172 |
+
translate:
|
173 |
+
# 你的APPID
|
174 |
+
"app_key": ""
|
175 |
+
# 你的密钥
|
176 |
+
"secret_key": ""
|
data_utils.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
import torch.utils.data
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
from tools.log import logger
|
8 |
+
import commons
|
9 |
+
from mel_processing import spectrogram_torch, mel_spectrogram_torch
|
10 |
+
from utils import load_wav_to_torch, load_filepaths_and_text
|
11 |
+
from text import cleaned_text_to_sequence
|
12 |
+
from config import config
|
13 |
+
|
14 |
+
"""Multi speaker version"""
|
15 |
+
|
16 |
+
|
17 |
+
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
18 |
+
"""
|
19 |
+
1) loads audio, speaker_id, text pairs
|
20 |
+
2) normalizes text and converts them to sequences of integers
|
21 |
+
3) computes spectrograms from audio files.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, audiopaths_sid_text, hparams):
|
25 |
+
self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
|
26 |
+
self.max_wav_value = hparams.max_wav_value
|
27 |
+
self.sampling_rate = hparams.sampling_rate
|
28 |
+
self.filter_length = hparams.filter_length
|
29 |
+
self.hop_length = hparams.hop_length
|
30 |
+
self.win_length = hparams.win_length
|
31 |
+
self.sampling_rate = hparams.sampling_rate
|
32 |
+
self.spk_map = hparams.spk2id
|
33 |
+
self.hparams = hparams
|
34 |
+
|
35 |
+
self.use_mel_spec_posterior = getattr(
|
36 |
+
hparams, "use_mel_posterior_encoder", False
|
37 |
+
)
|
38 |
+
if self.use_mel_spec_posterior:
|
39 |
+
self.n_mel_channels = getattr(hparams, "n_mel_channels", 80)
|
40 |
+
|
41 |
+
self.cleaned_text = getattr(hparams, "cleaned_text", False)
|
42 |
+
|
43 |
+
self.add_blank = hparams.add_blank
|
44 |
+
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
45 |
+
self.max_text_len = getattr(hparams, "max_text_len", 384)
|
46 |
+
|
47 |
+
random.seed(1234)
|
48 |
+
random.shuffle(self.audiopaths_sid_text)
|
49 |
+
self._filter()
|
50 |
+
|
51 |
+
def _filter(self):
|
52 |
+
"""
|
53 |
+
Filter text & store spec lengths
|
54 |
+
"""
|
55 |
+
# Store spectrogram lengths for Bucketing
|
56 |
+
# wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
|
57 |
+
# spec_length = wav_length // hop_length
|
58 |
+
|
59 |
+
audiopaths_sid_text_new = []
|
60 |
+
lengths = []
|
61 |
+
skipped = 0
|
62 |
+
logger.info("Init dataset...")
|
63 |
+
for _id, spk, language, text, phones, tone, word2ph in tqdm(
|
64 |
+
self.audiopaths_sid_text
|
65 |
+
):
|
66 |
+
audiopath = f"{_id}"
|
67 |
+
if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len:
|
68 |
+
phones = phones.split(" ")
|
69 |
+
tone = [int(i) for i in tone.split(" ")]
|
70 |
+
word2ph = [int(i) for i in word2ph.split(" ")]
|
71 |
+
audiopaths_sid_text_new.append(
|
72 |
+
[audiopath, spk, language, text, phones, tone, word2ph]
|
73 |
+
)
|
74 |
+
lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
|
75 |
+
else:
|
76 |
+
skipped += 1
|
77 |
+
logger.info(
|
78 |
+
"skipped: "
|
79 |
+
+ str(skipped)
|
80 |
+
+ ", total: "
|
81 |
+
+ str(len(self.audiopaths_sid_text))
|
82 |
+
)
|
83 |
+
self.audiopaths_sid_text = audiopaths_sid_text_new
|
84 |
+
self.lengths = lengths
|
85 |
+
|
86 |
+
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
87 |
+
# separate filename, speaker_id and text
|
88 |
+
audiopath, sid, language, text, phones, tone, word2ph = audiopath_sid_text
|
89 |
+
|
90 |
+
bert, ja_bert, en_bert, phones, tone, language = self.get_text(
|
91 |
+
text, word2ph, phones, tone, language, audiopath
|
92 |
+
)
|
93 |
+
|
94 |
+
spec, wav = self.get_audio(audiopath)
|
95 |
+
sid = torch.LongTensor([int(self.spk_map[sid])])
|
96 |
+
emo = torch.FloatTensor(np.load(audiopath.replace(".wav", ".emo.npy")))
|
97 |
+
return (phones, spec, wav, sid, tone, language, bert, ja_bert, en_bert, emo)
|
98 |
+
|
99 |
+
def get_audio(self, filename):
|
100 |
+
audio, sampling_rate = load_wav_to_torch(filename)
|
101 |
+
if sampling_rate != self.sampling_rate:
|
102 |
+
raise ValueError(
|
103 |
+
"{} {} SR doesn't match target {} SR".format(
|
104 |
+
filename, sampling_rate, self.sampling_rate
|
105 |
+
)
|
106 |
+
)
|
107 |
+
audio_norm = audio / self.max_wav_value
|
108 |
+
audio_norm = audio_norm.unsqueeze(0)
|
109 |
+
spec_filename = filename.replace(".wav", ".spec.pt")
|
110 |
+
if self.use_mel_spec_posterior:
|
111 |
+
spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
|
112 |
+
try:
|
113 |
+
spec = torch.load(spec_filename)
|
114 |
+
except:
|
115 |
+
if self.use_mel_spec_posterior:
|
116 |
+
spec = mel_spectrogram_torch(
|
117 |
+
audio_norm,
|
118 |
+
self.filter_length,
|
119 |
+
self.n_mel_channels,
|
120 |
+
self.sampling_rate,
|
121 |
+
self.hop_length,
|
122 |
+
self.win_length,
|
123 |
+
self.hparams.mel_fmin,
|
124 |
+
self.hparams.mel_fmax,
|
125 |
+
center=False,
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
spec = spectrogram_torch(
|
129 |
+
audio_norm,
|
130 |
+
self.filter_length,
|
131 |
+
self.sampling_rate,
|
132 |
+
self.hop_length,
|
133 |
+
self.win_length,
|
134 |
+
center=False,
|
135 |
+
)
|
136 |
+
spec = torch.squeeze(spec, 0)
|
137 |
+
if config.train_ms_config.spec_cache:
|
138 |
+
torch.save(spec, spec_filename)
|
139 |
+
return spec, audio_norm
|
140 |
+
|
141 |
+
def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
|
142 |
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
143 |
+
if self.add_blank:
|
144 |
+
phone = commons.intersperse(phone, 0)
|
145 |
+
tone = commons.intersperse(tone, 0)
|
146 |
+
language = commons.intersperse(language, 0)
|
147 |
+
for i in range(len(word2ph)):
|
148 |
+
word2ph[i] = word2ph[i] * 2
|
149 |
+
word2ph[0] += 1
|
150 |
+
bert_path = wav_path.replace(".wav", ".bert.pt")
|
151 |
+
try:
|
152 |
+
bert_ori = torch.load(bert_path)
|
153 |
+
assert bert_ori.shape[-1] == len(phone)
|
154 |
+
except Exception as e:
|
155 |
+
logger.warning("Bert load Failed")
|
156 |
+
logger.warning(e)
|
157 |
+
|
158 |
+
if language_str == "ZH":
|
159 |
+
bert = bert_ori
|
160 |
+
ja_bert = torch.zeros(1024, len(phone))
|
161 |
+
en_bert = torch.zeros(1024, len(phone))
|
162 |
+
elif language_str == "JP":
|
163 |
+
bert = torch.zeros(1024, len(phone))
|
164 |
+
ja_bert = bert_ori
|
165 |
+
en_bert = torch.zeros(1024, len(phone))
|
166 |
+
elif language_str == "EN":
|
167 |
+
bert = torch.zeros(1024, len(phone))
|
168 |
+
ja_bert = torch.zeros(1024, len(phone))
|
169 |
+
en_bert = bert_ori
|
170 |
+
phone = torch.LongTensor(phone)
|
171 |
+
tone = torch.LongTensor(tone)
|
172 |
+
language = torch.LongTensor(language)
|
173 |
+
return bert, ja_bert, en_bert, phone, tone, language
|
174 |
+
|
175 |
+
def get_sid(self, sid):
|
176 |
+
sid = torch.LongTensor([int(sid)])
|
177 |
+
return sid
|
178 |
+
|
179 |
+
def __getitem__(self, index):
|
180 |
+
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
181 |
+
|
182 |
+
def __len__(self):
|
183 |
+
return len(self.audiopaths_sid_text)
|
184 |
+
|
185 |
+
|
186 |
+
class TextAudioSpeakerCollate:
|
187 |
+
"""Zero-pads model inputs and targets"""
|
188 |
+
|
189 |
+
def __init__(self, return_ids=False):
|
190 |
+
self.return_ids = return_ids
|
191 |
+
|
192 |
+
def __call__(self, batch):
|
193 |
+
"""Collate's training batch from normalized text, audio and speaker identities
|
194 |
+
PARAMS
|
195 |
+
------
|
196 |
+
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
197 |
+
"""
|
198 |
+
# Right zero-pad all one-hot text sequences to max input length
|
199 |
+
_, ids_sorted_decreasing = torch.sort(
|
200 |
+
torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
|
201 |
+
)
|
202 |
+
|
203 |
+
max_text_len = max([len(x[0]) for x in batch])
|
204 |
+
max_spec_len = max([x[1].size(1) for x in batch])
|
205 |
+
max_wav_len = max([x[2].size(1) for x in batch])
|
206 |
+
|
207 |
+
text_lengths = torch.LongTensor(len(batch))
|
208 |
+
spec_lengths = torch.LongTensor(len(batch))
|
209 |
+
wav_lengths = torch.LongTensor(len(batch))
|
210 |
+
sid = torch.LongTensor(len(batch))
|
211 |
+
|
212 |
+
text_padded = torch.LongTensor(len(batch), max_text_len)
|
213 |
+
tone_padded = torch.LongTensor(len(batch), max_text_len)
|
214 |
+
language_padded = torch.LongTensor(len(batch), max_text_len)
|
215 |
+
bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
216 |
+
ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
217 |
+
en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
|
218 |
+
emo = torch.FloatTensor(len(batch), 1024)
|
219 |
+
|
220 |
+
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
221 |
+
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
222 |
+
text_padded.zero_()
|
223 |
+
tone_padded.zero_()
|
224 |
+
language_padded.zero_()
|
225 |
+
spec_padded.zero_()
|
226 |
+
wav_padded.zero_()
|
227 |
+
bert_padded.zero_()
|
228 |
+
ja_bert_padded.zero_()
|
229 |
+
en_bert_padded.zero_()
|
230 |
+
emo.zero_()
|
231 |
+
|
232 |
+
for i in range(len(ids_sorted_decreasing)):
|
233 |
+
row = batch[ids_sorted_decreasing[i]]
|
234 |
+
|
235 |
+
text = row[0]
|
236 |
+
text_padded[i, : text.size(0)] = text
|
237 |
+
text_lengths[i] = text.size(0)
|
238 |
+
|
239 |
+
spec = row[1]
|
240 |
+
spec_padded[i, :, : spec.size(1)] = spec
|
241 |
+
spec_lengths[i] = spec.size(1)
|
242 |
+
|
243 |
+
wav = row[2]
|
244 |
+
wav_padded[i, :, : wav.size(1)] = wav
|
245 |
+
wav_lengths[i] = wav.size(1)
|
246 |
+
|
247 |
+
sid[i] = row[3]
|
248 |
+
|
249 |
+
tone = row[4]
|
250 |
+
tone_padded[i, : tone.size(0)] = tone
|
251 |
+
|
252 |
+
language = row[5]
|
253 |
+
language_padded[i, : language.size(0)] = language
|
254 |
+
|
255 |
+
bert = row[6]
|
256 |
+
bert_padded[i, :, : bert.size(1)] = bert
|
257 |
+
|
258 |
+
ja_bert = row[7]
|
259 |
+
ja_bert_padded[i, :, : ja_bert.size(1)] = ja_bert
|
260 |
+
|
261 |
+
en_bert = row[8]
|
262 |
+
en_bert_padded[i, :, : en_bert.size(1)] = en_bert
|
263 |
+
|
264 |
+
emo[i, :] = row[9]
|
265 |
+
|
266 |
+
return (
|
267 |
+
text_padded,
|
268 |
+
text_lengths,
|
269 |
+
spec_padded,
|
270 |
+
spec_lengths,
|
271 |
+
wav_padded,
|
272 |
+
wav_lengths,
|
273 |
+
sid,
|
274 |
+
tone_padded,
|
275 |
+
language_padded,
|
276 |
+
bert_padded,
|
277 |
+
ja_bert_padded,
|
278 |
+
en_bert_padded,
|
279 |
+
emo,
|
280 |
+
)
|
281 |
+
|
282 |
+
|
283 |
+
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
284 |
+
"""
|
285 |
+
Maintain similar input lengths in a batch.
|
286 |
+
Length groups are specified by boundaries.
|
287 |
+
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
288 |
+
|
289 |
+
It removes samples which are not included in the boundaries.
|
290 |
+
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
291 |
+
"""
|
292 |
+
|
293 |
+
def __init__(
|
294 |
+
self,
|
295 |
+
dataset,
|
296 |
+
batch_size,
|
297 |
+
boundaries,
|
298 |
+
num_replicas=None,
|
299 |
+
rank=None,
|
300 |
+
shuffle=True,
|
301 |
+
):
|
302 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
303 |
+
self.lengths = dataset.lengths
|
304 |
+
self.batch_size = batch_size
|
305 |
+
self.boundaries = boundaries
|
306 |
+
|
307 |
+
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
308 |
+
self.total_size = sum(self.num_samples_per_bucket)
|
309 |
+
self.num_samples = self.total_size // self.num_replicas
|
310 |
+
|
311 |
+
def _create_buckets(self):
|
312 |
+
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
313 |
+
for i in range(len(self.lengths)):
|
314 |
+
length = self.lengths[i]
|
315 |
+
idx_bucket = self._bisect(length)
|
316 |
+
if idx_bucket != -1:
|
317 |
+
buckets[idx_bucket].append(i)
|
318 |
+
|
319 |
+
try:
|
320 |
+
for i in range(len(buckets) - 1, 0, -1):
|
321 |
+
if len(buckets[i]) == 0:
|
322 |
+
buckets.pop(i)
|
323 |
+
self.boundaries.pop(i + 1)
|
324 |
+
assert all(len(bucket) > 0 for bucket in buckets)
|
325 |
+
# When one bucket is not traversed
|
326 |
+
except Exception as e:
|
327 |
+
print("Bucket warning ", e)
|
328 |
+
for i in range(len(buckets) - 1, -1, -1):
|
329 |
+
if len(buckets[i]) == 0:
|
330 |
+
buckets.pop(i)
|
331 |
+
self.boundaries.pop(i + 1)
|
332 |
+
|
333 |
+
num_samples_per_bucket = []
|
334 |
+
for i in range(len(buckets)):
|
335 |
+
len_bucket = len(buckets[i])
|
336 |
+
total_batch_size = self.num_replicas * self.batch_size
|
337 |
+
rem = (
|
338 |
+
total_batch_size - (len_bucket % total_batch_size)
|
339 |
+
) % total_batch_size
|
340 |
+
num_samples_per_bucket.append(len_bucket + rem)
|
341 |
+
return buckets, num_samples_per_bucket
|
342 |
+
|
343 |
+
def __iter__(self):
|
344 |
+
# deterministically shuffle based on epoch
|
345 |
+
g = torch.Generator()
|
346 |
+
g.manual_seed(self.epoch)
|
347 |
+
|
348 |
+
indices = []
|
349 |
+
if self.shuffle:
|
350 |
+
for bucket in self.buckets:
|
351 |
+
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
352 |
+
else:
|
353 |
+
for bucket in self.buckets:
|
354 |
+
indices.append(list(range(len(bucket))))
|
355 |
+
|
356 |
+
batches = []
|
357 |
+
for i in range(len(self.buckets)):
|
358 |
+
bucket = self.buckets[i]
|
359 |
+
len_bucket = len(bucket)
|
360 |
+
if len_bucket == 0:
|
361 |
+
continue
|
362 |
+
ids_bucket = indices[i]
|
363 |
+
num_samples_bucket = self.num_samples_per_bucket[i]
|
364 |
+
|
365 |
+
# add extra samples to make it evenly divisible
|
366 |
+
rem = num_samples_bucket - len_bucket
|
367 |
+
ids_bucket = (
|
368 |
+
ids_bucket
|
369 |
+
+ ids_bucket * (rem // len_bucket)
|
370 |
+
+ ids_bucket[: (rem % len_bucket)]
|
371 |
+
)
|
372 |
+
|
373 |
+
# subsample
|
374 |
+
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
375 |
+
|
376 |
+
# batching
|
377 |
+
for j in range(len(ids_bucket) // self.batch_size):
|
378 |
+
batch = [
|
379 |
+
bucket[idx]
|
380 |
+
for idx in ids_bucket[
|
381 |
+
j * self.batch_size : (j + 1) * self.batch_size
|
382 |
+
]
|
383 |
+
]
|
384 |
+
batches.append(batch)
|
385 |
+
|
386 |
+
if self.shuffle:
|
387 |
+
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
388 |
+
batches = [batches[i] for i in batch_ids]
|
389 |
+
self.batches = batches
|
390 |
+
|
391 |
+
assert len(self.batches) * self.batch_size == self.num_samples
|
392 |
+
return iter(self.batches)
|
393 |
+
|
394 |
+
def _bisect(self, x, lo=0, hi=None):
|
395 |
+
if hi is None:
|
396 |
+
hi = len(self.boundaries) - 1
|
397 |
+
|
398 |
+
if hi > lo:
|
399 |
+
mid = (hi + lo) // 2
|
400 |
+
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
401 |
+
return mid
|
402 |
+
elif x <= self.boundaries[mid]:
|
403 |
+
return self._bisect(x, lo, mid)
|
404 |
+
else:
|
405 |
+
return self._bisect(x, mid + 1, hi)
|
406 |
+
else:
|
407 |
+
return -1
|
408 |
+
|
409 |
+
def __len__(self):
|
410 |
+
return self.num_samples // self.batch_size
|
default_config.yml
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 全局配置
|
2 |
+
# 对于希望在同一时间使用多个配置文件的情况,例如两个GPU同时跑两个训练集:通过环境变量指定配置文件,不指定则默认为./config.yml
|
3 |
+
|
4 |
+
# 拟提供通用路径配置,统一存放数据,避免数据放得很乱
|
5 |
+
# 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
|
6 |
+
# 不填或者填空则路径为相对于项目根目录的路径
|
7 |
+
dataset_path: "Data/"
|
8 |
+
|
9 |
+
# 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
|
10 |
+
mirror: ""
|
11 |
+
openi_token: "" # openi token
|
12 |
+
|
13 |
+
# resample 音频重采样配置
|
14 |
+
# 注意, “:” 后需要加空格
|
15 |
+
resample:
|
16 |
+
# 目标重采样率
|
17 |
+
sampling_rate: 44100
|
18 |
+
# 音频文件输入路径,重采样会将该路径下所有.wav音频文件重采样
|
19 |
+
# 请填入相对于datasetPath的相对路径
|
20 |
+
in_dir: "audios/raw" # 相对于根目录的路径为 /datasetPath/in_dir
|
21 |
+
# 音频文件重采样后输出路径
|
22 |
+
out_dir: "audios/wavs"
|
23 |
+
|
24 |
+
|
25 |
+
# preprocess_text 数据集预处理相关配置
|
26 |
+
# 注意, “:” 后需要加空格
|
27 |
+
preprocess_text:
|
28 |
+
# 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
|
29 |
+
transcription_path: "filelists/你的数据集文本.list"
|
30 |
+
# 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
|
31 |
+
cleaned_path: ""
|
32 |
+
# 训练集路径
|
33 |
+
train_path: "filelists/train.list"
|
34 |
+
# 验证集路径
|
35 |
+
val_path: "filelists/val.list"
|
36 |
+
# 配置文件路径
|
37 |
+
config_path: "config.json"
|
38 |
+
# 每个speaker的验证集条数
|
39 |
+
val_per_spk: 4
|
40 |
+
# 验证集最大条数,多于的会被截断并放到训练集中
|
41 |
+
max_val_total: 8
|
42 |
+
# 是否进行数据清洗
|
43 |
+
clean: true
|
44 |
+
|
45 |
+
|
46 |
+
# bert_gen 相关配置
|
47 |
+
# 注意, “:” 后需要加空格
|
48 |
+
bert_gen:
|
49 |
+
# 训练数据集配置文件路径
|
50 |
+
config_path: "config.json"
|
51 |
+
# 并行数
|
52 |
+
num_processes: 2
|
53 |
+
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
54 |
+
# 该选项同时决定了get_bert_feature的默认设备
|
55 |
+
device: "cuda"
|
56 |
+
# 使用多卡推理
|
57 |
+
use_multi_device: false
|
58 |
+
|
59 |
+
# emo_gen 相关配置
|
60 |
+
# 注意, “:” 后需要加空格
|
61 |
+
emo_gen:
|
62 |
+
# 训练数据集配置文件路径
|
63 |
+
config_path: "config.json"
|
64 |
+
# 并行数
|
65 |
+
num_processes: 2
|
66 |
+
# 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
|
67 |
+
device: "cuda"
|
68 |
+
|
69 |
+
# train 训练配置
|
70 |
+
# 注意, “:” 后需要加空格
|
71 |
+
train_ms:
|
72 |
+
env:
|
73 |
+
MASTER_ADDR: "localhost"
|
74 |
+
MASTER_PORT: 10086
|
75 |
+
WORLD_SIZE: 1
|
76 |
+
LOCAL_RANK: 0
|
77 |
+
RANK: 0
|
78 |
+
# 可以填写任意名的环境变量
|
79 |
+
# THE_ENV_VAR_YOU_NEED_TO_USE: "1234567"
|
80 |
+
# 底模设置
|
81 |
+
base:
|
82 |
+
use_base_model: false
|
83 |
+
repo_id: "Stardust_minus/Bert-VITS2"
|
84 |
+
model_image: "Bert-VITS2_2.1-Emo底模" # openi网页的模型名
|
85 |
+
# 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
|
86 |
+
model: "models"
|
87 |
+
# 配置文件路径
|
88 |
+
config_path: "config.json"
|
89 |
+
# 训练使用的worker,不建议超过CPU核心数
|
90 |
+
num_workers: 16
|
91 |
+
# 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
|
92 |
+
spec_cache: True
|
93 |
+
# 保存的检查点数量,多于此数目的权重会被删除来节省空间。
|
94 |
+
keep_ckpts: 8
|
95 |
+
|
96 |
+
|
97 |
+
# webui webui配置
|
98 |
+
# 注意, “:” 后需要加空格
|
99 |
+
webui:
|
100 |
+
# 推理设备
|
101 |
+
device: "cuda"
|
102 |
+
# 模型路径
|
103 |
+
model: "genshin/models/G_8000.pth"
|
104 |
+
# 配置文件路径
|
105 |
+
config_path: "config.json"
|
106 |
+
# 端口号
|
107 |
+
port: 7860
|
108 |
+
# 是否公开部署,对外网开放
|
109 |
+
share: false
|
110 |
+
# 是否开启debug模式
|
111 |
+
debug: false
|
112 |
+
# 语种识别库,可选langid, fastlid
|
113 |
+
language_identification_library: "langid"
|
114 |
+
|
115 |
+
|
116 |
+
# server api配置
|
117 |
+
# 注意, “:” 后需要加空格
|
118 |
+
# 注意,本配置下的所有配置均为相对于根目录的路径
|
119 |
+
server:
|
120 |
+
# 端口号
|
121 |
+
port: 5000
|
122 |
+
# 模型默认使用设备:但是当前并没有实现这个配置。
|
123 |
+
device: "cuda"
|
124 |
+
# 需要加载的所有模型的配置,可以填多个模型,也可以不填模型,等网页成功后手动加载模型
|
125 |
+
# 不加载模型的配置格式:删除默认给的两个模型配置,给models赋值 [ ],也就是空列表。参考模型2的speakers 即 models: [ ]
|
126 |
+
# 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
|
127 |
+
# 也可以不填模型,等网页加载成功后手动填写models。
|
128 |
+
models:
|
129 |
+
- # 模型的路径
|
130 |
+
model: ""
|
131 |
+
# 模型config.json的路径
|
132 |
+
config: ""
|
133 |
+
# 模型使用设备,若填写则会覆盖默认配置
|
134 |
+
device: "cuda"
|
135 |
+
# 模型默认使用的语言
|
136 |
+
language: "ZH"
|
137 |
+
# 模型人物默认参数
|
138 |
+
# 不必填写所有人物,不填的使用默认值
|
139 |
+
# 暂时不用填写,当前尚未实现按人区分配置
|
140 |
+
speakers:
|
141 |
+
- speaker: "科比"
|
142 |
+
sdp_ratio: 0.2
|
143 |
+
noise_scale: 0.6
|
144 |
+
noise_scale_w: 0.8
|
145 |
+
length_scale: 1
|
146 |
+
- speaker: "五条悟"
|
147 |
+
sdp_ratio: 0.3
|
148 |
+
noise_scale: 0.7
|
149 |
+
noise_scale_w: 0.8
|
150 |
+
length_scale: 0.5
|
151 |
+
- speaker: "安倍晋三"
|
152 |
+
sdp_ratio: 0.2
|
153 |
+
noise_scale: 0.6
|
154 |
+
noise_scale_w: 0.8
|
155 |
+
length_scale: 1.2
|
156 |
+
- # 模型的路径
|
157 |
+
model: ""
|
158 |
+
# 模型config.json的路径
|
159 |
+
config: ""
|
160 |
+
# 模型使用设备,若填写则会覆盖默认配置
|
161 |
+
device: "cpu"
|
162 |
+
# 模型默认使用的语言
|
163 |
+
language: "JP"
|
164 |
+
# 模型人物默认参数
|
165 |
+
# 不必填写所有人物,不填的使用默认值
|
166 |
+
speakers: [ ] # 也可以不填
|
167 |
+
|
168 |
+
|
169 |
+
# 百度翻译开放平台 api配置
|
170 |
+
# api接入文档 https://api.fanyi.baidu.com/doc/21
|
171 |
+
# 请不要在github等网站公开分享你的app id 与 key
|
172 |
+
translate:
|
173 |
+
# 你的APPID
|
174 |
+
"app_key": ""
|
175 |
+
# 你的密钥
|
176 |
+
"secret_key": ""
|
emo_gen.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import librosa
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from torch.utils.data import DataLoader, Dataset
|
11 |
+
from tqdm import tqdm
|
12 |
+
from transformers import Wav2Vec2Processor
|
13 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
14 |
+
Wav2Vec2Model,
|
15 |
+
Wav2Vec2PreTrainedModel,
|
16 |
+
)
|
17 |
+
|
18 |
+
import utils
|
19 |
+
from config import config
|
20 |
+
|
21 |
+
|
22 |
+
class RegressionHead(nn.Module):
|
23 |
+
r"""Classification head."""
|
24 |
+
|
25 |
+
def __init__(self, config):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
29 |
+
self.dropout = nn.Dropout(config.final_dropout)
|
30 |
+
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
31 |
+
|
32 |
+
def forward(self, features, **kwargs):
|
33 |
+
x = features
|
34 |
+
x = self.dropout(x)
|
35 |
+
x = self.dense(x)
|
36 |
+
x = torch.tanh(x)
|
37 |
+
x = self.dropout(x)
|
38 |
+
x = self.out_proj(x)
|
39 |
+
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class EmotionModel(Wav2Vec2PreTrainedModel):
|
44 |
+
r"""Speech emotion classifier."""
|
45 |
+
|
46 |
+
def __init__(self, config):
|
47 |
+
super().__init__(config)
|
48 |
+
|
49 |
+
self.config = config
|
50 |
+
self.wav2vec2 = Wav2Vec2Model(config)
|
51 |
+
self.classifier = RegressionHead(config)
|
52 |
+
self.init_weights()
|
53 |
+
|
54 |
+
def forward(
|
55 |
+
self,
|
56 |
+
input_values,
|
57 |
+
):
|
58 |
+
outputs = self.wav2vec2(input_values)
|
59 |
+
hidden_states = outputs[0]
|
60 |
+
hidden_states = torch.mean(hidden_states, dim=1)
|
61 |
+
logits = self.classifier(hidden_states)
|
62 |
+
|
63 |
+
return hidden_states, logits
|
64 |
+
|
65 |
+
|
66 |
+
class AudioDataset(Dataset):
|
67 |
+
def __init__(self, list_of_wav_files, sr, processor):
|
68 |
+
self.list_of_wav_files = list_of_wav_files
|
69 |
+
self.processor = processor
|
70 |
+
self.sr = sr
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self.list_of_wav_files)
|
74 |
+
|
75 |
+
def __getitem__(self, idx):
|
76 |
+
wav_file = self.list_of_wav_files[idx]
|
77 |
+
audio_data, _ = librosa.load(wav_file, sr=self.sr)
|
78 |
+
processed_data = self.processor(audio_data, sampling_rate=self.sr)[
|
79 |
+
"input_values"
|
80 |
+
][0]
|
81 |
+
return torch.from_numpy(processed_data)
|
82 |
+
|
83 |
+
|
84 |
+
def process_func(
|
85 |
+
x: np.ndarray,
|
86 |
+
sampling_rate: int,
|
87 |
+
model: EmotionModel,
|
88 |
+
processor: Wav2Vec2Processor,
|
89 |
+
device: str,
|
90 |
+
embeddings: bool = False,
|
91 |
+
) -> np.ndarray:
|
92 |
+
r"""Predict emotions or extract embeddings from raw audio signal."""
|
93 |
+
model = model.to(device)
|
94 |
+
y = processor(x, sampling_rate=sampling_rate)
|
95 |
+
y = y["input_values"][0]
|
96 |
+
y = torch.from_numpy(y).unsqueeze(0).to(device)
|
97 |
+
|
98 |
+
# run through model
|
99 |
+
with torch.no_grad():
|
100 |
+
y = model(y)[0 if embeddings else 1]
|
101 |
+
|
102 |
+
# convert to numpy
|
103 |
+
y = y.detach().cpu().numpy()
|
104 |
+
|
105 |
+
return y
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
parser = argparse.ArgumentParser()
|
110 |
+
parser.add_argument(
|
111 |
+
"-c", "--config", type=str, default=config.bert_gen_config.config_path
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--num_processes", type=int, default=config.bert_gen_config.num_processes
|
115 |
+
)
|
116 |
+
args, _ = parser.parse_known_args()
|
117 |
+
config_path = args.config
|
118 |
+
hps = utils.get_hparams_from_file(config_path)
|
119 |
+
|
120 |
+
device = config.bert_gen_config.device
|
121 |
+
|
122 |
+
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
123 |
+
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
124 |
+
if not Path(model_name).joinpath("pytorch_model.bin").exists():
|
125 |
+
utils.download_emo_models(config.mirror, REPO_ID, model_name)
|
126 |
+
|
127 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
128 |
+
model = EmotionModel.from_pretrained(model_name).to(device)
|
129 |
+
|
130 |
+
lines = []
|
131 |
+
with open(hps.data.training_files, encoding="utf-8") as f:
|
132 |
+
lines.extend(f.readlines())
|
133 |
+
|
134 |
+
with open(hps.data.validation_files, encoding="utf-8") as f:
|
135 |
+
lines.extend(f.readlines())
|
136 |
+
|
137 |
+
wavnames = [line.split("|")[0] for line in lines]
|
138 |
+
dataset = AudioDataset(wavnames, 16000, processor)
|
139 |
+
data_loader = DataLoader(
|
140 |
+
dataset,
|
141 |
+
batch_size=1,
|
142 |
+
shuffle=False,
|
143 |
+
num_workers=min(args.num_processes, os.cpu_count() - 1),
|
144 |
+
)
|
145 |
+
|
146 |
+
with torch.no_grad():
|
147 |
+
for i, data in tqdm(enumerate(data_loader), total=len(data_loader)):
|
148 |
+
wavname = wavnames[i]
|
149 |
+
emo_path = wavname.replace(".wav", ".emo.npy")
|
150 |
+
if os.path.exists(emo_path):
|
151 |
+
continue
|
152 |
+
emb = model(data.to(device))[0].detach().cpu().numpy()
|
153 |
+
np.save(emo_path, emb)
|
154 |
+
|
155 |
+
print("Emo vec 生成完毕!")
|
export_onnx.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models_onnx import SynthesizerTrn
|
2 |
+
import utils
|
3 |
+
from text.symbols import symbols
|
4 |
+
import os
|
5 |
+
import json
|
6 |
+
|
7 |
+
|
8 |
+
def export_onnx(export_path, model_path, config_path):
|
9 |
+
hps = utils.get_hparams_from_file(config_path)
|
10 |
+
net_g = SynthesizerTrn(
|
11 |
+
len(symbols),
|
12 |
+
hps.data.filter_length // 2 + 1,
|
13 |
+
hps.train.segment_size // hps.data.hop_length,
|
14 |
+
n_speakers=hps.data.n_speakers,
|
15 |
+
**hps.model,
|
16 |
+
)
|
17 |
+
_ = net_g.eval()
|
18 |
+
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
19 |
+
net_g.export_onnx(export_path)
|
20 |
+
|
21 |
+
spklist = []
|
22 |
+
for key in hps.data.spk2id.keys():
|
23 |
+
spklist.append(key)
|
24 |
+
|
25 |
+
MoeVSConf = {
|
26 |
+
"Folder": f"{export_path}",
|
27 |
+
"Name": f"{export_path}",
|
28 |
+
"Type": "BertVits",
|
29 |
+
"Symbol": symbols,
|
30 |
+
"Cleaner": "",
|
31 |
+
"Rate": hps.data.sampling_rate,
|
32 |
+
"CharaMix": True,
|
33 |
+
"Characters": spklist,
|
34 |
+
"LanguageMap": {"ZH": [0, 0], "JP": [1, 6], "EN": [2, 8]},
|
35 |
+
"Dict": "BasicDict",
|
36 |
+
"BertPath": [
|
37 |
+
"chinese-roberta-wwm-ext-large",
|
38 |
+
"deberta-v2-large-japanese",
|
39 |
+
"bert-base-japanese-v3",
|
40 |
+
],
|
41 |
+
}
|
42 |
+
|
43 |
+
with open(f"onnx/{export_path}.json", "w") as MoeVsConfFile:
|
44 |
+
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
print(symbols)
|
49 |
+
export_path = "HimenoSena"
|
50 |
+
model_path = "G_53000.pth"
|
51 |
+
config_path = "config.json"
|
52 |
+
if not os.path.exists("onnx"):
|
53 |
+
os.makedirs("onnx")
|
54 |
+
if not os.path.exists(f"onnx/{export_path}"):
|
55 |
+
os.makedirs(f"onnx/{export_path}")
|
56 |
+
export_onnx(export_path, model_path, config_path)
|
get_emo.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from emo_gen import EmotionModel, process_func
|
2 |
+
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from transformers import Wav2Vec2Processor
|
7 |
+
|
8 |
+
from config import config
|
9 |
+
|
10 |
+
model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
13 |
+
model = EmotionModel.from_pretrained(model_name).to(device)
|
14 |
+
|
15 |
+
|
16 |
+
def get_emo(path):
|
17 |
+
wav, sr = librosa.load(path, 16000)
|
18 |
+
device = config.bert_gen_config.device
|
19 |
+
return process_func(
|
20 |
+
np.expand_dims(wav, 0).astype(np.float64),
|
21 |
+
sr,
|
22 |
+
model,
|
23 |
+
processor,
|
24 |
+
device,
|
25 |
+
embeddings=True,
|
26 |
+
).squeeze(0)
|
infer.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
版本管理、兼容推理及模型加载实现。
|
3 |
+
版本说明:
|
4 |
+
1. 版本号与github的release版本号对应,使用哪个release版本训练的模型即对应其版本号
|
5 |
+
2. 请在模型的config.json中显示声明版本号,添加一个字段"version" : "你的版本号"
|
6 |
+
特殊版本说明:
|
7 |
+
1.1.1-fix: 1.1.1版本训练的模型,但是在推理时使用dev的日语修复
|
8 |
+
1.1.1-dev: dev开发
|
9 |
+
2.1:当前版本
|
10 |
+
"""
|
11 |
+
import torch
|
12 |
+
import commons
|
13 |
+
from text import cleaned_text_to_sequence, get_bert
|
14 |
+
from get_emo import get_emo
|
15 |
+
from text.cleaner import clean_text
|
16 |
+
import utils
|
17 |
+
|
18 |
+
from models import SynthesizerTrn
|
19 |
+
from text.symbols import symbols
|
20 |
+
from oldVersion.V200.models import SynthesizerTrn as V200SynthesizerTrn
|
21 |
+
from oldVersion.V200.text import symbols as V200symbols
|
22 |
+
from oldVersion.V111.models import SynthesizerTrn as V111SynthesizerTrn
|
23 |
+
from oldVersion.V111.text import symbols as V111symbols
|
24 |
+
from oldVersion.V110.models import SynthesizerTrn as V110SynthesizerTrn
|
25 |
+
from oldVersion.V110.text import symbols as V110symbols
|
26 |
+
from oldVersion.V101.models import SynthesizerTrn as V101SynthesizerTrn
|
27 |
+
from oldVersion.V101.text import symbols as V101symbols
|
28 |
+
|
29 |
+
from oldVersion import V111, V110, V101, V200
|
30 |
+
|
31 |
+
# 当前版本信息
|
32 |
+
latest_version = "2.1"
|
33 |
+
|
34 |
+
# 版本兼容
|
35 |
+
SynthesizerTrnMap = {
|
36 |
+
"2.0.2-fix": V200SynthesizerTrn,
|
37 |
+
"2.0.1": V200SynthesizerTrn,
|
38 |
+
"2.0": V200SynthesizerTrn,
|
39 |
+
"1.1.1-fix": V111SynthesizerTrn,
|
40 |
+
"1.1.1": V111SynthesizerTrn,
|
41 |
+
"1.1": V110SynthesizerTrn,
|
42 |
+
"1.1.0": V110SynthesizerTrn,
|
43 |
+
"1.0.1": V101SynthesizerTrn,
|
44 |
+
"1.0": V101SynthesizerTrn,
|
45 |
+
"1.0.0": V101SynthesizerTrn,
|
46 |
+
}
|
47 |
+
|
48 |
+
symbolsMap = {
|
49 |
+
"2.0.2-fix": V200symbols,
|
50 |
+
"2.0.1": V200symbols,
|
51 |
+
"2.0": V200symbols,
|
52 |
+
"1.1.1-fix": V111symbols,
|
53 |
+
"1.1.1": V111symbols,
|
54 |
+
"1.1": V110symbols,
|
55 |
+
"1.1.0": V110symbols,
|
56 |
+
"1.0.1": V101symbols,
|
57 |
+
"1.0": V101symbols,
|
58 |
+
"1.0.0": V101symbols,
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
def get_net_g(model_path: str, version: str, device: str, hps):
|
63 |
+
if version != latest_version:
|
64 |
+
net_g = SynthesizerTrnMap[version](
|
65 |
+
len(symbolsMap[version]),
|
66 |
+
hps.data.filter_length // 2 + 1,
|
67 |
+
hps.train.segment_size // hps.data.hop_length,
|
68 |
+
n_speakers=hps.data.n_speakers,
|
69 |
+
**hps.model,
|
70 |
+
).to(device)
|
71 |
+
else:
|
72 |
+
# 当前版本模型 net_g
|
73 |
+
net_g = SynthesizerTrn(
|
74 |
+
len(symbols),
|
75 |
+
hps.data.filter_length // 2 + 1,
|
76 |
+
hps.train.segment_size // hps.data.hop_length,
|
77 |
+
n_speakers=hps.data.n_speakers,
|
78 |
+
**hps.model,
|
79 |
+
).to(device)
|
80 |
+
_ = net_g.eval()
|
81 |
+
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
|
82 |
+
return net_g
|
83 |
+
|
84 |
+
|
85 |
+
def get_text(text, language_str, hps, device):
|
86 |
+
# 在此处实现当前版本的get_text
|
87 |
+
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
88 |
+
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
|
89 |
+
|
90 |
+
if hps.data.add_blank:
|
91 |
+
phone = commons.intersperse(phone, 0)
|
92 |
+
tone = commons.intersperse(tone, 0)
|
93 |
+
language = commons.intersperse(language, 0)
|
94 |
+
for i in range(len(word2ph)):
|
95 |
+
word2ph[i] = word2ph[i] * 2
|
96 |
+
word2ph[0] += 1
|
97 |
+
bert_ori = get_bert(norm_text, word2ph, language_str, device)
|
98 |
+
del word2ph
|
99 |
+
assert bert_ori.shape[-1] == len(phone), phone
|
100 |
+
|
101 |
+
if language_str == "ZH":
|
102 |
+
bert = bert_ori
|
103 |
+
ja_bert = torch.zeros(1024, len(phone))
|
104 |
+
en_bert = torch.zeros(1024, len(phone))
|
105 |
+
elif language_str == "JP":
|
106 |
+
bert = torch.zeros(1024, len(phone))
|
107 |
+
ja_bert = bert_ori
|
108 |
+
en_bert = torch.zeros(1024, len(phone))
|
109 |
+
elif language_str == "EN":
|
110 |
+
bert = torch.zeros(1024, len(phone))
|
111 |
+
ja_bert = torch.zeros(1024, len(phone))
|
112 |
+
en_bert = bert_ori
|
113 |
+
else:
|
114 |
+
raise ValueError("language_str should be ZH, JP or EN")
|
115 |
+
|
116 |
+
assert bert.shape[-1] == len(
|
117 |
+
phone
|
118 |
+
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
|
119 |
+
|
120 |
+
phone = torch.LongTensor(phone)
|
121 |
+
tone = torch.LongTensor(tone)
|
122 |
+
language = torch.LongTensor(language)
|
123 |
+
return bert, ja_bert, en_bert, phone, tone, language
|
124 |
+
|
125 |
+
|
126 |
+
def get_emo_(reference_audio, emotion):
|
127 |
+
emo = (
|
128 |
+
torch.from_numpy(get_emo(reference_audio))
|
129 |
+
if reference_audio
|
130 |
+
else torch.Tensor([emotion])
|
131 |
+
)
|
132 |
+
return emo
|
133 |
+
|
134 |
+
|
135 |
+
def infer(
|
136 |
+
text,
|
137 |
+
sdp_ratio,
|
138 |
+
noise_scale,
|
139 |
+
noise_scale_w,
|
140 |
+
length_scale,
|
141 |
+
sid,
|
142 |
+
language,
|
143 |
+
hps,
|
144 |
+
net_g,
|
145 |
+
device,
|
146 |
+
reference_audio=None,
|
147 |
+
emotion=None,
|
148 |
+
skip_start=False,
|
149 |
+
skip_end=False,
|
150 |
+
):
|
151 |
+
# 支持中日英三语版本
|
152 |
+
inferMap_V2 = {
|
153 |
+
"2.0.2-fix": V200.infer,
|
154 |
+
"2.0.1": V200.infer,
|
155 |
+
"2.0": V200.infer,
|
156 |
+
"1.1.1-fix": V111.infer_fix,
|
157 |
+
"1.1.1": V111.infer,
|
158 |
+
"1.1": V110.infer,
|
159 |
+
"1.1.0": V110.infer,
|
160 |
+
}
|
161 |
+
# 仅支持中文版本
|
162 |
+
# 在测试中,并未发现两��版本的模型不能互相通用
|
163 |
+
inferMap_V1 = {
|
164 |
+
"1.0.1": V101.infer,
|
165 |
+
"1.0": V101.infer,
|
166 |
+
"1.0.0": V101.infer,
|
167 |
+
}
|
168 |
+
version = hps.version if hasattr(hps, "version") else latest_version
|
169 |
+
# 非当前版本,根据版本号选择合适的infer
|
170 |
+
if version != latest_version:
|
171 |
+
if version in inferMap_V2.keys():
|
172 |
+
return inferMap_V2[version](
|
173 |
+
text,
|
174 |
+
sdp_ratio,
|
175 |
+
noise_scale,
|
176 |
+
noise_scale_w,
|
177 |
+
length_scale,
|
178 |
+
sid,
|
179 |
+
language,
|
180 |
+
hps,
|
181 |
+
net_g,
|
182 |
+
device,
|
183 |
+
)
|
184 |
+
if version in inferMap_V1.keys():
|
185 |
+
return inferMap_V1[version](
|
186 |
+
text,
|
187 |
+
sdp_ratio,
|
188 |
+
noise_scale,
|
189 |
+
noise_scale_w,
|
190 |
+
length_scale,
|
191 |
+
sid,
|
192 |
+
hps,
|
193 |
+
net_g,
|
194 |
+
device,
|
195 |
+
)
|
196 |
+
# 在此处实现当前版本的推理
|
197 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
198 |
+
text, language, hps, device
|
199 |
+
)
|
200 |
+
emo = get_emo_(reference_audio, emotion)
|
201 |
+
if skip_start:
|
202 |
+
phones = phones[1:]
|
203 |
+
tones = tones[1:]
|
204 |
+
lang_ids = lang_ids[1:]
|
205 |
+
bert = bert[:, 1:]
|
206 |
+
ja_bert = ja_bert[:, 1:]
|
207 |
+
en_bert = en_bert[:, 1:]
|
208 |
+
if skip_end:
|
209 |
+
phones = phones[:-1]
|
210 |
+
tones = tones[:-1]
|
211 |
+
lang_ids = lang_ids[:-1]
|
212 |
+
bert = bert[:, :-1]
|
213 |
+
ja_bert = ja_bert[:, :-1]
|
214 |
+
en_bert = en_bert[:, :-1]
|
215 |
+
with torch.no_grad():
|
216 |
+
x_tst = phones.to(device).unsqueeze(0)
|
217 |
+
tones = tones.to(device).unsqueeze(0)
|
218 |
+
lang_ids = lang_ids.to(device).unsqueeze(0)
|
219 |
+
bert = bert.to(device).unsqueeze(0)
|
220 |
+
ja_bert = ja_bert.to(device).unsqueeze(0)
|
221 |
+
en_bert = en_bert.to(device).unsqueeze(0)
|
222 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
223 |
+
emo = emo.to(device).unsqueeze(0)
|
224 |
+
del phones
|
225 |
+
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
226 |
+
audio = (
|
227 |
+
net_g.infer(
|
228 |
+
x_tst,
|
229 |
+
x_tst_lengths,
|
230 |
+
speakers,
|
231 |
+
tones,
|
232 |
+
lang_ids,
|
233 |
+
bert,
|
234 |
+
ja_bert,
|
235 |
+
en_bert,
|
236 |
+
emo,
|
237 |
+
sdp_ratio=sdp_ratio,
|
238 |
+
noise_scale=noise_scale,
|
239 |
+
noise_scale_w=noise_scale_w,
|
240 |
+
length_scale=length_scale,
|
241 |
+
)[0][0, 0]
|
242 |
+
.data.cpu()
|
243 |
+
.float()
|
244 |
+
.numpy()
|
245 |
+
)
|
246 |
+
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
|
247 |
+
if torch.cuda.is_available():
|
248 |
+
torch.cuda.empty_cache()
|
249 |
+
return audio
|
250 |
+
|
251 |
+
|
252 |
+
def infer_multilang(
|
253 |
+
text,
|
254 |
+
sdp_ratio,
|
255 |
+
noise_scale,
|
256 |
+
noise_scale_w,
|
257 |
+
length_scale,
|
258 |
+
sid,
|
259 |
+
language,
|
260 |
+
hps,
|
261 |
+
net_g,
|
262 |
+
device,
|
263 |
+
reference_audio=None,
|
264 |
+
emotion=None,
|
265 |
+
skip_start=False,
|
266 |
+
skip_end=False,
|
267 |
+
):
|
268 |
+
bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
|
269 |
+
emo = get_emo_(reference_audio, emotion)
|
270 |
+
for idx, (txt, lang) in enumerate(zip(text, language)):
|
271 |
+
skip_start = (idx != 0) or (skip_start and idx == 0)
|
272 |
+
skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
|
273 |
+
(
|
274 |
+
temp_bert,
|
275 |
+
temp_ja_bert,
|
276 |
+
temp_en_bert,
|
277 |
+
temp_phones,
|
278 |
+
temp_tones,
|
279 |
+
temp_lang_ids,
|
280 |
+
) = get_text(txt, lang, hps, device)
|
281 |
+
if skip_start:
|
282 |
+
temp_bert = temp_bert[:, 1:]
|
283 |
+
temp_ja_bert = temp_ja_bert[:, 1:]
|
284 |
+
temp_en_bert = temp_en_bert[:, 1:]
|
285 |
+
temp_phones = temp_phones[1:]
|
286 |
+
temp_tones = temp_tones[1:]
|
287 |
+
temp_lang_ids = temp_lang_ids[1:]
|
288 |
+
if skip_end:
|
289 |
+
temp_bert = temp_bert[:, :-1]
|
290 |
+
temp_ja_bert = temp_ja_bert[:, :-1]
|
291 |
+
temp_en_bert = temp_en_bert[:, :-1]
|
292 |
+
temp_phones = temp_phones[:-1]
|
293 |
+
temp_tones = temp_tones[:-1]
|
294 |
+
temp_lang_ids = temp_lang_ids[:-1]
|
295 |
+
bert.append(temp_bert)
|
296 |
+
ja_bert.append(temp_ja_bert)
|
297 |
+
en_bert.append(temp_en_bert)
|
298 |
+
phones.append(temp_phones)
|
299 |
+
tones.append(temp_tones)
|
300 |
+
lang_ids.append(temp_lang_ids)
|
301 |
+
bert = torch.concatenate(bert, dim=1)
|
302 |
+
ja_bert = torch.concatenate(ja_bert, dim=1)
|
303 |
+
en_bert = torch.concatenate(en_bert, dim=1)
|
304 |
+
phones = torch.concatenate(phones, dim=0)
|
305 |
+
tones = torch.concatenate(tones, dim=0)
|
306 |
+
lang_ids = torch.concatenate(lang_ids, dim=0)
|
307 |
+
with torch.no_grad():
|
308 |
+
x_tst = phones.to(device).unsqueeze(0)
|
309 |
+
tones = tones.to(device).unsqueeze(0)
|
310 |
+
lang_ids = lang_ids.to(device).unsqueeze(0)
|
311 |
+
bert = bert.to(device).unsqueeze(0)
|
312 |
+
ja_bert = ja_bert.to(device).unsqueeze(0)
|
313 |
+
en_bert = en_bert.to(device).unsqueeze(0)
|
314 |
+
emo = emo.to(device).unsqueeze(0)
|
315 |
+
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
316 |
+
del phones
|
317 |
+
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
|
318 |
+
audio = (
|
319 |
+
net_g.infer(
|
320 |
+
x_tst,
|
321 |
+
x_tst_lengths,
|
322 |
+
speakers,
|
323 |
+
tones,
|
324 |
+
lang_ids,
|
325 |
+
bert,
|
326 |
+
ja_bert,
|
327 |
+
en_bert,
|
328 |
+
emo,
|
329 |
+
sdp_ratio=sdp_ratio,
|
330 |
+
noise_scale=noise_scale,
|
331 |
+
noise_scale_w=noise_scale_w,
|
332 |
+
length_scale=length_scale,
|
333 |
+
)[0][0, 0]
|
334 |
+
.data.cpu()
|
335 |
+
.float()
|
336 |
+
.numpy()
|
337 |
+
)
|
338 |
+
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
|
339 |
+
if torch.cuda.is_available():
|
340 |
+
torch.cuda.empty_cache()
|
341 |
+
return audio
|