Spaces:
Runtime error
Runtime error
update
Browse files- Dockerfile +7 -9
- Dockerfile_GPU +5 -7
- README_zh.md +2 -2
- api_test.py +2 -2
- app.py +50 -50
- bert_vits2/bert_vits2.py +31 -6
- bert_vits2/models.py +23 -21
- bert_vits2/text/__init__.py +1 -9
- bert_vits2/text/bert_handler.py +33 -0
- bert_vits2/text/cleaner.py +19 -5
- bert_vits2/utils.py +20 -4
- contants.py +7 -0
- docker-compose-gpu.yaml +4 -2
- docker-compose.yaml +2 -1
- logger.py +1 -3
- utils/classify_language.py +14 -3
- utils/download.py +5 -1
- utils/lang_dict.py +10 -1
- utils/load_model.py +49 -36
- vits-simple-api-installer-latest.sh +19 -4
- vits/text/japanese.py +18 -2
- vits/vits.py +28 -17
- voice.py +65 -42
Dockerfile
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
FROM
|
2 |
|
3 |
RUN mkdir -p /app
|
4 |
WORKDIR /app
|
5 |
|
6 |
ENV DEBIAN_FRONTEND=noninteractive
|
7 |
|
|
|
8 |
RUN apt-get update && \
|
9 |
-
apt-get install -yq build-essential espeak-ng cmake wget && \
|
|
|
10 |
apt-get clean && \
|
11 |
apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
|
12 |
rm -rf /var/lib/apt/lists/*
|
@@ -16,20 +18,16 @@ RUN wget https://github.com/jemalloc/jemalloc/releases/download/5.3.0/jemalloc-5
|
|
16 |
tar -xvf jemalloc-5.3.0.tar.bz2 && \
|
17 |
cd jemalloc-5.3.0 && \
|
18 |
./configure && \
|
19 |
-
make && \
|
20 |
make install && \
|
21 |
cd .. && \
|
22 |
rm -rf jemalloc-5.3.0* && \
|
23 |
-
ldconfig
|
24 |
|
25 |
ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
|
26 |
|
27 |
-
RUN pip install torch --index-url https://download.pytorch.org/whl/cpu --no-cache-dir
|
28 |
-
|
29 |
COPY requirements.txt /app/
|
30 |
-
RUN pip install --
|
31 |
-
pip install pyopenjtalk==0.3.2 -i https://pypi.artrajz.cn/simple --no-cache-dir && \
|
32 |
-
pip install gunicorn --no-cache-dir && \
|
33 |
pip install -r requirements.txt --no-cache-dir&& \
|
34 |
rm -rf /root/.cache/pip/*
|
35 |
|
|
|
1 |
+
FROM artrajz/pytorch:1.13.1-cpu-py3.10.11-ubuntu22.04
|
2 |
|
3 |
RUN mkdir -p /app
|
4 |
WORKDIR /app
|
5 |
|
6 |
ENV DEBIAN_FRONTEND=noninteractive
|
7 |
|
8 |
+
|
9 |
RUN apt-get update && \
|
10 |
+
apt-get install -yq build-essential espeak-ng cmake wget ca-certificates tzdata&& \
|
11 |
+
update-ca-certificates && \
|
12 |
apt-get clean && \
|
13 |
apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
|
14 |
rm -rf /var/lib/apt/lists/*
|
|
|
18 |
tar -xvf jemalloc-5.3.0.tar.bz2 && \
|
19 |
cd jemalloc-5.3.0 && \
|
20 |
./configure && \
|
21 |
+
make -j$(nproc) && \
|
22 |
make install && \
|
23 |
cd .. && \
|
24 |
rm -rf jemalloc-5.3.0* && \
|
25 |
+
ldconfig
|
26 |
|
27 |
ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
|
28 |
|
|
|
|
|
29 |
COPY requirements.txt /app/
|
30 |
+
RUN pip install gunicorn --no-cache-dir && \
|
|
|
|
|
31 |
pip install -r requirements.txt --no-cache-dir&& \
|
32 |
rm -rf /root/.cache/pip/*
|
33 |
|
Dockerfile_GPU
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
FROM
|
2 |
|
3 |
RUN mkdir -p /app
|
4 |
WORKDIR /app
|
@@ -6,7 +6,7 @@ WORKDIR /app
|
|
6 |
ENV DEBIAN_FRONTEND=noninteractive
|
7 |
|
8 |
RUN apt-get update && \
|
9 |
-
apt-get install -yq build-essential espeak-ng cmake wget ca-certificates && \
|
10 |
update-ca-certificates && \
|
11 |
apt-get clean && \
|
12 |
apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
|
@@ -18,18 +18,16 @@ RUN wget https://github.com/jemalloc/jemalloc/releases/download/5.3.0/jemalloc-5
|
|
18 |
tar -xvf jemalloc-5.3.0.tar.bz2 && \
|
19 |
cd jemalloc-5.3.0 && \
|
20 |
./configure && \
|
21 |
-
make && \
|
22 |
make install && \
|
23 |
cd .. && \
|
24 |
rm -rf jemalloc-5.3.0* && \
|
25 |
-
ldconfig
|
26 |
|
27 |
ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
|
28 |
|
29 |
COPY requirements.txt /app/
|
30 |
-
RUN pip install --
|
31 |
-
pip install pyopenjtalk==0.3.2 fasttext -i https://pypi.artrajz.cn/simple --no-cache-dir && \
|
32 |
-
pip install gunicorn --no-cache-dir && \
|
33 |
pip install -r requirements.txt --no-cache-dir&& \
|
34 |
rm -rf /root/.cache/pip/*
|
35 |
|
|
|
1 |
+
FROM artrajz/pytorch:1.13.1-cu117-py3.10.11-ubuntu22.04
|
2 |
|
3 |
RUN mkdir -p /app
|
4 |
WORKDIR /app
|
|
|
6 |
ENV DEBIAN_FRONTEND=noninteractive
|
7 |
|
8 |
RUN apt-get update && \
|
9 |
+
apt-get install -yq build-essential espeak-ng cmake wget ca-certificates tzdata&& \
|
10 |
update-ca-certificates && \
|
11 |
apt-get clean && \
|
12 |
apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
|
|
|
18 |
tar -xvf jemalloc-5.3.0.tar.bz2 && \
|
19 |
cd jemalloc-5.3.0 && \
|
20 |
./configure && \
|
21 |
+
make -j$(nproc) && \
|
22 |
make install && \
|
23 |
cd .. && \
|
24 |
rm -rf jemalloc-5.3.0* && \
|
25 |
+
ldconfig
|
26 |
|
27 |
ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
|
28 |
|
29 |
COPY requirements.txt /app/
|
30 |
+
RUN pip install gunicorn --no-cache-dir && \
|
|
|
|
|
31 |
pip install -r requirements.txt --no-cache-dir&& \
|
32 |
rm -rf /root/.cache/pip/*
|
33 |
|
README_zh.md
CHANGED
@@ -21,7 +21,7 @@
|
|
21 |
|
22 |
- [x] VITS语音合成,语音转换
|
23 |
- [x] HuBert-soft VITS模型
|
24 |
-
- [x] W2V2 VITS / emotional-vits维度情感模型
|
25 |
- [x] [vits_chinese](https://github.com/PlayVoice/vits_chinese)
|
26 |
- [x] [Bert-VITS2](https://github.com/Stardust-minus/Bert-VITS2)
|
27 |
- [x] 加载多模型
|
@@ -341,7 +341,7 @@ pip install pyopenjtalk -i https://pypi.artrajz.cn/simple
|
|
341 |
| 合成文本 | text | true | | str | 需要合成语音的文本。 |
|
342 |
| 角色id | id | false | 从`config.py`中获取 | int | 即说话人id。 |
|
343 |
| 音频格式 | format | false | 从`config.py`中获取 | str | 支持wav,ogg,silk,mp3,flac |
|
344 |
-
| 文本语言 | lang | false | 从`config.py`中获取 | str |
|
345 |
| 语音长度/语速 | length | false | 从`config.py`中获取 | float | 调节语音长度,相当于调节语速,该数值越大语速越慢。 |
|
346 |
| 噪声 | noise | false | 从`config.py`中获取 | float | 样本噪声,控制合成的随机性。 |
|
347 |
| sdp噪声 | noisew | false | 从`config.py`中获取 | float | 随机时长预测器噪声,控制音素发音长度。 |
|
|
|
21 |
|
22 |
- [x] VITS语音合成,语音转换
|
23 |
- [x] HuBert-soft VITS模型
|
24 |
+
- [x] W2V2 VITS / [emotional-vits](https://github.com/innnky/emotional-vits)维度情感模型
|
25 |
- [x] [vits_chinese](https://github.com/PlayVoice/vits_chinese)
|
26 |
- [x] [Bert-VITS2](https://github.com/Stardust-minus/Bert-VITS2)
|
27 |
- [x] 加载多模型
|
|
|
341 |
| 合成文本 | text | true | | str | 需要合成语音的文本。 |
|
342 |
| 角色id | id | false | 从`config.py`中获取 | int | 即说话人id。 |
|
343 |
| 音频格式 | format | false | 从`config.py`中获取 | str | 支持wav,ogg,silk,mp3,flac |
|
344 |
+
| 文本语言 | lang | false | 从`config.py`中获取 | str | auto为自动识别语言模式,也是默认模式,但目前只支持识别整段文本的语言,无法细分到每个句子。其余可选语言zh和ja。 |
|
345 |
| 语音长度/语速 | length | false | 从`config.py`中获取 | float | 调节语音长度,相当于调节语速,该数值越大语速越慢。 |
|
346 |
| 噪声 | noise | false | 从`config.py`中获取 | float | 样本噪声,控制合成的随机性。 |
|
347 |
| sdp噪声 | noisew | false | 从`config.py`中获取 | float | 随机时长预测器噪声,控制音素发音长度。 |
|
api_test.py
CHANGED
@@ -396,9 +396,9 @@ if __name__ == '__main__':
|
|
396 |
ssml = """
|
397 |
<speak lang="zh" format="mp3" length="1.2">
|
398 |
<voice id="92" >这几天心里颇不宁静。</voice>
|
399 |
-
<voice id="
|
400 |
<voice id="142">月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;</voice>
|
401 |
-
<voice id="
|
402 |
<voice id="120">我悄悄地披了大衫,带上门出去。</voice><break time="2s"/>
|
403 |
<voice id="121">沿着荷塘,是一条曲折的小煤屑路。</voice>
|
404 |
<voice id="122">这是一条幽僻的路;白天也少人走,夜晚更加寂寞。</voice>
|
|
|
396 |
ssml = """
|
397 |
<speak lang="zh" format="mp3" length="1.2">
|
398 |
<voice id="92" >这几天心里颇不宁静。</voice>
|
399 |
+
<voice id="0" model_type="Bert-VITS2">今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。</voice>
|
400 |
<voice id="142">月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;</voice>
|
401 |
+
<voice id="0" model_type="Bert-VITS2">妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。</voice>
|
402 |
<voice id="120">我悄悄地披了大衫,带上门出去。</voice><break time="2s"/>
|
403 |
<voice id="121">沿着荷塘,是一条曲折的小煤屑路。</voice>
|
404 |
<voice id="122">这是一条幽僻的路;白天也少人走,夜晚更加寂寞。</voice>
|
app.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import os
|
2 |
import time
|
3 |
import uuid
|
|
|
|
|
4 |
from logger import logger
|
5 |
from flask import Flask, request, send_file, jsonify, make_response, render_template
|
6 |
from werkzeug.utils import secure_filename
|
@@ -85,28 +87,29 @@ def voice_vits_api():
|
|
85 |
max = int(request_data.get("max", app.config.get("MAX", 50)))
|
86 |
use_streaming = request_data.get('streaming', False, type=bool)
|
87 |
except Exception as e:
|
88 |
-
logger.error(f"[VITS] {e}")
|
89 |
return make_response("parameter error", 400)
|
90 |
|
91 |
-
logger.info(
|
92 |
-
|
|
|
93 |
|
94 |
if check_is_none(text):
|
95 |
-
logger.info(f"[VITS] text is empty")
|
96 |
return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
|
97 |
|
98 |
if check_is_none(id):
|
99 |
-
logger.info(f"[VITS] speaker id is empty")
|
100 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
101 |
|
102 |
if id < 0 or id >= tts.vits_speakers_count:
|
103 |
-
logger.info(f"[VITS] speaker id {id} does not exist")
|
104 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
105 |
|
106 |
# 校验模型是否支持输入的语言
|
107 |
-
speaker_lang = tts.voice_speakers[
|
108 |
if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
|
109 |
-
logger.info(f"[VITS] lang \"{lang}\" is not in {speaker_lang}")
|
110 |
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
|
111 |
|
112 |
# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
|
@@ -139,10 +142,10 @@ def voice_vits_api():
|
|
139 |
t1 = time.time()
|
140 |
audio = tts.vits_infer(task)
|
141 |
t2 = time.time()
|
142 |
-
logger.info(f"[VITS] finish in {(t2 - t1):.2f}s")
|
143 |
|
144 |
if app.config.get("SAVE_AUDIO", False):
|
145 |
-
logger.debug(f"[VITS] {fname}")
|
146 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
147 |
save_audio(audio.getvalue(), path)
|
148 |
|
@@ -162,20 +165,21 @@ def voice_hubert_api():
|
|
162 |
noisew = float(request.form.get("noisew", app.config.get("NOISEW", 0.8)))
|
163 |
use_streaming = request.form.get('streaming', False, type=bool)
|
164 |
except Exception as e:
|
165 |
-
logger.error(f"[
|
166 |
return make_response("parameter error", 400)
|
167 |
|
168 |
-
logger.info(
|
|
|
169 |
|
170 |
fname = secure_filename(str(uuid.uuid1()) + "." + voice.filename.split(".")[1])
|
171 |
voice.save(os.path.join(app.config['UPLOAD_FOLDER'], fname))
|
172 |
|
173 |
if check_is_none(id):
|
174 |
-
logger.info(f"[
|
175 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
176 |
|
177 |
if id < 0 or id >= tts.hubert_speakers_count:
|
178 |
-
logger.info(f"[
|
179 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
180 |
|
181 |
file_type = f"audio/{format}"
|
@@ -189,10 +193,10 @@ def voice_hubert_api():
|
|
189 |
t1 = time.time()
|
190 |
audio = tts.hubert_vits_infer(task)
|
191 |
t2 = time.time()
|
192 |
-
logger.info(f"[
|
193 |
|
194 |
if app.config.get("SAVE_AUDIO", False):
|
195 |
-
logger.debug(f"[
|
196 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
197 |
save_audio(audio.getvalue(), path)
|
198 |
|
@@ -230,29 +234,29 @@ def voice_w2v2_api():
|
|
230 |
emotion = int(request_data.get("emotion", app.config.get("EMOTION", 0)))
|
231 |
use_streaming = request_data.get('streaming', False, type=bool)
|
232 |
except Exception as e:
|
233 |
-
logger.error(f"[
|
234 |
return make_response(f"parameter error", 400)
|
235 |
|
236 |
-
logger.info(f"[
|
237 |
f"length:{length} noise:{noise} noisew:{noisew} emotion:{emotion}")
|
238 |
-
logger.info(f"[
|
239 |
|
240 |
if check_is_none(text):
|
241 |
-
logger.info(f"[
|
242 |
return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
|
243 |
|
244 |
if check_is_none(id):
|
245 |
-
logger.info(f"[
|
246 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
247 |
|
248 |
if id < 0 or id >= tts.w2v2_speakers_count:
|
249 |
-
logger.info(f"[
|
250 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
251 |
|
252 |
# 校验模型是否支持输入的语言
|
253 |
-
speaker_lang = tts.voice_speakers[
|
254 |
if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
|
255 |
-
logger.info(f"[
|
256 |
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
|
257 |
|
258 |
# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
|
@@ -279,10 +283,10 @@ def voice_w2v2_api():
|
|
279 |
t1 = time.time()
|
280 |
audio = tts.w2v2_vits_infer(task)
|
281 |
t2 = time.time()
|
282 |
-
logger.info(f"[
|
283 |
|
284 |
if app.config.get("SAVE_AUDIO", False):
|
285 |
-
logger.debug(f"[
|
286 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
287 |
save_audio(audio.getvalue(), path)
|
288 |
|
@@ -362,7 +366,7 @@ def ssml_api():
|
|
362 |
file_type = f"audio/{format}"
|
363 |
|
364 |
t1 = time.time()
|
365 |
-
audio = tts.
|
366 |
t2 = time.time()
|
367 |
logger.info(f"[ssml] finish in {(t2 - t1):.2f}s")
|
368 |
|
@@ -423,29 +427,29 @@ def voice_bert_vits2_api():
|
|
423 |
sdp_ratio = float(request_data.get("sdp_ratio", app.config.get("SDP_RATIO", 0.2)))
|
424 |
max = int(request_data.get("max", app.config.get("MAX", 50)))
|
425 |
except Exception as e:
|
426 |
-
logger.error(f"[
|
427 |
return make_response("parameter error", 400)
|
428 |
|
429 |
logger.info(
|
430 |
-
f"[
|
431 |
-
logger.info(f"[
|
432 |
|
433 |
if check_is_none(text):
|
434 |
-
logger.info(f"[
|
435 |
return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
|
436 |
|
437 |
if check_is_none(id):
|
438 |
-
logger.info(f"[
|
439 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
440 |
|
441 |
if id < 0 or id >= tts.bert_vits2_speakers_count:
|
442 |
-
logger.info(f"[
|
443 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
444 |
|
445 |
# 校验模型是否支持输入的语言
|
446 |
-
speaker_lang = tts.voice_speakers[
|
447 |
if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
|
448 |
-
logger.info(f"[
|
449 |
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
|
450 |
|
451 |
# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
|
@@ -468,10 +472,10 @@ def voice_bert_vits2_api():
|
|
468 |
t1 = time.time()
|
469 |
audio = tts.bert_vits2_infer(task)
|
470 |
t2 = time.time()
|
471 |
-
logger.info(f"[
|
472 |
|
473 |
if app.config.get("SAVE_AUDIO", False):
|
474 |
-
logger.debug(f"[
|
475 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
476 |
save_audio(audio.getvalue(), path)
|
477 |
|
@@ -490,18 +494,18 @@ def check():
|
|
490 |
else:
|
491 |
request_data = request.form
|
492 |
|
493 |
-
|
494 |
id = int(request_data.get("id"))
|
495 |
except Exception as e:
|
496 |
logger.info(f"[check] {e}")
|
497 |
return make_response(jsonify({"status": "error", "message": "parameter error"}), 400)
|
498 |
|
499 |
-
if check_is_none(
|
500 |
-
logger.info(f"[check] model {
|
501 |
return make_response(jsonify({"status": "error", "message": "model is empty"}), 400)
|
502 |
|
503 |
-
if
|
504 |
-
res = make_response(jsonify({"status": "error", "message": f"model {
|
505 |
res.status = 404
|
506 |
logger.info(f"[check] speaker id {id} error")
|
507 |
return res
|
@@ -510,16 +514,12 @@ def check():
|
|
510 |
logger.info(f"[check] speaker id is empty")
|
511 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
512 |
|
513 |
-
|
514 |
-
|
515 |
-
elif model.upper() == "HUBERT":
|
516 |
-
speaker_list = tts.voice_speakers["HUBERT-VITS"]
|
517 |
-
elif model.upper() == "W2V2":
|
518 |
-
speaker_list = tts.voice_speakers["W2V2-VITS"]
|
519 |
|
520 |
if len(speaker_list) == 0:
|
521 |
-
logger.info(f"[check] {
|
522 |
-
return make_response(jsonify({"status": "error", "message": f"{
|
523 |
|
524 |
if id < 0 or id >= len(speaker_list):
|
525 |
logger.info(f"[check] speaker id {id} does not exist")
|
|
|
1 |
import os
|
2 |
import time
|
3 |
import uuid
|
4 |
+
|
5 |
+
from contants import ModelType
|
6 |
from logger import logger
|
7 |
from flask import Flask, request, send_file, jsonify, make_response, render_template
|
8 |
from werkzeug.utils import secure_filename
|
|
|
87 |
max = int(request_data.get("max", app.config.get("MAX", 50)))
|
88 |
use_streaming = request_data.get('streaming', False, type=bool)
|
89 |
except Exception as e:
|
90 |
+
logger.error(f"[{ModelType.VITS.value}] {e}")
|
91 |
return make_response("parameter error", 400)
|
92 |
|
93 |
+
logger.info(
|
94 |
+
f"[{ModelType.VITS.value}] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew}")
|
95 |
+
logger.info(f"[{ModelType.VITS.value}] len:{len(text)} text:{text}")
|
96 |
|
97 |
if check_is_none(text):
|
98 |
+
logger.info(f"[{ModelType.VITS.value}] text is empty")
|
99 |
return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
|
100 |
|
101 |
if check_is_none(id):
|
102 |
+
logger.info(f"[{ModelType.VITS.value}] speaker id is empty")
|
103 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
104 |
|
105 |
if id < 0 or id >= tts.vits_speakers_count:
|
106 |
+
logger.info(f"[{ModelType.VITS.value}] speaker id {id} does not exist")
|
107 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
108 |
|
109 |
# 校验模型是否支持输入的语言
|
110 |
+
speaker_lang = tts.voice_speakers[ModelType.VITS.value][id].get('lang')
|
111 |
if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
|
112 |
+
logger.info(f"[{ModelType.VITS.value}] lang \"{lang}\" is not in {speaker_lang}")
|
113 |
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
|
114 |
|
115 |
# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
|
|
|
142 |
t1 = time.time()
|
143 |
audio = tts.vits_infer(task)
|
144 |
t2 = time.time()
|
145 |
+
logger.info(f"[{ModelType.VITS.value}] finish in {(t2 - t1):.2f}s")
|
146 |
|
147 |
if app.config.get("SAVE_AUDIO", False):
|
148 |
+
logger.debug(f"[{ModelType.VITS.value}] {fname}")
|
149 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
150 |
save_audio(audio.getvalue(), path)
|
151 |
|
|
|
165 |
noisew = float(request.form.get("noisew", app.config.get("NOISEW", 0.8)))
|
166 |
use_streaming = request.form.get('streaming', False, type=bool)
|
167 |
except Exception as e:
|
168 |
+
logger.error(f"[{ModelType.HUBERT_VITS.value}] {e}")
|
169 |
return make_response("parameter error", 400)
|
170 |
|
171 |
+
logger.info(
|
172 |
+
f"[{ModelType.HUBERT_VITS.value}] id:{id} format:{format} length:{length} noise:{noise} noisew:{noisew}")
|
173 |
|
174 |
fname = secure_filename(str(uuid.uuid1()) + "." + voice.filename.split(".")[1])
|
175 |
voice.save(os.path.join(app.config['UPLOAD_FOLDER'], fname))
|
176 |
|
177 |
if check_is_none(id):
|
178 |
+
logger.info(f"[{ModelType.HUBERT_VITS.value}] speaker id is empty")
|
179 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
180 |
|
181 |
if id < 0 or id >= tts.hubert_speakers_count:
|
182 |
+
logger.info(f"[{ModelType.HUBERT_VITS.value}] speaker id {id} does not exist")
|
183 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
184 |
|
185 |
file_type = f"audio/{format}"
|
|
|
193 |
t1 = time.time()
|
194 |
audio = tts.hubert_vits_infer(task)
|
195 |
t2 = time.time()
|
196 |
+
logger.info(f"[{ModelType.HUBERT_VITS.value}] finish in {(t2 - t1):.2f}s")
|
197 |
|
198 |
if app.config.get("SAVE_AUDIO", False):
|
199 |
+
logger.debug(f"[{ModelType.HUBERT_VITS.value}] {fname}")
|
200 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
201 |
save_audio(audio.getvalue(), path)
|
202 |
|
|
|
234 |
emotion = int(request_data.get("emotion", app.config.get("EMOTION", 0)))
|
235 |
use_streaming = request_data.get('streaming', False, type=bool)
|
236 |
except Exception as e:
|
237 |
+
logger.error(f"[{ModelType.W2V2_VITS.value}] {e}")
|
238 |
return make_response(f"parameter error", 400)
|
239 |
|
240 |
+
logger.info(f"[{ModelType.W2V2_VITS.value}] id:{id} format:{format} lang:{lang} "
|
241 |
f"length:{length} noise:{noise} noisew:{noisew} emotion:{emotion}")
|
242 |
+
logger.info(f"[{ModelType.W2V2_VITS.value}] len:{len(text)} text:{text}")
|
243 |
|
244 |
if check_is_none(text):
|
245 |
+
logger.info(f"[{ModelType.W2V2_VITS.value}] text is empty")
|
246 |
return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
|
247 |
|
248 |
if check_is_none(id):
|
249 |
+
logger.info(f"[{ModelType.W2V2_VITS.value}] speaker id is empty")
|
250 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
251 |
|
252 |
if id < 0 or id >= tts.w2v2_speakers_count:
|
253 |
+
logger.info(f"[{ModelType.W2V2_VITS.value}] speaker id {id} does not exist")
|
254 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
255 |
|
256 |
# 校验模型是否支持输入的语言
|
257 |
+
speaker_lang = tts.voice_speakers[ModelType.W2V2_VITS.value][id].get('lang')
|
258 |
if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
|
259 |
+
logger.info(f"[{ModelType.W2V2_VITS.value}] lang \"{lang}\" is not in {speaker_lang}")
|
260 |
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
|
261 |
|
262 |
# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
|
|
|
283 |
t1 = time.time()
|
284 |
audio = tts.w2v2_vits_infer(task)
|
285 |
t2 = time.time()
|
286 |
+
logger.info(f"[{ModelType.W2V2_VITS.value}] finish in {(t2 - t1):.2f}s")
|
287 |
|
288 |
if app.config.get("SAVE_AUDIO", False):
|
289 |
+
logger.debug(f"[{ModelType.W2V2_VITS.value}] {fname}")
|
290 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
291 |
save_audio(audio.getvalue(), path)
|
292 |
|
|
|
366 |
file_type = f"audio/{format}"
|
367 |
|
368 |
t1 = time.time()
|
369 |
+
audio = tts.process_ssml_infer_task(voice_tasks, format)
|
370 |
t2 = time.time()
|
371 |
logger.info(f"[ssml] finish in {(t2 - t1):.2f}s")
|
372 |
|
|
|
427 |
sdp_ratio = float(request_data.get("sdp_ratio", app.config.get("SDP_RATIO", 0.2)))
|
428 |
max = int(request_data.get("max", app.config.get("MAX", 50)))
|
429 |
except Exception as e:
|
430 |
+
logger.error(f"[{ModelType.BERT_VITS2.value}] {e}")
|
431 |
return make_response("parameter error", 400)
|
432 |
|
433 |
logger.info(
|
434 |
+
f"[{ModelType.BERT_VITS2.value}] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew} sdp_ratio:{sdp_ratio}")
|
435 |
+
logger.info(f"[{ModelType.BERT_VITS2.value}] len:{len(text)} text:{text}")
|
436 |
|
437 |
if check_is_none(text):
|
438 |
+
logger.info(f"[{ModelType.BERT_VITS2.value}] text is empty")
|
439 |
return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
|
440 |
|
441 |
if check_is_none(id):
|
442 |
+
logger.info(f"[{ModelType.BERT_VITS2.value}] speaker id is empty")
|
443 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
444 |
|
445 |
if id < 0 or id >= tts.bert_vits2_speakers_count:
|
446 |
+
logger.info(f"[{ModelType.BERT_VITS2.value}] speaker id {id} does not exist")
|
447 |
return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
|
448 |
|
449 |
# 校验模型是否支持输入的语言
|
450 |
+
speaker_lang = tts.voice_speakers[ModelType.BERT_VITS2.value][id].get('lang')
|
451 |
if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
|
452 |
+
logger.info(f"[{ModelType.BERT_VITS2.value}] lang \"{lang}\" is not in {speaker_lang}")
|
453 |
return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
|
454 |
|
455 |
# 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
|
|
|
472 |
t1 = time.time()
|
473 |
audio = tts.bert_vits2_infer(task)
|
474 |
t2 = time.time()
|
475 |
+
logger.info(f"[{ModelType.BERT_VITS2.value}] finish in {(t2 - t1):.2f}s")
|
476 |
|
477 |
if app.config.get("SAVE_AUDIO", False):
|
478 |
+
logger.debug(f"[{ModelType.BERT_VITS2.value}] {fname}")
|
479 |
path = os.path.join(app.config.get('CACHE_PATH'), fname)
|
480 |
save_audio(audio.getvalue(), path)
|
481 |
|
|
|
494 |
else:
|
495 |
request_data = request.form
|
496 |
|
497 |
+
model_type_str = request_data.get("model_type", request_data.get("model")).upper()
|
498 |
id = int(request_data.get("id"))
|
499 |
except Exception as e:
|
500 |
logger.info(f"[check] {e}")
|
501 |
return make_response(jsonify({"status": "error", "message": "parameter error"}), 400)
|
502 |
|
503 |
+
if check_is_none(model_type_str):
|
504 |
+
logger.info(f"[check] model {model_type_str} is empty")
|
505 |
return make_response(jsonify({"status": "error", "message": "model is empty"}), 400)
|
506 |
|
507 |
+
if model_type_str not in ModelType._value2member_map_:
|
508 |
+
res = make_response(jsonify({"status": "error", "message": f"model {model_type_str} does not exist"}))
|
509 |
res.status = 404
|
510 |
logger.info(f"[check] speaker id {id} error")
|
511 |
return res
|
|
|
514 |
logger.info(f"[check] speaker id is empty")
|
515 |
return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
|
516 |
|
517 |
+
model_type = ModelType(model_type_str)
|
518 |
+
speaker_list = tts.voice_speakers[model_type.value]
|
|
|
|
|
|
|
|
|
519 |
|
520 |
if len(speaker_list) == 0:
|
521 |
+
logger.info(f"[check] {model_type_str} not loaded")
|
522 |
+
return make_response(jsonify({"status": "error", "message": f"{model_type_str} not loaded"}), 400)
|
523 |
|
524 |
if id < 0 or id >= len(speaker_list):
|
525 |
logger.info(f"[check] speaker id {id} does not exist")
|
bert_vits2/bert_vits2.py
CHANGED
@@ -6,19 +6,39 @@ from bert_vits2 import utils as bert_vits2_utils
|
|
6 |
from bert_vits2.models import SynthesizerTrn
|
7 |
from bert_vits2.text import *
|
8 |
from bert_vits2.text.cleaner import clean_text
|
|
|
|
|
9 |
from utils import classify_language, get_hparams_from_file, lang_dict
|
10 |
from utils.sentence import sentence_split_and_markup, cut
|
11 |
|
12 |
|
13 |
class Bert_VITS2:
|
14 |
def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
|
15 |
-
self.hps_ms = get_hparams_from_file(config)
|
16 |
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
|
17 |
self.speakers = [item[0] for item in
|
18 |
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
self.legacy = getattr(self.hps_ms.data, 'legacy', False)
|
21 |
-
self.symbols = symbols_legacy if self.legacy else symbols
|
22 |
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
23 |
|
24 |
self.net_g = SynthesizerTrn(
|
@@ -33,11 +53,15 @@ class Bert_VITS2:
|
|
33 |
self.load_model(model)
|
34 |
|
35 |
def load_model(self, model):
|
36 |
-
bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True)
|
37 |
|
38 |
def get_speakers(self):
|
39 |
return self.speakers
|
40 |
|
|
|
|
|
|
|
|
|
41 |
def get_text(self, text, language_str, hps):
|
42 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
43 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)
|
@@ -49,7 +73,7 @@ class Bert_VITS2:
|
|
49 |
for i in range(len(word2ph)):
|
50 |
word2ph[i] = word2ph[i] * 2
|
51 |
word2ph[0] += 1
|
52 |
-
bert = get_bert(norm_text, word2ph, language_str)
|
53 |
del word2ph
|
54 |
assert bert.shape[-1] == len(phone), phone
|
55 |
|
@@ -98,7 +122,8 @@ class Bert_VITS2:
|
|
98 |
max = voice.get("max", 50)
|
99 |
# sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"])
|
100 |
if lang == "auto":
|
101 |
-
lang = classify_language(text, target_languages=
|
|
|
102 |
sentence_list = cut(text, max)
|
103 |
audios = []
|
104 |
for sentence in sentence_list:
|
|
|
6 |
from bert_vits2.models import SynthesizerTrn
|
7 |
from bert_vits2.text import *
|
8 |
from bert_vits2.text.cleaner import clean_text
|
9 |
+
from bert_vits2.utils import process_legacy_versions
|
10 |
+
from contants import ModelType
|
11 |
from utils import classify_language, get_hparams_from_file, lang_dict
|
12 |
from utils.sentence import sentence_split_and_markup, cut
|
13 |
|
14 |
|
15 |
class Bert_VITS2:
|
16 |
def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
|
17 |
+
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
|
18 |
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
|
19 |
self.speakers = [item[0] for item in
|
20 |
sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
|
21 |
+
self.symbols = symbols
|
22 |
+
|
23 |
+
|
24 |
+
# Compatible with legacy versions
|
25 |
+
self.version = process_legacy_versions(self.hps_ms)
|
26 |
+
|
27 |
+
if self.version in ["1.0", "1.0.0", "1.0.1"]:
|
28 |
+
self.symbols = symbols_legacy
|
29 |
+
self.hps_ms.model.n_layers_trans_flow = 3
|
30 |
+
|
31 |
+
|
32 |
+
elif self.version in ["1.1.0-transition"]:
|
33 |
+
self.hps_ms.model.n_layers_trans_flow = 3
|
34 |
+
|
35 |
+
elif self.version in ["1.1", "1.1.0", "1.1.1"]:
|
36 |
+
self.hps_ms.model.n_layers_trans_flow = 6
|
37 |
+
|
38 |
+
key = f"{ModelType.BERT_VITS2.value}_v{self.version}" if self.version else ModelType.BERT_VITS2.value
|
39 |
+
self.lang = lang_dict.get(key, ["unknown"])
|
40 |
+
self.bert_handler = BertHandler(self.lang)
|
41 |
|
|
|
|
|
42 |
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
43 |
|
44 |
self.net_g = SynthesizerTrn(
|
|
|
53 |
self.load_model(model)
|
54 |
|
55 |
def load_model(self, model):
|
56 |
+
bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, version=self.version)
|
57 |
|
58 |
def get_speakers(self):
|
59 |
return self.speakers
|
60 |
|
61 |
+
@property
|
62 |
+
def sampling_rate(self):
|
63 |
+
return self.hps_ms.data.sampling_rate
|
64 |
+
|
65 |
def get_text(self, text, language_str, hps):
|
66 |
norm_text, phone, tone, word2ph = clean_text(text, language_str)
|
67 |
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)
|
|
|
73 |
for i in range(len(word2ph)):
|
74 |
word2ph[i] = word2ph[i] * 2
|
75 |
word2ph[0] += 1
|
76 |
+
bert = self.bert_handler.get_bert(norm_text, word2ph, language_str)
|
77 |
del word2ph
|
78 |
assert bert.shape[-1] == len(phone), phone
|
79 |
|
|
|
122 |
max = voice.get("max", 50)
|
123 |
# sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"])
|
124 |
if lang == "auto":
|
125 |
+
lang = classify_language(text, target_languages=self.lang)
|
126 |
+
|
127 |
sentence_list = cut(text, max)
|
128 |
audios = []
|
129 |
for sentence in sentence_list:
|
bert_vits2/models.py
CHANGED
@@ -26,9 +26,11 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
26 |
|
27 |
self.drop = nn.Dropout(p_dropout)
|
28 |
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
29 |
-
|
30 |
-
self.conv_2 = nn.Conv1d(
|
31 |
-
|
|
|
|
|
32 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
33 |
|
34 |
self.pre_out_conv_1 = nn.Conv1d(2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
@@ -36,8 +38,8 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
36 |
self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
37 |
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
38 |
|
39 |
-
|
40 |
-
|
41 |
|
42 |
self.output_layer = nn.Sequential(
|
43 |
nn.Linear(filter_channels, 1),
|
@@ -48,13 +50,13 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
48 |
dur = self.dur_proj(dur)
|
49 |
x = torch.cat([x, dur], dim=1)
|
50 |
x = self.pre_out_conv_1(x * x_mask)
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
x = self.pre_out_conv_2(x * x_mask)
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
x = x * x_mask
|
59 |
x = x.transpose(1, 2)
|
60 |
output_prob = self.output_layer(x)
|
@@ -62,17 +64,17 @@ class DurationDiscriminator(nn.Module): # vits2
|
|
62 |
|
63 |
def forward(self, x, x_mask, dur_r, dur_hat, g=None):
|
64 |
x = torch.detach(x)
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
x = self.conv_1(x * x_mask)
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
x = self.conv_2(x * x_mask)
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
|
77 |
output_probs = []
|
78 |
for dur in [dur_r, dur_hat]:
|
@@ -590,7 +592,7 @@ class SynthesizerTrn(nn.Module):
|
|
590 |
gin_channels=256,
|
591 |
use_sdp=True,
|
592 |
n_flow_layer=4,
|
593 |
-
n_layers_trans_flow=
|
594 |
flow_share_parameter=False,
|
595 |
use_transformer_flow=True,
|
596 |
**kwargs):
|
|
|
26 |
|
27 |
self.drop = nn.Dropout(p_dropout)
|
28 |
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
29 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
30 |
+
self.conv_2 = nn.Conv1d(
|
31 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
32 |
+
)
|
33 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
34 |
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
|
35 |
|
36 |
self.pre_out_conv_1 = nn.Conv1d(2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
|
|
38 |
self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
39 |
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
|
40 |
|
41 |
+
if gin_channels != 0:
|
42 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
43 |
|
44 |
self.output_layer = nn.Sequential(
|
45 |
nn.Linear(filter_channels, 1),
|
|
|
50 |
dur = self.dur_proj(dur)
|
51 |
x = torch.cat([x, dur], dim=1)
|
52 |
x = self.pre_out_conv_1(x * x_mask)
|
53 |
+
x = torch.relu(x)
|
54 |
+
x = self.pre_out_norm_1(x)
|
55 |
+
x = self.drop(x)
|
56 |
x = self.pre_out_conv_2(x * x_mask)
|
57 |
+
x = torch.relu(x)
|
58 |
+
x = self.pre_out_norm_2(x)
|
59 |
+
x = self.drop(x)
|
60 |
x = x * x_mask
|
61 |
x = x.transpose(1, 2)
|
62 |
output_prob = self.output_layer(x)
|
|
|
64 |
|
65 |
def forward(self, x, x_mask, dur_r, dur_hat, g=None):
|
66 |
x = torch.detach(x)
|
67 |
+
if g is not None:
|
68 |
+
g = torch.detach(g)
|
69 |
+
x = x + self.cond(g)
|
70 |
x = self.conv_1(x * x_mask)
|
71 |
+
x = torch.relu(x)
|
72 |
+
x = self.norm_1(x)
|
73 |
+
x = self.drop(x)
|
74 |
x = self.conv_2(x * x_mask)
|
75 |
+
x = torch.relu(x)
|
76 |
+
x = self.norm_2(x)
|
77 |
+
x = self.drop(x)
|
78 |
|
79 |
output_probs = []
|
80 |
for dur in [dur_r, dur_hat]:
|
|
|
592 |
gin_channels=256,
|
593 |
use_sdp=True,
|
594 |
n_flow_layer=4,
|
595 |
+
n_layers_trans_flow=6,
|
596 |
flow_share_parameter=False,
|
597 |
use_transformer_flow=True,
|
598 |
**kwargs):
|
bert_vits2/text/__init__.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
from bert_vits2.text.symbols import *
|
2 |
-
from .
|
3 |
-
from .english_bert_mock import get_bert_feature as en_bert
|
4 |
-
from .japanese_bert import get_bert_feature as ja_bert
|
5 |
|
6 |
|
7 |
def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
|
@@ -17,9 +15,3 @@ def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
|
|
17 |
lang_id = language_id_map[language]
|
18 |
lang_ids = [lang_id for i in phones]
|
19 |
return phones, tones, lang_ids
|
20 |
-
|
21 |
-
|
22 |
-
def get_bert(norm_text, word2ph, language):
|
23 |
-
lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert}
|
24 |
-
bert = lang_bert_func_map[language](norm_text, word2ph)
|
25 |
-
return bert
|
|
|
1 |
from bert_vits2.text.symbols import *
|
2 |
+
from bert_vits2.text.bert_handler import BertHandler
|
|
|
|
|
3 |
|
4 |
|
5 |
def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
|
|
|
15 |
lang_id = language_id_map[language]
|
16 |
lang_ids = [lang_id for i in phones]
|
17 |
return phones, tones, lang_ids
|
|
|
|
|
|
|
|
|
|
|
|
bert_vits2/text/bert_handler.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
|
4 |
+
class BertHandler:
|
5 |
+
_bert_functions = {}
|
6 |
+
|
7 |
+
BERT_IMPORT_MAP = {
|
8 |
+
"zh": "bert_vits2.text.chinese_bert.get_bert_feature",
|
9 |
+
"en": "bert_vits2.text.english_bert_mock.get_bert_feature",
|
10 |
+
"ja": "bert_vits2.text.japanese_bert.get_bert_feature",
|
11 |
+
}
|
12 |
+
|
13 |
+
def __init__(self, languages):
|
14 |
+
for lang in languages:
|
15 |
+
if lang not in BertHandler._bert_functions:
|
16 |
+
self.load_bert_function(lang)
|
17 |
+
|
18 |
+
def load_bert_function(self, language):
|
19 |
+
if language not in BertHandler.BERT_IMPORT_MAP:
|
20 |
+
raise ValueError(f"Unsupported language: {language}")
|
21 |
+
|
22 |
+
module_path, function_name = BertHandler.BERT_IMPORT_MAP[language].rsplit('.', 1)
|
23 |
+
module = importlib.import_module(module_path, package=__package__)
|
24 |
+
bert_function = getattr(module, function_name)
|
25 |
+
|
26 |
+
BertHandler._bert_functions[language] = bert_function
|
27 |
+
|
28 |
+
def get_bert(self, norm_text, word2ph, language):
|
29 |
+
if language not in BertHandler._bert_functions:
|
30 |
+
raise ValueError(f"BERT for {language} has not been initialized. Please initialize first.")
|
31 |
+
|
32 |
+
bert_func = BertHandler._bert_functions[language]
|
33 |
+
return bert_func(norm_text, word2ph)
|
bert_vits2/text/cleaner.py
CHANGED
@@ -1,20 +1,34 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
language_module_map = {
|
4 |
-
'zh': chinese,
|
5 |
-
'ja': japanese
|
6 |
}
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def clean_text(text, language):
|
10 |
-
language_module =
|
11 |
norm_text = language_module.text_normalize(text)
|
12 |
phones, tones, word2ph = language_module.g2p(norm_text)
|
13 |
return norm_text, phones, tones, word2ph
|
14 |
|
15 |
|
16 |
def clean_text_bert(text, language):
|
17 |
-
language_module =
|
18 |
norm_text = language_module.text_normalize(text)
|
19 |
phones, tones, word2ph = language_module.g2p(norm_text)
|
20 |
bert = language_module.get_bert_feature(norm_text, word2ph)
|
|
|
1 |
+
import importlib
|
2 |
+
from bert_vits2.text import cleaned_text_to_sequence
|
3 |
|
4 |
language_module_map = {
|
5 |
+
'zh': "bert_vits2.text.chinese",
|
6 |
+
'ja': "bert_vits2.text.japanese"
|
7 |
}
|
8 |
|
9 |
+
_loaded_modules = {}
|
10 |
+
|
11 |
+
|
12 |
+
def get_language_module(language):
|
13 |
+
if language not in _loaded_modules:
|
14 |
+
module_path = language_module_map.get(language)
|
15 |
+
if not module_path:
|
16 |
+
raise ValueError(f"Unsupported language: {language}")
|
17 |
+
|
18 |
+
_loaded_modules[language] = importlib.import_module(module_path)
|
19 |
+
|
20 |
+
return _loaded_modules[language]
|
21 |
+
|
22 |
|
23 |
def clean_text(text, language):
|
24 |
+
language_module = get_language_module(language)
|
25 |
norm_text = language_module.text_normalize(text)
|
26 |
phones, tones, word2ph = language_module.g2p(norm_text)
|
27 |
return norm_text, phones, tones, word2ph
|
28 |
|
29 |
|
30 |
def clean_text_bert(text, language):
|
31 |
+
language_module = get_language_module(language)
|
32 |
norm_text = language_module.text_normalize(text)
|
33 |
phones, tones, word2ph = language_module.g2p(norm_text)
|
34 |
bert = language_module.get_bert_feature(norm_text, word2ph)
|
bert_vits2/utils.py
CHANGED
@@ -9,7 +9,7 @@ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
|
9 |
logger = logging
|
10 |
|
11 |
|
12 |
-
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
|
13 |
assert os.path.isfile(checkpoint_path)
|
14 |
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
15 |
iteration = checkpoint_dict['iteration']
|
@@ -36,12 +36,19 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|
36 |
new_state_dict[k] = saved_state_dict[k]
|
37 |
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
|
38 |
except:
|
39 |
-
#
|
40 |
if "ja_bert_proj" in k:
|
41 |
v = torch.zeros_like(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
logger.warning(
|
43 |
-
f"If you
|
44 |
-
|
|
|
45 |
|
46 |
new_state_dict[k] = v
|
47 |
if hasattr(model, 'module'):
|
@@ -52,3 +59,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
|
|
52 |
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
53 |
checkpoint_path, iteration))
|
54 |
return model, optimizer, learning_rate, iteration
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
logger = logging
|
10 |
|
11 |
|
12 |
+
def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, version=None):
|
13 |
assert os.path.isfile(checkpoint_path)
|
14 |
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
15 |
iteration = checkpoint_dict['iteration']
|
|
|
36 |
new_state_dict[k] = saved_state_dict[k]
|
37 |
assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
|
38 |
except:
|
39 |
+
# Handle legacy model versions and provide appropriate warnings
|
40 |
if "ja_bert_proj" in k:
|
41 |
v = torch.zeros_like(v)
|
42 |
+
if version is None:
|
43 |
+
logger.error(f"{k} is not in the checkpoint")
|
44 |
+
logger.warning(
|
45 |
+
f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"")
|
46 |
+
elif "flow.flows.0.enc.attn_layers.3" in k:
|
47 |
+
logger.error(f"{k} is not in the checkpoint")
|
48 |
logger.warning(
|
49 |
+
f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.")
|
50 |
+
else:
|
51 |
+
logger.error(f"{k} is not in the checkpoint")
|
52 |
|
53 |
new_state_dict[k] = v
|
54 |
if hasattr(model, 'module'):
|
|
|
59 |
logger.info("Loaded checkpoint '{}' (iteration {})".format(
|
60 |
checkpoint_path, iteration))
|
61 |
return model, optimizer, learning_rate, iteration
|
62 |
+
|
63 |
+
|
64 |
+
def process_legacy_versions(hps):
|
65 |
+
version = getattr(hps, "version", getattr(hps.data, "version", None))
|
66 |
+
if version:
|
67 |
+
prefix = version[0].lower()
|
68 |
+
if prefix == "v":
|
69 |
+
version = version[1:]
|
70 |
+
return version
|
contants.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
class ModelType(Enum):
|
4 |
+
VITS = "VITS"
|
5 |
+
HUBERT_VITS = "HUBERT-VITS"
|
6 |
+
W2V2_VITS = "W2V2-VITS"
|
7 |
+
BERT_VITS2 = "BERT-VITS2"
|
docker-compose-gpu.yaml
CHANGED
@@ -8,14 +8,16 @@ services:
|
|
8 |
- 23456:23456
|
9 |
environment:
|
10 |
LANG: 'C.UTF-8'
|
11 |
-
|
|
|
12 |
volumes:
|
13 |
- ./Model:/app/Model # 挂载模型文件夹
|
14 |
- ./config.py:/app/config.py # 挂载配置文件
|
15 |
- ./logs:/app/logs # logging logs
|
16 |
- ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
|
17 |
- ./vits/bert:/app/vits/bert # vits_chinese
|
18 |
-
- ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2
|
|
|
19 |
- ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
|
20 |
deploy:
|
21 |
resources:
|
|
|
8 |
- 23456:23456
|
9 |
environment:
|
10 |
LANG: 'C.UTF-8'
|
11 |
+
TZ: Asia/Shanghai #timezone
|
12 |
+
command: ['python', 'app.py'] # 目前用不了gunicorn多进程
|
13 |
volumes:
|
14 |
- ./Model:/app/Model # 挂载模型文件夹
|
15 |
- ./config.py:/app/config.py # 挂载配置文件
|
16 |
- ./logs:/app/logs # logging logs
|
17 |
- ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
|
18 |
- ./vits/bert:/app/vits/bert # vits_chinese
|
19 |
+
- ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2 chinese
|
20 |
+
- ./bert_vits2/bert/bert-base-japanese-v3:/app/bert_vits2/bert/bert-base-japanese-v3 # Bert-vits2 japanese
|
21 |
- ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
|
22 |
deploy:
|
23 |
resources:
|
docker-compose.yaml
CHANGED
@@ -14,5 +14,6 @@ services:
|
|
14 |
- ./logs:/app/logs # logging logs
|
15 |
- ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
|
16 |
- ./vits/bert:/app/vits/bert # vits_chinese
|
17 |
-
- ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2
|
|
|
18 |
- ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
|
|
|
14 |
- ./logs:/app/logs # logging logs
|
15 |
- ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
|
16 |
- ./vits/bert:/app/vits/bert # vits_chinese
|
17 |
+
- ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2 chinese
|
18 |
+
- ./bert_vits2/bert/bert-base-japanese-v3:/app/bert_vits2/bert/bert-base-japanese-v3 # Bert-vits2 japanese
|
19 |
- ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
|
logger.py
CHANGED
@@ -22,10 +22,8 @@ handler = TimedRotatingFileHandler(log_file, when="midnight", interval=1, backup
|
|
22 |
handler.suffix = "%Y-%m-%d.log"
|
23 |
formatter = logging.Formatter('%(levelname)s:%(name)s %(message)s')
|
24 |
handler.setFormatter(formatter)
|
25 |
-
logger.addHandler(handler)
|
26 |
|
27 |
-
logging.getLogger(
|
28 |
-
logging.getLogger("apscheduler.scheduler").addHandler(handler)
|
29 |
|
30 |
|
31 |
# Custom function to handle uncaught exceptions
|
|
|
22 |
handler.suffix = "%Y-%m-%d.log"
|
23 |
formatter = logging.Formatter('%(levelname)s:%(name)s %(message)s')
|
24 |
handler.setFormatter(formatter)
|
|
|
25 |
|
26 |
+
logging.getLogger().addHandler(handler)
|
|
|
27 |
|
28 |
|
29 |
# Custom function to handle uncaught exceptions
|
utils/classify_language.py
CHANGED
@@ -2,16 +2,27 @@ from config import LANGUAGE_IDENTIFICATION_LIBRARY
|
|
2 |
|
3 |
module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
def classify_language(text: str, target_languages: list = None) -> str:
|
7 |
if module == "fastlid" or module == "fasttext":
|
8 |
-
from fastlid import fastlid
|
9 |
classifier = fastlid
|
10 |
-
if target_languages != None:
|
|
|
|
|
11 |
elif module == "langid":
|
12 |
import langid
|
13 |
classifier = langid.classify
|
14 |
-
if target_languages != None:
|
|
|
|
|
15 |
else:
|
16 |
raise ValueError(f"Wrong LANGUAGE_IDENTIFICATION_LIBRARY in config.py")
|
17 |
|
|
|
2 |
|
3 |
module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
|
4 |
|
5 |
+
langid_languages = ["af", "am", "an", "ar", "as", "az", "be", "bg", "bn", "br", "bs", "ca", "cs", "cy", "da", "de", "dz", "el",
|
6 |
+
"en", "eo", "es", "et", "eu", "fa", "fi", "fo", "fr", "ga", "gl", "gu", "he", "hi", "hr", "ht", "hu", "hy",
|
7 |
+
"id", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "ku", "ky", "la", "lb", "lo", "lt", "lv", "mg",
|
8 |
+
"mk", "ml", "mn", "mr", "ms", "mt", "nb", "ne", "nl", "nn", "no", "oc", "or", "pa", "pl", "ps", "pt", "qu",
|
9 |
+
"ro", "ru", "rw", "se", "si", "sk", "sl", "sq", "sr", "sv", "sw", "ta", "te", "th", "tl", "tr", "ug", "uk",
|
10 |
+
"ur", "vi", "vo", "wa", "xh", "zh", "zu"]
|
11 |
+
|
12 |
|
13 |
def classify_language(text: str, target_languages: list = None) -> str:
|
14 |
if module == "fastlid" or module == "fasttext":
|
15 |
+
from fastlid import fastlid, supported_langs
|
16 |
classifier = fastlid
|
17 |
+
if target_languages != None:
|
18 |
+
target_languages = [lang for lang in target_languages if lang in supported_langs]
|
19 |
+
fastlid.set_languages = target_languages
|
20 |
elif module == "langid":
|
21 |
import langid
|
22 |
classifier = langid.classify
|
23 |
+
if target_languages != None:
|
24 |
+
target_languages = [lang for lang in target_languages if lang in langid_languages]
|
25 |
+
langid.set_languages(target_languages)
|
26 |
else:
|
27 |
raise ValueError(f"Wrong LANGUAGE_IDENTIFICATION_LIBRARY in config.py")
|
28 |
|
utils/download.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import os
|
2 |
import hashlib
|
3 |
import tarfile
|
@@ -18,6 +19,7 @@ class TqdmUpTo(tqdm):
|
|
18 |
|
19 |
|
20 |
def download_file(url, dest_path):
|
|
|
21 |
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=url.split('/')[-1]) as t:
|
22 |
urllib.request.urlretrieve(url, dest_path, reporthook=t.update_to)
|
23 |
|
@@ -37,6 +39,8 @@ def extract_file(file_path, destination=None):
|
|
37 |
if destination is None:
|
38 |
destination = Path(file_path).parent
|
39 |
|
|
|
|
|
40 |
if file_path.endswith('.zip'):
|
41 |
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
42 |
zip_ref.extractall(destination)
|
@@ -50,7 +54,7 @@ def extract_file(file_path, destination=None):
|
|
50 |
with SevenZipFile(file_path, mode='r') as z:
|
51 |
z.extractall(destination)
|
52 |
else:
|
53 |
-
|
54 |
|
55 |
|
56 |
def download_and_verify(urls, target_path, expected_md5=None, extract_destination=None):
|
|
|
1 |
+
import logging
|
2 |
import os
|
3 |
import hashlib
|
4 |
import tarfile
|
|
|
19 |
|
20 |
|
21 |
def download_file(url, dest_path):
|
22 |
+
logging.info(f"Downloading: {url}")
|
23 |
with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=url.split('/')[-1]) as t:
|
24 |
urllib.request.urlretrieve(url, dest_path, reporthook=t.update_to)
|
25 |
|
|
|
39 |
if destination is None:
|
40 |
destination = Path(file_path).parent
|
41 |
|
42 |
+
logging.info(f"Extracting to {destination}")
|
43 |
+
|
44 |
if file_path.endswith('.zip'):
|
45 |
with zipfile.ZipFile(file_path, 'r') as zip_ref:
|
46 |
zip_ref.extractall(destination)
|
|
|
54 |
with SevenZipFile(file_path, mode='r') as z:
|
55 |
z.extractall(destination)
|
56 |
else:
|
57 |
+
logging.error(f"Unsupported compression format for file {file_path}")
|
58 |
|
59 |
|
60 |
def download_and_verify(urls, target_path, expected_md5=None, extract_destination=None):
|
utils/lang_dict.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
lang_dict = {
|
2 |
"english_cleaners": ["en"],
|
3 |
"english_cleaners2": ["en"],
|
@@ -18,5 +20,12 @@ lang_dict = {
|
|
18 |
"ZR", "PH", "TX", "JS", "HN", "LP", "XS", "FY", "RA", "CX", "SM", "TT", "WZ", "SC",
|
19 |
"YB"],
|
20 |
"bert_chinese_cleaners": ["zh"],
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
}
|
|
|
1 |
+
from contants import ModelType
|
2 |
+
|
3 |
lang_dict = {
|
4 |
"english_cleaners": ["en"],
|
5 |
"english_cleaners2": ["en"],
|
|
|
20 |
"ZR", "PH", "TX", "JS", "HN", "LP", "XS", "FY", "RA", "CX", "SM", "TT", "WZ", "SC",
|
21 |
"YB"],
|
22 |
"bert_chinese_cleaners": ["zh"],
|
23 |
+
ModelType.BERT_VITS2.value: ["zh", "ja"],
|
24 |
+
f"{ModelType.BERT_VITS2.value}_v1.0": ["zh"],
|
25 |
+
f"{ModelType.BERT_VITS2.value}_v1.0.0": ["zh"],
|
26 |
+
f"{ModelType.BERT_VITS2.value}_v1.0.1": ["zh"],
|
27 |
+
f"{ModelType.BERT_VITS2.value}_v1.1": ["zh", "ja"],
|
28 |
+
f"{ModelType.BERT_VITS2.value}_v1.1.0": ["zh", "ja"],
|
29 |
+
f"{ModelType.BERT_VITS2.value}_v1.1.0-transition": ["zh", "ja"],
|
30 |
+
f"{ModelType.BERT_VITS2.value}_v1.1.1": ["zh", "ja"],
|
31 |
}
|
utils/load_model.py
CHANGED
@@ -3,30 +3,33 @@ import json
|
|
3 |
import logging
|
4 |
import config
|
5 |
import numpy as np
|
6 |
-
|
|
|
|
|
7 |
from vits import VITS
|
8 |
from voice import TTS
|
9 |
from config import DEVICE as device
|
10 |
from utils.lang_dict import lang_dict
|
|
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
symbols = model_config.get("symbols", None)
|
17 |
-
emotion_embedding = model_config.get("data").get("emotion_embedding", False)
|
18 |
-
|
19 |
-
if "use_spk_conditioned_encoder" in model_config.get("model"):
|
20 |
-
model_type = "bert_vits2"
|
21 |
return model_type
|
22 |
|
23 |
if symbols != None:
|
24 |
if not emotion_embedding:
|
25 |
-
mode_type =
|
26 |
else:
|
27 |
-
mode_type =
|
28 |
else:
|
29 |
-
mode_type =
|
30 |
|
31 |
return mode_type
|
32 |
|
@@ -72,16 +75,19 @@ def load_npy(emotion_reference_npy):
|
|
72 |
|
73 |
def parse_models(model_list):
|
74 |
categorized_models = {
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
}
|
80 |
|
81 |
for model_info in model_list:
|
82 |
config_path = model_info[1]
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
85 |
if model_type in categorized_models:
|
86 |
categorized_models[model_type].append(model_info)
|
87 |
|
@@ -93,23 +99,26 @@ def merge_models(model_list, model_class, model_type, additional_arg=None):
|
|
93 |
speakers = []
|
94 |
new_id = 0
|
95 |
|
96 |
-
for obj_id, (model_path, config_path) in enumerate(model_list):
|
97 |
obj_args = {
|
98 |
"model": model_path,
|
99 |
-
"config":
|
100 |
"model_type": model_type,
|
101 |
"device": device
|
102 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
if additional_arg:
|
104 |
obj_args.update(additional_arg)
|
105 |
|
106 |
obj = model_class(**obj_args)
|
107 |
|
108 |
-
if model_type == "bert_vits2":
|
109 |
-
key = model_type
|
110 |
-
else:
|
111 |
-
key = obj.get_cleaner()
|
112 |
-
|
113 |
lang = lang_dict.get(key, ["unknown"])
|
114 |
|
115 |
for real_id, name in enumerate(obj.get_speakers()):
|
@@ -124,11 +133,11 @@ def load_model(model_list) -> TTS:
|
|
124 |
categorized_models = parse_models(model_list)
|
125 |
|
126 |
# Handle VITS
|
127 |
-
vits_objs, vits_speakers = merge_models(categorized_models[
|
128 |
|
129 |
# Handle HUBERT-VITS
|
130 |
hubert_vits_objs, hubert_vits_speakers = [], []
|
131 |
-
if len(categorized_models[
|
132 |
if getattr(config, "HUBERT_SOFT_MODEL", None) is None or check_is_none(config.HUBERT_SOFT_MODEL):
|
133 |
raise ValueError(f"Please configure HUBERT_SOFT_MODEL path in config.py")
|
134 |
try:
|
@@ -137,13 +146,13 @@ def load_model(model_list) -> TTS:
|
|
137 |
except Exception as e:
|
138 |
raise ValueError(f"Load HUBERT_SOFT_MODEL failed {e}")
|
139 |
|
140 |
-
hubert_vits_objs, hubert_vits_speakers = merge_models(categorized_models[
|
141 |
additional_arg={"additional_model": hubert})
|
142 |
|
143 |
# Handle W2V2-VITS
|
144 |
w2v2_vits_objs, w2v2_vits_speakers = [], []
|
145 |
w2v2_emotion_count = 0
|
146 |
-
if len(categorized_models[
|
147 |
if getattr(config, "DIMENSIONAL_EMOTION_NPY", None) is None or check_is_none(
|
148 |
config.DIMENSIONAL_EMOTION_NPY):
|
149 |
raise ValueError(f"Please configure DIMENSIONAL_EMOTION_NPY path in config.py")
|
@@ -153,20 +162,24 @@ def load_model(model_list) -> TTS:
|
|
153 |
emotion_reference = None
|
154 |
raise ValueError(f"Load DIMENSIONAL_EMOTION_NPY failed {e}")
|
155 |
|
156 |
-
w2v2_vits_objs, w2v2_vits_speakers = merge_models(categorized_models[
|
157 |
additional_arg={"additional_model": emotion_reference})
|
158 |
w2v2_emotion_count = len(emotion_reference) if emotion_reference is not None else 0
|
159 |
|
160 |
# Handle BERT-VITS2
|
161 |
bert_vits2_objs, bert_vits2_speakers = [], []
|
162 |
-
if len(categorized_models[
|
163 |
from bert_vits2 import Bert_VITS2
|
164 |
-
bert_vits2_objs, bert_vits2_speakers = merge_models(categorized_models[
|
165 |
-
|
166 |
-
voice_obj = {
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
170 |
|
171 |
tts = TTS(voice_obj, voice_speakers, device=device, w2v2_emotion_count=w2v2_emotion_count)
|
172 |
return tts
|
|
|
3 |
import logging
|
4 |
import config
|
5 |
import numpy as np
|
6 |
+
|
7 |
+
import utils
|
8 |
+
from utils.data_utils import check_is_none, HParams
|
9 |
from vits import VITS
|
10 |
from voice import TTS
|
11 |
from config import DEVICE as device
|
12 |
from utils.lang_dict import lang_dict
|
13 |
+
from contants import ModelType
|
14 |
|
15 |
|
16 |
+
def recognition_model_type(hps: HParams) -> str:
|
17 |
+
# model_config = json.load(model_config_json)
|
18 |
+
symbols = getattr(hps, "symbols", None)
|
19 |
+
# symbols = model_config.get("symbols", None)
|
20 |
+
emotion_embedding = getattr(hps.data, "emotion_embedding", False)
|
21 |
|
22 |
+
if "use_spk_conditioned_encoder" in hps.model:
|
23 |
+
model_type = ModelType.BERT_VITS2
|
|
|
|
|
|
|
|
|
|
|
24 |
return model_type
|
25 |
|
26 |
if symbols != None:
|
27 |
if not emotion_embedding:
|
28 |
+
mode_type = ModelType.VITS
|
29 |
else:
|
30 |
+
mode_type = ModelType.W2V2_VITS
|
31 |
else:
|
32 |
+
mode_type = ModelType.HUBERT_VITS
|
33 |
|
34 |
return mode_type
|
35 |
|
|
|
75 |
|
76 |
def parse_models(model_list):
|
77 |
categorized_models = {
|
78 |
+
ModelType.VITS: [],
|
79 |
+
ModelType.HUBERT_VITS: [],
|
80 |
+
ModelType.W2V2_VITS: [],
|
81 |
+
ModelType.BERT_VITS2: []
|
82 |
}
|
83 |
|
84 |
for model_info in model_list:
|
85 |
config_path = model_info[1]
|
86 |
+
hps = utils.get_hparams_from_file(config_path)
|
87 |
+
model_info.append(hps)
|
88 |
+
model_type = recognition_model_type(hps)
|
89 |
+
# with open(config_path, 'r', encoding='utf-8') as model_config:
|
90 |
+
# model_type = recognition_model_type(model_config)
|
91 |
if model_type in categorized_models:
|
92 |
categorized_models[model_type].append(model_info)
|
93 |
|
|
|
99 |
speakers = []
|
100 |
new_id = 0
|
101 |
|
102 |
+
for obj_id, (model_path, config_path, hps) in enumerate(model_list):
|
103 |
obj_args = {
|
104 |
"model": model_path,
|
105 |
+
"config": hps,
|
106 |
"model_type": model_type,
|
107 |
"device": device
|
108 |
}
|
109 |
+
|
110 |
+
if model_type == ModelType.BERT_VITS2:
|
111 |
+
from bert_vits2.utils import process_legacy_versions
|
112 |
+
legacy_versions = process_legacy_versions(hps)
|
113 |
+
key = f"{model_type.value}_v{legacy_versions}" if legacy_versions else model_type.value
|
114 |
+
else:
|
115 |
+
key = getattr(hps.data, "text_cleaners", ["none"])[0]
|
116 |
+
|
117 |
if additional_arg:
|
118 |
obj_args.update(additional_arg)
|
119 |
|
120 |
obj = model_class(**obj_args)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
122 |
lang = lang_dict.get(key, ["unknown"])
|
123 |
|
124 |
for real_id, name in enumerate(obj.get_speakers()):
|
|
|
133 |
categorized_models = parse_models(model_list)
|
134 |
|
135 |
# Handle VITS
|
136 |
+
vits_objs, vits_speakers = merge_models(categorized_models[ModelType.VITS], VITS, ModelType.VITS)
|
137 |
|
138 |
# Handle HUBERT-VITS
|
139 |
hubert_vits_objs, hubert_vits_speakers = [], []
|
140 |
+
if len(categorized_models[ModelType.HUBERT_VITS]) != 0:
|
141 |
if getattr(config, "HUBERT_SOFT_MODEL", None) is None or check_is_none(config.HUBERT_SOFT_MODEL):
|
142 |
raise ValueError(f"Please configure HUBERT_SOFT_MODEL path in config.py")
|
143 |
try:
|
|
|
146 |
except Exception as e:
|
147 |
raise ValueError(f"Load HUBERT_SOFT_MODEL failed {e}")
|
148 |
|
149 |
+
hubert_vits_objs, hubert_vits_speakers = merge_models(categorized_models[ModelType.HUBERT_VITS], VITS, ModelType.HUBERT_VITS,
|
150 |
additional_arg={"additional_model": hubert})
|
151 |
|
152 |
# Handle W2V2-VITS
|
153 |
w2v2_vits_objs, w2v2_vits_speakers = [], []
|
154 |
w2v2_emotion_count = 0
|
155 |
+
if len(categorized_models[ModelType.W2V2_VITS]) != 0:
|
156 |
if getattr(config, "DIMENSIONAL_EMOTION_NPY", None) is None or check_is_none(
|
157 |
config.DIMENSIONAL_EMOTION_NPY):
|
158 |
raise ValueError(f"Please configure DIMENSIONAL_EMOTION_NPY path in config.py")
|
|
|
162 |
emotion_reference = None
|
163 |
raise ValueError(f"Load DIMENSIONAL_EMOTION_NPY failed {e}")
|
164 |
|
165 |
+
w2v2_vits_objs, w2v2_vits_speakers = merge_models(categorized_models[ModelType.W2V2_VITS], VITS, ModelType.W2V2_VITS,
|
166 |
additional_arg={"additional_model": emotion_reference})
|
167 |
w2v2_emotion_count = len(emotion_reference) if emotion_reference is not None else 0
|
168 |
|
169 |
# Handle BERT-VITS2
|
170 |
bert_vits2_objs, bert_vits2_speakers = [], []
|
171 |
+
if len(categorized_models[ModelType.BERT_VITS2]) != 0:
|
172 |
from bert_vits2 import Bert_VITS2
|
173 |
+
bert_vits2_objs, bert_vits2_speakers = merge_models(categorized_models[ModelType.BERT_VITS2], Bert_VITS2, ModelType.BERT_VITS2)
|
174 |
+
|
175 |
+
voice_obj = {ModelType.VITS: vits_objs,
|
176 |
+
ModelType.HUBERT_VITS: hubert_vits_objs,
|
177 |
+
ModelType.W2V2_VITS: w2v2_vits_objs,
|
178 |
+
ModelType.BERT_VITS2: bert_vits2_objs}
|
179 |
+
voice_speakers = {ModelType.VITS.value: vits_speakers,
|
180 |
+
ModelType.HUBERT_VITS.value: hubert_vits_speakers,
|
181 |
+
ModelType.W2V2_VITS.value: w2v2_vits_speakers,
|
182 |
+
ModelType.BERT_VITS2.value: bert_vits2_speakers}
|
183 |
|
184 |
tts = TTS(voice_obj, voice_speakers, device=device, w2v2_emotion_count=w2v2_emotion_count)
|
185 |
return tts
|
vits-simple-api-installer-latest.sh
CHANGED
@@ -28,9 +28,9 @@ EN_MESSAGES=(
|
|
28 |
["DOWNLOAD_VITS_CHINESE"]="Do you want to download the bert model for vits_chinese? Enter 1 for yes, 2 for no."
|
29 |
["MUST_DOWNLOAD_VITS_CHINESE"]="Using vits_chinese requires downloading these models, which will take up about 410MB."
|
30 |
["DOWNLOAD_BERT_VITS2"]="Do you want to download chinese-roberta-wwm-ext-large? Enter 1 for yes or 2 for no"
|
31 |
-
["MUST_DOWNLOAD_BERT_VITS2"]="To use Bert-VITS2, you must download these models, which will take up about
|
32 |
["DOWNLOADED"]="File is downloaded correctly."
|
33 |
-
["CORRUPTED"]="
|
34 |
["INSTALL_COMPLETE"]="The upgrade or installation has been completed."
|
35 |
["CONFIG_DIR"]="The configuration file directory is"
|
36 |
["IMPORT_NOTICE"]="If the vits model is not imported, it cannot be used. Import the model in the configuration file directory."
|
@@ -59,9 +59,9 @@ ZH_MESSAGES=(
|
|
59 |
["DOWNLOAD_VITS_CHINESE"]="是否要下载vits_chinese的bert模型?输入1表示是,2表示否。"
|
60 |
["MUST_DOWNLOAD_VITS_CHINESE"]="使用vits_chinese必须下载这些模型,将占用大约410MB。"
|
61 |
["DOWNLOAD_BERT_VITS2"]="是否要下载chinese-roberta-wwm-ext-large?输入1表示是,2表示否。"
|
62 |
-
["MUST_DOWNLOAD_BERT_VITS2"]="使用Bert-VITS2必须下载这些模型,将占用大约
|
63 |
["DOWNLOADED"]="文件已正确下载。"
|
64 |
-
["CORRUPTED"]="
|
65 |
["INSTALL_COMPLETE"]="更新或安装已完成。"
|
66 |
["CONFIG_DIR"]="配置文件目录是"
|
67 |
["IMPORT_NOTICE"]="如果vits模型没有被导入,它是无法使用的。请在配置文件目录中导入模型。"
|
@@ -240,6 +240,21 @@ if [ "$choice_download_bert_vits2" -eq 1 ]; then
|
|
240 |
download_with_fallback bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin \
|
241 |
"https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin"
|
242 |
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
fi
|
245 |
|
|
|
28 |
["DOWNLOAD_VITS_CHINESE"]="Do you want to download the bert model for vits_chinese? Enter 1 for yes, 2 for no."
|
29 |
["MUST_DOWNLOAD_VITS_CHINESE"]="Using vits_chinese requires downloading these models, which will take up about 410MB."
|
30 |
["DOWNLOAD_BERT_VITS2"]="Do you want to download chinese-roberta-wwm-ext-large? Enter 1 for yes or 2 for no"
|
31 |
+
["MUST_DOWNLOAD_BERT_VITS2"]="To use Bert-VITS2, you must download these models, which will take up about 1.63GB."
|
32 |
["DOWNLOADED"]="File is downloaded correctly."
|
33 |
+
["CORRUPTED"]="The file may not have been downloaded, or the download might be incomplete, and it could also be corrupted."
|
34 |
["INSTALL_COMPLETE"]="The upgrade or installation has been completed."
|
35 |
["CONFIG_DIR"]="The configuration file directory is"
|
36 |
["IMPORT_NOTICE"]="If the vits model is not imported, it cannot be used. Import the model in the configuration file directory."
|
|
|
59 |
["DOWNLOAD_VITS_CHINESE"]="是否要下载vits_chinese的bert模型?输入1表示是,2表示否。"
|
60 |
["MUST_DOWNLOAD_VITS_CHINESE"]="使用vits_chinese必须下载这些模型,将占用大约410MB。"
|
61 |
["DOWNLOAD_BERT_VITS2"]="是否要下载chinese-roberta-wwm-ext-large?输入1表示是,2表示否。"
|
62 |
+
["MUST_DOWNLOAD_BERT_VITS2"]="使用Bert-VITS2必须下载这些模型,将占用大约1.63GB。"
|
63 |
["DOWNLOADED"]="文件已正确下载。"
|
64 |
+
["CORRUPTED"]="文件可能未下载,或下载不完整,也有可能已损坏。"
|
65 |
["INSTALL_COMPLETE"]="更新或安装已完成。"
|
66 |
["CONFIG_DIR"]="配置文件目录是"
|
67 |
["IMPORT_NOTICE"]="如果vits模型没有被导入,它是无法使用的。请在配置文件目录中导入模型。"
|
|
|
240 |
download_with_fallback bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin \
|
241 |
"https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin"
|
242 |
fi
|
243 |
+
|
244 |
+
mkdir -p bert_vits2/bert/bert-base-japanese-v3
|
245 |
+
|
246 |
+
EXPECTED_MD5="6d0f8f3503dae04df0711b6175ef0c8e"
|
247 |
+
FILE_PATH="bert_vits2/bert/bert-base-japanese-v3/pytorch_model.bin"
|
248 |
+
echo -e "${MESSAGES["VERIFYING"]}$FILE_PATH"
|
249 |
+
ACTUAL_MD5=$(md5sum $FILE_PATH | awk '{print $1}')
|
250 |
+
|
251 |
+
if [ "$EXPECTED_MD5" == "$ACTUAL_MD5" ]; then
|
252 |
+
echo "${MESSAGES["DOWNLOADED"]}"
|
253 |
+
else
|
254 |
+
echo ${MESSAGES["CORRUPTED"]}
|
255 |
+
download_with_fallback bert_vits2/bert/bert-base-japanese-v3/pytorch_model.bin \
|
256 |
+
"https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin"
|
257 |
+
fi
|
258 |
|
259 |
fi
|
260 |
|
vits/text/japanese.py
CHANGED
@@ -1,7 +1,23 @@
|
|
|
|
1 |
import re
|
2 |
from unidecode import unidecode
|
3 |
import pyopenjtalk
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Regular expression matching Japanese without punctuation marks:
|
7 |
_japanese_characters = re.compile(
|
@@ -127,7 +143,7 @@ def get_real_hatsuon(text):
|
|
127 |
def japanese_to_ipa(text):
|
128 |
text = japanese_to_romaji_with_accent(text).replace('...', '…')
|
129 |
text = re.sub(
|
130 |
-
r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
|
131 |
text = get_real_sokuon(text)
|
132 |
text = get_real_hatsuon(text)
|
133 |
for regex, replacement in _romaji_to_ipa:
|
@@ -148,6 +164,6 @@ def japanese_to_ipa3(text):
|
|
148 |
text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
|
149 |
'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
|
150 |
text = re.sub(
|
151 |
-
r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
|
152 |
text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
|
153 |
return text
|
|
|
1 |
+
import os
|
2 |
import re
|
3 |
from unidecode import unidecode
|
4 |
import pyopenjtalk
|
5 |
|
6 |
+
from config import ABS_PATH
|
7 |
+
from utils.download import download_and_verify
|
8 |
+
|
9 |
+
URLS = [
|
10 |
+
"https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz",
|
11 |
+
"https://ghproxy.com/https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz",
|
12 |
+
]
|
13 |
+
install_path = os.path.dirname(pyopenjtalk.__file__)
|
14 |
+
dict_path = os.path.join(install_path, "open_jtalk_dic_utf_8-1.11", "char.bin")
|
15 |
+
TARGET_PATH = os.path.join(ABS_PATH, "open_jtalk_dic_utf_8-1.11.tar.gz")
|
16 |
+
EXTRACT_DESTINATION = os.path.join(install_path, "")
|
17 |
+
EXPECTED_MD5 = None
|
18 |
+
|
19 |
+
if not os.path.exists(dict_path):
|
20 |
+
success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION)
|
21 |
|
22 |
# Regular expression matching Japanese without punctuation marks:
|
23 |
_japanese_characters = re.compile(
|
|
|
143 |
def japanese_to_ipa(text):
|
144 |
text = japanese_to_romaji_with_accent(text).replace('...', '…')
|
145 |
text = re.sub(
|
146 |
+
r'([aiueo])\1+', lambda x: x.group(0)[0] + 'ː' * (len(x.group(0)) - 1), text)
|
147 |
text = get_real_sokuon(text)
|
148 |
text = get_real_hatsuon(text)
|
149 |
for regex, replacement in _romaji_to_ipa:
|
|
|
164 |
text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
|
165 |
'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
|
166 |
text = re.sub(
|
167 |
+
r'([aiɯeo])\1+', lambda x: x.group(0)[0] + 'ː' * (len(x.group(0)) - 1), text)
|
168 |
text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
|
169 |
return text
|
vits/vits.py
CHANGED
@@ -4,6 +4,8 @@ import numpy as np
|
|
4 |
import torch
|
5 |
from torch import no_grad, LongTensor, inference_mode, FloatTensor
|
6 |
import utils
|
|
|
|
|
7 |
from utils.sentence import sentence_split_and_markup
|
8 |
from vits import commons
|
9 |
from vits.mel_processing import spectrogram_torch
|
@@ -11,11 +13,10 @@ from vits.text import text_to_sequence
|
|
11 |
from vits.models import SynthesizerTrn
|
12 |
|
13 |
|
14 |
-
|
15 |
class VITS:
|
16 |
-
def __init__(self, model, config, additional_model=None, model_type=None, device=torch.device("cpu")
|
17 |
self.model_type = model_type
|
18 |
-
self.hps_ms =
|
19 |
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
|
20 |
self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
|
21 |
self.speakers = getattr(self.hps_ms, 'speakers', ['0'])
|
@@ -38,15 +39,18 @@ class VITS:
|
|
38 |
_ = self.net_g_ms.eval()
|
39 |
self.device = device
|
40 |
|
|
|
|
|
|
|
41 |
# load model
|
42 |
self.load_model(model, additional_model)
|
43 |
|
44 |
def load_model(self, model, additional_model=None):
|
45 |
utils.load_checkpoint(model, self.net_g_ms)
|
46 |
self.net_g_ms.to(self.device)
|
47 |
-
if self.model_type ==
|
48 |
self.hubert = additional_model
|
49 |
-
elif self.model_type ==
|
50 |
self.emotion_reference = additional_model
|
51 |
|
52 |
def get_cleaned_text(self, text, hps, cleaned=False):
|
@@ -71,13 +75,17 @@ class VITS:
|
|
71 |
def get_speakers(self, escape=False):
|
72 |
return self.speakers
|
73 |
|
|
|
|
|
|
|
|
|
74 |
def infer(self, params):
|
75 |
with no_grad():
|
76 |
x_tst = params.get("stn_tst").unsqueeze(0).to(self.device)
|
77 |
x_tst_lengths = LongTensor([params.get("stn_tst").size(0)]).to(self.device)
|
78 |
x_tst_prosody = torch.FloatTensor(params.get("char_embeds")).unsqueeze(0).to(
|
79 |
self.device) if self.bert_embedding else None
|
80 |
-
sid = params.get("sid").to(self.device)
|
81 |
emotion = params.get("emotion").to(self.device) if self.emotion_embedding else None
|
82 |
|
83 |
audio = self.net_g_ms.infer(x=x_tst,
|
@@ -97,15 +105,14 @@ class VITS:
|
|
97 |
emotion=None, cleaned=False, f0_scale=1):
|
98 |
emo = None
|
99 |
char_embeds = None
|
100 |
-
if self.model_type !=
|
101 |
if self.bert_embedding:
|
102 |
stn_tst, char_embeds = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
|
103 |
-
sid = None
|
104 |
else:
|
105 |
stn_tst = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
|
106 |
-
|
107 |
|
108 |
-
if self.model_type ==
|
109 |
# if emotion_reference.endswith('.npy'):
|
110 |
# emotion = np.load(emotion_reference)
|
111 |
# emotion = FloatTensor(emotion).unsqueeze(0)
|
@@ -121,7 +128,7 @@ class VITS:
|
|
121 |
emo = torch.FloatTensor(self.emotion_reference[emotion]).unsqueeze(0)
|
122 |
|
123 |
|
124 |
-
elif self.model_type ==
|
125 |
if self.use_f0:
|
126 |
audio, sampling_rate = librosa.load(audio_path, sr=self.hps_ms.data.sampling_rate, mono=True)
|
127 |
audio16000 = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
|
@@ -165,35 +172,39 @@ class VITS:
|
|
165 |
if text is not None: text = re.sub(r'\s+', ' ', text).strip()
|
166 |
|
167 |
tasks = []
|
168 |
-
if self.model_type ==
|
169 |
sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
|
170 |
for sentence in sentence_list:
|
171 |
params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
|
172 |
noise_scale=noise, noise_scale_w=noisew)
|
173 |
tasks.append(params)
|
174 |
|
175 |
-
elif self.model_type ==
|
176 |
params = self.get_infer_param(speaker_id=speaker_id, length_scale=length, noise_scale=noise,
|
177 |
noise_scale_w=noisew, audio_path=audio_path)
|
178 |
tasks.append(params)
|
179 |
|
180 |
-
elif self.model_type ==
|
181 |
sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
|
182 |
for sentence in sentence_list:
|
183 |
params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
|
184 |
noise_scale=noise, noise_scale_w=noisew, emotion=emotion)
|
185 |
tasks.append(params)
|
|
|
|
|
186 |
|
187 |
return tasks
|
188 |
|
189 |
def get_audio(self, voice, auto_break=False):
|
190 |
tasks = self.get_tasks(voice)
|
191 |
# 停顿0.75s,避免语音分段合成再拼接后的连接突兀
|
192 |
-
brk = np.zeros(int(0.75 *
|
193 |
|
194 |
audios = []
|
195 |
-
|
196 |
-
|
|
|
|
|
197 |
chunk = np.concatenate((self.infer(task), brk), axis=0)
|
198 |
else:
|
199 |
chunk = self.infer(task)
|
|
|
4 |
import torch
|
5 |
from torch import no_grad, LongTensor, inference_mode, FloatTensor
|
6 |
import utils
|
7 |
+
from contants import ModelType
|
8 |
+
from utils import get_hparams_from_file, lang_dict
|
9 |
from utils.sentence import sentence_split_and_markup
|
10 |
from vits import commons
|
11 |
from vits.mel_processing import spectrogram_torch
|
|
|
13 |
from vits.models import SynthesizerTrn
|
14 |
|
15 |
|
|
|
16 |
class VITS:
|
17 |
+
def __init__(self, model, config, additional_model=None, model_type=None, device=torch.device("cpu"), **kwargs):
|
18 |
self.model_type = model_type
|
19 |
+
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
|
20 |
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
|
21 |
self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
|
22 |
self.speakers = getattr(self.hps_ms, 'speakers', ['0'])
|
|
|
39 |
_ = self.net_g_ms.eval()
|
40 |
self.device = device
|
41 |
|
42 |
+
key = getattr(self.hps_ms.data, "text_cleaners", ["none"])[0]
|
43 |
+
self.lang = lang_dict.get(key, ["unknown"])
|
44 |
+
|
45 |
# load model
|
46 |
self.load_model(model, additional_model)
|
47 |
|
48 |
def load_model(self, model, additional_model=None):
|
49 |
utils.load_checkpoint(model, self.net_g_ms)
|
50 |
self.net_g_ms.to(self.device)
|
51 |
+
if self.model_type == ModelType.HUBERT_VITS:
|
52 |
self.hubert = additional_model
|
53 |
+
elif self.model_type == ModelType.W2V2_VITS:
|
54 |
self.emotion_reference = additional_model
|
55 |
|
56 |
def get_cleaned_text(self, text, hps, cleaned=False):
|
|
|
75 |
def get_speakers(self, escape=False):
|
76 |
return self.speakers
|
77 |
|
78 |
+
@property
|
79 |
+
def sampling_rate(self):
|
80 |
+
return self.hps_ms.data.sampling_rate
|
81 |
+
|
82 |
def infer(self, params):
|
83 |
with no_grad():
|
84 |
x_tst = params.get("stn_tst").unsqueeze(0).to(self.device)
|
85 |
x_tst_lengths = LongTensor([params.get("stn_tst").size(0)]).to(self.device)
|
86 |
x_tst_prosody = torch.FloatTensor(params.get("char_embeds")).unsqueeze(0).to(
|
87 |
self.device) if self.bert_embedding else None
|
88 |
+
sid = params.get("sid").to(self.device)
|
89 |
emotion = params.get("emotion").to(self.device) if self.emotion_embedding else None
|
90 |
|
91 |
audio = self.net_g_ms.infer(x=x_tst,
|
|
|
105 |
emotion=None, cleaned=False, f0_scale=1):
|
106 |
emo = None
|
107 |
char_embeds = None
|
108 |
+
if self.model_type != ModelType.HUBERT_VITS:
|
109 |
if self.bert_embedding:
|
110 |
stn_tst, char_embeds = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
|
|
|
111 |
else:
|
112 |
stn_tst = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
|
113 |
+
sid = LongTensor([speaker_id])
|
114 |
|
115 |
+
if self.model_type == ModelType.W2V2_VITS:
|
116 |
# if emotion_reference.endswith('.npy'):
|
117 |
# emotion = np.load(emotion_reference)
|
118 |
# emotion = FloatTensor(emotion).unsqueeze(0)
|
|
|
128 |
emo = torch.FloatTensor(self.emotion_reference[emotion]).unsqueeze(0)
|
129 |
|
130 |
|
131 |
+
elif self.model_type == ModelType.HUBERT_VITS:
|
132 |
if self.use_f0:
|
133 |
audio, sampling_rate = librosa.load(audio_path, sr=self.hps_ms.data.sampling_rate, mono=True)
|
134 |
audio16000 = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
|
|
|
172 |
if text is not None: text = re.sub(r'\s+', ' ', text).strip()
|
173 |
|
174 |
tasks = []
|
175 |
+
if self.model_type == ModelType.VITS:
|
176 |
sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
|
177 |
for sentence in sentence_list:
|
178 |
params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
|
179 |
noise_scale=noise, noise_scale_w=noisew)
|
180 |
tasks.append(params)
|
181 |
|
182 |
+
elif self.model_type == ModelType.HUBERT_VITS:
|
183 |
params = self.get_infer_param(speaker_id=speaker_id, length_scale=length, noise_scale=noise,
|
184 |
noise_scale_w=noisew, audio_path=audio_path)
|
185 |
tasks.append(params)
|
186 |
|
187 |
+
elif self.model_type == ModelType.W2V2_VITS:
|
188 |
sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
|
189 |
for sentence in sentence_list:
|
190 |
params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
|
191 |
noise_scale=noise, noise_scale_w=noisew, emotion=emotion)
|
192 |
tasks.append(params)
|
193 |
+
else:
|
194 |
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
195 |
|
196 |
return tasks
|
197 |
|
198 |
def get_audio(self, voice, auto_break=False):
|
199 |
tasks = self.get_tasks(voice)
|
200 |
# 停顿0.75s,避免语音分段合成再拼接后的连接突兀
|
201 |
+
brk = np.zeros(int(0.75 * self.sampling_rate), dtype=np.int16)
|
202 |
|
203 |
audios = []
|
204 |
+
num_tasks = len(tasks)
|
205 |
+
|
206 |
+
for i, task in enumerate(tasks):
|
207 |
+
if auto_break and i < num_tasks - 1:
|
208 |
chunk = np.concatenate((self.infer(task), brk), axis=0)
|
209 |
else:
|
210 |
chunk = self.infer(task)
|
voice.py
CHANGED
@@ -8,8 +8,9 @@ import config
|
|
8 |
import soundfile as sf
|
9 |
from io import BytesIO
|
10 |
from graiax import silkcoder
|
11 |
-
import utils
|
12 |
from logger import logger
|
|
|
|
|
13 |
|
14 |
|
15 |
# torch.set_num_threads(1) # 设置torch线程为1
|
@@ -21,11 +22,11 @@ class TTS:
|
|
21 |
self._voice_speakers = voice_speakers
|
22 |
self._strength_dict = {"x-weak": 0.25, "weak": 0.5, "Medium": 0.75, "Strong": 1, "x-strong": 1.25}
|
23 |
self._speakers_count = sum([len(self._voice_speakers[i]) for i in self._voice_speakers])
|
24 |
-
self._vits_speakers_count = len(self._voice_speakers[
|
25 |
-
self._hubert_speakers_count = len(self._voice_speakers[
|
26 |
-
self._w2v2_speakers_count = len(self._voice_speakers[
|
27 |
self._w2v2_emotion_count = kwargs.get("w2v2_emotion_count", 0)
|
28 |
-
self._bert_vits2_speakers_count = len(self._voice_speakers[
|
29 |
self.dem = None
|
30 |
|
31 |
# Initialization information
|
@@ -42,11 +43,11 @@ class TTS:
|
|
42 |
except Exception as e:
|
43 |
self.logger.warning(f"Load DIMENSIONAL_EMOTION_MODEL failed {e}")
|
44 |
|
45 |
-
if self._vits_speakers_count != 0: self.logger.info(f"[VITS] {self._vits_speakers_count} speakers")
|
46 |
-
if self._hubert_speakers_count != 0: self.logger.info(f"[
|
47 |
-
if self._w2v2_speakers_count != 0: self.logger.info(f"[
|
48 |
if self._bert_vits2_speakers_count != 0: self.logger.info(
|
49 |
-
f"[
|
50 |
self.logger.info(f"{self._speakers_count} speakers in total.")
|
51 |
if self._speakers_count == 0:
|
52 |
self.logger.warning(f"No model was loaded.")
|
@@ -124,6 +125,15 @@ class TTS:
|
|
124 |
break
|
125 |
yield chunk
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
def parse_ssml(self, ssml):
|
128 |
root = ET.fromstring(ssml)
|
129 |
format = root.attrib.get("format", "wav")
|
@@ -140,9 +150,11 @@ class TTS:
|
|
140 |
noisew = float(element.attrib.get("noisew", root.attrib.get("noisew", config.NOISEW)))
|
141 |
max = int(element.attrib.get("max", root.attrib.get("max", "0")))
|
142 |
# 不填写默认就是vits
|
143 |
-
|
144 |
# w2v2-vits/emotion-vits才有emotion
|
145 |
emotion = int(element.attrib.get("emotion", root.attrib.get("emotion", 0)))
|
|
|
|
|
146 |
|
147 |
voice_element = ET.tostring(element, encoding='unicode')
|
148 |
|
@@ -177,8 +189,9 @@ class TTS:
|
|
177 |
"noise": noise,
|
178 |
"noisew": noisew,
|
179 |
"max": max,
|
180 |
-
"
|
181 |
-
"emotion": emotion
|
|
|
182 |
})
|
183 |
|
184 |
# 分段末尾停顿0.75s
|
@@ -197,39 +210,49 @@ class TTS:
|
|
197 |
|
198 |
return voice_tasks, format
|
199 |
|
200 |
-
def
|
201 |
audios = []
|
|
|
|
|
202 |
for task in tasks:
|
203 |
if task.get("break"):
|
204 |
audios.append(np.zeros(int(task.get("break") * 22050), dtype=np.int16))
|
|
|
205 |
else:
|
206 |
-
|
207 |
-
if
|
208 |
-
raise ValueError(f"Unsupported model: {task.get('
|
209 |
-
|
210 |
-
|
|
|
|
|
|
|
|
|
211 |
audio = voice_obj.get_audio(task)
|
212 |
audios.append(audio)
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
216 |
return encoded_audio
|
217 |
|
218 |
def vits_infer(self, task):
|
219 |
format = task.get("format", "wav")
|
220 |
-
voice_obj = self._voice_obj[
|
221 |
-
real_id = self._voice_obj[
|
222 |
task["id"] = real_id # Change to real id
|
223 |
-
sampling_rate = voice_obj.
|
224 |
audio = voice_obj.get_audio(task, auto_break=True)
|
225 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
226 |
return encoded_audio
|
227 |
|
228 |
def stream_vits_infer(self, task, fname=None):
|
229 |
format = task.get("format", "wav")
|
230 |
-
voice_obj = self._voice_obj[
|
231 |
-
task["id"] = self._voice_obj[
|
232 |
-
sampling_rate = voice_obj.
|
233 |
genertator = voice_obj.get_stream_audio(task, auto_break=True)
|
234 |
# audio = BytesIO()
|
235 |
for chunk in genertator:
|
@@ -244,18 +267,18 @@ class TTS:
|
|
244 |
|
245 |
def hubert_vits_infer(self, task):
|
246 |
format = task.get("format", "wav")
|
247 |
-
voice_obj = self._voice_obj[
|
248 |
-
task["id"] = self._voice_obj[
|
249 |
-
sampling_rate = voice_obj.
|
250 |
audio = voice_obj.get_audio(task)
|
251 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
252 |
return encoded_audio
|
253 |
|
254 |
def w2v2_vits_infer(self, task):
|
255 |
format = task.get("format", "wav")
|
256 |
-
voice_obj = self._voice_obj[
|
257 |
-
task["id"] = self._voice_obj[
|
258 |
-
sampling_rate = voice_obj.
|
259 |
audio = voice_obj.get_audio(task, auto_break=True)
|
260 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
261 |
return encoded_audio
|
@@ -265,17 +288,17 @@ class TTS:
|
|
265 |
target_id = task.get("target_id")
|
266 |
format = task.get("format")
|
267 |
|
268 |
-
original_id_obj = int(self._voice_obj[
|
269 |
-
target_id_obj = int(self._voice_obj[
|
270 |
|
271 |
if original_id_obj != target_id_obj:
|
272 |
raise ValueError(f"speakers are in diffrent VITS Model")
|
273 |
|
274 |
-
task["original_id"] = int(self._voice_obj[
|
275 |
-
task["target_id"] = int(self._voice_obj[
|
276 |
|
277 |
-
voice_obj = self._voice_obj[
|
278 |
-
sampling_rate = voice_obj.
|
279 |
|
280 |
audio = voice_obj.voice_conversion(task)
|
281 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
@@ -294,9 +317,9 @@ class TTS:
|
|
294 |
|
295 |
def bert_vits2_infer(self, task):
|
296 |
format = task.get("format", "wav")
|
297 |
-
voice_obj = self._voice_obj[
|
298 |
-
task["id"] = self._voice_obj[
|
299 |
-
sampling_rate = voice_obj.
|
300 |
audio = voice_obj.get_audio(task, auto_break=True)
|
301 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
302 |
return encoded_audio
|
|
|
8 |
import soundfile as sf
|
9 |
from io import BytesIO
|
10 |
from graiax import silkcoder
|
|
|
11 |
from logger import logger
|
12 |
+
from contants import ModelType
|
13 |
+
from scipy.signal import resample_poly
|
14 |
|
15 |
|
16 |
# torch.set_num_threads(1) # 设置torch线程为1
|
|
|
22 |
self._voice_speakers = voice_speakers
|
23 |
self._strength_dict = {"x-weak": 0.25, "weak": 0.5, "Medium": 0.75, "Strong": 1, "x-strong": 1.25}
|
24 |
self._speakers_count = sum([len(self._voice_speakers[i]) for i in self._voice_speakers])
|
25 |
+
self._vits_speakers_count = len(self._voice_speakers[ModelType.VITS.value])
|
26 |
+
self._hubert_speakers_count = len(self._voice_speakers[ModelType.HUBERT_VITS.value])
|
27 |
+
self._w2v2_speakers_count = len(self._voice_speakers[ModelType.W2V2_VITS.value])
|
28 |
self._w2v2_emotion_count = kwargs.get("w2v2_emotion_count", 0)
|
29 |
+
self._bert_vits2_speakers_count = len(self._voice_speakers[ModelType.BERT_VITS2.value])
|
30 |
self.dem = None
|
31 |
|
32 |
# Initialization information
|
|
|
43 |
except Exception as e:
|
44 |
self.logger.warning(f"Load DIMENSIONAL_EMOTION_MODEL failed {e}")
|
45 |
|
46 |
+
if self._vits_speakers_count != 0: self.logger.info(f"[{ModelType.VITS.value}] {self._vits_speakers_count} speakers")
|
47 |
+
if self._hubert_speakers_count != 0: self.logger.info(f"[{ModelType.HUBERT_VITS.value}] {self._hubert_speakers_count} speakers")
|
48 |
+
if self._w2v2_speakers_count != 0: self.logger.info(f"[{ModelType.W2V2_VITS.value}] {self._w2v2_speakers_count} speakers")
|
49 |
if self._bert_vits2_speakers_count != 0: self.logger.info(
|
50 |
+
f"[{ModelType.BERT_VITS2.value}] {self._bert_vits2_speakers_count} speakers")
|
51 |
self.logger.info(f"{self._speakers_count} speakers in total.")
|
52 |
if self._speakers_count == 0:
|
53 |
self.logger.warning(f"No model was loaded.")
|
|
|
125 |
break
|
126 |
yield chunk
|
127 |
|
128 |
+
def resample_audio(self, audio, orig_sr, target_sr):
|
129 |
+
if orig_sr == target_sr:
|
130 |
+
return audio
|
131 |
+
|
132 |
+
gcd = np.gcd(orig_sr, target_sr)
|
133 |
+
audio = resample_poly(audio, target_sr // gcd, orig_sr // gcd)
|
134 |
+
|
135 |
+
return audio
|
136 |
+
|
137 |
def parse_ssml(self, ssml):
|
138 |
root = ET.fromstring(ssml)
|
139 |
format = root.attrib.get("format", "wav")
|
|
|
150 |
noisew = float(element.attrib.get("noisew", root.attrib.get("noisew", config.NOISEW)))
|
151 |
max = int(element.attrib.get("max", root.attrib.get("max", "0")))
|
152 |
# 不填写默认就是vits
|
153 |
+
model_type = element.attrib.get("model_type", root.attrib.get("model_type", "vits"))
|
154 |
# w2v2-vits/emotion-vits才有emotion
|
155 |
emotion = int(element.attrib.get("emotion", root.attrib.get("emotion", 0)))
|
156 |
+
# Bert-VITS2的参数
|
157 |
+
sdp_ratio = int(element.attrib.get("sdp_ratio", root.attrib.get("sdp_ratio", config.SDP_RATIO)))
|
158 |
|
159 |
voice_element = ET.tostring(element, encoding='unicode')
|
160 |
|
|
|
189 |
"noise": noise,
|
190 |
"noisew": noisew,
|
191 |
"max": max,
|
192 |
+
"model_type": model_type,
|
193 |
+
"emotion": emotion,
|
194 |
+
"sdp_ratio": sdp_ratio
|
195 |
})
|
196 |
|
197 |
# 分段末尾停顿0.75s
|
|
|
210 |
|
211 |
return voice_tasks, format
|
212 |
|
213 |
+
def process_ssml_infer_task(self, tasks, format):
|
214 |
audios = []
|
215 |
+
sampling_rates = []
|
216 |
+
last_sampling_rate = 22050
|
217 |
for task in tasks:
|
218 |
if task.get("break"):
|
219 |
audios.append(np.zeros(int(task.get("break") * 22050), dtype=np.int16))
|
220 |
+
sampling_rates.append(last_sampling_rate)
|
221 |
else:
|
222 |
+
model_type_str = task.get("model_type").upper()
|
223 |
+
if model_type_str not in [ModelType.VITS.value, ModelType.W2V2_VITS.value, ModelType.BERT_VITS2.value]:
|
224 |
+
raise ValueError(f"Unsupported model type: {task.get('model_type')}")
|
225 |
+
model_type = ModelType(model_type_str)
|
226 |
+
voice_obj = self._voice_obj[model_type][task.get("id")][1]
|
227 |
+
real_id = self._voice_obj[model_type][task.get("id")][0]
|
228 |
+
task["id"] = real_id
|
229 |
+
sampling_rates.append(voice_obj.sampling_rate)
|
230 |
+
last_sampling_rate = voice_obj.sampling_rate
|
231 |
audio = voice_obj.get_audio(task)
|
232 |
audios.append(audio)
|
233 |
+
# 得到最高的采样率
|
234 |
+
target_sr = max(sampling_rates)
|
235 |
+
# 所有音频要与最高采样率保持一致
|
236 |
+
resampled_audios = [self.resample_audio(audio, sr, target_sr) for audio, sr in zip(audios, sampling_rates)]
|
237 |
+
audio = np.concatenate(resampled_audios, axis=0)
|
238 |
+
encoded_audio = self.encode(target_sr, audio, format)
|
239 |
return encoded_audio
|
240 |
|
241 |
def vits_infer(self, task):
|
242 |
format = task.get("format", "wav")
|
243 |
+
voice_obj = self._voice_obj[ModelType.VITS][task.get("id")][1]
|
244 |
+
real_id = self._voice_obj[ModelType.VITS][task.get("id")][0]
|
245 |
task["id"] = real_id # Change to real id
|
246 |
+
sampling_rate = voice_obj.sampling_rate
|
247 |
audio = voice_obj.get_audio(task, auto_break=True)
|
248 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
249 |
return encoded_audio
|
250 |
|
251 |
def stream_vits_infer(self, task, fname=None):
|
252 |
format = task.get("format", "wav")
|
253 |
+
voice_obj = self._voice_obj[ModelType.VITS][task.get("id")][1]
|
254 |
+
task["id"] = self._voice_obj[ModelType.VITS][task.get("id")][0]
|
255 |
+
sampling_rate = voice_obj.sampling_rate
|
256 |
genertator = voice_obj.get_stream_audio(task, auto_break=True)
|
257 |
# audio = BytesIO()
|
258 |
for chunk in genertator:
|
|
|
267 |
|
268 |
def hubert_vits_infer(self, task):
|
269 |
format = task.get("format", "wav")
|
270 |
+
voice_obj = self._voice_obj[ModelType.HUBERT_VITS][task.get("id")][1]
|
271 |
+
task["id"] = self._voice_obj[ModelType.HUBERT_VITS][task.get("id")][0]
|
272 |
+
sampling_rate = voice_obj.sampling_rate
|
273 |
audio = voice_obj.get_audio(task)
|
274 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
275 |
return encoded_audio
|
276 |
|
277 |
def w2v2_vits_infer(self, task):
|
278 |
format = task.get("format", "wav")
|
279 |
+
voice_obj = self._voice_obj[ModelType.W2V2_VITS][task.get("id")][1]
|
280 |
+
task["id"] = self._voice_obj[ModelType.W2V2_VITS][task.get("id")][0]
|
281 |
+
sampling_rate = voice_obj.sampling_rate
|
282 |
audio = voice_obj.get_audio(task, auto_break=True)
|
283 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
284 |
return encoded_audio
|
|
|
288 |
target_id = task.get("target_id")
|
289 |
format = task.get("format")
|
290 |
|
291 |
+
original_id_obj = int(self._voice_obj[ModelType.VITS][original_id][2])
|
292 |
+
target_id_obj = int(self._voice_obj[ModelType.VITS][target_id][2])
|
293 |
|
294 |
if original_id_obj != target_id_obj:
|
295 |
raise ValueError(f"speakers are in diffrent VITS Model")
|
296 |
|
297 |
+
task["original_id"] = int(self._voice_obj[ModelType.VITS][original_id][0])
|
298 |
+
task["target_id"] = int(self._voice_obj[ModelType.VITS][target_id][0])
|
299 |
|
300 |
+
voice_obj = self._voice_obj[ModelType.VITS][original_id][1]
|
301 |
+
sampling_rate = voice_obj.sampling_rate
|
302 |
|
303 |
audio = voice_obj.voice_conversion(task)
|
304 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
|
|
317 |
|
318 |
def bert_vits2_infer(self, task):
|
319 |
format = task.get("format", "wav")
|
320 |
+
voice_obj = self._voice_obj[ModelType.BERT_VITS2][task.get("id")][1]
|
321 |
+
task["id"] = self._voice_obj[ModelType.BERT_VITS2][task.get("id")][0]
|
322 |
+
sampling_rate = voice_obj.sampling_rate
|
323 |
audio = voice_obj.get_audio(task, auto_break=True)
|
324 |
encoded_audio = self.encode(sampling_rate, audio, format)
|
325 |
return encoded_audio
|