Artrajz commited on
Commit
b0f5083
1 Parent(s): 24e3342
Dockerfile CHANGED
@@ -1,12 +1,14 @@
1
- FROM python:3.10.11-slim-bullseye
2
 
3
  RUN mkdir -p /app
4
  WORKDIR /app
5
 
6
  ENV DEBIAN_FRONTEND=noninteractive
7
 
 
8
  RUN apt-get update && \
9
- apt-get install -yq build-essential espeak-ng cmake wget && \
 
10
  apt-get clean && \
11
  apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
12
  rm -rf /var/lib/apt/lists/*
@@ -16,20 +18,16 @@ RUN wget https://github.com/jemalloc/jemalloc/releases/download/5.3.0/jemalloc-5
16
  tar -xvf jemalloc-5.3.0.tar.bz2 && \
17
  cd jemalloc-5.3.0 && \
18
  ./configure && \
19
- make && \
20
  make install && \
21
  cd .. && \
22
  rm -rf jemalloc-5.3.0* && \
23
- ldconfig
24
 
25
  ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
26
 
27
- RUN pip install torch --index-url https://download.pytorch.org/whl/cpu --no-cache-dir
28
-
29
  COPY requirements.txt /app/
30
- RUN pip install --upgrade pip && \
31
- pip install pyopenjtalk==0.3.2 -i https://pypi.artrajz.cn/simple --no-cache-dir && \
32
- pip install gunicorn --no-cache-dir && \
33
  pip install -r requirements.txt --no-cache-dir&& \
34
  rm -rf /root/.cache/pip/*
35
 
 
1
+ FROM artrajz/pytorch:1.13.1-cpu-py3.10.11-ubuntu22.04
2
 
3
  RUN mkdir -p /app
4
  WORKDIR /app
5
 
6
  ENV DEBIAN_FRONTEND=noninteractive
7
 
8
+
9
  RUN apt-get update && \
10
+ apt-get install -yq build-essential espeak-ng cmake wget ca-certificates tzdata&& \
11
+ update-ca-certificates && \
12
  apt-get clean && \
13
  apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
14
  rm -rf /var/lib/apt/lists/*
 
18
  tar -xvf jemalloc-5.3.0.tar.bz2 && \
19
  cd jemalloc-5.3.0 && \
20
  ./configure && \
21
+ make -j$(nproc) && \
22
  make install && \
23
  cd .. && \
24
  rm -rf jemalloc-5.3.0* && \
25
+ ldconfig
26
 
27
  ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
28
 
 
 
29
  COPY requirements.txt /app/
30
+ RUN pip install gunicorn --no-cache-dir && \
 
 
31
  pip install -r requirements.txt --no-cache-dir&& \
32
  rm -rf /root/.cache/pip/*
33
 
Dockerfile_GPU CHANGED
@@ -1,4 +1,4 @@
1
- FROM cnstark/pytorch:2.0.1-py3.10.11-cuda11.8.0-ubuntu22.04
2
 
3
  RUN mkdir -p /app
4
  WORKDIR /app
@@ -6,7 +6,7 @@ WORKDIR /app
6
  ENV DEBIAN_FRONTEND=noninteractive
7
 
8
  RUN apt-get update && \
9
- apt-get install -yq build-essential espeak-ng cmake wget ca-certificates && \
10
  update-ca-certificates && \
11
  apt-get clean && \
12
  apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
@@ -18,18 +18,16 @@ RUN wget https://github.com/jemalloc/jemalloc/releases/download/5.3.0/jemalloc-5
18
  tar -xvf jemalloc-5.3.0.tar.bz2 && \
19
  cd jemalloc-5.3.0 && \
20
  ./configure && \
21
- make && \
22
  make install && \
23
  cd .. && \
24
  rm -rf jemalloc-5.3.0* && \
25
- ldconfig
26
 
27
  ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
28
 
29
  COPY requirements.txt /app/
30
- RUN pip install --upgrade pip && \
31
- pip install pyopenjtalk==0.3.2 fasttext -i https://pypi.artrajz.cn/simple --no-cache-dir && \
32
- pip install gunicorn --no-cache-dir && \
33
  pip install -r requirements.txt --no-cache-dir&& \
34
  rm -rf /root/.cache/pip/*
35
 
 
1
+ FROM artrajz/pytorch:1.13.1-cu117-py3.10.11-ubuntu22.04
2
 
3
  RUN mkdir -p /app
4
  WORKDIR /app
 
6
  ENV DEBIAN_FRONTEND=noninteractive
7
 
8
  RUN apt-get update && \
9
+ apt-get install -yq build-essential espeak-ng cmake wget ca-certificates tzdata&& \
10
  update-ca-certificates && \
11
  apt-get clean && \
12
  apt-get purge -y --auto-remove -o APT::AutoRemove::RecommendsImportant=false && \
 
18
  tar -xvf jemalloc-5.3.0.tar.bz2 && \
19
  cd jemalloc-5.3.0 && \
20
  ./configure && \
21
+ make -j$(nproc) && \
22
  make install && \
23
  cd .. && \
24
  rm -rf jemalloc-5.3.0* && \
25
+ ldconfig
26
 
27
  ENV LD_PRELOAD=/usr/local/lib/libjemalloc.so
28
 
29
  COPY requirements.txt /app/
30
+ RUN pip install gunicorn --no-cache-dir && \
 
 
31
  pip install -r requirements.txt --no-cache-dir&& \
32
  rm -rf /root/.cache/pip/*
33
 
README_zh.md CHANGED
@@ -21,7 +21,7 @@
21
 
22
  - [x] VITS语音合成,语音转换
23
  - [x] HuBert-soft VITS模型
24
- - [x] W2V2 VITS / emotional-vits维度情感模型
25
  - [x] [vits_chinese](https://github.com/PlayVoice/vits_chinese)
26
  - [x] [Bert-VITS2](https://github.com/Stardust-minus/Bert-VITS2)
27
  - [x] 加载多模型
@@ -341,7 +341,7 @@ pip install pyopenjtalk -i https://pypi.artrajz.cn/simple
341
  | 合成文本 | text | true | | str | 需要合成语音的文本。 |
342
  | 角色id | id | false | 从`config.py`中获取 | int | 即说话人id。 |
343
  | 音频格式 | format | false | 从`config.py`中获取 | str | 支持wav,ogg,silk,mp3,flac |
344
- | 文本语言 | lang | false | 从`config.py`中获取 | str | 目前只有中文。 |
345
  | 语音长度/语速 | length | false | 从`config.py`中获取 | float | 调节语音长度,相当于调节语速,该数值越大语速越慢。 |
346
  | 噪声 | noise | false | 从`config.py`中获取 | float | 样本噪声,控制合成的随机性。 |
347
  | sdp噪声 | noisew | false | 从`config.py`中获取 | float | 随机时长预测器噪声,控制音素发音长度。 |
 
21
 
22
  - [x] VITS语音合成,语音转换
23
  - [x] HuBert-soft VITS模型
24
+ - [x] W2V2 VITS / [emotional-vits](https://github.com/innnky/emotional-vits)维度情感模型
25
  - [x] [vits_chinese](https://github.com/PlayVoice/vits_chinese)
26
  - [x] [Bert-VITS2](https://github.com/Stardust-minus/Bert-VITS2)
27
  - [x] 加载多模型
 
341
  | 合成文本 | text | true | | str | 需要合成语音的文本。 |
342
  | 角色id | id | false | 从`config.py`中获取 | int | 即说话人id。 |
343
  | 音频格式 | format | false | 从`config.py`中获取 | str | 支持wav,ogg,silk,mp3,flac |
344
+ | 文本语言 | lang | false | 从`config.py`中获取 | str | auto为自动识别语言模式,也是默认模式,但目前只支持识别整段文本的语言,无法细分到每个句子。其余可选语言zh和ja。 |
345
  | 语音长度/语速 | length | false | 从`config.py`中获取 | float | 调节语音长度,相当于调节语速,该数值越大语速越慢。 |
346
  | 噪声 | noise | false | 从`config.py`中获取 | float | 样本噪声,控制合成的随机性。 |
347
  | sdp噪声 | noisew | false | 从`config.py`中获取 | float | 随机时长预测器噪声,控制音素发音长度。 |
api_test.py CHANGED
@@ -396,9 +396,9 @@ if __name__ == '__main__':
396
  ssml = """
397
  <speak lang="zh" format="mp3" length="1.2">
398
  <voice id="92" >这几天心里颇不宁静。</voice>
399
- <voice id="125">今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。</voice>
400
  <voice id="142">月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;</voice>
401
- <voice id="98">妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。</voice>
402
  <voice id="120">我悄悄地披了大衫,带上门出去。</voice><break time="2s"/>
403
  <voice id="121">沿着荷塘,是一条曲折的小煤屑路。</voice>
404
  <voice id="122">这是一条幽僻的路;白天也少人走,夜晚更加寂寞。</voice>
 
396
  ssml = """
397
  <speak lang="zh" format="mp3" length="1.2">
398
  <voice id="92" >这几天心里颇不宁静。</voice>
399
+ <voice id="0" model_type="Bert-VITS2">今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。</voice>
400
  <voice id="142">月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;</voice>
401
+ <voice id="0" model_type="Bert-VITS2">妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。</voice>
402
  <voice id="120">我悄悄地披了大衫,带上门出去。</voice><break time="2s"/>
403
  <voice id="121">沿着荷塘,是一条曲折的小煤屑路。</voice>
404
  <voice id="122">这是一条幽僻的路;白天也少人走,夜晚更加寂寞。</voice>
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import time
3
  import uuid
 
 
4
  from logger import logger
5
  from flask import Flask, request, send_file, jsonify, make_response, render_template
6
  from werkzeug.utils import secure_filename
@@ -85,28 +87,29 @@ def voice_vits_api():
85
  max = int(request_data.get("max", app.config.get("MAX", 50)))
86
  use_streaming = request_data.get('streaming', False, type=bool)
87
  except Exception as e:
88
- logger.error(f"[VITS] {e}")
89
  return make_response("parameter error", 400)
90
 
91
- logger.info(f"[VITS] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew}")
92
- logger.info(f"[VITS] len:{len(text)} text:{text}")
 
93
 
94
  if check_is_none(text):
95
- logger.info(f"[VITS] text is empty")
96
  return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
97
 
98
  if check_is_none(id):
99
- logger.info(f"[VITS] speaker id is empty")
100
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
101
 
102
  if id < 0 or id >= tts.vits_speakers_count:
103
- logger.info(f"[VITS] speaker id {id} does not exist")
104
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
105
 
106
  # 校验模型是否支持输入的语言
107
- speaker_lang = tts.voice_speakers["VITS"][id].get('lang')
108
  if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
109
- logger.info(f"[VITS] lang \"{lang}\" is not in {speaker_lang}")
110
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
111
 
112
  # 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
@@ -139,10 +142,10 @@ def voice_vits_api():
139
  t1 = time.time()
140
  audio = tts.vits_infer(task)
141
  t2 = time.time()
142
- logger.info(f"[VITS] finish in {(t2 - t1):.2f}s")
143
 
144
  if app.config.get("SAVE_AUDIO", False):
145
- logger.debug(f"[VITS] {fname}")
146
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
147
  save_audio(audio.getvalue(), path)
148
 
@@ -162,20 +165,21 @@ def voice_hubert_api():
162
  noisew = float(request.form.get("noisew", app.config.get("NOISEW", 0.8)))
163
  use_streaming = request.form.get('streaming', False, type=bool)
164
  except Exception as e:
165
- logger.error(f"[hubert] {e}")
166
  return make_response("parameter error", 400)
167
 
168
- logger.info(f"[hubert] id:{id} format:{format} length:{length} noise:{noise} noisew:{noisew}")
 
169
 
170
  fname = secure_filename(str(uuid.uuid1()) + "." + voice.filename.split(".")[1])
171
  voice.save(os.path.join(app.config['UPLOAD_FOLDER'], fname))
172
 
173
  if check_is_none(id):
174
- logger.info(f"[hubert] speaker id is empty")
175
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
176
 
177
  if id < 0 or id >= tts.hubert_speakers_count:
178
- logger.info(f"[hubert] speaker id {id} does not exist")
179
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
180
 
181
  file_type = f"audio/{format}"
@@ -189,10 +193,10 @@ def voice_hubert_api():
189
  t1 = time.time()
190
  audio = tts.hubert_vits_infer(task)
191
  t2 = time.time()
192
- logger.info(f"[hubert] finish in {(t2 - t1):.2f}s")
193
 
194
  if app.config.get("SAVE_AUDIO", False):
195
- logger.debug(f"[hubert] {fname}")
196
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
197
  save_audio(audio.getvalue(), path)
198
 
@@ -230,29 +234,29 @@ def voice_w2v2_api():
230
  emotion = int(request_data.get("emotion", app.config.get("EMOTION", 0)))
231
  use_streaming = request_data.get('streaming', False, type=bool)
232
  except Exception as e:
233
- logger.error(f"[w2v2] {e}")
234
  return make_response(f"parameter error", 400)
235
 
236
- logger.info(f"[w2v2] id:{id} format:{format} lang:{lang} "
237
  f"length:{length} noise:{noise} noisew:{noisew} emotion:{emotion}")
238
- logger.info(f"[w2v2] len:{len(text)} text:{text}")
239
 
240
  if check_is_none(text):
241
- logger.info(f"[w2v2] text is empty")
242
  return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
243
 
244
  if check_is_none(id):
245
- logger.info(f"[w2v2] speaker id is empty")
246
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
247
 
248
  if id < 0 or id >= tts.w2v2_speakers_count:
249
- logger.info(f"[w2v2] speaker id {id} does not exist")
250
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
251
 
252
  # 校验模型是否支持输入的语言
253
- speaker_lang = tts.voice_speakers["W2V2-VITS"][id].get('lang')
254
  if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
255
- logger.info(f"[w2v2] lang \"{lang}\" is not in {speaker_lang}")
256
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
257
 
258
  # 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
@@ -279,10 +283,10 @@ def voice_w2v2_api():
279
  t1 = time.time()
280
  audio = tts.w2v2_vits_infer(task)
281
  t2 = time.time()
282
- logger.info(f"[w2v2] finish in {(t2 - t1):.2f}s")
283
 
284
  if app.config.get("SAVE_AUDIO", False):
285
- logger.debug(f"[w2v2] {fname}")
286
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
287
  save_audio(audio.getvalue(), path)
288
 
@@ -362,7 +366,7 @@ def ssml_api():
362
  file_type = f"audio/{format}"
363
 
364
  t1 = time.time()
365
- audio = tts.create_ssml_infer_task(voice_tasks, format)
366
  t2 = time.time()
367
  logger.info(f"[ssml] finish in {(t2 - t1):.2f}s")
368
 
@@ -423,29 +427,29 @@ def voice_bert_vits2_api():
423
  sdp_ratio = float(request_data.get("sdp_ratio", app.config.get("SDP_RATIO", 0.2)))
424
  max = int(request_data.get("max", app.config.get("MAX", 50)))
425
  except Exception as e:
426
- logger.error(f"[Bert-VITS2] {e}")
427
  return make_response("parameter error", 400)
428
 
429
  logger.info(
430
- f"[Bert-VITS2] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew} sdp_ratio:{sdp_ratio}")
431
- logger.info(f"[Bert-VITS2] len:{len(text)} text:{text}")
432
 
433
  if check_is_none(text):
434
- logger.info(f"[Bert-VITS2] text is empty")
435
  return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
436
 
437
  if check_is_none(id):
438
- logger.info(f"[Bert-VITS2] speaker id is empty")
439
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
440
 
441
  if id < 0 or id >= tts.bert_vits2_speakers_count:
442
- logger.info(f"[Bert-VITS2] speaker id {id} does not exist")
443
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
444
 
445
  # 校验模型是否支持输入的语言
446
- speaker_lang = tts.voice_speakers["BERT-VITS2"][id].get('lang')
447
  if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
448
- logger.info(f"[Bert-VITS2] lang \"{lang}\" is not in {speaker_lang}")
449
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
450
 
451
  # 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
@@ -468,10 +472,10 @@ def voice_bert_vits2_api():
468
  t1 = time.time()
469
  audio = tts.bert_vits2_infer(task)
470
  t2 = time.time()
471
- logger.info(f"[Bert-VITS2] finish in {(t2 - t1):.2f}s")
472
 
473
  if app.config.get("SAVE_AUDIO", False):
474
- logger.debug(f"[Bert-VITS2] {fname}")
475
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
476
  save_audio(audio.getvalue(), path)
477
 
@@ -490,18 +494,18 @@ def check():
490
  else:
491
  request_data = request.form
492
 
493
- model = request_data.get("model")
494
  id = int(request_data.get("id"))
495
  except Exception as e:
496
  logger.info(f"[check] {e}")
497
  return make_response(jsonify({"status": "error", "message": "parameter error"}), 400)
498
 
499
- if check_is_none(model):
500
- logger.info(f"[check] model {model} is empty")
501
  return make_response(jsonify({"status": "error", "message": "model is empty"}), 400)
502
 
503
- if model.upper() not in ("VITS", "HUBERT", "W2V2"):
504
- res = make_response(jsonify({"status": "error", "message": f"model {model} does not exist"}))
505
  res.status = 404
506
  logger.info(f"[check] speaker id {id} error")
507
  return res
@@ -510,16 +514,12 @@ def check():
510
  logger.info(f"[check] speaker id is empty")
511
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
512
 
513
- if model.upper() == "VITS":
514
- speaker_list = tts.voice_speakers["VITS"]
515
- elif model.upper() == "HUBERT":
516
- speaker_list = tts.voice_speakers["HUBERT-VITS"]
517
- elif model.upper() == "W2V2":
518
- speaker_list = tts.voice_speakers["W2V2-VITS"]
519
 
520
  if len(speaker_list) == 0:
521
- logger.info(f"[check] {model} not loaded")
522
- return make_response(jsonify({"status": "error", "message": f"{model} not loaded"}), 400)
523
 
524
  if id < 0 or id >= len(speaker_list):
525
  logger.info(f"[check] speaker id {id} does not exist")
 
1
  import os
2
  import time
3
  import uuid
4
+
5
+ from contants import ModelType
6
  from logger import logger
7
  from flask import Flask, request, send_file, jsonify, make_response, render_template
8
  from werkzeug.utils import secure_filename
 
87
  max = int(request_data.get("max", app.config.get("MAX", 50)))
88
  use_streaming = request_data.get('streaming', False, type=bool)
89
  except Exception as e:
90
+ logger.error(f"[{ModelType.VITS.value}] {e}")
91
  return make_response("parameter error", 400)
92
 
93
+ logger.info(
94
+ f"[{ModelType.VITS.value}] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew}")
95
+ logger.info(f"[{ModelType.VITS.value}] len:{len(text)} text:{text}")
96
 
97
  if check_is_none(text):
98
+ logger.info(f"[{ModelType.VITS.value}] text is empty")
99
  return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
100
 
101
  if check_is_none(id):
102
+ logger.info(f"[{ModelType.VITS.value}] speaker id is empty")
103
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
104
 
105
  if id < 0 or id >= tts.vits_speakers_count:
106
+ logger.info(f"[{ModelType.VITS.value}] speaker id {id} does not exist")
107
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
108
 
109
  # 校验模型是否支持输入的语言
110
+ speaker_lang = tts.voice_speakers[ModelType.VITS.value][id].get('lang')
111
  if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
112
+ logger.info(f"[{ModelType.VITS.value}] lang \"{lang}\" is not in {speaker_lang}")
113
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
114
 
115
  # 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
 
142
  t1 = time.time()
143
  audio = tts.vits_infer(task)
144
  t2 = time.time()
145
+ logger.info(f"[{ModelType.VITS.value}] finish in {(t2 - t1):.2f}s")
146
 
147
  if app.config.get("SAVE_AUDIO", False):
148
+ logger.debug(f"[{ModelType.VITS.value}] {fname}")
149
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
150
  save_audio(audio.getvalue(), path)
151
 
 
165
  noisew = float(request.form.get("noisew", app.config.get("NOISEW", 0.8)))
166
  use_streaming = request.form.get('streaming', False, type=bool)
167
  except Exception as e:
168
+ logger.error(f"[{ModelType.HUBERT_VITS.value}] {e}")
169
  return make_response("parameter error", 400)
170
 
171
+ logger.info(
172
+ f"[{ModelType.HUBERT_VITS.value}] id:{id} format:{format} length:{length} noise:{noise} noisew:{noisew}")
173
 
174
  fname = secure_filename(str(uuid.uuid1()) + "." + voice.filename.split(".")[1])
175
  voice.save(os.path.join(app.config['UPLOAD_FOLDER'], fname))
176
 
177
  if check_is_none(id):
178
+ logger.info(f"[{ModelType.HUBERT_VITS.value}] speaker id is empty")
179
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
180
 
181
  if id < 0 or id >= tts.hubert_speakers_count:
182
+ logger.info(f"[{ModelType.HUBERT_VITS.value}] speaker id {id} does not exist")
183
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
184
 
185
  file_type = f"audio/{format}"
 
193
  t1 = time.time()
194
  audio = tts.hubert_vits_infer(task)
195
  t2 = time.time()
196
+ logger.info(f"[{ModelType.HUBERT_VITS.value}] finish in {(t2 - t1):.2f}s")
197
 
198
  if app.config.get("SAVE_AUDIO", False):
199
+ logger.debug(f"[{ModelType.HUBERT_VITS.value}] {fname}")
200
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
201
  save_audio(audio.getvalue(), path)
202
 
 
234
  emotion = int(request_data.get("emotion", app.config.get("EMOTION", 0)))
235
  use_streaming = request_data.get('streaming', False, type=bool)
236
  except Exception as e:
237
+ logger.error(f"[{ModelType.W2V2_VITS.value}] {e}")
238
  return make_response(f"parameter error", 400)
239
 
240
+ logger.info(f"[{ModelType.W2V2_VITS.value}] id:{id} format:{format} lang:{lang} "
241
  f"length:{length} noise:{noise} noisew:{noisew} emotion:{emotion}")
242
+ logger.info(f"[{ModelType.W2V2_VITS.value}] len:{len(text)} text:{text}")
243
 
244
  if check_is_none(text):
245
+ logger.info(f"[{ModelType.W2V2_VITS.value}] text is empty")
246
  return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
247
 
248
  if check_is_none(id):
249
+ logger.info(f"[{ModelType.W2V2_VITS.value}] speaker id is empty")
250
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
251
 
252
  if id < 0 or id >= tts.w2v2_speakers_count:
253
+ logger.info(f"[{ModelType.W2V2_VITS.value}] speaker id {id} does not exist")
254
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
255
 
256
  # 校验模型是否支持输入的语言
257
+ speaker_lang = tts.voice_speakers[ModelType.W2V2_VITS.value][id].get('lang')
258
  if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
259
+ logger.info(f"[{ModelType.W2V2_VITS.value}] lang \"{lang}\" is not in {speaker_lang}")
260
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
261
 
262
  # 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
 
283
  t1 = time.time()
284
  audio = tts.w2v2_vits_infer(task)
285
  t2 = time.time()
286
+ logger.info(f"[{ModelType.W2V2_VITS.value}] finish in {(t2 - t1):.2f}s")
287
 
288
  if app.config.get("SAVE_AUDIO", False):
289
+ logger.debug(f"[{ModelType.W2V2_VITS.value}] {fname}")
290
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
291
  save_audio(audio.getvalue(), path)
292
 
 
366
  file_type = f"audio/{format}"
367
 
368
  t1 = time.time()
369
+ audio = tts.process_ssml_infer_task(voice_tasks, format)
370
  t2 = time.time()
371
  logger.info(f"[ssml] finish in {(t2 - t1):.2f}s")
372
 
 
427
  sdp_ratio = float(request_data.get("sdp_ratio", app.config.get("SDP_RATIO", 0.2)))
428
  max = int(request_data.get("max", app.config.get("MAX", 50)))
429
  except Exception as e:
430
+ logger.error(f"[{ModelType.BERT_VITS2.value}] {e}")
431
  return make_response("parameter error", 400)
432
 
433
  logger.info(
434
+ f"[{ModelType.BERT_VITS2.value}] id:{id} format:{format} lang:{lang} length:{length} noise:{noise} noisew:{noisew} sdp_ratio:{sdp_ratio}")
435
+ logger.info(f"[{ModelType.BERT_VITS2.value}] len:{len(text)} text:{text}")
436
 
437
  if check_is_none(text):
438
+ logger.info(f"[{ModelType.BERT_VITS2.value}] text is empty")
439
  return make_response(jsonify({"status": "error", "message": "text is empty"}), 400)
440
 
441
  if check_is_none(id):
442
+ logger.info(f"[{ModelType.BERT_VITS2.value}] speaker id is empty")
443
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
444
 
445
  if id < 0 or id >= tts.bert_vits2_speakers_count:
446
+ logger.info(f"[{ModelType.BERT_VITS2.value}] speaker id {id} does not exist")
447
  return make_response(jsonify({"status": "error", "message": f"id {id} does not exist"}), 400)
448
 
449
  # 校验模型是否支持输入的语言
450
+ speaker_lang = tts.voice_speakers[ModelType.BERT_VITS2.value][id].get('lang')
451
  if lang not in ["auto", "mix"] and len(speaker_lang) != 1 and lang not in speaker_lang:
452
+ logger.info(f"[{ModelType.BERT_VITS2.value}] lang \"{lang}\" is not in {speaker_lang}")
453
  return make_response(jsonify({"status": "error", "message": f"lang '{lang}' is not in {speaker_lang}"}), 400)
454
 
455
  # 如果配置文件中设置了LANGUAGE_AUTOMATIC_DETECT则强制将speaker_lang设置为LANGUAGE_AUTOMATIC_DETECT
 
472
  t1 = time.time()
473
  audio = tts.bert_vits2_infer(task)
474
  t2 = time.time()
475
+ logger.info(f"[{ModelType.BERT_VITS2.value}] finish in {(t2 - t1):.2f}s")
476
 
477
  if app.config.get("SAVE_AUDIO", False):
478
+ logger.debug(f"[{ModelType.BERT_VITS2.value}] {fname}")
479
  path = os.path.join(app.config.get('CACHE_PATH'), fname)
480
  save_audio(audio.getvalue(), path)
481
 
 
494
  else:
495
  request_data = request.form
496
 
497
+ model_type_str = request_data.get("model_type", request_data.get("model")).upper()
498
  id = int(request_data.get("id"))
499
  except Exception as e:
500
  logger.info(f"[check] {e}")
501
  return make_response(jsonify({"status": "error", "message": "parameter error"}), 400)
502
 
503
+ if check_is_none(model_type_str):
504
+ logger.info(f"[check] model {model_type_str} is empty")
505
  return make_response(jsonify({"status": "error", "message": "model is empty"}), 400)
506
 
507
+ if model_type_str not in ModelType._value2member_map_:
508
+ res = make_response(jsonify({"status": "error", "message": f"model {model_type_str} does not exist"}))
509
  res.status = 404
510
  logger.info(f"[check] speaker id {id} error")
511
  return res
 
514
  logger.info(f"[check] speaker id is empty")
515
  return make_response(jsonify({"status": "error", "message": "speaker id is empty"}), 400)
516
 
517
+ model_type = ModelType(model_type_str)
518
+ speaker_list = tts.voice_speakers[model_type.value]
 
 
 
 
519
 
520
  if len(speaker_list) == 0:
521
+ logger.info(f"[check] {model_type_str} not loaded")
522
+ return make_response(jsonify({"status": "error", "message": f"{model_type_str} not loaded"}), 400)
523
 
524
  if id < 0 or id >= len(speaker_list):
525
  logger.info(f"[check] speaker id {id} does not exist")
bert_vits2/bert_vits2.py CHANGED
@@ -6,19 +6,39 @@ from bert_vits2 import utils as bert_vits2_utils
6
  from bert_vits2.models import SynthesizerTrn
7
  from bert_vits2.text import *
8
  from bert_vits2.text.cleaner import clean_text
 
 
9
  from utils import classify_language, get_hparams_from_file, lang_dict
10
  from utils.sentence import sentence_split_and_markup, cut
11
 
12
 
13
  class Bert_VITS2:
14
  def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
15
- self.hps_ms = get_hparams_from_file(config)
16
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
17
  self.speakers = [item[0] for item in
18
  sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- self.legacy = getattr(self.hps_ms.data, 'legacy', False)
21
- self.symbols = symbols_legacy if self.legacy else symbols
22
  self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
23
 
24
  self.net_g = SynthesizerTrn(
@@ -33,11 +53,15 @@ class Bert_VITS2:
33
  self.load_model(model)
34
 
35
  def load_model(self, model):
36
- bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True)
37
 
38
  def get_speakers(self):
39
  return self.speakers
40
 
 
 
 
 
41
  def get_text(self, text, language_str, hps):
42
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
43
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)
@@ -49,7 +73,7 @@ class Bert_VITS2:
49
  for i in range(len(word2ph)):
50
  word2ph[i] = word2ph[i] * 2
51
  word2ph[0] += 1
52
- bert = get_bert(norm_text, word2ph, language_str)
53
  del word2ph
54
  assert bert.shape[-1] == len(phone), phone
55
 
@@ -98,7 +122,8 @@ class Bert_VITS2:
98
  max = voice.get("max", 50)
99
  # sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"])
100
  if lang == "auto":
101
- lang = classify_language(text, target_languages=lang_dict["bert_vits2"])
 
102
  sentence_list = cut(text, max)
103
  audios = []
104
  for sentence in sentence_list:
 
6
  from bert_vits2.models import SynthesizerTrn
7
  from bert_vits2.text import *
8
  from bert_vits2.text.cleaner import clean_text
9
+ from bert_vits2.utils import process_legacy_versions
10
+ from contants import ModelType
11
  from utils import classify_language, get_hparams_from_file, lang_dict
12
  from utils.sentence import sentence_split_and_markup, cut
13
 
14
 
15
  class Bert_VITS2:
16
  def __init__(self, model, config, device=torch.device("cpu"), **kwargs):
17
+ self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
18
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
19
  self.speakers = [item[0] for item in
20
  sorted(list(getattr(self.hps_ms.data, 'spk2id', {'0': 0}).items()), key=lambda x: x[1])]
21
+ self.symbols = symbols
22
+
23
+
24
+ # Compatible with legacy versions
25
+ self.version = process_legacy_versions(self.hps_ms)
26
+
27
+ if self.version in ["1.0", "1.0.0", "1.0.1"]:
28
+ self.symbols = symbols_legacy
29
+ self.hps_ms.model.n_layers_trans_flow = 3
30
+
31
+
32
+ elif self.version in ["1.1.0-transition"]:
33
+ self.hps_ms.model.n_layers_trans_flow = 3
34
+
35
+ elif self.version in ["1.1", "1.1.0", "1.1.1"]:
36
+ self.hps_ms.model.n_layers_trans_flow = 6
37
+
38
+ key = f"{ModelType.BERT_VITS2.value}_v{self.version}" if self.version else ModelType.BERT_VITS2.value
39
+ self.lang = lang_dict.get(key, ["unknown"])
40
+ self.bert_handler = BertHandler(self.lang)
41
 
 
 
42
  self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
43
 
44
  self.net_g = SynthesizerTrn(
 
53
  self.load_model(model)
54
 
55
  def load_model(self, model):
56
+ bert_vits2_utils.load_checkpoint(model, self.net_g, None, skip_optimizer=True, version=self.version)
57
 
58
  def get_speakers(self):
59
  return self.speakers
60
 
61
+ @property
62
+ def sampling_rate(self):
63
+ return self.hps_ms.data.sampling_rate
64
+
65
  def get_text(self, text, language_str, hps):
66
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
67
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, self._symbol_to_id)
 
73
  for i in range(len(word2ph)):
74
  word2ph[i] = word2ph[i] * 2
75
  word2ph[0] += 1
76
+ bert = self.bert_handler.get_bert(norm_text, word2ph, language_str)
77
  del word2ph
78
  assert bert.shape[-1] == len(phone), phone
79
 
 
122
  max = voice.get("max", 50)
123
  # sentence_list = sentence_split_and_markup(text, max, "ZH", ["zh"])
124
  if lang == "auto":
125
+ lang = classify_language(text, target_languages=self.lang)
126
+
127
  sentence_list = cut(text, max)
128
  audios = []
129
  for sentence in sentence_list:
bert_vits2/models.py CHANGED
@@ -26,9 +26,11 @@ class DurationDiscriminator(nn.Module): # vits2
26
 
27
  self.drop = nn.Dropout(p_dropout)
28
  self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
29
- # self.norm_1 = modules.LayerNorm(filter_channels)
30
- self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
31
- # self.norm_2 = modules.LayerNorm(filter_channels)
 
 
32
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
33
 
34
  self.pre_out_conv_1 = nn.Conv1d(2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
@@ -36,8 +38,8 @@ class DurationDiscriminator(nn.Module): # vits2
36
  self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
37
  self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
38
 
39
- # if gin_channels != 0:
40
- # self.cond = nn.Conv1d(gin_channels, in_channels, 1)
41
 
42
  self.output_layer = nn.Sequential(
43
  nn.Linear(filter_channels, 1),
@@ -48,13 +50,13 @@ class DurationDiscriminator(nn.Module): # vits2
48
  dur = self.dur_proj(dur)
49
  x = torch.cat([x, dur], dim=1)
50
  x = self.pre_out_conv_1(x * x_mask)
51
- # x = torch.relu(x)
52
- # x = self.pre_out_norm_1(x)
53
- # x = self.drop(x)
54
  x = self.pre_out_conv_2(x * x_mask)
55
- # x = torch.relu(x)
56
- # x = self.pre_out_norm_2(x)
57
- # x = self.drop(x)
58
  x = x * x_mask
59
  x = x.transpose(1, 2)
60
  output_prob = self.output_layer(x)
@@ -62,17 +64,17 @@ class DurationDiscriminator(nn.Module): # vits2
62
 
63
  def forward(self, x, x_mask, dur_r, dur_hat, g=None):
64
  x = torch.detach(x)
65
- # if g is not None:
66
- # g = torch.detach(g)
67
- # x = x + self.cond(g)
68
  x = self.conv_1(x * x_mask)
69
- # x = torch.relu(x)
70
- # x = self.norm_1(x)
71
- # x = self.drop(x)
72
  x = self.conv_2(x * x_mask)
73
- # x = torch.relu(x)
74
- # x = self.norm_2(x)
75
- # x = self.drop(x)
76
 
77
  output_probs = []
78
  for dur in [dur_r, dur_hat]:
@@ -590,7 +592,7 @@ class SynthesizerTrn(nn.Module):
590
  gin_channels=256,
591
  use_sdp=True,
592
  n_flow_layer=4,
593
- n_layers_trans_flow=3,
594
  flow_share_parameter=False,
595
  use_transformer_flow=True,
596
  **kwargs):
 
26
 
27
  self.drop = nn.Dropout(p_dropout)
28
  self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
29
+ self.norm_1 = modules.LayerNorm(filter_channels)
30
+ self.conv_2 = nn.Conv1d(
31
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
32
+ )
33
+ self.norm_2 = modules.LayerNorm(filter_channels)
34
  self.dur_proj = nn.Conv1d(1, filter_channels, 1)
35
 
36
  self.pre_out_conv_1 = nn.Conv1d(2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
 
38
  self.pre_out_conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
39
  self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
40
 
41
+ if gin_channels != 0:
42
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
43
 
44
  self.output_layer = nn.Sequential(
45
  nn.Linear(filter_channels, 1),
 
50
  dur = self.dur_proj(dur)
51
  x = torch.cat([x, dur], dim=1)
52
  x = self.pre_out_conv_1(x * x_mask)
53
+ x = torch.relu(x)
54
+ x = self.pre_out_norm_1(x)
55
+ x = self.drop(x)
56
  x = self.pre_out_conv_2(x * x_mask)
57
+ x = torch.relu(x)
58
+ x = self.pre_out_norm_2(x)
59
+ x = self.drop(x)
60
  x = x * x_mask
61
  x = x.transpose(1, 2)
62
  output_prob = self.output_layer(x)
 
64
 
65
  def forward(self, x, x_mask, dur_r, dur_hat, g=None):
66
  x = torch.detach(x)
67
+ if g is not None:
68
+ g = torch.detach(g)
69
+ x = x + self.cond(g)
70
  x = self.conv_1(x * x_mask)
71
+ x = torch.relu(x)
72
+ x = self.norm_1(x)
73
+ x = self.drop(x)
74
  x = self.conv_2(x * x_mask)
75
+ x = torch.relu(x)
76
+ x = self.norm_2(x)
77
+ x = self.drop(x)
78
 
79
  output_probs = []
80
  for dur in [dur_r, dur_hat]:
 
592
  gin_channels=256,
593
  use_sdp=True,
594
  n_flow_layer=4,
595
+ n_layers_trans_flow=6,
596
  flow_share_parameter=False,
597
  use_transformer_flow=True,
598
  **kwargs):
bert_vits2/text/__init__.py CHANGED
@@ -1,7 +1,5 @@
1
  from bert_vits2.text.symbols import *
2
- from .chinese_bert import get_bert_feature as zh_bert
3
- from .english_bert_mock import get_bert_feature as en_bert
4
- from .japanese_bert import get_bert_feature as ja_bert
5
 
6
 
7
  def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
@@ -17,9 +15,3 @@ def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
17
  lang_id = language_id_map[language]
18
  lang_ids = [lang_id for i in phones]
19
  return phones, tones, lang_ids
20
-
21
-
22
- def get_bert(norm_text, word2ph, language):
23
- lang_bert_func_map = {"zh": zh_bert, "en": en_bert, "ja": ja_bert}
24
- bert = lang_bert_func_map[language](norm_text, word2ph)
25
- return bert
 
1
  from bert_vits2.text.symbols import *
2
+ from bert_vits2.text.bert_handler import BertHandler
 
 
3
 
4
 
5
  def cleaned_text_to_sequence(cleaned_text, tones, language, _symbol_to_id):
 
15
  lang_id = language_id_map[language]
16
  lang_ids = [lang_id for i in phones]
17
  return phones, tones, lang_ids
 
 
 
 
 
 
bert_vits2/text/bert_handler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ class BertHandler:
5
+ _bert_functions = {}
6
+
7
+ BERT_IMPORT_MAP = {
8
+ "zh": "bert_vits2.text.chinese_bert.get_bert_feature",
9
+ "en": "bert_vits2.text.english_bert_mock.get_bert_feature",
10
+ "ja": "bert_vits2.text.japanese_bert.get_bert_feature",
11
+ }
12
+
13
+ def __init__(self, languages):
14
+ for lang in languages:
15
+ if lang not in BertHandler._bert_functions:
16
+ self.load_bert_function(lang)
17
+
18
+ def load_bert_function(self, language):
19
+ if language not in BertHandler.BERT_IMPORT_MAP:
20
+ raise ValueError(f"Unsupported language: {language}")
21
+
22
+ module_path, function_name = BertHandler.BERT_IMPORT_MAP[language].rsplit('.', 1)
23
+ module = importlib.import_module(module_path, package=__package__)
24
+ bert_function = getattr(module, function_name)
25
+
26
+ BertHandler._bert_functions[language] = bert_function
27
+
28
+ def get_bert(self, norm_text, word2ph, language):
29
+ if language not in BertHandler._bert_functions:
30
+ raise ValueError(f"BERT for {language} has not been initialized. Please initialize first.")
31
+
32
+ bert_func = BertHandler._bert_functions[language]
33
+ return bert_func(norm_text, word2ph)
bert_vits2/text/cleaner.py CHANGED
@@ -1,20 +1,34 @@
1
- from bert_vits2.text import chinese, japanese, cleaned_text_to_sequence
 
2
 
3
  language_module_map = {
4
- 'zh': chinese,
5
- 'ja': japanese
6
  }
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def clean_text(text, language):
10
- language_module = language_module_map[language]
11
  norm_text = language_module.text_normalize(text)
12
  phones, tones, word2ph = language_module.g2p(norm_text)
13
  return norm_text, phones, tones, word2ph
14
 
15
 
16
  def clean_text_bert(text, language):
17
- language_module = language_module_map[language]
18
  norm_text = language_module.text_normalize(text)
19
  phones, tones, word2ph = language_module.g2p(norm_text)
20
  bert = language_module.get_bert_feature(norm_text, word2ph)
 
1
+ import importlib
2
+ from bert_vits2.text import cleaned_text_to_sequence
3
 
4
  language_module_map = {
5
+ 'zh': "bert_vits2.text.chinese",
6
+ 'ja': "bert_vits2.text.japanese"
7
  }
8
 
9
+ _loaded_modules = {}
10
+
11
+
12
+ def get_language_module(language):
13
+ if language not in _loaded_modules:
14
+ module_path = language_module_map.get(language)
15
+ if not module_path:
16
+ raise ValueError(f"Unsupported language: {language}")
17
+
18
+ _loaded_modules[language] = importlib.import_module(module_path)
19
+
20
+ return _loaded_modules[language]
21
+
22
 
23
  def clean_text(text, language):
24
+ language_module = get_language_module(language)
25
  norm_text = language_module.text_normalize(text)
26
  phones, tones, word2ph = language_module.g2p(norm_text)
27
  return norm_text, phones, tones, word2ph
28
 
29
 
30
  def clean_text_bert(text, language):
31
+ language_module = get_language_module(language)
32
  norm_text = language_module.text_normalize(text)
33
  phones, tones, word2ph = language_module.g2p(norm_text)
34
  bert = language_module.get_bert_feature(norm_text, word2ph)
bert_vits2/utils.py CHANGED
@@ -9,7 +9,7 @@ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
9
  logger = logging
10
 
11
 
12
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
13
  assert os.path.isfile(checkpoint_path)
14
  checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
15
  iteration = checkpoint_dict['iteration']
@@ -36,12 +36,19 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
36
  new_state_dict[k] = saved_state_dict[k]
37
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
38
  except:
39
- # For upgrading from the old version
40
  if "ja_bert_proj" in k:
41
  v = torch.zeros_like(v)
 
 
 
 
 
 
42
  logger.warning(
43
- f"If you are using an older version of the model, you should add the parameter \"legacy\":true to the data of the model's config.json")
44
- logger.error(f"{k} is not in the checkpoint")
 
45
 
46
  new_state_dict[k] = v
47
  if hasattr(model, 'module'):
@@ -52,3 +59,12 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
52
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
53
  checkpoint_path, iteration))
54
  return model, optimizer, learning_rate, iteration
 
 
 
 
 
 
 
 
 
 
9
  logger = logging
10
 
11
 
12
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False, version=None):
13
  assert os.path.isfile(checkpoint_path)
14
  checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
15
  iteration = checkpoint_dict['iteration']
 
36
  new_state_dict[k] = saved_state_dict[k]
37
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
38
  except:
39
+ # Handle legacy model versions and provide appropriate warnings
40
  if "ja_bert_proj" in k:
41
  v = torch.zeros_like(v)
42
+ if version is None:
43
+ logger.error(f"{k} is not in the checkpoint")
44
+ logger.warning(
45
+ f"If you're using an older version of the model, consider adding the \"version\" parameter to the model's config.json under the \"data\" section. For instance: \"legacy_version\": \"1.0.1\"")
46
+ elif "flow.flows.0.enc.attn_layers.3" in k:
47
+ logger.error(f"{k} is not in the checkpoint")
48
  logger.warning(
49
+ f"If you're using a transitional version, please add the \"version\": \"1.1.0-transition\" parameter within the \"data\" section of the model's config.json.")
50
+ else:
51
+ logger.error(f"{k} is not in the checkpoint")
52
 
53
  new_state_dict[k] = v
54
  if hasattr(model, 'module'):
 
59
  logger.info("Loaded checkpoint '{}' (iteration {})".format(
60
  checkpoint_path, iteration))
61
  return model, optimizer, learning_rate, iteration
62
+
63
+
64
+ def process_legacy_versions(hps):
65
+ version = getattr(hps, "version", getattr(hps.data, "version", None))
66
+ if version:
67
+ prefix = version[0].lower()
68
+ if prefix == "v":
69
+ version = version[1:]
70
+ return version
contants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class ModelType(Enum):
4
+ VITS = "VITS"
5
+ HUBERT_VITS = "HUBERT-VITS"
6
+ W2V2_VITS = "W2V2-VITS"
7
+ BERT_VITS2 = "BERT-VITS2"
docker-compose-gpu.yaml CHANGED
@@ -8,14 +8,16 @@ services:
8
  - 23456:23456
9
  environment:
10
  LANG: 'C.UTF-8'
11
- #TZ: Asia/Shanghai #timezone
 
12
  volumes:
13
  - ./Model:/app/Model # 挂载模型文件夹
14
  - ./config.py:/app/config.py # 挂载配置文件
15
  - ./logs:/app/logs # logging logs
16
  - ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
17
  - ./vits/bert:/app/vits/bert # vits_chinese
18
- - ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2
 
19
  - ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
20
  deploy:
21
  resources:
 
8
  - 23456:23456
9
  environment:
10
  LANG: 'C.UTF-8'
11
+ TZ: Asia/Shanghai #timezone
12
+ command: ['python', 'app.py'] # 目前用不了gunicorn多进程
13
  volumes:
14
  - ./Model:/app/Model # 挂载模型文件夹
15
  - ./config.py:/app/config.py # 挂载配置文件
16
  - ./logs:/app/logs # logging logs
17
  - ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
18
  - ./vits/bert:/app/vits/bert # vits_chinese
19
+ - ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2 chinese
20
+ - ./bert_vits2/bert/bert-base-japanese-v3:/app/bert_vits2/bert/bert-base-japanese-v3 # Bert-vits2 japanese
21
  - ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
22
  deploy:
23
  resources:
docker-compose.yaml CHANGED
@@ -14,5 +14,6 @@ services:
14
  - ./logs:/app/logs # logging logs
15
  - ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
16
  - ./vits/bert:/app/vits/bert # vits_chinese
17
- - ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2
 
18
  - ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
 
14
  - ./logs:/app/logs # logging logs
15
  - ./gunicorn_config.py:/app/gunicorn_config.py # gunicorn configuration
16
  - ./vits/bert:/app/vits/bert # vits_chinese
17
+ - ./bert_vits2/bert/chinese-roberta-wwm-ext-large:/app/bert_vits2/bert/chinese-roberta-wwm-ext-large # Bert-vits2 chinese
18
+ - ./bert_vits2/bert/bert-base-japanese-v3:/app/bert_vits2/bert/bert-base-japanese-v3 # Bert-vits2 japanese
19
  - ./pyopenjtalk/open_jtalk_dic_utf_8-1.11:/usr/local/lib/python3.10/site-packages/pyopenjtalk/open_jtalk_dic_utf_8-1.11 #pyopentjalk
logger.py CHANGED
@@ -22,10 +22,8 @@ handler = TimedRotatingFileHandler(log_file, when="midnight", interval=1, backup
22
  handler.suffix = "%Y-%m-%d.log"
23
  formatter = logging.Formatter('%(levelname)s:%(name)s %(message)s')
24
  handler.setFormatter(formatter)
25
- logger.addHandler(handler)
26
 
27
- logging.getLogger("werkzeug").addHandler(handler)
28
- logging.getLogger("apscheduler.scheduler").addHandler(handler)
29
 
30
 
31
  # Custom function to handle uncaught exceptions
 
22
  handler.suffix = "%Y-%m-%d.log"
23
  formatter = logging.Formatter('%(levelname)s:%(name)s %(message)s')
24
  handler.setFormatter(formatter)
 
25
 
26
+ logging.getLogger().addHandler(handler)
 
27
 
28
 
29
  # Custom function to handle uncaught exceptions
utils/classify_language.py CHANGED
@@ -2,16 +2,27 @@ from config import LANGUAGE_IDENTIFICATION_LIBRARY
2
 
3
  module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
4
 
 
 
 
 
 
 
 
5
 
6
  def classify_language(text: str, target_languages: list = None) -> str:
7
  if module == "fastlid" or module == "fasttext":
8
- from fastlid import fastlid
9
  classifier = fastlid
10
- if target_languages != None: fastlid.set_languages = target_languages
 
 
11
  elif module == "langid":
12
  import langid
13
  classifier = langid.classify
14
- if target_languages != None: langid.set_languages(target_languages)
 
 
15
  else:
16
  raise ValueError(f"Wrong LANGUAGE_IDENTIFICATION_LIBRARY in config.py")
17
 
 
2
 
3
  module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
4
 
5
+ langid_languages = ["af", "am", "an", "ar", "as", "az", "be", "bg", "bn", "br", "bs", "ca", "cs", "cy", "da", "de", "dz", "el",
6
+ "en", "eo", "es", "et", "eu", "fa", "fi", "fo", "fr", "ga", "gl", "gu", "he", "hi", "hr", "ht", "hu", "hy",
7
+ "id", "is", "it", "ja", "jv", "ka", "kk", "km", "kn", "ko", "ku", "ky", "la", "lb", "lo", "lt", "lv", "mg",
8
+ "mk", "ml", "mn", "mr", "ms", "mt", "nb", "ne", "nl", "nn", "no", "oc", "or", "pa", "pl", "ps", "pt", "qu",
9
+ "ro", "ru", "rw", "se", "si", "sk", "sl", "sq", "sr", "sv", "sw", "ta", "te", "th", "tl", "tr", "ug", "uk",
10
+ "ur", "vi", "vo", "wa", "xh", "zh", "zu"]
11
+
12
 
13
  def classify_language(text: str, target_languages: list = None) -> str:
14
  if module == "fastlid" or module == "fasttext":
15
+ from fastlid import fastlid, supported_langs
16
  classifier = fastlid
17
+ if target_languages != None:
18
+ target_languages = [lang for lang in target_languages if lang in supported_langs]
19
+ fastlid.set_languages = target_languages
20
  elif module == "langid":
21
  import langid
22
  classifier = langid.classify
23
+ if target_languages != None:
24
+ target_languages = [lang for lang in target_languages if lang in langid_languages]
25
+ langid.set_languages(target_languages)
26
  else:
27
  raise ValueError(f"Wrong LANGUAGE_IDENTIFICATION_LIBRARY in config.py")
28
 
utils/download.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import hashlib
3
  import tarfile
@@ -18,6 +19,7 @@ class TqdmUpTo(tqdm):
18
 
19
 
20
  def download_file(url, dest_path):
 
21
  with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=url.split('/')[-1]) as t:
22
  urllib.request.urlretrieve(url, dest_path, reporthook=t.update_to)
23
 
@@ -37,6 +39,8 @@ def extract_file(file_path, destination=None):
37
  if destination is None:
38
  destination = Path(file_path).parent
39
 
 
 
40
  if file_path.endswith('.zip'):
41
  with zipfile.ZipFile(file_path, 'r') as zip_ref:
42
  zip_ref.extractall(destination)
@@ -50,7 +54,7 @@ def extract_file(file_path, destination=None):
50
  with SevenZipFile(file_path, mode='r') as z:
51
  z.extractall(destination)
52
  else:
53
- print(f"Unsupported compression format for file {file_path}")
54
 
55
 
56
  def download_and_verify(urls, target_path, expected_md5=None, extract_destination=None):
 
1
+ import logging
2
  import os
3
  import hashlib
4
  import tarfile
 
19
 
20
 
21
  def download_file(url, dest_path):
22
+ logging.info(f"Downloading: {url}")
23
  with TqdmUpTo(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=url.split('/')[-1]) as t:
24
  urllib.request.urlretrieve(url, dest_path, reporthook=t.update_to)
25
 
 
39
  if destination is None:
40
  destination = Path(file_path).parent
41
 
42
+ logging.info(f"Extracting to {destination}")
43
+
44
  if file_path.endswith('.zip'):
45
  with zipfile.ZipFile(file_path, 'r') as zip_ref:
46
  zip_ref.extractall(destination)
 
54
  with SevenZipFile(file_path, mode='r') as z:
55
  z.extractall(destination)
56
  else:
57
+ logging.error(f"Unsupported compression format for file {file_path}")
58
 
59
 
60
  def download_and_verify(urls, target_path, expected_md5=None, extract_destination=None):
utils/lang_dict.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  lang_dict = {
2
  "english_cleaners": ["en"],
3
  "english_cleaners2": ["en"],
@@ -18,5 +20,12 @@ lang_dict = {
18
  "ZR", "PH", "TX", "JS", "HN", "LP", "XS", "FY", "RA", "CX", "SM", "TT", "WZ", "SC",
19
  "YB"],
20
  "bert_chinese_cleaners": ["zh"],
21
- "bert_vits2": ["zh", "ja"]
 
 
 
 
 
 
 
22
  }
 
1
+ from contants import ModelType
2
+
3
  lang_dict = {
4
  "english_cleaners": ["en"],
5
  "english_cleaners2": ["en"],
 
20
  "ZR", "PH", "TX", "JS", "HN", "LP", "XS", "FY", "RA", "CX", "SM", "TT", "WZ", "SC",
21
  "YB"],
22
  "bert_chinese_cleaners": ["zh"],
23
+ ModelType.BERT_VITS2.value: ["zh", "ja"],
24
+ f"{ModelType.BERT_VITS2.value}_v1.0": ["zh"],
25
+ f"{ModelType.BERT_VITS2.value}_v1.0.0": ["zh"],
26
+ f"{ModelType.BERT_VITS2.value}_v1.0.1": ["zh"],
27
+ f"{ModelType.BERT_VITS2.value}_v1.1": ["zh", "ja"],
28
+ f"{ModelType.BERT_VITS2.value}_v1.1.0": ["zh", "ja"],
29
+ f"{ModelType.BERT_VITS2.value}_v1.1.0-transition": ["zh", "ja"],
30
+ f"{ModelType.BERT_VITS2.value}_v1.1.1": ["zh", "ja"],
31
  }
utils/load_model.py CHANGED
@@ -3,30 +3,33 @@ import json
3
  import logging
4
  import config
5
  import numpy as np
6
- from utils.data_utils import check_is_none
 
 
7
  from vits import VITS
8
  from voice import TTS
9
  from config import DEVICE as device
10
  from utils.lang_dict import lang_dict
 
11
 
12
 
 
 
 
 
 
13
 
14
- def recognition_model_type(model_config_json: json) -> str:
15
- model_config = json.load(model_config_json)
16
- symbols = model_config.get("symbols", None)
17
- emotion_embedding = model_config.get("data").get("emotion_embedding", False)
18
-
19
- if "use_spk_conditioned_encoder" in model_config.get("model"):
20
- model_type = "bert_vits2"
21
  return model_type
22
 
23
  if symbols != None:
24
  if not emotion_embedding:
25
- mode_type = "vits"
26
  else:
27
- mode_type = "w2v2"
28
  else:
29
- mode_type = "hubert"
30
 
31
  return mode_type
32
 
@@ -72,16 +75,19 @@ def load_npy(emotion_reference_npy):
72
 
73
  def parse_models(model_list):
74
  categorized_models = {
75
- "vits": [],
76
- "hubert": [],
77
- "w2v2": [],
78
- "bert_vits2": []
79
  }
80
 
81
  for model_info in model_list:
82
  config_path = model_info[1]
83
- with open(config_path, 'r', encoding='utf-8') as model_config:
84
- model_type = recognition_model_type(model_config)
 
 
 
85
  if model_type in categorized_models:
86
  categorized_models[model_type].append(model_info)
87
 
@@ -93,23 +99,26 @@ def merge_models(model_list, model_class, model_type, additional_arg=None):
93
  speakers = []
94
  new_id = 0
95
 
96
- for obj_id, (model_path, config_path) in enumerate(model_list):
97
  obj_args = {
98
  "model": model_path,
99
- "config": config_path,
100
  "model_type": model_type,
101
  "device": device
102
  }
 
 
 
 
 
 
 
 
103
  if additional_arg:
104
  obj_args.update(additional_arg)
105
 
106
  obj = model_class(**obj_args)
107
 
108
- if model_type == "bert_vits2":
109
- key = model_type
110
- else:
111
- key = obj.get_cleaner()
112
-
113
  lang = lang_dict.get(key, ["unknown"])
114
 
115
  for real_id, name in enumerate(obj.get_speakers()):
@@ -124,11 +133,11 @@ def load_model(model_list) -> TTS:
124
  categorized_models = parse_models(model_list)
125
 
126
  # Handle VITS
127
- vits_objs, vits_speakers = merge_models(categorized_models["vits"], VITS, "vits")
128
 
129
  # Handle HUBERT-VITS
130
  hubert_vits_objs, hubert_vits_speakers = [], []
131
- if len(categorized_models["hubert"]) != 0:
132
  if getattr(config, "HUBERT_SOFT_MODEL", None) is None or check_is_none(config.HUBERT_SOFT_MODEL):
133
  raise ValueError(f"Please configure HUBERT_SOFT_MODEL path in config.py")
134
  try:
@@ -137,13 +146,13 @@ def load_model(model_list) -> TTS:
137
  except Exception as e:
138
  raise ValueError(f"Load HUBERT_SOFT_MODEL failed {e}")
139
 
140
- hubert_vits_objs, hubert_vits_speakers = merge_models(categorized_models["hubert"], VITS, "hubert",
141
  additional_arg={"additional_model": hubert})
142
 
143
  # Handle W2V2-VITS
144
  w2v2_vits_objs, w2v2_vits_speakers = [], []
145
  w2v2_emotion_count = 0
146
- if len(categorized_models["w2v2"]) != 0:
147
  if getattr(config, "DIMENSIONAL_EMOTION_NPY", None) is None or check_is_none(
148
  config.DIMENSIONAL_EMOTION_NPY):
149
  raise ValueError(f"Please configure DIMENSIONAL_EMOTION_NPY path in config.py")
@@ -153,20 +162,24 @@ def load_model(model_list) -> TTS:
153
  emotion_reference = None
154
  raise ValueError(f"Load DIMENSIONAL_EMOTION_NPY failed {e}")
155
 
156
- w2v2_vits_objs, w2v2_vits_speakers = merge_models(categorized_models["w2v2"], VITS, "w2v2",
157
  additional_arg={"additional_model": emotion_reference})
158
  w2v2_emotion_count = len(emotion_reference) if emotion_reference is not None else 0
159
 
160
  # Handle BERT-VITS2
161
  bert_vits2_objs, bert_vits2_speakers = [], []
162
- if len(categorized_models["bert_vits2"]) != 0:
163
  from bert_vits2 import Bert_VITS2
164
- bert_vits2_objs, bert_vits2_speakers = merge_models(categorized_models["bert_vits2"], Bert_VITS2, "bert_vits2")
165
-
166
- voice_obj = {"VITS": vits_objs, "HUBERT-VITS": hubert_vits_objs, "W2V2-VITS": w2v2_vits_objs,
167
- "BERT-VITS2": bert_vits2_objs}
168
- voice_speakers = {"VITS": vits_speakers, "HUBERT-VITS": hubert_vits_speakers, "W2V2-VITS": w2v2_vits_speakers,
169
- "BERT-VITS2": bert_vits2_speakers}
 
 
 
 
170
 
171
  tts = TTS(voice_obj, voice_speakers, device=device, w2v2_emotion_count=w2v2_emotion_count)
172
  return tts
 
3
  import logging
4
  import config
5
  import numpy as np
6
+
7
+ import utils
8
+ from utils.data_utils import check_is_none, HParams
9
  from vits import VITS
10
  from voice import TTS
11
  from config import DEVICE as device
12
  from utils.lang_dict import lang_dict
13
+ from contants import ModelType
14
 
15
 
16
+ def recognition_model_type(hps: HParams) -> str:
17
+ # model_config = json.load(model_config_json)
18
+ symbols = getattr(hps, "symbols", None)
19
+ # symbols = model_config.get("symbols", None)
20
+ emotion_embedding = getattr(hps.data, "emotion_embedding", False)
21
 
22
+ if "use_spk_conditioned_encoder" in hps.model:
23
+ model_type = ModelType.BERT_VITS2
 
 
 
 
 
24
  return model_type
25
 
26
  if symbols != None:
27
  if not emotion_embedding:
28
+ mode_type = ModelType.VITS
29
  else:
30
+ mode_type = ModelType.W2V2_VITS
31
  else:
32
+ mode_type = ModelType.HUBERT_VITS
33
 
34
  return mode_type
35
 
 
75
 
76
  def parse_models(model_list):
77
  categorized_models = {
78
+ ModelType.VITS: [],
79
+ ModelType.HUBERT_VITS: [],
80
+ ModelType.W2V2_VITS: [],
81
+ ModelType.BERT_VITS2: []
82
  }
83
 
84
  for model_info in model_list:
85
  config_path = model_info[1]
86
+ hps = utils.get_hparams_from_file(config_path)
87
+ model_info.append(hps)
88
+ model_type = recognition_model_type(hps)
89
+ # with open(config_path, 'r', encoding='utf-8') as model_config:
90
+ # model_type = recognition_model_type(model_config)
91
  if model_type in categorized_models:
92
  categorized_models[model_type].append(model_info)
93
 
 
99
  speakers = []
100
  new_id = 0
101
 
102
+ for obj_id, (model_path, config_path, hps) in enumerate(model_list):
103
  obj_args = {
104
  "model": model_path,
105
+ "config": hps,
106
  "model_type": model_type,
107
  "device": device
108
  }
109
+
110
+ if model_type == ModelType.BERT_VITS2:
111
+ from bert_vits2.utils import process_legacy_versions
112
+ legacy_versions = process_legacy_versions(hps)
113
+ key = f"{model_type.value}_v{legacy_versions}" if legacy_versions else model_type.value
114
+ else:
115
+ key = getattr(hps.data, "text_cleaners", ["none"])[0]
116
+
117
  if additional_arg:
118
  obj_args.update(additional_arg)
119
 
120
  obj = model_class(**obj_args)
121
 
 
 
 
 
 
122
  lang = lang_dict.get(key, ["unknown"])
123
 
124
  for real_id, name in enumerate(obj.get_speakers()):
 
133
  categorized_models = parse_models(model_list)
134
 
135
  # Handle VITS
136
+ vits_objs, vits_speakers = merge_models(categorized_models[ModelType.VITS], VITS, ModelType.VITS)
137
 
138
  # Handle HUBERT-VITS
139
  hubert_vits_objs, hubert_vits_speakers = [], []
140
+ if len(categorized_models[ModelType.HUBERT_VITS]) != 0:
141
  if getattr(config, "HUBERT_SOFT_MODEL", None) is None or check_is_none(config.HUBERT_SOFT_MODEL):
142
  raise ValueError(f"Please configure HUBERT_SOFT_MODEL path in config.py")
143
  try:
 
146
  except Exception as e:
147
  raise ValueError(f"Load HUBERT_SOFT_MODEL failed {e}")
148
 
149
+ hubert_vits_objs, hubert_vits_speakers = merge_models(categorized_models[ModelType.HUBERT_VITS], VITS, ModelType.HUBERT_VITS,
150
  additional_arg={"additional_model": hubert})
151
 
152
  # Handle W2V2-VITS
153
  w2v2_vits_objs, w2v2_vits_speakers = [], []
154
  w2v2_emotion_count = 0
155
+ if len(categorized_models[ModelType.W2V2_VITS]) != 0:
156
  if getattr(config, "DIMENSIONAL_EMOTION_NPY", None) is None or check_is_none(
157
  config.DIMENSIONAL_EMOTION_NPY):
158
  raise ValueError(f"Please configure DIMENSIONAL_EMOTION_NPY path in config.py")
 
162
  emotion_reference = None
163
  raise ValueError(f"Load DIMENSIONAL_EMOTION_NPY failed {e}")
164
 
165
+ w2v2_vits_objs, w2v2_vits_speakers = merge_models(categorized_models[ModelType.W2V2_VITS], VITS, ModelType.W2V2_VITS,
166
  additional_arg={"additional_model": emotion_reference})
167
  w2v2_emotion_count = len(emotion_reference) if emotion_reference is not None else 0
168
 
169
  # Handle BERT-VITS2
170
  bert_vits2_objs, bert_vits2_speakers = [], []
171
+ if len(categorized_models[ModelType.BERT_VITS2]) != 0:
172
  from bert_vits2 import Bert_VITS2
173
+ bert_vits2_objs, bert_vits2_speakers = merge_models(categorized_models[ModelType.BERT_VITS2], Bert_VITS2, ModelType.BERT_VITS2)
174
+
175
+ voice_obj = {ModelType.VITS: vits_objs,
176
+ ModelType.HUBERT_VITS: hubert_vits_objs,
177
+ ModelType.W2V2_VITS: w2v2_vits_objs,
178
+ ModelType.BERT_VITS2: bert_vits2_objs}
179
+ voice_speakers = {ModelType.VITS.value: vits_speakers,
180
+ ModelType.HUBERT_VITS.value: hubert_vits_speakers,
181
+ ModelType.W2V2_VITS.value: w2v2_vits_speakers,
182
+ ModelType.BERT_VITS2.value: bert_vits2_speakers}
183
 
184
  tts = TTS(voice_obj, voice_speakers, device=device, w2v2_emotion_count=w2v2_emotion_count)
185
  return tts
vits-simple-api-installer-latest.sh CHANGED
@@ -28,9 +28,9 @@ EN_MESSAGES=(
28
  ["DOWNLOAD_VITS_CHINESE"]="Do you want to download the bert model for vits_chinese? Enter 1 for yes, 2 for no."
29
  ["MUST_DOWNLOAD_VITS_CHINESE"]="Using vits_chinese requires downloading these models, which will take up about 410MB."
30
  ["DOWNLOAD_BERT_VITS2"]="Do you want to download chinese-roberta-wwm-ext-large? Enter 1 for yes or 2 for no"
31
- ["MUST_DOWNLOAD_BERT_VITS2"]="To use Bert-VITS2, you must download these models, which will take up about 3.64GB."
32
  ["DOWNLOADED"]="File is downloaded correctly."
33
- ["CORRUPTED"]="File is corrupted or incomplete."
34
  ["INSTALL_COMPLETE"]="The upgrade or installation has been completed."
35
  ["CONFIG_DIR"]="The configuration file directory is"
36
  ["IMPORT_NOTICE"]="If the vits model is not imported, it cannot be used. Import the model in the configuration file directory."
@@ -59,9 +59,9 @@ ZH_MESSAGES=(
59
  ["DOWNLOAD_VITS_CHINESE"]="是否要下载vits_chinese的bert模型?输入1表示是,2表示否。"
60
  ["MUST_DOWNLOAD_VITS_CHINESE"]="使用vits_chinese必须下载这些模型,将占用大约410MB。"
61
  ["DOWNLOAD_BERT_VITS2"]="是否要下载chinese-roberta-wwm-ext-large?输入1表示是,2表示否。"
62
- ["MUST_DOWNLOAD_BERT_VITS2"]="使用Bert-VITS2必须下载这些模型,将占用大约3.64GB。"
63
  ["DOWNLOADED"]="文件已正确下载。"
64
- ["CORRUPTED"]="文件已损坏或不完整。"
65
  ["INSTALL_COMPLETE"]="更新或安装已完成。"
66
  ["CONFIG_DIR"]="配置文件目录是"
67
  ["IMPORT_NOTICE"]="如果vits模型没有被导入,它是无法使用的。请在配置文件目录中导入模型。"
@@ -240,6 +240,21 @@ if [ "$choice_download_bert_vits2" -eq 1 ]; then
240
  download_with_fallback bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin \
241
  "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin"
242
  fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
  fi
245
 
 
28
  ["DOWNLOAD_VITS_CHINESE"]="Do you want to download the bert model for vits_chinese? Enter 1 for yes, 2 for no."
29
  ["MUST_DOWNLOAD_VITS_CHINESE"]="Using vits_chinese requires downloading these models, which will take up about 410MB."
30
  ["DOWNLOAD_BERT_VITS2"]="Do you want to download chinese-roberta-wwm-ext-large? Enter 1 for yes or 2 for no"
31
+ ["MUST_DOWNLOAD_BERT_VITS2"]="To use Bert-VITS2, you must download these models, which will take up about 1.63GB."
32
  ["DOWNLOADED"]="File is downloaded correctly."
33
+ ["CORRUPTED"]="The file may not have been downloaded, or the download might be incomplete, and it could also be corrupted."
34
  ["INSTALL_COMPLETE"]="The upgrade or installation has been completed."
35
  ["CONFIG_DIR"]="The configuration file directory is"
36
  ["IMPORT_NOTICE"]="If the vits model is not imported, it cannot be used. Import the model in the configuration file directory."
 
59
  ["DOWNLOAD_VITS_CHINESE"]="是否要下载vits_chinese的bert模型?输入1表示是,2表示否。"
60
  ["MUST_DOWNLOAD_VITS_CHINESE"]="使用vits_chinese必须下载这些模型,将占用大约410MB。"
61
  ["DOWNLOAD_BERT_VITS2"]="是否要下载chinese-roberta-wwm-ext-large?输入1表示是,2表示否。"
62
+ ["MUST_DOWNLOAD_BERT_VITS2"]="使用Bert-VITS2必须下载这些模型,将占用大约1.63GB。"
63
  ["DOWNLOADED"]="文件已正确下载。"
64
+ ["CORRUPTED"]="文件可能未下载,或下载不完整,也有可能已损坏。"
65
  ["INSTALL_COMPLETE"]="更新或安装已完成。"
66
  ["CONFIG_DIR"]="配置文件目录是"
67
  ["IMPORT_NOTICE"]="如果vits模型没有被导入,它是无法使用的。请在配置文件目录中导入模型。"
 
240
  download_with_fallback bert_vits2/bert/chinese-roberta-wwm-ext-large/pytorch_model.bin \
241
  "https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/resolve/main/pytorch_model.bin"
242
  fi
243
+
244
+ mkdir -p bert_vits2/bert/bert-base-japanese-v3
245
+
246
+ EXPECTED_MD5="6d0f8f3503dae04df0711b6175ef0c8e"
247
+ FILE_PATH="bert_vits2/bert/bert-base-japanese-v3/pytorch_model.bin"
248
+ echo -e "${MESSAGES["VERIFYING"]}$FILE_PATH"
249
+ ACTUAL_MD5=$(md5sum $FILE_PATH | awk '{print $1}')
250
+
251
+ if [ "$EXPECTED_MD5" == "$ACTUAL_MD5" ]; then
252
+ echo "${MESSAGES["DOWNLOADED"]}"
253
+ else
254
+ echo ${MESSAGES["CORRUPTED"]}
255
+ download_with_fallback bert_vits2/bert/bert-base-japanese-v3/pytorch_model.bin \
256
+ "https://huggingface.co/cl-tohoku/bert-base-japanese-v3/resolve/main/pytorch_model.bin"
257
+ fi
258
 
259
  fi
260
 
vits/text/japanese.py CHANGED
@@ -1,7 +1,23 @@
 
1
  import re
2
  from unidecode import unidecode
3
  import pyopenjtalk
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Regular expression matching Japanese without punctuation marks:
7
  _japanese_characters = re.compile(
@@ -127,7 +143,7 @@ def get_real_hatsuon(text):
127
  def japanese_to_ipa(text):
128
  text = japanese_to_romaji_with_accent(text).replace('...', '…')
129
  text = re.sub(
130
- r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
131
  text = get_real_sokuon(text)
132
  text = get_real_hatsuon(text)
133
  for regex, replacement in _romaji_to_ipa:
@@ -148,6 +164,6 @@ def japanese_to_ipa3(text):
148
  text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
149
  'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
150
  text = re.sub(
151
- r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
152
  text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
153
  return text
 
1
+ import os
2
  import re
3
  from unidecode import unidecode
4
  import pyopenjtalk
5
 
6
+ from config import ABS_PATH
7
+ from utils.download import download_and_verify
8
+
9
+ URLS = [
10
+ "https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz",
11
+ "https://ghproxy.com/https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz",
12
+ ]
13
+ install_path = os.path.dirname(pyopenjtalk.__file__)
14
+ dict_path = os.path.join(install_path, "open_jtalk_dic_utf_8-1.11", "char.bin")
15
+ TARGET_PATH = os.path.join(ABS_PATH, "open_jtalk_dic_utf_8-1.11.tar.gz")
16
+ EXTRACT_DESTINATION = os.path.join(install_path, "")
17
+ EXPECTED_MD5 = None
18
+
19
+ if not os.path.exists(dict_path):
20
+ success, message = download_and_verify(URLS, TARGET_PATH, EXPECTED_MD5, EXTRACT_DESTINATION)
21
 
22
  # Regular expression matching Japanese without punctuation marks:
23
  _japanese_characters = re.compile(
 
143
  def japanese_to_ipa(text):
144
  text = japanese_to_romaji_with_accent(text).replace('...', '…')
145
  text = re.sub(
146
+ r'([aiueo])\1+', lambda x: x.group(0)[0] + 'ː' * (len(x.group(0)) - 1), text)
147
  text = get_real_sokuon(text)
148
  text = get_real_hatsuon(text)
149
  for regex, replacement in _romaji_to_ipa:
 
164
  text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
165
  'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
166
  text = re.sub(
167
+ r'([aiɯeo])\1+', lambda x: x.group(0)[0] + 'ː' * (len(x.group(0)) - 1), text)
168
  text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
169
  return text
vits/vits.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import torch
5
  from torch import no_grad, LongTensor, inference_mode, FloatTensor
6
  import utils
 
 
7
  from utils.sentence import sentence_split_and_markup
8
  from vits import commons
9
  from vits.mel_processing import spectrogram_torch
@@ -11,11 +13,10 @@ from vits.text import text_to_sequence
11
  from vits.models import SynthesizerTrn
12
 
13
 
14
-
15
  class VITS:
16
- def __init__(self, model, config, additional_model=None, model_type=None, device=torch.device("cpu"),**kwargs):
17
  self.model_type = model_type
18
- self.hps_ms = utils.get_hparams_from_file(config)
19
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
20
  self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
21
  self.speakers = getattr(self.hps_ms, 'speakers', ['0'])
@@ -38,15 +39,18 @@ class VITS:
38
  _ = self.net_g_ms.eval()
39
  self.device = device
40
 
 
 
 
41
  # load model
42
  self.load_model(model, additional_model)
43
 
44
  def load_model(self, model, additional_model=None):
45
  utils.load_checkpoint(model, self.net_g_ms)
46
  self.net_g_ms.to(self.device)
47
- if self.model_type == "hubert":
48
  self.hubert = additional_model
49
- elif self.model_type == "w2v2":
50
  self.emotion_reference = additional_model
51
 
52
  def get_cleaned_text(self, text, hps, cleaned=False):
@@ -71,13 +75,17 @@ class VITS:
71
  def get_speakers(self, escape=False):
72
  return self.speakers
73
 
 
 
 
 
74
  def infer(self, params):
75
  with no_grad():
76
  x_tst = params.get("stn_tst").unsqueeze(0).to(self.device)
77
  x_tst_lengths = LongTensor([params.get("stn_tst").size(0)]).to(self.device)
78
  x_tst_prosody = torch.FloatTensor(params.get("char_embeds")).unsqueeze(0).to(
79
  self.device) if self.bert_embedding else None
80
- sid = params.get("sid").to(self.device) if not self.bert_embedding else None
81
  emotion = params.get("emotion").to(self.device) if self.emotion_embedding else None
82
 
83
  audio = self.net_g_ms.infer(x=x_tst,
@@ -97,15 +105,14 @@ class VITS:
97
  emotion=None, cleaned=False, f0_scale=1):
98
  emo = None
99
  char_embeds = None
100
- if self.model_type != "hubert":
101
  if self.bert_embedding:
102
  stn_tst, char_embeds = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
103
- sid = None
104
  else:
105
  stn_tst = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
106
- sid = LongTensor([speaker_id])
107
 
108
- if self.model_type == "w2v2":
109
  # if emotion_reference.endswith('.npy'):
110
  # emotion = np.load(emotion_reference)
111
  # emotion = FloatTensor(emotion).unsqueeze(0)
@@ -121,7 +128,7 @@ class VITS:
121
  emo = torch.FloatTensor(self.emotion_reference[emotion]).unsqueeze(0)
122
 
123
 
124
- elif self.model_type == "hubert":
125
  if self.use_f0:
126
  audio, sampling_rate = librosa.load(audio_path, sr=self.hps_ms.data.sampling_rate, mono=True)
127
  audio16000 = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
@@ -165,35 +172,39 @@ class VITS:
165
  if text is not None: text = re.sub(r'\s+', ' ', text).strip()
166
 
167
  tasks = []
168
- if self.model_type == "vits":
169
  sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
170
  for sentence in sentence_list:
171
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
172
  noise_scale=noise, noise_scale_w=noisew)
173
  tasks.append(params)
174
 
175
- elif self.model_type == "hubert":
176
  params = self.get_infer_param(speaker_id=speaker_id, length_scale=length, noise_scale=noise,
177
  noise_scale_w=noisew, audio_path=audio_path)
178
  tasks.append(params)
179
 
180
- elif self.model_type == "w2v2":
181
  sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
182
  for sentence in sentence_list:
183
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
184
  noise_scale=noise, noise_scale_w=noisew, emotion=emotion)
185
  tasks.append(params)
 
 
186
 
187
  return tasks
188
 
189
  def get_audio(self, voice, auto_break=False):
190
  tasks = self.get_tasks(voice)
191
  # 停顿0.75s,避免语音分段合成再拼接后的连接突兀
192
- brk = np.zeros(int(0.75 * 22050), dtype=np.int16)
193
 
194
  audios = []
195
- for task in tasks:
196
- if auto_break:
 
 
197
  chunk = np.concatenate((self.infer(task), brk), axis=0)
198
  else:
199
  chunk = self.infer(task)
 
4
  import torch
5
  from torch import no_grad, LongTensor, inference_mode, FloatTensor
6
  import utils
7
+ from contants import ModelType
8
+ from utils import get_hparams_from_file, lang_dict
9
  from utils.sentence import sentence_split_and_markup
10
  from vits import commons
11
  from vits.mel_processing import spectrogram_torch
 
13
  from vits.models import SynthesizerTrn
14
 
15
 
 
16
  class VITS:
17
+ def __init__(self, model, config, additional_model=None, model_type=None, device=torch.device("cpu"), **kwargs):
18
  self.model_type = model_type
19
+ self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
20
  self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
21
  self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
22
  self.speakers = getattr(self.hps_ms, 'speakers', ['0'])
 
39
  _ = self.net_g_ms.eval()
40
  self.device = device
41
 
42
+ key = getattr(self.hps_ms.data, "text_cleaners", ["none"])[0]
43
+ self.lang = lang_dict.get(key, ["unknown"])
44
+
45
  # load model
46
  self.load_model(model, additional_model)
47
 
48
  def load_model(self, model, additional_model=None):
49
  utils.load_checkpoint(model, self.net_g_ms)
50
  self.net_g_ms.to(self.device)
51
+ if self.model_type == ModelType.HUBERT_VITS:
52
  self.hubert = additional_model
53
+ elif self.model_type == ModelType.W2V2_VITS:
54
  self.emotion_reference = additional_model
55
 
56
  def get_cleaned_text(self, text, hps, cleaned=False):
 
75
  def get_speakers(self, escape=False):
76
  return self.speakers
77
 
78
+ @property
79
+ def sampling_rate(self):
80
+ return self.hps_ms.data.sampling_rate
81
+
82
  def infer(self, params):
83
  with no_grad():
84
  x_tst = params.get("stn_tst").unsqueeze(0).to(self.device)
85
  x_tst_lengths = LongTensor([params.get("stn_tst").size(0)]).to(self.device)
86
  x_tst_prosody = torch.FloatTensor(params.get("char_embeds")).unsqueeze(0).to(
87
  self.device) if self.bert_embedding else None
88
+ sid = params.get("sid").to(self.device)
89
  emotion = params.get("emotion").to(self.device) if self.emotion_embedding else None
90
 
91
  audio = self.net_g_ms.infer(x=x_tst,
 
105
  emotion=None, cleaned=False, f0_scale=1):
106
  emo = None
107
  char_embeds = None
108
+ if self.model_type != ModelType.HUBERT_VITS:
109
  if self.bert_embedding:
110
  stn_tst, char_embeds = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
 
111
  else:
112
  stn_tst = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
113
+ sid = LongTensor([speaker_id])
114
 
115
+ if self.model_type == ModelType.W2V2_VITS:
116
  # if emotion_reference.endswith('.npy'):
117
  # emotion = np.load(emotion_reference)
118
  # emotion = FloatTensor(emotion).unsqueeze(0)
 
128
  emo = torch.FloatTensor(self.emotion_reference[emotion]).unsqueeze(0)
129
 
130
 
131
+ elif self.model_type == ModelType.HUBERT_VITS:
132
  if self.use_f0:
133
  audio, sampling_rate = librosa.load(audio_path, sr=self.hps_ms.data.sampling_rate, mono=True)
134
  audio16000 = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
 
172
  if text is not None: text = re.sub(r'\s+', ' ', text).strip()
173
 
174
  tasks = []
175
+ if self.model_type == ModelType.VITS:
176
  sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
177
  for sentence in sentence_list:
178
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
179
  noise_scale=noise, noise_scale_w=noisew)
180
  tasks.append(params)
181
 
182
+ elif self.model_type == ModelType.HUBERT_VITS:
183
  params = self.get_infer_param(speaker_id=speaker_id, length_scale=length, noise_scale=noise,
184
  noise_scale_w=noisew, audio_path=audio_path)
185
  tasks.append(params)
186
 
187
+ elif self.model_type == ModelType.W2V2_VITS:
188
  sentence_list = sentence_split_and_markup(text, max, lang, speaker_lang)
189
  for sentence in sentence_list:
190
  params = self.get_infer_param(text=sentence, speaker_id=speaker_id, length_scale=length,
191
  noise_scale=noise, noise_scale_w=noisew, emotion=emotion)
192
  tasks.append(params)
193
+ else:
194
+ raise ValueError(f"Unsupported model type: {self.model_type}")
195
 
196
  return tasks
197
 
198
  def get_audio(self, voice, auto_break=False):
199
  tasks = self.get_tasks(voice)
200
  # 停顿0.75s,避免语音分段合成再拼接后的连接突兀
201
+ brk = np.zeros(int(0.75 * self.sampling_rate), dtype=np.int16)
202
 
203
  audios = []
204
+ num_tasks = len(tasks)
205
+
206
+ for i, task in enumerate(tasks):
207
+ if auto_break and i < num_tasks - 1:
208
  chunk = np.concatenate((self.infer(task), brk), axis=0)
209
  else:
210
  chunk = self.infer(task)
voice.py CHANGED
@@ -8,8 +8,9 @@ import config
8
  import soundfile as sf
9
  from io import BytesIO
10
  from graiax import silkcoder
11
- import utils
12
  from logger import logger
 
 
13
 
14
 
15
  # torch.set_num_threads(1) # 设置torch线程为1
@@ -21,11 +22,11 @@ class TTS:
21
  self._voice_speakers = voice_speakers
22
  self._strength_dict = {"x-weak": 0.25, "weak": 0.5, "Medium": 0.75, "Strong": 1, "x-strong": 1.25}
23
  self._speakers_count = sum([len(self._voice_speakers[i]) for i in self._voice_speakers])
24
- self._vits_speakers_count = len(self._voice_speakers["VITS"])
25
- self._hubert_speakers_count = len(self._voice_speakers["HUBERT-VITS"])
26
- self._w2v2_speakers_count = len(self._voice_speakers["W2V2-VITS"])
27
  self._w2v2_emotion_count = kwargs.get("w2v2_emotion_count", 0)
28
- self._bert_vits2_speakers_count = len(self._voice_speakers["BERT-VITS2"])
29
  self.dem = None
30
 
31
  # Initialization information
@@ -42,11 +43,11 @@ class TTS:
42
  except Exception as e:
43
  self.logger.warning(f"Load DIMENSIONAL_EMOTION_MODEL failed {e}")
44
 
45
- if self._vits_speakers_count != 0: self.logger.info(f"[VITS] {self._vits_speakers_count} speakers")
46
- if self._hubert_speakers_count != 0: self.logger.info(f"[hubert] {self._hubert_speakers_count} speakers")
47
- if self._w2v2_speakers_count != 0: self.logger.info(f"[w2v2] {self._w2v2_speakers_count} speakers")
48
  if self._bert_vits2_speakers_count != 0: self.logger.info(
49
- f"[Bert-VITS2] {self._bert_vits2_speakers_count} speakers")
50
  self.logger.info(f"{self._speakers_count} speakers in total.")
51
  if self._speakers_count == 0:
52
  self.logger.warning(f"No model was loaded.")
@@ -124,6 +125,15 @@ class TTS:
124
  break
125
  yield chunk
126
 
 
 
 
 
 
 
 
 
 
127
  def parse_ssml(self, ssml):
128
  root = ET.fromstring(ssml)
129
  format = root.attrib.get("format", "wav")
@@ -140,9 +150,11 @@ class TTS:
140
  noisew = float(element.attrib.get("noisew", root.attrib.get("noisew", config.NOISEW)))
141
  max = int(element.attrib.get("max", root.attrib.get("max", "0")))
142
  # 不填写默认就是vits
143
- model = element.attrib.get("model", root.attrib.get("model", "vits"))
144
  # w2v2-vits/emotion-vits才有emotion
145
  emotion = int(element.attrib.get("emotion", root.attrib.get("emotion", 0)))
 
 
146
 
147
  voice_element = ET.tostring(element, encoding='unicode')
148
 
@@ -177,8 +189,9 @@ class TTS:
177
  "noise": noise,
178
  "noisew": noisew,
179
  "max": max,
180
- "model": model,
181
- "emotion": emotion
 
182
  })
183
 
184
  # 分段末尾停顿0.75s
@@ -197,39 +210,49 @@ class TTS:
197
 
198
  return voice_tasks, format
199
 
200
- def create_ssml_infer_task(self, tasks, format):
201
  audios = []
 
 
202
  for task in tasks:
203
  if task.get("break"):
204
  audios.append(np.zeros(int(task.get("break") * 22050), dtype=np.int16))
 
205
  else:
206
- model = task.get("model").upper()
207
- if model != "VITS" and model != "W2V2-VITS" and model != "EMOTION-VITS":
208
- raise ValueError(f"Unsupported model: {task.get('model')}")
209
- voice_obj = self._voice_obj[model][task.get("id")][1]
210
- task["id"] = self._voice_obj[model][task.get("id")][0]
 
 
 
 
211
  audio = voice_obj.get_audio(task)
212
  audios.append(audio)
213
-
214
- audio = np.concatenate(audios, axis=0)
215
- encoded_audio = self.encode(voice_obj.hps_ms.data.sampling_rate, audio, format)
 
 
 
216
  return encoded_audio
217
 
218
  def vits_infer(self, task):
219
  format = task.get("format", "wav")
220
- voice_obj = self._voice_obj["VITS"][task.get("id")][1]
221
- real_id = self._voice_obj["VITS"][task.get("id")][0]
222
  task["id"] = real_id # Change to real id
223
- sampling_rate = voice_obj.hps_ms.data.sampling_rate
224
  audio = voice_obj.get_audio(task, auto_break=True)
225
  encoded_audio = self.encode(sampling_rate, audio, format)
226
  return encoded_audio
227
 
228
  def stream_vits_infer(self, task, fname=None):
229
  format = task.get("format", "wav")
230
- voice_obj = self._voice_obj["VITS"][task.get("id")][1]
231
- task["id"] = self._voice_obj["VITS"][task.get("id")][0]
232
- sampling_rate = voice_obj.hps_ms.data.sampling_rate
233
  genertator = voice_obj.get_stream_audio(task, auto_break=True)
234
  # audio = BytesIO()
235
  for chunk in genertator:
@@ -244,18 +267,18 @@ class TTS:
244
 
245
  def hubert_vits_infer(self, task):
246
  format = task.get("format", "wav")
247
- voice_obj = self._voice_obj["HUBERT-VITS"][task.get("id")][1]
248
- task["id"] = self._voice_obj["HUBERT-VITS"][task.get("id")][0]
249
- sampling_rate = voice_obj.hps_ms.data.sampling_rate
250
  audio = voice_obj.get_audio(task)
251
  encoded_audio = self.encode(sampling_rate, audio, format)
252
  return encoded_audio
253
 
254
  def w2v2_vits_infer(self, task):
255
  format = task.get("format", "wav")
256
- voice_obj = self._voice_obj["W2V2-VITS"][task.get("id")][1]
257
- task["id"] = self._voice_obj["W2V2-VITS"][task.get("id")][0]
258
- sampling_rate = voice_obj.hps_ms.data.sampling_rate
259
  audio = voice_obj.get_audio(task, auto_break=True)
260
  encoded_audio = self.encode(sampling_rate, audio, format)
261
  return encoded_audio
@@ -265,17 +288,17 @@ class TTS:
265
  target_id = task.get("target_id")
266
  format = task.get("format")
267
 
268
- original_id_obj = int(self._voice_obj["VITS"][original_id][2])
269
- target_id_obj = int(self._voice_obj["VITS"][target_id][2])
270
 
271
  if original_id_obj != target_id_obj:
272
  raise ValueError(f"speakers are in diffrent VITS Model")
273
 
274
- task["original_id"] = int(self._voice_obj["VITS"][original_id][0])
275
- task["target_id"] = int(self._voice_obj["VITS"][target_id][0])
276
 
277
- voice_obj = self._voice_obj["VITS"][original_id][1]
278
- sampling_rate = voice_obj.hps_ms.data.sampling_rate
279
 
280
  audio = voice_obj.voice_conversion(task)
281
  encoded_audio = self.encode(sampling_rate, audio, format)
@@ -294,9 +317,9 @@ class TTS:
294
 
295
  def bert_vits2_infer(self, task):
296
  format = task.get("format", "wav")
297
- voice_obj = self._voice_obj["BERT-VITS2"][task.get("id")][1]
298
- task["id"] = self._voice_obj["BERT-VITS2"][task.get("id")][0]
299
- sampling_rate = voice_obj.hps_ms.data.sampling_rate
300
  audio = voice_obj.get_audio(task, auto_break=True)
301
  encoded_audio = self.encode(sampling_rate, audio, format)
302
  return encoded_audio
 
8
  import soundfile as sf
9
  from io import BytesIO
10
  from graiax import silkcoder
 
11
  from logger import logger
12
+ from contants import ModelType
13
+ from scipy.signal import resample_poly
14
 
15
 
16
  # torch.set_num_threads(1) # 设置torch线程为1
 
22
  self._voice_speakers = voice_speakers
23
  self._strength_dict = {"x-weak": 0.25, "weak": 0.5, "Medium": 0.75, "Strong": 1, "x-strong": 1.25}
24
  self._speakers_count = sum([len(self._voice_speakers[i]) for i in self._voice_speakers])
25
+ self._vits_speakers_count = len(self._voice_speakers[ModelType.VITS.value])
26
+ self._hubert_speakers_count = len(self._voice_speakers[ModelType.HUBERT_VITS.value])
27
+ self._w2v2_speakers_count = len(self._voice_speakers[ModelType.W2V2_VITS.value])
28
  self._w2v2_emotion_count = kwargs.get("w2v2_emotion_count", 0)
29
+ self._bert_vits2_speakers_count = len(self._voice_speakers[ModelType.BERT_VITS2.value])
30
  self.dem = None
31
 
32
  # Initialization information
 
43
  except Exception as e:
44
  self.logger.warning(f"Load DIMENSIONAL_EMOTION_MODEL failed {e}")
45
 
46
+ if self._vits_speakers_count != 0: self.logger.info(f"[{ModelType.VITS.value}] {self._vits_speakers_count} speakers")
47
+ if self._hubert_speakers_count != 0: self.logger.info(f"[{ModelType.HUBERT_VITS.value}] {self._hubert_speakers_count} speakers")
48
+ if self._w2v2_speakers_count != 0: self.logger.info(f"[{ModelType.W2V2_VITS.value}] {self._w2v2_speakers_count} speakers")
49
  if self._bert_vits2_speakers_count != 0: self.logger.info(
50
+ f"[{ModelType.BERT_VITS2.value}] {self._bert_vits2_speakers_count} speakers")
51
  self.logger.info(f"{self._speakers_count} speakers in total.")
52
  if self._speakers_count == 0:
53
  self.logger.warning(f"No model was loaded.")
 
125
  break
126
  yield chunk
127
 
128
+ def resample_audio(self, audio, orig_sr, target_sr):
129
+ if orig_sr == target_sr:
130
+ return audio
131
+
132
+ gcd = np.gcd(orig_sr, target_sr)
133
+ audio = resample_poly(audio, target_sr // gcd, orig_sr // gcd)
134
+
135
+ return audio
136
+
137
  def parse_ssml(self, ssml):
138
  root = ET.fromstring(ssml)
139
  format = root.attrib.get("format", "wav")
 
150
  noisew = float(element.attrib.get("noisew", root.attrib.get("noisew", config.NOISEW)))
151
  max = int(element.attrib.get("max", root.attrib.get("max", "0")))
152
  # 不填写默认就是vits
153
+ model_type = element.attrib.get("model_type", root.attrib.get("model_type", "vits"))
154
  # w2v2-vits/emotion-vits才有emotion
155
  emotion = int(element.attrib.get("emotion", root.attrib.get("emotion", 0)))
156
+ # Bert-VITS2的参数
157
+ sdp_ratio = int(element.attrib.get("sdp_ratio", root.attrib.get("sdp_ratio", config.SDP_RATIO)))
158
 
159
  voice_element = ET.tostring(element, encoding='unicode')
160
 
 
189
  "noise": noise,
190
  "noisew": noisew,
191
  "max": max,
192
+ "model_type": model_type,
193
+ "emotion": emotion,
194
+ "sdp_ratio": sdp_ratio
195
  })
196
 
197
  # 分段末尾停顿0.75s
 
210
 
211
  return voice_tasks, format
212
 
213
+ def process_ssml_infer_task(self, tasks, format):
214
  audios = []
215
+ sampling_rates = []
216
+ last_sampling_rate = 22050
217
  for task in tasks:
218
  if task.get("break"):
219
  audios.append(np.zeros(int(task.get("break") * 22050), dtype=np.int16))
220
+ sampling_rates.append(last_sampling_rate)
221
  else:
222
+ model_type_str = task.get("model_type").upper()
223
+ if model_type_str not in [ModelType.VITS.value, ModelType.W2V2_VITS.value, ModelType.BERT_VITS2.value]:
224
+ raise ValueError(f"Unsupported model type: {task.get('model_type')}")
225
+ model_type = ModelType(model_type_str)
226
+ voice_obj = self._voice_obj[model_type][task.get("id")][1]
227
+ real_id = self._voice_obj[model_type][task.get("id")][0]
228
+ task["id"] = real_id
229
+ sampling_rates.append(voice_obj.sampling_rate)
230
+ last_sampling_rate = voice_obj.sampling_rate
231
  audio = voice_obj.get_audio(task)
232
  audios.append(audio)
233
+ # 得到最高的采样率
234
+ target_sr = max(sampling_rates)
235
+ # 所有音频要与最高采样率保持一致
236
+ resampled_audios = [self.resample_audio(audio, sr, target_sr) for audio, sr in zip(audios, sampling_rates)]
237
+ audio = np.concatenate(resampled_audios, axis=0)
238
+ encoded_audio = self.encode(target_sr, audio, format)
239
  return encoded_audio
240
 
241
  def vits_infer(self, task):
242
  format = task.get("format", "wav")
243
+ voice_obj = self._voice_obj[ModelType.VITS][task.get("id")][1]
244
+ real_id = self._voice_obj[ModelType.VITS][task.get("id")][0]
245
  task["id"] = real_id # Change to real id
246
+ sampling_rate = voice_obj.sampling_rate
247
  audio = voice_obj.get_audio(task, auto_break=True)
248
  encoded_audio = self.encode(sampling_rate, audio, format)
249
  return encoded_audio
250
 
251
  def stream_vits_infer(self, task, fname=None):
252
  format = task.get("format", "wav")
253
+ voice_obj = self._voice_obj[ModelType.VITS][task.get("id")][1]
254
+ task["id"] = self._voice_obj[ModelType.VITS][task.get("id")][0]
255
+ sampling_rate = voice_obj.sampling_rate
256
  genertator = voice_obj.get_stream_audio(task, auto_break=True)
257
  # audio = BytesIO()
258
  for chunk in genertator:
 
267
 
268
  def hubert_vits_infer(self, task):
269
  format = task.get("format", "wav")
270
+ voice_obj = self._voice_obj[ModelType.HUBERT_VITS][task.get("id")][1]
271
+ task["id"] = self._voice_obj[ModelType.HUBERT_VITS][task.get("id")][0]
272
+ sampling_rate = voice_obj.sampling_rate
273
  audio = voice_obj.get_audio(task)
274
  encoded_audio = self.encode(sampling_rate, audio, format)
275
  return encoded_audio
276
 
277
  def w2v2_vits_infer(self, task):
278
  format = task.get("format", "wav")
279
+ voice_obj = self._voice_obj[ModelType.W2V2_VITS][task.get("id")][1]
280
+ task["id"] = self._voice_obj[ModelType.W2V2_VITS][task.get("id")][0]
281
+ sampling_rate = voice_obj.sampling_rate
282
  audio = voice_obj.get_audio(task, auto_break=True)
283
  encoded_audio = self.encode(sampling_rate, audio, format)
284
  return encoded_audio
 
288
  target_id = task.get("target_id")
289
  format = task.get("format")
290
 
291
+ original_id_obj = int(self._voice_obj[ModelType.VITS][original_id][2])
292
+ target_id_obj = int(self._voice_obj[ModelType.VITS][target_id][2])
293
 
294
  if original_id_obj != target_id_obj:
295
  raise ValueError(f"speakers are in diffrent VITS Model")
296
 
297
+ task["original_id"] = int(self._voice_obj[ModelType.VITS][original_id][0])
298
+ task["target_id"] = int(self._voice_obj[ModelType.VITS][target_id][0])
299
 
300
+ voice_obj = self._voice_obj[ModelType.VITS][original_id][1]
301
+ sampling_rate = voice_obj.sampling_rate
302
 
303
  audio = voice_obj.voice_conversion(task)
304
  encoded_audio = self.encode(sampling_rate, audio, format)
 
317
 
318
  def bert_vits2_infer(self, task):
319
  format = task.get("format", "wav")
320
+ voice_obj = self._voice_obj[ModelType.BERT_VITS2][task.get("id")][1]
321
+ task["id"] = self._voice_obj[ModelType.BERT_VITS2][task.get("id")][0]
322
+ sampling_rate = voice_obj.sampling_rate
323
  audio = voice_obj.get_audio(task, auto_break=True)
324
  encoded_audio = self.encode(sampling_rate, audio, format)
325
  return encoded_audio